feat(rpc): auth handshake, server-push broker, audit logging
- auth_handshake.rs: connection-init protocol (magic 0x01, token, ack) - push.rs: PushBroker manages per-identity push connections with gc - server.rs: ConnectionState, auth handshake on first bi-stream, pass identity_key/session_token to RequestContext per stream - client.rs: session_token in RpcClientConfig, auto auth handshake on connect - middleware.rs: log_rpc_call with SHA-256 redaction, hex_prefix helper - lib.rs: export auth_handshake and push modules
This commit is contained in:
135
crates/quicproquo-rpc/src/auth_handshake.rs
Normal file
135
crates/quicproquo-rpc/src/auth_handshake.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
//! Auth handshake protocol for QUIC connections.
|
||||
//!
|
||||
//! The session token is sent as a "connection init" message on the first
|
||||
//! bi-stream before any RPC calls.
|
||||
//!
|
||||
//! ## Protocol
|
||||
//! ```text
|
||||
//! Client → Server: [0x01 magic][token_len: u16 BE][token bytes]
|
||||
//! Server → Client: [0x01 magic][0x00 status OK]
|
||||
//! ```
|
||||
|
||||
use crate::error::RpcError;
|
||||
|
||||
/// Magic byte identifying an auth init frame.
|
||||
pub const AUTH_INIT_MAGIC: u8 = 0x01;
|
||||
|
||||
/// Status byte: auth accepted.
|
||||
const AUTH_STATUS_OK: u8 = 0x00;
|
||||
|
||||
/// Maximum token length (64 KiB).
|
||||
const MAX_TOKEN_LEN: usize = 65535;
|
||||
|
||||
/// Write an auth init frame to a QUIC send stream.
|
||||
pub async fn send_auth_init(
|
||||
send: &mut quinn::SendStream,
|
||||
token: &[u8],
|
||||
) -> Result<(), RpcError> {
|
||||
if token.len() > MAX_TOKEN_LEN {
|
||||
return Err(RpcError::Encode(format!(
|
||||
"auth token too large: {} bytes (max {MAX_TOKEN_LEN})",
|
||||
token.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let token_len = token.len() as u16;
|
||||
let mut buf = Vec::with_capacity(1 + 2 + token.len());
|
||||
buf.push(AUTH_INIT_MAGIC);
|
||||
buf.extend_from_slice(&token_len.to_be_bytes());
|
||||
buf.extend_from_slice(token);
|
||||
|
||||
send.write_all(&buf)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("send auth init: {e}")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read an auth init frame from a QUIC recv stream.
|
||||
pub async fn recv_auth_init(
|
||||
recv: &mut quinn::RecvStream,
|
||||
) -> Result<Vec<u8>, RpcError> {
|
||||
// Read magic byte.
|
||||
let mut header = [0u8; 3];
|
||||
recv.read_exact(&mut header)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("read auth init header: {e}")))?;
|
||||
|
||||
if header[0] != AUTH_INIT_MAGIC {
|
||||
return Err(RpcError::Decode(format!(
|
||||
"bad auth init magic: expected 0x{AUTH_INIT_MAGIC:02x}, got 0x{:02x}",
|
||||
header[0]
|
||||
)));
|
||||
}
|
||||
|
||||
let token_len = u16::from_be_bytes([header[1], header[2]]) as usize;
|
||||
if token_len > MAX_TOKEN_LEN {
|
||||
return Err(RpcError::Decode(format!(
|
||||
"auth token length {token_len} exceeds max {MAX_TOKEN_LEN}"
|
||||
)));
|
||||
}
|
||||
|
||||
let mut token = vec![0u8; token_len];
|
||||
if token_len > 0 {
|
||||
recv.read_exact(&mut token)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("read auth token: {e}")))?;
|
||||
}
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Send auth ack (success response).
|
||||
pub async fn send_auth_ack(
|
||||
send: &mut quinn::SendStream,
|
||||
) -> Result<(), RpcError> {
|
||||
let buf = [AUTH_INIT_MAGIC, AUTH_STATUS_OK];
|
||||
send.write_all(&buf)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("send auth ack: {e}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read auth ack from the server.
|
||||
pub async fn recv_auth_ack(
|
||||
recv: &mut quinn::RecvStream,
|
||||
) -> Result<(), RpcError> {
|
||||
let mut buf = [0u8; 2];
|
||||
recv.read_exact(&mut buf)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("read auth ack: {e}")))?;
|
||||
|
||||
if buf[0] != AUTH_INIT_MAGIC {
|
||||
return Err(RpcError::Decode(format!(
|
||||
"bad auth ack magic: expected 0x{AUTH_INIT_MAGIC:02x}, got 0x{:02x}",
|
||||
buf[0]
|
||||
)));
|
||||
}
|
||||
|
||||
if buf[1] != AUTH_STATUS_OK {
|
||||
return Err(RpcError::Decode(format!(
|
||||
"auth rejected with status 0x{:02x}",
|
||||
buf[1]
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn auth_init_magic_is_0x01() {
|
||||
assert_eq!(AUTH_INIT_MAGIC, 0x01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_too_large_returns_encode_error() {
|
||||
// We cannot call send_auth_init without a real stream, but we can
|
||||
// verify the length check logic by constructing the guard condition.
|
||||
let big_token = vec![0u8; MAX_TOKEN_LEN + 1];
|
||||
assert!(big_token.len() > MAX_TOKEN_LEN);
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ use quinn::{Connection, Endpoint};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::auth_handshake;
|
||||
use crate::error::{RpcError, RpcStatus};
|
||||
use crate::framing::{PushFrame, RequestFrame, ResponseFrame};
|
||||
|
||||
@@ -21,6 +22,8 @@ pub struct RpcClientConfig {
|
||||
pub tls_config: Arc<rustls::ClientConfig>,
|
||||
/// ALPN protocol.
|
||||
pub alpn: Vec<u8>,
|
||||
/// Session token to send during auth handshake.
|
||||
pub session_token: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
/// A QUIC RPC client connection.
|
||||
@@ -49,6 +52,20 @@ impl RpcClient {
|
||||
|
||||
debug!(remote = %connection.remote_address(), "connected to RPC server");
|
||||
|
||||
// Perform auth handshake if a session token was provided.
|
||||
if let Some(ref token) = config.session_token {
|
||||
let (mut send, mut recv) = connection
|
||||
.open_bi()
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("open auth stream: {e}")))?;
|
||||
|
||||
auth_handshake::send_auth_init(&mut send, token).await?;
|
||||
send.finish()
|
||||
.map_err(|e| RpcError::Connection(format!("finish auth send: {e}")))?;
|
||||
auth_handshake::recv_auth_ack(&mut recv).await?;
|
||||
debug!("auth handshake complete");
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
connection,
|
||||
next_request_id: AtomicU32::new(1),
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
//! - Response: `[status: u8][request_id: u32][payload_len: u32][protobuf bytes]`
|
||||
//! - Push: `[event_type: u16][payload_len: u32][protobuf bytes]` (uni-stream)
|
||||
|
||||
pub mod auth_handshake;
|
||||
pub mod framing;
|
||||
pub mod method;
|
||||
pub mod server;
|
||||
pub mod client;
|
||||
pub mod middleware;
|
||||
pub mod push;
|
||||
pub mod error;
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
//! Tower-based middleware layers for the RPC server.
|
||||
//! Middleware layers for the RPC server.
|
||||
//!
|
||||
//! - `AuthLayer`: validates session tokens and attaches identity to context.
|
||||
//! - `RateLimitLayer`: per-IP request rate limiting.
|
||||
//! - `SessionValidator`: validates session tokens and resolves identity keys.
|
||||
//! - `RateLimiter`: per-key sliding-window rate limiting.
|
||||
//! - `log_rpc_call`: structured audit logging for RPC calls.
|
||||
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use sha2::Digest;
|
||||
|
||||
use crate::error::RpcStatus;
|
||||
|
||||
// ── Auth middleware ──────────────────────────────────────────────────────────
|
||||
|
||||
@@ -70,6 +74,45 @@ impl RateLimiter {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Audit logging ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Log an RPC call with timing and caller info.
|
||||
///
|
||||
/// When `redact` is true, the identity key is hashed before logging so that
|
||||
/// raw keys never appear in log output.
|
||||
pub fn log_rpc_call(
|
||||
method_name: &str,
|
||||
identity_key: Option<&[u8]>,
|
||||
latency: Duration,
|
||||
status: RpcStatus,
|
||||
redact: bool,
|
||||
) {
|
||||
let ik_display = match identity_key {
|
||||
Some(ik) if redact => {
|
||||
let hash_input_len = 8.min(ik.len());
|
||||
let digest = sha2::Sha256::digest(&ik[..hash_input_len]);
|
||||
format!("h:{}", hex_prefix(&digest))
|
||||
}
|
||||
Some(ik) => hex_prefix(ik),
|
||||
None => "anonymous".to_string(),
|
||||
};
|
||||
tracing::info!(
|
||||
method = method_name,
|
||||
identity = %ik_display,
|
||||
latency_ms = latency.as_millis() as u64,
|
||||
status = ?status,
|
||||
"rpc"
|
||||
);
|
||||
}
|
||||
|
||||
fn hex_prefix(bytes: &[u8]) -> String {
|
||||
bytes
|
||||
.iter()
|
||||
.take(4)
|
||||
.map(|b| format!("{b:02x}"))
|
||||
.collect::<String>()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -93,4 +136,19 @@ mod tests {
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
assert!(rl.check(key)); // window expired
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hex_prefix_formats_first_4_bytes() {
|
||||
assert_eq!(hex_prefix(&[0xab, 0xcd, 0xef, 0x01, 0x99]), "abcdef01");
|
||||
assert_eq!(hex_prefix(&[0x00, 0xff]), "00ff");
|
||||
assert_eq!(hex_prefix(&[]), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_rpc_call_does_not_panic() {
|
||||
// Verify that audit log function does not panic with various inputs.
|
||||
log_rpc_call("test.method", None, Duration::from_millis(42), RpcStatus::Ok, false);
|
||||
log_rpc_call("test.method", Some(&[1, 2, 3, 4, 5, 6, 7, 8]), Duration::from_millis(1), RpcStatus::Internal, true);
|
||||
log_rpc_call("test.method", Some(&[0xab]), Duration::ZERO, RpcStatus::Unauthorized, true);
|
||||
}
|
||||
}
|
||||
|
||||
114
crates/quicproquo-rpc/src/push.rs
Normal file
114
crates/quicproquo-rpc/src/push.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
//! Server-push infrastructure — manages push connections per identity key.
|
||||
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use quinn::Connection;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::error::RpcError;
|
||||
|
||||
/// Manages push connections per identity key.
|
||||
pub struct PushBroker {
|
||||
/// Map from identity_key to active QUIC connections.
|
||||
connections: DashMap<Vec<u8>, Vec<Connection>>,
|
||||
}
|
||||
|
||||
impl PushBroker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
connections: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a connection for an identity.
|
||||
pub fn register(&self, identity_key: Vec<u8>, connection: Connection) {
|
||||
self.connections
|
||||
.entry(identity_key)
|
||||
.or_default()
|
||||
.push(connection);
|
||||
}
|
||||
|
||||
/// Remove closed connections from the registry.
|
||||
pub fn gc(&self) {
|
||||
self.connections.alter_all(|_, mut conns| {
|
||||
conns.retain(|c| c.close_reason().is_none());
|
||||
conns
|
||||
});
|
||||
self.connections.retain(|_, conns| !conns.is_empty());
|
||||
}
|
||||
|
||||
/// Send a push event to all connections for an identity.
|
||||
/// Returns the number of successful sends.
|
||||
pub async fn send_to(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
event_type: u16,
|
||||
payload: Bytes,
|
||||
) -> usize {
|
||||
let conns = match self.connections.get(identity_key) {
|
||||
Some(entry) => entry.clone(),
|
||||
None => return 0,
|
||||
};
|
||||
|
||||
let mut sent = 0usize;
|
||||
for conn in &conns {
|
||||
match crate::server::send_push(conn, event_type, payload.clone()).await {
|
||||
Ok(()) => {
|
||||
sent += 1;
|
||||
debug!("push sent to connection {}", conn.remote_address());
|
||||
}
|
||||
Err(RpcError::Connection(e)) => {
|
||||
warn!("push send failed to {}: {e}", conn.remote_address());
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("push send error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
sent
|
||||
}
|
||||
|
||||
/// Send a push event to all members of a channel.
|
||||
/// `member_keys` is the list of identity keys in the channel.
|
||||
/// Returns the total number of successful sends.
|
||||
pub async fn send_to_channel(
|
||||
&self,
|
||||
member_keys: &[Vec<u8>],
|
||||
event_type: u16,
|
||||
payload: Bytes,
|
||||
) -> usize {
|
||||
let mut total = 0usize;
|
||||
for key in member_keys {
|
||||
total += self.send_to(key, event_type, payload.clone()).await;
|
||||
}
|
||||
total
|
||||
}
|
||||
|
||||
/// Number of identities with registered connections.
|
||||
pub fn identity_count(&self) -> usize {
|
||||
self.connections.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PushBroker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_broker_is_empty() {
|
||||
let broker = PushBroker::new();
|
||||
assert_eq!(broker.identity_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_broker_is_empty() {
|
||||
let broker = PushBroker::default();
|
||||
assert_eq!(broker.identity_count(), 0);
|
||||
}
|
||||
}
|
||||
@@ -6,9 +6,20 @@ use bytes::BytesMut;
|
||||
use quinn::{Endpoint, Incoming, RecvStream, SendStream};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth_handshake;
|
||||
use crate::error::{RpcError, RpcStatus};
|
||||
use crate::framing::{RequestFrame, ResponseFrame, PushFrame};
|
||||
use crate::method::{HandlerResult, MethodRegistry, RequestContext};
|
||||
use crate::middleware::SessionValidator;
|
||||
|
||||
/// Per-connection state established during the auth handshake.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ConnectionState {
|
||||
/// The raw session token provided during the auth init handshake.
|
||||
pub session_token: Option<Vec<u8>>,
|
||||
/// The identity key resolved from the session token (populated by validator).
|
||||
pub identity_key: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
/// Configuration for the RPC server.
|
||||
pub struct RpcServerConfig {
|
||||
@@ -25,6 +36,8 @@ pub struct RpcServer<S: Send + Sync + 'static> {
|
||||
endpoint: Endpoint,
|
||||
state: Arc<S>,
|
||||
registry: Arc<MethodRegistry<S>>,
|
||||
/// Optional session validator for the auth handshake.
|
||||
validator: Option<Arc<dyn SessionValidator>>,
|
||||
}
|
||||
|
||||
impl<S: Send + Sync + 'static> RpcServer<S> {
|
||||
@@ -49,17 +62,27 @@ impl<S: Send + Sync + 'static> RpcServer<S> {
|
||||
endpoint,
|
||||
state,
|
||||
registry: Arc::new(registry),
|
||||
validator: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set a session validator for the auth handshake.
|
||||
pub fn with_validator(mut self, validator: Arc<dyn SessionValidator>) -> Self {
|
||||
self.validator = Some(validator);
|
||||
self
|
||||
}
|
||||
|
||||
/// Accept connections in a loop. Spawns a task per connection.
|
||||
pub async fn serve(self) -> Result<(), RpcError> {
|
||||
info!("RPC server accepting connections");
|
||||
while let Some(incoming) = self.endpoint.accept().await {
|
||||
let state = Arc::clone(&self.state);
|
||||
let registry = Arc::clone(&self.registry);
|
||||
let validator = self.validator.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_connection(incoming, state, registry).await {
|
||||
if let Err(e) =
|
||||
handle_connection(incoming, state, registry, validator).await
|
||||
{
|
||||
warn!("connection error: {e}");
|
||||
}
|
||||
});
|
||||
@@ -75,11 +98,13 @@ impl<S: Send + Sync + 'static> RpcServer<S> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a single QUIC connection: accept bi-directional streams for RPCs.
|
||||
/// Handle a single QUIC connection: perform auth handshake, then accept
|
||||
/// bi-directional streams for RPCs.
|
||||
async fn handle_connection<S: Send + Sync + 'static>(
|
||||
incoming: Incoming,
|
||||
state: Arc<S>,
|
||||
registry: Arc<MethodRegistry<S>>,
|
||||
validator: Option<Arc<dyn SessionValidator>>,
|
||||
) -> Result<(), RpcError> {
|
||||
let connection = incoming
|
||||
.await
|
||||
@@ -88,14 +113,40 @@ async fn handle_connection<S: Send + Sync + 'static>(
|
||||
let remote = connection.remote_address();
|
||||
debug!(remote = %remote, "new connection");
|
||||
|
||||
// Perform auth handshake on the first bi-stream.
|
||||
let conn_state = {
|
||||
let (mut send, mut recv) = connection
|
||||
.accept_bi()
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("accept auth stream: {e}")))?;
|
||||
|
||||
let token = auth_handshake::recv_auth_init(&mut recv).await?;
|
||||
debug!(remote = %remote, token_len = token.len(), "received auth init");
|
||||
|
||||
let identity_key = validator.as_ref().and_then(|v| v.validate(&token));
|
||||
|
||||
auth_handshake::send_auth_ack(&mut send).await?;
|
||||
send.finish()
|
||||
.map_err(|e| RpcError::Connection(format!("finish auth stream: {e}")))?;
|
||||
|
||||
Arc::new(ConnectionState {
|
||||
session_token: Some(token),
|
||||
identity_key,
|
||||
})
|
||||
};
|
||||
|
||||
// Accept RPC streams.
|
||||
loop {
|
||||
let stream = connection.accept_bi().await;
|
||||
match stream {
|
||||
Ok((send, recv)) => {
|
||||
let state = Arc::clone(&state);
|
||||
let registry = Arc::clone(®istry);
|
||||
let conn_state = Arc::clone(&conn_state);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_stream(send, recv, state, registry).await {
|
||||
if let Err(e) =
|
||||
handle_stream(send, recv, state, registry, &conn_state).await
|
||||
{
|
||||
debug!("stream error: {e}");
|
||||
}
|
||||
});
|
||||
@@ -120,6 +171,7 @@ async fn handle_stream<S: Send + Sync + 'static>(
|
||||
mut recv: RecvStream,
|
||||
state: Arc<S>,
|
||||
registry: Arc<MethodRegistry<S>>,
|
||||
conn_state: &ConnectionState,
|
||||
) -> Result<(), RpcError> {
|
||||
// Read the complete request from the stream.
|
||||
let mut buf = BytesMut::new();
|
||||
@@ -146,8 +198,8 @@ async fn handle_stream<S: Send + Sync + 'static>(
|
||||
Some((handler, name)) => {
|
||||
debug!(method_id = frame.method_id, method = name, req_id = frame.request_id, "dispatching");
|
||||
let ctx = RequestContext {
|
||||
identity_key: None, // populated by auth middleware
|
||||
session_token: None,
|
||||
identity_key: conn_state.identity_key.clone(),
|
||||
session_token: conn_state.session_token.clone(),
|
||||
payload: frame.payload,
|
||||
};
|
||||
handler(Arc::clone(&state), ctx).await
|
||||
|
||||
Reference in New Issue
Block a user