add unit test to test_if_valid_server

This commit is contained in:
rustdesk 2023-01-09 18:28:11 +08:00
parent 3ac37b9686
commit 80c1b89b47
2 changed files with 70 additions and 55 deletions

View File

@ -9,17 +9,54 @@ use std::net::SocketAddr;
use tokio::net::ToSocketAddrs; use tokio::net::ToSocketAddrs;
use tokio_socks::{IntoTargetAddr, TargetAddr}; use tokio_socks::{IntoTargetAddr, TargetAddr};
pub fn test_if_valid_server(host: &str) -> String { #[inline]
let mut host = host.to_owned(); pub fn check_port<T: std::string::ToString>(host: T, port: i32) -> String {
if !host.contains(":") { let host = host.to_string();
host = format!("{}:{}", host, 0); if crate::is_ipv6_str(&host) {
if host.starts_with("[") {
return host;
}
return format!("[{}]:{}", host, port);
} }
if !host.contains(":") {
return format!("{}:{}", host, port);
}
return host;
}
#[inline]
pub fn increase_port<T: std::string::ToString>(host: T, offset: i32) -> String {
let host = host.to_string();
if crate::is_ipv6_str(&host) {
if host.starts_with("[") {
let tmp: Vec<&str> = host.split("]:").collect();
if tmp.len() == 2 {
let port: i32 = tmp[1].parse().unwrap_or(0);
if port > 0 {
return format!("{}]:{}", tmp[0], port + offset);
}
}
}
} else if host.contains(":") {
let tmp: Vec<&str> = host.split(":").collect();
if tmp.len() == 2 {
let port: i32 = tmp[1].parse().unwrap_or(0);
if port > 0 {
return format!("{}:{}", tmp[0], port + offset);
}
}
}
return host;
}
pub fn test_if_valid_server(host: &str) -> String {
let host = check_port(host, 0);
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
match Config::get_network_type() { match Config::get_network_type() {
NetworkType::Direct => match host.to_socket_addrs() { NetworkType::Direct => match host.to_socket_addrs() {
Err(err) => err.to_string(), Err(err) => err.to_string(),
Ok(_) => "".to_owned(), Ok(x) => "".to_owned(),
}, },
NetworkType::ProxySocks => match &host.into_target_addr() { NetworkType::ProxySocks => match &host.into_target_addr() {
Err(err) => err.to_string(), Err(err) => err.to_string(),
@ -216,4 +253,30 @@ mod tests {
} }
assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap()).await.is_err()); assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap()).await.is_err());
} }
#[test]
fn test_test_if_valid_server() {
assert!(!test_if_valid_server("a").is_empty());
// on Linux, "1" is resolved to "0.0.0.1"
assert!(test_if_valid_server("1.1.1.1").is_empty());
assert!(test_if_valid_server("1.1.1.1:1").is_empty());
}
#[test]
fn test_check_port() {
assert_eq!(check_port("[1:2]:12", 32), "[1:2]:12");
assert_eq!(check_port("1:2", 32), "[1:2]:32");
assert_eq!(check_port("z1:2", 32), "z1:2");
assert_eq!(check_port("1.1.1.1", 32), "1.1.1.1:32");
assert_eq!(check_port("1.1.1.1:32", 32), "1.1.1.1:32");
assert_eq!(check_port("test.com:32", 0), "test.com:32");
assert_eq!(increase_port("[1:2]:12", 1), "[1:2]:13");
assert_eq!(increase_port("1.2.2.4:12", 1), "1.2.2.4:13");
assert_eq!(increase_port("1.2.2.4", 1), "1.2.2.4");
assert_eq!(increase_port("test.com", 1), "test.com");
assert_eq!(increase_port("test.com:13", 4), "test.com:17");
assert_eq!(increase_port("1:13", 4), "1:13");
assert_eq!(increase_port("22:1:13", 4), "22:1:13");
assert_eq!(increase_port("z1:2", 1), "z1:3");
}
} }

View File

@ -476,42 +476,12 @@ pub fn username() -> String {
#[inline] #[inline]
pub fn check_port<T: std::string::ToString>(host: T, port: i32) -> String { pub fn check_port<T: std::string::ToString>(host: T, port: i32) -> String {
let host = host.to_string(); hbb_common::socket_client::check_port(host, port)
if is_ipv6_str(&host) {
if host.starts_with("[") {
return host;
}
return format!("[{}]:{}", host, port);
}
if !host.contains(":") {
return format!("{}:{}", host, port);
}
return host;
} }
#[inline] #[inline]
pub fn increase_port<T: std::string::ToString>(host: T, offset: i32) -> String { pub fn increase_port<T: std::string::ToString>(host: T, offset: i32) -> String {
let host = host.to_string(); hbb_common::socket_client::increase_port(host, offset)
if is_ipv6_str(&host) {
if host.starts_with("[") {
let tmp: Vec<&str> = host.split("]:").collect();
if tmp.len() == 2 {
let port: i32 = tmp[1].parse().unwrap_or(0);
if port > 0 {
return format!("{}]:{}", tmp[0], port + offset);
}
}
}
} else if host.contains(":") {
let tmp: Vec<&str> = host.split(":").collect();
if tmp.len() == 2 {
let port: i32 = tmp[1].parse().unwrap_or(0);
if port > 0 {
return format!("{}:{}", tmp[0], port + offset);
}
}
}
return host;
} }
pub const POSTFIX_SERVICE: &'static str = "_service"; pub const POSTFIX_SERVICE: &'static str = "_service";
@ -741,22 +711,4 @@ pub fn make_fd_to_json(id: i32, path: String, entries: &Vec<FileEntry>) -> Strin
#[cfg(test)] #[cfg(test)]
mod test_common { mod test_common {
use super::*; use super::*;
#[test]
fn test_check_port() {
assert_eq!(check_port("[1:2]:12", 32), "[1:2]:12");
assert_eq!(check_port("1:2", 32), "[1:2]:32");
assert_eq!(check_port("z1:2", 32), "z1:2");
assert_eq!(check_port("1.1.1.1", 32), "1.1.1.1:32");
assert_eq!(check_port("1.1.1.1:32", 32), "1.1.1.1:32");
assert_eq!(check_port("test.com:32", 0), "test.com:32");
assert_eq!(increase_port("[1:2]:12", 1), "[1:2]:13");
assert_eq!(increase_port("1.2.2.4:12", 1), "1.2.2.4:13");
assert_eq!(increase_port("1.2.2.4", 1), "1.2.2.4");
assert_eq!(increase_port("test.com", 1), "test.com");
assert_eq!(increase_port("test.com:13", 4), "test.com:17");
assert_eq!(increase_port("1:13", 4), "1:13");
assert_eq!(increase_port("22:1:13", 4), "22:1:13");
assert_eq!(increase_port("z1:2", 1), "z1:3");
}
} }