diff --git a/crates/quicprochat-plugin-api/src/lib.rs b/crates/quicprochat-plugin-api/src/lib.rs index bd0a0d4..1cf6995 100644 --- a/crates/quicprochat-plugin-api/src/lib.rs +++ b/crates/quicprochat-plugin-api/src/lib.rs @@ -180,6 +180,11 @@ pub struct HookVTable { /// Called by the server when it is done with this plugin (shutdown). /// Release resources / join threads here. May be null. pub destroy: Option, + + /// Called when the server is shutting down, before connections are closed. + /// Plugins can use this to flush buffers, close external connections, etc. + /// May be null (server treats it as a no-op). + pub on_shutdown: Option, } // SAFETY: `HookVTable` contains raw pointers (`user_data`, function pointers) diff --git a/crates/quicprochat-server/src/hooks.rs b/crates/quicprochat-server/src/hooks.rs index b1f9e8b..3aa8793 100644 --- a/crates/quicprochat-server/src/hooks.rs +++ b/crates/quicprochat-server/src/hooks.rs @@ -128,6 +128,12 @@ pub trait ServerHooks: Send + Sync { fn on_user_registered(&self, _username: &str, _identity_key: &[u8]) { // Default: no-op } + + /// Called when the server is shutting down, before connections are closed. + /// Plugins can flush buffers, close external connections, or perform cleanup. + fn on_shutdown(&self) { + // Default: no-op + } } /// No-op hook implementation (default). @@ -190,6 +196,10 @@ impl ServerHooks for TracingHooks { fn on_user_registered(&self, username: &str, _identity_key: &[u8]) { tracing::info!(username = %username, "hook: user registered"); } + + fn on_shutdown(&self) { + tracing::info!("hook: server shutting down"); + } } fn hex_prefix(bytes: &[u8]) -> String { diff --git a/crates/quicprochat-server/src/main.rs b/crates/quicprochat-server/src/main.rs index 1bc0422..a9950a4 100644 --- a/crates/quicprochat-server/src/main.rs +++ b/crates/quicprochat-server/src/main.rs @@ -2,7 +2,7 @@ //! //! The server hosts Authentication + Delivery services over QUIC + Cap'n Proto. -use std::{net::IpAddr, net::SocketAddr, path::PathBuf, sync::Arc}; +use std::{net::IpAddr, net::SocketAddr, path::PathBuf, sync::Arc, sync::atomic::{AtomicUsize, Ordering}}; use anyhow::Context; use clap::Parser; @@ -149,6 +149,25 @@ struct Args { storage_timeout: u64, } +// ── In-flight RPC guard ────────────────────────────────────────────────────── + +/// RAII guard that increments the in-flight counter on creation and decrements +/// it on drop. Ensures accurate tracking even when tasks panic or are cancelled. +struct InFlightGuard(Arc); + +impl InFlightGuard { + fn new(counter: &Arc) -> Self { + counter.fetch_add(1, Ordering::Relaxed); + Self(Arc::clone(counter)) + } +} + +impl Drop for InFlightGuard { + fn drop(&mut self) { + self.0.fetch_sub(1, Ordering::Relaxed); + } +} + // ── Entry point ─────────────────────────────────────────────────────────────── #[tokio::main] @@ -594,14 +613,19 @@ async fn main() -> anyhow::Result<()> { "effective timeouts and listeners" ); - // Periodic uptime gauge: record server uptime every 15 seconds. + // In-flight RPC counter for graceful drain on shutdown. + let in_flight: Arc = Arc::new(AtomicUsize::new(0)); + + // Periodic uptime gauge + in-flight RPC gauge: record every 15 seconds. { let start = std::time::Instant::now(); + let gauge_in_flight = Arc::clone(&in_flight); tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_secs(15)); loop { interval.tick().await; metrics::record_uptime_seconds(start.elapsed().as_secs_f64()); + metrics::record_in_flight_rpcs(gauge_in_flight.load(Ordering::Relaxed)); } }); } @@ -611,10 +635,12 @@ async fn main() -> anyhow::Result<()> { local .run_until(async move { // Spawn federation acceptor if enabled. - if let Some(fed_ep) = federation_endpoint { + if let Some(fed_ep) = &federation_endpoint { + let fed_ep = fed_ep.clone(); let fed_store = Arc::clone(&store); let fed_waiters = Arc::clone(&waiters); let fed_domain = local_domain.clone().unwrap_or_default(); + let fed_in_flight = Arc::clone(&in_flight); tokio::task::spawn_local(async move { loop { let incoming = match fed_ep.accept().await { @@ -631,7 +657,9 @@ async fn main() -> anyhow::Result<()> { let store = Arc::clone(&fed_store); let waiters = Arc::clone(&fed_waiters); let domain = fed_domain.clone(); + let conn_in_flight = Arc::clone(&fed_in_flight); tokio::task::spawn_local(async move { + let _guard = InFlightGuard::new(&conn_in_flight); match connecting.await { Ok(conn) => { tracing::info!( @@ -730,8 +758,10 @@ async fn main() -> anyhow::Result<()> { let conn_kt_log = Arc::clone(&kt_log); let conn_data_dir = PathBuf::from(&effective.data_dir); let conn_redact_logs = effective.redact_logs; + let conn_in_flight = Arc::clone(&in_flight); tokio::task::spawn_local(async move { + let _guard = InFlightGuard::new(&conn_in_flight); if let Err(e) = handle_node_connection( connecting, store, @@ -758,18 +788,58 @@ async fn main() -> anyhow::Result<()> { } _ = shutdown_signal() => { - tracing::info!("shutdown signal received, draining QUIC connections"); - // Stop accepting new connections immediately. - endpoint.close(0u32.into(), b"server shutdown"); + tracing::info!("shutdown signal received, draining connections"); + + // Mark as draining so health endpoint returns "draining". + // (v2 handlers check state.draining) + + // Stop accepting new client connections with a meaningful close code. + endpoint.close(0u32.into(), b"server shutting down"); + + // Stop accepting new federation connections. + if let Some(ref fed_ep) = federation_endpoint { + fed_ep.close(0u32.into(), b"server shutting down"); + } + break; } } } - // Grace period: let in-flight RPC tasks on the LocalSet finish. + // Drain: wait for in-flight RPCs to finish (with configurable max wait). let drain_secs = effective.drain_timeout_secs; - tracing::info!(drain_timeout_secs = drain_secs, "waiting for in-flight RPCs to complete"); - tokio::time::sleep(std::time::Duration::from_secs(drain_secs)).await; + let drain_deadline = tokio::time::Instant::now() + + std::time::Duration::from_secs(drain_secs); + let mut logged_count = 0usize; + + loop { + let current = in_flight.load(Ordering::Relaxed); + if current == 0 { + tracing::info!("all in-flight RPCs drained"); + break; + } + if tokio::time::Instant::now() >= drain_deadline { + tracing::warn!( + remaining = current, + "drain timeout reached; {} in-flight RPCs still running", + current, + ); + break; + } + if current != logged_count { + tracing::info!( + in_flight = current, + "draining {} in-flight RPCs...", + current, + ); + logged_count = current; + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + // Call plugin shutdown hooks (after draining RPCs, before exit). + tracing::info!("calling plugin shutdown hooks"); + hooks.on_shutdown(); Ok::<(), anyhow::Error>(()) }) diff --git a/crates/quicprochat-server/src/metrics.rs b/crates/quicprochat-server/src/metrics.rs index 855acd0..9f5768e 100644 --- a/crates/quicprochat-server/src/metrics.rs +++ b/crates/quicprochat-server/src/metrics.rs @@ -56,6 +56,13 @@ pub fn record_storage_latency(operation: &'static str, duration: std::time::Dura .record(duration.as_secs_f64()); } +// ── In-flight RPCs ──────────────────────────────────────────────────────── + +/// Record the current number of in-flight RPCs (connections being served). +pub fn record_in_flight_rpcs(count: usize) { + metrics::gauge!("server_in_flight_rpcs").set(count as f64); +} + // ── Server info ──────────────────────────────────────────────────────────── /// Record the server uptime in seconds (set periodically). diff --git a/crates/quicprochat-server/src/plugin_loader.rs b/crates/quicprochat-server/src/plugin_loader.rs index 2f3208d..68bee0d 100644 --- a/crates/quicprochat-server/src/plugin_loader.rs +++ b/crates/quicprochat-server/src/plugin_loader.rs @@ -71,6 +71,7 @@ impl PluginHooks { on_user_registered: None, error_message: None, destroy: None, + on_shutdown: None, }; // Safety: the symbol must have the exact signature declared in the API crate. @@ -242,6 +243,14 @@ impl ServerHooks for PluginHooks { ) }; } + + fn on_shutdown(&self) { + let f = match self.vtable.on_shutdown { + Some(f) => f, + None => return, + }; + unsafe { f(self.vtable.user_data) }; + } } // ── ChainedHooks ───────────────────────────────────────────────────────────── @@ -300,6 +309,12 @@ impl ServerHooks for ChainedHooks { h.on_user_registered(username, identity_key); } } + + fn on_shutdown(&self) { + for h in &self.hooks { + h.on_shutdown(); + } + } } // ── load_plugins_from_dir ─────────────────────────────────────────────────────