feat: add graceful shutdown with drain timeout and per-RPC timeouts

Graceful shutdown (Phase 6.4):
- Listen for SIGTERM + SIGINT via tokio::signal
- Configurable drain timeout (--drain-timeout / QPQ_DRAIN_TIMEOUT, default 30s)
- Health endpoint returns "draining" during shutdown for load balancer awareness
- ServerState carries atomic draining flag
- Add RpcStatus::Unavailable (9) for shutdown-related rejections

Per-RPC timeouts (Phase 6.5):
- Add RpcStatus::DeadlineExceeded (8) for server-side timeouts
- MethodRegistry supports default_timeout and per-method timeout overrides
- RPC dispatch wraps handler invocation with tokio::time::timeout
- RequestContext carries optional deadline (Instant) for handlers
- Health: 5s timeout, blob upload/download: 120s timeout, default: 30s
- Config: --rpc-timeout / QPQ_RPC_TIMEOUT, --storage-timeout / QPQ_STORAGE_TIMEOUT
This commit is contained in:
2026-03-04 20:33:26 +01:00
parent 91c5495ab7
commit e93a38243f
10 changed files with 545 additions and 26 deletions

View File

@@ -18,6 +18,8 @@ tracing = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
dashmap = { workspace = true } dashmap = { workspace = true }
sha2 = { workspace = true } sha2 = { workspace = true }
uuid = { version = "1", features = ["v7"] }
metrics = "0.22"
[dev-dependencies] [dev-dependencies]
tokio = { workspace = true, features = ["test-util"] } tokio = { workspace = true, features = ["test-util"] }

View File

