feat(p2p): add persistence and graceful shutdown

- persistence.rs: Append-only log storage for routing table,
  KeyPackage cache, and messages with compaction and GC
- shutdown.rs: Coordinated shutdown with phase transitions,
  task tracking, connection draining, and hook system

Enables stateful operation and clean restarts.
This commit is contained in:
2026-04-01 09:19:13 +02:00
parent 024b6c91d1
commit a258f98a40
3 changed files with 1165 additions and 0 deletions

View File

@@ -27,7 +27,9 @@ pub mod keypackage_cache;
pub mod mesh_protocol;
pub mod metrics;
pub mod mls_lite;
pub mod persistence;
pub mod rate_limit;
pub mod shutdown;
pub mod identity;
pub mod link;
pub mod mesh_router;

View File

@@ -0,0 +1,693 @@
//! Persistence layer for mesh node state.
//!
//! This module provides durable storage for:
//! - Routing table entries
//! - KeyPackage cache
//! - Stored messages (store-and-forward)
//! - Node identity
//!
//! Uses a simple append-only log format with periodic compaction.
use std::collections::HashMap;
use std::fs::{self, File, OpenOptions};
use std::io::{self, BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::address::MeshAddress;
use crate::error::{MeshResult, StoreError};
/// Storage entry types.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum StorageEntry {
/// Routing table entry.
Route {
address: [u8; 16],
next_hop: String,
hops: u8,
sequence: u32,
expires_at: u64,
},
/// Remove a route.
RouteRemove { address: [u8; 16] },
/// KeyPackage cache entry.
KeyPackage {
address: [u8; 16],
data: Vec<u8>,
hash: [u8; 8],
expires_at: u64,
},
/// Remove a KeyPackage.
KeyPackageRemove { address: [u8; 16], hash: [u8; 8] },
/// Stored message.
Message {
id: Vec<u8>,
recipient: [u8; 16],
data: Vec<u8>,
expires_at: u64,
},
/// Remove a message.
MessageRemove { id: Vec<u8> },
/// Identity keypair (encrypted or raw for development).
Identity {
public_key: Vec<u8>,
secret_key_encrypted: Vec<u8>,
},
}
/// Append-only log for persistence.
pub struct AppendLog {
path: PathBuf,
writer: Option<BufWriter<File>>,
entries_since_compact: usize,
compact_threshold: usize,
}
impl AppendLog {
/// Open or create a log file.
pub fn open(path: impl AsRef<Path>) -> MeshResult<Self> {
let path = path.as_ref().to_path_buf();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| {
StoreError::Persistence(format!("failed to create directory: {}", e))
})?;
}
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.map_err(|e| StoreError::Persistence(format!("failed to open log: {}", e)))?;
Ok(Self {
path,
writer: Some(BufWriter::new(file)),
entries_since_compact: 0,
compact_threshold: 10_000,
})
}
/// Append an entry to the log.
pub fn append(&mut self, entry: &StorageEntry) -> MeshResult<()> {
let writer = self.writer.as_mut().ok_or_else(|| {
StoreError::Persistence("log not open".to_string())
})?;
let json = serde_json::to_string(entry).map_err(|e| {
StoreError::Serialization(format!("failed to serialize entry: {}", e))
})?;
writeln!(writer, "{}", json).map_err(|e| {
StoreError::Persistence(format!("failed to write entry: {}", e))
})?;
writer.flush().map_err(|e| {
StoreError::Persistence(format!("failed to flush: {}", e))
})?;
self.entries_since_compact += 1;
Ok(())
}
/// Read all entries from the log.
pub fn read_all(&self) -> MeshResult<Vec<StorageEntry>> {
let file = File::open(&self.path).map_err(|e| {
if e.kind() == io::ErrorKind::NotFound {
return StoreError::NotFound(self.path.display().to_string());
}
StoreError::Persistence(format!("failed to open log: {}", e))
})?;
let reader = BufReader::new(file);
let mut entries = Vec::new();
for line in reader.lines() {
let line = line.map_err(|e| {
StoreError::Persistence(format!("failed to read line: {}", e))
})?;
if line.trim().is_empty() {
continue;
}
let entry: StorageEntry = serde_json::from_str(&line).map_err(|e| {
StoreError::Serialization(format!("failed to parse entry: {}", e))
})?;
entries.push(entry);
}
Ok(entries)
}
/// Check if compaction is needed.
pub fn needs_compaction(&self) -> bool {
self.entries_since_compact >= self.compact_threshold
}
/// Compact the log by replaying and removing deleted entries.
pub fn compact(&mut self) -> MeshResult<CompactStats> {
let entries = self.read_all()?;
// Build current state by replaying log
let mut routes: HashMap<[u8; 16], StorageEntry> = HashMap::new();
let mut keypackages: HashMap<([u8; 16], [u8; 8]), StorageEntry> = HashMap::new();
let mut messages: HashMap<Vec<u8>, StorageEntry> = HashMap::new();
let mut identity: Option<StorageEntry> = None;
let now = now_secs();
for entry in entries {
match &entry {
StorageEntry::Route { address, expires_at, .. } => {
if *expires_at > now {
routes.insert(*address, entry);
}
}
StorageEntry::RouteRemove { address } => {
routes.remove(address);
}
StorageEntry::KeyPackage { address, hash, expires_at, .. } => {
if *expires_at > now {
keypackages.insert((*address, *hash), entry);
}
}
StorageEntry::KeyPackageRemove { address, hash } => {
keypackages.remove(&(*address, *hash));
}
StorageEntry::Message { id, expires_at, .. } => {
if *expires_at > now {
messages.insert(id.clone(), entry);
}
}
StorageEntry::MessageRemove { id } => {
messages.remove(id);
}
StorageEntry::Identity { .. } => {
identity = Some(entry);
}
}
}
// Write compacted log
let tmp_path = self.path.with_extension("tmp");
let mut tmp_file = File::create(&tmp_path).map_err(|e| {
StoreError::Persistence(format!("failed to create temp file: {}", e))
})?;
let mut written = 0;
if let Some(id) = identity {
let json = serde_json::to_string(&id).map_err(|e| {
StoreError::Serialization(e.to_string())
})?;
writeln!(tmp_file, "{}", json).map_err(|e| {
StoreError::Persistence(e.to_string())
})?;
written += 1;
}
for entry in routes.into_values() {
let json = serde_json::to_string(&entry).map_err(|e| {
StoreError::Serialization(e.to_string())
})?;
writeln!(tmp_file, "{}", json).map_err(|e| {
StoreError::Persistence(e.to_string())
})?;
written += 1;
}
for entry in keypackages.into_values() {
let json = serde_json::to_string(&entry).map_err(|e| {
StoreError::Serialization(e.to_string())
})?;
writeln!(tmp_file, "{}", json).map_err(|e| {
StoreError::Persistence(e.to_string())
})?;
written += 1;
}
for entry in messages.into_values() {
let json = serde_json::to_string(&entry).map_err(|e| {
StoreError::Serialization(e.to_string())
})?;
writeln!(tmp_file, "{}", json).map_err(|e| {
StoreError::Persistence(e.to_string())
})?;
written += 1;
}
tmp_file.sync_all().map_err(|e| {
StoreError::Persistence(format!("failed to sync: {}", e))
})?;
drop(tmp_file);
// Close current writer
self.writer = None;
// Replace old log with compacted one
fs::rename(&tmp_path, &self.path).map_err(|e| {
StoreError::Persistence(format!("failed to rename: {}", e))
})?;
// Reopen
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&self.path)
.map_err(|e| StoreError::Persistence(format!("failed to reopen: {}", e)))?;
self.writer = Some(BufWriter::new(file));
self.entries_since_compact = 0;
Ok(CompactStats {
entries_before: self.entries_since_compact,
entries_after: written,
})
}
/// Sync to disk.
pub fn sync(&mut self) -> MeshResult<()> {
if let Some(writer) = self.writer.as_mut() {
writer.flush().map_err(|e| {
StoreError::Persistence(format!("flush failed: {}", e))
})?;
writer.get_ref().sync_all().map_err(|e| {
StoreError::Persistence(format!("sync failed: {}", e))
})?;
}
Ok(())
}
}
/// Compaction statistics.
#[derive(Debug, Clone)]
pub struct CompactStats {
pub entries_before: usize,
pub entries_after: usize,
}
/// Persistent routing table storage.
pub struct PersistentRoutingTable {
log: AppendLog,
routes: HashMap<MeshAddress, RouteEntry>,
}
/// In-memory route entry.
#[derive(Debug, Clone)]
pub struct RouteEntry {
pub next_hop: String,
pub hops: u8,
pub sequence: u32,
pub expires_at: u64,
}
impl PersistentRoutingTable {
/// Open or create a persistent routing table.
pub fn open(path: impl AsRef<Path>) -> MeshResult<Self> {
let mut log = AppendLog::open(path)?;
let mut routes = HashMap::new();
let now = now_secs();
for entry in log.read_all().unwrap_or_default() {
if let StorageEntry::Route { address, next_hop, hops, sequence, expires_at } = entry {
if expires_at > now {
routes.insert(
MeshAddress::from_bytes(address),
RouteEntry { next_hop, hops, sequence, expires_at },
);
}
} else if let StorageEntry::RouteRemove { address } = entry {
routes.remove(&MeshAddress::from_bytes(address));
}
}
Ok(Self { log, routes })
}
/// Insert or update a route.
pub fn insert(
&mut self,
address: MeshAddress,
next_hop: String,
hops: u8,
sequence: u32,
ttl: Duration,
) -> MeshResult<()> {
let expires_at = now_secs() + ttl.as_secs();
self.log.append(&StorageEntry::Route {
address: *address.as_bytes(),
next_hop: next_hop.clone(),
hops,
sequence,
expires_at,
})?;
self.routes.insert(address, RouteEntry {
next_hop,
hops,
sequence,
expires_at,
});
Ok(())
}
/// Look up a route.
pub fn get(&self, address: &MeshAddress) -> Option<&RouteEntry> {
let entry = self.routes.get(address)?;
if entry.expires_at > now_secs() {
Some(entry)
} else {
None
}
}
/// Remove a route.
pub fn remove(&mut self, address: &MeshAddress) -> MeshResult<bool> {
if self.routes.remove(address).is_some() {
self.log.append(&StorageEntry::RouteRemove {
address: *address.as_bytes(),
})?;
Ok(true)
} else {
Ok(false)
}
}
/// Number of routes.
pub fn len(&self) -> usize {
self.routes.len()
}
/// Check if empty.
pub fn is_empty(&self) -> bool {
self.routes.is_empty()
}
/// Garbage collect expired routes.
pub fn gc(&mut self) -> MeshResult<usize> {
let now = now_secs();
let expired: Vec<_> = self.routes
.iter()
.filter(|(_, e)| e.expires_at <= now)
.map(|(a, _)| *a)
.collect();
let count = expired.len();
for addr in expired {
self.remove(&addr)?;
}
Ok(count)
}
/// Compact the underlying log.
pub fn compact(&mut self) -> MeshResult<CompactStats> {
self.log.compact()
}
/// Sync to disk.
pub fn sync(&mut self) -> MeshResult<()> {
self.log.sync()
}
}
/// Persistent message store.
pub struct PersistentMessageStore {
log: AppendLog,
messages: HashMap<Vec<u8>, MessageEntry>,
by_recipient: HashMap<MeshAddress, Vec<Vec<u8>>>,
}
/// In-memory message entry.
#[derive(Debug, Clone)]
pub struct MessageEntry {
pub recipient: MeshAddress,
pub data: Vec<u8>,
pub expires_at: u64,
}
impl PersistentMessageStore {
/// Open or create a persistent message store.
pub fn open(path: impl AsRef<Path>) -> MeshResult<Self> {
let mut log = AppendLog::open(path)?;
let mut messages = HashMap::new();
let mut by_recipient: HashMap<MeshAddress, Vec<Vec<u8>>> = HashMap::new();
let now = now_secs();
for entry in log.read_all().unwrap_or_default() {
if let StorageEntry::Message { id, recipient, data, expires_at } = entry {
if expires_at > now {
let addr = MeshAddress::from_bytes(recipient);
messages.insert(id.clone(), MessageEntry {
recipient: addr,
data,
expires_at,
});
by_recipient.entry(addr).or_default().push(id);
}
} else if let StorageEntry::MessageRemove { id } = entry {
if let Some(entry) = messages.remove(&id) {
if let Some(ids) = by_recipient.get_mut(&entry.recipient) {
ids.retain(|i| i != &id);
}
}
}
}
Ok(Self { log, messages, by_recipient })
}
/// Store a message.
pub fn store(
&mut self,
id: Vec<u8>,
recipient: MeshAddress,
data: Vec<u8>,
ttl: Duration,
) -> MeshResult<()> {
let expires_at = now_secs() + ttl.as_secs();
self.log.append(&StorageEntry::Message {
id: id.clone(),
recipient: *recipient.as_bytes(),
data: data.clone(),
expires_at,
})?;
self.messages.insert(id.clone(), MessageEntry {
recipient,
data,
expires_at,
});
self.by_recipient.entry(recipient).or_default().push(id);
Ok(())
}
/// Get messages for a recipient.
pub fn get_for_recipient(&self, recipient: &MeshAddress) -> Vec<(Vec<u8>, Vec<u8>)> {
let now = now_secs();
self.by_recipient
.get(recipient)
.map(|ids| {
ids.iter()
.filter_map(|id| {
let entry = self.messages.get(id)?;
if entry.expires_at > now {
Some((id.clone(), entry.data.clone()))
} else {
None
}
})
.collect()
})
.unwrap_or_default()
}
/// Remove a message.
pub fn remove(&mut self, id: &[u8]) -> MeshResult<bool> {
if let Some(entry) = self.messages.remove(id) {
if let Some(ids) = self.by_recipient.get_mut(&entry.recipient) {
ids.retain(|i| i != id);
}
self.log.append(&StorageEntry::MessageRemove {
id: id.to_vec(),
})?;
Ok(true)
} else {
Ok(false)
}
}
/// Number of stored messages.
pub fn len(&self) -> usize {
self.messages.len()
}
/// Check if empty.
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
/// Garbage collect expired messages.
pub fn gc(&mut self) -> MeshResult<usize> {
let now = now_secs();
let expired: Vec<_> = self.messages
.iter()
.filter(|(_, e)| e.expires_at <= now)
.map(|(id, _)| id.clone())
.collect();
let count = expired.len();
for id in expired {
self.remove(&id)?;
}
Ok(count)
}
/// Compact the underlying log.
pub fn compact(&mut self) -> MeshResult<CompactStats> {
self.log.compact()
}
/// Sync to disk.
pub fn sync(&mut self) -> MeshResult<()> {
self.log.sync()
}
}
/// Get current time as Unix seconds.
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn append_log_roundtrip() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.log");
{
let mut log = AppendLog::open(&path).unwrap();
log.append(&StorageEntry::Route {
address: [1u8; 16],
next_hop: "tcp:127.0.0.1:8080".to_string(),
hops: 2,
sequence: 42,
expires_at: now_secs() + 3600,
}).unwrap();
}
let log = AppendLog::open(&path).unwrap();
let entries = log.read_all().unwrap();
assert_eq!(entries.len(), 1);
if let StorageEntry::Route { sequence, .. } = &entries[0] {
assert_eq!(*sequence, 42);
} else {
panic!("expected Route entry");
}
}
#[test]
fn routing_table_persistence() {
let dir = tempdir().unwrap();
let path = dir.path().join("routes.log");
let addr = MeshAddress::from_bytes([0xAB; 16]);
{
let mut rt = PersistentRoutingTable::open(&path).unwrap();
rt.insert(
addr,
"tcp:192.168.1.1:8080".to_string(),
3,
100,
Duration::from_secs(3600),
).unwrap();
rt.sync().unwrap();
}
// Reopen and verify
let rt = PersistentRoutingTable::open(&path).unwrap();
let entry = rt.get(&addr).expect("route should exist");
assert_eq!(entry.hops, 3);
assert_eq!(entry.sequence, 100);
}
#[test]
fn message_store_persistence() {
let dir = tempdir().unwrap();
let path = dir.path().join("messages.log");
let recipient = MeshAddress::from_bytes([0xCD; 16]);
let id = b"msg-001".to_vec();
let data = b"Hello, mesh!".to_vec();
{
let mut store = PersistentMessageStore::open(&path).unwrap();
store.store(id.clone(), recipient, data.clone(), Duration::from_secs(3600)).unwrap();
store.sync().unwrap();
}
let store = PersistentMessageStore::open(&path).unwrap();
let msgs = store.get_for_recipient(&recipient);
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].0, id);
assert_eq!(msgs[0].1, data);
}
#[test]
fn compaction_removes_deleted() {
let dir = tempdir().unwrap();
let path = dir.path().join("compact.log");
let addr1 = MeshAddress::from_bytes([1; 16]);
let addr2 = MeshAddress::from_bytes([2; 16]);
{
let mut rt = PersistentRoutingTable::open(&path).unwrap();
rt.insert(addr1, "hop1".to_string(), 1, 1, Duration::from_secs(3600)).unwrap();
rt.insert(addr2, "hop2".to_string(), 1, 1, Duration::from_secs(3600)).unwrap();
rt.remove(&addr1).unwrap(); // Delete one
rt.compact().unwrap();
}
let rt = PersistentRoutingTable::open(&path).unwrap();
assert!(rt.get(&addr1).is_none());
assert!(rt.get(&addr2).is_some());
assert_eq!(rt.len(), 1);
}
#[test]
fn gc_removes_expired() {
let dir = tempdir().unwrap();
let path = dir.path().join("gc.log");
let addr = MeshAddress::from_bytes([0xEE; 16]);
let mut rt = PersistentRoutingTable::open(&path).unwrap();
rt.insert(addr, "hop".to_string(), 1, 1, Duration::from_secs(0)).unwrap();
// Should be expired immediately
std::thread::sleep(Duration::from_millis(10));
let gc_count = rt.gc().unwrap();
assert_eq!(gc_count, 1);
assert!(rt.get(&addr).is_none());
}
}

