From 3454454bd569c7b993bfd55b8f92ad4536edf9e5 Mon Sep 17 00:00:00 2001 From: fufesou Date: Wed, 19 Oct 2022 22:48:51 +0800 Subject: [PATCH] account oidc init rs Signed-off-by: fufesou --- Cargo.lock | 2 + Cargo.toml | 1 + src/hbbs_http.rs | 42 +++++++ src/hbbs_http/account.rs | 242 +++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 + 5 files changed, 289 insertions(+) create mode 100644 src/hbbs_http.rs create mode 100644 src/hbbs_http/account.rs diff --git a/Cargo.lock b/Cargo.lock index 26ed9455f..3fddb5c37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4419,6 +4419,7 @@ dependencies = [ "system_shutdown", "tray-item", "trayicon", + "url", "uuid", "virtual_display", "whoami", @@ -5492,6 +5493,7 @@ dependencies = [ "idna", "matches", "percent-encoding", + "serde 1.0.144", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index c950c8723..9516d8526 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,7 @@ wol-rs = "0.9.1" flutter_rust_bridge = { git = "https://github.com/SoLongAndThanksForAllThePizza/flutter_rust_bridge", optional = true } errno = "0.2.8" rdev = { git = "https://github.com/asur4s/rdev" } +url = { version = "2.1", features = ["serde"] } [target.'cfg(not(target_os = "linux"))'.dependencies] reqwest = { version = "0.11", features = ["json", "rustls-tls"], default-features=false } diff --git a/src/hbbs_http.rs b/src/hbbs_http.rs new file mode 100644 index 000000000..b0e8cdbab --- /dev/null +++ b/src/hbbs_http.rs @@ -0,0 +1,42 @@ +use hbb_common::{ + anyhow::{self, bail}, + tokio, ResultType, +}; +use reqwest::Response; +use serde_derive::Deserialize; +use serde_json::{Map, Value}; +use serde::de::DeserializeOwned; + +pub mod account; + +pub enum HbbHttpResponse { + ErrorFormat, + Error(String), + DataTypeFormat, + Data(T), +} + +#[tokio::main(flavor = "current_thread")] +async fn resp_to_serde_map(resp: Response) -> reqwest::Result> { + resp.json().await +} + +impl TryFrom for HbbHttpResponse { + type Error = reqwest::Error; + + fn try_from(resp: Response) -> Result>::Error> { + let map = resp_to_serde_map(resp)?; + if let Some(error) = map.get("error") { + if let Some(err) = error.as_str() { + Ok(Self::Error(err.to_owned())) + } else { + Ok(Self::ErrorFormat) + } + } else { + match serde_json::from_value(Value::Object(map)) { + Ok(v) => Ok(Self::Data(v)), + Err(_) => Ok(Self::DataTypeFormat), + } + } + } +} diff --git a/src/hbbs_http/account.rs b/src/hbbs_http/account.rs new file mode 100644 index 000000000..85b2f7e82 --- /dev/null +++ b/src/hbbs_http/account.rs @@ -0,0 +1,242 @@ +use super::HbbHttpResponse; +use hbb_common::{config::Config, log, sleep, tokio, tokio::sync::RwLock, ResultType}; +use serde_derive::Deserialize; +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, Instant}, +}; +use url::Url; + +lazy_static::lazy_static! { + static ref API_SERVER: String = crate::get_api_server( + Config::get_option("api-server"), Config::get_option("custom-rendezvous-server")); + static ref OIDC_SESSION: Arc> = Arc::new(RwLock::new(OidcSession::new())); +} + +const QUERY_INTERVAL_SECS: f32 = 1.0; +const QUERY_TIMEOUT_SECS: u64 = 60; + +#[derive(Deserialize, Clone)] +pub struct OidcAuthUrl { + code: String, + url: Url, +} + +#[derive(Debug, Deserialize, Default, Clone)] +pub struct UserPayload { + pub id: String, + pub name: String, + pub email: Option, + pub note: Option, + pub status: Option, + pub grp: Option, + pub is_admin: Option, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct AuthBody { + access_token: String, + token_type: String, + user: UserPayload, +} + +#[derive(Copy, Clone)] +pub enum OidcState { + // initial request + OidcRequest = 1, + // initial request failed + OidcRequestFailed = 2, + // request succeeded, loop querying + OidcQuerying = 11, + // loop querying failed + OidcQueryFailed = 12, + // query sucess before + OidcNotExists = 13, + // query timeout + OidcQueryTimeout = 14, + // already login + OidcLogin = 21, +} + +pub struct OidcSession { + client: reqwest::Client, + state: OidcState, + failed_msg: String, + code_url: Option, + auth_body: Option, + keep_querying: bool, + running: bool, + query_timeout: Duration, +} + +impl OidcSession { + fn new() -> Self { + Self { + client: reqwest::Client::new(), + state: OidcState::OidcRequest, + failed_msg: "".to_owned(), + code_url: None, + auth_body: None, + keep_querying: false, + running: false, + query_timeout: Duration::from_secs(QUERY_TIMEOUT_SECS), + } + } + + async fn auth(op: &str, id: &str, uuid: &str) -> ResultType> { + Ok(OIDC_SESSION + .read() + .await + .client + .post(format!("{}/api/oidc/auth", *API_SERVER)) + .json(&HashMap::from([("op", op), ("id", id), ("uuid", uuid)])) + .send() + .await? + .try_into()?) + } + + async fn query(code: &str, id: &str, uuid: &str) -> ResultType> { + let url = reqwest::Url::parse_with_params( + &format!("{}/api/oidc/auth-query", *API_SERVER), + &[("code", code), ("id", id), ("uuid", uuid)], + )?; + Ok(OIDC_SESSION + .read() + .await + .client + .get(url) + .send() + .await? + .try_into()?) + } + + fn reset(&mut self) { + self.state = OidcState::OidcRequest; + self.failed_msg = "".to_owned(); + self.keep_querying = true; + self.running = false; + self.code_url = None; + self.auth_body = None; + } + + async fn before_task(&mut self) { + self.reset(); + self.running = true; + } + + async fn after_task(&mut self) { + self.running = false; + } + + async fn auth_task(op: String, id: String, uuid: String) { + let code_url = match Self::auth(&op, &id, &uuid).await { + Ok(HbbHttpResponse::<_>::Data(code_url)) => code_url, + Ok(HbbHttpResponse::<_>::Error(err)) => { + OIDC_SESSION + .write() + .await + .set_state(OidcState::OidcRequestFailed, err); + return; + } + Ok(_) => { + OIDC_SESSION.write().await.set_state( + OidcState::OidcRequestFailed, + "Invalid auth response".to_owned(), + ); + return; + } + Err(err) => { + OIDC_SESSION + .write() + .await + .set_state(OidcState::OidcRequestFailed, err.to_string()); + return; + } + }; + + OIDC_SESSION + .write() + .await + .set_state(OidcState::OidcQuerying, "".to_owned()); + OIDC_SESSION.write().await.code_url = Some(code_url.clone()); + + let begin = Instant::now(); + let query_timeout = OIDC_SESSION.read().await.query_timeout; + while OIDC_SESSION.read().await.keep_querying && begin.elapsed() < query_timeout { + match Self::query(&code_url.code, &id, &uuid).await { + Ok(HbbHttpResponse::<_>::Data(auth_body)) => { + OIDC_SESSION + .write() + .await + .set_state(OidcState::OidcLogin, "".to_owned()); + OIDC_SESSION.write().await.auth_body = Some(auth_body); + return; + // to-do, set access-token + } + Ok(HbbHttpResponse::<_>::Error(err)) => { + if err.contains("No authed oidc is found") { + // ignore, keep querying + } else { + OIDC_SESSION + .write() + .await + .set_state(OidcState::OidcQueryFailed, err); + return; + } + } + Ok(_) => { + // ignore + } + Err(err) => { + log::trace!("Failed query oidc {}", err); + // ignore + } + } + sleep(QUERY_INTERVAL_SECS).await; + } + + if begin.elapsed() >= query_timeout { + OIDC_SESSION + .write() + .await + .set_state(OidcState::OidcQueryTimeout, "timeout".to_owned()); + } + + // no need to handle "keep_querying == false" + } + + fn set_state(&mut self, state: OidcState, failed_msg: String) { + self.state = state; + self.failed_msg = failed_msg; + } + + pub async fn account_auth(op: String, id: String, uuid: String) { + if OIDC_SESSION.read().await.running { + OIDC_SESSION.write().await.keep_querying = false; + } + let wait_secs = 0.3; + sleep(wait_secs).await; + while OIDC_SESSION.read().await.running { + sleep(wait_secs).await; + } + + tokio::spawn(async move { + OIDC_SESSION.write().await.before_task().await; + Self::auth_task(op, id, uuid).await; + OIDC_SESSION.write().await.after_task().await; + }); + } + + fn get_result_(&self) -> (u8, String, Option) { + ( + self.state as u8, + self.failed_msg.clone(), + self.auth_body.clone(), + ) + } + + pub async fn get_result() -> (u8, String, Option) { + OIDC_SESSION.read().await.get_result_() + } +} diff --git a/src/lib.rs b/src/lib.rs index 58dc50b04..eb8a876ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,8 @@ mod ui_cm_interface; mod ui_interface; mod ui_session_interface; +mod hbbs_http; + #[cfg(windows)] pub mod clipboard_file;