//! SQLCipher-backed persistent storage. use std::path::Path; use std::sync::Mutex; use rusqlite::{params, Connection}; use crate::storage::{StorageError, Store}; /// SQLCipher-encrypted storage backend. pub struct SqlStore { conn: Mutex, } impl SqlStore { pub fn open(path: impl AsRef, key: &str) -> Result { let conn = Connection::open(path).map_err(|e| StorageError::Db(e.to_string()))?; if !key.is_empty() { conn.pragma_update(None, "key", key) .map_err(|e| StorageError::Db(format!("PRAGMA key failed: {e}")))?; } conn.execute_batch( "PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL; PRAGMA foreign_keys = ON;", ) .map_err(|e| StorageError::Db(e.to_string()))?; let store = Self { conn: Mutex::new(conn), }; store.migrate()?; Ok(store) } fn migrate(&self) -> Result<(), StorageError> { let conn = self.conn.lock().unwrap(); conn.execute_batch( "CREATE TABLE IF NOT EXISTS key_packages ( id INTEGER PRIMARY KEY AUTOINCREMENT, identity_key BLOB NOT NULL, package_data BLOB NOT NULL, created_at INTEGER DEFAULT (strftime('%s','now')) ); CREATE TABLE IF NOT EXISTS deliveries ( id INTEGER PRIMARY KEY AUTOINCREMENT, recipient_key BLOB NOT NULL, channel_id BLOB NOT NULL DEFAULT X'', payload BLOB NOT NULL, created_at INTEGER DEFAULT (strftime('%s','now')) ); CREATE TABLE IF NOT EXISTS hybrid_keys ( identity_key BLOB PRIMARY KEY, hybrid_public_key BLOB NOT NULL ); CREATE INDEX IF NOT EXISTS idx_kp_identity ON key_packages(identity_key); CREATE INDEX IF NOT EXISTS idx_del_recipient_channel ON deliveries(recipient_key, channel_id); CREATE TABLE IF NOT EXISTS server_setup ( id INTEGER PRIMARY KEY CHECK (id = 1), setup_data BLOB NOT NULL ); CREATE TABLE IF NOT EXISTS users ( username TEXT PRIMARY KEY, opaque_record BLOB NOT NULL, created_at INTEGER DEFAULT (strftime('%s','now')) ); CREATE TABLE IF NOT EXISTS user_identity_keys ( username TEXT PRIMARY KEY, identity_key BLOB NOT NULL ); CREATE TABLE IF NOT EXISTS endpoints ( identity_key BLOB PRIMARY KEY, node_addr BLOB NOT NULL, updated_at INTEGER DEFAULT (strftime('%s','now')) );", ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } } impl Store for SqlStore { fn upload_key_package( &self, identity_key: &[u8], package: Vec, ) -> Result<(), StorageError> { let conn = self.conn.lock().unwrap(); conn.execute( "INSERT INTO key_packages (identity_key, package_data) VALUES (?1, ?2)", params![identity_key, package], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn fetch_key_package(&self, identity_key: &[u8]) -> Result>, StorageError> { let conn = self.conn.lock().unwrap(); let mut stmt = conn .prepare( "SELECT id, package_data FROM key_packages WHERE identity_key = ?1 ORDER BY id ASC LIMIT 1", ) .map_err(|e| StorageError::Db(e.to_string()))?; let row = stmt .query_row(params![identity_key], |row| { Ok((row.get::<_, i64>(0)?, row.get::<_, Vec>(1)?)) }) .optional() .map_err(|e| StorageError::Db(e.to_string()))?; match row { Some((id, package)) => { conn.execute("DELETE FROM key_packages WHERE id = ?1", params![id]) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(Some(package)) } None => Ok(None), } } fn enqueue( &self, recipient_key: &[u8], channel_id: &[u8], payload: Vec, ) -> Result<(), StorageError> { let conn = self.conn.lock().unwrap(); conn.execute( "INSERT INTO deliveries (recipient_key, channel_id, payload) VALUES (?1, ?2, ?3)", params![recipient_key, channel_id, payload], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn fetch( &self, recipient_key: &[u8], channel_id: &[u8], ) -> Result>, StorageError> { let conn = self.conn.lock().unwrap(); let mut stmt = conn .prepare( "SELECT id, payload FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 ORDER BY id ASC", ) .map_err(|e| StorageError::Db(e.to_string()))?; let rows: Vec<(i64, Vec)> = stmt .query_map(params![recipient_key, channel_id], |row| { Ok((row.get(0)?, row.get(1)?)) }) .map_err(|e| StorageError::Db(e.to_string()))? .collect::, _>>() .map_err(|e| StorageError::Db(e.to_string()))?; if !rows.is_empty() { let ids: Vec = rows.iter().map(|(id, _)| *id).collect(); let placeholders: String = ids.iter().map(|_| "?").collect::>().join(","); let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})"); let params: Vec<&dyn rusqlite::types::ToSql> = ids.iter().map(|id| id as &dyn rusqlite::types::ToSql).collect(); conn.execute(&sql, params.as_slice()) .map_err(|e| StorageError::Db(e.to_string()))?; } Ok(rows.into_iter().map(|(_, payload)| payload).collect()) } fn fetch_limited( &self, recipient_key: &[u8], channel_id: &[u8], limit: usize, ) -> Result>, StorageError> { let conn = self.conn.lock().unwrap(); let mut stmt = conn .prepare( "SELECT id, payload FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 ORDER BY id ASC LIMIT ?3", ) .map_err(|e| StorageError::Db(e.to_string()))?; let rows: Vec<(i64, Vec)> = stmt .query_map(params![recipient_key, channel_id, limit as i64], |row| { Ok((row.get(0)?, row.get(1)?)) }) .map_err(|e| StorageError::Db(e.to_string()))? .collect::, _>>() .map_err(|e| StorageError::Db(e.to_string()))?; if !rows.is_empty() { let ids: Vec = rows.iter().map(|(id, _)| *id).collect(); let placeholders: String = ids.iter().map(|_| "?").collect::>().join(","); let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})"); let params: Vec<&dyn rusqlite::types::ToSql> = ids.iter().map(|id| id as &dyn rusqlite::types::ToSql).collect(); conn.execute(&sql, params.as_slice()) .map_err(|e| StorageError::Db(e.to_string()))?; } Ok(rows.into_iter().map(|(_, payload)| payload).collect()) } fn queue_depth( &self, recipient_key: &[u8], channel_id: &[u8], ) -> Result { let conn = self.conn.lock().unwrap(); let count: i64 = conn .query_row( "SELECT COUNT(*) FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2", params![recipient_key, channel_id], |row| row.get(0), ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(count as usize) } fn gc_expired_messages(&self, max_age_secs: u64) -> Result { let conn = self.conn.lock().unwrap(); let cutoff = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() .saturating_sub(max_age_secs); let deleted = conn .execute( "DELETE FROM deliveries WHERE created_at < ?1", params![cutoff as i64], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(deleted) } fn upload_hybrid_key( &self, identity_key: &[u8], hybrid_pk: Vec, ) -> Result<(), StorageError> { let conn = self.conn.lock().unwrap(); conn.execute( "INSERT OR REPLACE INTO hybrid_keys (identity_key, hybrid_public_key) VALUES (?1, ?2)", params![identity_key, hybrid_pk], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result>, StorageError> { let conn = self.conn.lock().unwrap(); let mut stmt = conn .prepare("SELECT hybrid_public_key FROM hybrid_keys WHERE identity_key = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![identity_key], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn store_server_setup(&self, setup: Vec) -> Result<(), StorageError> { let conn = self.conn.lock().unwrap(); conn.execute( "INSERT OR REPLACE INTO server_setup (id, setup_data) VALUES (1, ?1)", params![setup], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn get_server_setup(&self) -> Result>, StorageError> { let conn = self.conn.lock().unwrap(); let mut stmt = conn .prepare("SELECT setup_data FROM server_setup WHERE id = 1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row([], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn store_user_record(&self, username: &str, record: Vec) -> Result<(), StorageError> { let conn = self.conn.lock().unwrap(); conn.execute( "INSERT OR REPLACE INTO users (username, opaque_record) VALUES (?1, ?2)", params![username, record], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn get_user_record(&self, username: &str) -> Result>, StorageError> { let conn = self.conn.lock().unwrap(); let mut stmt = conn .prepare("SELECT opaque_record FROM users WHERE username = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![username], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn has_user_record(&self, username: &str) -> Result { let conn = self.conn.lock().unwrap(); let exists: bool = conn .query_row( "SELECT EXISTS(SELECT 1 FROM users WHERE username = ?1)", params![username], |row| row.get(0), ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(exists) } fn store_user_identity_key( &self, username: &str, identity_key: Vec, ) -> Result<(), StorageError> { let conn = self.conn.lock().unwrap(); conn.execute( "INSERT OR REPLACE INTO user_identity_keys (username, identity_key) VALUES (?1, ?2)", params![username, identity_key], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn get_user_identity_key(&self, username: &str) -> Result>, StorageError> { let conn = self.conn.lock().unwrap(); let mut stmt = conn .prepare("SELECT identity_key FROM user_identity_keys WHERE username = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![username], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } fn publish_endpoint( &self, identity_key: &[u8], node_addr: Vec, ) -> Result<(), StorageError> { let conn = self.conn.lock().unwrap(); conn.execute( "INSERT OR REPLACE INTO endpoints (identity_key, node_addr) VALUES (?1, ?2)", params![identity_key, node_addr], ) .map_err(|e| StorageError::Db(e.to_string()))?; Ok(()) } fn resolve_endpoint(&self, identity_key: &[u8]) -> Result>, StorageError> { let conn = self.conn.lock().unwrap(); let mut stmt = conn .prepare("SELECT node_addr FROM endpoints WHERE identity_key = ?1") .map_err(|e| StorageError::Db(e.to_string()))?; stmt.query_row(params![identity_key], |row| row.get(0)) .optional() .map_err(|e| StorageError::Db(e.to_string())) } } /// Convenience extension for `rusqlite::OptionalExtension`. trait OptionalExt { fn optional(self) -> Result, rusqlite::Error>; } impl OptionalExt for Result { fn optional(self) -> Result, rusqlite::Error> { match self { Ok(v) => Ok(Some(v)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e), } } } #[cfg(test)] mod tests { use super::*; fn open_in_memory() -> SqlStore { SqlStore::open(":memory:", "").unwrap() } #[test] fn key_package_fifo() { let store = open_in_memory(); let mut identity = [0u8; 32]; identity[..31].copy_from_slice(b"alice_identity_key__32bytes_lon"); store .upload_key_package(&identity, b"kp1".to_vec()) .unwrap(); store .upload_key_package(&identity, b"kp2".to_vec()) .unwrap(); assert_eq!( store.fetch_key_package(&identity).unwrap(), Some(b"kp1".to_vec()) ); assert_eq!( store.fetch_key_package(&identity).unwrap(), Some(b"kp2".to_vec()) ); assert_eq!(store.fetch_key_package(&identity).unwrap(), None); } #[test] fn delivery_round_trip() { let store = open_in_memory(); let rk = [1u8; 32]; let ch = b"channel-1"; store.enqueue(&rk, ch, b"msg1".to_vec()).unwrap(); store.enqueue(&rk, ch, b"msg2".to_vec()).unwrap(); let msgs = store.fetch(&rk, ch).unwrap(); assert_eq!(msgs, vec![b"msg1".to_vec(), b"msg2".to_vec()]); assert!(store.fetch(&rk, ch).unwrap().is_empty()); } #[test] fn fetch_limited_partial_drain() { let store = open_in_memory(); let rk = [5u8; 32]; let ch = b"ch"; store.enqueue(&rk, ch, b"a".to_vec()).unwrap(); store.enqueue(&rk, ch, b"b".to_vec()).unwrap(); store.enqueue(&rk, ch, b"c".to_vec()).unwrap(); let msgs = store.fetch_limited(&rk, ch, 2).unwrap(); assert_eq!(msgs, vec![b"a".to_vec(), b"b".to_vec()]); let remaining = store.fetch(&rk, ch).unwrap(); assert_eq!(remaining, vec![b"c".to_vec()]); } #[test] fn queue_depth_count() { let store = open_in_memory(); let rk = [6u8; 32]; let ch = b"ch"; assert_eq!(store.queue_depth(&rk, ch).unwrap(), 0); store.enqueue(&rk, ch, b"x".to_vec()).unwrap(); store.enqueue(&rk, ch, b"y".to_vec()).unwrap(); assert_eq!(store.queue_depth(&rk, ch).unwrap(), 2); } #[test] fn has_user_record_check() { let store = open_in_memory(); assert!(!store.has_user_record("alice").unwrap()); store.store_user_record("alice", b"record".to_vec()).unwrap(); assert!(store.has_user_record("alice").unwrap()); assert!(!store.has_user_record("bob").unwrap()); } #[test] fn user_identity_key_round_trip() { let store = open_in_memory(); assert!(store.get_user_identity_key("alice").unwrap().is_none()); store.store_user_identity_key("alice", vec![1u8; 32]).unwrap(); assert_eq!(store.get_user_identity_key("alice").unwrap(), Some(vec![1u8; 32])); } #[test] fn hybrid_key_round_trip() { let store = open_in_memory(); let ik = [2u8; 32]; let pk = b"hybrid_public_key_data".to_vec(); store.upload_hybrid_key(&ik, pk.clone()).unwrap(); assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(pk)); } #[test] fn separate_channels_isolated() { let store = open_in_memory(); let rk = [4u8; 32]; store.enqueue(&rk, b"ch-a", b"a1".to_vec()).unwrap(); store.enqueue(&rk, b"ch-b", b"b1".to_vec()).unwrap(); let a_msgs = store.fetch(&rk, b"ch-a").unwrap(); assert_eq!(a_msgs, vec![b"a1".to_vec()]); let b_msgs = store.fetch(&rk, b"ch-b").unwrap(); assert_eq!(b_msgs, vec![b"b1".to_vec()]); } }