Graceful shutdown (Phase 6.4): - Listen for SIGTERM + SIGINT via tokio::signal - Configurable drain timeout (--drain-timeout / QPQ_DRAIN_TIMEOUT, default 30s) - Health endpoint returns "draining" during shutdown for load balancer awareness - ServerState carries atomic draining flag - Add RpcStatus::Unavailable (9) for shutdown-related rejections Per-RPC timeouts (Phase 6.5): - Add RpcStatus::DeadlineExceeded (8) for server-side timeouts - MethodRegistry supports default_timeout and per-method timeout overrides - RPC dispatch wraps handler invocation with tokio::time::timeout - RequestContext carries optional deadline (Instant) for handlers - Health: 5s timeout, blob upload/download: 120s timeout, default: 30s - Config: --rpc-timeout / QPQ_RPC_TIMEOUT, --storage-timeout / QPQ_STORAGE_TIMEOUT
310 lines
10 KiB
Rust
310 lines
10 KiB
Rust
//! QUIC RPC server — accepts connections, dispatches requests to handlers.
|
|
|
|
use std::sync::Arc;
|
|
|
|
use bytes::BytesMut;
|
|
use quinn::{Endpoint, Incoming, RecvStream, SendStream};
|
|
use tracing::{debug, info, warn};
|
|
|
|
use crate::auth_handshake;
|
|
use crate::error::{RpcError, RpcStatus};
|
|
use crate::framing::{RequestFrame, ResponseFrame, PushFrame};
|
|
use crate::method::{HandlerResult, MethodRegistry, RequestContext};
|
|
use crate::middleware::SessionValidator;
|
|
|
|
/// Per-connection state established during the auth handshake.
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct ConnectionState {
|
|
/// The raw session token provided during the auth init handshake.
|
|
pub session_token: Option<Vec<u8>>,
|
|
/// The identity key resolved from the session token (populated by validator).
|
|
pub identity_key: Option<Vec<u8>>,
|
|
}
|
|
|
|
/// Configuration for the RPC server.
|
|
pub struct RpcServerConfig {
|
|
/// QUIC listen address.
|
|
pub listen_addr: std::net::SocketAddr,
|
|
/// TLS server config (rustls).
|
|
pub tls_config: Arc<rustls::ServerConfig>,
|
|
/// ALPN protocol for the RPC service.
|
|
pub alpn: Vec<u8>,
|
|
}
|
|
|
|
/// The QUIC RPC server.
|
|
pub struct RpcServer<S: Send + Sync + 'static> {
|
|
endpoint: Endpoint,
|
|
state: Arc<S>,
|
|
registry: Arc<MethodRegistry<S>>,
|
|
/// Optional session validator for the auth handshake.
|
|
validator: Option<Arc<dyn SessionValidator>>,
|
|
}
|
|
|
|
impl<S: Send + Sync + 'static> RpcServer<S> {
|
|
/// Create and bind the QUIC endpoint. Does not start accepting yet.
|
|
pub fn bind(
|
|
config: RpcServerConfig,
|
|
state: Arc<S>,
|
|
registry: MethodRegistry<S>,
|
|
) -> Result<Self, RpcError> {
|
|
let mut tls = (*config.tls_config).clone();
|
|
tls.alpn_protocols = vec![config.alpn];
|
|
let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(tls)
|
|
.map_err(|e| RpcError::Connection(format!("TLS config: {e}")))?;
|
|
let server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_tls));
|
|
|
|
let endpoint = Endpoint::server(server_config, config.listen_addr)
|
|
.map_err(|e| RpcError::Connection(format!("bind {}: {e}", config.listen_addr)))?;
|
|
|
|
info!(addr = %config.listen_addr, "RPC server bound");
|
|
|
|
Ok(Self {
|
|
endpoint,
|
|
state,
|
|
registry: Arc::new(registry),
|
|
validator: None,
|
|
})
|
|
}
|
|
|
|
/// Set a session validator for the auth handshake.
|
|
pub fn with_validator(mut self, validator: Arc<dyn SessionValidator>) -> Self {
|
|
self.validator = Some(validator);
|
|
self
|
|
}
|
|
|
|
/// Accept connections in a loop. Spawns a task per connection.
|
|
pub async fn serve(self) -> Result<(), RpcError> {
|
|
info!("RPC server accepting connections");
|
|
while let Some(incoming) = self.endpoint.accept().await {
|
|
let state = Arc::clone(&self.state);
|
|
let registry = Arc::clone(&self.registry);
|
|
let validator = self.validator.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) =
|
|
handle_connection(incoming, state, registry, validator).await
|
|
{
|
|
warn!("connection error: {e}");
|
|
}
|
|
});
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Get the local address the server is listening on.
|
|
pub fn local_addr(&self) -> Result<std::net::SocketAddr, RpcError> {
|
|
self.endpoint
|
|
.local_addr()
|
|
.map_err(|e| RpcError::Connection(e.to_string()))
|
|
}
|
|
}
|
|
|
|
/// Handle a single QUIC connection: perform auth handshake, then accept
|
|
/// bi-directional streams for RPCs.
|
|
async fn handle_connection<S: Send + Sync + 'static>(
|
|
incoming: Incoming,
|
|
state: Arc<S>,
|
|
registry: Arc<MethodRegistry<S>>,
|
|
validator: Option<Arc<dyn SessionValidator>>,
|
|
) -> Result<(), RpcError> {
|
|
let connection = incoming
|
|
.await
|
|
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
|
|
|
let remote = connection.remote_address();
|
|
debug!(remote = %remote, "new connection");
|
|
|
|
metrics::gauge!("rpc_active_connections").increment(1.0);
|
|
metrics::counter!("rpc_connections_total").increment(1);
|
|
|
|
// Perform auth handshake on the first bi-stream.
|
|
let conn_state = {
|
|
let (mut send, mut recv) = connection
|
|
.accept_bi()
|
|
.await
|
|
.map_err(|e| RpcError::Connection(format!("accept auth stream: {e}")))?;
|
|
|
|
let token = auth_handshake::recv_auth_init(&mut recv).await?;
|
|
debug!(remote = %remote, token_len = token.len(), "received auth init");
|
|
|
|
let identity_key = validator.as_ref().and_then(|v| v.validate(&token));
|
|
|
|
auth_handshake::send_auth_ack(&mut send).await?;
|
|
send.finish()
|
|
.map_err(|e| RpcError::Connection(format!("finish auth stream: {e}")))?;
|
|
|
|
Arc::new(ConnectionState {
|
|
session_token: Some(token),
|
|
identity_key,
|
|
})
|
|
};
|
|
|
|
// Accept RPC streams.
|
|
let result = loop {
|
|
let stream = connection.accept_bi().await;
|
|
match stream {
|
|
Ok((send, recv)) => {
|
|
let state = Arc::clone(&state);
|
|
let registry = Arc::clone(®istry);
|
|
let conn_state = Arc::clone(&conn_state);
|
|
tokio::spawn(async move {
|
|
if let Err(e) =
|
|
handle_stream(send, recv, state, registry, &conn_state).await
|
|
{
|
|
debug!("stream error: {e}");
|
|
}
|
|
});
|
|
}
|
|
Err(quinn::ConnectionError::ApplicationClosed(_)) => {
|
|
debug!(remote = %remote, "connection closed by peer");
|
|
break Ok(());
|
|
}
|
|
Err(e) => {
|
|
debug!(remote = %remote, "accept_bi error: {e}");
|
|
break Ok(());
|
|
}
|
|
}
|
|
};
|
|
|
|
metrics::gauge!("rpc_active_connections").decrement(1.0);
|
|
result
|
|
}
|
|
|
|
/// Handle a single bi-directional stream: read request, dispatch, write response.
|
|
async fn handle_stream<S: Send + Sync + 'static>(
|
|
mut send: SendStream,
|
|
mut recv: RecvStream,
|
|
state: Arc<S>,
|
|
registry: Arc<MethodRegistry<S>>,
|
|
conn_state: &ConnectionState,
|
|
) -> Result<(), RpcError> {
|
|
// Read the complete request from the stream.
|
|
let mut buf = BytesMut::new();
|
|
while let Some(chunk) = recv
|
|
.read_chunk(65536, true)
|
|
.await
|
|
.map_err(|e| RpcError::Connection(e.to_string()))?
|
|
{
|
|
buf.extend_from_slice(&chunk.bytes);
|
|
if buf.len() > crate::framing::MAX_PAYLOAD_SIZE + crate::framing::REQUEST_HEADER_SIZE {
|
|
return Err(RpcError::PayloadTooLarge {
|
|
size: buf.len(),
|
|
max: crate::framing::MAX_PAYLOAD_SIZE,
|
|
});
|
|
}
|
|
}
|
|
|
|
let frame = match RequestFrame::decode(&mut buf)? {
|
|
Some(f) => f,
|
|
None => return Err(RpcError::Decode("incomplete request frame".into())),
|
|
};
|
|
|
|
let trace_id = uuid::Uuid::now_v7().to_string();
|
|
|
|
let result = match registry.get(frame.method_id) {
|
|
Some((handler, name, timeout)) => {
|
|
let span = tracing::info_span!(
|
|
"rpc",
|
|
trace_id = %trace_id,
|
|
method_id = frame.method_id,
|
|
method = name,
|
|
req_id = frame.request_id,
|
|
);
|
|
let _guard = span.enter();
|
|
debug!("dispatching");
|
|
|
|
let deadline = timeout.map(|d| tokio::time::Instant::now() + d);
|
|
|
|
let start = std::time::Instant::now();
|
|
let ctx = RequestContext {
|
|
identity_key: conn_state.identity_key.clone(),
|
|
session_token: conn_state.session_token.clone(),
|
|
payload: frame.payload,
|
|
trace_id: trace_id.clone(),
|
|
deadline,
|
|
};
|
|
|
|
let result = if let Some(dur) = timeout {
|
|
match tokio::time::timeout(dur, handler(Arc::clone(&state), ctx)).await {
|
|
Ok(r) => r,
|
|
Err(_) => {
|
|
warn!(method = name, timeout_ms = dur.as_millis() as u64, "request deadline exceeded");
|
|
HandlerResult::err(RpcStatus::DeadlineExceeded, "request deadline exceeded")
|
|
}
|
|
}
|
|
} else {
|
|
handler(Arc::clone(&state), ctx).await
|
|
};
|
|
|
|
let elapsed = start.elapsed();
|
|
|
|
// Per-endpoint latency histogram.
|
|
metrics::histogram!("rpc_request_duration_seconds", "method" => name)
|
|
.record(elapsed.as_secs_f64());
|
|
metrics::counter!("rpc_requests_total", "method" => name, "status" => status_label(result.status))
|
|
.increment(1);
|
|
|
|
result
|
|
}
|
|
None => {
|
|
warn!(method_id = frame.method_id, trace_id = %trace_id, "unknown method");
|
|
metrics::counter!("rpc_requests_total", "method" => "unknown", "status" => "unknown_method")
|
|
.increment(1);
|
|
HandlerResult::err(RpcStatus::UnknownMethod, "unknown method")
|
|
}
|
|
};
|
|
|
|
let response = ResponseFrame {
|
|
status: result.status as u8,
|
|
request_id: frame.request_id,
|
|
payload: result.payload,
|
|
};
|
|
|
|
let encoded = response.encode();
|
|
send.write_all(&encoded)
|
|
.await
|
|
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
|
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Convert an RpcStatus to a short label for metrics.
|
|
fn status_label(status: RpcStatus) -> &'static str {
|
|
match status {
|
|
RpcStatus::Ok => "ok",
|
|
RpcStatus::BadRequest => "bad_request",
|
|
RpcStatus::Unauthorized => "unauthorized",
|
|
RpcStatus::Forbidden => "forbidden",
|
|
RpcStatus::NotFound => "not_found",
|
|
RpcStatus::RateLimited => "rate_limited",
|
|
RpcStatus::DeadlineExceeded => "deadline_exceeded",
|
|
RpcStatus::Unavailable => "unavailable",
|
|
RpcStatus::Internal => "internal",
|
|
RpcStatus::UnknownMethod => "unknown_method",
|
|
}
|
|
}
|
|
|
|
/// Send a push event to a client via a QUIC uni-stream.
|
|
pub async fn send_push(
|
|
connection: &quinn::Connection,
|
|
event_type: u16,
|
|
payload: bytes::Bytes,
|
|
) -> Result<(), RpcError> {
|
|
let mut send = connection
|
|
.open_uni()
|
|
.await
|
|
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
|
|
|
let frame = PushFrame {
|
|
event_type,
|
|
payload,
|
|
};
|
|
let encoded = frame.encode();
|
|
send.write_all(&encoded)
|
|
.await
|
|
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
|
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|