use async_speed_limit::Limiter; use async_trait::async_trait; use hbb_common::{ allow_err, bail, bytes::{Bytes, BytesMut}, futures_util::{sink::SinkExt, stream::StreamExt}, log, protobuf::Message as _, rendezvous_proto::*, sleep, tcp::{listen_any, FramedStream}, timeout, tokio::{ self, io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, sync::{Mutex, RwLock}, time::{interval, Duration}, }, ResultType, }; use sodiumoxide::crypto::sign; use std::{ collections::{HashMap, HashSet}, io::prelude::*, io::Error, net::SocketAddr, sync::atomic::{AtomicUsize, Ordering}, }; type Usage = (usize, usize, usize, usize); lazy_static::lazy_static! { static ref PEERS: Mutex>> = Default::default(); static ref USAGE: RwLock> = Default::default(); static ref BLACKLIST: RwLock> = Default::default(); static ref BLOCKLIST: RwLock> = Default::default(); } static DOWNGRADE_THRESHOLD_100: AtomicUsize = AtomicUsize::new(66); // 0.66 static DOWNGRADE_START_CHECK: AtomicUsize = AtomicUsize::new(1_800_000); // in ms static LIMIT_SPEED: AtomicUsize = AtomicUsize::new(4 * 1024 * 1024); // in bit/s static TOTAL_BANDWIDTH: AtomicUsize = AtomicUsize::new(1024 * 1024 * 1024); // in bit/s static SINGLE_BANDWIDTH: AtomicUsize = AtomicUsize::new(16 * 1024 * 1024); // in bit/s const BLACKLIST_FILE: &str = "blacklist.txt"; const BLOCKLIST_FILE: &str = "blocklist.txt"; #[tokio::main(flavor = "multi_thread")] pub async fn start(port: &str, key: &str) -> ResultType<()> { let key = get_server_sk(key); if let Ok(mut file) = std::fs::File::open(BLACKLIST_FILE) { let mut contents = String::new(); if file.read_to_string(&mut contents).is_ok() { for x in contents.split('\n') { if let Some(ip) = x.trim().split(' ').next() { BLACKLIST.write().await.insert(ip.to_owned()); } } } } log::info!( "#blacklist({}): {}", BLACKLIST_FILE, BLACKLIST.read().await.len() ); if let Ok(mut file) = std::fs::File::open(BLOCKLIST_FILE) { let mut contents = String::new(); if file.read_to_string(&mut contents).is_ok() { for x in contents.split('\n') { if let Some(ip) = x.trim().split(' ').next() { BLOCKLIST.write().await.insert(ip.to_owned()); } } } } log::info!( "#blocklist({}): {}", BLOCKLIST_FILE, BLOCKLIST.read().await.len() ); let port: u16 = port.parse()?; log::info!("Listening on tcp :{}", port); let port2 = port + 2; log::info!("Listening on websocket :{}", port2); let main_task = async move { loop { log::info!("Start"); io_loop(listen_any(port, true).await?, listen_any(port2, true).await?, &key).await; } }; let listen_signal = crate::common::listen_signal(); tokio::select!( res = main_task => res, res = listen_signal => res, ) } fn check_params() { let tmp = std::env::var("DOWNGRADE_THRESHOLD") .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { DOWNGRADE_THRESHOLD_100.store((tmp * 100.) as _, Ordering::SeqCst); } log::info!( "DOWNGRADE_THRESHOLD: {}", DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100. ); let tmp = std::env::var("DOWNGRADE_START_CHECK") .map(|x| x.parse::().unwrap_or(0)) .unwrap_or(0); if tmp > 0 { DOWNGRADE_START_CHECK.store(tmp * 1000, Ordering::SeqCst); } log::info!( "DOWNGRADE_START_CHECK: {}s", DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000 ); let tmp = std::env::var("LIMIT_SPEED") .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { LIMIT_SPEED.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst); } log::info!( "LIMIT_SPEED: {}Mb/s", LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024. ); let tmp = std::env::var("TOTAL_BANDWIDTH") .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { TOTAL_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst); } log::info!( "TOTAL_BANDWIDTH: {}Mb/s", TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024. ); let tmp = std::env::var("SINGLE_BANDWIDTH") .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { SINGLE_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst); } log::info!( "SINGLE_BANDWIDTH: {}Mb/s", SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024. ) } async fn check_cmd(cmd: &str, limiter: Limiter) -> String { use std::fmt::Write; let mut res = "".to_owned(); let mut fds = cmd.trim().split(' '); match fds.next() { Some("h") => { res = format!( "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n", "blacklist-add(ba) ", "blacklist-remove(br) ", "blacklist(b) ", "blocklist-add(Ba) ", "blocklist-remove(Br) ", "blocklist(B) ", "downgrade-threshold(dt) [value]", "downgrade-start-check(t) [value(second)]", "limit-speed(ls) [value(Mb/s)]", "total-bandwidth(tb) [value(Mb/s)]", "single-bandwidth(sb) [value(Mb/s)]", "usage(u)" ) } Some("blacklist-add" | "ba") => { if let Some(ip) = fds.next() { for ip in ip.split('|') { BLACKLIST.write().await.insert(ip.to_owned()); } } } Some("blacklist-remove" | "br") => { if let Some(ip) = fds.next() { if ip == "all" { BLACKLIST.write().await.clear(); } else { for ip in ip.split('|') { BLACKLIST.write().await.remove(ip); } } } } Some("blacklist" | "b") => { if let Some(ip) = fds.next() { res = format!("{}\n", BLACKLIST.read().await.get(ip).is_some()); } else { for ip in BLACKLIST.read().await.clone().into_iter() { let _ = writeln!(res, "{ip}"); } } } Some("blocklist-add" | "Ba") => { if let Some(ip) = fds.next() { for ip in ip.split('|') { BLOCKLIST.write().await.insert(ip.to_owned()); } } } Some("blocklist-remove" | "Br") => { if let Some(ip) = fds.next() { if ip == "all" { BLOCKLIST.write().await.clear(); } else { for ip in ip.split('|') { BLOCKLIST.write().await.remove(ip); } } } } Some("blocklist" | "B") => { if let Some(ip) = fds.next() { res = format!("{}\n", BLOCKLIST.read().await.get(ip).is_some()); } else { for ip in BLOCKLIST.read().await.clone().into_iter() { let _ = writeln!(res, "{ip}"); } } } Some("downgrade-threshold" | "dt") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { DOWNGRADE_THRESHOLD_100.store((v * 100.) as _, Ordering::SeqCst); } } } else { res = format!( "{}\n", DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100. ); } } Some("downgrade-start-check" | "t") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0 { DOWNGRADE_START_CHECK.store(v * 1000, Ordering::SeqCst); } } } else { res = format!("{}s\n", DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000); } } Some("limit-speed" | "ls") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { LIMIT_SPEED.store((v * 1024. * 1024.) as _, Ordering::SeqCst); } } } else { res = format!( "{}Mb/s\n", LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024. ); } } Some("total-bandwidth" | "tb") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { TOTAL_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst); limiter.set_speed_limit(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _); } } } else { res = format!( "{}Mb/s\n", TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024. ); } } Some("single-bandwidth" | "sb") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { SINGLE_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst); } } } else { res = format!( "{}Mb/s\n", SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024. ); } } Some("usage" | "u") => { let mut tmp: Vec<(String, Usage)> = USAGE .read() .await .iter() .map(|x| (x.0.clone(), *x.1)) .collect(); tmp.sort_by(|a, b| ((b.1).1).partial_cmp(&(a.1).1).unwrap()); for (ip, (elapsed, total, highest, speed)) in tmp { if elapsed == 0 { continue; } let _ = writeln!( res, "{}: {}s {:.2}MB {}kb/s {}kb/s {}kb/s", ip, elapsed / 1000, total as f64 / 1024. / 1024. / 8., highest, total / elapsed, speed ); } } _ => {} } res } async fn io_loop(listener: TcpListener, listener2: TcpListener, key: &str) { check_params(); let limiter = ::new(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _); loop { tokio::select! { res = listener.accept() => { match res { Ok((stream, addr)) => { stream.set_nodelay(true).ok(); handle_connection(stream, addr, &limiter, key, false).await; } Err(err) => { log::error!("listener.accept failed: {}", err); break; } } } res = listener2.accept() => { match res { Ok((stream, addr)) => { stream.set_nodelay(true).ok(); handle_connection(stream, addr, &limiter, key, true).await; } Err(err) => { log::error!("listener2.accept failed: {}", err); break; } } } } } } async fn handle_connection( stream: TcpStream, addr: SocketAddr, limiter: &Limiter, key: &str, ws: bool, ) { let ip = hbb_common::try_into_v4(addr).ip(); if !ws && ip.is_loopback() { let limiter = limiter.clone(); tokio::spawn(async move { let mut stream = stream; let mut buffer = [0; 1024]; if let Ok(Ok(n)) = timeout(1000, stream.read(&mut buffer[..])).await { if let Ok(data) = std::str::from_utf8(&buffer[..n]) { let res = check_cmd(data, limiter).await; stream.write(res.as_bytes()).await.ok(); } } }); return; } let ip = ip.to_string(); if BLOCKLIST.read().await.get(&ip).is_some() { log::info!("{} blocked", ip); return; } let key = key.to_owned(); let limiter = limiter.clone(); tokio::spawn(async move { allow_err!(make_pair(stream, addr, &key, limiter, ws).await); }); } async fn make_pair( stream: TcpStream, addr: SocketAddr, key: &str, limiter: Limiter, ws: bool, ) -> ResultType<()> { if ws { make_pair_( tokio_tungstenite::accept_async(stream).await?, addr, key, limiter, ) .await; } else { make_pair_(FramedStream::from(stream, addr), addr, key, limiter).await; } Ok(()) } async fn make_pair_(stream: impl StreamTrait, addr: SocketAddr, key: &str, limiter: Limiter) { let mut stream = stream; if let Ok(Some(Ok(bytes))) = timeout(30_000, stream.recv()).await { if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { if let Some(rendezvous_message::Union::RequestRelay(rf)) = msg_in.union { if !key.is_empty() && rf.licence_key != key { return; } if !rf.uuid.is_empty() { let mut peer = PEERS.lock().await.remove(&rf.uuid); if let Some(peer) = peer.as_mut() { log::info!("Relayrequest {} from {} got paired", rf.uuid, addr); let id = format!("{}:{}", addr.ip(), addr.port()); USAGE.write().await.insert(id.clone(), Default::default()); if !stream.is_ws() && !peer.is_ws() { peer.set_raw(); stream.set_raw(); log::info!("Both are raw"); } if let Err(err) = relay(addr, &mut stream, peer, limiter, id.clone()).await { log::info!("Relay of {} closed: {}", addr, err); } else { log::info!("Relay of {} closed", addr); } USAGE.write().await.remove(&id); } else { log::info!("New relay request {} from {}", rf.uuid, addr); PEERS.lock().await.insert(rf.uuid.clone(), Box::new(stream)); sleep(30.).await; PEERS.lock().await.remove(&rf.uuid); } } } } } } async fn relay( addr: SocketAddr, stream: &mut impl StreamTrait, peer: &mut Box, total_limiter: Limiter, id: String, ) -> ResultType<()> { let ip = addr.ip().to_string(); let mut tm = std::time::Instant::now(); let mut elapsed = 0; let mut total = 0; let mut total_s = 0; let mut highest_s = 0; let mut downgrade: bool = false; let mut blacked: bool = false; let sb = SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64; let limiter = ::new(sb); let blacklist_limiter = ::new(LIMIT_SPEED.load(Ordering::SeqCst) as _); let downgrade_threshold = (sb * DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100. / 1000.) as usize; // in bit/ms let mut timer = interval(Duration::from_secs(3)); let mut last_recv_time = std::time::Instant::now(); loop { tokio::select! { res = peer.recv() => { if let Some(Ok(bytes)) = res { last_recv_time = std::time::Instant::now(); let nb = bytes.len() * 8; if blacked || downgrade { blacklist_limiter.consume(nb).await; } else { limiter.consume(nb).await; } total_limiter.consume(nb).await; total += nb; total_s += nb; if !bytes.is_empty() { stream.send_raw(bytes.into()).await?; } } else { break; } }, res = stream.recv() => { if let Some(Ok(bytes)) = res { last_recv_time = std::time::Instant::now(); let nb = bytes.len() * 8; if blacked || downgrade { blacklist_limiter.consume(nb).await; } else { limiter.consume(nb).await; } total_limiter.consume(nb).await; total += nb; total_s += nb; if !bytes.is_empty() { peer.send_raw(bytes.into()).await?; } } else { break; } }, _ = timer.tick() => { if last_recv_time.elapsed().as_secs() > 30 { bail!("Timeout"); } } } let n = tm.elapsed().as_millis() as usize; if n >= 1_000 { if BLOCKLIST.read().await.get(&ip).is_some() { log::info!("{} blocked", ip); break; } blacked = BLACKLIST.read().await.get(&ip).is_some(); tm = std::time::Instant::now(); let speed = total_s / n; if speed > highest_s { highest_s = speed; } elapsed += n; USAGE.write().await.insert( id.clone(), (elapsed as _, total as _, highest_s as _, speed as _), ); total_s = 0; if elapsed > DOWNGRADE_START_CHECK.load(Ordering::SeqCst) && !downgrade && total > elapsed * downgrade_threshold { downgrade = true; log::info!( "Downgrade {}, exceed downgrade threshold {}bit/ms in {}ms", id, downgrade_threshold, elapsed ); } } } Ok(()) } fn get_server_sk(key: &str) -> String { let mut key = key.to_owned(); if let Ok(sk) = base64::decode(&key) { if sk.len() == sign::SECRETKEYBYTES { log::info!("The key is a crypto private key"); key = base64::encode(&sk[(sign::SECRETKEYBYTES / 2)..]); } } if key == "-" || key == "_" { let (pk, _) = crate::common::gen_sk(300); key = pk; } if !key.is_empty() { log::info!("Key: {}", key); } key } #[async_trait] trait StreamTrait: Send + Sync + 'static { async fn recv(&mut self) -> Option>; async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()>; fn is_ws(&self) -> bool; fn set_raw(&mut self); } #[async_trait] impl StreamTrait for FramedStream { async fn recv(&mut self) -> Option> { self.next().await } async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()> { self.send_bytes(bytes).await } fn is_ws(&self) -> bool { false } fn set_raw(&mut self) { self.set_raw(); } } #[async_trait] impl StreamTrait for tokio_tungstenite::WebSocketStream { async fn recv(&mut self) -> Option> { if let Some(msg) = self.next().await { match msg { Ok(msg) => { match msg { tungstenite::Message::Binary(bytes) => { Some(Ok(bytes[..].into())) // to-do: poor performance } _ => Some(Ok(BytesMut::new())), } } Err(err) => Some(Err(Error::new(std::io::ErrorKind::Other, err.to_string()))), } } else { None } } async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()> { Ok(self .send(tungstenite::Message::Binary(bytes.to_vec())) .await?) // to-do: poor performance } fn is_ws(&self) -> bool { true } fn set_raw(&mut self) {} }