//! QUIC RPC server — accepts connections, dispatches requests to handlers. use std::sync::Arc; use bytes::BytesMut; use quinn::{Endpoint, Incoming, RecvStream, SendStream}; use tracing::{debug, info, warn}; use crate::auth_handshake; use crate::error::{RpcError, RpcStatus}; use crate::framing::{RequestFrame, ResponseFrame, PushFrame}; use crate::method::{HandlerResult, MethodRegistry, RequestContext}; use crate::middleware::SessionValidator; /// Per-connection state established during the auth handshake. #[derive(Debug, Clone, Default)] pub struct ConnectionState { /// The raw session token provided during the auth init handshake. pub session_token: Option>, /// The identity key resolved from the session token (populated by validator). pub identity_key: Option>, } /// Configuration for the RPC server. pub struct RpcServerConfig { /// QUIC listen address. pub listen_addr: std::net::SocketAddr, /// TLS server config (rustls). pub tls_config: Arc, /// ALPN protocol for the RPC service. pub alpn: Vec, } /// The QUIC RPC server. pub struct RpcServer { endpoint: Endpoint, state: Arc, registry: Arc>, /// Optional session validator for the auth handshake. validator: Option>, } impl RpcServer { /// Create and bind the QUIC endpoint. Does not start accepting yet. pub fn bind( config: RpcServerConfig, state: Arc, registry: MethodRegistry, ) -> Result { let mut tls = (*config.tls_config).clone(); tls.alpn_protocols = vec![config.alpn]; let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(tls) .map_err(|e| RpcError::Connection(format!("TLS config: {e}")))?; let server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_tls)); let endpoint = Endpoint::server(server_config, config.listen_addr) .map_err(|e| RpcError::Connection(format!("bind {}: {e}", config.listen_addr)))?; info!(addr = %config.listen_addr, "RPC server bound"); Ok(Self { endpoint, state, registry: Arc::new(registry), validator: None, }) } /// Set a session validator for the auth handshake. pub fn with_validator(mut self, validator: Arc) -> Self { self.validator = Some(validator); self } /// Accept connections in a loop. Spawns a task per connection. pub async fn serve(self) -> Result<(), RpcError> { info!("RPC server accepting connections"); while let Some(incoming) = self.endpoint.accept().await { let state = Arc::clone(&self.state); let registry = Arc::clone(&self.registry); let validator = self.validator.clone(); tokio::spawn(async move { if let Err(e) = handle_connection(incoming, state, registry, validator).await { warn!("connection error: {e}"); } }); } Ok(()) } /// Get the local address the server is listening on. pub fn local_addr(&self) -> Result { self.endpoint .local_addr() .map_err(|e| RpcError::Connection(e.to_string())) } } /// Handle a single QUIC connection: perform auth handshake, then accept /// bi-directional streams for RPCs. async fn handle_connection( incoming: Incoming, state: Arc, registry: Arc>, validator: Option>, ) -> Result<(), RpcError> { let connection = incoming .await .map_err(|e| RpcError::Connection(e.to_string()))?; let remote = connection.remote_address(); debug!(remote = %remote, "new connection"); metrics::gauge!("rpc_active_connections").increment(1.0); metrics::counter!("rpc_connections_total").increment(1); // Perform auth handshake on the first bi-stream. let conn_state = { let (mut send, mut recv) = connection .accept_bi() .await .map_err(|e| RpcError::Connection(format!("accept auth stream: {e}")))?; let token = auth_handshake::recv_auth_init(&mut recv).await?; debug!(remote = %remote, token_len = token.len(), "received auth init"); let identity_key = validator.as_ref().and_then(|v| v.validate(&token)); auth_handshake::send_auth_ack(&mut send).await?; send.finish() .map_err(|e| RpcError::Connection(format!("finish auth stream: {e}")))?; Arc::new(ConnectionState { session_token: Some(token), identity_key, }) }; // Accept RPC streams. let result = loop { let stream = connection.accept_bi().await; match stream { Ok((send, recv)) => { let state = Arc::clone(&state); let registry = Arc::clone(®istry); let conn_state = Arc::clone(&conn_state); tokio::spawn(async move { if let Err(e) = handle_stream(send, recv, state, registry, &conn_state).await { debug!("stream error: {e}"); } }); } Err(quinn::ConnectionError::ApplicationClosed(_)) => { debug!(remote = %remote, "connection closed by peer"); break Ok(()); } Err(e) => { debug!(remote = %remote, "accept_bi error: {e}"); break Ok(()); } } }; metrics::gauge!("rpc_active_connections").decrement(1.0); result } /// Handle a single bi-directional stream: read request, dispatch, write response. async fn handle_stream( mut send: SendStream, mut recv: RecvStream, state: Arc, registry: Arc>, conn_state: &ConnectionState, ) -> Result<(), RpcError> { // Read the complete request from the stream. let mut buf = BytesMut::new(); while let Some(chunk) = recv .read_chunk(65536, true) .await .map_err(|e| RpcError::Connection(e.to_string()))? { buf.extend_from_slice(&chunk.bytes); if buf.len() > crate::framing::MAX_PAYLOAD_SIZE + crate::framing::REQUEST_HEADER_SIZE { return Err(RpcError::PayloadTooLarge { size: buf.len(), max: crate::framing::MAX_PAYLOAD_SIZE, }); } } let frame = match RequestFrame::decode(&mut buf)? { Some(f) => f, None => return Err(RpcError::Decode("incomplete request frame".into())), }; let trace_id = uuid::Uuid::now_v7().to_string(); let result = match registry.get(frame.method_id) { Some((handler, name, timeout)) => { let span = tracing::info_span!( "rpc", trace_id = %trace_id, method_id = frame.method_id, method = name, req_id = frame.request_id, ); let _guard = span.enter(); debug!("dispatching"); let deadline = timeout.map(|d| tokio::time::Instant::now() + d); let start = std::time::Instant::now(); let ctx = RequestContext { identity_key: conn_state.identity_key.clone(), session_token: conn_state.session_token.clone(), payload: frame.payload, trace_id: trace_id.clone(), deadline, }; let result = if let Some(dur) = timeout { match tokio::time::timeout(dur, handler(Arc::clone(&state), ctx)).await { Ok(r) => r, Err(_) => { warn!(method = name, timeout_ms = dur.as_millis() as u64, "request deadline exceeded"); HandlerResult::err(RpcStatus::DeadlineExceeded, "request deadline exceeded") } } } else { handler(Arc::clone(&state), ctx).await }; let elapsed = start.elapsed(); // Per-endpoint latency histogram. metrics::histogram!("rpc_request_duration_seconds", "method" => name) .record(elapsed.as_secs_f64()); metrics::counter!("rpc_requests_total", "method" => name, "status" => status_label(result.status)) .increment(1); result } None => { warn!(method_id = frame.method_id, trace_id = %trace_id, "unknown method"); metrics::counter!("rpc_requests_total", "method" => "unknown", "status" => "unknown_method") .increment(1); HandlerResult::err(RpcStatus::UnknownMethod, "unknown method") } }; let response = ResponseFrame { status: result.status as u8, request_id: frame.request_id, payload: result.payload, }; let encoded = response.encode(); send.write_all(&encoded) .await .map_err(|e| RpcError::Connection(e.to_string()))?; send.finish().map_err(|e| RpcError::Connection(e.to_string()))?; Ok(()) } /// Convert an RpcStatus to a short label for metrics. fn status_label(status: RpcStatus) -> &'static str { match status { RpcStatus::Ok => "ok", RpcStatus::BadRequest => "bad_request", RpcStatus::Unauthorized => "unauthorized", RpcStatus::Forbidden => "forbidden", RpcStatus::NotFound => "not_found", RpcStatus::RateLimited => "rate_limited", RpcStatus::DeadlineExceeded => "deadline_exceeded", RpcStatus::Unavailable => "unavailable", RpcStatus::Internal => "internal", RpcStatus::UnknownMethod => "unknown_method", } } /// Send a push event to a client via a QUIC uni-stream. pub async fn send_push( connection: &quinn::Connection, event_type: u16, payload: bytes::Bytes, ) -> Result<(), RpcError> { let mut send = connection .open_uni() .await .map_err(|e| RpcError::Connection(e.to_string()))?; let frame = PushFrame { event_type, payload, }; let encoded = frame.encode(); send.write_all(&encoded) .await .map_err(|e| RpcError::Connection(e.to_string()))?; send.finish().map_err(|e| RpcError::Connection(e.to_string()))?; Ok(()) }