feat(rpc): auth handshake, server-push broker, audit logging

- auth_handshake.rs: connection-init protocol (magic 0x01, token, ack)
- push.rs: PushBroker manages per-identity push connections with gc
- server.rs: ConnectionState, auth handshake on first bi-stream, pass
  identity_key/session_token to RequestContext per stream
- client.rs: session_token in RpcClientConfig, auto auth handshake on connect
- middleware.rs: log_rpc_call with SHA-256 redaction, hex_prefix helper
- lib.rs: export auth_handshake and push modules
This commit is contained in:
2026-03-04 12:08:20 +01:00
parent ff93275dc1
commit f09dbe10ce
7 changed files with 387 additions and 8 deletions

View File

@@ -0,0 +1,135 @@
//! Auth handshake protocol for QUIC connections.
//!
//! The session token is sent as a "connection init" message on the first
//! bi-stream before any RPC calls.
//!
//! ## Protocol
//! ```text
//! Client → Server: [0x01 magic][token_len: u16 BE][token bytes]
//! Server → Client: [0x01 magic][0x00 status OK]
//! ```
use crate::error::RpcError;
/// Magic byte identifying an auth init frame.
pub const AUTH_INIT_MAGIC: u8 = 0x01;
/// Status byte: auth accepted.
const AUTH_STATUS_OK: u8 = 0x00;
/// Maximum token length (64 KiB).
const MAX_TOKEN_LEN: usize = 65535;
/// Write an auth init frame to a QUIC send stream.
pub async fn send_auth_init(
send: &mut quinn::SendStream,
token: &[u8],
) -> Result<(), RpcError> {
if token.len() > MAX_TOKEN_LEN {
return Err(RpcError::Encode(format!(
"auth token too large: {} bytes (max {MAX_TOKEN_LEN})",
token.len()
)));
}
let token_len = token.len() as u16;
let mut buf = Vec::with_capacity(1 + 2 + token.len());
buf.push(AUTH_INIT_MAGIC);
buf.extend_from_slice(&token_len.to_be_bytes());
buf.extend_from_slice(token);
send.write_all(&buf)
.await
.map_err(|e| RpcError::Connection(format!("send auth init: {e}")))?;
Ok(())
}
/// Read an auth init frame from a QUIC recv stream.
pub async fn recv_auth_init(
recv: &mut quinn::RecvStream,
) -> Result<Vec<u8>, RpcError> {
// Read magic byte.
let mut header = [0u8; 3];
recv.read_exact(&mut header)
.await
.map_err(|e| RpcError::Connection(format!("read auth init header: {e}")))?;
if header[0] != AUTH_INIT_MAGIC {
return Err(RpcError::Decode(format!(
"bad auth init magic: expected 0x{AUTH_INIT_MAGIC:02x}, got 0x{:02x}",
header[0]
)));
}
let token_len = u16::from_be_bytes([header[1], header[2]]) as usize;
if token_len > MAX_TOKEN_LEN {
return Err(RpcError::Decode(format!(
"auth token length {token_len} exceeds max {MAX_TOKEN_LEN}"
)));
}
let mut token = vec![0u8; token_len];
if token_len > 0 {
recv.read_exact(&mut token)
.await
.map_err(|e| RpcError::Connection(format!("read auth token: {e}")))?;
}
Ok(token)
}
/// Send auth ack (success response).
pub async fn send_auth_ack(
send: &mut quinn::SendStream,
) -> Result<(), RpcError> {
let buf = [AUTH_INIT_MAGIC, AUTH_STATUS_OK];
send.write_all(&buf)
.await
.map_err(|e| RpcError::Connection(format!("send auth ack: {e}")))?;
Ok(())
}
/// Read auth ack from the server.
pub async fn recv_auth_ack(
recv: &mut quinn::RecvStream,
) -> Result<(), RpcError> {
let mut buf = [0u8; 2];
recv.read_exact(&mut buf)
.await
.map_err(|e| RpcError::Connection(format!("read auth ack: {e}")))?;
if buf[0] != AUTH_INIT_MAGIC {
return Err(RpcError::Decode(format!(
"bad auth ack magic: expected 0x{AUTH_INIT_MAGIC:02x}, got 0x{:02x}",
buf[0]
)));
}
if buf[1] != AUTH_STATUS_OK {
return Err(RpcError::Decode(format!(
"auth rejected with status 0x{:02x}",
buf[1]
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auth_init_magic_is_0x01() {
assert_eq!(AUTH_INIT_MAGIC, 0x01);
}
#[test]
fn token_too_large_returns_encode_error() {
// We cannot call send_auth_init without a real stream, but we can
// verify the length check logic by constructing the guard condition.
let big_token = vec![0u8; MAX_TOKEN_LEN + 1];
assert!(big_token.len() > MAX_TOKEN_LEN);
}
}

View File

@@ -8,6 +8,7 @@ use quinn::{Connection, Endpoint};
use tokio::sync::mpsc;
use tracing::{debug, warn};
use crate::auth_handshake;
use crate::error::{RpcError, RpcStatus};
use crate::framing::{PushFrame, RequestFrame, ResponseFrame};
@@ -21,6 +22,8 @@ pub struct RpcClientConfig {
pub tls_config: Arc<rustls::ClientConfig>,
/// ALPN protocol.
pub alpn: Vec<u8>,
/// Session token to send during auth handshake.
pub session_token: Option<Vec<u8>>,
}
/// A QUIC RPC client connection.
@@ -49,6 +52,20 @@ impl RpcClient {
debug!(remote = %connection.remote_address(), "connected to RPC server");
// Perform auth handshake if a session token was provided.
if let Some(ref token) = config.session_token {
let (mut send, mut recv) = connection
.open_bi()
.await
.map_err(|e| RpcError::Connection(format!("open auth stream: {e}")))?;
auth_handshake::send_auth_init(&mut send, token).await?;
send.finish()
.map_err(|e| RpcError::Connection(format!("finish auth send: {e}")))?;
auth_handshake::recv_auth_ack(&mut recv).await?;
debug!("auth handshake complete");
}
Ok(Self {
connection,
next_request_id: AtomicU32::new(1),

View File

@@ -5,9 +5,11 @@
//! - Response: `[status: u8][request_id: u32][payload_len: u32][protobuf bytes]`
//! - Push: `[event_type: u16][payload_len: u32][protobuf bytes]` (uni-stream)
pub mod auth_handshake;
pub mod framing;
pub mod method;
pub mod server;
pub mod client;
pub mod middleware;
pub mod push;
pub mod error;

View File

@@ -1,11 +1,15 @@
//! Tower-based middleware layers for the RPC server.
//! Middleware layers for the RPC server.
//!
//! - `AuthLayer`: validates session tokens and attaches identity to context.
//! - `RateLimitLayer`: per-IP request rate limiting.
//! - `SessionValidator`: validates session tokens and resolves identity keys.
//! - `RateLimiter`: per-key sliding-window rate limiting.
//! - `log_rpc_call`: structured audit logging for RPC calls.
use std::time::{Duration, Instant};
use dashmap::DashMap;
use sha2::Digest;
use crate::error::RpcStatus;
// ── Auth middleware ──────────────────────────────────────────────────────────
@@ -70,6 +74,45 @@ impl RateLimiter {
}
}
// ── Audit logging ───────────────────────────────────────────────────────────
/// Log an RPC call with timing and caller info.
///
/// When `redact` is true, the identity key is hashed before logging so that
/// raw keys never appear in log output.
pub fn log_rpc_call(
method_name: &str,
identity_key: Option<&[u8]>,
latency: Duration,
status: RpcStatus,
redact: bool,
) {
let ik_display = match identity_key {
Some(ik) if redact => {
let hash_input_len = 8.min(ik.len());
let digest = sha2::Sha256::digest(&ik[..hash_input_len]);
format!("h:{}", hex_prefix(&digest))
}
Some(ik) => hex_prefix(ik),
None => "anonymous".to_string(),
};
tracing::info!(
method = method_name,
identity = %ik_display,
latency_ms = latency.as_millis() as u64,
status = ?status,
"rpc"
);
}
fn hex_prefix(bytes: &[u8]) -> String {
bytes
.iter()
.take(4)
.map(|b| format!("{b:02x}"))
.collect::<String>()
}
#[cfg(test)]
mod tests {
use super::*;
@@ -93,4 +136,19 @@ mod tests {
std::thread::sleep(Duration::from_millis(5));
assert!(rl.check(key)); // window expired
}
#[test]
fn hex_prefix_formats_first_4_bytes() {
assert_eq!(hex_prefix(&[0xab, 0xcd, 0xef, 0x01, 0x99]), "abcdef01");
assert_eq!(hex_prefix(&[0x00, 0xff]), "00ff");
assert_eq!(hex_prefix(&[]), "");
}
#[test]
fn log_rpc_call_does_not_panic() {
// Verify that audit log function does not panic with various inputs.
log_rpc_call("test.method", None, Duration::from_millis(42), RpcStatus::Ok, false);
log_rpc_call("test.method", Some(&[1, 2, 3, 4, 5, 6, 7, 8]), Duration::from_millis(1), RpcStatus::Internal, true);
log_rpc_call("test.method", Some(&[0xab]), Duration::ZERO, RpcStatus::Unauthorized, true);
}
}

View File

@@ -0,0 +1,114 @@
//! Server-push infrastructure — manages push connections per identity key.
use bytes::Bytes;
use dashmap::DashMap;
use quinn::Connection;
use tracing::{debug, warn};
use crate::error::RpcError;
/// Manages push connections per identity key.
pub struct PushBroker {
/// Map from identity_key to active QUIC connections.
connections: DashMap<Vec<u8>, Vec<Connection>>,
}
impl PushBroker {
pub fn new() -> Self {
Self {
connections: DashMap::new(),
}
}
/// Register a connection for an identity.
pub fn register(&self, identity_key: Vec<u8>, connection: Connection) {
self.connections
.entry(identity_key)
.or_default()
.push(connection);
}
/// Remove closed connections from the registry.
pub fn gc(&self) {
self.connections.alter_all(|_, mut conns| {
conns.retain(|c| c.close_reason().is_none());
conns
});
self.connections.retain(|_, conns| !conns.is_empty());
}
/// Send a push event to all connections for an identity.
/// Returns the number of successful sends.
pub async fn send_to(
&self,
identity_key: &[u8],
event_type: u16,
payload: Bytes,
) -> usize {
let conns = match self.connections.get(identity_key) {
Some(entry) => entry.clone(),
None => return 0,
};
let mut sent = 0usize;
for conn in &conns {
match crate::server::send_push(conn, event_type, payload.clone()).await {
Ok(()) => {
sent += 1;
debug!("push sent to connection {}", conn.remote_address());
}
Err(RpcError::Connection(e)) => {
warn!("push send failed to {}: {e}", conn.remote_address());
}
Err(e) => {
warn!("push send error: {e}");
}
}
}
sent
}
/// Send a push event to all members of a channel.
/// `member_keys` is the list of identity keys in the channel.
/// Returns the total number of successful sends.
pub async fn send_to_channel(
&self,
member_keys: &[Vec<u8>],
event_type: u16,
payload: Bytes,
) -> usize {
let mut total = 0usize;
for key in member_keys {
total += self.send_to(key, event_type, payload.clone()).await;
}
total
}
/// Number of identities with registered connections.
pub fn identity_count(&self) -> usize {
self.connections.len()
}
}
impl Default for PushBroker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_broker_is_empty() {
let broker = PushBroker::new();
assert_eq!(broker.identity_count(), 0);
}
#[test]
fn default_broker_is_empty() {
let broker = PushBroker::default();
assert_eq!(broker.identity_count(), 0);
}
}

View File

@@ -6,9 +6,20 @@ use bytes::BytesMut;
use quinn::{Endpoint, Incoming, RecvStream, SendStream};
use tracing::{debug, info, warn};
use crate::auth_handshake;
use crate::error::{RpcError, RpcStatus};
use crate::framing::{RequestFrame, ResponseFrame, PushFrame};
use crate::method::{HandlerResult, MethodRegistry, RequestContext};
use crate::middleware::SessionValidator;
/// Per-connection state established during the auth handshake.
#[derive(Debug, Clone, Default)]
pub struct ConnectionState {
/// The raw session token provided during the auth init handshake.
pub session_token: Option<Vec<u8>>,
/// The identity key resolved from the session token (populated by validator).
pub identity_key: Option<Vec<u8>>,
}
/// Configuration for the RPC server.
pub struct RpcServerConfig {
@@ -25,6 +36,8 @@ pub struct RpcServer<S: Send + Sync + 'static> {
endpoint: Endpoint,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
/// Optional session validator for the auth handshake.
validator: Option<Arc<dyn SessionValidator>>,
}
impl<S: Send + Sync + 'static> RpcServer<S> {
@@ -49,17 +62,27 @@ impl<S: Send + Sync + 'static> RpcServer<S> {
endpoint,
state,
registry: Arc::new(registry),
validator: None,
})
}
/// Set a session validator for the auth handshake.
pub fn with_validator(mut self, validator: Arc<dyn SessionValidator>) -> Self {
self.validator = Some(validator);
self
}
/// Accept connections in a loop. Spawns a task per connection.
pub async fn serve(self) -> Result<(), RpcError> {
info!("RPC server accepting connections");
while let Some(incoming) = self.endpoint.accept().await {
let state = Arc::clone(&self.state);
let registry = Arc::clone(&self.registry);
let validator = self.validator.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(incoming, state, registry).await {
if let Err(e) =
handle_connection(incoming, state, registry, validator).await
{
warn!("connection error: {e}");
}
});
@@ -75,11 +98,13 @@ impl<S: Send + Sync + 'static> RpcServer<S> {
}
}
/// Handle a single QUIC connection: accept bi-directional streams for RPCs.
/// Handle a single QUIC connection: perform auth handshake, then accept
/// bi-directional streams for RPCs.
async fn handle_connection<S: Send + Sync + 'static>(
incoming: Incoming,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
validator: Option<Arc<dyn SessionValidator>>,
) -> Result<(), RpcError> {
let connection = incoming
.await
@@ -88,14 +113,40 @@ async fn handle_connection<S: Send + Sync + 'static>(
let remote = connection.remote_address();
debug!(remote = %remote, "new connection");
// Perform auth handshake on the first bi-stream.
let conn_state = {
let (mut send, mut recv) = connection
.accept_bi()
.await
.map_err(|e| RpcError::Connection(format!("accept auth stream: {e}")))?;
let token = auth_handshake::recv_auth_init(&mut recv).await?;
debug!(remote = %remote, token_len = token.len(), "received auth init");
let identity_key = validator.as_ref().and_then(|v| v.validate(&token));
auth_handshake::send_auth_ack(&mut send).await?;
send.finish()
.map_err(|e| RpcError::Connection(format!("finish auth stream: {e}")))?;
Arc::new(ConnectionState {
session_token: Some(token),
identity_key,
})
};
// Accept RPC streams.
loop {
let stream = connection.accept_bi().await;
match stream {
Ok((send, recv)) => {
let state = Arc::clone(&state);
let registry = Arc::clone(&registry);
let conn_state = Arc::clone(&conn_state);
tokio::spawn(async move {
if let Err(e) = handle_stream(send, recv, state, registry).await {
if let Err(e) =
handle_stream(send, recv, state, registry, &conn_state).await
{
debug!("stream error: {e}");
}
});
@@ -120,6 +171,7 @@ async fn handle_stream<S: Send + Sync + 'static>(
mut recv: RecvStream,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
conn_state: &ConnectionState,
) -> Result<(), RpcError> {
// Read the complete request from the stream.
let mut buf = BytesMut::new();
@@ -146,8 +198,8 @@ async fn handle_stream<S: Send + Sync + 'static>(
Some((handler, name)) => {
debug!(method_id = frame.method_id, method = name, req_id = frame.request_id, "dispatching");
let ctx = RequestContext {
identity_key: None, // populated by auth middleware
session_token: None,
identity_key: conn_state.identity_key.clone(),
session_token: conn_state.session_token.clone(),
payload: frame.payload,
};
handler(Arc::clone(&state), ctx).await