//! SQLCipher-backed persistent storage. use std::path::Path; use std::sync::Mutex; use rand::RngCore; use rusqlite::{params, Connection}; use sha2::{Digest, Sha256}; use crate::storage::{SessionRecord, StorageError, Store}; /// Schema version after introducing the migration runner (existing DBs had 1). const SCHEMA_VERSION: i32 = 13; /// Default number of connections in the pool. const DEFAULT_POOL_SIZE: usize = 4; /// Migrations: (migration_number, SQL). Files named NNN_name.sql, applied in order when N > user_version. const MIGRATIONS: &[(i32, &str)] = &[ (1, include_str!("../migrations/001_initial.sql")), (3, include_str!("../migrations/002_add_seq.sql")), (4, include_str!("../migrations/003_channels.sql")), (5, include_str!("../migrations/004_federation.sql")), (6, include_str!("../migrations/005_signing_key.sql")), (7, include_str!("../migrations/006_kt_log.sql")), (8, include_str!("../migrations/007_add_expiry.sql")), (9, include_str!("../migrations/008_devices.sql")), (10, include_str!("../migrations/009_sessions.sql")), (11, include_str!("../migrations/010_blobs.sql")), (12, include_str!("../migrations/011_recovery_bundles.sql")), (13, include_str!("../migrations/012_moderation.sql")), ]; /// Runs pending migrations on an open connection: applies any migration whose number is greater /// than the current PRAGMA user_version, then sets user_version to SCHEMA_VERSION. fn run_migrations(conn: &Connection) -> Result<(), StorageError> { let current_version: i32 = conn .pragma_query_value(None, "user_version", |row| row.get(0)) .map_err(|e| StorageError::Db(format!("PRAGMA user_version failed: {e}")))?; for (migration_num, sql) in MIGRATIONS { if *migration_num > current_version { conn.execute_batch(sql).map_err(|e| StorageError::Db(e.to_string()))?; } } conn.pragma_update(None, "user_version", SCHEMA_VERSION) .map_err(|e| StorageError::Db(format!("set user_version failed: {e}")))?; Ok(()) } /// SQLCipher-encrypted storage backend with a connection pool. /// /// Maintains `pool_size` SQLite connections (default 4) behind `std::sync::Mutex`. /// Each store method tries all connections via `try_lock()` before falling back to /// blocking on the first connection. WAL mode allows concurrent readers; writers /// are serialised by SQLite itself. pub struct SqlStore { pool: Vec>, } impl SqlStore { /// Try to acquire any connection from the pool without blocking. /// Falls back to blocking on the first connection. fn get_conn(&self) -> Result, StorageError> { // Fast path: try each connection without blocking. for conn in &self.pool { if let Ok(guard) = conn.try_lock() { return Ok(guard); } } // Slow path: block on the first connection. self.pool[0] .lock() .map_err(|e| StorageError::Db(format!("lock poisoned: {e}"))) } pub fn open(path: impl AsRef, key: &str) -> Result { Self::open_with_pool_size(path, key, DEFAULT_POOL_SIZE) } pub fn open_with_pool_size( path: impl AsRef, key: &str, pool_size: usize, ) -> Result { let pool_size = pool_size.max(1); let path = path.as_ref(); // Open the first connection and run migrations. let first = Self::open_one(path, key)?; let current_version: i32 = first .pragma_query_value(None, "user_version", |row| row.get(0)) .map_err(|e| StorageError::Db(format!("PRAGMA user_version failed: {e}")))?; if current_version > SCHEMA_VERSION { return Err(StorageError::Db(format!( "database schema version {current_version} is newer than supported {SCHEMA_VERSION}" ))); } run_migrations(&first)?; let mut pool = Vec::with_capacity(pool_size); pool.push(Mutex::new(first)); // Open remaining connections (they skip migrations since the first one already ran them). for _ in 1..pool_size { let conn = Self::open_one(path, key)?; pool.push(Mutex::new(conn)); } Ok(Self { pool }) } /// Open a single connection with shared pragmas. fn open_one(path: &Path, key: &str) -> Result { let conn = Connection::open(path).map_err(|e| StorageError::Db(e.to_string()))?; if !key.is_empty() { conn.pragma_update(None, "key", key) .map_err(|e| StorageError::Db(format!("PRAGMA key failed: {e}")))?; } conn.execute_batch( "PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL; PRAGMA foreign_keys = ON;", ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(conn) } } impl Store for SqlStore { fn upload_key_package( &self, identity_key: &[u8], package: Vec, ) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT INTO key_packages (identity_key, package_data) VALUES (?1, ?2)", params![identity_key, package], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn fetch_key_package(&self, identity_key: &[u8]) -> Result>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare( "SELECT id, package_data FROM key_packages WHERE identity_key = ?1 ORDER BY id ASC LIMIT 1", ) .map_err(|e| StorageError::Db(e.to_string()))?; let row = stmt .query_row(params![identity_key], |row| { Ok((row.get::<_, i64>(0)?, row.get::<_, Vec>(1)?)) }) .optional() .map_err(|e| StorageError::Db(e.to_string()))?; match row { Some((id, package)) => { conn.execute("DELETE FROM key_packages WHERE id = ?1", params![id]) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(Some(package)) } None => Ok(None), } } fn enqueue( &self, recipient_key: &[u8], channel_id: &[u8], payload: Vec, ttl_secs: Option, ) -> Result { let conn = self.get_conn()?; // Atomically get-and-increment the per-inbox sequence counter. // RETURNING gives us the post-update next_seq; the assigned seq is next_seq - 1. let seq: i64 = conn .query_row( "INSERT INTO delivery_seq_counters (recipient_key, channel_id, next_seq) VALUES (?1, ?2, 1) ON CONFLICT(recipient_key, channel_id) DO UPDATE SET next_seq = next_seq + 1 RETURNING next_seq - 1", params![recipient_key, channel_id], |row| row.get(0), ) .map_err(|e| StorageError::Db(e.to_string()))?; let expires_at: Option = ttl_secs.map(|ttl| { let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() as i64; now + ttl as i64 }); conn.execute( "INSERT INTO deliveries (recipient_key, channel_id, seq, payload, expires_at) VALUES (?1, ?2, ?3, ?4, ?5)", params![recipient_key, channel_id, seq, payload, expires_at], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(seq as u64) } fn fetch( &self, recipient_key: &[u8], channel_id: &[u8], ) -> Result)>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare( "SELECT id, seq, payload FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND (expires_at IS NULL OR expires_at > strftime('%s','now')) ORDER BY seq ASC", ) .map_err(|e| StorageError::Db(e.to_string()))?; let rows: Vec<(i64, i64, Vec)> = stmt .query_map(params![recipient_key, channel_id], |row| { Ok((row.get(0)?, row.get(1)?, row.get(2)?)) }) .map_err(|e| StorageError::Db(e.to_string()))? .collect::, _>>() .map_err(|e| StorageError::Db(e.to_string()))?; if !rows.is_empty() { let ids: Vec = rows.iter().map(|(id, _, _)| *id).collect(); let placeholders: String = ids.iter().map(|_| "?").collect::>().join(","); let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})"); let params: Vec<&dyn rusqlite::types::ToSql> = ids .iter() .map(|id| id as &dyn rusqlite::types::ToSql) .collect(); conn.execute(&sql, params.as_slice()) .map_err(|e| StorageError::Db(e.to_string()))?; } Ok(rows.into_iter().map(|(_, seq, payload)| (seq as u64, payload)).collect()) } fn fetch_limited( &self, recipient_key: &[u8], channel_id: &[u8], limit: usize, ) -> Result)>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare( "SELECT id, seq, payload FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND (expires_at IS NULL OR expires_at > strftime('%s','now')) ORDER BY seq ASC LIMIT ?3", ) .map_err(|e| StorageError::Db(e.to_string()))?; let rows: Vec<(i64, i64, Vec)> = stmt .query_map(params![recipient_key, channel_id, limit as i64], |row| { Ok((row.get(0)?, row.get(1)?, row.get(2)?)) }) .map_err(|e| StorageError::Db(e.to_string()))? .collect::, _>>() .map_err(|e| StorageError::Db(e.to_string()))?; if !rows.is_empty() { let ids: Vec = rows.iter().map(|(id, _, _)| *id).collect(); let placeholders: String = ids.iter().map(|_| "?").collect::>().join(","); let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})"); let params: Vec<&dyn rusqlite::types::ToSql> = ids .iter() .map(|id| id as &dyn rusqlite::types::ToSql) .collect(); conn.execute(&sql, params.as_slice()) .map_err(|e| StorageError::Db(e.to_string()))?; } Ok(rows.into_iter().map(|(_, seq, payload)| (seq as u64, payload)).collect()) } fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result { let conn = self.get_conn()?; let count: i64 = conn .query_row( "SELECT COUNT(*) FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND (expires_at IS NULL OR expires_at > strftime('%s','now'))", params![recipient_key, channel_id], |row| row.get(0), ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(count as usize) } fn gc_expired_messages(&self, max_age_secs: u64) -> Result { let conn = self.get_conn()?; let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(); let cutoff = now.saturating_sub(max_age_secs); // Delete messages older than max_age_secs based on created_at. let deleted_age = conn .execute( "DELETE FROM deliveries WHERE created_at < ?1", params![cutoff as i64], ) .map_err(|e| StorageError::Db(e.to_string()))?; // Delete messages that have passed their per-message TTL expiry. let deleted_ttl = conn .execute( "DELETE FROM deliveries WHERE expires_at IS NOT NULL AND expires_at <= ?1", params![now as i64], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(deleted_age + deleted_ttl) } fn upload_hybrid_key( &self, identity_key: &[u8], hybrid_pk: Vec, ) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT OR REPLACE INTO hybrid_keys (identity_key, hybrid_public_key) VALUES (?1, ?2)", params![identity_key, hybrid_pk], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT hybrid_public_key FROM hybrid_keys WHERE identity_key = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![identity_key], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn store_server_setup(&self, setup: Vec) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT OR REPLACE INTO server_setup (id, setup_data) VALUES (1, ?1)", params![setup], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn get_server_setup(&self) -> Result>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT setup_data FROM server_setup WHERE id = 1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row([], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn store_signing_key_seed(&self, seed: Vec) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT OR REPLACE INTO server_signing_key (id, seed_data) VALUES (1, ?1)", params![seed], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn get_signing_key_seed(&self) -> Result>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT seed_data FROM server_signing_key WHERE id = 1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row([], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn save_kt_log(&self, bytes: Vec) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT OR REPLACE INTO kt_log (id, log_data) VALUES (1, ?1)", params![bytes], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn load_kt_log(&self) -> Result>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT log_data FROM kt_log WHERE id = 1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row([], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn save_revocation_log(&self, bytes: Vec) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT OR REPLACE INTO kt_log (id, log_data) VALUES (2, ?1)", params![bytes], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn load_revocation_log(&self) -> Result>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT log_data FROM kt_log WHERE id = 2") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row([], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn store_user_record(&self, username: &str, record: Vec) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT INTO users (username, opaque_record) VALUES (?1, ?2)", params![username, record], ) .map_err(|e| { if let rusqlite::Error::SqliteFailure(ref err, _) = &e { if err.code == rusqlite::ErrorCode::ConstraintViolation { return StorageError::DuplicateUser(username.to_string()); } } StorageError::Db(e.to_string()) })?; Ok(()) } fn get_user_record(&self, username: &str) -> Result>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT opaque_record FROM users WHERE username = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![username], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn has_user_record(&self, username: &str) -> Result { let conn = self.get_conn()?; let exists: bool = conn .query_row( "SELECT EXISTS(SELECT 1 FROM users WHERE username = ?1)", params![username], |row| row.get(0), ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(exists) } fn store_user_identity_key( &self, username: &str, identity_key: Vec, ) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT OR REPLACE INTO user_identity_keys (username, identity_key) VALUES (?1, ?2)", params![username, identity_key], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn get_user_identity_key(&self, username: &str) -> Result>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT identity_key FROM user_identity_keys WHERE username = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![username], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn resolve_identity_key(&self, identity_key: &[u8]) -> Result, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT username FROM user_identity_keys WHERE identity_key = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![identity_key], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn peek( &self, recipient_key: &[u8], channel_id: &[u8], limit: usize, ) -> Result)>, StorageError> { let conn = self.get_conn()?; let sql = if limit == 0 { "SELECT seq, payload FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND (expires_at IS NULL OR expires_at > strftime('%s','now')) ORDER BY seq ASC".to_string() } else { format!( "SELECT seq, payload FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND (expires_at IS NULL OR expires_at > strftime('%s','now')) ORDER BY seq ASC LIMIT {}", limit ) }; let mut stmt = conn.prepare(&sql).map_err(|e| StorageError::Db(e.to_string()))?; let rows: Vec<(i64, Vec)> = stmt .query_map(params![recipient_key, channel_id], |row| { Ok((row.get(0)?, row.get(1)?)) }) .map_err(|e| StorageError::Db(e.to_string()))? .collect::, _>>() .map_err(|e| StorageError::Db(e.to_string()))?; Ok(rows.into_iter().map(|(seq, payload)| (seq as u64, payload)).collect()) } fn ack( &self, recipient_key: &[u8], channel_id: &[u8], seq_up_to: u64, ) -> Result { let conn = self.get_conn()?; let deleted = conn .execute( "DELETE FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND seq <= ?3", params![recipient_key, channel_id, seq_up_to as i64], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(deleted) } fn publish_endpoint( &self, identity_key: &[u8], node_addr: Vec, ) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT OR REPLACE INTO endpoints (identity_key, node_addr) VALUES (?1, ?2)", params![identity_key, node_addr], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn resolve_endpoint(&self, identity_key: &[u8]) -> Result>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT node_addr FROM endpoints WHERE identity_key = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![identity_key], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<(Vec, bool), StorageError> { let (a, b) = if member_a < member_b { (member_a.to_vec(), member_b.to_vec()) } else { (member_b.to_vec(), member_a.to_vec()) }; let conn = self.get_conn()?; let existing: Option> = conn .query_row( "SELECT channel_id FROM channels WHERE member_a = ?1 AND member_b = ?2", params![a, b], |row| row.get(0), ) .optional() .map_err(|e| StorageError::Db(e.to_string()))?; if let Some(id) = existing { return Ok((id, false)); } let mut channel_id = [0u8; 16]; rand::rngs::OsRng.fill_bytes(&mut channel_id); conn.execute( "INSERT INTO channels (channel_id, member_a, member_b) VALUES (?1, ?2, ?3)", params![channel_id.as_slice(), a, b], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok((channel_id.to_vec(), true)) } fn get_channel_members(&self, channel_id: &[u8]) -> Result, Vec)>, StorageError> { let conn = self.get_conn()?; conn.query_row( "SELECT member_a, member_b FROM channels WHERE channel_id = ?1", params![channel_id], |row| Ok((row.get::<_, Vec>(0)?, row.get::<_, Vec>(1)?)), ) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn store_identity_home_server( &self, identity_key: &[u8], home_server: &str, ) -> Result<(), StorageError> { let conn = self.get_conn()?; let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() as i64; conn.execute( "INSERT OR REPLACE INTO identity_home_servers (identity_key, home_server, updated_at) VALUES (?1, ?2, ?3)", params![identity_key, home_server, now], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn get_identity_home_server( &self, identity_key: &[u8], ) -> Result, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT home_server FROM identity_home_servers WHERE identity_key = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![identity_key], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn upsert_federation_peer( &self, domain: &str, is_active: bool, ) -> Result<(), StorageError> { let conn = self.get_conn()?; let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() as i64; conn.execute( "INSERT INTO federation_peers (domain, last_seen, is_active) VALUES (?1, ?2, ?3) ON CONFLICT(domain) DO UPDATE SET last_seen = ?2, is_active = ?3", params![domain, now, is_active as i32], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn list_federation_peers(&self) -> Result, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT domain, is_active FROM federation_peers WHERE is_active = 1") .map_err(|e| StorageError::Db(e.to_string()))?; let rows = stmt .query_map([], |row| { Ok((row.get::<_, String>(0)?, row.get::<_, i32>(1)? != 0)) }) .map_err(|e| StorageError::Db(e.to_string()))? .collect::, _>>() .map_err(|e| StorageError::Db(e.to_string()))?; Ok(rows) } fn delete_account(&self, identity_key: &[u8]) -> Result<(), StorageError> { let conn = self.get_conn()?; // Resolve the username for this identity key. let username: Option = conn .query_row( "SELECT username FROM user_identity_keys WHERE identity_key = ?1", params![identity_key], |row| row.get(0), ) .optional() .map_err(|e| StorageError::Db(e.to_string()))?; // Use a transaction for atomicity. conn.execute_batch("BEGIN IMMEDIATE") .map_err(|e| StorageError::Db(e.to_string()))?; let result = (|| -> Result<(), StorageError> { // 1. Delete queued deliveries. conn.execute( "DELETE FROM deliveries WHERE recipient_key = ?1", params![identity_key], ).map_err(|e| StorageError::Db(e.to_string()))?; conn.execute( "DELETE FROM delivery_seq_counters WHERE recipient_key = ?1", params![identity_key], ).map_err(|e| StorageError::Db(e.to_string()))?; // 2. Delete key packages. conn.execute( "DELETE FROM key_packages WHERE identity_key = ?1", params![identity_key], ).map_err(|e| StorageError::Db(e.to_string()))?; // 3. Delete hybrid keys. conn.execute( "DELETE FROM hybrid_keys WHERE identity_key = ?1", params![identity_key], ).map_err(|e| StorageError::Db(e.to_string()))?; // 4. Delete channel memberships. conn.execute( "DELETE FROM channels WHERE member_a = ?1 OR member_b = ?1", params![identity_key], ).map_err(|e| StorageError::Db(e.to_string()))?; // 5. Delete identity key mapping. conn.execute( "DELETE FROM user_identity_keys WHERE identity_key = ?1", params![identity_key], ).map_err(|e| StorageError::Db(e.to_string()))?; // 6. Delete user record (by username). if let Some(ref uname) = username { conn.execute( "DELETE FROM users WHERE username = ?1", params![uname], ).map_err(|e| StorageError::Db(e.to_string()))?; } // 7. Delete endpoints (table may not exist on older schemas). let _ = conn.execute( "DELETE FROM endpoints WHERE identity_key = ?1", params![identity_key], ); // 8. Delete devices. let _ = conn.execute( "DELETE FROM devices WHERE identity_key = ?1", params![identity_key], ); // Do NOT delete KT log entries — append-only for auditability. Ok(()) })(); match result { Ok(()) => { conn.execute_batch("COMMIT") .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } Err(e) => { let _ = conn.execute_batch("ROLLBACK"); Err(e) } } } fn register_device(&self, identity_key: &[u8], device_id: &[u8], device_name: &str) -> Result { let conn = self.get_conn()?; // Check if device already exists. let exists: bool = conn .query_row( "SELECT EXISTS(SELECT 1 FROM devices WHERE identity_key = ?1 AND device_id = ?2)", params![identity_key, device_id], |row| row.get(0), ) .map_err(|e| StorageError::Db(e.to_string()))?; if exists { return Ok(false); } conn.execute( "INSERT INTO devices (identity_key, device_id, device_name) VALUES (?1, ?2, ?3)", params![identity_key, device_id, device_name], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(true) } fn list_devices(&self, identity_key: &[u8]) -> Result, String, u64)>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT device_id, device_name, registered_at FROM devices WHERE identity_key = ?1 ORDER BY registered_at ASC") .map_err(|e| StorageError::Db(e.to_string()))?; let rows = stmt .query_map(params![identity_key], |row| { Ok(( row.get::<_, Vec>(0)?, row.get::<_, String>(1)?, row.get::<_, i64>(2)? as u64, )) }) .map_err(|e| StorageError::Db(e.to_string()))? .collect::, _>>() .map_err(|e| StorageError::Db(e.to_string()))?; Ok(rows) } fn revoke_device(&self, identity_key: &[u8], device_id: &[u8]) -> Result { let conn = self.get_conn()?; let deleted = conn .execute( "DELETE FROM devices WHERE identity_key = ?1 AND device_id = ?2", params![identity_key, device_id], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(deleted > 0) } fn device_count(&self, identity_key: &[u8]) -> Result { let conn = self.get_conn()?; let count: i64 = conn .query_row( "SELECT COUNT(*) FROM devices WHERE identity_key = ?1", params![identity_key], |row| row.get(0), ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(count as usize) } // ── Session persistence ──────────────────────────────────────────────── fn store_session(&self, token: &[u8], record: &SessionRecord) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT OR REPLACE INTO sessions (token, username, identity_key, created_at, expires_at) VALUES (?1, ?2, ?3, ?4, ?5)", params![token, record.username, record.identity_key, record.created_at as i64, record.expires_at as i64], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn get_session(&self, token: &[u8]) -> Result, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT username, identity_key, created_at, expires_at FROM sessions WHERE token = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![token], |row| { Ok(SessionRecord { username: row.get(0)?, identity_key: row.get(1)?, created_at: row.get::<_, i64>(2)? as u64, expires_at: row.get::<_, i64>(3)? as u64, }) }) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn delete_expired_sessions(&self, now: u64) -> Result { let conn = self.get_conn()?; let deleted = conn .execute( "DELETE FROM sessions WHERE expires_at <= ?1", params![now as i64], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(deleted) } fn delete_session(&self, token: &[u8]) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute("DELETE FROM sessions WHERE token = ?1", params![token]) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } // ── Blob storage ─────────────────────────────────────────────────────── fn store_blob_chunk( &self, blob_hash: &[u8], chunk: &[u8], offset: u64, total_size: u64, mime_type: &str, ) -> Result>, StorageError> { let conn = self.get_conn()?; // Insert chunk into staging. conn.execute( "INSERT OR REPLACE INTO blob_staging (blob_hash, offset, chunk, total_size, mime_type) VALUES (?1, ?2, ?3, ?4, ?5)", params![blob_hash, offset as i64, chunk, total_size as i64, mime_type], ) .map_err(|e| StorageError::Db(e.to_string()))?; // Check if all chunks have arrived. let staged_size: i64 = conn .query_row( "SELECT COALESCE(SUM(LENGTH(chunk)), 0) FROM blob_staging WHERE blob_hash = ?1", params![blob_hash], |row| row.get(0), ) .map_err(|e| StorageError::Db(e.to_string()))?; if staged_size as u64 != total_size { return Ok(None); } // All chunks received — assemble in offset order. let mut stmt = conn .prepare("SELECT chunk FROM blob_staging WHERE blob_hash = ?1 ORDER BY offset ASC") .map_err(|e| StorageError::Db(e.to_string()))?; let chunks: Vec> = stmt .query_map(params![blob_hash], |row| row.get(0)) .map_err(|e| StorageError::Db(e.to_string()))? .collect::, _>>() .map_err(|e| StorageError::Db(e.to_string()))?; let mut assembled = Vec::with_capacity(total_size as usize); for c in &chunks { assembled.extend_from_slice(c); } // Verify SHA-256. let hash = Sha256::digest(&assembled); if hash.as_slice() != blob_hash { // Clean up staging rows for this blob. conn.execute( "DELETE FROM blob_staging WHERE blob_hash = ?1", params![blob_hash], ) .map_err(|e| StorageError::Db(e.to_string()))?; return Err(StorageError::Db( "blob hash mismatch after assembly".into(), )); } // Use the hash as the blob_id (content-addressable). let blob_id = hash.to_vec(); let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() as i64; conn.execute( "INSERT OR REPLACE INTO blobs (blob_id, data, total_size, mime_type, uploaded_at) VALUES (?1, ?2, ?3, ?4, ?5)", params![blob_id, assembled, total_size as i64, mime_type, now], ) .map_err(|e| StorageError::Db(e.to_string()))?; // Clean up staging. conn.execute( "DELETE FROM blob_staging WHERE blob_hash = ?1", params![blob_hash], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(Some(blob_id)) } fn get_blob_chunk( &self, blob_id: &[u8], offset: u64, length: u32, ) -> Result, u64, String)>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare( "SELECT substr(data, ?2, ?3), total_size, mime_type FROM blobs WHERE blob_id = ?1", ) .map_err(|e| StorageError::Db(e.to_string()))?; // SQLite substr is 1-indexed. stmt.query_row( params![blob_id, (offset + 1) as i64, length as i64], |row| { Ok(( row.get::<_, Vec>(0)?, row.get::<_, i64>(1)? as u64, row.get::<_, String>(2)?, )) }, ) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn store_group_metadata( &self, _group_id: &[u8], _name: &str, _description: &str, _avatar_hash: &[u8], _creator_key: &[u8], ) -> Result<(), StorageError> { Ok(()) } fn get_group_metadata(&self, _group_id: &[u8]) -> Result, Vec, u64)>, StorageError> { Ok(None) } fn add_group_member(&self, _group_id: &[u8], _identity_key: &[u8]) -> Result<(), StorageError> { Ok(()) } fn remove_group_member(&self, _group_id: &[u8], _identity_key: &[u8]) -> Result { Ok(false) } fn list_group_members(&self, _group_id: &[u8]) -> Result, u64)>, StorageError> { Ok(Vec::new()) } fn store_recovery_bundle( &self, token_hash: &[u8], bundle: Vec, ttl_secs: u64, ) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT OR REPLACE INTO recovery_bundles (token_hash, bundle, ttl_secs) VALUES (?1, ?2, ?3)", params![token_hash, bundle, ttl_secs as i64], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn get_recovery_bundle(&self, token_hash: &[u8]) -> Result>, StorageError> { let conn = self.get_conn()?; let mut stmt = conn .prepare("SELECT bundle FROM recovery_bundles WHERE token_hash = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![token_hash], |row| row.get::<_, Vec>(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn delete_recovery_bundle(&self, token_hash: &[u8]) -> Result { let conn = self.get_conn()?; let affected = conn .execute( "DELETE FROM recovery_bundles WHERE token_hash = ?1", params![token_hash], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(affected > 0) } fn store_report( &self, encrypted_report: &[u8], conversation_id: &[u8], reporter_identity: &[u8], ) -> Result { let conn = self.get_conn()?; conn.execute( "INSERT INTO reports (encrypted_report, conversation_id, reporter_identity) VALUES (?1, ?2, ?3)", params![encrypted_report, conversation_id, reporter_identity], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(conn.last_insert_rowid() as u64) } fn list_reports( &self, limit: u32, offset: u32, ) -> Result, Vec, Vec, u64)>, StorageError> { let conn = self.get_conn()?; let effective_limit = if limit == 0 { i64::MAX } else { limit as i64 }; let mut stmt = conn .prepare( "SELECT id, encrypted_report, conversation_id, reporter_identity, created_at FROM reports ORDER BY id LIMIT ?1 OFFSET ?2", ) .map_err(|e| StorageError::Db(e.to_string()))?; let rows = stmt .query_map(params![effective_limit, offset as i64], |row| { Ok(( row.get::<_, i64>(0)? as u64, row.get::<_, Vec>(1)?, row.get::<_, Vec>(2)?, row.get::<_, Vec>(3)?, row.get::<_, i64>(4)? as u64, )) }) .map_err(|e| StorageError::Db(e.to_string()))?; let mut result = Vec::new(); for row in rows { result.push(row.map_err(|e| StorageError::Db(e.to_string()))?); } Ok(result) } fn ban_user( &self, identity_key: &[u8], reason: &str, expires_at: u64, ) -> Result<(), StorageError> { let conn = self.get_conn()?; conn.execute( "INSERT OR REPLACE INTO bans (identity_key, reason, expires_at) VALUES (?1, ?2, ?3)", params![identity_key, reason, expires_at as i64], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn unban_user(&self, identity_key: &[u8]) -> Result { let conn = self.get_conn()?; let affected = conn .execute( "DELETE FROM bans WHERE identity_key = ?1", params![identity_key], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(affected > 0) } fn is_banned(&self, identity_key: &[u8]) -> Result, StorageError> { let conn = self.get_conn()?; let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() as i64; let mut stmt = conn .prepare( "SELECT reason FROM bans WHERE identity_key = ?1 AND (expires_at = 0 OR expires_at > ?2)", ) .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![identity_key, now], |row| row.get::<_, String>(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn list_banned(&self) -> Result, String, u64, u64)>, StorageError> { let conn = self.get_conn()?; let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() as i64; let mut stmt = conn .prepare( "SELECT identity_key, reason, banned_at, expires_at FROM bans WHERE expires_at = 0 OR expires_at > ?1 ORDER BY banned_at", ) .map_err(|e| StorageError::Db(e.to_string()))?; let rows = stmt .query_map(params![now], |row| { Ok(( row.get::<_, Vec>(0)?, row.get::<_, String>(1)?, row.get::<_, i64>(2)? as u64, row.get::<_, i64>(3)? as u64, )) }) .map_err(|e| StorageError::Db(e.to_string()))?; let mut result = Vec::new(); for row in rows { result.push(row.map_err(|e| StorageError::Db(e.to_string()))?); } Ok(result) } } /// Convenience extension for `rusqlite::OptionalExtension`. trait OptionalExt { fn optional(self) -> Result, rusqlite::Error>; } impl OptionalExt for Result { fn optional(self) -> Result, rusqlite::Error> { match self { Ok(v) => Ok(Some(v)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e), } } } #[cfg(test)] #[allow(clippy::unwrap_used)] mod tests { use super::*; use std::path::PathBuf; fn open_in_memory() -> SqlStore { // Pool size 1 for :memory: — each connection is a separate DB. SqlStore::open_with_pool_size(":memory:", "", 1).unwrap() } #[test] fn sets_user_version_after_migrate() { let dir = tempfile::tempdir().expect("tempdir"); let db_path: PathBuf = dir.path().join("store.db"); { let store = SqlStore::open(&db_path, "").expect("open store"); let _guard = store.get_conn().unwrap(); } let conn = rusqlite::Connection::open(&db_path).expect("reopen db"); let version: i32 = conn .pragma_query_value(None, "user_version", |row| row.get(0)) .expect("read user_version"); assert_eq!(version, SCHEMA_VERSION); } #[test] fn key_package_fifo() { let store = open_in_memory(); let identity = [1u8; 32]; store .upload_key_package(&identity, b"kp1".to_vec()) .unwrap(); store .upload_key_package(&identity, b"kp2".to_vec()) .unwrap(); assert_eq!( store.fetch_key_package(&identity).unwrap(), Some(b"kp1".to_vec()) ); assert_eq!( store.fetch_key_package(&identity).unwrap(), Some(b"kp2".to_vec()) ); assert_eq!(store.fetch_key_package(&identity).unwrap(), None); } #[test] fn delivery_round_trip() { let store = open_in_memory(); let rk = [1u8; 32]; let ch = b"channel-1"; let seq0 = store.enqueue(&rk, ch, b"msg1".to_vec(), None).unwrap(); let seq1 = store.enqueue(&rk, ch, b"msg2".to_vec(), None).unwrap(); assert_eq!(seq0, 0); assert_eq!(seq1, 1); let msgs = store.fetch(&rk, ch).unwrap(); assert_eq!(msgs, vec![(0u64, b"msg1".to_vec()), (1u64, b"msg2".to_vec())]); assert!(store.fetch(&rk, ch).unwrap().is_empty()); } #[test] fn fetch_limited_partial_drain() { let store = open_in_memory(); let rk = [5u8; 32]; let ch = b"ch"; store.enqueue(&rk, ch, b"a".to_vec(), None).unwrap(); store.enqueue(&rk, ch, b"b".to_vec(), None).unwrap(); store.enqueue(&rk, ch, b"c".to_vec(), None).unwrap(); let msgs = store.fetch_limited(&rk, ch, 2).unwrap(); assert_eq!(msgs, vec![(0u64, b"a".to_vec()), (1u64, b"b".to_vec())]); let remaining = store.fetch(&rk, ch).unwrap(); assert_eq!(remaining, vec![(2u64, b"c".to_vec())]); } #[test] fn queue_depth_count() { let store = open_in_memory(); let rk = [6u8; 32]; let ch = b"ch"; assert_eq!(store.queue_depth(&rk, ch).unwrap(), 0); store.enqueue(&rk, ch, b"x".to_vec(), None).unwrap(); store.enqueue(&rk, ch, b"y".to_vec(), None).unwrap(); assert_eq!(store.queue_depth(&rk, ch).unwrap(), 2); } #[test] fn has_user_record_check() { let store = open_in_memory(); assert!(!store.has_user_record("user1").unwrap()); store .store_user_record("user1", b"record".to_vec()) .unwrap(); assert!(store.has_user_record("user1").unwrap()); assert!(!store.has_user_record("user2").unwrap()); } #[test] fn user_identity_key_round_trip() { let store = open_in_memory(); assert!(store.get_user_identity_key("user1").unwrap().is_none()); store .store_user_identity_key("user1", vec![1u8; 32]) .unwrap(); assert_eq!( store.get_user_identity_key("user1").unwrap(), Some(vec![1u8; 32]) ); } #[test] fn hybrid_key_round_trip() { let store = open_in_memory(); let ik = [2u8; 32]; let pk = b"hybrid_public_key_data".to_vec(); store.upload_hybrid_key(&ik, pk.clone()).unwrap(); assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(pk)); } #[test] fn separate_channels_isolated() { let store = open_in_memory(); let rk = [4u8; 32]; store.enqueue(&rk, b"ch-a", b"a1".to_vec(), None).unwrap(); store.enqueue(&rk, b"ch-b", b"b1".to_vec(), None).unwrap(); let a_msgs = store.fetch(&rk, b"ch-a").unwrap(); assert_eq!(a_msgs, vec![(0u64, b"a1".to_vec())]); let b_msgs = store.fetch(&rk, b"ch-b").unwrap(); assert_eq!(b_msgs, vec![(0u64, b"b1".to_vec())]); } #[test] fn create_channel_was_new_first_call() { let store = open_in_memory(); let a = [10u8; 32]; let b = [11u8; 32]; let (id, was_new) = store.create_channel(&a, &b).unwrap(); assert_eq!(id.len(), 16, "channel_id must be 16 bytes"); assert!(was_new, "first create_channel must return was_new=true"); } #[test] fn create_channel_idempotent_same_direction() { let store = open_in_memory(); let a = [12u8; 32]; let b = [13u8; 32]; let (id1, was_new1) = store.create_channel(&a, &b).unwrap(); let (id2, was_new2) = store.create_channel(&a, &b).unwrap(); assert_eq!(id1, id2, "repeated call must return same channel_id"); assert!(was_new1); assert!(!was_new2, "second call must return was_new=false"); } #[test] fn create_channel_idempotent_reversed_direction() { let store = open_in_memory(); let a = [14u8; 32]; let b = [15u8; 32]; let (id1, was_new1) = store.create_channel(&a, &b).unwrap(); let (id2, was_new2) = store.create_channel(&b, &a).unwrap(); assert_eq!(id1, id2, "reversed-key call must return same channel_id"); assert!(was_new1); assert!(!was_new2, "reversed-key second call must return was_new=false"); } #[test] fn create_channel_different_pairs_isolated() { let store = open_in_memory(); let a = [16u8; 32]; let b = [17u8; 32]; let c = [18u8; 32]; let (id_ab, _) = store.create_channel(&a, &b).unwrap(); let (id_ac, _) = store.create_channel(&a, &c).unwrap(); let (id_bc, _) = store.create_channel(&b, &c).unwrap(); assert_ne!(id_ab, id_ac); assert_ne!(id_ab, id_bc); assert_ne!(id_ac, id_bc); } #[test] fn create_channel_get_members_roundtrip() { let store = open_in_memory(); let a = [20u8; 32]; let b = [21u8; 32]; let (id, _) = store.create_channel(&a, &b).unwrap(); let members = store.get_channel_members(&id).unwrap(); assert!(members.is_some(), "get_channel_members must return Some after create"); let (ma, mb) = members.unwrap(); // members stored in canonical (lex) order let (expected_a, expected_b) = if a < b { (a.to_vec(), b.to_vec()) } else { (b.to_vec(), a.to_vec()) }; assert_eq!(ma, expected_a); assert_eq!(mb, expected_b); } #[test] fn get_channel_members_unknown_id_returns_none() { let store = open_in_memory(); assert!(store.get_channel_members(&[0u8; 16]).unwrap().is_none()); } #[test] fn resolve_identity_key_after_store() { let store = open_in_memory(); let ik = [30u8; 32]; store.store_user_record("carol", b"record".to_vec()).unwrap(); store.store_user_identity_key("carol", ik.to_vec()).unwrap(); let resolved = store.resolve_identity_key(&ik).unwrap(); assert_eq!(resolved, Some("carol".to_string())); } #[test] fn resolve_identity_key_unknown_returns_none() { let store = open_in_memory(); let unknown = [31u8; 32]; assert!(store.resolve_identity_key(&unknown).unwrap().is_none()); } #[test] fn resolve_identity_key_two_users_distinct() { let store = open_in_memory(); let ik_a = [32u8; 32]; let ik_b = [33u8; 32]; store.store_user_record("user_a", b"ra".to_vec()).unwrap(); store.store_user_record("user_b", b"rb".to_vec()).unwrap(); store.store_user_identity_key("user_a", ik_a.to_vec()).unwrap(); store.store_user_identity_key("user_b", ik_b.to_vec()).unwrap(); assert_eq!(store.resolve_identity_key(&ik_a).unwrap(), Some("user_a".to_string())); assert_eq!(store.resolve_identity_key(&ik_b).unwrap(), Some("user_b".to_string())); } }