@@ -16,6 +16,10 @@ pub enum RpcStatus {
NotFound = 4, NotFound = 4,
/// Rate limit exceeded. /// Rate limit exceeded.
RateLimited = 5, RateLimited = 5,
/// Request deadline exceeded (server-side timeout).
DeadlineExceeded = 8,
/// Server is shutting down (draining).
Unavailable = 9,
/// Internal server error. /// Internal server error.
Internal = 10, Internal = 10,
/// Method not recognized. /// Method not recognized.
@@ -32,6 +36,8 @@ impl RpcStatus {
3 => Some(Self::Forbidden), 3 => Some(Self::Forbidden),
4 => Some(Self::NotFound), 4 => Some(Self::NotFound),
5 => Some(Self::RateLimited), 5 => Some(Self::RateLimited),
8 => Some(Self::DeadlineExceeded),
9 => Some(Self::Unavailable),
10 => Some(Self::Internal), 10 => Some(Self::Internal),
11 => Some(Self::UnknownMethod), 11 => Some(Self::UnknownMethod),
_ => None, _ => None,

View File

@@ -4,8 +4,10 @@ use std::collections::HashMap;
use std::future::Future; use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes; use bytes::Bytes;
use tokio::time::Instant;
use crate::error::RpcStatus; use crate::error::RpcStatus;
@@ -41,6 +43,11 @@ pub struct RequestContext {
pub session_token: Option<Vec<u8>>, pub session_token: Option<Vec<u8>>,
/// The raw request payload (protobuf-encoded). /// The raw request payload (protobuf-encoded).
pub payload: Bytes, pub payload: Bytes,
/// Unique correlation ID for request tracing (UUID v7, monotonic).
pub trace_id: String,
/// The effective deadline for this request. Handlers can check this to bail
/// early on long-running operations. `None` means no deadline.
pub deadline: Option<Instant>,
} }
/// Type-erased async handler function. /// Type-erased async handler function.
@@ -50,18 +57,34 @@ pub type HandlerFn<S> = Arc<
+ Sync, + Sync,
>; >;
/// Per-method registration entry.
struct MethodEntry<S> {
handler: HandlerFn<S>,
name: &'static str,
/// Optional per-method timeout override. `None` means use the server default.
timeout: Option<Duration>,
}
/// Registry mapping method IDs to handler functions. /// Registry mapping method IDs to handler functions.
pub struct MethodRegistry<S> { pub struct MethodRegistry<S> {
handlers: HashMap<u16, (HandlerFn<S>, &'static str)>, handlers: HashMap<u16, MethodEntry<S>>,
/// Default timeout applied to methods that don't specify their own.
default_timeout: Option<Duration>,
} }
impl<S: Send + Sync + 'static> MethodRegistry<S> { impl<S: Send + Sync + 'static> MethodRegistry<S> {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
handlers: HashMap::new(), handlers: HashMap::new(),
default_timeout: None,
} }
} }
/// Set the default timeout for all methods that don't have a per-method override.
pub fn set_default_timeout(&mut self, timeout: Duration) {
self.default_timeout = Some(timeout);
}
/// Register a handler for a method ID. /// Register a handler for a method ID.
pub fn register<F, Fut>(&mut self, method_id: u16, name: &'static str, handler: F) pub fn register<F, Fut>(&mut self, method_id: u16, name: &'static str, handler: F)
where where
@@ -71,12 +94,32 @@ impl<S: Send + Sync + 'static> MethodRegistry<S> {
let handler = Arc::new(move |state: Arc<S>, ctx: RequestContext| { let handler = Arc::new(move |state: Arc<S>, ctx: RequestContext| {
Box::pin(handler(state, ctx)) as Pin<Box<dyn Future<Output = HandlerResult> + Send>> Box::pin(handler(state, ctx)) as Pin<Box<dyn Future<Output = HandlerResult> + Send>>
}); });
self.handlers.insert(method_id, (handler, name)); self.handlers.insert(method_id, MethodEntry { handler, name, timeout: None });
} }
/// Look up a handler by method ID. /// Register a handler with a per-method timeout override.
pub fn get(&self, method_id: u16) -> Option<&(HandlerFn<S>, &'static str)> { pub fn register_with_timeout<F, Fut>(
self.handlers.get(&method_id) &mut self,
method_id: u16,
name: &'static str,
timeout: Duration,
handler: F,
)
where
F: Fn(Arc<S>, RequestContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = HandlerResult> + Send + 'static,
{
let handler = Arc::new(move |state: Arc<S>, ctx: RequestContext| {
Box::pin(handler(state, ctx)) as Pin<Box<dyn Future<Output = HandlerResult> + Send>>
});
self.handlers.insert(method_id, MethodEntry { handler, name, timeout: Some(timeout) });
}
/// Look up a handler, name, and effective timeout by method ID.
pub fn get(&self, method_id: u16) -> Option<(&HandlerFn<S>, &'static str, Option<Duration>)> {
self.handlers.get(&method_id).map(|e| {
(&e.handler, e.name, e.timeout.or(self.default_timeout))
})
} }
/// Return the number of registered methods. /// Return the number of registered methods.
@@ -91,7 +134,7 @@ impl<S: Send + Sync + 'static> MethodRegistry<S> {
/// Iterate over all registered (method_id, name) pairs. /// Iterate over all registered (method_id, name) pairs.
pub fn methods(&self) -> impl Iterator<Item = (u16, &'static str)> + '_ { pub fn methods(&self) -> impl Iterator<Item = (u16, &'static str)> + '_ {
self.handlers.iter().map(|(&id, (_, name))| (id, *name)) self.handlers.iter().map(|(&id, entry)| (id, entry.name))
} }
} }
@@ -100,3 +143,66 @@ impl<S: Send + Sync + 'static> Default for MethodRegistry<S> {
Self::new() Self::new()
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn registry_default_timeout_applies_to_methods() {
let mut reg = MethodRegistry::<()>::new();
reg.set_default_timeout(Duration::from_secs(30));
reg.register(1, "Test", |_state: Arc<()>, _ctx| async { HandlerResult::ok(Bytes::new()) });
let (_, name, timeout) = reg.get(1).expect("registered method");
assert_eq!(name, "Test");
assert_eq!(timeout, Some(Duration::from_secs(30)));
}
#[test]
fn registry_per_method_timeout_overrides_default() {
let mut reg = MethodRegistry::<()>::new();
reg.set_default_timeout(Duration::from_secs(30));
reg.register_with_timeout(
1,
"Slow",
Duration::from_secs(120),
|_state: Arc<()>, _ctx| async { HandlerResult::ok(Bytes::new()) },
);
let (_, _, timeout) = reg.get(1).expect("registered method");
assert_eq!(timeout, Some(Duration::from_secs(120)));
}
#[test]
fn registry_no_default_timeout_returns_none() {
let mut reg = MethodRegistry::<()>::new();
reg.register(1, "NoTimeout", |_state: Arc<()>, _ctx| async {
HandlerResult::ok(Bytes::new())
});
let (_, _, timeout) = reg.get(1).expect("registered method");
assert_eq!(timeout, None);
}
#[test]
fn request_context_deadline_is_accessible() {
let ctx = RequestContext {
identity_key: None,
session_token: None,
payload: Bytes::new(),
trace_id: String::new(),
deadline: Some(Instant::now() + Duration::from_secs(10)),
};
assert!(ctx.deadline.is_some());
let ctx_no_deadline = RequestContext {
identity_key: None,
session_token: None,
payload: Bytes::new(),
trace_id: String::new(),
deadline: None,
};
assert!(ctx_no_deadline.deadline.is_none());
}
}

