diff --git a/libs/hbb_common/src/socket_client.rs b/libs/hbb_common/src/socket_client.rs index b7cb13754..dde237267 100644 --- a/libs/hbb_common/src/socket_client.rs +++ b/libs/hbb_common/src/socket_client.rs @@ -9,17 +9,54 @@ use std::net::SocketAddr; use tokio::net::ToSocketAddrs; use tokio_socks::{IntoTargetAddr, TargetAddr}; -pub fn test_if_valid_server(host: &str) -> String { - let mut host = host.to_owned(); - if !host.contains(":") { - host = format!("{}:{}", host, 0); +#[inline] +pub fn check_port(host: T, port: i32) -> String { + let host = host.to_string(); + if crate::is_ipv6_str(&host) { + if host.starts_with("[") { + return host; + } + return format!("[{}]:{}", host, port); } + if !host.contains(":") { + return format!("{}:{}", host, port); + } + return host; +} + +#[inline] +pub fn increase_port(host: T, offset: i32) -> String { + let host = host.to_string(); + if crate::is_ipv6_str(&host) { + if host.starts_with("[") { + let tmp: Vec<&str> = host.split("]:").collect(); + if tmp.len() == 2 { + let port: i32 = tmp[1].parse().unwrap_or(0); + if port > 0 { + return format!("{}]:{}", tmp[0], port + offset); + } + } + } + } else if host.contains(":") { + let tmp: Vec<&str> = host.split(":").collect(); + if tmp.len() == 2 { + let port: i32 = tmp[1].parse().unwrap_or(0); + if port > 0 { + return format!("{}:{}", tmp[0], port + offset); + } + } + } + return host; +} + +pub fn test_if_valid_server(host: &str) -> String { + let host = check_port(host, 0); use std::net::ToSocketAddrs; match Config::get_network_type() { NetworkType::Direct => match host.to_socket_addrs() { Err(err) => err.to_string(), - Ok(_) => "".to_owned(), + Ok(x) => "".to_owned(), }, NetworkType::ProxySocks => match &host.into_target_addr() { Err(err) => err.to_string(), @@ -216,4 +253,30 @@ mod tests { } assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap()).await.is_err()); } + + #[test] + fn test_test_if_valid_server() { + assert!(!test_if_valid_server("a").is_empty()); + // on Linux, "1" is resolved to "0.0.0.1" + assert!(test_if_valid_server("1.1.1.1").is_empty()); + assert!(test_if_valid_server("1.1.1.1:1").is_empty()); + } + + #[test] + fn test_check_port() { + assert_eq!(check_port("[1:2]:12", 32), "[1:2]:12"); + assert_eq!(check_port("1:2", 32), "[1:2]:32"); + assert_eq!(check_port("z1:2", 32), "z1:2"); + assert_eq!(check_port("1.1.1.1", 32), "1.1.1.1:32"); + assert_eq!(check_port("1.1.1.1:32", 32), "1.1.1.1:32"); + assert_eq!(check_port("test.com:32", 0), "test.com:32"); + assert_eq!(increase_port("[1:2]:12", 1), "[1:2]:13"); + assert_eq!(increase_port("1.2.2.4:12", 1), "1.2.2.4:13"); + assert_eq!(increase_port("1.2.2.4", 1), "1.2.2.4"); + assert_eq!(increase_port("test.com", 1), "test.com"); + assert_eq!(increase_port("test.com:13", 4), "test.com:17"); + assert_eq!(increase_port("1:13", 4), "1:13"); + assert_eq!(increase_port("22:1:13", 4), "22:1:13"); + assert_eq!(increase_port("z1:2", 1), "z1:3"); + } } diff --git a/src/common.rs b/src/common.rs index 254c910e6..96a7763d0 100644 --- a/src/common.rs +++ b/src/common.rs @@ -476,42 +476,12 @@ pub fn username() -> String { #[inline] pub fn check_port(host: T, port: i32) -> String { - let host = host.to_string(); - if is_ipv6_str(&host) { - if host.starts_with("[") { - return host; - } - return format!("[{}]:{}", host, port); - } - if !host.contains(":") { - return format!("{}:{}", host, port); - } - return host; + hbb_common::socket_client::check_port(host, port) } #[inline] pub fn increase_port(host: T, offset: i32) -> String { - let host = host.to_string(); - if is_ipv6_str(&host) { - if host.starts_with("[") { - let tmp: Vec<&str> = host.split("]:").collect(); - if tmp.len() == 2 { - let port: i32 = tmp[1].parse().unwrap_or(0); - if port > 0 { - return format!("{}]:{}", tmp[0], port + offset); - } - } - } - } else if host.contains(":") { - let tmp: Vec<&str> = host.split(":").collect(); - if tmp.len() == 2 { - let port: i32 = tmp[1].parse().unwrap_or(0); - if port > 0 { - return format!("{}:{}", tmp[0], port + offset); - } - } - } - return host; + hbb_common::socket_client::increase_port(host, offset) } pub const POSTFIX_SERVICE: &'static str = "_service"; @@ -741,22 +711,4 @@ pub fn make_fd_to_json(id: i32, path: String, entries: &Vec) -> Strin #[cfg(test)] mod test_common { use super::*; - - #[test] - fn test_check_port() { - assert_eq!(check_port("[1:2]:12", 32), "[1:2]:12"); - assert_eq!(check_port("1:2", 32), "[1:2]:32"); - assert_eq!(check_port("z1:2", 32), "z1:2"); - assert_eq!(check_port("1.1.1.1", 32), "1.1.1.1:32"); - assert_eq!(check_port("1.1.1.1:32", 32), "1.1.1.1:32"); - assert_eq!(check_port("test.com:32", 0), "test.com:32"); - assert_eq!(increase_port("[1:2]:12", 1), "[1:2]:13"); - assert_eq!(increase_port("1.2.2.4:12", 1), "1.2.2.4:13"); - assert_eq!(increase_port("1.2.2.4", 1), "1.2.2.4"); - assert_eq!(increase_port("test.com", 1), "test.com"); - assert_eq!(increase_port("test.com:13", 4), "test.com:17"); - assert_eq!(increase_port("1:13", 4), "1:13"); - assert_eq!(increase_port("22:1:13", 4), "22:1:13"); - assert_eq!(increase_port("z1:2", 1), "z1:3"); - } }