aboutsummaryrefslogtreecommitdiffhomepage
path: root/libs/hbb_common/src/bytes_codec.rs
diff options
context:
space:
mode:
Diffstat (limited to 'libs/hbb_common/src/bytes_codec.rs')
-rw-r--r--libs/hbb_common/src/bytes_codec.rs274
1 files changed, 274 insertions, 0 deletions
diff --git a/libs/hbb_common/src/bytes_codec.rs b/libs/hbb_common/src/bytes_codec.rs
new file mode 100644
index 0000000..e029f1c
--- /dev/null
+++ b/libs/hbb_common/src/bytes_codec.rs
@@ -0,0 +1,274 @@
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use std::io;
+use tokio_util::codec::{Decoder, Encoder};
+
+#[derive(Debug, Clone, Copy)]
+pub struct BytesCodec {
+ state: DecodeState,
+ raw: bool,
+ max_packet_length: usize,
+}
+
+#[derive(Debug, Clone, Copy)]
+enum DecodeState {
+ Head,
+ Data(usize),
+}
+
+impl BytesCodec {
+ pub fn new() -> Self {
+ Self {
+ state: DecodeState::Head,
+ raw: false,
+ max_packet_length: usize::MAX,
+ }
+ }
+
+ pub fn set_raw(&mut self) {
+ self.raw = true;
+ }
+
+ pub fn set_max_packet_length(&mut self, n: usize) {
+ self.max_packet_length = n;
+ }
+
+ fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
+ if src.is_empty() {
+ return Ok(None);
+ }
+ let head_len = ((src[0] & 0x3) + 1) as usize;
+ if src.len() < head_len {
+ return Ok(None);
+ }
+ let mut n = src[0] as usize;
+ if head_len > 1 {
+ n |= (src[1] as usize) << 8;
+ }
+ if head_len > 2 {
+ n |= (src[2] as usize) << 16;
+ }
+ if head_len > 3 {
+ n |= (src[3] as usize) << 24;
+ }
+ n >>= 2;
+ if n > self.max_packet_length {
+ return Err(io::Error::new(io::ErrorKind::InvalidData, "Too big packet"));
+ }
+ src.advance(head_len);
+ src.reserve(n);
+ return Ok(Some(n));
+ }
+
+ fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
+ if src.len() < n {
+ return Ok(None);
+ }
+ Ok(Some(src.split_to(n)))
+ }
+}
+
+impl Decoder for BytesCodec {
+ type Item = BytesMut;
+ type Error = io::Error;
+
+ fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
+ if self.raw {
+ if !src.is_empty() {
+ let len = src.len();
+ return Ok(Some(src.split_to(len)));
+ } else {
+ return Ok(None);
+ }
+ }
+ let n = match self.state {
+ DecodeState::Head => match self.decode_head(src)? {
+ Some(n) => {
+ self.state = DecodeState::Data(n);
+ n
+ }
+ None => return Ok(None),
+ },
+ DecodeState::Data(n) => n,
+ };
+
+ match self.decode_data(n, src)? {
+ Some(data) => {
+ self.state = DecodeState::Head;
+ Ok(Some(data))
+ }
+ None => Ok(None),
+ }
+ }
+}
+
+impl Encoder<Bytes> for BytesCodec {
+ type Error = io::Error;
+
+ fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
+ if self.raw {
+ buf.reserve(data.len());
+ buf.put(data);
+ return Ok(());
+ }
+ if data.len() <= 0x3F {
+ buf.put_u8((data.len() << 2) as u8);
+ } else if data.len() <= 0x3FFF {
+ buf.put_u16_le((data.len() << 2) as u16 | 0x1);
+ } else if data.len() <= 0x3FFFFF {
+ let h = (data.len() << 2) as u32 | 0x2;
+ buf.put_u16_le((h & 0xFFFF) as u16);
+ buf.put_u8((h >> 16) as u8);
+ } else if data.len() <= 0x3FFFFFFF {
+ buf.put_u32_le((data.len() << 2) as u32 | 0x3);
+ } else {
+ return Err(io::Error::new(io::ErrorKind::InvalidInput, "Overflow"));
+ }
+ buf.extend(data);
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ #[test]
+ fn test_codec1() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ bytes.resize(0x3F, 1);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ let buf_saved = buf.clone();
+ assert_eq!(buf.len(), 0x3F + 1);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3F);
+ assert_eq!(res[0], 1);
+ } else {
+ assert!(false);
+ }
+ let mut codec2 = BytesCodec::new();
+ let mut buf2 = BytesMut::new();
+ if let Ok(None) = codec2.decode(&mut buf2) {
+ } else {
+ assert!(false);
+ }
+ buf2.extend(&buf_saved[0..1]);
+ if let Ok(None) = codec2.decode(&mut buf2) {
+ } else {
+ assert!(false);
+ }
+ buf2.extend(&buf_saved[1..]);
+ if let Ok(Some(res)) = codec2.decode(&mut buf2) {
+ assert_eq!(res.len(), 0x3F);
+ assert_eq!(res[0], 1);
+ } else {
+ assert!(false);
+ }
+ }
+
+ #[test]
+ fn test_codec2() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ assert!(!codec.encode("".into(), &mut buf).is_err());
+ assert_eq!(buf.len(), 1);
+ bytes.resize(0x3F + 1, 2);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ assert_eq!(buf.len(), 0x3F + 2 + 2);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0);
+ } else {
+ assert!(false);
+ }
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3F + 1);
+ assert_eq!(res[0], 2);
+ } else {
+ assert!(false);
+ }
+ }
+
+ #[test]
+ fn test_codec3() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ bytes.resize(0x3F - 1, 3);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ assert_eq!(buf.len(), 0x3F + 1 - 1);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3F - 1);
+ assert_eq!(res[0], 3);
+ } else {
+ assert!(false);
+ }
+ }
+ #[test]
+ fn test_codec4() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ bytes.resize(0x3FFF, 4);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ assert_eq!(buf.len(), 0x3FFF + 2);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3FFF);
+ assert_eq!(res[0], 4);
+ } else {
+ assert!(false);
+ }
+ }
+
+ #[test]
+ fn test_codec5() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ bytes.resize(0x3FFFFF, 5);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ assert_eq!(buf.len(), 0x3FFFFF + 3);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3FFFFF);
+ assert_eq!(res[0], 5);
+ } else {
+ assert!(false);
+ }
+ }
+
+ #[test]
+ fn test_codec6() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ bytes.resize(0x3FFFFF + 1, 6);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ let buf_saved = buf.clone();
+ assert_eq!(buf.len(), 0x3FFFFF + 4 + 1);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3FFFFF + 1);
+ assert_eq!(res[0], 6);
+ } else {
+ assert!(false);
+ }
+ let mut codec2 = BytesCodec::new();
+ let mut buf2 = BytesMut::new();
+ buf2.extend(&buf_saved[0..1]);
+ if let Ok(None) = codec2.decode(&mut buf2) {
+ } else {
+ assert!(false);
+ }
+ buf2.extend(&buf_saved[1..6]);
+ if let Ok(None) = codec2.decode(&mut buf2) {
+ } else {
+ assert!(false);
+ }
+ buf2.extend(&buf_saved[6..]);
+ if let Ok(Some(res)) = codec2.decode(&mut buf2) {
+ assert_eq!(res.len(), 0x3FFFFF + 1);
+ assert_eq!(res[0], 6);
+ } else {
+ assert!(false);
+ }
+ }
+}