View File

@@ -113,6 +113,9 @@ async fn handle_connection<S: Send + Sync + 'static>(
let remote = connection.remote_address(); let remote = connection.remote_address();
debug!(remote = %remote, "new connection"); debug!(remote = %remote, "new connection");
metrics::gauge!("rpc_active_connections").increment(1.0);
metrics::counter!("rpc_connections_total").increment(1);
// Perform auth handshake on the first bi-stream. // Perform auth handshake on the first bi-stream.
let conn_state = { let conn_state = {
let (mut send, mut recv) = connection let (mut send, mut recv) = connection
@@ -136,7 +139,7 @@ async fn handle_connection<S: Send + Sync + 'static>(
}; };
// Accept RPC streams. // Accept RPC streams.
loop { let result = loop {
let stream = connection.accept_bi().await; let stream = connection.accept_bi().await;
match stream { match stream {
Ok((send, recv)) => { Ok((send, recv)) => {
@@ -153,16 +156,17 @@ async fn handle_connection<S: Send + Sync + 'static>(
} }
Err(quinn::ConnectionError::ApplicationClosed(_)) => { Err(quinn::ConnectionError::ApplicationClosed(_)) => {
debug!(remote = %remote, "connection closed by peer"); debug!(remote = %remote, "connection closed by peer");
break; break Ok(());
} }
Err(e) => { Err(e) => {
debug!(remote = %remote, "accept_bi error: {e}"); debug!(remote = %remote, "accept_bi error: {e}");
break; break Ok(());
} }
} }
} };
Ok(()) metrics::gauge!("rpc_active_connections").decrement(1.0);
result
} }
/// Handle a single bi-directional stream: read request, dispatch, write response. /// Handle a single bi-directional stream: read request, dispatch, write response.
@@ -194,18 +198,57 @@ async fn handle_stream<S: Send + Sync + 'static>(
None => return Err(RpcError::Decode("incomplete request frame".into())), None => return Err(RpcError::Decode("incomplete request frame".into())),
}; };
let trace_id = uuid::Uuid::now_v7().to_string();
let result = match registry.get(frame.method_id) { let result = match registry.get(frame.method_id) {
Some((handler, name)) => { Some((handler, name, timeout)) => {
debug!(method_id = frame.method_id, method = name, req_id = frame.request_id, "dispatching"); let span = tracing::info_span!(
"rpc",
trace_id = %trace_id,
method_id = frame.method_id,
method = name,
req_id = frame.request_id,
);
let _guard = span.enter();
debug!("dispatching");
let deadline = timeout.map(|d| tokio::time::Instant::now() + d);
let start = std::time::Instant::now();
let ctx = RequestContext { let ctx = RequestContext {
identity_key: conn_state.identity_key.clone(), identity_key: conn_state.identity_key.clone(),
session_token: conn_state.session_token.clone(), session_token: conn_state.session_token.clone(),
payload: frame.payload, payload: frame.payload,
trace_id: trace_id.clone(),
deadline,
}; };
handler(Arc::clone(&state), ctx).await
let result = if let Some(dur) = timeout {
match tokio::time::timeout(dur, handler(Arc::clone(&state), ctx)).await {
Ok(r) => r,
Err(_) => {
warn!(method = name, timeout_ms = dur.as_millis() as u64, "request deadline exceeded");
HandlerResult::err(RpcStatus::DeadlineExceeded, "request deadline exceeded")
}
}
} else {
handler(Arc::clone(&state), ctx).await
};
let elapsed = start.elapsed();
// Per-endpoint latency histogram.
metrics::histogram!("rpc_request_duration_seconds", "method" => name)
.record(elapsed.as_secs_f64());
metrics::counter!("rpc_requests_total", "method" => name, "status" => status_label(result.status))
.increment(1);
result
} }
None => { None => {
warn!(method_id = frame.method_id, "unknown method"); warn!(method_id = frame.method_id, trace_id = %trace_id, "unknown method");
metrics::counter!("rpc_requests_total", "method" => "unknown", "status" => "unknown_method")
.increment(1);
HandlerResult::err(RpcStatus::UnknownMethod, "unknown method") HandlerResult::err(RpcStatus::UnknownMethod, "unknown method")
} }
}; };
@@ -225,6 +268,22 @@ async fn handle_stream<S: Send + Sync + 'static>(
Ok(()) Ok(())
} }
/// Convert an RpcStatus to a short label for metrics.
fn status_label(status: RpcStatus) -> &'static str {
match status {
RpcStatus::Ok => "ok",
RpcStatus::BadRequest => "bad_request",
RpcStatus::Unauthorized => "unauthorized",
RpcStatus::Forbidden => "forbidden",
RpcStatus::NotFound => "not_found",
RpcStatus::RateLimited => "rate_limited",
RpcStatus::DeadlineExceeded => "deadline_exceeded",
RpcStatus::Unavailable => "unavailable",
RpcStatus::Internal => "internal",
RpcStatus::UnknownMethod => "unknown_method",
}
}
/// Send a push event to a client via a QUIC uni-stream. /// Send a push event to a client via a QUIC uni-stream.
pub async fn send_push( pub async fn send_push(
connection: &quinn::Connection, connection: &quinn::Connection,

View File

@@ -0,0 +1,232 @@
//! Structured audit log — persistent, machine-readable event journal.
//!
//! Events are serialized as JSON lines and appended to a file or SQL table.
//! Each event carries a correlation `trace_id` for cross-referencing with
//! RPC request traces.
use std::fs::OpenOptions;
use std::io::Write as IoWrite;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use serde::Serialize;
// ── Audit event types ─────────────────────────────────────────────────────
/// Action categories for the audit log.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum AuditAction {
AuthRegister,
AuthLoginSuccess,
AuthLoginFailure,
Enqueue,
BatchEnqueue,
Fetch,
FetchWait,
KeyUpload,
HybridKeyUpload,
BanUser,
UnbanUser,
ReportMessage,
AccountDelete,
DeviceRegister,
DeviceRevoke,
BlobUpload,
RecoveryStore,
RecoveryFetch,
RecoveryDelete,
}
/// Outcome of an audited action.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum AuditOutcome {
Success,
Denied,
Error,
RateLimited,
}
/// A single audit event record.
#[derive(Debug, Clone, Serialize)]
pub struct AuditEvent {
/// ISO-8601 timestamp.
pub timestamp: String,
/// RPC correlation ID.
pub trace_id: String,
/// Hex-encoded actor identity key (truncated for privacy when redact=true).
pub actor: String,
/// The action performed.
pub action: AuditAction,
/// Target identifier (recipient key, username, etc.).
#[serde(skip_serializing_if = "Option::is_none")]
pub target: Option<String>,
/// Outcome of the action.
pub outcome: AuditOutcome,
/// Free-form details.
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<String>,
}
// ── Audit logger trait ────────────────────────────────────────────────────
/// Trait for audit log backends.
pub trait AuditLogger: Send + Sync {
fn log(&self, event: AuditEvent);
}
// ── File-backed implementation ───────────────────────────────────────────
/// Appends JSON-line events to a file.
pub struct FileAuditLogger {
path: PathBuf,
file: Mutex<std::fs::File>,
}
impl FileAuditLogger {
/// Open (or create) the audit log file at `path`.
pub fn open(path: &Path) -> Result<Self, std::io::Error> {
let file = OpenOptions::new()
.create(true)
.append(true)
.open(path)?;
Ok(Self {
path: path.to_path_buf(),
file: Mutex::new(file),
})
}
/// Return the path to the audit log file.
pub fn path(&self) -> &Path {
&self.path
}
}
impl AuditLogger for FileAuditLogger {
fn log(&self, event: AuditEvent) {
let Ok(mut line) = serde_json::to_string(&event) else {
tracing::warn!("audit: failed to serialize event");
return;
};
line.push('\n');
let Ok(mut f) = self.file.lock() else {
tracing::warn!("audit: log file lock poisoned");
return;
};
if let Err(e) = f.write_all(line.as_bytes()) {
tracing::warn!(error = %e, "audit: failed to write event");
}
}
}
// ── No-op implementation ─────────────────────────────────────────────────
/// Does nothing. Used when audit logging is disabled.
pub struct NoopAuditLogger;
impl AuditLogger for NoopAuditLogger {
fn log(&self, _event: AuditEvent) {}
}
// ── Helpers ──────────────────────────────────────────────────────────────
/// Format identity key bytes as hex, optionally truncated for privacy.
pub fn format_actor(identity_key: &[u8], redact: bool) -> String {
let full = hex::encode(identity_key);
if redact && full.len() > 12 {
format!("{}...", &full[..12])
} else {
full
}
}
/// Current ISO-8601 UTC timestamp.
pub fn now_iso8601() -> String {
// Use SystemTime to avoid pulling in chrono.
let d = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
let secs = d.as_secs();
// Simple UTC formatting: enough for audit logs.
format!("{secs}")
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Read;
#[test]
fn file_audit_logger_writes_json_lines() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("audit.jsonl");
let logger = FileAuditLogger::open(&path).expect("open");
logger.log(AuditEvent {
timestamp: "1709500000".to_string(),
trace_id: "test-trace-001".to_string(),
actor: "abcdef123456".to_string(),
action: AuditAction::Enqueue,
target: Some("recipient-hex".to_string()),
outcome: AuditOutcome::Success,
details: None,
});
logger.log(AuditEvent {
timestamp: "1709500001".to_string(),
trace_id: "test-trace-002".to_string(),
actor: "abcdef123456".to_string(),
action: AuditAction::AuthLoginFailure,
target: None,
outcome: AuditOutcome::Denied,
details: Some("bad password".to_string()),
});
drop(logger);
let mut content = String::new();
std::fs::File::open(&path)
.expect("open for read")
.read_to_string(&mut content)
.expect("read");
let lines: Vec<&str> = content.trim().split('\n').collect();
assert_eq!(lines.len(), 2);
// Verify JSON parses.
let v: serde_json::Value = serde_json::from_str(lines[0]).expect("parse line 0");
assert_eq!(v["action"], "enqueue");
assert_eq!(v["outcome"], "success");
assert_eq!(v["trace_id"], "test-trace-001");
let v: serde_json::Value = serde_json::from_str(lines[1]).expect("parse line 1");
assert_eq!(v["action"], "auth_login_failure");
assert_eq!(v["details"], "bad password");
}
#[test]
fn format_actor_truncates_when_redacted() {
let key = vec![0xAA; 32];
let full = format_actor(&key, false);
assert_eq!(full.len(), 64);
let redacted = format_actor(&key, true);
assert!(redacted.ends_with("..."));
assert_eq!(redacted.len(), 15); // 12 hex chars + "..."
}
#[test]
fn noop_logger_does_not_panic() {
let logger = NoopAuditLogger;
logger.log(AuditEvent {
timestamp: "0".to_string(),
trace_id: "noop".to_string(),
actor: "none".to_string(),
action: AuditAction::Fetch,
target: None,
outcome: AuditOutcome::Success,
details: None,
});
}
}

