diff options
author | Huabing Zhou <[email protected]> | 2023-01-06 10:40:26 +0800 |
---|---|---|
committer | Huabing Zhou <[email protected]> | 2023-01-06 10:40:26 +0800 |
commit | 2314783d4284a94711e48620e8fd9f315d1154dc (patch) | |
tree | bf2db1a45da2d4d231d4fef1c361862896db31a4 /libs | |
parent | 753c774380edb4ae641f9fbaa343e44aea844d7c (diff) | |
download | rustdesk-server-2314783d4284a94711e48620e8fd9f315d1154dc.tar.gz rustdesk-server-2314783d4284a94711e48620e8fd9f315d1154dc.zip |
sync rustdesk's hbb_common here
Diffstat (limited to 'libs')
-rw-r--r-- | libs/hbb_common/Cargo.toml | 7 | ||||
-rw-r--r-- | libs/hbb_common/protos/message.proto | 144 | ||||
-rw-r--r-- | libs/hbb_common/src/config.rs | 498 | ||||
-rw-r--r-- | libs/hbb_common/src/fs.rs | 338 | ||||
-rw-r--r-- | libs/hbb_common/src/lib.rs | 152 | ||||
-rw-r--r-- | libs/hbb_common/src/password_security.rs | 242 | ||||
-rw-r--r-- | libs/hbb_common/src/platform/linux.rs | 157 | ||||
-rw-r--r-- | libs/hbb_common/src/platform/mod.rs | 2 | ||||
-rw-r--r-- | libs/hbb_common/src/socket_client.rs | 194 | ||||
-rw-r--r-- | libs/hbb_common/src/tcp.rs | 144 | ||||
-rw-r--r-- | libs/hbb_common/src/udp.rs | 13 |
11 files changed, 1659 insertions, 232 deletions
diff --git a/libs/hbb_common/Cargo.toml b/libs/hbb_common/Cargo.toml index 4b28fc1..59f0896 100644 --- a/libs/hbb_common/Cargo.toml +++ b/libs/hbb_common/Cargo.toml @@ -11,7 +11,7 @@ protobuf = { version = "3.1", features = ["with-bytes"] } tokio = { version = "1.20", features = ["full"] } tokio-util = { version = "0.7", features = ["full"] } futures = "0.3" -bytes = "1.2" +bytes = { version = "1.2", features = ["serde"] } log = "0.4" env_logger = "0.9" socket2 = { version = "0.3", features = ["reuseport"] } @@ -30,15 +30,18 @@ filetime = "0.2" sodiumoxide = "0.2" regex = "1.4" tokio-socks = { git = "https://github.com/open-trade/tokio-socks" } +chrono = "0.4" [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] mac_address = "1.1" +machine-uid = "0.2" [features] quic = [] +flatpak = [] [build-dependencies] -protobuf-codegen = "3.1" +protobuf-codegen = { version = "3.1" } [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["winuser"] } diff --git a/libs/hbb_common/protos/message.proto b/libs/hbb_common/protos/message.proto index 15ee971..b127ac3 100644 --- a/libs/hbb_common/protos/message.proto +++ b/libs/hbb_common/protos/message.proto @@ -1,13 +1,13 @@ syntax = "proto3"; package hbb; -message VP9 { +message EncodedVideoFrame { bytes data = 1; bool key = 2; int64 pts = 3; } -message VP9s { repeated VP9 frames = 1; } +message EncodedVideoFrames { repeated EncodedVideoFrame frames = 1; } message RGB { bool compress = 1; } @@ -19,9 +19,11 @@ message YUV { message VideoFrame { oneof union { - VP9s vp9s = 6; + EncodedVideoFrames vp9s = 6; RGB rgb = 7; YUV yuv = 8; + EncodedVideoFrames h264s = 10; + EncodedVideoFrames h265s = 11; } int64 timestamp = 9; } @@ -38,6 +40,7 @@ message DisplayInfo { int32 height = 4; string name = 5; bool online = 6; + bool cursor_embedded = 7; } message PortForward { @@ -61,10 +64,21 @@ message LoginRequest { PortForward port_forward = 8; } bool video_ack_required = 9; + uint64 session_id = 10; + string version = 11; } message ChatMessage { string text = 1; } +message Features { + bool privacy_mode = 1; +} + +message SupportedEncoding { + bool h264 = 1; + bool h265 = 2; +} + message PeerInfo { string username = 1; string hostname = 2; @@ -74,6 +88,8 @@ message PeerInfo { bool sas_enabled = 6; string version = 7; int32 conn_id = 8; + Features features = 9; + SupportedEncoding encoding = 10; } message LoginResponse { @@ -90,6 +106,13 @@ message MouseEvent { repeated ControlKey modifiers = 4; } +enum KeyboardMode{ + Legacy = 0; + Map = 1; + Translate = 2; + Auto = 3; +} + enum ControlKey { Unknown = 0; Alt = 1; @@ -183,6 +206,7 @@ message KeyEvent { string seq = 6; } repeated ControlKey modifiers = 8; + KeyboardMode mode = 9; } message CursorData { @@ -252,6 +276,7 @@ message FileAction { FileRemoveFile remove_file = 6; ReadAllFiles all_files = 7; FileTransferCancel cancel = 8; + FileTransferSendConfirmRequest send_confirm = 9; } } @@ -263,14 +288,24 @@ message FileResponse { FileTransferBlock block = 2; FileTransferError error = 3; FileTransferDone done = 4; + FileTransferDigest digest = 5; } } +message FileTransferDigest { + int32 id = 1; + sint32 file_num = 2; + uint64 last_modified = 3; + uint64 file_size = 4; + bool is_upload = 5; +} + message FileTransferBlock { int32 id = 1; sint32 file_num = 2; bytes data = 3; bool compressed = 4; + uint32 blk_id = 5; } message FileTransferError { @@ -283,6 +318,16 @@ message FileTransferSendRequest { int32 id = 1; string path = 2; bool include_hidden = 3; + int32 file_num = 4; +} + +message FileTransferSendConfirmRequest { + int32 id = 1; + sint32 file_num = 2; + oneof union { + bool skip = 3; + uint32 offset_blk = 4; + } } message FileTransferDone { @@ -294,6 +339,7 @@ message FileTransferReceiveRequest { int32 id = 1; string path = 2; // path written to repeated FileEntry files = 3; + int32 file_num = 4; } message FileRemoveDir { @@ -315,38 +361,31 @@ message FileDirCreate { // main logic from freeRDP message CliprdrMonitorReady { - int32 conn_id = 1; } message CliprdrFormat { - int32 conn_id = 1; int32 id = 2; string format = 3; } message CliprdrServerFormatList { - int32 conn_id = 1; repeated CliprdrFormat formats = 2; } message CliprdrServerFormatListResponse { - int32 conn_id = 1; int32 msg_flags = 2; } message CliprdrServerFormatDataRequest { - int32 conn_id = 1; int32 requested_format_id = 2; } message CliprdrServerFormatDataResponse { - int32 conn_id = 1; int32 msg_flags = 2; bytes format_data = 3; } message CliprdrFileContentsRequest { - int32 conn_id = 1; int32 stream_id = 2; int32 list_index = 3; int32 dw_flags = 4; @@ -358,7 +397,6 @@ message CliprdrFileContentsRequest { } message CliprdrFileContentsResponse { - int32 conn_id = 1; int32 msg_flags = 3; int32 stream_id = 4; bytes requested_data = 5; @@ -382,6 +420,7 @@ message SwitchDisplay { sint32 y = 3; int32 width = 4; int32 height = 5; + bool cursor_embedded = 6; } message PermissionInfo { @@ -390,6 +429,8 @@ message PermissionInfo { Clipboard = 2; Audio = 3; File = 4; + Restart = 5; + Recording = 6; } Permission permission = 1; @@ -403,6 +444,20 @@ enum ImageQuality { Best = 4; } +message VideoCodecState { + enum PerferCodec { + Auto = 0; + VPX = 1; + H264 = 2; + H265 = 3; + } + + int32 score_vpx = 1; + int32 score_h264 = 2; + int32 score_h265 = 3; + PerferCodec perfer = 4; +} + message OptionMessage { enum BoolOption { NotSet = 0; @@ -418,16 +473,15 @@ message OptionMessage { BoolOption disable_audio = 7; BoolOption disable_clipboard = 8; BoolOption enable_file_transfer = 9; -} - -message OptionResponse { - OptionMessage opt = 1; - string error = 2; + VideoCodecState video_codec_state = 10; + int32 custom_fps = 11; } message TestDelay { int64 time = 1; bool from_client = 2; + uint32 last_delay = 3; + uint32 target_bitrate = 4; } message PublicKey { @@ -447,6 +501,57 @@ message AudioFrame { int64 timestamp = 2; } +// Notify peer to show message box. +message MessageBox { + // Message type. Refer to flutter/lib/commom.dart/msgBox(). + string msgtype = 1; + string title = 2; + // English + string text = 3; + // If not empty, msgbox provides a button to following the link. + // The link here can't be directly http url. + // It must be the key of http url configed in peer side or "rustdesk://*" (jump in app). + string link = 4; +} + +message BackNotification { + // no need to consider block input by someone else + enum BlockInputState { + BlkStateUnknown = 0; + BlkOnSucceeded = 2; + BlkOnFailed = 3; + BlkOffSucceeded = 4; + BlkOffFailed = 5; + } + enum PrivacyModeState { + PrvStateUnknown = 0; + // Privacy mode on by someone else + PrvOnByOther = 2; + // Privacy mode is not supported on the remote side + PrvNotSupported = 3; + // Privacy mode on by self + PrvOnSucceeded = 4; + // Privacy mode on by self, but denied + PrvOnFailedDenied = 5; + // Some plugins are not found + PrvOnFailedPlugin = 6; + // Privacy mode on by self, but failed + PrvOnFailed = 7; + // Privacy mode off by self + PrvOffSucceeded = 8; + // Ctrl + P + PrvOffByPeer = 9; + // Privacy mode off by self, but failed + PrvOffFailed = 10; + PrvOffUnknown = 11; + } + + oneof union { + PrivacyModeState privacy_mode_state = 1; + BlockInputState block_input_state = 2; + } +} + message Misc { oneof union { ChatMessage chat_message = 4; @@ -456,8 +561,12 @@ message Misc { AudioFormat audio_format = 8; string close_reason = 9; bool refresh_video = 10; - OptionResponse option_response = 11; bool video_received = 12; + BackNotification back_notification = 13; + bool restart_remote_device = 14; + bool uac = 15; + bool foreground_window_elevated = 16; + bool stop_service = 17; } } @@ -481,5 +590,6 @@ message Message { FileResponse file_response = 18; Misc misc = 19; Cliprdr cliprdr = 20; + MessageBox message_box = 21; } } diff --git a/libs/hbb_common/src/config.rs b/libs/hbb_common/src/config.rs index c3a502b..1d427a2 100644 --- a/libs/hbb_common/src/config.rs +++ b/libs/hbb_common/src/config.rs @@ -1,22 +1,35 @@ -use crate::log; -use directories_next::ProjectDirs; -use rand::Rng; -use serde_derive::{Deserialize, Serialize}; -use sodiumoxide::crypto::sign; use std::{ collections::HashMap, fs, - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, path::{Path, PathBuf}, sync::{Arc, Mutex, RwLock}, time::SystemTime, }; +use anyhow::Result; +use rand::Rng; +use regex::Regex; +use serde as de; +use serde_derive::{Deserialize, Serialize}; +use sodiumoxide::base64; +use sodiumoxide::crypto::sign; + +use crate::{ + log, + password_security::{ + decrypt_str_or_original, decrypt_vec_or_original, encrypt_str_or_original, + encrypt_vec_or_original, + }, +}; + pub const RENDEZVOUS_TIMEOUT: u64 = 12_000; pub const CONNECT_TIMEOUT: u64 = 18_000; +pub const READ_TIMEOUT: u64 = 30_000; pub const REG_INTERVAL: i64 = 12_000; pub const COMPRESS_LEVEL: i32 = 3; -const SERIAL: i32 = 1; +const SERIAL: i32 = 3; +const PASSWORD_ENC_VERSION: &'static str = "00"; // 128x128 #[cfg(target_os = "macos")] // 128x128 on 160x160 canvas, then shrink to 128, mac looks better with padding pub const ICON: &str = " @@ -38,12 +51,27 @@ lazy_static::lazy_static! { pub static ref ONLINE: Arc<Mutex<HashMap<String, i64>>> = Default::default(); pub static ref PROD_RENDEZVOUS_SERVER: Arc<RwLock<String>> = Default::default(); pub static ref APP_NAME: Arc<RwLock<String>> = Arc::new(RwLock::new("RustDesk".to_owned())); + static ref KEY_PAIR: Arc<Mutex<Option<(Vec<u8>, Vec<u8>)>>> = Default::default(); + static ref HW_CODEC_CONFIG: Arc<RwLock<HwCodecConfig>> = Arc::new(RwLock::new(HwCodecConfig::load())); } -#[cfg(any(target_os = "android", target_os = "ios"))] + lazy_static::lazy_static! { pub static ref APP_DIR: Arc<RwLock<String>> = Default::default(); +} + +#[cfg(any(target_os = "android", target_os = "ios"))] +lazy_static::lazy_static! { pub static ref APP_HOME_DIR: Arc<RwLock<String>> = Default::default(); } + +// #[cfg(any(target_os = "android", target_os = "ios"))] +lazy_static::lazy_static! { + pub static ref HELPER_URL: HashMap<&'static str, &'static str> = HashMap::from([ + ("rustdesk docs home", "https://rustdesk.com/docs/en/"), + ("rustdesk docs x11-required", "https://rustdesk.com/docs/en/manual/linux/#x11-required"), + ]); +} + const CHARS: &'static [char] = &[ '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'm', 'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', @@ -54,9 +82,30 @@ pub const RENDEZVOUS_SERVERS: &'static [&'static str] = &[ "rs-sg.rustdesk.com", "rs-cn.rustdesk.com", ]; +pub const RS_PUB_KEY: &'static str = "OeVuKk5nlHiXp+APNn0Y3pC1Iwpwn44JGqrQCsWqmBw="; pub const RENDEZVOUS_PORT: i32 = 21116; pub const RELAY_PORT: i32 = 21117; +macro_rules! serde_field_string { + ($default_func:ident, $de_func:ident, $default_expr:expr) => { + fn $default_func() -> String { + $default_expr + } + + fn $de_func<'de, D>(deserializer: D) -> Result<String, D::Error> + where + D: de::Deserializer<'de>, + { + let s: &str = de::Deserialize::deserialize(deserializer)?; + Ok(if s.is_empty() { + Self::$default_func() + } else { + s.to_owned() + }) + } + }; +} + #[derive(Clone, Copy, PartialEq, Eq, Debug)] pub enum NetworkType { Direct, @@ -66,13 +115,15 @@ pub enum NetworkType { #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] pub struct Config { #[serde(default)] - pub id: String, + pub id: String, // use + #[serde(default)] + enc_id: String, // store #[serde(default)] password: String, #[serde(default)] salt: String, #[serde(default)] - pub key_pair: (Vec<u8>, Vec<u8>), // sk, pk + key_pair: (Vec<u8>, Vec<u8>), // sk, pk #[serde(default)] key_confirmed: bool, #[serde(default)] @@ -107,7 +158,7 @@ pub struct Config2 { pub options: HashMap<String, String>, } -#[derive(Debug, Default, Serialize, Deserialize, Clone)] +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] pub struct PeerConfig { #[serde(default)] pub password: Vec<u8>, @@ -117,9 +168,20 @@ pub struct PeerConfig { pub size_ft: Size, #[serde(default)] pub size_pf: Size, - #[serde(default)] - pub view_style: String, // original (default), scale - #[serde(default)] + #[serde( + default = "PeerConfig::default_view_style", + deserialize_with = "PeerConfig::deserialize_view_style" + )] + pub view_style: String, + #[serde( + default = "PeerConfig::default_scroll_style", + deserialize_with = "PeerConfig::deserialize_scroll_style" + )] + pub scroll_style: String, + #[serde( + default = "PeerConfig::default_image_quality", + deserialize_with = "PeerConfig::deserialize_image_quality" + )] pub image_quality: String, #[serde(default)] pub custom_image_quality: Vec<i32>, @@ -139,12 +201,21 @@ pub struct PeerConfig { pub disable_clipboard: bool, #[serde(default)] pub enable_file_transfer: bool, - - // the other scalar value must before this #[serde(default)] + pub show_quality_monitor: bool, + #[serde(default)] + pub keyboard_mode: String, + + // The other scalar value must before this + #[serde(default, deserialize_with = "PeerConfig::deserialize_options")] pub options: HashMap<String, String>, + // Various data for flutter ui + #[serde(default)] + pub ui_flutter: HashMap<String, String>, #[serde(default)] pub info: PeerInfoSerde, + #[serde(default)] + pub transfer: TransferSerde, } #[derive(Debug, PartialEq, Default, Serialize, Deserialize, Clone)] @@ -157,6 +228,14 @@ pub struct PeerInfoSerde { pub platform: String, } +#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] +pub struct TransferSerde { + #[serde(default)] + pub write_jobs: Vec<String>, + #[serde(default)] + pub read_jobs: Vec<String>, +} + fn patch(path: PathBuf) -> PathBuf { if let Some(_tmp) = path.to_str() { #[cfg(windows)] @@ -188,7 +267,17 @@ fn patch(path: PathBuf) -> PathBuf { impl Config2 { fn load() -> Config2 { - Config::load_::<Config2>("2") + let mut config = Config::load_::<Config2>("2"); + if let Some(mut socks) = config.socks { + let (password, _, store) = + decrypt_str_or_original(&socks.password, PASSWORD_ENC_VERSION); + socks.password = password; + config.socks = Some(socks); + if store { + config.store(); + } + } + config } pub fn file() -> PathBuf { @@ -196,7 +285,12 @@ impl Config2 { } fn store(&self) { - Config::store_(self, "2"); + let mut config = self.clone(); + if let Some(mut socks) = config.socks { + socks.password = encrypt_str_or_original(&socks.password, PASSWORD_ENC_VERSION); + config.socks = Some(socks); + } + Config::store_(&config, "2"); } pub fn get() -> Config2 { @@ -227,6 +321,11 @@ pub fn load_path<T: serde::Serialize + serde::de::DeserializeOwned + Default + s cfg } +#[inline] +pub fn store_path<T: serde::Serialize>(path: PathBuf, cfg: T) -> crate::ResultType<()> { + Ok(confy::store_path(path, cfg)?) +} + impl Config { fn load_<T: serde::Serialize + serde::de::DeserializeOwned + Default + std::fmt::Debug>( suffix: &str, @@ -235,24 +334,68 @@ impl Config { log::debug!("Configuration path: {}", file.display()); let cfg = load_path(file); if suffix.is_empty() { - log::debug!("{:?}", cfg); + log::trace!("{:?}", cfg); } cfg } fn store_<T: serde::Serialize>(config: &T, suffix: &str) { let file = Self::file_(suffix); - if let Err(err) = confy::store_path(file, config) { + if let Err(err) = store_path(file, config) { log::error!("Failed to store config: {}", err); } } fn load() -> Config { - Config::load_::<Config>("") + let mut config = Config::load_::<Config>(""); + let mut store = false; + let (password, _, store1) = decrypt_str_or_original(&config.password, PASSWORD_ENC_VERSION); + config.password = password; + store |= store1; + let mut id_valid = false; + let (id, encrypted, store2) = decrypt_str_or_original(&config.enc_id, PASSWORD_ENC_VERSION); + if encrypted { + config.id = id; + id_valid = true; + store |= store2; + } else { + if crate::get_modified_time(&Self::file_("")) + .checked_sub(std::time::Duration::from_secs(30)) // allow modification during installation + .unwrap_or(crate::get_exe_time()) + < crate::get_exe_time() + { + if !config.id.is_empty() + && config.enc_id.is_empty() + && !decrypt_str_or_original(&config.id, PASSWORD_ENC_VERSION).1 + { + id_valid = true; + store = true; + } + } + } + if !id_valid { + for _ in 0..3 { + if let Some(id) = Config::get_auto_id() { + config.id = id; + store = true; + break; + } else { + log::error!("Failed to generate new id"); + } + } + } + if store { + config.store(); + } + config } fn store(&self) { - Config::store_(self, ""); + let mut config = self.clone(); + config.password = encrypt_str_or_original(&config.password, PASSWORD_ENC_VERSION); + config.enc_id = encrypt_str_or_original(&config.id, PASSWORD_ENC_VERSION); + config.id = "".to_owned(); + Config::store_(&config, ""); } pub fn file() -> PathBuf { @@ -264,15 +407,22 @@ impl Config { Config::with_extension(Self::path(name)) } + pub fn is_empty(&self) -> bool { + (self.id.is_empty() && self.enc_id.is_empty()) || self.key_pair.0.is_empty() + } + pub fn get_home() -> PathBuf { #[cfg(any(target_os = "android", target_os = "ios"))] return Self::path(APP_HOME_DIR.read().unwrap().as_str()); - if let Some(path) = dirs_next::home_dir() { - patch(path) - } else if let Ok(path) = std::env::current_dir() { - path - } else { - std::env::temp_dir() + #[cfg(not(any(target_os = "android", target_os = "ios")))] + { + if let Some(path) = dirs_next::home_dir() { + patch(path) + } else if let Ok(path) = std::env::current_dir() { + path + } else { + std::env::temp_dir() + } } } @@ -283,17 +433,22 @@ impl Config { path.push(p); return path; } - #[cfg(not(target_os = "macos"))] - let org = ""; - #[cfg(target_os = "macos")] - let org = ORG.read().unwrap().clone(); - // /var/root for root - if let Some(project) = ProjectDirs::from("", &org, &*APP_NAME.read().unwrap()) { - let mut path = patch(project.config_dir().to_path_buf()); - path.push(p); - return path; + #[cfg(not(any(target_os = "android", target_os = "ios")))] + { + #[cfg(not(target_os = "macos"))] + let org = ""; + #[cfg(target_os = "macos")] + let org = ORG.read().unwrap().clone(); + // /var/root for root + if let Some(project) = + directories_next::ProjectDirs::from("", &org, &*APP_NAME.read().unwrap()) + { + let mut path = patch(project.config_dir().to_path_buf()); + path.push(p); + return path; + } + return "".into(); } - return "".into(); } #[allow(unreachable_code)] @@ -356,8 +511,12 @@ impl Config { } #[inline] - pub fn get_any_listen_addr() -> SocketAddr { - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) + pub fn get_any_listen_addr(is_ipv4: bool) -> SocketAddr { + if is_ipv4 { + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) + } else { + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) + } } pub fn get_rendezvous_server() -> String { @@ -472,22 +631,25 @@ impl Config { .to_string(), ); } - let mut id = 0u32; + #[cfg(not(any(target_os = "android", target_os = "ios")))] - if let Ok(Some(ma)) = mac_address::get_mac_address() { - for x in &ma.bytes()[2..] { - id = (id << 8) | (*x as u32); + { + let mut id = 0u32; + if let Ok(Some(ma)) = mac_address::get_mac_address() { + for x in &ma.bytes()[2..] { + id = (id << 8) | (*x as u32); + } + id = id & 0x1FFFFFFF; + Some(id.to_string()) + } else { + None } - id = id & 0x1FFFFFFF; - Some(id.to_string()) - } else { - None } } - pub fn get_auto_password() -> String { + pub fn get_auto_password(length: usize) -> String { let mut rng = rand::thread_rng(); - (0..6) + (0..length) .map(|_| CHARS[rng.gen::<usize>() % CHARS.len()]) .collect() } @@ -525,24 +687,26 @@ impl Config { config.store(); } - pub fn set_key_pair(pair: (Vec<u8>, Vec<u8>)) { - let mut config = CONFIG.write().unwrap(); - if config.key_pair == pair { - return; - } - config.key_pair = pair; - config.store(); - } - pub fn get_key_pair() -> (Vec<u8>, Vec<u8>) { // lock here to make sure no gen_keypair more than once - let mut config = CONFIG.write().unwrap(); + // no use of CONFIG directly here to ensure no recursive calling in Config::load because of password dec which calling this function + let mut lock = KEY_PAIR.lock().unwrap(); + if let Some(p) = lock.as_ref() { + return p.clone(); + } + let mut config = Config::load_::<Config>(""); if config.key_pair.0.is_empty() { let (pk, sk) = sign::gen_keypair(); - config.key_pair = (sk.0.to_vec(), pk.0.into()); - config.store(); + let key_pair = (sk.0.to_vec(), pk.0.into()); + config.key_pair = key_pair.clone(); + std::thread::spawn(|| { + let mut config = CONFIG.write().unwrap(); + config.key_pair = key_pair; + config.store(); + }); } - config.key_pair.clone() + *lock = Some(config.key_pair.clone()); + return config.key_pair; } pub fn get_id() -> String { @@ -608,7 +772,7 @@ impl Config { log::info!("id updated from {} to {}", id, new_id); } - pub fn set_password(password: &str) { + pub fn set_permanent_password(password: &str) { let mut config = CONFIG.write().unwrap(); if password == config.password { return; @@ -617,13 +781,8 @@ impl Config { config.store(); } - pub fn get_password() -> String { - let mut password = CONFIG.read().unwrap().password.clone(); - if password.is_empty() { - password = Config::get_auto_password(); - Config::set_password(&password); - } - password + pub fn get_permanent_password() -> String { + CONFIG.read().unwrap().password.clone() } pub fn set_salt(salt: &str) { @@ -638,7 +797,7 @@ impl Config { pub fn get_salt() -> String { let mut salt = CONFIG.read().unwrap().salt.clone(); if salt.is_empty() { - salt = Config::get_auto_password(); + salt = Config::get_auto_password(6); Config::set_salt(&salt); } salt @@ -693,9 +852,30 @@ const PEERS: &str = "peers"; impl PeerConfig { pub fn load(id: &str) -> PeerConfig { - let _unused = CONFIG.read().unwrap(); // for lock + let _lock = CONFIG.read().unwrap(); match confy::load_path(&Self::path(id)) { - Ok(config) => config, + Ok(config) => { + let mut config: PeerConfig = config; + let mut store = false; + let (password, _, store2) = + decrypt_vec_or_original(&config.password, PASSWORD_ENC_VERSION); + config.password = password; + store = store || store2; + config.options.get_mut("rdp_password").map(|v| { + let (password, _, store2) = decrypt_str_or_original(v, PASSWORD_ENC_VERSION); + *v = password; + store = store || store2; + }); + config.options.get_mut("os-password").map(|v| { + let (password, _, store2) = decrypt_str_or_original(v, PASSWORD_ENC_VERSION); + *v = password; + store = store || store2; + }); + if store { + config.store(id); + } + config + } Err(err) => { log::error!("Failed to load config: {}", err); Default::default() @@ -704,8 +884,18 @@ impl PeerConfig { } pub fn store(&self, id: &str) { - let _unused = CONFIG.read().unwrap(); // for lock - if let Err(err) = confy::store_path(Self::path(id), self) { + let _lock = CONFIG.read().unwrap(); + let mut config = self.clone(); + config.password = encrypt_vec_or_original(&config.password, PASSWORD_ENC_VERSION); + config + .options + .get_mut("rdp_password") + .map(|v| *v = encrypt_str_or_original(v, PASSWORD_ENC_VERSION)); + config + .options + .get_mut("os-password") + .map(|v| *v = encrypt_str_or_original(v, PASSWORD_ENC_VERSION)); + if let Err(err) = store_path(Self::path(id), config) { log::error!("Failed to store config: {}", err); } } @@ -715,7 +905,17 @@ impl PeerConfig { } fn path(id: &str) -> PathBuf { - let path: PathBuf = [PEERS, id].iter().collect(); + let id_encoded: String; + + //If the id contains invalid chars, encode it + let forbidden_paths = Regex::new(r".*[<>:/\\|\?\*].*").unwrap(); + if forbidden_paths.is_match(id) { + id_encoded = + "base64_".to_string() + base64::encode(id, base64::Variant::Original).as_str(); + } else { + id_encoded = id.to_string(); + } + let path: PathBuf = [PEERS, id_encoded.as_str()].iter().collect(); Config::with_extension(Config::path(path)) } @@ -738,11 +938,22 @@ impl PeerConfig { .map(|p| p.to_str().unwrap_or("")) .unwrap_or("") .to_owned(); - let c = PeerConfig::load(&id); + + let id_decoded_string: String; + if id.starts_with("base64_") && id.len() != 7 { + let id_decoded = base64::decode(&id[7..], base64::Variant::Original) + .unwrap_or(Vec::new()); + id_decoded_string = + String::from_utf8_lossy(&id_decoded).as_ref().to_owned(); + } else { + id_decoded_string = id; + } + + let c = PeerConfig::load(&id_decoded_string); if c.info.platform.is_empty() { fs::remove_file(&p).ok(); } - (id, t, c) + (id_decoded_string, t, c) }) .filter(|p| !p.2.info.platform.is_empty()) .collect(); @@ -752,6 +963,33 @@ impl PeerConfig { } Default::default() } + + serde_field_string!( + default_view_style, + deserialize_view_style, + "original".to_owned() + ); + serde_field_string!( + default_scroll_style, + deserialize_scroll_style, + "scrollauto".to_owned() + ); + serde_field_string!( + default_image_quality, + deserialize_image_quality, + "balanced".to_owned() + ); + + fn deserialize_options<'de, D>(deserializer: D) -> Result<HashMap<String, String>, D::Error> + where + D: de::Deserializer<'de>, + { + let mut mp: HashMap<String, String> = de::Deserialize::deserialize(deserializer)?; + if !mp.contains_key("codec-preference") { + mp.insert("codec-preference".to_owned(), "auto".to_owned()); + } + Ok(mp) + } } #[derive(Debug, Default, Serialize, Deserialize, Clone)] @@ -759,11 +997,16 @@ pub struct LocalConfig { #[serde(default)] remote_id: String, // latest used one #[serde(default)] + kb_layout_type: String, + #[serde(default)] size: Size, #[serde(default)] pub fav: Vec<String>, #[serde(default)] options: HashMap<String, String>, + // Various data for flutter ui + #[serde(default)] + ui_flutter: HashMap<String, String>, } impl LocalConfig { @@ -775,6 +1018,16 @@ impl LocalConfig { Config::store_(self, "_local"); } + pub fn get_kb_layout_type() -> String { + LOCAL_CONFIG.read().unwrap().kb_layout_type.clone() + } + + pub fn set_kb_layout_type(kb_layout_type: String) { + let mut config = LOCAL_CONFIG.write().unwrap(); + config.kb_layout_type = kb_layout_type; + config.store(); + } + pub fn get_size() -> Size { LOCAL_CONFIG.read().unwrap().size } @@ -835,17 +1088,59 @@ impl LocalConfig { config.store(); } } + + pub fn get_flutter_config(k: &str) -> String { + if let Some(v) = LOCAL_CONFIG.read().unwrap().ui_flutter.get(k) { + v.clone() + } else { + "".to_owned() + } + } + + pub fn set_flutter_config(k: String, v: String) { + let mut config = LOCAL_CONFIG.write().unwrap(); + let v2 = if v.is_empty() { None } else { Some(&v) }; + if v2 != config.ui_flutter.get(&k) { + if v2.is_none() { + config.ui_flutter.remove(&k); + } else { + config.ui_flutter.insert(k, v); + } + config.store(); + } + } } #[derive(Debug, Default, Serialize, Deserialize, Clone)] -pub struct LanPeers { +pub struct DiscoveryPeer { + #[serde(default)] + pub id: String, + #[serde(default)] + pub username: String, + #[serde(default)] + pub hostname: String, + #[serde(default)] + pub platform: String, #[serde(default)] - pub peers: String, + pub online: bool, + #[serde(default)] + pub ip_mac: HashMap<String, String>, +} + +impl DiscoveryPeer { + pub fn is_same_peer(&self, other: &DiscoveryPeer) -> bool { + self.id == other.id && self.username == other.username + } +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct LanPeers { + pub peers: Vec<DiscoveryPeer>, } impl LanPeers { pub fn load() -> LanPeers { - let _unused = CONFIG.read().unwrap(); // for lock + let _lock = CONFIG.read().unwrap(); match confy::load_path(&Config::file_("_lan_peers")) { Ok(peers) => peers, Err(err) => { @@ -855,9 +1150,11 @@ impl LanPeers { } } - pub fn store(peers: String) { - let f = LanPeers { peers }; - if let Err(err) = confy::store_path(Config::file_("_lan_peers"), f) { + pub fn store(peers: &Vec<DiscoveryPeer>) { + let f = LanPeers { + peers: peers.clone(), + }; + if let Err(err) = store_path(Config::file_("_lan_peers"), f) { log::error!("Failed to store lan peers: {}", err); } } @@ -871,9 +1168,40 @@ impl LanPeers { } } +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct HwCodecConfig { + #[serde(default)] + pub options: HashMap<String, String>, +} + +impl HwCodecConfig { + pub fn load() -> HwCodecConfig { + Config::load_::<HwCodecConfig>("_hwcodec") + } + + pub fn store(&self) { + Config::store_(self, "_hwcodec"); + } + + pub fn remove() { + std::fs::remove_file(Config::file_("_hwcodec")).ok(); + } + + /// refresh current global HW_CODEC_CONFIG, usually uesd after HwCodecConfig::remove() + pub fn refresh() { + *HW_CODEC_CONFIG.write().unwrap() = HwCodecConfig::load(); + log::debug!("HW_CODEC_CONFIG refreshed successfully"); + } + + pub fn get() -> HwCodecConfig { + return HW_CODEC_CONFIG.read().unwrap().clone(); + } +} + #[cfg(test)] mod tests { use super::*; + #[test] fn test_serialize() { let cfg: Config = Default::default(); diff --git a/libs/hbb_common/src/fs.rs b/libs/hbb_common/src/fs.rs index 69cd348..fec8b86 100644 --- a/libs/hbb_common/src/fs.rs +++ b/libs/hbb_common/src/fs.rs @@ -1,13 +1,17 @@ -use crate::{bail, message_proto::*, ResultType}; +#[cfg(windows)] +use std::os::windows::prelude::*; use std::path::{Path, PathBuf}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use serde_derive::{Deserialize, Serialize}; +use tokio::{fs::File, io::*}; + +use crate::{bail, get_version_number, message_proto::*, ResultType, Stream}; // https://doc.rust-lang.org/std/os/windows/fs/trait.MetadataExt.html use crate::{ compress::{compress, decompress}, config::{Config, COMPRESS_LEVEL}, }; -#[cfg(windows)] -use std::os::windows::prelude::*; -use tokio::{fs::File, io::*}; pub fn read_dir(path: &PathBuf, include_hidden: bool) -> ResultType<FileDirectory> { let mut dir = FileDirectory { @@ -184,16 +188,63 @@ pub fn get_recursive_files(path: &str, include_hidden: bool) -> ResultType<Vec<F read_dir_recursive(&get_path(path), &get_path(""), include_hidden) } +#[inline] +pub fn is_file_exists(file_path: &str) -> bool { + return Path::new(file_path).exists(); +} + +#[inline] +pub fn can_enable_overwrite_detection(version: i64) -> bool { + version >= get_version_number("1.1.10") +} + #[derive(Default)] pub struct TransferJob { - id: i32, - path: PathBuf, - files: Vec<FileEntry>, - file_num: i32, + pub id: i32, + pub remote: String, + pub path: PathBuf, + pub show_hidden: bool, + pub is_remote: bool, + pub is_last_job: bool, + pub file_num: i32, + pub files: Vec<FileEntry>, + file: Option<File>, total_size: u64, finished_size: u64, transferred: u64, + enable_overwrite_detection: bool, + file_confirmed: bool, + // indicating the last file is skipped + file_skipped: bool, + file_is_waiting: bool, + default_overwrite_strategy: Option<bool>, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct TransferJobMeta { + #[serde(default)] + pub id: i32, + #[serde(default)] + pub remote: String, + #[serde(default)] + pub to: String, + #[serde(default)] + pub show_hidden: bool, + #[serde(default)] + pub file_num: i32, + #[serde(default)] + pub is_remote: bool, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct RemoveJobMeta { + #[serde(default)] + pub path: String, + #[serde(default)] + pub is_remote: bool, + #[serde(default)] + pub no_confirm: bool, } #[inline] @@ -219,25 +270,54 @@ fn is_compressed_file(name: &str) -> bool { } impl TransferJob { - pub fn new_write(id: i32, path: String, files: Vec<FileEntry>) -> Self { + pub fn new_write( + id: i32, + remote: String, + path: String, + file_num: i32, + show_hidden: bool, + is_remote: bool, + files: Vec<FileEntry>, + enable_overwrite_detection: bool, + ) -> Self { + log::info!("new write {}", path); let total_size = files.iter().map(|x| x.size as u64).sum(); Self { id, + remote, path: get_path(&path), + file_num, + show_hidden, + is_remote, files, total_size, + enable_overwrite_detection, ..Default::default() } } - pub fn new_read(id: i32, path: String, include_hidden: bool) -> ResultType<Self> { - let files = get_recursive_files(&path, include_hidden)?; + pub fn new_read( + id: i32, + remote: String, + path: String, + file_num: i32, + show_hidden: bool, + is_remote: bool, + enable_overwrite_detection: bool, + ) -> ResultType<Self> { + log::info!("new read {}", path); + let files = get_recursive_files(&path, show_hidden)?; let total_size = files.iter().map(|x| x.size as u64).sum(); Ok(Self { id, + remote, path: get_path(&path), + file_num, + show_hidden, + is_remote, files, total_size, + enable_overwrite_detection, ..Default::default() }) } @@ -302,7 +382,7 @@ impl TransferJob { } } - pub async fn write(&mut self, block: FileTransferBlock, raw: Option<&[u8]>) -> ResultType<()> { + pub async fn write(&mut self, block: FileTransferBlock) -> ResultType<()> { if block.id != self.id { bail!("Wrong id"); } @@ -324,25 +404,20 @@ impl TransferJob { let path = format!("{}.download", get_string(&path)); self.file = Some(File::create(&path).await?); } - let data = if let Some(data) = raw { - data - } else { - &block.data - }; if block.compressed { - let tmp = decompress(data); + let tmp = decompress(&block.data); self.file.as_mut().unwrap().write_all(&tmp).await?; self.finished_size += tmp.len() as u64; } else { - self.file.as_mut().unwrap().write_all(data).await?; - self.finished_size += data.len() as u64; + self.file.as_mut().unwrap().write_all(&block.data).await?; + self.finished_size += block.data.len() as u64; } - self.transferred += data.len() as u64; + self.transferred += block.data.len() as u64; Ok(()) } #[inline] - fn join(&self, name: &str) -> PathBuf { + pub fn join(&self, name: &str) -> PathBuf { if name.is_empty() { self.path.clone() } else { @@ -350,7 +425,7 @@ impl TransferJob { } } - pub async fn read(&mut self) -> ResultType<Option<FileTransferBlock>> { + pub async fn read(&mut self, stream: &mut Stream) -> ResultType<Option<FileTransferBlock>> { let file_num = self.file_num as usize; if file_num >= self.files.len() { self.file.take(); @@ -361,13 +436,26 @@ impl TransferJob { match File::open(self.join(&name)).await { Ok(file) => { self.file = Some(file); + self.file_confirmed = false; + self.file_is_waiting = false; } Err(err) => { self.file_num += 1; + self.file_confirmed = false; + self.file_is_waiting = false; return Err(err.into()); } } } + if self.enable_overwrite_detection { + if !self.file_confirmed() { + if !self.file_is_waiting() { + self.send_current_digest(stream).await?; + self.set_file_is_waiting(true); + } + return Ok(None); + } + } const BUF_SIZE: usize = 128 * 1024; let mut buf: Vec<u8> = Vec::with_capacity(BUF_SIZE); unsafe { @@ -380,6 +468,8 @@ impl TransferJob { Err(err) => { self.file_num += 1; self.file = None; + self.file_confirmed = false; + self.file_is_waiting = false; return Err(err.into()); } Ok(n) => { @@ -394,6 +484,8 @@ impl TransferJob { if offset == 0 { self.file_num += 1; self.file = None; + self.file_confirmed = false; + self.file_is_waiting = false; } else { self.finished_size += offset as u64; if !is_compressed_file(name) { @@ -413,6 +505,139 @@ impl TransferJob { ..Default::default() })) } + + async fn send_current_digest(&mut self, stream: &mut Stream) -> ResultType<()> { + let mut msg = Message::new(); + let mut resp = FileResponse::new(); + let meta = self.file.as_ref().unwrap().metadata().await?; + let last_modified = meta + .modified()? + .duration_since(SystemTime::UNIX_EPOCH)? + .as_secs(); + resp.set_digest(FileTransferDigest { + id: self.id, + file_num: self.file_num, + last_modified, + file_size: meta.len(), + ..Default::default() + }); + msg.set_file_response(resp); + stream.send(&msg).await?; + log::info!( + "id: {}, file_num:{}, digest message is sent. waiting for confirm. msg: {:?}", + self.id, + self.file_num, + msg + ); + Ok(()) + } + + pub fn set_overwrite_strategy(&mut self, overwrite_strategy: Option<bool>) { + self.default_overwrite_strategy = overwrite_strategy; + } + + pub fn default_overwrite_strategy(&self) -> Option<bool> { + self.default_overwrite_strategy + } + + pub fn set_file_confirmed(&mut self, file_confirmed: bool) { + log::info!("id: {}, file_confirmed: {}", self.id, file_confirmed); + self.file_confirmed = file_confirmed; + self.file_skipped = false; + } + + pub fn set_file_is_waiting(&mut self, file_is_waiting: bool) { + self.file_is_waiting = file_is_waiting; + } + + #[inline] + pub fn file_is_waiting(&self) -> bool { + self.file_is_waiting + } + + #[inline] + pub fn file_confirmed(&self) -> bool { + self.file_confirmed + } + + /// Indicating whether the last file is skipped + #[inline] + pub fn file_skipped(&self) -> bool { + self.file_skipped + } + + /// Indicating whether the whole task is skipped + #[inline] + pub fn job_skipped(&self) -> bool { + self.file_skipped() && self.files.len() == 1 + } + + /// Check whether the job is completed after `read` returns `None` + /// This is a helper function which gives additional lifecycle when the job reads `None`. + /// If returns `true`, it means we can delete the job automatically. `False` otherwise. + /// + /// [`Note`] + /// Conditions: + /// 1. Files are not waiting for confirmation by peers. + #[inline] + pub fn job_completed(&self) -> bool { + // has no error, Condition 2 + if !self.enable_overwrite_detection || (!self.file_confirmed && !self.file_is_waiting) { + return true; + } + return false; + } + + /// Get job error message, useful for getting status when job had finished + pub fn job_error(&self) -> Option<String> { + if self.job_skipped() { + return Some("skipped".to_string()); + } + None + } + + pub fn set_file_skipped(&mut self) -> bool { + log::debug!("skip file {} in job {}", self.file_num, self.id); + self.file.take(); + self.set_file_confirmed(false); + self.set_file_is_waiting(false); + self.file_num += 1; + self.file_skipped = true; + true + } + + pub fn confirm(&mut self, r: &FileTransferSendConfirmRequest) -> bool { + if self.file_num() != r.file_num { + log::info!("file num truncated, ignoring"); + } else { + match r.union { + Some(file_transfer_send_confirm_request::Union::Skip(s)) => { + if s { + self.set_file_skipped(); + } else { + self.set_file_confirmed(true); + } + } + Some(file_transfer_send_confirm_request::Union::OffsetBlk(_offset)) => { + self.set_file_confirmed(true); + } + _ => {} + } + } + true + } + + #[inline] + pub fn gen_meta(&self) -> TransferJobMeta { + TransferJobMeta { + id: self.id, + remote: self.remote.to_string(), + to: self.path.to_string_lossy().to_string(), + file_num: self.file_num, + show_hidden: self.show_hidden, + is_remote: self.is_remote, + } + } } #[inline] @@ -453,12 +678,22 @@ pub fn new_block(block: FileTransferBlock) -> Message { } #[inline] -pub fn new_receive(id: i32, path: String, files: Vec<FileEntry>) -> Message { +pub fn new_send_confirm(r: FileTransferSendConfirmRequest) -> Message { + let mut msg_out = Message::new(); + let mut action = FileAction::new(); + action.set_send_confirm(r); + msg_out.set_file_action(action); + msg_out +} + +#[inline] +pub fn new_receive(id: i32, path: String, file_num: i32, files: Vec<FileEntry>) -> Message { let mut action = FileAction::new(); action.set_receive(FileTransferReceiveRequest { id, path, files: files.into(), + file_num, ..Default::default() }); let mut msg_out = Message::new(); @@ -467,12 +702,14 @@ pub fn new_receive(id: i32, path: String, files: Vec<FileEntry>) -> Message { } #[inline] -pub fn new_send(id: i32, path: String, include_hidden: bool) -> Message { +pub fn new_send(id: i32, path: String, file_num: i32, include_hidden: bool) -> Message { + log::info!("new send: {},id : {}", path, id); let mut action = FileAction::new(); action.set_send(FileTransferSendRequest { id, path, include_hidden, + file_num, ..Default::default() }); let mut msg_out = Message::new(); @@ -509,7 +746,10 @@ pub async fn handle_read_jobs( ) -> ResultType<()> { let mut finished = Vec::new(); for job in jobs.iter_mut() { - match job.read().await { + if job.is_last_job { + continue; + } + match job.read(stream).await { Err(err) => { stream .send(&new_error(job.id(), err, job.file_num())) @@ -519,8 +759,19 @@ pub async fn handle_read_jobs( stream.send(&new_block(block)).await?; } Ok(None) => { - finished.push(job.id()); - stream.send(&new_done(job.id(), job.file_num())).await?; + if job.job_completed() { + finished.push(job.id()); + let err = job.job_error(); + if err.is_some() { + stream + .send(&new_error(job.id(), err.unwrap(), job.file_num())) + .await?; + } else { + stream.send(&new_done(job.id(), job.file_num())).await?; + } + } else { + // waiting confirmation. + } } } } @@ -566,3 +817,34 @@ pub fn transform_windows_path(entries: &mut Vec<FileEntry>) { } } +pub enum DigestCheckResult { + IsSame, + NeedConfirm(FileTransferDigest), + NoSuchFile, +} + +#[inline] +pub fn is_write_need_confirmation( + file_path: &str, + digest: &FileTransferDigest, +) -> ResultType<DigestCheckResult> { + let path = Path::new(file_path); + if path.exists() && path.is_file() { + let metadata = std::fs::metadata(path)?; + let modified_time = metadata.modified()?; + let remote_mt = Duration::from_secs(digest.last_modified); + let local_mt = modified_time.duration_since(UNIX_EPOCH)?; + if remote_mt == local_mt && digest.file_size == metadata.len() { + return Ok(DigestCheckResult::IsSame); + } + Ok(DigestCheckResult::NeedConfirm(FileTransferDigest { + id: digest.id, + file_num: digest.file_num, + last_modified: local_mt.as_secs(), + file_size: metadata.len(), + ..Default::default() + })) + } else { + Ok(DigestCheckResult::NoSuchFile) + } +} diff --git a/libs/hbb_common/src/lib.rs b/libs/hbb_common/src/lib.rs index 5f23e46..85e0100 100644 --- a/libs/hbb_common/src/lib.rs +++ b/libs/hbb_common/src/lib.rs @@ -1,15 +1,16 @@ pub mod compress; -#[path = "./protos/message.rs"] -pub mod message_proto; -#[path = "./protos/rendezvous.rs"] -pub mod rendezvous_proto; +pub mod platform; +pub mod protos; pub use bytes; +use config::Config; pub use futures; pub use protobuf; +pub use protos::message as message_proto; +pub use protos::rendezvous as rendezvous_proto; use std::{ fs::File, io::{self, BufRead}, - net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, path::Path, time::{self, SystemTime, UNIX_EPOCH}, }; @@ -27,6 +28,7 @@ pub use anyhow::{self, bail}; pub use futures_util; pub mod config; pub mod fs; +pub use lazy_static; #[cfg(not(any(target_os = "android", target_os = "ios")))] pub use mac_address; pub use rand; @@ -35,6 +37,9 @@ pub use sodiumoxide; pub use tokio_socks; pub use tokio_socks::IntoTargetAddr; pub use tokio_socks::TargetAddr; +pub mod password_security; +pub use chrono; +pub use directories_next; #[cfg(feature = "quic")] pub type Stream = quic::Connection; @@ -61,6 +66,21 @@ macro_rules! allow_err { } else { } }; + + ($e:expr, $($arg:tt)*) => { + if let Err(err) = $e { + log::debug!( + "{:?}, {}, {}:{}:{}:{}", + err, + format_args!($($arg)*), + module_path!(), + file!(), + line!(), + column!() + ); + } else { + } + }; } #[inline] @@ -97,13 +117,31 @@ impl AddrMangle { } bytes[..(16 - n_padding)].to_vec() } - _ => { - panic!("Only support ipv4"); + SocketAddr::V6(addr_v6) => { + let mut x = addr_v6.ip().octets().to_vec(); + let port: [u8; 2] = addr_v6.port().to_le_bytes(); + x.push(port[0]); + x.push(port[1]); + x } } } pub fn decode(bytes: &[u8]) -> SocketAddr { + if bytes.len() > 16 { + if bytes.len() != 18 { + return Config::get_any_listen_addr(false); + } + #[allow(invalid_value)] + let mut tmp: [u8; 2] = unsafe { std::mem::MaybeUninit::uninit().assume_init() }; + tmp.copy_from_slice(&bytes[16..]); + let port = u16::from_le_bytes(tmp); + #[allow(invalid_value)] + let mut tmp: [u8; 16] = unsafe { std::mem::MaybeUninit::uninit().assume_init() }; + tmp.copy_from_slice(&bytes[..16]); + let ip = std::net::Ipv6Addr::from(tmp); + return SocketAddr::new(IpAddr::V6(ip), port); + } let mut padded = [0u8; 16]; padded[..bytes.len()].copy_from_slice(&bytes); let number = u128::from_le_bytes(padded); @@ -156,19 +194,23 @@ pub fn get_version_from_url(url: &str) -> String { } pub fn gen_version() { + use std::io::prelude::*; let mut file = File::create("./src/version.rs").unwrap(); for line in read_lines("Cargo.toml").unwrap() { if let Ok(line) = line { let ab: Vec<&str> = line.split("=").map(|x| x.trim()).collect(); if ab.len() == 2 && ab[0] == "version" { - use std::io::prelude::*; - file.write_all(format!("pub const VERSION: &str = {};", ab[1]).as_bytes()) + file.write_all(format!("pub const VERSION: &str = {};\n", ab[1]).as_bytes()) .ok(); - file.sync_all().ok(); break; } } } + // generate build date + let build_date = format!("{}", chrono::Local::now().format("%Y-%m-%d %H:%M")); + file.write_all(format!("pub const BUILD_DATE: &str = \"{}\";", build_date).as_bytes()) + .ok(); + file.sync_all().ok(); } fn read_lines<P>(filename: P) -> io::Result<io::Lines<io::BufReader<File>>> @@ -199,6 +241,40 @@ pub fn get_modified_time(path: &std::path::Path) -> SystemTime { .unwrap_or(UNIX_EPOCH) } +pub fn get_created_time(path: &std::path::Path) -> SystemTime { + std::fs::metadata(&path) + .map(|m| m.created().unwrap_or(UNIX_EPOCH)) + .unwrap_or(UNIX_EPOCH) +} + +pub fn get_exe_time() -> SystemTime { + std::env::current_exe().map_or(UNIX_EPOCH, |path| { + let m = get_modified_time(&path); + let c = get_created_time(&path); + if m > c { + m + } else { + c + } + }) +} + +pub fn get_uuid() -> Vec<u8> { + #[cfg(not(any(target_os = "android", target_os = "ios")))] + if let Ok(id) = machine_uid::get() { + return id.into(); + } + Config::get_key_pair().1 +} + +#[inline] +pub fn get_time() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis()) + .unwrap_or(0) as _ +} + #[cfg(test)] mod tests { use super::*; @@ -206,5 +282,61 @@ mod tests { fn test_mangle() { let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 16, 32), 21116)); assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); + + let addr = "[2001:db8::1]:8080".parse::<SocketAddr>().unwrap(); + assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); + + let addr = "[2001:db8:ff::1111]:80".parse::<SocketAddr>().unwrap(); + assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); + } + + #[test] + fn test_allow_err() { + allow_err!(Err("test err") as Result<(), &str>); + allow_err!( + Err("test err with msg") as Result<(), &str>, + "prompt {}", + "failed" + ); + } +} + +#[inline] +pub fn is_ipv4_str(id: &str) -> bool { + regex::Regex::new(r"^\d+\.\d+\.\d+\.\d+(:\d+)?$") + .unwrap() + .is_match(id) +} + +#[inline] +pub fn is_ipv6_str(id: &str) -> bool { + regex::Regex::new(r"^((([a-fA-F0-9]{1,4}:{1,2})+[a-fA-F0-9]{1,4})|(\[([a-fA-F0-9]{1,4}:{1,2})+[a-fA-F0-9]{1,4}\]:\d+))$") + .unwrap() + .is_match(id) +} + +#[inline] +pub fn is_ip_str(id: &str) -> bool { + is_ipv4_str(id) || is_ipv6_str(id) +} + +#[cfg(test)] +mod test_lib { + use super::*; + + #[test] + fn test_ipv6() { + assert_eq!(is_ipv6_str("1:2:3"), true); + assert_eq!(is_ipv6_str("[ab:2:3]:12"), true); + assert_eq!(is_ipv6_str("[ABEF:2a:3]:12"), true); + assert_eq!(is_ipv6_str("[ABEG:2a:3]:12"), false); + assert_eq!(is_ipv6_str("1[ab:2:3]:12"), false); + assert_eq!(is_ipv6_str("1.1.1.1"), false); + assert_eq!(is_ip_str("1.1.1.1"), true); + assert_eq!(is_ipv6_str("1:2:"), false); + assert_eq!(is_ipv6_str("1:2::0"), true); + assert_eq!(is_ipv6_str("[1:2::0]:1"), true); + assert_eq!(is_ipv6_str("[1:2::0]:"), false); + assert_eq!(is_ipv6_str("1:2::0]:1"), false); } } diff --git a/libs/hbb_common/src/password_security.rs b/libs/hbb_common/src/password_security.rs new file mode 100644 index 0000000..6029069 --- /dev/null +++ b/libs/hbb_common/src/password_security.rs @@ -0,0 +1,242 @@ +use crate::config::Config; +use sodiumoxide::base64; +use std::sync::{Arc, RwLock}; + +lazy_static::lazy_static! { + pub static ref TEMPORARY_PASSWORD:Arc<RwLock<String>> = Arc::new(RwLock::new(Config::get_auto_password(temporary_password_length()))); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum VerificationMethod { + OnlyUseTemporaryPassword, + OnlyUsePermanentPassword, + UseBothPasswords, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ApproveMode { + Both, + Password, + Click, +} + +// Should only be called in server +pub fn update_temporary_password() { + *TEMPORARY_PASSWORD.write().unwrap() = Config::get_auto_password(temporary_password_length()); +} + +// Should only be called in server +pub fn temporary_password() -> String { + TEMPORARY_PASSWORD.read().unwrap().clone() +} + +fn verification_method() -> VerificationMethod { + let method = Config::get_option("verification-method"); + if method == "use-temporary-password" { + VerificationMethod::OnlyUseTemporaryPassword + } else if method == "use-permanent-password" { + VerificationMethod::OnlyUsePermanentPassword + } else { + VerificationMethod::UseBothPasswords // default + } +} + +pub fn temporary_password_length() -> usize { + let length = Config::get_option("temporary-password-length"); + if length == "8" { + 8 + } else if length == "10" { + 10 + } else { + 6 // default + } +} + +pub fn temporary_enabled() -> bool { + verification_method() != VerificationMethod::OnlyUsePermanentPassword +} + +pub fn permanent_enabled() -> bool { + verification_method() != VerificationMethod::OnlyUseTemporaryPassword +} + +pub fn has_valid_password() -> bool { + temporary_enabled() && !temporary_password().is_empty() + || permanent_enabled() && !Config::get_permanent_password().is_empty() +} + +pub fn approve_mode() -> ApproveMode { + let mode = Config::get_option("approve-mode"); + if mode == "password" { + ApproveMode::Password + } else if mode == "click" { + ApproveMode::Click + } else { + ApproveMode::Both + } +} + +pub fn hide_cm() -> bool { + approve_mode() == ApproveMode::Password + && verification_method() == VerificationMethod::OnlyUsePermanentPassword + && !Config::get_option("allow-hide-cm").is_empty() +} + +const VERSION_LEN: usize = 2; + +pub fn encrypt_str_or_original(s: &str, version: &str) -> String { + if decrypt_str_or_original(s, version).1 { + log::error!("Duplicate encryption!"); + return s.to_owned(); + } + if version == "00" { + if let Ok(s) = encrypt(s.as_bytes()) { + return version.to_owned() + &s; + } + } + s.to_owned() +} + +// String: password +// bool: whether decryption is successful +// bool: whether should store to re-encrypt when load +pub fn decrypt_str_or_original(s: &str, current_version: &str) -> (String, bool, bool) { + if s.len() > VERSION_LEN { + let version = &s[..VERSION_LEN]; + if version == "00" { + if let Ok(v) = decrypt(&s[VERSION_LEN..].as_bytes()) { + return ( + String::from_utf8_lossy(&v).to_string(), + true, + version != current_version, + ); + } + } + } + + (s.to_owned(), false, !s.is_empty()) +} + +pub fn encrypt_vec_or_original(v: &[u8], version: &str) -> Vec<u8> { + if decrypt_vec_or_original(v, version).1 { + log::error!("Duplicate encryption!"); + return v.to_owned(); + } + if version == "00" { + if let Ok(s) = encrypt(v) { + let mut version = version.to_owned().into_bytes(); + version.append(&mut s.into_bytes()); + return version; + } + } + v.to_owned() +} + +// Vec<u8>: password +// bool: whether decryption is successful +// bool: whether should store to re-encrypt when load +pub fn decrypt_vec_or_original(v: &[u8], current_version: &str) -> (Vec<u8>, bool, bool) { + if v.len() > VERSION_LEN { + let version = String::from_utf8_lossy(&v[..VERSION_LEN]); + if version == "00" { + if let Ok(v) = decrypt(&v[VERSION_LEN..]) { + return (v, true, version != current_version); + } + } + } + + (v.to_owned(), false, !v.is_empty()) +} + +fn encrypt(v: &[u8]) -> Result<String, ()> { + if v.len() > 0 { + symmetric_crypt(v, true).map(|v| base64::encode(v, base64::Variant::Original)) + } else { + Err(()) + } +} + +fn decrypt(v: &[u8]) -> Result<Vec<u8>, ()> { + if v.len() > 0 { + base64::decode(v, base64::Variant::Original).and_then(|v| symmetric_crypt(&v, false)) + } else { + Err(()) + } +} + +fn symmetric_crypt(data: &[u8], encrypt: bool) -> Result<Vec<u8>, ()> { + use sodiumoxide::crypto::secretbox; + use std::convert::TryInto; + + let mut keybuf = crate::get_uuid(); + keybuf.resize(secretbox::KEYBYTES, 0); + let key = secretbox::Key(keybuf.try_into().map_err(|_| ())?); + let nonce = secretbox::Nonce([0; secretbox::NONCEBYTES]); + + if encrypt { + Ok(secretbox::seal(data, &nonce, &key)) + } else { + secretbox::open(data, &nonce, &key) + } +} + +mod test { + + #[test] + fn test() { + use super::*; + + let version = "00"; + + println!("test str"); + let data = "Hello World"; + let encrypted = encrypt_str_or_original(data, version); + let (decrypted, succ, store) = decrypt_str_or_original(&encrypted, version); + println!("data: {}", data); + println!("encrypted: {}", encrypted); + println!("decrypted: {}", decrypted); + assert_eq!(data, decrypted); + assert_eq!(version, &encrypted[..2]); + assert_eq!(succ, true); + assert_eq!(store, false); + let (_, _, store) = decrypt_str_or_original(&encrypted, "99"); + assert_eq!(store, true); + assert_eq!(decrypt_str_or_original(&decrypted, version).1, false); + assert_eq!(encrypt_str_or_original(&encrypted, version), encrypted); + + println!("test vec"); + let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6]; + let encrypted = encrypt_vec_or_original(&data, version); + let (decrypted, succ, store) = decrypt_vec_or_original(&encrypted, version); + println!("data: {:?}", data); + println!("encrypted: {:?}", encrypted); + println!("decrypted: {:?}", decrypted); + assert_eq!(data, decrypted); + assert_eq!(version.as_bytes(), &encrypted[..2]); + assert_eq!(store, false); + assert_eq!(succ, true); + let (_, _, store) = decrypt_vec_or_original(&encrypted, "99"); + assert_eq!(store, true); + assert_eq!(decrypt_vec_or_original(&decrypted, version).1, false); + assert_eq!(encrypt_vec_or_original(&encrypted, version), encrypted); + + println!("test original"); + let data = version.to_string() + "Hello World"; + let (decrypted, succ, store) = decrypt_str_or_original(&data, version); + assert_eq!(data, decrypted); + assert_eq!(store, true); + assert_eq!(succ, false); + let verbytes = version.as_bytes(); + let data: Vec<u8> = vec![verbytes[0] as u8, verbytes[1] as u8, 1, 2, 3, 4, 5, 6]; + let (decrypted, succ, store) = decrypt_vec_or_original(&data, version); + assert_eq!(data, decrypted); + assert_eq!(store, true); + assert_eq!(succ, false); + let (_, succ, store) = decrypt_str_or_original("", version); + assert_eq!(store, false); + assert_eq!(succ, false); + let (_, succ, store) = decrypt_vec_or_original(&vec![], version); + assert_eq!(store, false); + assert_eq!(succ, false); + } +} diff --git a/libs/hbb_common/src/platform/linux.rs b/libs/hbb_common/src/platform/linux.rs new file mode 100644 index 0000000..4c6375d --- /dev/null +++ b/libs/hbb_common/src/platform/linux.rs @@ -0,0 +1,157 @@ +use crate::ResultType; + +lazy_static::lazy_static! { + pub static ref DISTRO: Disto = Disto::new(); +} + +pub struct Disto { + pub name: String, + pub version_id: String, +} + +impl Disto { + fn new() -> Self { + let name = run_cmds("awk -F'=' '/^NAME=/ {print $2}' /etc/os-release".to_owned()) + .unwrap_or_default() + .trim() + .trim_matches('"') + .to_string(); + let version_id = + run_cmds("awk -F'=' '/^VERSION_ID=/ {print $2}' /etc/os-release".to_owned()) + .unwrap_or_default() + .trim() + .trim_matches('"') + .to_string(); + Self { name, version_id } + } +} + +pub fn get_display_server() -> String { + let mut session = get_values_of_seat0([0].to_vec())[0].clone(); + if session.is_empty() { + // loginctl has not given the expected output. try something else. + if let Ok(sid) = std::env::var("XDG_SESSION_ID") { + // could also execute "cat /proc/self/sessionid" + session = sid.to_owned(); + } + if session.is_empty() { + session = run_cmds("cat /proc/self/sessionid".to_owned()).unwrap_or_default(); + } + } + + get_display_server_of_session(&session) +} + +fn get_display_server_of_session(session: &str) -> String { + let mut display_server = if let Ok(output) = + run_loginctl(Some(vec!["show-session", "-p", "Type", session])) + // Check session type of the session + { + let display_server = String::from_utf8_lossy(&output.stdout) + .replace("Type=", "") + .trim_end() + .into(); + if display_server == "tty" { + // If the type is tty... + if let Ok(output) = run_loginctl(Some(vec!["show-session", "-p", "TTY", session])) + // Get the tty number + { + let tty: String = String::from_utf8_lossy(&output.stdout) + .replace("TTY=", "") + .trim_end() + .into(); + if let Ok(xorg_results) = run_cmds(format!("ps -e | grep \"{}.\\\\+Xorg\"", tty)) + // And check if Xorg is running on that tty + { + if xorg_results.trim_end().to_string() != "" { + // If it is, manually return "x11", otherwise return tty + return "x11".to_owned(); + } + } + } + } + display_server + } else { + "".to_owned() + }; + if display_server.is_empty() { + // loginctl has not given the expected output. try something else. + if let Ok(sestype) = std::env::var("XDG_SESSION_TYPE") { + display_server = sestype; + } + } + // If the session is not a tty, then just return the type as usual + display_server +} + +pub fn get_values_of_seat0(indices: Vec<usize>) -> Vec<String> { + if let Ok(output) = run_loginctl(None) { + for line in String::from_utf8_lossy(&output.stdout).lines() { + if line.contains("seat0") { + if let Some(sid) = line.split_whitespace().nth(0) { + if is_active(sid) { + return indices + .into_iter() + .map(|idx| line.split_whitespace().nth(idx).unwrap_or("").to_owned()) + .collect::<Vec<String>>(); + } + } + } + } + } + + // some case, there is no seat0 https://github.com/rustdesk/rustdesk/issues/73 + if let Ok(output) = run_loginctl(None) { + for line in String::from_utf8_lossy(&output.stdout).lines() { + if let Some(sid) = line.split_whitespace().nth(0) { + let d = get_display_server_of_session(sid); + if is_active(sid) && d != "tty" { + return indices + .into_iter() + .map(|idx| line.split_whitespace().nth(idx).unwrap_or("").to_owned()) + .collect::<Vec<String>>(); + } + } + } + } + + return indices + .iter() + .map(|_x| "".to_owned()) + .collect::<Vec<String>>(); +} + +fn is_active(sid: &str) -> bool { + if let Ok(output) = run_loginctl(Some(vec!["show-session", "-p", "State", sid])) { + String::from_utf8_lossy(&output.stdout).contains("active") + } else { + false + } +} + +pub fn run_cmds(cmds: String) -> ResultType<String> { + let output = std::process::Command::new("sh") + .args(vec!["-c", &cmds]) + .output()?; + Ok(String::from_utf8_lossy(&output.stdout).to_string()) +} + +#[cfg(not(feature = "flatpak"))] +fn run_loginctl(args: Option<Vec<&str>>) -> std::io::Result<std::process::Output> { + let mut cmd = std::process::Command::new("loginctl"); + if let Some(a) = args { + return cmd.args(a).output(); + } + cmd.output() +} + +#[cfg(feature = "flatpak")] +fn run_loginctl(args: Option<Vec<&str>>) -> std::io::Result<std::process::Output> { + let mut l_args = String::from("loginctl"); + if let Some(a) = args { + l_args = format!("{} {}", l_args, a.join(" ")); + } + std::process::Command::new("flatpak-spawn") + .args(vec![String::from("--host"), l_args]) + .output() +} diff --git a/libs/hbb_common/src/platform/mod.rs b/libs/hbb_common/src/platform/mod.rs new file mode 100644 index 0000000..8daba25 --- /dev/null +++ b/libs/hbb_common/src/platform/mod.rs @@ -0,0 +1,2 @@ +#[cfg(target_os = "linux")] +pub mod linux; diff --git a/libs/hbb_common/src/socket_client.rs b/libs/hbb_common/src/socket_client.rs index 72ab73f..b7cb137 100644 --- a/libs/hbb_common/src/socket_client.rs +++ b/libs/hbb_common/src/socket_client.rs @@ -9,31 +9,15 @@ use std::net::SocketAddr; use tokio::net::ToSocketAddrs; use tokio_socks::{IntoTargetAddr, TargetAddr}; -fn to_socket_addr(host: &str) -> ResultType<SocketAddr> { - use std::net::ToSocketAddrs; - host.to_socket_addrs()? - .filter(|x| x.is_ipv4()) - .next() - .context("Failed to solve") -} - -pub fn get_target_addr(host: &str) -> ResultType<TargetAddr<'static>> { - let addr = match Config::get_network_type() { - NetworkType::Direct => to_socket_addr(&host)?.into_target_addr()?, - NetworkType::ProxySocks => host.into_target_addr()?, - } - .to_owned(); - Ok(addr) -} - pub fn test_if_valid_server(host: &str) -> String { let mut host = host.to_owned(); if !host.contains(":") { host = format!("{}:{}", host, 0); } + use std::net::ToSocketAddrs; match Config::get_network_type() { - NetworkType::Direct => match to_socket_addr(&host) { + NetworkType::Direct => match host.to_socket_addrs() { Err(err) => err.to_string(), Ok(_) => "".to_owned(), }, @@ -44,33 +28,126 @@ pub fn test_if_valid_server(host: &str) -> String { } } -pub async fn connect_tcp<'t, T: IntoTargetAddr<'t>>( +pub trait IsResolvedSocketAddr { + fn resolve(&self) -> Option<&SocketAddr>; +} + +impl IsResolvedSocketAddr for SocketAddr { + fn resolve(&self) -> Option<&SocketAddr> { + Some(&self) + } +} + +impl IsResolvedSocketAddr for String { + fn resolve(&self) -> Option<&SocketAddr> { + None + } +} + +impl IsResolvedSocketAddr for &str { + fn resolve(&self) -> Option<&SocketAddr> { + None + } +} + +#[inline] +pub async fn connect_tcp< + 't, + T: IntoTargetAddr<'t> + ToSocketAddrs + IsResolvedSocketAddr + std::fmt::Display, +>( target: T, - local: SocketAddr, ms_timeout: u64, ) -> ResultType<FramedStream> { - let target_addr = target.into_target_addr()?; + connect_tcp_local(target, None, ms_timeout).await +} +pub async fn connect_tcp_local< + 't, + T: IntoTargetAddr<'t> + ToSocketAddrs + IsResolvedSocketAddr + std::fmt::Display, +>( + target: T, + local: Option<SocketAddr>, + ms_timeout: u64, +) -> ResultType<FramedStream> { if let Some(conf) = Config::get_socks() { - FramedStream::connect( + return FramedStream::connect( conf.proxy.as_str(), - target_addr, + target, local, conf.username.as_str(), conf.password.as_str(), ms_timeout, ) - .await - } else { - let addr = std::net::ToSocketAddrs::to_socket_addrs(&target_addr)? - .filter(|x| x.is_ipv4()) - .next() - .context("Invalid target addr, no valid ipv4 address can be resolved.")?; - Ok(FramedStream::new(addr, local, ms_timeout).await?) + .await; } + if let Some(target) = target.resolve() { + if let Some(local) = local { + if local.is_ipv6() && target.is_ipv4() { + let target = query_nip_io(&target).await?; + return Ok(FramedStream::new(target, Some(local), ms_timeout).await?); + } + } + } + Ok(FramedStream::new(target, local, ms_timeout).await?) } -pub async fn new_udp<T: ToSocketAddrs>(local: T, ms_timeout: u64) -> ResultType<FramedSocket> { +#[inline] +pub fn is_ipv4(target: &TargetAddr<'_>) -> bool { + match target { + TargetAddr::Ip(addr) => addr.is_ipv4(), + _ => true, + } +} + +#[inline] +pub async fn query_nip_io(addr: &SocketAddr) -> ResultType<SocketAddr> { + tokio::net::lookup_host(format!("{}.nip.io:{}", addr.ip(), addr.port())) + .await? + .filter(|x| x.is_ipv6()) + .next() + .context("Failed to get ipv6 from nip.io") +} + +#[inline] +pub fn ipv4_to_ipv6(addr: String, ipv4: bool) -> String { + if !ipv4 && crate::is_ipv4_str(&addr) { + if let Some(ip) = addr.split(":").next() { + return addr.replace(ip, &format!("{}.nip.io", ip)); + } + } + addr +} + +async fn test_target(target: &str) -> ResultType<SocketAddr> { + if let Ok(Ok(s)) = super::timeout(1000, tokio::net::TcpStream::connect(target)).await { + if let Ok(addr) = s.peer_addr() { + return Ok(addr); + } + } + tokio::net::lookup_host(target) + .await? + .next() + .context(format!("Failed to look up host for {}", target)) +} + +#[inline] +pub async fn new_udp_for( + target: &str, + ms_timeout: u64, +) -> ResultType<(FramedSocket, TargetAddr<'static>)> { + let (ipv4, target) = if NetworkType::Direct == Config::get_network_type() { + let addr = test_target(target).await?; + (addr.is_ipv4(), addr.into_target_addr()?) + } else { + (true, target.into_target_addr()?) + }; + Ok(( + new_udp(Config::get_any_listen_addr(ipv4), ms_timeout).await?, + target.to_owned(), + )) +} + +async fn new_udp<T: ToSocketAddrs>(local: T, ms_timeout: u64) -> ResultType<FramedSocket> { match Config::get_socks() { None => Ok(FramedSocket::new(local).await?), Some(conf) => { @@ -87,9 +164,56 @@ pub async fn new_udp<T: ToSocketAddrs>(local: T, ms_timeout: u64) -> ResultType< } } -pub async fn rebind_udp<T: ToSocketAddrs>(local: T) -> ResultType<Option<FramedSocket>> { - match Config::get_network_type() { - NetworkType::Direct => Ok(Some(FramedSocket::new(local).await?)), - _ => Ok(None), +pub async fn rebind_udp_for( + target: &str, +) -> ResultType<Option<(FramedSocket, TargetAddr<'static>)>> { + if Config::get_network_type() != NetworkType::Direct { + return Ok(None); + } + let addr = test_target(target).await?; + let v4 = addr.is_ipv4(); + Ok(Some(( + FramedSocket::new(Config::get_any_listen_addr(v4)).await?, + addr.into_target_addr()?.to_owned(), + ))) +} + +#[cfg(test)] +mod tests { + use std::net::ToSocketAddrs; + + use super::*; + + #[test] + fn test_nat64() { + test_nat64_async(); + } + + #[tokio::main(flavor = "current_thread")] + async fn test_nat64_async() { + assert_eq!(ipv4_to_ipv6("1.1.1.1".to_owned(), true), "1.1.1.1"); + assert_eq!(ipv4_to_ipv6("1.1.1.1".to_owned(), false), "1.1.1.1.nip.io"); + assert_eq!( + ipv4_to_ipv6("1.1.1.1:8080".to_owned(), false), + "1.1.1.1.nip.io:8080" + ); + assert_eq!( + ipv4_to_ipv6("rustdesk.com".to_owned(), false), + "rustdesk.com" + ); + if ("rustdesk.com:80") + .to_socket_addrs() + .unwrap() + .next() + .unwrap() + .is_ipv6() + { + assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap()) + .await + .unwrap() + .is_ipv6()); + return; + } + assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap()).await.is_err()); } } diff --git a/libs/hbb_common/src/tcp.rs b/libs/hbb_common/src/tcp.rs index 7966920..a1322fc 100644 --- a/libs/hbb_common/src/tcp.rs +++ b/libs/hbb_common/src/tcp.rs @@ -5,7 +5,7 @@ use protobuf::Message; use sodiumoxide::crypto::secretbox::{self, Key, Nonce}; use std::{ io::{self, Error, ErrorKind}, - net::SocketAddr, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, @@ -73,73 +73,79 @@ fn new_socket(addr: std::net::SocketAddr, reuse: bool) -> Result<TcpSocket, std: } impl FramedStream { - pub async fn new<T1: ToSocketAddrs, T2: ToSocketAddrs>( - remote_addr: T1, - local_addr: T2, + pub async fn new<T: ToSocketAddrs + std::fmt::Display>( + remote_addr: T, + local_addr: Option<SocketAddr>, ms_timeout: u64, ) -> ResultType<Self> { - for local_addr in lookup_host(&local_addr).await? { - for remote_addr in lookup_host(&remote_addr).await? { - let stream = super::timeout( - ms_timeout, - new_socket(local_addr, true)?.connect(remote_addr), - ) - .await??; - stream.set_nodelay(true).ok(); - let addr = stream.local_addr()?; - return Ok(Self( - Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), - addr, - None, - 0, - )); + for remote_addr in lookup_host(&remote_addr).await? { + let local = if let Some(addr) = local_addr { + addr + } else { + crate::config::Config::get_any_listen_addr(remote_addr.is_ipv4()) + }; + if let Ok(socket) = new_socket(local, true) { + if let Ok(Ok(stream)) = + super::timeout(ms_timeout, socket.connect(remote_addr)).await + { + stream.set_nodelay(true).ok(); + let addr = stream.local_addr()?; + return Ok(Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )); + } } } - bail!("could not resolve to any address"); + bail!(format!("Failed to connect to {}", remote_addr)); } - pub async fn connect<'a, 't, P, T1, T2>( + pub async fn connect<'a, 't, P, T>( proxy: P, - target: T1, - local: T2, + target: T, + local_addr: Option<SocketAddr>, username: &'a str, password: &'a str, ms_timeout: u64, ) -> ResultType<Self> where P: ToProxyAddrs, - T1: IntoTargetAddr<'t>, - T2: ToSocketAddrs, + T: IntoTargetAddr<'t>, { - if let Some(local) = lookup_host(&local).await?.next() { - if let Some(proxy) = proxy.to_proxy_addrs().next().await { - let stream = - super::timeout(ms_timeout, new_socket(local, true)?.connect(proxy?)).await??; - stream.set_nodelay(true).ok(); - let stream = if username.trim().is_empty() { - super::timeout( - ms_timeout, - Socks5Stream::connect_with_socket(stream, target), - ) - .await?? - } else { - super::timeout( - ms_timeout, - Socks5Stream::connect_with_password_and_socket( - stream, target, username, password, - ), - ) - .await?? - }; - let addr = stream.local_addr()?; - return Ok(Self( - Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), - addr, - None, - 0, - )); + if let Some(Ok(proxy)) = proxy.to_proxy_addrs().next().await { + let local = if let Some(addr) = local_addr { + addr + } else { + crate::config::Config::get_any_listen_addr(proxy.is_ipv4()) + }; + let stream = + super::timeout(ms_timeout, new_socket(local, true)?.connect(proxy)).await??; + stream.set_nodelay(true).ok(); + let stream = if username.trim().is_empty() { + super::timeout( + ms_timeout, + Socks5Stream::connect_with_socket(stream, target), + ) + .await?? + } else { + super::timeout( + ms_timeout, + Socks5Stream::connect_with_password_and_socket( + stream, target, username, password, + ), + ) + .await?? }; - }; + let addr = stream.local_addr()?; + return Ok(Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )); + } bail!("could not resolve to any address"); } @@ -252,6 +258,38 @@ pub async fn new_listener<T: ToSocketAddrs>(addr: T, reuse: bool) -> ResultType< } } +pub async fn listen_any(port: u16) -> ResultType<TcpListener> { + if let Ok(mut socket) = TcpSocket::new_v6() { + #[cfg(unix)] + { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + let raw_fd = socket.into_raw_fd(); + let sock2 = unsafe { socket2::Socket::from_raw_fd(raw_fd) }; + sock2.set_only_v6(false).ok(); + socket = unsafe { TcpSocket::from_raw_fd(sock2.into_raw_fd()) }; + } + #[cfg(windows)] + { + use std::os::windows::prelude::{FromRawSocket, IntoRawSocket}; + let raw_socket = socket.into_raw_socket(); + let sock2 = unsafe { socket2::Socket::from_raw_socket(raw_socket) }; + sock2.set_only_v6(false).ok(); + socket = unsafe { TcpSocket::from_raw_socket(sock2.into_raw_socket()) }; + } + if socket + .bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port)) + .is_ok() + { + if let Ok(l) = socket.listen(DEFAULT_BACKLOG) { + return Ok(l); + } + } + } + let s = TcpSocket::new_v4()?; + s.bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port))?; + Ok(s.listen(DEFAULT_BACKLOG)?) +} + impl Unpin for DynTcpStream {} impl AsyncRead for DynTcpStream { diff --git a/libs/hbb_common/src/udp.rs b/libs/hbb_common/src/udp.rs index 3532dd1..38121a4 100644 --- a/libs/hbb_common/src/udp.rs +++ b/libs/hbb_common/src/udp.rs @@ -49,7 +49,7 @@ impl FramedSocket { #[allow(clippy::never_loop)] pub async fn new_reuse<T: std::net::ToSocketAddrs>(addr: T) -> ResultType<Self> { - for addr in addr.to_socket_addrs()?.filter(|x| x.is_ipv4()) { + for addr in addr.to_socket_addrs()? { let socket = new_socket(addr, true, 0)?.into_udp_socket(); return Ok(Self::Direct(UdpFramed::new( UdpSocket::from_std(socket)?, @@ -63,7 +63,7 @@ impl FramedSocket { addr: T, buf_size: usize, ) -> ResultType<Self> { - for addr in addr.to_socket_addrs()?.filter(|x| x.is_ipv4()) { + for addr in addr.to_socket_addrs()? { return Ok(Self::Direct(UdpFramed::new( UdpSocket::from_std(new_socket(addr, false, buf_size)?.into_udp_socket())?, BytesCodec::new(), @@ -164,4 +164,13 @@ impl FramedSocket { None } } + + pub fn is_ipv4(&self) -> bool { + if let FramedSocket::Direct(x) = self { + if let Ok(v) = x.get_ref().local_addr() { + return v.is_ipv4(); + } + } + true + } } |