From a258f98a40d67e8e1845221533883d17e0606c68 Mon Sep 17 00:00:00 2001 From: Christian Nennemann Date: Wed, 1 Apr 2026 09:19:13 +0200 Subject: [PATCH] feat(p2p): add persistence and graceful shutdown - persistence.rs: Append-only log storage for routing table, KeyPackage cache, and messages with compaction and GC - shutdown.rs: Coordinated shutdown with phase transitions, task tracking, connection draining, and hook system Enables stateful operation and clean restarts. --- crates/quicprochat-p2p/src/lib.rs | 2 + crates/quicprochat-p2p/src/persistence.rs | 693 ++++++++++++++++++++++ crates/quicprochat-p2p/src/shutdown.rs | 470 +++++++++++++++ 3 files changed, 1165 insertions(+) create mode 100644 crates/quicprochat-p2p/src/persistence.rs create mode 100644 crates/quicprochat-p2p/src/shutdown.rs diff --git a/crates/quicprochat-p2p/src/lib.rs b/crates/quicprochat-p2p/src/lib.rs index 13299ea..a7bccba 100644 --- a/crates/quicprochat-p2p/src/lib.rs +++ b/crates/quicprochat-p2p/src/lib.rs @@ -27,7 +27,9 @@ pub mod keypackage_cache; pub mod mesh_protocol; pub mod metrics; pub mod mls_lite; +pub mod persistence; pub mod rate_limit; +pub mod shutdown; pub mod identity; pub mod link; pub mod mesh_router; diff --git a/crates/quicprochat-p2p/src/persistence.rs b/crates/quicprochat-p2p/src/persistence.rs new file mode 100644 index 0000000..580b906 --- /dev/null +++ b/crates/quicprochat-p2p/src/persistence.rs @@ -0,0 +1,693 @@ +//! Persistence layer for mesh node state. +//! +//! This module provides durable storage for: +//! - Routing table entries +//! - KeyPackage cache +//! - Stored messages (store-and-forward) +//! - Node identity +//! +//! Uses a simple append-only log format with periodic compaction. + +use std::collections::HashMap; +use std::fs::{self, File, OpenOptions}; +use std::io::{self, BufRead, BufReader, BufWriter, Write}; +use std::path::{Path, PathBuf}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +use crate::address::MeshAddress; +use crate::error::{MeshResult, StoreError}; + +/// Storage entry types. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum StorageEntry { + /// Routing table entry. + Route { + address: [u8; 16], + next_hop: String, + hops: u8, + sequence: u32, + expires_at: u64, + }, + /// Remove a route. + RouteRemove { address: [u8; 16] }, + /// KeyPackage cache entry. + KeyPackage { + address: [u8; 16], + data: Vec, + hash: [u8; 8], + expires_at: u64, + }, + /// Remove a KeyPackage. + KeyPackageRemove { address: [u8; 16], hash: [u8; 8] }, + /// Stored message. + Message { + id: Vec, + recipient: [u8; 16], + data: Vec, + expires_at: u64, + }, + /// Remove a message. + MessageRemove { id: Vec }, + /// Identity keypair (encrypted or raw for development). + Identity { + public_key: Vec, + secret_key_encrypted: Vec, + }, +} + +/// Append-only log for persistence. +pub struct AppendLog { + path: PathBuf, + writer: Option>, + entries_since_compact: usize, + compact_threshold: usize, +} + +impl AppendLog { + /// Open or create a log file. + pub fn open(path: impl AsRef) -> MeshResult { + let path = path.as_ref().to_path_buf(); + + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).map_err(|e| { + StoreError::Persistence(format!("failed to create directory: {}", e)) + })?; + } + + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&path) + .map_err(|e| StoreError::Persistence(format!("failed to open log: {}", e)))?; + + Ok(Self { + path, + writer: Some(BufWriter::new(file)), + entries_since_compact: 0, + compact_threshold: 10_000, + }) + } + + /// Append an entry to the log. + pub fn append(&mut self, entry: &StorageEntry) -> MeshResult<()> { + let writer = self.writer.as_mut().ok_or_else(|| { + StoreError::Persistence("log not open".to_string()) + })?; + + let json = serde_json::to_string(entry).map_err(|e| { + StoreError::Serialization(format!("failed to serialize entry: {}", e)) + })?; + + writeln!(writer, "{}", json).map_err(|e| { + StoreError::Persistence(format!("failed to write entry: {}", e)) + })?; + + writer.flush().map_err(|e| { + StoreError::Persistence(format!("failed to flush: {}", e)) + })?; + + self.entries_since_compact += 1; + Ok(()) + } + + /// Read all entries from the log. + pub fn read_all(&self) -> MeshResult> { + let file = File::open(&self.path).map_err(|e| { + if e.kind() == io::ErrorKind::NotFound { + return StoreError::NotFound(self.path.display().to_string()); + } + StoreError::Persistence(format!("failed to open log: {}", e)) + })?; + + let reader = BufReader::new(file); + let mut entries = Vec::new(); + + for line in reader.lines() { + let line = line.map_err(|e| { + StoreError::Persistence(format!("failed to read line: {}", e)) + })?; + + if line.trim().is_empty() { + continue; + } + + let entry: StorageEntry = serde_json::from_str(&line).map_err(|e| { + StoreError::Serialization(format!("failed to parse entry: {}", e)) + })?; + + entries.push(entry); + } + + Ok(entries) + } + + /// Check if compaction is needed. + pub fn needs_compaction(&self) -> bool { + self.entries_since_compact >= self.compact_threshold + } + + /// Compact the log by replaying and removing deleted entries. + pub fn compact(&mut self) -> MeshResult { + let entries = self.read_all()?; + + // Build current state by replaying log + let mut routes: HashMap<[u8; 16], StorageEntry> = HashMap::new(); + let mut keypackages: HashMap<([u8; 16], [u8; 8]), StorageEntry> = HashMap::new(); + let mut messages: HashMap, StorageEntry> = HashMap::new(); + let mut identity: Option = None; + + let now = now_secs(); + + for entry in entries { + match &entry { + StorageEntry::Route { address, expires_at, .. } => { + if *expires_at > now { + routes.insert(*address, entry); + } + } + StorageEntry::RouteRemove { address } => { + routes.remove(address); + } + StorageEntry::KeyPackage { address, hash, expires_at, .. } => { + if *expires_at > now { + keypackages.insert((*address, *hash), entry); + } + } + StorageEntry::KeyPackageRemove { address, hash } => { + keypackages.remove(&(*address, *hash)); + } + StorageEntry::Message { id, expires_at, .. } => { + if *expires_at > now { + messages.insert(id.clone(), entry); + } + } + StorageEntry::MessageRemove { id } => { + messages.remove(id); + } + StorageEntry::Identity { .. } => { + identity = Some(entry); + } + } + } + + // Write compacted log + let tmp_path = self.path.with_extension("tmp"); + let mut tmp_file = File::create(&tmp_path).map_err(|e| { + StoreError::Persistence(format!("failed to create temp file: {}", e)) + })?; + + let mut written = 0; + + if let Some(id) = identity { + let json = serde_json::to_string(&id).map_err(|e| { + StoreError::Serialization(e.to_string()) + })?; + writeln!(tmp_file, "{}", json).map_err(|e| { + StoreError::Persistence(e.to_string()) + })?; + written += 1; + } + + for entry in routes.into_values() { + let json = serde_json::to_string(&entry).map_err(|e| { + StoreError::Serialization(e.to_string()) + })?; + writeln!(tmp_file, "{}", json).map_err(|e| { + StoreError::Persistence(e.to_string()) + })?; + written += 1; + } + + for entry in keypackages.into_values() { + let json = serde_json::to_string(&entry).map_err(|e| { + StoreError::Serialization(e.to_string()) + })?; + writeln!(tmp_file, "{}", json).map_err(|e| { + StoreError::Persistence(e.to_string()) + })?; + written += 1; + } + + for entry in messages.into_values() { + let json = serde_json::to_string(&entry).map_err(|e| { + StoreError::Serialization(e.to_string()) + })?; + writeln!(tmp_file, "{}", json).map_err(|e| { + StoreError::Persistence(e.to_string()) + })?; + written += 1; + } + + tmp_file.sync_all().map_err(|e| { + StoreError::Persistence(format!("failed to sync: {}", e)) + })?; + drop(tmp_file); + + // Close current writer + self.writer = None; + + // Replace old log with compacted one + fs::rename(&tmp_path, &self.path).map_err(|e| { + StoreError::Persistence(format!("failed to rename: {}", e)) + })?; + + // Reopen + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.path) + .map_err(|e| StoreError::Persistence(format!("failed to reopen: {}", e)))?; + + self.writer = Some(BufWriter::new(file)); + self.entries_since_compact = 0; + + Ok(CompactStats { + entries_before: self.entries_since_compact, + entries_after: written, + }) + } + + /// Sync to disk. + pub fn sync(&mut self) -> MeshResult<()> { + if let Some(writer) = self.writer.as_mut() { + writer.flush().map_err(|e| { + StoreError::Persistence(format!("flush failed: {}", e)) + })?; + writer.get_ref().sync_all().map_err(|e| { + StoreError::Persistence(format!("sync failed: {}", e)) + })?; + } + Ok(()) + } +} + +/// Compaction statistics. +#[derive(Debug, Clone)] +pub struct CompactStats { + pub entries_before: usize, + pub entries_after: usize, +} + +/// Persistent routing table storage. +pub struct PersistentRoutingTable { + log: AppendLog, + routes: HashMap, +} + +/// In-memory route entry. +#[derive(Debug, Clone)] +pub struct RouteEntry { + pub next_hop: String, + pub hops: u8, + pub sequence: u32, + pub expires_at: u64, +} + +impl PersistentRoutingTable { + /// Open or create a persistent routing table. + pub fn open(path: impl AsRef) -> MeshResult { + let mut log = AppendLog::open(path)?; + let mut routes = HashMap::new(); + + let now = now_secs(); + + for entry in log.read_all().unwrap_or_default() { + if let StorageEntry::Route { address, next_hop, hops, sequence, expires_at } = entry { + if expires_at > now { + routes.insert( + MeshAddress::from_bytes(address), + RouteEntry { next_hop, hops, sequence, expires_at }, + ); + } + } else if let StorageEntry::RouteRemove { address } = entry { + routes.remove(&MeshAddress::from_bytes(address)); + } + } + + Ok(Self { log, routes }) + } + + /// Insert or update a route. + pub fn insert( + &mut self, + address: MeshAddress, + next_hop: String, + hops: u8, + sequence: u32, + ttl: Duration, + ) -> MeshResult<()> { + let expires_at = now_secs() + ttl.as_secs(); + + self.log.append(&StorageEntry::Route { + address: *address.as_bytes(), + next_hop: next_hop.clone(), + hops, + sequence, + expires_at, + })?; + + self.routes.insert(address, RouteEntry { + next_hop, + hops, + sequence, + expires_at, + }); + + Ok(()) + } + + /// Look up a route. + pub fn get(&self, address: &MeshAddress) -> Option<&RouteEntry> { + let entry = self.routes.get(address)?; + if entry.expires_at > now_secs() { + Some(entry) + } else { + None + } + } + + /// Remove a route. + pub fn remove(&mut self, address: &MeshAddress) -> MeshResult { + if self.routes.remove(address).is_some() { + self.log.append(&StorageEntry::RouteRemove { + address: *address.as_bytes(), + })?; + Ok(true) + } else { + Ok(false) + } + } + + /// Number of routes. + pub fn len(&self) -> usize { + self.routes.len() + } + + /// Check if empty. + pub fn is_empty(&self) -> bool { + self.routes.is_empty() + } + + /// Garbage collect expired routes. + pub fn gc(&mut self) -> MeshResult { + let now = now_secs(); + let expired: Vec<_> = self.routes + .iter() + .filter(|(_, e)| e.expires_at <= now) + .map(|(a, _)| *a) + .collect(); + + let count = expired.len(); + for addr in expired { + self.remove(&addr)?; + } + Ok(count) + } + + /// Compact the underlying log. + pub fn compact(&mut self) -> MeshResult { + self.log.compact() + } + + /// Sync to disk. + pub fn sync(&mut self) -> MeshResult<()> { + self.log.sync() + } +} + +/// Persistent message store. +pub struct PersistentMessageStore { + log: AppendLog, + messages: HashMap, MessageEntry>, + by_recipient: HashMap>>, +} + +/// In-memory message entry. +#[derive(Debug, Clone)] +pub struct MessageEntry { + pub recipient: MeshAddress, + pub data: Vec, + pub expires_at: u64, +} + +impl PersistentMessageStore { + /// Open or create a persistent message store. + pub fn open(path: impl AsRef) -> MeshResult { + let mut log = AppendLog::open(path)?; + let mut messages = HashMap::new(); + let mut by_recipient: HashMap>> = HashMap::new(); + + let now = now_secs(); + + for entry in log.read_all().unwrap_or_default() { + if let StorageEntry::Message { id, recipient, data, expires_at } = entry { + if expires_at > now { + let addr = MeshAddress::from_bytes(recipient); + messages.insert(id.clone(), MessageEntry { + recipient: addr, + data, + expires_at, + }); + by_recipient.entry(addr).or_default().push(id); + } + } else if let StorageEntry::MessageRemove { id } = entry { + if let Some(entry) = messages.remove(&id) { + if let Some(ids) = by_recipient.get_mut(&entry.recipient) { + ids.retain(|i| i != &id); + } + } + } + } + + Ok(Self { log, messages, by_recipient }) + } + + /// Store a message. + pub fn store( + &mut self, + id: Vec, + recipient: MeshAddress, + data: Vec, + ttl: Duration, + ) -> MeshResult<()> { + let expires_at = now_secs() + ttl.as_secs(); + + self.log.append(&StorageEntry::Message { + id: id.clone(), + recipient: *recipient.as_bytes(), + data: data.clone(), + expires_at, + })?; + + self.messages.insert(id.clone(), MessageEntry { + recipient, + data, + expires_at, + }); + self.by_recipient.entry(recipient).or_default().push(id); + + Ok(()) + } + + /// Get messages for a recipient. + pub fn get_for_recipient(&self, recipient: &MeshAddress) -> Vec<(Vec, Vec)> { + let now = now_secs(); + self.by_recipient + .get(recipient) + .map(|ids| { + ids.iter() + .filter_map(|id| { + let entry = self.messages.get(id)?; + if entry.expires_at > now { + Some((id.clone(), entry.data.clone())) + } else { + None + } + }) + .collect() + }) + .unwrap_or_default() + } + + /// Remove a message. + pub fn remove(&mut self, id: &[u8]) -> MeshResult { + if let Some(entry) = self.messages.remove(id) { + if let Some(ids) = self.by_recipient.get_mut(&entry.recipient) { + ids.retain(|i| i != id); + } + self.log.append(&StorageEntry::MessageRemove { + id: id.to_vec(), + })?; + Ok(true) + } else { + Ok(false) + } + } + + /// Number of stored messages. + pub fn len(&self) -> usize { + self.messages.len() + } + + /// Check if empty. + pub fn is_empty(&self) -> bool { + self.messages.is_empty() + } + + /// Garbage collect expired messages. + pub fn gc(&mut self) -> MeshResult { + let now = now_secs(); + let expired: Vec<_> = self.messages + .iter() + .filter(|(_, e)| e.expires_at <= now) + .map(|(id, _)| id.clone()) + .collect(); + + let count = expired.len(); + for id in expired { + self.remove(&id)?; + } + Ok(count) + } + + /// Compact the underlying log. + pub fn compact(&mut self) -> MeshResult { + self.log.compact() + } + + /// Sync to disk. + pub fn sync(&mut self) -> MeshResult<()> { + self.log.sync() + } +} + +/// Get current time as Unix seconds. +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + + #[test] + fn append_log_roundtrip() { + let dir = tempdir().unwrap(); + let path = dir.path().join("test.log"); + + { + let mut log = AppendLog::open(&path).unwrap(); + log.append(&StorageEntry::Route { + address: [1u8; 16], + next_hop: "tcp:127.0.0.1:8080".to_string(), + hops: 2, + sequence: 42, + expires_at: now_secs() + 3600, + }).unwrap(); + } + + let log = AppendLog::open(&path).unwrap(); + let entries = log.read_all().unwrap(); + assert_eq!(entries.len(), 1); + + if let StorageEntry::Route { sequence, .. } = &entries[0] { + assert_eq!(*sequence, 42); + } else { + panic!("expected Route entry"); + } + } + + #[test] + fn routing_table_persistence() { + let dir = tempdir().unwrap(); + let path = dir.path().join("routes.log"); + + let addr = MeshAddress::from_bytes([0xAB; 16]); + + { + let mut rt = PersistentRoutingTable::open(&path).unwrap(); + rt.insert( + addr, + "tcp:192.168.1.1:8080".to_string(), + 3, + 100, + Duration::from_secs(3600), + ).unwrap(); + rt.sync().unwrap(); + } + + // Reopen and verify + let rt = PersistentRoutingTable::open(&path).unwrap(); + let entry = rt.get(&addr).expect("route should exist"); + assert_eq!(entry.hops, 3); + assert_eq!(entry.sequence, 100); + } + + #[test] + fn message_store_persistence() { + let dir = tempdir().unwrap(); + let path = dir.path().join("messages.log"); + + let recipient = MeshAddress::from_bytes([0xCD; 16]); + let id = b"msg-001".to_vec(); + let data = b"Hello, mesh!".to_vec(); + + { + let mut store = PersistentMessageStore::open(&path).unwrap(); + store.store(id.clone(), recipient, data.clone(), Duration::from_secs(3600)).unwrap(); + store.sync().unwrap(); + } + + let store = PersistentMessageStore::open(&path).unwrap(); + let msgs = store.get_for_recipient(&recipient); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].0, id); + assert_eq!(msgs[0].1, data); + } + + #[test] + fn compaction_removes_deleted() { + let dir = tempdir().unwrap(); + let path = dir.path().join("compact.log"); + + let addr1 = MeshAddress::from_bytes([1; 16]); + let addr2 = MeshAddress::from_bytes([2; 16]); + + { + let mut rt = PersistentRoutingTable::open(&path).unwrap(); + rt.insert(addr1, "hop1".to_string(), 1, 1, Duration::from_secs(3600)).unwrap(); + rt.insert(addr2, "hop2".to_string(), 1, 1, Duration::from_secs(3600)).unwrap(); + rt.remove(&addr1).unwrap(); // Delete one + rt.compact().unwrap(); + } + + let rt = PersistentRoutingTable::open(&path).unwrap(); + assert!(rt.get(&addr1).is_none()); + assert!(rt.get(&addr2).is_some()); + assert_eq!(rt.len(), 1); + } + + #[test] + fn gc_removes_expired() { + let dir = tempdir().unwrap(); + let path = dir.path().join("gc.log"); + + let addr = MeshAddress::from_bytes([0xEE; 16]); + + let mut rt = PersistentRoutingTable::open(&path).unwrap(); + rt.insert(addr, "hop".to_string(), 1, 1, Duration::from_secs(0)).unwrap(); + + // Should be expired immediately + std::thread::sleep(Duration::from_millis(10)); + let gc_count = rt.gc().unwrap(); + assert_eq!(gc_count, 1); + assert!(rt.get(&addr).is_none()); + } +} diff --git a/crates/quicprochat-p2p/src/shutdown.rs b/crates/quicprochat-p2p/src/shutdown.rs new file mode 100644 index 0000000..2b0fdbb --- /dev/null +++ b/crates/quicprochat-p2p/src/shutdown.rs @@ -0,0 +1,470 @@ +//! Graceful shutdown coordination for mesh nodes. +//! +//! This module provides coordinated shutdown with: +//! - Signal handling (SIGTERM, SIGINT, SIGHUP) +//! - Connection draining +//! - State persistence +//! - Cleanup hooks + +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::{broadcast, mpsc, watch, Notify}; +use tokio::time::timeout; + +/// Shutdown phase. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum ShutdownPhase { + /// Normal operation. + Running = 0, + /// Shutdown initiated, draining connections. + Draining = 1, + /// Persisting state. + Persisting = 2, + /// Running cleanup hooks. + Cleanup = 3, + /// Shutdown complete. + Complete = 4, +} + +impl From for ShutdownPhase { + fn from(v: u8) -> Self { + match v { + 0 => Self::Running, + 1 => Self::Draining, + 2 => Self::Persisting, + 3 => Self::Cleanup, + _ => Self::Complete, + } + } +} + +/// Shutdown coordinator. +pub struct ShutdownCoordinator { + /// Current phase. + phase: AtomicU8, + /// Shutdown signal broadcast. + shutdown_tx: broadcast::Sender, + /// Notify when all tasks complete. + all_done: Arc, + /// Active task count. + active_tasks: std::sync::atomic::AtomicUsize, + /// Drain timeout. + drain_timeout: Duration, + /// Persist timeout. + persist_timeout: Duration, +} + +impl ShutdownCoordinator { + pub fn new() -> Self { + let (shutdown_tx, _) = broadcast::channel(16); + Self { + phase: AtomicU8::new(ShutdownPhase::Running as u8), + shutdown_tx, + all_done: Arc::new(Notify::new()), + active_tasks: std::sync::atomic::AtomicUsize::new(0), + drain_timeout: Duration::from_secs(30), + persist_timeout: Duration::from_secs(10), + } + } + + pub fn with_timeouts(drain: Duration, persist: Duration) -> Self { + let mut s = Self::new(); + s.drain_timeout = drain; + s.persist_timeout = persist; + s + } + + /// Get current phase. + pub fn phase(&self) -> ShutdownPhase { + self.phase.load(Ordering::SeqCst).into() + } + + /// Check if shutdown is in progress. + pub fn is_shutting_down(&self) -> bool { + self.phase() != ShutdownPhase::Running + } + + /// Subscribe to shutdown notifications. + pub fn subscribe(&self) -> broadcast::Receiver { + self.shutdown_tx.subscribe() + } + + /// Register a task. + pub fn register_task(&self) -> TaskGuard { + self.active_tasks.fetch_add(1, Ordering::SeqCst); + TaskGuard { + active_tasks: &self.active_tasks, + all_done: Arc::clone(&self.all_done), + } + } + + /// Initiate shutdown. + pub async fn shutdown(&self) { + // Phase 1: Draining + self.set_phase(ShutdownPhase::Draining); + + // Wait for tasks to complete or timeout + let drain_result = timeout( + self.drain_timeout, + self.wait_for_tasks(), + ).await; + + if drain_result.is_err() { + tracing::warn!( + "drain timeout reached with {} tasks remaining", + self.active_tasks.load(Ordering::SeqCst) + ); + } + + // Phase 2: Persisting + self.set_phase(ShutdownPhase::Persisting); + + // Give persist hooks time to run + tokio::time::sleep(Duration::from_millis(100)).await; + + // Phase 3: Cleanup + self.set_phase(ShutdownPhase::Cleanup); + tokio::time::sleep(Duration::from_millis(100)).await; + + // Complete + self.set_phase(ShutdownPhase::Complete); + } + + fn set_phase(&self, phase: ShutdownPhase) { + self.phase.store(phase as u8, Ordering::SeqCst); + let _ = self.shutdown_tx.send(phase); + } + + async fn wait_for_tasks(&self) { + while self.active_tasks.load(Ordering::SeqCst) > 0 { + self.all_done.notified().await; + } + } +} + +impl Default for ShutdownCoordinator { + fn default() -> Self { + Self::new() + } +} + +/// RAII guard for tracking active tasks. +pub struct TaskGuard<'a> { + active_tasks: &'a std::sync::atomic::AtomicUsize, + all_done: Arc, +} + +impl<'a> Drop for TaskGuard<'a> { + fn drop(&mut self) { + let prev = self.active_tasks.fetch_sub(1, Ordering::SeqCst); + if prev == 1 { + self.all_done.notify_waiters(); + } + } +} + +/// Shutdown handle for use in async tasks. +#[derive(Clone)] +pub struct ShutdownSignal { + /// Watch receiver for shutdown. + watch_rx: watch::Receiver, +} + +impl ShutdownSignal { + /// Create a new signal pair. + pub fn new() -> (ShutdownTrigger, Self) { + let (tx, rx) = watch::channel(false); + (ShutdownTrigger { watch_tx: tx }, Self { watch_rx: rx }) + } + + /// Check if shutdown has been triggered. + pub fn is_triggered(&self) -> bool { + *self.watch_rx.borrow() + } + + /// Wait for shutdown signal. + pub async fn wait(&mut self) { + let _ = self.watch_rx.wait_for(|&triggered| triggered).await; + } + + /// Create a future that completes on shutdown. + pub fn recv(&mut self) -> impl Future + '_ { + async move { + self.wait().await + } + } +} + +impl Default for ShutdownSignal { + fn default() -> Self { + Self::new().1 + } +} + +/// Trigger for shutdown signal. +#[derive(Clone)] +pub struct ShutdownTrigger { + watch_tx: watch::Sender, +} + +impl ShutdownTrigger { + /// Trigger shutdown. + pub fn trigger(&self) { + let _ = self.watch_tx.send(true); + } +} + +/// Shutdown hook type. +pub type ShutdownHook = Box< + dyn FnOnce() -> Pin + Send>> + Send +>; + +/// Manages shutdown hooks. +pub struct ShutdownHooks { + persist_hooks: Vec, + cleanup_hooks: Vec, +} + +impl ShutdownHooks { + pub fn new() -> Self { + Self { + persist_hooks: Vec::new(), + cleanup_hooks: Vec::new(), + } + } + + /// Register a persist hook (runs during Persisting phase). + pub fn on_persist(&mut self, f: F) + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future + Send + 'static, + { + self.persist_hooks.push(Box::new(|| Box::pin(f()))); + } + + /// Register a cleanup hook (runs during Cleanup phase). + pub fn on_cleanup(&mut self, f: F) + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future + Send + 'static, + { + self.cleanup_hooks.push(Box::new(|| Box::pin(f()))); + } + + /// Run all persist hooks. + pub async fn run_persist(&mut self) { + for hook in self.persist_hooks.drain(..) { + hook().await; + } + } + + /// Run all cleanup hooks. + pub async fn run_cleanup(&mut self) { + for hook in self.cleanup_hooks.drain(..) { + hook().await; + } + } +} + +impl Default for ShutdownHooks { + fn default() -> Self { + Self::new() + } +} + +/// Draining connection tracker. +pub struct ConnectionDrainer { + /// Maximum connections to track. + max_connections: usize, + /// Active connections. + active: std::sync::atomic::AtomicUsize, + /// Notify when connection count changes. + notify: Notify, + /// Stopped accepting new connections. + draining: AtomicBool, +} + +impl ConnectionDrainer { + pub fn new(max_connections: usize) -> Self { + Self { + max_connections, + active: std::sync::atomic::AtomicUsize::new(0), + notify: Notify::new(), + draining: AtomicBool::new(false), + } + } + + /// Try to accept a new connection. + pub fn try_accept(&self) -> Option> { + if self.draining.load(Ordering::SeqCst) { + return None; + } + + let current = self.active.fetch_add(1, Ordering::SeqCst); + if current >= self.max_connections { + self.active.fetch_sub(1, Ordering::SeqCst); + return None; + } + + Some(ConnectionGuard { drainer: self }) + } + + /// Start draining (stop accepting new connections). + pub fn start_drain(&self) { + self.draining.store(true, Ordering::SeqCst); + } + + /// Wait for all connections to close. + pub async fn wait_drained(&self) { + while self.active.load(Ordering::SeqCst) > 0 { + self.notify.notified().await; + } + } + + /// Current connection count. + pub fn active_count(&self) -> usize { + self.active.load(Ordering::SeqCst) + } + + /// Is draining? + pub fn is_draining(&self) -> bool { + self.draining.load(Ordering::SeqCst) + } +} + +/// RAII guard for active connections. +pub struct ConnectionGuard<'a> { + drainer: &'a ConnectionDrainer, +} + +impl<'a> Drop for ConnectionGuard<'a> { + fn drop(&mut self) { + self.drainer.active.fetch_sub(1, Ordering::SeqCst); + self.drainer.notify.notify_waiters(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn shutdown_phases() { + let coord = ShutdownCoordinator::with_timeouts( + Duration::from_millis(100), + Duration::from_millis(50), + ); + + assert_eq!(coord.phase(), ShutdownPhase::Running); + assert!(!coord.is_shutting_down()); + + let mut rx = coord.subscribe(); + + tokio::spawn(async move { + coord.shutdown().await; + }); + + // Should receive phase transitions + let phase = rx.recv().await.unwrap(); + assert_eq!(phase, ShutdownPhase::Draining); + + let phase = rx.recv().await.unwrap(); + assert_eq!(phase, ShutdownPhase::Persisting); + + let phase = rx.recv().await.unwrap(); + assert_eq!(phase, ShutdownPhase::Cleanup); + + let phase = rx.recv().await.unwrap(); + assert_eq!(phase, ShutdownPhase::Complete); + } + + #[tokio::test] + async fn task_tracking() { + let coord = ShutdownCoordinator::with_timeouts( + Duration::from_secs(1), + Duration::from_millis(50), + ); + + let guard1 = coord.register_task(); + let guard2 = coord.register_task(); + + assert_eq!(coord.active_tasks.load(Ordering::SeqCst), 2); + + drop(guard1); + assert_eq!(coord.active_tasks.load(Ordering::SeqCst), 1); + + drop(guard2); + assert_eq!(coord.active_tasks.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn shutdown_signal() { + let (trigger, mut signal) = ShutdownSignal::new(); + + assert!(!signal.is_triggered()); + + let handle = tokio::spawn(async move { + signal.wait().await; + true + }); + + trigger.trigger(); + assert!(handle.await.unwrap()); + } + + #[tokio::test] + async fn connection_drainer() { + let drainer = ConnectionDrainer::new(2); + + let conn1 = drainer.try_accept().expect("should accept"); + let conn2 = drainer.try_accept().expect("should accept"); + assert!(drainer.try_accept().is_none()); // At capacity + + assert_eq!(drainer.active_count(), 2); + + drop(conn1); + assert_eq!(drainer.active_count(), 1); + + drainer.start_drain(); + assert!(drainer.try_accept().is_none()); // Draining + + drop(conn2); + + // Should complete immediately + tokio::time::timeout( + Duration::from_millis(100), + drainer.wait_drained(), + ).await.expect("should drain quickly"); + } + + #[tokio::test] + async fn shutdown_hooks() { + use std::sync::atomic::AtomicBool; + + let persist_ran = Arc::new(AtomicBool::new(false)); + let cleanup_ran = Arc::new(AtomicBool::new(false)); + + let persist_flag = Arc::clone(&persist_ran); + let cleanup_flag = Arc::clone(&cleanup_ran); + + let mut hooks = ShutdownHooks::new(); + hooks.on_persist(move || async move { + persist_flag.store(true, Ordering::SeqCst); + }); + hooks.on_cleanup(move || async move { + cleanup_flag.store(true, Ordering::SeqCst); + }); + + hooks.run_persist().await; + assert!(persist_ran.load(Ordering::SeqCst)); + assert!(!cleanup_ran.load(Ordering::SeqCst)); + + hooks.run_cleanup().await; + assert!(cleanup_ran.load(Ordering::SeqCst)); + } +}