Files
quicproquo/crates/quicprochat-server/src/hooks.rs
Christian Nennemann 66eca065e0 feat: add in-flight RPC tracking, plugin shutdown hooks, and graceful drain
Replace the fixed 30s sleep-based shutdown drain with actual in-flight RPC
tracking using an Arc<AtomicUsize> counter and RAII InFlightGuard. On
SIGTERM/SIGINT the server now:

1. Stops accepting new client and federation connections
2. Sends QUIC CONNECTION_CLOSE with reason "server shutting down"
3. Polls the in-flight counter until it reaches 0 (or drain timeout)
4. Logs drain progress as RPCs complete
5. Calls plugin on_shutdown hooks before exit

Also adds:
- on_shutdown hook to HookVTable (C-ABI plugin API) and ServerHooks trait
- server_in_flight_rpcs Prometheus gauge metric
- Federation connection tracking via shared in-flight counter
2026-03-21 19:14:06 +01:00

209 lines
6.4 KiB
Rust

//! Server-side plugin hooks for extending quicprochat.
//!
//! Implement the [`ServerHooks`] trait to intercept server events — message delivery,
//! authentication, channel creation, and more. Hooks fire after validation but before
//! storage, so they can inspect, log, or reject operations.
//!
//! # Built-in implementations
//!
//! - [`NoopHooks`] — does nothing (default when no hooks are configured)
//! - [`TracingHooks`] — logs all events via `tracing` at info/debug level
//!
//! # Writing a custom hook
//!
//! ```rust,ignore
//! use quicprochat_server::hooks::{ServerHooks, HookAction, MessageEvent};
//!
//! struct ModeratorHook {
//! banned_words: Vec<String>,
//! }
//!
//! impl ServerHooks for ModeratorHook {
//! fn on_message_enqueue(&self, event: &MessageEvent) -> HookAction {
//! // Can't inspect encrypted content (E2E), but can enforce rate limits,
//! // payload size limits, or sender restrictions.
//! if event.payload_len > 1_000_000 {
//! return HookAction::Reject("payload too large".into());
//! }
//! HookAction::Continue
//! }
//! }
//! ```
/// The result of a hook invocation.
#[derive(Clone, Debug)]
pub enum HookAction {
/// Allow the operation to proceed.
Continue,
/// Reject the operation with a reason (returned to the client as an error).
Reject(String),
}
/// Event data for message enqueue operations.
#[derive(Clone, Debug)]
pub struct MessageEvent {
/// Sender's identity key (32 bytes), if known (None in sealed sender mode).
pub sender_identity: Option<Vec<u8>>,
/// Recipient's identity key (32 bytes).
pub recipient_key: Vec<u8>,
/// Channel ID (16 bytes) if this is a DM channel message.
pub channel_id: Vec<u8>,
/// Length of the encrypted payload in bytes.
pub payload_len: usize,
/// Server-assigned sequence number.
pub seq: u64,
}
/// Event data for authentication operations.
#[derive(Clone, Debug)]
pub struct AuthEvent {
/// The username attempting to authenticate.
pub username: String,
/// Whether the authentication succeeded.
pub success: bool,
/// Failure reason (empty on success).
pub failure_reason: String,
}
/// Event data for channel creation operations.
#[derive(Clone, Debug)]
pub struct ChannelEvent {
/// The channel's unique ID (16 bytes).
pub channel_id: Vec<u8>,
/// Identity key of the initiator.
pub initiator_key: Vec<u8>,
/// Identity key of the peer.
pub peer_key: Vec<u8>,
/// True if this is a newly created channel (initiator creates the MLS group).
pub was_new: bool,
}
/// Event data for message fetch operations.
#[derive(Clone, Debug)]
pub struct FetchEvent {
/// Identity key of the fetcher.
pub recipient_key: Vec<u8>,
/// Channel ID being fetched from.
pub channel_id: Vec<u8>,
/// Number of messages returned.
pub message_count: usize,
}
/// Trait for server-side plugin hooks.
///
/// All methods have default implementations that return [`HookAction::Continue`],
/// so you only need to override the events you care about.
///
/// Hooks are called synchronously in the RPC handler path. Keep them fast —
/// offload heavy work (HTTP calls, disk I/O) to background tasks.
pub trait ServerHooks: Send + Sync {
/// Called after validation, before a message is stored in the delivery queue.
///
/// Return `HookAction::Reject` to prevent delivery.
fn on_message_enqueue(&self, _event: &MessageEvent) -> HookAction {
HookAction::Continue
}
/// Called after a batch of messages is enqueued.
fn on_batch_enqueue(&self, _events: &[MessageEvent]) {
// Default: no-op
}
/// Called after a successful or failed login attempt.
fn on_auth(&self, _event: &AuthEvent) {
// Default: no-op
}
/// Called after a channel is created or looked up.
fn on_channel_created(&self, _event: &ChannelEvent) {
// Default: no-op
}
/// Called after messages are fetched from the delivery queue.
fn on_fetch(&self, _event: &FetchEvent) {
// Default: no-op
}
/// Called when a user registers (OPAQUE registration complete).
fn on_user_registered(&self, _username: &str, _identity_key: &[u8]) {
// Default: no-op
}
/// Called when the server is shutting down, before connections are closed.
/// Plugins can flush buffers, close external connections, or perform cleanup.
fn on_shutdown(&self) {
// Default: no-op
}
}
/// No-op hook implementation (default).
pub struct NoopHooks;
impl ServerHooks for NoopHooks {}
/// Hook implementation that logs all events via `tracing`.
pub struct TracingHooks;
impl ServerHooks for TracingHooks {
fn on_message_enqueue(&self, event: &MessageEvent) -> HookAction {
tracing::info!(
recipient_prefix = %hex_prefix(&event.recipient_key),
payload_len = event.payload_len,
seq = event.seq,
has_sender = event.sender_identity.is_some(),
"hook: message enqueued"
);
HookAction::Continue
}
fn on_batch_enqueue(&self, events: &[MessageEvent]) {
tracing::info!(
count = events.len(),
"hook: batch enqueue"
);
}
fn on_auth(&self, event: &AuthEvent) {
if event.success {
tracing::info!(username = %event.username, "hook: login success");
} else {
tracing::warn!(
username = %event.username,
reason = %event.failure_reason,
"hook: login failure"
);
}
}
fn on_channel_created(&self, event: &ChannelEvent) {
tracing::info!(
channel_id = %hex_prefix(&event.channel_id),
was_new = event.was_new,
"hook: channel created"
);
}
fn on_fetch(&self, event: &FetchEvent) {
if event.message_count > 0 {
tracing::debug!(
recipient_prefix = %hex_prefix(&event.recipient_key),
count = event.message_count,
"hook: messages fetched"
);
}
}
fn on_user_registered(&self, username: &str, _identity_key: &[u8]) {
tracing::info!(username = %username, "hook: user registered");
}
fn on_shutdown(&self) {
tracing::info!("hook: server shutting down");
}
}
fn hex_prefix(bytes: &[u8]) -> String {
let n = bytes.len().min(4);
hex::encode(&bytes[..n])
}