Files
quicproquo/crates/quicproquo-rpc/src/middleware.rs
Christian Nennemann f09dbe10ce feat(rpc): auth handshake, server-push broker, audit logging
- auth_handshake.rs: connection-init protocol (magic 0x01, token, ack)
- push.rs: PushBroker manages per-identity push connections with gc
- server.rs: ConnectionState, auth handshake on first bi-stream, pass
  identity_key/session_token to RequestContext per stream
- client.rs: session_token in RpcClientConfig, auto auth handshake on connect
- middleware.rs: log_rpc_call with SHA-256 redaction, hex_prefix helper
- lib.rs: export auth_handshake and push modules
2026-03-04 12:08:20 +01:00

155 lines
5.0 KiB
Rust

//! Middleware layers for the RPC server.
//!
//! - `SessionValidator`: validates session tokens and resolves identity keys.
//! - `RateLimiter`: per-key sliding-window rate limiting.
//! - `log_rpc_call`: structured audit logging for RPC calls.
use std::time::{Duration, Instant};
use dashmap::DashMap;
use sha2::Digest;
use crate::error::RpcStatus;
// ── Auth middleware ──────────────────────────────────────────────────────────
/// Validates bearer tokens and resolves identity keys.
pub trait SessionValidator: Send + Sync + 'static {
/// Validate a session token, returning the identity key if valid.
fn validate(&self, token: &[u8]) -> Option<Vec<u8>>;
}
/// Auth context extracted from a validated session.
#[derive(Debug, Clone)]
pub struct AuthContext {
/// The Ed25519 identity key of the authenticated caller.
pub identity_key: Vec<u8>,
}
// ── Rate limiter ─────────────────────────────────────────────────────────────
/// Simple per-key sliding-window rate limiter.
pub struct RateLimiter {
/// Max requests per window.
max_requests: u32,
/// Window duration.
window: Duration,
/// Map from key → (count, window_start).
state: DashMap<Vec<u8>, (u32, Instant)>,
}
impl RateLimiter {
/// Create a new rate limiter.
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
max_requests,
window,
state: DashMap::new(),
}
}
/// Check if a request from `key` is allowed. Returns `true` if allowed.
pub fn check(&self, key: &[u8]) -> bool {
let now = Instant::now();
let mut entry = self.state.entry(key.to_vec()).or_insert((0, now));
let (count, window_start) = entry.value_mut();
if now.duration_since(*window_start) >= self.window {
// Reset window.
*count = 1;
*window_start = now;
true
} else if *count < self.max_requests {
*count += 1;
true
} else {
false
}
}
/// Remove expired entries (call periodically for memory hygiene).
pub fn gc(&self) {
let now = Instant::now();
self.state.retain(|_, (_, start)| now.duration_since(*start) < self.window * 2);
}
}
// ── Audit logging ───────────────────────────────────────────────────────────
/// Log an RPC call with timing and caller info.
///
/// When `redact` is true, the identity key is hashed before logging so that
/// raw keys never appear in log output.
pub fn log_rpc_call(
method_name: &str,
identity_key: Option<&[u8]>,
latency: Duration,
status: RpcStatus,
redact: bool,
) {
let ik_display = match identity_key {
Some(ik) if redact => {
let hash_input_len = 8.min(ik.len());
let digest = sha2::Sha256::digest(&ik[..hash_input_len]);
format!("h:{}", hex_prefix(&digest))
}
Some(ik) => hex_prefix(ik),
None => "anonymous".to_string(),
};
tracing::info!(
method = method_name,
identity = %ik_display,
latency_ms = latency.as_millis() as u64,
status = ?status,
"rpc"
);
}
fn hex_prefix(bytes: &[u8]) -> String {
bytes
.iter()
.take(4)
.map(|b| format!("{b:02x}"))
.collect::<String>()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rate_limiter_allows_within_limit() {
let rl = RateLimiter::new(3, Duration::from_secs(60));
let key = b"test-key";
assert!(rl.check(key));
assert!(rl.check(key));
assert!(rl.check(key));
assert!(!rl.check(key)); // 4th request denied
}
#[test]
fn rate_limiter_resets_after_window() {
let rl = RateLimiter::new(1, Duration::from_millis(1));
let key = b"test-key";
assert!(rl.check(key));
assert!(!rl.check(key));
std::thread::sleep(Duration::from_millis(5));
assert!(rl.check(key)); // window expired
}
#[test]
fn hex_prefix_formats_first_4_bytes() {
assert_eq!(hex_prefix(&[0xab, 0xcd, 0xef, 0x01, 0x99]), "abcdef01");
assert_eq!(hex_prefix(&[0x00, 0xff]), "00ff");
assert_eq!(hex_prefix(&[]), "");
}
#[test]
fn log_rpc_call_does_not_panic() {
// Verify that audit log function does not panic with various inputs.
log_rpc_call("test.method", None, Duration::from_millis(42), RpcStatus::Ok, false);
log_rpc_call("test.method", Some(&[1, 2, 3, 4, 5, 6, 7, 8]), Duration::from_millis(1), RpcStatus::Internal, true);
log_rpc_call("test.method", Some(&[0xab]), Duration::ZERO, RpcStatus::Unauthorized, true);
}
}