- auth_handshake.rs: connection-init protocol (magic 0x01, token, ack) - push.rs: PushBroker manages per-identity push connections with gc - server.rs: ConnectionState, auth handshake on first bi-stream, pass identity_key/session_token to RequestContext per stream - client.rs: session_token in RpcClientConfig, auto auth handshake on connect - middleware.rs: log_rpc_call with SHA-256 redaction, hex_prefix helper - lib.rs: export auth_handshake and push modules
155 lines
5.0 KiB
Rust
155 lines
5.0 KiB
Rust
//! Middleware layers for the RPC server.
|
|
//!
|
|
//! - `SessionValidator`: validates session tokens and resolves identity keys.
|
|
//! - `RateLimiter`: per-key sliding-window rate limiting.
|
|
//! - `log_rpc_call`: structured audit logging for RPC calls.
|
|
|
|
use std::time::{Duration, Instant};
|
|
|
|
use dashmap::DashMap;
|
|
use sha2::Digest;
|
|
|
|
use crate::error::RpcStatus;
|
|
|
|
// ── Auth middleware ──────────────────────────────────────────────────────────
|
|
|
|
/// Validates bearer tokens and resolves identity keys.
|
|
pub trait SessionValidator: Send + Sync + 'static {
|
|
/// Validate a session token, returning the identity key if valid.
|
|
fn validate(&self, token: &[u8]) -> Option<Vec<u8>>;
|
|
}
|
|
|
|
/// Auth context extracted from a validated session.
|
|
#[derive(Debug, Clone)]
|
|
pub struct AuthContext {
|
|
/// The Ed25519 identity key of the authenticated caller.
|
|
pub identity_key: Vec<u8>,
|
|
}
|
|
|
|
// ── Rate limiter ─────────────────────────────────────────────────────────────
|
|
|
|
/// Simple per-key sliding-window rate limiter.
|
|
pub struct RateLimiter {
|
|
/// Max requests per window.
|
|
max_requests: u32,
|
|
/// Window duration.
|
|
window: Duration,
|
|
/// Map from key → (count, window_start).
|
|
state: DashMap<Vec<u8>, (u32, Instant)>,
|
|
}
|
|
|
|
impl RateLimiter {
|
|
/// Create a new rate limiter.
|
|
pub fn new(max_requests: u32, window: Duration) -> Self {
|
|
Self {
|
|
max_requests,
|
|
window,
|
|
state: DashMap::new(),
|
|
}
|
|
}
|
|
|
|
/// Check if a request from `key` is allowed. Returns `true` if allowed.
|
|
pub fn check(&self, key: &[u8]) -> bool {
|
|
let now = Instant::now();
|
|
let mut entry = self.state.entry(key.to_vec()).or_insert((0, now));
|
|
let (count, window_start) = entry.value_mut();
|
|
|
|
if now.duration_since(*window_start) >= self.window {
|
|
// Reset window.
|
|
*count = 1;
|
|
*window_start = now;
|
|
true
|
|
} else if *count < self.max_requests {
|
|
*count += 1;
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Remove expired entries (call periodically for memory hygiene).
|
|
pub fn gc(&self) {
|
|
let now = Instant::now();
|
|
self.state.retain(|_, (_, start)| now.duration_since(*start) < self.window * 2);
|
|
}
|
|
}
|
|
|
|
// ── Audit logging ───────────────────────────────────────────────────────────
|
|
|
|
/// Log an RPC call with timing and caller info.
|
|
///
|
|
/// When `redact` is true, the identity key is hashed before logging so that
|
|
/// raw keys never appear in log output.
|
|
pub fn log_rpc_call(
|
|
method_name: &str,
|
|
identity_key: Option<&[u8]>,
|
|
latency: Duration,
|
|
status: RpcStatus,
|
|
redact: bool,
|
|
) {
|
|
let ik_display = match identity_key {
|
|
Some(ik) if redact => {
|
|
let hash_input_len = 8.min(ik.len());
|
|
let digest = sha2::Sha256::digest(&ik[..hash_input_len]);
|
|
format!("h:{}", hex_prefix(&digest))
|
|
}
|
|
Some(ik) => hex_prefix(ik),
|
|
None => "anonymous".to_string(),
|
|
};
|
|
tracing::info!(
|
|
method = method_name,
|
|
identity = %ik_display,
|
|
latency_ms = latency.as_millis() as u64,
|
|
status = ?status,
|
|
"rpc"
|
|
);
|
|
}
|
|
|
|
fn hex_prefix(bytes: &[u8]) -> String {
|
|
bytes
|
|
.iter()
|
|
.take(4)
|
|
.map(|b| format!("{b:02x}"))
|
|
.collect::<String>()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn rate_limiter_allows_within_limit() {
|
|
let rl = RateLimiter::new(3, Duration::from_secs(60));
|
|
let key = b"test-key";
|
|
assert!(rl.check(key));
|
|
assert!(rl.check(key));
|
|
assert!(rl.check(key));
|
|
assert!(!rl.check(key)); // 4th request denied
|
|
}
|
|
|
|
#[test]
|
|
fn rate_limiter_resets_after_window() {
|
|
let rl = RateLimiter::new(1, Duration::from_millis(1));
|
|
let key = b"test-key";
|
|
assert!(rl.check(key));
|
|
assert!(!rl.check(key));
|
|
std::thread::sleep(Duration::from_millis(5));
|
|
assert!(rl.check(key)); // window expired
|
|
}
|
|
|
|
#[test]
|
|
fn hex_prefix_formats_first_4_bytes() {
|
|
assert_eq!(hex_prefix(&[0xab, 0xcd, 0xef, 0x01, 0x99]), "abcdef01");
|
|
assert_eq!(hex_prefix(&[0x00, 0xff]), "00ff");
|
|
assert_eq!(hex_prefix(&[]), "");
|
|
}
|
|
|
|
#[test]
|
|
fn log_rpc_call_does_not_panic() {
|
|
// Verify that audit log function does not panic with various inputs.
|
|
log_rpc_call("test.method", None, Duration::from_millis(42), RpcStatus::Ok, false);
|
|
log_rpc_call("test.method", Some(&[1, 2, 3, 4, 5, 6, 7, 8]), Duration::from_millis(1), RpcStatus::Internal, true);
|
|
log_rpc_call("test.method", Some(&[0xab]), Duration::ZERO, RpcStatus::Unauthorized, true);
|
|
}
|
|
}
|