//! Middleware layers for the RPC server. //! //! - `SessionValidator`: validates session tokens and resolves identity keys. //! - `RateLimiter`: per-key sliding-window rate limiting. //! - `log_rpc_call`: structured audit logging for RPC calls. use std::time::{Duration, Instant}; use dashmap::DashMap; use sha2::Digest; use crate::error::RpcStatus; // ── Auth middleware ────────────────────────────────────────────────────────── /// Validates bearer tokens and resolves identity keys. pub trait SessionValidator: Send + Sync + 'static { /// Validate a session token, returning the identity key if valid. fn validate(&self, token: &[u8]) -> Option>; } /// Auth context extracted from a validated session. #[derive(Debug, Clone)] pub struct AuthContext { /// The Ed25519 identity key of the authenticated caller. pub identity_key: Vec, } // ── Rate limiter ───────────────────────────────────────────────────────────── /// Simple per-key sliding-window rate limiter. pub struct RateLimiter { /// Max requests per window. max_requests: u32, /// Window duration. window: Duration, /// Map from key → (count, window_start). state: DashMap, (u32, Instant)>, } impl RateLimiter { /// Create a new rate limiter. pub fn new(max_requests: u32, window: Duration) -> Self { Self { max_requests, window, state: DashMap::new(), } } /// Check if a request from `key` is allowed. Returns `true` if allowed. pub fn check(&self, key: &[u8]) -> bool { let now = Instant::now(); let mut entry = self.state.entry(key.to_vec()).or_insert((0, now)); let (count, window_start) = entry.value_mut(); if now.duration_since(*window_start) >= self.window { // Reset window. *count = 1; *window_start = now; true } else if *count < self.max_requests { *count += 1; true } else { false } } /// Remove expired entries (call periodically for memory hygiene). pub fn gc(&self) { let now = Instant::now(); self.state.retain(|_, (_, start)| now.duration_since(*start) < self.window * 2); } } // ── Audit logging ─────────────────────────────────────────────────────────── /// Log an RPC call with timing and caller info. /// /// When `redact` is true, the identity key is hashed before logging so that /// raw keys never appear in log output. pub fn log_rpc_call( method_name: &str, identity_key: Option<&[u8]>, latency: Duration, status: RpcStatus, redact: bool, ) { let ik_display = match identity_key { Some(ik) if redact => { let hash_input_len = 8.min(ik.len()); let digest = sha2::Sha256::digest(&ik[..hash_input_len]); format!("h:{}", hex_prefix(&digest)) } Some(ik) => hex_prefix(ik), None => "anonymous".to_string(), }; tracing::info!( method = method_name, identity = %ik_display, latency_ms = latency.as_millis() as u64, status = ?status, "rpc" ); } fn hex_prefix(bytes: &[u8]) -> String { bytes .iter() .take(4) .map(|b| format!("{b:02x}")) .collect::() } #[cfg(test)] mod tests { use super::*; #[test] fn rate_limiter_allows_within_limit() { let rl = RateLimiter::new(3, Duration::from_secs(60)); let key = b"test-key"; assert!(rl.check(key)); assert!(rl.check(key)); assert!(rl.check(key)); assert!(!rl.check(key)); // 4th request denied } #[test] fn rate_limiter_resets_after_window() { let rl = RateLimiter::new(1, Duration::from_millis(1)); let key = b"test-key"; assert!(rl.check(key)); assert!(!rl.check(key)); std::thread::sleep(Duration::from_millis(5)); assert!(rl.check(key)); // window expired } #[test] fn hex_prefix_formats_first_4_bytes() { assert_eq!(hex_prefix(&[0xab, 0xcd, 0xef, 0x01, 0x99]), "abcdef01"); assert_eq!(hex_prefix(&[0x00, 0xff]), "00ff"); assert_eq!(hex_prefix(&[]), ""); } #[test] fn log_rpc_call_does_not_panic() { // Verify that audit log function does not panic with various inputs. log_rpc_call("test.method", None, Duration::from_millis(42), RpcStatus::Ok, false); log_rpc_call("test.method", Some(&[1, 2, 3, 4, 5, 6, 7, 8]), Duration::from_millis(1), RpcStatus::Internal, true); log_rpc_call("test.method", Some(&[0xab]), Duration::ZERO, RpcStatus::Unauthorized, true); } }