diff --git a/g3proxy/src/inspect/stream/mod.rs b/g3proxy/src/inspect/stream/mod.rs index e85167abe..c69a9a4cc 100644 --- a/g3proxy/src/inspect/stream/mod.rs +++ b/g3proxy/src/inspect/stream/mod.rs @@ -16,7 +16,7 @@ use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::time::Instant; use g3_daemon::server::ServerQuitPolicy; @@ -91,7 +91,10 @@ pub(crate) trait StreamTransitTask { r = &mut clt_to_ups => { let _ = ups_to_clt.write_flush().await; return match r { - Ok(_) => Err(ServerTaskError::ClosedByClient), + Ok(_) => { + let _ = clt_to_ups.writer().shutdown().await; + Err(ServerTaskError::ClosedByClient) + } Err(LimitedCopyError::ReadFailed(e)) => Err(ServerTaskError::ClientTcpReadFailed(e)), Err(LimitedCopyError::WriteFailed(e)) => Err(ServerTaskError::UpstreamWriteFailed(e)), }; @@ -99,7 +102,10 @@ pub(crate) trait StreamTransitTask { r = &mut ups_to_clt => { let _ = clt_to_ups.write_flush().await; return match r { - Ok(_) => Err(ServerTaskError::ClosedByUpstream), + Ok(_) => { + let _ = ups_to_clt.writer().shutdown().await; + Err(ServerTaskError::ClosedByUpstream) + } Err(LimitedCopyError::ReadFailed(e)) => Err(ServerTaskError::UpstreamReadFailed(e)), Err(LimitedCopyError::WriteFailed(e)) => Err(ServerTaskError::ClientTcpWriteFailed(e)), }; @@ -267,7 +273,10 @@ where r = &mut clt_to_ups => { let _ = ups_to_clt.write_flush().await; return match r { - Ok(_) => Err(ServerTaskError::ClosedByClient), + Ok(_) => { + let _ = clt_to_ups.writer().shutdown().await; + Err(ServerTaskError::ClosedByClient) + } Err(LimitedCopyError::ReadFailed(e)) => Err(ServerTaskError::ClientTcpReadFailed(e)), Err(LimitedCopyError::WriteFailed(e)) => Err(ServerTaskError::UpstreamWriteFailed(e)), }; @@ -275,7 +284,10 @@ where r = &mut ups_to_clt => { let _ = clt_to_ups.write_flush().await; return match r { - Ok(_) => Err(ServerTaskError::ClosedByUpstream), + Ok(_) => { + let _ = ups_to_clt.writer().shutdown().await; + Err(ServerTaskError::ClosedByUpstream) + } Err(LimitedCopyError::ReadFailed(e)) => Err(ServerTaskError::UpstreamReadFailed(e)), Err(LimitedCopyError::WriteFailed(e)) => Err(ServerTaskError::ClientTcpWriteFailed(e)), }; diff --git a/g3proxy/src/serve/http_proxy/task/forward/task.rs b/g3proxy/src/serve/http_proxy/task/forward/task.rs index 128cebc1f..05cf5d4df 100644 --- a/g3proxy/src/serve/http_proxy/task/forward/task.rs +++ b/g3proxy/src/serve/http_proxy/task/forward/task.rs @@ -630,10 +630,8 @@ impl<'a> HttpProxyForwardTask<'a> { .run_with_connection(fwd_ctx, clt_r, clt_w, connection, audit_task) .await; match r { - Ok(r) => { - if let Some(connection) = r { - fwd_ctx.save_alive_connection(connection); - } + Ok(ups_s) => { + self.save_or_close(fwd_ctx, clt_w, ups_s).await; return Ok(()); } Err(e) => { @@ -661,10 +659,8 @@ impl<'a> HttpProxyForwardTask<'a> { .run_with_connection(fwd_ctx, clt_r, clt_w, connection, audit_task) .await { - Ok(r) => { - if let Some(connection) = r { - fwd_ctx.save_alive_connection(connection); - } + Ok(ups_s) => { + self.save_or_close(fwd_ctx, clt_w, ups_s).await; Ok(()) } Err(e) => { @@ -677,6 +673,24 @@ impl<'a> HttpProxyForwardTask<'a> { } } + async fn save_or_close( + &self, + fwd_ctx: &mut BoxHttpForwardContext, + clt_w: &mut HttpClientWriter, + ups_s: Option, + ) where + CDW: AsyncWrite + Unpin, + { + if self.should_close { + if let Some(mut connection) = ups_s { + let _ = connection.0.shutdown().await; + } + let _ = clt_w.shutdown().await; + } else if let Some(connection) = ups_s { + fwd_ctx.save_alive_connection(connection); + } + } + async fn get_new_connection( &mut self, fwd_ctx: &mut BoxHttpForwardContext, @@ -975,12 +989,8 @@ impl<'a> HttpProxyForwardTask<'a> { .await?; self.task_notes.stage = ServerTaskStage::Finished; - if self.should_close || close_remote { - if self.is_https { - // make sure we correctly shutdown tls connection, or the ticket won't be reused - // FIXME use async drop at escaper side when supported - let _ = ups_w.shutdown().await; - } + if close_remote { + let _ = ups_w.shutdown().await; Ok(None) } else { Ok(Some(ups_c)) @@ -1164,16 +1174,7 @@ impl<'a> HttpProxyForwardTask<'a> { .await?; self.task_notes.stage = ServerTaskStage::Finished; - if self.should_close { - if self.is_https { - // make sure we correctly shutdown tls connection, or the ticket won't be reused - // FIXME use async drop at escaper side when supported - let _ = ups_w.shutdown().await; - } - Ok(None) - } else { - Ok(Some(ups_c)) - } + Ok(Some(ups_c)) } async fn send_full_req_and_recv_rsp( @@ -1263,16 +1264,7 @@ impl<'a> HttpProxyForwardTask<'a> { .await?; self.task_notes.stage = ServerTaskStage::Finished; - return if self.should_close { - if self.is_https { - // make sure we correctly shutdown tls connection, or the ticket won't be reused - // FIXME use async drop at escaper side when supported - let _ = ups_w.shutdown().await; - } - Ok(None) - } else { - Ok(Some(ups_c)) - }; + return Ok(Some(ups_c)); } } @@ -1444,12 +1436,8 @@ impl<'a> HttpProxyForwardTask<'a> { .await?; self.task_notes.stage = ServerTaskStage::Finished; - if self.should_close || close_remote { - if self.is_https { - // make sure we correctly shutdown tls connection, or the ticket won't be reused - // FIXME use async drop at escaper side when supported - let _ = ups_w.shutdown().await; - } + if close_remote { + let _ = ups_w.shutdown().await; Ok(None) } else { Ok(Some(ups_c)) diff --git a/g3proxy/src/serve/http_rproxy/task/forward/task.rs b/g3proxy/src/serve/http_rproxy/task/forward/task.rs index 3f3d63416..963f56d0f 100644 --- a/g3proxy/src/serve/http_rproxy/task/forward/task.rs +++ b/g3proxy/src/serve/http_rproxy/task/forward/task.rs @@ -524,10 +524,8 @@ impl<'a> HttpRProxyForwardTask<'a> { .run_with_connection(fwd_ctx, clt_r, clt_w, connection) .await; match r { - Ok(r) => { - if let Some(connection) = r { - fwd_ctx.save_alive_connection(connection); - } + Ok(ups_s) => { + self.save_or_close(fwd_ctx, clt_w, ups_s).await; return Ok(()); } Err(e) => { @@ -555,10 +553,8 @@ impl<'a> HttpRProxyForwardTask<'a> { .run_with_connection(fwd_ctx, clt_r, clt_w, connection) .await { - Ok(r) => { - if let Some(connection) = r { - fwd_ctx.save_alive_connection(connection); - } + Ok(ups_s) => { + self.save_or_close(fwd_ctx, clt_w, ups_s).await; Ok(()) } Err(e) => { @@ -571,6 +567,24 @@ impl<'a> HttpRProxyForwardTask<'a> { } } + async fn save_or_close( + &self, + fwd_ctx: &mut BoxHttpForwardContext, + clt_w: &mut HttpClientWriter, + ups_s: Option, + ) where + CDW: AsyncWrite + Unpin, + { + if self.should_close { + if let Some(mut connection) = ups_s { + let _ = connection.0.shutdown().await; + } + let _ = clt_w.shutdown().await; + } else if let Some(connection) = ups_s { + fwd_ctx.save_alive_connection(connection); + } + } + async fn get_new_connection( &mut self, fwd_ctx: &mut BoxHttpForwardContext, @@ -785,11 +799,7 @@ impl<'a> HttpRProxyForwardTask<'a> { self.send_response(clt_w, ups_r, &rsp_header).await?; self.task_notes.stage = ServerTaskStage::Finished; - if self.should_close { - Ok(None) - } else { - Ok(Some(ups_c)) - } + Ok(Some(ups_c)) } async fn send_full_req_and_recv_rsp( @@ -879,11 +889,7 @@ impl<'a> HttpRProxyForwardTask<'a> { self.send_response(clt_w, ups_r, &rsp_header).await?; self.task_notes.stage = ServerTaskStage::Finished; - return if self.should_close { - Ok(None) - } else { - Ok(Some(ups_c)) - }; + return Ok(Some(ups_c)); } } @@ -1055,7 +1061,8 @@ impl<'a> HttpRProxyForwardTask<'a> { self.send_response(clt_w, ups_r, &rsp_header).await?; self.task_notes.stage = ServerTaskStage::Finished; - if self.should_close || close_remote { + if close_remote { + let _ = ups_w.shutdown().await; Ok(None) } else { Ok(Some(ups_c))