diff --git a/Cargo.lock b/Cargo.lock index 9af999253..16159ab24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1553,6 +1553,7 @@ name = "hbb_common" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "bytes", "confy", "directories-next", @@ -1560,10 +1561,13 @@ dependencies = [ "env_logger 0.9.0", "filetime", "futures", + "futures-core", + "futures-sink", "futures-util", "lazy_static", "log", "mac_address", + "pin-project", "protobuf", "protobuf-codegen-pure", "quinn", @@ -1574,6 +1578,7 @@ dependencies = [ "socket2 0.3.19", "sodiumoxide", "tokio", + "tokio-socks", "tokio-util", "toml", "winapi 0.3.9", @@ -2455,6 +2460,26 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pin-project" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58ad3879ad3baf4e44784bc6a718a8698867bb991f8ce24d1bcbe2cfb4c3a75e" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "744b6f092ba29c3650faf274db506afd39944f48420f6c86b17cfe0ee1cb36bb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.8" @@ -3643,6 +3668,22 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-socks" +version = "0.5.1" +source = "git+https://github.com/fufesou/tokio-socks#121a780c7e6a31c3aac70e7234f5c62eecaf0629" +dependencies = [ + "bytes", + "either", + "futures-core", + "futures-sink", + "futures-util", + "pin-project", + "thiserror", + "tokio", + "tokio-util", +] + [[package]] name = "tokio-util" version = "0.6.9" diff --git a/libs/hbb_common/Cargo.toml b/libs/hbb_common/Cargo.toml index c67c71647..dbebe8206 100644 --- a/libs/hbb_common/Cargo.toml +++ b/libs/hbb_common/Cargo.toml @@ -28,6 +28,11 @@ confy = { git = "https://github.com/open-trade/confy" } dirs-next = "2.0" filetime = "0.2" sodiumoxide = "0.2" +tokio-socks = { git = "https://github.com/fufesou/tokio-socks" } +futures-core = "0.3" +futures-sink = "0.3" +async-trait = "0.1" +pin-project = "1" [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] mac_address = "1.1" diff --git a/libs/hbb_common/src/config.rs b/libs/hbb_common/src/config.rs index 26b8d160f..fb558eef1 100644 --- a/libs/hbb_common/src/config.rs +++ b/libs/hbb_common/src/config.rs @@ -71,6 +71,16 @@ pub struct Config { keys_confirmed: HashMap, } +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct Socks5Server { + #[serde(default)] + pub proxy: String, + #[serde(default)] + pub username: String, + #[serde(default)] + pub password: String, +} + // more variable configs #[derive(Debug, Default, Serialize, Deserialize, Clone)] pub struct Config2 { @@ -85,6 +95,9 @@ pub struct Config2 { #[serde(default)] serial: i32, + #[serde(default)] + socks: Option, + // the other scalar value must before this #[serde(default)] pub options: HashMap, @@ -619,6 +632,16 @@ impl Config { pub fn get_remote_id() -> String { CONFIG2.read().unwrap().remote_id.clone() } + + pub fn set_socks(socks: Option) { + let mut config = CONFIG2.write().unwrap(); + config.socks = socks; + config.store(); + } + + pub fn get_socks() -> Option { + CONFIG2.read().unwrap().socks.clone() + } } const PEERS: &str = "peers"; diff --git a/libs/hbb_common/src/lib.rs b/libs/hbb_common/src/lib.rs index dc0c3e3e2..4bc90d198 100644 --- a/libs/hbb_common/src/lib.rs +++ b/libs/hbb_common/src/lib.rs @@ -17,16 +17,20 @@ pub use tokio; pub use tokio_util; pub mod tcp; pub mod udp; +pub mod socket_client; pub use env_logger; pub use log; pub mod bytes_codec; #[cfg(feature = "quic")] pub mod quic; pub use anyhow::{self, bail}; +pub use futures_core; +pub use futures_sink; pub use futures_util; pub mod config; pub mod fs; pub use sodiumoxide; +pub use tokio_socks; #[cfg(feature = "quic")] pub type Stream = quic::Connection; diff --git a/libs/hbb_common/src/socket_client.rs b/libs/hbb_common/src/socket_client.rs new file mode 100644 index 000000000..596a96eaf --- /dev/null +++ b/libs/hbb_common/src/socket_client.rs @@ -0,0 +1,87 @@ +use crate::{ + config::{Config, Socks5Server}, + tcp::FramedStream, + udp::{FramedSocket, UdpFramedWrapper}, + ResultType, +}; +use anyhow::{bail, Context}; +use std::net::SocketAddr; +use tokio::net::ToSocketAddrs; +use tokio_socks::{udp::Socks5UdpFramed, IntoTargetAddr}; +use tokio_util::{codec::BytesCodec, udp::UdpFramed}; + +pub fn get_socks5_conf() -> Option { + // Config::set_socks(Some(Socks5Server { + // proxy: "139.186.136.143:1080".to_owned(), + // ..Default::default() + // })); + Config::get_socks() +} + +pub async fn connect_tcp<'t, T: IntoTargetAddr<'t>>( + target: T, + local: SocketAddr, + ms_timeout: u64, +) -> ResultType { + let target_addr = target.into_target_addr()?; + + if let Some(conf) = get_socks5_conf() { + FramedStream::connect( + conf.proxy.as_str(), + target_addr, + local, + conf.username.as_str(), + conf.password.as_str(), + ms_timeout, + ) + .await + } else { + let addrs: Vec = + std::net::ToSocketAddrs::to_socket_addrs(&target_addr)?.collect(); + if addrs.is_empty() { + bail!("Invalid target addr"); + }; + + FramedStream::new(addrs[0], local, ms_timeout) + .await + .with_context(|| "Failed to connect to rendezvous server") + } +} + +// TODO: merge connect_udp and connect_udp_socks5 +pub async fn connect_udp_socket( + local: T1, +) -> ResultType<( + FramedSocket>>, + Option, +)> { + Ok((FramedSocket::new(local).await?, None)) +} + +pub async fn connect_udp_socks5<'t, T1: IntoTargetAddr<'t>, T2: ToSocketAddrs>( + target: T1, + local: T2, + socks5: &Option, + ms_timeout: u64, +) -> ResultType<( + FramedSocket>, + Option, +)> { + match socks5 { + Some(conf) => { + let (socket, addr) = FramedSocket::connect( + conf.proxy.as_str(), + target, + local, + conf.username.as_str(), + conf.password.as_str(), + ms_timeout, + ) + .await?; + Ok((socket, Some(addr))) + } + None => { + bail!("Nil socks5 server config") + } + } +} diff --git a/libs/hbb_common/src/tcp.rs b/libs/hbb_common/src/tcp.rs index 4189963ba..9ee33ca5a 100644 --- a/libs/hbb_common/src/tcp.rs +++ b/libs/hbb_common/src/tcp.rs @@ -4,16 +4,31 @@ use futures::{SinkExt, StreamExt}; use protobuf::Message; use sodiumoxide::crypto::secretbox::{self, Key, Nonce}; use std::{ - io::{Error, ErrorKind}, + io::{self, Error, ErrorKind}, + net::SocketAddr, ops::{Deref, DerefMut}, + pin::Pin, + task::{Context, Poll}, }; -use tokio::net::{lookup_host, TcpListener, TcpSocket, TcpStream, ToSocketAddrs}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::{lookup_host, TcpListener, TcpSocket, ToSocketAddrs}, +}; +use tokio_socks::{tcp::Socks5Stream, IntoTargetAddr, ToProxyAddrs}; use tokio_util::codec::Framed; -pub struct FramedStream(Framed, Option<(Key, u64, u64)>, u64); +pub trait TcpStreamTrait: AsyncRead + AsyncWrite + Unpin {} +pub struct DynTcpStream(Box); + +pub struct FramedStream( + Framed, + SocketAddr, + Option<(Key, u64, u64)>, + u64, +); impl Deref for FramedStream { - type Target = Framed; + type Target = Framed; fn deref(&self) -> &Self::Target { &self.0 @@ -26,6 +41,20 @@ impl DerefMut for FramedStream { } } +impl Deref for DynTcpStream { + type Target = Box; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for DynTcpStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + fn new_socket(addr: std::net::SocketAddr, reuse: bool) -> Result { let socket = match addr { std::net::SocketAddr::V4(..) => TcpSocket::new_v4()?, @@ -44,8 +73,8 @@ fn new_socket(addr: std::net::SocketAddr, reuse: bool) -> Result( - remote_addr: T, + pub async fn new( + remote_addr: T1, local_addr: T2, ms_timeout: u64, ) -> ResultType { @@ -56,27 +85,86 @@ impl FramedStream { new_socket(local_addr, true)?.connect(remote_addr), ) .await??; - return Ok(Self(Framed::new(stream, BytesCodec::new()), None, 0)); + let addr = stream.local_addr()?; + return Ok(Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )); } } bail!("could not resolve to any address"); } - pub fn set_send_timeout(&mut self, ms: u64) { - self.2 = ms; + pub async fn connect<'a, 't, P, T1, T2>( + proxy: P, + target: T1, + local: T2, + username: &'a str, + password: &'a str, + ms_timeout: u64, + ) -> ResultType + where + P: ToProxyAddrs, + T1: IntoTargetAddr<'t>, + T2: ToSocketAddrs, + { + if let Some(local) = lookup_host(&local).await?.next() { + if let Some(proxy) = proxy.to_proxy_addrs().next().await { + let stream = + super::timeout(ms_timeout, new_socket(local, true)?.connect(proxy?)).await??; + let stream = if username.trim().is_empty() { + super::timeout( + ms_timeout, + Socks5Stream::connect_with_socket(stream, target), + ) + .await?? + } else { + super::timeout( + ms_timeout, + Socks5Stream::connect_with_password_and_socket( + stream, target, username, password, + ), + ) + .await?? + }; + let addr = stream.local_addr()?; + return Ok(Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + )); + }; + }; + bail!("could not resolve to any address"); } - pub fn from(stream: TcpStream) -> Self { - Self(Framed::new(stream, BytesCodec::new()), None, 0) + pub fn local_addr(&self) -> SocketAddr { + self.1 + } + + pub fn set_send_timeout(&mut self, ms: u64) { + self.3 = ms; + } + + pub fn from(stream: impl TcpStreamTrait + Send + 'static, addr: SocketAddr) -> Self { + Self( + Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), + addr, + None, + 0, + ) } pub fn set_raw(&mut self) { self.0.codec_mut().set_raw(); - self.1 = None; + self.2 = None; } pub fn is_secured(&self) -> bool { - self.1.is_some() + self.2.is_some() } #[inline] @@ -87,7 +175,7 @@ impl FramedStream { #[inline] pub async fn send_raw(&mut self, msg: Vec) -> ResultType<()> { let mut msg = msg; - if let Some(key) = self.1.as_mut() { + if let Some(key) = self.2.as_mut() { key.1 += 1; let nonce = Self::get_nonce(key.1); msg = secretbox::seal(&msg, &nonce, &key.0); @@ -98,8 +186,8 @@ impl FramedStream { #[inline] pub async fn send_bytes(&mut self, bytes: Bytes) -> ResultType<()> { - if self.2 > 0 { - super::timeout(self.2, self.0.send(bytes)).await??; + if self.3 > 0 { + super::timeout(self.3, self.0.send(bytes)).await??; } else { self.0.send(bytes).await?; } @@ -109,7 +197,7 @@ impl FramedStream { #[inline] pub async fn next(&mut self) -> Option> { let mut res = self.0.next().await; - if let Some(key) = self.1.as_mut() { + if let Some(key) = self.2.as_mut() { if let Some(Ok(bytes)) = res.as_mut() { key.2 += 1; let nonce = Self::get_nonce(key.2); @@ -137,7 +225,7 @@ impl FramedStream { } pub fn set_key(&mut self, key: Key) { - self.1 = Some((key, 0, 0)); + self.2 = Some((key, 0, 0)); } fn get_nonce(seqnum: u64) -> Nonce { @@ -161,3 +249,35 @@ pub async fn new_listener(addr: T, reuse: bool) -> ResultType< bail!("could not resolve to any address"); } } + +impl Unpin for DynTcpStream {} + +impl AsyncRead for DynTcpStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf) + } +} + +impl AsyncWrite for DynTcpStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut self.0), cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx) + } +} + +impl TcpStreamTrait for R {} diff --git a/libs/hbb_common/src/udp.rs b/libs/hbb_common/src/udp.rs index 637de6218..8203c47d0 100644 --- a/libs/hbb_common/src/udp.rs +++ b/libs/hbb_common/src/udp.rs @@ -1,26 +1,62 @@ use crate::{bail, ResultType}; -use bytes::BytesMut; +use anyhow::anyhow; +use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; +use futures_core::Stream; +use futures_sink::Sink; +use pin_project::pin_project; use protobuf::Message; use socket2::{Domain, Socket, Type}; use std::{ - io::Error, net::SocketAddr, ops::{Deref, DerefMut}, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::net::{ToSocketAddrs, UdpSocket}; +use tokio_socks::{ + udp::{Socks5UdpFramed, Socks5UdpMessage}, + IntoTargetAddr, TargetAddr, ToProxyAddrs, }; -use tokio::{net::ToSocketAddrs, net::UdpSocket}; use tokio_util::{codec::BytesCodec, udp::UdpFramed}; -pub struct FramedSocket(UdpFramed); +pub struct FramedSocket(F); -impl Deref for FramedSocket { - type Target = UdpFramed; +#[pin_project] +pub struct UdpFramedWrapper(#[pin] F); + +pub trait BytesMutGetter<'a> { + fn get_bytes_mut(&'a self) -> &'a BytesMut; +} + +impl Deref for FramedSocket { + type Target = F; fn deref(&self) -> &Self::Target { &self.0 } } +impl DerefMut for FramedSocket { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Deref for UdpFramedWrapper { + type Target = F; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for UdpFramedWrapper { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + fn new_socket(addr: SocketAddr, reuse: bool) -> Result { let socket = match addr { SocketAddr::V4(..) => Socket::new(Domain::ipv4(), Type::dgram(), None), @@ -38,50 +74,101 @@ fn new_socket(addr: SocketAddr, reuse: bool) -> Result { Ok(socket) } -impl DerefMut for FramedSocket { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl FramedSocket { +impl FramedSocket>> { pub async fn new(addr: T) -> ResultType { let socket = UdpSocket::bind(addr).await?; - Ok(Self(UdpFramed::new(socket, BytesCodec::new()))) + Ok(Self(UdpFramedWrapper(UdpFramed::new( + socket, + BytesCodec::new(), + )))) } #[allow(clippy::never_loop)] pub async fn new_reuse(addr: T) -> ResultType { for addr in addr.to_socket_addrs()? { - return Ok(Self(UdpFramed::new( - UdpSocket::from_std(new_socket(addr, true)?.into_udp_socket())?, + let socket = new_socket(addr, true)?.into_udp_socket(); + return Ok(Self(UdpFramedWrapper(UdpFramed::new( + UdpSocket::from_std(socket)?, BytesCodec::new(), - ))); + )))); } bail!("could not resolve to any address"); } +} + +impl FramedSocket> { + pub async fn connect<'a, 't, P: ToProxyAddrs, T1: IntoTargetAddr<'t>, T2: ToSocketAddrs>( + proxy: P, + target: T1, + local: T2, + username: &'a str, + password: &'a str, + ms_timeout: u64, + ) -> ResultType<(Self, SocketAddr)> { + let framed = if username.trim().is_empty() { + super::timeout( + ms_timeout, + Socks5UdpFramed::connect(proxy, target, Some(local)), + ) + .await?? + } else { + super::timeout( + ms_timeout, + Socks5UdpFramed::connect_with_password( + proxy, + target, + Some(local), + username, + password, + ), + ) + .await?? + }; + let addr = if let TargetAddr::Ip(c) = framed.target_addr() { + c + } else { + unreachable!() + }; + log::trace!( + "Socks5 udp connected, local addr: {}, target addr: {}", + framed.local_addr().unwrap(), + &addr + ); + Ok((Self(UdpFramedWrapper(framed)), addr)) + } +} + +// TODO: simplify this constraint +impl FramedSocket +where + F: Unpin + Stream + Sink<(Bytes, SocketAddr)>, + >::Error: Sync + Send + std::error::Error + 'static, +{ + pub async fn new_with(self) -> ResultType { + Ok(self) + } #[inline] pub async fn send(&mut self, msg: &impl Message, addr: SocketAddr) -> ResultType<()> { self.0 - .send((bytes::Bytes::from(msg.write_to_bytes().unwrap()), addr)) + .send((Bytes::from(msg.write_to_bytes().unwrap()), addr)) .await?; Ok(()) } #[inline] pub async fn send_raw(&mut self, msg: &'static [u8], addr: SocketAddr) -> ResultType<()> { - self.0.send((bytes::Bytes::from(msg), addr)).await?; + self.0.send((Bytes::from(msg), addr)).await?; Ok(()) } #[inline] - pub async fn next(&mut self) -> Option> { + pub async fn next(&mut self) -> Option<::Item> { self.0.next().await } #[inline] - pub async fn next_timeout(&mut self, ms: u64) -> Option> { + pub async fn next_timeout(&mut self, ms: u64) -> Option<::Item> { if let Ok(res) = tokio::time::timeout(std::time::Duration::from_millis(ms), self.0.next()).await { @@ -91,3 +178,59 @@ impl FramedSocket { } } } + +impl<'a> BytesMutGetter<'a> for BytesMut { + fn get_bytes_mut(&'a self) -> &'a BytesMut { + self + } +} + +impl<'a> BytesMutGetter<'a> for Socks5UdpMessage { + fn get_bytes_mut(&'a self) -> &'a BytesMut { + &self.data + } +} + +impl Stream for UdpFramedWrapper +where + F: Stream>, + for<'b> M: BytesMutGetter<'b> + std::fmt::Debug, + E: std::error::Error + Into, +{ + type Item = ResultType<(BytesMut, SocketAddr)>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project().0.poll_next(cx) { + Poll::Ready(Some(Ok((msg, addr)))) => { + Poll::Ready(Some(Ok((msg.get_bytes_mut().clone(), addr)))) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(anyhow!(e)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl Sink<(Bytes, SocketAddr)> for UdpFramedWrapper +where + F: Sink<(Bytes, SocketAddr)>, +{ + type Error = >::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().0.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: (Bytes, SocketAddr)) -> Result<(), Self::Error> { + self.project().0.start_send(item) + } + + #[allow(unused_mut)] + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().0.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().0.poll_close(cx) + } +} diff --git a/src/client.rs b/src/client.rs index c3096cb53..c37413ffb 100644 --- a/src/client.rs +++ b/src/client.rs @@ -13,8 +13,8 @@ use hbb_common::{ message_proto::*, protobuf::Message as _, rendezvous_proto::*, + socket_client, sodiumoxide::crypto::{box_, secretbox, sign}, - tcp::FramedStream, timeout, tokio::time::Duration, AddrMangle, ResultType, Stream, @@ -107,10 +107,10 @@ impl Client { let any_addr = Config::get_any_listen_addr(); let rendezvous_server = crate::get_rendezvous_server(1_000).await; log::info!("rendezvous server: {}", rendezvous_server); - let mut socket = FramedStream::new(rendezvous_server, any_addr, RENDEZVOUS_TIMEOUT) - .await - .with_context(|| "Failed to connect to rendezvous server")?; - let my_addr = socket.get_ref().local_addr()?; + + let mut socket = + socket_client::connect_tcp(rendezvous_server, any_addr, RENDEZVOUS_TIMEOUT).await?; + let my_addr = socket.local_addr(); let mut pk = Vec::new(); let mut relay_server = "".to_owned(); @@ -262,7 +262,8 @@ impl Client { } log::info!("peer address: {}, timeout: {}", peer, connect_timeout); let start = std::time::Instant::now(); - let mut conn = FramedStream::new(peer, local_addr, connect_timeout).await; + // NOTICE: Socks5 is be used event in intranet. Which may be not a good way. + let mut conn = socket_client::connect_tcp(peer, local_addr, connect_timeout).await; let direct = !conn.is_err(); if conn.is_err() { if !relay_server.is_empty() { @@ -393,9 +394,11 @@ impl Client { let mut uuid = "".to_owned(); for i in 1..=3 { // use different socket due to current hbbs implement requiring different nat address for each attempt - let mut socket = FramedStream::new(rendezvous_server, any_addr, RENDEZVOUS_TIMEOUT) - .await - .with_context(|| "Failed to connect to rendezvous server")?; + let mut socket = + socket_client::connect_tcp(rendezvous_server, any_addr, RENDEZVOUS_TIMEOUT) + .await + .with_context(|| "Failed to connect to rendezvous server")?; + let mut msg_out = RendezvousMessage::new(); uuid = Uuid::new_v4().to_string(); log::info!( @@ -438,7 +441,7 @@ impl Client { relay_server: String, conn_type: ConnType, ) -> ResultType { - let mut conn = FramedStream::new( + let mut conn = socket_client::connect_tcp( crate::check_port(relay_server, RELAY_PORT), Config::get_any_listen_addr(), CONNECT_TIMEOUT, diff --git a/src/common.rs b/src/common.rs index 921516311..75a619c31 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,20 +1,29 @@ +use std::net::SocketAddr; + pub use arboard::Clipboard as ClipboardContext; use hbb_common::{ - allow_err, bail, + allow_err, + anyhow::bail, + bytes::{Bytes, BytesMut}, compress::{compress as compress_func, decompress}, config::{Config, COMPRESS_LEVEL, RENDEZVOUS_TIMEOUT}, + futures_core::Stream, + futures_sink::Sink, log, message_proto::*, protobuf::Message as _, protobuf::ProtobufEnum, rendezvous_proto::*, - sleep, - tcp::FramedStream, - tokio, ResultType, + sleep, socket_client, tokio, + udp::FramedSocket, + ResultType, }; #[cfg(any(target_os = "android", target_os = "ios", feature = "cli"))] use hbb_common::{config::RENDEZVOUS_PORT, futures::future::join_all}; -use std::sync::{Arc, Mutex}; +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; pub const CLIPBOARD_NAME: &'static str = "clipboard"; pub const CLIPBOARD_INTERVAL: u64 = 333; @@ -259,13 +268,13 @@ async fn test_nat_type_() -> ResultType { let mut port2 = 0; let mut addr = Config::get_any_listen_addr(); for i in 0..2 { - let mut socket = FramedStream::new( + let mut socket = socket_client::connect_tcp( if i == 0 { &server1 } else { &server2 }, addr, RENDEZVOUS_TIMEOUT, ) .await?; - addr = socket.get_ref().local_addr()?; + addr = socket.local_addr(); socket.send(&msg_out).await?; if let Some(Ok(bytes)) = socket.next_timeout(3000).await { if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { @@ -302,12 +311,12 @@ async fn test_nat_type_() -> ResultType { } #[cfg(any(target_os = "android", target_os = "ios"))] -pub async fn get_rendezvous_server(_ms_timeout: u64) -> std::net::SocketAddr { +pub async fn get_rendezvous_server(_ms_timeout: u64) -> SocketAddr { Config::get_rendezvous_server() } #[cfg(not(any(target_os = "android", target_os = "ios")))] -pub async fn get_rendezvous_server(ms_timeout: u64) -> std::net::SocketAddr { +pub async fn get_rendezvous_server(ms_timeout: u64) -> SocketAddr { crate::ipc::get_rendezvous_server(ms_timeout).await } @@ -330,7 +339,7 @@ async fn test_rendezvous_server_() { for host in servers { futs.push(tokio::spawn(async move { let tm = std::time::Instant::now(); - if FramedStream::new( + if socket_client::connect_tcp( &crate::check_port(&host, RENDEZVOUS_PORT), Config::get_any_listen_addr(), RENDEZVOUS_TIMEOUT, @@ -437,8 +446,37 @@ pub fn check_software_update() { #[tokio::main(flavor = "current_thread")] async fn _check_software_update() -> hbb_common::ResultType<()> { sleep(3.).await; + let rendezvous_server = get_rendezvous_server(1_000).await; - let mut socket = hbb_common::udp::FramedSocket::new(Config::get_any_listen_addr()).await?; + let socks5_conf = socket_client::get_socks5_conf(); + if socks5_conf.is_some() { + let conn_fn = |bind_addr: SocketAddr| async move { + socket_client::connect_udp_socks5( + rendezvous_server, + bind_addr, + &socks5_conf, + RENDEZVOUS_TIMEOUT, + ) + .await + }; + _inner_check_software_update(conn_fn, rendezvous_server).await + } else { + _inner_check_software_update(socket_client::connect_udp_socket, rendezvous_server).await + } +} + +pub async fn _inner_check_software_update<'a, F, Fut, Frm>( + conn_fn: F, + rendezvous_server: SocketAddr, +) -> ResultType<()> +where + F: FnOnce(SocketAddr) -> Fut, + Fut: Future, Option)>>, + Frm: Unpin + Stream> + Sink<(Bytes, SocketAddr)>, + >::Error: Sync + Send + std::error::Error + 'static, +{ + sleep(3.).await; + let (mut socket, _) = conn_fn(Config::get_any_listen_addr()).await?; let mut msg_out = RendezvousMessage::new(); msg_out.set_software_update(SoftwareUpdate { url: crate::VERSION.to_owned(), diff --git a/src/rendezvous_mediator.rs b/src/rendezvous_mediator.rs index 3719b353d..ffa2d565a 100644 --- a/src/rendezvous_mediator.rs +++ b/src/rendezvous_mediator.rs @@ -1,13 +1,16 @@ use crate::server::{check_zombie, new as new_server, ServerPtr}; use hbb_common::{ allow_err, + anyhow::bail, + bytes::{Bytes, BytesMut}, config::{Config, RENDEZVOUS_PORT, RENDEZVOUS_TIMEOUT}, futures::future::join_all, + futures_core::Stream, + futures_sink::Sink, log, protobuf::Message as _, rendezvous_proto::*, - sleep, - tcp::FramedStream, + sleep, socket_client, tokio::{ self, select, time::{interval, Duration}, @@ -16,6 +19,7 @@ use hbb_common::{ AddrMangle, ResultType, }; use std::{ + future::Future, net::SocketAddr, sync::{Arc, Mutex}, time::SystemTime, @@ -59,7 +63,35 @@ impl RendezvousMediator { let server = server.clone(); let servers = servers.clone(); futs.push(tokio::spawn(async move { - allow_err!(Self::start(server, host, servers).await); + let socks5_conf = socket_client::get_socks5_conf(); + if socks5_conf.is_some() { + let target = format!("{}:{}", host, RENDEZVOUS_PORT); + let conn_fn = |bind_addr: SocketAddr| { + let target = target.clone(); + let conf_ref = &socks5_conf; + async move { + socket_client::connect_udp_socks5( + target, + bind_addr, + conf_ref, + RENDEZVOUS_TIMEOUT, + ) + .await + } + }; + allow_err!(Self::start(server, host, servers, conn_fn, true).await); + } else { + allow_err!( + Self::start( + server, + host, + servers, + socket_client::connect_udp_socket, + false, + ) + .await + ); + } })); } join_all(futs).await; @@ -68,11 +100,19 @@ impl RendezvousMediator { } } - pub async fn start( + pub async fn start<'a, F, Fut, Frm>( server: ServerPtr, host: String, rendezvous_servers: Vec, - ) -> ResultType<()> { + conn_fn: F, + socks5: bool, + ) -> ResultType<()> + where + F: Fn(SocketAddr) -> Fut, + Fut: Future, Option)>>, + Frm: Unpin + Stream> + Sink<(Bytes, SocketAddr)>, + >::Error: Sync + Send + std::error::Error + 'static, + { log::info!("start rendezvous mediator of {}", host); let host_prefix: String = host .split(".") @@ -93,7 +133,12 @@ impl RendezvousMediator { last_id_pk_registry: "".to_owned(), }; allow_err!(rz.dns_check()); - let mut socket = FramedSocket::new(Config::get_any_listen_addr()).await?; + + let bind_addr = Config::get_any_listen_addr(); + let (mut socket, target_addr) = conn_fn(bind_addr).await?; + if let Some(addr) = target_addr { + rz.addr = addr; + } const TIMER_OUT: Duration = Duration::from_secs(1); let mut timer = interval(TIMER_OUT); let mut last_timer = SystemTime::UNIX_EPOCH; @@ -136,60 +181,68 @@ impl RendezvousMediator { } }; select! { - Some(Ok((bytes, _))) = socket.next() => { - if let Ok(msg_in) = Message::parse_from_bytes(&bytes) { - match msg_in.union { - Some(rendezvous_message::Union::register_peer_response(rpr)) => { - update_latency(); - if rpr.request_pk { - log::info!("request_pk received from {}", host); - allow_err!(rz.register_pk(&mut socket).await); - continue; - } - } - Some(rendezvous_message::Union::register_pk_response(rpr)) => { - update_latency(); - match rpr.result.enum_value_or_default() { - register_pk_response::Result::OK => { - Config::set_key_confirmed(true); - Config::set_host_key_confirmed(&rz.host_prefix, true); - *SOLVING_PK_MISMATCH.lock().unwrap() = "".to_owned(); + n = socket.next() => { + match n { + Some(Ok((bytes, _))) => { + if let Ok(msg_in) = Message::parse_from_bytes(&bytes) { + match msg_in.union { + Some(rendezvous_message::Union::register_peer_response(rpr)) => { + update_latency(); + if rpr.request_pk { + log::info!("request_pk received from {}", host); + allow_err!(rz.register_pk(&mut socket).await); + continue; + } } - register_pk_response::Result::UUID_MISMATCH => { - allow_err!(rz.handle_uuid_mismatch(&mut socket).await); + Some(rendezvous_message::Union::register_pk_response(rpr)) => { + update_latency(); + match rpr.result.enum_value_or_default() { + register_pk_response::Result::OK => { + Config::set_key_confirmed(true); + Config::set_host_key_confirmed(&rz.host_prefix, true); + *SOLVING_PK_MISMATCH.lock().unwrap() = "".to_owned(); + } + register_pk_response::Result::UUID_MISMATCH => { + allow_err!(rz.handle_uuid_mismatch(&mut socket).await); + } + _ => {} + } + } + Some(rendezvous_message::Union::punch_hole(ph)) => { + let rz = rz.clone(); + let server = server.clone(); + tokio::spawn(async move { + allow_err!(rz.handle_punch_hole(ph, server).await); + }); + } + Some(rendezvous_message::Union::request_relay(rr)) => { + let rz = rz.clone(); + let server = server.clone(); + tokio::spawn(async move { + allow_err!(rz.handle_request_relay(rr, server).await); + }); + } + Some(rendezvous_message::Union::fetch_local_addr(fla)) => { + let rz = rz.clone(); + let server = server.clone(); + tokio::spawn(async move { + allow_err!(rz.handle_intranet(fla, server).await); + }); + } + Some(rendezvous_message::Union::configure_update(cu)) => { + Config::set_option("rendezvous-servers".to_owned(), cu.rendezvous_servers.join(",")); + Config::set_serial(cu.serial); } _ => {} } + } else { + log::debug!("Non-protobuf message bytes received: {:?}", bytes); } - Some(rendezvous_message::Union::punch_hole(ph)) => { - let rz = rz.clone(); - let server = server.clone(); - tokio::spawn(async move { - allow_err!(rz.handle_punch_hole(ph, server).await); - }); - } - Some(rendezvous_message::Union::request_relay(rr)) => { - let rz = rz.clone(); - let server = server.clone(); - tokio::spawn(async move { - allow_err!(rz.handle_request_relay(rr, server).await); - }); - } - Some(rendezvous_message::Union::fetch_local_addr(fla)) => { - let rz = rz.clone(); - let server = server.clone(); - tokio::spawn(async move { - allow_err!(rz.handle_intranet(fla, server).await); - }); - } - Some(rendezvous_message::Union::configure_update(cu)) => { - Config::set_option("rendezvous-servers".to_owned(), cu.rendezvous_servers.join(",")); - Config::set_serial(cu.serial); - } - _ => {} - } - } else { - log::debug!("Non-protobuf message bytes received: {:?}", bytes); + }, + Some(Err(e)) => bail!("Failed to receive next {}", e), // maybe socks5 tcp disconnected + None => { + // unreachable!() + }, } }, _ = timer.tick() => { @@ -200,13 +253,21 @@ impl RendezvousMediator { break; } if rz.addr.port() == 0 { - allow_err!(rz.dns_check()); - if rz.addr.port() == 0 { - continue; - } else { - // have to do this for osx, to avoid "Can't assign requested address" - // when socket created before OS network ready - socket = FramedSocket::new(Config::get_any_listen_addr()).await?; + // tcp is established to help connecting socks5 + if !socks5 { + allow_err!(rz.dns_check()); + if rz.addr.port() == 0 { + continue; + } else { + // have to do this for osx, to avoid "Can't assign requested address" + // when socket created before OS network ready + + let r = conn_fn(bind_addr).await?; + socket = r.0; + if let Some(addr) = r.1 { + rz.addr = addr; + } + } } } let now = SystemTime::now(); @@ -226,10 +287,18 @@ impl RendezvousMediator { Config::update_latency(&host, -1); old_latency = 0; if now.duration_since(last_dns_check).map(|d| d.as_millis() as i64).unwrap_or(0) > DNS_INTERVAL { - if let Ok(_) = rz.dns_check() { + // tcp is established to help connecting socks5 + if !socks5 { + if let Ok(_) = rz.dns_check() { // in some case of network reconnect (dial IP network), // old UDP socket not work any more after network recover - socket = FramedSocket::new(Config::get_any_listen_addr()).await?; + + let r = conn_fn(bind_addr).await?; + socket = r.0; + if let Some(addr) = r.1 { + rz.addr = addr; + } + } } last_dns_check = now; } @@ -280,8 +349,14 @@ impl RendezvousMediator { uuid, secure, ); - let mut socket = - FramedStream::new(self.addr, Config::get_any_listen_addr(), RENDEZVOUS_TIMEOUT).await?; + + let mut socket = socket_client::connect_tcp( + format!("{}:{}", self.host, RENDEZVOUS_PORT), + Config::get_any_listen_addr(), + RENDEZVOUS_TIMEOUT, + ) + .await?; + let mut msg_out = Message::new(); let mut rr = RelayResponse { socket_addr, @@ -303,15 +378,15 @@ impl RendezvousMediator { async fn handle_intranet(&self, fla: FetchLocalAddr, server: ServerPtr) -> ResultType<()> { let peer_addr = AddrMangle::decode(&fla.socket_addr); log::debug!("Handle intranet from {:?}", peer_addr); - let (mut socket, port) = { - let socket = - FramedStream::new(self.addr, Config::get_any_listen_addr(), RENDEZVOUS_TIMEOUT) - .await?; - let port = socket.get_ref().local_addr()?.port(); - (socket, port) - }; - let local_addr = socket.get_ref().local_addr()?; - let local_addr: SocketAddr = format!("{}:{}", local_addr.ip(), port).parse()?; + let mut socket = socket_client::connect_tcp( + format!("{}:{}", self.host, RENDEZVOUS_PORT), + Config::get_any_listen_addr(), + RENDEZVOUS_TIMEOUT, + ) + .await?; + let local_addr = socket.local_addr(); + let local_addr: SocketAddr = + format!("{}:{}", local_addr.ip(), local_addr.port()).parse()?; let mut msg_out = Message::new(); let mut relay_server = Config::get_option("relay-server"); if relay_server.is_empty() { @@ -347,10 +422,14 @@ impl RendezvousMediator { let peer_addr = AddrMangle::decode(&ph.socket_addr); log::debug!("Punch hole to {:?}", peer_addr); let mut socket = { - let socket = - FramedStream::new(self.addr, Config::get_any_listen_addr(), RENDEZVOUS_TIMEOUT) - .await?; - allow_err!(FramedStream::new(peer_addr, socket.get_ref().local_addr()?, 300).await); + let socket = socket_client::connect_tcp( + format!("{}:{}", self.host, RENDEZVOUS_PORT), + Config::get_any_listen_addr(), + RENDEZVOUS_TIMEOUT, + ) + .await?; + let local_addr = socket.local_addr(); + allow_err!(socket_client::connect_tcp(peer_addr, local_addr, 300).await); socket }; let mut msg_out = Message::new(); @@ -370,7 +449,11 @@ impl RendezvousMediator { Ok(()) } - async fn register_pk(&mut self, socket: &mut FramedSocket) -> ResultType<()> { + async fn register_pk(&mut self, socket: &mut FramedSocket) -> ResultType<()> + where + Frm: Unpin + Stream + Sink<(Bytes, SocketAddr)>, + >::Error: Sync + Send + std::error::Error + 'static, + { let mut msg_out = Message::new(); let pk = Config::get_key_pair().1; let uuid = if let Ok(id) = machine_uid::get() { @@ -391,7 +474,11 @@ impl RendezvousMediator { Ok(()) } - async fn handle_uuid_mismatch(&mut self, socket: &mut FramedSocket) -> ResultType<()> { + async fn handle_uuid_mismatch(&mut self, socket: &mut FramedSocket) -> ResultType<()> + where + Frm: Unpin + Stream + Sink<(Bytes, SocketAddr)>, + >::Error: Sync + Send + std::error::Error + 'static, + { if self.last_id_pk_registry != Config::get_id() { return Ok(()); } @@ -409,7 +496,11 @@ impl RendezvousMediator { self.register_pk(socket).await } - async fn register_peer(&mut self, socket: &mut FramedSocket) -> ResultType<()> { + async fn register_peer(&mut self, socket: &mut FramedSocket) -> ResultType<()> + where + Frm: Unpin + Stream + Sink<(Bytes, SocketAddr)>, + >::Error: Sync + Send + std::error::Error + 'static, + { if !SOLVING_PK_MISMATCH.lock().unwrap().is_empty() { return Ok(()); } diff --git a/src/server.rs b/src/server.rs index 270b8075e..451c18e1b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,5 @@ use crate::ipc::Data; -pub use connection::*; +use connection::{ConnInner, Connection}; use hbb_common::{ allow_err, anyhow::{anyhow, Context}, @@ -11,8 +11,8 @@ use hbb_common::{ rendezvous_proto::*, sleep, sodiumoxide::crypto::{box_, secretbox, sign}, - tcp::FramedStream, timeout, tokio, ResultType, Stream, + socket_client, }; use service::{GenericService, Service, ServiceTmpl, Subscriber}; use std::{ @@ -61,7 +61,7 @@ pub fn new() -> ServerPtr { } async fn accept_connection_(server: ServerPtr, socket: Stream, secure: bool) -> ResultType<()> { - let local_addr = socket.get_ref().local_addr()?; + let local_addr = socket.local_addr(); drop(socket); // even we drop socket, below still may fail if not use reuse_addr, // there is TIME_WAIT before socket really released, so sometimes we @@ -69,7 +69,8 @@ async fn accept_connection_(server: ServerPtr, socket: Stream, secure: bool) -> let listener = new_listener(local_addr, true).await?; log::info!("Server listening on: {}", &listener.local_addr()?); if let Ok((stream, addr)) = timeout(CONNECT_TIMEOUT, listener.accept()).await? { - create_tcp_connection(server, Stream::from(stream), addr, secure).await?; + let stream_addr = stream.local_addr()?; + create_tcp_connection(server, Stream::from(stream, stream_addr), addr, secure).await?; } Ok(()) } @@ -183,8 +184,8 @@ async fn create_relay_connection_( peer_addr: SocketAddr, secure: bool, ) -> ResultType<()> { - let mut stream = FramedStream::new( - &crate::check_port(relay_server, RELAY_PORT), + let mut stream = socket_client::connect_tcp( + crate::check_port(relay_server, RELAY_PORT), Config::get_any_listen_addr(), CONNECT_TIMEOUT, )