diff options
-rw-r--r-- | .bleep | 2 | ||||
-rw-r--r-- | pingora-core/src/apps/mod.rs | 27 | ||||
-rw-r--r-- | pingora-core/src/protocols/l4/stream.rs | 101 | ||||
-rw-r--r-- | pingora-core/src/protocols/mod.rs | 18 | ||||
-rw-r--r-- | pingora-core/src/protocols/tls/mod.rs | 5 | ||||
-rw-r--r-- | pingora-proxy/src/subrequest.rs | 6 | ||||
-rw-r--r-- | pingora-proxy/tests/test_basic.rs | 22 |
7 files changed, 171 insertions, 10 deletions
@@ -1 +1 @@ -9d8254e966aca99eeb6623f2bcf5fc93facbb05a
\ No newline at end of file +dc97a9520b124eb464f348b0381991d8669c8d8a
\ No newline at end of file diff --git a/pingora-core/src/apps/mod.rs b/pingora-core/src/apps/mod.rs index 32fd82f..0786d08 100644 --- a/pingora-core/src/apps/mod.rs +++ b/pingora-core/src/apps/mod.rs @@ -28,6 +28,9 @@ use crate::protocols::Digest; use crate::protocols::Stream; use crate::protocols::ALPN; +// https://datatracker.ietf.org/doc/html/rfc9113#section-3.4 +const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + #[async_trait] /// This trait defines the interface of a transport layer (TCP or TLS) application. pub trait ServerApp { @@ -102,11 +105,29 @@ where { async fn process_new( self: &Arc<Self>, - stream: Stream, + mut stream: Stream, shutdown: &ShutdownWatch, ) -> Option<Stream> { - let h2c = self.server_options().as_ref().map_or(false, |o| o.h2c); - // TODO: allow h2c and http/1.1 to co-exist + let mut h2c = self.server_options().as_ref().map_or(false, |o| o.h2c); + + // try to read h2 preface + if h2c { + let mut buf = [0u8; H2_PREFACE.len()]; + let peeked = stream + .try_peek(&mut buf) + .await + .map_err(|e| { + // this error is normal when h1 reuse and close the connection + debug!("Read error while peeking h2c preface {e}"); + e + }) + .ok()?; + // not all streams support peeking + if peeked { + // turn off h2c (use h1) if h2 preface doesn't exist + h2c = buf == H2_PREFACE; + } + } if h2c || matches!(stream.selected_alpn_proto(), Some(ALPN::H2)) { // create a shared connection digest let digest = Arc::new(Digest { diff --git a/pingora-core/src/protocols/l4/stream.rs b/pingora-core/src/protocols/l4/stream.rs index 30b05a0..8ecb515 100644 --- a/pingora-core/src/protocols/l4/stream.rs +++ b/pingora-core/src/protocols/l4/stream.rs @@ -39,8 +39,8 @@ use tokio::net::UnixStream; use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive}; use crate::protocols::raw_connect::ProxyDigest; use crate::protocols::{ - GetProxyDigest, GetSocketDigest, GetTimingDigest, Shutdown, SocketDigest, Ssl, TimingDigest, - UniqueID, UniqueIDType, + GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl, + TimingDigest, UniqueID, UniqueIDType, }; use crate::upstreams::peer::Tracer; @@ -350,6 +350,8 @@ const BUF_WRITE_SIZE: usize = 1460; #[derive(Debug)] pub struct Stream { stream: BufStream<RawStreamWrapper>, + // the data put back at the front of the read buffer, in order to replay the read + rewind_read_buf: Vec<u8>, buffer_write: bool, proxy_digest: Option<Arc<ProxyDigest>>, socket_digest: Option<Arc<SocketDigest>>, @@ -401,6 +403,13 @@ impl Stream { pub fn set_rx_timestamp(&mut self) -> io::Result<()> { Ok(()) } + + /// Put Some data back to the head of the stream to be read again + pub(crate) fn rewind(&mut self, data: &[u8]) { + if !data.is_empty() { + self.rewind_read_buf.extend_from_slice(data); + } + } } impl From<TcpStream> for Stream { @@ -411,6 +420,7 @@ impl From<TcpStream> for Stream { BUF_WRITE_SIZE, RawStreamWrapper::new(RawStream::Tcp(s)), ), + rewind_read_buf: Vec::new(), buffer_write: true, established_ts: SystemTime::now(), proxy_digest: None, @@ -432,6 +442,7 @@ impl From<UnixStream> for Stream { BUF_WRITE_SIZE, RawStreamWrapper::new(RawStream::Unix(s)), ), + rewind_read_buf: Vec::new(), buffer_write: true, established_ts: SystemTime::now(), proxy_digest: None, @@ -475,6 +486,17 @@ impl UniqueID for Stream { impl Ssl for Stream {} #[async_trait] +impl Peek for Stream { + async fn try_peek(&mut self, buf: &mut [u8]) -> std::io::Result<bool> { + use tokio::io::AsyncReadExt; + self.read_exact(buf).await?; + // rewind regardless of what is read + self.rewind(buf); + Ok(true) + } +} + +#[async_trait] impl Shutdown for Stream { async fn shutdown(&mut self) { AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| { @@ -562,7 +584,16 @@ impl AsyncRead for Stream { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>> { - let result = Pin::new(&mut self.stream).poll_read(cx, buf); + let result = if !self.rewind_read_buf.is_empty() { + let mut data_to_read = self.rewind_read_buf.as_slice(); + let result = Pin::new(&mut data_to_read).poll_read(cx, buf); + // put the remaining data in another Vec + let remaining_buf = Vec::from(data_to_read); + let _ = std::mem::replace(&mut self.rewind_read_buf, remaining_buf); + result + } else { + Pin::new(&mut self.stream).poll_read(cx, buf) + }; self.read_pending_time.poll_time(&result); self.rx_ts = self.stream.get_ref().rx_ts; result @@ -859,4 +890,68 @@ mod tests { assert_eq!(n, message.len()); assert!(stream.rx_ts.is_none()); } + + #[tokio::test] + async fn test_stream_rewind() { + let message = b"hello world"; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let notify = Arc::new(Notify::new()); + let notify2 = notify.clone(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + notify2.notified().await; + stream.write_all(message).await.unwrap(); + }); + + let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into(); + + let rewind_test = b"this is Sparta!"; + stream.rewind(rewind_test); + + // partially read rewind_test because of the buffer size limit + let mut buffer = vec![0u8; message.len()]; + let n = stream.read(buffer.as_mut_slice()).await.unwrap(); + assert_eq!(n, message.len()); + assert_eq!(buffer, rewind_test[..message.len()]); + + // read the rest of rewind_test + let n = stream.read(buffer.as_mut_slice()).await.unwrap(); + assert_eq!(n, rewind_test.len() - message.len()); + assert_eq!(buffer[..n], rewind_test[message.len()..]); + + // read the actual data + notify.notify_one(); + let n = stream.read(buffer.as_mut_slice()).await.unwrap(); + assert_eq!(n, message.len()); + assert_eq!(buffer, message); + } + + #[tokio::test] + async fn test_stream_peek() { + let message = b"hello world"; + dbg!("try peek"); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let notify = Arc::new(Notify::new()); + let notify2 = notify.clone(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + notify2.notified().await; + stream.write_all(message).await.unwrap(); + drop(stream); + }); + + notify.notify_one(); + + let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into(); + let mut buffer = vec![0u8; 5]; + assert!(stream.try_peek(&mut buffer).await.unwrap()); + assert_eq!(buffer, message[0..5]); + let mut buffer = vec![]; + stream.read_to_end(&mut buffer).await.unwrap(); + assert_eq!(buffer, message); + } } diff --git a/pingora-core/src/protocols/mod.rs b/pingora-core/src/protocols/mod.rs index 7105d61..62efb9c 100644 --- a/pingora-core/src/protocols/mod.rs +++ b/pingora-core/src/protocols/mod.rs @@ -71,6 +71,17 @@ pub trait Ssl { } } +/// The ability peek data before consuming it +#[async_trait] +pub trait Peek { + /// Peek data but not consuming it. This call should block until some data + /// is sent. + /// Return `false` if peeking is not supported/allowed. + async fn try_peek(&mut self, _buf: &mut [u8]) -> std::io::Result<bool> { + Ok(false) + } +} + use std::any::Any; use tokio::io::{AsyncRead, AsyncWrite}; @@ -84,6 +95,7 @@ pub trait IO: + GetTimingDigest + GetProxyDigest + GetSocketDigest + + Peek + Unpin + Debug + Send @@ -104,6 +116,7 @@ impl< + GetTimingDigest + GetProxyDigest + GetSocketDigest + + Peek + Unpin + Debug + Send @@ -154,6 +167,8 @@ mod ext_io_impl { } } + impl Peek for Mock {} + use std::io::Cursor; #[async_trait] @@ -181,6 +196,7 @@ mod ext_io_impl { None } } + impl<T> Peek for Cursor<T> {} use tokio::io::DuplexStream; @@ -209,6 +225,8 @@ mod ext_io_impl { None } } + + impl Peek for DuplexStream {} } #[cfg(unix)] diff --git a/pingora-core/src/protocols/tls/mod.rs b/pingora-core/src/protocols/tls/mod.rs index 89fe0d3..fceb0a5 100644 --- a/pingora-core/src/protocols/tls/mod.rs +++ b/pingora-core/src/protocols/tls/mod.rs @@ -26,7 +26,7 @@ pub use boringssl_openssl::*; pub mod dummy_tls; use crate::protocols::digest::TimingDigest; -use crate::protocols::{Ssl, UniqueID, UniqueIDType}; +use crate::protocols::{Peek, Ssl, UniqueID, UniqueIDType}; use crate::tls::{self, ssl, tokio_ssl::SslStream as InnerSsl}; use log::warn; use pingora_error::{ErrorType::*, OrErr, Result}; @@ -184,6 +184,9 @@ impl<T> Ssl for SslStream<T> { } } +// TODO: implement Peek if needed +impl<T> Peek for SslStream<T> {} + /// The protocol for Application-Layer Protocol Negotiation #[derive(Hash, Clone, Debug)] pub enum ALPN { diff --git a/pingora-proxy/src/subrequest.rs b/pingora-proxy/src/subrequest.rs index f9367b7..68e2876 100644 --- a/pingora-proxy/src/subrequest.rs +++ b/pingora-proxy/src/subrequest.rs @@ -18,8 +18,8 @@ use core::task::{Context, Poll}; use pingora_cache::lock::WritePermit; use pingora_core::protocols::raw_connect::ProxyDigest; use pingora_core::protocols::{ - GetProxyDigest, GetSocketDigest, GetTimingDigest, SocketDigest, Ssl, TimingDigest, UniqueID, - UniqueIDType, + GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, SocketDigest, Ssl, TimingDigest, + UniqueID, UniqueIDType, }; use std::io::Cursor; use std::sync::Arc; @@ -94,6 +94,8 @@ impl GetSocketDigest for DummyIO { } } +impl Peek for DummyIO {} + #[async_trait] impl pingora_core::protocols::Shutdown for DummyIO { async fn shutdown(&mut self) -> () {} diff --git a/pingora-proxy/tests/test_basic.rs b/pingora-proxy/tests/test_basic.rs index 2be27e4..569e51a 100644 --- a/pingora-proxy/tests/test_basic.rs +++ b/pingora-proxy/tests/test_basic.rs @@ -167,6 +167,28 @@ async fn test_h2c_to_h2c() { } #[tokio::test] +async fn test_h1_on_h2c_port() { + init(); + + let client = hyper::client::Client::builder() + .http2_only(false) + .build_http(); + + let mut req = hyper::Request::builder() + .uri("http://127.0.0.1:6146") + .body(Body::empty()) + .unwrap(); + req.headers_mut() + .insert("x-h2", HeaderValue::from_bytes(b"true").unwrap()); + let res = client.request(req).await.unwrap(); + assert_eq!(res.status(), reqwest::StatusCode::OK); + assert_eq!(res.version(), reqwest::Version::HTTP_11); + + let body = res.into_body().data().await.unwrap().unwrap(); + assert_eq!(body.as_ref(), b"Hello World!\n"); +} + +#[tokio::test] async fn test_h2_to_h2_host_override() { init(); let client = reqwest::Client::builder() |