aboutsummaryrefslogtreecommitdiffhomepage
path: root/libs
diff options
context:
space:
mode:
authorrustdesk <[email protected]>2022-05-12 20:07:45 +0800
committerrustdesk <[email protected]>2022-05-12 20:07:45 +0800
commit03ca2a95177d16999adad254057eb9c27b5e53b6 (patch)
tree36dab7dab5e015de94b08db67581a3b782599d56 /libs
parentb3f39598a7324dacec0cd84d5e09b95724805cc8 (diff)
downloadrustdesk-server-03ca2a95177d16999adad254057eb9c27b5e53b6.tar.gz
rustdesk-server-03ca2a95177d16999adad254057eb9c27b5e53b6.zip
missed files
Diffstat (limited to 'libs')
-rw-r--r--libs/hbb_common/.gitignore4
-rw-r--r--libs/hbb_common/Cargo.toml48
-rw-r--r--libs/hbb_common/build.rs9
-rw-r--r--libs/hbb_common/protos/message.proto481
-rw-r--r--libs/hbb_common/protos/rendezvous.proto171
-rw-r--r--libs/hbb_common/src/bytes_codec.rs274
-rw-r--r--libs/hbb_common/src/compress.rs50
-rw-r--r--libs/hbb_common/src/config.rs876
-rw-r--r--libs/hbb_common/src/fs.rs560
-rw-r--r--libs/hbb_common/src/lib.rs211
-rw-r--r--libs/hbb_common/src/quic.rs135
-rw-r--r--libs/hbb_common/src/socket_client.rs91
-rw-r--r--libs/hbb_common/src/tcp.rs285
-rw-r--r--libs/hbb_common/src/udp.rs165
14 files changed, 3360 insertions, 0 deletions
diff --git a/libs/hbb_common/.gitignore b/libs/hbb_common/.gitignore
new file mode 100644
index 0000000..b1cf151
--- /dev/null
+++ b/libs/hbb_common/.gitignore
@@ -0,0 +1,4 @@
+/target
+**/*.rs.bk
+Cargo.lock
+src/protos/
diff --git a/libs/hbb_common/Cargo.toml b/libs/hbb_common/Cargo.toml
new file mode 100644
index 0000000..bc31223
--- /dev/null
+++ b/libs/hbb_common/Cargo.toml
@@ -0,0 +1,48 @@
+[package]
+name = "hbb_common"
+version = "0.1.0"
+authors = ["open-trade <[email protected]>"]
+edition = "2018"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+protobuf = "3.0.0-alpha.2"
+tokio = { version = "1.15", features = ["full"] }
+tokio-util = { version = "0.6", features = ["full"] }
+futures = "0.3"
+bytes = "1.1"
+log = "0.4"
+env_logger = "0.9"
+socket2 = { version = "0.3", features = ["reuseport"] }
+zstd = "0.9"
+quinn = {version = "0.8", optional = true }
+anyhow = "1.0"
+futures-util = "0.3"
+directories-next = "2.0"
+rand = "0.8"
+serde_derive = "1.0"
+serde = "1.0"
+lazy_static = "1.4"
+confy = { git = "https://github.com/open-trade/confy" }
+dirs-next = "2.0"
+filetime = "0.2"
+sodiumoxide = "0.2"
+regex = "1.4"
+tokio-socks = { git = "https://github.com/open-trade/tokio-socks" }
+
+[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
+mac_address = "1.1"
+
+[features]
+quic = []
+
+[build-dependencies]
+protobuf-codegen-pure = "3.0.0-alpha.2"
+
+[target.'cfg(target_os = "windows")'.dependencies]
+winapi = { version = "0.3", features = ["winuser"] }
+
+[dev-dependencies]
+toml = "0.5"
+serde_json = "1.0"
diff --git a/libs/hbb_common/build.rs b/libs/hbb_common/build.rs
new file mode 100644
index 0000000..99dacb7
--- /dev/null
+++ b/libs/hbb_common/build.rs
@@ -0,0 +1,9 @@
+fn main() {
+ std::fs::create_dir_all("src/protos").unwrap();
+ protobuf_codegen_pure::Codegen::new()
+ .out_dir("src/protos")
+ .inputs(&["protos/rendezvous.proto", "protos/message.proto"])
+ .include("protos")
+ .run()
+ .expect("Codegen failed.");
+}
diff --git a/libs/hbb_common/protos/message.proto b/libs/hbb_common/protos/message.proto
new file mode 100644
index 0000000..f95007d
--- /dev/null
+++ b/libs/hbb_common/protos/message.proto
@@ -0,0 +1,481 @@
+syntax = "proto3";
+package hbb;
+
+message VP9 {
+ bytes data = 1;
+ bool key = 2;
+ int64 pts = 3;
+}
+
+message VP9s { repeated VP9 frames = 1; }
+
+message RGB { bool compress = 1; }
+
+// planes data send directly in binary for better use arraybuffer on web
+message YUV {
+ bool compress = 1;
+ int32 stride = 2;
+}
+
+message VideoFrame {
+ oneof union {
+ VP9s vp9s = 6;
+ RGB rgb = 7;
+ YUV yuv = 8;
+ }
+}
+
+message IdPk {
+ string id = 1;
+ bytes pk = 2;
+}
+
+message DisplayInfo {
+ sint32 x = 1;
+ sint32 y = 2;
+ int32 width = 3;
+ int32 height = 4;
+ string name = 5;
+ bool online = 6;
+}
+
+message PortForward {
+ string host = 1;
+ int32 port = 2;
+}
+
+message FileTransfer {
+ string dir = 1;
+ bool show_hidden = 2;
+}
+
+message LoginRequest {
+ string username = 1;
+ bytes password = 2;
+ string my_id = 4;
+ string my_name = 5;
+ OptionMessage option = 6;
+ oneof union {
+ FileTransfer file_transfer = 7;
+ PortForward port_forward = 8;
+ }
+ bool video_ack_required = 9;
+}
+
+message ChatMessage { string text = 1; }
+
+message PeerInfo {
+ string username = 1;
+ string hostname = 2;
+ string platform = 3;
+ repeated DisplayInfo displays = 4;
+ int32 current_display = 5;
+ bool sas_enabled = 6;
+ string version = 7;
+ int32 conn_id = 8;
+}
+
+message LoginResponse {
+ oneof union {
+ string error = 1;
+ PeerInfo peer_info = 2;
+ }
+}
+
+message MouseEvent {
+ int32 mask = 1;
+ sint32 x = 2;
+ sint32 y = 3;
+ repeated ControlKey modifiers = 4;
+}
+
+enum ControlKey {
+ Unknown = 0;
+ Alt = 1;
+ Backspace = 2;
+ CapsLock = 3;
+ Control = 4;
+ Delete = 5;
+ DownArrow = 6;
+ End = 7;
+ Escape = 8;
+ F1 = 9;
+ F10 = 10;
+ F11 = 11;
+ F12 = 12;
+ F2 = 13;
+ F3 = 14;
+ F4 = 15;
+ F5 = 16;
+ F6 = 17;
+ F7 = 18;
+ F8 = 19;
+ F9 = 20;
+ Home = 21;
+ LeftArrow = 22;
+ /// meta key (also known as "windows"; "super"; and "command")
+ Meta = 23;
+ /// option key on macOS (alt key on Linux and Windows)
+ Option = 24; // deprecated, use Alt instead
+ PageDown = 25;
+ PageUp = 26;
+ Return = 27;
+ RightArrow = 28;
+ Shift = 29;
+ Space = 30;
+ Tab = 31;
+ UpArrow = 32;
+ Numpad0 = 33;
+ Numpad1 = 34;
+ Numpad2 = 35;
+ Numpad3 = 36;
+ Numpad4 = 37;
+ Numpad5 = 38;
+ Numpad6 = 39;
+ Numpad7 = 40;
+ Numpad8 = 41;
+ Numpad9 = 42;
+ Cancel = 43;
+ Clear = 44;
+ Menu = 45; // deprecated, use Alt instead
+ Pause = 46;
+ Kana = 47;
+ Hangul = 48;
+ Junja = 49;
+ Final = 50;
+ Hanja = 51;
+ Kanji = 52;
+ Convert = 53;
+ Select = 54;
+ Print = 55;
+ Execute = 56;
+ Snapshot = 57;
+ Insert = 58;
+ Help = 59;
+ Sleep = 60;
+ Separator = 61;
+ Scroll = 62;
+ NumLock = 63;
+ RWin = 64;
+ Apps = 65;
+ Multiply = 66;
+ Add = 67;
+ Subtract = 68;
+ Decimal = 69;
+ Divide = 70;
+ Equals = 71;
+ NumpadEnter = 72;
+ RShift = 73;
+ RControl = 74;
+ RAlt = 75;
+ CtrlAltDel = 100;
+ LockScreen = 101;
+}
+
+message KeyEvent {
+ bool down = 1;
+ bool press = 2;
+ oneof union {
+ ControlKey control_key = 3;
+ uint32 chr = 4;
+ uint32 unicode = 5;
+ string seq = 6;
+ }
+ repeated ControlKey modifiers = 8;
+}
+
+message CursorData {
+ uint64 id = 1;
+ sint32 hotx = 2;
+ sint32 hoty = 3;
+ int32 width = 4;
+ int32 height = 5;
+ bytes colors = 6;
+}
+
+message CursorPosition {
+ sint32 x = 1;
+ sint32 y = 2;
+}
+
+message Hash {
+ string salt = 1;
+ string challenge = 2;
+}
+
+message Clipboard {
+ bool compress = 1;
+ bytes content = 2;
+}
+
+enum FileType {
+ Dir = 0;
+ DirLink = 2;
+ DirDrive = 3;
+ File = 4;
+ FileLink = 5;
+}
+
+message FileEntry {
+ FileType entry_type = 1;
+ string name = 2;
+ bool is_hidden = 3;
+ uint64 size = 4;
+ uint64 modified_time = 5;
+}
+
+message FileDirectory {
+ int32 id = 1;
+ string path = 2;
+ repeated FileEntry entries = 3;
+}
+
+message ReadDir {
+ string path = 1;
+ bool include_hidden = 2;
+}
+
+message ReadAllFiles {
+ int32 id = 1;
+ string path = 2;
+ bool include_hidden = 3;
+}
+
+message FileAction {
+ oneof union {
+ ReadDir read_dir = 1;
+ FileTransferSendRequest send = 2;
+ FileTransferReceiveRequest receive = 3;
+ FileDirCreate create = 4;
+ FileRemoveDir remove_dir = 5;
+ FileRemoveFile remove_file = 6;
+ ReadAllFiles all_files = 7;
+ FileTransferCancel cancel = 8;
+ }
+}
+
+message FileTransferCancel { int32 id = 1; }
+
+message FileResponse {
+ oneof union {
+ FileDirectory dir = 1;
+ FileTransferBlock block = 2;
+ FileTransferError error = 3;
+ FileTransferDone done = 4;
+ }
+}
+
+message FileTransferBlock {
+ int32 id = 1;
+ sint32 file_num = 2;
+ bytes data = 3;
+ bool compressed = 4;
+}
+
+message FileTransferError {
+ int32 id = 1;
+ string error = 2;
+ sint32 file_num = 3;
+}
+
+message FileTransferSendRequest {
+ int32 id = 1;
+ string path = 2;
+ bool include_hidden = 3;
+}
+
+message FileTransferDone {
+ int32 id = 1;
+ sint32 file_num = 2;
+}
+
+message FileTransferReceiveRequest {
+ int32 id = 1;
+ string path = 2; // path written to
+ repeated FileEntry files = 3;
+}
+
+message FileRemoveDir {
+ int32 id = 1;
+ string path = 2;
+ bool recursive = 3;
+}
+
+message FileRemoveFile {
+ int32 id = 1;
+ string path = 2;
+ sint32 file_num = 3;
+}
+
+message FileDirCreate {
+ int32 id = 1;
+ string path = 2;
+}
+
+// main logic from freeRDP
+message CliprdrMonitorReady {
+ int32 conn_id = 1;
+}
+
+message CliprdrFormat {
+ int32 conn_id = 1;
+ int32 id = 2;
+ string format = 3;
+}
+
+message CliprdrServerFormatList {
+ int32 conn_id = 1;
+ repeated CliprdrFormat formats = 2;
+}
+
+message CliprdrServerFormatListResponse {
+ int32 conn_id = 1;
+ int32 msg_flags = 2;
+}
+
+message CliprdrServerFormatDataRequest {
+ int32 conn_id = 1;
+ int32 requested_format_id = 2;
+}
+
+message CliprdrServerFormatDataResponse {
+ int32 conn_id = 1;
+ int32 msg_flags = 2;
+ bytes format_data = 3;
+}
+
+message CliprdrFileContentsRequest {
+ int32 conn_id = 1;
+ int32 stream_id = 2;
+ int32 list_index = 3;
+ int32 dw_flags = 4;
+ int32 n_position_low = 5;
+ int32 n_position_high = 6;
+ int32 cb_requested = 7;
+ bool have_clip_data_id = 8;
+ int32 clip_data_id = 9;
+}
+
+message CliprdrFileContentsResponse {
+ int32 conn_id = 1;
+ int32 msg_flags = 3;
+ int32 stream_id = 4;
+ bytes requested_data = 5;
+}
+
+message Cliprdr {
+ oneof union {
+ CliprdrMonitorReady ready = 1;
+ CliprdrServerFormatList format_list = 2;
+ CliprdrServerFormatListResponse format_list_response = 3;
+ CliprdrServerFormatDataRequest format_data_request = 4;
+ CliprdrServerFormatDataResponse format_data_response = 5;
+ CliprdrFileContentsRequest file_contents_request = 6;
+ CliprdrFileContentsResponse file_contents_response = 7;
+ }
+}
+
+message SwitchDisplay {
+ int32 display = 1;
+ sint32 x = 2;
+ sint32 y = 3;
+ int32 width = 4;
+ int32 height = 5;
+}
+
+message PermissionInfo {
+ enum Permission {
+ Keyboard = 0;
+ Clipboard = 2;
+ Audio = 3;
+ File = 4;
+ }
+
+ Permission permission = 1;
+ bool enabled = 2;
+}
+
+enum ImageQuality {
+ NotSet = 0;
+ Low = 2;
+ Balanced = 3;
+ Best = 4;
+}
+
+message OptionMessage {
+ enum BoolOption {
+ NotSet = 0;
+ No = 1;
+ Yes = 2;
+ }
+ ImageQuality image_quality = 1;
+ BoolOption lock_after_session_end = 2;
+ BoolOption show_remote_cursor = 3;
+ BoolOption privacy_mode = 4;
+ BoolOption block_input = 5;
+ int32 custom_image_quality = 6;
+ BoolOption disable_audio = 7;
+ BoolOption disable_clipboard = 8;
+ BoolOption enable_file_transfer = 9;
+}
+
+message OptionResponse {
+ OptionMessage opt = 1;
+ string error = 2;
+}
+
+message TestDelay {
+ int64 time = 1;
+ bool from_client = 2;
+}
+
+message PublicKey {
+ bytes asymmetric_value = 1;
+ bytes symmetric_value = 2;
+}
+
+message SignedId { bytes id = 1; }
+
+message AudioFormat {
+ uint32 sample_rate = 1;
+ uint32 channels = 2;
+}
+
+message AudioFrame { bytes data = 1; }
+
+message Misc {
+ oneof union {
+ ChatMessage chat_message = 4;
+ SwitchDisplay switch_display = 5;
+ PermissionInfo permission_info = 6;
+ OptionMessage option = 7;
+ AudioFormat audio_format = 8;
+ string close_reason = 9;
+ bool refresh_video = 10;
+ OptionResponse option_response = 11;
+ bool video_received = 12;
+ }
+}
+
+message Message {
+ oneof union {
+ SignedId signed_id = 3;
+ PublicKey public_key = 4;
+ TestDelay test_delay = 5;
+ VideoFrame video_frame = 6;
+ LoginRequest login_request = 7;
+ LoginResponse login_response = 8;
+ Hash hash = 9;
+ MouseEvent mouse_event = 10;
+ AudioFrame audio_frame = 11;
+ CursorData cursor_data = 12;
+ CursorPosition cursor_position = 13;
+ uint64 cursor_id = 14;
+ KeyEvent key_event = 15;
+ Clipboard clipboard = 16;
+ FileAction file_action = 17;
+ FileResponse file_response = 18;
+ Misc misc = 19;
+ Cliprdr cliprdr = 20;
+ }
+}
diff --git a/libs/hbb_common/protos/rendezvous.proto b/libs/hbb_common/protos/rendezvous.proto
new file mode 100644
index 0000000..2c5f1b3
--- /dev/null
+++ b/libs/hbb_common/protos/rendezvous.proto
@@ -0,0 +1,171 @@
+syntax = "proto3";
+package hbb;
+
+message RegisterPeer {
+ string id = 1;
+ int32 serial = 2;
+}
+
+enum ConnType {
+ DEFAULT_CONN = 0;
+ FILE_TRANSFER = 1;
+ PORT_FORWARD = 2;
+ RDP = 3;
+}
+
+message RegisterPeerResponse { bool request_pk = 2; }
+
+message PunchHoleRequest {
+ string id = 1;
+ NatType nat_type = 2;
+ string licence_key = 3;
+ ConnType conn_type = 4;
+ string token = 5;
+}
+
+message PunchHole {
+ bytes socket_addr = 1;
+ string relay_server = 2;
+ NatType nat_type = 3;
+}
+
+message TestNatRequest {
+ int32 serial = 1;
+}
+
+// per my test, uint/int has no difference in encoding, int not good for negative, use sint for negative
+message TestNatResponse {
+ int32 port = 1;
+ ConfigUpdate cu = 2; // for mobile
+}
+
+enum NatType {
+ UNKNOWN_NAT = 0;
+ ASYMMETRIC = 1;
+ SYMMETRIC = 2;
+}
+
+message PunchHoleSent {
+ bytes socket_addr = 1;
+ string id = 2;
+ string relay_server = 3;
+ NatType nat_type = 4;
+ string version = 5;
+}
+
+message RegisterPk {
+ string id = 1;
+ bytes uuid = 2;
+ bytes pk = 3;
+ string old_id = 4;
+}
+
+message RegisterPkResponse {
+ enum Result {
+ OK = 0;
+ UUID_MISMATCH = 2;
+ ID_EXISTS = 3;
+ TOO_FREQUENT = 4;
+ INVALID_ID_FORMAT = 5;
+ NOT_SUPPORT = 6;
+ SERVER_ERROR = 7;
+ }
+ Result result = 1;
+}
+
+message PunchHoleResponse {
+ bytes socket_addr = 1;
+ bytes pk = 2;
+ enum Failure {
+ ID_NOT_EXIST = 0;
+ OFFLINE = 2;
+ LICENSE_MISMATCH = 3;
+ LICENSE_OVERUSE = 4;
+ }
+ Failure failure = 3;
+ string relay_server = 4;
+ oneof union {
+ NatType nat_type = 5;
+ bool is_local = 6;
+ }
+ string other_failure = 7;
+}
+
+message ConfigUpdate {
+ int32 serial = 1;
+ repeated string rendezvous_servers = 2;
+}
+
+message RequestRelay {
+ string id = 1;
+ string uuid = 2;
+ bytes socket_addr = 3;
+ string relay_server = 4;
+ bool secure = 5;
+ string licence_key = 6;
+ ConnType conn_type = 7;
+ string token = 8;
+}
+
+message RelayResponse {
+ bytes socket_addr = 1;
+ string uuid = 2;
+ string relay_server = 3;
+ oneof union {
+ string id = 4;
+ bytes pk = 5;
+ }
+ string refuse_reason = 6;
+ string version = 7;
+}
+
+message SoftwareUpdate { string url = 1; }
+
+// if in same intranet, punch hole won't work both for udp and tcp,
+// even some router has below connection error if we connect itself,
+// { kind: Other, error: "could not resolve to any address" },
+// so we request local address to connect.
+message FetchLocalAddr {
+ bytes socket_addr = 1;
+ string relay_server = 2;
+}
+
+message LocalAddr {
+ bytes socket_addr = 1;
+ bytes local_addr = 2;
+ string relay_server = 3;
+ string id = 4;
+ string version = 5;
+}
+
+message PeerDiscovery {
+ string cmd = 1;
+ string mac = 2;
+ string id = 3;
+ string username = 4;
+ string hostname = 5;
+ string platform = 6;
+ string misc = 7;
+}
+
+message RendezvousMessage {
+ oneof union {
+ RegisterPeer register_peer = 6;
+ RegisterPeerResponse register_peer_response = 7;
+ PunchHoleRequest punch_hole_request = 8;
+ PunchHole punch_hole = 9;
+ PunchHoleSent punch_hole_sent = 10;
+ PunchHoleResponse punch_hole_response = 11;
+ FetchLocalAddr fetch_local_addr = 12;
+ LocalAddr local_addr = 13;
+ ConfigUpdate configure_update = 14;
+ RegisterPk register_pk = 15;
+ RegisterPkResponse register_pk_response = 16;
+ SoftwareUpdate software_update = 17;
+ RequestRelay request_relay = 18;
+ RelayResponse relay_response = 19;
+ TestNatRequest test_nat_request = 20;
+ TestNatResponse test_nat_response = 21;
+ PeerDiscovery peer_discovery = 22;
+ }
+}
diff --git a/libs/hbb_common/src/bytes_codec.rs b/libs/hbb_common/src/bytes_codec.rs
new file mode 100644
index 0000000..e029f1c
--- /dev/null
+++ b/libs/hbb_common/src/bytes_codec.rs
@@ -0,0 +1,274 @@
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use std::io;
+use tokio_util::codec::{Decoder, Encoder};
+
+#[derive(Debug, Clone, Copy)]
+pub struct BytesCodec {
+ state: DecodeState,
+ raw: bool,
+ max_packet_length: usize,
+}
+
+#[derive(Debug, Clone, Copy)]
+enum DecodeState {
+ Head,
+ Data(usize),
+}
+
+impl BytesCodec {
+ pub fn new() -> Self {
+ Self {
+ state: DecodeState::Head,
+ raw: false,
+ max_packet_length: usize::MAX,
+ }
+ }
+
+ pub fn set_raw(&mut self) {
+ self.raw = true;
+ }
+
+ pub fn set_max_packet_length(&mut self, n: usize) {
+ self.max_packet_length = n;
+ }
+
+ fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
+ if src.is_empty() {
+ return Ok(None);
+ }
+ let head_len = ((src[0] & 0x3) + 1) as usize;
+ if src.len() < head_len {
+ return Ok(None);
+ }
+ let mut n = src[0] as usize;
+ if head_len > 1 {
+ n |= (src[1] as usize) << 8;
+ }
+ if head_len > 2 {
+ n |= (src[2] as usize) << 16;
+ }
+ if head_len > 3 {
+ n |= (src[3] as usize) << 24;
+ }
+ n >>= 2;
+ if n > self.max_packet_length {
+ return Err(io::Error::new(io::ErrorKind::InvalidData, "Too big packet"));
+ }
+ src.advance(head_len);
+ src.reserve(n);
+ return Ok(Some(n));
+ }
+
+ fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
+ if src.len() < n {
+ return Ok(None);
+ }
+ Ok(Some(src.split_to(n)))
+ }
+}
+
+impl Decoder for BytesCodec {
+ type Item = BytesMut;
+ type Error = io::Error;
+
+ fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
+ if self.raw {
+ if !src.is_empty() {
+ let len = src.len();
+ return Ok(Some(src.split_to(len)));
+ } else {
+ return Ok(None);
+ }
+ }
+ let n = match self.state {
+ DecodeState::Head => match self.decode_head(src)? {
+ Some(n) => {
+ self.state = DecodeState::Data(n);
+ n
+ }
+ None => return Ok(None),
+ },
+ DecodeState::Data(n) => n,
+ };
+
+ match self.decode_data(n, src)? {
+ Some(data) => {
+ self.state = DecodeState::Head;
+ Ok(Some(data))
+ }
+ None => Ok(None),
+ }
+ }
+}
+
+impl Encoder<Bytes> for BytesCodec {
+ type Error = io::Error;
+
+ fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
+ if self.raw {
+ buf.reserve(data.len());
+ buf.put(data);
+ return Ok(());
+ }
+ if data.len() <= 0x3F {
+ buf.put_u8((data.len() << 2) as u8);
+ } else if data.len() <= 0x3FFF {
+ buf.put_u16_le((data.len() << 2) as u16 | 0x1);
+ } else if data.len() <= 0x3FFFFF {
+ let h = (data.len() << 2) as u32 | 0x2;
+ buf.put_u16_le((h & 0xFFFF) as u16);
+ buf.put_u8((h >> 16) as u8);
+ } else if data.len() <= 0x3FFFFFFF {
+ buf.put_u32_le((data.len() << 2) as u32 | 0x3);
+ } else {
+ return Err(io::Error::new(io::ErrorKind::InvalidInput, "Overflow"));
+ }
+ buf.extend(data);
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ #[test]
+ fn test_codec1() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ bytes.resize(0x3F, 1);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ let buf_saved = buf.clone();
+ assert_eq!(buf.len(), 0x3F + 1);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3F);
+ assert_eq!(res[0], 1);
+ } else {
+ assert!(false);
+ }
+ let mut codec2 = BytesCodec::new();
+ let mut buf2 = BytesMut::new();
+ if let Ok(None) = codec2.decode(&mut buf2) {
+ } else {
+ assert!(false);
+ }
+ buf2.extend(&buf_saved[0..1]);
+ if let Ok(None) = codec2.decode(&mut buf2) {
+ } else {
+ assert!(false);
+ }
+ buf2.extend(&buf_saved[1..]);
+ if let Ok(Some(res)) = codec2.decode(&mut buf2) {
+ assert_eq!(res.len(), 0x3F);
+ assert_eq!(res[0], 1);
+ } else {
+ assert!(false);
+ }
+ }
+
+ #[test]
+ fn test_codec2() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ assert!(!codec.encode("".into(), &mut buf).is_err());
+ assert_eq!(buf.len(), 1);
+ bytes.resize(0x3F + 1, 2);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ assert_eq!(buf.len(), 0x3F + 2 + 2);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0);
+ } else {
+ assert!(false);
+ }
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3F + 1);
+ assert_eq!(res[0], 2);
+ } else {
+ assert!(false);
+ }
+ }
+
+ #[test]
+ fn test_codec3() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ bytes.resize(0x3F - 1, 3);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ assert_eq!(buf.len(), 0x3F + 1 - 1);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3F - 1);
+ assert_eq!(res[0], 3);
+ } else {
+ assert!(false);
+ }
+ }
+ #[test]
+ fn test_codec4() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ bytes.resize(0x3FFF, 4);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ assert_eq!(buf.len(), 0x3FFF + 2);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3FFF);
+ assert_eq!(res[0], 4);
+ } else {
+ assert!(false);
+ }
+ }
+
+ #[test]
+ fn test_codec5() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ bytes.resize(0x3FFFFF, 5);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ assert_eq!(buf.len(), 0x3FFFFF + 3);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3FFFFF);
+ assert_eq!(res[0], 5);
+ } else {
+ assert!(false);
+ }
+ }
+
+ #[test]
+ fn test_codec6() {
+ let mut codec = BytesCodec::new();
+ let mut buf = BytesMut::new();
+ let mut bytes: Vec<u8> = Vec::new();
+ bytes.resize(0x3FFFFF + 1, 6);
+ assert!(!codec.encode(bytes.into(), &mut buf).is_err());
+ let buf_saved = buf.clone();
+ assert_eq!(buf.len(), 0x3FFFFF + 4 + 1);
+ if let Ok(Some(res)) = codec.decode(&mut buf) {
+ assert_eq!(res.len(), 0x3FFFFF + 1);
+ assert_eq!(res[0], 6);
+ } else {
+ assert!(false);
+ }
+ let mut codec2 = BytesCodec::new();
+ let mut buf2 = BytesMut::new();
+ buf2.extend(&buf_saved[0..1]);
+ if let Ok(None) = codec2.decode(&mut buf2) {
+ } else {
+ assert!(false);
+ }
+ buf2.extend(&buf_saved[1..6]);
+ if let Ok(None) = codec2.decode(&mut buf2) {
+ } else {
+ assert!(false);
+ }
+ buf2.extend(&buf_saved[6..]);
+ if let Ok(Some(res)) = codec2.decode(&mut buf2) {
+ assert_eq!(res.len(), 0x3FFFFF + 1);
+ assert_eq!(res[0], 6);
+ } else {
+ assert!(false);
+ }
+ }
+}
diff --git a/libs/hbb_common/src/compress.rs b/libs/hbb_common/src/compress.rs
new file mode 100644
index 0000000..a969ccf
--- /dev/null
+++ b/libs/hbb_common/src/compress.rs
@@ -0,0 +1,50 @@
+use std::cell::RefCell;
+use zstd::block::{Compressor, Decompressor};
+
+thread_local! {
+ static COMPRESSOR: RefCell<Compressor> = RefCell::new(Compressor::new());
+ static DECOMPRESSOR: RefCell<Decompressor> = RefCell::new(Decompressor::new());
+}
+
+/// The library supports regular compression levels from 1 up to ZSTD_maxCLevel(),
+/// which is currently 22. Levels >= 20
+/// Default level is ZSTD_CLEVEL_DEFAULT==3.
+/// value 0 means default, which is controlled by ZSTD_CLEVEL_DEFAULT
+pub fn compress(data: &[u8], level: i32) -> Vec<u8> {
+ let mut out = Vec::new();
+ COMPRESSOR.with(|c| {
+ if let Ok(mut c) = c.try_borrow_mut() {
+ match c.compress(data, level) {
+ Ok(res) => out = res,
+ Err(err) => {
+ crate::log::debug!("Failed to compress: {}", err);
+ }
+ }
+ }
+ });
+ out
+}
+
+pub fn decompress(data: &[u8]) -> Vec<u8> {
+ let mut out = Vec::new();
+ DECOMPRESSOR.with(|d| {
+ if let Ok(mut d) = d.try_borrow_mut() {
+ const MAX: usize = 1024 * 1024 * 64;
+ const MIN: usize = 1024 * 1024;
+ let mut n = 30 * data.len();
+ if n > MAX {
+ n = MAX;
+ }
+ if n < MIN {
+ n = MIN;
+ }
+ match d.decompress(data, n) {
+ Ok(res) => out = res,
+ Err(err) => {
+ crate::log::debug!("Failed to decompress: {}", err);
+ }
+ }
+ }
+ });
+ out
+}
diff --git a/libs/hbb_common/src/config.rs b/libs/hbb_common/src/config.rs
new file mode 100644
index 0000000..a7c1bc6
--- /dev/null
+++ b/libs/hbb_common/src/config.rs
@@ -0,0 +1,876 @@
+use crate::log;
+use directories_next::ProjectDirs;
+use rand::Rng;
+use serde_derive::{Deserialize, Serialize};
+use sodiumoxide::crypto::sign;
+use std::{
+ collections::HashMap,
+ fs,
+ net::{IpAddr, Ipv4Addr, SocketAddr},
+ path::{Path, PathBuf},
+ sync::{Arc, Mutex, RwLock},
+ time::SystemTime,
+};
+
+pub const RENDEZVOUS_TIMEOUT: u64 = 12_000;
+pub const CONNECT_TIMEOUT: u64 = 18_000;
+pub const REG_INTERVAL: i64 = 12_000;
+pub const COMPRESS_LEVEL: i32 = 3;
+const SERIAL: i32 = 1;
+// 128x128
+#[cfg(target_os = "macos")] // 128x128 on 160x160 canvas, then shrink to 128, mac looks better with padding
+pub const ICON: &str = "
+";
+#[cfg(not(target_os = "macos"))] // 128x128 no padding
+pub const ICON: &str = "
+";
+#[cfg(target_os = "macos")]
+lazy_static::lazy_static! {
+ pub static ref ORG: Arc<RwLock<String>> = Arc::new(RwLock::new("com.carriez".to_owned()));
+}
+
+type Size = (i32, i32, i32, i32);
+
+lazy_static::lazy_static! {
+ static ref CONFIG: Arc<RwLock<Config>> = Arc::new(RwLock::new(Config::load()));
+ static ref CONFIG2: Arc<RwLock<Config2>> = Arc::new(RwLock::new(Config2::load()));
+ static ref LOCAL_CONFIG: Arc<RwLock<LocalConfig>> = Arc::new(RwLock::new(LocalConfig::load()));
+ pub static ref ONLINE: Arc<Mutex<HashMap<String, i64>>> = Default::default();
+ pub static ref PROD_RENDEZVOUS_SERVER: Arc<RwLock<String>> = Default::default();
+ pub static ref APP_NAME: Arc<RwLock<String>> = Arc::new(RwLock::new("RustDesk".to_owned()));
+}
+#[cfg(any(target_os = "android", target_os = "ios"))]
+lazy_static::lazy_static! {
+ pub static ref APP_DIR: Arc<RwLock<String>> = Default::default();
+ pub static ref APP_HOME_DIR: Arc<RwLock<String>> = Default::default();
+}
+const CHARS: &'static [char] = &[
+ '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k',
+ 'm', 'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+];
+
+pub const RENDEZVOUS_SERVERS: &'static [&'static str] = &[
+ "rs-ny.rustdesk.com",
+ "rs-sg.rustdesk.com",
+ "rs-cn.rustdesk.com",
+];
+pub const RENDEZVOUS_PORT: i32 = 21116;
+pub const RELAY_PORT: i32 = 21117;
+
+#[derive(Clone, Copy, PartialEq, Eq, Debug)]
+pub enum NetworkType {
+ Direct,
+ ProxySocks,
+}
+
+#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
+pub struct Config {
+ #[serde(default)]
+ pub id: String,
+ #[serde(default)]
+ password: String,
+ #[serde(default)]
+ salt: String,
+ #[serde(default)]
+ pub key_pair: (Vec<u8>, Vec<u8>), // sk, pk
+ #[serde(default)]
+ key_confirmed: bool,
+ #[serde(default)]
+ keys_confirmed: HashMap<String, bool>,
+}
+
+#[derive(Debug, Default, PartialEq, Serialize, Deserialize, Clone)]
+pub struct Socks5Server {
+ #[serde(default)]
+ pub proxy: String,
+ #[serde(default)]
+ pub username: String,
+ #[serde(default)]
+ pub password: String,
+}
+
+// more variable configs
+#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
+pub struct Config2 {
+ #[serde(default)]
+ rendezvous_server: String,
+ #[serde(default)]
+ nat_type: i32,
+ #[serde(default)]
+ serial: i32,
+
+ #[serde(default)]
+ socks: Option<Socks5Server>,
+
+ // the other scalar value must before this
+ #[serde(default)]
+ pub options: HashMap<String, String>,
+}
+
+#[derive(Debug, Default, Serialize, Deserialize, Clone)]
+pub struct PeerConfig {
+ #[serde(default)]
+ pub password: Vec<u8>,
+ #[serde(default)]
+ pub size: Size,
+ #[serde(default)]
+ pub size_ft: Size,
+ #[serde(default)]
+ pub size_pf: Size,
+ #[serde(default)]
+ pub view_style: String, // original (default), scale
+ #[serde(default)]
+ pub image_quality: String,
+ #[serde(default)]
+ pub custom_image_quality: Vec<i32>,
+ #[serde(default)]
+ pub show_remote_cursor: bool,
+ #[serde(default)]
+ pub lock_after_session_end: bool,
+ #[serde(default)]
+ pub privacy_mode: bool,
+ #[serde(default)]
+ pub port_forwards: Vec<(i32, String, i32)>,
+ #[serde(default)]
+ pub direct_failures: i32,
+ #[serde(default)]
+ pub disable_audio: bool,
+ #[serde(default)]
+ pub disable_clipboard: bool,
+ #[serde(default)]
+ pub enable_file_transfer: bool,
+
+ // the other scalar value must before this
+ #[serde(default)]
+ pub options: HashMap<String, String>,
+ #[serde(default)]
+ pub info: PeerInfoSerde,
+}
+
+#[derive(Debug, PartialEq, Default, Serialize, Deserialize, Clone)]
+pub struct PeerInfoSerde {
+ #[serde(default)]
+ pub username: String,
+ #[serde(default)]
+ pub hostname: String,
+ #[serde(default)]
+ pub platform: String,
+}
+
+fn patch(path: PathBuf) -> PathBuf {
+ if let Some(_tmp) = path.to_str() {
+ #[cfg(windows)]
+ return _tmp
+ .replace(
+ "system32\\config\\systemprofile",
+ "ServiceProfiles\\LocalService",
+ )
+ .into();
+ #[cfg(target_os = "macos")]
+ return _tmp.replace("Application Support", "Preferences").into();
+ #[cfg(target_os = "linux")]
+ {
+ if _tmp == "/root" {
+ if let Ok(output) = std::process::Command::new("whoami").output() {
+ let user = String::from_utf8_lossy(&output.stdout)
+ .to_string()
+ .trim()
+ .to_owned();
+ if user != "root" {
+ return format!("/home/{}", user).into();
+ }
+ }
+ }
+ }
+ }
+ path
+}
+
+impl Config2 {
+ fn load() -> Config2 {
+ Config::load_::<Config2>("2")
+ }
+
+ pub fn file() -> PathBuf {
+ Config::file_("2")
+ }
+
+ fn store(&self) {
+ Config::store_(self, "2");
+ }
+
+ pub fn get() -> Config2 {
+ return CONFIG2.read().unwrap().clone();
+ }
+
+ pub fn set(cfg: Config2) -> bool {
+ let mut lock = CONFIG2.write().unwrap();
+ if *lock == cfg {
+ return false;
+ }
+ *lock = cfg;
+ lock.store();
+ true
+ }
+}
+
+pub fn load_path<T: serde::Serialize + serde::de::DeserializeOwned + Default + std::fmt::Debug>(
+ file: PathBuf,
+) -> T {
+ let cfg = match confy::load_path(&file) {
+ Ok(config) => config,
+ Err(err) => {
+ log::error!("Failed to load config: {}", err);
+ T::default()
+ }
+ };
+ cfg
+}
+
+impl Config {
+ fn load_<T: serde::Serialize + serde::de::DeserializeOwned + Default + std::fmt::Debug>(
+ suffix: &str,
+ ) -> T {
+ let file = Self::file_(suffix);
+ log::debug!("Configuration path: {}", file.display());
+ let cfg = load_path(file);
+ if suffix.is_empty() {
+ log::debug!("{:?}", cfg);
+ }
+ cfg
+ }
+
+ fn store_<T: serde::Serialize>(config: &T, suffix: &str) {
+ let file = Self::file_(suffix);
+ if let Err(err) = confy::store_path(file, config) {
+ log::error!("Failed to store config: {}", err);
+ }
+ }
+
+ fn load() -> Config {
+ Config::load_::<Config>("")
+ }
+
+ fn store(&self) {
+ Config::store_(self, "");
+ }
+
+ pub fn file() -> PathBuf {
+ Self::file_("")
+ }
+
+ fn file_(suffix: &str) -> PathBuf {
+ let name = format!("{}{}", *APP_NAME.read().unwrap(), suffix);
+ Self::path(name).with_extension("toml")
+ }
+
+ pub fn get_home() -> PathBuf {
+ #[cfg(any(target_os = "android", target_os = "ios"))]
+ return Self::path(APP_HOME_DIR.read().unwrap().as_str());
+ if let Some(path) = dirs_next::home_dir() {
+ patch(path)
+ } else if let Ok(path) = std::env::current_dir() {
+ path
+ } else {
+ std::env::temp_dir()
+ }
+ }
+
+ pub fn path<P: AsRef<Path>>(p: P) -> PathBuf {
+ #[cfg(any(target_os = "android", target_os = "ios"))]
+ {
+ let mut path: PathBuf = APP_DIR.read().unwrap().clone().into();
+ path.push(p);
+ return path;
+ }
+ #[cfg(not(target_os = "macos"))]
+ let org = "";
+ #[cfg(target_os = "macos")]
+ let org = ORG.read().unwrap().clone();
+ // /var/root for root
+ if let Some(project) = ProjectDirs::from("", &org, &*APP_NAME.read().unwrap()) {
+ let mut path = patch(project.config_dir().to_path_buf());
+ path.push(p);
+ return path;
+ }
+ return "".into();
+ }
+
+ #[allow(unreachable_code)]
+ pub fn log_path() -> PathBuf {
+ #[cfg(target_os = "macos")]
+ {
+ if let Some(path) = dirs_next::home_dir().as_mut() {
+ path.push(format!("Library/Logs/{}", *APP_NAME.read().unwrap()));
+ return path.clone();
+ }
+ }
+ #[cfg(target_os = "linux")]
+ {
+ let mut path = Self::get_home();
+ path.push(format!(".local/share/logs/{}", *APP_NAME.read().unwrap()));
+ std::fs::create_dir_all(&path).ok();
+ return path;
+ }
+ if let Some(path) = Self::path("").parent() {
+ let mut path: PathBuf = path.into();
+ path.push("log");
+ return path;
+ }
+ "".into()
+ }
+
+ pub fn ipc_path(postfix: &str) -> String {
+ #[cfg(windows)]
+ {
+ // \\ServerName\pipe\PipeName
+ // where ServerName is either the name of a remote computer or a period, to specify the local computer.
+ // https://docs.microsoft.com/en-us/windows/win32/ipc/pipe-names
+ format!(
+ "\\\\.\\pipe\\{}\\query{}",
+ *APP_NAME.read().unwrap(),
+ postfix
+ )
+ }
+ #[cfg(not(windows))]
+ {
+ use std::os::unix::fs::PermissionsExt;
+ #[cfg(target_os = "android")]
+ let mut path: PathBuf =
+ format!("{}/{}", *APP_DIR.read().unwrap(), *APP_NAME.read().unwrap()).into();
+ #[cfg(not(target_os = "android"))]
+ let mut path: PathBuf = format!("/tmp/{}", *APP_NAME.read().unwrap()).into();
+ fs::create_dir(&path).ok();
+ fs::set_permissions(&path, fs::Permissions::from_mode(0o0777)).ok();
+ path.push(format!("ipc{}", postfix));
+ path.to_str().unwrap_or("").to_owned()
+ }
+ }
+
+ pub fn icon_path() -> PathBuf {
+ let mut path = Self::path("icons");
+ if fs::create_dir_all(&path).is_err() {
+ path = std::env::temp_dir();
+ }
+ path
+ }
+
+ #[inline]
+ pub fn get_any_listen_addr() -> SocketAddr {
+ SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)
+ }
+
+ pub fn get_rendezvous_server() -> String {
+ let mut rendezvous_server = Self::get_option("custom-rendezvous-server");
+ if rendezvous_server.is_empty() {
+ rendezvous_server = PROD_RENDEZVOUS_SERVER.read().unwrap().clone();
+ }
+ if rendezvous_server.is_empty() {
+ rendezvous_server = CONFIG2.read().unwrap().rendezvous_server.clone();
+ }
+ if rendezvous_server.is_empty() {
+ rendezvous_server = Self::get_rendezvous_servers()
+ .drain(..)
+ .next()
+ .unwrap_or("".to_owned());
+ }
+ if !rendezvous_server.contains(":") {
+ rendezvous_server = format!("{}:{}", rendezvous_server, RENDEZVOUS_PORT);
+ }
+ rendezvous_server
+ }
+
+ pub fn get_rendezvous_servers() -> Vec<String> {
+ let s = Self::get_option("custom-rendezvous-server");
+ if !s.is_empty() {
+ return vec![s];
+ }
+ let s = PROD_RENDEZVOUS_SERVER.read().unwrap().clone();
+ if !s.is_empty() {
+ return vec![s];
+ }
+ let serial_obsolute = CONFIG2.read().unwrap().serial > SERIAL;
+ if serial_obsolute {
+ let ss: Vec<String> = Self::get_option("rendezvous-servers")
+ .split(",")
+ .filter(|x| x.contains("."))
+ .map(|x| x.to_owned())
+ .collect();
+ if !ss.is_empty() {
+ return ss;
+ }
+ }
+ return RENDEZVOUS_SERVERS.iter().map(|x| x.to_string()).collect();
+ }
+
+ pub fn reset_online() {
+ *ONLINE.lock().unwrap() = Default::default();
+ }
+
+ pub fn update_latency(host: &str, latency: i64) {
+ ONLINE.lock().unwrap().insert(host.to_owned(), latency);
+ let mut host = "".to_owned();
+ let mut delay = i64::MAX;
+ for (tmp_host, tmp_delay) in ONLINE.lock().unwrap().iter() {
+ if tmp_delay > &0 && tmp_delay < &delay {
+ delay = tmp_delay.clone();
+ host = tmp_host.to_string();
+ }
+ }
+ if !host.is_empty() {
+ let mut config = CONFIG2.write().unwrap();
+ if host != config.rendezvous_server {
+ log::debug!("Update rendezvous_server in config to {}", host);
+ log::debug!("{:?}", *ONLINE.lock().unwrap());
+ config.rendezvous_server = host;
+ config.store();
+ }
+ }
+ }
+
+ pub fn set_id(id: &str) {
+ let mut config = CONFIG.write().unwrap();
+ if id == config.id {
+ return;
+ }
+ config.id = id.into();
+ config.store();
+ }
+
+ pub fn set_nat_type(nat_type: i32) {
+ let mut config = CONFIG2.write().unwrap();
+ if nat_type == config.nat_type {
+ return;
+ }
+ config.nat_type = nat_type;
+ config.store();
+ }
+
+ pub fn get_nat_type() -> i32 {
+ CONFIG2.read().unwrap().nat_type
+ }
+
+ pub fn set_serial(serial: i32) {
+ let mut config = CONFIG2.write().unwrap();
+ if serial == config.serial {
+ return;
+ }
+ config.serial = serial;
+ config.store();
+ }
+
+ pub fn get_serial() -> i32 {
+ std::cmp::max(CONFIG2.read().unwrap().serial, SERIAL)
+ }
+
+ fn get_auto_id() -> Option<String> {
+ #[cfg(any(target_os = "android", target_os = "ios"))]
+ {
+ return Some(
+ rand::thread_rng()
+ .gen_range(1_000_000_000..2_000_000_000)
+ .to_string(),
+ );
+ }
+ let mut id = 0u32;
+ #[cfg(not(any(target_os = "android", target_os = "ios")))]
+ if let Ok(Some(ma)) = mac_address::get_mac_address() {
+ for x in &ma.bytes()[2..] {
+ id = (id << 8) | (*x as u32);
+ }
+ id = id & 0x1FFFFFFF;
+ Some(id.to_string())
+ } else {
+ None
+ }
+ }
+
+ pub fn get_auto_password() -> String {
+ let mut rng = rand::thread_rng();
+ (0..6)
+ .map(|_| CHARS[rng.gen::<usize>() % CHARS.len()])
+ .collect()
+ }
+
+ pub fn get_key_confirmed() -> bool {
+ CONFIG.read().unwrap().key_confirmed
+ }
+
+ pub fn set_key_confirmed(v: bool) {
+ let mut config = CONFIG.write().unwrap();
+ if config.key_confirmed == v {
+ return;
+ }
+ config.key_confirmed = v;
+ if !v {
+ config.keys_confirmed = Default::default();
+ }
+ config.store();
+ }
+
+ pub fn get_host_key_confirmed(host: &str) -> bool {
+ if let Some(true) = CONFIG.read().unwrap().keys_confirmed.get(host) {
+ true
+ } else {
+ false
+ }
+ }
+
+ pub fn set_host_key_confirmed(host: &str, v: bool) {
+ if Self::get_host_key_confirmed(host) == v {
+ return;
+ }
+ let mut config = CONFIG.write().unwrap();
+ config.keys_confirmed.insert(host.to_owned(), v);
+ config.store();
+ }
+
+ pub fn set_key_pair(pair: (Vec<u8>, Vec<u8>)) {
+ let mut config = CONFIG.write().unwrap();
+ if config.key_pair == pair {
+ return;
+ }
+ config.key_pair = pair;
+ config.store();
+ }
+
+ pub fn get_key_pair() -> (Vec<u8>, Vec<u8>) {
+ // lock here to make sure no gen_keypair more than once
+ let mut config = CONFIG.write().unwrap();
+ if config.key_pair.0.is_empty() {
+ let (pk, sk) = sign::gen_keypair();
+ config.key_pair = (sk.0.to_vec(), pk.0.into());
+ config.store();
+ }
+ config.key_pair.clone()
+ }
+
+ pub fn get_id() -> String {
+ let mut id = CONFIG.read().unwrap().id.clone();
+ if id.is_empty() {
+ if let Some(tmp) = Config::get_auto_id() {
+ id = tmp;
+ Config::set_id(&id);
+ }
+ }
+ id
+ }
+
+ pub fn get_id_or(b: String) -> String {
+ let a = CONFIG.read().unwrap().id.clone();
+ if a.is_empty() {
+ b
+ } else {
+ a
+ }
+ }
+
+ pub fn get_options() -> HashMap<String, String> {
+ CONFIG2.read().unwrap().options.clone()
+ }
+
+ pub fn set_options(v: HashMap<String, String>) {
+ let mut config = CONFIG2.write().unwrap();
+ if config.options == v {
+ return;
+ }
+ config.options = v;
+ config.store();
+ }
+
+ pub fn get_option(k: &str) -> String {
+ if let Some(v) = CONFIG2.read().unwrap().options.get(k) {
+ v.clone()
+ } else {
+ "".to_owned()
+ }
+ }
+
+ pub fn set_option(k: String, v: String) {
+ let mut config = CONFIG2.write().unwrap();
+ let v2 = if v.is_empty() { None } else { Some(&v) };
+ if v2 != config.options.get(&k) {
+ if v2.is_none() {
+ config.options.remove(&k);
+ } else {
+ config.options.insert(k, v);
+ }
+ config.store();
+ }
+ }
+
+ pub fn update_id() {
+ // to-do: how about if one ip register a lot of ids?
+ let id = Self::get_id();
+ let mut rng = rand::thread_rng();
+ let new_id = rng.gen_range(1_000_000_000..2_000_000_000).to_string();
+ Config::set_id(&new_id);
+ log::info!("id updated from {} to {}", id, new_id);
+ }
+
+ pub fn set_password(password: &str) {
+ let mut config = CONFIG.write().unwrap();
+ if password == config.password {
+ return;
+ }
+ config.password = password.into();
+ config.store();
+ }
+
+ pub fn get_password() -> String {
+ let mut password = CONFIG.read().unwrap().password.clone();
+ if password.is_empty() {
+ password = Config::get_auto_password();
+ Config::set_password(&password);
+ }
+ password
+ }
+
+ pub fn set_salt(salt: &str) {
+ let mut config = CONFIG.write().unwrap();
+ if salt == config.salt {
+ return;
+ }
+ config.salt = salt.into();
+ config.store();
+ }
+
+ pub fn get_salt() -> String {
+ let mut salt = CONFIG.read().unwrap().salt.clone();
+ if salt.is_empty() {
+ salt = Config::get_auto_password();
+ Config::set_salt(&salt);
+ }
+ salt
+ }
+
+ pub fn set_socks(socks: Option<Socks5Server>) {
+ let mut config = CONFIG2.write().unwrap();
+ if config.socks == socks {
+ return;
+ }
+ config.socks = socks;
+ config.store();
+ }
+
+ pub fn get_socks() -> Option<Socks5Server> {
+ CONFIG2.read().unwrap().socks.clone()
+ }
+
+ pub fn get_network_type() -> NetworkType {
+ match &CONFIG2.read().unwrap().socks {
+ None => NetworkType::Direct,
+ Some(_) => NetworkType::ProxySocks,
+ }
+ }
+
+ pub fn get() -> Config {
+ return CONFIG.read().unwrap().clone();
+ }
+
+ pub fn set(cfg: Config) -> bool {
+ let mut lock = CONFIG.write().unwrap();
+ if *lock == cfg {
+ return false;
+ }
+ *lock = cfg;
+ lock.store();
+ true
+ }
+}
+
+const PEERS: &str = "peers";
+
+impl PeerConfig {
+ pub fn load(id: &str) -> PeerConfig {
+ let _ = CONFIG.read().unwrap(); // for lock
+ match confy::load_path(&Self::path(id)) {
+ Ok(config) => config,
+ Err(err) => {
+ log::error!("Failed to load config: {}", err);
+ Default::default()
+ }
+ }
+ }
+
+ pub fn store(&self, id: &str) {
+ let _ = CONFIG.read().unwrap(); // for lock
+ if let Err(err) = confy::store_path(Self::path(id), self) {
+ log::error!("Failed to store config: {}", err);
+ }
+ }
+
+ pub fn remove(id: &str) {
+ fs::remove_file(&Self::path(id)).ok();
+ }
+
+ fn path(id: &str) -> PathBuf {
+ let path: PathBuf = [PEERS, id].iter().collect();
+ Config::path(path).with_extension("toml")
+ }
+
+ pub fn peers() -> Vec<(String, SystemTime, PeerConfig)> {
+ if let Ok(peers) = Config::path(PEERS).read_dir() {
+ if let Ok(peers) = peers
+ .map(|res| res.map(|e| e.path()))
+ .collect::<Result<Vec<_>, _>>()
+ {
+ let mut peers: Vec<_> = peers
+ .iter()
+ .filter(|p| {
+ p.is_file()
+ && p.extension().map(|p| p.to_str().unwrap_or("")) == Some("toml")
+ })
+ .map(|p| {
+ let t = crate::get_modified_time(&p);
+ let id = p
+ .file_stem()
+ .map(|p| p.to_str().unwrap_or(""))
+ .unwrap_or("")
+ .to_owned();
+ let c = PeerConfig::load(&id);
+ if c.info.platform.is_empty() {
+ fs::remove_file(&p).ok();
+ }
+ (id, t, c)
+ })
+ .filter(|p| !p.2.info.platform.is_empty())
+ .collect();
+ peers.sort_unstable_by(|a, b| b.1.cmp(&a.1));
+ return peers;
+ }
+ }
+ Default::default()
+ }
+}
+
+#[derive(Debug, Default, Serialize, Deserialize, Clone)]
+pub struct LocalConfig {
+ #[serde(default)]
+ remote_id: String, // latest used one
+ #[serde(default)]
+ size: Size,
+ #[serde(default)]
+ pub fav: Vec<String>,
+ #[serde(default)]
+ options: HashMap<String, String>,
+}
+
+impl LocalConfig {
+ fn load() -> LocalConfig {
+ Config::load_::<LocalConfig>("_local")
+ }
+
+ fn store(&self) {
+ Config::store_(self, "_local");
+ }
+
+ pub fn get_size() -> Size {
+ LOCAL_CONFIG.read().unwrap().size
+ }
+
+ pub fn set_size(x: i32, y: i32, w: i32, h: i32) {
+ let mut config = LOCAL_CONFIG.write().unwrap();
+ let size = (x, y, w, h);
+ if size == config.size || size.2 < 300 || size.3 < 300 {
+ return;
+ }
+ config.size = size;
+ config.store();
+ }
+
+ pub fn set_remote_id(remote_id: &str) {
+ let mut config = LOCAL_CONFIG.write().unwrap();
+ if remote_id == config.remote_id {
+ return;
+ }
+ config.remote_id = remote_id.into();
+ config.store();
+ }
+
+ pub fn get_remote_id() -> String {
+ LOCAL_CONFIG.read().unwrap().remote_id.clone()
+ }
+
+ pub fn set_fav(fav: Vec<String>) {
+ let mut lock = LOCAL_CONFIG.write().unwrap();
+ if lock.fav == fav {
+ return;
+ }
+ lock.fav = fav;
+ lock.store();
+ }
+
+ pub fn get_fav() -> Vec<String> {
+ LOCAL_CONFIG.read().unwrap().fav.clone()
+ }
+
+ pub fn get_option(k: &str) -> String {
+ if let Some(v) = LOCAL_CONFIG.read().unwrap().options.get(k) {
+ v.clone()
+ } else {
+ "".to_owned()
+ }
+ }
+
+ pub fn set_option(k: String, v: String) {
+ let mut config = LOCAL_CONFIG.write().unwrap();
+ let v2 = if v.is_empty() { None } else { Some(&v) };
+ if v2 != config.options.get(&k) {
+ if v2.is_none() {
+ config.options.remove(&k);
+ } else {
+ config.options.insert(k, v);
+ }
+ config.store();
+ }
+ }
+}
+
+#[derive(Debug, Default, Serialize, Deserialize, Clone)]
+pub struct LanPeers {
+ #[serde(default)]
+ pub peers: String,
+}
+
+impl LanPeers {
+ pub fn load() -> LanPeers {
+ let _ = CONFIG.read().unwrap(); // for lock
+ match confy::load_path(&Config::file_("_lan_peers")) {
+ Ok(peers) => peers,
+ Err(err) => {
+ log::error!("Failed to load lan peers: {}", err);
+ Default::default()
+ }
+ }
+ }
+
+ pub fn store(peers: String) {
+ let f = LanPeers { peers };
+ if let Err(err) = confy::store_path(Config::file_("_lan_peers"), f) {
+ log::error!("Failed to store lan peers: {}", err);
+ }
+ }
+
+ pub fn modify_time() -> crate::ResultType<u64> {
+ let p = Config::file_("_lan_peers");
+ Ok(fs::metadata(p)?
+ .modified()?
+ .duration_since(SystemTime::UNIX_EPOCH)?
+ .as_millis() as _)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ #[test]
+ fn test_serialize() {
+ let cfg: Config = Default::default();
+ let res = toml::to_string_pretty(&cfg);
+ assert!(res.is_ok());
+ let cfg: PeerConfig = Default::default();
+ let res = toml::to_string_pretty(&cfg);
+ assert!(res.is_ok());
+ }
+}
diff --git a/libs/hbb_common/src/fs.rs b/libs/hbb_common/src/fs.rs
new file mode 100644
index 0000000..475f4df
--- /dev/null
+++ b/libs/hbb_common/src/fs.rs
@@ -0,0 +1,560 @@
+use crate::{bail, message_proto::*, ResultType};
+use std::path::{Path, PathBuf};
+// https://doc.rust-lang.org/std/os/windows/fs/trait.MetadataExt.html
+use crate::{
+ compress::{compress, decompress},
+ config::{Config, COMPRESS_LEVEL},
+};
+#[cfg(windows)]
+use std::os::windows::prelude::*;
+use tokio::{fs::File, io::*};
+
+pub fn read_dir(path: &PathBuf, include_hidden: bool) -> ResultType<FileDirectory> {
+ let mut dir = FileDirectory {
+ path: get_string(&path),
+ ..Default::default()
+ };
+ #[cfg(windows)]
+ if "/" == &get_string(&path) {
+ let drives = unsafe { winapi::um::fileapi::GetLogicalDrives() };
+ for i in 0..32 {
+ if drives & (1 << i) != 0 {
+ let name = format!(
+ "{}:",
+ std::char::from_u32('A' as u32 + i as u32).unwrap_or('A')
+ );
+ dir.entries.push(FileEntry {
+ name,
+ entry_type: FileType::DirDrive.into(),
+ ..Default::default()
+ });
+ }
+ }
+ return Ok(dir);
+ }
+ for entry in path.read_dir()? {
+ if let Ok(entry) = entry {
+ let p = entry.path();
+ let name = p
+ .file_name()
+ .map(|p| p.to_str().unwrap_or(""))
+ .unwrap_or("")
+ .to_owned();
+ if name.is_empty() {
+ continue;
+ }
+ let mut is_hidden = false;
+ let meta;
+ if let Ok(tmp) = std::fs::symlink_metadata(&p) {
+ meta = tmp;
+ } else {
+ continue;
+ }
+ // docs.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants
+ #[cfg(windows)]
+ if meta.file_attributes() & 0x2 != 0 {
+ is_hidden = true;
+ }
+ #[cfg(not(windows))]
+ if name.find('.').unwrap_or(usize::MAX) == 0 {
+ is_hidden = true;
+ }
+ if is_hidden && !include_hidden {
+ continue;
+ }
+ let (entry_type, size) = {
+ if p.is_dir() {
+ if meta.file_type().is_symlink() {
+ (FileType::DirLink.into(), 0)
+ } else {
+ (FileType::Dir.into(), 0)
+ }
+ } else {
+ if meta.file_type().is_symlink() {
+ (FileType::FileLink.into(), 0)
+ } else {
+ (FileType::File.into(), meta.len())
+ }
+ }
+ };
+ let modified_time = meta
+ .modified()
+ .map(|x| {
+ x.duration_since(std::time::SystemTime::UNIX_EPOCH)
+ .map(|x| x.as_secs())
+ .unwrap_or(0)
+ })
+ .unwrap_or(0) as u64;
+ dir.entries.push(FileEntry {
+ name: get_file_name(&p),
+ entry_type,
+ is_hidden,
+ size,
+ modified_time,
+ ..Default::default()
+ });
+ }
+ }
+ Ok(dir)
+}
+
+#[inline]
+pub fn get_file_name(p: &PathBuf) -> String {
+ p.file_name()
+ .map(|p| p.to_str().unwrap_or(""))
+ .unwrap_or("")
+ .to_owned()
+}
+
+#[inline]
+pub fn get_string(path: &PathBuf) -> String {
+ path.to_str().unwrap_or("").to_owned()
+}
+
+#[inline]
+pub fn get_path(path: &str) -> PathBuf {
+ Path::new(path).to_path_buf()
+}
+
+#[inline]
+pub fn get_home_as_string() -> String {
+ get_string(&Config::get_home())
+}
+
+fn read_dir_recursive(
+ path: &PathBuf,
+ prefix: &PathBuf,
+ include_hidden: bool,
+) -> ResultType<Vec<FileEntry>> {
+ let mut files = Vec::new();
+ if path.is_dir() {
+ // to-do: symbol link handling, cp the link rather than the content
+ // to-do: file mode, for unix
+ let fd = read_dir(&path, include_hidden)?;
+ for entry in fd.entries.iter() {
+ match entry.entry_type.enum_value() {
+ Ok(FileType::File) => {
+ let mut entry = entry.clone();
+ entry.name = get_string(&prefix.join(entry.name));
+ files.push(entry);
+ }
+ Ok(FileType::Dir) => {
+ if let Ok(mut tmp) = read_dir_recursive(
+ &path.join(&entry.name),
+ &prefix.join(&entry.name),
+ include_hidden,
+ ) {
+ for entry in tmp.drain(0..) {
+ files.push(entry);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ Ok(files)
+ } else if path.is_file() {
+ let (size, modified_time) = if let Ok(meta) = std::fs::metadata(&path) {
+ (
+ meta.len(),
+ meta.modified()
+ .map(|x| {
+ x.duration_since(std::time::SystemTime::UNIX_EPOCH)
+ .map(|x| x.as_secs())
+ .unwrap_or(0)
+ })
+ .unwrap_or(0) as u64,
+ )
+ } else {
+ (0, 0)
+ };
+ files.push(FileEntry {
+ entry_type: FileType::File.into(),
+ size,
+ modified_time,
+ ..Default::default()
+ });
+ Ok(files)
+ } else {
+ bail!("Not exists");
+ }
+}
+
+pub fn get_recursive_files(path: &str, include_hidden: bool) -> ResultType<Vec<FileEntry>> {
+ read_dir_recursive(&get_path(path), &get_path(""), include_hidden)
+}
+
+#[derive(Default)]
+pub struct TransferJob {
+ id: i32,
+ path: PathBuf,
+ files: Vec<FileEntry>,
+ file_num: i32,
+ file: Option<File>,
+ total_size: u64,
+ finished_size: u64,
+ transferred: u64,
+}
+
+#[inline]
+fn get_ext(name: &str) -> &str {
+ if let Some(i) = name.rfind(".") {
+ return &name[i + 1..];
+ }
+ ""
+}
+
+#[inline]
+fn is_compressed_file(name: &str) -> bool {
+ let ext = get_ext(name);
+ ext == "xz"
+ || ext == "gz"
+ || ext == "zip"
+ || ext == "7z"
+ || ext == "rar"
+ || ext == "bz2"
+ || ext == "tgz"
+ || ext == "png"
+ || ext == "jpg"
+}
+
+impl TransferJob {
+ pub fn new_write(id: i32, path: String, files: Vec<FileEntry>) -> Self {
+ let total_size = files.iter().map(|x| x.size as u64).sum();
+ Self {
+ id,
+ path: get_path(&path),
+ files,
+ total_size,
+ ..Default::default()
+ }
+ }
+
+ pub fn new_read(id: i32, path: String, include_hidden: bool) -> ResultType<Self> {
+ let files = get_recursive_files(&path, include_hidden)?;
+ let total_size = files.iter().map(|x| x.size as u64).sum();
+ Ok(Self {
+ id,
+ path: get_path(&path),
+ files,
+ total_size,
+ ..Default::default()
+ })
+ }
+
+ #[inline]
+ pub fn files(&self) -> &Vec<FileEntry> {
+ &self.files
+ }
+
+ #[inline]
+ pub fn set_files(&mut self, files: Vec<FileEntry>) {
+ self.files = files;
+ }
+
+ #[inline]
+ pub fn id(&self) -> i32 {
+ self.id
+ }
+
+ #[inline]
+ pub fn total_size(&self) -> u64 {
+ self.total_size
+ }
+
+ #[inline]
+ pub fn finished_size(&self) -> u64 {
+ self.finished_size
+ }
+
+ #[inline]
+ pub fn transferred(&self) -> u64 {
+ self.transferred
+ }
+
+ #[inline]
+ pub fn file_num(&self) -> i32 {
+ self.file_num
+ }
+
+ pub fn modify_time(&self) {
+ let file_num = self.file_num as usize;
+ if file_num < self.files.len() {
+ let entry = &self.files[file_num];
+ let path = self.join(&entry.name);
+ let download_path = format!("{}.download", get_string(&path));
+ std::fs::rename(&download_path, &path).ok();
+ filetime::set_file_mtime(
+ &path,
+ filetime::FileTime::from_unix_time(entry.modified_time as _, 0),
+ )
+ .ok();
+ }
+ }
+
+ pub fn remove_download_file(&self) {
+ let file_num = self.file_num as usize;
+ if file_num < self.files.len() {
+ let entry = &self.files[file_num];
+ let path = self.join(&entry.name);
+ let download_path = format!("{}.download", get_string(&path));
+ std::fs::remove_file(&download_path).ok();
+ }
+ }
+
+ pub async fn write(&mut self, block: FileTransferBlock, raw: Option<&[u8]>) -> ResultType<()> {
+ if block.id != self.id {
+ bail!("Wrong id");
+ }
+ let file_num = block.file_num as usize;
+ if file_num >= self.files.len() {
+ bail!("Wrong file number");
+ }
+ if file_num != self.file_num as usize || self.file.is_none() {
+ self.modify_time();
+ if let Some(file) = self.file.as_mut() {
+ file.sync_all().await?;
+ }
+ self.file_num = block.file_num;
+ let entry = &self.files[file_num];
+ let path = self.join(&entry.name);
+ if let Some(p) = path.parent() {
+ std::fs::create_dir_all(p).ok();
+ }
+ let path = format!("{}.download", get_string(&path));
+ self.file = Some(File::create(&path).await?);
+ }
+ let data = if let Some(data) = raw {
+ data
+ } else {
+ &block.data
+ };
+ if block.compressed {
+ let tmp = decompress(data);
+ self.file.as_mut().unwrap().write_all(&tmp).await?;
+ self.finished_size += tmp.len() as u64;
+ } else {
+ self.file.as_mut().unwrap().write_all(data).await?;
+ self.finished_size += data.len() as u64;
+ }
+ self.transferred += data.len() as u64;
+ Ok(())
+ }
+
+ #[inline]
+ fn join(&self, name: &str) -> PathBuf {
+ if name.is_empty() {
+ self.path.clone()
+ } else {
+ self.path.join(name)
+ }
+ }
+
+ pub async fn read(&mut self) -> ResultType<Option<FileTransferBlock>> {
+ let file_num = self.file_num as usize;
+ if file_num >= self.files.len() {
+ self.file.take();
+ return Ok(None);
+ }
+ let name = &self.files[file_num].name;
+ if self.file.is_none() {
+ match File::open(self.join(&name)).await {
+ Ok(file) => {
+ self.file = Some(file);
+ }
+ Err(err) => {
+ self.file_num += 1;
+ return Err(err.into());
+ }
+ }
+ }
+ const BUF_SIZE: usize = 128 * 1024;
+ let mut buf: Vec<u8> = Vec::with_capacity(BUF_SIZE);
+ unsafe {
+ buf.set_len(BUF_SIZE);
+ }
+ let mut compressed = false;
+ let mut offset: usize = 0;
+ loop {
+ match self.file.as_mut().unwrap().read(&mut buf[offset..]).await {
+ Err(err) => {
+ self.file_num += 1;
+ self.file = None;
+ return Err(err.into());
+ }
+ Ok(n) => {
+ offset += n;
+ if n == 0 || offset == BUF_SIZE {
+ break;
+ }
+ }
+ }
+ }
+ unsafe { buf.set_len(offset) };
+ if offset == 0 {
+ self.file_num += 1;
+ self.file = None;
+ } else {
+ self.finished_size += offset as u64;
+ if !is_compressed_file(name) {
+ let tmp = compress(&buf, COMPRESS_LEVEL);
+ if tmp.len() < buf.len() {
+ buf = tmp;
+ compressed = true;
+ }
+ }
+ self.transferred += buf.len() as u64;
+ }
+ Ok(Some(FileTransferBlock {
+ id: self.id,
+ file_num: file_num as _,
+ data: buf.into(),
+ compressed,
+ ..Default::default()
+ }))
+ }
+}
+
+#[inline]
+pub fn new_error<T: std::string::ToString>(id: i32, err: T, file_num: i32) -> Message {
+ let mut resp = FileResponse::new();
+ resp.set_error(FileTransferError {
+ id,
+ error: err.to_string(),
+ file_num,
+ ..Default::default()
+ });
+ let mut msg_out = Message::new();
+ msg_out.set_file_response(resp);
+ msg_out
+}
+
+#[inline]
+pub fn new_dir(id: i32, path: String, files: Vec<FileEntry>) -> Message {
+ let mut resp = FileResponse::new();
+ resp.set_dir(FileDirectory {
+ id,
+ path,
+ entries: files.into(),
+ ..Default::default()
+ });
+ let mut msg_out = Message::new();
+ msg_out.set_file_response(resp);
+ msg_out
+}
+
+#[inline]
+pub fn new_block(block: FileTransferBlock) -> Message {
+ let mut resp = FileResponse::new();
+ resp.set_block(block);
+ let mut msg_out = Message::new();
+ msg_out.set_file_response(resp);
+ msg_out
+}
+
+#[inline]
+pub fn new_receive(id: i32, path: String, files: Vec<FileEntry>) -> Message {
+ let mut action = FileAction::new();
+ action.set_receive(FileTransferReceiveRequest {
+ id,
+ path,
+ files: files.into(),
+ ..Default::default()
+ });
+ let mut msg_out = Message::new();
+ msg_out.set_file_action(action);
+ msg_out
+}
+
+#[inline]
+pub fn new_send(id: i32, path: String, include_hidden: bool) -> Message {
+ let mut action = FileAction::new();
+ action.set_send(FileTransferSendRequest {
+ id,
+ path,
+ include_hidden,
+ ..Default::default()
+ });
+ let mut msg_out = Message::new();
+ msg_out.set_file_action(action);
+ msg_out
+}
+
+#[inline]
+pub fn new_done(id: i32, file_num: i32) -> Message {
+ let mut resp = FileResponse::new();
+ resp.set_done(FileTransferDone {
+ id,
+ file_num,
+ ..Default::default()
+ });
+ let mut msg_out = Message::new();
+ msg_out.set_file_response(resp);
+ msg_out
+}
+
+#[inline]
+pub fn remove_job(id: i32, jobs: &mut Vec<TransferJob>) {
+ *jobs = jobs.drain(0..).filter(|x| x.id() != id).collect();
+}
+
+#[inline]
+pub fn get_job(id: i32, jobs: &mut Vec<TransferJob>) -> Option<&mut TransferJob> {
+ jobs.iter_mut().filter(|x| x.id() == id).next()
+}
+
+pub async fn handle_read_jobs(
+ jobs: &mut Vec<TransferJob>,
+ stream: &mut crate::Stream,
+) -> ResultType<()> {
+ let mut finished = Vec::new();
+ for job in jobs.iter_mut() {
+ match job.read().await {
+ Err(err) => {
+ stream
+ .send(&new_error(job.id(), err, job.file_num()))
+ .await?;
+ }
+ Ok(Some(block)) => {
+ stream.send(&new_block(block)).await?;
+ }
+ Ok(None) => {
+ finished.push(job.id());
+ stream.send(&new_done(job.id(), job.file_num())).await?;
+ }
+ }
+ }
+ for id in finished {
+ remove_job(id, jobs);
+ }
+ Ok(())
+}
+
+pub fn remove_all_empty_dir(path: &PathBuf) -> ResultType<()> {
+ let fd = read_dir(path, true)?;
+ for entry in fd.entries.iter() {
+ match entry.entry_type.enum_value() {
+ Ok(FileType::Dir) => {
+ remove_all_empty_dir(&path.join(&entry.name)).ok();
+ }
+ Ok(FileType::DirLink) | Ok(FileType::FileLink) => {
+ std::fs::remove_file(&path.join(&entry.name)).ok();
+ }
+ _ => {}
+ }
+ }
+ std::fs::remove_dir(path).ok();
+ Ok(())
+}
+
+#[inline]
+pub fn remove_file(file: &str) -> ResultType<()> {
+ std::fs::remove_file(get_path(file))?;
+ Ok(())
+}
+
+#[inline]
+pub fn create_dir(dir: &str) -> ResultType<()> {
+ std::fs::create_dir_all(get_path(dir))?;
+ Ok(())
+}
diff --git a/libs/hbb_common/src/lib.rs b/libs/hbb_common/src/lib.rs
new file mode 100644
index 0000000..0a9dace
--- /dev/null
+++ b/libs/hbb_common/src/lib.rs
@@ -0,0 +1,211 @@
+pub mod compress;
+#[path = "./protos/message.rs"]
+pub mod message_proto;
+#[path = "./protos/rendezvous.rs"]
+pub mod rendezvous_proto;
+pub use bytes;
+pub use futures;
+pub use protobuf;
+use std::{
+ fs::File,
+ io::{self, BufRead},
+ net::{Ipv4Addr, SocketAddr, SocketAddrV4},
+ path::Path,
+ time::{self, SystemTime, UNIX_EPOCH},
+};
+pub use tokio;
+pub use tokio_util;
+pub mod socket_client;
+pub mod tcp;
+pub mod udp;
+pub use env_logger;
+pub use log;
+pub mod bytes_codec;
+#[cfg(feature = "quic")]
+pub mod quic;
+pub use anyhow::{self, bail};
+pub use futures_util;
+pub mod config;
+pub mod fs;
+#[cfg(not(any(target_os = "android", target_os = "ios")))]
+pub use mac_address;
+pub use rand;
+pub use regex;
+pub use sodiumoxide;
+pub use tokio_socks;
+pub use tokio_socks::IntoTargetAddr;
+pub use tokio_socks::TargetAddr;
+pub use lazy_static;
+
+#[cfg(feature = "quic")]
+pub type Stream = quic::Connection;
+#[cfg(not(feature = "quic"))]
+pub type Stream = tcp::FramedStream;
+
+#[inline]
+pub async fn sleep(sec: f32) {
+ tokio::time::sleep(time::Duration::from_secs_f32(sec)).await;
+}
+
+#[macro_export]
+macro_rules! allow_err {
+ ($e:expr) => {
+ if let Err(err) = $e {
+ log::debug!(
+ "{:?}, {}:{}:{}:{}",
+ err,
+ module_path!(),
+ file!(),
+ line!(),
+ column!()
+ );
+ } else {
+ }
+ };
+}
+
+#[inline]
+pub fn timeout<T: std::future::Future>(ms: u64, future: T) -> tokio::time::Timeout<T> {
+ tokio::time::timeout(std::time::Duration::from_millis(ms), future)
+}
+
+pub type ResultType<F, E = anyhow::Error> = anyhow::Result<F, E>;
+
+/// Certain router and firewalls scan the packet and if they
+/// find an IP address belonging to their pool that they use to do the NAT mapping/translation, so here we mangle the ip address
+
+pub struct AddrMangle();
+
+impl AddrMangle {
+ pub fn encode(addr: SocketAddr) -> Vec<u8> {
+ match addr {
+ SocketAddr::V4(addr_v4) => {
+ let tm = (SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap()
+ .as_micros() as u32) as u128;
+ let ip = u32::from_le_bytes(addr_v4.ip().octets()) as u128;
+ let port = addr.port() as u128;
+ let v = ((ip + tm) << 49) | (tm << 17) | (port + (tm & 0xFFFF));
+ let bytes = v.to_le_bytes();
+ let mut n_padding = 0;
+ for i in bytes.iter().rev() {
+ if i == &0u8 {
+ n_padding += 1;
+ } else {
+ break;
+ }
+ }
+ bytes[..(16 - n_padding)].to_vec()
+ }
+ _ => {
+ panic!("Only support ipv4");
+ }
+ }
+ }
+
+ pub fn decode(bytes: &[u8]) -> SocketAddr {
+ let mut padded = [0u8; 16];
+ padded[..bytes.len()].copy_from_slice(&bytes);
+ let number = u128::from_le_bytes(padded);
+ let tm = (number >> 17) & (u32::max_value() as u128);
+ let ip = (((number >> 49) - tm) as u32).to_le_bytes();
+ let port = (number & 0xFFFFFF) - (tm & 0xFFFF);
+ SocketAddr::V4(SocketAddrV4::new(
+ Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]),
+ port as u16,
+ ))
+ }
+}
+
+pub fn get_version_from_url(url: &str) -> String {
+ let n = url.chars().count();
+ let a = url
+ .chars()
+ .rev()
+ .enumerate()
+ .filter(|(_, x)| x == &'-')
+ .next()
+ .map(|(i, _)| i);
+ if let Some(a) = a {
+ let b = url
+ .chars()
+ .rev()
+ .enumerate()
+ .filter(|(_, x)| x == &'.')
+ .next()
+ .map(|(i, _)| i);
+ if let Some(b) = b {
+ if a > b {
+ if url
+ .chars()
+ .skip(n - b)
+ .collect::<String>()
+ .parse::<i32>()
+ .is_ok()
+ {
+ return url.chars().skip(n - a).collect();
+ } else {
+ return url.chars().skip(n - a).take(a - b - 1).collect();
+ }
+ } else {
+ return url.chars().skip(n - a).collect();
+ }
+ }
+ }
+ "".to_owned()
+}
+
+pub fn gen_version() {
+ let mut file = File::create("./src/version.rs").unwrap();
+ for line in read_lines("Cargo.toml").unwrap() {
+ if let Ok(line) = line {
+ let ab: Vec<&str> = line.split("=").map(|x| x.trim()).collect();
+ if ab.len() == 2 && ab[0] == "version" {
+ use std::io::prelude::*;
+ file.write_all(format!("pub const VERSION: &str = {};", ab[1]).as_bytes())
+ .ok();
+ file.sync_all().ok();
+ break;
+ }
+ }
+ }
+}
+
+fn read_lines<P>(filename: P) -> io::Result<io::Lines<io::BufReader<File>>>
+where
+ P: AsRef<Path>,
+{
+ let file = File::open(filename)?;
+ Ok(io::BufReader::new(file).lines())
+}
+
+pub fn is_valid_custom_id(id: &str) -> bool {
+ regex::Regex::new(r"^[a-zA-Z]\w{5,15}$")
+ .unwrap()
+ .is_match(id)
+}
+
+pub fn get_version_number(v: &str) -> i64 {
+ let mut n = 0;
+ for x in v.split(".") {
+ n = n * 1000 + x.parse::<i64>().unwrap_or(0);
+ }
+ n
+}
+
+pub fn get_modified_time(path: &std::path::Path) -> SystemTime {
+ std::fs::metadata(&path)
+ .map(|m| m.modified().unwrap_or(UNIX_EPOCH))
+ .unwrap_or(UNIX_EPOCH)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ #[test]
+ fn test_mangle() {
+ let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 16, 32), 21116));
+ assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr)));
+ }
+}
diff --git a/libs/hbb_common/src/quic.rs b/libs/hbb_common/src/quic.rs
new file mode 100644
index 0000000..ada2acd
--- /dev/null
+++ b/libs/hbb_common/src/quic.rs
@@ -0,0 +1,135 @@
+use crate::{allow_err, anyhow::anyhow, ResultType};
+use protobuf::Message;
+use std::{net::SocketAddr, sync::Arc};
+use tokio::{self, stream::StreamExt, sync::mpsc};
+
+const QUIC_HBB: &[&[u8]] = &[b"hbb"];
+const SERVER_NAME: &str = "hbb";
+
+type Sender = mpsc::UnboundedSender<Value>;
+type Receiver = mpsc::UnboundedReceiver<Value>;
+
+pub fn new_server(socket: std::net::UdpSocket) -> ResultType<(Server, SocketAddr)> {
+ let mut transport_config = quinn::TransportConfig::default();
+ transport_config.stream_window_uni(0);
+ let mut server_config = quinn::ServerConfig::default();
+ server_config.transport = Arc::new(transport_config);
+ let mut server_config = quinn::ServerConfigBuilder::new(server_config);
+ server_config.protocols(QUIC_HBB);
+ // server_config.enable_keylog();
+ // server_config.use_stateless_retry(true);
+ let mut endpoint = quinn::Endpoint::builder();
+ endpoint.listen(server_config.build());
+ let (end, incoming) = endpoint.with_socket(socket)?;
+ Ok((Server { incoming }, end.local_addr()?))
+}
+
+pub async fn new_client(local_addr: &SocketAddr, peer: &SocketAddr) -> ResultType<Connection> {
+ let mut endpoint = quinn::Endpoint::builder();
+ let mut client_config = quinn::ClientConfigBuilder::default();
+ client_config.protocols(QUIC_HBB);
+ //client_config.enable_keylog();
+ endpoint.default_client_config(client_config.build());
+ let (endpoint, _) = endpoint.bind(local_addr)?;
+ let new_conn = endpoint.connect(peer, SERVER_NAME)?.await?;
+ Connection::new_for_client(new_conn.connection).await
+}
+
+pub struct Server {
+ incoming: quinn::Incoming,
+}
+
+impl Server {
+ #[inline]
+ pub async fn next(&mut self) -> ResultType<Option<Connection>> {
+ Connection::new_for_server(&mut self.incoming).await
+ }
+}
+
+pub struct Connection {
+ conn: quinn::Connection,
+ tx: quinn::SendStream,
+ rx: Receiver,
+}
+
+type Value = ResultType<Vec<u8>>;
+
+impl Connection {
+ async fn new_for_server(incoming: &mut quinn::Incoming) -> ResultType<Option<Self>> {
+ if let Some(conn) = incoming.next().await {
+ let quinn::NewConnection {
+ connection: conn,
+ // uni_streams,
+ mut bi_streams,
+ ..
+ } = conn.await?;
+ let (tx, rx) = mpsc::unbounded_channel::<Value>();
+ tokio::spawn(async move {
+ loop {
+ let stream = bi_streams.next().await;
+ if let Some(stream) = stream {
+ let stream = match stream {
+ Err(e) => {
+ tx.send(Err(e.into())).ok();
+ break;
+ }
+ Ok(s) => s,
+ };
+ let cloned = tx.clone();
+ tokio::spawn(async move {
+ allow_err!(handle_request(stream.1, cloned).await);
+ });
+ } else {
+ tx.send(Err(anyhow!("Reset by the peer"))).ok();
+ break;
+ }
+ }
+ log::info!("Exit connection outer loop");
+ });
+ let tx = conn.open_uni().await?;
+ Ok(Some(Self { conn, tx, rx }))
+ } else {
+ Ok(None)
+ }
+ }
+
+ async fn new_for_client(conn: quinn::Connection) -> ResultType<Self> {
+ let (tx, rx_quic) = conn.open_bi().await?;
+ let (tx_mpsc, rx) = mpsc::unbounded_channel::<Value>();
+ tokio::spawn(async move {
+ allow_err!(handle_request(rx_quic, tx_mpsc).await);
+ });
+ Ok(Self { conn, tx, rx })
+ }
+
+ #[inline]
+ pub async fn next(&mut self) -> Option<Value> {
+ // None is returned when all Sender halves have dropped,
+ // indicating that no further values can be sent on the channel.
+ self.rx.recv().await
+ }
+
+ #[inline]
+ pub fn remote_address(&self) -> SocketAddr {
+ self.conn.remote_address()
+ }
+
+ #[inline]
+ pub async fn send_raw(&mut self, bytes: &[u8]) -> ResultType<()> {
+ self.tx.write_all(bytes).await?;
+ Ok(())
+ }
+
+ #[inline]
+ pub async fn send(&mut self, msg: &dyn Message) -> ResultType<()> {
+ match msg.write_to_bytes() {
+ Ok(bytes) => self.send_raw(&bytes).await?,
+ err => allow_err!(err),
+ }
+ Ok(())
+ }
+}
+
+async fn handle_request(rx: quinn::RecvStream, tx: Sender) -> ResultType<()> {
+ Ok(())
+}
diff --git a/libs/hbb_common/src/socket_client.rs b/libs/hbb_common/src/socket_client.rs
new file mode 100644
index 0000000..0375b71
--- /dev/null
+++ b/libs/hbb_common/src/socket_client.rs
@@ -0,0 +1,91 @@
+use crate::{
+ config::{Config, NetworkType},
+ tcp::FramedStream,
+ udp::FramedSocket,
+ ResultType,
+};
+use anyhow::Context;
+use std::net::SocketAddr;
+use tokio::net::ToSocketAddrs;
+use tokio_socks::{IntoTargetAddr, TargetAddr};
+
+fn to_socket_addr(host: &str) -> ResultType<SocketAddr> {
+ use std::net::ToSocketAddrs;
+ host.to_socket_addrs()?.next().context("Failed to solve")
+}
+
+pub fn get_target_addr(host: &str) -> ResultType<TargetAddr<'static>> {
+ let addr = match Config::get_network_type() {
+ NetworkType::Direct => to_socket_addr(&host)?.into_target_addr()?,
+ NetworkType::ProxySocks => host.into_target_addr()?,
+ }
+ .to_owned();
+ Ok(addr)
+}
+
+pub fn test_if_valid_server(host: &str) -> String {
+ let mut host = host.to_owned();
+ if !host.contains(":") {
+ host = format!("{}:{}", host, 0);
+ }
+
+ match Config::get_network_type() {
+ NetworkType::Direct => match to_socket_addr(&host) {
+ Err(err) => err.to_string(),
+ Ok(_) => "".to_owned(),
+ },
+ NetworkType::ProxySocks => match &host.into_target_addr() {
+ Err(err) => err.to_string(),
+ Ok(_) => "".to_owned(),
+ },
+ }
+}
+
+pub async fn connect_tcp<'t, T: IntoTargetAddr<'t>>(
+ target: T,
+ local: SocketAddr,
+ ms_timeout: u64,
+) -> ResultType<FramedStream> {
+ let target_addr = target.into_target_addr()?;
+
+ if let Some(conf) = Config::get_socks() {
+ FramedStream::connect(
+ conf.proxy.as_str(),
+ target_addr,
+ local,
+ conf.username.as_str(),
+ conf.password.as_str(),
+ ms_timeout,
+ )
+ .await
+ } else {
+ let addr = std::net::ToSocketAddrs::to_socket_addrs(&target_addr)?
+ .next()
+ .context("Invalid target addr")?;
+ Ok(FramedStream::new(addr, local, ms_timeout).await?)
+ }
+}
+
+pub async fn new_udp<T: ToSocketAddrs>(local: T, ms_timeout: u64) -> ResultType<FramedSocket> {
+ match Config::get_socks() {
+ None => Ok(FramedSocket::new(local).await?),
+ Some(conf) => {
+ let socket = FramedSocket::new_proxy(
+ conf.proxy.as_str(),
+ local,
+ conf.username.as_str(),
+ conf.password.as_str(),
+ ms_timeout,
+ )
+ .await?;
+ Ok(socket)
+ }
+ }
+}
+
+pub async fn rebind_udp<T: ToSocketAddrs>(local: T) -> ResultType<Option<FramedSocket>> {
+ match Config::get_network_type() {
+ NetworkType::Direct => Ok(Some(FramedSocket::new(local).await?)),
+ _ => Ok(None),
+ }
+}
diff --git a/libs/hbb_common/src/tcp.rs b/libs/hbb_common/src/tcp.rs
new file mode 100644
index 0000000..7966920
--- /dev/null
+++ b/libs/hbb_common/src/tcp.rs
@@ -0,0 +1,285 @@
+use crate::{bail, bytes_codec::BytesCodec, ResultType};
+use bytes::{BufMut, Bytes, BytesMut};
+use futures::{SinkExt, StreamExt};
+use protobuf::Message;
+use sodiumoxide::crypto::secretbox::{self, Key, Nonce};
+use std::{
+ io::{self, Error, ErrorKind},
+ net::SocketAddr,
+ ops::{Deref, DerefMut},
+ pin::Pin,
+ task::{Context, Poll},
+};
+use tokio::{
+ io::{AsyncRead, AsyncWrite, ReadBuf},
+ net::{lookup_host, TcpListener, TcpSocket, ToSocketAddrs},
+};
+use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr, ToProxyAddrs};
+use tokio_util::codec::Framed;
+
+pub trait TcpStreamTrait: AsyncRead + AsyncWrite + Unpin {}
+pub struct DynTcpStream(Box<dyn TcpStreamTrait + Send + Sync>);
+
+pub struct FramedStream(
+ Framed<DynTcpStream, BytesCodec>,
+ SocketAddr,
+ Option<(Key, u64, u64)>,
+ u64,
+);
+
+impl Deref for FramedStream {
+ type Target = Framed<DynTcpStream, BytesCodec>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl DerefMut for FramedStream {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+impl Deref for DynTcpStream {
+ type Target = Box<dyn TcpStreamTrait + Send + Sync>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl DerefMut for DynTcpStream {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+fn new_socket(addr: std::net::SocketAddr, reuse: bool) -> Result<TcpSocket, std::io::Error> {
+ let socket = match addr {
+ std::net::SocketAddr::V4(..) => TcpSocket::new_v4()?,
+ std::net::SocketAddr::V6(..) => TcpSocket::new_v6()?,
+ };
+ if reuse {
+ // windows has no reuse_port, but it's reuse_address
+ // almost equals to unix's reuse_port + reuse_address,
+ // though may introduce nondeterministic behavior
+ #[cfg(unix)]
+ socket.set_reuseport(true)?;
+ socket.set_reuseaddr(true)?;
+ }
+ socket.bind(addr)?;
+ Ok(socket)
+}
+
+impl FramedStream {
+ pub async fn new<T1: ToSocketAddrs, T2: ToSocketAddrs>(
+ remote_addr: T1,
+ local_addr: T2,
+ ms_timeout: u64,
+ ) -> ResultType<Self> {
+ for local_addr in lookup_host(&local_addr).await? {
+ for remote_addr in lookup_host(&remote_addr).await? {
+ let stream = super::timeout(
+ ms_timeout,
+ new_socket(local_addr, true)?.connect(remote_addr),
+ )
+ .await??;
+ stream.set_nodelay(true).ok();
+ let addr = stream.local_addr()?;
+ return Ok(Self(
+ Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
+ addr,
+ None,
+ 0,
+ ));
+ }
+ }
+ bail!("could not resolve to any address");
+ }
+
+ pub async fn connect<'a, 't, P, T1, T2>(
+ proxy: P,
+ target: T1,
+ local: T2,
+ username: &'a str,
+ password: &'a str,
+ ms_timeout: u64,
+ ) -> ResultType<Self>
+ where
+ P: ToProxyAddrs,
+ T1: IntoTargetAddr<'t>,
+ T2: ToSocketAddrs,
+ {
+ if let Some(local) = lookup_host(&local).await?.next() {
+ if let Some(proxy) = proxy.to_proxy_addrs().next().await {
+ let stream =
+ super::timeout(ms_timeout, new_socket(local, true)?.connect(proxy?)).await??;
+ stream.set_nodelay(true).ok();
+ let stream = if username.trim().is_empty() {
+ super::timeout(
+ ms_timeout,
+ Socks5Stream::connect_with_socket(stream, target),
+ )
+ .await??
+ } else {
+ super::timeout(
+ ms_timeout,
+ Socks5Stream::connect_with_password_and_socket(
+ stream, target, username, password,
+ ),
+ )
+ .await??
+ };
+ let addr = stream.local_addr()?;
+ return Ok(Self(
+ Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
+ addr,
+ None,
+ 0,
+ ));
+ };
+ };
+ bail!("could not resolve to any address");
+ }
+
+ pub fn local_addr(&self) -> SocketAddr {
+ self.1
+ }
+
+ pub fn set_send_timeout(&mut self, ms: u64) {
+ self.3 = ms;
+ }
+
+ pub fn from(stream: impl TcpStreamTrait + Send + Sync + 'static, addr: SocketAddr) -> Self {
+ Self(
+ Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()),
+ addr,
+ None,
+ 0,
+ )
+ }
+
+ pub fn set_raw(&mut self) {
+ self.0.codec_mut().set_raw();
+ self.2 = None;
+ }
+
+ pub fn is_secured(&self) -> bool {
+ self.2.is_some()
+ }
+
+ #[inline]
+ pub async fn send(&mut self, msg: &impl Message) -> ResultType<()> {
+ self.send_raw(msg.write_to_bytes()?).await
+ }
+
+ #[inline]
+ pub async fn send_raw(&mut self, msg: Vec<u8>) -> ResultType<()> {
+ let mut msg = msg;
+ if let Some(key) = self.2.as_mut() {
+ key.1 += 1;
+ let nonce = Self::get_nonce(key.1);
+ msg = secretbox::seal(&msg, &nonce, &key.0);
+ }
+ self.send_bytes(bytes::Bytes::from(msg)).await?;
+ Ok(())
+ }
+
+ #[inline]
+ pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> {
+ if self.3 > 0 {
+ super::timeout(self.3, self.0.send(bytes)).await??;
+ } else {
+ self.0.send(bytes).await?;
+ }
+ Ok(())
+ }
+
+ #[inline]
+ pub async fn next(&mut self) -> Option<Result<BytesMut, Error>> {
+ let mut res = self.0.next().await;
+ if let Some(key) = self.2.as_mut() {
+ if let Some(Ok(bytes)) = res.as_mut() {
+ key.2 += 1;
+ let nonce = Self::get_nonce(key.2);
+ match secretbox::open(&bytes, &nonce, &key.0) {
+ Ok(res) => {
+ bytes.clear();
+ bytes.put_slice(&res);
+ }
+ Err(()) => {
+ return Some(Err(Error::new(ErrorKind::Other, "decryption error")));
+ }
+ }
+ }
+ }
+ res
+ }
+
+ #[inline]
+ pub async fn next_timeout(&mut self, ms: u64) -> Option<Result<BytesMut, Error>> {
+ if let Ok(res) = super::timeout(ms, self.next()).await {
+ res
+ } else {
+ None
+ }
+ }
+
+ pub fn set_key(&mut self, key: Key) {
+ self.2 = Some((key, 0, 0));
+ }
+
+ fn get_nonce(seqnum: u64) -> Nonce {
+ let mut nonce = Nonce([0u8; secretbox::NONCEBYTES]);
+ nonce.0[..std::mem::size_of_val(&seqnum)].copy_from_slice(&seqnum.to_le_bytes());
+ nonce
+ }
+}
+
+const DEFAULT_BACKLOG: u32 = 128;
+
+#[allow(clippy::never_loop)]
+pub async fn new_listener<T: ToSocketAddrs>(addr: T, reuse: bool) -> ResultType<TcpListener> {
+ if !reuse {
+ Ok(TcpListener::bind(addr).await?)
+ } else {
+ for addr in lookup_host(&addr).await? {
+ let socket = new_socket(addr, true)?;
+ return Ok(socket.listen(DEFAULT_BACKLOG)?);
+ }
+ bail!("could not resolve to any address");
+ }
+}
+
+impl Unpin for DynTcpStream {}
+
+impl AsyncRead for DynTcpStream {
+ fn poll_read(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
+ }
+}
+
+impl AsyncWrite for DynTcpStream {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
+ }
+
+ fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
+ }
+}
+
+impl<R: AsyncRead + AsyncWrite + Unpin> TcpStreamTrait for R {}
diff --git a/libs/hbb_common/src/udp.rs b/libs/hbb_common/src/udp.rs
new file mode 100644
index 0000000..0338618
--- /dev/null
+++ b/libs/hbb_common/src/udp.rs
@@ -0,0 +1,165 @@
+use crate::{bail, ResultType};
+use anyhow::anyhow;
+use bytes::{Bytes, BytesMut};
+use futures::{SinkExt, StreamExt};
+use protobuf::Message;
+use socket2::{Domain, Socket, Type};
+use std::net::SocketAddr;
+use tokio::net::{ToSocketAddrs, UdpSocket};
+use tokio_socks::{udp::Socks5UdpFramed, IntoTargetAddr, TargetAddr, ToProxyAddrs};
+use tokio_util::{codec::BytesCodec, udp::UdpFramed};
+
+pub enum FramedSocket {
+ Direct(UdpFramed<BytesCodec>),
+ ProxySocks(Socks5UdpFramed),
+}
+
+fn new_socket(addr: SocketAddr, reuse: bool, buf_size: usize) -> Result<Socket, std::io::Error> {
+ let socket = match addr {
+ SocketAddr::V4(..) => Socket::new(Domain::ipv4(), Type::dgram(), None),
+ SocketAddr::V6(..) => Socket::new(Domain::ipv6(), Type::dgram(), None),
+ }?;
+ if reuse {
+ // windows has no reuse_port, but it's reuse_address
+ // almost equals to unix's reuse_port + reuse_address,
+ // though may introduce nondeterministic behavior
+ #[cfg(unix)]
+ socket.set_reuse_port(true)?;
+ socket.set_reuse_address(true)?;
+ }
+ if buf_size > 0 {
+ socket.set_recv_buffer_size(buf_size).ok();
+ }
+ log::info!(
+ "Receive buf size of udp {}: {:?}",
+ addr,
+ socket.recv_buffer_size()
+ );
+ socket.bind(&addr.into())?;
+ Ok(socket)
+}
+
+impl FramedSocket {
+ pub async fn new<T: ToSocketAddrs>(addr: T) -> ResultType<Self> {
+ let socket = UdpSocket::bind(addr).await?;
+ Ok(Self::Direct(UdpFramed::new(socket, BytesCodec::new())))
+ }
+
+ #[allow(clippy::never_loop)]
+ pub async fn new_reuse<T: std::net::ToSocketAddrs>(addr: T) -> ResultType<Self> {
+ for addr in addr.to_socket_addrs()? {
+ let socket = new_socket(addr, true, 0)?.into_udp_socket();
+ return Ok(Self::Direct(UdpFramed::new(
+ UdpSocket::from_std(socket)?,
+ BytesCodec::new(),
+ )));
+ }
+ bail!("could not resolve to any address");
+ }
+
+ pub async fn new_with_buf_size<T: std::net::ToSocketAddrs>(
+ addr: T,
+ buf_size: usize,
+ ) -> ResultType<Self> {
+ for addr in addr.to_socket_addrs()? {
+ return Ok(Self::Direct(UdpFramed::new(
+ UdpSocket::from_std(new_socket(addr, false, buf_size)?.into_udp_socket())?,
+ BytesCodec::new(),
+ )));
+ }
+ bail!("could not resolve to any address");
+ }
+
+ pub async fn new_proxy<'a, 't, P: ToProxyAddrs, T: ToSocketAddrs>(
+ proxy: P,
+ local: T,
+ username: &'a str,
+ password: &'a str,
+ ms_timeout: u64,
+ ) -> ResultType<Self> {
+ let framed = if username.trim().is_empty() {
+ super::timeout(ms_timeout, Socks5UdpFramed::connect(proxy, Some(local))).await??
+ } else {
+ super::timeout(
+ ms_timeout,
+ Socks5UdpFramed::connect_with_password(proxy, Some(local), username, password),
+ )
+ .await??
+ };
+ log::trace!(
+ "Socks5 udp connected, local addr: {:?}, target addr: {}",
+ framed.local_addr(),
+ framed.socks_addr()
+ );
+ Ok(Self::ProxySocks(framed))
+ }
+
+ #[inline]
+ pub async fn send(
+ &mut self,
+ msg: &impl Message,
+ addr: impl IntoTargetAddr<'_>,
+ ) -> ResultType<()> {
+ let addr = addr.into_target_addr()?.to_owned();
+ let send_data = Bytes::from(msg.write_to_bytes()?);
+ let _ = match self {
+ Self::Direct(f) => match addr {
+ TargetAddr::Ip(addr) => f.send((send_data, addr)).await?,
+ _ => {}
+ },
+ Self::ProxySocks(f) => f.send((send_data, addr)).await?,
+ };
+ Ok(())
+ }
+
+ // https://stackoverflow.com/a/68733302/1926020
+ #[inline]
+ pub async fn send_raw(
+ &mut self,
+ msg: &'static [u8],
+ addr: impl IntoTargetAddr<'static>,
+ ) -> ResultType<()> {
+ let addr = addr.into_target_addr()?.to_owned();
+
+ let _ = match self {
+ Self::Direct(f) => match addr {
+ TargetAddr::Ip(addr) => f.send((Bytes::from(msg), addr)).await?,
+ _ => {}
+ },
+ Self::ProxySocks(f) => f.send((Bytes::from(msg), addr)).await?,
+ };
+ Ok(())
+ }
+
+ #[inline]
+ pub async fn next(&mut self) -> Option<ResultType<(BytesMut, TargetAddr<'static>)>> {
+ match self {
+ Self::Direct(f) => match f.next().await {
+ Some(Ok((data, addr))) => {
+ Some(Ok((data, addr.into_target_addr().ok()?.to_owned())))
+ }
+ Some(Err(e)) => Some(Err(anyhow!(e))),
+ None => None,
+ },
+ Self::ProxySocks(f) => match f.next().await {
+ Some(Ok((data, _))) => Some(Ok((data.data, data.dst_addr))),
+ Some(Err(e)) => Some(Err(anyhow!(e))),
+ None => None,
+ },
+ }
+ }
+
+ #[inline]
+ pub async fn next_timeout(
+ &mut self,
+ ms: u64,
+ ) -> Option<ResultType<(BytesMut, TargetAddr<'static>)>> {
+ if let Ok(res) =
+ tokio::time::timeout(std::time::Duration::from_millis(ms), self.next()).await
+ {
+ res
+ } else {
+ None
+ }
+ }
+}