mirror of
https://github.com/ekzhang/bore.git
synced 2025-07-05 07:51:56 +07:00
Use copy_bidirectional, handle half-closed TCP streams (#165)
This commit is contained in:
@ -8,9 +8,7 @@ use tracing::{error, info, info_span, warn, Instrument};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::Authenticator;
|
||||
use crate::shared::{
|
||||
proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT,
|
||||
};
|
||||
use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT};
|
||||
|
||||
/// State structure for the client.
|
||||
pub struct Client {
|
||||
@ -112,10 +110,10 @@ impl Client {
|
||||
}
|
||||
remote_conn.send(ClientMessage::Accept(id)).await?;
|
||||
let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
|
||||
let parts = remote_conn.into_parts();
|
||||
let mut parts = remote_conn.into_parts();
|
||||
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
|
||||
local_conn.write_all(&parts.read_buf).await?; // mostly of the cases, this will be empty
|
||||
proxy(local_conn, parts.io).await?;
|
||||
tokio::io::copy_bidirectional(&mut local_conn, &mut parts.io).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ use tracing::{info, info_span, warn, Instrument};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::Authenticator;
|
||||
use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
|
||||
use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
|
||||
|
||||
/// State structure for the server.
|
||||
pub struct Server {
|
||||
@ -172,10 +172,10 @@ impl Server {
|
||||
info!(%id, "forwarding connection");
|
||||
match self.conns.remove(&id) {
|
||||
Some((_, mut stream2)) => {
|
||||
let parts = stream.into_parts();
|
||||
let mut parts = stream.into_parts();
|
||||
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
|
||||
stream2.write_all(&parts.read_buf).await?;
|
||||
proxy(parts.io, stream2).await?
|
||||
tokio::io::copy_bidirectional(&mut parts.io, &mut stream2).await?;
|
||||
}
|
||||
None => warn!(%id, "missing connection"),
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ use std::time::Duration;
|
||||
use anyhow::{Context, Result};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::io::{self, AsyncRead, AsyncWrite};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts};
|
||||
use tracing::trace;
|
||||
@ -97,18 +97,3 @@ impl<U: AsyncRead + AsyncWrite + Unpin> Delimited<U> {
|
||||
self.0.into_parts()
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy data mutually between two read/write streams.
|
||||
pub async fn proxy<S1, S2>(stream1: S1, stream2: S2) -> io::Result<()>
|
||||
where
|
||||
S1: AsyncRead + AsyncWrite + Unpin,
|
||||
S2: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let (mut s1_read, mut s1_write) = io::split(stream1);
|
||||
let (mut s2_read, mut s2_write) = io::split(stream2);
|
||||
tokio::select! {
|
||||
res = io::copy(&mut s1_read, &mut s2_write) => res,
|
||||
res = io::copy(&mut s2_read, &mut s1_write) => res,
|
||||
}?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -125,3 +125,40 @@ fn empty_port_range() {
|
||||
let max_port = 3000;
|
||||
let _ = Server::new(min_port..=max_port, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn half_closed_tcp_stream() -> Result<()> {
|
||||
// Check that "half-closed" TCP streams will not result in spontaneous hangups.
|
||||
let _guard = SERIAL_GUARD.lock().await;
|
||||
|
||||
spawn_server(None).await;
|
||||
let (listener, addr) = spawn_client(None).await?;
|
||||
|
||||
let (mut cli, (mut srv, _)) = tokio::try_join!(TcpStream::connect(addr), listener.accept())?;
|
||||
|
||||
// Send data before half-closing one of the streams.
|
||||
let mut buf = b"message before shutdown".to_vec();
|
||||
cli.write_all(&buf).await?;
|
||||
|
||||
// Only close the write half of the stream. This is a half-closed stream. In the
|
||||
// TCP protocol, it is represented as a FIN packet on one end. The entire stream
|
||||
// is only closed after two FINs are exchanged and ACKed by the other end.
|
||||
cli.shutdown().await?;
|
||||
|
||||
srv.read_exact(&mut buf).await?;
|
||||
assert_eq!(buf, b"message before shutdown");
|
||||
assert_eq!(srv.read(&mut buf).await?, 0); // EOF
|
||||
|
||||
// Now make sure that the other stream can still send data, despite
|
||||
// half-shutdown on client->server side.
|
||||
let mut buf = b"hello from the other side!".to_vec();
|
||||
srv.write_all(&buf).await?;
|
||||
cli.read_exact(&mut buf).await?;
|
||||
assert_eq!(buf, b"hello from the other side!");
|
||||
|
||||
// We don't have to think about CLOSE_RD handling because that's not really
|
||||
// part of the TCP protocol, just the POSIX streams API. It is implemented by
|
||||
// the OS ignoring future packets received on that stream.
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user