// Copyright 2024 Cloudflare, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! Pingora tokio runtime. //! //! Tokio runtime comes in two flavors: a single-threaded runtime //! and a multi-threaded one which provides work stealing. //! Benchmark shows that, compared to the single-threaded runtime, the multi-threaded one //! has some overhead due to its more sophisticated work steal scheduling. //! //! This crate provides a third flavor: a multi-threaded runtime without work stealing. //! This flavor is as efficient as the single-threaded runtime while allows the async //! program to use multiple cores. use once_cell::sync::{Lazy, OnceCell}; use rand::Rng; use std::sync::Arc; use std::thread::JoinHandle; use std::time::Duration; use thread_local::ThreadLocal; use tokio::runtime::{Builder, Handle}; use tokio::sync::oneshot::{channel, Sender}; /// Pingora async multi-threaded runtime /// /// The `Steal` flavor is effectively tokio multi-threaded runtime. /// /// The `NoSteal` flavor is backed by multiple tokio single-threaded runtime. pub enum Runtime { Steal(tokio::runtime::Runtime), NoSteal(NoStealRuntime), } impl Runtime { /// Create a `Steal` flavor runtime. This just a regular tokio runtime pub fn new_steal(threads: usize, name: &str) -> Self { Self::Steal( Builder::new_multi_thread() .enable_all() .worker_threads(threads) .thread_name(name) .build() .unwrap(), ) } /// Create a `NoSteal` flavor runtime. This is backed by multiple tokio current-thread runtime pub fn new_no_steal(threads: usize, name: &str) -> Self { Self::NoSteal(NoStealRuntime::new(threads, name)) } /// Return the &[Handle] of the [Runtime]. /// For `Steal` flavor, it will just return the &[Handle]. /// For `NoSteal` flavor, it will return the &[Handle] of a random thread in its pool. /// So if we want tasks to spawn on all the threads, call this function to get a fresh [Handle] /// for each async task. pub fn get_handle(&self) -> &Handle { match self { Self::Steal(r) => r.handle(), Self::NoSteal(r) => r.get_runtime(), } } /// Call tokio's `shutdown_timeout` of all the runtimes. This function is blocking until /// all runtimes exit. pub fn shutdown_timeout(self, timeout: Duration) { match self { Self::Steal(r) => r.shutdown_timeout(timeout), Self::NoSteal(r) => r.shutdown_timeout(timeout), } } } // only NoStealRuntime set the pools in thread threads static CURRENT_HANDLE: Lazy> = Lazy::new(ThreadLocal::new); /// Return the [Handle] of current runtime. /// If the current thread is under a `Steal` runtime, the current [Handle] is returned. /// If the current thread is under a `NoSteal` runtime, the [Handle] of a random thread /// under this runtime is returned. This function will panic if called outside any runtime. pub fn current_handle() -> Handle { if let Some(pools) = CURRENT_HANDLE.get() { // safety: the CURRENT_HANDLE is set when the pool is being initialized in init_pools() let pools = pools.get().unwrap(); let mut rng = rand::thread_rng(); let index = rng.gen_range(0..pools.len()); pools[index].clone() } else { // not NoStealRuntime, just check the current tokio runtime Handle::current() } } type Control = (Sender, JoinHandle<()>); type Pools = Arc>>; /// Multi-threaded runtime backed by a pool of single threaded tokio runtime pub struct NoStealRuntime { threads: usize, name: String, // Lazily init the runtimes so that they are created after pingora // daemonize itself. Otherwise the runtime threads are lost. pools: Arc>>, controls: OnceCell>, } impl NoStealRuntime { /// Create a new [NoStealRuntime]. Panic if `threads` is 0 pub fn new(threads: usize, name: &str) -> Self { assert!(threads != 0); NoStealRuntime { threads, name: name.to_string(), pools: Arc::new(OnceCell::new()), controls: OnceCell::new(), } } fn init_pools(&self) -> (Box<[Handle]>, Vec) { let mut pools = Vec::with_capacity(self.threads); let mut controls = Vec::with_capacity(self.threads); for _ in 0..self.threads { let rt = Builder::new_current_thread().enable_all().build().unwrap(); let handler = rt.handle().clone(); let (tx, rx) = channel::(); let pools_ref = self.pools.clone(); let join = std::thread::Builder::new() .name(self.name.clone()) .spawn(move || { CURRENT_HANDLE.get_or(|| pools_ref); if let Ok(timeout) = rt.block_on(rx) { rt.shutdown_timeout(timeout); } // else Err(_): tx is dropped, just exit }) .unwrap(); pools.push(handler); controls.push((tx, join)); } (pools.into_boxed_slice(), controls) } /// Return the &[Handle] of a random thread of this runtime pub fn get_runtime(&self) -> &Handle { let mut rng = rand::thread_rng(); let index = rng.gen_range(0..self.threads); self.get_runtime_at(index) } /// Return the number of threads of this runtime pub fn threads(&self) -> usize { self.threads } fn get_pools(&self) -> &[Handle] { if let Some(p) = self.pools.get() { p } else { // TODO: use a mutex to avoid creating a lot threads only to drop them let (pools, controls) = self.init_pools(); // there could be another thread racing with this one to init the pools match self.pools.try_insert(pools) { Ok(p) => { // unwrap to make sure that this is the one that init both pools and controls self.controls.set(controls).unwrap(); p } // another thread already set it, just return it Err((p, _my_pools)) => p, } } } /// Return the &[Handle] of a given thread of this runtime pub fn get_runtime_at(&self, index: usize) -> &Handle { let pools = self.get_pools(); &pools[index] } /// Call tokio's `shutdown_timeout` of all the runtimes. This function is blocking until /// all runtimes exit. pub fn shutdown_timeout(mut self, timeout: Duration) { if let Some(controls) = self.controls.take() { let (txs, joins): (Vec>, Vec>) = controls.into_iter().unzip(); for tx in txs { let _ = tx.send(timeout); // Err() when rx is dropped } for join in joins { let _ = join.join(); // ignore thread error } } // else, the controls and the runtimes are not even init yet, just return; } // TODO: runtime metrics } #[test] fn test_steal_runtime() { use tokio::time::{sleep, Duration}; let rt = Runtime::new_steal(2, "test"); let handle = rt.get_handle(); let ret = handle.block_on(async { sleep(Duration::from_secs(1)).await; let handle = current_handle(); let join = handle.spawn(async { sleep(Duration::from_secs(1)).await; }); join.await.unwrap(); 1 }); assert_eq!(ret, 1); } #[test] fn test_no_steal_runtime() { use tokio::time::{sleep, Duration}; let rt = Runtime::new_no_steal(2, "test"); let handle = rt.get_handle(); let ret = handle.block_on(async { sleep(Duration::from_secs(1)).await; let handle = current_handle(); let join = handle.spawn(async { sleep(Duration::from_secs(1)).await; }); join.await.unwrap(); 1 }); assert_eq!(ret, 1); } #[test] fn test_no_steal_shutdown() { use tokio::time::{sleep, Duration}; let rt = Runtime::new_no_steal(2, "test"); let handle = rt.get_handle(); let ret = handle.block_on(async { sleep(Duration::from_secs(1)).await; let handle = current_handle(); let join = handle.spawn(async { sleep(Duration::from_secs(1)).await; }); join.await.unwrap(); 1 }); assert_eq!(ret, 1); rt.shutdown_timeout(Duration::from_secs(1)); }