From 875570e040ff5e0ef055e99aa32054687faa9954 Mon Sep 17 00:00:00 2001 From: fufesou Date: Tue, 4 Jan 2022 00:44:50 +0800 Subject: [PATCH] refactor udp framed Signed-off-by: fufesou --- Cargo.lock | 4 - libs/hbb_common/Cargo.toml | 4 - libs/hbb_common/src/config.rs | 13 +++ libs/hbb_common/src/lib.rs | 2 - libs/hbb_common/src/socket_client.rs | 54 ++++----- libs/hbb_common/src/udp.rs | 167 ++++++--------------------- src/common.rs | 52 ++------- src/rendezvous_mediator.rs | 145 ++++++++++------------- 8 files changed, 140 insertions(+), 301 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 16159ab24..c5afa5031 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1553,7 +1553,6 @@ name = "hbb_common" version = "0.1.0" dependencies = [ "anyhow", - "async-trait", "bytes", "confy", "directories-next", @@ -1561,13 +1560,10 @@ dependencies = [ "env_logger 0.9.0", "filetime", "futures", - "futures-core", - "futures-sink", "futures-util", "lazy_static", "log", "mac_address", - "pin-project", "protobuf", "protobuf-codegen-pure", "quinn", diff --git a/libs/hbb_common/Cargo.toml b/libs/hbb_common/Cargo.toml index dbebe8206..4dc7dcf86 100644 --- a/libs/hbb_common/Cargo.toml +++ b/libs/hbb_common/Cargo.toml @@ -29,10 +29,6 @@ dirs-next = "2.0" filetime = "0.2" sodiumoxide = "0.2" tokio-socks = { git = "https://github.com/fufesou/tokio-socks" } -futures-core = "0.3" -futures-sink = "0.3" -async-trait = "0.1" -pin-project = "1" [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] mac_address = "1.1" diff --git a/libs/hbb_common/src/config.rs b/libs/hbb_common/src/config.rs index fb558eef1..49b12f4b0 100644 --- a/libs/hbb_common/src/config.rs +++ b/libs/hbb_common/src/config.rs @@ -55,6 +55,12 @@ pub const RENDEZVOUS_SERVERS: &'static [&'static str] = &[ pub const RENDEZVOUS_PORT: i32 = 21116; pub const RELAY_PORT: i32 = 21117; +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum NetworkType { + Direct, + ProxySocks, +} + #[derive(Debug, Default, Serialize, Deserialize, Clone)] pub struct Config { #[serde(default)] @@ -642,6 +648,13 @@ impl Config { pub fn get_socks() -> Option { CONFIG2.read().unwrap().socks.clone() } + + pub fn get_network_type() -> NetworkType { + match &CONFIG2.read().unwrap().socks { + None => NetworkType::Direct, + Some(_) => NetworkType::ProxySocks, + } + } } const PEERS: &str = "peers"; diff --git a/libs/hbb_common/src/lib.rs b/libs/hbb_common/src/lib.rs index 4bc90d198..dd685f77b 100644 --- a/libs/hbb_common/src/lib.rs +++ b/libs/hbb_common/src/lib.rs @@ -24,8 +24,6 @@ pub mod bytes_codec; #[cfg(feature = "quic")] pub mod quic; pub use anyhow::{self, bail}; -pub use futures_core; -pub use futures_sink; pub use futures_util; pub mod config; pub mod fs; diff --git a/libs/hbb_common/src/socket_client.rs b/libs/hbb_common/src/socket_client.rs index 596a96eaf..2fd9bcff1 100644 --- a/libs/hbb_common/src/socket_client.rs +++ b/libs/hbb_common/src/socket_client.rs @@ -1,22 +1,21 @@ use crate::{ - config::{Config, Socks5Server}, + config::{Config, NetworkType}, tcp::FramedStream, - udp::{FramedSocket, UdpFramedWrapper}, + udp::FramedSocket, ResultType, }; use anyhow::{bail, Context}; use std::net::SocketAddr; use tokio::net::ToSocketAddrs; -use tokio_socks::{udp::Socks5UdpFramed, IntoTargetAddr}; -use tokio_util::{codec::BytesCodec, udp::UdpFramed}; +use tokio_socks::IntoTargetAddr; -pub fn get_socks5_conf() -> Option { - // Config::set_socks(Some(Socks5Server { - // proxy: "139.186.136.143:1080".to_owned(), - // ..Default::default() - // })); - Config::get_socks() -} +// fn get_socks5_conf() -> Option { +// // Config::set_socks(Some(Socks5Server { +// // proxy: "139.186.136.143:1080".to_owned(), +// // ..Default::default() +// // })); +// Config::get_socks() +// } pub async fn connect_tcp<'t, T: IntoTargetAddr<'t>>( target: T, @@ -25,7 +24,7 @@ pub async fn connect_tcp<'t, T: IntoTargetAddr<'t>>( ) -> ResultType { let target_addr = target.into_target_addr()?; - if let Some(conf) = get_socks5_conf() { + if let Some(conf) = Config::get_socks() { FramedStream::connect( conf.proxy.as_str(), target_addr, @@ -48,26 +47,13 @@ pub async fn connect_tcp<'t, T: IntoTargetAddr<'t>>( } } -// TODO: merge connect_udp and connect_udp_socks5 -pub async fn connect_udp_socket( - local: T1, -) -> ResultType<( - FramedSocket>>, - Option, -)> { - Ok((FramedSocket::new(local).await?, None)) -} - -pub async fn connect_udp_socks5<'t, T1: IntoTargetAddr<'t>, T2: ToSocketAddrs>( +pub async fn connect_udp<'t, T1: IntoTargetAddr<'t>, T2: ToSocketAddrs>( target: T1, local: T2, - socks5: &Option, ms_timeout: u64, -) -> ResultType<( - FramedSocket>, - Option, -)> { - match socks5 { +) -> ResultType<(FramedSocket, Option)> { + match Config::get_socks() { + None => Ok((FramedSocket::new(local).await?, None)), Some(conf) => { let (socket, addr) = FramedSocket::connect( conf.proxy.as_str(), @@ -80,8 +66,12 @@ pub async fn connect_udp_socks5<'t, T1: IntoTargetAddr<'t>, T2: ToSocketAddrs>( .await?; Ok((socket, Some(addr))) } - None => { - bail!("Nil socks5 server config") - } + } +} + +pub async fn reconnect_udp(local: T) -> ResultType> { + match Config::get_network_type() { + NetworkType::Direct => Ok(Some(FramedSocket::new(local).await?)), + _ => Ok(None), } } diff --git a/libs/hbb_common/src/udp.rs b/libs/hbb_common/src/udp.rs index 8203c47d0..719cea076 100644 --- a/libs/hbb_common/src/udp.rs +++ b/libs/hbb_common/src/udp.rs @@ -2,59 +2,16 @@ use crate::{bail, ResultType}; use anyhow::anyhow; use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; -use futures_core::Stream; -use futures_sink::Sink; -use pin_project::pin_project; use protobuf::Message; use socket2::{Domain, Socket, Type}; -use std::{ - net::SocketAddr, - ops::{Deref, DerefMut}, - pin::Pin, - task::{Context, Poll}, -}; +use std::net::SocketAddr; use tokio::net::{ToSocketAddrs, UdpSocket}; -use tokio_socks::{ - udp::{Socks5UdpFramed, Socks5UdpMessage}, - IntoTargetAddr, TargetAddr, ToProxyAddrs, -}; +use tokio_socks::{udp::Socks5UdpFramed, IntoTargetAddr, TargetAddr, ToProxyAddrs}; use tokio_util::{codec::BytesCodec, udp::UdpFramed}; -pub struct FramedSocket(F); - -#[pin_project] -pub struct UdpFramedWrapper(#[pin] F); - -pub trait BytesMutGetter<'a> { - fn get_bytes_mut(&'a self) -> &'a BytesMut; -} - -impl Deref for FramedSocket { - type Target = F; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for FramedSocket { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl Deref for UdpFramedWrapper { - type Target = F; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for UdpFramedWrapper { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } +pub enum FramedSocket { + Direct(UdpFramed), + ProxySocks(Socks5UdpFramed), } fn new_socket(addr: SocketAddr, reuse: bool) -> Result { @@ -74,29 +31,24 @@ fn new_socket(addr: SocketAddr, reuse: bool) -> Result { Ok(socket) } -impl FramedSocket>> { +impl FramedSocket { pub async fn new(addr: T) -> ResultType { let socket = UdpSocket::bind(addr).await?; - Ok(Self(UdpFramedWrapper(UdpFramed::new( - socket, - BytesCodec::new(), - )))) + Ok(Self::Direct(UdpFramed::new(socket, BytesCodec::new()))) } #[allow(clippy::never_loop)] pub async fn new_reuse(addr: T) -> ResultType { for addr in addr.to_socket_addrs()? { let socket = new_socket(addr, true)?.into_udp_socket(); - return Ok(Self(UdpFramedWrapper(UdpFramed::new( + return Ok(Self::Direct(UdpFramed::new( UdpSocket::from_std(socket)?, BytesCodec::new(), - )))); + ))); } bail!("could not resolve to any address"); } -} -impl FramedSocket> { pub async fn connect<'a, 't, P: ToProxyAddrs, T1: IntoTargetAddr<'t>, T2: ToSocketAddrs>( proxy: P, target: T1, @@ -134,43 +86,48 @@ impl FramedSocket> { framed.local_addr().unwrap(), &addr ); - Ok((Self(UdpFramedWrapper(framed)), addr)) - } -} - -// TODO: simplify this constraint -impl FramedSocket -where - F: Unpin + Stream + Sink<(Bytes, SocketAddr)>, - >::Error: Sync + Send + std::error::Error + 'static, -{ - pub async fn new_with(self) -> ResultType { - Ok(self) + Ok((Self::ProxySocks(framed), addr)) } #[inline] pub async fn send(&mut self, msg: &impl Message, addr: SocketAddr) -> ResultType<()> { - self.0 - .send((Bytes::from(msg.write_to_bytes().unwrap()), addr)) - .await?; + let send_data = (Bytes::from(msg.write_to_bytes().unwrap()), addr); + let _ = match self { + Self::Direct(f) => f.send(send_data).await?, + Self::ProxySocks(f) => f.send(send_data).await?, + }; Ok(()) } #[inline] pub async fn send_raw(&mut self, msg: &'static [u8], addr: SocketAddr) -> ResultType<()> { - self.0.send((Bytes::from(msg), addr)).await?; + let _ = match self { + Self::Direct(f) => f.send((Bytes::from(msg), addr)).await?, + Self::ProxySocks(f) => f.send((Bytes::from(msg), addr)).await?, + }; Ok(()) } #[inline] - pub async fn next(&mut self) -> Option<::Item> { - self.0.next().await + pub async fn next(&mut self) -> Option> { + match self { + Self::Direct(f) => match f.next().await { + Some(Ok((data, addr))) => Some(Ok((data, addr))), + Some(Err(e)) => Some(Err(anyhow!(e))), + None => None, + }, + Self::ProxySocks(f) => match f.next().await { + Some(Ok((data, addr))) => Some(Ok((data.data, addr))), + Some(Err(e)) => Some(Err(anyhow!(e))), + None => None, + }, + } } #[inline] - pub async fn next_timeout(&mut self, ms: u64) -> Option<::Item> { + pub async fn next_timeout(&mut self, ms: u64) -> Option> { if let Ok(res) = - tokio::time::timeout(std::time::Duration::from_millis(ms), self.0.next()).await + tokio::time::timeout(std::time::Duration::from_millis(ms), self.next()).await { res } else { @@ -178,59 +135,3 @@ where } } } - -impl<'a> BytesMutGetter<'a> for BytesMut { - fn get_bytes_mut(&'a self) -> &'a BytesMut { - self - } -} - -impl<'a> BytesMutGetter<'a> for Socks5UdpMessage { - fn get_bytes_mut(&'a self) -> &'a BytesMut { - &self.data - } -} - -impl Stream for UdpFramedWrapper -where - F: Stream>, - for<'b> M: BytesMutGetter<'b> + std::fmt::Debug, - E: std::error::Error + Into, -{ - type Item = ResultType<(BytesMut, SocketAddr)>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project().0.poll_next(cx) { - Poll::Ready(Some(Ok((msg, addr)))) => { - Poll::Ready(Some(Ok((msg.get_bytes_mut().clone(), addr)))) - } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(anyhow!(e)))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} - -impl Sink<(Bytes, SocketAddr)> for UdpFramedWrapper -where - F: Sink<(Bytes, SocketAddr)>, -{ - type Error = >::Error; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().0.poll_ready(cx) - } - - fn start_send(self: Pin<&mut Self>, item: (Bytes, SocketAddr)) -> Result<(), Self::Error> { - self.project().0.start_send(item) - } - - #[allow(unused_mut)] - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().0.poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().0.poll_close(cx) - } -} diff --git a/src/common.rs b/src/common.rs index 75a619c31..dcef673ed 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,27 +1,20 @@ -use std::net::SocketAddr; - pub use arboard::Clipboard as ClipboardContext; use hbb_common::{ allow_err, anyhow::bail, - bytes::{Bytes, BytesMut}, compress::{compress as compress_func, decompress}, - config::{Config, COMPRESS_LEVEL, RENDEZVOUS_TIMEOUT}, - futures_core::Stream, - futures_sink::Sink, + config::{Config, NetworkType, COMPRESS_LEVEL, RENDEZVOUS_TIMEOUT}, log, message_proto::*, protobuf::Message as _, protobuf::ProtobufEnum, rendezvous_proto::*, - sleep, socket_client, tokio, - udp::FramedSocket, - ResultType, + sleep, socket_client, tokio, ResultType, }; #[cfg(any(target_os = "android", target_os = "ios", feature = "cli"))] use hbb_common::{config::RENDEZVOUS_PORT, futures::future::join_all}; use std::{ - future::Future, + net::SocketAddr, sync::{Arc, Mutex}, }; @@ -274,7 +267,9 @@ async fn test_nat_type_() -> ResultType { RENDEZVOUS_TIMEOUT, ) .await?; - addr = socket.local_addr(); + if Config::get_network_type() == NetworkType::Direct { + addr = socket.local_addr(); + } socket.send(&msg_out).await?; if let Some(Ok(bytes)) = socket.next_timeout(3000).await { if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { @@ -448,35 +443,12 @@ async fn _check_software_update() -> hbb_common::ResultType<()> { sleep(3.).await; let rendezvous_server = get_rendezvous_server(1_000).await; - let socks5_conf = socket_client::get_socks5_conf(); - if socks5_conf.is_some() { - let conn_fn = |bind_addr: SocketAddr| async move { - socket_client::connect_udp_socks5( - rendezvous_server, - bind_addr, - &socks5_conf, - RENDEZVOUS_TIMEOUT, - ) - .await - }; - _inner_check_software_update(conn_fn, rendezvous_server).await - } else { - _inner_check_software_update(socket_client::connect_udp_socket, rendezvous_server).await - } -} - -pub async fn _inner_check_software_update<'a, F, Fut, Frm>( - conn_fn: F, - rendezvous_server: SocketAddr, -) -> ResultType<()> -where - F: FnOnce(SocketAddr) -> Fut, - Fut: Future, Option)>>, - Frm: Unpin + Stream> + Sink<(Bytes, SocketAddr)>, - >::Error: Sync + Send + std::error::Error + 'static, -{ - sleep(3.).await; - let (mut socket, _) = conn_fn(Config::get_any_listen_addr()).await?; + let (mut socket, _) = socket_client::connect_udp( + rendezvous_server, + Config::get_any_listen_addr(), + RENDEZVOUS_TIMEOUT, + ) + .await?; let mut msg_out = RendezvousMessage::new(); msg_out.set_software_update(SoftwareUpdate { url: crate::VERSION.to_owned(), diff --git a/src/rendezvous_mediator.rs b/src/rendezvous_mediator.rs index ffa2d565a..dc27126c5 100644 --- a/src/rendezvous_mediator.rs +++ b/src/rendezvous_mediator.rs @@ -2,11 +2,8 @@ use crate::server::{check_zombie, new as new_server, ServerPtr}; use hbb_common::{ allow_err, anyhow::bail, - bytes::{Bytes, BytesMut}, config::{Config, RENDEZVOUS_PORT, RENDEZVOUS_TIMEOUT}, futures::future::join_all, - futures_core::Stream, - futures_sink::Sink, log, protobuf::Message as _, rendezvous_proto::*, @@ -19,7 +16,6 @@ use hbb_common::{ AddrMangle, ResultType, }; use std::{ - future::Future, net::SocketAddr, sync::{Arc, Mutex}, time::SystemTime, @@ -63,35 +59,36 @@ impl RendezvousMediator { let server = server.clone(); let servers = servers.clone(); futs.push(tokio::spawn(async move { - let socks5_conf = socket_client::get_socks5_conf(); - if socks5_conf.is_some() { - let target = format!("{}:{}", host, RENDEZVOUS_PORT); - let conn_fn = |bind_addr: SocketAddr| { - let target = target.clone(); - let conf_ref = &socks5_conf; - async move { - socket_client::connect_udp_socks5( - target, - bind_addr, - conf_ref, - RENDEZVOUS_TIMEOUT, - ) - .await - } - }; - allow_err!(Self::start(server, host, servers, conn_fn, true).await); - } else { - allow_err!( - Self::start( - server, - host, - servers, - socket_client::connect_udp_socket, - false, - ) - .await - ); - } + allow_err!(Self::start(server, host, servers).await); + // let socks5_conf = socket_client::get_socks5_conf(); + // if socks5_conf.is_some() { + // let target = format!("{}:{}", host, RENDEZVOUS_PORT); + // let conn_fn = |bind_addr: SocketAddr| { + // let target = target.clone(); + // let conf_ref = &socks5_conf; + // async move { + // socket_client::connect_udp_socks5( + // target, + // bind_addr, + // conf_ref, + // RENDEZVOUS_TIMEOUT, + // ) + // .await + // } + // }; + // allow_err!(Self::start(server, host, servers, conn_fn, true).await); + // } else { + // allow_err!( + // Self::start( + // server, + // host, + // servers, + // socket_client::connect_udp_socket, + // false, + // ) + // .await + // ); + // } })); } join_all(futs).await; @@ -100,19 +97,11 @@ impl RendezvousMediator { } } - pub async fn start<'a, F, Fut, Frm>( + pub async fn start( server: ServerPtr, host: String, rendezvous_servers: Vec, - conn_fn: F, - socks5: bool, - ) -> ResultType<()> - where - F: Fn(SocketAddr) -> Fut, - Fut: Future, Option)>>, - Frm: Unpin + Stream> + Sink<(Bytes, SocketAddr)>, - >::Error: Sync + Send + std::error::Error + 'static, - { + ) -> ResultType<()> { log::info!("start rendezvous mediator of {}", host); let host_prefix: String = host .split(".") @@ -132,12 +121,17 @@ impl RendezvousMediator { rendezvous_servers, last_id_pk_registry: "".to_owned(), }; - allow_err!(rz.dns_check()); + let mut host_addr = rz.addr; + allow_err!(rz.dns_check(&mut host_addr)); let bind_addr = Config::get_any_listen_addr(); - let (mut socket, target_addr) = conn_fn(bind_addr).await?; + let target = format!("{}:{}", host, RENDEZVOUS_PORT); + let (mut socket, target_addr) = + socket_client::connect_udp(target, bind_addr, RENDEZVOUS_TIMEOUT).await?; if let Some(addr) = target_addr { rz.addr = addr; + } else { + rz.addr = host_addr; } const TIMER_OUT: Duration = Duration::from_secs(1); let mut timer = interval(TIMER_OUT); @@ -254,20 +248,16 @@ impl RendezvousMediator { } if rz.addr.port() == 0 { // tcp is established to help connecting socks5 - if !socks5 { - allow_err!(rz.dns_check()); - if rz.addr.port() == 0 { - continue; - } else { - // have to do this for osx, to avoid "Can't assign requested address" - // when socket created before OS network ready - - let r = conn_fn(bind_addr).await?; - socket = r.0; - if let Some(addr) = r.1 { - rz.addr = addr; - } - } + allow_err!(rz.dns_check(&mut host_addr)); + if host_addr.port() == 0 { + continue; + } else { + // have to do this for osx, to avoid "Can't assign requested address" + // when socket created before OS network ready + if let Some(s) = socket_client::reconnect_udp(bind_addr).await? { + socket = s; + rz.addr = host_addr; + }; } } let now = SystemTime::now(); @@ -287,18 +277,13 @@ impl RendezvousMediator { Config::update_latency(&host, -1); old_latency = 0; if now.duration_since(last_dns_check).map(|d| d.as_millis() as i64).unwrap_or(0) > DNS_INTERVAL { - // tcp is established to help connecting socks5 - if !socks5 { - if let Ok(_) = rz.dns_check() { + if let Ok(_) = rz.dns_check(&mut host_addr) { // in some case of network reconnect (dial IP network), // old UDP socket not work any more after network recover - - let r = conn_fn(bind_addr).await?; - socket = r.0; - if let Some(addr) = r.1 { - rz.addr = addr; - } - } + if let Some(s) = socket_client::reconnect_udp(bind_addr).await? { + socket = s; + rz.addr = host_addr; + }; } last_dns_check = now; } @@ -314,8 +299,8 @@ impl RendezvousMediator { Ok(()) } - fn dns_check(&mut self) -> ResultType<()> { - self.addr = hbb_common::to_socket_addr(&crate::check_port(&self.host, RENDEZVOUS_PORT))?; + fn dns_check(&self, addr: &mut SocketAddr) -> ResultType<()> { + *addr = hbb_common::to_socket_addr(&crate::check_port(&self.host, RENDEZVOUS_PORT))?; log::debug!("Lookup dns of {}", self.host); Ok(()) } @@ -449,11 +434,7 @@ impl RendezvousMediator { Ok(()) } - async fn register_pk(&mut self, socket: &mut FramedSocket) -> ResultType<()> - where - Frm: Unpin + Stream + Sink<(Bytes, SocketAddr)>, - >::Error: Sync + Send + std::error::Error + 'static, - { + async fn register_pk(&mut self, socket: &mut FramedSocket) -> ResultType<()> { let mut msg_out = Message::new(); let pk = Config::get_key_pair().1; let uuid = if let Ok(id) = machine_uid::get() { @@ -474,11 +455,7 @@ impl RendezvousMediator { Ok(()) } - async fn handle_uuid_mismatch(&mut self, socket: &mut FramedSocket) -> ResultType<()> - where - Frm: Unpin + Stream + Sink<(Bytes, SocketAddr)>, - >::Error: Sync + Send + std::error::Error + 'static, - { + async fn handle_uuid_mismatch(&mut self, socket: &mut FramedSocket) -> ResultType<()> { if self.last_id_pk_registry != Config::get_id() { return Ok(()); } @@ -496,11 +473,7 @@ impl RendezvousMediator { self.register_pk(socket).await } - async fn register_peer(&mut self, socket: &mut FramedSocket) -> ResultType<()> - where - Frm: Unpin + Stream + Sink<(Bytes, SocketAddr)>, - >::Error: Sync + Send + std::error::Error + 'static, - { + async fn register_peer(&mut self, socket: &mut FramedSocket) -> ResultType<()> { if !SOLVING_PK_MISMATCH.lock().unwrap().is_empty() { return Ok(()); }