use std::{ collections::{HashMap, VecDeque}, fs, hash::Hash, path::{Path, PathBuf}, sync::Mutex, }; use serde::{Deserialize, Serialize}; #[derive(thiserror::Error, Debug)] pub enum StorageError { #[error("io error: {0}")] Io(String), #[error("serialization error")] Serde, #[error("database error: {0}")] Db(String), } fn lock(m: &Mutex) -> Result, StorageError> { m.lock() .map_err(|e| StorageError::Io(format!("lock poisoned: {e}"))) } // ── Store trait ────────────────────────────────────────────────────────────── /// Abstraction over storage backends (file-backed, SQLCipher, etc.). pub trait Store: Send + Sync { fn upload_key_package(&self, identity_key: &[u8], package: Vec) -> Result<(), StorageError>; fn fetch_key_package(&self, identity_key: &[u8]) -> Result>, StorageError>; /// Enqueue a payload and return the monotonically increasing per-inbox sequence number /// assigned to this message. Clients sort by seq before MLS processing. fn enqueue( &self, recipient_key: &[u8], channel_id: &[u8], payload: Vec, ) -> Result; /// Fetch and drain all queued messages, returning `(seq, payload)` pairs ordered by seq. fn fetch( &self, recipient_key: &[u8], channel_id: &[u8], ) -> Result)>, StorageError>; /// Fetch up to `limit` messages without draining the entire queue (Fix 8). /// Returns `(seq, payload)` pairs ordered by seq. fn fetch_limited( &self, recipient_key: &[u8], channel_id: &[u8], limit: usize, ) -> Result)>, StorageError>; /// Return the number of queued messages for (recipient, channel) (Fix 7). fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result; /// Delete messages older than `max_age_secs`. Returns count deleted (Fix 7). fn gc_expired_messages(&self, max_age_secs: u64) -> Result; fn upload_hybrid_key( &self, identity_key: &[u8], hybrid_pk: Vec, ) -> Result<(), StorageError>; fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result>, StorageError>; /// Store the OPAQUE `ServerSetup` (generated once, loaded on restart). fn store_server_setup(&self, setup: Vec) -> Result<(), StorageError>; /// Load the persisted `ServerSetup`, if any. fn get_server_setup(&self) -> Result>, StorageError>; /// Store an OPAQUE user record (serialized `ServerRegistration`). fn store_user_record(&self, username: &str, record: Vec) -> Result<(), StorageError>; /// Retrieve an OPAQUE user record by username. fn get_user_record(&self, username: &str) -> Result>, StorageError>; /// Check if a user record already exists (Fix 5). fn has_user_record(&self, username: &str) -> Result; /// Store identity key for a user (Fix 2). fn store_user_identity_key( &self, username: &str, identity_key: Vec, ) -> Result<(), StorageError>; /// Retrieve identity key for a user (Fix 2). fn get_user_identity_key(&self, username: &str) -> Result>, StorageError>; /// Publish a P2P endpoint address for an identity key. fn publish_endpoint(&self, identity_key: &[u8], node_addr: Vec) -> Result<(), StorageError>; /// Resolve a peer's P2P endpoint address. fn resolve_endpoint(&self, identity_key: &[u8]) -> Result>, StorageError>; } // ── ChannelKey ─────────────────────────────────────────────────────────────── #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)] pub struct ChannelKey { pub channel_id: Vec, pub recipient_key: Vec, } impl Hash for ChannelKey { fn hash(&self, state: &mut H) { self.channel_id.hash(state); self.recipient_key.hash(state); } } // ── FileBackedStore ────────────────────────────────────────────────────────── #[derive(Serialize, Deserialize, Default)] struct QueueMapV1 { map: HashMap, VecDeque>>, } #[derive(Serialize, Deserialize, Default)] struct QueueMapV2 { map: HashMap>>, } #[derive(Serialize, Deserialize, Default, Clone)] struct SeqEntry { seq: u64, data: Vec, } /// V3 delivery store: each queue entry carries a monotonic per-inbox sequence number. #[derive(Serialize, Deserialize, Default)] struct QueueMapV3 { map: HashMap>, next_seq: HashMap, } /// File-backed storage for KeyPackages and delivery queues. /// /// Each mutation flushes the entire map to disk. Suitable for MVP-scale loads. pub struct FileBackedStore { kp_path: PathBuf, ds_path: PathBuf, hk_path: PathBuf, setup_path: PathBuf, users_path: PathBuf, identity_keys_path: PathBuf, key_packages: Mutex, VecDeque>>>, deliveries: Mutex, hybrid_keys: Mutex, Vec>>, users: Mutex>>, identity_keys: Mutex>>, endpoints: Mutex, Vec>>, } impl FileBackedStore { pub fn open(dir: impl AsRef) -> Result { let dir = dir.as_ref(); if !dir.exists() { fs::create_dir_all(dir).map_err(|e| StorageError::Io(e.to_string()))?; } let kp_path = dir.join("keypackages.bin"); let ds_path = dir.join("deliveries.bin"); let hk_path = dir.join("hybridkeys.bin"); let setup_path = dir.join("server_setup.bin"); let users_path = dir.join("users.bin"); let identity_keys_path = dir.join("identity_keys.bin"); let key_packages = Mutex::new(Self::load_kp_map(&kp_path)?); let deliveries = Mutex::new(Self::load_delivery_map_v3(&ds_path)?); let hybrid_keys = Mutex::new(Self::load_hybrid_keys(&hk_path)?); let users = Mutex::new(Self::load_users(&users_path)?); let identity_keys = Mutex::new(Self::load_map_string_bytes(&identity_keys_path)?); Ok(Self { kp_path, ds_path, hk_path, setup_path, users_path, identity_keys_path, key_packages, deliveries, hybrid_keys, users, identity_keys, endpoints: Mutex::new(HashMap::new()), }) } fn load_kp_map(path: &Path) -> Result, VecDeque>>, StorageError> { if !path.exists() { return Ok(HashMap::new()); } let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(HashMap::new()); } let map: QueueMapV1 = bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)?; Ok(map.map) } fn flush_kp_map( &self, path: &Path, map: &HashMap, VecDeque>>, ) -> Result<(), StorageError> { let payload = QueueMapV1 { map: map.clone() }; let bytes = bincode::serialize(&payload).map_err(|_| StorageError::Serde)?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string())) } /// Load deliveries as V3. Falls back to V2 format (assigns seqs starting at 0). fn load_delivery_map_v3(path: &Path) -> Result { if !path.exists() { return Ok(QueueMapV3::default()); } let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(QueueMapV3::default()); } // Try V3 first. if let Ok(v3) = bincode::deserialize::(&bytes) { return Ok(v3); } // Fall back to V2: assign ascending seqs starting at 0 per channel. let v2 = bincode::deserialize::(&bytes) .map_err(|_| StorageError::Io("deliveries file: unrecognised format".into()))?; let mut v3 = QueueMapV3::default(); for (key, queue) in v2.map { let entries: VecDeque = queue .into_iter() .enumerate() .map(|(i, data)| SeqEntry { seq: i as u64, data }) .collect(); let next = entries.len() as u64; v3.next_seq.insert(key.clone(), next); v3.map.insert(key, entries); } Ok(v3) } fn flush_delivery_map(&self, path: &Path, map: &QueueMapV3) -> Result<(), StorageError> { let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string())) } fn load_hybrid_keys(path: &Path) -> Result, Vec>, StorageError> { if !path.exists() { return Ok(HashMap::new()); } let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(HashMap::new()); } bincode::deserialize(&bytes).map_err(|_| StorageError::Serde) } fn flush_hybrid_keys( &self, path: &Path, map: &HashMap, Vec>, ) -> Result<(), StorageError> { let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string())) } fn load_users(path: &Path) -> Result>, StorageError> { if !path.exists() { return Ok(HashMap::new()); } let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(HashMap::new()); } bincode::deserialize(&bytes).map_err(|_| StorageError::Serde) } fn flush_users(&self, path: &Path, map: &HashMap>) -> Result<(), StorageError> { let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string())) } fn load_map_string_bytes(path: &Path) -> Result>, StorageError> { Self::load_users(path) } fn flush_map_string_bytes( &self, path: &Path, map: &HashMap>, ) -> Result<(), StorageError> { self.flush_users(path, map) } } impl Store for FileBackedStore { fn upload_key_package( &self, identity_key: &[u8], package: Vec, ) -> Result<(), StorageError> { let mut map = lock(&self.key_packages)?; map.entry(identity_key.to_vec()) .or_default() .push_back(package); self.flush_kp_map(&self.kp_path, &*map) } fn fetch_key_package(&self, identity_key: &[u8]) -> Result>, StorageError> { let mut map = lock(&self.key_packages)?; let package = map.get_mut(identity_key).and_then(|q| q.pop_front()); self.flush_kp_map(&self.kp_path, &*map)?; Ok(package) } fn enqueue( &self, recipient_key: &[u8], channel_id: &[u8], payload: Vec, ) -> Result { let mut inner = lock(&self.deliveries)?; let key = ChannelKey { channel_id: channel_id.to_vec(), recipient_key: recipient_key.to_vec(), }; let seq = { let entry = inner.next_seq.entry(key.clone()).or_insert(0); let s = *entry; *entry = s + 1; s }; inner.map.entry(key).or_default().push_back(SeqEntry { seq, data: payload }); self.flush_delivery_map(&self.ds_path, &*inner)?; Ok(seq) } fn fetch( &self, recipient_key: &[u8], channel_id: &[u8], ) -> Result)>, StorageError> { let mut inner = lock(&self.deliveries)?; let key = ChannelKey { channel_id: channel_id.to_vec(), recipient_key: recipient_key.to_vec(), }; let messages: Vec<(u64, Vec)> = inner .map .get_mut(&key) .map(|q| q.drain(..).map(|e| (e.seq, e.data)).collect()) .unwrap_or_default(); self.flush_delivery_map(&self.ds_path, &*inner)?; Ok(messages) } fn fetch_limited( &self, recipient_key: &[u8], channel_id: &[u8], limit: usize, ) -> Result)>, StorageError> { let mut inner = lock(&self.deliveries)?; let key = ChannelKey { channel_id: channel_id.to_vec(), recipient_key: recipient_key.to_vec(), }; let messages: Vec<(u64, Vec)> = inner .map .get_mut(&key) .map(|q| { let count = limit.min(q.len()); q.drain(..count).map(|e| (e.seq, e.data)).collect() }) .unwrap_or_default(); self.flush_delivery_map(&self.ds_path, &*inner)?; Ok(messages) } fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result { let inner = lock(&self.deliveries)?; let key = ChannelKey { channel_id: channel_id.to_vec(), recipient_key: recipient_key.to_vec(), }; Ok(inner.map.get(&key).map(|q| q.len()).unwrap_or(0)) } fn gc_expired_messages(&self, _max_age_secs: u64) -> Result { // FileBackedStore does not track timestamps per message — no-op. Ok(0) } fn upload_hybrid_key( &self, identity_key: &[u8], hybrid_pk: Vec, ) -> Result<(), StorageError> { let mut map = lock(&self.hybrid_keys)?; map.insert(identity_key.to_vec(), hybrid_pk); self.flush_hybrid_keys(&self.hk_path, &*map) } fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result>, StorageError> { let map = lock(&self.hybrid_keys)?; Ok(map.get(identity_key).cloned()) } fn store_server_setup(&self, setup: Vec) -> Result<(), StorageError> { if let Some(parent) = self.setup_path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(&self.setup_path, setup).map_err(|e| StorageError::Io(e.to_string())) } fn get_server_setup(&self) -> Result>, StorageError> { if !self.setup_path.exists() { return Ok(None); } let bytes = fs::read(&self.setup_path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(None); } Ok(Some(bytes)) } fn store_user_record(&self, username: &str, record: Vec) -> Result<(), StorageError> { let mut map = lock(&self.users)?; map.insert(username.to_string(), record); self.flush_users(&self.users_path, &*map) } fn get_user_record(&self, username: &str) -> Result>, StorageError> { let map = lock(&self.users)?; Ok(map.get(username).cloned()) } fn has_user_record(&self, username: &str) -> Result { let map = lock(&self.users)?; Ok(map.contains_key(username)) } fn store_user_identity_key( &self, username: &str, identity_key: Vec, ) -> Result<(), StorageError> { let mut map = lock(&self.identity_keys)?; map.insert(username.to_string(), identity_key); self.flush_map_string_bytes(&self.identity_keys_path, &*map) } fn get_user_identity_key(&self, username: &str) -> Result>, StorageError> { let map = lock(&self.identity_keys)?; Ok(map.get(username).cloned()) } fn publish_endpoint( &self, identity_key: &[u8], node_addr: Vec, ) -> Result<(), StorageError> { let mut map = lock(&self.endpoints)?; map.insert(identity_key.to_vec(), node_addr); Ok(()) } fn resolve_endpoint(&self, identity_key: &[u8]) -> Result>, StorageError> { let map = lock(&self.endpoints)?; Ok(map.get(identity_key).cloned()) } }