diff options
23 files changed, 629 insertions, 54 deletions
@@ -1 +1 @@ -aa7c01c0f1c7a33d4ec55cf3531997e4cb2542ab
\ No newline at end of file +ca6a894f27f4448b3d0fe7ab3d84221133e83e5b
\ No newline at end of file diff --git a/pingora-core/src/connectors/l4.rs b/pingora-core/src/connectors/l4.rs index 58bd209..e4f106f 100644 --- a/pingora-core/src/connectors/l4.rs +++ b/pingora-core/src/connectors/l4.rs @@ -17,10 +17,15 @@ use log::debug; use pingora_error::{Context, Error, ErrorType::*, OrErr, Result}; use rand::seq::SliceRandom; use std::net::SocketAddr as InetSocketAddr; +#[cfg(unix)] use std::os::unix::io::AsRawFd; +#[cfg(windows)] +use std::os::windows::io::AsRawSocket; +#[cfg(unix)] +use crate::protocols::l4::ext::connect_uds; use crate::protocols::l4::ext::{ - connect_uds, connect_with as tcp_connect, set_dscp, set_recv_buf, set_tcp_fastopen_connect, + connect_with as tcp_connect, set_dscp, set_recv_buf, set_tcp_fastopen_connect, }; use crate::protocols::l4::socket::SocketAddr; use crate::protocols::l4::stream::Stream; @@ -102,16 +107,21 @@ where match peer_addr { SocketAddr::Inet(addr) => { let connect_future = tcp_connect(addr, bind_to.as_ref(), |socket| { + #[cfg(unix)] + let raw = socket.as_raw_fd(); + #[cfg(windows)] + let raw = socket.as_raw_socket(); + if peer.tcp_fast_open() { - set_tcp_fastopen_connect(socket.as_raw_fd())?; + set_tcp_fastopen_connect(raw)?; } if let Some(recv_buf) = peer.tcp_recv_buf() { debug!("Setting recv buf size"); - set_recv_buf(socket.as_raw_fd(), recv_buf)?; + set_recv_buf(raw, recv_buf)?; } if let Some(dscp) = peer.dscp() { debug!("Setting dscp"); - set_dscp(socket.as_raw_fd(), dscp)?; + set_dscp(raw, dscp)?; } Ok(()) }); @@ -137,6 +147,7 @@ where } } } + #[cfg(unix)] SocketAddr::Unix(addr) => { let connect_future = connect_uds( addr.as_pathname() @@ -179,7 +190,10 @@ where } stream.set_nodelay()?; + #[cfg(unix)] let digest = SocketDigest::from_raw_fd(stream.as_raw_fd()); + #[cfg(windows)] + let digest = SocketDigest::from_raw_socket(stream.as_raw_socket()); digest .peer_addr .set(Some(peer_addr.clone())) @@ -217,6 +231,7 @@ pub(crate) fn bind_to_random<P: Peer>( InetSocketAddr::V4(_) => bind_to_ips(v4_list), InetSocketAddr::V6(_) => bind_to_ips(v6_list), }, + #[cfg(unix)] SocketAddr::Unix(_) => None, }; @@ -235,6 +250,7 @@ pub(crate) fn bind_to_random<P: Peer>( use crate::protocols::raw_connect; +#[cfg(unix)] async fn proxy_connect<P: Peer>(peer: &P) -> Result<Stream> { // safe to unwrap let proxy = peer.get_proxy().unwrap(); @@ -275,6 +291,11 @@ async fn proxy_connect<P: Peer>(peer: &P) -> Result<Stream> { Ok(*stream) } +#[cfg(windows)] +async fn proxy_connect<P: Peer>(peer: &P) -> Result<Stream> { + panic!("peer proxy not supported on windows") +} + #[cfg(test)] mod tests { use super::*; @@ -282,6 +303,7 @@ mod tests { use std::collections::BTreeMap; use std::path::PathBuf; use tokio::io::AsyncWriteExt; + #[cfg(unix)] use tokio::net::UnixListener; #[tokio::test] @@ -359,6 +381,7 @@ mod tests { assert!(new_session.is_ok()); } + #[cfg(unix)] #[tokio::test] async fn test_connect_proxy_fail() { let mut peer = HttpPeer::new("1.1.1.1:80".to_string(), false, "".to_string()); @@ -376,9 +399,11 @@ mod tests { assert!(!e.retry()); } + #[cfg(unix)] const MOCK_UDS_PATH: &str = "/tmp/test_unix_connect_proxy.sock"; // one-off mock server + #[cfg(unix)] async fn mock_connect_server() { let _ = std::fs::remove_file(MOCK_UDS_PATH); let listener = UnixListener::bind(MOCK_UDS_PATH).unwrap(); @@ -410,10 +435,12 @@ mod tests { assert!(new_session.is_ok()); } + #[cfg(unix)] const MOCK_BAD_UDS_PATH: &str = "/tmp/test_unix_bad_connect_proxy.sock"; // one-off mock bad proxy // closes connection upon accepting + #[cfg(unix)] async fn mock_connect_bad_server() { let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH); let listener = UnixListener::bind(MOCK_BAD_UDS_PATH).unwrap(); @@ -424,6 +451,7 @@ mod tests { let _ = std::fs::remove_file(MOCK_BAD_UDS_PATH); } + #[cfg(unix)] #[tokio::test(flavor = "multi_thread")] async fn test_connect_proxy_conn_closed() { tokio::spawn(async { diff --git a/pingora-core/src/connectors/mod.rs b/pingora-core/src/connectors/mod.rs index cbe299f..57c16aa 100644 --- a/pingora-core/src/connectors/mod.rs +++ b/pingora-core/src/connectors/mod.rs @@ -195,11 +195,29 @@ impl TransportConnector { let mut stream = l.into_inner(); // test_reusable_stream: we assume server would never actively send data // first on an idle stream. + #[cfg(unix)] if peer.matches_fd(stream.id()) && test_reusable_stream(&mut stream) { Some(stream) } else { None } + #[cfg(windows)] + { + use std::os::windows::io::{AsRawSocket, RawSocket}; + struct WrappedRawSocket(RawSocket); + impl AsRawSocket for WrappedRawSocket { + fn as_raw_socket(&self) -> RawSocket { + self.0 + } + } + if peer.matches_sock(WrappedRawSocket(stream.id() as RawSocket)) + && test_reusable_stream(&mut stream) + { + Some(stream) + } else { + None + } + } } Err(_) => { error!("failed to acquire reusable stream"); @@ -373,6 +391,7 @@ mod tests { use crate::tls::ssl::SslMethod; use crate::upstreams::peer::BasicPeer; use tokio::io::AsyncWriteExt; + #[cfg(unix)] use tokio::net::UnixListener; // 192.0.2.1 is effectively a black hole @@ -404,9 +423,11 @@ mod tests { assert!(reused); } + #[cfg(unix)] const MOCK_UDS_PATH: &str = "/tmp/test_unix_transport_connector.sock"; // one-off mock server + #[cfg(unix)] async fn mock_connect_server() { let _ = std::fs::remove_file(MOCK_UDS_PATH); let listener = UnixListener::bind(MOCK_UDS_PATH).unwrap(); diff --git a/pingora-core/src/listeners/l4.rs b/pingora-core/src/listeners/l4.rs index 748037a..43c4939 100644 --- a/pingora-core/src/listeners/l4.rs +++ b/pingora-core/src/listeners/l4.rs @@ -20,8 +20,12 @@ use pingora_error::{ use std::fs::Permissions; use std::io::ErrorKind; use std::net::{SocketAddr, ToSocketAddrs}; +#[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd}; +#[cfg(unix)] use std::os::unix::net::UnixListener as StdUnixListener; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket}; use std::time::Duration; use tokio::net::TcpSocket; @@ -29,6 +33,7 @@ use crate::protocols::l4::ext::{set_dscp, set_tcp_fastopen_backlog}; use crate::protocols::l4::listener::Listener; pub use crate::protocols::l4::stream::Stream; use crate::protocols::TcpKeepalive; +#[cfg(unix)] use crate::server::ListenFds; const TCP_LISTENER_MAX_TRY: usize = 30; @@ -40,6 +45,7 @@ const LISTENER_BACKLOG: u32 = 65535; #[derive(Clone, Debug)] pub enum ServerAddress { Tcp(String, Option<TcpSocketOptions>), + #[cfg(unix)] Uds(String, Option<Permissions>), } @@ -47,6 +53,7 @@ impl AsRef<str> for ServerAddress { fn as_ref(&self) -> &str { match &self { Self::Tcp(l, _) => l, + #[cfg(unix)] Self::Uds(l, _) => l, } } @@ -82,6 +89,7 @@ pub struct TcpSocketOptions { // TODO: allow configuring reuseaddr, backlog, etc. from here? } +#[cfg(unix)] mod uds { use super::{OrErr, Result}; use crate::protocols::l4::listener::Listener; @@ -149,19 +157,24 @@ fn apply_tcp_socket_options(sock: &TcpSocket, opt: Option<&TcpSocketOptions>) -> .set_only_v6(ipv6_only) .or_err(BindError, "failed to set IPV6_V6ONLY")?; } + #[cfg(unix)] + let raw = sock.as_raw_fd(); + #[cfg(windows)] + let raw = sock.as_raw_socket(); if let Some(backlog) = opt.tcp_fastopen { - set_tcp_fastopen_backlog(sock.as_raw_fd(), backlog)?; + set_tcp_fastopen_backlog(raw, backlog)?; } if let Some(dscp) = opt.dscp { - set_dscp(sock.as_raw_fd(), dscp)?; + set_dscp(raw, dscp)?; } Ok(()) } fn from_raw_fd(address: &ServerAddress, fd: i32) -> Result<Listener> { match address { + #[cfg(unix)] ServerAddress::Uds(addr, perm) => { let std_listener = unsafe { StdUnixListener::from_raw_fd(fd) }; // set permissions just in case @@ -169,7 +182,10 @@ fn from_raw_fd(address: &ServerAddress, fd: i32) -> Result<Listener> { Ok(uds::set_backlog(std_listener, LISTENER_BACKLOG)?.into()) } ServerAddress::Tcp(_, _) => { + #[cfg(unix)] let std_listener_socket = unsafe { std::net::TcpStream::from_raw_fd(fd) }; + #[cfg(windows)] + let std_listener_socket = unsafe { std::net::TcpStream::from_raw_socket(fd as u64) }; let listener_socket = TcpSocket::from_std_stream(std_listener_socket); // Note that we call listen on an already listening socket // POSIX undefined but on Linux it will update the backlog size @@ -231,6 +247,7 @@ async fn bind_tcp(addr: &str, opt: Option<TcpSocketOptions>) -> Result<Listener> async fn bind(addr: &ServerAddress) -> Result<Listener> { match addr { + #[cfg(unix)] ServerAddress::Uds(l, perm) => uds::bind(l, perm.clone()), ServerAddress::Tcp(l, opt) => bind_tcp(l, opt.clone()).await, } @@ -253,6 +270,7 @@ impl ListenerEndpoint { self.listen_addr.as_ref() } + #[cfg(unix)] pub async fn listen(&mut self, fds: Option<ListenFds>) -> Result<()> { if self.listener.is_some() { return Ok(()); @@ -278,6 +296,12 @@ impl ListenerEndpoint { Ok(()) } + #[cfg(windows)] + pub async fn listen(&mut self) -> Result<()> { + self.listener = Some(bind(&self.listen_addr).await?); + Ok(()) + } + fn apply_stream_settings(&self, stream: &mut Stream) -> Result<()> { // settings are applied based on whether the underlying stream supports it stream.set_nodelay()?; @@ -288,7 +312,10 @@ impl ListenerEndpoint { stream.set_keepalive(ka)?; } if let Some(dscp) = op.dscp { + #[cfg(unix)] set_dscp(stream.as_raw_fd(), dscp)?; + #[cfg(windows)] + set_dscp(stream.as_raw_socket(), dscp)?; } Ok(()) } @@ -315,7 +342,13 @@ mod test { async fn test_listen_tcp() { let addr = "127.0.0.1:7100"; let mut listener = ListenerEndpoint::new(ServerAddress::Tcp(addr.into(), None)); - listener.listen(None).await.unwrap(); + listener + .listen( + #[cfg(unix)] + None, + ) + .await + .unwrap(); tokio::spawn(async move { // just try to accept once listener.accept().await.unwrap(); @@ -332,7 +365,13 @@ mod test { ..Default::default() }); let mut listener = ListenerEndpoint::new(ServerAddress::Tcp("[::]:7101".into(), sock_opt)); - listener.listen(None).await.unwrap(); + listener + .listen( + #[cfg(unix)] + None, + ) + .await + .unwrap(); tokio::spawn(async move { // just try to accept twice listener.accept().await.unwrap(); @@ -346,6 +385,7 @@ mod test { .expect("can connect to v6 addr"); } + #[cfg(unix)] #[tokio::test] async fn test_listen_uds() { let addr = "/tmp/test_listen_uds"; diff --git a/pingora-core/src/listeners/mod.rs b/pingora-core/src/listeners/mod.rs index cb6e2c6..35e30d5 100644 --- a/pingora-core/src/listeners/mod.rs +++ b/pingora-core/src/listeners/mod.rs @@ -18,6 +18,7 @@ mod l4; mod tls; use crate::protocols::Stream; +#[cfg(unix)] use crate::server::ListenFds; use pingora_error::Result; @@ -36,10 +37,11 @@ struct TransportStackBuilder { } impl TransportStackBuilder { - pub fn build(&mut self, upgrade_listeners: Option<ListenFds>) -> TransportStack { + pub fn build(&mut self, #[cfg(unix)] upgrade_listeners: Option<ListenFds>) -> TransportStack { TransportStack { l4: ListenerEndpoint::new(self.l4.clone()), tls: self.tls.take().map(|tls| Arc::new(tls.build())), + #[cfg(unix)] upgrade_listeners, } } @@ -58,7 +60,12 @@ impl TransportStack { } pub async fn listen(&mut self) -> Result<()> { - self.l4.listen(self.upgrade_listeners.take()).await + self.l4 + .listen( + #[cfg(unix)] + self.upgrade_listeners.take(), + ) + .await } pub async fn accept(&mut self) -> Result<UninitializedStream> { @@ -109,6 +116,7 @@ impl Listeners { } /// Create a new [`Listeners`] with a Unix domain socket endpoint from the given string. + #[cfg(unix)] pub fn uds(addr: &str, perm: Option<Permissions>) -> Self { let mut listeners = Self::new(); listeners.add_uds(addr, perm); @@ -136,6 +144,7 @@ impl Listeners { } /// Add a Unix domain socket endpoint to `self`. + #[cfg(unix)] pub fn add_uds(&mut self, addr: &str, perm: Option<Permissions>) { self.add_address(ServerAddress::Uds(addr.into(), perm)); } @@ -168,10 +177,18 @@ impl Listeners { self.stacks.push(TransportStackBuilder { l4, tls }) } - pub(crate) fn build(&mut self, upgrade_listeners: Option<ListenFds>) -> Vec<TransportStack> { + pub(crate) fn build( + &mut self, + #[cfg(unix)] upgrade_listeners: Option<ListenFds>, + ) -> Vec<TransportStack> { self.stacks .iter_mut() - .map(|b| b.build(upgrade_listeners.clone())) + .map(|b| { + b.build( + #[cfg(unix)] + upgrade_listeners.clone(), + ) + }) .collect() } @@ -194,7 +211,10 @@ mod test { let mut listeners = Listeners::tcp(addr1); listeners.add_tcp(addr2); - let listeners = listeners.build(None); + let listeners = listeners.build( + #[cfg(unix)] + None, + ); assert_eq!(listeners.len(), 2); for mut listener in listeners { tokio::spawn(async move { @@ -221,7 +241,13 @@ mod test { let cert_path = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR")); let key_path = format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR")); let mut listeners = Listeners::tls(addr, &cert_path, &key_path).unwrap(); - let mut listener = listeners.build(None).pop().unwrap(); + let mut listener = listeners + .build( + #[cfg(unix)] + None, + ) + .pop() + .unwrap(); tokio::spawn(async move { listener.listen().await.unwrap(); diff --git a/pingora-core/src/protocols/digest.rs b/pingora-core/src/protocols/digest.rs index 3150306..88720e5 100644 --- a/pingora-core/src/protocols/digest.rs +++ b/pingora-core/src/protocols/digest.rs @@ -62,7 +62,10 @@ impl Default for TimingDigest { #[derive(Debug)] /// The interface to return socket-related information pub struct SocketDigest { + #[cfg(unix)] raw_fd: std::os::unix::io::RawFd, + #[cfg(windows)] + raw_sock: std::os::windows::io::RawSocket, /// Remote socket address pub peer_addr: OnceCell<Option<SocketAddr>>, /// Local socket address @@ -70,6 +73,7 @@ pub struct SocketDigest { } impl SocketDigest { + #[cfg(unix)] pub fn from_raw_fd(raw_fd: std::os::unix::io::RawFd) -> SocketDigest { SocketDigest { raw_fd, @@ -78,22 +82,48 @@ impl SocketDigest { } } + #[cfg(windows)] + pub fn from_raw_socket(raw_sock: std::os::windows::io::RawSocket) -> SocketDigest { + SocketDigest { + raw_sock, + peer_addr: OnceCell::new(), + local_addr: OnceCell::new(), + } + } + + #[cfg(unix)] pub fn peer_addr(&self) -> Option<&SocketAddr> { self.peer_addr .get_or_init(|| SocketAddr::from_raw_fd(self.raw_fd, true)) .as_ref() } + #[cfg(windows)] + pub fn peer_addr(&self) -> Option<&SocketAddr> { + self.peer_addr + .get_or_init(|| SocketAddr::from_raw_socket(self.raw_sock, true)) + .as_ref() + } + + #[cfg(unix)] pub fn local_addr(&self) -> Option<&SocketAddr> { self.local_addr .get_or_init(|| SocketAddr::from_raw_fd(self.raw_fd, false)) .as_ref() } + #[cfg(windows)] + pub fn local_addr(&self) -> Option<&SocketAddr> { + self.local_addr + .get_or_init(|| SocketAddr::from_raw_socket(self.raw_sock, false)) + .as_ref() + } + fn is_inet(&self) -> bool { self.local_addr().and_then(|p| p.as_inet()).is_some() } + #[cfg(unix)] pub fn tcp_info(&self) -> Option<TCP_INFO> { if self.is_inet() { get_tcp_info(self.raw_fd).ok() @@ -102,6 +132,16 @@ impl SocketDigest { } } + #[cfg(windows)] + pub fn tcp_info(&self) -> Option<TCP_INFO> { + if self.is_inet() { + get_tcp_info(self.raw_sock).ok() + } else { + None + } + } + + #[cfg(unix)] pub fn get_recv_buf(&self) -> Option<usize> { if self.is_inet() { get_recv_buf(self.raw_fd).ok() @@ -109,6 +149,15 @@ impl SocketDigest { None } } + + #[cfg(windows)] + pub fn get_recv_buf(&self) -> Option<usize> { + if self.is_inet() { + get_recv_buf(self.raw_sock).ok() + } else { + None + } + } } /// The interface to return timing information diff --git a/pingora-core/src/protocols/l4/ext.rs b/pingora-core/src/protocols/l4/ext.rs index 4b65f84..d3ab70d 100644 --- a/pingora-core/src/protocols/l4/ext.rs +++ b/pingora-core/src/protocols/l4/ext.rs @@ -16,6 +16,7 @@ #![allow(non_camel_case_types)] +#[cfg(unix)] use libc::socklen_t; #[cfg(target_os = "linux")] use libc::{c_int, c_ulonglong, c_void}; @@ -23,9 +24,14 @@ use pingora_error::{Error, ErrorType::*, OrErr, Result}; use std::io::{self, ErrorKind}; use std::mem; use std::net::SocketAddr; +#[cfg(unix)] use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; use std::time::Duration; -use tokio::net::{TcpSocket, TcpStream, UnixStream}; +#[cfg(unix)] +use tokio::net::UnixStream; +use tokio::net::{TcpSocket, TcpStream}; use crate::connectors::l4::BindTo; @@ -98,9 +104,16 @@ impl TCP_INFO { } /// Return the size of [`TCP_INFO`] + #[cfg(unix)] pub fn len() -> socklen_t { mem::size_of::<Self>() as socklen_t } + + /// Return the size of [`TCP_INFO`] + #[cfg(windows)] + pub fn len() -> usize { + mem::size_of::<Self>() + } } #[cfg(target_os = "linux")] @@ -170,7 +183,7 @@ fn ip_bind_addr_no_port(fd: RawFd, val: bool) -> io::Result<()> { ) } -#[cfg(not(target_os = "linux"))] +#[cfg(all(unix, not(target_os = "linux")))] fn ip_bind_addr_no_port(_fd: RawFd, _val: bool) -> io::Result<()> { Ok(()) } @@ -233,11 +246,16 @@ fn set_keepalive(fd: RawFd, ka: &TcpKeepalive) -> io::Result<()> { set_so_keepalive_count(fd, ka.count) } -#[cfg(not(target_os = "linux"))] +#[cfg(all(unix, not(target_os = "linux")))] fn set_keepalive(_fd: RawFd, _ka: &TcpKeepalive) -> io::Result<()> { Ok(()) } +#[cfg(windows)] +fn set_keepalive(_sock: RawSocket, _ka: &TcpKeepalive) -> io::Result<()> { + Ok(()) +} + /// Get the kernel TCP_INFO for the given FD. #[cfg(target_os = "linux")] pub fn get_tcp_info(fd: RawFd) -> io::Result<TCP_INFO> { @@ -256,21 +274,31 @@ pub fn set_recv_buf(fd: RawFd, val: usize) -> Result<()> { .or_err(ConnectError, "failed to set SO_RCVBUF") } -#[cfg(not(target_os = "linux"))] +#[cfg(all(unix, not(target_os = "linux")))] pub fn set_recv_buf(_fd: RawFd, _: usize) -> Result<()> { Ok(()) } +#[cfg(windows)] +pub fn set_recv_buf(_sock: RawSocket, _: usize) -> Result<()> { + Ok(()) +} + #[cfg(target_os = "linux")] pub fn get_recv_buf(fd: RawFd) -> io::Result<usize> { get_opt_sized::<c_int>(fd, libc::SOL_SOCKET, libc::SO_RCVBUF).map(|v| v as usize) } -#[cfg(not(target_os = "linux"))] +#[cfg(all(unix, not(target_os = "linux")))] pub fn get_recv_buf(_fd: RawFd) -> io::Result<usize> { Ok(0) } +#[cfg(windows)] +pub fn get_recv_buf(_sock: RawSocket) -> io::Result<usize> { + Ok(0) +} + /// Enable client side TCP fast open. #[cfg(target_os = "linux")] pub fn set_tcp_fastopen_connect(fd: RawFd) -> Result<()> { @@ -283,11 +311,16 @@ pub fn set_tcp_fastopen_connect(fd: RawFd) -> Result<()> { .or_err(ConnectError, "failed to set TCP_FASTOPEN_CONNECT") } -#[cfg(not(target_os = "linux"))] +#[cfg(all(unix, not(target_os = "linux")))] pub fn set_tcp_fastopen_connect(_fd: RawFd) -> Result<()> { Ok(()) } +#[cfg(windows)] +pub fn set_tcp_fastopen_connect(_sock: RawSocket) -> Result<()> { + Ok(()) +} + /// Enable server side TCP fast open. #[cfg(target_os = "linux")] pub fn set_tcp_fastopen_backlog(fd: RawFd, backlog: usize) -> Result<()> { @@ -295,11 +328,16 @@ pub fn set_tcp_fastopen_backlog(fd: RawFd, backlog: usize) -> Result<()> { .or_err(ConnectError, "failed to set TCP_FASTOPEN") } -#[cfg(not(target_os = "linux"))] +#[cfg(all(unix, not(target_os = "linux")))] pub fn set_tcp_fastopen_backlog(_fd: RawFd, _backlog: usize) -> Result<()> { Ok(()) } +#[cfg(windows)] +pub fn set_tcp_fastopen_backlog(_sock: RawSocket, _backlog: usize) -> Result<()> { + Ok(()) +} + #[cfg(target_os = "linux")] pub fn set_dscp(fd: RawFd, value: u8) -> Result<()> { use super::socket::SocketAddr; @@ -320,17 +358,22 @@ pub fn set_dscp(fd: RawFd, value: u8) -> Result<()> { } } -#[cfg(not(target_os = "linux"))] +#[cfg(all(unix, not(target_os = "linux")))] pub fn set_dscp(_fd: RawFd, _value: u8) -> Result<()> { Ok(()) } +#[cfg(windows)] +pub fn set_dscp(_sock: RawSocket, _value: u8) -> Result<()> { + Ok(()) +} + #[cfg(target_os = "linux")] pub fn get_socket_cookie(fd: RawFd) -> io::Result<u64> { get_opt_sized::<c_ulonglong>(fd, libc::SOL_SOCKET, libc::SO_COOKIE) } -#[cfg(not(target_os = "linux"))] +#[cfg(all(unix, not(target_os = "linux")))] pub fn get_socket_cookie(_fd: RawFd) -> io::Result<u64> { Ok(0) // SO_COOKIE is a Linux concept } @@ -380,7 +423,8 @@ async fn inner_connect_with<F: FnOnce(&TcpSocket) -> Result<()>>( } .or_err(SocketError, "failed to create socket")?; - if cfg!(target_os = "linux") { + #[cfg(target_os = "linux")] + { ip_bind_addr_no_port(socket.as_raw_fd(), true).or_err( SocketError, "failed to set socket opts IP_BIND_ADDRESS_NO_PORT", @@ -399,6 +443,13 @@ async fn inner_connect_with<F: FnOnce(&TcpSocket) -> Result<()>>( } } } + + #[cfg(windows)] + if let Some(baddr) = bind_to { + socket + .bind(*baddr) + .or_err_with(BindError, || format!("failed to bind to socket {}", *baddr))?; + }; // TODO: add support for bind on other platforms set_socket(&socket)?; @@ -418,6 +469,7 @@ pub async fn connect(addr: &SocketAddr, bind_to: Option<&BindTo>) -> Result<TcpS } /// connect() to the given Unix domain socket +#[cfg(unix)] pub async fn connect_uds(path: &std::path::Path) -> Result<UnixStream> { UnixStream::connect(path) .await @@ -460,9 +512,12 @@ impl std::fmt::Display for TcpKeepalive { /// Apply the given TCP keepalive settings to the given connection pub fn set_tcp_keepalive(stream: &TcpStream, ka: &TcpKeepalive) -> Result<()> { - let fd = stream.as_raw_fd(); + #[cfg(unix)] + let raw = stream.as_raw_fd(); + #[cfg(windows)] + let raw = stream.as_raw_socket(); // TODO: check localhost or if keepalive is already set - set_keepalive(fd, ka).or_err(ConnectError, "failed to set keepalive") + set_keepalive(raw, ka).or_err(ConnectError, "failed to set keepalive") } #[cfg(test)] @@ -473,7 +528,10 @@ mod test { fn test_set_recv_buf() { use tokio::net::TcpSocket; let socket = TcpSocket::new_v4().unwrap(); + #[cfg(unix)] set_recv_buf(socket.as_raw_fd(), 102400).unwrap(); + #[cfg(windows)] + set_recv_buf(socket.as_raw_socket(), 102400).unwrap(); #[cfg(target_os = "linux")] { diff --git a/pingora-core/src/protocols/l4/listener.rs b/pingora-core/src/protocols/l4/listener.rs index 29cc9e9..d62f7f0 100644 --- a/pingora-core/src/protocols/l4/listener.rs +++ b/pingora-core/src/protocols/l4/listener.rs @@ -15,8 +15,13 @@ //! Listeners use std::io; +#[cfg(unix)] use std::os::unix::io::AsRawFd; -use tokio::net::{TcpListener, UnixListener}; +#[cfg(windows)] +use std::os::windows::io::AsRawSocket; +use tokio::net::TcpListener; +#[cfg(unix)] +use tokio::net::UnixListener; use crate::protocols::digest::{GetSocketDigest, SocketDigest}; use crate::protocols::l4::stream::Stream; @@ -25,6 +30,7 @@ use crate::protocols::l4::stream::Stream; #[derive(Debug)] pub enum Listener { Tcp(TcpListener), + #[cfg(unix)] Unix(UnixListener), } @@ -34,12 +40,14 @@ impl From<TcpListener> for Listener { } } +#[cfg(unix)] impl From<UnixListener> for Listener { fn from(s: UnixListener) -> Self { Self::Unix(s) } } +#[cfg(unix)] impl AsRawFd for Listener { fn as_raw_fd(&self) -> std::os::unix::io::RawFd { match &self { @@ -49,13 +57,25 @@ impl AsRawFd for Listener { } } +#[cfg(windows)] +impl AsRawSocket for Listener { + fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { + match &self { + Self::Tcp(l) => l.as_raw_socket(), + } + } +} + impl Listener { /// Accept a connection from the listening endpoint pub async fn accept(&self) -> io::Result<Stream> { match &self { Self::Tcp(l) => l.accept().await.map(|(stream, peer_addr)| { let mut s: Stream = stream.into(); + #[cfg(unix)] let digest = SocketDigest::from_raw_fd(s.as_raw_fd()); + #[cfg(windows)] + let digest = SocketDigest::from_raw_socket(s.as_raw_socket()); digest .peer_addr .set(Some(peer_addr.into())) @@ -66,6 +86,7 @@ impl Listener { // and init it in the socket digest here s }), + #[cfg(unix)] Self::Unix(l) => l.accept().await.map(|(stream, peer_addr)| { let mut s: Stream = stream.into(); let digest = SocketDigest::from_raw_fd(s.as_raw_fd()); diff --git a/pingora-core/src/protocols/l4/socket.rs b/pingora-core/src/protocols/l4/socket.rs index 93d334e..64e0e07 100644 --- a/pingora-core/src/protocols/l4/socket.rs +++ b/pingora-core/src/protocols/l4/socket.rs @@ -16,11 +16,14 @@ use crate::{Error, OrErr}; use log::warn; +#[cfg(unix)] use nix::sys::socket::{getpeername, getsockname, SockaddrStorage}; use std::cmp::Ordering; use std::hash::{Hash, Hasher}; use std::net::SocketAddr as StdSockAddr; +#[cfg(unix)] use std::os::unix::net::SocketAddr as StdUnixSockAddr; +#[cfg(unix)] use tokio::net::unix::SocketAddr as TokioUnixSockAddr; /// [`SocketAddr`] is a storage type that contains either a Internet (IP address) @@ -28,6 +31,7 @@ use tokio::net::unix::SocketAddr as TokioUnixSockAddr; #[derive(Debug, Clone)] pub enum SocketAddr { Inet(StdSockAddr), + #[cfg(unix)] Unix(StdUnixSockAddr), } @@ -42,6 +46,7 @@ impl SocketAddr { } /// Get a reference to the Unix domain socket if it is one + #[cfg(unix)] pub fn as_unix(&self) -> Option<&StdUnixSockAddr> { if let SocketAddr::Unix(addr) = self { Some(addr) @@ -57,6 +62,7 @@ impl SocketAddr { } } + #[cfg(unix)] fn from_sockaddr_storage(sock: &SockaddrStorage) -> Option<SocketAddr> { if let Some(v4) = sock.as_sockaddr_in() { return Some(SocketAddr::Inet(StdSockAddr::V4( @@ -77,6 +83,7 @@ impl SocketAddr { )) } + #[cfg(unix)] pub fn from_raw_fd(fd: std::os::unix::io::RawFd, peer_addr: bool) -> Option<SocketAddr> { let sockaddr_storage = if peer_addr { getpeername(fd) @@ -90,12 +97,28 @@ impl SocketAddr { Err(_e) => None, } } + + #[cfg(windows)] + pub fn from_raw_socket( + sock: std::os::windows::io::RawSocket, + is_peer_addr: bool, + ) -> Option<SocketAddr> { + use crate::protocols::windows::{local_addr, peer_addr}; + if is_peer_addr { + peer_addr(sock) + } else { + local_addr(sock) + } + .map(|s| s.into()) + .ok() + } } impl std::fmt::Display for SocketAddr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { SocketAddr::Inet(addr) => write!(f, "{addr}"), + #[cfg(unix)] SocketAddr::Unix(addr) => { if let Some(path) = addr.as_pathname() { write!(f, "{}", path.display()) @@ -111,6 +134,7 @@ impl Hash for SocketAddr { fn hash<H: Hasher>(&self, state: &mut H) { match self { Self::Inet(sockaddr) => sockaddr.hash(state), + #[cfg(unix)] Self::Unix(sockaddr) => { if let Some(path) = sockaddr.as_pathname() { // use the underlying path as the hash @@ -130,6 +154,7 @@ impl PartialEq for SocketAddr { fn eq(&self, other: &Self) -> bool { match self { Self::Inet(addr) => Some(addr) == other.as_inet(), + #[cfg(unix)] Self::Unix(addr) => { let path = addr.as_pathname(); // can only compare UDS with path, assume false on all unnamed UDS @@ -156,6 +181,7 @@ impl Ord for SocketAddr { Ordering::Less } } + #[cfg(unix)] Self::Unix(addr) => { if let Some(o) = other.as_unix() { // NOTE: unnamed UDS are consider the same @@ -175,6 +201,7 @@ impl std::str::FromStr for SocketAddr { type Err = Box<Error>; // This is very basic parsing logic, it might treat invalid IP:PORT str as UDS path + #[cfg(unix)] fn from_str(s: &str) -> Result<Self, Self::Err> { if s.starts_with("unix:") { // format unix:/tmp/server.socket @@ -195,6 +222,12 @@ impl std::str::FromStr for SocketAddr { } } } + + #[cfg(windows)] + fn from_str(s: &str) -> Result<Self, Self::Err> { + let addr = StdSockAddr::from_str(s).or_err(crate::BindError, "invalid socket addr")?; + Ok(SocketAddr::Inet(addr)) + } } impl std::net::ToSocketAddrs for SocketAddr { @@ -213,12 +246,14 @@ impl std::net::ToSocketAddrs for SocketAddr { } } +#[cfg(unix)] impl From<StdSockAddr> for SocketAddr { fn from(sockaddr: StdSockAddr) -> Self { SocketAddr::Inet(sockaddr) } } +#[cfg(unix)] impl From<StdUnixSockAddr> for SocketAddr { fn from(sockaddr: StdUnixSockAddr) -> Self { SocketAddr::Unix(sockaddr) @@ -228,6 +263,7 @@ impl From<StdUnixSockAddr> for SocketAddr { // TODO: ideally mio/tokio will start using the std version of the unix `SocketAddr` // so we can avoid a fallible conversion // https://github.com/tokio-rs/mio/issues/1527 +#[cfg(unix)] impl TryFrom<TokioUnixSockAddr> for SocketAddr { type Error = String; @@ -251,12 +287,14 @@ mod test { assert!(ip.as_inet().is_some()); } + #[cfg(unix)] #[test] fn parse_uds() { let uds: SocketAddr = "/tmp/my.sock".parse().unwrap(); assert!(uds.as_unix().is_some()); } + #[cfg(unix)] #[test] fn parse_uds_with_prefix() { let uds: SocketAddr = "unix:/tmp/my.sock".parse().unwrap(); diff --git a/pingora-core/src/protocols/l4/stream.rs b/pingora-core/src/protocols/l4/stream.rs index 8c5f3cf..898b3e6 100644 --- a/pingora-core/src/protocols/l4/stream.rs +++ b/pingora-core/src/protocols/l4/stream.rs @@ -20,13 +20,18 @@ use log::{debug, error}; use pingora_error::{ErrorType::*, OrErr, Result}; use std::io::IoSliceMut; +#[cfg(unix)] use std::os::unix::io::AsRawFd; +#[cfg(windows)] +use std::os::windows::io::AsRawSocket; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::{Duration, Instant, SystemTime}; use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, Interest, ReadBuf}; -use tokio::net::{TcpStream, UnixStream}; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive}; use crate::protocols::raw_connect::ProxyDigest; @@ -39,6 +44,7 @@ use crate::upstreams::peer::Tracer; #[derive(Debug)] enum RawStream { Tcp(TcpStream), + #[cfg(unix)] Unix(UnixStream), } @@ -52,6 +58,7 @@ impl AsyncRead for RawStream { unsafe { match &mut Pin::get_unchecked_mut(self) { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf), + #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf), } } @@ -64,6 +71,7 @@ impl AsyncWrite for RawStream { unsafe { match &mut Pin::get_unchecked_mut(self) { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf), + #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf), } } @@ -74,6 +82,7 @@ impl AsyncWrite for RawStream { unsafe { match &mut Pin::get_unchecked_mut(self) { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx), + #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx), } } @@ -84,6 +93,7 @@ impl AsyncWrite for RawStream { unsafe { match &mut Pin::get_unchecked_mut(self) { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx), + #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx), } } @@ -98,6 +108,7 @@ impl AsyncWrite for RawStream { unsafe { match &mut Pin::get_unchecked_mut(self) { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs), + #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs), } } @@ -106,11 +117,13 @@ impl AsyncWrite for RawStream { fn is_write_vectored(&self) -> bool { match self { RawStream::Tcp(s) => s.is_write_vectored(), + #[cfg(unix)] RawStream::Unix(s) => s.is_write_vectored(), } } } +#[cfg(unix)] impl AsRawFd for RawStream { fn as_raw_fd(&self) -> std::os::unix::io::RawFd { match self { @@ -120,6 +133,15 @@ impl AsRawFd for RawStream { } } +#[cfg(windows)] +impl AsRawSocket for RawStream { + fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { + match self { + RawStream::Tcp(s) => s.as_raw_socket(), + } + } +} + #[derive(Debug)] struct RawStreamWrapper { pub(crate) stream: RawStream, @@ -162,6 +184,7 @@ impl AsyncRead for RawStreamWrapper { let rs_wrapper = Pin::get_unchecked_mut(self); match &mut rs_wrapper.stream { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf), + #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf), } } @@ -245,6 +268,7 @@ impl AsyncWrite for RawStreamWrapper { unsafe { match &mut Pin::get_unchecked_mut(self).stream { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf), + #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf), } } @@ -255,6 +279,7 @@ impl AsyncWrite for RawStreamWrapper { unsafe { match &mut Pin::get_unchecked_mut(self).stream { RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx), + #[cfg(unix)] RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx), } } @@ -289,12 +314,20 @@ impl AsyncWrite for RawStreamWrapper { } } +#[cfg(unix)] impl AsRawFd for RawStreamWrapper { fn as_raw_fd(&self) -> std::os::unix::io::RawFd { self.stream.as_raw_fd() } } +#[cfg(windows)] +impl AsRawSocket for RawStreamWrapper { + fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { + self.stream.as_raw_socket() + } +} + // Large read buffering helps reducing syscalls with little trade-off // Ssl layer always does "small" reads in 16k (TLS record size) so L4 read buffer helps a lot. const BUF_READ_SIZE: usize = 64 * 1024; @@ -384,6 +417,7 @@ impl From<TcpStream> for Stream { } } +#[cfg(unix)] impl From<UnixStream> for Stream { fn from(s: UnixStream) -> Self { Stream { @@ -404,18 +438,34 @@ impl From<UnixStream> for Stream { } } +#[cfg(unix)] impl AsRawFd for Stream { fn as_raw_fd(&self) -> std::os::unix::io::RawFd { self.stream.get_ref().as_raw_fd() } } +#[cfg(windows)] +impl AsRawSocket for Stream { + fn as_raw_socket(&self) -> std::os::windows::io::RawSocket { + self.stream.get_ref().as_raw_socket() + } +} + +#[cfg(unix)] impl UniqueID for Stream { fn id(&self) -> UniqueIDType { self.as_raw_fd() } } +#[cfg(windows)] +impl UniqueID for Stream { + fn id(&self) -> usize { + self.as_raw_socket() as usize + } +} + impl Ssl for Stream {} #[async_trait] @@ -473,6 +523,7 @@ impl Drop for Stream { /* use nodelay/local_addr function to detect socket status */ let ret = match &self.stream.get_ref().stream { RawStream::Tcp(s) => s.nodelay().err(), + #[cfg(unix)] RawStream::Unix(s) => s.local_addr().err(), }; if let Some(e) = ret { diff --git a/pingora-core/src/protocols/mod.rs b/pingora-core/src/protocols/mod.rs index 4c1aa88..7105d61 100644 --- a/pingora-core/src/protocols/mod.rs +++ b/pingora-core/src/protocols/mod.rs @@ -19,6 +19,8 @@ pub mod http; pub mod l4; pub mod raw_connect; pub mod tls; +#[cfg(windows)] +mod windows; pub use digest::{ Digest, GetProxyDigest, GetSocketDigest, GetTimingDigest, ProtoDigest, SocketDigest, @@ -309,3 +311,33 @@ impl ConnFdReusable for InetSocketAddr { } } } + +#[cfg(windows)] +impl ConnSockReusable for InetSocketAddr { + fn check_sock_match<V: AsRawSocket>(&self, sock: V) -> bool { + let sock = sock.as_raw_socket(); + match windows::peer_addr(sock) { + Ok(peer) => { + const ZERO: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); + if self.ip() == ZERO { + // https://www.rfc-editor.org/rfc/rfc1122.html#section-3.2.1.3 + // 0.0.0.0 should only be used as source IP not destination + // However in some systems this destination IP is mapped to 127.0.0.1. + // We just skip this check here to avoid false positive mismatch. + return true; + } + if self == &peer { + debug!("Inet FD to: {self} is reusable"); + true + } else { + error!("Crit: FD mismatch: fd: {sock:?}, addr: {self}, peer: {peer}",); + false + } + } + Err(e) => { + debug!("Idle connection is broken: {e:?}"); + false + } + } + } +} diff --git a/pingora-core/src/protocols/windows.rs b/pingora-core/src/protocols/windows.rs new file mode 100644 index 0000000..07e8f5b --- /dev/null +++ b/pingora-core/src/protocols/windows.rs @@ -0,0 +1,129 @@ +// Copyright 2024 Cloudflare, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Windows specific functionality for calling the WinSock c api +//! +//! Implementations here are based on the implementation in the std library +//! https://github.com/rust-lang/rust/blob/84ac80f/library/std/src/sys_common/net.rs +//! https://github.com/rust-lang/rust/blob/84ac80f/library/std/src/sys/pal/windows/net.rs + +use std::os::windows::io::RawSocket; +use std::{io, mem, net::SocketAddr}; + +use windows_sys::Win32::Networking::WinSock::{ + getpeername, getsockname, AF_INET, AF_INET6, SOCKADDR_IN, SOCKADDR_IN6, SOCKADDR_STORAGE, + SOCKET, +}; + +pub(crate) fn peer_addr(raw_sock: RawSocket) -> io::Result<SocketAddr> { + let mut storage = unsafe { mem::zeroed::<SOCKADDR_STORAGE>() }; + let mut addrlen = mem::size_of_val(&storage) as i32; + + unsafe { + let res = getpeername( + raw_sock as SOCKET, + core::ptr::addr_of_mut!(storage) as *mut _, + &mut addrlen, + ); + if res != 0 { + return Err(io::Error::last_os_error()); + } + } + + sockaddr_to_addr(&storage, addrlen as usize) +} +pub(crate) fn local_addr(raw_sock: RawSocket) -> io::Result<SocketAddr> { + let mut storage = unsafe { mem::zeroed::<SOCKADDR_STORAGE>() }; + let mut addrlen = mem::size_of_val(&storage) as i32; + + unsafe { + let res = getsockname( + raw_sock as libc::SOCKET, + core::ptr::addr_of_mut!(storage) as *mut _, + &mut addrlen, + ); + if res != 0 { + return Err(io::Error::last_os_error()); + } + } + + sockaddr_to_addr(&storage, addrlen as usize) +} + +fn sockaddr_to_addr(storage: &SOCKADDR_STORAGE, len: usize) -> io::Result<SocketAddr> { + match storage.ss_family { + AF_INET => { + assert!(len >= mem::size_of::<SOCKADDR_IN>()); + Ok(SocketAddr::from(unsafe { + let sockaddr = *(storage as *const _ as *const SOCKADDR_IN); + ( + sockaddr.sin_addr.S_un.S_addr.to_ne_bytes(), + sockaddr.sin_port.to_be(), + ) + })) + } + AF_INET6 => { + assert!(len >= mem::size_of::<SOCKADDR_IN6>()); + Ok(SocketAddr::from(unsafe { + let sockaddr = *(storage as *const _ as *const SOCKADDR_IN6); + (sockaddr.sin6_addr.u.Byte, sockaddr.sin6_port.to_be()) + })) + } + _ => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid argument", + )), + } +} + +#[cfg(test)] +mod tests { + use std::os::windows::io::AsRawSocket; + + use crate::protocols::l4::{listener::Listener, stream::Stream}; + + use super::*; + + async fn assert_listener_and_stream(addr: &str) { + let tokio_listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + + let listener_local_addr = tokio_listener.local_addr().unwrap(); + + let tokio_stream = tokio::net::TcpStream::connect(listener_local_addr) + .await + .unwrap(); + + let stream_local_addr = tokio_stream.local_addr().unwrap(); + let stream_peer_addr = tokio_stream.peer_addr().unwrap(); + + let stream: Stream = tokio_stream.into(); + let listener: Listener = tokio_listener.into(); + + let raw_sock = listener.as_raw_socket(); + assert_eq!(listener_local_addr, local_addr(raw_sock).unwrap()); + + let raw_sock = stream.as_raw_socket(); + assert_eq!(stream_peer_addr, peer_addr(raw_sock).unwrap()); + assert_eq!(stream_local_addr, local_addr(raw_sock).unwrap()); + } + + #[tokio::test] + async fn get_v4_addrs_from_raw_socket() { + assert_listener_and_stream("127.0.0.1:0").await + } + #[tokio::test] + async fn get_v6_addrs_from_raw_socket() { + assert_listener_and_stream("[::1]:0").await + } +} diff --git a/pingora-core/src/server/daemon.rs b/pingora-core/src/server/daemon.rs index a026cec..8d6ea3b 100644 --- a/pingora-core/src/server/daemon.rs +++ b/pingora-core/src/server/daemon.rs @@ -54,6 +54,7 @@ unsafe fn gid_for_username(name: &CString) -> Option<libc::gid_t> { } /// Start a server instance as a daemon. +#[cfg(unix)] pub fn daemonize(conf: &ServerConf) { // TODO: customize working dir diff --git a/pingora-core/src/server/mod.rs b/pingora-core/src/server/mod.rs index 0a2977d..d28b00c 100644 --- a/pingora-core/src/server/mod.rs +++ b/pingora-core/src/server/mod.rs @@ -15,9 +15,12 @@ //! Server process and configuration management pub mod configuration; +#[cfg(unix)] mod daemon; +#[cfg(unix)] pub(crate) mod transfer_fd; +#[cfg(unix)] use daemon::daemonize; use log::{debug, error, info, warn}; use pingora_runtime::Runtime; @@ -26,12 +29,14 @@ use pingora_timeout::fast_timeout; use sentry::ClientOptions; use std::sync::Arc; use std::thread; +#[cfg(unix)] use tokio::signal::unix; use tokio::sync::{watch, Mutex}; use tokio::time::{sleep, Duration}; use crate::services::Service; use configuration::{Opt, ServerConf}; +#[cfg(unix)] pub use transfer_fd::Fds; use pingora_error::{Error, ErrorType, Result}; @@ -51,6 +56,7 @@ enum ShutdownType { /// The receiver for server's shutdown event. The value will turn to true once the server starts /// to shutdown pub type ShutdownWatch = watch::Receiver<bool>; +#[cfg(unix)] pub type ListenFds = Arc<Mutex<Fds>>; /// The server object @@ -60,6 +66,7 @@ pub type ListenFds = Arc<Mutex<Fds>>; /// zero downtime upgrade and error reporting. pub struct Server { services: Vec<Box<dyn Service>>, + #[cfg(unix)] listen_fds: Option<ListenFds>, shutdown_watch: watch::Sender<bool>, // TODO: we many want to drop this copy to let sender call closed() @@ -78,6 +85,7 @@ pub struct Server { // TODO: delete the pid when exit impl Server { + #[cfg(unix)] async fn main_loop(&self) -> ShutdownType { // waiting for exit signal // TODO: there should be a signal handling function @@ -145,7 +153,7 @@ impl Server { fn run_service( mut service: Box<dyn Service>, - fds: Option<ListenFds>, + #[cfg(unix)] fds: Option<ListenFds>, shutdown: ShutdownWatch, threads: usize, work_stealing: bool, @@ -155,12 +163,19 @@ impl Server { { let service_runtime = Server::create_runtime(service.name(), threads, work_stealing); service_runtime.get_handle().spawn(async move { - service.start_service(fds, shutdown).await; + service + .start_service( + #[cfg(unix)] + fds, + shutdown, + ) + .await; info!("service exited.") }); service_runtime } + #[cfg(unix)] fn load_fds(&mut self, upgrade: bool) -> Result<(), nix::Error> { let mut fds = Fds::new(); if upgrade { @@ -188,6 +203,7 @@ impl Server { Server { services: vec![], + #[cfg(unix)] listen_fds: None, shutdown_watch: tx, shutdown_recv: rx, @@ -269,6 +285,7 @@ impl Server { } // load fds + #[cfg(unix)] match self.load_fds(self.options.as_ref().map_or(false, |o| o.upgrade)) { Ok(_) => { info!("Bootstrap done"); @@ -296,6 +313,7 @@ impl Server { let conf = self.configuration.as_ref(); + #[cfg(unix)] if conf.daemon { info!("Daemonizing the server"); fast_timeout::pause_for_fork(); @@ -303,6 +321,11 @@ impl Server { fast_timeout::unpause(); } + #[cfg(windows)] + if conf.daemon { + panic!("Daemonizing under windows is not supported"); + } + /* only init sentry in release builds */ #[cfg(not(debug_assertions))] let _guard = self.sentry.as_ref().map(|opts| sentry::init(opts.clone())); @@ -313,6 +336,7 @@ impl Server { let threads = service.threads().unwrap_or(conf.threads); let runtime = Server::run_service( service, + #[cfg(unix)] self.listen_fds.clone(), self.shutdown_recv.clone(), threads, @@ -324,7 +348,10 @@ impl Server { // blocked on main loop so that it runs forever // Only work steal runtime can use block_on() let server_runtime = Server::create_runtime("Server", 1, true); + #[cfg(unix)] let shutdown_type = server_runtime.get_handle().block_on(self.main_loop()); + #[cfg(windows)] + let shutdown_type = ShutdownType::Graceful; if matches!(shutdown_type, ShutdownType::Graceful) { let exit_timeout = self diff --git a/pingora-core/src/services/background.rs b/pingora-core/src/services/background.rs index 2b84532..2ce618f 100644 --- a/pingora-core/src/services/background.rs +++ b/pingora-core/src/services/background.rs @@ -23,7 +23,9 @@ use async_trait::async_trait; use std::sync::Arc; use super::Service; -use crate::server::{ListenFds, ShutdownWatch}; +#[cfg(unix)] +use crate::server::ListenFds; +use crate::server::ShutdownWatch; /// The background service interface #[async_trait] @@ -65,7 +67,11 @@ impl<A> Service for GenBackgroundService<A> where A: BackgroundService + Send + Sync + 'static, { - async fn start_service(&mut self, _fds: Option<ListenFds>, shutdown: ShutdownWatch) { + async fn start_service( + &mut self, + #[cfg(unix)] _fds: Option<ListenFds>, + shutdown: ShutdownWatch, + ) { self.task.start(shutdown).await; } diff --git a/pingora-core/src/services/listening.rs b/pingora-core/src/services/listening.rs index ea81799..b1b2800 100644 --- a/pingora-core/src/services/listening.rs +++ b/pingora-core/src/services/listening.rs @@ -21,7 +21,9 @@ use crate::apps::ServerApp; use crate::listeners::{Listeners, ServerAddress, TcpSocketOptions, TlsSettings, TransportStack}; use crate::protocols::Stream; -use crate::server::{ListenFds, ShutdownWatch}; +#[cfg(unix)] +use crate::server::ListenFds; +use crate::server::ShutdownWatch; use crate::services::Service as ServiceTrait; use async_trait::async_trait; @@ -83,6 +85,7 @@ impl<A> Service<A> { /// /// Optionally take a permission of the socket file. The default is read and write access for /// everyone (0o666). + #[cfg(unix)] pub fn add_uds(&mut self, addr: &str, perm: Option<Permissions>) { self.listeners.add_uds(addr, perm); } @@ -201,9 +204,16 @@ impl<A: ServerApp + Send + Sync + 'static> Service<A> { #[async_trait] impl<A: ServerApp + Send + Sync + 'static> ServiceTrait for Service<A> { - async fn start_service(&mut self, fds: Option<ListenFds>, shutdown: ShutdownWatch) { + async fn start_service( + &mut self, + #[cfg(unix)] fds: Option<ListenFds>, + shutdown: ShutdownWatch, + ) { let runtime = current_handle(); - let endpoints = self.listeners.build(fds); + let endpoints = self.listeners.build( + #[cfg(unix)] + fds, + ); let app_logic = self .app_logic .take() diff --git a/pingora-core/src/services/mod.rs b/pingora-core/src/services/mod.rs index 67e72dc..f708a61 100644 --- a/pingora-core/src/services/mod.rs +++ b/pingora-core/src/services/mod.rs @@ -23,7 +23,9 @@ use async_trait::async_trait; -use crate::server::{ListenFds, ShutdownWatch}; +#[cfg(unix)] +use crate::server::ListenFds; +use crate::server::ShutdownWatch; pub mod background; pub mod listening; @@ -33,13 +35,17 @@ pub mod listening; pub trait Service: Sync + Send { /// This function will be called when the server is ready to start the service. /// - /// - `fds`: a collection of listening file descriptors. During zero downtime restart + /// - `fds` (Unix only): a collection of listening file descriptors. During zero downtime restart /// the `fds` would contain the listening sockets passed from the old service, services should /// take the sockets they need to use then. If the sockets the service looks for don't appear in /// the collection, the service should create its own listening sockets and then put them into /// the collection in order for them to be passed to the next server. /// - `shutdown`: the shutdown signal this server would receive. - async fn start_service(&mut self, fds: Option<ListenFds>, mut shutdown: ShutdownWatch); + async fn start_service( + &mut self, + #[cfg(unix)] fds: Option<ListenFds>, + mut shutdown: ShutdownWatch, + ); /// The name of the service, just for logging and naming the threads assigned to this service /// diff --git a/pingora-core/src/upstreams/peer.rs b/pingora-core/src/upstreams/peer.rs index 7b80857..4030667 100644 --- a/pingora-core/src/upstreams/peer.rs +++ b/pingora-core/src/upstreams/peer.rs @@ -23,14 +23,17 @@ use std::collections::BTreeMap; use std::fmt::{Display, Formatter, Result as FmtResult}; use std::hash::{Hash, Hasher}; use std::net::{IpAddr, SocketAddr as InetSocketAddr, ToSocketAddrs as ToInetSocketAddrs}; -use std::os::unix::net::SocketAddr as UnixSocketAddr; -use std::os::unix::prelude::AsRawFd; +#[cfg(unix)] +use std::os::unix::{net::SocketAddr as UnixSocketAddr, prelude::AsRawFd}; +#[cfg(windows)] +use std::os::windows::io::AsRawSocket; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::Duration; use crate::connectors::{l4::BindTo, L4Connect}; use crate::protocols::l4::socket::SocketAddr; +#[cfg(unix)] use crate::protocols::ConnFdReusable; use crate::protocols::TcpKeepalive; use crate::tls::x509::X509; @@ -186,10 +189,17 @@ pub trait Peer: Display + Clone { .unwrap_or_default() } + #[cfg(unix)] fn matches_fd<V: AsRawFd>(&self, fd: V) -> bool { self.address().check_fd_match(fd) } + #[cfg(windows)] + fn matches_sock<V: AsRawSocket>(&self, sock: V) -> bool { + use crate::protocols::ConnSockReusable; + self.address().check_sock_match(sock) + } + fn get_tracer(&self) -> Option<Tracer> { None } @@ -211,6 +221,7 @@ impl BasicPeer { } /// Create a new [`BasicPeer`] with the given path to a Unix domain socket. + #[cfg(unix)] pub fn new_uds<P: AsRef<Path>>(path: P) -> Result<Self> { let addr = SocketAddr::Unix( UnixSocketAddr::from_pathname(path.as_ref()) @@ -445,6 +456,7 @@ impl HttpPeer { } /// Create a new [`HttpPeer`] with the given path to Unix domain socket and TLS settings. + #[cfg(unix)] pub fn new_uds(path: &str, tls: bool, sni: String) -> Result<Self> { let addr = SocketAddr::Unix( UnixSocketAddr::from_pathname(Path::new(path)).or_err(SocketError, "invalid path")?, @@ -547,6 +559,7 @@ impl Peer for HttpPeer { self.proxy.as_ref() } + #[cfg(unix)] fn matches_fd<V: AsRawFd>(&self, fd: V) -> bool { if let Some(proxy) = self.get_proxy() { proxy.next_hop.check_fd_match(fd) @@ -555,6 +568,17 @@ impl Peer for HttpPeer { } } + #[cfg(windows)] + fn matches_sock<V: AsRawSocket>(&self, sock: V) -> bool { + use crate::protocols::ConnSockReusable; + + if let Some(proxy) = self.get_proxy() { + panic!("windows do not support peers with proxy") + } else { + self.address().check_sock_match(sock) + } + } + fn get_client_cert_key(&self) -> Option<&Arc<CertKey>> { self.client_cert_key.as_ref() } diff --git a/pingora-proxy/src/proxy_h1.rs b/pingora-proxy/src/proxy_h1.rs index 1dbb561..15ce038 100644 --- a/pingora-proxy/src/proxy_h1.rs +++ b/pingora-proxy/src/proxy_h1.rs @@ -116,13 +116,18 @@ impl<SV> HttpProxy<SV> { SV: ProxyHttp + Send + Sync, SV::CTX: Send + Sync, { + #[cfg(windows)] + let raw = client_session.id() as std::os::windows::io::RawSocket; + #[cfg(unix)] + let raw = client_session.id(); + if let Err(e) = self .inner .connected_to_upstream( session, reused, peer, - client_session.id(), + raw, Some(client_session.digest()), ctx, ) diff --git a/pingora-proxy/src/proxy_h2.rs b/pingora-proxy/src/proxy_h2.rs index 0750160..f133c2f 100644 --- a/pingora-proxy/src/proxy_h2.rs +++ b/pingora-proxy/src/proxy_h2.rs @@ -193,16 +193,14 @@ impl<SV> HttpProxy<SV> { SV: ProxyHttp + Send + Sync, SV::CTX: Send + Sync, { + #[cfg(windows)] + let raw = client_session.fd() as std::os::windows::io::RawSocket; + #[cfg(unix)] + let raw = client_session.fd(); + if let Err(e) = self .inner - .connected_to_upstream( - session, - reused, - peer, - client_session.fd(), - client_session.digest(), - ctx, - ) + .connected_to_upstream(session, reused, peer, raw, client_session.digest(), ctx) .await { return (false, Some(e)); diff --git a/pingora-proxy/src/proxy_trait.rs b/pingora-proxy/src/proxy_trait.rs index 029ea58..4f3f627 100644 --- a/pingora-proxy/src/proxy_trait.rs +++ b/pingora-proxy/src/proxy_trait.rs @@ -425,7 +425,8 @@ pub trait ProxyHttp { _session: &mut Session, _reused: bool, _peer: &HttpPeer, - _fd: std::os::unix::io::RawFd, + #[cfg(unix)] _fd: std::os::unix::io::RawFd, + #[cfg(windows)] _sock: std::os::windows::io::RawSocket, _digest: Option<&Digest>, _ctx: &mut Self::CTX, ) -> Result<()> diff --git a/pingora-proxy/src/subrequest.rs b/pingora-proxy/src/subrequest.rs index e75dcc6..f9367b7 100644 --- a/pingora-proxy/src/subrequest.rs +++ b/pingora-proxy/src/subrequest.rs @@ -19,6 +19,7 @@ use pingora_cache::lock::WritePermit; use pingora_core::protocols::raw_connect::ProxyDigest; use pingora_core::protocols::{ GetProxyDigest, GetSocketDigest, GetTimingDigest, SocketDigest, Ssl, TimingDigest, UniqueID, + UniqueIDType, }; use std::io::Cursor; use std::sync::Arc; @@ -68,7 +69,7 @@ impl AsyncWrite for DummyIO { } impl UniqueID for DummyIO { - fn id(&self) -> i32 { + fn id(&self) -> UniqueIDType { 0 // placeholder } } diff --git a/pingora-proxy/tests/utils/server_utils.rs b/pingora-proxy/tests/utils/server_utils.rs index f90a27e..885fcb1 100644 --- a/pingora-proxy/tests/utils/server_utils.rs +++ b/pingora-proxy/tests/utils/server_utils.rs @@ -192,7 +192,8 @@ impl ProxyHttp for ExampleProxyHttps { _http_session: &mut Session, reused: bool, _peer: &HttpPeer, - _fd: std::os::unix::io::RawFd, + #[cfg(unix)] _fd: std::os::unix::io::RawFd, + #[cfg(windows)] _sock: std::os::windows::io::RawSocket, digest: Option<&Digest>, ctx: &mut CTX, ) -> Result<()> { @@ -311,7 +312,8 @@ impl ProxyHttp for ExampleProxyHttp { _http_session: &mut Session, reused: bool, _peer: &HttpPeer, - _fd: std::os::unix::io::RawFd, + #[cfg(unix)] _fd: std::os::unix::io::RawFd, + #[cfg(windows)] _sock: std::os::windows::io::RawSocket, digest: Option<&Digest>, ctx: &mut CTX, ) -> Result<()> { @@ -544,6 +546,7 @@ fn test_main() { let mut proxy_service_http = pingora_proxy::http_proxy_service(&my_server.configuration, ExampleProxyHttp {}); proxy_service_http.add_tcp("0.0.0.0:6147"); + #[cfg(unix)] proxy_service_http.add_uds("/tmp/pingora_proxy.sock", None); let mut proxy_service_h2c = |