diff options
author | open-trade <[email protected]> | 2020-03-09 23:54:36 +0800 |
---|---|---|
committer | open-trade <[email protected]> | 2020-03-09 23:54:36 +0800 |
commit | 69f60499dbef16a7976d9eff737254509bf7fc0b (patch) | |
tree | 5acd6ebc273bbeea0f6e1e1b8bf4b8960a8adb84 /src | |
parent | df9de9596868b99faf5b0b02d95c83186ae97ae2 (diff) | |
download | rustdesk-server-69f60499dbef16a7976d9eff737254509bf7fc0b.tar.gz rustdesk-server-69f60499dbef16a7976d9eff737254509bf7fc0b.zip |
refactored
Diffstat (limited to 'src')
-rw-r--r-- | src/rendezvous_server.rs | 137 |
1 files changed, 65 insertions, 72 deletions
diff --git a/src/rendezvous_server.rs b/src/rendezvous_server.rs index 3f75e77..cbd5355 100644 --- a/src/rendezvous_server.rs +++ b/src/rendezvous_server.rs @@ -1,11 +1,6 @@ use hbb_common::{ - bytes::BytesMut, - log, - message_proto::*, - protobuf::parse_from_bytes, - tokio::{net::UdpSocket, stream::StreamExt}, - udp::FramedSocket, - AddrMangle, ResultType, + bytes::BytesMut, log, message_proto::*, protobuf::parse_from_bytes, tokio::net::UdpSocket, + udp::FramedSocket, AddrMangle, ResultType, }; use std::{collections::HashMap, net::SocketAddr}; @@ -86,7 +81,7 @@ impl RendezvousServer { mod tests { use super::*; use hbb_common::tokio; - use std::time::Duration; + use std::io::{Error, ErrorKind}; #[allow(unused_must_use)] #[tokio::main] @@ -100,74 +95,72 @@ mod tests { let addr_server = format!("127.0.0.1:{}", port_server); let f1 = RendezvousServer::start(&addr_server); let addr_server = addr_server.parse().unwrap(); - let f2 = async { - // B register it to server - let socket_b = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - let local_addr_b = socket_b.local_addr().unwrap(); - let mut socket_b = FramedSocket::new(socket_b); - let mut msg_out = Message::new(); - msg_out.set_register_peer(RegisterPeer { - hbb_addr: "123".to_string(), - ..Default::default() - }); - socket_b.send(&msg_out, addr_server).await; - - // A send punch request to server - let socket_a = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - let local_addr_a = socket_a.local_addr().unwrap(); - let mut socket_a = FramedSocket::new(socket_a); - msg_out.set_punch_hole_request(PunchHoleRequest { - hbb_addr: "123".to_string(), + let f2 = punch_hole(addr_server); + tokio::try_join!(f1, f2); + } + + async fn punch_hole(addr_server: SocketAddr) -> ResultType<()> { + // B register it to server + let socket_b = UdpSocket::bind("127.0.0.1:0").await?; + let local_addr_b = socket_b.local_addr().unwrap(); + let mut socket_b = FramedSocket::new(socket_b); + let mut msg_out = Message::new(); + msg_out.set_register_peer(RegisterPeer { + hbb_addr: "123".to_string(), + ..Default::default() + }); + socket_b.send(&msg_out, addr_server).await?; + + // A send punch request to server + let socket_a = UdpSocket::bind("127.0.0.1:0").await?; + let local_addr_a = socket_a.local_addr().unwrap(); + let mut socket_a = FramedSocket::new(socket_a); + msg_out.set_punch_hole_request(PunchHoleRequest { + hbb_addr: "123".to_string(), + ..Default::default() + }); + socket_a.send(&msg_out, addr_server).await?; + + println!( + "A {:?} request punch hole to B {:?} via server {:?}", + local_addr_a, local_addr_b, addr_server, + ); + + // on B side, responsed to A's punch request forwarded from server + if let Some(Ok((bytes, addr))) = socket_b.next_timeout(1000).await { + assert_eq!(addr_server, addr); + let msg_in = parse_from_bytes::<Message>(&bytes)?; + let remote_addr_a = AddrMangle::decode(&msg_in.get_punch_hole().socket_addr[..]); + assert_eq!(local_addr_a, remote_addr_a); + + // B punch A + socket_b + .get_mut() + .send_to(&b"SYN"[..], &remote_addr_a) + .await?; + + msg_out.set_punch_hole_sent(PunchHoleSent { + socket_addr: AddrMangle::encode(&remote_addr_a), ..Default::default() }); - socket_a.send(&msg_out, addr_server).await; - - println!( - "A {:?} request punch hole to B {:?} via server {:?}", - local_addr_a, local_addr_b, addr_server, - ); - - // on B side, responsed to A's punch request forwarded from server - if let Ok(Some(Ok((bytes, addr)))) = - tokio::time::timeout(Duration::from_millis(1000), socket_b.next()).await - { - assert_eq!(addr_server, addr); - let msg_in = parse_from_bytes::<Message>(&bytes).unwrap(); - let remote_addr_a = AddrMangle::decode(&msg_in.get_punch_hole().socket_addr[..]); - assert_eq!(local_addr_a, remote_addr_a); - - // B punch A - socket_b - .get_mut() - .send_to(&b"SYN"[..], &remote_addr_a) - .await; - - msg_out.set_punch_hole_sent(PunchHoleSent { - socket_addr: AddrMangle::encode(&remote_addr_a), - ..Default::default() - }); - socket_b.send(&msg_out, addr_server).await; - } + socket_b.send(&msg_out, addr_server).await?; + } else { + panic!("failed"); + } - // on A side - socket_a.next().await; // skip "SYN" - if let Ok(Some(Ok((bytes, addr)))) = - tokio::time::timeout(Duration::from_millis(1000), socket_a.next()).await - { - assert_eq!(addr_server, addr); - let msg_in = parse_from_bytes::<Message>(&bytes).unwrap(); - let remote_addr_b = - AddrMangle::decode(&msg_in.get_punch_hole_response().socket_addr[..]); - assert_eq!(local_addr_b, remote_addr_b); - } + // on A side + socket_a.next().await; // skip "SYN" + if let Some(Ok((bytes, addr))) = socket_a.next_timeout(1000).await { + assert_eq!(addr_server, addr); + let msg_in = parse_from_bytes::<Message>(&bytes)?; + let remote_addr_b = + AddrMangle::decode(&msg_in.get_punch_hole_response().socket_addr[..]); + assert_eq!(local_addr_b, remote_addr_b); + } else { + panic!("failed"); + } - if true { - Err(Box::new(simple_error::SimpleError::new("done"))) - } else { - Ok(()) - } - }; - tokio::try_join!(f1, f2); + Err(Box::new(Error::new(ErrorKind::Other, "done"))) } #[test] |