Files
quicproquo/crates/quicproquo-rpc/src/server.rs
Christian Nennemann e93a38243f feat: add graceful shutdown with drain timeout and per-RPC timeouts
Graceful shutdown (Phase 6.4):
- Listen for SIGTERM + SIGINT via tokio::signal
- Configurable drain timeout (--drain-timeout / QPQ_DRAIN_TIMEOUT, default 30s)
- Health endpoint returns "draining" during shutdown for load balancer awareness
- ServerState carries atomic draining flag
- Add RpcStatus::Unavailable (9) for shutdown-related rejections

Per-RPC timeouts (Phase 6.5):
- Add RpcStatus::DeadlineExceeded (8) for server-side timeouts
- MethodRegistry supports default_timeout and per-method timeout overrides
- RPC dispatch wraps handler invocation with tokio::time::timeout
- RequestContext carries optional deadline (Instant) for handlers
- Health: 5s timeout, blob upload/download: 120s timeout, default: 30s
- Config: --rpc-timeout / QPQ_RPC_TIMEOUT, --storage-timeout / QPQ_STORAGE_TIMEOUT
2026-03-04 20:33:26 +01:00

310 lines
10 KiB
Rust

//! QUIC RPC server — accepts connections, dispatches requests to handlers.
use std::sync::Arc;
use bytes::BytesMut;
use quinn::{Endpoint, Incoming, RecvStream, SendStream};
use tracing::{debug, info, warn};
use crate::auth_handshake;
use crate::error::{RpcError, RpcStatus};
use crate::framing::{RequestFrame, ResponseFrame, PushFrame};
use crate::method::{HandlerResult, MethodRegistry, RequestContext};
use crate::middleware::SessionValidator;
/// Per-connection state established during the auth handshake.
#[derive(Debug, Clone, Default)]
pub struct ConnectionState {
/// The raw session token provided during the auth init handshake.
pub session_token: Option<Vec<u8>>,
/// The identity key resolved from the session token (populated by validator).
pub identity_key: Option<Vec<u8>>,
}
/// Configuration for the RPC server.
pub struct RpcServerConfig {
/// QUIC listen address.
pub listen_addr: std::net::SocketAddr,
/// TLS server config (rustls).
pub tls_config: Arc<rustls::ServerConfig>,
/// ALPN protocol for the RPC service.
pub alpn: Vec<u8>,
}
/// The QUIC RPC server.
pub struct RpcServer<S: Send + Sync + 'static> {
endpoint: Endpoint,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
/// Optional session validator for the auth handshake.
validator: Option<Arc<dyn SessionValidator>>,
}
impl<S: Send + Sync + 'static> RpcServer<S> {
/// Create and bind the QUIC endpoint. Does not start accepting yet.
pub fn bind(
config: RpcServerConfig,
state: Arc<S>,
registry: MethodRegistry<S>,
) -> Result<Self, RpcError> {
let mut tls = (*config.tls_config).clone();
tls.alpn_protocols = vec![config.alpn];
let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(tls)
.map_err(|e| RpcError::Connection(format!("TLS config: {e}")))?;
let server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_tls));
let endpoint = Endpoint::server(server_config, config.listen_addr)
.map_err(|e| RpcError::Connection(format!("bind {}: {e}", config.listen_addr)))?;
info!(addr = %config.listen_addr, "RPC server bound");
Ok(Self {
endpoint,
state,
registry: Arc::new(registry),
validator: None,
})
}
/// Set a session validator for the auth handshake.
pub fn with_validator(mut self, validator: Arc<dyn SessionValidator>) -> Self {
self.validator = Some(validator);
self
}
/// Accept connections in a loop. Spawns a task per connection.
pub async fn serve(self) -> Result<(), RpcError> {
info!("RPC server accepting connections");
while let Some(incoming) = self.endpoint.accept().await {
let state = Arc::clone(&self.state);
let registry = Arc::clone(&self.registry);
let validator = self.validator.clone();
tokio::spawn(async move {
if let Err(e) =
handle_connection(incoming, state, registry, validator).await
{
warn!("connection error: {e}");
}
});
}
Ok(())
}
/// Get the local address the server is listening on.
pub fn local_addr(&self) -> Result<std::net::SocketAddr, RpcError> {
self.endpoint
.local_addr()
.map_err(|e| RpcError::Connection(e.to_string()))
}
}
/// Handle a single QUIC connection: perform auth handshake, then accept
/// bi-directional streams for RPCs.
async fn handle_connection<S: Send + Sync + 'static>(
incoming: Incoming,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
validator: Option<Arc<dyn SessionValidator>>,
) -> Result<(), RpcError> {
let connection = incoming
.await
.map_err(|e| RpcError::Connection(e.to_string()))?;
let remote = connection.remote_address();
debug!(remote = %remote, "new connection");
metrics::gauge!("rpc_active_connections").increment(1.0);
metrics::counter!("rpc_connections_total").increment(1);
// Perform auth handshake on the first bi-stream.
let conn_state = {
let (mut send, mut recv) = connection
.accept_bi()
.await
.map_err(|e| RpcError::Connection(format!("accept auth stream: {e}")))?;
let token = auth_handshake::recv_auth_init(&mut recv).await?;
debug!(remote = %remote, token_len = token.len(), "received auth init");
let identity_key = validator.as_ref().and_then(|v| v.validate(&token));
auth_handshake::send_auth_ack(&mut send).await?;
send.finish()
.map_err(|e| RpcError::Connection(format!("finish auth stream: {e}")))?;
Arc::new(ConnectionState {
session_token: Some(token),
identity_key,
})
};
// Accept RPC streams.
let result = loop {
let stream = connection.accept_bi().await;
match stream {
Ok((send, recv)) => {
let state = Arc::clone(&state);
let registry = Arc::clone(&registry);
let conn_state = Arc::clone(&conn_state);
tokio::spawn(async move {
if let Err(e) =
handle_stream(send, recv, state, registry, &conn_state).await
{
debug!("stream error: {e}");
}
});
}
Err(quinn::ConnectionError::ApplicationClosed(_)) => {
debug!(remote = %remote, "connection closed by peer");
break Ok(());
}
Err(e) => {
debug!(remote = %remote, "accept_bi error: {e}");
break Ok(());
}
}
};
metrics::gauge!("rpc_active_connections").decrement(1.0);
result
}
/// Handle a single bi-directional stream: read request, dispatch, write response.
async fn handle_stream<S: Send + Sync + 'static>(
mut send: SendStream,
mut recv: RecvStream,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
conn_state: &ConnectionState,
) -> Result<(), RpcError> {
// Read the complete request from the stream.
let mut buf = BytesMut::new();
while let Some(chunk) = recv
.read_chunk(65536, true)
.await
.map_err(|e| RpcError::Connection(e.to_string()))?
{
buf.extend_from_slice(&chunk.bytes);
if buf.len() > crate::framing::MAX_PAYLOAD_SIZE + crate::framing::REQUEST_HEADER_SIZE {
return Err(RpcError::PayloadTooLarge {
size: buf.len(),
max: crate::framing::MAX_PAYLOAD_SIZE,
});
}
}
let frame = match RequestFrame::decode(&mut buf)? {
Some(f) => f,
None => return Err(RpcError::Decode("incomplete request frame".into())),
};
let trace_id = uuid::Uuid::now_v7().to_string();
let result = match registry.get(frame.method_id) {
Some((handler, name, timeout)) => {
let span = tracing::info_span!(
"rpc",
trace_id = %trace_id,
method_id = frame.method_id,
method = name,
req_id = frame.request_id,
);
let _guard = span.enter();
debug!("dispatching");
let deadline = timeout.map(|d| tokio::time::Instant::now() + d);
let start = std::time::Instant::now();
let ctx = RequestContext {
identity_key: conn_state.identity_key.clone(),
session_token: conn_state.session_token.clone(),
payload: frame.payload,
trace_id: trace_id.clone(),
deadline,
};
let result = if let Some(dur) = timeout {
match tokio::time::timeout(dur, handler(Arc::clone(&state), ctx)).await {
Ok(r) => r,
Err(_) => {
warn!(method = name, timeout_ms = dur.as_millis() as u64, "request deadline exceeded");
HandlerResult::err(RpcStatus::DeadlineExceeded, "request deadline exceeded")
}
}
} else {
handler(Arc::clone(&state), ctx).await
};
let elapsed = start.elapsed();
// Per-endpoint latency histogram.
metrics::histogram!("rpc_request_duration_seconds", "method" => name)
.record(elapsed.as_secs_f64());
metrics::counter!("rpc_requests_total", "method" => name, "status" => status_label(result.status))
.increment(1);
result
}
None => {
warn!(method_id = frame.method_id, trace_id = %trace_id, "unknown method");
metrics::counter!("rpc_requests_total", "method" => "unknown", "status" => "unknown_method")
.increment(1);
HandlerResult::err(RpcStatus::UnknownMethod, "unknown method")
}
};
let response = ResponseFrame {
status: result.status as u8,
request_id: frame.request_id,
payload: result.payload,
};
let encoded = response.encode();
send.write_all(&encoded)
.await
.map_err(|e| RpcError::Connection(e.to_string()))?;
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
Ok(())
}
/// Convert an RpcStatus to a short label for metrics.
fn status_label(status: RpcStatus) -> &'static str {
match status {
RpcStatus::Ok => "ok",
RpcStatus::BadRequest => "bad_request",
RpcStatus::Unauthorized => "unauthorized",
RpcStatus::Forbidden => "forbidden",
RpcStatus::NotFound => "not_found",
RpcStatus::RateLimited => "rate_limited",
RpcStatus::DeadlineExceeded => "deadline_exceeded",
RpcStatus::Unavailable => "unavailable",
RpcStatus::Internal => "internal",
RpcStatus::UnknownMethod => "unknown_method",
}
}
/// Send a push event to a client via a QUIC uni-stream.
pub async fn send_push(
connection: &quinn::Connection,
event_type: u16,
payload: bytes::Bytes,
) -> Result<(), RpcError> {
let mut send = connection
.open_uni()
.await
.map_err(|e| RpcError::Connection(e.to_string()))?;
let frame = PushFrame {
event_type,
payload,
};
let encoded = frame.encode();
send.write_all(&encoded)
.await
.map_err(|e| RpcError::Connection(e.to_string()))?;
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
Ok(())
}