diff options
author | Andrew Hauck <[email protected]> | 2024-07-20 14:50:00 -0700 |
---|---|---|
committer | Yuchen Wu <[email protected]> | 2024-07-26 13:35:13 -0700 |
commit | a51874039fd788b89dbc354386620ac9c98d5a59 (patch) | |
tree | b9ccdd118f8600a1143b074bf95fd419da466c58 | |
parent | 7c122e7f36de5c946ac960a1691c5dd41f26e6e6 (diff) | |
download | pingora-a51874039fd788b89dbc354386620ac9c98d5a59.tar.gz pingora-a51874039fd788b89dbc354386620ac9c98d5a59.zip |
Add callback function to Backends update() to address atomicity issue when building selector
-rw-r--r-- | .bleep | 2 | ||||
-rw-r--r-- | pingora-load-balancing/src/lib.rs | 119 |
2 files changed, 100 insertions, 21 deletions
@@ -1 +1 @@ -78a170341a0fb030b8bcb2afe84afb268cdc5b2d
\ No newline at end of file +9fdf48d67b78675c989f51ec18829a81fe6976ef
\ No newline at end of file diff --git a/pingora-load-balancing/src/lib.rs b/pingora-load-balancing/src/lib.rs index c950c85..07062df 100644 --- a/pingora-load-balancing/src/lib.rs +++ b/pingora-load-balancing/src/lib.rs @@ -130,8 +130,17 @@ impl Backends { self.health_check = Some(hc.into()) } - /// Return true when the new is different from the current set of backends - fn do_update(&self, new_backends: BTreeSet<Backend>, enablement: HashMap<u64, bool>) -> bool { + /// Updates backends when the new is different from the current set, + /// the callback will be invoked when the new set of backend is different + /// from the current one so that the caller can update the selector accordingly. + fn do_update<F>( + &self, + new_backends: BTreeSet<Backend>, + enablement: HashMap<u64, bool>, + callback: F, + ) where + F: Fn(Arc<BTreeSet<Backend>>), + { if (**self.backends.load()) != new_backends { let old_health = self.health.load(); let mut health = HashMap::with_capacity(new_backends.len()); @@ -147,10 +156,14 @@ impl Backends { health.insert(hash_key, backend_health); } - // TODO: put backend and health under 1 ArcSwap so that this update is atomic - self.backends.store(Arc::new(new_backends)); + // TODO: put this all under 1 ArcSwap so the update is atomic + // It's important the `callback()` executes first since computing selector backends might + // be expensive. For example, if a caller checks `backends` to see if any are available + // they may encounter false positives if the selector isn't ready yet. + let new_backends = Arc::new(new_backends); + callback(new_backends.clone()); + self.backends.store(new_backends); self.health.store(Arc::new(health)); - true } else { // no backend change, just check enablement for (hash_key, backend_enabled) in enablement.iter() { @@ -160,7 +173,6 @@ impl Backends { backend_health.enable(*backend_enabled); } } - false } } @@ -199,12 +211,15 @@ impl Backends { /// Call the service discovery method to update the collection of backends. /// - /// Return `true` when the new collection is different from the current set of backends. - /// This return value is useful to tell the caller when to rebuild things that are expensive to - /// update, such as consistent hashing rings. - pub async fn update(&self) -> Result<bool> { + /// The callback will be invoked when the new set of backend is different + /// from the current one so that the caller can update the selector accordingly. + pub async fn update<F>(&self, callback: F) -> Result<()> + where + F: Fn(Arc<BTreeSet<Backend>>), + { let (new_backends, enablement) = self.discovery.discover().await?; - Ok(self.do_update(new_backends, enablement)) + self.do_update(new_backends, enablement, callback); + Ok(()) } /// Run health check on all backends if it is set. @@ -320,11 +335,9 @@ where /// This function will be called every `update_frequency` if this [LoadBalancer] instance /// is running as a background service. pub async fn update(&self) -> Result<()> { - if self.backends.update().await? { - self.selector - .store(Arc::new(S::build(&self.backends.get_backend()))) - } - Ok(()) + self.backends + .update(|backends| self.selector.store(Arc::new(S::build(&backends)))) + .await } /// Return the first healthy [Backend] according to the selection algorithm and the @@ -378,6 +391,8 @@ where #[cfg(test)] mod test { + use std::sync::atomic::{AtomicBool, Ordering::Relaxed}; + use super::*; use async_trait::async_trait; @@ -408,10 +423,20 @@ mod test { backends.set_health_check(check); // true: new backend discovered - assert!(backends.update().await.unwrap()); + let updated = AtomicBool::new(false); + backends + .update(|_| updated.store(true, Relaxed)) + .await + .unwrap(); + assert!(updated.load(Relaxed)); // false: no new backend discovered - assert!(!backends.update().await.unwrap()); + let updated = AtomicBool::new(false); + backends + .update(|_| updated.store(true, Relaxed)) + .await + .unwrap(); + assert!(!updated.load(Relaxed)); backends.run_health_check(false).await; @@ -449,7 +474,14 @@ mod test { let discovery = TestDiscovery(discovery); let backends = Backends::new(Box::new(discovery)); - assert!(backends.update().await.unwrap()); + + // true: new backend discovered + let updated = AtomicBool::new(false); + backends + .update(|_| updated.store(true, Relaxed)) + .await + .unwrap(); + assert!(updated.load(Relaxed)); let backend = backends.get_backend(); assert!(backend.contains(&good1)); @@ -476,7 +508,12 @@ mod test { backends.set_health_check(check); // true: new backend discovered - assert!(backends.update().await.unwrap()); + let updated = AtomicBool::new(false); + backends + .update(|_| updated.store(true, Relaxed)) + .await + .unwrap(); + assert!(updated.load(Relaxed)); backends.run_health_check(true).await; @@ -484,4 +521,46 @@ mod test { assert!(backends.ready(&good2)); assert!(!backends.ready(&bad)); } + + mod thread_safety { + use super::*; + + struct MockDiscovery { + expected: usize, + } + #[async_trait] + impl ServiceDiscovery for MockDiscovery { + async fn discover(&self) -> Result<(BTreeSet<Backend>, HashMap<u64, bool>)> { + let mut d = BTreeSet::new(); + let mut m = HashMap::with_capacity(self.expected); + for i in 0..self.expected { + let b = Backend::new(&format!("1.1.1.1:{i}")).unwrap(); + m.insert(i as u64, true); + d.insert(b); + } + Ok((d, m)) + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_consistency() { + let expected = 3000; + let discovery = MockDiscovery { expected }; + let lb = Arc::new(LoadBalancer::<selection::Consistent>::from_backends( + Backends::new(Box::new(discovery)), + )); + let lb2 = lb.clone(); + + tokio::spawn(async move { + assert!(lb2.update().await.is_ok()); + }); + let mut backend_count = 0; + while backend_count == 0 { + let backends = lb.backends(); + backend_count = backends.backends.load_full().len(); + } + assert_eq!(backend_count, expected); + assert!(lb.select_with(b"test", 1, |_, _| true).is_some()); + } + } } |