diff options
author | rustdesk <[email protected]> | 2022-05-12 20:07:45 +0800 |
---|---|---|
committer | rustdesk <[email protected]> | 2022-05-12 20:07:45 +0800 |
commit | 03ca2a95177d16999adad254057eb9c27b5e53b6 (patch) | |
tree | 36dab7dab5e015de94b08db67581a3b782599d56 /libs | |
parent | b3f39598a7324dacec0cd84d5e09b95724805cc8 (diff) | |
download | rustdesk-server-03ca2a95177d16999adad254057eb9c27b5e53b6.tar.gz rustdesk-server-03ca2a95177d16999adad254057eb9c27b5e53b6.zip |
missed files
Diffstat (limited to 'libs')
-rw-r--r-- | libs/hbb_common/.gitignore | 4 | ||||
-rw-r--r-- | libs/hbb_common/Cargo.toml | 48 | ||||
-rw-r--r-- | libs/hbb_common/build.rs | 9 | ||||
-rw-r--r-- | libs/hbb_common/protos/message.proto | 481 | ||||
-rw-r--r-- | libs/hbb_common/protos/rendezvous.proto | 171 | ||||
-rw-r--r-- | libs/hbb_common/src/bytes_codec.rs | 274 | ||||
-rw-r--r-- | libs/hbb_common/src/compress.rs | 50 | ||||
-rw-r--r-- | libs/hbb_common/src/config.rs | 876 | ||||
-rw-r--r-- | libs/hbb_common/src/fs.rs | 560 | ||||
-rw-r--r-- | libs/hbb_common/src/lib.rs | 211 | ||||
-rw-r--r-- | libs/hbb_common/src/quic.rs | 135 | ||||
-rw-r--r-- | libs/hbb_common/src/socket_client.rs | 91 | ||||
-rw-r--r-- | libs/hbb_common/src/tcp.rs | 285 | ||||
-rw-r--r-- | libs/hbb_common/src/udp.rs | 165 |
14 files changed, 3360 insertions, 0 deletions
diff --git a/libs/hbb_common/.gitignore b/libs/hbb_common/.gitignore new file mode 100644 index 0000000..b1cf151 --- /dev/null +++ b/libs/hbb_common/.gitignore @@ -0,0 +1,4 @@ +/target +**/*.rs.bk +Cargo.lock +src/protos/ diff --git a/libs/hbb_common/Cargo.toml b/libs/hbb_common/Cargo.toml new file mode 100644 index 0000000..bc31223 --- /dev/null +++ b/libs/hbb_common/Cargo.toml @@ -0,0 +1,48 @@ +[package] +name = "hbb_common" +version = "0.1.0" +authors = ["open-trade <[email protected]>"] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +protobuf = "3.0.0-alpha.2" +tokio = { version = "1.15", features = ["full"] } +tokio-util = { version = "0.6", features = ["full"] } +futures = "0.3" +bytes = "1.1" +log = "0.4" +env_logger = "0.9" +socket2 = { version = "0.3", features = ["reuseport"] } +zstd = "0.9" +quinn = {version = "0.8", optional = true } +anyhow = "1.0" +futures-util = "0.3" +directories-next = "2.0" +rand = "0.8" +serde_derive = "1.0" +serde = "1.0" +lazy_static = "1.4" +confy = { git = "https://github.com/open-trade/confy" } +dirs-next = "2.0" +filetime = "0.2" +sodiumoxide = "0.2" +regex = "1.4" +tokio-socks = { git = "https://github.com/open-trade/tokio-socks" } + +[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] +mac_address = "1.1" + +[features] +quic = [] + +[build-dependencies] +protobuf-codegen-pure = "3.0.0-alpha.2" + +[target.'cfg(target_os = "windows")'.dependencies] +winapi = { version = "0.3", features = ["winuser"] } + +[dev-dependencies] +toml = "0.5" +serde_json = "1.0" diff --git a/libs/hbb_common/build.rs b/libs/hbb_common/build.rs new file mode 100644 index 0000000..99dacb7 --- /dev/null +++ b/libs/hbb_common/build.rs @@ -0,0 +1,9 @@ +fn main() { + std::fs::create_dir_all("src/protos").unwrap(); + protobuf_codegen_pure::Codegen::new() + .out_dir("src/protos") + .inputs(&["protos/rendezvous.proto", "protos/message.proto"]) + .include("protos") + .run() + .expect("Codegen failed."); +} diff --git a/libs/hbb_common/protos/message.proto b/libs/hbb_common/protos/message.proto new file mode 100644 index 0000000..f95007d --- /dev/null +++ b/libs/hbb_common/protos/message.proto @@ -0,0 +1,481 @@ +syntax = "proto3"; +package hbb; + +message VP9 { + bytes data = 1; + bool key = 2; + int64 pts = 3; +} + +message VP9s { repeated VP9 frames = 1; } + +message RGB { bool compress = 1; } + +// planes data send directly in binary for better use arraybuffer on web +message YUV { + bool compress = 1; + int32 stride = 2; +} + +message VideoFrame { + oneof union { + VP9s vp9s = 6; + RGB rgb = 7; + YUV yuv = 8; + } +} + +message IdPk { + string id = 1; + bytes pk = 2; +} + +message DisplayInfo { + sint32 x = 1; + sint32 y = 2; + int32 width = 3; + int32 height = 4; + string name = 5; + bool online = 6; +} + +message PortForward { + string host = 1; + int32 port = 2; +} + +message FileTransfer { + string dir = 1; + bool show_hidden = 2; +} + +message LoginRequest { + string username = 1; + bytes password = 2; + string my_id = 4; + string my_name = 5; + OptionMessage option = 6; + oneof union { + FileTransfer file_transfer = 7; + PortForward port_forward = 8; + } + bool video_ack_required = 9; +} + +message ChatMessage { string text = 1; } + +message PeerInfo { + string username = 1; + string hostname = 2; + string platform = 3; + repeated DisplayInfo displays = 4; + int32 current_display = 5; + bool sas_enabled = 6; + string version = 7; + int32 conn_id = 8; +} + +message LoginResponse { + oneof union { + string error = 1; + PeerInfo peer_info = 2; + } +} + +message MouseEvent { + int32 mask = 1; + sint32 x = 2; + sint32 y = 3; + repeated ControlKey modifiers = 4; +} + +enum ControlKey { + Unknown = 0; + Alt = 1; + Backspace = 2; + CapsLock = 3; + Control = 4; + Delete = 5; + DownArrow = 6; + End = 7; + Escape = 8; + F1 = 9; + F10 = 10; + F11 = 11; + F12 = 12; + F2 = 13; + F3 = 14; + F4 = 15; + F5 = 16; + F6 = 17; + F7 = 18; + F8 = 19; + F9 = 20; + Home = 21; + LeftArrow = 22; + /// meta key (also known as "windows"; "super"; and "command") + Meta = 23; + /// option key on macOS (alt key on Linux and Windows) + Option = 24; // deprecated, use Alt instead + PageDown = 25; + PageUp = 26; + Return = 27; + RightArrow = 28; + Shift = 29; + Space = 30; + Tab = 31; + UpArrow = 32; + Numpad0 = 33; + Numpad1 = 34; + Numpad2 = 35; + Numpad3 = 36; + Numpad4 = 37; + Numpad5 = 38; + Numpad6 = 39; + Numpad7 = 40; + Numpad8 = 41; + Numpad9 = 42; + Cancel = 43; + Clear = 44; + Menu = 45; // deprecated, use Alt instead + Pause = 46; + Kana = 47; + Hangul = 48; + Junja = 49; + Final = 50; + Hanja = 51; + Kanji = 52; + Convert = 53; + Select = 54; + Print = 55; + Execute = 56; + Snapshot = 57; + Insert = 58; + Help = 59; + Sleep = 60; + Separator = 61; + Scroll = 62; + NumLock = 63; + RWin = 64; + Apps = 65; + Multiply = 66; + Add = 67; + Subtract = 68; + Decimal = 69; + Divide = 70; + Equals = 71; + NumpadEnter = 72; + RShift = 73; + RControl = 74; + RAlt = 75; + CtrlAltDel = 100; + LockScreen = 101; +} + +message KeyEvent { + bool down = 1; + bool press = 2; + oneof union { + ControlKey control_key = 3; + uint32 chr = 4; + uint32 unicode = 5; + string seq = 6; + } + repeated ControlKey modifiers = 8; +} + +message CursorData { + uint64 id = 1; + sint32 hotx = 2; + sint32 hoty = 3; + int32 width = 4; + int32 height = 5; + bytes colors = 6; +} + +message CursorPosition { + sint32 x = 1; + sint32 y = 2; +} + +message Hash { + string salt = 1; + string challenge = 2; +} + +message Clipboard { + bool compress = 1; + bytes content = 2; +} + +enum FileType { + Dir = 0; + DirLink = 2; + DirDrive = 3; + File = 4; + FileLink = 5; +} + +message FileEntry { + FileType entry_type = 1; + string name = 2; + bool is_hidden = 3; + uint64 size = 4; + uint64 modified_time = 5; +} + +message FileDirectory { + int32 id = 1; + string path = 2; + repeated FileEntry entries = 3; +} + +message ReadDir { + string path = 1; + bool include_hidden = 2; +} + +message ReadAllFiles { + int32 id = 1; + string path = 2; + bool include_hidden = 3; +} + +message FileAction { + oneof union { + ReadDir read_dir = 1; + FileTransferSendRequest send = 2; + FileTransferReceiveRequest receive = 3; + FileDirCreate create = 4; + FileRemoveDir remove_dir = 5; + FileRemoveFile remove_file = 6; + ReadAllFiles all_files = 7; + FileTransferCancel cancel = 8; + } +} + +message FileTransferCancel { int32 id = 1; } + +message FileResponse { + oneof union { + FileDirectory dir = 1; + FileTransferBlock block = 2; + FileTransferError error = 3; + FileTransferDone done = 4; + } +} + +message FileTransferBlock { + int32 id = 1; + sint32 file_num = 2; + bytes data = 3; + bool compressed = 4; +} + +message FileTransferError { + int32 id = 1; + string error = 2; + sint32 file_num = 3; +} + +message FileTransferSendRequest { + int32 id = 1; + string path = 2; + bool include_hidden = 3; +} + +message FileTransferDone { + int32 id = 1; + sint32 file_num = 2; +} + +message FileTransferReceiveRequest { + int32 id = 1; + string path = 2; // path written to + repeated FileEntry files = 3; +} + +message FileRemoveDir { + int32 id = 1; + string path = 2; + bool recursive = 3; +} + +message FileRemoveFile { + int32 id = 1; + string path = 2; + sint32 file_num = 3; +} + +message FileDirCreate { + int32 id = 1; + string path = 2; +} + +// main logic from freeRDP +message CliprdrMonitorReady { + int32 conn_id = 1; +} + +message CliprdrFormat { + int32 conn_id = 1; + int32 id = 2; + string format = 3; +} + +message CliprdrServerFormatList { + int32 conn_id = 1; + repeated CliprdrFormat formats = 2; +} + +message CliprdrServerFormatListResponse { + int32 conn_id = 1; + int32 msg_flags = 2; +} + +message CliprdrServerFormatDataRequest { + int32 conn_id = 1; + int32 requested_format_id = 2; +} + +message CliprdrServerFormatDataResponse { + int32 conn_id = 1; + int32 msg_flags = 2; + bytes format_data = 3; +} + +message CliprdrFileContentsRequest { + int32 conn_id = 1; + int32 stream_id = 2; + int32 list_index = 3; + int32 dw_flags = 4; + int32 n_position_low = 5; + int32 n_position_high = 6; + int32 cb_requested = 7; + bool have_clip_data_id = 8; + int32 clip_data_id = 9; +} + +message CliprdrFileContentsResponse { + int32 conn_id = 1; + int32 msg_flags = 3; + int32 stream_id = 4; + bytes requested_data = 5; +} + +message Cliprdr { + oneof union { + CliprdrMonitorReady ready = 1; + CliprdrServerFormatList format_list = 2; + CliprdrServerFormatListResponse format_list_response = 3; + CliprdrServerFormatDataRequest format_data_request = 4; + CliprdrServerFormatDataResponse format_data_response = 5; + CliprdrFileContentsRequest file_contents_request = 6; + CliprdrFileContentsResponse file_contents_response = 7; + } +} + +message SwitchDisplay { + int32 display = 1; + sint32 x = 2; + sint32 y = 3; + int32 width = 4; + int32 height = 5; +} + +message PermissionInfo { + enum Permission { + Keyboard = 0; + Clipboard = 2; + Audio = 3; + File = 4; + } + + Permission permission = 1; + bool enabled = 2; +} + +enum ImageQuality { + NotSet = 0; + Low = 2; + Balanced = 3; + Best = 4; +} + +message OptionMessage { + enum BoolOption { + NotSet = 0; + No = 1; + Yes = 2; + } + ImageQuality image_quality = 1; + BoolOption lock_after_session_end = 2; + BoolOption show_remote_cursor = 3; + BoolOption privacy_mode = 4; + BoolOption block_input = 5; + int32 custom_image_quality = 6; + BoolOption disable_audio = 7; + BoolOption disable_clipboard = 8; + BoolOption enable_file_transfer = 9; +} + +message OptionResponse { + OptionMessage opt = 1; + string error = 2; +} + +message TestDelay { + int64 time = 1; + bool from_client = 2; +} + +message PublicKey { + bytes asymmetric_value = 1; + bytes symmetric_value = 2; +} + +message SignedId { bytes id = 1; } + +message AudioFormat { + uint32 sample_rate = 1; + uint32 channels = 2; +} + +message AudioFrame { bytes data = 1; } + +message Misc { + oneof union { + ChatMessage chat_message = 4; + SwitchDisplay switch_display = 5; + PermissionInfo permission_info = 6; + OptionMessage option = 7; + AudioFormat audio_format = 8; + string close_reason = 9; + bool refresh_video = 10; + OptionResponse option_response = 11; + bool video_received = 12; + } +} + +message Message { + oneof union { + SignedId signed_id = 3; + PublicKey public_key = 4; + TestDelay test_delay = 5; + VideoFrame video_frame = 6; + LoginRequest login_request = 7; + LoginResponse login_response = 8; + Hash hash = 9; + MouseEvent mouse_event = 10; + AudioFrame audio_frame = 11; + CursorData cursor_data = 12; + CursorPosition cursor_position = 13; + uint64 cursor_id = 14; + KeyEvent key_event = 15; + Clipboard clipboard = 16; + FileAction file_action = 17; + FileResponse file_response = 18; + Misc misc = 19; + Cliprdr cliprdr = 20; + } +} diff --git a/libs/hbb_common/protos/rendezvous.proto b/libs/hbb_common/protos/rendezvous.proto new file mode 100644 index 0000000..2c5f1b3 --- /dev/null +++ b/libs/hbb_common/protos/rendezvous.proto @@ -0,0 +1,171 @@ +syntax = "proto3"; +package hbb; + +message RegisterPeer { + string id = 1; + int32 serial = 2; +} + +enum ConnType { + DEFAULT_CONN = 0; + FILE_TRANSFER = 1; + PORT_FORWARD = 2; + RDP = 3; +} + +message RegisterPeerResponse { bool request_pk = 2; } + +message PunchHoleRequest { + string id = 1; + NatType nat_type = 2; + string licence_key = 3; + ConnType conn_type = 4; + string token = 5; +} + +message PunchHole { + bytes socket_addr = 1; + string relay_server = 2; + NatType nat_type = 3; +} + +message TestNatRequest { + int32 serial = 1; +} + +// per my test, uint/int has no difference in encoding, int not good for negative, use sint for negative +message TestNatResponse { + int32 port = 1; + ConfigUpdate cu = 2; // for mobile +} + +enum NatType { + UNKNOWN_NAT = 0; + ASYMMETRIC = 1; + SYMMETRIC = 2; +} + +message PunchHoleSent { + bytes socket_addr = 1; + string id = 2; + string relay_server = 3; + NatType nat_type = 4; + string version = 5; +} + +message RegisterPk { + string id = 1; + bytes uuid = 2; + bytes pk = 3; + string old_id = 4; +} + +message RegisterPkResponse { + enum Result { + OK = 0; + UUID_MISMATCH = 2; + ID_EXISTS = 3; + TOO_FREQUENT = 4; + INVALID_ID_FORMAT = 5; + NOT_SUPPORT = 6; + SERVER_ERROR = 7; + } + Result result = 1; +} + +message PunchHoleResponse { + bytes socket_addr = 1; + bytes pk = 2; + enum Failure { + ID_NOT_EXIST = 0; + OFFLINE = 2; + LICENSE_MISMATCH = 3; + LICENSE_OVERUSE = 4; + } + Failure failure = 3; + string relay_server = 4; + oneof union { + NatType nat_type = 5; + bool is_local = 6; + } + string other_failure = 7; +} + +message ConfigUpdate { + int32 serial = 1; + repeated string rendezvous_servers = 2; +} + +message RequestRelay { + string id = 1; + string uuid = 2; + bytes socket_addr = 3; + string relay_server = 4; + bool secure = 5; + string licence_key = 6; + ConnType conn_type = 7; + string token = 8; +} + +message RelayResponse { + bytes socket_addr = 1; + string uuid = 2; + string relay_server = 3; + oneof union { + string id = 4; + bytes pk = 5; + } + string refuse_reason = 6; + string version = 7; +} + +message SoftwareUpdate { string url = 1; } + +// if in same intranet, punch hole won't work both for udp and tcp, +// even some router has below connection error if we connect itself, +// { kind: Other, error: "could not resolve to any address" }, +// so we request local address to connect. +message FetchLocalAddr { + bytes socket_addr = 1; + string relay_server = 2; +} + +message LocalAddr { + bytes socket_addr = 1; + bytes local_addr = 2; + string relay_server = 3; + string id = 4; + string version = 5; +} + +message PeerDiscovery { + string cmd = 1; + string mac = 2; + string id = 3; + string username = 4; + string hostname = 5; + string platform = 6; + string misc = 7; +} + +message RendezvousMessage { + oneof union { + RegisterPeer register_peer = 6; + RegisterPeerResponse register_peer_response = 7; + PunchHoleRequest punch_hole_request = 8; + PunchHole punch_hole = 9; + PunchHoleSent punch_hole_sent = 10; + PunchHoleResponse punch_hole_response = 11; + FetchLocalAddr fetch_local_addr = 12; + LocalAddr local_addr = 13; + ConfigUpdate configure_update = 14; + RegisterPk register_pk = 15; + RegisterPkResponse register_pk_response = 16; + SoftwareUpdate software_update = 17; + RequestRelay request_relay = 18; + RelayResponse relay_response = 19; + TestNatRequest test_nat_request = 20; + TestNatResponse test_nat_response = 21; + PeerDiscovery peer_discovery = 22; + } +} diff --git a/libs/hbb_common/src/bytes_codec.rs b/libs/hbb_common/src/bytes_codec.rs new file mode 100644 index 0000000..e029f1c --- /dev/null +++ b/libs/hbb_common/src/bytes_codec.rs @@ -0,0 +1,274 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::io; +use tokio_util::codec::{Decoder, Encoder}; + +#[derive(Debug, Clone, Copy)] +pub struct BytesCodec { + state: DecodeState, + raw: bool, + max_packet_length: usize, +} + +#[derive(Debug, Clone, Copy)] +enum DecodeState { + Head, + Data(usize), +} + +impl BytesCodec { + pub fn new() -> Self { + Self { + state: DecodeState::Head, + raw: false, + max_packet_length: usize::MAX, + } + } + + pub fn set_raw(&mut self) { + self.raw = true; + } + + pub fn set_max_packet_length(&mut self, n: usize) { + self.max_packet_length = n; + } + + fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> { + if src.is_empty() { + return Ok(None); + } + let head_len = ((src[0] & 0x3) + 1) as usize; + if src.len() < head_len { + return Ok(None); + } + let mut n = src[0] as usize; + if head_len > 1 { + n |= (src[1] as usize) << 8; + } + if head_len > 2 { + n |= (src[2] as usize) << 16; + } + if head_len > 3 { + n |= (src[3] as usize) << 24; + } + n >>= 2; + if n > self.max_packet_length { + return Err(io::Error::new(io::ErrorKind::InvalidData, "Too big packet")); + } + src.advance(head_len); + src.reserve(n); + return Ok(Some(n)); + } + + fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> { + if src.len() < n { + return Ok(None); + } + Ok(Some(src.split_to(n))) + } +} + +impl Decoder for BytesCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> { + if self.raw { + if !src.is_empty() { + let len = src.len(); + return Ok(Some(src.split_to(len))); + } else { + return Ok(None); + } + } + let n = match self.state { + DecodeState::Head => match self.decode_head(src)? { + Some(n) => { + self.state = DecodeState::Data(n); + n + } + None => return Ok(None), + }, + DecodeState::Data(n) => n, + }; + + match self.decode_data(n, src)? { + Some(data) => { + self.state = DecodeState::Head; + Ok(Some(data)) + } + None => Ok(None), + } + } +} + +impl Encoder<Bytes> for BytesCodec { + type Error = io::Error; + + fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> { + if self.raw { + buf.reserve(data.len()); + buf.put(data); + return Ok(()); + } + if data.len() <= 0x3F { + buf.put_u8((data.len() << 2) as u8); + } else if data.len() <= 0x3FFF { + buf.put_u16_le((data.len() << 2) as u16 | 0x1); + } else if data.len() <= 0x3FFFFF { + let h = (data.len() << 2) as u32 | 0x2; + buf.put_u16_le((h & 0xFFFF) as u16); + buf.put_u8((h >> 16) as u8); + } else if data.len() <= 0x3FFFFFFF { + buf.put_u32_le((data.len() << 2) as u32 | 0x3); + } else { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "Overflow")); + } + buf.extend(data); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_codec1() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec<u8> = Vec::new(); + bytes.resize(0x3F, 1); + assert!(!codec.encode(bytes.into(), &mut buf).is_err()); + let buf_saved = buf.clone(); + assert_eq!(buf.len(), 0x3F + 1); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3F); + assert_eq!(res[0], 1); + } else { + assert!(false); + } + let mut codec2 = BytesCodec::new(); + let mut buf2 = BytesMut::new(); + if let Ok(None) = codec2.decode(&mut buf2) { + } else { + assert!(false); + } + buf2.extend(&buf_saved[0..1]); + if let Ok(None) = codec2.decode(&mut buf2) { + } else { + assert!(false); + } + buf2.extend(&buf_saved[1..]); + if let Ok(Some(res)) = codec2.decode(&mut buf2) { + assert_eq!(res.len(), 0x3F); + assert_eq!(res[0], 1); + } else { + assert!(false); + } + } + + #[test] + fn test_codec2() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec<u8> = Vec::new(); + assert!(!codec.encode("".into(), &mut buf).is_err()); + assert_eq!(buf.len(), 1); + bytes.resize(0x3F + 1, 2); + assert!(!codec.encode(bytes.into(), &mut buf).is_err()); + assert_eq!(buf.len(), 0x3F + 2 + 2); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0); + } else { + assert!(false); + } + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3F + 1); + assert_eq!(res[0], 2); + } else { + assert!(false); + } + } + + #[test] + fn test_codec3() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec<u8> = Vec::new(); + bytes.resize(0x3F - 1, 3); + assert!(!codec.encode(bytes.into(), &mut buf).is_err()); + assert_eq!(buf.len(), 0x3F + 1 - 1); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3F - 1); + assert_eq!(res[0], 3); + } else { + assert!(false); + } + } + #[test] + fn test_codec4() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec<u8> = Vec::new(); + bytes.resize(0x3FFF, 4); + assert!(!codec.encode(bytes.into(), &mut buf).is_err()); + assert_eq!(buf.len(), 0x3FFF + 2); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3FFF); + assert_eq!(res[0], 4); + } else { + assert!(false); + } + } + + #[test] + fn test_codec5() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec<u8> = Vec::new(); + bytes.resize(0x3FFFFF, 5); + assert!(!codec.encode(bytes.into(), &mut buf).is_err()); + assert_eq!(buf.len(), 0x3FFFFF + 3); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3FFFFF); + assert_eq!(res[0], 5); + } else { + assert!(false); + } + } + + #[test] + fn test_codec6() { + let mut codec = BytesCodec::new(); + let mut buf = BytesMut::new(); + let mut bytes: Vec<u8> = Vec::new(); + bytes.resize(0x3FFFFF + 1, 6); + assert!(!codec.encode(bytes.into(), &mut buf).is_err()); + let buf_saved = buf.clone(); + assert_eq!(buf.len(), 0x3FFFFF + 4 + 1); + if let Ok(Some(res)) = codec.decode(&mut buf) { + assert_eq!(res.len(), 0x3FFFFF + 1); + assert_eq!(res[0], 6); + } else { + assert!(false); + } + let mut codec2 = BytesCodec::new(); + let mut buf2 = BytesMut::new(); + buf2.extend(&buf_saved[0..1]); + if let Ok(None) = codec2.decode(&mut buf2) { + } else { + assert!(false); + } + buf2.extend(&buf_saved[1..6]); + if let Ok(None) = codec2.decode(&mut buf2) { + } else { + assert!(false); + } + buf2.extend(&buf_saved[6..]); + if let Ok(Some(res)) = codec2.decode(&mut buf2) { + assert_eq!(res.len(), 0x3FFFFF + 1); + assert_eq!(res[0], 6); + } else { + assert!(false); + } + } +} diff --git a/libs/hbb_common/src/compress.rs b/libs/hbb_common/src/compress.rs new file mode 100644 index 0000000..a969ccf --- /dev/null +++ b/libs/hbb_common/src/compress.rs @@ -0,0 +1,50 @@ +use std::cell::RefCell; +use zstd::block::{Compressor, Decompressor}; + +thread_local! { + static COMPRESSOR: RefCell<Compressor> = RefCell::new(Compressor::new()); + static DECOMPRESSOR: RefCell<Decompressor> = RefCell::new(Decompressor::new()); +} + +/// The library supports regular compression levels from 1 up to ZSTD_maxCLevel(), +/// which is currently 22. Levels >= 20 +/// Default level is ZSTD_CLEVEL_DEFAULT==3. +/// value 0 means default, which is controlled by ZSTD_CLEVEL_DEFAULT +pub fn compress(data: &[u8], level: i32) -> Vec<u8> { + let mut out = Vec::new(); + COMPRESSOR.with(|c| { + if let Ok(mut c) = c.try_borrow_mut() { + match c.compress(data, level) { + Ok(res) => out = res, + Err(err) => { + crate::log::debug!("Failed to compress: {}", err); + } + } + } + }); + out +} + +pub fn decompress(data: &[u8]) -> Vec<u8> { + let mut out = Vec::new(); + DECOMPRESSOR.with(|d| { + if let Ok(mut d) = d.try_borrow_mut() { + const MAX: usize = 1024 * 1024 * 64; + const MIN: usize = 1024 * 1024; + let mut n = 30 * data.len(); + if n > MAX { + n = MAX; + } + if n < MIN { + n = MIN; + } + match d.decompress(data, n) { + Ok(res) => out = res, + Err(err) => { + crate::log::debug!("Failed to decompress: {}", err); + } + } + } + }); + out +} diff --git a/libs/hbb_common/src/config.rs b/libs/hbb_common/src/config.rs new file mode 100644 index 0000000..a7c1bc6 --- /dev/null +++ b/libs/hbb_common/src/config.rs @@ -0,0 +1,876 @@ +use crate::log; +use directories_next::ProjectDirs; +use rand::Rng; +use serde_derive::{Deserialize, Serialize}; +use sodiumoxide::crypto::sign; +use std::{ + collections::HashMap, + fs, + net::{IpAddr, Ipv4Addr, SocketAddr}, + path::{Path, PathBuf}, + sync::{Arc, Mutex, RwLock}, + time::SystemTime, +}; + +pub const RENDEZVOUS_TIMEOUT: u64 = 12_000; +pub const CONNECT_TIMEOUT: u64 = 18_000; +pub const REG_INTERVAL: i64 = 12_000; +pub const COMPRESS_LEVEL: i32 = 3; +const SERIAL: i32 = 1; +// 128x128 +#[cfg(target_os = "macos")] // 128x128 on 160x160 canvas, then shrink to 128, mac looks better with padding +pub const ICON: &str = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAIAAAACACAMAAAD04JH5AAAAyVBMVEUAAAAAcf8Acf8Acf8Acv8Acf8Acf8Acf8Acf8AcP8Acf8Ab/8AcP8Acf////8AaP/z+f/o8v/k7v/5/v/T5f8AYP/u9v/X6f+hx/+Kuv95pP8Aef/B1/+TwP9xoP8BdP/g6P+Irv9ZmP8Bgf/E3f98q/9sn/+01f+Es/9nm/9Jif8hhv8off/M4P+syP+avP86iP/c7f+xy/9yqf9Om/9hk/9Rjv+60P99tv9fpf88lv8yjf8Tgf8deP+kvP8BiP8NeP8hkP80gP8oj2VLAAAADXRSTlMA7o7qLvnaxZ1FOxYPjH9HWgAABHJJREFUeNrtm+tW4jAQgBfwuu7MtIUWsOUiCCioIIgLiqvr+z/UHq/LJKVkmwTcc/r9E2nzlU4mSTP9lpGRkZGR8VX5cZjfL+yCEXYL+/nDH//U/Pd8DgyTy39Xbv7oIAcWyB0cqbW/sweW2NtRaj8H1sgpGOwUIAH7Bkd7YJW9dXFwAJY5WNP/cmCZQnJvzIN18on5LwfWySXlxEPYAIcad8D6PdiHDbCfIFCADVBIENiFDbCbIACKPPXrZ+cP8E6/0znvP4EymgIEravIRcTxu8HxNSJ60a8W0AYECKrlAN+YwAthCd9wm1Ug6wKzIn5SgRduXfwkqDasCjx0XFzi9PV6zwNcIuhcWBOg+ikySq8C9UD4dEKWBCoOcspvAuLHTo9sCDQiFPHotRM48j8G5gVur1FdAN2uaYEuiz7xFsgEJ2RUoMUakXuBTHHoGxQYOBhHjeUBAefEnMAowFhaLBOKuOemBBbxLRQrH2PBCgMvNCPQGMeevTb9zLrPxz2Mo+QbEaijzPUcOOHMQZkKGRAIPem39+bypREMPTkQW/oCfk866zAkiIFG4yIKRE/aAnfiSd0WrORY6pFdXQEqi9mvAQm0RIOSnoCcZ8vJoz3diCnjRk+g8VP4/fuQDJ2Lxr6WwG0gXs9aTpDzW0vgDBlVUpixR8gYk44AD8FrUKHr8JQJGgIDnoDqoALxmWPQSi9AVVzm8gKUuEPGr/QCvptwJkbSYT/TC4S8C96DGjTj86aHtAI0x2WaBIq0eSYYpRa4EsdWVVwWu9O0Aj6f6dyBMnwEraeOgSYu0wZlauzA47QCbT7DgAQSE+hZWoEBF/BBmWOewNMK3BsSqKUW4MGcWqCSVmDkbvkXGKQOwg6PAUO9oL3xXhA20yaiCjuwYygRVQlUOTWTCf2SuNJTxeFjgaHByGuAIvd8ItdPLTDhS7IuqEE1YSKVOgbayLhSFQhMzYh8hwfBs1r7c505YVIQYEdNoKwxK06MJiyrpUFHiF0NAfCQUVHoiRclIXJIR6C2fqG37pBHvcWpgwzvAtYwkR5UGV2e42UISdBJETl3mg8ouo54Rcnti1/vaT+iuUQBt500Cgo4U10BeHSkk57FB0JjWkKRMWgLUA0lLodtImAQdaMiiri3+gIAPZQoutHNsgKF1aaDMhMyIdBf8Th+Bh8MTjGWCpl5Wv43tDmnF+IUVMrcZgRoiAxhtrloYizNkZaAnF5leglbNhj0wYCAbCDvGb0mP4nib7O7ZlcYQ2m1gPtIZgVgGNNMeaVAaWR+57TrqgtUnm3sHQ+kYeE6fufUubG1ez50FXbPnWgBlgSABmN3TTcsRl2yWkHRrwbiunvk/W2+Mg1hPZplPDeXRbZzStFH15s1QIVd3UImP5z/bHpeeQLvRJ7XLFUffQIlCvqlXETQbgN9/rlYABGosv+Vi9m2Xs639YLGrZd0br+odetlvdsvbN56abfd4vbCzv9Q3v/ygoOV21A4OPpfXvH4Ai+5ZGRkZGRkbJA/t/I0QMzoMiEAAAAASUVORK5CYII= +"; +#[cfg(not(target_os = "macos"))] // 128x128 no padding +pub const ICON: &str = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAIAAAACACAMAAAD04JH5AAAA7VBMVEUAAAAAcf8Acf8Acf8Adf8Acf8Acf8AcP8Acv8AcP8Acf8Acf8Acf8Acv8Acf8Acf8Ab/8AcP8Acf8Acf8Acf/////7/f8Dc/8TfP/1+f/n8v9Hmf/u9v+Uw//Q5f9hp/8Yfv8Qev8Ld/+52P+z1f+s0f81j/8wjP8Hdf/3+/8mh/8fg//x9//h7//H4P9xsP9rrf9oq/8rif/r9P/D3v+92/+Duv9bpP/d7f/U5/9NnP8/lP8jhP/L4v/B3P+OwP9+t/95tf9Rn/8bgf/Z6v+Zx/90sv9lqf85kf+hy/9UoP+Wxf+kzP+dyP+Lvv/H4q8IAAAAFHRSTlMA+u6bB6x5XR4V0+S4i4k5N+a81W8MiAQAAAVcSURBVHjazdvpWtpAGIbhgEutdW3fL2GHsMsiq4KI+66t5384XahF/GbizJAy3j/1Ah5CJhNCxpm1vbryLRrBfxKJrq+sbjtSa5u7WIDdzTVH5PNSBAsSWfrsMJ+iWKDoJ2fW8hIWbGl55vW/YuE2XhUsb8CCr9OCJVix9G//gyWf/o6/KCyJfrbwAfAPYS0CayK/j4mbsGjrV8AXWLTrONuwasdZhVWrzgqsWnG+wap1Jwqrok4EVkUcmKhdVvBaOVnzYEY/oJpMD4mo6ONF/ZSIUsX2FZjQA7xRqUET+y/v2W/Sy59u62DCDMgdJmhqgIk7eqWQBBNWwPhmj147w8QTzTjKVsGEEBBLuzSrhIkivTF8DD/Aa6forQNMHBD/VyXkgHGfuBN5ALln1TADOnESyGCiT8L/1kILqD6Q0BEm9kkofhdSwNUJiV1jQvZ/SnthBNSaJJGZbgGJUnX+gEqCZPpsJ2T2Y/MGVBrE8eOAvCA/X8A4QXLnmEhTgIPqPAG5IQU4fhmkFOT7HAFenwIU8Jd/TUEODQIUtu1eOj/dUD9cknOTpgEDkup3YrOfVStDUomcWcBVisTiNxVw3TPpgCl4RgFFybZ/9iHmn8uS2yYBA8m7qUEu9oOEejH9gHxC+PazCHbcFM8K+gGHJNAs4z2xgnAkVHQDcnG1IzvnCSfvom7AM3EZ9voah4+KXoAvGFJHMSgqEfegF3BBTKoOVfkMMXFfJ8AT7MuXUDeOE9PWCUiKBpKOlmAP1gngH2LChw7vhJgr9YD8Hnt0BxrE27CtHnDJR4AHTX1+KFAP4Ef0LHTxN9HwlAMSbAjmoavKZ8ayakDXYAhwN3wzqgZk2UPvwRjshmeqATeCT09f3mWnEqoBGf4NxAB/moRqADuOtmDiid6KqQVcsQeOYOKW3uqqBRwL5nITj/yrlFpAVrDpTJT5llQLaLMHwshY7UDgvD+VujDC96WWWsBtSAE5FnChFnAeUkDMdAvw88EqTNT5SYXpTlgPaRQM1AIGorkolNnoUS1gJHigCX48SaoF3Asuspg4Mz0U8+FTgIkCG01V09kwBQP8xG5ofD5AXeirkPEJSUlwSVIfP5ykVQNaggvz+k7prTvVgDKF8BnUXP4kqgEe/257E8Ig7EE1gA8g2stBTz7FLxqrB3SIeYaeQ2IG6gE5l2+Cmt5MGOfP4KsGiH8DOYWOoujnDY2ALHF3810goZFOQDVBTFx9Uj7eI6bp6QTgnLjeGGq6KeJuoRUQixN3pDYWyz1Rva8XIL5UPFQZCsmG3gV7R+dieS+Jd3iHLglce7oBuCOhp3zwHLxPQpfQDvBOSKjZqUIml3ZJ6AD6AajFSZJwewWR8ZPsEY26SQDaJOMeZP23w6bTJ6kBjAJQILm9hzqm7otu4G+nhgGxIQUlPLKzL7GhbxqAboMCuN2XXd+lAL0ajAMwclV+FD6jAPEy5ghAlhfwX2FODX445gHKxyN++fs64PUHmDMAbbYN2DlKk2QaScwdgMs4SZxMv4OJJSoIIQBl2Qtk3gk4qiOUANRPJQHB+0A6j5AC4J27QQEZ4eZPAsYBXFk0N/YD7iUrxRBqALxOTzoMC3x8lCFlfkMjuz8iLfk6fzQCQgjg8q3ZEd8RzUVuKelBh96Nzcc3qelL1V+2zfRv1xc56Ino3tpdPT7cd//MspfTrD/7R6p4W4O2qLMObfnyIHvvYcrPtkZjDybW7d/eb32Bg/UlHnYXuXz5CMt8rC90sr7Uy/5iN+vL/ewveLS/5NNKwcbyR1r2a3/h8wdY+v3L2tZC5oUvW2uO1M7qyvp/Xv6/48z4CTxjJEfyjEaMAAAAAElFTkSuQmCC +"; +#[cfg(target_os = "macos")] +lazy_static::lazy_static! { + pub static ref ORG: Arc<RwLock<String>> = Arc::new(RwLock::new("com.carriez".to_owned())); +} + +type Size = (i32, i32, i32, i32); + +lazy_static::lazy_static! { + static ref CONFIG: Arc<RwLock<Config>> = Arc::new(RwLock::new(Config::load())); + static ref CONFIG2: Arc<RwLock<Config2>> = Arc::new(RwLock::new(Config2::load())); + static ref LOCAL_CONFIG: Arc<RwLock<LocalConfig>> = Arc::new(RwLock::new(LocalConfig::load())); + pub static ref ONLINE: Arc<Mutex<HashMap<String, i64>>> = Default::default(); + pub static ref PROD_RENDEZVOUS_SERVER: Arc<RwLock<String>> = Default::default(); + pub static ref APP_NAME: Arc<RwLock<String>> = Arc::new(RwLock::new("RustDesk".to_owned())); +} +#[cfg(any(target_os = "android", target_os = "ios"))] +lazy_static::lazy_static! { + pub static ref APP_DIR: Arc<RwLock<String>> = Default::default(); + pub static ref APP_HOME_DIR: Arc<RwLock<String>> = Default::default(); +} +const CHARS: &'static [char] = &[ + '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', + 'm', 'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', +]; + +pub const RENDEZVOUS_SERVERS: &'static [&'static str] = &[ + "rs-ny.rustdesk.com", + "rs-sg.rustdesk.com", + "rs-cn.rustdesk.com", +]; +pub const RENDEZVOUS_PORT: i32 = 21116; +pub const RELAY_PORT: i32 = 21117; + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum NetworkType { + Direct, + ProxySocks, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] +pub struct Config { + #[serde(default)] + pub id: String, + #[serde(default)] + password: String, + #[serde(default)] + salt: String, + #[serde(default)] + pub key_pair: (Vec<u8>, Vec<u8>), // sk, pk + #[serde(default)] + key_confirmed: bool, + #[serde(default)] + keys_confirmed: HashMap<String, bool>, +} + +#[derive(Debug, Default, PartialEq, Serialize, Deserialize, Clone)] +pub struct Socks5Server { + #[serde(default)] + pub proxy: String, + #[serde(default)] + pub username: String, + #[serde(default)] + pub password: String, +} + +// more variable configs +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] +pub struct Config2 { + #[serde(default)] + rendezvous_server: String, + #[serde(default)] + nat_type: i32, + #[serde(default)] + serial: i32, + + #[serde(default)] + socks: Option<Socks5Server>, + + // the other scalar value must before this + #[serde(default)] + pub options: HashMap<String, String>, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct PeerConfig { + #[serde(default)] + pub password: Vec<u8>, + #[serde(default)] + pub size: Size, + #[serde(default)] + pub size_ft: Size, + #[serde(default)] + pub size_pf: Size, + #[serde(default)] + pub view_style: String, // original (default), scale + #[serde(default)] + pub image_quality: String, + #[serde(default)] + pub custom_image_quality: Vec<i32>, + #[serde(default)] + pub show_remote_cursor: bool, + #[serde(default)] + pub lock_after_session_end: bool, + #[serde(default)] + pub privacy_mode: bool, + #[serde(default)] + pub port_forwards: Vec<(i32, String, i32)>, + #[serde(default)] + pub direct_failures: i32, + #[serde(default)] + pub disable_audio: bool, + #[serde(default)] + pub disable_clipboard: bool, + #[serde(default)] + pub enable_file_transfer: bool, + + // the other scalar value must before this + #[serde(default)] + pub options: HashMap<String, String>, + #[serde(default)] + pub info: PeerInfoSerde, +} + +#[derive(Debug, PartialEq, Default, Serialize, Deserialize, Clone)] +pub struct PeerInfoSerde { + #[serde(default)] + pub username: String, + #[serde(default)] + pub hostname: String, + #[serde(default)] + pub platform: String, +} + +fn patch(path: PathBuf) -> PathBuf { + if let Some(_tmp) = path.to_str() { + #[cfg(windows)] + return _tmp + .replace( + "system32\\config\\systemprofile", + "ServiceProfiles\\LocalService", + ) + .into(); + #[cfg(target_os = "macos")] + return _tmp.replace("Application Support", "Preferences").into(); + #[cfg(target_os = "linux")] + { + if _tmp == "/root" { + if let Ok(output) = std::process::Command::new("whoami").output() { + let user = String::from_utf8_lossy(&output.stdout) + .to_string() + .trim() + .to_owned(); + if user != "root" { + return format!("/home/{}", user).into(); + } + } + } + } + } + path +} + +impl Config2 { + fn load() -> Config2 { + Config::load_::<Config2>("2") + } + + pub fn file() -> PathBuf { + Config::file_("2") + } + + fn store(&self) { + Config::store_(self, "2"); + } + + pub fn get() -> Config2 { + return CONFIG2.read().unwrap().clone(); + } + + pub fn set(cfg: Config2) -> bool { + let mut lock = CONFIG2.write().unwrap(); + if *lock == cfg { + return false; + } + *lock = cfg; + lock.store(); + true + } +} + +pub fn load_path<T: serde::Serialize + serde::de::DeserializeOwned + Default + std::fmt::Debug>( + file: PathBuf, +) -> T { + let cfg = match confy::load_path(&file) { + Ok(config) => config, + Err(err) => { + log::error!("Failed to load config: {}", err); + T::default() + } + }; + cfg +} + +impl Config { + fn load_<T: serde::Serialize + serde::de::DeserializeOwned + Default + std::fmt::Debug>( + suffix: &str, + ) -> T { + let file = Self::file_(suffix); + log::debug!("Configuration path: {}", file.display()); + let cfg = load_path(file); + if suffix.is_empty() { + log::debug!("{:?}", cfg); + } + cfg + } + + fn store_<T: serde::Serialize>(config: &T, suffix: &str) { + let file = Self::file_(suffix); + if let Err(err) = confy::store_path(file, config) { + log::error!("Failed to store config: {}", err); + } + } + + fn load() -> Config { + Config::load_::<Config>("") + } + + fn store(&self) { + Config::store_(self, ""); + } + + pub fn file() -> PathBuf { + Self::file_("") + } + + fn file_(suffix: &str) -> PathBuf { + let name = format!("{}{}", *APP_NAME.read().unwrap(), suffix); + Self::path(name).with_extension("toml") + } + + pub fn get_home() -> PathBuf { + #[cfg(any(target_os = "android", target_os = "ios"))] + return Self::path(APP_HOME_DIR.read().unwrap().as_str()); + if let Some(path) = dirs_next::home_dir() { + patch(path) + } else if let Ok(path) = std::env::current_dir() { + path + } else { + std::env::temp_dir() + } + } + + pub fn path<P: AsRef<Path>>(p: P) -> PathBuf { + #[cfg(any(target_os = "android", target_os = "ios"))] + { + let mut path: PathBuf = APP_DIR.read().unwrap().clone().into(); + path.push(p); + return path; + } + #[cfg(not(target_os = "macos"))] + let org = ""; + #[cfg(target_os = "macos")] + let org = ORG.read().unwrap().clone(); + // /var/root for root + if let Some(project) = ProjectDirs::from("", &org, &*APP_NAME.read().unwrap()) { + let mut path = patch(project.config_dir().to_path_buf()); + path.push(p); + return path; + } + return "".into(); + } + + #[allow(unreachable_code)] + pub fn log_path() -> PathBuf { + #[cfg(target_os = "macos")] + { + if let Some(path) = dirs_next::home_dir().as_mut() { + path.push(format!("Library/Logs/{}", *APP_NAME.read().unwrap())); + return path.clone(); + } + } + #[cfg(target_os = "linux")] + { + let mut path = Self::get_home(); + path.push(format!(".local/share/logs/{}", *APP_NAME.read().unwrap())); + std::fs::create_dir_all(&path).ok(); + return path; + } + if let Some(path) = Self::path("").parent() { + let mut path: PathBuf = path.into(); + path.push("log"); + return path; + } + "".into() + } + + pub fn ipc_path(postfix: &str) -> String { + #[cfg(windows)] + { + // \\ServerName\pipe\PipeName + // where ServerName is either the name of a remote computer or a period, to specify the local computer. + // https://docs.microsoft.com/en-us/windows/win32/ipc/pipe-names + format!( + "\\\\.\\pipe\\{}\\query{}", + *APP_NAME.read().unwrap(), + postfix + ) + } + #[cfg(not(windows))] + { + use std::os::unix::fs::PermissionsExt; + #[cfg(target_os = "android")] + let mut path: PathBuf = + format!("{}/{}", *APP_DIR.read().unwrap(), *APP_NAME.read().unwrap()).into(); + #[cfg(not(target_os = "android"))] + let mut path: PathBuf = format!("/tmp/{}", *APP_NAME.read().unwrap()).into(); + fs::create_dir(&path).ok(); + fs::set_permissions(&path, fs::Permissions::from_mode(0o0777)).ok(); + path.push(format!("ipc{}", postfix)); + path.to_str().unwrap_or("").to_owned() + } + } + + pub fn icon_path() -> PathBuf { + let mut path = Self::path("icons"); + if fs::create_dir_all(&path).is_err() { + path = std::env::temp_dir(); + } + path + } + + #[inline] + pub fn get_any_listen_addr() -> SocketAddr { + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) + } + + pub fn get_rendezvous_server() -> String { + let mut rendezvous_server = Self::get_option("custom-rendezvous-server"); + if rendezvous_server.is_empty() { + rendezvous_server = PROD_RENDEZVOUS_SERVER.read().unwrap().clone(); + } + if rendezvous_server.is_empty() { + rendezvous_server = CONFIG2.read().unwrap().rendezvous_server.clone(); + } + if rendezvous_server.is_empty() { + rendezvous_server = Self::get_rendezvous_servers() + .drain(..) + .next() + .unwrap_or("".to_owned()); + } + if !rendezvous_server.contains(":") { + rendezvous_server = format!("{}:{}", rendezvous_server, RENDEZVOUS_PORT); + } + rendezvous_server + } + + pub fn get_rendezvous_servers() -> Vec<String> { + let s = Self::get_option("custom-rendezvous-server"); + if !s.is_empty() { + return vec![s]; + } + let s = PROD_RENDEZVOUS_SERVER.read().unwrap().clone(); + if !s.is_empty() { + return vec![s]; + } + let serial_obsolute = CONFIG2.read().unwrap().serial > SERIAL; + if serial_obsolute { + let ss: Vec<String> = Self::get_option("rendezvous-servers") + .split(",") + .filter(|x| x.contains(".")) + .map(|x| x.to_owned()) + .collect(); + if !ss.is_empty() { + return ss; + } + } + return RENDEZVOUS_SERVERS.iter().map(|x| x.to_string()).collect(); + } + + pub fn reset_online() { + *ONLINE.lock().unwrap() = Default::default(); + } + + pub fn update_latency(host: &str, latency: i64) { + ONLINE.lock().unwrap().insert(host.to_owned(), latency); + let mut host = "".to_owned(); + let mut delay = i64::MAX; + for (tmp_host, tmp_delay) in ONLINE.lock().unwrap().iter() { + if tmp_delay > &0 && tmp_delay < &delay { + delay = tmp_delay.clone(); + host = tmp_host.to_string(); + } + } + if !host.is_empty() { + let mut config = CONFIG2.write().unwrap(); + if host != config.rendezvous_server { + log::debug!("Update rendezvous_server in config to {}", host); + log::debug!("{:?}", *ONLINE.lock().unwrap()); + config.rendezvous_server = host; + config.store(); + } + } + } + + pub fn set_id(id: &str) { + let mut config = CONFIG.write().unwrap(); + if id == config.id { + return; + } + config.id = id.into(); + config.store(); + } + + pub fn set_nat_type(nat_type: i32) { + let mut config = CONFIG2.write().unwrap(); + if nat_type == config.nat_type { + return; + } + config.nat_type = nat_type; + config.store(); + } + + pub fn get_nat_type() -> i32 { + CONFIG2.read().unwrap().nat_type + } + + pub fn set_serial(serial: i32) { + let mut config = CONFIG2.write().unwrap(); + if serial == config.serial { + return; + } + config.serial = serial; + config.store(); + } + + pub fn get_serial() -> i32 { + std::cmp::max(CONFIG2.read().unwrap().serial, SERIAL) + } + + fn get_auto_id() -> Option<String> { + #[cfg(any(target_os = "android", target_os = "ios"))] + { + return Some( + rand::thread_rng() + .gen_range(1_000_000_000..2_000_000_000) + .to_string(), + ); + } + let mut id = 0u32; + #[cfg(not(any(target_os = "android", target_os = "ios")))] + if let Ok(Some(ma)) = mac_address::get_mac_address() { + for x in &ma.bytes()[2..] { + id = (id << 8) | (*x as u32); + } + id = id & 0x1FFFFFFF; + Some(id.to_string()) + } else { + None + } + } + + pub fn get_auto_password() -> String { + let mut rng = rand::thread_rng(); + (0..6) + .map(|_| CHARS[rng.gen::<usize>() % CHARS.len()]) + .collect() + } + + pub fn get_key_confirmed() -> bool { + CONFIG.read().unwrap().key_confirmed + } + + pub fn set_key_confirmed(v: bool) { + let mut config = CONFIG.write().unwrap(); + if config.key_confirmed == v { + return; + } + config.key_confirmed = v; + if !v { + config.keys_confirmed = Default::default(); + } + config.store(); + } + + pub fn get_host_key_confirmed(host: &str) -> bool { + if let Some(true) = CONFIG.read().unwrap().keys_confirmed.get(host) { + true + } else { + false + } + } + + pub fn set_host_key_confirmed(host: &str, v: bool) { + if Self::get_host_key_confirmed(host) == v { + return; + } + let mut config = CONFIG.write().unwrap(); + config.keys_confirmed.insert(host.to_owned(), v); + config.store(); + } + + pub fn set_key_pair(pair: (Vec<u8>, Vec<u8>)) { + let mut config = CONFIG.write().unwrap(); + if config.key_pair == pair { + return; + } + config.key_pair = pair; + config.store(); + } + + pub fn get_key_pair() -> (Vec<u8>, Vec<u8>) { + // lock here to make sure no gen_keypair more than once + let mut config = CONFIG.write().unwrap(); + if config.key_pair.0.is_empty() { + let (pk, sk) = sign::gen_keypair(); + config.key_pair = (sk.0.to_vec(), pk.0.into()); + config.store(); + } + config.key_pair.clone() + } + + pub fn get_id() -> String { + let mut id = CONFIG.read().unwrap().id.clone(); + if id.is_empty() { + if let Some(tmp) = Config::get_auto_id() { + id = tmp; + Config::set_id(&id); + } + } + id + } + + pub fn get_id_or(b: String) -> String { + let a = CONFIG.read().unwrap().id.clone(); + if a.is_empty() { + b + } else { + a + } + } + + pub fn get_options() -> HashMap<String, String> { + CONFIG2.read().unwrap().options.clone() + } + + pub fn set_options(v: HashMap<String, String>) { + let mut config = CONFIG2.write().unwrap(); + if config.options == v { + return; + } + config.options = v; + config.store(); + } + + pub fn get_option(k: &str) -> String { + if let Some(v) = CONFIG2.read().unwrap().options.get(k) { + v.clone() + } else { + "".to_owned() + } + } + + pub fn set_option(k: String, v: String) { + let mut config = CONFIG2.write().unwrap(); + let v2 = if v.is_empty() { None } else { Some(&v) }; + if v2 != config.options.get(&k) { + if v2.is_none() { + config.options.remove(&k); + } else { + config.options.insert(k, v); + } + config.store(); + } + } + + pub fn update_id() { + // to-do: how about if one ip register a lot of ids? + let id = Self::get_id(); + let mut rng = rand::thread_rng(); + let new_id = rng.gen_range(1_000_000_000..2_000_000_000).to_string(); + Config::set_id(&new_id); + log::info!("id updated from {} to {}", id, new_id); + } + + pub fn set_password(password: &str) { + let mut config = CONFIG.write().unwrap(); + if password == config.password { + return; + } + config.password = password.into(); + config.store(); + } + + pub fn get_password() -> String { + let mut password = CONFIG.read().unwrap().password.clone(); + if password.is_empty() { + password = Config::get_auto_password(); + Config::set_password(&password); + } + password + } + + pub fn set_salt(salt: &str) { + let mut config = CONFIG.write().unwrap(); + if salt == config.salt { + return; + } + config.salt = salt.into(); + config.store(); + } + + pub fn get_salt() -> String { + let mut salt = CONFIG.read().unwrap().salt.clone(); + if salt.is_empty() { + salt = Config::get_auto_password(); + Config::set_salt(&salt); + } + salt + } + + pub fn set_socks(socks: Option<Socks5Server>) { + let mut config = CONFIG2.write().unwrap(); + if config.socks == socks { + return; + } + config.socks = socks; + config.store(); + } + + pub fn get_socks() -> Option<Socks5Server> { + CONFIG2.read().unwrap().socks.clone() + } + + pub fn get_network_type() -> NetworkType { + match &CONFIG2.read().unwrap().socks { + None => NetworkType::Direct, + Some(_) => NetworkType::ProxySocks, + } + } + + pub fn get() -> Config { + return CONFIG.read().unwrap().clone(); + } + + pub fn set(cfg: Config) -> bool { + let mut lock = CONFIG.write().unwrap(); + if *lock == cfg { + return false; + } + *lock = cfg; + lock.store(); + true + } +} + +const PEERS: &str = "peers"; + +impl PeerConfig { + pub fn load(id: &str) -> PeerConfig { + let _ = CONFIG.read().unwrap(); // for lock + match confy::load_path(&Self::path(id)) { + Ok(config) => config, + Err(err) => { + log::error!("Failed to load config: {}", err); + Default::default() + } + } + } + + pub fn store(&self, id: &str) { + let _ = CONFIG.read().unwrap(); // for lock + if let Err(err) = confy::store_path(Self::path(id), self) { + log::error!("Failed to store config: {}", err); + } + } + + pub fn remove(id: &str) { + fs::remove_file(&Self::path(id)).ok(); + } + + fn path(id: &str) -> PathBuf { + let path: PathBuf = [PEERS, id].iter().collect(); + Config::path(path).with_extension("toml") + } + + pub fn peers() -> Vec<(String, SystemTime, PeerConfig)> { + if let Ok(peers) = Config::path(PEERS).read_dir() { + if let Ok(peers) = peers + .map(|res| res.map(|e| e.path())) + .collect::<Result<Vec<_>, _>>() + { + let mut peers: Vec<_> = peers + .iter() + .filter(|p| { + p.is_file() + && p.extension().map(|p| p.to_str().unwrap_or("")) == Some("toml") + }) + .map(|p| { + let t = crate::get_modified_time(&p); + let id = p + .file_stem() + .map(|p| p.to_str().unwrap_or("")) + .unwrap_or("") + .to_owned(); + let c = PeerConfig::load(&id); + if c.info.platform.is_empty() { + fs::remove_file(&p).ok(); + } + (id, t, c) + }) + .filter(|p| !p.2.info.platform.is_empty()) + .collect(); + peers.sort_unstable_by(|a, b| b.1.cmp(&a.1)); + return peers; + } + } + Default::default() + } +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct LocalConfig { + #[serde(default)] + remote_id: String, // latest used one + #[serde(default)] + size: Size, + #[serde(default)] + pub fav: Vec<String>, + #[serde(default)] + options: HashMap<String, String>, +} + +impl LocalConfig { + fn load() -> LocalConfig { + Config::load_::<LocalConfig>("_local") + } + + fn store(&self) { + Config::store_(self, "_local"); + } + + pub fn get_size() -> Size { + LOCAL_CONFIG.read().unwrap().size + } + + pub fn set_size(x: i32, y: i32, w: i32, h: i32) { + let mut config = LOCAL_CONFIG.write().unwrap(); + let size = (x, y, w, h); + if size == config.size || size.2 < 300 || size.3 < 300 { + return; + } + config.size = size; + config.store(); + } + + pub fn set_remote_id(remote_id: &str) { + let mut config = LOCAL_CONFIG.write().unwrap(); + if remote_id == config.remote_id { + return; + } + config.remote_id = remote_id.into(); + config.store(); + } + + pub fn get_remote_id() -> String { + LOCAL_CONFIG.read().unwrap().remote_id.clone() + } + + pub fn set_fav(fav: Vec<String>) { + let mut lock = LOCAL_CONFIG.write().unwrap(); + if lock.fav == fav { + return; + } + lock.fav = fav; + lock.store(); + } + + pub fn get_fav() -> Vec<String> { + LOCAL_CONFIG.read().unwrap().fav.clone() + } + + pub fn get_option(k: &str) -> String { + if let Some(v) = LOCAL_CONFIG.read().unwrap().options.get(k) { + v.clone() + } else { + "".to_owned() + } + } + + pub fn set_option(k: String, v: String) { + let mut config = LOCAL_CONFIG.write().unwrap(); + let v2 = if v.is_empty() { None } else { Some(&v) }; + if v2 != config.options.get(&k) { + if v2.is_none() { + config.options.remove(&k); + } else { + config.options.insert(k, v); + } + config.store(); + } + } +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct LanPeers { + #[serde(default)] + pub peers: String, +} + +impl LanPeers { + pub fn load() -> LanPeers { + let _ = CONFIG.read().unwrap(); // for lock + match confy::load_path(&Config::file_("_lan_peers")) { + Ok(peers) => peers, + Err(err) => { + log::error!("Failed to load lan peers: {}", err); + Default::default() + } + } + } + + pub fn store(peers: String) { + let f = LanPeers { peers }; + if let Err(err) = confy::store_path(Config::file_("_lan_peers"), f) { + log::error!("Failed to store lan peers: {}", err); + } + } + + pub fn modify_time() -> crate::ResultType<u64> { + let p = Config::file_("_lan_peers"); + Ok(fs::metadata(p)? + .modified()? + .duration_since(SystemTime::UNIX_EPOCH)? + .as_millis() as _) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_serialize() { + let cfg: Config = Default::default(); + let res = toml::to_string_pretty(&cfg); + assert!(res.is_ok()); + let cfg: PeerConfig = Default::default(); + let res = toml::to_string_pretty(&cfg); + assert!(res.is_ok()); + } +} diff --git a/libs/hbb_common/src/fs.rs b/libs/hbb_common/src/fs.rs new file mode 100644 index 0000000..475f4df --- /dev/null +++ b/libs/hbb_common/src/fs.rs @@ -0,0 +1,560 @@ +use crate::{bail, message_proto::*, ResultType}; +use std::path::{Path, PathBuf}; +// https://doc.rust-lang.org/std/os/windows/fs/trait.MetadataExt.html +use crate::{ + compress::{compress, decompress}, + config::{Config, COMPRESS_LEVEL}, +}; +#[cfg(windows)] +use std::os::windows::prelude::*; +use tokio::{fs::File, io::*}; + +pub fn read_dir(path: &PathBuf, include_hidden: bool) -> ResultType<FileDirectory> { + let mut dir = FileDirectory { + path: get_string(&path), + ..Default::default() + }; + #[cfg(windows)] + if "/" == &get_string(&path) { + let drives = unsafe { winapi::um::fileapi::GetLogicalDrives() }; + for i in 0..32 { + if drives & (1 << i) != 0 { + let name = format!( + "{}:", + std::char::from_u32('A' as u32 + i as u32).unwrap_or('A') + ); + dir.entries.push(FileEntry { + name, + entry_type: FileType::DirDrive.into(), + ..Default::default() + }); + } + } + return Ok(dir); + } + for entry in path.read_dir()? { + if let Ok(entry) = entry { + let p = entry.path(); + let name = p + .file_name() + .map(|p| p.to_str().unwrap_or("")) + .unwrap_or("") + .to_owned(); + if name.is_empty() { + continue; + } + let mut is_hidden = false; + let meta; + if let Ok(tmp) = std::fs::symlink_metadata(&p) { + meta = tmp; + } else { + continue; + } + // docs.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants + #[cfg(windows)] + if meta.file_attributes() & 0x2 != 0 { + is_hidden = true; + } + #[cfg(not(windows))] + if name.find('.').unwrap_or(usize::MAX) == 0 { + is_hidden = true; + } + if is_hidden && !include_hidden { + continue; + } + let (entry_type, size) = { + if p.is_dir() { + if meta.file_type().is_symlink() { + (FileType::DirLink.into(), 0) + } else { + (FileType::Dir.into(), 0) + } + } else { + if meta.file_type().is_symlink() { + (FileType::FileLink.into(), 0) + } else { + (FileType::File.into(), meta.len()) + } + } + }; + let modified_time = meta + .modified() + .map(|x| { + x.duration_since(std::time::SystemTime::UNIX_EPOCH) + .map(|x| x.as_secs()) + .unwrap_or(0) + }) + .unwrap_or(0) as u64; + dir.entries.push(FileEntry { + name: get_file_name(&p), + entry_type, + is_hidden, + size, + modified_time, + ..Default::default() + }); + } + } + Ok(dir) +} + +#[inline] +pub fn get_file_name(p: &PathBuf) -> String { + p.file_name() + .map(|p| p.to_str().unwrap_or("")) + .unwrap_or("") + .to_owned() +} + +#[inline] +pub fn get_string(path: &PathBuf) -> String { + path.to_str().unwrap_or("").to_owned() +} + +#[inline] +pub fn get_path(path: &str) -> PathBuf { + Path::new(path).to_path_buf() +} + +#[inline] +pub fn get_home_as_string() -> String { + get_string(&Config::get_home()) +} + +fn read_dir_recursive( + path: &PathBuf, + prefix: &PathBuf, + include_hidden: bool, +) -> ResultType<Vec<FileEntry>> { + let mut files = Vec::new(); + if path.is_dir() { + // to-do: symbol link handling, cp the link rather than the content + // to-do: file mode, for unix + let fd = read_dir(&path, include_hidden)?; + for entry in fd.entries.iter() { + match entry.entry_type.enum_value() { + Ok(FileType::File) => { + let mut entry = entry.clone(); + entry.name = get_string(&prefix.join(entry.name)); + files.push(entry); + } + Ok(FileType::Dir) => { + if let Ok(mut tmp) = read_dir_recursive( + &path.join(&entry.name), + &prefix.join(&entry.name), + include_hidden, + ) { + for entry in tmp.drain(0..) { + files.push(entry); + } + } + } + _ => {} + } + } + Ok(files) + } else if path.is_file() { + let (size, modified_time) = if let Ok(meta) = std::fs::metadata(&path) { + ( + meta.len(), + meta.modified() + .map(|x| { + x.duration_since(std::time::SystemTime::UNIX_EPOCH) + .map(|x| x.as_secs()) + .unwrap_or(0) + }) + .unwrap_or(0) as u64, + ) + } else { + (0, 0) + }; + files.push(FileEntry { + entry_type: FileType::File.into(), + size, + modified_time, + ..Default::default() + }); + Ok(files) + } else { + bail!("Not exists"); + } +} + +pub fn get_recursive_files(path: &str, include_hidden: bool) -> ResultType<Vec<FileEntry>> { + read_dir_recursive(&get_path(path), &get_path(""), include_hidden) +} + +#[derive(Default)] +pub struct TransferJob { + id: i32, + path: PathBuf, + files: Vec<FileEntry>, + file_num: i32, + file: Option<File>, + total_size: u64, + finished_size: u64, + transferred: u64, +} + +#[inline] +fn get_ext(name: &str) -> &str { + if let Some(i) = name.rfind(".") { + return &name[i + 1..]; + } + "" +} + +#[inline] +fn is_compressed_file(name: &str) -> bool { + let ext = get_ext(name); + ext == "xz" + || ext == "gz" + || ext == "zip" + || ext == "7z" + || ext == "rar" + || ext == "bz2" + || ext == "tgz" + || ext == "png" + || ext == "jpg" +} + +impl TransferJob { + pub fn new_write(id: i32, path: String, files: Vec<FileEntry>) -> Self { + let total_size = files.iter().map(|x| x.size as u64).sum(); + Self { + id, + path: get_path(&path), + files, + total_size, + ..Default::default() + } + } + + pub fn new_read(id: i32, path: String, include_hidden: bool) -> ResultType<Self> { + let files = get_recursive_files(&path, include_hidden)?; + let total_size = files.iter().map(|x| x.size as u64).sum(); + Ok(Self { + id, + path: get_path(&path), + files, + total_size, + ..Default::default() + }) + } + + #[inline] + pub fn files(&self) -> &Vec<FileEntry> { + &self.files + } + + #[inline] + pub fn set_files(&mut self, files: Vec<FileEntry>) { + self.files = files; + } + + #[inline] + pub fn id(&self) -> i32 { + self.id + } + + #[inline] + pub fn total_size(&self) -> u64 { + self.total_size + } + + #[inline] + pub fn finished_size(&self) -> u64 { + self.finished_size + } + + #[inline] + pub fn transferred(&self) -> u64 { + self.transferred + } + + #[inline] + pub fn file_num(&self) -> i32 { + self.file_num + } + + pub fn modify_time(&self) { + let file_num = self.file_num as usize; + if file_num < self.files.len() { + let entry = &self.files[file_num]; + let path = self.join(&entry.name); + let download_path = format!("{}.download", get_string(&path)); + std::fs::rename(&download_path, &path).ok(); + filetime::set_file_mtime( + &path, + filetime::FileTime::from_unix_time(entry.modified_time as _, 0), + ) + .ok(); + } + } + + pub fn remove_download_file(&self) { + let file_num = self.file_num as usize; + if file_num < self.files.len() { + let entry = &self.files[file_num]; + let path = self.join(&entry.name); + let download_path = format!("{}.download", get_string(&path)); + std::fs::remove_file(&download_path).ok(); + } + } + + pub async fn write(&mut self, block: FileTransferBlock, raw: Option<&[u8]>) -> ResultType<()> { + if block.id != self.id { + bail!("Wrong id"); + } + let file_num = block.file_num as usize; + if file_num >= self.files.len() { + bail!("Wrong file number"); + } + if file_num != self.file_num as usize || self.file.is_none() { + self.modify_time(); + if let Some(file) = self.file.as_mut() { + file.sync_all().await?; + } + self.file_num = block.file_num; + let entry = &self.files[file_num]; + let path = self.join(&entry.name); + if let Some(p) = path.parent() { + std::fs::create_dir_all(p).ok(); + } + let path = format!("{}.download", get_string(&path)); + self.file = Some(File::create(&path).await?); + } + let data = if let Some(data) = raw { + data + } else { + &block.data + }; + if block.compressed { + let tmp = decompress(data); + self.file.as_mut().unwrap().write_all(&tmp).await?; + self.finished_size += tmp.len() as u64; + } else { + self.file.as_mut().unwrap().write_all(data).await?; + self.finished_size += data.len() as u64; + } + self.transferred += data.len() as u64; + Ok(()) + } + + #[inline] + fn join(&self, name: &str) -> PathBuf { + if name.is_empty() { + self.path.clone() + } else { + self.path.join(name) + } + } + + pub async fn read(&mut self) -> ResultType<Option<FileTransferBlock>> { + let file_num = self.file_num as usize; + if file_num >= self.files.len() { + self.file.take(); + return Ok(None); + } + let name = &self.files[file_num].name; + if self.file.is_none() { + match File::open(self.join(&name)).await { + Ok(file) => { + self.file = Some(file); + } + Err(err) => { + self.file_num += 1; + return Err(err.into()); + } + } + } + const BUF_SIZE: usize = 128 * 1024; + let mut buf: Vec<u8> = Vec::with_capacity(BUF_SIZE); + unsafe { + buf.set_len(BUF_SIZE); + } + let mut compressed = false; + let mut offset: usize = 0; + loop { + match self.file.as_mut().unwrap().read(&mut buf[offset..]).await { + Err(err) => { + self.file_num += 1; + self.file = None; + return Err(err.into()); + } + Ok(n) => { + offset += n; + if n == 0 || offset == BUF_SIZE { + break; + } + } + } + } + unsafe { buf.set_len(offset) }; + if offset == 0 { + self.file_num += 1; + self.file = None; + } else { + self.finished_size += offset as u64; + if !is_compressed_file(name) { + let tmp = compress(&buf, COMPRESS_LEVEL); + if tmp.len() < buf.len() { + buf = tmp; + compressed = true; + } + } + self.transferred += buf.len() as u64; + } + Ok(Some(FileTransferBlock { + id: self.id, + file_num: file_num as _, + data: buf.into(), + compressed, + ..Default::default() + })) + } +} + +#[inline] +pub fn new_error<T: std::string::ToString>(id: i32, err: T, file_num: i32) -> Message { + let mut resp = FileResponse::new(); + resp.set_error(FileTransferError { + id, + error: err.to_string(), + file_num, + ..Default::default() + }); + let mut msg_out = Message::new(); + msg_out.set_file_response(resp); + msg_out +} + +#[inline] +pub fn new_dir(id: i32, path: String, files: Vec<FileEntry>) -> Message { + let mut resp = FileResponse::new(); + resp.set_dir(FileDirectory { + id, + path, + entries: files.into(), + ..Default::default() + }); + let mut msg_out = Message::new(); + msg_out.set_file_response(resp); + msg_out +} + +#[inline] +pub fn new_block(block: FileTransferBlock) -> Message { + let mut resp = FileResponse::new(); + resp.set_block(block); + let mut msg_out = Message::new(); + msg_out.set_file_response(resp); + msg_out +} + +#[inline] +pub fn new_receive(id: i32, path: String, files: Vec<FileEntry>) -> Message { + let mut action = FileAction::new(); + action.set_receive(FileTransferReceiveRequest { + id, + path, + files: files.into(), + ..Default::default() + }); + let mut msg_out = Message::new(); + msg_out.set_file_action(action); + msg_out +} + +#[inline] +pub fn new_send(id: i32, path: String, include_hidden: bool) -> Message { + let mut action = FileAction::new(); + action.set_send(FileTransferSendRequest { + id, + path, + include_hidden, + ..Default::default() + }); + let mut msg_out = Message::new(); + msg_out.set_file_action(action); + msg_out +} + +#[inline] +pub fn new_done(id: i32, file_num: i32) -> Message { + let mut resp = FileResponse::new(); + resp.set_done(FileTransferDone { + id, + file_num, + ..Default::default() + }); + let mut msg_out = Message::new(); + msg_out.set_file_response(resp); + msg_out +} + +#[inline] +pub fn remove_job(id: i32, jobs: &mut Vec<TransferJob>) { + *jobs = jobs.drain(0..).filter(|x| x.id() != id).collect(); +} + +#[inline] +pub fn get_job(id: i32, jobs: &mut Vec<TransferJob>) -> Option<&mut TransferJob> { + jobs.iter_mut().filter(|x| x.id() == id).next() +} + +pub async fn handle_read_jobs( + jobs: &mut Vec<TransferJob>, + stream: &mut crate::Stream, +) -> ResultType<()> { + let mut finished = Vec::new(); + for job in jobs.iter_mut() { + match job.read().await { + Err(err) => { + stream + .send(&new_error(job.id(), err, job.file_num())) + .await?; + } + Ok(Some(block)) => { + stream.send(&new_block(block)).await?; + } + Ok(None) => { + finished.push(job.id()); + stream.send(&new_done(job.id(), job.file_num())).await?; + } + } + } + for id in finished { + remove_job(id, jobs); + } + Ok(()) +} + +pub fn remove_all_empty_dir(path: &PathBuf) -> ResultType<()> { + let fd = read_dir(path, true)?; + for entry in fd.entries.iter() { + match entry.entry_type.enum_value() { + Ok(FileType::Dir) => { + remove_all_empty_dir(&path.join(&entry.name)).ok(); + } + Ok(FileType::DirLink) | Ok(FileType::FileLink) => { + std::fs::remove_file(&path.join(&entry.name)).ok(); + } + _ => {} + } + } + std::fs::remove_dir(path).ok(); + Ok(()) +} + +#[inline] +pub fn remove_file(file: &str) -> ResultType<()> { + std::fs::remove_file(get_path(file))?; + Ok(()) +} + +#[inline] +pub fn create_dir(dir: &str) -> ResultType<()> { + std::fs::create_dir_all(get_path(dir))?; + Ok(()) +} diff --git a/libs/hbb_common/src/lib.rs b/libs/hbb_common/src/lib.rs new file mode 100644 index 0000000..0a9dace --- /dev/null +++ b/libs/hbb_common/src/lib.rs @@ -0,0 +1,211 @@ +pub mod compress; +#[path = "./protos/message.rs"] +pub mod message_proto; +#[path = "./protos/rendezvous.rs"] +pub mod rendezvous_proto; +pub use bytes; +pub use futures; +pub use protobuf; +use std::{ + fs::File, + io::{self, BufRead}, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + path::Path, + time::{self, SystemTime, UNIX_EPOCH}, +}; +pub use tokio; +pub use tokio_util; +pub mod socket_client; +pub mod tcp; +pub mod udp; +pub use env_logger; +pub use log; +pub mod bytes_codec; +#[cfg(feature = "quic")] +pub mod quic; +pub use anyhow::{self, bail}; +pub use futures_util; +pub mod config; +pub mod fs; +#[cfg(not(any(target_os = "android", target_os = "ios")))] +pub use mac_address; +pub use rand; +pub use regex; +pub use sodiumoxide; +pub use tokio_socks; +pub use tokio_socks::IntoTargetAddr; +pub use tokio_socks::TargetAddr; +pub use lazy_static; + +#[cfg(feature = "quic")] +pub type Stream = quic::Connection; +#[cfg(not(feature = "quic"))] +pub type Stream = tcp::FramedStream; + +#[inline] +pub async fn sleep(sec: f32) { + tokio::time::sleep(time::Duration::from_secs_f32(sec)).await; +} + +#[macro_export] +macro_rules! allow_err { + ($e:expr) => { + if let Err(err) = $e { + log::debug!( + "{:?}, {}:{}:{}:{}", + err, + module_path!(), + file!(), + line!(), + column!() + ); + } else { + } + }; +} + +#[inline] +pub fn timeout<T: std::future::Future>(ms: u64, future: T) -> tokio::time::Timeout<T> { + tokio::time::timeout(std::time::Duration::from_millis(ms), future) +} + +pub type ResultType<F, E = anyhow::Error> = anyhow::Result<F, E>; + +/// Certain router and firewalls scan the packet and if they +/// find an IP address belonging to their pool that they use to do the NAT mapping/translation, so here we mangle the ip address + +pub struct AddrMangle(); + +impl AddrMangle { + pub fn encode(addr: SocketAddr) -> Vec<u8> { + match addr { + SocketAddr::V4(addr_v4) => { + let tm = (SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_micros() as u32) as u128; + let ip = u32::from_le_bytes(addr_v4.ip().octets()) as u128; + let port = addr.port() as u128; + let v = ((ip + tm) << 49) | (tm << 17) | (port + (tm & 0xFFFF)); + let bytes = v.to_le_bytes(); + let mut n_padding = 0; + for i in bytes.iter().rev() { + if i == &0u8 { + n_padding += 1; + } else { + break; + } + } + bytes[..(16 - n_padding)].to_vec() + } + _ => { + panic!("Only support ipv4"); + } + } + } + + pub fn decode(bytes: &[u8]) -> SocketAddr { + let mut padded = [0u8; 16]; + padded[..bytes.len()].copy_from_slice(&bytes); + let number = u128::from_le_bytes(padded); + let tm = (number >> 17) & (u32::max_value() as u128); + let ip = (((number >> 49) - tm) as u32).to_le_bytes(); + let port = (number & 0xFFFFFF) - (tm & 0xFFFF); + SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]), + port as u16, + )) + } +} + +pub fn get_version_from_url(url: &str) -> String { + let n = url.chars().count(); + let a = url + .chars() + .rev() + .enumerate() + .filter(|(_, x)| x == &'-') + .next() + .map(|(i, _)| i); + if let Some(a) = a { + let b = url + .chars() + .rev() + .enumerate() + .filter(|(_, x)| x == &'.') + .next() + .map(|(i, _)| i); + if let Some(b) = b { + if a > b { + if url + .chars() + .skip(n - b) + .collect::<String>() + .parse::<i32>() + .is_ok() + { + return url.chars().skip(n - a).collect(); + } else { + return url.chars().skip(n - a).take(a - b - 1).collect(); + } + } else { + return url.chars().skip(n - a).collect(); + } + } + } + "".to_owned() +} + +pub fn gen_version() { + let mut file = File::create("./src/version.rs").unwrap(); + for line in read_lines("Cargo.toml").unwrap() { + if let Ok(line) = line { + let ab: Vec<&str> = line.split("=").map(|x| x.trim()).collect(); + if ab.len() == 2 && ab[0] == "version" { + use std::io::prelude::*; + file.write_all(format!("pub const VERSION: &str = {};", ab[1]).as_bytes()) + .ok(); + file.sync_all().ok(); + break; + } + } + } +} + +fn read_lines<P>(filename: P) -> io::Result<io::Lines<io::BufReader<File>>> +where + P: AsRef<Path>, +{ + let file = File::open(filename)?; + Ok(io::BufReader::new(file).lines()) +} + +pub fn is_valid_custom_id(id: &str) -> bool { + regex::Regex::new(r"^[a-zA-Z]\w{5,15}$") + .unwrap() + .is_match(id) +} + +pub fn get_version_number(v: &str) -> i64 { + let mut n = 0; + for x in v.split(".") { + n = n * 1000 + x.parse::<i64>().unwrap_or(0); + } + n +} + +pub fn get_modified_time(path: &std::path::Path) -> SystemTime { + std::fs::metadata(&path) + .map(|m| m.modified().unwrap_or(UNIX_EPOCH)) + .unwrap_or(UNIX_EPOCH) +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_mangle() { + let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 16, 32), 21116)); + assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); + } +} diff --git a/libs/hbb_common/src/quic.rs b/libs/hbb_common/src/quic.rs new file mode 100644 index 0000000..ada2acd --- /dev/null +++ b/libs/hbb_common/src/quic.rs @@ -0,0 +1,135 @@ +use crate::{allow_err, anyhow::anyhow, ResultType}; +use protobuf::Message; +use std::{net::SocketAddr, sync::Arc}; +use tokio::{self, stream::StreamExt, sync::mpsc}; + +const QUIC_HBB: &[&[u8]] = &[b"hbb"]; +const SERVER_NAME: &str = "hbb"; + +type Sender = mpsc::UnboundedSender<Value>; +type Receiver = mpsc::UnboundedReceiver<Value>; + +pub fn new_server(socket: std::net::UdpSocket) -> ResultType<(Server, SocketAddr)> { + let mut transport_config = quinn::TransportConfig::default(); + transport_config.stream_window_uni(0); + let mut server_config = quinn::ServerConfig::default(); + server_config.transport = Arc::new(transport_config); + let mut server_config = quinn::ServerConfigBuilder::new(server_config); + server_config.protocols(QUIC_HBB); + // server_config.enable_keylog(); + // server_config.use_stateless_retry(true); + let mut endpoint = quinn::Endpoint::builder(); + endpoint.listen(server_config.build()); + let (end, incoming) = endpoint.with_socket(socket)?; + Ok((Server { incoming }, end.local_addr()?)) +} + +pub async fn new_client(local_addr: &SocketAddr, peer: &SocketAddr) -> ResultType<Connection> { + let mut endpoint = quinn::Endpoint::builder(); + let mut client_config = quinn::ClientConfigBuilder::default(); + client_config.protocols(QUIC_HBB); + //client_config.enable_keylog(); + endpoint.default_client_config(client_config.build()); + let (endpoint, _) = endpoint.bind(local_addr)?; + let new_conn = endpoint.connect(peer, SERVER_NAME)?.await?; + Connection::new_for_client(new_conn.connection).await +} + +pub struct Server { + incoming: quinn::Incoming, +} + +impl Server { + #[inline] + pub async fn next(&mut self) -> ResultType<Option<Connection>> { + Connection::new_for_server(&mut self.incoming).await + } +} + +pub struct Connection { + conn: quinn::Connection, + tx: quinn::SendStream, + rx: Receiver, +} + +type Value = ResultType<Vec<u8>>; + +impl Connection { + async fn new_for_server(incoming: &mut quinn::Incoming) -> ResultType<Option<Self>> { + if let Some(conn) = incoming.next().await { + let quinn::NewConnection { + connection: conn, + // uni_streams, + mut bi_streams, + .. + } = conn.await?; + let (tx, rx) = mpsc::unbounded_channel::<Value>(); + tokio::spawn(async move { + loop { + let stream = bi_streams.next().await; + if let Some(stream) = stream { + let stream = match stream { + Err(e) => { + tx.send(Err(e.into())).ok(); + break; + } + Ok(s) => s, + }; + let cloned = tx.clone(); + tokio::spawn(async move { + allow_err!(handle_request(stream.1, cloned).await); + }); + } else { + tx.send(Err(anyhow!("Reset by the peer"))).ok(); + break; + } + } + log::info!("Exit connection outer loop"); + }); + let tx = conn.open_uni().await?; + Ok(Some(Self { conn, tx, rx })) + } else { + Ok(None) + } + } + + async fn new_for_client(conn: quinn::Connection) -> ResultType<Self> { + let (tx, rx_quic) = conn.open_bi().await?; + let (tx_mpsc, rx) = mpsc::unbounded_channel::<Value>(); + tokio::spawn(async move { + allow_err!(handle_request(rx_quic, tx_mpsc).await); + }); + Ok(Self { conn, tx, rx }) + } + + #[inline] + pub async fn next(&mut self) -> Option<Value> { + // None is returned when all Sender halves have dropped, + // indicating that no further values can be sent on the channel. + self.rx.recv().await + } + + #[inline] + pub fn remote_address(&self) -> SocketAddr { + self.conn.remote_address() + } + + #[inline] + pub async fn send_raw(&mut self, bytes: &[u8]) -> ResultType<()> { + self.tx.write_all(bytes).await?; + Ok(()) + } + + #[inline] + pub async fn send(&mut self, msg: &dyn Message) -> ResultType<()> { + match msg.write_to_bytes() { + Ok(bytes) => self.send_raw(&bytes).await?, + err => allow_err!(err), + } + Ok(()) + } +} + +async fn handle_request(rx: quinn::RecvStream, tx: Sender) -> ResultType<()> { + Ok(()) +} diff --git a/libs/hbb_common/src/socket_client.rs b/libs/hbb_common/src/socket_client.rs new file mode 100644 index 0000000..0375b71 --- /dev/null +++ b/libs/hbb_common/src/socket_client.rs @@ -0,0 +1,91 @@ +use crate::{ + config::{Config, NetworkType}, + tcp::FramedStream, + udp::FramedSocket, + ResultType, +}; +use anyhow::Context; +use std::net::SocketAddr; +use tokio::net::ToSocketAddrs; +use tokio_socks::{IntoTargetAddr, TargetAddr}; + +fn to_socket_addr(host: &str) -> ResultType<SocketAddr> { + use std::net::ToSocketAddrs; + host.to_socket_addrs()?.next().context("Failed to solve") +} + +pub fn get_target_addr(host: &str) -> ResultType<TargetAddr<'static>> { + let addr = match Config::get_network_type() { + NetworkType::Direct => to_socket_addr(&host)?.into_target_addr()?, + NetworkType::ProxySocks => host.into_target_addr()?, + } + .to_owned(); + Ok(addr) +} + +pub fn test_if_valid_server(host: &str) -> String { + let mut host = host.to_owned(); + if !host.contains(":") { + host = format!("{}:{}", host, 0); + } + + match Config::get_network_type() { + NetworkType::Direct => match to_socket_addr(&host) { + Err(err) => err.to_string(), + Ok(_) => "".to_owned(), + }, + NetworkType::ProxySocks => match &host.into_target_addr() { + Err(err) => err.to_string(), + Ok(_) => "".to_owned(), + }, + } +} + +pub async fn connect_tcp<'t, T: IntoTargetAddr<'t>>( + target: T, + local: SocketAddr, + ms_timeout: u64, +) -> ResultType<FramedStream> { + let target_addr = target.into_target_addr()?; + + if let Some(conf) = Config::get_socks() { + FramedStream::connect( + conf.proxy.as_str(), + target_addr, + local, + conf.username.as_str(), + conf.password.as_str(), + ms_timeout, + ) + .await + } else { + let addr = std::net::ToSocketAddrs::to_socket_addrs(&target_addr)? + .next() + .context("Invalid target addr")?; + Ok(FramedStream::new(addr, local, ms_timeout).await?) + } +} + +pub async fn new_udp<T: ToSocketAddrs>(local: T, ms_timeout: u64) -> ResultType<FramedSocket> { + match Config::get_socks() { + None => Ok(FramedSocket::new(local).await?), + Some(conf) => { + let socket = FramedSocket::new_proxy( + conf.proxy.as_str(), + local, + conf.username.as_str(), + conf.password.as_str(), + ms_timeout, + ) + .await?; + Ok(socket) + } + } +} + +pub async fn rebind_udp<T: ToSocketAddrs>(local: T) -> ResultType<Option<FramedSocket>> { + match Config::get_network_type() { + NetworkType::Direct => Ok(Some(FramedSocket::new(local).await?)), + _ => Ok(None), + } +} diff --git a/libs/hbb_common/src/tcp.rs b/libs/hbb_common/src/tcp.rs new file mode 100644 index 0000000..7966920 --- /dev/null +++ b/libs/hbb_common/src/tcp.rs @@ -0,0 +1,285 @@ +use crate::{bail, bytes_codec::BytesCodec, ResultType}; +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{SinkExt, StreamExt}; +use protobuf::Message; +use sodiumoxide::crypto::secretbox::{self, Key, Nonce}; +use std::{ + io::{self, Error, ErrorKind}, + net::SocketAddr, + ops::{Deref, DerefMut}, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::{lookup_host, TcpListener, TcpSocket, ToSocketAddrs}, +}; +use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr, ToProxyAddrs}; +use tokio_util::codec::Framed; + +pub trait TcpStreamTrait: AsyncRead + AsyncWrite + Unpin {} +pub struct DynTcpStream(Box<dyn TcpStreamTrait + Send + Sync>); + +pub struct FramedStream( + Framed<DynTcpStream, BytesCodec>, + SocketAddr, + Option<(Key, u64, u64)>, + u64, +); + +impl Deref for FramedStream { + type Target = Framed<DynTcpStream, BytesCodec>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for FramedStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Deref for DynTcpStream { + type Target = Box<dyn TcpStreamTrait + Send + Sync>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for DynTcpStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +fn new_socket(addr: std::net::SocketAddr, reuse: bool) -> Result<TcpSocket, std::io::Error> { + let socket = match addr { + std::net::SocketAddr::V4(..) => TcpSocket::new_v4()?, + std::net::SocketAddr::V6(..) => TcpSocket::new_v6()?, + }; + if reuse { + // windows has no reuse_port, but it's reuse_address + // almost equals to unix's reuse_port + reuse_address, + // though may introduce nondeterministic behavior + #[cfg(unix)] + socket.set_reuseport(true)?; + socket.set_reuseaddr(true)?; + } + socket.bind(addr)?; + Ok(socket) +} + +impl FramedStream { + pub async fn new<T1: ToSocketAddrs, T2: ToSocketAddrs>( + remote_addr: T1, + local_addr: T2, + ms_timeout: u64, + ) -> ResultType<Self> { + for local_addr in lookup_host(&local_addr).await? { + for remote_addr in lookup_host(&remote_addr).await? { + let stream = super::timeout( + ms_timeout, + new_socket(local_addr, true)?.connect(remote_addr), + ) + .await??; + stream.set_nodelay(true).ok(); + let addr = stream.local_addr()?; + return Ok(Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )); + } + } + bail!("could not resolve to any address"); + } + + pub async fn connect<'a, 't, P, T1, T2>( + proxy: P, + target: T1, + local: T2, + username: &'a str, + password: &'a str, + ms_timeout: u64, + ) -> ResultType<Self> + where + P: ToProxyAddrs, + T1: IntoTargetAddr<'t>, + T2: ToSocketAddrs, + { + if let Some(local) = lookup_host(&local).await?.next() { + if let Some(proxy) = proxy.to_proxy_addrs().next().await { + let stream = + super::timeout(ms_timeout, new_socket(local, true)?.connect(proxy?)).await??; + stream.set_nodelay(true).ok(); + let stream = if username.trim().is_empty() { + super::timeout( + ms_timeout, + Socks5Stream::connect_with_socket(stream, target), + ) + .await?? + } else { + super::timeout( + ms_timeout, + Socks5Stream::connect_with_password_and_socket( + stream, target, username, password, + ), + ) + .await?? + }; + let addr = stream.local_addr()?; + return Ok(Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )); + }; + }; + bail!("could not resolve to any address"); + } + + pub fn local_addr(&self) -> SocketAddr { + self.1 + } + + pub fn set_send_timeout(&mut self, ms: u64) { + self.3 = ms; + } + + pub fn from(stream: impl TcpStreamTrait + Send + Sync + 'static, addr: SocketAddr) -> Self { + Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + ) + } + + pub fn set_raw(&mut self) { + self.0.codec_mut().set_raw(); + self.2 = None; + } + + pub fn is_secured(&self) -> bool { + self.2.is_some() + } + + #[inline] + pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> { + self.send_raw(msg.write_to_bytes()?).await + } + + #[inline] + pub async fn send_raw(&mut self, msg: Vec<u8>) -> ResultType<()> { + let mut msg = msg; + if let Some(key) = self.2.as_mut() { + key.1 += 1; + let nonce = Self::get_nonce(key.1); + msg = secretbox::seal(&msg, &nonce, &key.0); + } + self.send_bytes(bytes::Bytes::from(msg)).await?; + Ok(()) + } + + #[inline] + pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { + if self.3 > 0 { + super::timeout(self.3, self.0.send(bytes)).await??; + } else { + self.0.send(bytes).await?; + } + Ok(()) + } + + #[inline] + pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> { + let mut res = self.0.next().await; + if let Some(key) = self.2.as_mut() { + if let Some(Ok(bytes)) = res.as_mut() { + key.2 += 1; + let nonce = Self::get_nonce(key.2); + match secretbox::open(&bytes, &nonce, &key.0) { + Ok(res) => { + bytes.clear(); + bytes.put_slice(&res); + } + Err(()) => { + return Some(Err(Error::new(ErrorKind::Other, "decryption error"))); + } + } + } + } + res + } + + #[inline] + pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> { + if let Ok(res) = super::timeout(ms, self.next()).await { + res + } else { + None + } + } + + pub fn set_key(&mut self, key: Key) { + self.2 = Some((key, 0, 0)); + } + + fn get_nonce(seqnum: u64) -> Nonce { + let mut nonce = Nonce([0u8; secretbox::NONCEBYTES]); + nonce.0[..std::mem::size_of_val(&seqnum)].copy_from_slice(&seqnum.to_le_bytes()); + nonce + } +} + +const DEFAULT_BACKLOG: u32 = 128; + +#[allow(clippy::never_loop)] +pub async fn new_listener<T: ToSocketAddrs>(addr: T, reuse: bool) -> ResultType<TcpListener> { + if !reuse { + Ok(TcpListener::bind(addr).await?) + } else { + for addr in lookup_host(&addr).await? { + let socket = new_socket(addr, true)?; + return Ok(socket.listen(DEFAULT_BACKLOG)?); + } + bail!("could not resolve to any address"); + } +} + +impl Unpin for DynTcpStream {} + +impl AsyncRead for DynTcpStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf) + } +} + +impl AsyncWrite for DynTcpStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + AsyncWrite::poll_flush(Pin::new(&mut self.0), cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx) + } +} + +impl<R: AsyncRead + AsyncWrite + Unpin> TcpStreamTrait for R {} diff --git a/libs/hbb_common/src/udp.rs b/libs/hbb_common/src/udp.rs new file mode 100644 index 0000000..0338618 --- /dev/null +++ b/libs/hbb_common/src/udp.rs @@ -0,0 +1,165 @@ +use crate::{bail, ResultType}; +use anyhow::anyhow; +use bytes::{Bytes, BytesMut}; +use futures::{SinkExt, StreamExt}; +use protobuf::Message; +use socket2::{Domain, Socket, Type}; +use std::net::SocketAddr; +use tokio::net::{ToSocketAddrs, UdpSocket}; +use tokio_socks::{udp::Socks5UdpFramed, IntoTargetAddr, TargetAddr, ToProxyAddrs}; +use tokio_util::{codec::BytesCodec, udp::UdpFramed}; + +pub enum FramedSocket { + Direct(UdpFramed<BytesCodec>), + ProxySocks(Socks5UdpFramed), +} + +fn new_socket(addr: SocketAddr, reuse: bool, buf_size: usize) -> Result<Socket, std::io::Error> { + let socket = match addr { + SocketAddr::V4(..) => Socket::new(Domain::ipv4(), Type::dgram(), None), + SocketAddr::V6(..) => Socket::new(Domain::ipv6(), Type::dgram(), None), + }?; + if reuse { + // windows has no reuse_port, but it's reuse_address + // almost equals to unix's reuse_port + reuse_address, + // though may introduce nondeterministic behavior + #[cfg(unix)] + socket.set_reuse_port(true)?; + socket.set_reuse_address(true)?; + } + if buf_size > 0 { + socket.set_recv_buffer_size(buf_size).ok(); + } + log::info!( + "Receive buf size of udp {}: {:?}", + addr, + socket.recv_buffer_size() + ); + socket.bind(&addr.into())?; + Ok(socket) +} + +impl FramedSocket { + pub async fn new<T: ToSocketAddrs>(addr: T) -> ResultType<Self> { + let socket = UdpSocket::bind(addr).await?; + Ok(Self::Direct(UdpFramed::new(socket, BytesCodec::new()))) + } + + #[allow(clippy::never_loop)] + pub async fn new_reuse<T: std::net::ToSocketAddrs>(addr: T) -> ResultType<Self> { + for addr in addr.to_socket_addrs()? { + let socket = new_socket(addr, true, 0)?.into_udp_socket(); + return Ok(Self::Direct(UdpFramed::new( + UdpSocket::from_std(socket)?, + BytesCodec::new(), + ))); + } + bail!("could not resolve to any address"); + } + + pub async fn new_with_buf_size<T: std::net::ToSocketAddrs>( + addr: T, + buf_size: usize, + ) -> ResultType<Self> { + for addr in addr.to_socket_addrs()? { + return Ok(Self::Direct(UdpFramed::new( + UdpSocket::from_std(new_socket(addr, false, buf_size)?.into_udp_socket())?, + BytesCodec::new(), + ))); + } + bail!("could not resolve to any address"); + } + + pub async fn new_proxy<'a, 't, P: ToProxyAddrs, T: ToSocketAddrs>( + proxy: P, + local: T, + username: &'a str, + password: &'a str, + ms_timeout: u64, + ) -> ResultType<Self> { + let framed = if username.trim().is_empty() { + super::timeout(ms_timeout, Socks5UdpFramed::connect(proxy, Some(local))).await?? + } else { + super::timeout( + ms_timeout, + Socks5UdpFramed::connect_with_password(proxy, Some(local), username, password), + ) + .await?? + }; + log::trace!( + "Socks5 udp connected, local addr: {:?}, target addr: {}", + framed.local_addr(), + framed.socks_addr() + ); + Ok(Self::ProxySocks(framed)) + } + + #[inline] + pub async fn send( + &mut self, + msg: &impl Message, + addr: impl IntoTargetAddr<'_>, + ) -> ResultType<()> { + let addr = addr.into_target_addr()?.to_owned(); + let send_data = Bytes::from(msg.write_to_bytes()?); + let _ = match self { + Self::Direct(f) => match addr { + TargetAddr::Ip(addr) => f.send((send_data, addr)).await?, + _ => {} + }, + Self::ProxySocks(f) => f.send((send_data, addr)).await?, + }; + Ok(()) + } + + // https://stackoverflow.com/a/68733302/1926020 + #[inline] + pub async fn send_raw( + &mut self, + msg: &'static [u8], + addr: impl IntoTargetAddr<'static>, + ) -> ResultType<()> { + let addr = addr.into_target_addr()?.to_owned(); + + let _ = match self { + Self::Direct(f) => match addr { + TargetAddr::Ip(addr) => f.send((Bytes::from(msg), addr)).await?, + _ => {} + }, + Self::ProxySocks(f) => f.send((Bytes::from(msg), addr)).await?, + }; + Ok(()) + } + + #[inline] + pub async fn next(&mut self) -> Option<ResultType<(BytesMut, TargetAddr<'static>)>> { + match self { + Self::Direct(f) => match f.next().await { + Some(Ok((data, addr))) => { + Some(Ok((data, addr.into_target_addr().ok()?.to_owned()))) + } + Some(Err(e)) => Some(Err(anyhow!(e))), + None => None, + }, + Self::ProxySocks(f) => match f.next().await { + Some(Ok((data, _))) => Some(Ok((data.data, data.dst_addr))), + Some(Err(e)) => Some(Err(anyhow!(e))), + None => None, + }, + } + } + + #[inline] + pub async fn next_timeout( + &mut self, + ms: u64, + ) -> Option<ResultType<(BytesMut, TargetAddr<'static>)>> { + if let Ok(res) = + tokio::time::timeout(std::time::Duration::from_millis(ms), self.next()).await + { + res + } else { + None + } + } +} |