diff --git a/libs/hbb_common/protos/rendezvous.proto b/libs/hbb_common/protos/rendezvous.proto index d86226290..5d57b083d 100644 --- a/libs/hbb_common/protos/rendezvous.proto +++ b/libs/hbb_common/protos/rendezvous.proto @@ -163,6 +163,10 @@ message OnlineResponse { bytes states = 1; } +message KeyExchange { + repeated bytes keys = 1; +} + message RendezvousMessage { oneof union { RegisterPeer register_peer = 6; @@ -184,5 +188,6 @@ message RendezvousMessage { PeerDiscovery peer_discovery = 22; OnlineRequest online_request = 23; OnlineResponse online_response = 24; + KeyExchange key_exchange = 25; } } diff --git a/libs/hbb_common/src/config.rs b/libs/hbb_common/src/config.rs index 5b3ded04c..132f27545 100644 --- a/libs/hbb_common/src/config.rs +++ b/libs/hbb_common/src/config.rs @@ -25,7 +25,7 @@ use crate::{ pub const RENDEZVOUS_TIMEOUT: u64 = 12_000; pub const CONNECT_TIMEOUT: u64 = 18_000; -pub const READ_TIMEOUT: u64 = 30_000; +pub const READ_TIMEOUT: u64 = 18_000; pub const REG_INTERVAL: i64 = 12_000; pub const COMPRESS_LEVEL: i32 = 3; const SERIAL: i32 = 3; diff --git a/libs/hbb_common/src/tcp.rs b/libs/hbb_common/src/tcp.rs index f574e8309..2285e6430 100644 --- a/libs/hbb_common/src/tcp.rs +++ b/libs/hbb_common/src/tcp.rs @@ -3,7 +3,10 @@ use anyhow::Context as AnyhowCtx; use bytes::{BufMut, Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; use protobuf::Message; -use sodiumoxide::crypto::secretbox::{self, Key, Nonce}; +use sodiumoxide::crypto::{ + box_, + secretbox::{self, Key, Nonce}, +}; use std::{ io::{self, Error, ErrorKind}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, @@ -21,10 +24,13 @@ use tokio_util::codec::Framed; pub trait TcpStreamTrait: AsyncRead + AsyncWrite + Unpin {} pub struct DynTcpStream(Box); +#[derive(Clone)] +pub struct Encrypt(Key, u64, u64); + pub struct FramedStream( Framed, SocketAddr, - Option<(Key, u64, u64)>, + Option, u64, ); @@ -185,9 +191,7 @@ impl FramedStream { pub async fn send_raw(&mut self, msg: Vec) -> ResultType<()> { let mut msg = msg; if let Some(key) = self.2.as_mut() { - key.1 += 1; - let nonce = Self::get_nonce(key.1); - msg = secretbox::seal(&msg, &nonce, &key.0); + msg = key.enc(&msg); } self.send_bytes(bytes::Bytes::from(msg)).await?; Ok(()) @@ -206,18 +210,10 @@ impl FramedStream { #[inline] pub async fn next(&mut self) -> Option> { let mut res = self.0.next().await; - if let Some(key) = self.2.as_mut() { - if let Some(Ok(bytes)) = res.as_mut() { - key.2 += 1; - let nonce = Self::get_nonce(key.2); - match secretbox::open(bytes, &nonce, &key.0) { - Ok(res) => { - bytes.clear(); - bytes.put_slice(&res); - } - Err(()) => { - return Some(Err(Error::new(ErrorKind::Other, "decryption error"))); - } + if let Some(Ok(bytes)) = res.as_mut() { + if let Some(key) = self.2.as_mut() { + if let Err(err) = key.dec(bytes) { + return Some(Err(err)); } } } @@ -234,7 +230,7 @@ impl FramedStream { } pub fn set_key(&mut self, key: Key) { - self.2 = Some((key, 0, 0)); + self.2 = Some(Encrypt::new(key)); } fn get_nonce(seqnum: u64) -> Nonce { @@ -323,3 +319,48 @@ impl AsyncWrite for DynTcpStream { } impl TcpStreamTrait for R {} + +impl Encrypt { + pub fn new(key: Key) -> Self { + Self(key, 0, 0) + } + + pub fn dec(&mut self, bytes: &mut BytesMut) -> Result<(), Error> { + self.2 += 1; + let nonce = FramedStream::get_nonce(self.2); + match secretbox::open(bytes, &nonce, &self.0) { + Ok(res) => { + bytes.clear(); + bytes.put_slice(&res); + Ok(()) + } + Err(()) => Err(Error::new(ErrorKind::Other, "decryption error")), + } + } + + pub fn enc(&mut self, data: &[u8]) -> Vec { + self.1 += 1; + let nonce = FramedStream::get_nonce(self.1); + secretbox::seal(&data, &nonce, &self.0) + } + + pub fn decode( + symmetric_data: &[u8], + their_pk_b: &[u8], + our_sk_b: &box_::SecretKey, + ) -> ResultType { + assert!(their_pk_b.len() == box_::PUBLICKEYBYTES); + let nonce = box_::Nonce([0u8; box_::NONCEBYTES]); + let mut pk_ = [0u8; box_::PUBLICKEYBYTES]; + pk_[..].copy_from_slice(their_pk_b); + let their_pk_b = box_::PublicKey(pk_); + let symmetric_key = box_::open(symmetric_data, &nonce, &their_pk_b, &our_sk_b) + .map_err(|_| anyhow::anyhow!("Handshake failed: box decryption failure"))?; + if symmetric_key.len() != secretbox::KEYBYTES { + anyhow::bail!("Handshake failed: invalid secret key length from peer"); + } + let mut key = [0u8; secretbox::KEYBYTES]; + key[..].copy_from_slice(&symmetric_key); + Ok(Key(key)) + } +} diff --git a/src/client.rs b/src/client.rs index b762b41d3..72195a54c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -31,10 +31,7 @@ use hbb_common::{ allow_err, anyhow::{anyhow, Context}, bail, - config::{ - Config, PeerConfig, PeerInfoSerde, CONNECT_TIMEOUT, READ_TIMEOUT, RELAY_PORT, - RENDEZVOUS_TIMEOUT, - }, + config::{Config, PeerConfig, PeerInfoSerde, CONNECT_TIMEOUT, READ_TIMEOUT, RELAY_PORT}, get_version_number, log, message_proto::{option_message::BoolOption, *}, protobuf::Message as _, @@ -42,6 +39,7 @@ use hbb_common::{ rendezvous_proto::*, socket_client, sodiumoxide::crypto::{box_, secretbox, sign}, + tcp::FramedStream, timeout, tokio::time::Duration, AddrMangle, ResultType, Stream, @@ -239,7 +237,7 @@ impl Client { return Ok(( socket_client::connect_tcp( crate::check_port(peer, RELAY_PORT + 1), - RENDEZVOUS_TIMEOUT, + CONNECT_TIMEOUT, ) .await?, true, @@ -249,18 +247,18 @@ impl Client { // Allow connect to {domain}:{port} if hbb_common::is_domain_port_str(peer) { return Ok(( - socket_client::connect_tcp(peer, RENDEZVOUS_TIMEOUT).await?, + socket_client::connect_tcp(peer, CONNECT_TIMEOUT).await?, true, None, )); } let (mut rendezvous_server, servers, contained) = crate::get_rendezvous_server(1_000).await; - let mut socket = socket_client::connect_tcp(&*rendezvous_server, RENDEZVOUS_TIMEOUT).await; + let mut socket = socket_client::connect_tcp(&*rendezvous_server, CONNECT_TIMEOUT).await; debug_assert!(!servers.contains(&rendezvous_server)); if socket.is_err() && !servers.is_empty() { log::info!("try the other servers: {:?}", servers); for server in servers { - socket = socket_client::connect_tcp(&*server, RENDEZVOUS_TIMEOUT).await; + socket = socket_client::connect_tcp(&*server, CONNECT_TIMEOUT).await; if socket.is_ok() { rendezvous_server = server; break; @@ -276,6 +274,11 @@ impl Client { let mut signed_id_pk = Vec::new(); let mut relay_server = "".to_owned(); + if !key.is_empty() && !token.is_empty() { + // mainly for the security of token + allow_err!(secure_punch_connection(&mut socket, key).await); + } + let start = std::time::Instant::now(); let mut peer_addr = Config::get_any_listen_addr(true); let mut peer_nat_type = NatType::UNKNOWN_NAT; @@ -299,71 +302,69 @@ impl Client { ..Default::default() }); socket.send(&msg_out).await?; - if let Some(Ok(bytes)) = socket.next_timeout(i * 6000).await { - if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { - match msg_in.union { - Some(rendezvous_message::Union::PunchHoleResponse(ph)) => { - if ph.socket_addr.is_empty() { - if !ph.other_failure.is_empty() { - bail!(ph.other_failure); - } - match ph.failure.enum_value_or_default() { - punch_hole_response::Failure::ID_NOT_EXIST => { - bail!("ID does not exist"); - } - punch_hole_response::Failure::OFFLINE => { - bail!("Remote desktop is offline"); - } - punch_hole_response::Failure::LICENSE_MISMATCH => { - bail!("Key mismatch"); - } - punch_hole_response::Failure::LICENSE_OVERUSE => { - bail!("Key overuse"); - } - } - } else { - peer_nat_type = ph.nat_type(); - is_local = ph.is_local(); - signed_id_pk = ph.pk.into(); - relay_server = ph.relay_server; - peer_addr = AddrMangle::decode(&ph.socket_addr); - log::info!("Hole Punched {} = {}", peer, peer_addr); - break; + if let Some(msg_in) = + crate::common::get_next_nonkeyexchange_msg(&mut socket, Some(i * 6000)).await + { + match msg_in.union { + Some(rendezvous_message::Union::PunchHoleResponse(ph)) => { + if ph.socket_addr.is_empty() { + if !ph.other_failure.is_empty() { + bail!(ph.other_failure); } - } - Some(rendezvous_message::Union::RelayResponse(rr)) => { - log::info!( - "relay requested from peer, time used: {:?}, relay_server: {}", - start.elapsed(), - rr.relay_server - ); - signed_id_pk = rr.pk().into(); - let mut conn = Self::create_relay( - peer, - rr.uuid, - rr.relay_server, - key, - conn_type, - my_addr.is_ipv4(), - ) - .await?; - let pk = Self::secure_connection( - peer, - signed_id_pk, - key, - &mut conn, - false, - interface, - ) - .await?; - return Ok((conn, false, pk)); - } - _ => { - log::error!("Unexpected protobuf msg received: {:?}", msg_in); + match ph.failure.enum_value_or_default() { + punch_hole_response::Failure::ID_NOT_EXIST => { + bail!("ID does not exist"); + } + punch_hole_response::Failure::OFFLINE => { + bail!("Remote desktop is offline"); + } + punch_hole_response::Failure::LICENSE_MISMATCH => { + bail!("Key mismatch"); + } + punch_hole_response::Failure::LICENSE_OVERUSE => { + bail!("Key overuse"); + } + } + } else { + peer_nat_type = ph.nat_type(); + is_local = ph.is_local(); + signed_id_pk = ph.pk.into(); + relay_server = ph.relay_server; + peer_addr = AddrMangle::decode(&ph.socket_addr); + log::info!("Hole Punched {} = {}", peer, peer_addr); + break; } } - } else { - log::error!("Non-protobuf message bytes received: {:?}", bytes); + Some(rendezvous_message::Union::RelayResponse(rr)) => { + log::info!( + "relay requested from peer, time used: {:?}, relay_server: {}", + start.elapsed(), + rr.relay_server + ); + signed_id_pk = rr.pk().into(); + let mut conn = Self::create_relay( + peer, + rr.uuid, + rr.relay_server, + key, + conn_type, + my_addr.is_ipv4(), + ) + .await?; + let pk = Self::secure_connection( + peer, + signed_id_pk, + key, + &mut conn, + false, + interface, + ) + .await?; + return Ok((conn, false, pk)); + } + _ => { + log::error!("Unexpected protobuf msg received: {:?}", msg_in); + } } } } @@ -540,8 +541,15 @@ impl Client { if let Some(message::Union::SignedId(si)) = msg_in.union { if let Ok((id, their_pk_b)) = decode_id_pk(&si.id, &sign_pk) { if id == peer_id { - let (msg, key) = create_symmetric_key_msg(their_pk_b); - timeout(CONNECT_TIMEOUT, conn.send(&msg)).await??; + let (asymmetric_value, symmetric_value, key) = + create_symmetric_key_msg(their_pk_b); + let mut msg_out = Message::new(); + msg_out.set_public_key(PublicKey { + asymmetric_value, + symmetric_value, + ..Default::default() + }); + timeout(CONNECT_TIMEOUT, conn.send(&msg_out)).await??; conn.set_key(key); } else { log::error!("Handshake failed: sign failure"); @@ -585,7 +593,7 @@ impl Client { let mut ipv4 = true; for i in 1..=3 { // use different socket due to current hbbs implement requiring different nat address for each attempt - let mut socket = socket_client::connect_tcp(rendezvous_server, RENDEZVOUS_TIMEOUT) + let mut socket = socket_client::connect_tcp(rendezvous_server, CONNECT_TIMEOUT) .await .with_context(|| "Failed to connect to rendezvous server")?; @@ -2486,21 +2494,52 @@ fn decode_id_pk(signed: &[u8], key: &sign::PublicKey) -> ResultType<(String, [u8 if let Some(pk) = get_pk(&res.pk) { Ok((res.id, pk)) } else { - bail!("Wrong public length"); + bail!("Wrong their public length"); } } -fn create_symmetric_key_msg(their_pk_b: [u8; 32]) -> (Message, secretbox::Key) { +fn create_symmetric_key_msg(their_pk_b: [u8; 32]) -> (Bytes, Bytes, secretbox::Key) { let their_pk_b = box_::PublicKey(their_pk_b); let (our_pk_b, out_sk_b) = box_::gen_keypair(); let key = secretbox::gen_key(); let nonce = box_::Nonce([0u8; box_::NONCEBYTES]); let sealed_key = box_::seal(&key.0, &nonce, &their_pk_b, &out_sk_b); - let mut msg_out = Message::new(); - msg_out.set_public_key(PublicKey { - asymmetric_value: Vec::from(our_pk_b.0).into(), - symmetric_value: sealed_key.into(), - ..Default::default() - }); - (msg_out, key) + (Vec::from(our_pk_b.0).into(), sealed_key.into(), key) +} + +async fn secure_punch_connection(conn: &mut FramedStream, key: &str) -> ResultType<()> { + let rs_pk = get_rs_pk(key); + let Some(rs_pk) = rs_pk else { + bail!("Handshake failed: invalid public key from rendezvous server"); + }; + match timeout(READ_TIMEOUT, conn.next()).await? { + Some(Ok(bytes)) => { + if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { + match msg_in.union { + Some(rendezvous_message::Union::KeyExchange(ex)) => { + if ex.keys.len() != 1 { + bail!("Handshake failed: invalid key exchange message"); + } + let their_pk_b = sign::verify(&ex.keys[0], &rs_pk) + .map_err(|_| anyhow!("Signature mismatch in key exchange"))?; + let (asymmetric_value, symmetric_value, key) = create_symmetric_key_msg( + get_pk(&their_pk_b) + .context("Wrong their public length in key exchange")?, + ); + let mut msg_out = RendezvousMessage::new(); + msg_out.set_key_exchange(KeyExchange { + keys: vec![asymmetric_value, symmetric_value], + ..Default::default() + }); + timeout(CONNECT_TIMEOUT, conn.send(&msg_out)).await??; + conn.set_key(key); + log::info!("Token secured"); + } + _ => {} + } + } + } + _ => {} + } + Ok(()) } diff --git a/src/common.rs b/src/common.rs index 959ce2896..d891648f2 100644 --- a/src/common.rs +++ b/src/common.rs @@ -19,13 +19,15 @@ use hbb_common::compress::decompress; use hbb_common::{ allow_err, compress::compress as compress_func, - config::{self, Config, RENDEZVOUS_TIMEOUT}, + config::{self, Config, CONNECT_TIMEOUT, READ_TIMEOUT}, get_version_number, log, message_proto::*, protobuf::Enum, protobuf::Message as _, rendezvous_proto::*, - sleep, socket_client, tokio, ResultType, + sleep, socket_client, + tcp::FramedStream, + tokio, ResultType, }; // #[cfg(any(target_os = "android", target_os = "ios", feature = "cli"))] use hbb_common::{config::RENDEZVOUS_PORT, futures::future::join_all}; @@ -231,7 +233,7 @@ pub fn update_clipboard(clipboard: Clipboard, old: Option<&Arc>>) pub async fn send_opts_after_login( config: &crate::client::LoginConfigHandler, - peer: &mut hbb_common::tcp::FramedStream, + peer: &mut FramedStream, ) { if let Some(opts) = config.get_option_message_after_login() { let mut misc = Misc::new(); @@ -551,11 +553,8 @@ async fn test_nat_type_() -> ResultType { let mut port1 = 0; let mut port2 = 0; for i in 0..2 { - let mut socket = socket_client::connect_tcp( - if i == 0 { &*server1 } else { &*server2 }, - RENDEZVOUS_TIMEOUT, - ) - .await?; + let server = if i == 0 { &*server1 } else { &*server2 }; + let mut socket = socket_client::connect_tcp(server, CONNECT_TIMEOUT).await?; if i == 0 { Config::set_option( "local-ip-addr".to_owned(), @@ -563,21 +562,20 @@ async fn test_nat_type_() -> ResultType { ); } socket.send(&msg_out).await?; - if let Some(Ok(bytes)) = socket.next_timeout(RENDEZVOUS_TIMEOUT).await { - if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { - if let Some(rendezvous_message::Union::TestNatResponse(tnr)) = msg_in.union { - if i == 0 { - port1 = tnr.port; - } else { - port2 = tnr.port; - } - if let Some(cu) = tnr.cu.as_ref() { - Config::set_option( - "rendezvous-servers".to_owned(), - cu.rendezvous_servers.join(","), - ); - Config::set_serial(cu.serial); - } + if let Some(msg_in) = get_next_nonkeyexchange_msg(&mut socket, None).await { + if let Some(rendezvous_message::Union::TestNatResponse(tnr)) = msg_in.union { + log::debug!("Got nat response from {}: port={}", server, tnr.port); + if i == 0 { + port1 = tnr.port; + } else { + port2 = tnr.port; + } + if let Some(cu) = tnr.cu.as_ref() { + Config::set_option( + "rendezvous-servers".to_owned(), + cu.rendezvous_servers.join(","), + ); + Config::set_serial(cu.serial); } } } else { @@ -654,7 +652,7 @@ async fn test_rendezvous_server_() { let tm = std::time::Instant::now(); if socket_client::connect_tcp( crate::check_port(&host, RENDEZVOUS_PORT), - RENDEZVOUS_TIMEOUT, + CONNECT_TIMEOUT, ) .await .is_ok() @@ -765,7 +763,7 @@ async fn check_software_update_() -> hbb_common::ResultType<()> { let rendezvous_server = format!("rs-sg.rustdesk.com:{}", config::RENDEZVOUS_PORT); let (mut socket, rendezvous_server) = - socket_client::new_udp_for(&rendezvous_server, RENDEZVOUS_TIMEOUT).await?; + socket_client::new_udp_for(&rendezvous_server, CONNECT_TIMEOUT).await?; let mut msg_out = RendezvousMessage::new(); msg_out.set_software_update(SoftwareUpdate { @@ -774,12 +772,14 @@ async fn check_software_update_() -> hbb_common::ResultType<()> { }); socket.send(&msg_out, rendezvous_server).await?; use hbb_common::protobuf::Message; - if let Some(Ok((bytes, _))) = socket.next_timeout(30_000).await { - if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { - if let Some(rendezvous_message::Union::SoftwareUpdate(su)) = msg_in.union { - let version = hbb_common::get_version_from_url(&su.url); - if get_version_number(&version) > get_version_number(crate::VERSION) { - *SOFTWARE_UPDATE_URL.lock().unwrap() = su.url; + for _ in 0..2 { + if let Some(Ok((bytes, _))) = socket.next_timeout(READ_TIMEOUT).await { + if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { + if let Some(rendezvous_message::Union::SoftwareUpdate(su)) = msg_in.union { + let version = hbb_common::get_version_from_url(&su.url); + if get_version_number(&version) > get_version_number(crate::VERSION) { + *SOFTWARE_UPDATE_URL.lock().unwrap() = su.url; + } } } } @@ -1002,3 +1002,27 @@ pub fn pk_to_fingerprint(pk: Vec) -> String { }) .collect() } + +#[inline] +pub async fn get_next_nonkeyexchange_msg( + conn: &mut FramedStream, + timeout: Option, +) -> Option { + let timeout = timeout.unwrap_or(READ_TIMEOUT); + for _ in 0..2 { + if let Some(Ok(bytes)) = conn.next_timeout(timeout).await { + if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { + match &msg_in.union { + Some(rendezvous_message::Union::KeyExchange(_)) => { + continue; + } + _ => { + return Some(msg_in); + } + } + } + } + break; + } + None +} diff --git a/src/rendezvous_mediator.rs b/src/rendezvous_mediator.rs index b6422e9a6..faf77b284 100644 --- a/src/rendezvous_mediator.rs +++ b/src/rendezvous_mediator.rs @@ -13,7 +13,7 @@ use hbb_common::tcp::FramedStream; use hbb_common::{ allow_err, anyhow::bail, - config::{Config, REG_INTERVAL, RENDEZVOUS_PORT, RENDEZVOUS_TIMEOUT}, + config::{Config, CONNECT_TIMEOUT, READ_TIMEOUT, REG_INTERVAL, RENDEZVOUS_PORT}, futures::future::join_all, log, protobuf::Message as _, @@ -120,7 +120,7 @@ impl RendezvousMediator { }) .unwrap_or(host.to_owned()); let host = crate::check_port(&host, RENDEZVOUS_PORT); - let (mut socket, addr) = socket_client::new_udp_for(&host, RENDEZVOUS_TIMEOUT).await?; + let (mut socket, addr) = socket_client::new_udp_for(&host, CONNECT_TIMEOUT).await?; let mut rz = Self { addr: addr, host: host.clone(), @@ -307,7 +307,7 @@ impl RendezvousMediator { secure, ); - let mut socket = socket_client::connect_tcp(&*self.host, RENDEZVOUS_TIMEOUT).await?; + let mut socket = socket_client::connect_tcp(&*self.host, CONNECT_TIMEOUT).await?; let mut msg_out = Message::new(); let mut rr = RelayResponse { @@ -352,7 +352,7 @@ impl RendezvousMediator { } let peer_addr = AddrMangle::decode(&fla.socket_addr); log::debug!("Handle intranet from {:?}", peer_addr); - let mut socket = socket_client::connect_tcp(&*self.host, RENDEZVOUS_TIMEOUT).await?; + let mut socket = socket_client::connect_tcp(&*self.host, CONNECT_TIMEOUT).await?; let local_addr = socket.local_addr(); let local_addr: SocketAddr = format!("{}:{}", local_addr.ip(), local_addr.port()).parse()?; @@ -391,7 +391,7 @@ impl RendezvousMediator { let peer_addr = AddrMangle::decode(&ph.socket_addr); log::debug!("Punch hole to {:?}", peer_addr); let mut socket = { - let socket = socket_client::connect_tcp(&*self.host, RENDEZVOUS_TIMEOUT).await?; + let socket = socket_client::connect_tcp(&*self.host, CONNECT_TIMEOUT).await?; let local_addr = socket.local_addr(); // key important here for punch hole to tell my gateway incoming peer is safe. // it can not be async here, because local_addr can not be reused, we must close the connection before use it again. @@ -649,7 +649,8 @@ pub async fn query_online_states, Vec)>(ids: Vec ResultType { - let (rendezvous_server, _servers, _contained) = crate::get_rendezvous_server(1_000).await; + let (rendezvous_server, _servers, _contained) = + crate::get_rendezvous_server(READ_TIMEOUT).await; let tmp: Vec<&str> = rendezvous_server.split(":").collect(); if tmp.len() != 2 { bail!("Invalid server address: {}", rendezvous_server); @@ -659,7 +660,7 @@ async fn create_online_stream() -> ResultType { bail!("Invalid server address: {}", rendezvous_server); } let online_server = format!("{}:{}", tmp[0], port - 1); - socket_client::connect_tcp(online_server, RENDEZVOUS_TIMEOUT).await + socket_client::connect_tcp(online_server, CONNECT_TIMEOUT).await } async fn query_online_states_( @@ -683,38 +684,30 @@ async fn query_online_states_( let mut socket = create_online_stream().await?; socket.send(&msg_out).await?; - match socket.next_timeout(RENDEZVOUS_TIMEOUT).await { - Some(Ok(bytes)) => { - if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { - match msg_in.union { - Some(rendezvous_message::Union::OnlineResponse(online_response)) => { - let states = online_response.states; - let mut onlines = Vec::new(); - let mut offlines = Vec::new(); - for i in 0..ids.len() { - // bytes index from left to right - let bit_value = 0x01 << (7 - i % 8); - if (states[i / 8] & bit_value) == bit_value { - onlines.push(ids[i].clone()); - } else { - offlines.push(ids[i].clone()); - } - } - return Ok((onlines, offlines)); - } - _ => { - // ignore + if let Some(msg_in) = crate::common::get_next_nonkeyexchange_msg(&mut socket, None).await { + match msg_in.union { + Some(rendezvous_message::Union::OnlineResponse(online_response)) => { + let states = online_response.states; + let mut onlines = Vec::new(); + let mut offlines = Vec::new(); + for i in 0..ids.len() { + // bytes index from left to right + let bit_value = 0x01 << (7 - i % 8); + if (states[i / 8] & bit_value) == bit_value { + onlines.push(ids[i].clone()); + } else { + offlines.push(ids[i].clone()); } } + return Ok((onlines, offlines)); + } + _ => { + // ignore } } - Some(Err(e)) => { - log::error!("Failed to receive {e}"); - } - None => { - // TODO: Make sure socket closed? - bail!("Online stream receives None"); - } + } else { + // TODO: Make sure socket closed? + bail!("Online stream receives None"); } if query_begin.elapsed() > timeout { diff --git a/src/server.rs b/src/server.rs index 39cb258af..66095fd9e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -10,7 +10,7 @@ use bytes::Bytes; pub use connection::*; #[cfg(not(any(target_os = "android", target_os = "ios")))] use hbb_common::config::Config2; -use hbb_common::tcp::new_listener; +use hbb_common::tcp::{self, new_listener}; use hbb_common::{ allow_err, anyhow::{anyhow, Context}, @@ -21,7 +21,7 @@ use hbb_common::{ protobuf::{Enum, Message as _}, rendezvous_proto::*, socket_client, - sodiumoxide::crypto::{box_, secretbox, sign}, + sodiumoxide::crypto::{box_, sign}, timeout, tokio, ResultType, Stream, }; #[cfg(not(any(target_os = "android", target_os = "ios")))] @@ -158,21 +158,11 @@ pub async fn create_tcp_connection( if let Ok(msg_in) = Message::parse_from_bytes(&bytes) { if let Some(message::Union::PublicKey(pk)) = msg_in.union { if pk.asymmetric_value.len() == box_::PUBLICKEYBYTES { - let nonce = box_::Nonce([0u8; box_::NONCEBYTES]); - let mut pk_ = [0u8; box_::PUBLICKEYBYTES]; - pk_[..].copy_from_slice(&pk.asymmetric_value); - let their_pk_b = box_::PublicKey(pk_); - let symmetric_key = - box_::open(&pk.symmetric_value, &nonce, &their_pk_b, &our_sk_b) - .map_err(|_| { - anyhow!("Handshake failed: box decryption failure") - })?; - if symmetric_key.len() != secretbox::KEYBYTES { - bail!("Handshake failed: invalid secret key length from peer"); - } - let mut key = [0u8; secretbox::KEYBYTES]; - key[..].copy_from_slice(&symmetric_key); - stream.set_key(secretbox::Key(key)); + stream.set_key(tcp::Encrypt::decode( + &pk.symmetric_value, + &pk.asymmetric_value, + &our_sk_b, + )?); } else if pk.asymmetric_value.is_empty() { Config::set_key_confirmed(false); log::info!("Force to update pk"); @@ -445,7 +435,10 @@ pub async fn start_ipc_url_server() { m.insert("name", "on_url_scheme_received"); m.insert("url", url.as_str()); let event = serde_json::to_string(&m).unwrap_or("".to_owned()); - match crate::flutter::push_global_event(crate::flutter::APP_TYPE_MAIN, event) { + match crate::flutter::push_global_event( + crate::flutter::APP_TYPE_MAIN, + event, + ) { None => log::warn!("No main window app found!"), Some(..) => {} } diff --git a/src/ui/common.tis b/src/ui/common.tis index 92e704052..932c6e76a 100644 --- a/src/ui/common.tis +++ b/src/ui/common.tis @@ -330,7 +330,7 @@ handler.msgbox_retry = function(type, title, text, link, hasRetry) { function retryConnect(cancelTimer=false) { if (cancelTimer) self.timer(0, retryConnect); if (!is_port_forward) connecting(); - handler.reconnect(); + handler.reconnect(false); } /******************** end of msgbox ****************************************/ diff --git a/src/ui_interface.rs b/src/ui_interface.rs index a6fc39a88..d2496d4d3 100644 --- a/src/ui_interface.rs +++ b/src/ui_interface.rs @@ -18,9 +18,8 @@ use hbb_common::{ }; use hbb_common::{ - config::{RENDEZVOUS_PORT, RENDEZVOUS_TIMEOUT}, + config::{CONNECT_TIMEOUT, RENDEZVOUS_PORT}, futures::future::join_all, - protobuf::Message as _, rendezvous_proto::*, }; @@ -1003,7 +1002,7 @@ async fn check_id( ) -> &'static str { if let Ok(mut socket) = hbb_common::socket_client::connect_tcp( crate::check_port(rendezvous_server, RENDEZVOUS_PORT), - RENDEZVOUS_TIMEOUT, + CONNECT_TIMEOUT, ) .await { @@ -1016,34 +1015,34 @@ async fn check_id( }); let mut ok = false; if socket.send(&msg_out).await.is_ok() { - if let Some(Ok(bytes)) = socket.next_timeout(3_000).await { - if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) { - match msg_in.union { - Some(rendezvous_message::Union::RegisterPkResponse(rpr)) => { - match rpr.result.enum_value_or_default() { - register_pk_response::Result::OK => { - ok = true; - } - register_pk_response::Result::ID_EXISTS => { - return "Not available"; - } - register_pk_response::Result::TOO_FREQUENT => { - return "Too frequent"; - } - register_pk_response::Result::NOT_SUPPORT => { - return "server_not_support"; - } - register_pk_response::Result::SERVER_ERROR => { - return "Server error"; - } - register_pk_response::Result::INVALID_ID_FORMAT => { - return INVALID_FORMAT; - } - _ => {} + if let Some(msg_in) = + crate::common::get_next_nonkeyexchange_msg(&mut socket, None).await + { + match msg_in.union { + Some(rendezvous_message::Union::RegisterPkResponse(rpr)) => { + match rpr.result.enum_value_or_default() { + register_pk_response::Result::OK => { + ok = true; } + register_pk_response::Result::ID_EXISTS => { + return "Not available"; + } + register_pk_response::Result::TOO_FREQUENT => { + return "Too frequent"; + } + register_pk_response::Result::NOT_SUPPORT => { + return "server_not_support"; + } + register_pk_response::Result::SERVER_ERROR => { + return "Server error"; + } + register_pk_response::Result::INVALID_ID_FORMAT => { + return INVALID_FORMAT; + } + _ => {} } - _ => {} } + _ => {} } } }