//! Simple TCP mesh transport for testing and local networks. //! //! Uses length-prefixed framing (`[u32 BE length][payload]`) over raw TCP //! connections. Each send opens a new connection; each recv accepts one. use std::net::SocketAddr; use std::sync::Arc; use anyhow::{bail, Result}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use crate::transport::{MeshTransport, TransportAddr, TransportInfo, TransportPacket}; /// TCP mesh transport. /// /// Listens on a local port for incoming connections and sends packets by /// connecting to remote socket addresses. pub struct TcpTransport { listener: Arc, local_addr: SocketAddr, } impl TcpTransport { /// Bind a new TCP transport on the given address. /// /// Use `"127.0.0.1:0"` to let the OS assign a free port. pub async fn bind(addr: &str) -> Result { let listener = TcpListener::bind(addr).await?; let local_addr = listener.local_addr()?; tracing::info!(%local_addr, "TcpTransport listening"); Ok(Self { listener: Arc::new(listener), local_addr, }) } /// The local address this transport is listening on. pub fn local_addr(&self) -> SocketAddr { self.local_addr } /// Create a [`TransportAddr::Socket`] pointing to this transport's listen address. pub fn transport_addr(&self) -> TransportAddr { TransportAddr::Socket(self.local_addr) } } #[async_trait::async_trait] impl MeshTransport for TcpTransport { fn info(&self) -> TransportInfo { TransportInfo { name: "tcp".to_string(), mtu: 65535, bitrate: 1_000_000_000, bidirectional: true, } } async fn send(&self, dest: &TransportAddr, data: &[u8]) -> Result<()> { let addr = match dest { TransportAddr::Socket(addr) => *addr, other => bail!("TcpTransport cannot send to {other}"), }; let mut stream = TcpStream::connect(addr).await?; // Length-prefixed framing: [u32 BE length][payload]. let len = (data.len() as u32).to_be_bytes(); stream.write_all(&len).await?; stream.write_all(data).await?; stream.flush().await?; stream.shutdown().await?; tracing::debug!(%addr, bytes = data.len(), "TcpTransport: message sent"); Ok(()) } async fn recv(&self) -> Result { let (mut stream, peer_addr) = self.listener.accept().await?; // Read length-prefixed payload. let mut len_buf = [0u8; 4]; stream.read_exact(&mut len_buf).await?; let len = u32::from_be_bytes(len_buf) as usize; if len > 5 * 1024 * 1024 { bail!("payload too large: {len} bytes"); } let mut payload = vec![0u8; len]; stream.read_exact(&mut payload).await?; tracing::debug!(%peer_addr, bytes = len, "TcpTransport: message received"); Ok(TransportPacket { from: TransportAddr::Socket(peer_addr), data: payload, }) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn tcp_roundtrip() { let transport = TcpTransport::bind("127.0.0.1:0") .await .expect("bind TCP transport"); let dest = transport.transport_addr(); let payload = b"hello over TCP"; let recv_handle = tokio::spawn(async move { let packet = transport.recv().await.expect("recv packet"); assert_eq!(packet.data, payload.to_vec()); // Source should be a Socket address. match &packet.from { TransportAddr::Socket(_) => {} other => panic!("expected Socket addr, got {other}"), } }); // Give the listener a moment to be ready. tokio::time::sleep(std::time::Duration::from_millis(50)).await; // Send via a separate TcpTransport (simulating a different node). let sender = TcpTransport::bind("127.0.0.1:0") .await .expect("bind sender"); sender.send(&dest, payload).await.expect("send packet"); recv_handle.await.expect("recv task completed"); } #[tokio::test] async fn tcp_rejects_non_socket_addr() { let transport = TcpTransport::bind("127.0.0.1:0") .await .expect("bind TCP transport"); let bad_addr = TransportAddr::LoRa([0x01, 0x02, 0x03, 0x04]); let result = transport.send(&bad_addr, b"nope").await; assert!(result.is_err()); } }