feat: add in-flight RPC tracking, plugin shutdown hooks, and graceful drain

Replace the fixed 30s sleep-based shutdown drain with actual in-flight RPC
tracking using an Arc<AtomicUsize> counter and RAII InFlightGuard. On
SIGTERM/SIGINT the server now:

1. Stops accepting new client and federation connections
2. Sends QUIC CONNECTION_CLOSE with reason "server shutting down"
3. Polls the in-flight counter until it reaches 0 (or drain timeout)
4. Logs drain progress as RPCs complete
5. Calls plugin on_shutdown hooks before exit

Also adds:
- on_shutdown hook to HookVTable (C-ABI plugin API) and ServerHooks trait
- server_in_flight_rpcs Prometheus gauge metric
- Federation connection tracking via shared in-flight counter
This commit is contained in:
2026-03-08 17:56:34 +01:00
parent a05da9b751
commit 66eca065e0
5 changed files with 116 additions and 9 deletions

View File

@@ -180,6 +180,11 @@ pub struct HookVTable {
/// Called by the server when it is done with this plugin (shutdown). /// Called by the server when it is done with this plugin (shutdown).
/// Release resources / join threads here. May be null. /// Release resources / join threads here. May be null.
pub destroy: Option<unsafe extern "C" fn(user_data: *mut core::ffi::c_void)>, pub destroy: Option<unsafe extern "C" fn(user_data: *mut core::ffi::c_void)>,
/// Called when the server is shutting down, before connections are closed.
/// Plugins can use this to flush buffers, close external connections, etc.
/// May be null (server treats it as a no-op).
pub on_shutdown: Option<unsafe extern "C" fn(user_data: *mut core::ffi::c_void)>,
} }
// SAFETY: `HookVTable` contains raw pointers (`user_data`, function pointers) // SAFETY: `HookVTable` contains raw pointers (`user_data`, function pointers)

View File