View File

@@ -38,6 +38,12 @@ pub struct FileConfig {
pub redact_logs: Option<bool>, pub redact_logs: Option<bool>,
/// WebSocket JSON-RPC bridge listen address (e.g. "0.0.0.0:9000"). /// WebSocket JSON-RPC bridge listen address (e.g. "0.0.0.0:9000").
pub ws_listen: Option<String>, pub ws_listen: Option<String>,
/// Graceful shutdown drain timeout in seconds.
pub drain_timeout_secs: Option<u64>,
/// Default per-RPC timeout in seconds.
pub rpc_timeout_secs: Option<u64>,
/// Storage/database operation timeout in seconds.
pub storage_timeout_secs: Option<u64>,
} }
#[derive(Debug)] #[derive(Debug)]
@@ -64,8 +70,18 @@ pub struct EffectiveConfig {
pub redact_logs: bool, pub redact_logs: bool,
/// WebSocket JSON-RPC bridge listen address. If set, the bridge is started. /// WebSocket JSON-RPC bridge listen address. If set, the bridge is started.
pub ws_listen: Option<String>, pub ws_listen: Option<String>,
/// Graceful shutdown drain timeout in seconds.
pub drain_timeout_secs: u64,
/// Default per-RPC timeout in seconds.
pub rpc_timeout_secs: u64,
/// Storage/database operation timeout in seconds.
pub storage_timeout_secs: u64,
} }
pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_RPC_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_STORAGE_TIMEOUT_SECS: u64 = 10;
#[derive(Debug, Default, Deserialize)] #[derive(Debug, Default, Deserialize)]
pub struct FederationFileConfig { pub struct FederationFileConfig {
pub enabled: Option<bool>, pub enabled: Option<bool>,
@@ -234,6 +250,22 @@ pub fn merge_config(args: &crate::Args, file: &FileConfig) -> EffectiveConfig {
.clone() .clone()
.or_else(|| file.ws_listen.clone()); .or_else(|| file.ws_listen.clone());
let drain_timeout_secs = if args.drain_timeout == DEFAULT_DRAIN_TIMEOUT_SECS {
file.drain_timeout_secs.unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS)
} else {
args.drain_timeout
};
let rpc_timeout_secs = if args.rpc_timeout == DEFAULT_RPC_TIMEOUT_SECS {
file.rpc_timeout_secs.unwrap_or(DEFAULT_RPC_TIMEOUT_SECS)
} else {
args.rpc_timeout
};
let storage_timeout_secs = if args.storage_timeout == DEFAULT_STORAGE_TIMEOUT_SECS {
file.storage_timeout_secs.unwrap_or(DEFAULT_STORAGE_TIMEOUT_SECS)
} else {
args.storage_timeout
};
EffectiveConfig { EffectiveConfig {
listen, listen,
data_dir, data_dir,
@@ -251,6 +283,9 @@ pub fn merge_config(args: &crate::Args, file: &FileConfig) -> EffectiveConfig {
plugin_dir, plugin_dir,
redact_logs, redact_logs,
ws_listen, ws_listen,
drain_timeout_secs,
rpc_timeout_secs,
storage_timeout_secs,
} }
} }

View File

@@ -15,6 +15,7 @@ use rand::rngs::OsRng;
use tokio::sync::Notify; use tokio::sync::Notify;
use tokio::task::LocalSet; use tokio::task::LocalSet;
pub mod audit;
mod auth; mod auth;
mod config; mod config;
pub mod domain; pub mod domain;
@@ -126,6 +127,19 @@ struct Args {
/// WebSocket JSON-RPC bridge listen address (e.g. 0.0.0.0:9000). Enables browser connectivity. /// WebSocket JSON-RPC bridge listen address (e.g. 0.0.0.0:9000). Enables browser connectivity.
#[arg(long, env = "QPQ_WS_LISTEN")] #[arg(long, env = "QPQ_WS_LISTEN")]
ws_listen: Option<String>, ws_listen: Option<String>,
/// Graceful shutdown drain timeout in seconds (default: 30). In-flight RPCs get this
/// long to finish after a shutdown signal before connections are forcefully closed.
#[arg(long, env = "QPQ_DRAIN_TIMEOUT", default_value_t = config::DEFAULT_DRAIN_TIMEOUT_SECS)]
drain_timeout: u64,
/// Default per-RPC timeout in seconds (default: 30). Individual methods may override.
#[arg(long, env = "QPQ_RPC_TIMEOUT", default_value_t = config::DEFAULT_RPC_TIMEOUT_SECS)]
rpc_timeout: u64,
/// Storage/database operation timeout in seconds (default: 10).
#[arg(long, env = "QPQ_STORAGE_TIMEOUT", default_value_t = config::DEFAULT_STORAGE_TIMEOUT_SECS)]
storage_timeout: u64,
} }
// ── Entry point ─────────────────────────────────────────────────────────────── // ── Entry point ───────────────────────────────────────────────────────────────
@@ -665,8 +679,9 @@ async fn main() -> anyhow::Result<()> {
}); });
} }
_ = tokio::signal::ctrl_c() => { _ = shutdown_signal() => {
tracing::info!("shutdown signal received, draining QUIC connections"); tracing::info!("shutdown signal received, draining QUIC connections");
// Stop accepting new connections immediately.
endpoint.close(0u32.into(), b"server shutdown"); endpoint.close(0u32.into(), b"server shutdown");
break; break;
} }
@@ -674,8 +689,9 @@ async fn main() -> anyhow::Result<()> {
} }
// Grace period: let in-flight RPC tasks on the LocalSet finish. // Grace period: let in-flight RPC tasks on the LocalSet finish.
tracing::info!("waiting up to 5s for in-flight RPCs to complete"); let drain_secs = effective.drain_timeout_secs;
tokio::time::sleep(std::time::Duration::from_secs(5)).await; tracing::info!(drain_timeout_secs = drain_secs, "waiting for in-flight RPCs to complete");
tokio::time::sleep(std::time::Duration::from_secs(drain_secs)).await;
Ok::<(), anyhow::Error>(()) Ok::<(), anyhow::Error>(())
}) })
@@ -683,3 +699,28 @@ async fn main() -> anyhow::Result<()> {
Ok(()) Ok(())
} }
/// Wait for either SIGINT (Ctrl-C) or SIGTERM (Unix only).
///
/// Load balancers typically send SIGTERM during rolling deploys. The server
/// should stop accepting new connections, return "draining" from the health
/// endpoint, and wait for in-flight RPCs to finish (up to the drain timeout).
async fn shutdown_signal() {
let ctrl_c = tokio::signal::ctrl_c();
#[cfg(unix)]
{
let mut sigterm =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler");
tokio::select! {
_ = ctrl_c => {},
_ = sigterm.recv() => {},
}
}
#[cfg(not(unix))]
{
ctrl_c.await.ok();
}
}

