mirror of
https://github.com/ekzhang/bore.git
synced 2025-07-05 16:02:25 +07:00
Use copy_bidirectional, handle half-closed TCP streams (#165)
This commit is contained in:
@ -8,9 +8,7 @@ use tracing::{error, info, info_span, warn, Instrument};
|
|||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::auth::Authenticator;
|
use crate::auth::Authenticator;
|
||||||
use crate::shared::{
|
use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT};
|
||||||
proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// State structure for the client.
|
/// State structure for the client.
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
@ -112,10 +110,10 @@ impl Client {
|
|||||||
}
|
}
|
||||||
remote_conn.send(ClientMessage::Accept(id)).await?;
|
remote_conn.send(ClientMessage::Accept(id)).await?;
|
||||||
let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
|
let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
|
||||||
let parts = remote_conn.into_parts();
|
let mut parts = remote_conn.into_parts();
|
||||||
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
|
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
|
||||||
local_conn.write_all(&parts.read_buf).await?; // mostly of the cases, this will be empty
|
local_conn.write_all(&parts.read_buf).await?; // mostly of the cases, this will be empty
|
||||||
proxy(local_conn, parts.io).await?;
|
tokio::io::copy_bidirectional(&mut local_conn, &mut parts.io).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,7 @@ use tracing::{info, info_span, warn, Instrument};
|
|||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::auth::Authenticator;
|
use crate::auth::Authenticator;
|
||||||
use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
|
use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT};
|
||||||
|
|
||||||
/// State structure for the server.
|
/// State structure for the server.
|
||||||
pub struct Server {
|
pub struct Server {
|
||||||
@ -172,10 +172,10 @@ impl Server {
|
|||||||
info!(%id, "forwarding connection");
|
info!(%id, "forwarding connection");
|
||||||
match self.conns.remove(&id) {
|
match self.conns.remove(&id) {
|
||||||
Some((_, mut stream2)) => {
|
Some((_, mut stream2)) => {
|
||||||
let parts = stream.into_parts();
|
let mut parts = stream.into_parts();
|
||||||
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
|
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
|
||||||
stream2.write_all(&parts.read_buf).await?;
|
stream2.write_all(&parts.read_buf).await?;
|
||||||
proxy(parts.io, stream2).await?
|
tokio::io::copy_bidirectional(&mut parts.io, &mut stream2).await?;
|
||||||
}
|
}
|
||||||
None => warn!(%id, "missing connection"),
|
None => warn!(%id, "missing connection"),
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@ use std::time::Duration;
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
use tokio::io::{self, AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts};
|
use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts};
|
||||||
use tracing::trace;
|
use tracing::trace;
|
||||||
@ -97,18 +97,3 @@ impl<U: AsyncRead + AsyncWrite + Unpin> Delimited<U> {
|
|||||||
self.0.into_parts()
|
self.0.into_parts()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Copy data mutually between two read/write streams.
|
|
||||||
pub async fn proxy<S1, S2>(stream1: S1, stream2: S2) -> io::Result<()>
|
|
||||||
where
|
|
||||||
S1: AsyncRead + AsyncWrite + Unpin,
|
|
||||||
S2: AsyncRead + AsyncWrite + Unpin,
|
|
||||||
{
|
|
||||||
let (mut s1_read, mut s1_write) = io::split(stream1);
|
|
||||||
let (mut s2_read, mut s2_write) = io::split(stream2);
|
|
||||||
tokio::select! {
|
|
||||||
res = io::copy(&mut s1_read, &mut s2_write) => res,
|
|
||||||
res = io::copy(&mut s2_read, &mut s1_write) => res,
|
|
||||||
}?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
@ -125,3 +125,40 @@ fn empty_port_range() {
|
|||||||
let max_port = 3000;
|
let max_port = 3000;
|
||||||
let _ = Server::new(min_port..=max_port, None);
|
let _ = Server::new(min_port..=max_port, None);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn half_closed_tcp_stream() -> Result<()> {
|
||||||
|
// Check that "half-closed" TCP streams will not result in spontaneous hangups.
|
||||||
|
let _guard = SERIAL_GUARD.lock().await;
|
||||||
|
|
||||||
|
spawn_server(None).await;
|
||||||
|
let (listener, addr) = spawn_client(None).await?;
|
||||||
|
|
||||||
|
let (mut cli, (mut srv, _)) = tokio::try_join!(TcpStream::connect(addr), listener.accept())?;
|
||||||
|
|
||||||
|
// Send data before half-closing one of the streams.
|
||||||
|
let mut buf = b"message before shutdown".to_vec();
|
||||||
|
cli.write_all(&buf).await?;
|
||||||
|
|
||||||
|
// Only close the write half of the stream. This is a half-closed stream. In the
|
||||||
|
// TCP protocol, it is represented as a FIN packet on one end. The entire stream
|
||||||
|
// is only closed after two FINs are exchanged and ACKed by the other end.
|
||||||
|
cli.shutdown().await?;
|
||||||
|
|
||||||
|
srv.read_exact(&mut buf).await?;
|
||||||
|
assert_eq!(buf, b"message before shutdown");
|
||||||
|
assert_eq!(srv.read(&mut buf).await?, 0); // EOF
|
||||||
|
|
||||||
|
// Now make sure that the other stream can still send data, despite
|
||||||
|
// half-shutdown on client->server side.
|
||||||
|
let mut buf = b"hello from the other side!".to_vec();
|
||||||
|
srv.write_all(&buf).await?;
|
||||||
|
cli.read_exact(&mut buf).await?;
|
||||||
|
assert_eq!(buf, b"hello from the other side!");
|
||||||
|
|
||||||
|
// We don't have to think about CLOSE_RD handling because that's not really
|
||||||
|
// part of the TCP protocol, just the POSIX streams API. It is implemented by
|
||||||
|
// the OS ignoring future packets received on that stream.
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user