Rename all crate directories, package names, binary names, proto package/module paths, ALPN strings, env var prefixes, config filenames, mDNS service names, and plugin ABI symbols from quicproquo/qpq to quicprochat/qpc.
1463 lines
52 KiB
Rust
1463 lines
52 KiB
Rust
//! SQLCipher-backed persistent storage.
|
|
|
|
use std::path::Path;
|
|
use std::sync::Mutex;
|
|
|
|
use rand::RngCore;
|
|
use rusqlite::{params, Connection};
|
|
use sha2::{Digest, Sha256};
|
|
|
|
use crate::storage::{SessionRecord, StorageError, Store};
|
|
|
|
/// Schema version after introducing the migration runner (existing DBs had 1).
|
|
const SCHEMA_VERSION: i32 = 13;
|
|
|
|
/// Default number of connections in the pool.
|
|
const DEFAULT_POOL_SIZE: usize = 4;
|
|
|
|
/// Migrations: (migration_number, SQL). Files named NNN_name.sql, applied in order when N > user_version.
|
|
const MIGRATIONS: &[(i32, &str)] = &[
|
|
(1, include_str!("../migrations/001_initial.sql")),
|
|
(3, include_str!("../migrations/002_add_seq.sql")),
|
|
(4, include_str!("../migrations/003_channels.sql")),
|
|
(5, include_str!("../migrations/004_federation.sql")),
|
|
(6, include_str!("../migrations/005_signing_key.sql")),
|
|
(7, include_str!("../migrations/006_kt_log.sql")),
|
|
(8, include_str!("../migrations/007_add_expiry.sql")),
|
|
(9, include_str!("../migrations/008_devices.sql")),
|
|
(10, include_str!("../migrations/009_sessions.sql")),
|
|
(11, include_str!("../migrations/010_blobs.sql")),
|
|
(12, include_str!("../migrations/011_recovery_bundles.sql")),
|
|
(13, include_str!("../migrations/012_moderation.sql")),
|
|
];
|
|
|
|
/// Runs pending migrations on an open connection: applies any migration whose number is greater
|
|
/// than the current PRAGMA user_version, then sets user_version to SCHEMA_VERSION.
|
|
fn run_migrations(conn: &Connection) -> Result<(), StorageError> {
|
|
let current_version: i32 = conn
|
|
.pragma_query_value(None, "user_version", |row| row.get(0))
|
|
.map_err(|e| StorageError::Db(format!("PRAGMA user_version failed: {e}")))?;
|
|
|
|
for (migration_num, sql) in MIGRATIONS {
|
|
if *migration_num > current_version {
|
|
conn.execute_batch(sql).map_err(|e| StorageError::Db(e.to_string()))?;
|
|
}
|
|
}
|
|
|
|
conn.pragma_update(None, "user_version", SCHEMA_VERSION)
|
|
.map_err(|e| StorageError::Db(format!("set user_version failed: {e}")))?;
|
|
Ok(())
|
|
}
|
|
|
|
/// SQLCipher-encrypted storage backend with a connection pool.
|
|
///
|
|
/// Maintains `pool_size` SQLite connections (default 4) behind `std::sync::Mutex`.
|
|
/// Each store method tries all connections via `try_lock()` before falling back to
|
|
/// blocking on the first connection. WAL mode allows concurrent readers; writers
|
|
/// are serialised by SQLite itself.
|
|
pub struct SqlStore {
|
|
pool: Vec<Mutex<Connection>>,
|
|
}
|
|
|
|
impl SqlStore {
|
|
/// Try to acquire any connection from the pool without blocking.
|
|
/// Falls back to blocking on the first connection.
|
|
fn get_conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>, StorageError> {
|
|
// Fast path: try each connection without blocking.
|
|
for conn in &self.pool {
|
|
if let Ok(guard) = conn.try_lock() {
|
|
return Ok(guard);
|
|
}
|
|
}
|
|
// Slow path: block on the first connection.
|
|
self.pool[0]
|
|
.lock()
|
|
.map_err(|e| StorageError::Db(format!("lock poisoned: {e}")))
|
|
}
|
|
|
|
pub fn open(path: impl AsRef<Path>, key: &str) -> Result<Self, StorageError> {
|
|
Self::open_with_pool_size(path, key, DEFAULT_POOL_SIZE)
|
|
}
|
|
|
|
pub fn open_with_pool_size(
|
|
path: impl AsRef<Path>,
|
|
key: &str,
|
|
pool_size: usize,
|
|
) -> Result<Self, StorageError> {
|
|
let pool_size = pool_size.max(1);
|
|
let path = path.as_ref();
|
|
|
|
// Open the first connection and run migrations.
|
|
let first = Self::open_one(path, key)?;
|
|
let current_version: i32 = first
|
|
.pragma_query_value(None, "user_version", |row| row.get(0))
|
|
.map_err(|e| StorageError::Db(format!("PRAGMA user_version failed: {e}")))?;
|
|
|
|
if current_version > SCHEMA_VERSION {
|
|
return Err(StorageError::Db(format!(
|
|
"database schema version {current_version} is newer than supported {SCHEMA_VERSION}"
|
|
)));
|
|
}
|
|
|
|
run_migrations(&first)?;
|
|
|
|
let mut pool = Vec::with_capacity(pool_size);
|
|
pool.push(Mutex::new(first));
|
|
|
|
// Open remaining connections (they skip migrations since the first one already ran them).
|
|
for _ in 1..pool_size {
|
|
let conn = Self::open_one(path, key)?;
|
|
pool.push(Mutex::new(conn));
|
|
}
|
|
|
|
Ok(Self { pool })
|
|
}
|
|
|
|
/// Open a single connection with shared pragmas.
|
|
fn open_one(path: &Path, key: &str) -> Result<Connection, StorageError> {
|
|
let conn = Connection::open(path).map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
if !key.is_empty() {
|
|
conn.pragma_update(None, "key", key)
|
|
.map_err(|e| StorageError::Db(format!("PRAGMA key failed: {e}")))?;
|
|
}
|
|
|
|
conn.execute_batch(
|
|
"PRAGMA journal_mode = WAL;
|
|
PRAGMA synchronous = NORMAL;
|
|
PRAGMA foreign_keys = ON;",
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
Ok(conn)
|
|
}
|
|
}
|
|
|
|
impl Store for SqlStore {
|
|
fn upload_key_package(
|
|
&self,
|
|
identity_key: &[u8],
|
|
package: Vec<u8>,
|
|
) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT INTO key_packages (identity_key, package_data) VALUES (?1, ?2)",
|
|
params![identity_key, package],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
|
|
let mut stmt = conn
|
|
.prepare(
|
|
"SELECT id, package_data FROM key_packages
|
|
WHERE identity_key = ?1
|
|
ORDER BY id ASC
|
|
LIMIT 1",
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
let row = stmt
|
|
.query_row(params![identity_key], |row| {
|
|
Ok((row.get::<_, i64>(0)?, row.get::<_, Vec<u8>>(1)?))
|
|
})
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
match row {
|
|
Some((id, package)) => {
|
|
conn.execute("DELETE FROM key_packages WHERE id = ?1", params![id])
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(Some(package))
|
|
}
|
|
None => Ok(None),
|
|
}
|
|
}
|
|
|
|
fn enqueue(
|
|
&self,
|
|
recipient_key: &[u8],
|
|
channel_id: &[u8],
|
|
payload: Vec<u8>,
|
|
ttl_secs: Option<u32>,
|
|
) -> Result<u64, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
// Atomically get-and-increment the per-inbox sequence counter.
|
|
// RETURNING gives us the post-update next_seq; the assigned seq is next_seq - 1.
|
|
let seq: i64 = conn
|
|
.query_row(
|
|
"INSERT INTO delivery_seq_counters (recipient_key, channel_id, next_seq)
|
|
VALUES (?1, ?2, 1)
|
|
ON CONFLICT(recipient_key, channel_id) DO UPDATE SET next_seq = next_seq + 1
|
|
RETURNING next_seq - 1",
|
|
params![recipient_key, channel_id],
|
|
|row| row.get(0),
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
let expires_at: Option<i64> = ttl_secs.map(|ttl| {
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs() as i64;
|
|
now + ttl as i64
|
|
});
|
|
conn.execute(
|
|
"INSERT INTO deliveries (recipient_key, channel_id, seq, payload, expires_at) VALUES (?1, ?2, ?3, ?4, ?5)",
|
|
params![recipient_key, channel_id, seq, payload, expires_at],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(seq as u64)
|
|
}
|
|
|
|
fn fetch(
|
|
&self,
|
|
recipient_key: &[u8],
|
|
channel_id: &[u8],
|
|
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
|
|
let mut stmt = conn
|
|
.prepare(
|
|
"SELECT id, seq, payload FROM deliveries
|
|
WHERE recipient_key = ?1 AND channel_id = ?2
|
|
AND (expires_at IS NULL OR expires_at > strftime('%s','now'))
|
|
ORDER BY seq ASC",
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
let rows: Vec<(i64, i64, Vec<u8>)> = stmt
|
|
.query_map(params![recipient_key, channel_id], |row| {
|
|
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
|
|
})
|
|
.map_err(|e| StorageError::Db(e.to_string()))?
|
|
.collect::<Result<Vec<_>, _>>()
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
if !rows.is_empty() {
|
|
let ids: Vec<i64> = rows.iter().map(|(id, _, _)| *id).collect();
|
|
let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
|
|
let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})");
|
|
let params: Vec<&dyn rusqlite::types::ToSql> = ids
|
|
.iter()
|
|
.map(|id| id as &dyn rusqlite::types::ToSql)
|
|
.collect();
|
|
conn.execute(&sql, params.as_slice())
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
}
|
|
|
|
Ok(rows.into_iter().map(|(_, seq, payload)| (seq as u64, payload)).collect())
|
|
}
|
|
|
|
fn fetch_limited(
|
|
&self,
|
|
recipient_key: &[u8],
|
|
channel_id: &[u8],
|
|
limit: usize,
|
|
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
|
|
let mut stmt = conn
|
|
.prepare(
|
|
"SELECT id, seq, payload FROM deliveries
|
|
WHERE recipient_key = ?1 AND channel_id = ?2
|
|
AND (expires_at IS NULL OR expires_at > strftime('%s','now'))
|
|
ORDER BY seq ASC
|
|
LIMIT ?3",
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
let rows: Vec<(i64, i64, Vec<u8>)> = stmt
|
|
.query_map(params![recipient_key, channel_id, limit as i64], |row| {
|
|
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
|
|
})
|
|
.map_err(|e| StorageError::Db(e.to_string()))?
|
|
.collect::<Result<Vec<_>, _>>()
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
if !rows.is_empty() {
|
|
let ids: Vec<i64> = rows.iter().map(|(id, _, _)| *id).collect();
|
|
let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
|
|
let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})");
|
|
let params: Vec<&dyn rusqlite::types::ToSql> = ids
|
|
.iter()
|
|
.map(|id| id as &dyn rusqlite::types::ToSql)
|
|
.collect();
|
|
conn.execute(&sql, params.as_slice())
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
}
|
|
|
|
Ok(rows.into_iter().map(|(_, seq, payload)| (seq as u64, payload)).collect())
|
|
}
|
|
|
|
fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result<usize, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let count: i64 = conn
|
|
.query_row(
|
|
"SELECT COUNT(*) FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND (expires_at IS NULL OR expires_at > strftime('%s','now'))",
|
|
params![recipient_key, channel_id],
|
|
|row| row.get(0),
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(count as usize)
|
|
}
|
|
|
|
fn gc_expired_messages(&self, max_age_secs: u64) -> Result<usize, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs();
|
|
let cutoff = now.saturating_sub(max_age_secs);
|
|
// Delete messages older than max_age_secs based on created_at.
|
|
let deleted_age = conn
|
|
.execute(
|
|
"DELETE FROM deliveries WHERE created_at < ?1",
|
|
params![cutoff as i64],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
// Delete messages that have passed their per-message TTL expiry.
|
|
let deleted_ttl = conn
|
|
.execute(
|
|
"DELETE FROM deliveries WHERE expires_at IS NOT NULL AND expires_at <= ?1",
|
|
params![now as i64],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(deleted_age + deleted_ttl)
|
|
}
|
|
|
|
fn upload_hybrid_key(
|
|
&self,
|
|
identity_key: &[u8],
|
|
hybrid_pk: Vec<u8>,
|
|
) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO hybrid_keys (identity_key, hybrid_public_key) VALUES (?1, ?2)",
|
|
params![identity_key, hybrid_pk],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT hybrid_public_key FROM hybrid_keys WHERE identity_key = ?1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
stmt.query_row(params![identity_key], |row| row.get(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn store_server_setup(&self, setup: Vec<u8>) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO server_setup (id, setup_data) VALUES (1, ?1)",
|
|
params![setup],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn get_server_setup(&self) -> Result<Option<Vec<u8>>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT setup_data FROM server_setup WHERE id = 1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
stmt.query_row([], |row| row.get(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn store_signing_key_seed(&self, seed: Vec<u8>) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO server_signing_key (id, seed_data) VALUES (1, ?1)",
|
|
params![seed],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn get_signing_key_seed(&self) -> Result<Option<Vec<u8>>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT seed_data FROM server_signing_key WHERE id = 1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
stmt.query_row([], |row| row.get(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn save_kt_log(&self, bytes: Vec<u8>) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO kt_log (id, log_data) VALUES (1, ?1)",
|
|
params![bytes],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn load_kt_log(&self) -> Result<Option<Vec<u8>>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT log_data FROM kt_log WHERE id = 1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
stmt.query_row([], |row| row.get(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn save_revocation_log(&self, bytes: Vec<u8>) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO kt_log (id, log_data) VALUES (2, ?1)",
|
|
params![bytes],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn load_revocation_log(&self) -> Result<Option<Vec<u8>>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT log_data FROM kt_log WHERE id = 2")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
stmt.query_row([], |row| row.get(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT INTO users (username, opaque_record) VALUES (?1, ?2)",
|
|
params![username, record],
|
|
)
|
|
.map_err(|e| {
|
|
if let rusqlite::Error::SqliteFailure(ref err, _) = &e {
|
|
if err.code == rusqlite::ErrorCode::ConstraintViolation {
|
|
return StorageError::DuplicateUser(username.to_string());
|
|
}
|
|
}
|
|
StorageError::Db(e.to_string())
|
|
})?;
|
|
Ok(())
|
|
}
|
|
|
|
fn get_user_record(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT opaque_record FROM users WHERE username = ?1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
stmt.query_row(params![username], |row| row.get(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn has_user_record(&self, username: &str) -> Result<bool, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let exists: bool = conn
|
|
.query_row(
|
|
"SELECT EXISTS(SELECT 1 FROM users WHERE username = ?1)",
|
|
params![username],
|
|
|row| row.get(0),
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(exists)
|
|
}
|
|
|
|
fn store_user_identity_key(
|
|
&self,
|
|
username: &str,
|
|
identity_key: Vec<u8>,
|
|
) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO user_identity_keys (username, identity_key) VALUES (?1, ?2)",
|
|
params![username, identity_key],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn get_user_identity_key(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT identity_key FROM user_identity_keys WHERE username = ?1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
stmt.query_row(params![username], |row| row.get(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn resolve_identity_key(&self, identity_key: &[u8]) -> Result<Option<String>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT username FROM user_identity_keys WHERE identity_key = ?1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
stmt.query_row(params![identity_key], |row| row.get(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn peek(
|
|
&self,
|
|
recipient_key: &[u8],
|
|
channel_id: &[u8],
|
|
limit: usize,
|
|
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
|
|
let sql = if limit == 0 {
|
|
"SELECT seq, payload FROM deliveries
|
|
WHERE recipient_key = ?1 AND channel_id = ?2
|
|
AND (expires_at IS NULL OR expires_at > strftime('%s','now'))
|
|
ORDER BY seq ASC".to_string()
|
|
} else {
|
|
format!(
|
|
"SELECT seq, payload FROM deliveries
|
|
WHERE recipient_key = ?1 AND channel_id = ?2
|
|
AND (expires_at IS NULL OR expires_at > strftime('%s','now'))
|
|
ORDER BY seq ASC
|
|
LIMIT {}",
|
|
limit
|
|
)
|
|
};
|
|
|
|
let mut stmt = conn.prepare(&sql).map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
let rows: Vec<(i64, Vec<u8>)> = stmt
|
|
.query_map(params![recipient_key, channel_id], |row| {
|
|
Ok((row.get(0)?, row.get(1)?))
|
|
})
|
|
.map_err(|e| StorageError::Db(e.to_string()))?
|
|
.collect::<Result<Vec<_>, _>>()
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
Ok(rows.into_iter().map(|(seq, payload)| (seq as u64, payload)).collect())
|
|
}
|
|
|
|
fn ack(
|
|
&self,
|
|
recipient_key: &[u8],
|
|
channel_id: &[u8],
|
|
seq_up_to: u64,
|
|
) -> Result<usize, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let deleted = conn
|
|
.execute(
|
|
"DELETE FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND seq <= ?3",
|
|
params![recipient_key, channel_id, seq_up_to as i64],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(deleted)
|
|
}
|
|
|
|
fn publish_endpoint(
|
|
&self,
|
|
identity_key: &[u8],
|
|
node_addr: Vec<u8>,
|
|
) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO endpoints (identity_key, node_addr) VALUES (?1, ?2)",
|
|
params![identity_key, node_addr],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn resolve_endpoint(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT node_addr FROM endpoints WHERE identity_key = ?1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
stmt.query_row(params![identity_key], |row| row.get(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<(Vec<u8>, bool), StorageError> {
|
|
let (a, b) = if member_a < member_b {
|
|
(member_a.to_vec(), member_b.to_vec())
|
|
} else {
|
|
(member_b.to_vec(), member_a.to_vec())
|
|
};
|
|
let conn = self.get_conn()?;
|
|
let existing: Option<Vec<u8>> = conn
|
|
.query_row(
|
|
"SELECT channel_id FROM channels WHERE member_a = ?1 AND member_b = ?2",
|
|
params![a, b],
|
|
|row| row.get(0),
|
|
)
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
if let Some(id) = existing {
|
|
return Ok((id, false));
|
|
}
|
|
let mut channel_id = [0u8; 16];
|
|
rand::rngs::OsRng.fill_bytes(&mut channel_id);
|
|
conn.execute(
|
|
"INSERT INTO channels (channel_id, member_a, member_b) VALUES (?1, ?2, ?3)",
|
|
params![channel_id.as_slice(), a, b],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok((channel_id.to_vec(), true))
|
|
}
|
|
|
|
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.query_row(
|
|
"SELECT member_a, member_b FROM channels WHERE channel_id = ?1",
|
|
params![channel_id],
|
|
|row| Ok((row.get::<_, Vec<u8>>(0)?, row.get::<_, Vec<u8>>(1)?)),
|
|
)
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn store_identity_home_server(
|
|
&self,
|
|
identity_key: &[u8],
|
|
home_server: &str,
|
|
) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs() as i64;
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO identity_home_servers (identity_key, home_server, updated_at) VALUES (?1, ?2, ?3)",
|
|
params![identity_key, home_server, now],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn get_identity_home_server(
|
|
&self,
|
|
identity_key: &[u8],
|
|
) -> Result<Option<String>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT home_server FROM identity_home_servers WHERE identity_key = ?1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
stmt.query_row(params![identity_key], |row| row.get(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn upsert_federation_peer(
|
|
&self,
|
|
domain: &str,
|
|
is_active: bool,
|
|
) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs() as i64;
|
|
conn.execute(
|
|
"INSERT INTO federation_peers (domain, last_seen, is_active) VALUES (?1, ?2, ?3)
|
|
ON CONFLICT(domain) DO UPDATE SET last_seen = ?2, is_active = ?3",
|
|
params![domain, now, is_active as i32],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn list_federation_peers(&self) -> Result<Vec<(String, bool)>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT domain, is_active FROM federation_peers WHERE is_active = 1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
let rows = stmt
|
|
.query_map([], |row| {
|
|
Ok((row.get::<_, String>(0)?, row.get::<_, i32>(1)? != 0))
|
|
})
|
|
.map_err(|e| StorageError::Db(e.to_string()))?
|
|
.collect::<Result<Vec<_>, _>>()
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(rows)
|
|
}
|
|
|
|
fn delete_account(&self, identity_key: &[u8]) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
|
|
// Resolve the username for this identity key.
|
|
let username: Option<String> = conn
|
|
.query_row(
|
|
"SELECT username FROM user_identity_keys WHERE identity_key = ?1",
|
|
params![identity_key],
|
|
|row| row.get(0),
|
|
)
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
// Use a transaction for atomicity.
|
|
conn.execute_batch("BEGIN IMMEDIATE")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
let result = (|| -> Result<(), StorageError> {
|
|
// 1. Delete queued deliveries.
|
|
conn.execute(
|
|
"DELETE FROM deliveries WHERE recipient_key = ?1",
|
|
params![identity_key],
|
|
).map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
conn.execute(
|
|
"DELETE FROM delivery_seq_counters WHERE recipient_key = ?1",
|
|
params![identity_key],
|
|
).map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
// 2. Delete key packages.
|
|
conn.execute(
|
|
"DELETE FROM key_packages WHERE identity_key = ?1",
|
|
params![identity_key],
|
|
).map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
// 3. Delete hybrid keys.
|
|
conn.execute(
|
|
"DELETE FROM hybrid_keys WHERE identity_key = ?1",
|
|
params![identity_key],
|
|
).map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
// 4. Delete channel memberships.
|
|
conn.execute(
|
|
"DELETE FROM channels WHERE member_a = ?1 OR member_b = ?1",
|
|
params![identity_key],
|
|
).map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
// 5. Delete identity key mapping.
|
|
conn.execute(
|
|
"DELETE FROM user_identity_keys WHERE identity_key = ?1",
|
|
params![identity_key],
|
|
).map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
// 6. Delete user record (by username).
|
|
if let Some(ref uname) = username {
|
|
conn.execute(
|
|
"DELETE FROM users WHERE username = ?1",
|
|
params![uname],
|
|
).map_err(|e| StorageError::Db(e.to_string()))?;
|
|
}
|
|
|
|
// 7. Delete endpoints (table may not exist on older schemas).
|
|
let _ = conn.execute(
|
|
"DELETE FROM endpoints WHERE identity_key = ?1",
|
|
params![identity_key],
|
|
);
|
|
|
|
// 8. Delete devices.
|
|
let _ = conn.execute(
|
|
"DELETE FROM devices WHERE identity_key = ?1",
|
|
params![identity_key],
|
|
);
|
|
|
|
// Do NOT delete KT log entries — append-only for auditability.
|
|
|
|
Ok(())
|
|
})();
|
|
|
|
match result {
|
|
Ok(()) => {
|
|
conn.execute_batch("COMMIT")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
Err(e) => {
|
|
let _ = conn.execute_batch("ROLLBACK");
|
|
Err(e)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn register_device(&self, identity_key: &[u8], device_id: &[u8], device_name: &str) -> Result<bool, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
// Check if device already exists.
|
|
let exists: bool = conn
|
|
.query_row(
|
|
"SELECT EXISTS(SELECT 1 FROM devices WHERE identity_key = ?1 AND device_id = ?2)",
|
|
params![identity_key, device_id],
|
|
|row| row.get(0),
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
if exists {
|
|
return Ok(false);
|
|
}
|
|
conn.execute(
|
|
"INSERT INTO devices (identity_key, device_id, device_name) VALUES (?1, ?2, ?3)",
|
|
params![identity_key, device_id, device_name],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(true)
|
|
}
|
|
|
|
fn list_devices(&self, identity_key: &[u8]) -> Result<Vec<(Vec<u8>, String, u64)>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT device_id, device_name, registered_at FROM devices WHERE identity_key = ?1 ORDER BY registered_at ASC")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
let rows = stmt
|
|
.query_map(params![identity_key], |row| {
|
|
Ok((
|
|
row.get::<_, Vec<u8>>(0)?,
|
|
row.get::<_, String>(1)?,
|
|
row.get::<_, i64>(2)? as u64,
|
|
))
|
|
})
|
|
.map_err(|e| StorageError::Db(e.to_string()))?
|
|
.collect::<Result<Vec<_>, _>>()
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(rows)
|
|
}
|
|
|
|
fn revoke_device(&self, identity_key: &[u8], device_id: &[u8]) -> Result<bool, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let deleted = conn
|
|
.execute(
|
|
"DELETE FROM devices WHERE identity_key = ?1 AND device_id = ?2",
|
|
params![identity_key, device_id],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(deleted > 0)
|
|
}
|
|
|
|
fn device_count(&self, identity_key: &[u8]) -> Result<usize, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let count: i64 = conn
|
|
.query_row(
|
|
"SELECT COUNT(*) FROM devices WHERE identity_key = ?1",
|
|
params![identity_key],
|
|
|row| row.get(0),
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(count as usize)
|
|
}
|
|
|
|
// ── Session persistence ────────────────────────────────────────────────
|
|
|
|
fn store_session(&self, token: &[u8], record: &SessionRecord) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO sessions (token, username, identity_key, created_at, expires_at) VALUES (?1, ?2, ?3, ?4, ?5)",
|
|
params![token, record.username, record.identity_key, record.created_at as i64, record.expires_at as i64],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn get_session(&self, token: &[u8]) -> Result<Option<SessionRecord>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT username, identity_key, created_at, expires_at FROM sessions WHERE token = ?1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
stmt.query_row(params![token], |row| {
|
|
Ok(SessionRecord {
|
|
username: row.get(0)?,
|
|
identity_key: row.get(1)?,
|
|
created_at: row.get::<_, i64>(2)? as u64,
|
|
expires_at: row.get::<_, i64>(3)? as u64,
|
|
})
|
|
})
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn delete_expired_sessions(&self, now: u64) -> Result<usize, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let deleted = conn
|
|
.execute(
|
|
"DELETE FROM sessions WHERE expires_at <= ?1",
|
|
params![now as i64],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(deleted)
|
|
}
|
|
|
|
fn delete_session(&self, token: &[u8]) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute("DELETE FROM sessions WHERE token = ?1", params![token])
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
// ── Blob storage ───────────────────────────────────────────────────────
|
|
|
|
fn store_blob_chunk(
|
|
&self,
|
|
blob_hash: &[u8],
|
|
chunk: &[u8],
|
|
offset: u64,
|
|
total_size: u64,
|
|
mime_type: &str,
|
|
) -> Result<Option<Vec<u8>>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
|
|
// Insert chunk into staging.
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO blob_staging (blob_hash, offset, chunk, total_size, mime_type) VALUES (?1, ?2, ?3, ?4, ?5)",
|
|
params![blob_hash, offset as i64, chunk, total_size as i64, mime_type],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
// Check if all chunks have arrived.
|
|
let staged_size: i64 = conn
|
|
.query_row(
|
|
"SELECT COALESCE(SUM(LENGTH(chunk)), 0) FROM blob_staging WHERE blob_hash = ?1",
|
|
params![blob_hash],
|
|
|row| row.get(0),
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
if staged_size as u64 != total_size {
|
|
return Ok(None);
|
|
}
|
|
|
|
// All chunks received — assemble in offset order.
|
|
let mut stmt = conn
|
|
.prepare("SELECT chunk FROM blob_staging WHERE blob_hash = ?1 ORDER BY offset ASC")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
let chunks: Vec<Vec<u8>> = stmt
|
|
.query_map(params![blob_hash], |row| row.get(0))
|
|
.map_err(|e| StorageError::Db(e.to_string()))?
|
|
.collect::<Result<Vec<_>, _>>()
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
let mut assembled = Vec::with_capacity(total_size as usize);
|
|
for c in &chunks {
|
|
assembled.extend_from_slice(c);
|
|
}
|
|
|
|
// Verify SHA-256.
|
|
let hash = Sha256::digest(&assembled);
|
|
if hash.as_slice() != blob_hash {
|
|
// Clean up staging rows for this blob.
|
|
conn.execute(
|
|
"DELETE FROM blob_staging WHERE blob_hash = ?1",
|
|
params![blob_hash],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
return Err(StorageError::Db(
|
|
"blob hash mismatch after assembly".into(),
|
|
));
|
|
}
|
|
|
|
// Use the hash as the blob_id (content-addressable).
|
|
let blob_id = hash.to_vec();
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs() as i64;
|
|
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO blobs (blob_id, data, total_size, mime_type, uploaded_at) VALUES (?1, ?2, ?3, ?4, ?5)",
|
|
params![blob_id, assembled, total_size as i64, mime_type, now],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
// Clean up staging.
|
|
conn.execute(
|
|
"DELETE FROM blob_staging WHERE blob_hash = ?1",
|
|
params![blob_hash],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
Ok(Some(blob_id))
|
|
}
|
|
|
|
fn get_blob_chunk(
|
|
&self,
|
|
blob_id: &[u8],
|
|
offset: u64,
|
|
length: u32,
|
|
) -> Result<Option<(Vec<u8>, u64, String)>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare(
|
|
"SELECT substr(data, ?2, ?3), total_size, mime_type FROM blobs WHERE blob_id = ?1",
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
|
|
// SQLite substr is 1-indexed.
|
|
stmt.query_row(
|
|
params![blob_id, (offset + 1) as i64, length as i64],
|
|
|row| {
|
|
Ok((
|
|
row.get::<_, Vec<u8>>(0)?,
|
|
row.get::<_, i64>(1)? as u64,
|
|
row.get::<_, String>(2)?,
|
|
))
|
|
},
|
|
)
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn store_group_metadata(
|
|
&self,
|
|
_group_id: &[u8],
|
|
_name: &str,
|
|
_description: &str,
|
|
_avatar_hash: &[u8],
|
|
_creator_key: &[u8],
|
|
) -> Result<(), StorageError> {
|
|
Ok(())
|
|
}
|
|
|
|
fn get_group_metadata(&self, _group_id: &[u8]) -> Result<Option<(String, String, Vec<u8>, Vec<u8>, u64)>, StorageError> {
|
|
Ok(None)
|
|
}
|
|
|
|
fn add_group_member(&self, _group_id: &[u8], _identity_key: &[u8]) -> Result<(), StorageError> {
|
|
Ok(())
|
|
}
|
|
|
|
fn remove_group_member(&self, _group_id: &[u8], _identity_key: &[u8]) -> Result<bool, StorageError> {
|
|
Ok(false)
|
|
}
|
|
|
|
fn list_group_members(&self, _group_id: &[u8]) -> Result<Vec<(Vec<u8>, u64)>, StorageError> {
|
|
Ok(Vec::new())
|
|
}
|
|
|
|
fn store_recovery_bundle(
|
|
&self,
|
|
token_hash: &[u8],
|
|
bundle: Vec<u8>,
|
|
ttl_secs: u64,
|
|
) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO recovery_bundles (token_hash, bundle, ttl_secs) VALUES (?1, ?2, ?3)",
|
|
params![token_hash, bundle, ttl_secs as i64],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn get_recovery_bundle(&self, token_hash: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let mut stmt = conn
|
|
.prepare("SELECT bundle FROM recovery_bundles WHERE token_hash = ?1")
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
stmt.query_row(params![token_hash], |row| row.get::<_, Vec<u8>>(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn delete_recovery_bundle(&self, token_hash: &[u8]) -> Result<bool, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let affected = conn
|
|
.execute(
|
|
"DELETE FROM recovery_bundles WHERE token_hash = ?1",
|
|
params![token_hash],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(affected > 0)
|
|
}
|
|
|
|
fn store_report(
|
|
&self,
|
|
encrypted_report: &[u8],
|
|
conversation_id: &[u8],
|
|
reporter_identity: &[u8],
|
|
) -> Result<u64, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT INTO reports (encrypted_report, conversation_id, reporter_identity)
|
|
VALUES (?1, ?2, ?3)",
|
|
params![encrypted_report, conversation_id, reporter_identity],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(conn.last_insert_rowid() as u64)
|
|
}
|
|
|
|
fn list_reports(
|
|
&self,
|
|
limit: u32,
|
|
offset: u32,
|
|
) -> Result<Vec<(u64, Vec<u8>, Vec<u8>, Vec<u8>, u64)>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let effective_limit = if limit == 0 { i64::MAX } else { limit as i64 };
|
|
let mut stmt = conn
|
|
.prepare(
|
|
"SELECT id, encrypted_report, conversation_id, reporter_identity, created_at
|
|
FROM reports ORDER BY id LIMIT ?1 OFFSET ?2",
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
let rows = stmt
|
|
.query_map(params![effective_limit, offset as i64], |row| {
|
|
Ok((
|
|
row.get::<_, i64>(0)? as u64,
|
|
row.get::<_, Vec<u8>>(1)?,
|
|
row.get::<_, Vec<u8>>(2)?,
|
|
row.get::<_, Vec<u8>>(3)?,
|
|
row.get::<_, i64>(4)? as u64,
|
|
))
|
|
})
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
let mut result = Vec::new();
|
|
for row in rows {
|
|
result.push(row.map_err(|e| StorageError::Db(e.to_string()))?);
|
|
}
|
|
Ok(result)
|
|
}
|
|
|
|
fn ban_user(
|
|
&self,
|
|
identity_key: &[u8],
|
|
reason: &str,
|
|
expires_at: u64,
|
|
) -> Result<(), StorageError> {
|
|
let conn = self.get_conn()?;
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO bans (identity_key, reason, expires_at)
|
|
VALUES (?1, ?2, ?3)",
|
|
params![identity_key, reason, expires_at as i64],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(())
|
|
}
|
|
|
|
fn unban_user(&self, identity_key: &[u8]) -> Result<bool, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let affected = conn
|
|
.execute(
|
|
"DELETE FROM bans WHERE identity_key = ?1",
|
|
params![identity_key],
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
Ok(affected > 0)
|
|
}
|
|
|
|
fn is_banned(&self, identity_key: &[u8]) -> Result<Option<String>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs() as i64;
|
|
let mut stmt = conn
|
|
.prepare(
|
|
"SELECT reason FROM bans
|
|
WHERE identity_key = ?1 AND (expires_at = 0 OR expires_at > ?2)",
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
stmt.query_row(params![identity_key, now], |row| row.get::<_, String>(0))
|
|
.optional()
|
|
.map_err(|e| StorageError::Db(e.to_string()))
|
|
}
|
|
|
|
fn list_banned(&self) -> Result<Vec<(Vec<u8>, String, u64, u64)>, StorageError> {
|
|
let conn = self.get_conn()?;
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs() as i64;
|
|
let mut stmt = conn
|
|
.prepare(
|
|
"SELECT identity_key, reason, banned_at, expires_at FROM bans
|
|
WHERE expires_at = 0 OR expires_at > ?1
|
|
ORDER BY banned_at",
|
|
)
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
let rows = stmt
|
|
.query_map(params![now], |row| {
|
|
Ok((
|
|
row.get::<_, Vec<u8>>(0)?,
|
|
row.get::<_, String>(1)?,
|
|
row.get::<_, i64>(2)? as u64,
|
|
row.get::<_, i64>(3)? as u64,
|
|
))
|
|
})
|
|
.map_err(|e| StorageError::Db(e.to_string()))?;
|
|
let mut result = Vec::new();
|
|
for row in rows {
|
|
result.push(row.map_err(|e| StorageError::Db(e.to_string()))?);
|
|
}
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
/// Convenience extension for `rusqlite::OptionalExtension`.
|
|
trait OptionalExt<T> {
|
|
fn optional(self) -> Result<Option<T>, rusqlite::Error>;
|
|
}
|
|
|
|
impl<T> OptionalExt<T> for Result<T, rusqlite::Error> {
|
|
fn optional(self) -> Result<Option<T>, rusqlite::Error> {
|
|
match self {
|
|
Ok(v) => Ok(Some(v)),
|
|
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
|
|
Err(e) => Err(e),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[allow(clippy::unwrap_used)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::path::PathBuf;
|
|
|
|
fn open_in_memory() -> SqlStore {
|
|
// Pool size 1 for :memory: — each connection is a separate DB.
|
|
SqlStore::open_with_pool_size(":memory:", "", 1).unwrap()
|
|
}
|
|
|
|
#[test]
|
|
fn sets_user_version_after_migrate() {
|
|
let dir = tempfile::tempdir().expect("tempdir");
|
|
let db_path: PathBuf = dir.path().join("store.db");
|
|
|
|
{
|
|
let store = SqlStore::open(&db_path, "").expect("open store");
|
|
let _guard = store.get_conn().unwrap();
|
|
}
|
|
|
|
let conn = rusqlite::Connection::open(&db_path).expect("reopen db");
|
|
let version: i32 = conn
|
|
.pragma_query_value(None, "user_version", |row| row.get(0))
|
|
.expect("read user_version");
|
|
|
|
assert_eq!(version, SCHEMA_VERSION);
|
|
}
|
|
|
|
#[test]
|
|
fn key_package_fifo() {
|
|
let store = open_in_memory();
|
|
let identity = [1u8; 32];
|
|
|
|
store
|
|
.upload_key_package(&identity, b"kp1".to_vec())
|
|
.unwrap();
|
|
store
|
|
.upload_key_package(&identity, b"kp2".to_vec())
|
|
.unwrap();
|
|
|
|
assert_eq!(
|
|
store.fetch_key_package(&identity).unwrap(),
|
|
Some(b"kp1".to_vec())
|
|
);
|
|
assert_eq!(
|
|
store.fetch_key_package(&identity).unwrap(),
|
|
Some(b"kp2".to_vec())
|
|
);
|
|
assert_eq!(store.fetch_key_package(&identity).unwrap(), None);
|
|
}
|
|
|
|
#[test]
|
|
fn delivery_round_trip() {
|
|
let store = open_in_memory();
|
|
let rk = [1u8; 32];
|
|
let ch = b"channel-1";
|
|
|
|
let seq0 = store.enqueue(&rk, ch, b"msg1".to_vec(), None).unwrap();
|
|
let seq1 = store.enqueue(&rk, ch, b"msg2".to_vec(), None).unwrap();
|
|
assert_eq!(seq0, 0);
|
|
assert_eq!(seq1, 1);
|
|
|
|
let msgs = store.fetch(&rk, ch).unwrap();
|
|
assert_eq!(msgs, vec![(0u64, b"msg1".to_vec()), (1u64, b"msg2".to_vec())]);
|
|
|
|
assert!(store.fetch(&rk, ch).unwrap().is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn fetch_limited_partial_drain() {
|
|
let store = open_in_memory();
|
|
let rk = [5u8; 32];
|
|
let ch = b"ch";
|
|
|
|
store.enqueue(&rk, ch, b"a".to_vec(), None).unwrap();
|
|
store.enqueue(&rk, ch, b"b".to_vec(), None).unwrap();
|
|
store.enqueue(&rk, ch, b"c".to_vec(), None).unwrap();
|
|
|
|
let msgs = store.fetch_limited(&rk, ch, 2).unwrap();
|
|
assert_eq!(msgs, vec![(0u64, b"a".to_vec()), (1u64, b"b".to_vec())]);
|
|
|
|
let remaining = store.fetch(&rk, ch).unwrap();
|
|
assert_eq!(remaining, vec![(2u64, b"c".to_vec())]);
|
|
}
|
|
|
|
#[test]
|
|
fn queue_depth_count() {
|
|
let store = open_in_memory();
|
|
let rk = [6u8; 32];
|
|
let ch = b"ch";
|
|
|
|
assert_eq!(store.queue_depth(&rk, ch).unwrap(), 0);
|
|
store.enqueue(&rk, ch, b"x".to_vec(), None).unwrap();
|
|
store.enqueue(&rk, ch, b"y".to_vec(), None).unwrap();
|
|
assert_eq!(store.queue_depth(&rk, ch).unwrap(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn has_user_record_check() {
|
|
let store = open_in_memory();
|
|
assert!(!store.has_user_record("user1").unwrap());
|
|
store
|
|
.store_user_record("user1", b"record".to_vec())
|
|
.unwrap();
|
|
assert!(store.has_user_record("user1").unwrap());
|
|
assert!(!store.has_user_record("user2").unwrap());
|
|
}
|
|
|
|
#[test]
|
|
fn user_identity_key_round_trip() {
|
|
let store = open_in_memory();
|
|
assert!(store.get_user_identity_key("user1").unwrap().is_none());
|
|
store
|
|
.store_user_identity_key("user1", vec![1u8; 32])
|
|
.unwrap();
|
|
assert_eq!(
|
|
store.get_user_identity_key("user1").unwrap(),
|
|
Some(vec![1u8; 32])
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn hybrid_key_round_trip() {
|
|
let store = open_in_memory();
|
|
let ik = [2u8; 32];
|
|
let pk = b"hybrid_public_key_data".to_vec();
|
|
|
|
store.upload_hybrid_key(&ik, pk.clone()).unwrap();
|
|
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(pk));
|
|
}
|
|
|
|
#[test]
|
|
fn separate_channels_isolated() {
|
|
let store = open_in_memory();
|
|
let rk = [4u8; 32];
|
|
|
|
store.enqueue(&rk, b"ch-a", b"a1".to_vec(), None).unwrap();
|
|
store.enqueue(&rk, b"ch-b", b"b1".to_vec(), None).unwrap();
|
|
|
|
let a_msgs = store.fetch(&rk, b"ch-a").unwrap();
|
|
assert_eq!(a_msgs, vec![(0u64, b"a1".to_vec())]);
|
|
|
|
let b_msgs = store.fetch(&rk, b"ch-b").unwrap();
|
|
assert_eq!(b_msgs, vec![(0u64, b"b1".to_vec())]);
|
|
}
|
|
|
|
#[test]
|
|
fn create_channel_was_new_first_call() {
|
|
let store = open_in_memory();
|
|
let a = [10u8; 32];
|
|
let b = [11u8; 32];
|
|
let (id, was_new) = store.create_channel(&a, &b).unwrap();
|
|
assert_eq!(id.len(), 16, "channel_id must be 16 bytes");
|
|
assert!(was_new, "first create_channel must return was_new=true");
|
|
}
|
|
|
|
#[test]
|
|
fn create_channel_idempotent_same_direction() {
|
|
let store = open_in_memory();
|
|
let a = [12u8; 32];
|
|
let b = [13u8; 32];
|
|
let (id1, was_new1) = store.create_channel(&a, &b).unwrap();
|
|
let (id2, was_new2) = store.create_channel(&a, &b).unwrap();
|
|
assert_eq!(id1, id2, "repeated call must return same channel_id");
|
|
assert!(was_new1);
|
|
assert!(!was_new2, "second call must return was_new=false");
|
|
}
|
|
|
|
#[test]
|
|
fn create_channel_idempotent_reversed_direction() {
|
|
let store = open_in_memory();
|
|
let a = [14u8; 32];
|
|
let b = [15u8; 32];
|
|
let (id1, was_new1) = store.create_channel(&a, &b).unwrap();
|
|
let (id2, was_new2) = store.create_channel(&b, &a).unwrap();
|
|
assert_eq!(id1, id2, "reversed-key call must return same channel_id");
|
|
assert!(was_new1);
|
|
assert!(!was_new2, "reversed-key second call must return was_new=false");
|
|
}
|
|
|
|
#[test]
|
|
fn create_channel_different_pairs_isolated() {
|
|
let store = open_in_memory();
|
|
let a = [16u8; 32];
|
|
let b = [17u8; 32];
|
|
let c = [18u8; 32];
|
|
let (id_ab, _) = store.create_channel(&a, &b).unwrap();
|
|
let (id_ac, _) = store.create_channel(&a, &c).unwrap();
|
|
let (id_bc, _) = store.create_channel(&b, &c).unwrap();
|
|
assert_ne!(id_ab, id_ac);
|
|
assert_ne!(id_ab, id_bc);
|
|
assert_ne!(id_ac, id_bc);
|
|
}
|
|
|
|
#[test]
|
|
fn create_channel_get_members_roundtrip() {
|
|
let store = open_in_memory();
|
|
let a = [20u8; 32];
|
|
let b = [21u8; 32];
|
|
let (id, _) = store.create_channel(&a, &b).unwrap();
|
|
let members = store.get_channel_members(&id).unwrap();
|
|
assert!(members.is_some(), "get_channel_members must return Some after create");
|
|
let (ma, mb) = members.unwrap();
|
|
// members stored in canonical (lex) order
|
|
let (expected_a, expected_b) = if a < b {
|
|
(a.to_vec(), b.to_vec())
|
|
} else {
|
|
(b.to_vec(), a.to_vec())
|
|
};
|
|
assert_eq!(ma, expected_a);
|
|
assert_eq!(mb, expected_b);
|
|
}
|
|
|
|
#[test]
|
|
fn get_channel_members_unknown_id_returns_none() {
|
|
let store = open_in_memory();
|
|
assert!(store.get_channel_members(&[0u8; 16]).unwrap().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_identity_key_after_store() {
|
|
let store = open_in_memory();
|
|
let ik = [30u8; 32];
|
|
store.store_user_record("carol", b"record".to_vec()).unwrap();
|
|
store.store_user_identity_key("carol", ik.to_vec()).unwrap();
|
|
let resolved = store.resolve_identity_key(&ik).unwrap();
|
|
assert_eq!(resolved, Some("carol".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_identity_key_unknown_returns_none() {
|
|
let store = open_in_memory();
|
|
let unknown = [31u8; 32];
|
|
assert!(store.resolve_identity_key(&unknown).unwrap().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_identity_key_two_users_distinct() {
|
|
let store = open_in_memory();
|
|
let ik_a = [32u8; 32];
|
|
let ik_b = [33u8; 32];
|
|
store.store_user_record("user_a", b"ra".to_vec()).unwrap();
|
|
store.store_user_record("user_b", b"rb".to_vec()).unwrap();
|
|
store.store_user_identity_key("user_a", ik_a.to_vec()).unwrap();
|
|
store.store_user_identity_key("user_b", ik_b.to_vec()).unwrap();
|
|
assert_eq!(store.resolve_identity_key(&ik_a).unwrap(), Some("user_a".to_string()));
|
|
assert_eq!(store.resolve_identity_key(&ik_b).unwrap(), Some("user_b".to_string()));
|
|
}
|
|
}
|