aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorSucipto <[email protected]>2024-11-08 05:58:31 +0700
committerGitHub <[email protected]>2024-11-07 17:58:31 -0500
commit825fe48e0654dc6e4e065df364a51ea79488e44b (patch)
treefe045e694b6dea8b44904f5a13c5da46ff0ea614
parentb28576396956b3e5642c9f7949b750e2d20b6441 (diff)
downloadcaddy-825fe48e0654dc6e4e065df364a51ea79488e44b.tar.gz
caddy-825fe48e0654dc6e4e065df364a51ea79488e44b.zip
reverseproxy: Allow `0` as weights for `weighted_round_robin` (#6681)
* Allow 0 as weights Change positive to non-negative * reverseproxy: allow 0 as weighted round robin value * test: add more wrr select test --------- Co-authored-by: peanutduck <[email protected]>
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies.go19
-rw-r--r--modules/caddyhttp/reverseproxy/selectionpolicies_test.go52
2 files changed, 65 insertions, 6 deletions
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go
index 293ff75e2..fcf7f90f6 100644
--- a/modules/caddyhttp/reverseproxy/selectionpolicies.go
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go
@@ -111,8 +111,8 @@ func (r *WeightedRoundRobinSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser)
if err != nil {
return d.Errf("invalid weight value '%s': %v", weight, err)
}
- if weightInt < 1 {
- return d.Errf("invalid weight value '%s': weight should be non-zero and positive", weight)
+ if weightInt < 0 {
+ return d.Errf("invalid weight value '%s': weight should be non-negative", weight)
}
r.Weights = append(r.Weights, weightInt)
}
@@ -136,8 +136,15 @@ func (r *WeightedRoundRobinSelection) Select(pool UpstreamPool, _ *http.Request,
return pool[0]
}
var index, totalWeight int
+ var weights []int
+
+ for _, w := range r.Weights {
+ if w > 0 {
+ weights = append(weights, w)
+ }
+ }
currentWeight := int(atomic.AddUint32(&r.index, 1)) % r.totalWeight
- for i, weight := range r.Weights {
+ for i, weight := range weights {
totalWeight += weight
if currentWeight < totalWeight {
index = i
@@ -145,9 +152,9 @@ func (r *WeightedRoundRobinSelection) Select(pool UpstreamPool, _ *http.Request,
}
}
- upstreams := make([]*Upstream, 0, len(r.Weights))
- for _, upstream := range pool {
- if !upstream.Available() {
+ upstreams := make([]*Upstream, 0, len(weights))
+ for i, upstream := range pool {
+ if !upstream.Available() || r.Weights[i] == 0 {
continue
}
upstreams = append(upstreams, upstream)
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
index a4701ce86..580abbdde 100644
--- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
@@ -131,6 +131,58 @@ func TestWeightedRoundRobinPolicy(t *testing.T) {
}
}
+func TestWeightedRoundRobinPolicyWithZeroWeight(t *testing.T) {
+ pool := testPool()
+ wrrPolicy := WeightedRoundRobinSelection{
+ Weights: []int{0, 2, 1},
+ totalWeight: 3,
+ }
+ req, _ := http.NewRequest("GET", "/", nil)
+
+ h := wrrPolicy.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected first weighted round robin host to be second host in the pool.")
+ }
+
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[2] {
+ t.Error("Expected second weighted round robin host to be third host in the pool.")
+ }
+
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expected third weighted round robin host to be second host in the pool.")
+ }
+
+ // mark second host as down
+ pool[1].setHealthy(false)
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[2] {
+ t.Error("Expect select next available host.")
+ }
+
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[2] {
+ t.Error("Expect select only available host.")
+ }
+ // mark second host as up
+ pool[1].setHealthy(true)
+
+ h = wrrPolicy.Select(pool, req, nil)
+ if h != pool[1] {
+ t.Error("Expect select first host on availability.")
+ }
+
+ // test next select in full cycle
+ expected := []*Upstream{pool[1], pool[2], pool[1], pool[1], pool[2], pool[1]}
+ for i, want := range expected {
+ got := wrrPolicy.Select(pool, req, nil)
+ if want != got {
+ t.Errorf("Selection %d: got host[%s], want host[%s]", i+1, got, want)
+ }
+ }
+}
+
func TestLeastConnPolicy(t *testing.T) {
pool := testPool()
lcPolicy := LeastConnSelection{}