diff --git a/libs/hbb_common/src/socket_client.rs b/libs/hbb_common/src/socket_client.rs index 72ab73f16..e09d38c3a 100644 --- a/libs/hbb_common/src/socket_client.rs +++ b/libs/hbb_common/src/socket_client.rs @@ -6,15 +6,25 @@ use crate::{ }; use anyhow::Context; use std::net::SocketAddr; -use tokio::net::ToSocketAddrs; +use std::net::ToSocketAddrs; use tokio_socks::{IntoTargetAddr, TargetAddr}; -fn to_socket_addr(host: &str) -> ResultType { - use std::net::ToSocketAddrs; - host.to_socket_addrs()? - .filter(|x| x.is_ipv4()) - .next() - .context("Failed to solve") +fn to_socket_addr(host: T) -> ResultType { + let mut addr_ipv4 = None; + let mut addr_ipv6 = None; + for addr in host.to_socket_addrs()? { + if addr.is_ipv4() && addr_ipv4.is_none() { + addr_ipv4 = Some(addr); + } + if addr.is_ipv6() && addr_ipv6.is_none() { + addr_ipv6 = Some(addr); + } + } + if let Some(addr) = addr_ipv4 { + Ok(addr) + } else { + addr_ipv6.context("Failed to solve") + } } pub fn get_target_addr(host: &str) -> ResultType> { @@ -44,15 +54,43 @@ pub fn test_if_valid_server(host: &str) -> String { } } -pub async fn connect_tcp<'t, T: IntoTargetAddr<'t>>( +pub trait IntoTargetAddr2<'a> { + /// Converts the value of self to a `TargetAddr`. + fn into_target_addr2(&self) -> ResultType>; +} + +impl<'a> IntoTargetAddr2<'a> for SocketAddr { + fn into_target_addr2(&self) -> ResultType> { + Ok(TargetAddr::Ip(*self)) + } +} + +impl<'a> IntoTargetAddr2<'a> for TargetAddr<'a> { + fn into_target_addr2(&self) -> ResultType> { + Ok(self.clone()) + } +} + +impl<'a> IntoTargetAddr2<'a> for String { + fn into_target_addr2(&self) -> ResultType> { + Ok(to_socket_addr(self)?.into_target_addr()?) + } +} + +impl<'a> IntoTargetAddr2<'a> for &str { + fn into_target_addr2(&self) -> ResultType> { + Ok(to_socket_addr(self)?.into_target_addr()?) + } +} + +pub async fn connect_tcp<'t, T: IntoTargetAddr2<'t> + std::fmt::Debug>( target: T, local: SocketAddr, ms_timeout: u64, ) -> ResultType { - let target_addr = target.into_target_addr()?; - + let target_addr = target.into_target_addr2()?; if let Some(conf) = Config::get_socks() { - FramedStream::connect( + return FramedStream::connect( conf.proxy.as_str(), target_addr, local, @@ -60,23 +98,21 @@ pub async fn connect_tcp<'t, T: IntoTargetAddr<'t>>( conf.password.as_str(), ms_timeout, ) - .await - } else { - let addr = std::net::ToSocketAddrs::to_socket_addrs(&target_addr)? - .filter(|x| x.is_ipv4()) - .next() - .context("Invalid target addr, no valid ipv4 address can be resolved.")?; - Ok(FramedStream::new(addr, local, ms_timeout).await?) + .await; } + let addr = ToSocketAddrs::to_socket_addrs(&target_addr)? + .next() + .context(format!("Invalid target addr: {:?}", target))?; + Ok(FramedStream::new(addr, local, ms_timeout).await?) } pub async fn new_udp(local: T, ms_timeout: u64) -> ResultType { match Config::get_socks() { - None => Ok(FramedSocket::new(local).await?), + None => Ok(FramedSocket::new(to_socket_addr(&local)?).await?), Some(conf) => { let socket = FramedSocket::new_proxy( conf.proxy.as_str(), - local, + to_socket_addr(local)?, conf.username.as_str(), conf.password.as_str(), ms_timeout, @@ -89,7 +125,17 @@ pub async fn new_udp(local: T, ms_timeout: u64) -> ResultType< pub async fn rebind_udp(local: T) -> ResultType> { match Config::get_network_type() { - NetworkType::Direct => Ok(Some(FramedSocket::new(local).await?)), + NetworkType::Direct => Ok(Some(FramedSocket::new(to_socket_addr(local)?).await?)), _ => Ok(None), } } +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_to_socket_addr() { + assert_eq!(to_socket_addr("127.0.0.1:8080").unwrap(), "127.0.0.1:8080".parse().unwrap()); + assert!(to_socket_addr("[ff::]:0").unwrap().is_ipv6()); + assert!(to_socket_addr("xx").is_err()); + } +} \ No newline at end of file diff --git a/libs/hbb_common/src/udp.rs b/libs/hbb_common/src/udp.rs index 3532dd1e0..1f5bf2637 100644 --- a/libs/hbb_common/src/udp.rs +++ b/libs/hbb_common/src/udp.rs @@ -49,7 +49,7 @@ impl FramedSocket { #[allow(clippy::never_loop)] pub async fn new_reuse(addr: T) -> ResultType { - for addr in addr.to_socket_addrs()?.filter(|x| x.is_ipv4()) { + for addr in addr.to_socket_addrs()? { let socket = new_socket(addr, true, 0)?.into_udp_socket(); return Ok(Self::Direct(UdpFramed::new( UdpSocket::from_std(socket)?, @@ -63,7 +63,7 @@ impl FramedSocket { addr: T, buf_size: usize, ) -> ResultType { - for addr in addr.to_socket_addrs()?.filter(|x| x.is_ipv4()) { + for addr in addr.to_socket_addrs()? { return Ok(Self::Direct(UdpFramed::new( UdpSocket::from_std(new_socket(addr, false, buf_size)?.into_udp_socket())?, BytesCodec::new(),