use std::{ collections::{HashMap, VecDeque}, fs, hash::Hash, path::{Path, PathBuf}, sync::Mutex, }; use rand::RngCore; use serde::{Deserialize, Serialize}; #[derive(thiserror::Error, Debug)] pub enum StorageError { #[error("io error: {0}")] Io(String), #[error("serialization error")] Serde, #[error("database error: {0}")] Db(String), /// Unique constraint violation (e.g. user already exists). #[error("duplicate user: {0}")] DuplicateUser(String), } /// A persisted session record mapping a bearer token to an authenticated user. pub struct SessionRecord { pub username: String, pub identity_key: Vec, pub created_at: u64, pub expires_at: u64, } fn lock(m: &Mutex) -> Result, StorageError> { m.lock() .map_err(|e| StorageError::Io(format!("lock poisoned: {e}"))) } // ── Store trait ────────────────────────────────────────────────────────────── /// Abstraction over storage backends (file-backed, SQLCipher, etc.). pub trait Store: Send + Sync { fn upload_key_package(&self, identity_key: &[u8], package: Vec) -> Result<(), StorageError>; fn fetch_key_package(&self, identity_key: &[u8]) -> Result>, StorageError>; /// Enqueue a payload and return the monotonically increasing per-inbox sequence number /// assigned to this message. Clients sort by seq before MLS processing. /// When `ttl_secs` is `Some(n)`, the message expires n seconds from now. fn enqueue( &self, recipient_key: &[u8], channel_id: &[u8], payload: Vec, ttl_secs: Option, ) -> Result; /// Fetch and drain all queued messages, returning `(seq, payload)` pairs ordered by seq. fn fetch( &self, recipient_key: &[u8], channel_id: &[u8], ) -> Result)>, StorageError>; /// Fetch up to `limit` messages without draining the entire queue (Fix 8). /// Returns `(seq, payload)` pairs ordered by seq. fn fetch_limited( &self, recipient_key: &[u8], channel_id: &[u8], limit: usize, ) -> Result)>, StorageError>; /// Return the number of queued messages for (recipient, channel) (Fix 7). fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result; /// Delete messages older than `max_age_secs`. Returns count deleted (Fix 7). fn gc_expired_messages(&self, max_age_secs: u64) -> Result; fn upload_hybrid_key( &self, identity_key: &[u8], hybrid_pk: Vec, ) -> Result<(), StorageError>; fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result>, StorageError>; /// Store the OPAQUE `ServerSetup` (generated once, loaded on restart). fn store_server_setup(&self, setup: Vec) -> Result<(), StorageError>; /// Load the persisted `ServerSetup`, if any. fn get_server_setup(&self) -> Result>, StorageError>; /// Persist the server's Ed25519 signing key seed (32 bytes) for delivery proofs. fn store_signing_key_seed(&self, seed: Vec) -> Result<(), StorageError>; /// Load the persisted signing key seed, if any. fn get_signing_key_seed(&self) -> Result>, StorageError>; /// Persist the Key Transparency Merkle log (bincode-serialised `MerkleLog` bytes). fn save_kt_log(&self, bytes: Vec) -> Result<(), StorageError>; /// Load the persisted KT Merkle log, if any. fn load_kt_log(&self) -> Result>, StorageError>; /// Store an OPAQUE user record (serialized `ServerRegistration`). fn store_user_record(&self, username: &str, record: Vec) -> Result<(), StorageError>; /// Retrieve an OPAQUE user record by username. fn get_user_record(&self, username: &str) -> Result>, StorageError>; /// Check if a user record already exists (Fix 5). fn has_user_record(&self, username: &str) -> Result; /// Store identity key for a user (Fix 2). fn store_user_identity_key( &self, username: &str, identity_key: Vec, ) -> Result<(), StorageError>; /// Retrieve identity key for a user (Fix 2). fn get_user_identity_key(&self, username: &str) -> Result>, StorageError>; /// Reverse lookup: resolve an identity key to the registered username. fn resolve_identity_key(&self, identity_key: &[u8]) -> Result, StorageError>; /// Peek at queued messages without removing them (non-destructive). /// Returns `(seq, payload)` pairs ordered by seq. fn peek( &self, recipient_key: &[u8], channel_id: &[u8], limit: usize, ) -> Result)>, StorageError>; /// Acknowledge (remove) all messages with seq <= seq_up_to. fn ack( &self, recipient_key: &[u8], channel_id: &[u8], seq_up_to: u64, ) -> Result; /// Publish a P2P endpoint address for an identity key. fn publish_endpoint(&self, identity_key: &[u8], node_addr: Vec) -> Result<(), StorageError>; /// Resolve a peer's P2P endpoint address. fn resolve_endpoint(&self, identity_key: &[u8]) -> Result>, StorageError>; /// Create a 1:1 channel between two members. /// Returns `(channel_id, was_new)` where `was_new` is true iff the channel was created by /// this call (false = it already existed). Members are stored in sorted order for deterministic /// lookup — both `create_channel(a, b)` and `create_channel(b, a)` return the same channel_id. /// The caller who receives `was_new = true` is the MLS group initiator and must send the Welcome. fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<(Vec, bool), StorageError>; /// Get the two members of a channel by channel_id (16 bytes). Returns (member_a, member_b) in sorted order. #[allow(clippy::type_complexity)] fn get_channel_members(&self, channel_id: &[u8]) -> Result, Vec)>, StorageError>; // ── Federation ────────────────────────────────────────────────────────── /// Store the home server domain for an identity key. #[allow(dead_code)] // federation not yet wired up fn store_identity_home_server( &self, identity_key: &[u8], home_server: &str, ) -> Result<(), StorageError>; /// Get the home server domain for an identity key. fn get_identity_home_server( &self, identity_key: &[u8], ) -> Result, StorageError>; /// Insert or update a federation peer. fn upsert_federation_peer( &self, domain: &str, is_active: bool, ) -> Result<(), StorageError>; /// List all active federation peers. #[allow(dead_code)] // federation not yet wired up fn list_federation_peers(&self) -> Result, StorageError>; /// Permanently delete all data associated with an identity key. /// Removes deliveries, key packages, hybrid keys, channel memberships, /// user identity key mapping, and the user record itself. /// Does NOT delete KT log entries (append-only for auditability). fn delete_account(&self, identity_key: &[u8]) -> Result<(), StorageError>; // ── Device registry ───────────────────────────────────────────────────── /// Register a device for an identity. Returns false if the device already exists. /// Caller must check device_count < 5 before calling. fn register_device(&self, identity_key: &[u8], device_id: &[u8], device_name: &str) -> Result; /// List all registered devices for an identity: (device_id, name, registered_at). fn list_devices(&self, identity_key: &[u8]) -> Result, String, u64)>, StorageError>; /// Revoke (remove) a registered device. Returns false if not found. fn revoke_device(&self, identity_key: &[u8], device_id: &[u8]) -> Result; /// Return the number of registered devices for an identity. fn device_count(&self, identity_key: &[u8]) -> Result; // ── Session persistence ──────────────────────────────────────────────── /// Store a session token → record mapping. fn store_session(&self, _token: &[u8], _record: &SessionRecord) -> Result<(), StorageError> { Ok(()) } /// Retrieve a session record by bearer token. fn get_session(&self, _token: &[u8]) -> Result, StorageError> { Ok(None) } /// Delete all sessions whose `expires_at` <= `now`. Returns count deleted. fn delete_expired_sessions(&self, _now: u64) -> Result { Ok(0) } /// Delete a single session by token. fn delete_session(&self, _token: &[u8]) -> Result<(), StorageError> { Ok(()) } // ── Blob storage ─────────────────────────────────────────────────────── /// Append a chunk to the staging area for an in-progress upload. /// When all chunks have arrived (sum of chunk sizes == `total_size`), assembles the blob, /// verifies its SHA-256 hash against `blob_hash`, inserts into permanent storage, and /// returns `Some(blob_id)`. Otherwise returns `None`. fn store_blob_chunk( &self, _blob_hash: &[u8], _chunk: &[u8], _offset: u64, _total_size: u64, _mime_type: &str, ) -> Result>, StorageError> { Err(StorageError::Io("blob storage not supported".into())) } /// Read a slice of a completed blob. Returns `(chunk_data, total_size, mime_type)`. fn get_blob_chunk( &self, _blob_id: &[u8], _offset: u64, _length: u32, ) -> Result, u64, String)>, StorageError> { Err(StorageError::Io("blob storage not supported".into())) } } // ── ChannelKey ─────────────────────────────────────────────────────────────── #[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)] pub struct ChannelKey { pub channel_id: Vec, pub recipient_key: Vec, } impl Hash for ChannelKey { fn hash(&self, state: &mut H) { self.channel_id.hash(state); self.recipient_key.hash(state); } } // ── FileBackedStore ────────────────────────────────────────────────────────── #[derive(Serialize, Deserialize, Default)] struct QueueMapV1 { map: HashMap, VecDeque>>, } #[derive(Serialize, Deserialize, Default)] struct QueueMapV2 { map: HashMap>>, } #[derive(Serialize, Deserialize, Default, Clone)] struct SeqEntry { seq: u64, data: Vec, } /// V3 delivery store: each queue entry carries a monotonic per-inbox sequence number. #[derive(Serialize, Deserialize, Default)] struct QueueMapV3 { map: HashMap>, next_seq: HashMap, } /// File-backed storage for KeyPackages and delivery queues. /// /// Each mutation flushes the entire map to disk. Suitable for MVP-scale loads. pub struct FileBackedStore { kp_path: PathBuf, ds_path: PathBuf, hk_path: PathBuf, setup_path: PathBuf, signing_key_path: PathBuf, kt_log_path: PathBuf, users_path: PathBuf, identity_keys_path: PathBuf, channels_path: PathBuf, key_packages: Mutex, VecDeque>>>, deliveries: Mutex, #[allow(clippy::type_complexity)] channels: Mutex, (Vec, Vec)>>, hybrid_keys: Mutex, Vec>>, users: Mutex>>, identity_keys: Mutex>>, endpoints: Mutex, Vec>>, /// Device registry: identity_key -> Vec<(device_id, device_name, registered_at)> #[allow(clippy::type_complexity)] devices: Mutex, Vec<(Vec, String, u64)>>>, } impl FileBackedStore { pub fn open(dir: impl AsRef) -> Result { let dir = dir.as_ref(); if !dir.exists() { fs::create_dir_all(dir).map_err(|e| StorageError::Io(e.to_string()))?; } let kp_path = dir.join("keypackages.bin"); let ds_path = dir.join("deliveries.bin"); let hk_path = dir.join("hybridkeys.bin"); let setup_path = dir.join("server_setup.bin"); let signing_key_path = dir.join("server_signing_key.bin"); let kt_log_path = dir.join("kt_log.bin"); let users_path = dir.join("users.bin"); let identity_keys_path = dir.join("identity_keys.bin"); let channels_path = dir.join("channels.bin"); let key_packages = Mutex::new(Self::load_kp_map(&kp_path)?); let deliveries = Mutex::new(Self::load_delivery_map_v3(&ds_path)?); let hybrid_keys = Mutex::new(Self::load_hybrid_keys(&hk_path)?); let users = Mutex::new(Self::load_users(&users_path)?); let identity_keys = Mutex::new(Self::load_map_string_bytes(&identity_keys_path)?); let channels = Mutex::new(Self::load_channels(&channels_path)?); Ok(Self { kp_path, ds_path, hk_path, setup_path, signing_key_path, kt_log_path, users_path, identity_keys_path, channels_path, key_packages, deliveries, channels, hybrid_keys, users, identity_keys, endpoints: Mutex::new(HashMap::new()), devices: Mutex::new(HashMap::new()), }) } #[allow(clippy::type_complexity)] fn load_channels( path: &Path, ) -> Result, (Vec, Vec)>, StorageError> { if !path.exists() { return Ok(HashMap::new()); } let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(HashMap::new()); } bincode::deserialize(&bytes).map_err(|_| StorageError::Serde) } fn flush_channels( &self, path: &Path, map: &HashMap, (Vec, Vec)>, ) -> Result<(), StorageError> { let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string())) } fn load_kp_map(path: &Path) -> Result, VecDeque>>, StorageError> { if !path.exists() { return Ok(HashMap::new()); } let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(HashMap::new()); } let map: QueueMapV1 = bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)?; Ok(map.map) } fn flush_kp_map( &self, path: &Path, map: &HashMap, VecDeque>>, ) -> Result<(), StorageError> { let payload = QueueMapV1 { map: map.clone() }; let bytes = bincode::serialize(&payload).map_err(|_| StorageError::Serde)?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string())) } /// Load deliveries as V3. Falls back to V2 format (assigns seqs starting at 0). fn load_delivery_map_v3(path: &Path) -> Result { if !path.exists() { return Ok(QueueMapV3::default()); } let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(QueueMapV3::default()); } // Try V3 first. if let Ok(v3) = bincode::deserialize::(&bytes) { return Ok(v3); } // Fall back to V2: assign ascending seqs starting at 0 per channel. let v2 = bincode::deserialize::(&bytes) .map_err(|_| StorageError::Io("deliveries file: unrecognised format".into()))?; let mut v3 = QueueMapV3::default(); for (key, queue) in v2.map { let entries: VecDeque = queue .into_iter() .enumerate() .map(|(i, data)| SeqEntry { seq: i as u64, data }) .collect(); let next = entries.len() as u64; v3.next_seq.insert(key.clone(), next); v3.map.insert(key, entries); } Ok(v3) } fn flush_delivery_map(&self, path: &Path, map: &QueueMapV3) -> Result<(), StorageError> { let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string())) } fn load_hybrid_keys(path: &Path) -> Result, Vec>, StorageError> { if !path.exists() { return Ok(HashMap::new()); } let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(HashMap::new()); } bincode::deserialize(&bytes).map_err(|_| StorageError::Serde) } fn flush_hybrid_keys( &self, path: &Path, map: &HashMap, Vec>, ) -> Result<(), StorageError> { let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string())) } fn load_users(path: &Path) -> Result>, StorageError> { if !path.exists() { return Ok(HashMap::new()); } let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(HashMap::new()); } bincode::deserialize(&bytes).map_err(|_| StorageError::Serde) } fn flush_users(&self, path: &Path, map: &HashMap>) -> Result<(), StorageError> { let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?; if let Some(parent) = path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string())) } fn load_map_string_bytes(path: &Path) -> Result>, StorageError> { Self::load_users(path) } fn flush_map_string_bytes( &self, path: &Path, map: &HashMap>, ) -> Result<(), StorageError> { self.flush_users(path, map) } } impl Store for FileBackedStore { fn upload_key_package( &self, identity_key: &[u8], package: Vec, ) -> Result<(), StorageError> { let mut map = lock(&self.key_packages)?; map.entry(identity_key.to_vec()) .or_default() .push_back(package); self.flush_kp_map(&self.kp_path, &map) } fn fetch_key_package(&self, identity_key: &[u8]) -> Result>, StorageError> { let mut map = lock(&self.key_packages)?; let package = map.get_mut(identity_key).and_then(|q| q.pop_front()); self.flush_kp_map(&self.kp_path, &map)?; Ok(package) } fn enqueue( &self, recipient_key: &[u8], channel_id: &[u8], payload: Vec, _ttl_secs: Option, ) -> Result { let mut inner = lock(&self.deliveries)?; let key = ChannelKey { channel_id: channel_id.to_vec(), recipient_key: recipient_key.to_vec(), }; let entry = inner.next_seq.entry(key.clone()).or_insert(0); let seq = *entry; *entry = seq + 1; inner.map.entry(key).or_default().push_back(SeqEntry { seq, data: payload }); self.flush_delivery_map(&self.ds_path, &inner)?; Ok(seq) } fn fetch( &self, recipient_key: &[u8], channel_id: &[u8], ) -> Result)>, StorageError> { let mut inner = lock(&self.deliveries)?; let key = ChannelKey { channel_id: channel_id.to_vec(), recipient_key: recipient_key.to_vec(), }; let messages: Vec<(u64, Vec)> = inner .map .get_mut(&key) .map(|q| q.drain(..).map(|e| (e.seq, e.data)).collect()) .unwrap_or_default(); self.flush_delivery_map(&self.ds_path, &inner)?; Ok(messages) } fn fetch_limited( &self, recipient_key: &[u8], channel_id: &[u8], limit: usize, ) -> Result)>, StorageError> { let mut inner = lock(&self.deliveries)?; let key = ChannelKey { channel_id: channel_id.to_vec(), recipient_key: recipient_key.to_vec(), }; let messages: Vec<(u64, Vec)> = inner .map .get_mut(&key) .map(|q| { let count = limit.min(q.len()); q.drain(..count).map(|e| (e.seq, e.data)).collect() }) .unwrap_or_default(); self.flush_delivery_map(&self.ds_path, &inner)?; Ok(messages) } fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result { let inner = lock(&self.deliveries)?; let key = ChannelKey { channel_id: channel_id.to_vec(), recipient_key: recipient_key.to_vec(), }; Ok(inner.map.get(&key).map(|q| q.len()).unwrap_or(0)) } fn gc_expired_messages(&self, _max_age_secs: u64) -> Result { // FileBackedStore does not track timestamps per message — no-op. Ok(0) } fn upload_hybrid_key( &self, identity_key: &[u8], hybrid_pk: Vec, ) -> Result<(), StorageError> { let mut map = lock(&self.hybrid_keys)?; map.insert(identity_key.to_vec(), hybrid_pk); self.flush_hybrid_keys(&self.hk_path, &map) } fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result>, StorageError> { let map = lock(&self.hybrid_keys)?; Ok(map.get(identity_key).cloned()) } fn store_server_setup(&self, setup: Vec) -> Result<(), StorageError> { if let Some(parent) = self.setup_path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(&self.setup_path, setup).map_err(|e| StorageError::Io(e.to_string()))?; #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; let _ = std::fs::set_permissions(&self.setup_path, std::fs::Permissions::from_mode(0o600)); } Ok(()) } fn get_server_setup(&self) -> Result>, StorageError> { if !self.setup_path.exists() { return Ok(None); } let bytes = fs::read(&self.setup_path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(None); } Ok(Some(bytes)) } fn store_signing_key_seed(&self, seed: Vec) -> Result<(), StorageError> { if let Some(parent) = self.signing_key_path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(&self.signing_key_path, &seed).map_err(|e| StorageError::Io(e.to_string()))?; #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; let _ = std::fs::set_permissions( &self.signing_key_path, std::fs::Permissions::from_mode(0o600), ); } Ok(()) } fn get_signing_key_seed(&self) -> Result>, StorageError> { if !self.signing_key_path.exists() { return Ok(None); } let bytes = fs::read(&self.signing_key_path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(None); } Ok(Some(bytes)) } fn save_kt_log(&self, bytes: Vec) -> Result<(), StorageError> { if let Some(parent) = self.kt_log_path.parent() { fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?; } fs::write(&self.kt_log_path, &bytes).map_err(|e| StorageError::Io(e.to_string())) } fn load_kt_log(&self) -> Result>, StorageError> { if !self.kt_log_path.exists() { return Ok(None); } let bytes = fs::read(&self.kt_log_path).map_err(|e| StorageError::Io(e.to_string()))?; if bytes.is_empty() { return Ok(None); } Ok(Some(bytes)) } fn store_user_record(&self, username: &str, record: Vec) -> Result<(), StorageError> { let mut map = lock(&self.users)?; match map.entry(username.to_string()) { std::collections::hash_map::Entry::Occupied(_) => { return Err(StorageError::DuplicateUser(username.to_string())) } std::collections::hash_map::Entry::Vacant(v) => { v.insert(record); } } self.flush_users(&self.users_path, &map) } fn get_user_record(&self, username: &str) -> Result>, StorageError> { let map = lock(&self.users)?; Ok(map.get(username).cloned()) } fn has_user_record(&self, username: &str) -> Result { let map = lock(&self.users)?; Ok(map.contains_key(username)) } fn store_user_identity_key( &self, username: &str, identity_key: Vec, ) -> Result<(), StorageError> { let mut map = lock(&self.identity_keys)?; map.insert(username.to_string(), identity_key); self.flush_map_string_bytes(&self.identity_keys_path, &map) } fn get_user_identity_key(&self, username: &str) -> Result>, StorageError> { let map = lock(&self.identity_keys)?; Ok(map.get(username).cloned()) } fn resolve_identity_key(&self, identity_key: &[u8]) -> Result, StorageError> { let map = lock(&self.identity_keys)?; for (username, ik) in map.iter() { if ik.as_slice() == identity_key { return Ok(Some(username.clone())); } } Ok(None) } fn peek( &self, recipient_key: &[u8], channel_id: &[u8], limit: usize, ) -> Result)>, StorageError> { let inner = lock(&self.deliveries)?; let key = ChannelKey { channel_id: channel_id.to_vec(), recipient_key: recipient_key.to_vec(), }; let messages: Vec<(u64, Vec)> = inner .map .get(&key) .map(|q| { let count = if limit == 0 { q.len() } else { limit.min(q.len()) }; q.iter() .take(count) .map(|e| (e.seq, e.data.clone())) .collect() }) .unwrap_or_default(); // Non-destructive: do NOT flush. Ok(messages) } fn ack( &self, recipient_key: &[u8], channel_id: &[u8], seq_up_to: u64, ) -> Result { let mut inner = lock(&self.deliveries)?; let key = ChannelKey { channel_id: channel_id.to_vec(), recipient_key: recipient_key.to_vec(), }; let removed = if let Some(q) = inner.map.get_mut(&key) { let before = q.len(); q.retain(|e| e.seq > seq_up_to); before - q.len() } else { 0 }; self.flush_delivery_map(&self.ds_path, &inner)?; Ok(removed) } fn publish_endpoint( &self, identity_key: &[u8], node_addr: Vec, ) -> Result<(), StorageError> { let mut map = lock(&self.endpoints)?; map.insert(identity_key.to_vec(), node_addr); Ok(()) } fn resolve_endpoint(&self, identity_key: &[u8]) -> Result>, StorageError> { let map = lock(&self.endpoints)?; Ok(map.get(identity_key).cloned()) } fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<(Vec, bool), StorageError> { let (a, b) = if member_a < member_b { (member_a.to_vec(), member_b.to_vec()) } else { (member_b.to_vec(), member_a.to_vec()) }; let mut map = lock(&self.channels)?; if let Some((channel_id, _)) = map.iter().find(|(_, (ma, mb))| ma == &a && mb == &b) { return Ok((channel_id.clone(), false)); } let mut channel_id = [0u8; 16]; rand::rngs::OsRng.fill_bytes(&mut channel_id); let channel_id = channel_id.to_vec(); map.insert(channel_id.clone(), (a, b)); self.flush_channels(&self.channels_path, &map)?; Ok((channel_id, true)) } fn get_channel_members(&self, channel_id: &[u8]) -> Result, Vec)>, StorageError> { let map = lock(&self.channels)?; Ok(map.get(channel_id).cloned()) } fn store_identity_home_server( &self, _identity_key: &[u8], _home_server: &str, ) -> Result<(), StorageError> { // File-backed store: federation mappings are ephemeral (in-memory only). Ok(()) } fn get_identity_home_server( &self, _identity_key: &[u8], ) -> Result, StorageError> { Ok(None) } fn upsert_federation_peer( &self, _domain: &str, _is_active: bool, ) -> Result<(), StorageError> { Ok(()) } fn list_federation_peers(&self) -> Result, StorageError> { Ok(vec![]) } fn delete_account(&self, identity_key: &[u8]) -> Result<(), StorageError> { // Resolve username from identity key for user record deletion. let username = { let ik_map = lock(&self.identity_keys)?; ik_map.iter() .find(|(_, v)| v.as_slice() == identity_key) .map(|(k, _)| k.clone()) }; // Remove deliveries where this identity is the recipient. { let mut deliveries = lock(&self.deliveries)?; deliveries.map.retain(|k, _| k.recipient_key != identity_key); deliveries.next_seq.retain(|k, _| k.recipient_key != identity_key); self.flush_delivery_map(&self.ds_path, &deliveries)?; } // Remove key packages. { let mut kp = lock(&self.key_packages)?; kp.remove(identity_key); self.flush_kp_map(&self.kp_path, &kp)?; } // Remove hybrid keys. { let mut hk = lock(&self.hybrid_keys)?; hk.remove(identity_key); self.flush_hybrid_keys(&self.hk_path, &hk)?; } // Remove channels where this identity is a member. { let mut ch = lock(&self.channels)?; ch.retain(|_, (a, b)| a.as_slice() != identity_key && b.as_slice() != identity_key); self.flush_channels(&self.channels_path, &ch)?; } // Remove identity key mapping and user record. if let Some(uname) = username { { let mut ik_map = lock(&self.identity_keys)?; ik_map.remove(&uname); self.flush_map_string_bytes(&self.identity_keys_path, &ik_map)?; } { let mut users = lock(&self.users)?; users.remove(&uname); self.flush_users(&self.users_path, &users)?; } } // Remove endpoint. { let mut ep = lock(&self.endpoints)?; ep.remove(identity_key); } // Remove devices. { let mut dev = lock(&self.devices)?; dev.remove(identity_key); } Ok(()) } fn register_device(&self, identity_key: &[u8], device_id: &[u8], device_name: &str) -> Result { let mut map = lock(&self.devices)?; let devices = map.entry(identity_key.to_vec()).or_default(); if devices.iter().any(|(id, _, _)| id == device_id) { return Ok(false); } let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(); devices.push((device_id.to_vec(), device_name.to_string(), now)); Ok(true) } fn list_devices(&self, identity_key: &[u8]) -> Result, String, u64)>, StorageError> { let map = lock(&self.devices)?; Ok(map.get(identity_key).cloned().unwrap_or_default()) } fn revoke_device(&self, identity_key: &[u8], device_id: &[u8]) -> Result { let mut map = lock(&self.devices)?; if let Some(devices) = map.get_mut(identity_key) { let before = devices.len(); devices.retain(|(id, _, _)| id != device_id); Ok(devices.len() < before) } else { Ok(false) } } fn device_count(&self, identity_key: &[u8]) -> Result { let map = lock(&self.devices)?; Ok(map.get(identity_key).map(|v| v.len()).unwrap_or(0)) } } #[cfg(test)] #[allow(clippy::unwrap_used)] mod tests { use super::*; use tempfile::TempDir; fn temp_store() -> (TempDir, FileBackedStore) { let dir = TempDir::new().unwrap(); let store = FileBackedStore::open(dir.path()).unwrap(); (dir, store) } #[test] fn key_package_upload_fetch() { let (_dir, store) = temp_store(); let ik = vec![1u8; 32]; store.upload_key_package(&ik, vec![10, 20, 30]).unwrap(); let pkg = store.fetch_key_package(&ik).unwrap(); assert_eq!(pkg, Some(vec![10, 20, 30])); // Second fetch should return None (consumed) let pkg2 = store.fetch_key_package(&ik).unwrap(); assert_eq!(pkg2, None); } #[test] fn enqueue_fetch_with_seq() { let (_dir, store) = temp_store(); let rk = vec![2u8; 32]; let ch = vec![]; let seq0 = store.enqueue(&rk, &ch, vec![1], None).unwrap(); let seq1 = store.enqueue(&rk, &ch, vec![2], None).unwrap(); assert_eq!(seq0, 0); assert_eq!(seq1, 1); let msgs = store.fetch(&rk, &ch).unwrap(); assert_eq!(msgs.len(), 2); assert_eq!(msgs[0], (0, vec![1])); assert_eq!(msgs[1], (1, vec![2])); // After fetch, queue should be empty let msgs2 = store.fetch(&rk, &ch).unwrap(); assert!(msgs2.is_empty()); } #[test] fn fetch_limited_respects_limit() { let (_dir, store) = temp_store(); let rk = vec![3u8; 32]; let ch = vec![]; for i in 0..5 { store.enqueue(&rk, &ch, vec![i], None).unwrap(); } let msgs = store.fetch_limited(&rk, &ch, 2).unwrap(); assert_eq!(msgs.len(), 2); assert_eq!(msgs[0].1, vec![0]); assert_eq!(msgs[1].1, vec![1]); // Remaining 3 should still be there let depth = store.queue_depth(&rk, &ch).unwrap(); assert_eq!(depth, 3); } #[test] fn queue_depth_tracking() { let (_dir, store) = temp_store(); let rk = vec![4u8; 32]; let ch = vec![]; assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0); store.enqueue(&rk, &ch, vec![1], None).unwrap(); assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 1); store.enqueue(&rk, &ch, vec![2], None).unwrap(); assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 2); store.fetch(&rk, &ch).unwrap(); assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0); } #[test] fn hybrid_key_upload_fetch() { let (_dir, store) = temp_store(); let ik = vec![5u8; 32]; assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), None); store.upload_hybrid_key(&ik, vec![99; 100]).unwrap(); assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(vec![99; 100])); } #[test] fn user_record_crud() { let (_dir, store) = temp_store(); assert!(!store.has_user_record("alice").unwrap()); store.store_user_record("alice", vec![1, 2, 3]).unwrap(); assert!(store.has_user_record("alice").unwrap()); assert_eq!(store.get_user_record("alice").unwrap(), Some(vec![1, 2, 3])); } #[test] fn user_identity_key_crud() { let (_dir, store) = temp_store(); assert_eq!(store.get_user_identity_key("bob").unwrap(), None); store.store_user_identity_key("bob", vec![7u8; 32]).unwrap(); assert_eq!(store.get_user_identity_key("bob").unwrap(), Some(vec![7u8; 32])); } #[test] fn endpoint_publish_resolve() { let (_dir, store) = temp_store(); let ik = vec![8u8; 32]; assert_eq!(store.resolve_endpoint(&ik).unwrap(), None); store.publish_endpoint(&ik, vec![10, 20]).unwrap(); assert_eq!(store.resolve_endpoint(&ik).unwrap(), Some(vec![10, 20])); } #[test] fn create_channel_and_members() { let (_dir, store) = temp_store(); let a = vec![1u8; 32]; let b = vec![2u8; 32]; assert_eq!(store.get_channel_members(&[0u8; 16]).unwrap(), None); let (id1, was_new1) = store.create_channel(&a, &b).unwrap(); assert_eq!(id1.len(), 16); assert!(was_new1, "first call must return was_new=true"); let members = store.get_channel_members(&id1).unwrap().unwrap(); assert_eq!(members.0, a); assert_eq!(members.1, b); let (id2, was_new2) = store.create_channel(&b, &a).unwrap(); assert_eq!(id1, id2, "reversed key order must return same channel_id"); assert!(!was_new2, "second call (reversed) must return was_new=false"); } #[test] fn create_channel_idempotent_same_direction() { let (_dir, store) = temp_store(); let a = vec![3u8; 32]; let b = vec![4u8; 32]; let (id1, was_new1) = store.create_channel(&a, &b).unwrap(); let (id2, was_new2) = store.create_channel(&a, &b).unwrap(); assert_eq!(id1, id2); assert!(was_new1); assert!(!was_new2); } #[test] fn create_channel_different_pairs_get_different_ids() { let (_dir, store) = temp_store(); let a = vec![5u8; 32]; let b = vec![6u8; 32]; let c = vec![7u8; 32]; let (id_ab, _) = store.create_channel(&a, &b).unwrap(); let (id_ac, _) = store.create_channel(&a, &c).unwrap(); let (id_bc, _) = store.create_channel(&b, &c).unwrap(); assert_ne!(id_ab, id_ac); assert_ne!(id_ab, id_bc); assert_ne!(id_ac, id_bc); } }