summaryrefslogtreecommitdiffhomepage
path: root/modules
diff options
context:
space:
mode:
authorAziz Rmadi <[email protected]>2024-07-09 02:06:30 -0500
committerGitHub <[email protected]>2024-07-09 03:06:30 -0400
commit630c62b3137abf688aa1a698a614fa28c08e43dd (patch)
treead0f014660b790bc2ed38811460bc250d9d456ce /modules
parent9338741ca79a74247ced86bc26e4994138470852 (diff)
downloadcaddy-630c62b3137abf688aa1a698a614fa28c08e43dd.tar.gz
caddy-630c62b3137abf688aa1a698a614fa28c08e43dd.zip
fixed bug in resolving ip version in dynamic upstreams (#6448)
Diffstat (limited to 'modules')
-rw-r--r--modules/caddyhttp/reverseproxy/upstreams.go26
-rw-r--r--modules/caddyhttp/reverseproxy/upstreams_test.go56
2 files changed, 70 insertions, 12 deletions
diff --git a/modules/caddyhttp/reverseproxy/upstreams.go b/modules/caddyhttp/reverseproxy/upstreams.go
index 46e45c646..c8ba930d2 100644
--- a/modules/caddyhttp/reverseproxy/upstreams.go
+++ b/modules/caddyhttp/reverseproxy/upstreams.go
@@ -231,6 +231,19 @@ type IPVersions struct {
IPv6 *bool `json:"ipv6,omitempty"`
}
+func resolveIpVersion(versions *IPVersions) string {
+ resolveIpv4 := versions == nil || (versions.IPv4 == nil && versions.IPv6 == nil) || (versions.IPv4 != nil && *versions.IPv4)
+ resolveIpv6 := versions == nil || (versions.IPv6 == nil && versions.IPv4 == nil) || (versions.IPv6 != nil && *versions.IPv6)
+ switch {
+ case resolveIpv4 && !resolveIpv6:
+ return "ip4"
+ case !resolveIpv4 && resolveIpv6:
+ return "ip6"
+ default:
+ return "ip"
+ }
+}
+
// AUpstreams provides upstreams from A/AAAA lookups.
// Results are cached and refreshed at the configured
// refresh interval.
@@ -313,9 +326,6 @@ func (au *AUpstreams) Provision(ctx caddy.Context) error {
func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
- resolveIpv4 := au.Versions == nil || au.Versions.IPv4 == nil || *au.Versions.IPv4
- resolveIpv6 := au.Versions == nil || au.Versions.IPv6 == nil || *au.Versions.IPv6
-
// Map ipVersion early, so we can use it as part of the cache-key.
// This should be fairly inexpensive and comes and the upside of
// allowing the same dynamic upstream (name + port combination)
@@ -324,15 +334,7 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// It also forced a cache-miss if a previously cached dynamic
// upstream changes its ip version, e.g. after a config reload,
// while keeping the cache-invalidation as simple as it currently is.
- var ipVersion string
- switch {
- case resolveIpv4 && !resolveIpv6:
- ipVersion = "ip4"
- case !resolveIpv4 && resolveIpv6:
- ipVersion = "ip6"
- default:
- ipVersion = "ip"
- }
+ ipVersion := resolveIpVersion(au.Versions)
auStr := repl.ReplaceAll(au.String()+ipVersion, "")
diff --git a/modules/caddyhttp/reverseproxy/upstreams_test.go b/modules/caddyhttp/reverseproxy/upstreams_test.go
new file mode 100644
index 000000000..48e2d2a63
--- /dev/null
+++ b/modules/caddyhttp/reverseproxy/upstreams_test.go
@@ -0,0 +1,56 @@
+package reverseproxy
+
+import "testing"
+
+func TestResolveIpVersion(t *testing.T) {
+ falseBool := false
+ trueBool := true
+ tests := []struct {
+ Versions *IPVersions
+ expectedIpVersion string
+ }{
+ {
+ Versions: &IPVersions{IPv4: &trueBool},
+ expectedIpVersion: "ip4",
+ },
+ {
+ Versions: &IPVersions{IPv4: &falseBool},
+ expectedIpVersion: "ip",
+ },
+ {
+ Versions: &IPVersions{IPv4: &trueBool, IPv6: &falseBool},
+ expectedIpVersion: "ip4",
+ },
+ {
+ Versions: &IPVersions{IPv6: &trueBool},
+ expectedIpVersion: "ip6",
+ },
+ {
+ Versions: &IPVersions{IPv6: &falseBool},
+ expectedIpVersion: "ip",
+ },
+ {
+ Versions: &IPVersions{IPv6: &trueBool, IPv4: &falseBool},
+ expectedIpVersion: "ip6",
+ },
+ {
+ Versions: &IPVersions{},
+ expectedIpVersion: "ip",
+ },
+ {
+ Versions: &IPVersions{IPv4: &trueBool, IPv6: &trueBool},
+ expectedIpVersion: "ip",
+ },
+ {
+ Versions: &IPVersions{IPv4: &falseBool, IPv6: &falseBool},
+ expectedIpVersion: "ip",
+ },
+ }
+ for _, test := range tests {
+ ipVersion := resolveIpVersion(test.Versions)
+ if ipVersion != test.expectedIpVersion {
+ t.Errorf("resolveIpVersion(): Expected %s got %s", test.expectedIpVersion, ipVersion)
+ }
+ }
+
+}