diff --git a/libs/hbb_common/src/socket_client.rs b/libs/hbb_common/src/socket_client.rs index 667a5161e..b7cb13754 100644 --- a/libs/hbb_common/src/socket_client.rs +++ b/libs/hbb_common/src/socket_client.rs @@ -118,20 +118,33 @@ pub fn ipv4_to_ipv6(addr: String, ipv4: bool) -> String { addr } -async fn test_is_ipv4(target: &str) -> bool { +async fn test_target(target: &str) -> ResultType { if let Ok(Ok(s)) = super::timeout(1000, tokio::net::TcpStream::connect(target)).await { - return s.local_addr().map(|x| x.is_ipv4()).unwrap_or(true); + if let Ok(addr) = s.peer_addr() { + return Ok(addr); + } } - true + tokio::net::lookup_host(target) + .await? + .next() + .context(format!("Failed to look up host for {}", target)) } #[inline] -pub async fn new_udp_for(target: &str, ms_timeout: u64) -> ResultType { - new_udp( - Config::get_any_listen_addr(test_is_ipv4(target).await), - ms_timeout, - ) - .await +pub async fn new_udp_for( + target: &str, + ms_timeout: u64, +) -> ResultType<(FramedSocket, TargetAddr<'static>)> { + let (ipv4, target) = if NetworkType::Direct == Config::get_network_type() { + let addr = test_target(target).await?; + (addr.is_ipv4(), addr.into_target_addr()?) + } else { + (true, target.into_target_addr()?) + }; + Ok(( + new_udp(Config::get_any_listen_addr(ipv4), ms_timeout).await?, + target.to_owned(), + )) } async fn new_udp(local: T, ms_timeout: u64) -> ResultType { @@ -151,13 +164,18 @@ async fn new_udp(local: T, ms_timeout: u64) -> ResultType ResultType> { - match Config::get_network_type() { - NetworkType::Direct => Ok(Some( - FramedSocket::new(Config::get_any_listen_addr(test_is_ipv4(target).await)).await?, - )), - _ => Ok(None), +pub async fn rebind_udp_for( + target: &str, +) -> ResultType)>> { + if Config::get_network_type() != NetworkType::Direct { + return Ok(None); } + let addr = test_target(target).await?; + let v4 = addr.is_ipv4(); + Ok(Some(( + FramedSocket::new(Config::get_any_listen_addr(v4)).await?, + addr.into_target_addr()?.to_owned(), + ))) } #[cfg(test)] diff --git a/src/common.rs b/src/common.rs index 1ae9b7dbe..c28bbc3fc 100644 --- a/src/common.rs +++ b/src/common.rs @@ -520,7 +520,8 @@ async fn check_software_update_() -> hbb_common::ResultType<()> { sleep(3.).await; let rendezvous_server = format!("rs-sg.rustdesk.com:{}", config::RENDEZVOUS_PORT); - let mut socket = socket_client::new_udp_for(&rendezvous_server, RENDEZVOUS_TIMEOUT).await?; + let (mut socket, rendezvous_server) = + socket_client::new_udp_for(&rendezvous_server, RENDEZVOUS_TIMEOUT).await?; let mut msg_out = RendezvousMessage::new(); msg_out.set_software_update(SoftwareUpdate { diff --git a/src/main.rs b/src/main.rs index ca0bc2234..67ddb875f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,7 +38,7 @@ fn main() { "-p, --port-forward=[PORT-FORWARD-OPTIONS] 'Format: remote-id:local-port:remote-port[:remote-host]' -c, --connect=[REMOTE_ID] 'test only' -k, --key=[KEY] '' - -s, --server... 'Start server'", + -s, --server=[] 'Start server'", ); let matches = App::new("rustdesk") .version(crate::VERSION) diff --git a/src/rendezvous_mediator.rs b/src/rendezvous_mediator.rs index ca08172d3..ec70bdf84 100644 --- a/src/rendezvous_mediator.rs +++ b/src/rendezvous_mediator.rs @@ -18,13 +18,14 @@ use hbb_common::{ log, protobuf::Message as _, rendezvous_proto::*, - sleep, socket_client, + sleep, + socket_client::{self, is_ipv4}, tokio::{ self, select, time::{interval, Duration}, }, udp::FramedSocket, - AddrMangle, ResultType, TargetAddr, + AddrMangle, ResultType, }; use crate::server::{check_zombie, new as new_server, ServerPtr}; @@ -38,11 +39,10 @@ static SHOULD_EXIT: AtomicBool = AtomicBool::new(false); #[derive(Clone)] pub struct RendezvousMediator { - addr: String, + addr: hbb_common::tokio_socks::TargetAddr<'static>, host: String, host_prefix: String, last_id_pk_registry: String, - is_ipv4: bool, } impl RendezvousMediator { @@ -111,17 +111,15 @@ impl RendezvousMediator { } }) .unwrap_or(host.to_owned()); + let host = crate::check_port(&host, RENDEZVOUS_PORT); + let (mut socket, addr) = socket_client::new_udp_for(&host, RENDEZVOUS_TIMEOUT).await?; let mut rz = Self { - addr: crate::check_port(&host, RENDEZVOUS_PORT), - is_ipv4: false, + addr: addr, host: host.clone(), host_prefix, last_id_pk_registry: "".to_owned(), }; - let mut socket = socket_client::new_udp_for(&rz.addr, RENDEZVOUS_TIMEOUT).await?; - rz.is_ipv4 = socket.is_ipv4(); - const TIMER_OUT: Duration = Duration::from_secs(1); let mut timer = interval(TIMER_OUT); let mut last_timer: Option = None; @@ -253,9 +251,9 @@ impl RendezvousMediator { if last_dns_check.elapsed().as_millis() as i64 > DNS_INTERVAL { // in some case of network reconnect (dial IP network), // old UDP socket not work any more after network recover - if let Some(s) = socket_client::rebind_udp_for(&rz.addr).await? { + if let Some((s, addr)) = socket_client::rebind_udp_for(&rz.host).await? { socket = s; - rz.is_ipv4 = socket.is_ipv4(); + rz.addr = addr; } last_dns_check = Instant::now(); } @@ -301,8 +299,7 @@ impl RendezvousMediator { secure, ); - let mut socket = - socket_client::connect_tcp(self.addr.to_owned(), RENDEZVOUS_TIMEOUT).await?; + let mut socket = socket_client::connect_tcp(&*self.host, RENDEZVOUS_TIMEOUT).await?; let mut msg_out = Message::new(); let mut rr = RelayResponse { @@ -317,14 +314,21 @@ impl RendezvousMediator { } msg_out.set_relay_response(rr); socket.send(&msg_out).await?; - crate::create_relay_connection(server, relay_server, uuid, peer_addr, secure, self.is_ipv4) - .await; + crate::create_relay_connection( + server, + relay_server, + uuid, + peer_addr, + secure, + is_ipv4(&self.addr), + ) + .await; Ok(()) } async fn handle_intranet(&self, fla: FetchLocalAddr, server: ServerPtr) -> ResultType<()> { let relay_server = self.get_relay_server(fla.relay_server); - if !self.is_ipv4 { + if !is_ipv4(&self.addr) { // nat64, go relay directly, because current hbbs will crash if demangle ipv6 address let uuid = Uuid::new_v4().to_string(); return self @@ -340,8 +344,7 @@ impl RendezvousMediator { } let peer_addr = AddrMangle::decode(&fla.socket_addr); log::debug!("Handle intranet from {:?}", peer_addr); - let mut socket = - socket_client::connect_tcp(self.addr.to_owned(), RENDEZVOUS_TIMEOUT).await?; + let mut socket = socket_client::connect_tcp(&*self.host, RENDEZVOUS_TIMEOUT).await?; let local_addr = socket.local_addr(); let local_addr: SocketAddr = format!("{}:{}", local_addr.ip(), local_addr.port()).parse()?; @@ -380,8 +383,7 @@ impl RendezvousMediator { let peer_addr = AddrMangle::decode(&ph.socket_addr); log::debug!("Punch hole to {:?}", peer_addr); let mut socket = { - let socket = - socket_client::connect_tcp(self.addr.to_owned(), RENDEZVOUS_TIMEOUT).await?; + let socket = socket_client::connect_tcp(&*self.host, RENDEZVOUS_TIMEOUT).await?; let local_addr = socket.local_addr(); // key important here for punch hole to tell my gateway incoming peer is safe. // it can not be async here, because local_addr can not be reused, we must close the connection before use it again.