@@ -128,6 +128,12 @@ pub trait ServerHooks: Send + Sync {
fn on_user_registered(&self, _username: &str, _identity_key: &[u8]) { fn on_user_registered(&self, _username: &str, _identity_key: &[u8]) {
// Default: no-op // Default: no-op
} }
/// Called when the server is shutting down, before connections are closed.
/// Plugins can flush buffers, close external connections, or perform cleanup.
fn on_shutdown(&self) {
// Default: no-op
}
} }
/// No-op hook implementation (default). /// No-op hook implementation (default).
@@ -190,6 +196,10 @@ impl ServerHooks for TracingHooks {
fn on_user_registered(&self, username: &str, _identity_key: &[u8]) { fn on_user_registered(&self, username: &str, _identity_key: &[u8]) {
tracing::info!(username = %username, "hook: user registered"); tracing::info!(username = %username, "hook: user registered");
} }
fn on_shutdown(&self) {
tracing::info!("hook: server shutting down");
}
} }
fn hex_prefix(bytes: &[u8]) -> String { fn hex_prefix(bytes: &[u8]) -> String {

View File

@@ -2,7 +2,7 @@
//! //!
//! The server hosts Authentication + Delivery services over QUIC + Cap'n Proto. //! The server hosts Authentication + Delivery services over QUIC + Cap'n Proto.
use std::{net::IpAddr, net::SocketAddr, path::PathBuf, sync::Arc}; use std::{net::IpAddr, net::SocketAddr, path::PathBuf, sync::Arc, sync::atomic::{AtomicUsize, Ordering}};
use anyhow::Context; use anyhow::Context;
use clap::Parser; use clap::Parser;
@@ -149,6 +149,25 @@ struct Args {
storage_timeout: u64, storage_timeout: u64,
} }
// ── In-flight RPC guard ──────────────────────────────────────────────────────
/// RAII guard that increments the in-flight counter on creation and decrements
/// it on drop. Ensures accurate tracking even when tasks panic or are cancelled.
struct InFlightGuard(Arc<AtomicUsize>);
impl InFlightGuard {
fn new(counter: &Arc<AtomicUsize>) -> Self {
counter.fetch_add(1, Ordering::Relaxed);
Self(Arc::clone(counter))
}
}
impl Drop for InFlightGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Relaxed);
}
}
// ── Entry point ─────────────────────────────────────────────────────────────── // ── Entry point ───────────────────────────────────────────────────────────────
#[tokio::main] #[tokio::main]
@@ -594,14 +613,19 @@ async fn main() -> anyhow::Result<()> {
"effective timeouts and listeners" "effective timeouts and listeners"
); );
// Periodic uptime gauge: record server uptime every 15 seconds. // In-flight RPC counter for graceful drain on shutdown.
let in_flight: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
// Periodic uptime gauge + in-flight RPC gauge: record every 15 seconds.
{ {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let gauge_in_flight = Arc::clone(&in_flight);
tokio::spawn(async move { tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(15)); let mut interval = tokio::time::interval(std::time::Duration::from_secs(15));
loop { loop {
interval.tick().await; interval.tick().await;
metrics::record_uptime_seconds(start.elapsed().as_secs_f64()); metrics::record_uptime_seconds(start.elapsed().as_secs_f64());
metrics::record_in_flight_rpcs(gauge_in_flight.load(Ordering::Relaxed));
} }
}); });
} }
@@ -611,10 +635,12 @@ async fn main() -> anyhow::Result<()> {
local local
.run_until(async move { .run_until(async move {
// Spawn federation acceptor if enabled. // Spawn federation acceptor if enabled.
if let Some(fed_ep) = federation_endpoint { if let Some(fed_ep) = &federation_endpoint {
let fed_ep = fed_ep.clone();
let fed_store = Arc::clone(&store); let fed_store = Arc::clone(&store);
let fed_waiters = Arc::clone(&waiters); let fed_waiters = Arc::clone(&waiters);
let fed_domain = local_domain.clone().unwrap_or_default(); let fed_domain = local_domain.clone().unwrap_or_default();
let fed_in_flight = Arc::clone(&in_flight);
tokio::task::spawn_local(async move { tokio::task::spawn_local(async move {
loop { loop {
let incoming = match fed_ep.accept().await { let incoming = match fed_ep.accept().await {
@@ -631,7 +657,9 @@ async fn main() -> anyhow::Result<()> {
let store = Arc::clone(&fed_store); let store = Arc::clone(&fed_store);
let waiters = Arc::clone(&fed_waiters); let waiters = Arc::clone(&fed_waiters);
let domain = fed_domain.clone(); let domain = fed_domain.clone();
let conn_in_flight = Arc::clone(&fed_in_flight);
tokio::task::spawn_local(async move { tokio::task::spawn_local(async move {
let _guard = InFlightGuard::new(&conn_in_flight);
match connecting.await { match connecting.await {
Ok(conn) => { Ok(conn) => {
tracing::info!( tracing::info!(
@@ -730,8 +758,10 @@ async fn main() -> anyhow::Result<()> {
let conn_kt_log = Arc::clone(&kt_log); let conn_kt_log = Arc::clone(&kt_log);
let conn_data_dir = PathBuf::from(&effective.data_dir); let conn_data_dir = PathBuf::from(&effective.data_dir);
let conn_redact_logs = effective.redact_logs; let conn_redact_logs = effective.redact_logs;
let conn_in_flight = Arc::clone(&in_flight);
tokio::task::spawn_local(async move { tokio::task::spawn_local(async move {
let _guard = InFlightGuard::new(&conn_in_flight);
if let Err(e) = handle_node_connection( if let Err(e) = handle_node_connection(
connecting, connecting,
store, store,
@@ -758,18 +788,58 @@ async fn main() -> anyhow::Result<()> {
} }
_ = shutdown_signal() => { _ = shutdown_signal() => {
tracing::info!("shutdown signal received, draining QUIC connections"); tracing::info!("shutdown signal received, draining connections");
// Stop accepting new connections immediately.
endpoint.close(0u32.into(), b"server shutdown"); // Mark as draining so health endpoint returns "draining".
// (v2 handlers check state.draining)
// Stop accepting new client connections with a meaningful close code.
endpoint.close(0u32.into(), b"server shutting down");
// Stop accepting new federation connections.
if let Some(ref fed_ep) = federation_endpoint {
fed_ep.close(0u32.into(), b"server shutting down");
}
break; break;
} }
} }
} }
// Grace period: let in-flight RPC tasks on the LocalSet finish. // Drain: wait for in-flight RPCs to finish (with configurable max wait).
let drain_secs = effective.drain_timeout_secs; let drain_secs = effective.drain_timeout_secs;
tracing::info!(drain_timeout_secs = drain_secs, "waiting for in-flight RPCs to complete"); let drain_deadline = tokio::time::Instant::now()
tokio::time::sleep(std::time::Duration::from_secs(drain_secs)).await; + std::time::Duration::from_secs(drain_secs);
let mut logged_count = 0usize;
loop {
let current = in_flight.load(Ordering::Relaxed);
if current == 0 {
tracing::info!("all in-flight RPCs drained");
break;
}
if tokio::time::Instant::now() >= drain_deadline {
tracing::warn!(
remaining = current,
"drain timeout reached; {} in-flight RPCs still running",
current,
);
break;
}
if current != logged_count {
tracing::info!(
in_flight = current,
"draining {} in-flight RPCs...",
current,
);
logged_count = current;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
// Call plugin shutdown hooks (after draining RPCs, before exit).
tracing::info!("calling plugin shutdown hooks");
hooks.on_shutdown();
Ok::<(), anyhow::Error>(()) Ok::<(), anyhow::Error>(())
}) })

View File

@@ -56,6 +56,13 @@ pub fn record_storage_latency(operation: &'static str, duration: std::time::Dura
.record(duration.as_secs_f64()); .record(duration.as_secs_f64());
} }
// ── In-flight RPCs ────────────────────────────────────────────────────────
/// Record the current number of in-flight RPCs (connections being served).
pub fn record_in_flight_rpcs(count: usize) {
metrics::gauge!("server_in_flight_rpcs").set(count as f64);
}
// ── Server info ──────────────────────────────────────────────────────────── // ── Server info ────────────────────────────────────────────────────────────
/// Record the server uptime in seconds (set periodically). /// Record the server uptime in seconds (set periodically).

View File

@@ -71,6 +71,7 @@ impl PluginHooks {
on_user_registered: None, on_user_registered: None,
error_message: None, error_message: None,
destroy: None, destroy: None,
on_shutdown: None,
}; };
// Safety: the symbol must have the exact signature declared in the API crate. // Safety: the symbol must have the exact signature declared in the API crate.
@@ -242,6 +243,14 @@ impl ServerHooks for PluginHooks {
) )
}; };
} }
fn on_shutdown(&self) {
let f = match self.vtable.on_shutdown {
Some(f) => f,
None => return,
};
unsafe { f(self.vtable.user_data) };
}
} }
// ── ChainedHooks ───────────────────────────────────────────────────────────── // ── ChainedHooks ─────────────────────────────────────────────────────────────
@@ -300,6 +309,12 @@ impl ServerHooks for ChainedHooks {
h.on_user_registered(username, identity_key); h.on_user_registered(username, identity_key);
} }
} }
fn on_shutdown(&self) {
for h in &self.hooks {
h.on_shutdown();
}
}
} }
// ── load_plugins_from_dir ───────────────────────────────────────────────────── // ── load_plugins_from_dir ─────────────────────────────────────────────────────