//! KeyPackage cache for mesh-based MLS group setup. //! //! The [`KeyPackageCache`] stores MLS KeyPackages received from other nodes, //! enabling group creation without a central server. KeyPackages are: //! //! - Indexed by the node's 16-byte mesh address //! - Hashed (8 bytes) for announce inclusion //! - TTL-managed for expiry (MLS KeyPackages are single-use but we cache N of them) //! - Bounded by capacity to prevent memory exhaustion //! //! # Protocol Flow //! //! 1. Bob generates KeyPackage, computes hash, includes hash in MeshAnnounce //! 2. Bob broadcasts full KeyPackage periodically (or on request) //! 3. Alice receives Bob's KeyPackage, stores in cache //! 4. Alice wants to create group with Bob: fetches from cache, creates Welcome //! 5. Alice sends Welcome to Bob via mesh routing use std::collections::HashMap; use std::time::{Duration, Instant}; use crate::address::MeshAddress; use crate::announce::compute_keypackage_hash; /// Default TTL for cached KeyPackages (24 hours). const DEFAULT_TTL: Duration = Duration::from_secs(24 * 60 * 60); /// Default maximum KeyPackages per address (allow rotation). const DEFAULT_MAX_PER_ADDRESS: usize = 3; /// A cached KeyPackage entry. #[derive(Clone, Debug)] pub struct CachedKeyPackage { /// The serialized MLS KeyPackage bytes. pub bytes: Vec, /// 8-byte truncated hash for matching against announces. pub hash: [u8; 8], /// When this entry was stored. pub stored_at: Instant, /// When this entry expires. pub expires_at: Instant, } impl CachedKeyPackage { /// Create a new cached entry with default TTL. pub fn new(bytes: Vec) -> Self { Self::with_ttl(bytes, DEFAULT_TTL) } /// Create with custom TTL. pub fn with_ttl(bytes: Vec, ttl: Duration) -> Self { let hash = compute_keypackage_hash(&bytes); let now = Instant::now(); Self { bytes, hash, stored_at: now, expires_at: now + ttl, } } /// Check if this entry has expired. pub fn is_expired(&self) -> bool { Instant::now() > self.expires_at } } /// Cache for KeyPackages received from mesh peers. pub struct KeyPackageCache { /// Address -> list of cached KeyPackages (multiple for rotation). entries: HashMap>, /// Maximum KeyPackages stored per address. max_per_address: usize, /// Total capacity (max addresses). max_addresses: usize, } impl KeyPackageCache { /// Create a new cache with default settings. pub fn new() -> Self { Self::with_capacity(1000, DEFAULT_MAX_PER_ADDRESS) } /// Create with custom capacity. pub fn with_capacity(max_addresses: usize, max_per_address: usize) -> Self { Self { entries: HashMap::new(), max_per_address, max_addresses, } } /// Store a KeyPackage for a given address. /// /// Returns `true` if stored, `false` if rejected (at capacity or duplicate hash). pub fn store(&mut self, address: MeshAddress, keypackage_bytes: Vec) -> bool { let entry = CachedKeyPackage::new(keypackage_bytes); self.store_entry(address, entry) } /// Store a KeyPackage entry. fn store_entry(&mut self, address: MeshAddress, entry: CachedKeyPackage) -> bool { // Check if we already have this exact KeyPackage if let Some(existing) = self.entries.get(&address) { if existing.iter().any(|e| e.hash == entry.hash) { return false; // Duplicate } } // Check total capacity if !self.entries.contains_key(&address) && self.entries.len() >= self.max_addresses { // Evict oldest entry self.evict_oldest(); } let list = self.entries.entry(address).or_default(); // Enforce per-address limit while list.len() >= self.max_per_address { list.remove(0); // Remove oldest } list.push(entry); true } /// Get the newest KeyPackage for an address. pub fn get(&self, address: &MeshAddress) -> Option<&CachedKeyPackage> { self.entries .get(address) .and_then(|list| list.iter().rev().find(|e| !e.is_expired())) } /// Get a KeyPackage by its hash. pub fn get_by_hash(&self, address: &MeshAddress, hash: &[u8; 8]) -> Option<&CachedKeyPackage> { self.entries.get(address).and_then(|list| { list.iter() .rev() .find(|e| &e.hash == hash && !e.is_expired()) }) } /// Get the newest KeyPackage bytes for an address. pub fn get_bytes(&self, address: &MeshAddress) -> Option> { self.get(address).map(|e| e.bytes.clone()) } /// Check if we have a KeyPackage matching a given hash. pub fn has_hash(&self, address: &MeshAddress, hash: &[u8; 8]) -> bool { self.get_by_hash(address, hash).is_some() } /// Remove all expired entries. Returns count removed. pub fn gc_expired(&mut self) -> usize { let mut removed = 0; self.entries.retain(|_, list| { let before = list.len(); list.retain(|e| !e.is_expired()); removed += before - list.len(); !list.is_empty() }); removed } /// Evict the oldest entry across all addresses. fn evict_oldest(&mut self) { let oldest_addr = self .entries .iter() .filter_map(|(addr, list)| { list.first().map(|e| (addr.clone(), e.stored_at)) }) .min_by_key(|(_, stored)| *stored) .map(|(addr, _)| addr); if let Some(addr) = oldest_addr { if let Some(list) = self.entries.get_mut(&addr) { list.remove(0); if list.is_empty() { self.entries.remove(&addr); } } } } /// Number of addresses with cached KeyPackages. pub fn len(&self) -> usize { self.entries.len() } /// Whether the cache is empty. pub fn is_empty(&self) -> bool { self.entries.is_empty() } /// Total number of cached KeyPackages. pub fn total_keypackages(&self) -> usize { self.entries.values().map(|v| v.len()).sum() } /// Consume a KeyPackage (remove after use, as MLS KeyPackages are single-use). /// /// Returns the KeyPackage bytes if found. pub fn consume(&mut self, address: &MeshAddress, hash: &[u8; 8]) -> Option> { let list = self.entries.get_mut(address)?; let idx = list.iter().position(|e| &e.hash == hash)?; let entry = list.remove(idx); if list.is_empty() { self.entries.remove(address); } Some(entry.bytes) } } impl Default for KeyPackageCache { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; fn make_keypackage(seed: u8) -> Vec { vec![seed; 100 + seed as usize] } fn make_address(seed: u8) -> MeshAddress { MeshAddress::from_bytes([seed; 16]) } #[test] fn store_and_retrieve() { let mut cache = KeyPackageCache::new(); let addr = make_address(1); let kp = make_keypackage(1); let hash = compute_keypackage_hash(&kp); assert!(cache.store(addr, kp.clone())); assert_eq!(cache.len(), 1); let retrieved = cache.get(&addr).expect("should exist"); assert_eq!(retrieved.bytes, kp); assert_eq!(retrieved.hash, hash); } #[test] fn reject_duplicate() { let mut cache = KeyPackageCache::new(); let addr = make_address(2); let kp = make_keypackage(2); assert!(cache.store(addr, kp.clone())); assert!(!cache.store(addr, kp), "duplicate should be rejected"); assert_eq!(cache.total_keypackages(), 1); } #[test] fn multiple_per_address() { let mut cache = KeyPackageCache::with_capacity(100, 3); let addr = make_address(3); assert!(cache.store(addr, make_keypackage(1))); assert!(cache.store(addr, make_keypackage(2))); assert!(cache.store(addr, make_keypackage(3))); assert_eq!(cache.total_keypackages(), 3); // Fourth should evict first assert!(cache.store(addr, make_keypackage(4))); assert_eq!(cache.total_keypackages(), 3); // First should be gone let hash1 = compute_keypackage_hash(&make_keypackage(1)); assert!(!cache.has_hash(&addr, &hash1)); // Fourth should be present let hash4 = compute_keypackage_hash(&make_keypackage(4)); assert!(cache.has_hash(&addr, &hash4)); } #[test] fn consume_removes_keypackage() { let mut cache = KeyPackageCache::new(); let addr = make_address(4); let kp = make_keypackage(4); let hash = compute_keypackage_hash(&kp); cache.store(addr, kp.clone()); assert!(cache.has_hash(&addr, &hash)); let consumed = cache.consume(&addr, &hash).expect("should consume"); assert_eq!(consumed, kp); assert!(!cache.has_hash(&addr, &hash)); assert!(cache.is_empty()); } #[test] fn get_by_hash() { let mut cache = KeyPackageCache::new(); let addr = make_address(5); let kp1 = make_keypackage(51); let kp2 = make_keypackage(52); let hash1 = compute_keypackage_hash(&kp1); let hash2 = compute_keypackage_hash(&kp2); cache.store(addr, kp1.clone()); cache.store(addr, kp2.clone()); let found1 = cache.get_by_hash(&addr, &hash1).expect("hash1"); assert_eq!(found1.bytes, kp1); let found2 = cache.get_by_hash(&addr, &hash2).expect("hash2"); assert_eq!(found2.bytes, kp2); let wrong_hash = [0xFFu8; 8]; assert!(cache.get_by_hash(&addr, &wrong_hash).is_none()); } #[test] fn capacity_eviction() { let mut cache = KeyPackageCache::with_capacity(2, 1); let addr1 = make_address(1); let addr2 = make_address(2); let addr3 = make_address(3); cache.store(addr1, make_keypackage(1)); cache.store(addr2, make_keypackage(2)); assert_eq!(cache.len(), 2); // Third should evict oldest (addr1) cache.store(addr3, make_keypackage(3)); assert_eq!(cache.len(), 2); assert!(cache.get(&addr1).is_none()); assert!(cache.get(&addr2).is_some()); assert!(cache.get(&addr3).is_some()); } #[test] fn expiry() { let mut cache = KeyPackageCache::new(); let addr = make_address(6); // Create entry with very short TTL let kp = make_keypackage(6); let entry = CachedKeyPackage::with_ttl(kp, Duration::from_millis(1)); cache.store_entry(addr, entry); assert_eq!(cache.total_keypackages(), 1); // Wait for expiry std::thread::sleep(Duration::from_millis(10)); // GC should remove it let removed = cache.gc_expired(); assert_eq!(removed, 1); assert!(cache.is_empty()); } }