diff --git a/libs/hbb_common/protos/message.proto b/libs/hbb_common/protos/message.proto index d064cd122..afceede23 100644 --- a/libs/hbb_common/protos/message.proto +++ b/libs/hbb_common/protos/message.proto @@ -25,6 +25,11 @@ message VideoFrame { } } +message IdPk { + string id = 1; + bytes pk = 2; +} + message DisplayInfo { sint32 x = 1; sint32 y = 2; diff --git a/libs/hbb_common/src/lib.rs b/libs/hbb_common/src/lib.rs index c5c2ac5dd..81871bf26 100644 --- a/libs/hbb_common/src/lib.rs +++ b/libs/hbb_common/src/lib.rs @@ -31,20 +31,11 @@ pub use lazy_static; pub use mac_address; pub use rand; pub use regex; -use serde_derive::{Deserialize, Serialize}; pub use sodiumoxide; pub use tokio_socks; pub use tokio_socks::IntoTargetAddr; pub use tokio_socks::TargetAddr; -#[derive(Debug, Default, Serialize, Deserialize, Clone)] -pub struct IdPk { - #[serde(default)] - pub id: String, - #[serde(default)] - pub pk: [u8; 32], -} - #[cfg(feature = "quic")] pub type Stream = quic::Connection; #[cfg(not(feature = "quic"))] diff --git a/src/client.rs b/src/client.rs index 71d4d7f08..4b0b7147c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -17,7 +17,7 @@ use hbb_common::{ sodiumoxide::crypto::{box_, secretbox, sign}, timeout, tokio::time::Duration, - AddrMangle, IdPk, ResultType, Stream, + AddrMangle, ResultType, Stream, }; use magnum_opus::{Channels::*, Decoder as AudioDecoder}; use scrap::{Decoder, Image, VideoCodecId}; @@ -327,11 +327,9 @@ impl Client { let rs_pk = get_rs_pk("OeVuKk5nlHiXp+APNn0Y3pC1Iwpwn44JGqrQCsWqmBw="); let mut sign_pk = None; if !signed_id_pk.is_empty() && rs_pk.is_some() { - if let Ok(data) = sign::verify(&signed_id_pk, &rs_pk.unwrap()) { - if let Ok(v) = serde_json::from_slice::(&data) { - if v.id == peer_id { - sign_pk = Some(sign::PublicKey(v.pk)); - } + if let Ok((id, pk)) = decode_id_pk(&signed_id_pk, &rs_pk.unwrap()) { + if id == peer_id { + sign_pk = Some(sign::PublicKey(pk)); } } if sign_pk.is_none() { @@ -351,18 +349,9 @@ impl Client { let bytes = res?; if let Ok(msg_in) = Message::parse_from_bytes(&bytes) { if let Some(message::Union::signed_id(si)) = msg_in.union { - if let Ok(data) = sign::verify(&si.id, &sign_pk) { - let (id, their_pk_b) = match serde_json::from_slice::(&data) { - Ok(v) => (v.id, box_::PublicKey(v.pk)), - Err(_) => { - log::error!( - "Handshake failed: invalid public box key length from peer" - ); - conn.send(&Message::new()).await?; - return Ok(()); - } - }; + if let Ok((id, their_pk_b)) = decode_id_pk(&si.id, &sign_pk) { if id == peer_id { + let their_pk_b = box_::PublicKey(their_pk_b); let (our_pk_b, out_sk_b) = box_::gen_keypair(); let key = secretbox::gen_key(); let nonce = box_::Nonce([0u8; box_::NONCEBYTES]); @@ -384,7 +373,7 @@ impl Client { log::info!("pk mismatch, fall back to non-secure"); let mut msg_out = Message::new(); msg_out.set_public_key(PublicKey::new()); - timeout(CONNECT_TIMEOUT, conn.send(&msg_out)).await??; + conn.send(&msg_out).await?; } } else { log::error!("Handshake failed: invalid message type"); @@ -1298,16 +1287,36 @@ pub fn check_if_retry(msgtype: &str, title: &str, text: &str) -> bool { && !text.to_lowercase().contains("resolve") && !text.to_lowercase().contains("mismatch") && !text.to_lowercase().contains("manually") + && !text.to_lowercase().contains("not allowed") +} + +#[inline] +fn get_pk(pk: &[u8]) -> Option<[u8; 32]> { + if pk.len() == 32 { + let mut tmp = [0u8; 32]; + tmp[..].copy_from_slice(&pk); + Some(tmp) + } else { + None + } } #[inline] fn get_rs_pk(str_base64: &str) -> Option { if let Ok(pk) = base64::decode(str_base64) { - if pk.len() == sign::PUBLICKEYBYTES { - let mut tmp = [0u8; sign::PUBLICKEYBYTES]; - tmp[..].copy_from_slice(&pk); - return Some(sign::PublicKey(tmp)); - } + get_pk(&pk).map(|x| sign::PublicKey(x)) + } else { + None + } +} + +fn decode_id_pk(signed: &[u8], key: &sign::PublicKey) -> ResultType<(String, [u8; 32])> { + let res = IdPk::parse_from_bytes( + &sign::verify(signed, key).map_err(|_| anyhow!("Signature mismatch"))?, + )?; + if let Some(pk) = get_pk(&res.pk) { + Ok((res.id, pk)) + } else { + bail!("Wrong public length"); } - None } diff --git a/src/server.rs b/src/server.rs index 3f93adf1e..ac8530669 100644 --- a/src/server.rs +++ b/src/server.rs @@ -103,10 +103,12 @@ pub async fn create_tcp_connection( let (our_pk_b, our_sk_b) = box_::gen_keypair(); msg_out.set_signed_id(SignedId { id: sign::sign( - &serde_json::to_vec(&hbb_common::IdPk { + &IdPk { id: Config::get_id(), - pk: our_pk_b.0, - }) + pk: our_pk_b.0.to_vec(), + ..Default::default() + } + .write_to_bytes() .unwrap_or_default(), &sk, ),