diff options
author | rustdesk <[email protected]> | 2023-01-27 11:37:15 +0800 |
---|---|---|
committer | rustdesk <[email protected]> | 2023-01-27 11:37:15 +0800 |
commit | a974906fdca2eb88d84e16bca1f89bc85accc1e3 (patch) | |
tree | 6bc4f8bf7be80388c07f570710337e624231e3b0 | |
parent | 17ddc89bd025dd0304b858f39a264fe707ea40bc (diff) | |
parent | 088a009078abf1cae873b600ac4a50a9a2bfbace (diff) | |
download | rustdesk-server-a974906fdca2eb88d84e16bca1f89bc85accc1e3.tar.gz rustdesk-server-a974906fdca2eb88d84e16bca1f89bc85accc1e3.zip |
Merge branch 'master' into tmp
32 files changed, 669 insertions, 521 deletions
diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index e37f85b..c227484 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -26,7 +26,7 @@ jobs: build: name: Build - ${{ matrix.job.name }} - runs-on: ubuntu-22.04 + runs-on: ubuntu-18.04 strategy: fail-fast: false matrix: @@ -114,7 +114,7 @@ jobs: needs: - build - build-win - runs-on: ubuntu-22.04 + runs-on: ubuntu-18.04 strategy: fail-fast: false matrix: @@ -153,7 +153,7 @@ jobs: name: Docker push - ${{ matrix.job.name }} needs: build - runs-on: ubuntu-22.04 + runs-on: ubuntu-18.04 strategy: fail-fast: false matrix: @@ -223,7 +223,7 @@ jobs: name: Docker manifest needs: docker - runs-on: ubuntu-22.04 + runs-on: ubuntu-18.04 steps: @@ -275,7 +275,7 @@ jobs: name: Docker push classic - ${{ matrix.job.name }} needs: build - runs-on: ubuntu-22.04 + runs-on: ubuntu-18.04 strategy: fail-fast: false matrix: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..d425a10 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,72 @@ +name: test + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +jobs: + check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + - uses: Swatinem/rust-cache@v2 + - uses: actions-rs/cargo@v1 + with: + command: check + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + - uses: Swatinem/rust-cache@v2 + - uses: actions-rs/cargo@v1 + with: + command: test + args: --all + + fmt: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + components: rustfmt + - uses: Swatinem/rust-cache@v2 + - uses: actions-rs/cargo@v1 + with: + command: build + - uses: actions-rs/cargo@v1 + with: + command: fmt + args: --all -- --check + + clippy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + components: clippy + - uses: Swatinem/rust-cache@v2 + - uses: actions-rs/cargo@v1 + with: + command: clippy + args: --all -- -D warnings @@ -6,3 +6,5 @@ debian/.debhelper debian/debhelper-build-stamp .DS_Store .vscode +src/version.rs +db_v2.sqlite3 @@ -785,7 +785,7 @@ dependencies = [ [[package]] name = "hbbs" -version = "1.1.6" +version = "1.1.7" dependencies = [ "async-speed-limit", "async-trait", @@ -1,6 +1,6 @@ [package] name = "hbbs" -version = "1.1.6" +version = "1.1.7" authors = ["open-trade <[email protected]>"] edition = "2021" build = "build.rs" @@ -173,16 +173,14 @@ services: restart: unless-stopped ``` -We use these environment variables: +For this container image, you can use these environment variables, **in addition** to the ones specified in the following **ENV variables** section: | variable | optional | description | | --- | --- | --- | | RELAY | no | the IP address/DNS name of the machine running this container | | ENCRYPTED_ONLY | yes | if set to **"1"** unencrypted connection will not be accepted | -| DB_URL | yes | path for database file | | KEY_PUB | yes | public part of the key pair | | KEY_PRIV | yes | private part of the key pair | -| RUST_LOG | yes | set debug level (error|warn|info|debug|trace) | ### Secret management in S6-overlay based images @@ -316,3 +314,22 @@ These packages are meant for the following distributions: - Ubuntu 18.04 LTS - Debian 11 bullseye - Debian 10 buster + +## ENV variables + +hbbs and hbbr can be configured using these ENV variables. +You can specify the variables as usual or use an `.env` file. + +| variable | binary | description | +| --- | --- | --- | +| ALWAYS_USE_RELAY | hbbs | if set to **"Y"** disallows direct peer connection | +| DB_URL | hbbs | path for database file | +| DOWNGRADE_START_CHECK | hbbr | delay (in seconds) before downgrade check | +| DOWNGRADE_THRESHOLD | hbbr | threshold of downgrade check (bit/ms) | +| KEY | hbbs/hbbr | if set force the use of a specific key, if set to **"_"** force the use of any key | +| LIMIT_SPEED | hbbr | speed limit (in Mb/s) | +| PORT | hbbs/hbbr | listening port (21116 for hbbs - 21117 for hbbr) | +| RELAY_SERVERS | hbbs | IP address/DNS name of the machines running hbbr (separated by comma) | +| RUST_LOG | all | set debug level (error\|warn\|info\|debug\|trace) | +| SINGLE_BANDWIDTH | hbbr | max bandwidth for a single connection (in Mb/s) | +| TOTAL_BANDWIDTH | hbbr | max total bandwidth (in Mb/s) | diff --git a/db_v2.sqlite3 b/db_v2.sqlite3 Binary files differindex 3d9350d..c95a2f3 100644 --- a/db_v2.sqlite3 +++ b/db_v2.sqlite3 diff --git a/debian/changelog b/debian/changelog index c9918b0..db3af9f 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +rustdesk-server (1.1.7) UNRELEASED; urgency=medium + + * ipv6 support + + -- rustdesk <[email protected]> Wed, 11 Jan 2023 11:27:00 +0800 + rustdesk-server (1.1.6) UNRELEASED; urgency=medium * Initial release diff --git a/debian/rustdesk-server-hbbr.postinst b/debian/rustdesk-server-hbbr.postinst index f65944f..27fbbd5 100644 --- a/debian/rustdesk-server-hbbr.postinst +++ b/debian/rustdesk-server-hbbr.postinst @@ -3,6 +3,10 @@ set -e SERVICE=rustdesk-hbbr.service +if [ "$1" = "configure" ]; then + mkdir -p /var/log/rustdesk +fi + case "$1" in configure|abort-upgrade|abort-deconfigure|abort-remove) mkdir -p /var/lib/rustdesk-server/ diff --git a/debian/rustdesk-server-hbbr.postrm b/debian/rustdesk-server-hbbr.postrm index b5922d0..0420d84 100644 --- a/debian/rustdesk-server-hbbr.postrm +++ b/debian/rustdesk-server-hbbr.postrm @@ -6,7 +6,7 @@ SERVICE=rustdesk-hbbr.service systemctl --system daemon-reload >/dev/null || true if [ "$1" = "purge" ]; then - rm -rf /var/lib/rustdesk-server/ + rm -rf /var/lib/rustdesk-server/ /var/log/rustdesk/rustdesk-hbbr.* /var/log/rustdesk/rustdesk-hbbs.* deb-systemd-helper purge "${SERVICE}" >/dev/null || true deb-systemd-helper unmask "${SERVICE}" >/dev/null || true fi diff --git a/debian/rustdesk-server-hbbs.postinst b/debian/rustdesk-server-hbbs.postinst index b0004f2..df42813 100644 --- a/debian/rustdesk-server-hbbs.postinst +++ b/debian/rustdesk-server-hbbs.postinst @@ -3,6 +3,10 @@ set -e SERVICE=rustdesk-hbbs.service +if [ "$1" = "configure" ]; then + mkdir -p /var/log/rustdesk +fi + case "$1" in configure|abort-upgrade|abort-deconfigure|abort-remove) mkdir -p /var/lib/rustdesk-server/ diff --git a/docker/rootfs/etc/s6-overlay/s6-rc.d/hbbr/run b/docker/rootfs/etc/s6-overlay/s6-rc.d/hbbr/run index eae3611..c7e1b19 100755 --- a/docker/rootfs/etc/s6-overlay/s6-rc.d/hbbr/run +++ b/docker/rootfs/etc/s6-overlay/s6-rc.d/hbbr/run @@ -1,3 +1,3 @@ -#!/command/execlineb -P -posix-cd /data +#!/command/with-contenv sh +cd /data /usr/bin/hbbr diff --git a/libs/hbb_common/build.rs b/libs/hbb_common/build.rs index bff0cfa..fe0d310 100644 --- a/libs/hbb_common/build.rs +++ b/libs/hbb_common/build.rs @@ -8,10 +8,7 @@ fn main() { .out_dir(out_dir) .inputs(&["protos/rendezvous.proto", "protos/message.proto"]) .include("protos") - .customize( - protobuf_codegen::Customize::default() - .tokio_bytes(true) - ) + .customize(protobuf_codegen::Customize::default().tokio_bytes(true)) .run() .expect("Codegen failed."); } diff --git a/libs/hbb_common/src/bytes_codec.rs b/libs/hbb_common/src/bytes_codec.rs index e029f1c..699aa9b 100644 --- a/libs/hbb_common/src/bytes_codec.rs +++ b/libs/hbb_common/src/bytes_codec.rs @@ -15,6 +15,12 @@ enum DecodeState { Data(usize), } +impl Default for BytesCodec { + fn default() -> Self { + Self::new() + } +} + impl BytesCodec { pub fn new() -> Self { Self { @@ -56,7 +62,7 @@ impl BytesCodec { } src.advance(head_len); src.reserve(n); - return Ok(Some(n)); + Ok(Some(n)) } fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> { diff --git a/libs/hbb_common/src/compress.rs b/libs/hbb_common/src/compress.rs index a969ccf..e7668a9 100644 --- a/libs/hbb_common/src/compress.rs +++ b/libs/hbb_common/src/compress.rs @@ -32,12 +32,7 @@ pub fn decompress(data: &[u8]) -> Vec<u8> { const MAX: usize = 1024 * 1024 * 64; const MIN: usize = 1024 * 1024; let mut n = 30 * data.len(); - if n > MAX { - n = MAX; - } - if n < MIN { - n = MIN; - } + n = n.clamp(MIN, MAX); match d.decompress(data, n) { Ok(res) => out = res, Err(err) => { diff --git a/libs/hbb_common/src/config.rs b/libs/hbb_common/src/config.rs index 20334ed..943917a 100644 --- a/libs/hbb_common/src/config.rs +++ b/libs/hbb_common/src/config.rs @@ -29,7 +29,7 @@ pub const READ_TIMEOUT: u64 = 30_000; pub const REG_INTERVAL: i64 = 12_000; pub const COMPRESS_LEVEL: i32 = 3; const SERIAL: i32 = 3; -const PASSWORD_ENC_VERSION: &'static str = "00"; +const PASSWORD_ENC_VERSION: &str = "00"; // 128x128 #[cfg(target_os = "macos")] // 128x128 on 160x160 canvas, then shrink to 128, mac looks better with padding pub const ICON: &str = " @@ -43,6 +43,7 @@ lazy_static::lazy_static! { } type Size = (i32, i32, i32, i32); +type KeyPair = (Vec<u8>, Vec<u8>); lazy_static::lazy_static! { static ref CONFIG: Arc<RwLock<Config>> = Arc::new(RwLock::new(Config::load())); @@ -54,7 +55,7 @@ lazy_static::lazy_static! { _ => "", }.to_owned())); pub static ref APP_NAME: Arc<RwLock<String>> = Arc::new(RwLock::new("RustDesk".to_owned())); - static ref KEY_PAIR: Arc<Mutex<Option<(Vec<u8>, Vec<u8>)>>> = Default::default(); + static ref KEY_PAIR: Arc<Mutex<Option<KeyPair>>> = Default::default(); static ref HW_CODEC_CONFIG: Arc<RwLock<HwCodecConfig>> = Arc::new(RwLock::new(HwCodecConfig::load())); } @@ -75,12 +76,12 @@ lazy_static::lazy_static! { ]); } -const CHARS: &'static [char] = &[ +const CHARS: &[char] = &[ '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'm', 'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ]; -const RENDEZVOUS_SERVERS: &'static [&'static str] = &[ +pub const RENDEZVOUS_SERVERS: &[&str] = &[ "rs-ny.rustdesk.com", "rs-sg.rustdesk.com", "rs-cn.rustdesk.com", @@ -131,7 +132,7 @@ pub struct Config { #[serde(default)] salt: String, #[serde(default)] - key_pair: (Vec<u8>, Vec<u8>), // sk, pk + key_pair: KeyPair, // sk, pk #[serde(default)] key_confirmed: bool, #[serde(default)] @@ -319,7 +320,7 @@ impl Config2 { pub fn load_path<T: serde::Serialize + serde::de::DeserializeOwned + Default + std::fmt::Debug>( file: PathBuf, ) -> T { - let cfg = match confy::load_path(&file) { + let cfg = match confy::load_path(file) { Ok(config) => config, Err(err) => { log::error!("Failed to load config: {}", err); @@ -366,20 +367,16 @@ impl Config { config.id = id; id_valid = true; store |= store2; - } else { - if crate::get_modified_time(&Self::file_("")) - .checked_sub(std::time::Duration::from_secs(30)) // allow modification during installation - .unwrap_or(crate::get_exe_time()) - < crate::get_exe_time() - { - if !config.id.is_empty() - && config.enc_id.is_empty() - && !decrypt_str_or_original(&config.id, PASSWORD_ENC_VERSION).1 - { - id_valid = true; - store = true; - } - } + } else if crate::get_modified_time(&Self::file_("")) + .checked_sub(std::time::Duration::from_secs(30)) // allow modification during installation + .unwrap_or_else(crate::get_exe_time) + < crate::get_exe_time() + && !config.id.is_empty() + && config.enc_id.is_empty() + && !decrypt_str_or_original(&config.id, PASSWORD_ENC_VERSION).1 + { + id_valid = true; + store = true; } if !id_valid { for _ in 0..3 { @@ -444,18 +441,18 @@ impl Config { #[cfg(not(any(target_os = "android", target_os = "ios")))] { #[cfg(not(target_os = "macos"))] - let org = ""; + let org = "".to_owned(); #[cfg(target_os = "macos")] let org = ORG.read().unwrap().clone(); // /var/root for root if let Some(project) = - directories_next::ProjectDirs::from("", &org, &*APP_NAME.read().unwrap()) + directories_next::ProjectDirs::from("", &org, &APP_NAME.read().unwrap()) { let mut path = patch(project.config_dir().to_path_buf()); path.push(p); return path; } - return "".into(); + "".into() } } @@ -539,9 +536,9 @@ impl Config { rendezvous_server = Self::get_rendezvous_servers() .drain(..) .next() - .unwrap_or("".to_owned()); + .unwrap_or_default(); } - if !rendezvous_server.contains(":") { + if !rendezvous_server.contains(':') { rendezvous_server = format!("{}:{}", rendezvous_server, RENDEZVOUS_PORT); } rendezvous_server @@ -559,8 +556,8 @@ impl Config { let serial_obsolute = CONFIG2.read().unwrap().serial > SERIAL; if serial_obsolute { let ss: Vec<String> = Self::get_option("rendezvous-servers") - .split(",") - .filter(|x| x.contains(".")) + .split(',') + .filter(|x| x.contains('.')) .map(|x| x.to_owned()) .collect(); if !ss.is_empty() { @@ -580,7 +577,7 @@ impl Config { let mut delay = i64::MAX; for (tmp_host, tmp_delay) in ONLINE.lock().unwrap().iter() { if tmp_delay > &0 && tmp_delay < &delay { - delay = tmp_delay.clone(); + delay = *tmp_delay; host = tmp_host.to_string(); } } @@ -647,7 +644,7 @@ impl Config { for x in &ma.bytes()[2..] { id = (id << 8) | (*x as u32); } - id = id & 0x1FFFFFFF; + id &= 0x1FFFFFFF; Some(id.to_string()) } else { None @@ -679,11 +676,7 @@ impl Config { } pub fn get_host_key_confirmed(host: &str) -> bool { - if let Some(true) = CONFIG.read().unwrap().keys_confirmed.get(host) { - true - } else { - false - } + matches!(CONFIG.read().unwrap().keys_confirmed.get(host), Some(true)) } pub fn set_host_key_confirmed(host: &str, v: bool) { @@ -695,7 +688,7 @@ impl Config { config.store(); } - pub fn get_key_pair() -> (Vec<u8>, Vec<u8>) { + pub fn get_key_pair() -> KeyPair { // lock here to make sure no gen_keypair more than once // no use of CONFIG directly here to ensure no recursive calling in Config::load because of password dec which calling this function let mut lock = KEY_PAIR.lock().unwrap(); @@ -714,7 +707,7 @@ impl Config { }); } *lock = Some(config.key_pair.clone()); - return config.key_pair; + config.key_pair } pub fn get_id() -> String { @@ -849,7 +842,7 @@ impl Config { let ext = path.extension(); if let Some(ext) = ext { let ext = format!("{}.toml", ext.to_string_lossy()); - path.with_extension(&ext) + path.with_extension(ext) } else { path.with_extension("toml") } @@ -861,7 +854,7 @@ const PEERS: &str = "peers"; impl PeerConfig { pub fn load(id: &str) -> PeerConfig { let _lock = CONFIG.read().unwrap(); - match confy::load_path(&Self::path(id)) { + match confy::load_path(Self::path(id)) { Ok(config) => { let mut config: PeerConfig = config; let mut store = false; @@ -869,16 +862,16 @@ impl PeerConfig { decrypt_vec_or_original(&config.password, PASSWORD_ENC_VERSION); config.password = password; store = store || store2; - config.options.get_mut("rdp_password").map(|v| { + if let Some(v) = config.options.get_mut("rdp_password") { let (password, _, store2) = decrypt_str_or_original(v, PASSWORD_ENC_VERSION); *v = password; store = store || store2; - }); - config.options.get_mut("os-password").map(|v| { + } + if let Some(v) = config.options.get_mut("os-password") { let (password, _, store2) = decrypt_str_or_original(v, PASSWORD_ENC_VERSION); *v = password; store = store || store2; - }); + } if store { config.store(id); } @@ -895,34 +888,29 @@ impl PeerConfig { let _lock = CONFIG.read().unwrap(); let mut config = self.clone(); config.password = encrypt_vec_or_original(&config.password, PASSWORD_ENC_VERSION); - config - .options - .get_mut("rdp_password") - .map(|v| *v = encrypt_str_or_original(v, PASSWORD_ENC_VERSION)); - config - .options - .get_mut("os-password") - .map(|v| *v = encrypt_str_or_original(v, PASSWORD_ENC_VERSION)); + if let Some(v) = config.options.get_mut("rdp_password") { + *v = encrypt_str_or_original(v, PASSWORD_ENC_VERSION) + } + if let Some(v) = config.options.get_mut("os-password") { + *v = encrypt_str_or_original(v, PASSWORD_ENC_VERSION) + }; if let Err(err) = store_path(Self::path(id), config) { log::error!("Failed to store config: {}", err); } } pub fn remove(id: &str) { - fs::remove_file(&Self::path(id)).ok(); + fs::remove_file(Self::path(id)).ok(); } fn path(id: &str) -> PathBuf { - let id_encoded: String; - //If the id contains invalid chars, encode it let forbidden_paths = Regex::new(r".*[<>:/\\|\?\*].*").unwrap(); - if forbidden_paths.is_match(id) { - id_encoded = - "base64_".to_string() + base64::encode(id, base64::Variant::Original).as_str(); + let id_encoded = if forbidden_paths.is_match(id) { + "base64_".to_string() + base64::encode(id, base64::Variant::Original).as_str() } else { - id_encoded = id.to_string(); - } + id.to_string() + }; let path: PathBuf = [PEERS, id_encoded.as_str()].iter().collect(); Config::with_extension(Config::path(path)) } @@ -940,26 +928,24 @@ impl PeerConfig { && p.extension().map(|p| p.to_str().unwrap_or("")) == Some("toml") }) .map(|p| { - let t = crate::get_modified_time(&p); + let t = crate::get_modified_time(p); let id = p .file_stem() .map(|p| p.to_str().unwrap_or("")) .unwrap_or("") .to_owned(); - let id_decoded_string: String; - if id.starts_with("base64_") && id.len() != 7 { + let id_decoded_string = if id.starts_with("base64_") && id.len() != 7 { let id_decoded = base64::decode(&id[7..], base64::Variant::Original) - .unwrap_or(Vec::new()); - id_decoded_string = - String::from_utf8_lossy(&id_decoded).as_ref().to_owned(); + .unwrap_or_default(); + String::from_utf8_lossy(&id_decoded).as_ref().to_owned() } else { - id_decoded_string = id; - } + id + }; let c = PeerConfig::load(&id_decoded_string); if c.info.platform.is_empty() { - fs::remove_file(&p).ok(); + fs::remove_file(p).ok(); } (id_decoded_string, t, c) }) @@ -1149,7 +1135,7 @@ pub struct LanPeers { impl LanPeers { pub fn load() -> LanPeers { let _lock = CONFIG.read().unwrap(); - match confy::load_path(&Config::file_("_lan_peers")) { + match confy::load_path(Config::file_("_lan_peers")) { Ok(peers) => peers, Err(err) => { log::error!("Failed to load lan peers: {}", err); @@ -1158,9 +1144,9 @@ impl LanPeers { } } - pub fn store(peers: &Vec<DiscoveryPeer>) { + pub fn store(peers: &[DiscoveryPeer]) { let f = LanPeers { - peers: peers.clone(), + peers: peers.to_owned(), }; if let Err(err) = store_path(Config::file_("_lan_peers"), f) { log::error!("Failed to store lan peers: {}", err); diff --git a/libs/hbb_common/src/fs.rs b/libs/hbb_common/src/fs.rs index fec8b86..ea54e11 100644 --- a/libs/hbb_common/src/fs.rs +++ b/libs/hbb_common/src/fs.rs @@ -13,13 +13,13 @@ use crate::{ config::{Config, COMPRESS_LEVEL}, }; -pub fn read_dir(path: &PathBuf, include_hidden: bool) -> ResultType<FileDirectory> { +pub fn read_dir(path: &Path, include_hidden: bool) -> ResultType<FileDirectory> { let mut dir = FileDirectory { - path: get_string(&path), + path: get_string(path), ..Default::default() }; #[cfg(windows)] - if "/" == &get_string(&path) { + if "/" == &get_string(path) { let drives = unsafe { winapi::um::fileapi::GetLogicalDrives() }; for i in 0..32 { if drives & (1 << i) != 0 { @@ -36,74 +36,70 @@ pub fn read_dir(path: &PathBuf, include_hidden: bool) -> ResultType<FileDirector } return Ok(dir); } - for entry in path.read_dir()? { - if let Ok(entry) = entry { - let p = entry.path(); - let name = p - .file_name() - .map(|p| p.to_str().unwrap_or("")) - .unwrap_or("") - .to_owned(); - if name.is_empty() { - continue; - } - let mut is_hidden = false; - let meta; - if let Ok(tmp) = std::fs::symlink_metadata(&p) { - meta = tmp; - } else { - continue; - } - // docs.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants - #[cfg(windows)] - if meta.file_attributes() & 0x2 != 0 { - is_hidden = true; - } - #[cfg(not(windows))] - if name.find('.').unwrap_or(usize::MAX) == 0 { - is_hidden = true; - } - if is_hidden && !include_hidden { - continue; - } - let (entry_type, size) = { - if p.is_dir() { - if meta.file_type().is_symlink() { - (FileType::DirLink.into(), 0) - } else { - (FileType::Dir.into(), 0) - } + for entry in path.read_dir()?.flatten() { + let p = entry.path(); + let name = p + .file_name() + .map(|p| p.to_str().unwrap_or("")) + .unwrap_or("") + .to_owned(); + if name.is_empty() { + continue; + } + let mut is_hidden = false; + let meta; + if let Ok(tmp) = std::fs::symlink_metadata(&p) { + meta = tmp; + } else { + continue; + } + // docs.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants + #[cfg(windows)] + if meta.file_attributes() & 0x2 != 0 { + is_hidden = true; + } + #[cfg(not(windows))] + if name.find('.').unwrap_or(usize::MAX) == 0 { + is_hidden = true; + } + if is_hidden && !include_hidden { + continue; + } + let (entry_type, size) = { + if p.is_dir() { + if meta.file_type().is_symlink() { + (FileType::DirLink.into(), 0) } else { - if meta.file_type().is_symlink() { - (FileType::FileLink.into(), 0) - } else { - (FileType::File.into(), meta.len()) - } + (FileType::Dir.into(), 0) } - }; - let modified_time = meta - .modified() - .map(|x| { - x.duration_since(std::time::SystemTime::UNIX_EPOCH) - .map(|x| x.as_secs()) - .unwrap_or(0) - }) - .unwrap_or(0) as u64; - dir.entries.push(FileEntry { - name: get_file_name(&p), - entry_type, - is_hidden, - size, - modified_time, - ..Default::default() - }); - } + } else if meta.file_type().is_symlink() { + (FileType::FileLink.into(), 0) + } else { + (FileType::File.into(), meta.len()) + } + }; + let modified_time = meta + .modified() + .map(|x| { + x.duration_since(std::time::SystemTime::UNIX_EPOCH) + .map(|x| x.as_secs()) + .unwrap_or(0) + }) + .unwrap_or(0); + dir.entries.push(FileEntry { + name: get_file_name(&p), + entry_type, + is_hidden, + size, + modified_time, + ..Default::default() + }); } Ok(dir) } #[inline] -pub fn get_file_name(p: &PathBuf) -> String { +pub fn get_file_name(p: &Path) -> String { p.file_name() .map(|p| p.to_str().unwrap_or("")) .unwrap_or("") @@ -111,7 +107,7 @@ pub fn get_file_name(p: &PathBuf) -> String { } #[inline] -pub fn get_string(path: &PathBuf) -> String { +pub fn get_string(path: &Path) -> String { path.to_str().unwrap_or("").to_owned() } @@ -127,14 +123,14 @@ pub fn get_home_as_string() -> String { fn read_dir_recursive( path: &PathBuf, - prefix: &PathBuf, + prefix: &Path, include_hidden: bool, ) -> ResultType<Vec<FileEntry>> { let mut files = Vec::new(); if path.is_dir() { // to-do: symbol link handling, cp the link rather than the content // to-do: file mode, for unix - let fd = read_dir(&path, include_hidden)?; + let fd = read_dir(path, include_hidden)?; for entry in fd.entries.iter() { match entry.entry_type.enum_value() { Ok(FileType::File) => { @@ -158,7 +154,7 @@ fn read_dir_recursive( } Ok(files) } else if path.is_file() { - let (size, modified_time) = if let Ok(meta) = std::fs::metadata(&path) { + let (size, modified_time) = if let Ok(meta) = std::fs::metadata(path) { ( meta.len(), meta.modified() @@ -167,7 +163,7 @@ fn read_dir_recursive( .map(|x| x.as_secs()) .unwrap_or(0) }) - .unwrap_or(0) as u64, + .unwrap_or(0), ) } else { (0, 0) @@ -249,7 +245,7 @@ pub struct RemoveJobMeta { #[inline] fn get_ext(name: &str) -> &str { - if let Some(i) = name.rfind(".") { + if let Some(i) = name.rfind('.') { return &name[i + 1..]; } "" @@ -270,6 +266,7 @@ fn is_compressed_file(name: &str) -> bool { } impl TransferJob { + #[allow(clippy::too_many_arguments)] pub fn new_write( id: i32, remote: String, @@ -281,7 +278,7 @@ impl TransferJob { enable_overwrite_detection: bool, ) -> Self { log::info!("new write {}", path); - let total_size = files.iter().map(|x| x.size as u64).sum(); + let total_size = files.iter().map(|x| x.size).sum(); Self { id, remote, @@ -307,7 +304,7 @@ impl TransferJob { ) -> ResultType<Self> { log::info!("new read {}", path); let files = get_recursive_files(&path, show_hidden)?; - let total_size = files.iter().map(|x| x.size as u64).sum(); + let total_size = files.iter().map(|x| x.size).sum(); Ok(Self { id, remote, @@ -363,7 +360,7 @@ impl TransferJob { let entry = &self.files[file_num]; let path = self.join(&entry.name); let download_path = format!("{}.download", get_string(&path)); - std::fs::rename(&download_path, &path).ok(); + std::fs::rename(download_path, &path).ok(); filetime::set_file_mtime( &path, filetime::FileTime::from_unix_time(entry.modified_time as _, 0), @@ -378,7 +375,7 @@ impl TransferJob { let entry = &self.files[file_num]; let path = self.join(&entry.name); let download_path = format!("{}.download", get_string(&path)); - std::fs::remove_file(&download_path).ok(); + std::fs::remove_file(download_path).ok(); } } @@ -433,7 +430,7 @@ impl TransferJob { } let name = &self.files[file_num].name; if self.file.is_none() { - match File::open(self.join(&name)).await { + match File::open(self.join(name)).await { Ok(file) => { self.file = Some(file); self.file_confirmed = false; @@ -447,20 +444,15 @@ impl TransferJob { } } } - if self.enable_overwrite_detection { - if !self.file_confirmed() { - if !self.file_is_waiting() { - self.send_current_digest(stream).await?; - self.set_file_is_waiting(true); - } - return Ok(None); + if self.enable_overwrite_detection && !self.file_confirmed() { + if !self.file_is_waiting() { + self.send_current_digest(stream).await?; + self.set_file_is_waiting(true); } + return Ok(None); } const BUF_SIZE: usize = 128 * 1024; - let mut buf: Vec<u8> = Vec::with_capacity(BUF_SIZE); - unsafe { - buf.set_len(BUF_SIZE); - } + let mut buf: Vec<u8> = vec![0; BUF_SIZE]; let mut compressed = false; let mut offset: usize = 0; loop { @@ -582,10 +574,7 @@ impl TransferJob { #[inline] pub fn job_completed(&self) -> bool { // has no error, Condition 2 - if !self.enable_overwrite_detection || (!self.file_confirmed && !self.file_is_waiting) { - return true; - } - return false; + !self.enable_overwrite_detection || (!self.file_confirmed && !self.file_is_waiting) } /// Get job error message, useful for getting status when job had finished @@ -660,7 +649,7 @@ pub fn new_dir(id: i32, path: String, files: Vec<FileEntry>) -> Message { resp.set_dir(FileDirectory { id, path, - entries: files.into(), + entries: files, ..Default::default() }); let mut msg_out = Message::new(); @@ -692,7 +681,7 @@ pub fn new_receive(id: i32, path: String, file_num: i32, files: Vec<FileEntry>) action.set_receive(FileTransferReceiveRequest { id, path, - files: files.into(), + files, file_num, ..Default::default() }); @@ -736,8 +725,8 @@ pub fn remove_job(id: i32, jobs: &mut Vec<TransferJob>) { } #[inline] -pub fn get_job(id: i32, jobs: &mut Vec<TransferJob>) -> Option<&mut TransferJob> { - jobs.iter_mut().filter(|x| x.id() == id).next() +pub fn get_job(id: i32, jobs: &mut [TransferJob]) -> Option<&mut TransferJob> { + jobs.iter_mut().find(|x| x.id() == id) } pub async fn handle_read_jobs( @@ -789,7 +778,7 @@ pub fn remove_all_empty_dir(path: &PathBuf) -> ResultType<()> { remove_all_empty_dir(&path.join(&entry.name)).ok(); } Ok(FileType::DirLink) | Ok(FileType::FileLink) => { - std::fs::remove_file(&path.join(&entry.name)).ok(); + std::fs::remove_file(path.join(&entry.name)).ok(); } _ => {} } @@ -813,7 +802,7 @@ pub fn create_dir(dir: &str) -> ResultType<()> { #[inline] pub fn transform_windows_path(entries: &mut Vec<FileEntry>) { for entry in entries { - entry.name = entry.name.replace("\\", "/"); + entry.name = entry.name.replace('\\', "/"); } } diff --git a/libs/hbb_common/src/lib.rs b/libs/hbb_common/src/lib.rs index e57994f..9e00437 100644 --- a/libs/hbb_common/src/lib.rs +++ b/libs/hbb_common/src/lib.rs @@ -96,8 +96,24 @@ pub type ResultType<F, E = anyhow::Error> = anyhow::Result<F, E>; pub struct AddrMangle(); +#[inline] +pub fn try_into_v4(addr: SocketAddr) -> SocketAddr { + match addr { + SocketAddr::V6(v6) if !addr.ip().is_loopback() => { + if let Some(v4) = v6.ip().to_ipv4() { + SocketAddr::new(IpAddr::V4(v4), addr.port()) + } else { + addr + } + } + _ => addr, + } +} + impl AddrMangle { pub fn encode(addr: SocketAddr) -> Vec<u8> { + // not work with [:1]:<port> + let addr = try_into_v4(addr); match addr { SocketAddr::V4(addr_v4) => { let tm = (SystemTime::now() @@ -129,22 +145,20 @@ impl AddrMangle { } pub fn decode(bytes: &[u8]) -> SocketAddr { + use std::convert::TryInto; + if bytes.len() > 16 { if bytes.len() != 18 { return Config::get_any_listen_addr(false); } - #[allow(invalid_value)] - let mut tmp: [u8; 2] = unsafe { std::mem::MaybeUninit::uninit().assume_init() }; - tmp.copy_from_slice(&bytes[16..]); + let tmp: [u8; 2] = bytes[16..].try_into().unwrap(); let port = u16::from_le_bytes(tmp); - #[allow(invalid_value)] - let mut tmp: [u8; 16] = unsafe { std::mem::MaybeUninit::uninit().assume_init() }; - tmp.copy_from_slice(&bytes[..16]); + let tmp: [u8; 16] = bytes[..16].try_into().unwrap(); let ip = std::net::Ipv6Addr::from(tmp); return SocketAddr::new(IpAddr::V6(ip), port); } let mut padded = [0u8; 16]; - padded[..bytes.len()].copy_from_slice(&bytes); + padded[..bytes.len()].copy_from_slice(bytes); let number = u128::from_le_bytes(padded); let tm = (number >> 17) & (u32::max_value() as u128); let ip = (((number >> 49) - tm) as u32).to_le_bytes(); @@ -158,21 +172,9 @@ impl AddrMangle { pub fn get_version_from_url(url: &str) -> String { let n = url.chars().count(); - let a = url - .chars() - .rev() - .enumerate() - .filter(|(_, x)| x == &'-') - .next() - .map(|(i, _)| i); + let a = url.chars().rev().position(|x| x == '-'); if let Some(a) = a { - let b = url - .chars() - .rev() - .enumerate() - .filter(|(_, x)| x == &'.') - .next() - .map(|(i, _)| i); + let b = url.chars().rev().position(|x| x == '.'); if let Some(b) = b { if a > b { if url @@ -195,22 +197,30 @@ pub fn get_version_from_url(url: &str) -> String { } pub fn gen_version() { + if Ok("release".to_owned()) != std::env::var("PROFILE") { + return; + } + println!("cargo:rerun-if-changed=Cargo.toml"); use std::io::prelude::*; let mut file = File::create("./src/version.rs").unwrap(); - for line in read_lines("Cargo.toml").unwrap() { - if let Ok(line) = line { - let ab: Vec<&str> = line.split("=").map(|x| x.trim()).collect(); - if ab.len() == 2 && ab[0] == "version" { - file.write_all(format!("pub const VERSION: &str = {};\n", ab[1]).as_bytes()) - .ok(); - break; - } + for line in read_lines("Cargo.toml").unwrap().flatten() { + let ab: Vec<&str> = line.split('=').map(|x| x.trim()).collect(); + if ab.len() == 2 && ab[0] == "version" { + file.write_all(format!("pub const VERSION: &str = {};\n", ab[1]).as_bytes()) + .ok(); + break; } } // generate build date let build_date = format!("{}", chrono::Local::now().format("%Y-%m-%d %H:%M")); - file.write_all(format!("pub const BUILD_DATE: &str = \"{}\";", build_date).as_bytes()) - .ok(); + file.write_all( + format!( + "#[allow(dead_code)]\npub const BUILD_DATE: &str = \"{}\";", + build_date + ) + .as_bytes(), + ) + .ok(); file.sync_all().ok(); } @@ -230,20 +240,20 @@ pub fn is_valid_custom_id(id: &str) -> bool { pub fn get_version_number(v: &str) -> i64 { let mut n = 0; - for x in v.split(".") { + for x in v.split('.') { n = n * 1000 + x.parse::<i64>().unwrap_or(0); } n } pub fn get_modified_time(path: &std::path::Path) -> SystemTime { - std::fs::metadata(&path) + std::fs::metadata(path) .map(|m| m.modified().unwrap_or(UNIX_EPOCH)) .unwrap_or(UNIX_EPOCH) } pub fn get_created_time(path: &std::path::Path) -> SystemTime { - std::fs::metadata(&path) + std::fs::metadata(path) .map(|m| m.created().unwrap_or(UNIX_EPOCH)) .unwrap_or(UNIX_EPOCH) } @@ -276,32 +286,6 @@ pub fn get_time() -> i64 { .unwrap_or(0) as _ } -#[cfg(test)] -mod tests { - use super::*; - #[test] - fn test_mangle() { - let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 16, 32), 21116)); - assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); - - let addr = "[2001:db8::1]:8080".parse::<SocketAddr>().unwrap(); - assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); - - let addr = "[2001:db8:ff::1111]:80".parse::<SocketAddr>().unwrap(); - assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); - } - - #[test] - fn test_allow_err() { - allow_err!(Err("test err") as Result<(), &str>); - allow_err!( - Err("test err with msg") as Result<(), &str>, - "prompt {}", - "failed" - ); - } -} - #[inline] pub fn is_ipv4_str(id: &str) -> bool { regex::Regex::new(r"^\d+\.\d+\.\d+\.\d+(:\d+)?$") @@ -334,10 +318,32 @@ pub fn is_domain_port_str(id: &str) -> bool { } #[cfg(test)] -mod test_lib { +mod test { use super::*; #[test] + fn test_mangle() { + let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 16, 32), 21116)); + assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); + + let addr = "[2001:db8::1]:8080".parse::<SocketAddr>().unwrap(); + assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); + + let addr = "[2001:db8:ff::1111]:80".parse::<SocketAddr>().unwrap(); + assert_eq!(addr, AddrMangle::decode(&AddrMangle::encode(addr))); + } + + #[test] + fn test_allow_err() { + allow_err!(Err("test err") as Result<(), &str>); + allow_err!( + Err("test err with msg") as Result<(), &str>, + "prompt {}", + "failed" + ); + } + + #[test] fn test_ipv6() { assert_eq!(is_ipv6_str("1:2:3"), true); assert_eq!(is_ipv6_str("[ab:2:3]:12"), true); @@ -373,4 +379,20 @@ mod test_lib { assert_eq!(is_domain_port_str("test.com:0"), true); assert_eq!(is_domain_port_str("test.com:98989"), true); } + + #[test] + fn test_mangle2() { + let addr = "[::ffff:127.0.0.1]:8080".parse().unwrap(); + let addr_v4 = "127.0.0.1:8080".parse().unwrap(); + assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr)), addr_v4); + assert_eq!( + AddrMangle::decode(&AddrMangle::encode("[::127.0.0.1]:8080".parse().unwrap())), + addr_v4 + ); + assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr_v4)), addr_v4); + let addr_v6 = "[ef::fe]:8080".parse().unwrap(); + assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr_v6)), addr_v6); + let addr_v6 = "[::1]:8080".parse().unwrap(); + assert_eq!(AddrMangle::decode(&AddrMangle::encode(addr_v6)), addr_v6); + } } diff --git a/libs/hbb_common/src/password_security.rs b/libs/hbb_common/src/password_security.rs index 6029069..0b66107 100644 --- a/libs/hbb_common/src/password_security.rs +++ b/libs/hbb_common/src/password_security.rs @@ -104,7 +104,7 @@ pub fn decrypt_str_or_original(s: &str, current_version: &str) -> (String, bool, if s.len() > VERSION_LEN { let version = &s[..VERSION_LEN]; if version == "00" { - if let Ok(v) = decrypt(&s[VERSION_LEN..].as_bytes()) { + if let Ok(v) = decrypt(s[VERSION_LEN..].as_bytes()) { return ( String::from_utf8_lossy(&v).to_string(), true, @@ -149,7 +149,7 @@ pub fn decrypt_vec_or_original(v: &[u8], current_version: &str) -> (Vec<u8>, boo } fn encrypt(v: &[u8]) -> Result<String, ()> { - if v.len() > 0 { + if !v.is_empty() { symmetric_crypt(v, true).map(|v| base64::encode(v, base64::Variant::Original)) } else { Err(()) @@ -157,7 +157,7 @@ fn encrypt(v: &[u8]) -> Result<String, ()> { } fn decrypt(v: &[u8]) -> Result<Vec<u8>, ()> { - if v.len() > 0 { + if !v.is_empty() { base64::decode(v, base64::Variant::Original).and_then(|v| symmetric_crypt(&v, false)) } else { Err(()) diff --git a/libs/hbb_common/src/platform/linux.rs b/libs/hbb_common/src/platform/linux.rs index e824163..716025d 100644 --- a/libs/hbb_common/src/platform/linux.rs +++ b/libs/hbb_common/src/platform/linux.rs @@ -32,7 +32,7 @@ pub fn get_display_server() -> String { // loginctl has not given the expected output. try something else. if let Ok(sid) = std::env::var("XDG_SESSION_ID") { // could also execute "cat /proc/self/sessionid" - session = sid.to_owned(); + session = sid; } if session.is_empty() { session = run_cmds("cat /proc/self/sessionid".to_owned()).unwrap_or_default(); @@ -63,7 +63,7 @@ fn get_display_server_of_session(session: &str) -> String { if let Ok(xorg_results) = run_cmds(format!("ps -e | grep \"{}.\\\\+Xorg\"", tty)) // And check if Xorg is running on that tty { - if xorg_results.trim_end().to_string() != "" { + if xorg_results.trim_end() != "" { // If it is, manually return "x11", otherwise return tty return "x11".to_owned(); } @@ -88,7 +88,7 @@ pub fn get_values_of_seat0(indices: Vec<usize>) -> Vec<String> { if let Ok(output) = run_loginctl(None) { for line in String::from_utf8_lossy(&output.stdout).lines() { if line.contains("seat0") { - if let Some(sid) = line.split_whitespace().nth(0) { + if let Some(sid) = line.split_whitespace().next() { if is_active(sid) { return indices .into_iter() @@ -103,7 +103,7 @@ pub fn get_values_of_seat0(indices: Vec<usize>) -> Vec<String> { // some case, there is no seat0 https://github.com/rustdesk/rustdesk/issues/73 if let Ok(output) = run_loginctl(None) { for line in String::from_utf8_lossy(&output.stdout).lines() { - if let Some(sid) = line.split_whitespace().nth(0) { + if let Some(sid) = line.split_whitespace().next() { let d = get_display_server_of_session(sid); if is_active(sid) && d != "tty" { return indices diff --git a/libs/hbb_common/src/socket_client.rs b/libs/hbb_common/src/socket_client.rs index 6f62163..a034b4e 100644 --- a/libs/hbb_common/src/socket_client.rs +++ b/libs/hbb_common/src/socket_client.rs @@ -71,7 +71,7 @@ pub trait IsResolvedSocketAddr { impl IsResolvedSocketAddr for SocketAddr { fn resolve(&self) -> Option<&SocketAddr> { - Some(&self) + Some(self) } } @@ -120,12 +120,12 @@ pub async fn connect_tcp_local< if let Some(target) = target.resolve() { if let Some(local) = local { if local.is_ipv6() && target.is_ipv4() { - let target = query_nip_io(&target).await?; - return Ok(FramedStream::new(target, Some(local), ms_timeout).await?); + let target = query_nip_io(target).await?; + return FramedStream::new(target, Some(local), ms_timeout).await; } } } - Ok(FramedStream::new(target, local, ms_timeout).await?) + FramedStream::new(target, local, ms_timeout).await } #[inline] @@ -140,15 +140,14 @@ pub fn is_ipv4(target: &TargetAddr<'_>) -> bool { pub async fn query_nip_io(addr: &SocketAddr) -> ResultType<SocketAddr> { tokio::net::lookup_host(format!("{}.nip.io:{}", addr.ip(), addr.port())) .await? - .filter(|x| x.is_ipv6()) - .next() + .find(|x| x.is_ipv6()) .context("Failed to get ipv6 from nip.io") } #[inline] pub fn ipv4_to_ipv6(addr: String, ipv4: bool) -> String { if !ipv4 && crate::is_ipv4_str(&addr) { - if let Some(ip) = addr.split(":").next() { + if let Some(ip) = addr.split(':').next() { return addr.replace(ip, &format!("{}.nip.io", ip)); } } diff --git a/libs/hbb_common/src/tcp.rs b/libs/hbb_common/src/tcp.rs index a1322fc..a7ac4eb 100644 --- a/libs/hbb_common/src/tcp.rs +++ b/libs/hbb_common/src/tcp.rs @@ -1,4 +1,5 @@ use crate::{bail, bytes_codec::BytesCodec, ResultType}; +use anyhow::Context as AnyhowCtx; use bytes::{BufMut, Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use protobuf::Message; @@ -209,7 +210,7 @@ impl FramedStream { if let Some(Ok(bytes)) = res.as_mut() { key.2 += 1; let nonce = Self::get_nonce(key.2); - match secretbox::open(&bytes, &nonce, &key.0) { + match secretbox::open(bytes, &nonce, &key.0) { Ok(res) => { bytes.clear(); bytes.put_slice(&res); @@ -245,16 +246,17 @@ impl FramedStream { const DEFAULT_BACKLOG: u32 = 128; -#[allow(clippy::never_loop)] pub async fn new_listener<T: ToSocketAddrs>(addr: T, reuse: bool) -> ResultType<TcpListener> { if !reuse { Ok(TcpListener::bind(addr).await?) } else { - for addr in lookup_host(&addr).await? { - let socket = new_socket(addr, true)?; - return Ok(socket.listen(DEFAULT_BACKLOG)?); - } - bail!("could not resolve to any address"); + let addr = lookup_host(&addr) + .await? + .next() + .context("could not resolve to any address")?; + new_socket(addr, true)? + .listen(DEFAULT_BACKLOG) + .map_err(anyhow::Error::msg) } } diff --git a/libs/hbb_common/src/udp.rs b/libs/hbb_common/src/udp.rs index 38121a4..bb0d071 100644 --- a/libs/hbb_common/src/udp.rs +++ b/libs/hbb_common/src/udp.rs @@ -1,11 +1,11 @@ -use crate::{bail, ResultType}; -use anyhow::anyhow; +use crate::ResultType; +use anyhow::{anyhow, Context}; use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use protobuf::Message; use socket2::{Domain, Socket, Type}; use std::net::SocketAddr; -use tokio::net::{ToSocketAddrs, UdpSocket}; +use tokio::net::{lookup_host, ToSocketAddrs, UdpSocket}; use tokio_socks::{udp::Socks5UdpFramed, IntoTargetAddr, TargetAddr, ToProxyAddrs}; use tokio_util::{codec::BytesCodec, udp::UdpFramed}; @@ -37,39 +37,31 @@ fn new_socket(addr: SocketAddr, reuse: bool, buf_size: usize) -> Result<Socket, addr, socket.recv_buffer_size() ); + if addr.is_ipv6() && addr.ip().is_unspecified() && addr.port() > 0 { + socket.set_only_v6(false).ok(); + } socket.bind(&addr.into())?; Ok(socket) } impl FramedSocket { pub async fn new<T: ToSocketAddrs>(addr: T) -> ResultType<Self> { - let socket = UdpSocket::bind(addr).await?; - Ok(Self::Direct(UdpFramed::new(socket, BytesCodec::new()))) - } - - #[allow(clippy::never_loop)] - pub async fn new_reuse<T: std::net::ToSocketAddrs>(addr: T) -> ResultType<Self> { - for addr in addr.to_socket_addrs()? { - let socket = new_socket(addr, true, 0)?.into_udp_socket(); - return Ok(Self::Direct(UdpFramed::new( - UdpSocket::from_std(socket)?, - BytesCodec::new(), - ))); - } - bail!("could not resolve to any address"); + Self::new_reuse(addr, false, 0).await } - pub async fn new_with_buf_size<T: std::net::ToSocketAddrs>( + pub async fn new_reuse<T: ToSocketAddrs>( addr: T, + reuse: bool, buf_size: usize, ) -> ResultType<Self> { - for addr in addr.to_socket_addrs()? { - return Ok(Self::Direct(UdpFramed::new( - UdpSocket::from_std(new_socket(addr, false, buf_size)?.into_udp_socket())?, - BytesCodec::new(), - ))); - } - bail!("could not resolve to any address"); + let addr = lookup_host(&addr) + .await? + .next() + .context("could not resolve to any address")?; + Ok(Self::Direct(UdpFramed::new( + UdpSocket::from_std(new_socket(addr, reuse, buf_size)?.into_udp_socket())?, + BytesCodec::new(), + ))) } pub async fn new_proxy<'a, 't, P: ToProxyAddrs, T: ToSocketAddrs>( @@ -104,11 +96,12 @@ impl FramedSocket { ) -> ResultType<()> { let addr = addr.into_target_addr()?.to_owned(); let send_data = Bytes::from(msg.write_to_bytes()?); - let _ = match self { - Self::Direct(f) => match addr { - TargetAddr::Ip(addr) => f.send((send_data, addr)).await?, - _ => {} - }, + match self { + Self::Direct(f) => { + if let TargetAddr::Ip(addr) = addr { + f.send((send_data, addr)).await? + } + } Self::ProxySocks(f) => f.send((send_data, addr)).await?, }; Ok(()) @@ -123,11 +116,12 @@ impl FramedSocket { ) -> ResultType<()> { let addr = addr.into_target_addr()?.to_owned(); - let _ = match self { - Self::Direct(f) => match addr { - TargetAddr::Ip(addr) => f.send((Bytes::from(msg), addr)).await?, - _ => {} - }, + match self { + Self::Direct(f) => { + if let TargetAddr::Ip(addr) = addr { + f.send((Bytes::from(msg), addr)).await? + } + } Self::ProxySocks(f) => f.send((Bytes::from(msg), addr)).await?, }; Ok(()) @@ -165,12 +159,12 @@ impl FramedSocket { } } - pub fn is_ipv4(&self) -> bool { + pub fn local_addr(&self) -> Option<SocketAddr> { if let FramedSocket::Direct(x) = self { if let Ok(v) = x.get_ref().local_addr() { - return v.is_ipv4(); + return Some(v); } } - true + None } } diff --git a/src/common.rs b/src/common.rs index a57485b..9ea4a1e 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,5 +1,8 @@ use clap::App; -use hbb_common::{anyhow::Context, log, ResultType}; +use hbb_common::{ + anyhow::{Context, Result}, + log, ResultType, +}; use ini::Ini; use sodiumoxide::crypto::sign; use std::{ @@ -9,15 +12,17 @@ use std::{ time::{Instant, SystemTime}, }; +#[allow(dead_code)] pub(crate) fn get_expired_time() -> Instant { let now = Instant::now(); now.checked_sub(std::time::Duration::from_secs(3600)) .unwrap_or(now) } +#[allow(dead_code)] pub(crate) fn test_if_valid_server(host: &str, name: &str) -> ResultType<SocketAddr> { use std::net::ToSocketAddrs; - let res = if host.contains(":") { + let res = if host.contains(':') { host.to_socket_addrs()?.next().context("") } else { format!("{}:{}", host, 0) @@ -31,9 +36,10 @@ pub(crate) fn test_if_valid_server(host: &str, name: &str) -> ResultType<SocketA res } +#[allow(dead_code)] pub(crate) fn get_servers(s: &str, tag: &str) -> Vec<String> { let servers: Vec<String> = s - .split(",") + .split(',') .filter(|x| !x.is_empty() && test_if_valid_server(x, tag).is_ok()) .map(|x| x.to_owned()) .collect(); @@ -41,17 +47,19 @@ pub(crate) fn get_servers(s: &str, tag: &str) -> Vec<String> { servers } +#[allow(dead_code)] #[inline] fn arg_name(name: &str) -> String { - name.to_uppercase().replace("_", "-") + name.to_uppercase().replace('_', "-") } +#[allow(dead_code)] pub fn init_args(args: &str, name: &str, about: &str) { let matches = App::new(name) .version(crate::version::VERSION) .author("Purslane Ltd. <[email protected]>") .about(about) - .args_from_usage(&args) + .args_from_usage(args) .get_matches(); if let Ok(v) = Ini::load_from_file(".env") { if let Some(section) = v.section(None::<String>) { @@ -76,16 +84,19 @@ pub fn init_args(args: &str, name: &str, about: &str) { } } +#[allow(dead_code)] #[inline] pub fn get_arg(name: &str) -> String { get_arg_or(name, "".to_owned()) } +#[allow(dead_code)] #[inline] pub fn get_arg_or(name: &str, default: String) -> String { std::env::var(arg_name(name)).unwrap_or(default) } +#[allow(dead_code)] #[inline] pub fn now() -> u64 { SystemTime::now() @@ -118,7 +129,7 @@ pub fn gen_sk(wait: u64) -> (String, Option<sign::SecretKey>) { }; let (mut pk, mut sk) = gen_func(); for _ in 0..300 { - if !pk.contains("/") && !pk.contains(":") { + if !pk.contains('/') && !pk.contains(':') { break; } (pk, sk) = gen_func(); @@ -138,3 +149,43 @@ pub fn gen_sk(wait: u64) -> (String, Option<sign::SecretKey>) { } ("".to_owned(), None) } + +#[cfg(unix)] +pub async fn listen_signal() -> Result<()> { + use hbb_common::tokio; + use hbb_common::tokio::signal::unix::{signal, SignalKind}; + + tokio::spawn(async { + let mut s = signal(SignalKind::hangup())?; + let hangup = s.recv(); + let mut s = signal(SignalKind::terminate())?; + let terminate = s.recv(); + let mut s = signal(SignalKind::interrupt())?; + let interrupt = s.recv(); + let mut s = signal(SignalKind::quit())?; + let quit = s.recv(); + + tokio::select! { + _ = hangup => { + log::info!("signal hangup"); + } + _ = terminate => { + log::info!("signal terminate"); + } + _ = interrupt => { + log::info!("signal interrupt"); + } + _ = quit => { + log::info!("signal quit"); + } + } + Ok(()) + }) + .await? +} + +#[cfg(not(unix))] +pub async fn listen_signal() -> Result<()> { + let () = std::future::pending().await; + unreachable!(); +} diff --git a/src/database.rs b/src/database.rs index 41ad5e3..fa1b6ed 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,6 +1,5 @@ use async_trait::async_trait; use hbb_common::{log, ResultType}; -use serde_json::value::Value; use sqlx::{ sqlite::SqliteConnectOptions, ConnectOptions, Connection, Error as SqlxError, SqliteConnection, }; @@ -8,7 +7,6 @@ use std::{ops::DerefMut, str::FromStr}; //use sqlx::postgres::PgPoolOptions; //use sqlx::mysql::MySqlPoolOptions; -pub(crate) type MapValue = serde_json::map::Map<String, Value>; type Pool = deadpool::managed::Pool<DbPool>; pub struct DbPool { @@ -54,7 +52,7 @@ impl Database { std::fs::File::create(url).ok(); } let n: usize = std::env::var("MAX_DATABASE_CONNECTIONS") - .unwrap_or("1".to_owned()) + .unwrap_or_else(|_| "1".to_owned()) .parse() .unwrap_or(1); log::debug!("MAX_DATABASE_CONNECTIONS={}", n); @@ -105,24 +103,6 @@ impl Database { .await?) } - #[inline] - pub async fn get_conn(&self) -> ResultType<deadpool::managed::Object<DbPool>> { - Ok(self.pool.get().await?) - } - - pub async fn update_peer(&self, payload: MapValue, guid: &[u8]) -> ResultType<()> { - let mut conn = self.get_conn().await?; - let mut tx = conn.begin().await?; - if let Some(v) = payload.get("note") { - let v = get_str(v); - sqlx::query!("update peer set note = ? where guid = ?", v, guid) - .execute(&mut tx) - .await?; - } - tx.commit().await?; - Ok(()) - } - pub async fn insert_peer( &self, id: &str, @@ -199,17 +179,3 @@ mod tests { hbb_common::futures::future::join_all(jobs).await; } } - -pub(crate) fn get_str(v: &Value) -> Option<&str> { - match v { - Value::String(v) => { - let v = v.trim(); - if v.is_empty() { - None - } else { - Some(v) - } - } - _ => None, - } -} diff --git a/src/peer.rs b/src/peer.rs index 49d440d..4c3ab72 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -1,19 +1,22 @@ use crate::common::*; use crate::database; use hbb_common::{ + bytes::Bytes, log, rendezvous_proto::*, tokio::sync::{Mutex, RwLock}, - bytes::Bytes, ResultType, }; use serde_derive::{Deserialize, Serialize}; use std::{collections::HashMap, collections::HashSet, net::SocketAddr, sync::Arc, time::Instant}; +type IpBlockMap = HashMap<String, ((u32, Instant), (HashSet<String>, Instant))>; +type UserStatusMap = HashMap<Vec<u8>, Arc<(Option<Vec<u8>>, bool)>>; +type IpChangesMap = HashMap<String, (Instant, HashMap<String, i32>)>; lazy_static::lazy_static! { - pub(crate) static ref IP_BLOCKER: Mutex<HashMap<String, ((u32, Instant), (HashSet<String>, Instant))>> = Default::default(); - pub(crate) static ref USER_STATUS: RwLock<HashMap<Vec<u8>, Arc<(Option<Vec<u8>>, bool)>>> = Default::default(); - pub(crate) static ref IP_CHANGES: Mutex<HashMap<String, (Instant, HashMap<String, i32>)>> = Default::default(); + pub(crate) static ref IP_BLOCKER: Mutex<IpBlockMap> = Default::default(); + pub(crate) static ref USER_STATUS: RwLock<UserStatusMap> = Default::default(); + pub(crate) static ref IP_CHANGES: Mutex<IpChangesMap> = Default::default(); } pub static IP_CHANGE_DUR: u64 = 180; pub static IP_CHANGE_DUR_X2: u64 = IP_CHANGE_DUR * 2; @@ -32,9 +35,9 @@ pub(crate) struct Peer { pub(crate) guid: Vec<u8>, pub(crate) uuid: Bytes, pub(crate) pk: Bytes, - pub(crate) user: Option<Vec<u8>>, + // pub(crate) user: Option<Vec<u8>>, pub(crate) info: PeerInfo, - pub(crate) disabled: bool, + // pub(crate) disabled: bool, pub(crate) reg_pk: (u32, Instant), // how often register_pk } @@ -47,8 +50,8 @@ impl Default for Peer { uuid: Bytes::new(), pk: Bytes::new(), info: Default::default(), - user: None, - disabled: false, + // user: None, + // disabled: false, reg_pk: (0, get_expired_time()), } } @@ -65,7 +68,6 @@ pub(crate) struct PeerMap { impl PeerMap { pub(crate) async fn new() -> ResultType<Self> { let db = std::env::var("DB_URL").unwrap_or({ - #[allow(unused_mut)] let mut db = "db_v2.sqlite3".to_owned(); #[cfg(all(windows, not(debug_assertions)))] { @@ -132,24 +134,22 @@ impl PeerMap { #[inline] pub(crate) async fn get(&self, id: &str) -> Option<LockPeer> { - let p = self.map.read().await.get(id).map(|x| x.clone()); + let p = self.map.read().await.get(id).cloned(); if p.is_some() { return p; - } else { - if let Ok(Some(v)) = self.db.get_peer(id).await { - let peer = Peer { - guid: v.guid, - uuid: v.uuid.into(), - pk: v.pk.into(), - user: v.user, - info: serde_json::from_str::<PeerInfo>(&v.info).unwrap_or_default(), - disabled: v.status == Some(0), - ..Default::default() - }; - let peer = Arc::new(RwLock::new(peer)); - self.map.write().await.insert(id.to_owned(), peer.clone()); - return Some(peer); - } + } else if let Ok(Some(v)) = self.db.get_peer(id).await { + let peer = Peer { + guid: v.guid, + uuid: v.uuid.into(), + pk: v.pk.into(), + // user: v.user, + info: serde_json::from_str::<PeerInfo>(&v.info).unwrap_or_default(), + // disabled: v.status == Some(0), + ..Default::default() + }; + let peer = Arc::new(RwLock::new(peer)); + self.map.write().await.insert(id.to_owned(), peer.clone()); + return Some(peer); } None } @@ -170,7 +170,7 @@ impl PeerMap { #[inline] pub(crate) async fn get_in_memory(&self, id: &str) -> Option<LockPeer> { - self.map.read().await.get(id).map(|x| x.clone()) + self.map.read().await.get(id).cloned() } #[inline] diff --git a/src/relay_server.rs b/src/relay_server.rs index 7b198d6..94a5cb3 100644 --- a/src/relay_server.rs +++ b/src/relay_server.rs @@ -8,7 +8,7 @@ use hbb_common::{ protobuf::Message as _, rendezvous_proto::*, sleep, - tcp::{new_listener, FramedStream}, + tcp::{listen_any, FramedStream}, timeout, tokio::{ self, @@ -37,12 +37,12 @@ lazy_static::lazy_static! { } static mut DOWNGRADE_THRESHOLD: f64 = 0.66; -static mut DOWNGRADE_START_CHECK: usize = 1800_000; // in ms +static mut DOWNGRADE_START_CHECK: usize = 1_800_000; // in ms static mut LIMIT_SPEED: usize = 4 * 1024 * 1024; // in bit/s static mut TOTAL_BANDWIDTH: usize = 1024 * 1024 * 1024; // in bit/s static mut SINGLE_BANDWIDTH: usize = 16 * 1024 * 1024; // in bit/s -const BLACKLIST_FILE: &'static str = "blacklist.txt"; -const BLOCKLIST_FILE: &'static str = "blocklist.txt"; +const BLACKLIST_FILE: &str = "blacklist.txt"; +const BLOCKLIST_FILE: &str = "blocklist.txt"; #[tokio::main(flavor = "multi_thread")] pub async fn start(port: &str, key: &str) -> ResultType<()> { @@ -50,8 +50,8 @@ pub async fn start(port: &str, key: &str) -> ResultType<()> { if let Ok(mut file) = std::fs::File::open(BLACKLIST_FILE) { let mut contents = String::new(); if file.read_to_string(&mut contents).is_ok() { - for x in contents.split("\n") { - if let Some(ip) = x.trim().split(' ').nth(0) { + for x in contents.split('\n') { + if let Some(ip) = x.trim().split(' ').next() { BLACKLIST.write().await.insert(ip.to_owned()); } } @@ -65,8 +65,8 @@ pub async fn start(port: &str, key: &str) -> ResultType<()> { if let Ok(mut file) = std::fs::File::open(BLOCKLIST_FILE) { let mut contents = String::new(); if file.read_to_string(&mut contents).is_ok() { - for x in contents.split("\n") { - if let Some(ip) = x.trim().split(' ').nth(0) { + for x in contents.split('\n') { + if let Some(ip) = x.trim().split(' ').next() { BLOCKLIST.write().await.insert(ip.to_owned()); } } @@ -77,19 +77,21 @@ pub async fn start(port: &str, key: &str) -> ResultType<()> { BLOCKLIST_FILE, BLOCKLIST.read().await.len() ); - let addr = format!("0.0.0.0:{}", port); - log::info!("Listening on tcp {}", addr); - let addr2 = format!("0.0.0.0:{}", port.parse::<u16>().unwrap() + 2); - log::info!("Listening on websocket {}", addr2); - loop { - log::info!("Start"); - io_loop( - new_listener(&addr, false).await?, - new_listener(&addr2, false).await?, - &key, - ) - .await; - } + let port: u16 = port.parse()?; + log::info!("Listening on tcp :{}", port); + let port2 = port + 2; + log::info!("Listening on websocket :{}", port2); + let main_task = async move { + loop { + log::info!("Start"); + io_loop(listen_any(port).await?, listen_any(port2).await?, &key).await; + } + }; + let listen_signal = crate::common::listen_signal(); + tokio::select!( + res = main_task => res, + res = listen_signal => res, + ) } fn check_params() { @@ -151,8 +153,10 @@ fn check_params() { } async fn check_cmd(cmd: &str, limiter: Limiter) -> String { + use std::fmt::Write; + let mut res = "".to_owned(); - let mut fds = cmd.trim().split(" "); + let mut fds = cmd.trim().split(' '); match fds.next() { Some("h") => { res = format!( @@ -173,7 +177,7 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String { } Some("blacklist-add" | "ba") => { if let Some(ip) = fds.next() { - for ip in ip.split("|") { + for ip in ip.split('|') { BLACKLIST.write().await.insert(ip.to_owned()); } } @@ -183,7 +187,7 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String { if ip == "all" { BLACKLIST.write().await.clear(); } else { - for ip in ip.split("|") { + for ip in ip.split('|') { BLACKLIST.write().await.remove(ip); } } @@ -194,13 +198,13 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String { res = format!("{}\n", BLACKLIST.read().await.get(ip).is_some()); } else { for ip in BLACKLIST.read().await.clone().into_iter() { - res += &format!("{}\n", ip); + let _ = writeln!(res, "{ip}"); } } } Some("blocklist-add" | "Ba") => { if let Some(ip) = fds.next() { - for ip in ip.split("|") { + for ip in ip.split('|') { BLOCKLIST.write().await.insert(ip.to_owned()); } } @@ -210,7 +214,7 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String { if ip == "all" { BLOCKLIST.write().await.clear(); } else { - for ip in ip.split("|") { + for ip in ip.split('|') { BLOCKLIST.write().await.remove(ip); } } @@ -221,7 +225,7 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String { res = format!("{}\n", BLOCKLIST.read().await.get(ip).is_some()); } else { for ip in BLOCKLIST.read().await.clone().into_iter() { - res += &format!("{}\n", ip); + let _ = writeln!(res, "{ip}"); } } } @@ -306,15 +310,16 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String { .read() .await .iter() - .map(|x| (x.0.clone(), x.1.clone())) + .map(|x| (x.0.clone(), *x.1)) .collect(); tmp.sort_by(|a, b| ((b.1).1).partial_cmp(&(a.1).1).unwrap()); for (ip, (elapsed, total, highest, speed)) in tmp { - if elapsed <= 0 { + if elapsed == 0 { continue; } - res += &format!( - "{}: {}s {:.2}MB {}kb/s {}kb/s {}kb/s\n", + let _ = writeln!( + res, + "{}: {}s {:.2}MB {}kb/s {}kb/s {}kb/s", ip, elapsed / 1000, total as f64 / 1024. / 1024. / 8., @@ -489,7 +494,7 @@ async fn relay( total_limiter.consume(nb).await; total += nb; total_s += nb; - if bytes.len() > 0 { + if !bytes.is_empty() { stream.send_raw(bytes.into()).await?; } } else { @@ -508,7 +513,7 @@ async fn relay( total_limiter.consume(nb).await; total += nb; total_s += nb; - if bytes.len() > 0 { + if !bytes.is_empty() { peer.send_raw(bytes.into()).await?; } } else { @@ -530,7 +535,7 @@ async fn relay( } blacked = BLACKLIST.read().await.get(&ip).is_some(); tm = std::time::Instant::now(); - let speed = total_s / (n as usize); + let speed = total_s / n; if speed > highest_s { highest_s = speed; } @@ -540,16 +545,17 @@ async fn relay( (elapsed as _, total as _, highest_s as _, speed as _), ); total_s = 0; - if elapsed > unsafe { DOWNGRADE_START_CHECK } && !downgrade { - if total > elapsed * downgrade_threshold { - downgrade = true; - log::info!( - "Downgrade {}, exceed downgrade threshold {}bit/ms in {}ms", - id, - downgrade_threshold, - elapsed - ); - } + if elapsed > unsafe { DOWNGRADE_START_CHECK } + && !downgrade + && total > elapsed * downgrade_threshold + { + downgrade = true; + log::info!( + "Downgrade {}, exceed downgrade threshold {}bit/ms in {}ms", + id, + downgrade_threshold, + elapsed + ); } } } diff --git a/src/rendezvous_server.rs b/src/rendezvous_server.rs index ed8796d..e7418db 100644 --- a/src/rendezvous_server.rs +++ b/src/rendezvous_server.rs @@ -4,6 +4,7 @@ use hbb_common::{ allow_err, bytes::{Bytes, BytesMut}, bytes_codec::BytesCodec, + config, futures::future::join_all, futures_util::{ sink::SinkExt, @@ -15,7 +16,7 @@ use hbb_common::{ register_pk_response::Result::{TOO_FREQUENT, UUID_MISMATCH}, *, }, - tcp::{new_listener, FramedStream}, + tcp::{listen_any, FramedStream}, timeout, tokio::{ self, @@ -25,6 +26,7 @@ use hbb_common::{ time::{interval, Duration}, }, tokio_util::codec::Framed, + try_into_v4, udp::FramedSocket, AddrMangle, ResultType, }; @@ -32,7 +34,7 @@ use ipnetwork::Ipv4Network; use sodiumoxide::crypto::sign; use std::{ collections::HashMap, - net::{IpAddr, Ipv4Addr, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, sync::Arc, time::Instant, }; @@ -40,7 +42,7 @@ const ADDR_127: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); #[derive(Clone, Debug)] enum Data { - Msg(RendezvousMessage, SocketAddr), + Msg(Box<RendezvousMessage>, SocketAddr), RelayServers0(String), RelayServers(RelayServers), } @@ -92,15 +94,15 @@ impl RendezvousServer { pub async fn start(port: i32, serial: i32, key: &str, rmem: usize) -> ResultType<()> { let (key, sk) = Self::get_server_sk(key); let addr = format!("0.0.0.0:{}", port); - let addr2 = format!("0.0.0.0:{}", port - 1); - let addr3 = format!("0.0.0.0:{}", port + 2); + let nat_port = port - 1; + let ws_port = port + 2; let pm = PeerMap::new().await?; log::info!("serial={}", serial); let rendezvous_servers = get_servers(&get_arg("rendezvous-servers"), "rendezvous-servers"); - log::info!("Listening on tcp/udp {}", addr); - log::info!("Listening on tcp {}, extra port for NAT test", addr2); - log::info!("Listening on websocket {}", addr3); - let mut socket = FramedSocket::new_with_buf_size(&addr, rmem).await?; + log::info!("Listening on tcp/udp :{}", port); + log::info!("Listening on tcp :{}, extra port for NAT test", nat_port); + log::info!("Listening on websocket :{}", ws_port); + let mut socket = create_udp_listener(port, rmem).await?; let (tx, mut rx) = mpsc::unbounded_channel::<Data>(); let software_url = get_arg("software-url"); let version = hbb_common::get_version_from_url(&software_url); @@ -138,9 +140,9 @@ impl RendezvousServer { log::info!("local-ip: {:?}", rs.inner.local_ip); std::env::set_var("PORT_FOR_API", port.to_string()); rs.parse_relay_servers(&get_arg("relay-servers")); - let mut listener = new_listener(&addr, false).await?; - let mut listener2 = new_listener(&addr2, false).await?; - let mut listener3 = new_listener(&addr3, false).await?; + let mut listener = create_tcp_listener(port).await?; + let mut listener2 = create_tcp_listener(nat_port).await?; + let mut listener3 = create_tcp_listener(ws_port).await?; let test_addr = std::env::var("TEST_HBBS").unwrap_or_default(); if std::env::var("ALWAYS_USE_RELAY") .unwrap_or_default() @@ -170,37 +172,44 @@ impl RendezvousServer { allow_err!(test_hbbs(test_addr).await); }); }; - loop { - log::info!("Start"); - match rs - .io_loop( - &mut rx, - &mut listener, - &mut listener2, - &mut listener3, - &mut socket, - &key, - ) - .await - { - LoopFailure::UdpSocket => { - drop(socket); - socket = FramedSocket::new_with_buf_size(&addr, rmem).await?; - } - LoopFailure::Listener => { - drop(listener); - listener = new_listener(&addr, false).await?; - } - LoopFailure::Listener2 => { - drop(listener2); - listener2 = new_listener(&addr2, false).await?; - } - LoopFailure::Listener3 => { - drop(listener3); - listener3 = new_listener(&addr3, false).await?; + let main_task = async move { + loop { + log::info!("Start"); + match rs + .io_loop( + &mut rx, + &mut listener, + &mut listener2, + &mut listener3, + &mut socket, + &key, + ) + .await + { + LoopFailure::UdpSocket => { + drop(socket); + socket = create_udp_listener(port, rmem).await?; + } + LoopFailure::Listener => { + drop(listener); + listener = create_tcp_listener(port).await?; + } + LoopFailure::Listener2 => { + drop(listener2); + listener2 = create_tcp_listener(nat_port).await?; + } + LoopFailure::Listener3 => { + drop(listener3); + listener3 = create_tcp_listener(ws_port).await?; + } } } - } + }; + let listen_signal = listen_signal(); + tokio::select!( + res = main_task => res, + res = listen_signal => res, + ) } async fn io_loop( @@ -226,7 +235,7 @@ impl RendezvousServer { } Some(data) = rx.recv() => { match data { - Data::Msg(msg, addr) => { allow_err!(socket.send(&msg, addr).await); } + Data::Msg(msg, addr) => { allow_err!(socket.send(msg.as_ref(), addr).await); } Data::RelayServers0(rs) => { self.parse_relay_servers(&rs); } Data::RelayServers(rs) => { self.relay_servers = Arc::new(rs); } } @@ -296,11 +305,11 @@ impl RendezvousServer { socket: &mut FramedSocket, key: &str, ) -> ResultType<()> { - if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { + if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(bytes) { match msg_in.union { Some(rendezvous_message::Union::RegisterPeer(rp)) => { // B registered - if rp.id.len() > 0 { + if !rp.id.is_empty() { log::trace!("New peer registered: {:?} {:?}", &rp.id, &addr); self.update_addr(rp.id, addr, socket).await?; if self.inner.serial > rp.serial { @@ -377,12 +386,10 @@ impl RendezvousServer { *tm = Instant::now(); ips.clear(); ips.insert(ip.clone(), 1); + } else if let Some(v) = ips.get_mut(&ip) { + *v += 1; } else { - if let Some(v) = ips.get_mut(&ip) { - *v += 1; - } else { - ips.insert(ip.clone(), 1); - } + ips.insert(ip.clone(), 1); } } else { lock.insert( @@ -465,27 +472,27 @@ impl RendezvousServer { key: &str, ws: bool, ) -> bool { - if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { + if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(bytes) { match msg_in.union { Some(rendezvous_message::Union::PunchHoleRequest(ph)) => { // there maybe several attempt, so sink can be none if let Some(sink) = sink.take() { - self.tcp_punch.lock().await.insert(addr, sink); + self.tcp_punch.lock().await.insert(try_into_v4(addr), sink); } - allow_err!(self.handle_tcp_punch_hole_request(addr, ph, &key, ws).await); + allow_err!(self.handle_tcp_punch_hole_request(addr, ph, key, ws).await); return true; } Some(rendezvous_message::Union::RequestRelay(mut rf)) => { // there maybe several attempt, so sink can be none if let Some(sink) = sink.take() { - self.tcp_punch.lock().await.insert(addr, sink); + self.tcp_punch.lock().await.insert(try_into_v4(addr), sink); } if let Some(peer) = self.pm.get_in_memory(&rf.id).await { let mut msg_out = RendezvousMessage::new(); rf.socket_addr = AddrMangle::encode(addr).into(); msg_out.set_request_relay(rf); let peer_addr = peer.read().await.socket_addr; - self.tx.send(Data::Msg(msg_out, peer_addr)).ok(); + self.tx.send(Data::Msg(msg_out.into(), peer_addr)).ok(); } return true; } @@ -740,14 +747,14 @@ impl RendezvousServer { ..Default::default() }); } - return Ok((msg_out, Some(peer_addr))); + Ok((msg_out, Some(peer_addr))) } else { let mut msg_out = RendezvousMessage::new(); msg_out.set_punch_hole_response(PunchHoleResponse { failure: punch_hole_response::Failure::ID_NOT_EXIST.into(), ..Default::default() }); - return Ok((msg_out, None)); + Ok((msg_out, None)) } } @@ -758,8 +765,8 @@ impl RendezvousServer { peers: Vec<String>, ) -> ResultType<()> { let mut states = BytesMut::zeroed((peers.len() + 7) / 8); - for i in 0..peers.len() { - if let Some(peer) = self.pm.get_in_memory(&peers[i]).await { + for (i, peer_id) in peers.iter().enumerate() { + if let Some(peer) = self.pm.get_in_memory(peer_id).await { let elapsed = peer.read().await.last_reg_time.elapsed().as_millis() as i32; // bytes index from left to right let states_idx = i / 8; @@ -825,7 +832,7 @@ impl RendezvousServer { ) -> ResultType<()> { let (msg, to_addr) = self.handle_punch_hole_request(addr, ph, key, ws).await?; if let Some(addr) = to_addr { - self.tx.send(Data::Msg(msg, addr))?; + self.tx.send(Data::Msg(msg.into(), addr))?; } else { self.send_to_tcp_sync(msg, addr).await?; } @@ -841,7 +848,7 @@ impl RendezvousServer { ) -> ResultType<()> { let (msg, to_addr) = self.handle_punch_hole_request(addr, ph, key, false).await?; self.tx.send(Data::Msg( - msg, + msg.into(), match to_addr { Some(addr) => addr, None => addr, @@ -900,8 +907,10 @@ impl RendezvousServer { } async fn check_cmd(&self, cmd: &str) -> String { + use std::fmt::Write as _; + let mut res = "".to_owned(); - let mut fds = cmd.trim().split(" "); + let mut fds = cmd.trim().split(' '); match fds.next() { Some("h") => { res = format!( @@ -919,7 +928,7 @@ impl RendezvousServer { self.tx.send(Data::RelayServers0(rs.to_owned())).ok(); } else { for ip in self.relay_servers.iter() { - res += &format!("{}\n", ip); + let _ = writeln!(res, "{ip}"); } } } @@ -935,8 +944,9 @@ impl RendezvousServer { if start < 0 { if let Some(ip) = ip { if let Some((a, b)) = lock.get(ip) { - res += &format!( - "{}/{}s {}/{}s\n", + let _ = writeln!( + res, + "{}/{}s {}/{}s", a.0, a.1.elapsed().as_secs(), b.0.len(), @@ -961,8 +971,9 @@ impl RendezvousServer { continue; } if let Some((ip, (a, b))) = x { - res += &format!( - "{}: {}/{}s {}/{}s\n", + let _ = writeln!( + res, + "{}: {}/{}s {}/{}s", ip, a.0, a.1.elapsed().as_secs(), @@ -979,10 +990,10 @@ impl RendezvousServer { res = format!("{}\n", lock.len()); let id = fds.next(); let mut start = id.map(|x| x.parse::<i32>().unwrap_or(-1)).unwrap_or(-1); - if start < 0 || start > 10_000_000 { + if !(0..=10_000_000).contains(&start) { if let Some(id) = id { if let Some((tm, ips)) = lock.get(id) { - res += &format!("{}s {:?}\n", tm.elapsed().as_secs(), ips); + let _ = writeln!(res, "{}s {:?}", tm.elapsed().as_secs(), ips); } if fds.next() == Some("-") { lock.remove(id); @@ -1002,7 +1013,7 @@ impl RendezvousServer { continue; } if let Some((id, (tm, ips))) = x { - res += &format!("{}: {}s {:?}\n", id, tm.elapsed().as_secs(), ips,); + let _ = writeln!(res, "{}: {}s {:?}", id, tm.elapsed().as_secs(), ips,); } } } @@ -1016,7 +1027,7 @@ impl RendezvousServer { } self.tx.send(Data::RelayServers0(rs.to_owned())).ok(); } else { - res += &format!("ALWAYS_USE_RELAY: {:?}\n", unsafe { ALWAYS_USE_RELAY }); + let _ = writeln!(res, "ALWAYS_USE_RELAY: {:?}", unsafe { ALWAYS_USE_RELAY }); } } Some("test-geo" | "tg") => { @@ -1039,7 +1050,7 @@ impl RendezvousServer { async fn handle_listener2(&self, stream: TcpStream, addr: SocketAddr) { let mut rs = self.clone(); - if addr.ip().to_string() == "127.0.0.1" { + if addr.ip().is_loopback() { tokio::spawn(async move { let mut stream = stream; let mut buffer = [0; 64]; @@ -1099,13 +1110,10 @@ impl RendezvousServer { let (a, mut b) = ws_stream.split(); sink = Some(Sink::Ws(a)); while let Ok(Some(Ok(msg))) = timeout(30_000, b.next()).await { - match msg { - tungstenite::Message::Binary(bytes) => { - if !self.handle_tcp(&bytes, &mut sink, addr, key, ws).await { - break; - } + if let tungstenite::Message::Binary(bytes) = msg { + if !self.handle_tcp(&bytes, &mut sink, addr, key, ws).await { + break; } - _ => {} } } } else { @@ -1131,7 +1139,7 @@ impl RendezvousServer { } else { match self.pm.get(&id).await { Some(peer) => { - let pk = peer.read().await.pk.clone().into(); + let pk = peer.read().await.pk.clone(); sign::sign( &hbb_common::message_proto::IdPk { id, @@ -1140,7 +1148,7 @@ impl RendezvousServer { } .write_to_bytes() .unwrap_or_default(), - &self.inner.sk.as_ref().unwrap(), + self.inner.sk.as_ref().unwrap(), ) .into() } @@ -1196,8 +1204,8 @@ async fn check_relay_servers(rs0: Arc<RelayServers>, tx: Sender) { let rs = Arc::new(Mutex::new(Vec::new())); for x in rs0.iter() { let mut host = x.to_owned(); - if !host.contains(":") { - host = format!("{}:{}", host, hbb_common::config::RELAY_PORT); + if !host.contains(':') { + host = format!("{}:{}", host, config::RELAY_PORT); } let rs = rs.clone(); let x = x.clone(); @@ -1212,7 +1220,7 @@ async fn check_relay_servers(rs0: Arc<RelayServers>, tx: Sender) { } join_all(futs).await; log::debug!("check_relay_servers"); - let rs = std::mem::replace(&mut *rs.lock().await, Default::default()); + let rs = std::mem::take(&mut *rs.lock().await); if !rs.is_empty() { tx.send(Data::RelayServers(rs)).ok(); } @@ -1220,7 +1228,7 @@ async fn check_relay_servers(rs0: Arc<RelayServers>, tx: Sender) { // temp solution to solve udp socket failure async fn test_hbbs(addr: SocketAddr) -> ResultType<()> { - let mut socket = FramedSocket::new("0.0.0.0:0").await?; + let mut socket = FramedSocket::new(config::Config::get_any_listen_addr(addr.is_ipv4())).await?; let mut msg_out = RendezvousMessage::new(); msg_out.set_register_peer(RegisterPeer { id: "(:test_hbbs:)".to_owned(), @@ -1261,3 +1269,22 @@ async fn send_rk_res( }); socket.send(&msg_out, addr).await } + +async fn create_udp_listener(port: i32, rmem: usize) -> ResultType<FramedSocket> { + let addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port as _); + if let Ok(s) = FramedSocket::new_reuse(&addr, false, rmem).await { + log::debug!("listen on udp {:?}", s.local_addr()); + return Ok(s); + } + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port as _); + let s = FramedSocket::new_reuse(&addr, false, rmem).await?; + log::debug!("listen on udp {:?}", s.local_addr()); + Ok(s) +} + +#[inline] +async fn create_tcp_listener(port: i32) -> ResultType<TcpListener> { + let s = listen_any(port as _).await?; + log::debug!("listen on tcp {:?}", s.local_addr()); + Ok(s) +} diff --git a/src/utils.rs b/src/utils.rs index 7b36aed..b37878a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -33,7 +33,7 @@ fn gen_keypair() { } fn validate_keypair(pk: &str, sk: &str) -> ResultType<()> { - let sk1 = base64::decode(&sk); + let sk1 = base64::decode(sk); if sk1.is_err() { bail!("Invalid secret key"); } @@ -45,7 +45,7 @@ fn validate_keypair(pk: &str, sk: &str) -> ResultType<()> { } let secret_key = secret_key.unwrap(); - let pk1 = base64::decode(&pk); + let pk1 = base64::decode(pk); if pk1.is_err() { bail!("Invalid public key"); } @@ -96,14 +96,13 @@ fn doctor_ip(server_ip_address: std::net::IpAddr, server_address: Option<&str>) // reverse dns lookup // TODO: (check) doesn't seem to do reverse lookup on OSX... let reverse = lookup_addr(&server_ip_address).unwrap(); - if server_address.is_some() { - if reverse == server_address.unwrap() { + if let Some(server_address) = server_address { + if reverse == server_address { println!("Reverse DNS lookup: '{}' MATCHES server address", reverse); } else { println!( "Reverse DNS lookup: '{}' DOESN'T MATCH server address '{}'", - reverse, - server_address.unwrap() + reverse, server_address ); } } @@ -126,19 +125,18 @@ fn doctor(server_address_unclean: &str) { let server_address2 = server_address3.to_lowercase(); let server_address = server_address2.as_str(); println!("Checking server: {}\n", server_address); - let server_ipaddr = server_address.parse::<IpAddr>(); - if server_ipaddr.is_err() { + if let Ok(server_ipaddr) = server_address.parse::<IpAddr>() { + // user requested an ip address + doctor_ip(server_ipaddr, None); + } else { // the passed string is not an ip address let ips: Vec<std::net::IpAddr> = lookup_host(server_address).unwrap(); - println!("Found {} IP addresses: ", ips.iter().count()); + println!("Found {} IP addresses: ", ips.len()); ips.iter().for_each(|ip| println!(" - {ip}")); - ips.iter().for_each(|ip| doctor_ip(*ip, Some(server_address))); - - } else { - // user requested an ip address - doctor_ip(server_ipaddr.unwrap(), None); + ips.iter() + .for_each(|ip| doctor_ip(*ip, Some(server_address))); } } diff --git a/src/version.rs b/src/version.rs index cd0f4db..7df1331 100644 --- a/src/version.rs +++ b/src/version.rs @@ -1,2 +1,3 @@ -pub const VERSION: &str = "1.1.6"; -pub const BUILD_DATE: &str = "2023-01-06 10:39";
\ No newline at end of file +pub const VERSION: &str = "1.1.7"; +#[allow(dead_code)] +pub const BUILD_DATE: &str = "2023-01-10 22:43"; diff --git a/systemd/rustdesk-hbbr.service b/systemd/rustdesk-hbbr.service index 43396a3..2ef10e2 100644 --- a/systemd/rustdesk-hbbr.service +++ b/systemd/rustdesk-hbbr.service @@ -10,6 +10,8 @@ WorkingDirectory=/var/lib/rustdesk-server/ User= Group= Restart=always +StandardOutput=append:/var/log/rustdesk/rustdesk-hbbr.log +StandardError=append:/var/log/rustdesk/rustdesk-hbbr.error # Restart service after 10 seconds if node service crashes RestartSec=10 diff --git a/systemd/rustdesk-hbbs.service b/systemd/rustdesk-hbbs.service index 16427f1..344f015 100644 --- a/systemd/rustdesk-hbbs.service +++ b/systemd/rustdesk-hbbs.service @@ -10,6 +10,8 @@ WorkingDirectory=/var/lib/rustdesk-server/ User= Group= Restart=always +StandardOutput=append:/var/log/rustdesk/rustdesk-hbbs.log +StandardError=append:/var/log/rustdesk/rustdesk-hbbs.error # Restart service after 10 seconds if node service crashes RestartSec=10 |