feat(rpc): auth handshake, server-push broker, audit logging

- auth_handshake.rs: connection-init protocol (magic 0x01, token, ack)
- push.rs: PushBroker manages per-identity push connections with gc
- server.rs: ConnectionState, auth handshake on first bi-stream, pass
  identity_key/session_token to RequestContext per stream
- client.rs: session_token in RpcClientConfig, auto auth handshake on connect
- middleware.rs: log_rpc_call with SHA-256 redaction, hex_prefix helper
- lib.rs: export auth_handshake and push modules
This commit is contained in:
2026-03-04 12:08:20 +01:00
parent ff93275dc1
commit f09dbe10ce
7 changed files with 387 additions and 8 deletions

View File

@@ -6,9 +6,20 @@ use bytes::BytesMut;
use quinn::{Endpoint, Incoming, RecvStream, SendStream};
use tracing::{debug, info, warn};
use crate::auth_handshake;
use crate::error::{RpcError, RpcStatus};
use crate::framing::{RequestFrame, ResponseFrame, PushFrame};
use crate::method::{HandlerResult, MethodRegistry, RequestContext};
use crate::middleware::SessionValidator;
/// Per-connection state established during the auth handshake.
#[derive(Debug, Clone, Default)]
pub struct ConnectionState {
/// The raw session token provided during the auth init handshake.
pub session_token: Option<Vec<u8>>,
/// The identity key resolved from the session token (populated by validator).
pub identity_key: Option<Vec<u8>>,
}
/// Configuration for the RPC server.
pub struct RpcServerConfig {
@@ -25,6 +36,8 @@ pub struct RpcServer<S: Send + Sync + 'static> {
endpoint: Endpoint,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
/// Optional session validator for the auth handshake.
validator: Option<Arc<dyn SessionValidator>>,
}
impl<S: Send + Sync + 'static> RpcServer<S> {
@@ -49,17 +62,27 @@ impl<S: Send + Sync + 'static> RpcServer<S> {
endpoint,
state,
registry: Arc::new(registry),
validator: None,
})
}
/// Set a session validator for the auth handshake.
pub fn with_validator(mut self, validator: Arc<dyn SessionValidator>) -> Self {
self.validator = Some(validator);
self
}
/// Accept connections in a loop. Spawns a task per connection.
pub async fn serve(self) -> Result<(), RpcError> {
info!("RPC server accepting connections");
while let Some(incoming) = self.endpoint.accept().await {
let state = Arc::clone(&self.state);
let registry = Arc::clone(&self.registry);
let validator = self.validator.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(incoming, state, registry).await {
if let Err(e) =
handle_connection(incoming, state, registry, validator).await
{
warn!("connection error: {e}");
}
});
@@ -75,11 +98,13 @@ impl<S: Send + Sync + 'static> RpcServer<S> {
}
}
/// Handle a single QUIC connection: accept bi-directional streams for RPCs.
/// Handle a single QUIC connection: perform auth handshake, then accept
/// bi-directional streams for RPCs.
async fn handle_connection<S: Send + Sync + 'static>(
incoming: Incoming,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
validator: Option<Arc<dyn SessionValidator>>,
) -> Result<(), RpcError> {
let connection = incoming
.await
@@ -88,14 +113,40 @@ async fn handle_connection<S: Send + Sync + 'static>(
let remote = connection.remote_address();
debug!(remote = %remote, "new connection");
// Perform auth handshake on the first bi-stream.
let conn_state = {
let (mut send, mut recv) = connection
.accept_bi()
.await
.map_err(|e| RpcError::Connection(format!("accept auth stream: {e}")))?;
let token = auth_handshake::recv_auth_init(&mut recv).await?;
debug!(remote = %remote, token_len = token.len(), "received auth init");
let identity_key = validator.as_ref().and_then(|v| v.validate(&token));
auth_handshake::send_auth_ack(&mut send).await?;
send.finish()
.map_err(|e| RpcError::Connection(format!("finish auth stream: {e}")))?;
Arc::new(ConnectionState {
session_token: Some(token),
identity_key,
})
};
// Accept RPC streams.
loop {
let stream = connection.accept_bi().await;
match stream {
Ok((send, recv)) => {
let state = Arc::clone(&state);
let registry = Arc::clone(&registry);
let conn_state = Arc::clone(&conn_state);
tokio::spawn(async move {
if let Err(e) = handle_stream(send, recv, state, registry).await {
if let Err(e) =
handle_stream(send, recv, state, registry, &conn_state).await
{
debug!("stream error: {e}");
}
});
@@ -120,6 +171,7 @@ async fn handle_stream<S: Send + Sync + 'static>(
mut recv: RecvStream,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
conn_state: &ConnectionState,
) -> Result<(), RpcError> {
// Read the complete request from the stream.
let mut buf = BytesMut::new();
@@ -146,8 +198,8 @@ async fn handle_stream<S: Send + Sync + 'static>(
Some((handler, name)) => {
debug!(method_id = frame.method_id, method = name, req_id = frame.request_id, "dispatching");
let ctx = RequestContext {
identity_key: None, // populated by auth middleware
session_token: None,
identity_key: conn_state.identity_key.clone(),
session_token: conn_state.session_token.clone(),
payload: frame.payload,
};
handler(Arc::clone(&state), ctx).await