feat: add graceful shutdown with drain timeout and per-RPC timeouts
Graceful shutdown (Phase 6.4): - Listen for SIGTERM + SIGINT via tokio::signal - Configurable drain timeout (--drain-timeout / QPQ_DRAIN_TIMEOUT, default 30s) - Health endpoint returns "draining" during shutdown for load balancer awareness - ServerState carries atomic draining flag - Add RpcStatus::Unavailable (9) for shutdown-related rejections Per-RPC timeouts (Phase 6.5): - Add RpcStatus::DeadlineExceeded (8) for server-side timeouts - MethodRegistry supports default_timeout and per-method timeout overrides - RPC dispatch wraps handler invocation with tokio::time::timeout - RequestContext carries optional deadline (Instant) for handlers - Health: 5s timeout, blob upload/download: 120s timeout, default: 30s - Config: --rpc-timeout / QPQ_RPC_TIMEOUT, --storage-timeout / QPQ_STORAGE_TIMEOUT
This commit is contained in:
@@ -18,6 +18,8 @@ tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
uuid = { version = "1", features = ["v7"] }
|
||||
metrics = "0.22"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["test-util"] }
|
||||
|
||||
@@ -16,6 +16,10 @@ pub enum RpcStatus {
|
||||
NotFound = 4,
|
||||
/// Rate limit exceeded.
|
||||
RateLimited = 5,
|
||||
/// Request deadline exceeded (server-side timeout).
|
||||
DeadlineExceeded = 8,
|
||||
/// Server is shutting down (draining).
|
||||
Unavailable = 9,
|
||||
/// Internal server error.
|
||||
Internal = 10,
|
||||
/// Method not recognized.
|
||||
@@ -32,6 +36,8 @@ impl RpcStatus {
|
||||
3 => Some(Self::Forbidden),
|
||||
4 => Some(Self::NotFound),
|
||||
5 => Some(Self::RateLimited),
|
||||
8 => Some(Self::DeadlineExceeded),
|
||||
9 => Some(Self::Unavailable),
|
||||
10 => Some(Self::Internal),
|
||||
11 => Some(Self::UnknownMethod),
|
||||
_ => None,
|
||||
|
||||
@@ -4,8 +4,10 @@ use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use bytes::Bytes;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::error::RpcStatus;
|
||||
|
||||
@@ -41,6 +43,11 @@ pub struct RequestContext {
|
||||
pub session_token: Option<Vec<u8>>,
|
||||
/// The raw request payload (protobuf-encoded).
|
||||
pub payload: Bytes,
|
||||
/// Unique correlation ID for request tracing (UUID v7, monotonic).
|
||||
pub trace_id: String,
|
||||
/// The effective deadline for this request. Handlers can check this to bail
|
||||
/// early on long-running operations. `None` means no deadline.
|
||||
pub deadline: Option<Instant>,
|
||||
}
|
||||
|
||||
/// Type-erased async handler function.
|
||||
@@ -50,18 +57,34 @@ pub type HandlerFn<S> = Arc<
|
||||
+ Sync,
|
||||
>;
|
||||
|
||||
/// Per-method registration entry.
|
||||
struct MethodEntry<S> {
|
||||
handler: HandlerFn<S>,
|
||||
name: &'static str,
|
||||
/// Optional per-method timeout override. `None` means use the server default.
|
||||
timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
/// Registry mapping method IDs to handler functions.
|
||||
pub struct MethodRegistry<S> {
|
||||
handlers: HashMap<u16, (HandlerFn<S>, &'static str)>,
|
||||
handlers: HashMap<u16, MethodEntry<S>>,
|
||||
/// Default timeout applied to methods that don't specify their own.
|
||||
default_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl<S: Send + Sync + 'static> MethodRegistry<S> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handlers: HashMap::new(),
|
||||
default_timeout: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the default timeout for all methods that don't have a per-method override.
|
||||
pub fn set_default_timeout(&mut self, timeout: Duration) {
|
||||
self.default_timeout = Some(timeout);
|
||||
}
|
||||
|
||||
/// Register a handler for a method ID.
|
||||
pub fn register<F, Fut>(&mut self, method_id: u16, name: &'static str, handler: F)
|
||||
where
|
||||
@@ -71,12 +94,32 @@ impl<S: Send + Sync + 'static> MethodRegistry<S> {
|
||||
let handler = Arc::new(move |state: Arc<S>, ctx: RequestContext| {
|
||||
Box::pin(handler(state, ctx)) as Pin<Box<dyn Future<Output = HandlerResult> + Send>>
|
||||
});
|
||||
self.handlers.insert(method_id, (handler, name));
|
||||
self.handlers.insert(method_id, MethodEntry { handler, name, timeout: None });
|
||||
}
|
||||
|
||||
/// Look up a handler by method ID.
|
||||
pub fn get(&self, method_id: u16) -> Option<&(HandlerFn<S>, &'static str)> {
|
||||
self.handlers.get(&method_id)
|
||||
/// Register a handler with a per-method timeout override.
|
||||
pub fn register_with_timeout<F, Fut>(
|
||||
&mut self,
|
||||
method_id: u16,
|
||||
name: &'static str,
|
||||
timeout: Duration,
|
||||
handler: F,
|
||||
)
|
||||
where
|
||||
F: Fn(Arc<S>, RequestContext) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output = HandlerResult> + Send + 'static,
|
||||
{
|
||||
let handler = Arc::new(move |state: Arc<S>, ctx: RequestContext| {
|
||||
Box::pin(handler(state, ctx)) as Pin<Box<dyn Future<Output = HandlerResult> + Send>>
|
||||
});
|
||||
self.handlers.insert(method_id, MethodEntry { handler, name, timeout: Some(timeout) });
|
||||
}
|
||||
|
||||
/// Look up a handler, name, and effective timeout by method ID.
|
||||
pub fn get(&self, method_id: u16) -> Option<(&HandlerFn<S>, &'static str, Option<Duration>)> {
|
||||
self.handlers.get(&method_id).map(|e| {
|
||||
(&e.handler, e.name, e.timeout.or(self.default_timeout))
|
||||
})
|
||||
}
|
||||
|
||||
/// Return the number of registered methods.
|
||||
@@ -91,7 +134,7 @@ impl<S: Send + Sync + 'static> MethodRegistry<S> {
|
||||
|
||||
/// Iterate over all registered (method_id, name) pairs.
|
||||
pub fn methods(&self) -> impl Iterator<Item = (u16, &'static str)> + '_ {
|
||||
self.handlers.iter().map(|(&id, (_, name))| (id, *name))
|
||||
self.handlers.iter().map(|(&id, entry)| (id, entry.name))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,3 +143,66 @@ impl<S: Send + Sync + 'static> Default for MethodRegistry<S> {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn registry_default_timeout_applies_to_methods() {
|
||||
let mut reg = MethodRegistry::<()>::new();
|
||||
reg.set_default_timeout(Duration::from_secs(30));
|
||||
reg.register(1, "Test", |_state: Arc<()>, _ctx| async { HandlerResult::ok(Bytes::new()) });
|
||||
|
||||
let (_, name, timeout) = reg.get(1).expect("registered method");
|
||||
assert_eq!(name, "Test");
|
||||
assert_eq!(timeout, Some(Duration::from_secs(30)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_per_method_timeout_overrides_default() {
|
||||
let mut reg = MethodRegistry::<()>::new();
|
||||
reg.set_default_timeout(Duration::from_secs(30));
|
||||
reg.register_with_timeout(
|
||||
1,
|
||||
"Slow",
|
||||
Duration::from_secs(120),
|
||||
|_state: Arc<()>, _ctx| async { HandlerResult::ok(Bytes::new()) },
|
||||
);
|
||||
|
||||
let (_, _, timeout) = reg.get(1).expect("registered method");
|
||||
assert_eq!(timeout, Some(Duration::from_secs(120)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_no_default_timeout_returns_none() {
|
||||
let mut reg = MethodRegistry::<()>::new();
|
||||
reg.register(1, "NoTimeout", |_state: Arc<()>, _ctx| async {
|
||||
HandlerResult::ok(Bytes::new())
|
||||
});
|
||||
|
||||
let (_, _, timeout) = reg.get(1).expect("registered method");
|
||||
assert_eq!(timeout, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_context_deadline_is_accessible() {
|
||||
let ctx = RequestContext {
|
||||
identity_key: None,
|
||||
session_token: None,
|
||||
payload: Bytes::new(),
|
||||
trace_id: String::new(),
|
||||
deadline: Some(Instant::now() + Duration::from_secs(10)),
|
||||
};
|
||||
assert!(ctx.deadline.is_some());
|
||||
|
||||
let ctx_no_deadline = RequestContext {
|
||||
identity_key: None,
|
||||
session_token: None,
|
||||
payload: Bytes::new(),
|
||||
trace_id: String::new(),
|
||||
deadline: None,
|
||||
};
|
||||
assert!(ctx_no_deadline.deadline.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,6 +113,9 @@ async fn handle_connection<S: Send + Sync + 'static>(
|
||||
let remote = connection.remote_address();
|
||||
debug!(remote = %remote, "new connection");
|
||||
|
||||
metrics::gauge!("rpc_active_connections").increment(1.0);
|
||||
metrics::counter!("rpc_connections_total").increment(1);
|
||||
|
||||
// Perform auth handshake on the first bi-stream.
|
||||
let conn_state = {
|
||||
let (mut send, mut recv) = connection
|
||||
@@ -136,7 +139,7 @@ async fn handle_connection<S: Send + Sync + 'static>(
|
||||
};
|
||||
|
||||
// Accept RPC streams.
|
||||
loop {
|
||||
let result = loop {
|
||||
let stream = connection.accept_bi().await;
|
||||
match stream {
|
||||
Ok((send, recv)) => {
|
||||
@@ -153,16 +156,17 @@ async fn handle_connection<S: Send + Sync + 'static>(
|
||||
}
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => {
|
||||
debug!(remote = %remote, "connection closed by peer");
|
||||
break;
|
||||
break Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(remote = %remote, "accept_bi error: {e}");
|
||||
break;
|
||||
break Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
metrics::gauge!("rpc_active_connections").decrement(1.0);
|
||||
result
|
||||
}
|
||||
|
||||
/// Handle a single bi-directional stream: read request, dispatch, write response.
|
||||
@@ -194,18 +198,57 @@ async fn handle_stream<S: Send + Sync + 'static>(
|
||||
None => return Err(RpcError::Decode("incomplete request frame".into())),
|
||||
};
|
||||
|
||||
let trace_id = uuid::Uuid::now_v7().to_string();
|
||||
|
||||
let result = match registry.get(frame.method_id) {
|
||||
Some((handler, name)) => {
|
||||
debug!(method_id = frame.method_id, method = name, req_id = frame.request_id, "dispatching");
|
||||
Some((handler, name, timeout)) => {
|
||||
let span = tracing::info_span!(
|
||||
"rpc",
|
||||
trace_id = %trace_id,
|
||||
method_id = frame.method_id,
|
||||
method = name,
|
||||
req_id = frame.request_id,
|
||||
);
|
||||
let _guard = span.enter();
|
||||
debug!("dispatching");
|
||||
|
||||
let deadline = timeout.map(|d| tokio::time::Instant::now() + d);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let ctx = RequestContext {
|
||||
identity_key: conn_state.identity_key.clone(),
|
||||
session_token: conn_state.session_token.clone(),
|
||||
payload: frame.payload,
|
||||
trace_id: trace_id.clone(),
|
||||
deadline,
|
||||
};
|
||||
handler(Arc::clone(&state), ctx).await
|
||||
|
||||
let result = if let Some(dur) = timeout {
|
||||
match tokio::time::timeout(dur, handler(Arc::clone(&state), ctx)).await {
|
||||
Ok(r) => r,
|
||||
Err(_) => {
|
||||
warn!(method = name, timeout_ms = dur.as_millis() as u64, "request deadline exceeded");
|
||||
HandlerResult::err(RpcStatus::DeadlineExceeded, "request deadline exceeded")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
handler(Arc::clone(&state), ctx).await
|
||||
};
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
// Per-endpoint latency histogram.
|
||||
metrics::histogram!("rpc_request_duration_seconds", "method" => name)
|
||||
.record(elapsed.as_secs_f64());
|
||||
metrics::counter!("rpc_requests_total", "method" => name, "status" => status_label(result.status))
|
||||
.increment(1);
|
||||
|
||||
result
|
||||
}
|
||||
None => {
|
||||
warn!(method_id = frame.method_id, "unknown method");
|
||||
warn!(method_id = frame.method_id, trace_id = %trace_id, "unknown method");
|
||||
metrics::counter!("rpc_requests_total", "method" => "unknown", "status" => "unknown_method")
|
||||
.increment(1);
|
||||
HandlerResult::err(RpcStatus::UnknownMethod, "unknown method")
|
||||
}
|
||||
};
|
||||
@@ -225,6 +268,22 @@ async fn handle_stream<S: Send + Sync + 'static>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert an RpcStatus to a short label for metrics.
|
||||
fn status_label(status: RpcStatus) -> &'static str {
|
||||
match status {
|
||||
RpcStatus::Ok => "ok",
|
||||
RpcStatus::BadRequest => "bad_request",
|
||||
RpcStatus::Unauthorized => "unauthorized",
|
||||
RpcStatus::Forbidden => "forbidden",
|
||||
RpcStatus::NotFound => "not_found",
|
||||
RpcStatus::RateLimited => "rate_limited",
|
||||
RpcStatus::DeadlineExceeded => "deadline_exceeded",
|
||||
RpcStatus::Unavailable => "unavailable",
|
||||
RpcStatus::Internal => "internal",
|
||||
RpcStatus::UnknownMethod => "unknown_method",
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a push event to a client via a QUIC uni-stream.
|
||||
pub async fn send_push(
|
||||
connection: &quinn::Connection,
|
||||
|
||||
Reference in New Issue
Block a user