aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.bleep2
-rw-r--r--pingora-core/src/apps/mod.rs27
-rw-r--r--pingora-core/src/protocols/l4/stream.rs101
-rw-r--r--pingora-core/src/protocols/mod.rs18
-rw-r--r--pingora-core/src/protocols/tls/mod.rs5
-rw-r--r--pingora-proxy/src/subrequest.rs6
-rw-r--r--pingora-proxy/tests/test_basic.rs22
7 files changed, 171 insertions, 10 deletions
diff --git a/.bleep b/.bleep
index 4611476..9840797 100644
--- a/.bleep
+++ b/.bleep
@@ -1 +1 @@
-9d8254e966aca99eeb6623f2bcf5fc93facbb05a \ No newline at end of file
+dc97a9520b124eb464f348b0381991d8669c8d8a \ No newline at end of file
diff --git a/pingora-core/src/apps/mod.rs b/pingora-core/src/apps/mod.rs
index 32fd82f..0786d08 100644
--- a/pingora-core/src/apps/mod.rs
+++ b/pingora-core/src/apps/mod.rs
@@ -28,6 +28,9 @@ use crate::protocols::Digest;
use crate::protocols::Stream;
use crate::protocols::ALPN;
+// https://datatracker.ietf.org/doc/html/rfc9113#section-3.4
+const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
+
#[async_trait]
/// This trait defines the interface of a transport layer (TCP or TLS) application.
pub trait ServerApp {
@@ -102,11 +105,29 @@ where
{
async fn process_new(
self: &Arc<Self>,
- stream: Stream,
+ mut stream: Stream,
shutdown: &ShutdownWatch,
) -> Option<Stream> {
- let h2c = self.server_options().as_ref().map_or(false, |o| o.h2c);
- // TODO: allow h2c and http/1.1 to co-exist
+ let mut h2c = self.server_options().as_ref().map_or(false, |o| o.h2c);
+
+ // try to read h2 preface
+ if h2c {
+ let mut buf = [0u8; H2_PREFACE.len()];
+ let peeked = stream
+ .try_peek(&mut buf)
+ .await
+ .map_err(|e| {
+ // this error is normal when h1 reuse and close the connection
+ debug!("Read error while peeking h2c preface {e}");
+ e
+ })
+ .ok()?;
+ // not all streams support peeking
+ if peeked {
+ // turn off h2c (use h1) if h2 preface doesn't exist
+ h2c = buf == H2_PREFACE;
+ }
+ }
if h2c || matches!(stream.selected_alpn_proto(), Some(ALPN::H2)) {
// create a shared connection digest
let digest = Arc::new(Digest {
diff --git a/pingora-core/src/protocols/l4/stream.rs b/pingora-core/src/protocols/l4/stream.rs
index 30b05a0..8ecb515 100644
--- a/pingora-core/src/protocols/l4/stream.rs
+++ b/pingora-core/src/protocols/l4/stream.rs
@@ -39,8 +39,8 @@ use tokio::net::UnixStream;
use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive};
use crate::protocols::raw_connect::ProxyDigest;
use crate::protocols::{
- GetProxyDigest, GetSocketDigest, GetTimingDigest, Shutdown, SocketDigest, Ssl, TimingDigest,
- UniqueID, UniqueIDType,
+ GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, Shutdown, SocketDigest, Ssl,
+ TimingDigest, UniqueID, UniqueIDType,
};
use crate::upstreams::peer::Tracer;
@@ -350,6 +350,8 @@ const BUF_WRITE_SIZE: usize = 1460;
#[derive(Debug)]
pub struct Stream {
stream: BufStream<RawStreamWrapper>,
+ // the data put back at the front of the read buffer, in order to replay the read
+ rewind_read_buf: Vec<u8>,
buffer_write: bool,
proxy_digest: Option<Arc<ProxyDigest>>,
socket_digest: Option<Arc<SocketDigest>>,
@@ -401,6 +403,13 @@ impl Stream {
pub fn set_rx_timestamp(&mut self) -> io::Result<()> {
Ok(())
}
+
+ /// Put Some data back to the head of the stream to be read again
+ pub(crate) fn rewind(&mut self, data: &[u8]) {
+ if !data.is_empty() {
+ self.rewind_read_buf.extend_from_slice(data);
+ }
+ }
}
impl From<TcpStream> for Stream {
@@ -411,6 +420,7 @@ impl From<TcpStream> for Stream {
BUF_WRITE_SIZE,
RawStreamWrapper::new(RawStream::Tcp(s)),
),
+ rewind_read_buf: Vec::new(),
buffer_write: true,
established_ts: SystemTime::now(),
proxy_digest: None,
@@ -432,6 +442,7 @@ impl From<UnixStream> for Stream {
BUF_WRITE_SIZE,
RawStreamWrapper::new(RawStream::Unix(s)),
),
+ rewind_read_buf: Vec::new(),
buffer_write: true,
established_ts: SystemTime::now(),
proxy_digest: None,
@@ -475,6 +486,17 @@ impl UniqueID for Stream {
impl Ssl for Stream {}
#[async_trait]
+impl Peek for Stream {
+ async fn try_peek(&mut self, buf: &mut [u8]) -> std::io::Result<bool> {
+ use tokio::io::AsyncReadExt;
+ self.read_exact(buf).await?;
+ // rewind regardless of what is read
+ self.rewind(buf);
+ Ok(true)
+ }
+}
+
+#[async_trait]
impl Shutdown for Stream {
async fn shutdown(&mut self) {
AsyncWriteExt::shutdown(self).await.unwrap_or_else(|e| {
@@ -562,7 +584,16 @@ impl AsyncRead for Stream {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
- let result = Pin::new(&mut self.stream).poll_read(cx, buf);
+ let result = if !self.rewind_read_buf.is_empty() {
+ let mut data_to_read = self.rewind_read_buf.as_slice();
+ let result = Pin::new(&mut data_to_read).poll_read(cx, buf);
+ // put the remaining data in another Vec
+ let remaining_buf = Vec::from(data_to_read);
+ let _ = std::mem::replace(&mut self.rewind_read_buf, remaining_buf);
+ result
+ } else {
+ Pin::new(&mut self.stream).poll_read(cx, buf)
+ };
self.read_pending_time.poll_time(&result);
self.rx_ts = self.stream.get_ref().rx_ts;
result
@@ -859,4 +890,68 @@ mod tests {
assert_eq!(n, message.len());
assert!(stream.rx_ts.is_none());
}
+
+ #[tokio::test]
+ async fn test_stream_rewind() {
+ let message = b"hello world";
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+ let addr = listener.local_addr().unwrap();
+ let notify = Arc::new(Notify::new());
+ let notify2 = notify.clone();
+
+ tokio::spawn(async move {
+ let (mut stream, _) = listener.accept().await.unwrap();
+ notify2.notified().await;
+ stream.write_all(message).await.unwrap();
+ });
+
+ let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
+
+ let rewind_test = b"this is Sparta!";
+ stream.rewind(rewind_test);
+
+ // partially read rewind_test because of the buffer size limit
+ let mut buffer = vec![0u8; message.len()];
+ let n = stream.read(buffer.as_mut_slice()).await.unwrap();
+ assert_eq!(n, message.len());
+ assert_eq!(buffer, rewind_test[..message.len()]);
+
+ // read the rest of rewind_test
+ let n = stream.read(buffer.as_mut_slice()).await.unwrap();
+ assert_eq!(n, rewind_test.len() - message.len());
+ assert_eq!(buffer[..n], rewind_test[message.len()..]);
+
+ // read the actual data
+ notify.notify_one();
+ let n = stream.read(buffer.as_mut_slice()).await.unwrap();
+ assert_eq!(n, message.len());
+ assert_eq!(buffer, message);
+ }
+
+ #[tokio::test]
+ async fn test_stream_peek() {
+ let message = b"hello world";
+ dbg!("try peek");
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+ let addr = listener.local_addr().unwrap();
+ let notify = Arc::new(Notify::new());
+ let notify2 = notify.clone();
+
+ tokio::spawn(async move {
+ let (mut stream, _) = listener.accept().await.unwrap();
+ notify2.notified().await;
+ stream.write_all(message).await.unwrap();
+ drop(stream);
+ });
+
+ notify.notify_one();
+
+ let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into();
+ let mut buffer = vec![0u8; 5];
+ assert!(stream.try_peek(&mut buffer).await.unwrap());
+ assert_eq!(buffer, message[0..5]);
+ let mut buffer = vec![];
+ stream.read_to_end(&mut buffer).await.unwrap();
+ assert_eq!(buffer, message);
+ }
}
diff --git a/pingora-core/src/protocols/mod.rs b/pingora-core/src/protocols/mod.rs
index 7105d61..62efb9c 100644
--- a/pingora-core/src/protocols/mod.rs
+++ b/pingora-core/src/protocols/mod.rs
@@ -71,6 +71,17 @@ pub trait Ssl {
}
}
+/// The ability peek data before consuming it
+#[async_trait]
+pub trait Peek {
+ /// Peek data but not consuming it. This call should block until some data
+ /// is sent.
+ /// Return `false` if peeking is not supported/allowed.
+ async fn try_peek(&mut self, _buf: &mut [u8]) -> std::io::Result<bool> {
+ Ok(false)
+ }
+}
+
use std::any::Any;
use tokio::io::{AsyncRead, AsyncWrite};
@@ -84,6 +95,7 @@ pub trait IO:
+ GetTimingDigest
+ GetProxyDigest
+ GetSocketDigest
+ + Peek
+ Unpin
+ Debug
+ Send
@@ -104,6 +116,7 @@ impl<
+ GetTimingDigest
+ GetProxyDigest
+ GetSocketDigest
+ + Peek
+ Unpin
+ Debug
+ Send
@@ -154,6 +167,8 @@ mod ext_io_impl {
}
}
+ impl Peek for Mock {}
+
use std::io::Cursor;
#[async_trait]
@@ -181,6 +196,7 @@ mod ext_io_impl {
None
}
}
+ impl<T> Peek for Cursor<T> {}
use tokio::io::DuplexStream;
@@ -209,6 +225,8 @@ mod ext_io_impl {
None
}
}
+
+ impl Peek for DuplexStream {}
}
#[cfg(unix)]
diff --git a/pingora-core/src/protocols/tls/mod.rs b/pingora-core/src/protocols/tls/mod.rs
index 89fe0d3..fceb0a5 100644
--- a/pingora-core/src/protocols/tls/mod.rs
+++ b/pingora-core/src/protocols/tls/mod.rs
@@ -26,7 +26,7 @@ pub use boringssl_openssl::*;
pub mod dummy_tls;
use crate::protocols::digest::TimingDigest;
-use crate::protocols::{Ssl, UniqueID, UniqueIDType};
+use crate::protocols::{Peek, Ssl, UniqueID, UniqueIDType};
use crate::tls::{self, ssl, tokio_ssl::SslStream as InnerSsl};
use log::warn;
use pingora_error::{ErrorType::*, OrErr, Result};
@@ -184,6 +184,9 @@ impl<T> Ssl for SslStream<T> {
}
}
+// TODO: implement Peek if needed
+impl<T> Peek for SslStream<T> {}
+
/// The protocol for Application-Layer Protocol Negotiation
#[derive(Hash, Clone, Debug)]
pub enum ALPN {
diff --git a/pingora-proxy/src/subrequest.rs b/pingora-proxy/src/subrequest.rs
index f9367b7..68e2876 100644
--- a/pingora-proxy/src/subrequest.rs
+++ b/pingora-proxy/src/subrequest.rs
@@ -18,8 +18,8 @@ use core::task::{Context, Poll};
use pingora_cache::lock::WritePermit;
use pingora_core::protocols::raw_connect::ProxyDigest;
use pingora_core::protocols::{
- GetProxyDigest, GetSocketDigest, GetTimingDigest, SocketDigest, Ssl, TimingDigest, UniqueID,
- UniqueIDType,
+ GetProxyDigest, GetSocketDigest, GetTimingDigest, Peek, SocketDigest, Ssl, TimingDigest,
+ UniqueID, UniqueIDType,
};
use std::io::Cursor;
use std::sync::Arc;
@@ -94,6 +94,8 @@ impl GetSocketDigest for DummyIO {
}
}
+impl Peek for DummyIO {}
+
#[async_trait]
impl pingora_core::protocols::Shutdown for DummyIO {
async fn shutdown(&mut self) -> () {}
diff --git a/pingora-proxy/tests/test_basic.rs b/pingora-proxy/tests/test_basic.rs
index 2be27e4..569e51a 100644
--- a/pingora-proxy/tests/test_basic.rs
+++ b/pingora-proxy/tests/test_basic.rs
@@ -167,6 +167,28 @@ async fn test_h2c_to_h2c() {
}
#[tokio::test]
+async fn test_h1_on_h2c_port() {
+ init();
+
+ let client = hyper::client::Client::builder()
+ .http2_only(false)
+ .build_http();
+
+ let mut req = hyper::Request::builder()
+ .uri("http://127.0.0.1:6146")
+ .body(Body::empty())
+ .unwrap();
+ req.headers_mut()
+ .insert("x-h2", HeaderValue::from_bytes(b"true").unwrap());
+ let res = client.request(req).await.unwrap();
+ assert_eq!(res.status(), reqwest::StatusCode::OK);
+ assert_eq!(res.version(), reqwest::Version::HTTP_11);
+
+ let body = res.into_body().data().await.unwrap().unwrap();
+ assert_eq!(body.as_ref(), b"Hello World!\n");
+}
+
+#[tokio::test]
async fn test_h2_to_h2_host_override() {
init();
let client = reqwest::Client::builder()