From ac433dc11a61aff7576c431bd6f6c6b8be925279 Mon Sep 17 00:00:00 2001
From: 21pages <pages21@163.com>
Date: Tue, 3 Jan 2023 19:16:23 +0800
Subject: [PATCH] fix post heartbeat block

Signed-off-by: 21pages <pages21@163.com>
---
 src/server/connection.rs | 37 ++++++++++++++++++++++++-------------
 1 file changed, 24 insertions(+), 13 deletions(-)

diff --git a/src/server/connection.rs b/src/server/connection.rs
index 919aeae99..f91281a52 100644
--- a/src/server/connection.rs
+++ b/src/server/connection.rs
@@ -154,6 +154,7 @@ impl Connection {
         let (tx, mut rx) = mpsc::unbounded_channel::<(Instant, Arc<Message>)>();
         let (tx_video, mut rx_video) = mpsc::unbounded_channel::<(Instant, Arc<Message>)>();
         let (tx_input, rx_input) = std_mpsc::channel();
+        let (tx_stop, mut rx_stop) = mpsc::unbounded_channel::<String>();
 
         let tx_cloned = tx.clone();
         let mut conn = Self {
@@ -393,11 +394,12 @@ impl Connection {
                     }
                 }
                 _ = conn.http_timer.tick() => {
-                    if let Err(_) = Connection::post_heartbeat(conn.server_audit_conn.clone(), conn.inner.id).await {
-                        conn.on_close_manually("web console", "web console").await;
-                        break;
-                    }
+                    Connection::post_heartbeat(conn.server_audit_conn.clone(), conn.inner.id, tx_stop.clone());
                 },
+                Some(reason) = rx_stop.recv() => {
+                    conn.on_close_manually(&reason, &reason).await;
+
+                }
                 Some((instant, value)) = rx_video.recv() => {
                     if !conn.video_ack_required {
                         video_service::notify_video_frame_fetched(id, Some(instant.into()));
@@ -582,6 +584,7 @@ impl Connection {
         rx_from_cm: &mut mpsc::UnboundedReceiver<Data>,
     ) -> ResultType<()> {
         let mut last_recv_time = Instant::now();
+        let (tx_stop, mut rx_stop) = mpsc::unbounded_channel::<String>();
         if let Some(mut forward) = self.port_forward_socket.take() {
             log::info!("Running port forwarding loop");
             self.stream.set_raw();
@@ -615,7 +618,10 @@ impl Connection {
                         if last_recv_time.elapsed() >= H1 {
                             bail!("Timeout");
                         }
-                        Connection::post_heartbeat(self.server_audit_conn.clone(), self.inner.id).await?;
+                        Connection::post_heartbeat(self.server_audit_conn.clone(), self.inner.id, tx_stop.clone());
+                    }
+                    Some(reason) = rx_stop.recv() => {
+                        bail!(reason);
                     }
                 }
             }
@@ -705,23 +711,28 @@ impl Connection {
         });
     }
 
-    async fn post_heartbeat(server_audit_conn: String, conn_id: i32) -> ResultType<()> {
+    fn post_heartbeat(
+        server_audit_conn: String,
+        conn_id: i32,
+        tx_stop: mpsc::UnboundedSender<String>,
+    ) {
         if server_audit_conn.is_empty() {
-            return Ok(());
+            return;
         }
         let url = server_audit_conn.clone();
         let mut v = Value::default();
         v["id"] = json!(Config::get_id());
         v["uuid"] = json!(base64::encode(hbb_common::get_uuid()));
         v["conn_id"] = json!(conn_id);
-        if let Ok(rsp) = Self::post_audit_async(url, v).await {
-            if let Ok(rsp) = serde_json::from_str::<ConnAuditResponse>(&rsp) {
-                if rsp.action == "disconnect" {
-                    bail!("disconnect by server");
+        tokio::spawn(async move {
+            if let Ok(rsp) = Self::post_audit_async(url, v).await {
+                if let Ok(rsp) = serde_json::from_str::<ConnAuditResponse>(&rsp) {
+                    if rsp.action == "disconnect" {
+                        tx_stop.send("web console".to_string()).ok();
+                    }
                 }
             }
-        }
-        return Ok(());
+        });
     }
 
     fn post_file_audit(