Use copy_bidirectional, handle half-closed TCP streams (#165)

This commit is contained in:
Eric Zhang
2025-06-09 16:10:40 -04:00
committed by GitHub
parent 8ad7ee212b
commit 7969486d32
4 changed files with 44 additions and 24 deletions

View File

@ -8,9 +8,7 @@ use tracing::{error, info, info_span, warn, Instrument};
use uuid::Uuid;
use crate::auth::Authenticator;
use crate::shared::{
proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT,
};
use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT};
/// State structure for the client.
pub struct Client {
@ -112,10 +110,10 @@ impl Client {
}
remote_conn.send(ClientMessage::Accept(id)).await?;
let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
let parts = remote_conn.into_parts();
let mut parts = remote_conn.into_parts();
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
local_conn.write_all(&parts.read_buf).await?; // mostly of the cases, this will be empty
proxy(local_conn, parts.io).await?;
tokio::io::copy_bidirectional(&mut local_conn, &mut parts.io).await?;
Ok(())
}
}

View File

@ -12,7 +12,7 @@ use tracing::{info, info_span, warn, Instrument};
use uuid::Uuid;
use crate::auth::Authenticator;
use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
/// State structure for the server.
pub struct Server {
@ -172,10 +172,10 @@ impl Server {
info!(%id, "forwarding connection");
match self.conns.remove(&id) {
Some((_, mut stream2)) => {
let parts = stream.into_parts();
let mut parts = stream.into_parts();
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
stream2.write_all(&parts.read_buf).await?;
proxy(parts.io, stream2).await?
tokio::io::copy_bidirectional(&mut parts.io, &mut stream2).await?;
}
None => warn!(%id, "missing connection"),
}

View File

@ -5,7 +5,7 @@ use std::time::Duration;
use anyhow::{Context, Result};
use futures_util::{SinkExt, StreamExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::io::{self, AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::timeout;
use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts};
use tracing::trace;
@ -97,18 +97,3 @@ impl<U: AsyncRead + AsyncWrite + Unpin> Delimited<U> {
self.0.into_parts()
}
}
/// Copy data mutually between two read/write streams.
pub async fn proxy<S1, S2>(stream1: S1, stream2: S2) -> io::Result<()>
where
S1: AsyncRead + AsyncWrite + Unpin,
S2: AsyncRead + AsyncWrite + Unpin,
{
let (mut s1_read, mut s1_write) = io::split(stream1);
let (mut s2_read, mut s2_write) = io::split(stream2);
tokio::select! {
res = io::copy(&mut s1_read, &mut s2_write) => res,
res = io::copy(&mut s2_read, &mut s1_write) => res,
}?;
Ok(())
}

View File

@ -125,3 +125,40 @@ fn empty_port_range() {
let max_port = 3000;
let _ = Server::new(min_port..=max_port, None);
}
#[tokio::test]
async fn half_closed_tcp_stream() -> Result<()> {
// Check that "half-closed" TCP streams will not result in spontaneous hangups.
let _guard = SERIAL_GUARD.lock().await;
spawn_server(None).await;
let (listener, addr) = spawn_client(None).await?;
let (mut cli, (mut srv, _)) = tokio::try_join!(TcpStream::connect(addr), listener.accept())?;
// Send data before half-closing one of the streams.
let mut buf = b"message before shutdown".to_vec();
cli.write_all(&buf).await?;
// Only close the write half of the stream. This is a half-closed stream. In the
// TCP protocol, it is represented as a FIN packet on one end. The entire stream
// is only closed after two FINs are exchanged and ACKed by the other end.
cli.shutdown().await?;
srv.read_exact(&mut buf).await?;
assert_eq!(buf, b"message before shutdown");
assert_eq!(srv.read(&mut buf).await?, 0); // EOF
// Now make sure that the other stream can still send data, despite
// half-shutdown on client->server side.
let mut buf = b"hello from the other side!".to_vec();
srv.write_all(&buf).await?;
cli.read_exact(&mut buf).await?;
assert_eq!(buf, b"hello from the other side!");
// We don't have to think about CLOSE_RD handling because that's not really
// part of the TCP protocol, just the POSIX streams API. It is implemented by
// the OS ignoring future packets received on that stream.
Ok(())
}