refactor session insert, update if already exists (#9739)

* All share the same last_receive_time
* Not second port forward

Signed-off-by: 21pages <sunboeasy@gmail.com>
This commit is contained in:
21pages 2024-10-24 23:14:43 +08:00 committed by GitHub
parent 4da584055d
commit c8b9031996
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -152,8 +152,6 @@ struct Session {
last_recv_time: Arc<Mutex<Instant>>, last_recv_time: Arc<Mutex<Instant>>,
random_password: String, random_password: String,
tfa: bool, tfa: bool,
conn_type: AuthConnType,
conn_id: i32,
} }
#[cfg(not(any(target_os = "android", target_os = "ios")))] #[cfg(not(any(target_os = "android", target_os = "ios")))]
@ -217,7 +215,7 @@ pub struct Connection {
server_audit_conn: String, server_audit_conn: String,
server_audit_file: String, server_audit_file: String,
lr: LoginRequest, lr: LoginRequest,
last_recv_time: Arc<Mutex<Instant>>, session_last_recv_time: Option<Arc<Mutex<Instant>>>,
chat_unanswered: bool, chat_unanswered: bool,
file_transferred: bool, file_transferred: bool,
#[cfg(windows)] #[cfg(windows)]
@ -364,7 +362,7 @@ impl Connection {
server_audit_conn: "".to_owned(), server_audit_conn: "".to_owned(),
server_audit_file: "".to_owned(), server_audit_file: "".to_owned(),
lr: Default::default(), lr: Default::default(),
last_recv_time: Arc::new(Mutex::new(Instant::now())), session_last_recv_time: None,
chat_unanswered: false, chat_unanswered: false,
file_transferred: false, file_transferred: false,
#[cfg(windows)] #[cfg(windows)]
@ -595,7 +593,7 @@ impl Connection {
}, },
Ok(bytes) => { Ok(bytes) => {
last_recv_time = Instant::now(); last_recv_time = Instant::now();
*conn.last_recv_time.lock().unwrap() = Instant::now(); conn.session_last_recv_time.as_mut().map(|t| *t.lock().unwrap() = Instant::now());
if let Ok(msg_in) = Message::parse_from_bytes(&bytes) { if let Ok(msg_in) = Message::parse_from_bytes(&bytes) {
if !conn.on_message(msg_in).await { if !conn.on_message(msg_in).await {
break; break;
@ -762,6 +760,10 @@ impl Connection {
} }
if let Err(err) = conn.try_port_forward_loop(&mut rx_from_cm).await { if let Err(err) = conn.try_port_forward_loop(&mut rx_from_cm).await {
conn.on_close(&err.to_string(), false).await; conn.on_close(&err.to_string(), false).await;
raii::AuthedConnID::remove_session_if_last_duplication(
conn.inner.id(),
conn.session_key(),
);
} }
conn.post_conn_audit(json!({ conn.post_conn_audit(json!({
@ -1140,6 +1142,11 @@ impl Connection {
auth_conn_type, auth_conn_type,
self.session_key(), self.session_key(),
)); ));
self.session_last_recv_time = SESSIONS
.lock()
.unwrap()
.get(&self.session_key())
.map(|s| s.last_recv_time.clone());
self.post_conn_audit( self.post_conn_audit(
json!({"peer": ((&self.lr.my_id, &self.lr.my_name)), "type": conn_type}), json!({"peer": ((&self.lr.my_id, &self.lr.my_name)), "type": conn_type}),
); );
@ -1549,15 +1556,10 @@ impl Connection {
if password::temporary_enabled() { if password::temporary_enabled() {
let password = password::temporary_password(); let password = password::temporary_password();
if self.validate_one_password(password.clone()) { if self.validate_one_password(password.clone()) {
raii::AuthedConnID::insert_session( raii::AuthedConnID::update_or_insert_session(
self.session_key(), self.session_key(),
Session { Some(password),
last_recv_time: self.last_recv_time.clone(), Some(false),
random_password: password,
tfa: false,
conn_type: self.conn_type(),
conn_id: self.inner.id(),
},
); );
return true; return true;
} }
@ -1581,15 +1583,11 @@ impl Connection {
.get(&self.session_key()) .get(&self.session_key())
.map(|s| s.to_owned()); .map(|s| s.to_owned());
// last_recv_time is a mutex variable shared with connection, can be updated lively. // last_recv_time is a mutex variable shared with connection, can be updated lively.
if let Some(mut session) = session { if let Some(session) = session {
if !self.lr.password.is_empty() if !self.lr.password.is_empty()
&& (tfa && session.tfa && (tfa && session.tfa
|| !tfa && self.validate_one_password(session.random_password.clone())) || !tfa && self.validate_one_password(session.random_password.clone()))
{ {
session.last_recv_time = self.last_recv_time.clone();
session.conn_id = self.inner.id();
session.conn_type = self.conn_type();
raii::AuthedConnID::insert_session(self.session_key(), session);
log::info!("is recent session"); log::info!("is recent session");
return true; return true;
} }
@ -1841,34 +1839,13 @@ impl Connection {
if res { if res {
self.update_failure(failure, true, 1); self.update_failure(failure, true, 1);
self.require_2fa.take(); self.require_2fa.take();
raii::AuthedConnID::set_session_2fa(self.session_key());
self.send_logon_response().await; self.send_logon_response().await;
self.try_start_cm( self.try_start_cm(
self.lr.my_id.to_owned(), self.lr.my_id.to_owned(),
self.lr.my_name.to_owned(), self.lr.my_name.to_owned(),
self.authorized, self.authorized,
); );
let session = SESSIONS
.lock()
.unwrap()
.get(&self.session_key())
.map(|s| s.to_owned());
if let Some(mut session) = session {
session.tfa = true;
session.conn_id = self.inner.id();
session.conn_type = self.conn_type();
raii::AuthedConnID::insert_session(self.session_key(), session);
} else {
raii::AuthedConnID::insert_session(
self.session_key(),
Session {
last_recv_time: self.last_recv_time.clone(),
random_password: "".to_owned(),
tfa: true,
conn_type: self.conn_type(),
conn_id: self.inner.id(),
},
);
}
if !tfa.hwid.is_empty() && Self::enable_trusted_devices() { if !tfa.hwid.is_empty() && Self::enable_trusted_devices() {
Config::add_trusted_device(TrustedDevice { Config::add_trusted_device(TrustedDevice {
hwid: tfa.hwid, hwid: tfa.hwid,
@ -3872,16 +3849,17 @@ mod raii {
} }
pub fn remove_session_if_last_duplication(conn_id: i32, key: SessionKey) { pub fn remove_session_if_last_duplication(conn_id: i32, key: SessionKey) {
let contains = SESSIONS.lock().unwrap().contains_key(&key); let mut lock = SESSIONS.lock().unwrap();
let contains = lock.contains_key(&key);
if contains { if contains {
let another = AUTHED_CONNS let another = AUTHED_CONNS
.lock() .lock()
.unwrap() .unwrap()
.iter() .iter()
.any(|c| c.0 != conn_id && c.2 == key && c.1 != AuthConnType::PortForward); .any(|c| c.0 != conn_id && c.2 == key);
if !another { if !another {
// Keep the session if there is another connection with same peer_id and session_id. // Keep the session if there is another connection with same peer_id and session_id.
SESSIONS.lock().unwrap().remove(&key); lock.remove(&key);
log::info!("remove session"); log::info!("remove session");
} else { } else {
log::info!("skip remove session"); log::info!("skip remove session");
@ -3889,32 +3867,46 @@ mod raii {
} }
} }
pub fn insert_session(key: SessionKey, session: Session) { pub fn update_or_insert_session(
let mut insert = true; key: SessionKey,
if session.conn_type == AuthConnType::PortForward { password: Option<String>,
// port forward doesn't update last received time tfa: Option<bool>,
let other_alive_conns = AUTHED_CONNS ) {
.lock() let mut lock = SESSIONS.lock().unwrap();
.unwrap() let session = lock.get_mut(&key);
.iter() if let Some(session) = session {
.filter(|c| { if let Some(password) = password {
c.2 == key && c.1 != AuthConnType::PortForward // port forward doesn't remove itself session.random_password = password;
})
.map(|c| c.0)
.collect::<Vec<_>>();
let another = SESSIONS.lock().unwrap().get(&key).map(|s| {
other_alive_conns.contains(&s.conn_id)
&& s.tfa == session.tfa
&& s.conn_type != AuthConnType::PortForward
}) == Some(true);
if another {
insert = false;
log::info!("skip insert session for port forward");
} }
if let Some(tfa) = tfa {
session.tfa = tfa;
}
} else {
lock.insert(
key,
Session {
random_password: password.unwrap_or_default(),
tfa: tfa.unwrap_or_default(),
last_recv_time: Arc::new(Mutex::new(Instant::now())),
},
);
} }
if insert { }
log::info!("insert session for {:?}", session.conn_type);
SESSIONS.lock().unwrap().insert(key, session); pub fn set_session_2fa(key: SessionKey) {
let mut lock = SESSIONS.lock().unwrap();
let session = lock.get_mut(&key);
if let Some(session) = session {
session.tfa = true;
} else {
lock.insert(
key,
Session {
last_recv_time: Arc::new(Mutex::new(Instant::now())),
random_password: "".to_owned(),
tfa: true,
},
);
} }
} }
} }