use super::HbbHttpResponse; use hbb_common::{ config::{Config, LocalConfig}, log, ResultType, }; use reqwest::blocking::Client; use serde_derive::{Deserialize, Serialize}; use std::{ collections::HashMap, sync::{Arc, RwLock}, time::{Duration, Instant}, }; use url::Url; lazy_static::lazy_static! { static ref API_SERVER: String = crate::get_api_server( Config::get_option("api-server"), Config::get_option("custom-rendezvous-server")); static ref OIDC_SESSION: Arc> = Arc::new(RwLock::new(OidcSession::new())); } const QUERY_INTERVAL_SECS: f32 = 1.0; const QUERY_TIMEOUT_SECS: u64 = 60 * 3; const REQUESTING_ACCOUNT_AUTH: &str = "Requesting account auth"; const WAITING_ACCOUNT_AUTH: &str = "Waiting account auth"; const LOGIN_ACCOUNT_AUTH: &str = "Login account auth"; #[derive(Deserialize, Clone, Debug)] pub struct OidcAuthUrl { code: String, url: Url, } #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct UserPayload { pub id: String, pub name: String, pub email: Option, pub note: Option, pub status: Option, pub grp: Option, pub is_admin: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AuthBody { pub access_token: String, pub token_type: String, pub user: UserPayload, } pub struct OidcSession { client: Client, state_msg: &'static str, failed_msg: String, code_url: Option, auth_body: Option, keep_querying: bool, running: bool, query_timeout: Duration, } #[derive(Serialize)] pub struct AuthResult { pub state_msg: String, pub failed_msg: String, pub url: Option, pub auth_body: Option, } impl OidcSession { fn new() -> Self { Self { client: Client::new(), state_msg: REQUESTING_ACCOUNT_AUTH, failed_msg: "".to_owned(), code_url: None, auth_body: None, keep_querying: false, running: false, query_timeout: Duration::from_secs(QUERY_TIMEOUT_SECS), } } fn auth(op: &str, id: &str, uuid: &str) -> ResultType> { Ok(OIDC_SESSION .read() .unwrap() .client .post(format!("{}/api/oidc/auth", *API_SERVER)) .json(&HashMap::from([("op", op), ("id", id), ("uuid", uuid)])) .send()? .try_into()?) } fn query(code: &str, id: &str, uuid: &str) -> ResultType> { let url = reqwest::Url::parse_with_params( &format!("{}/api/oidc/auth-query", *API_SERVER), &[("code", code), ("id", id), ("uuid", uuid)], )?; Ok(OIDC_SESSION .read() .unwrap() .client .get(url) .send()? .try_into()?) } fn reset(&mut self) { self.state_msg = REQUESTING_ACCOUNT_AUTH; self.failed_msg = "".to_owned(); self.keep_querying = true; self.running = false; self.code_url = None; self.auth_body = None; } fn before_task(&mut self) { self.reset(); self.running = true; } fn after_task(&mut self) { self.running = false; } fn sleep(secs: f32) { std::thread::sleep(std::time::Duration::from_secs_f32(secs)); } fn auth_task(op: String, id: String, uuid: String) { let auth_request_res = Self::auth(&op, &id, &uuid); log::info!("Request oidc auth result: {:?}", &auth_request_res); let code_url = match auth_request_res { Ok(HbbHttpResponse::<_>::Data(code_url)) => code_url, Ok(HbbHttpResponse::<_>::Error(err)) => { OIDC_SESSION .write() .unwrap() .set_state(REQUESTING_ACCOUNT_AUTH, err); return; } Ok(_) => { OIDC_SESSION .write() .unwrap() .set_state(REQUESTING_ACCOUNT_AUTH, "Invalid auth response".to_owned()); return; } Err(err) => { OIDC_SESSION .write() .unwrap() .set_state(REQUESTING_ACCOUNT_AUTH, err.to_string()); return; } }; OIDC_SESSION .write() .unwrap() .set_state(WAITING_ACCOUNT_AUTH, "".to_owned()); OIDC_SESSION.write().unwrap().code_url = Some(code_url.clone()); let begin = Instant::now(); let query_timeout = OIDC_SESSION.read().unwrap().query_timeout; while OIDC_SESSION.read().unwrap().keep_querying && begin.elapsed() < query_timeout { match Self::query(&code_url.code, &id, &uuid) { Ok(HbbHttpResponse::<_>::Data(auth_body)) => { LocalConfig::set_option( "access_token".to_owned(), auth_body.access_token.clone(), ); LocalConfig::set_option( "user_info".to_owned(), serde_json::to_string(&auth_body.user).unwrap_or_default(), ); OIDC_SESSION .write() .unwrap() .set_state(LOGIN_ACCOUNT_AUTH, "".to_owned()); OIDC_SESSION.write().unwrap().auth_body = Some(auth_body); return; } Ok(HbbHttpResponse::<_>::Error(err)) => { if err.contains("No authed oidc is found") { // ignore, keep querying } else { OIDC_SESSION .write() .unwrap() .set_state(WAITING_ACCOUNT_AUTH, err); return; } } Ok(_) => { // ignore } Err(err) => { log::trace!("Failed query oidc {}", err); // ignore } } Self::sleep(QUERY_INTERVAL_SECS); } if begin.elapsed() >= query_timeout { OIDC_SESSION .write() .unwrap() .set_state(WAITING_ACCOUNT_AUTH, "timeout".to_owned()); } // no need to handle "keep_querying == false" } fn set_state(&mut self, state_msg: &'static str, failed_msg: String) { self.state_msg = state_msg; self.failed_msg = failed_msg; } fn wait_stop_querying() { let wait_secs = 0.3; while OIDC_SESSION.read().unwrap().running { Self::sleep(wait_secs); } } pub fn account_auth(op: String, id: String, uuid: String) { Self::auth_cancel(); Self::wait_stop_querying(); OIDC_SESSION.write().unwrap().before_task(); std::thread::spawn(|| { Self::auth_task(op, id, uuid); OIDC_SESSION.write().unwrap().after_task(); }); } fn get_result_(&self) -> AuthResult { AuthResult { state_msg: self.state_msg.to_string(), failed_msg: self.failed_msg.clone(), url: self.code_url.as_ref().map(|x| x.url.to_string()), auth_body: self.auth_body.clone(), } } pub fn auth_cancel() { OIDC_SESSION.write().unwrap().keep_querying = false; } pub fn get_result() -> AuthResult { OIDC_SESSION.read().unwrap().get_result_() } }