From f09dbe10cefa1769dd5eb3c24ccd3d73836b1085 Mon Sep 17 00:00:00 2001 From: Christian Nennemann Date: Wed, 4 Mar 2026 12:08:20 +0100 Subject: [PATCH] feat(rpc): auth handshake, server-push broker, audit logging - auth_handshake.rs: connection-init protocol (magic 0x01, token, ack) - push.rs: PushBroker manages per-identity push connections with gc - server.rs: ConnectionState, auth handshake on first bi-stream, pass identity_key/session_token to RequestContext per stream - client.rs: session_token in RpcClientConfig, auto auth handshake on connect - middleware.rs: log_rpc_call with SHA-256 redaction, hex_prefix helper - lib.rs: export auth_handshake and push modules --- crates/quicproquo-rpc/Cargo.toml | 1 + crates/quicproquo-rpc/src/auth_handshake.rs | 135 ++++++++++++++++++++ crates/quicproquo-rpc/src/client.rs | 17 +++ crates/quicproquo-rpc/src/lib.rs | 2 + crates/quicproquo-rpc/src/middleware.rs | 64 +++++++++- crates/quicproquo-rpc/src/push.rs | 114 +++++++++++++++++ crates/quicproquo-rpc/src/server.rs | 62 ++++++++- 7 files changed, 387 insertions(+), 8 deletions(-) create mode 100644 crates/quicproquo-rpc/src/auth_handshake.rs create mode 100644 crates/quicproquo-rpc/src/push.rs diff --git a/crates/quicproquo-rpc/Cargo.toml b/crates/quicproquo-rpc/Cargo.toml index eccc986..4882c44 100644 --- a/crates/quicproquo-rpc/Cargo.toml +++ b/crates/quicproquo-rpc/Cargo.toml @@ -17,6 +17,7 @@ tower = { workspace = true } tracing = { workspace = true } thiserror = { workspace = true } dashmap = { workspace = true } +sha2 = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["test-util"] } diff --git a/crates/quicproquo-rpc/src/auth_handshake.rs b/crates/quicproquo-rpc/src/auth_handshake.rs new file mode 100644 index 0000000..8b86fbd --- /dev/null +++ b/crates/quicproquo-rpc/src/auth_handshake.rs @@ -0,0 +1,135 @@ +//! Auth handshake protocol for QUIC connections. +//! +//! The session token is sent as a "connection init" message on the first +//! bi-stream before any RPC calls. +//! +//! ## Protocol +//! ```text +//! Client → Server: [0x01 magic][token_len: u16 BE][token bytes] +//! Server → Client: [0x01 magic][0x00 status OK] +//! ``` + +use crate::error::RpcError; + +/// Magic byte identifying an auth init frame. +pub const AUTH_INIT_MAGIC: u8 = 0x01; + +/// Status byte: auth accepted. +const AUTH_STATUS_OK: u8 = 0x00; + +/// Maximum token length (64 KiB). +const MAX_TOKEN_LEN: usize = 65535; + +/// Write an auth init frame to a QUIC send stream. +pub async fn send_auth_init( + send: &mut quinn::SendStream, + token: &[u8], +) -> Result<(), RpcError> { + if token.len() > MAX_TOKEN_LEN { + return Err(RpcError::Encode(format!( + "auth token too large: {} bytes (max {MAX_TOKEN_LEN})", + token.len() + ))); + } + + let token_len = token.len() as u16; + let mut buf = Vec::with_capacity(1 + 2 + token.len()); + buf.push(AUTH_INIT_MAGIC); + buf.extend_from_slice(&token_len.to_be_bytes()); + buf.extend_from_slice(token); + + send.write_all(&buf) + .await + .map_err(|e| RpcError::Connection(format!("send auth init: {e}")))?; + + Ok(()) +} + +/// Read an auth init frame from a QUIC recv stream. +pub async fn recv_auth_init( + recv: &mut quinn::RecvStream, +) -> Result, RpcError> { + // Read magic byte. + let mut header = [0u8; 3]; + recv.read_exact(&mut header) + .await + .map_err(|e| RpcError::Connection(format!("read auth init header: {e}")))?; + + if header[0] != AUTH_INIT_MAGIC { + return Err(RpcError::Decode(format!( + "bad auth init magic: expected 0x{AUTH_INIT_MAGIC:02x}, got 0x{:02x}", + header[0] + ))); + } + + let token_len = u16::from_be_bytes([header[1], header[2]]) as usize; + if token_len > MAX_TOKEN_LEN { + return Err(RpcError::Decode(format!( + "auth token length {token_len} exceeds max {MAX_TOKEN_LEN}" + ))); + } + + let mut token = vec![0u8; token_len]; + if token_len > 0 { + recv.read_exact(&mut token) + .await + .map_err(|e| RpcError::Connection(format!("read auth token: {e}")))?; + } + + Ok(token) +} + +/// Send auth ack (success response). +pub async fn send_auth_ack( + send: &mut quinn::SendStream, +) -> Result<(), RpcError> { + let buf = [AUTH_INIT_MAGIC, AUTH_STATUS_OK]; + send.write_all(&buf) + .await + .map_err(|e| RpcError::Connection(format!("send auth ack: {e}")))?; + Ok(()) +} + +/// Read auth ack from the server. +pub async fn recv_auth_ack( + recv: &mut quinn::RecvStream, +) -> Result<(), RpcError> { + let mut buf = [0u8; 2]; + recv.read_exact(&mut buf) + .await + .map_err(|e| RpcError::Connection(format!("read auth ack: {e}")))?; + + if buf[0] != AUTH_INIT_MAGIC { + return Err(RpcError::Decode(format!( + "bad auth ack magic: expected 0x{AUTH_INIT_MAGIC:02x}, got 0x{:02x}", + buf[0] + ))); + } + + if buf[1] != AUTH_STATUS_OK { + return Err(RpcError::Decode(format!( + "auth rejected with status 0x{:02x}", + buf[1] + ))); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn auth_init_magic_is_0x01() { + assert_eq!(AUTH_INIT_MAGIC, 0x01); + } + + #[test] + fn token_too_large_returns_encode_error() { + // We cannot call send_auth_init without a real stream, but we can + // verify the length check logic by constructing the guard condition. + let big_token = vec![0u8; MAX_TOKEN_LEN + 1]; + assert!(big_token.len() > MAX_TOKEN_LEN); + } +} diff --git a/crates/quicproquo-rpc/src/client.rs b/crates/quicproquo-rpc/src/client.rs index 268f3c2..56523f7 100644 --- a/crates/quicproquo-rpc/src/client.rs +++ b/crates/quicproquo-rpc/src/client.rs @@ -8,6 +8,7 @@ use quinn::{Connection, Endpoint}; use tokio::sync::mpsc; use tracing::{debug, warn}; +use crate::auth_handshake; use crate::error::{RpcError, RpcStatus}; use crate::framing::{PushFrame, RequestFrame, ResponseFrame}; @@ -21,6 +22,8 @@ pub struct RpcClientConfig { pub tls_config: Arc, /// ALPN protocol. pub alpn: Vec, + /// Session token to send during auth handshake. + pub session_token: Option>, } /// A QUIC RPC client connection. @@ -49,6 +52,20 @@ impl RpcClient { debug!(remote = %connection.remote_address(), "connected to RPC server"); + // Perform auth handshake if a session token was provided. + if let Some(ref token) = config.session_token { + let (mut send, mut recv) = connection + .open_bi() + .await + .map_err(|e| RpcError::Connection(format!("open auth stream: {e}")))?; + + auth_handshake::send_auth_init(&mut send, token).await?; + send.finish() + .map_err(|e| RpcError::Connection(format!("finish auth send: {e}")))?; + auth_handshake::recv_auth_ack(&mut recv).await?; + debug!("auth handshake complete"); + } + Ok(Self { connection, next_request_id: AtomicU32::new(1), diff --git a/crates/quicproquo-rpc/src/lib.rs b/crates/quicproquo-rpc/src/lib.rs index 4789425..ccf0c7c 100644 --- a/crates/quicproquo-rpc/src/lib.rs +++ b/crates/quicproquo-rpc/src/lib.rs @@ -5,9 +5,11 @@ //! - Response: `[status: u8][request_id: u32][payload_len: u32][protobuf bytes]` //! - Push: `[event_type: u16][payload_len: u32][protobuf bytes]` (uni-stream) +pub mod auth_handshake; pub mod framing; pub mod method; pub mod server; pub mod client; pub mod middleware; +pub mod push; pub mod error; diff --git a/crates/quicproquo-rpc/src/middleware.rs b/crates/quicproquo-rpc/src/middleware.rs index b0d62e2..5b3817e 100644 --- a/crates/quicproquo-rpc/src/middleware.rs +++ b/crates/quicproquo-rpc/src/middleware.rs @@ -1,11 +1,15 @@ -//! Tower-based middleware layers for the RPC server. +//! Middleware layers for the RPC server. //! -//! - `AuthLayer`: validates session tokens and attaches identity to context. -//! - `RateLimitLayer`: per-IP request rate limiting. +//! - `SessionValidator`: validates session tokens and resolves identity keys. +//! - `RateLimiter`: per-key sliding-window rate limiting. +//! - `log_rpc_call`: structured audit logging for RPC calls. use std::time::{Duration, Instant}; use dashmap::DashMap; +use sha2::Digest; + +use crate::error::RpcStatus; // ── Auth middleware ────────────────────────────────────────────────────────── @@ -70,6 +74,45 @@ impl RateLimiter { } } +// ── Audit logging ─────────────────────────────────────────────────────────── + +/// Log an RPC call with timing and caller info. +/// +/// When `redact` is true, the identity key is hashed before logging so that +/// raw keys never appear in log output. +pub fn log_rpc_call( + method_name: &str, + identity_key: Option<&[u8]>, + latency: Duration, + status: RpcStatus, + redact: bool, +) { + let ik_display = match identity_key { + Some(ik) if redact => { + let hash_input_len = 8.min(ik.len()); + let digest = sha2::Sha256::digest(&ik[..hash_input_len]); + format!("h:{}", hex_prefix(&digest)) + } + Some(ik) => hex_prefix(ik), + None => "anonymous".to_string(), + }; + tracing::info!( + method = method_name, + identity = %ik_display, + latency_ms = latency.as_millis() as u64, + status = ?status, + "rpc" + ); +} + +fn hex_prefix(bytes: &[u8]) -> String { + bytes + .iter() + .take(4) + .map(|b| format!("{b:02x}")) + .collect::() +} + #[cfg(test)] mod tests { use super::*; @@ -93,4 +136,19 @@ mod tests { std::thread::sleep(Duration::from_millis(5)); assert!(rl.check(key)); // window expired } + + #[test] + fn hex_prefix_formats_first_4_bytes() { + assert_eq!(hex_prefix(&[0xab, 0xcd, 0xef, 0x01, 0x99]), "abcdef01"); + assert_eq!(hex_prefix(&[0x00, 0xff]), "00ff"); + assert_eq!(hex_prefix(&[]), ""); + } + + #[test] + fn log_rpc_call_does_not_panic() { + // Verify that audit log function does not panic with various inputs. + log_rpc_call("test.method", None, Duration::from_millis(42), RpcStatus::Ok, false); + log_rpc_call("test.method", Some(&[1, 2, 3, 4, 5, 6, 7, 8]), Duration::from_millis(1), RpcStatus::Internal, true); + log_rpc_call("test.method", Some(&[0xab]), Duration::ZERO, RpcStatus::Unauthorized, true); + } } diff --git a/crates/quicproquo-rpc/src/push.rs b/crates/quicproquo-rpc/src/push.rs new file mode 100644 index 0000000..a205fde --- /dev/null +++ b/crates/quicproquo-rpc/src/push.rs @@ -0,0 +1,114 @@ +//! Server-push infrastructure — manages push connections per identity key. + +use bytes::Bytes; +use dashmap::DashMap; +use quinn::Connection; +use tracing::{debug, warn}; + +use crate::error::RpcError; + +/// Manages push connections per identity key. +pub struct PushBroker { + /// Map from identity_key to active QUIC connections. + connections: DashMap, Vec>, +} + +impl PushBroker { + pub fn new() -> Self { + Self { + connections: DashMap::new(), + } + } + + /// Register a connection for an identity. + pub fn register(&self, identity_key: Vec, connection: Connection) { + self.connections + .entry(identity_key) + .or_default() + .push(connection); + } + + /// Remove closed connections from the registry. + pub fn gc(&self) { + self.connections.alter_all(|_, mut conns| { + conns.retain(|c| c.close_reason().is_none()); + conns + }); + self.connections.retain(|_, conns| !conns.is_empty()); + } + + /// Send a push event to all connections for an identity. + /// Returns the number of successful sends. + pub async fn send_to( + &self, + identity_key: &[u8], + event_type: u16, + payload: Bytes, + ) -> usize { + let conns = match self.connections.get(identity_key) { + Some(entry) => entry.clone(), + None => return 0, + }; + + let mut sent = 0usize; + for conn in &conns { + match crate::server::send_push(conn, event_type, payload.clone()).await { + Ok(()) => { + sent += 1; + debug!("push sent to connection {}", conn.remote_address()); + } + Err(RpcError::Connection(e)) => { + warn!("push send failed to {}: {e}", conn.remote_address()); + } + Err(e) => { + warn!("push send error: {e}"); + } + } + } + sent + } + + /// Send a push event to all members of a channel. + /// `member_keys` is the list of identity keys in the channel. + /// Returns the total number of successful sends. + pub async fn send_to_channel( + &self, + member_keys: &[Vec], + event_type: u16, + payload: Bytes, + ) -> usize { + let mut total = 0usize; + for key in member_keys { + total += self.send_to(key, event_type, payload.clone()).await; + } + total + } + + /// Number of identities with registered connections. + pub fn identity_count(&self) -> usize { + self.connections.len() + } +} + +impl Default for PushBroker { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_broker_is_empty() { + let broker = PushBroker::new(); + assert_eq!(broker.identity_count(), 0); + } + + #[test] + fn default_broker_is_empty() { + let broker = PushBroker::default(); + assert_eq!(broker.identity_count(), 0); + } +} diff --git a/crates/quicproquo-rpc/src/server.rs b/crates/quicproquo-rpc/src/server.rs index a6bff6b..d6453a9 100644 --- a/crates/quicproquo-rpc/src/server.rs +++ b/crates/quicproquo-rpc/src/server.rs @@ -6,9 +6,20 @@ use bytes::BytesMut; use quinn::{Endpoint, Incoming, RecvStream, SendStream}; use tracing::{debug, info, warn}; +use crate::auth_handshake; use crate::error::{RpcError, RpcStatus}; use crate::framing::{RequestFrame, ResponseFrame, PushFrame}; use crate::method::{HandlerResult, MethodRegistry, RequestContext}; +use crate::middleware::SessionValidator; + +/// Per-connection state established during the auth handshake. +#[derive(Debug, Clone, Default)] +pub struct ConnectionState { + /// The raw session token provided during the auth init handshake. + pub session_token: Option>, + /// The identity key resolved from the session token (populated by validator). + pub identity_key: Option>, +} /// Configuration for the RPC server. pub struct RpcServerConfig { @@ -25,6 +36,8 @@ pub struct RpcServer { endpoint: Endpoint, state: Arc, registry: Arc>, + /// Optional session validator for the auth handshake. + validator: Option>, } impl RpcServer { @@ -49,17 +62,27 @@ impl RpcServer { endpoint, state, registry: Arc::new(registry), + validator: None, }) } + /// Set a session validator for the auth handshake. + pub fn with_validator(mut self, validator: Arc) -> Self { + self.validator = Some(validator); + self + } + /// Accept connections in a loop. Spawns a task per connection. pub async fn serve(self) -> Result<(), RpcError> { info!("RPC server accepting connections"); while let Some(incoming) = self.endpoint.accept().await { let state = Arc::clone(&self.state); let registry = Arc::clone(&self.registry); + let validator = self.validator.clone(); tokio::spawn(async move { - if let Err(e) = handle_connection(incoming, state, registry).await { + if let Err(e) = + handle_connection(incoming, state, registry, validator).await + { warn!("connection error: {e}"); } }); @@ -75,11 +98,13 @@ impl RpcServer { } } -/// Handle a single QUIC connection: accept bi-directional streams for RPCs. +/// Handle a single QUIC connection: perform auth handshake, then accept +/// bi-directional streams for RPCs. async fn handle_connection( incoming: Incoming, state: Arc, registry: Arc>, + validator: Option>, ) -> Result<(), RpcError> { let connection = incoming .await @@ -88,14 +113,40 @@ async fn handle_connection( let remote = connection.remote_address(); debug!(remote = %remote, "new connection"); + // Perform auth handshake on the first bi-stream. + let conn_state = { + let (mut send, mut recv) = connection + .accept_bi() + .await + .map_err(|e| RpcError::Connection(format!("accept auth stream: {e}")))?; + + let token = auth_handshake::recv_auth_init(&mut recv).await?; + debug!(remote = %remote, token_len = token.len(), "received auth init"); + + let identity_key = validator.as_ref().and_then(|v| v.validate(&token)); + + auth_handshake::send_auth_ack(&mut send).await?; + send.finish() + .map_err(|e| RpcError::Connection(format!("finish auth stream: {e}")))?; + + Arc::new(ConnectionState { + session_token: Some(token), + identity_key, + }) + }; + + // Accept RPC streams. loop { let stream = connection.accept_bi().await; match stream { Ok((send, recv)) => { let state = Arc::clone(&state); let registry = Arc::clone(®istry); + let conn_state = Arc::clone(&conn_state); tokio::spawn(async move { - if let Err(e) = handle_stream(send, recv, state, registry).await { + if let Err(e) = + handle_stream(send, recv, state, registry, &conn_state).await + { debug!("stream error: {e}"); } }); @@ -120,6 +171,7 @@ async fn handle_stream( mut recv: RecvStream, state: Arc, registry: Arc>, + conn_state: &ConnectionState, ) -> Result<(), RpcError> { // Read the complete request from the stream. let mut buf = BytesMut::new(); @@ -146,8 +198,8 @@ async fn handle_stream( Some((handler, name)) => { debug!(method_id = frame.method_id, method = name, req_id = frame.request_id, "dispatching"); let ctx = RequestContext { - identity_key: None, // populated by auth middleware - session_token: None, + identity_key: conn_state.identity_key.clone(), + session_token: conn_state.session_token.clone(), payload: frame.payload, }; handler(Arc::clone(&state), ctx).await