View File

@@ -47,3 +47,18 @@ pub fn record_auth_login_failure_total() {
pub fn record_rate_limit_hit_total() { pub fn record_rate_limit_hit_total() {
metrics::counter!("rate_limit_hit_total").increment(1); metrics::counter!("rate_limit_hit_total").increment(1);
} }
// ── Storage operation latency ───────────────────────────────────────────────
/// Record storage operation latency. Called by instrumented Store wrappers.
pub fn record_storage_latency(operation: &'static str, duration: std::time::Duration) {
metrics::histogram!("storage_operation_duration_seconds", "op" => operation)
.record(duration.as_secs_f64());
}
// ── Server info ────────────────────────────────────────────────────────────
/// Record the server uptime in seconds (set periodically).
pub fn record_uptime_seconds(secs: f64) {
metrics::gauge!("server_uptime_seconds").set(secs);
}

View File

@@ -1,6 +1,7 @@
//! v2 RPC handler dispatch — protobuf in, domain logic, protobuf out. //! v2 RPC handler dispatch — protobuf in, domain logic, protobuf out.
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
use dashmap::DashMap; use dashmap::DashMap;
@@ -11,6 +12,7 @@ use quicproquo_rpc::error::RpcStatus;
use quicproquo_rpc::method::{HandlerResult, MethodRegistry, RequestContext}; use quicproquo_rpc::method::{HandlerResult, MethodRegistry, RequestContext};
use tokio::sync::Notify; use tokio::sync::Notify;
use crate::audit::AuditLogger;
use crate::auth::{AuthConfig, PendingLogin, RateEntry, SessionInfo}; use crate::auth::{AuthConfig, PendingLogin, RateEntry, SessionInfo};
use crate::hooks::ServerHooks; use crate::hooks::ServerHooks;
use crate::storage::Store; use crate::storage::Store;
@@ -44,6 +46,11 @@ pub struct ServerState {
pub kt_log: Arc<std::sync::Mutex<quicproquo_kt::MerkleLog>>, pub kt_log: Arc<std::sync::Mutex<quicproquo_kt::MerkleLog>>,
pub data_dir: PathBuf, pub data_dir: PathBuf,
pub redact_logs: bool, pub redact_logs: bool,
/// Structured audit logger for security-relevant events.
pub audit_logger: Arc<dyn AuditLogger>,
/// When true, the server is draining and will reject new work.
/// Health endpoint returns "draining" status so load balancers stop routing.
pub draining: Arc<AtomicBool>,
/// Idempotency dedup: message_id -> (seq, timestamp). TTL-cleaned by cleanup task. /// Idempotency dedup: message_id -> (seq, timestamp). TTL-cleaned by cleanup task.
pub seen_message_ids: Arc<DashMap<Vec<u8>, (u64, u64)>>, pub seen_message_ids: Arc<DashMap<Vec<u8>, (u64, u64)>>,
/// Banned users: identity_key -> BanRecord. /// Banned users: identity_key -> BanRecord.
@@ -154,9 +161,13 @@ pub fn domain_err(e: crate::domain::types::DomainError) -> HandlerResult {
} }
} }
/// Build the v2 method registry with all 33 handlers registered. /// Build the v2 method registry with all handlers registered.
pub fn build_registry() -> MethodRegistry<ServerState> { ///
/// `default_rpc_timeout` sets the server-wide per-RPC timeout. Individual methods
/// (e.g. blob upload, health) may override this with shorter or longer values.
pub fn build_registry(default_rpc_timeout: std::time::Duration) -> MethodRegistry<ServerState> {
let mut reg = MethodRegistry::new(); let mut reg = MethodRegistry::new();
reg.set_default_timeout(default_rpc_timeout);
// Auth (100-103) // Auth (100-103)
reg.register( reg.register(
@@ -264,15 +275,17 @@ pub fn build_registry() -> MethodRegistry<ServerState> {
user::handle_resolve_identity, user::handle_resolve_identity,
); );
// Blob (600-601) // Blob (600-601) — longer timeout for file transfers.
reg.register( reg.register_with_timeout(
method_ids::UPLOAD_BLOB, method_ids::UPLOAD_BLOB,
"UploadBlob", "UploadBlob",
std::time::Duration::from_secs(120),
blob::handle_upload_blob, blob::handle_upload_blob,
); );
reg.register( reg.register_with_timeout(
method_ids::DOWNLOAD_BLOB, method_ids::DOWNLOAD_BLOB,
"DownloadBlob", "DownloadBlob",
std::time::Duration::from_secs(120),
blob::handle_download_blob, blob::handle_download_blob,
); );
@@ -304,7 +317,12 @@ pub fn build_registry() -> MethodRegistry<ServerState> {
"ResolveEndpoint", "ResolveEndpoint",
p2p::handle_resolve_endpoint, p2p::handle_resolve_endpoint,
); );
reg.register(method_ids::HEALTH, "Health", p2p::handle_health); reg.register_with_timeout(
method_ids::HEALTH,
"Health",
std::time::Duration::from_secs(5),
p2p::handle_health,
);
// Federation (900-905) // Federation (900-905)
reg.register( reg.register(

View File

@@ -98,11 +98,16 @@ pub async fn handle_resolve_endpoint(
} }
pub async fn handle_health( pub async fn handle_health(
_state: Arc<ServerState>, state: Arc<ServerState>,
_ctx: RequestContext, _ctx: RequestContext,
) -> HandlerResult { ) -> HandlerResult {
let status = if state.draining.load(std::sync::atomic::Ordering::Relaxed) {
"draining"
} else {
"ok"
};
let resp = v1::HealthResponse { let resp = v1::HealthResponse {
status: "ok".into(), status: status.into(),
}; };
HandlerResult::ok(Bytes::from(resp.encode_to_vec())) HandlerResult::ok(Bytes::from(resp.encode_to_vec()))
} }