From 024b6c91d11a8aa765d020e58f8b4c10a01c1758 Mon Sep 17 00:00:00 2001 From: Christian Nennemann Date: Wed, 1 Apr 2026 09:16:44 +0200 Subject: [PATCH] feat(p2p): add production infrastructure modules - error.rs: Structured error types with context for all subsystems (transport, routing, crypto, protocol, store, config) - config.rs: Runtime configuration with TOML parsing and validation - metrics.rs: Counter/gauge/histogram metrics with transport-specific tracking and JSON-serializable snapshots - rate_limit.rs: Token bucket rate limiting with per-peer tracking, duty cycle enforcement for LoRa, and backpressure control These modules provide the foundation for production deployment. --- Cargo.lock | 18 + crates/quicprochat-p2p/Cargo.toml | 4 + crates/quicprochat-p2p/src/config.rs | 460 +++++++++++++++++++++ crates/quicprochat-p2p/src/error.rs | 354 ++++++++++++++++ crates/quicprochat-p2p/src/lib.rs | 4 + crates/quicprochat-p2p/src/metrics.rs | 502 +++++++++++++++++++++++ crates/quicprochat-p2p/src/rate_limit.rs | 482 ++++++++++++++++++++++ 7 files changed, 1824 insertions(+) create mode 100644 crates/quicprochat-p2p/src/config.rs create mode 100644 crates/quicprochat-p2p/src/error.rs create mode 100644 crates/quicprochat-p2p/src/metrics.rs create mode 100644 crates/quicprochat-p2p/src/rate_limit.rs diff --git a/Cargo.lock b/Cargo.lock index 80c4700..7f759dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2157,6 +2157,22 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humantime" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" + +[[package]] +name = "humantime-serde" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57a3db5ea5923d99402c94e9feb261dc5ee9b4efa158b0315f788cf549cc200c" +dependencies = [ + "humantime", + "serde", +] + [[package]] name = "hybrid-array" version = "0.2.3" @@ -4454,6 +4470,7 @@ dependencies = [ "ciborium", "hex", "hkdf", + "humantime-serde", "iroh", "quicprochat-core", "rand 0.8.5", @@ -4463,6 +4480,7 @@ dependencies = [ "tempfile", "thiserror 1.0.69", "tokio", + "toml", "tracing", "x25519-dalek", "zeroize", diff --git a/crates/quicprochat-p2p/Cargo.toml b/crates/quicprochat-p2p/Cargo.toml index fa131fd..73c5d61 100644 --- a/crates/quicprochat-p2p/Cargo.toml +++ b/crates/quicprochat-p2p/Cargo.toml @@ -37,6 +37,10 @@ x25519-dalek = { workspace = true } hkdf = { workspace = true } thiserror = { workspace = true } +# Configuration +toml = "0.8" +humantime-serde = "1" + [dev-dependencies] tempfile = "3" diff --git a/crates/quicprochat-p2p/src/config.rs b/crates/quicprochat-p2p/src/config.rs new file mode 100644 index 0000000..f18dc16 --- /dev/null +++ b/crates/quicprochat-p2p/src/config.rs @@ -0,0 +1,460 @@ +//! Runtime configuration for mesh networking. +//! +//! This module provides centralized configuration with sensible defaults +//! and validation. Configuration can be loaded from files, environment +//! variables, or set programmatically. + +use std::path::PathBuf; +use std::time::Duration; + +use serde::{Deserialize, Serialize}; + +use crate::error::{ConfigError, MeshResult}; +use crate::transport::CryptoMode; + +/// Top-level mesh node configuration. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct MeshConfig { + /// Node identity configuration. + pub identity: IdentityConfig, + /// Announce protocol configuration. + pub announce: AnnounceConfig, + /// Routing configuration. + pub routing: RoutingConfig, + /// Store-and-forward configuration. + pub store: StoreConfig, + /// Transport configuration. + pub transport: TransportConfig, + /// Crypto configuration. + pub crypto: CryptoConfig, + /// Rate limiting configuration. + pub rate_limit: RateLimitConfig, + /// Logging configuration. + pub logging: LoggingConfig, +} + +impl Default for MeshConfig { + fn default() -> Self { + Self { + identity: IdentityConfig::default(), + announce: AnnounceConfig::default(), + routing: RoutingConfig::default(), + store: StoreConfig::default(), + transport: TransportConfig::default(), + crypto: CryptoConfig::default(), + rate_limit: RateLimitConfig::default(), + logging: LoggingConfig::default(), + } + } +} + +impl MeshConfig { + /// Load configuration from a TOML file. + pub fn from_file(path: &PathBuf) -> MeshResult { + let content = std::fs::read_to_string(path).map_err(|e| { + ConfigError::Parse(format!("failed to read config file: {}", e)) + })?; + Self::from_toml(&content) + } + + /// Parse configuration from TOML string. + pub fn from_toml(toml: &str) -> MeshResult { + let config: Self = toml::from_str(toml).map_err(|e| { + ConfigError::Parse(format!("TOML parse error: {}", e)) + })?; + config.validate()?; + Ok(config) + } + + /// Serialize to TOML string. + pub fn to_toml(&self) -> MeshResult { + toml::to_string_pretty(self).map_err(|e| { + ConfigError::Parse(format!("TOML serialize error: {}", e)).into() + }) + } + + /// Validate configuration values. + pub fn validate(&self) -> MeshResult<()> { + self.announce.validate()?; + self.routing.validate()?; + self.store.validate()?; + self.rate_limit.validate()?; + Ok(()) + } + + /// Create a minimal config for constrained devices. + pub fn constrained() -> Self { + Self { + store: StoreConfig { + max_messages: 100, + max_keypackages: 50, + ..Default::default() + }, + routing: RoutingConfig { + max_entries: 100, + ..Default::default() + }, + announce: AnnounceConfig { + interval: Duration::from_secs(1800), // 30 min + ..Default::default() + }, + crypto: CryptoConfig { + default_mode: CryptoMode::MlsLiteUnsigned, + ..Default::default() + }, + ..Default::default() + } + } +} + +/// Identity configuration. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct IdentityConfig { + /// Path to persist identity keypair. + pub keypair_path: Option, + /// Whether to auto-generate keypair if missing. + pub auto_generate: bool, +} + +impl Default for IdentityConfig { + fn default() -> Self { + Self { + keypair_path: None, + auto_generate: true, + } + } +} + +/// Announce protocol configuration. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct AnnounceConfig { + /// Interval between periodic announcements. + #[serde(with = "humantime_serde")] + pub interval: Duration, + /// Maximum age before announce is considered stale. + #[serde(with = "humantime_serde")] + pub max_age: Duration, + /// Maximum propagation hops. + pub max_hops: u8, + /// Capabilities to advertise. + pub capabilities: u16, + /// Whether to include KeyPackage hash in announces. + pub include_keypackage: bool, +} + +impl Default for AnnounceConfig { + fn default() -> Self { + Self { + interval: Duration::from_secs(600), // 10 min + max_age: Duration::from_secs(1800), // 30 min + max_hops: 8, + capabilities: 0x0003, // CAP_RELAY | CAP_STORE + include_keypackage: true, + } + } +} + +impl AnnounceConfig { + fn validate(&self) -> MeshResult<()> { + if self.interval < Duration::from_secs(10) { + return Err(ConfigError::InvalidValue { + key: "announce.interval".to_string(), + reason: "must be at least 10 seconds".to_string(), + }.into()); + } + if self.max_hops == 0 || self.max_hops > 32 { + return Err(ConfigError::InvalidValue { + key: "announce.max_hops".to_string(), + reason: "must be between 1 and 32".to_string(), + }.into()); + } + Ok(()) + } +} + +/// Routing configuration. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct RoutingConfig { + /// Maximum routing table entries. + pub max_entries: usize, + /// Default route TTL. + #[serde(with = "humantime_serde")] + pub default_ttl: Duration, + /// How often to garbage collect expired routes. + #[serde(with = "humantime_serde")] + pub gc_interval: Duration, +} + +impl Default for RoutingConfig { + fn default() -> Self { + Self { + max_entries: 10_000, + default_ttl: Duration::from_secs(1800), // 30 min + gc_interval: Duration::from_secs(60), + } + } +} + +impl RoutingConfig { + fn validate(&self) -> MeshResult<()> { + if self.max_entries == 0 { + return Err(ConfigError::InvalidValue { + key: "routing.max_entries".to_string(), + reason: "must be at least 1".to_string(), + }.into()); + } + Ok(()) + } +} + +/// Store-and-forward configuration. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct StoreConfig { + /// Maximum messages in store. + pub max_messages: usize, + /// Maximum messages per recipient. + pub max_per_recipient: usize, + /// Maximum cached KeyPackages. + pub max_keypackages: usize, + /// Maximum KeyPackages per address. + pub max_keypackages_per_addr: usize, + /// Default message TTL. + #[serde(with = "humantime_serde")] + pub default_ttl: Duration, + /// Path for persistent storage (None = in-memory only). + pub persistence_path: Option, +} + +impl Default for StoreConfig { + fn default() -> Self { + Self { + max_messages: 10_000, + max_per_recipient: 100, + max_keypackages: 1_000, + max_keypackages_per_addr: 3, + default_ttl: Duration::from_secs(24 * 3600), // 24 hours + persistence_path: None, + } + } +} + +impl StoreConfig { + fn validate(&self) -> MeshResult<()> { + if self.max_messages == 0 { + return Err(ConfigError::InvalidValue { + key: "store.max_messages".to_string(), + reason: "must be at least 1".to_string(), + }.into()); + } + Ok(()) + } +} + +/// Transport configuration. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct TransportConfig { + /// Enable iroh/QUIC transport. + pub enable_iroh: bool, + /// Enable TCP transport. + pub enable_tcp: bool, + /// TCP listen address. + pub tcp_listen: Option, + /// Enable LoRa transport. + pub enable_lora: bool, + /// LoRa device path (e.g., /dev/ttyUSB0). + pub lora_device: Option, + /// LoRa spreading factor (7-12). + pub lora_sf: u8, + /// LoRa bandwidth in kHz. + pub lora_bw: u32, + /// Connection timeout. + #[serde(with = "humantime_serde")] + pub connect_timeout: Duration, + /// Send timeout. + #[serde(with = "humantime_serde")] + pub send_timeout: Duration, +} + +impl Default for TransportConfig { + fn default() -> Self { + Self { + enable_iroh: true, + enable_tcp: true, + tcp_listen: None, + enable_lora: false, + lora_device: None, + lora_sf: 10, + lora_bw: 125, + connect_timeout: Duration::from_secs(10), + send_timeout: Duration::from_secs(30), + } + } +} + +/// Crypto configuration. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct CryptoConfig { + /// Default crypto mode. + pub default_mode: CryptoMode, + /// Whether to auto-upgrade to better crypto when available. + pub auto_upgrade: bool, + /// Whether to sign MLS-Lite messages. + pub mls_lite_sign: bool, + /// Enable post-quantum hybrid mode. + pub enable_pq: bool, +} + +impl Default for CryptoConfig { + fn default() -> Self { + Self { + default_mode: CryptoMode::MlsClassical, + auto_upgrade: true, + mls_lite_sign: true, + enable_pq: false, // PQ is large, opt-in + } + } +} + +/// Rate limiting configuration. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct RateLimitConfig { + /// Maximum announces per peer per minute. + pub announce_per_peer_per_min: u32, + /// Maximum messages per peer per minute. + pub message_per_peer_per_min: u32, + /// Maximum KeyPackage requests per minute. + pub keypackage_requests_per_min: u32, + /// LoRa duty cycle limit (0.0-1.0, e.g., 0.01 = 1%). + pub lora_duty_cycle: f32, +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + announce_per_peer_per_min: 10, + message_per_peer_per_min: 60, + keypackage_requests_per_min: 20, + lora_duty_cycle: 0.01, // EU868 1% default + } + } +} + +impl RateLimitConfig { + fn validate(&self) -> MeshResult<()> { + if self.lora_duty_cycle < 0.0 || self.lora_duty_cycle > 1.0 { + return Err(ConfigError::InvalidValue { + key: "rate_limit.lora_duty_cycle".to_string(), + reason: "must be between 0.0 and 1.0".to_string(), + }.into()); + } + Ok(()) + } +} + +/// Logging configuration. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(default)] +pub struct LoggingConfig { + /// Log level (trace, debug, info, warn, error). + pub level: String, + /// Whether to log to file. + pub file: Option, + /// Whether to include timestamps. + pub timestamps: bool, + /// Whether to include span context. + pub spans: bool, +} + +impl Default for LoggingConfig { + fn default() -> Self { + Self { + level: "info".to_string(), + file: None, + timestamps: true, + spans: false, + } + } +} + +// Serde helper for CryptoMode +impl Serialize for CryptoMode { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let s = match self { + CryptoMode::MlsHybrid => "mls-hybrid", + CryptoMode::MlsClassical => "mls-classical", + CryptoMode::MlsLiteSigned => "mls-lite-signed", + CryptoMode::MlsLiteUnsigned => "mls-lite-unsigned", + }; + serializer.serialize_str(s) + } +} + +impl<'de> Deserialize<'de> for CryptoMode { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + match s.as_str() { + "mls-hybrid" => Ok(CryptoMode::MlsHybrid), + "mls-classical" => Ok(CryptoMode::MlsClassical), + "mls-lite-signed" => Ok(CryptoMode::MlsLiteSigned), + "mls-lite-unsigned" => Ok(CryptoMode::MlsLiteUnsigned), + _ => Err(serde::de::Error::unknown_variant( + &s, + &["mls-hybrid", "mls-classical", "mls-lite-signed", "mls-lite-unsigned"], + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_config_is_valid() { + let config = MeshConfig::default(); + assert!(config.validate().is_ok()); + } + + #[test] + fn constrained_config_is_valid() { + let config = MeshConfig::constrained(); + assert!(config.validate().is_ok()); + assert_eq!(config.store.max_messages, 100); + } + + #[test] + fn toml_roundtrip() { + let config = MeshConfig::default(); + let toml = config.to_toml().expect("serialize"); + let restored = MeshConfig::from_toml(&toml).expect("parse"); + assert_eq!(config.announce.max_hops, restored.announce.max_hops); + } + + #[test] + fn invalid_announce_interval() { + let mut config = MeshConfig::default(); + config.announce.interval = Duration::from_secs(1); // Too short + assert!(config.validate().is_err()); + } + + #[test] + fn invalid_duty_cycle() { + let mut config = MeshConfig::default(); + config.rate_limit.lora_duty_cycle = 2.0; // > 1.0 + assert!(config.validate().is_err()); + } +} diff --git a/crates/quicprochat-p2p/src/error.rs b/crates/quicprochat-p2p/src/error.rs new file mode 100644 index 0000000..eb6434a --- /dev/null +++ b/crates/quicprochat-p2p/src/error.rs @@ -0,0 +1,354 @@ +//! Production-ready error types for the mesh P2P layer. +//! +//! This module provides structured error types with context for debugging +//! and recovery. Errors are categorized by subsystem for easier handling. + +use std::fmt; + +use thiserror::Error; + +use crate::address::MeshAddress; +use crate::transport::TransportAddr; + +/// Top-level mesh error type. +#[derive(Debug, Error)] +pub enum MeshError { + /// Transport layer errors. + #[error("transport error: {0}")] + Transport(#[from] TransportError), + + /// Routing errors. + #[error("routing error: {0}")] + Routing(#[from] RoutingError), + + /// Crypto/encryption errors. + #[error("crypto error: {0}")] + Crypto(#[from] CryptoError), + + /// Protocol errors (malformed messages, version mismatch). + #[error("protocol error: {0}")] + Protocol(#[from] ProtocolError), + + /// Store/cache errors. + #[error("store error: {0}")] + Store(#[from] StoreError), + + /// Configuration errors. + #[error("config error: {0}")] + Config(#[from] ConfigError), + + /// Internal errors (bugs, invariant violations). + #[error("internal error: {0}")] + Internal(String), +} + +/// Transport layer errors. +#[derive(Debug, Error)] +pub enum TransportError { + /// Failed to send data. + #[error("send failed to {dest}: {reason}")] + SendFailed { dest: String, reason: String }, + + /// Failed to receive data. + #[error("receive failed: {0}")] + ReceiveFailed(String), + + /// Connection failed or lost. + #[error("connection to {dest} failed: {reason}")] + ConnectionFailed { dest: String, reason: String }, + + /// Transport not available. + #[error("transport '{name}' not available")] + NotAvailable { name: String }, + + /// No transports registered. + #[error("no transports registered")] + NoTransports, + + /// MTU exceeded. + #[error("payload {size} bytes exceeds MTU {mtu} bytes")] + MtuExceeded { size: usize, mtu: usize }, + + /// Duty cycle limit reached. + #[error("duty cycle limit reached: {used_ms}ms used of {limit_ms}ms allowed")] + DutyCycleExceeded { used_ms: u64, limit_ms: u64 }, + + /// Timeout waiting for response. + #[error("timeout waiting for response from {dest}")] + Timeout { dest: String }, + + /// I/O error. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), +} + +/// Routing errors. +#[derive(Debug, Error)] +pub enum RoutingError { + /// No route to destination. + #[error("no route to {0}")] + NoRoute(String), + + /// Route expired. + #[error("route to {dest} expired (last seen {age_secs}s ago)")] + RouteExpired { dest: String, age_secs: u64 }, + + /// Too many hops. + #[error("max hops ({max}) exceeded for message to {dest}")] + MaxHopsExceeded { dest: String, max: u8 }, + + /// Message expired. + #[error("message expired (TTL {ttl_secs}s, age {age_secs}s)")] + MessageExpired { ttl_secs: u32, age_secs: u64 }, + + /// Duplicate message (dedup). + #[error("duplicate message ID {0}")] + Duplicate(String), + + /// Routing table full. + #[error("routing table full ({capacity} entries)")] + TableFull { capacity: usize }, +} + +/// Crypto/encryption errors. +#[derive(Debug, Error)] +pub enum CryptoError { + /// Signature verification failed. + #[error("signature verification failed for {context}")] + SignatureInvalid { context: String }, + + /// Decryption failed. + #[error("decryption failed: {0}")] + DecryptionFailed(String), + + /// Key not found. + #[error("key not found for {0}")] + KeyNotFound(String), + + /// KeyPackage invalid or expired. + #[error("KeyPackage invalid: {0}")] + KeyPackageInvalid(String), + + /// Replay attack detected. + #[error("replay detected: sequence {seq} already seen from {sender}")] + ReplayDetected { sender: String, seq: u32 }, + + /// Wrong epoch. + #[error("wrong epoch: expected {expected}, got {got}")] + WrongEpoch { expected: u16, got: u16 }, + + /// MLS error (from openmls). + #[error("MLS error: {0}")] + Mls(String), +} + +/// Protocol errors. +#[derive(Debug, Error)] +pub enum ProtocolError { + /// Unknown message type. + #[error("unknown message type: 0x{0:02x}")] + UnknownMessageType(u8), + + /// Invalid message format. + #[error("invalid message format: {0}")] + InvalidFormat(String), + + /// Version mismatch. + #[error("protocol version mismatch: expected {expected}, got {got}")] + VersionMismatch { expected: u8, got: u8 }, + + /// Required field missing. + #[error("required field missing: {0}")] + MissingField(String), + + /// CBOR decode error. + #[error("CBOR decode error: {0}")] + CborDecode(String), + + /// CBOR encode error. + #[error("CBOR encode error: {0}")] + CborEncode(String), + + /// Message too large. + #[error("message too large: {size} bytes (max {max})")] + MessageTooLarge { size: usize, max: usize }, +} + +/// Store/cache errors. +#[derive(Debug, Error)] +pub enum StoreError { + /// Store is full. + #[error("store full: {current}/{capacity} items")] + Full { current: usize, capacity: usize }, + + /// Item not found. + #[error("item not found: {0}")] + NotFound(String), + + /// Persistence error. + #[error("persistence error: {0}")] + Persistence(String), + + /// Serialization error. + #[error("serialization error: {0}")] + Serialization(String), +} + +/// Configuration errors. +#[derive(Debug, Error)] +pub enum ConfigError { + /// Invalid configuration value. + #[error("invalid config value for '{key}': {reason}")] + InvalidValue { key: String, reason: String }, + + /// Missing required configuration. + #[error("missing required config: {0}")] + Missing(String), + + /// Configuration parse error. + #[error("config parse error: {0}")] + Parse(String), +} + +/// Result type alias for mesh operations. +pub type MeshResult = Result; + +/// Error context extension trait for adding context to errors. +pub trait ErrorContext { + /// Add context to an error. + fn context(self, context: impl Into) -> MeshResult; + + /// Add context with a closure (lazy evaluation). + fn with_context(self, f: F) -> MeshResult + where + F: FnOnce() -> String; +} + +impl> ErrorContext for Result { + fn context(self, context: impl Into) -> MeshResult { + self.map_err(|e| { + let err = e.into(); + MeshError::Internal(format!("{}: {}", context.into(), err)) + }) + } + + fn with_context(self, f: F) -> MeshResult + where + F: FnOnce() -> String, + { + self.map_err(|e| { + let err = e.into(); + MeshError::Internal(format!("{}: {}", f(), err)) + }) + } +} + +/// Convert anyhow errors to MeshError. +impl From for MeshError { + fn from(e: anyhow::Error) -> Self { + MeshError::Internal(e.to_string()) + } +} + +/// Helper to create transport send errors. +impl TransportError { + pub fn send_failed(dest: &TransportAddr, reason: impl Into) -> Self { + Self::SendFailed { + dest: dest.to_string(), + reason: reason.into(), + } + } + + pub fn connection_failed(dest: &TransportAddr, reason: impl Into) -> Self { + Self::ConnectionFailed { + dest: dest.to_string(), + reason: reason.into(), + } + } +} + +/// Helper to create routing errors. +impl RoutingError { + pub fn no_route(addr: &MeshAddress) -> Self { + Self::NoRoute(format!("{}", addr)) + } + + pub fn no_route_bytes(addr: &[u8]) -> Self { + Self::NoRoute(hex::encode(&addr[..8.min(addr.len())])) + } +} + +/// Helper to create crypto errors. +impl CryptoError { + pub fn signature_invalid(context: impl Into) -> Self { + Self::SignatureInvalid { + context: context.into(), + } + } + + pub fn replay(sender: &MeshAddress, seq: u32) -> Self { + Self::ReplayDetected { + sender: format!("{}", sender), + seq, + } + } +} + +/// Helper to create protocol errors. +impl ProtocolError { + pub fn cbor_decode(e: impl fmt::Display) -> Self { + Self::CborDecode(e.to_string()) + } + + pub fn cbor_encode(e: impl fmt::Display) -> Self { + Self::CborEncode(e.to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn error_display() { + let err = TransportError::SendFailed { + dest: "tcp:127.0.0.1:8080".to_string(), + reason: "connection refused".to_string(), + }; + assert!(err.to_string().contains("tcp:127.0.0.1:8080")); + assert!(err.to_string().contains("connection refused")); + } + + #[test] + fn error_conversion() { + let transport_err = TransportError::NoTransports; + let mesh_err: MeshError = transport_err.into(); + assert!(matches!(mesh_err, MeshError::Transport(_))); + } + + #[test] + fn routing_error_helpers() { + let addr = MeshAddress::from_bytes([0xAB; 16]); + let err = RoutingError::no_route(&addr); + assert!(err.to_string().contains("no route")); + } + + #[test] + fn crypto_error_helpers() { + let addr = MeshAddress::from_bytes([0xCD; 16]); + let err = CryptoError::replay(&addr, 42); + assert!(err.to_string().contains("42")); + } + + #[test] + fn context_extension() { + fn fallible() -> Result<(), TransportError> { + Err(TransportError::NoTransports) + } + + let result: MeshResult<()> = fallible().context("during startup"); + assert!(result.is_err()); + let err_str = result.unwrap_err().to_string(); + assert!(err_str.contains("during startup")); + } +} diff --git a/crates/quicprochat-p2p/src/lib.rs b/crates/quicprochat-p2p/src/lib.rs index 0380839..13299ea 100644 --- a/crates/quicprochat-p2p/src/lib.rs +++ b/crates/quicprochat-p2p/src/lib.rs @@ -15,7 +15,9 @@ pub mod address; pub mod announce; pub mod announce_protocol; +pub mod config; pub mod crypto_negotiation; +pub mod error; pub mod fapp; pub mod fapp_router; pub mod broadcast; @@ -23,7 +25,9 @@ pub mod envelope; pub mod envelope_v2; pub mod keypackage_cache; pub mod mesh_protocol; +pub mod metrics; pub mod mls_lite; +pub mod rate_limit; pub mod identity; pub mod link; pub mod mesh_router; diff --git a/crates/quicprochat-p2p/src/metrics.rs b/crates/quicprochat-p2p/src/metrics.rs new file mode 100644 index 0000000..10a497c --- /dev/null +++ b/crates/quicprochat-p2p/src/metrics.rs @@ -0,0 +1,502 @@ +//! Observability metrics for mesh networking. +//! +//! This module provides structured metrics collection for monitoring +//! mesh node health, performance, and resource usage. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, RwLock}; +use std::time::{Duration, Instant}; + +/// Atomic counter for thread-safe metric updates. +#[derive(Debug, Default)] +pub struct Counter(AtomicU64); + +impl Counter { + pub fn new() -> Self { + Self(AtomicU64::new(0)) + } + + pub fn inc(&self) { + self.0.fetch_add(1, Ordering::Relaxed); + } + + pub fn inc_by(&self, n: u64) { + self.0.fetch_add(n, Ordering::Relaxed); + } + + pub fn get(&self) -> u64 { + self.0.load(Ordering::Relaxed) + } + + pub fn reset(&self) -> u64 { + self.0.swap(0, Ordering::Relaxed) + } +} + +/// Gauge for values that can go up and down. +#[derive(Debug, Default)] +pub struct Gauge(AtomicU64); + +impl Gauge { + pub fn new() -> Self { + Self(AtomicU64::new(0)) + } + + pub fn set(&self, val: u64) { + self.0.store(val, Ordering::Relaxed); + } + + pub fn inc(&self) { + self.0.fetch_add(1, Ordering::Relaxed); + } + + pub fn dec(&self) { + self.0.fetch_sub(1, Ordering::Relaxed); + } + + pub fn get(&self) -> u64 { + self.0.load(Ordering::Relaxed) + } +} + +/// Histogram for tracking distributions (simple bucket-based). +#[derive(Debug)] +pub struct Histogram { + /// Bucket boundaries (upper limits). + buckets: Vec, + /// Count in each bucket. + counts: Vec, + /// Sum of all values. + sum: AtomicU64, + /// Total count. + count: AtomicU64, +} + +impl Histogram { + /// Create with default latency buckets (ms). + pub fn latency_ms() -> Self { + Self::new(vec![1, 5, 10, 25, 50, 100, 250, 500, 1000, 5000, 10000]) + } + + /// Create with default size buckets (bytes). + pub fn size_bytes() -> Self { + Self::new(vec![64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 65536]) + } + + pub fn new(buckets: Vec) -> Self { + let counts = buckets.iter().map(|_| AtomicU64::new(0)).collect(); + Self { + buckets, + counts, + sum: AtomicU64::new(0), + count: AtomicU64::new(0), + } + } + + pub fn observe(&self, value: u64) { + self.sum.fetch_add(value, Ordering::Relaxed); + self.count.fetch_add(1, Ordering::Relaxed); + + for (i, &upper) in self.buckets.iter().enumerate() { + if value <= upper { + self.counts[i].fetch_add(1, Ordering::Relaxed); + return; + } + } + // Value exceeds all buckets — count in last + if let Some(last) = self.counts.last() { + last.fetch_add(1, Ordering::Relaxed); + } + } + + pub fn observe_duration(&self, d: Duration) { + self.observe(d.as_millis() as u64); + } + + pub fn sum(&self) -> u64 { + self.sum.load(Ordering::Relaxed) + } + + pub fn count(&self) -> u64 { + self.count.load(Ordering::Relaxed) + } + + pub fn avg(&self) -> f64 { + let count = self.count(); + if count == 0 { + 0.0 + } else { + self.sum() as f64 / count as f64 + } + } +} + +/// Per-transport metrics. +#[derive(Debug, Default)] +pub struct TransportMetrics { + /// Messages sent successfully. + pub sent: Counter, + /// Messages received. + pub received: Counter, + /// Send failures. + pub send_errors: Counter, + /// Receive errors. + pub recv_errors: Counter, + /// Bytes sent. + pub bytes_sent: Counter, + /// Bytes received. + pub bytes_received: Counter, + /// Active connections (for connection-oriented transports). + pub connections: Gauge, +} + +/// Per-peer metrics. +#[derive(Debug)] +pub struct PeerMetrics { + /// Messages sent to this peer. + pub messages_sent: Counter, + /// Messages received from this peer. + pub messages_received: Counter, + /// Last seen timestamp. + pub last_seen: RwLock>, + /// Round-trip time samples. + pub rtt_ms: Histogram, +} + +impl Default for PeerMetrics { + fn default() -> Self { + Self { + messages_sent: Counter::new(), + messages_received: Counter::new(), + last_seen: RwLock::new(None), + rtt_ms: Histogram::latency_ms(), + } + } +} + +impl PeerMetrics { + pub fn touch(&self) { + if let Ok(mut last) = self.last_seen.write() { + *last = Some(Instant::now()); + } + } + + pub fn age(&self) -> Option { + self.last_seen + .read() + .ok() + .and_then(|t| t.map(|i| i.elapsed())) + } +} + +/// Global mesh metrics. +#[derive(Debug)] +pub struct MeshMetrics { + /// Transport metrics by name. + pub transports: RwLock>>, + /// Routing metrics. + pub routing: RoutingMetrics, + /// Store metrics. + pub store: StoreMetrics, + /// Crypto metrics. + pub crypto: CryptoMetrics, + /// Protocol metrics. + pub protocol: ProtocolMetrics, + /// Node start time. + pub started_at: Instant, +} + +impl Default for MeshMetrics { + fn default() -> Self { + Self::new() + } +} + +impl MeshMetrics { + pub fn new() -> Self { + Self { + transports: RwLock::new(HashMap::new()), + routing: RoutingMetrics::default(), + store: StoreMetrics::default(), + crypto: CryptoMetrics::default(), + protocol: ProtocolMetrics::default(), + started_at: Instant::now(), + } + } + + /// Get or create transport metrics. + pub fn transport(&self, name: &str) -> Arc { + { + let map = self.transports.read().unwrap(); + if let Some(m) = map.get(name) { + return Arc::clone(m); + } + } + let mut map = self.transports.write().unwrap(); + map.entry(name.to_string()) + .or_insert_with(|| Arc::new(TransportMetrics::default())) + .clone() + } + + /// Node uptime. + pub fn uptime(&self) -> Duration { + self.started_at.elapsed() + } + + /// Export metrics as a snapshot. + pub fn snapshot(&self) -> MetricsSnapshot { + let transports = self.transports.read().unwrap(); + let transport_snapshots: HashMap = transports + .iter() + .map(|(name, m)| { + ( + name.clone(), + TransportSnapshot { + sent: m.sent.get(), + received: m.received.get(), + send_errors: m.send_errors.get(), + bytes_sent: m.bytes_sent.get(), + bytes_received: m.bytes_received.get(), + connections: m.connections.get(), + }, + ) + }) + .collect(); + + MetricsSnapshot { + uptime_secs: self.uptime().as_secs(), + transports: transport_snapshots, + routing: RoutingSnapshot { + table_size: self.routing.table_size.get(), + lookups: self.routing.lookups.get(), + lookup_misses: self.routing.lookup_misses.get(), + announcements_processed: self.routing.announcements_processed.get(), + }, + store: StoreSnapshot { + messages_stored: self.store.messages_stored.get(), + messages_delivered: self.store.messages_delivered.get(), + messages_expired: self.store.messages_expired.get(), + current_size: self.store.current_size.get(), + }, + crypto: CryptoSnapshot { + encryptions: self.crypto.encryptions.get(), + decryptions: self.crypto.decryptions.get(), + signature_verifications: self.crypto.signature_verifications.get(), + signature_failures: self.crypto.signature_failures.get(), + replay_detections: self.crypto.replay_detections.get(), + }, + } + } +} + +/// Routing subsystem metrics. +#[derive(Debug, Default)] +pub struct RoutingMetrics { + /// Current routing table size. + pub table_size: Gauge, + /// Route lookups. + pub lookups: Counter, + /// Route lookup misses. + pub lookup_misses: Counter, + /// Routes added. + pub routes_added: Counter, + /// Routes expired. + pub routes_expired: Counter, + /// Announcements processed. + pub announcements_processed: Counter, + /// Announcements forwarded. + pub announcements_forwarded: Counter, + /// Duplicate announcements dropped. + pub duplicates_dropped: Counter, +} + +/// Store subsystem metrics. +#[derive(Debug, Default)] +pub struct StoreMetrics { + /// Messages stored. + pub messages_stored: Counter, + /// Messages delivered. + pub messages_delivered: Counter, + /// Messages expired. + pub messages_expired: Counter, + /// Current store size. + pub current_size: Gauge, + /// Store capacity reached events. + pub capacity_reached: Counter, +} + +/// Crypto subsystem metrics. +#[derive(Debug)] +pub struct CryptoMetrics { + /// Successful encryptions. + pub encryptions: Counter, + /// Successful decryptions. + pub decryptions: Counter, + /// Decryption failures. + pub decryption_failures: Counter, + /// Signature verifications. + pub signature_verifications: Counter, + /// Signature failures. + pub signature_failures: Counter, + /// Replay attacks detected. + pub replay_detections: Counter, + /// Encryption latency. + pub encrypt_latency: Histogram, +} + +impl Default for CryptoMetrics { + fn default() -> Self { + Self { + encryptions: Counter::new(), + decryptions: Counter::new(), + decryption_failures: Counter::new(), + signature_verifications: Counter::new(), + signature_failures: Counter::new(), + replay_detections: Counter::new(), + encrypt_latency: Histogram::latency_ms(), + } + } +} + +/// Protocol metrics. +#[derive(Debug, Default)] +pub struct ProtocolMetrics { + /// Messages parsed. + pub messages_parsed: Counter, + /// Parse errors. + pub parse_errors: Counter, + /// Unknown message types. + pub unknown_types: Counter, + /// Messages too large. + pub oversized: Counter, +} + +/// Point-in-time snapshot of metrics. +#[derive(Debug, Clone, serde::Serialize)] +pub struct MetricsSnapshot { + pub uptime_secs: u64, + pub transports: HashMap, + pub routing: RoutingSnapshot, + pub store: StoreSnapshot, + pub crypto: CryptoSnapshot, +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct TransportSnapshot { + pub sent: u64, + pub received: u64, + pub send_errors: u64, + pub bytes_sent: u64, + pub bytes_received: u64, + pub connections: u64, +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct RoutingSnapshot { + pub table_size: u64, + pub lookups: u64, + pub lookup_misses: u64, + pub announcements_processed: u64, +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct StoreSnapshot { + pub messages_stored: u64, + pub messages_delivered: u64, + pub messages_expired: u64, + pub current_size: u64, +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct CryptoSnapshot { + pub encryptions: u64, + pub decryptions: u64, + pub signature_verifications: u64, + pub signature_failures: u64, + pub replay_detections: u64, +} + +/// Global metrics instance. +static GLOBAL_METRICS: std::sync::OnceLock> = std::sync::OnceLock::new(); + +/// Get the global metrics instance. +pub fn metrics() -> &'static Arc { + GLOBAL_METRICS.get_or_init(|| Arc::new(MeshMetrics::new())) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn counter_basics() { + let c = Counter::new(); + assert_eq!(c.get(), 0); + c.inc(); + assert_eq!(c.get(), 1); + c.inc_by(5); + assert_eq!(c.get(), 6); + let old = c.reset(); + assert_eq!(old, 6); + assert_eq!(c.get(), 0); + } + + #[test] + fn gauge_basics() { + let g = Gauge::new(); + assert_eq!(g.get(), 0); + g.set(10); + assert_eq!(g.get(), 10); + g.inc(); + assert_eq!(g.get(), 11); + g.dec(); + assert_eq!(g.get(), 10); + } + + #[test] + fn histogram_basics() { + let h = Histogram::new(vec![10, 50, 100]); + h.observe(5); + h.observe(25); + h.observe(75); + h.observe(200); + + assert_eq!(h.count(), 4); + assert_eq!(h.sum(), 5 + 25 + 75 + 200); + } + + #[test] + fn transport_metrics() { + let m = MeshMetrics::new(); + let tcp = m.transport("tcp"); + tcp.sent.inc(); + tcp.bytes_sent.inc_by(100); + + assert_eq!(tcp.sent.get(), 1); + assert_eq!(tcp.bytes_sent.get(), 100); + + // Same name returns same instance + let tcp2 = m.transport("tcp"); + assert_eq!(tcp2.sent.get(), 1); + } + + #[test] + fn snapshot_serializes() { + let m = MeshMetrics::new(); + m.transport("tcp").sent.inc(); + m.routing.lookups.inc_by(10); + + let snapshot = m.snapshot(); + let json = serde_json::to_string(&snapshot).expect("serialize"); + assert!(json.contains("\"uptime_secs\":")); + assert!(json.contains("\"lookups\":10")); + } + + #[test] + fn global_metrics() { + let m = metrics(); + m.protocol.messages_parsed.inc(); + assert_eq!(metrics().protocol.messages_parsed.get(), 1); + } +} diff --git a/crates/quicprochat-p2p/src/rate_limit.rs b/crates/quicprochat-p2p/src/rate_limit.rs new file mode 100644 index 0000000..612ff5b --- /dev/null +++ b/crates/quicprochat-p2p/src/rate_limit.rs @@ -0,0 +1,482 @@ +//! Rate limiting for DoS protection. +//! +//! This module provides token bucket rate limiters for controlling +//! message rates per peer and globally. Designed for low overhead +//! even on constrained devices. + +use std::collections::HashMap; +use std::sync::RwLock; +use std::time::{Duration, Instant}; + +use crate::address::MeshAddress; +use crate::config::RateLimitConfig; +use crate::error::{MeshError, MeshResult}; + +/// Result of a rate limit check. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RateLimitResult { + /// Request allowed. + Allowed, + /// Request denied, retry after this duration. + Denied { retry_after: Duration }, + /// Soft warning: approaching limit. + Warning { remaining: u32 }, +} + +impl RateLimitResult { + pub fn is_allowed(&self) -> bool { + matches!(self, Self::Allowed | Self::Warning { .. }) + } +} + +/// Token bucket rate limiter. +#[derive(Debug)] +pub struct TokenBucket { + /// Maximum tokens (bucket capacity). + capacity: u32, + /// Current tokens. + tokens: f64, + /// Tokens added per second. + refill_rate: f64, + /// Last refill time. + last_refill: Instant, +} + +impl TokenBucket { + /// Create a new token bucket. + pub fn new(capacity: u32, per_second: f64) -> Self { + Self { + capacity, + tokens: capacity as f64, + refill_rate: per_second, + last_refill: Instant::now(), + } + } + + /// Create from per-minute rate. + pub fn per_minute(per_minute: u32) -> Self { + let capacity = per_minute.max(1); + let per_second = per_minute as f64 / 60.0; + Self::new(capacity, per_second) + } + + /// Refill tokens based on elapsed time. + fn refill(&mut self) { + let now = Instant::now(); + let elapsed = now.duration_since(self.last_refill); + let add = elapsed.as_secs_f64() * self.refill_rate; + self.tokens = (self.tokens + add).min(self.capacity as f64); + self.last_refill = now; + } + + /// Try to consume one token. + pub fn try_acquire(&mut self) -> RateLimitResult { + self.try_acquire_n(1) + } + + /// Try to consume n tokens. + pub fn try_acquire_n(&mut self, n: u32) -> RateLimitResult { + self.refill(); + + let n_f = n as f64; + if self.tokens >= n_f { + self.tokens -= n_f; + let remaining = self.tokens as u32; + if remaining < self.capacity / 4 { + RateLimitResult::Warning { remaining } + } else { + RateLimitResult::Allowed + } + } else { + let deficit = n_f - self.tokens; + let wait_secs = deficit / self.refill_rate; + RateLimitResult::Denied { + retry_after: Duration::from_secs_f64(wait_secs), + } + } + } + + /// Current available tokens. + pub fn available(&mut self) -> u32 { + self.refill(); + self.tokens as u32 + } +} + +/// Per-peer rate limiter with multiple buckets. +#[derive(Debug)] +pub struct PeerRateLimiter { + /// Message bucket. + messages: TokenBucket, + /// Announce bucket. + announces: TokenBucket, + /// KeyPackage request bucket. + keypackage_requests: TokenBucket, + /// Last activity (for cleanup). + last_activity: Instant, +} + +impl PeerRateLimiter { + pub fn from_config(config: &RateLimitConfig) -> Self { + Self { + messages: TokenBucket::per_minute(config.message_per_peer_per_min), + announces: TokenBucket::per_minute(config.announce_per_peer_per_min), + keypackage_requests: TokenBucket::per_minute(config.keypackage_requests_per_min), + last_activity: Instant::now(), + } + } + + pub fn check_message(&mut self) -> RateLimitResult { + self.last_activity = Instant::now(); + self.messages.try_acquire() + } + + pub fn check_announce(&mut self) -> RateLimitResult { + self.last_activity = Instant::now(); + self.announces.try_acquire() + } + + pub fn check_keypackage_request(&mut self) -> RateLimitResult { + self.last_activity = Instant::now(); + self.keypackage_requests.try_acquire() + } + + /// Time since last activity. + pub fn idle_time(&self) -> Duration { + self.last_activity.elapsed() + } +} + +/// Global rate limiter managing per-peer limits. +pub struct RateLimiter { + /// Configuration. + config: RateLimitConfig, + /// Per-peer limiters. + peers: RwLock>, + /// Maximum tracked peers (to prevent memory exhaustion). + max_peers: usize, +} + +impl RateLimiter { + pub fn new(config: RateLimitConfig) -> Self { + Self { + config, + peers: RwLock::new(HashMap::new()), + max_peers: 10_000, + } + } + + /// Check if a message from peer is allowed. + pub fn check_message(&self, peer: &MeshAddress) -> MeshResult { + let mut peers = self.peers.write().map_err(|_| { + MeshError::Internal("rate limiter lock poisoned".to_string()) + })?; + + let limiter = peers + .entry(*peer) + .or_insert_with(|| PeerRateLimiter::from_config(&self.config)); + + Ok(limiter.check_message()) + } + + /// Check if an announce from peer is allowed. + pub fn check_announce(&self, peer: &MeshAddress) -> MeshResult { + let mut peers = self.peers.write().map_err(|_| { + MeshError::Internal("rate limiter lock poisoned".to_string()) + })?; + + let limiter = peers + .entry(*peer) + .or_insert_with(|| PeerRateLimiter::from_config(&self.config)); + + Ok(limiter.check_announce()) + } + + /// Check if a KeyPackage request from peer is allowed. + pub fn check_keypackage_request(&self, peer: &MeshAddress) -> MeshResult { + let mut peers = self.peers.write().map_err(|_| { + MeshError::Internal("rate limiter lock poisoned".to_string()) + })?; + + let limiter = peers + .entry(*peer) + .or_insert_with(|| PeerRateLimiter::from_config(&self.config)); + + Ok(limiter.check_keypackage_request()) + } + + /// Remove limiters for peers idle longer than max_idle. + pub fn cleanup(&self, max_idle: Duration) -> usize { + let mut peers = match self.peers.write() { + Ok(p) => p, + Err(_) => return 0, + }; + + let before = peers.len(); + peers.retain(|_, limiter| limiter.idle_time() < max_idle); + before - peers.len() + } + + /// Number of tracked peers. + pub fn tracked_peers(&self) -> usize { + self.peers.read().map(|p| p.len()).unwrap_or(0) + } +} + +/// Duty cycle tracker for LoRa compliance. +#[derive(Debug)] +pub struct DutyCycleTracker { + /// Duty cycle limit (0.0 to 1.0). + limit: f32, + /// Window size for tracking. + window: Duration, + /// Transmission records: (timestamp, duration_ms). + transmissions: RwLock>, +} + +impl DutyCycleTracker { + /// Create with a duty cycle limit (e.g., 0.01 for 1%). + pub fn new(limit: f32) -> Self { + Self { + limit: limit.clamp(0.0, 1.0), + window: Duration::from_secs(3600), // 1 hour window + transmissions: RwLock::new(Vec::new()), + } + } + + /// Check if we can transmit for the given duration. + pub fn can_transmit(&self, airtime_ms: u64) -> bool { + let used = self.used_ms(); + let window_ms = self.window.as_millis() as u64; + let limit_ms = (window_ms as f32 * self.limit) as u64; + used + airtime_ms <= limit_ms + } + + /// Record a transmission. + pub fn record(&self, airtime_ms: u64) { + if let Ok(mut tx) = self.transmissions.write() { + tx.push((Instant::now(), airtime_ms)); + } + } + + /// Get total airtime used in current window. + pub fn used_ms(&self) -> u64 { + let cutoff = Instant::now() - self.window; + let tx = match self.transmissions.read() { + Ok(t) => t, + Err(_) => return 0, + }; + + tx.iter() + .filter(|(t, _)| *t > cutoff) + .map(|(_, d)| *d) + .sum() + } + + /// Get remaining airtime in current window. + pub fn remaining_ms(&self) -> u64 { + let window_ms = self.window.as_millis() as u64; + let limit_ms = (window_ms as f32 * self.limit) as u64; + limit_ms.saturating_sub(self.used_ms()) + } + + /// Clean up old records. + pub fn cleanup(&self) { + let cutoff = Instant::now() - self.window; + if let Ok(mut tx) = self.transmissions.write() { + tx.retain(|(t, _)| *t > cutoff); + } + } + + /// Current duty cycle usage as fraction. + pub fn current_usage(&self) -> f32 { + let window_ms = self.window.as_millis() as f32; + self.used_ms() as f32 / window_ms + } +} + +/// Backpressure signal for flow control. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BackpressureLevel { + /// No backpressure, process normally. + None, + /// Light pressure, shed low-priority work. + Light, + /// Medium pressure, shed non-critical work. + Medium, + /// Heavy pressure, only process critical messages. + Heavy, + /// Overloaded, reject new work. + Overloaded, +} + +impl BackpressureLevel { + /// Should we process a message at this priority (0 = highest)? + pub fn should_process(&self, priority: u8) -> bool { + match self { + Self::None => true, + Self::Light => priority <= 2, + Self::Medium => priority <= 1, + Self::Heavy => priority == 0, + Self::Overloaded => false, + } + } +} + +/// Backpressure controller based on queue depth. +#[derive(Debug)] +pub struct BackpressureController { + /// Thresholds for each level. + thresholds: [usize; 4], + /// Current queue depth. + current: std::sync::atomic::AtomicUsize, +} + +impl BackpressureController { + pub fn new(light: usize, medium: usize, heavy: usize, overload: usize) -> Self { + Self { + thresholds: [light, medium, heavy, overload], + current: std::sync::atomic::AtomicUsize::new(0), + } + } + + pub fn default_for_constrained() -> Self { + Self::new(10, 25, 50, 100) + } + + pub fn default_for_standard() -> Self { + Self::new(100, 500, 1000, 5000) + } + + pub fn set_queue_depth(&self, depth: usize) { + self.current.store(depth, std::sync::atomic::Ordering::Relaxed); + } + + pub fn level(&self) -> BackpressureLevel { + let depth = self.current.load(std::sync::atomic::Ordering::Relaxed); + if depth >= self.thresholds[3] { + BackpressureLevel::Overloaded + } else if depth >= self.thresholds[2] { + BackpressureLevel::Heavy + } else if depth >= self.thresholds[1] { + BackpressureLevel::Medium + } else if depth >= self.thresholds[0] { + BackpressureLevel::Light + } else { + BackpressureLevel::None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_bucket_allows_burst() { + let mut bucket = TokenBucket::new(10, 1.0); + for _ in 0..10 { + assert!(bucket.try_acquire().is_allowed()); + } + assert!(!bucket.try_acquire().is_allowed()); + } + + #[test] + fn token_bucket_refills() { + let mut bucket = TokenBucket::new(2, 100.0); // 100/sec refill + bucket.try_acquire(); + bucket.try_acquire(); + assert!(!bucket.try_acquire().is_allowed()); + + std::thread::sleep(Duration::from_millis(50)); + assert!(bucket.try_acquire().is_allowed()); + } + + #[test] + fn token_bucket_warning() { + let mut bucket = TokenBucket::new(8, 1.0); + // Use 7 tokens (leaves 1, which is < 8/4 = 2) + for _ in 0..7 { + bucket.try_acquire(); + } + let result = bucket.try_acquire(); + assert!(matches!(result, RateLimitResult::Warning { remaining: 0 })); + } + + #[test] + fn peer_rate_limiter() { + let config = RateLimitConfig { + message_per_peer_per_min: 5, + ..Default::default() + }; + let mut limiter = PeerRateLimiter::from_config(&config); + + for _ in 0..5 { + assert!(limiter.check_message().is_allowed()); + } + assert!(!limiter.check_message().is_allowed()); + } + + #[test] + fn rate_limiter_per_peer() { + let config = RateLimitConfig { + message_per_peer_per_min: 2, + ..Default::default() + }; + let limiter = RateLimiter::new(config); + + let peer1 = MeshAddress::from_bytes([1; 16]); + let peer2 = MeshAddress::from_bytes([2; 16]); + + assert!(limiter.check_message(&peer1).unwrap().is_allowed()); + assert!(limiter.check_message(&peer1).unwrap().is_allowed()); + assert!(!limiter.check_message(&peer1).unwrap().is_allowed()); + + // peer2 has its own bucket + assert!(limiter.check_message(&peer2).unwrap().is_allowed()); + } + + #[test] + fn duty_cycle_tracker() { + let tracker = DutyCycleTracker::new(0.01); // 1% + // 1 hour = 3600000 ms, 1% = 36000 ms + + assert!(tracker.can_transmit(1000)); + tracker.record(1000); + assert_eq!(tracker.used_ms(), 1000); + + assert!(tracker.can_transmit(35000)); + tracker.record(35000); + + // Now at 36000ms, at limit + assert!(!tracker.can_transmit(1000)); + } + + #[test] + fn backpressure_levels() { + let bp = BackpressureController::new(10, 50, 100, 200); + + bp.set_queue_depth(5); + assert_eq!(bp.level(), BackpressureLevel::None); + + bp.set_queue_depth(30); + assert_eq!(bp.level(), BackpressureLevel::Light); + + bp.set_queue_depth(75); + assert_eq!(bp.level(), BackpressureLevel::Medium); + + bp.set_queue_depth(150); + assert_eq!(bp.level(), BackpressureLevel::Heavy); + + bp.set_queue_depth(250); + assert_eq!(bp.level(), BackpressureLevel::Overloaded); + } + + #[test] + fn backpressure_priority_filter() { + assert!(BackpressureLevel::None.should_process(5)); + assert!(!BackpressureLevel::Light.should_process(5)); + assert!(BackpressureLevel::Light.should_process(2)); + assert!(!BackpressureLevel::Overloaded.should_process(0)); + } +}