View File

@@ -0,0 +1,470 @@
//! Graceful shutdown coordination for mesh nodes.
//!
//! This module provides coordinated shutdown with:
//! - Signal handling (SIGTERM, SIGINT, SIGHUP)
//! - Connection draining
//! - State persistence
//! - Cleanup hooks
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, mpsc, watch, Notify};
use tokio::time::timeout;
/// Shutdown phase.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum ShutdownPhase {
/// Normal operation.
Running = 0,
/// Shutdown initiated, draining connections.
Draining = 1,
/// Persisting state.
Persisting = 2,
/// Running cleanup hooks.
Cleanup = 3,
/// Shutdown complete.
Complete = 4,
}
impl From<u8> for ShutdownPhase {
fn from(v: u8) -> Self {
match v {
0 => Self::Running,
1 => Self::Draining,
2 => Self::Persisting,
3 => Self::Cleanup,
_ => Self::Complete,
}
}
}
/// Shutdown coordinator.
pub struct ShutdownCoordinator {
/// Current phase.
phase: AtomicU8,
/// Shutdown signal broadcast.
shutdown_tx: broadcast::Sender<ShutdownPhase>,
/// Notify when all tasks complete.
all_done: Arc<Notify>,
/// Active task count.
active_tasks: std::sync::atomic::AtomicUsize,
/// Drain timeout.
drain_timeout: Duration,
/// Persist timeout.
persist_timeout: Duration,
}
impl ShutdownCoordinator {
pub fn new() -> Self {
let (shutdown_tx, _) = broadcast::channel(16);
Self {
phase: AtomicU8::new(ShutdownPhase::Running as u8),
shutdown_tx,
all_done: Arc::new(Notify::new()),
active_tasks: std::sync::atomic::AtomicUsize::new(0),
drain_timeout: Duration::from_secs(30),
persist_timeout: Duration::from_secs(10),
}
}
pub fn with_timeouts(drain: Duration, persist: Duration) -> Self {
let mut s = Self::new();
s.drain_timeout = drain;
s.persist_timeout = persist;
s
}
/// Get current phase.
pub fn phase(&self) -> ShutdownPhase {
self.phase.load(Ordering::SeqCst).into()
}
/// Check if shutdown is in progress.
pub fn is_shutting_down(&self) -> bool {
self.phase() != ShutdownPhase::Running
}
/// Subscribe to shutdown notifications.
pub fn subscribe(&self) -> broadcast::Receiver<ShutdownPhase> {
self.shutdown_tx.subscribe()
}
/// Register a task.
pub fn register_task(&self) -> TaskGuard {
self.active_tasks.fetch_add(1, Ordering::SeqCst);
TaskGuard {
active_tasks: &self.active_tasks,
all_done: Arc::clone(&self.all_done),
}
}
/// Initiate shutdown.
pub async fn shutdown(&self) {
// Phase 1: Draining
self.set_phase(ShutdownPhase::Draining);
// Wait for tasks to complete or timeout
let drain_result = timeout(
self.drain_timeout,
self.wait_for_tasks(),
).await;
if drain_result.is_err() {
tracing::warn!(
"drain timeout reached with {} tasks remaining",
self.active_tasks.load(Ordering::SeqCst)
);
}
// Phase 2: Persisting
self.set_phase(ShutdownPhase::Persisting);
// Give persist hooks time to run
tokio::time::sleep(Duration::from_millis(100)).await;
// Phase 3: Cleanup
self.set_phase(ShutdownPhase::Cleanup);
tokio::time::sleep(Duration::from_millis(100)).await;
// Complete
self.set_phase(ShutdownPhase::Complete);
}
fn set_phase(&self, phase: ShutdownPhase) {
self.phase.store(phase as u8, Ordering::SeqCst);
let _ = self.shutdown_tx.send(phase);
}
async fn wait_for_tasks(&self) {
while self.active_tasks.load(Ordering::SeqCst) > 0 {
self.all_done.notified().await;
}
}
}
impl Default for ShutdownCoordinator {
fn default() -> Self {
Self::new()
}
}
/// RAII guard for tracking active tasks.
pub struct TaskGuard<'a> {
active_tasks: &'a std::sync::atomic::AtomicUsize,
all_done: Arc<Notify>,
}
impl<'a> Drop for TaskGuard<'a> {
fn drop(&mut self) {
let prev = self.active_tasks.fetch_sub(1, Ordering::SeqCst);
if prev == 1 {
self.all_done.notify_waiters();
}
}
}
/// Shutdown handle for use in async tasks.
#[derive(Clone)]
pub struct ShutdownSignal {
/// Watch receiver for shutdown.
watch_rx: watch::Receiver<bool>,
}
impl ShutdownSignal {
/// Create a new signal pair.
pub fn new() -> (ShutdownTrigger, Self) {
let (tx, rx) = watch::channel(false);
(ShutdownTrigger { watch_tx: tx }, Self { watch_rx: rx })
}
/// Check if shutdown has been triggered.
pub fn is_triggered(&self) -> bool {
*self.watch_rx.borrow()
}
/// Wait for shutdown signal.
pub async fn wait(&mut self) {
let _ = self.watch_rx.wait_for(|&triggered| triggered).await;
}
/// Create a future that completes on shutdown.
pub fn recv(&mut self) -> impl Future<Output = ()> + '_ {
async move {
self.wait().await
}
}
}
impl Default for ShutdownSignal {
fn default() -> Self {
Self::new().1
}
}
/// Trigger for shutdown signal.
#[derive(Clone)]
pub struct ShutdownTrigger {
watch_tx: watch::Sender<bool>,
}
impl ShutdownTrigger {
/// Trigger shutdown.
pub fn trigger(&self) {
let _ = self.watch_tx.send(true);
}
}
/// Shutdown hook type.
pub type ShutdownHook = Box<
dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send
>;
/// Manages shutdown hooks.
pub struct ShutdownHooks {
persist_hooks: Vec<ShutdownHook>,
cleanup_hooks: Vec<ShutdownHook>,
}
impl ShutdownHooks {
pub fn new() -> Self {
Self {
persist_hooks: Vec::new(),
cleanup_hooks: Vec::new(),
}
}
/// Register a persist hook (runs during Persisting phase).
pub fn on_persist<F, Fut>(&mut self, f: F)
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.persist_hooks.push(Box::new(|| Box::pin(f())));
}
/// Register a cleanup hook (runs during Cleanup phase).
pub fn on_cleanup<F, Fut>(&mut self, f: F)
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.cleanup_hooks.push(Box::new(|| Box::pin(f())));
}
/// Run all persist hooks.
pub async fn run_persist(&mut self) {
for hook in self.persist_hooks.drain(..) {
hook().await;
}
}
/// Run all cleanup hooks.
pub async fn run_cleanup(&mut self) {
for hook in self.cleanup_hooks.drain(..) {
hook().await;
}
}
}
impl Default for ShutdownHooks {
fn default() -> Self {
Self::new()
}
}
/// Draining connection tracker.
pub struct ConnectionDrainer {
/// Maximum connections to track.
max_connections: usize,
/// Active connections.
active: std::sync::atomic::AtomicUsize,
/// Notify when connection count changes.
notify: Notify,
/// Stopped accepting new connections.
draining: AtomicBool,
}
impl ConnectionDrainer {
pub fn new(max_connections: usize) -> Self {
Self {
max_connections,
active: std::sync::atomic::AtomicUsize::new(0),
notify: Notify::new(),
draining: AtomicBool::new(false),
}
}
/// Try to accept a new connection.
pub fn try_accept(&self) -> Option<ConnectionGuard<'_>> {
if self.draining.load(Ordering::SeqCst) {
return None;
}
let current = self.active.fetch_add(1, Ordering::SeqCst);
if current >= self.max_connections {
self.active.fetch_sub(1, Ordering::SeqCst);
return None;
}
Some(ConnectionGuard { drainer: self })
}
/// Start draining (stop accepting new connections).
pub fn start_drain(&self) {
self.draining.store(true, Ordering::SeqCst);
}
/// Wait for all connections to close.
pub async fn wait_drained(&self) {
while self.active.load(Ordering::SeqCst) > 0 {
self.notify.notified().await;
}
}
/// Current connection count.
pub fn active_count(&self) -> usize {
self.active.load(Ordering::SeqCst)
}
/// Is draining?
pub fn is_draining(&self) -> bool {
self.draining.load(Ordering::SeqCst)
}
}
/// RAII guard for active connections.
pub struct ConnectionGuard<'a> {
drainer: &'a ConnectionDrainer,
}
impl<'a> Drop for ConnectionGuard<'a> {
fn drop(&mut self) {
self.drainer.active.fetch_sub(1, Ordering::SeqCst);
self.drainer.notify.notify_waiters();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn shutdown_phases() {
let coord = ShutdownCoordinator::with_timeouts(
Duration::from_millis(100),
Duration::from_millis(50),
);
assert_eq!(coord.phase(), ShutdownPhase::Running);
assert!(!coord.is_shutting_down());
let mut rx = coord.subscribe();
tokio::spawn(async move {
coord.shutdown().await;
});
// Should receive phase transitions
let phase = rx.recv().await.unwrap();
assert_eq!(phase, ShutdownPhase::Draining);
let phase = rx.recv().await.unwrap();
assert_eq!(phase, ShutdownPhase::Persisting);
let phase = rx.recv().await.unwrap();
assert_eq!(phase, ShutdownPhase::Cleanup);
let phase = rx.recv().await.unwrap();
assert_eq!(phase, ShutdownPhase::Complete);
}
#[tokio::test]
async fn task_tracking() {
let coord = ShutdownCoordinator::with_timeouts(
Duration::from_secs(1),
Duration::from_millis(50),
);
let guard1 = coord.register_task();
let guard2 = coord.register_task();
assert_eq!(coord.active_tasks.load(Ordering::SeqCst), 2);
drop(guard1);
assert_eq!(coord.active_tasks.load(Ordering::SeqCst), 1);
drop(guard2);
assert_eq!(coord.active_tasks.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn shutdown_signal() {
let (trigger, mut signal) = ShutdownSignal::new();
assert!(!signal.is_triggered());
let handle = tokio::spawn(async move {
signal.wait().await;
true
});
trigger.trigger();
assert!(handle.await.unwrap());
}
#[tokio::test]
async fn connection_drainer() {
let drainer = ConnectionDrainer::new(2);
let conn1 = drainer.try_accept().expect("should accept");
let conn2 = drainer.try_accept().expect("should accept");
assert!(drainer.try_accept().is_none()); // At capacity
assert_eq!(drainer.active_count(), 2);
drop(conn1);
assert_eq!(drainer.active_count(), 1);
drainer.start_drain();
assert!(drainer.try_accept().is_none()); // Draining
drop(conn2);
// Should complete immediately
tokio::time::timeout(
Duration::from_millis(100),
drainer.wait_drained(),
).await.expect("should drain quickly");
}
#[tokio::test]
async fn shutdown_hooks() {
use std::sync::atomic::AtomicBool;
let persist_ran = Arc::new(AtomicBool::new(false));
let cleanup_ran = Arc::new(AtomicBool::new(false));
let persist_flag = Arc::clone(&persist_ran);
let cleanup_flag = Arc::clone(&cleanup_ran);
let mut hooks = ShutdownHooks::new();
hooks.on_persist(move || async move {
persist_flag.store(true, Ordering::SeqCst);
});
hooks.on_cleanup(move || async move {
cleanup_flag.store(true, Ordering::SeqCst);
});
hooks.run_persist().await;
assert!(persist_ran.load(Ordering::SeqCst));
assert!(!cleanup_ran.load(Ordering::SeqCst));
hooks.run_cleanup().await;
assert!(cleanup_ran.load(Ordering::SeqCst));
}
}