diff options
author | open-trade <[email protected]> | 2020-05-14 11:47:28 +0000 |
---|---|---|
committer | open-trade <[email protected]> | 2020-05-14 11:47:28 +0000 |
commit | 1f4f1cc8e2fcbe4e0a64c2d992dfcb72dbe86940 (patch) | |
tree | bd47f442d730caf7fb6bee3967fe9278c836ed51 /src | |
parent | b06f9d22aca9e51a4c29c74f89ac164cd2b8b896 (diff) | |
download | rustdesk-server-1f4f1cc8e2fcbe4e0a64c2d992dfcb72dbe86940.tar.gz rustdesk-server-1f4f1cc8e2fcbe4e0a64c2d992dfcb72dbe86940.zip |
refactor for preparing sled
Diffstat (limited to 'src')
-rw-r--r-- | src/rendezvous_server.rs | 118 |
1 files changed, 80 insertions, 38 deletions
diff --git a/src/rendezvous_server.rs b/src/rendezvous_server.rs index cb0e9e5..32105aa 100644 --- a/src/rendezvous_server.rs +++ b/src/rendezvous_server.rs @@ -19,34 +19,38 @@ use hbb_common::{ use std::{ collections::HashMap, net::SocketAddr, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, RwLock}, time::Instant, }; +#[derive(Clone)] struct Peer { socket_addr: SocketAddr, last_reg_time: Instant, } +#[derive(Clone)] struct PeerMap { - map: HashMap<String, Peer>, + map: Arc<RwLock<HashMap<String, Peer>>>, db: sled::Db, } impl PeerMap { fn new() -> ResultType<Self> { Ok(Self { - map: HashMap::new(), + map: Default::default(), db: sled::open("./sled.db")?, }) } + #[inline] fn insert(&mut self, key: String, peer: Peer) { - self.map.insert(key, peer); + self.map.write().unwrap().insert(key, peer); } - fn get(&self, key: &str) -> Option<&Peer> { - self.map.get(key) + #[inline] + fn get(&self, key: &str) -> Option<Peer> { + self.map.read().unwrap().get(key).map(|x| x.clone()) } } @@ -56,38 +60,40 @@ type Sink = SplitSink<Framed<TcpStream, BytesCodec>, Bytes>; #[derive(Clone)] pub struct RendezvousServer { tcp_punch: Arc<Mutex<HashMap<SocketAddr, Sink>>>, + pm: PeerMap, + tx: mpsc::UnboundedSender<(RendezvousMessage, SocketAddr)>, } impl RendezvousServer { pub async fn start(addr: &str) -> ResultType<()> { - let mut pm = PeerMap::new()?; let mut socket = FramedSocket::new(addr).await?; + let (tx, mut rx) = mpsc::unbounded_channel::<(RendezvousMessage, SocketAddr)>(); let mut rs = Self { tcp_punch: Arc::new(Mutex::new(HashMap::new())), + pm: PeerMap::new()?, + tx: tx.clone(), }; - let (tx, mut rx) = mpsc::unbounded_channel::<(SocketAddr, String)>(); let mut listener = new_listener(addr, true).await?; loop { tokio::select! { - Some((addr, id)) = rx.recv() => { - allow_err!(rs.handle_punch_hole_request(addr, &id, &mut socket, true, &pm).await); + Some((msg, addr)) = rx.recv() => { + allow_err!(socket.send(&msg, addr).await); } Some(Ok((bytes, addr))) = socket.next() => { - allow_err!(rs.handle_msg(&bytes, addr, &mut socket, &mut pm).await); + allow_err!(rs.handle_msg(&bytes, addr, &mut socket).await); } Ok((stream, addr)) = listener.accept() => { log::debug!("Tcp connection from {:?}", addr); let (a, mut b) = Framed::new(stream, BytesCodec::new()).split(); let tcp_punch = rs.tcp_punch.clone(); tcp_punch.lock().unwrap().insert(addr, a); - let tx = tx.clone(); let mut rs = rs.clone(); tokio::spawn(async move { while let Some(Ok(bytes)) = b.next().await { if let Ok(msg_in) = parse_from_bytes::<RendezvousMessage>(&bytes) { match msg_in.union { Some(rendezvous_message::Union::punch_hole_request(ph)) => { - allow_err!(tx.send((addr, ph.id))); + allow_err!(rs.handle_tcp_punch_hole_request(addr, &ph.id).await); } Some(rendezvous_message::Union::punch_hole_sent(phs)) => { allow_err!(rs.handle_hole_sent(&phs, addr, None).await); @@ -109,12 +115,12 @@ impl RendezvousServer { } } + #[inline] async fn handle_msg( &mut self, bytes: &BytesMut, addr: SocketAddr, socket: &mut FramedSocket, - pm: &mut PeerMap, ) -> ResultType<()> { if let Ok(msg_in) = parse_from_bytes::<RendezvousMessage>(&bytes) { match msg_in.union { @@ -122,7 +128,7 @@ impl RendezvousServer { // B registered if rp.id.len() > 0 { log::debug!("New peer registered: {:?} {:?}", &rp.id, &addr); - pm.insert( + self.pm.insert( rp.id, Peer { socket_addr: addr, @@ -135,8 +141,7 @@ impl RendezvousServer { } } Some(rendezvous_message::Union::punch_hole_request(ph)) => { - self.handle_punch_hole_request(addr, &ph.id, socket, false, &pm) - .await?; + self.handle_udp_punch_hole_request(addr, &ph.id).await?; } Some(rendezvous_message::Union::punch_hole_sent(phs)) => { self.handle_hole_sent(&phs, addr, Some(socket)).await?; @@ -150,6 +155,7 @@ impl RendezvousServer { Ok(()) } + #[inline] async fn handle_hole_sent<'a>( &mut self, phs: &PunchHoleSent, @@ -172,11 +178,12 @@ impl RendezvousServer { if let Some(socket) = socket { socket.send(&msg_out, addr_a).await?; } else { - self.send_to_tcp(&msg_out, addr_a).await?; + self.send_to_tcp(&msg_out, addr_a).await; } Ok(()) } + #[inline] async fn handle_local_addr<'a>( &mut self, la: &LocalAddr, @@ -199,36 +206,30 @@ impl RendezvousServer { if let Some(socket) = socket { socket.send(&msg_out, addr_a).await?; } else { - self.send_to_tcp(&msg_out, addr_a).await?; + self.send_to_tcp(&msg_out, addr_a).await; } Ok(()) } + #[inline] async fn handle_punch_hole_request( &mut self, addr: SocketAddr, id: &str, - socket: &mut FramedSocket, - is_tcp: bool, - pm: &PeerMap, - ) -> ResultType<()> { + ) -> ResultType<(RendezvousMessage, Option<SocketAddr>)> { // punch hole request from A, forward to B, // check if in same intranet first, // fetch local addrs if in same intranet. // because punch hole won't work if in the same intranet, // all routers will drop such self-connections. - if let Some(peer) = pm.get(id) { + if let Some(peer) = self.pm.get(id) { if peer.last_reg_time.elapsed().as_millis() as i32 >= REG_TIMEOUT { let mut msg_out = RendezvousMessage::new(); msg_out.set_punch_hole_response(PunchHoleResponse { failure: punch_hole_response::Failure::OFFLINE.into(), ..Default::default() }); - return if is_tcp { - self.send_to_tcp(&msg_out, addr).await - } else { - socket.send(&msg_out, addr).await - }; + return Ok((msg_out, None)); } let mut msg_out = RendezvousMessage::new(); let same_intranet = match peer.socket_addr { @@ -265,32 +266,73 @@ impl RendezvousServer { ..Default::default() }); } - socket.send(&msg_out, peer.socket_addr).await?; + return Ok((msg_out, Some(peer.socket_addr))); } else { let mut msg_out = RendezvousMessage::new(); msg_out.set_punch_hole_response(PunchHoleResponse { failure: punch_hole_response::Failure::ID_NOT_EXIST.into(), ..Default::default() }); - return if is_tcp { - self.send_to_tcp(&msg_out, addr).await - } else { - socket.send(&msg_out, addr).await - }; + return Ok((msg_out, None)); } - Ok(()) } - async fn send_to_tcp(&mut self, msg: &RendezvousMessage, addr: SocketAddr) -> ResultType<()> { + #[inline] + async fn send_to_tcp(&mut self, msg: &RendezvousMessage, addr: SocketAddr) { let tcp = self.tcp_punch.lock().unwrap().remove(&addr); if let Some(mut tcp) = tcp { if let Ok(bytes) = msg.write_to_bytes() { tokio::spawn(async move { allow_err!(tcp.send(Bytes::from(bytes)).await); - log::debug!("Send punch hole to {} via tcp", addr); }); } } + } + + #[inline] + async fn send_to_tcp_sync( + &mut self, + msg: &RendezvousMessage, + addr: SocketAddr, + ) -> ResultType<()> { + let tcp = self.tcp_punch.lock().unwrap().remove(&addr); + if let Some(mut tcp) = tcp { + if let Ok(bytes) = msg.write_to_bytes() { + tcp.send(Bytes::from(bytes)).await?; + } + } + Ok(()) + } + + #[inline] + async fn handle_tcp_punch_hole_request( + &mut self, + addr: SocketAddr, + id: &str, + ) -> ResultType<()> { + let (msg, to_addr) = self.handle_punch_hole_request(addr, id).await?; + if let Some(addr) = to_addr { + self.tx.send((msg, addr))?; + } else { + self.send_to_tcp_sync(&msg, addr).await?; + } + Ok(()) + } + + #[inline] + async fn handle_udp_punch_hole_request( + &mut self, + addr: SocketAddr, + id: &str, + ) -> ResultType<()> { + let (msg, to_addr) = self.handle_punch_hole_request(addr, id).await?; + self.tx.send(( + msg, + match to_addr { + Some(addr) => addr, + None => addr, + }, + ))?; Ok(()) } } |