Skip to content

Commit

Permalink
g3proxy: try shutdown WR cleanly when possoble
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-jq committed Jan 18, 2025
1 parent 6bc990d commit 229bbf7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 64 deletions.
22 changes: 17 additions & 5 deletions g3proxy/src/inspect/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

use std::time::Duration;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::time::Instant;

use g3_daemon::server::ServerQuitPolicy;
Expand Down Expand Up @@ -91,15 +91,21 @@ pub(crate) trait StreamTransitTask {
r = &mut clt_to_ups => {
let _ = ups_to_clt.write_flush().await;
return match r {
Ok(_) => Err(ServerTaskError::ClosedByClient),
Ok(_) => {
let _ = clt_to_ups.writer().shutdown().await;
Err(ServerTaskError::ClosedByClient)
}
Err(LimitedCopyError::ReadFailed(e)) => Err(ServerTaskError::ClientTcpReadFailed(e)),
Err(LimitedCopyError::WriteFailed(e)) => Err(ServerTaskError::UpstreamWriteFailed(e)),
};
}
r = &mut ups_to_clt => {
let _ = clt_to_ups.write_flush().await;
return match r {
Ok(_) => Err(ServerTaskError::ClosedByUpstream),
Ok(_) => {
let _ = ups_to_clt.writer().shutdown().await;
Err(ServerTaskError::ClosedByUpstream)
}
Err(LimitedCopyError::ReadFailed(e)) => Err(ServerTaskError::UpstreamReadFailed(e)),
Err(LimitedCopyError::WriteFailed(e)) => Err(ServerTaskError::ClientTcpWriteFailed(e)),
};
Expand Down Expand Up @@ -267,15 +273,21 @@ where
r = &mut clt_to_ups => {
let _ = ups_to_clt.write_flush().await;
return match r {
Ok(_) => Err(ServerTaskError::ClosedByClient),
Ok(_) => {
let _ = clt_to_ups.writer().shutdown().await;
Err(ServerTaskError::ClosedByClient)
}
Err(LimitedCopyError::ReadFailed(e)) => Err(ServerTaskError::ClientTcpReadFailed(e)),
Err(LimitedCopyError::WriteFailed(e)) => Err(ServerTaskError::UpstreamWriteFailed(e)),
};
}
r = &mut ups_to_clt => {
let _ = clt_to_ups.write_flush().await;
return match r {
Ok(_) => Err(ServerTaskError::ClosedByUpstream),
Ok(_) => {
let _ = ups_to_clt.writer().shutdown().await;
Err(ServerTaskError::ClosedByUpstream)
}
Err(LimitedCopyError::ReadFailed(e)) => Err(ServerTaskError::UpstreamReadFailed(e)),
Err(LimitedCopyError::WriteFailed(e)) => Err(ServerTaskError::ClientTcpWriteFailed(e)),
};
Expand Down
68 changes: 28 additions & 40 deletions g3proxy/src/serve/http_proxy/task/forward/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,10 +630,8 @@ impl<'a> HttpProxyForwardTask<'a> {
.run_with_connection(fwd_ctx, clt_r, clt_w, connection, audit_task)
.await;
match r {
Ok(r) => {
if let Some(connection) = r {
fwd_ctx.save_alive_connection(connection);
}
Ok(ups_s) => {
self.save_or_close(fwd_ctx, clt_w, ups_s).await;
return Ok(());
}
Err(e) => {
Expand Down Expand Up @@ -661,10 +659,8 @@ impl<'a> HttpProxyForwardTask<'a> {
.run_with_connection(fwd_ctx, clt_r, clt_w, connection, audit_task)
.await
{
Ok(r) => {
if let Some(connection) = r {
fwd_ctx.save_alive_connection(connection);
}
Ok(ups_s) => {
self.save_or_close(fwd_ctx, clt_w, ups_s).await;
Ok(())
}
Err(e) => {
Expand All @@ -677,6 +673,24 @@ impl<'a> HttpProxyForwardTask<'a> {
}
}

async fn save_or_close<CDW>(
&self,
fwd_ctx: &mut BoxHttpForwardContext,
clt_w: &mut HttpClientWriter<CDW>,
ups_s: Option<BoxHttpForwardConnection>,
) where
CDW: AsyncWrite + Unpin,
{
if self.should_close {
if let Some(mut connection) = ups_s {
let _ = connection.0.shutdown().await;
}
let _ = clt_w.shutdown().await;
} else if let Some(connection) = ups_s {
fwd_ctx.save_alive_connection(connection);
}
}

async fn get_new_connection<CDW>(
&mut self,
fwd_ctx: &mut BoxHttpForwardContext,
Expand Down Expand Up @@ -975,12 +989,8 @@ impl<'a> HttpProxyForwardTask<'a> {
.await?;

self.task_notes.stage = ServerTaskStage::Finished;
if self.should_close || close_remote {
if self.is_https {
// make sure we correctly shutdown tls connection, or the ticket won't be reused
// FIXME use async drop at escaper side when supported
let _ = ups_w.shutdown().await;
}
if close_remote {
let _ = ups_w.shutdown().await;
Ok(None)
} else {
Ok(Some(ups_c))
Expand Down Expand Up @@ -1164,16 +1174,7 @@ impl<'a> HttpProxyForwardTask<'a> {
.await?;

self.task_notes.stage = ServerTaskStage::Finished;
if self.should_close {
if self.is_https {
// make sure we correctly shutdown tls connection, or the ticket won't be reused
// FIXME use async drop at escaper side when supported
let _ = ups_w.shutdown().await;
}
Ok(None)
} else {
Ok(Some(ups_c))
}
Ok(Some(ups_c))
}

async fn send_full_req_and_recv_rsp(
Expand Down Expand Up @@ -1263,16 +1264,7 @@ impl<'a> HttpProxyForwardTask<'a> {
.await?;

self.task_notes.stage = ServerTaskStage::Finished;
return if self.should_close {
if self.is_https {
// make sure we correctly shutdown tls connection, or the ticket won't be reused
// FIXME use async drop at escaper side when supported
let _ = ups_w.shutdown().await;
}
Ok(None)
} else {
Ok(Some(ups_c))
};
return Ok(Some(ups_c));
}
}

Expand Down Expand Up @@ -1444,12 +1436,8 @@ impl<'a> HttpProxyForwardTask<'a> {
.await?;

self.task_notes.stage = ServerTaskStage::Finished;
if self.should_close || close_remote {
if self.is_https {
// make sure we correctly shutdown tls connection, or the ticket won't be reused
// FIXME use async drop at escaper side when supported
let _ = ups_w.shutdown().await;
}
if close_remote {
let _ = ups_w.shutdown().await;
Ok(None)
} else {
Ok(Some(ups_c))
Expand Down
45 changes: 26 additions & 19 deletions g3proxy/src/serve/http_rproxy/task/forward/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,10 +524,8 @@ impl<'a> HttpRProxyForwardTask<'a> {
.run_with_connection(fwd_ctx, clt_r, clt_w, connection)
.await;
match r {
Ok(r) => {
if let Some(connection) = r {
fwd_ctx.save_alive_connection(connection);
}
Ok(ups_s) => {
self.save_or_close(fwd_ctx, clt_w, ups_s).await;
return Ok(());
}
Err(e) => {
Expand Down Expand Up @@ -555,10 +553,8 @@ impl<'a> HttpRProxyForwardTask<'a> {
.run_with_connection(fwd_ctx, clt_r, clt_w, connection)
.await
{
Ok(r) => {
if let Some(connection) = r {
fwd_ctx.save_alive_connection(connection);
}
Ok(ups_s) => {
self.save_or_close(fwd_ctx, clt_w, ups_s).await;
Ok(())
}
Err(e) => {
Expand All @@ -571,6 +567,24 @@ impl<'a> HttpRProxyForwardTask<'a> {
}
}

async fn save_or_close<CDW>(
&self,
fwd_ctx: &mut BoxHttpForwardContext,
clt_w: &mut HttpClientWriter<CDW>,
ups_s: Option<BoxHttpForwardConnection>,
) where
CDW: AsyncWrite + Unpin,
{
if self.should_close {
if let Some(mut connection) = ups_s {
let _ = connection.0.shutdown().await;
}
let _ = clt_w.shutdown().await;
} else if let Some(connection) = ups_s {
fwd_ctx.save_alive_connection(connection);
}
}

async fn get_new_connection<CDW>(
&mut self,
fwd_ctx: &mut BoxHttpForwardContext,
Expand Down Expand Up @@ -785,11 +799,7 @@ impl<'a> HttpRProxyForwardTask<'a> {
self.send_response(clt_w, ups_r, &rsp_header).await?;

self.task_notes.stage = ServerTaskStage::Finished;
if self.should_close {
Ok(None)
} else {
Ok(Some(ups_c))
}
Ok(Some(ups_c))
}

async fn send_full_req_and_recv_rsp(
Expand Down Expand Up @@ -879,11 +889,7 @@ impl<'a> HttpRProxyForwardTask<'a> {
self.send_response(clt_w, ups_r, &rsp_header).await?;

self.task_notes.stage = ServerTaskStage::Finished;
return if self.should_close {
Ok(None)
} else {
Ok(Some(ups_c))
};
return Ok(Some(ups_c));
}
}

Expand Down Expand Up @@ -1055,7 +1061,8 @@ impl<'a> HttpRProxyForwardTask<'a> {
self.send_response(clt_w, ups_r, &rsp_header).await?;

self.task_notes.stage = ServerTaskStage::Finished;
if self.should_close || close_remote {
if close_remote {
let _ = ups_w.shutdown().await;
Ok(None)
} else {
Ok(Some(ups_c))
Expand Down

0 comments on commit 229bbf7

Please sign in to comment.