use async_speed_limit::Limiter; use async_trait::async_trait; use hbb_common::{ allow_err, bail, bytes::{Bytes, BytesMut}, futures_util::{sink::SinkExt, stream::StreamExt}, log, protobuf::Message as _, rendezvous_proto::*, sleep, tcp::{new_listener, FramedStream}, timeout, tokio::{ self, io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, sync::{Mutex, RwLock}, time::{interval, Duration}, }, ResultType, }; use sodiumoxide::crypto::sign; use std::{ collections::{HashMap, HashSet}, io::prelude::*, io::Error, net::SocketAddr, }; type Usage = (usize, usize, usize, usize); lazy_static::lazy_static! { static ref PEERS: Mutex>> = Default::default(); static ref USAGE: RwLock> = Default::default(); static ref BLACKLIST: RwLock> = Default::default(); static ref BLOCKLIST: RwLock> = Default::default(); } static mut DOWNGRADE_THRESHOLD: f64 = 0.66; static mut DOWNGRADE_START_CHECK: usize = 1800_000; // in ms static mut LIMIT_SPEED: usize = 4 * 1024 * 1024; // in bit/s static mut TOTAL_BANDWIDTH: usize = 1024 * 1024 * 1024; // in bit/s static mut SINGLE_BANDWIDTH: usize = 16 * 1024 * 1024; // in bit/s const BLACKLIST_FILE: &'static str = "blacklist.txt"; const BLOCKLIST_FILE: &'static str = "blocklist.txt"; #[tokio::main(flavor = "multi_thread")] pub async fn start(port: &str, key: &str) -> ResultType<()> { let key = get_server_sk(key); if let Ok(mut file) = std::fs::File::open(BLACKLIST_FILE) { let mut contents = String::new(); if file.read_to_string(&mut contents).is_ok() { for x in contents.split("\n") { if let Some(ip) = x.trim().split(' ').nth(0) { BLACKLIST.write().await.insert(ip.to_owned()); } } } } log::info!( "#blacklist({}): {}", BLACKLIST_FILE, BLACKLIST.read().await.len() ); if let Ok(mut file) = std::fs::File::open(BLOCKLIST_FILE) { let mut contents = String::new(); if file.read_to_string(&mut contents).is_ok() { for x in contents.split("\n") { if let Some(ip) = x.trim().split(' ').nth(0) { BLOCKLIST.write().await.insert(ip.to_owned()); } } } } log::info!( "#blocklist({}): {}", BLOCKLIST_FILE, BLOCKLIST.read().await.len() ); let addr = format!("0.0.0.0:{}", port); log::info!("Listening on tcp {}", addr); let addr2 = format!("0.0.0.0:{}", port.parse::().unwrap() + 2); log::info!("Listening on websocket {}", addr2); loop { log::info!("Start"); io_loop( new_listener(&addr, false).await?, new_listener(&addr2, false).await?, &key, ) .await; } } fn check_params() { let tmp = std::env::var("DOWNGRADE_THRESHOLD") .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { unsafe { DOWNGRADE_THRESHOLD = tmp; } } unsafe { log::info!("DOWNGRADE_THRESHOLD: {}", DOWNGRADE_THRESHOLD) }; let tmp = std::env::var("DOWNGRADE_START_CHECK") .map(|x| x.parse::().unwrap_or(0)) .unwrap_or(0); if tmp > 0 { unsafe { DOWNGRADE_START_CHECK = tmp * 1000; } } unsafe { log::info!("DOWNGRADE_START_CHECK: {}s", DOWNGRADE_START_CHECK / 1000) }; let tmp = std::env::var("LIMIT_SPEED") .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { unsafe { LIMIT_SPEED = (tmp * 1024. * 1024.) as usize; } } unsafe { log::info!("LIMIT_SPEED: {}Mb/s", LIMIT_SPEED as f64 / 1024. / 1024.) }; let tmp = std::env::var("TOTAL_BANDWIDTH") .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { unsafe { TOTAL_BANDWIDTH = (tmp * 1024. * 1024.) as usize; } } unsafe { log::info!( "TOTAL_BANDWIDTH: {}Mb/s", TOTAL_BANDWIDTH as f64 / 1024. / 1024. ) }; let tmp = std::env::var("SINGLE_BANDWIDTH") .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { unsafe { SINGLE_BANDWIDTH = (tmp * 1024. * 1024.) as usize; } } unsafe { log::info!( "SINGLE_BANDWIDTH: {}Mb/s", SINGLE_BANDWIDTH as f64 / 1024. / 1024. ) }; } async fn check_cmd(cmd: &str, limiter: Limiter) -> String { let mut res = "".to_owned(); let mut fds = cmd.trim().split(" "); match fds.next() { Some("h") => { res = format!( "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n", "blacklist-add(ba) ", "blacklist-remove(br) ", "blacklist(b) ", "blocklist-add(Ba) ", "blocklist-remove(Br) ", "blocklist(B) ", "downgrade-threshold(dt) [value]", "downgrade-start-check(t) [value(second)]", "limit-speed(ls) [value(Mb/s)]", "total-bandwidth(tb) [value(Mb/s)]", "single-bandwidth(sb) [value(Mb/s)]", "usage(u)" ) } Some("blacklist-add" | "ba") => { if let Some(ip) = fds.next() { for ip in ip.split("|") { BLACKLIST.write().await.insert(ip.to_owned()); } } } Some("blacklist-remove" | "br") => { if let Some(ip) = fds.next() { if ip == "all" { BLACKLIST.write().await.clear(); } else { for ip in ip.split("|") { BLACKLIST.write().await.remove(ip); } } } } Some("blacklist" | "b") => { if let Some(ip) = fds.next() { res = format!("{}\n", BLACKLIST.read().await.get(ip).is_some()); } else { for ip in BLACKLIST.read().await.clone().into_iter() { res += &format!("{}\n", ip); } } } Some("blocklist-add" | "Ba") => { if let Some(ip) = fds.next() { for ip in ip.split("|") { BLOCKLIST.write().await.insert(ip.to_owned()); } } } Some("blocklist-remove" | "Br") => { if let Some(ip) = fds.next() { if ip == "all" { BLOCKLIST.write().await.clear(); } else { for ip in ip.split("|") { BLOCKLIST.write().await.remove(ip); } } } } Some("blocklist" | "B") => { if let Some(ip) = fds.next() { res = format!("{}\n", BLOCKLIST.read().await.get(ip).is_some()); } else { for ip in BLOCKLIST.read().await.clone().into_iter() { res += &format!("{}\n", ip); } } } Some("downgrade-threshold" | "dt") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { unsafe { DOWNGRADE_THRESHOLD = v; } } } } else { unsafe { res = format!("{}\n", DOWNGRADE_THRESHOLD); } } } Some("downgrade-start-check" | "t") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0 { unsafe { DOWNGRADE_START_CHECK = v * 1000; } } } } else { unsafe { res = format!("{}s\n", DOWNGRADE_START_CHECK / 1000); } } } Some("limit-speed" | "ls") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { unsafe { LIMIT_SPEED = (v * 1024. * 1024.) as _; } } } } else { unsafe { res = format!("{}Mb/s\n", LIMIT_SPEED as f64 / 1024. / 1024.); } } } Some("total-bandwidth" | "tb") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { unsafe { TOTAL_BANDWIDTH = (v * 1024. * 1024.) as _; limiter.set_speed_limit(TOTAL_BANDWIDTH as _); } } } } else { unsafe { res = format!("{}Mb/s\n", TOTAL_BANDWIDTH as f64 / 1024. / 1024.); } } } Some("single-bandwidth" | "sb") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { unsafe { SINGLE_BANDWIDTH = (v * 1024. * 1024.) as _; } } } } else { unsafe { res = format!("{}Mb/s\n", SINGLE_BANDWIDTH as f64 / 1024. / 1024.); } } } Some("usage" | "u") => { let mut tmp: Vec<(String, Usage)> = USAGE .read() .await .iter() .map(|x| (x.0.clone(), x.1.clone())) .collect(); tmp.sort_by(|a, b| ((b.1).1).partial_cmp(&(a.1).1).unwrap()); for (ip, (elapsed, total, highest, speed)) in tmp { if elapsed <= 0 { continue; } res += &format!( "{}: {}s {:.2}MB {}kb/s {}kb/s {}kb/s\n", ip, elapsed / 1000, total as f64 / 1024. / 1024. / 8., highest, total / elapsed, speed ); } } _ => {} } res } async fn io_loop(listener: TcpListener, listener2: TcpListener, key: &str) { check_params(); let limiter = ::new(unsafe { TOTAL_BANDWIDTH as _ }); loop { tokio::select! { res = listener.accept() => { match res { Ok((stream, addr)) => { stream.set_nodelay(true).ok(); handle_connection(stream, addr, &limiter, key, false).await; } Err(err) => { log::error!("listener.accept failed: {}", err); break; } } } res = listener2.accept() => { match res { Ok((stream, addr)) => { stream.set_nodelay(true).ok(); handle_connection(stream, addr, &limiter, key, true).await; } Err(err) => { log::error!("listener2.accept failed: {}", err); break; } } } } } } async fn handle_connection( stream: TcpStream, addr: SocketAddr, limiter: &Limiter, key: &str, ws: bool, ) { let ip = addr.ip().to_string(); if !ws && ip == "127.0.0.1" { let limiter = limiter.clone(); tokio::spawn(async move { let mut stream = stream; let mut buffer = [0; 64]; if let Ok(Ok(n)) = timeout(1000, stream.read(&mut buffer[..])).await { if let Ok(data) = std::str::from_utf8(&buffer[..n]) { let res = check_cmd(data, limiter).await; stream.write(res.as_bytes()).await.ok(); } } }); return; } if BLOCKLIST.read().await.get(&ip).is_some() { log::info!("{} blocked", ip); return; } let key = key.to_owned(); let limiter = limiter.clone(); tokio::spawn(async move { allow_err!(make_pair(stream, addr, &key, limiter, ws).await); }); } async fn make_pair( stream: TcpStream, addr: SocketAddr, key: &str, limiter: Limiter, ws: bool, ) -> ResultType<()> { if ws { make_pair_( tokio_tungstenite::accept_async(stream).await?, addr, key, limiter, ) .await; } else { make_pair_(FramedStream::from(stream, addr), addr, key, limiter).await; } Ok(()) } async fn make_pair_(stream: impl StreamTrait, addr: SocketAddr, key: &str, limiter: Limiter) { let mut stream = stream; if let Ok(Some(Ok(bytes))) = timeout(30_000, stream.recv()).await { if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { if let Some(rendezvous_message::Union::request_relay(rf)) = msg_in.union { if !key.is_empty() && rf.licence_key != key { return; } if !rf.uuid.is_empty() { let mut peer = PEERS.lock().await.remove(&rf.uuid); if let Some(peer) = peer.as_mut() { log::info!("Relayrequest {} from {} got paired", rf.uuid, addr); let id = format!("{}:{}", addr.ip(), addr.port()); USAGE.write().await.insert(id.clone(), Default::default()); if !stream.is_ws() && !peer.is_ws() { peer.set_raw(); stream.set_raw(); log::info!("Both are raw"); } if let Err(err) = relay(addr, &mut stream, peer, limiter, id.clone()).await { log::info!("Relay of {} closed: {}", addr, err); } else { log::info!("Relay of {} closed", addr); } USAGE.write().await.remove(&id); } else { log::info!("New relay request {} from {}", rf.uuid, addr); PEERS.lock().await.insert(rf.uuid.clone(), Box::new(stream)); sleep(30.).await; PEERS.lock().await.remove(&rf.uuid); } } } } } } async fn relay( addr: SocketAddr, stream: &mut impl StreamTrait, peer: &mut Box, total_limiter: Limiter, id: String, ) -> ResultType<()> { let ip = addr.ip().to_string(); let mut tm = std::time::Instant::now(); let mut elapsed = 0; let mut total = 0; let mut total_s = 0; let mut highest_s = 0; let mut downgrade: bool = false; let mut blacked: bool = false; let limiter = ::new(unsafe { SINGLE_BANDWIDTH as _ }); let blacklist_limiter = ::new(unsafe { LIMIT_SPEED as _ }); let downgrade_threshold = (unsafe { SINGLE_BANDWIDTH as f64 * DOWNGRADE_THRESHOLD } / 1000.) as usize; // in bit/ms let mut timer = interval(Duration::from_secs(3)); let mut last_recv_time = std::time::Instant::now(); loop { tokio::select! { res = peer.recv() => { if let Some(Ok(bytes)) = res { last_recv_time = std::time::Instant::now(); let nb = bytes.len() * 8; if blacked || downgrade { blacklist_limiter.consume(nb).await; } else { limiter.consume(nb).await; } total_limiter.consume(nb).await; total += nb; total_s += nb; if bytes.len() > 0 { stream.send_raw(bytes.into()).await?; } } else { break; } }, res = stream.recv() => { if let Some(Ok(bytes)) = res { last_recv_time = std::time::Instant::now(); let nb = bytes.len() * 8; if blacked || downgrade { blacklist_limiter.consume(nb).await; } else { limiter.consume(nb).await; } total_limiter.consume(nb).await; total += nb; total_s += nb; if bytes.len() > 0 { peer.send_raw(bytes.into()).await?; } } else { break; } }, _ = timer.tick() => { if last_recv_time.elapsed().as_secs() > 30 { bail!("Timeout"); } } } let n = tm.elapsed().as_millis() as usize; if n >= 1_000 { if BLOCKLIST.read().await.get(&ip).is_some() { log::info!("{} blocked", ip); break; } blacked = BLACKLIST.read().await.get(&ip).is_some(); tm = std::time::Instant::now(); let speed = total_s / (n as usize); if speed > highest_s { highest_s = speed; } elapsed += n; USAGE.write().await.insert( id.clone(), (elapsed as _, total as _, highest_s as _, speed as _), ); total_s = 0; if elapsed > unsafe { DOWNGRADE_START_CHECK } && !downgrade { if total > elapsed * downgrade_threshold { downgrade = true; log::info!( "Downgrade {}, exceed downgrade threshold {}bit/ms in {}ms", id, downgrade_threshold, elapsed ); } } } } Ok(()) } fn get_server_sk(key: &str) -> String { let mut key = key.to_owned(); if let Ok(sk) = base64::decode(&key) { if sk.len() == sign::SECRETKEYBYTES { log::info!("The key is a crypto private key"); key = base64::encode(&sk[(sign::SECRETKEYBYTES / 2)..]); } } if key == "-" || key == "_" { let (pk, _) = crate::common::gen_sk(); key = pk; } if !key.is_empty() { log::info!("Key: {}", key); } key } #[async_trait] trait StreamTrait: Send + Sync + 'static { async fn recv(&mut self) -> Option>; async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()>; fn is_ws(&self) -> bool; fn set_raw(&mut self); } #[async_trait] impl StreamTrait for FramedStream { async fn recv(&mut self) -> Option> { self.next().await } async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()> { self.send_bytes(bytes).await } fn is_ws(&self) -> bool { false } fn set_raw(&mut self) { self.set_raw(); } } #[async_trait] impl StreamTrait for tokio_tungstenite::WebSocketStream { async fn recv(&mut self) -> Option> { if let Some(msg) = self.next().await { match msg { Ok(msg) => { match msg { tungstenite::Message::Binary(bytes) => { Some(Ok(bytes[..].into())) // to-do: poor performance } _ => Some(Ok(BytesMut::new())), } } Err(err) => Some(Err(Error::new(std::io::ErrorKind::Other, err.to_string()))), } } else { None } } async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()> { Ok(self .send(tungstenite::Message::Binary(bytes.to_vec())) .await?) // to-do: poor performance } fn is_ws(&self) -> bool { true } fn set_raw(&mut self) {} }