feat: add post-quantum hybrid KEM + SQLCipher persistence

Feature 1 — Post-Quantum Hybrid KEM (X25519 + ML-KEM-768):
- Create hybrid_kem.rs with keygen, encrypt, decrypt + 11 unit tests
- Wire format: version(1) | x25519_eph_pk(32) | mlkem_ct(1088) | nonce(12) | ct
- Add uploadHybridKey/fetchHybridKey RPCs to node.capnp schema
- Server: hybrid key storage in FileBackedStore + RPC handlers
- Client: hybrid keypair in StoredState, auto-wrap/unwrap in send/recv/invite/join
- demo-group runs full hybrid PQ envelope round-trip

Feature 2 — SQLCipher Persistence:
- Extract Store trait from FileBackedStore API
- Create SqlStore (rusqlite + bundled-sqlcipher) with encrypted-at-rest SQLite
- Schema: key_packages, deliveries, hybrid_keys tables with indexes
- Server CLI: --store-backend=sql, --db-path, --db-key flags
- 5 unit tests for SqlStore (FIFO, round-trip, upsert, channel isolation)

Also includes: client lib.rs refactor, auth config, TOML config file support,
mdBook documentation, and various cleanups by user.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-22 08:07:48 +01:00
parent d1ddef4cea
commit f334ed3d43
81 changed files with 14502 additions and 2289 deletions

View File

@@ -32,6 +32,9 @@ quinn-proto = { workspace = true }
rustls = { workspace = true }
rcgen = { workspace = true }
# Database
rusqlite = { workspace = true }
# Error handling
anyhow = { workspace = true }
thiserror = { workspace = true }
@@ -40,3 +43,4 @@ serde = { workspace = true }
# CLI
clap = { workspace = true }
toml = { version = "0.8" }

View File

@@ -25,14 +25,15 @@
//! | `QUICNPROTOCHAT_LISTEN` | `--listen` | `0.0.0.0:4201` |
//! | `RUST_LOG` | — | `info` |
use std::{fs, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
use std::{fs, net::SocketAddr, path::{Path, PathBuf}, sync::Arc, time::Duration};
use anyhow::Context;
use serde::Deserialize;
use capnp::capability::Promise;
use capnp_rpc::{rpc_twoparty_capnp::Side, twoparty, RpcSystem};
use clap::Parser;
use dashmap::DashMap;
use quicnprotochat_proto::node_capnp::node_service;
use quicnprotochat_proto::node_capnp::{auth, node_service};
use quinn::{Endpoint, ServerConfig};
use quinn_proto::crypto::rustls::QuicServerConfig;
use rcgen::generate_simple_self_signed;
@@ -43,12 +44,139 @@ use tokio::sync::Notify;
use tokio::time::timeout;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
mod sql_store;
mod storage;
use storage::{FileBackedStore, StorageError};
use sql_store::SqlStore;
use storage::{FileBackedStore, Store, StorageError};
const MAX_PAYLOAD_BYTES: usize = 5 * 1024 * 1024; // 5 MB cap per message
const MAX_KEYPACKAGE_BYTES: usize = 1 * 1024 * 1024; // 1 MB cap per KeyPackage
const CURRENT_WIRE_VERSION: u16 = 1; // allow 0 (legacy) and 1 (current)
const CURRENT_WIRE_VERSION: u16 = 1; // legacy disabled; current wire version only
const DEFAULT_LISTEN: &str = "0.0.0.0:7000";
const DEFAULT_DATA_DIR: &str = "data";
const DEFAULT_TLS_CERT: &str = "data/server-cert.der";
const DEFAULT_TLS_KEY: &str = "data/server-key.der";
const DEFAULT_STORE_BACKEND: &str = "file";
const DEFAULT_DB_PATH: &str = "data/quicnprotochat.db";
#[derive(Clone, Debug)]
struct AuthConfig {
required_token: Option<Vec<u8>>,
}
impl AuthConfig {
fn new(required_token: Option<String>) -> Self {
let required_token = required_token.filter(|s| !s.is_empty()).map(|s| s.into_bytes());
Self { required_token }
}
}
#[derive(Debug, Default, Deserialize)]
struct FileConfig {
listen: Option<String>,
data_dir: Option<String>,
tls_cert: Option<PathBuf>,
tls_key: Option<PathBuf>,
auth_token: Option<String>,
store_backend: Option<String>,
db_path: Option<PathBuf>,
db_key: Option<String>,
}
#[derive(Debug)]
struct EffectiveConfig {
listen: String,
data_dir: String,
tls_cert: PathBuf,
tls_key: PathBuf,
auth_token: Option<String>,
store_backend: String,
db_path: PathBuf,
db_key: String,
}
fn load_config(path: Option<&Path>) -> anyhow::Result<FileConfig> {
let path = match path {
Some(p) => PathBuf::from(p),
None => PathBuf::from("quicnprotochat-server.toml"),
};
if !path.exists() {
return Ok(FileConfig::default());
}
let contents = fs::read_to_string(&path)
.with_context(|| format!("read config file {path:?}"))?;
let cfg: FileConfig = toml::from_str(&contents)
.with_context(|| format!("parse config file {path:?}"))?;
Ok(cfg)
}
fn merge_config(args: &Args, file: &FileConfig) -> EffectiveConfig {
let listen = if args.listen == DEFAULT_LISTEN {
file.listen.clone().unwrap_or_else(|| DEFAULT_LISTEN.to_string())
} else {
args.listen.clone()
};
let data_dir = if args.data_dir == DEFAULT_DATA_DIR {
file.data_dir.clone().unwrap_or_else(|| DEFAULT_DATA_DIR.to_string())
} else {
args.data_dir.clone()
};
let tls_cert = if args.tls_cert == PathBuf::from(DEFAULT_TLS_CERT) {
file.tls_cert.clone().unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_CERT))
} else {
args.tls_cert.clone()
};
let tls_key = if args.tls_key == PathBuf::from(DEFAULT_TLS_KEY) {
file.tls_key.clone().unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_KEY))
} else {
args.tls_key.clone()
};
let auth_token = if args.auth_token.is_some() {
args.auth_token.clone()
} else {
file.auth_token.clone()
};
let store_backend = if args.store_backend == DEFAULT_STORE_BACKEND {
file.store_backend
.clone()
.unwrap_or_else(|| DEFAULT_STORE_BACKEND.to_string())
} else {
args.store_backend.clone()
};
let db_path = if args.db_path == PathBuf::from(DEFAULT_DB_PATH) {
file.db_path
.clone()
.unwrap_or_else(|| PathBuf::from(DEFAULT_DB_PATH))
} else {
args.db_path.clone()
};
let db_key = if args.db_key.is_empty() {
file.db_key.clone().unwrap_or_else(|| args.db_key.clone())
} else {
args.db_key.clone()
};
EffectiveConfig {
listen,
data_dir,
tls_cert,
tls_key,
auth_token,
store_backend,
db_path,
db_key,
}
}
// ── CLI ───────────────────────────────────────────────────────────────────────
@@ -59,37 +187,50 @@ const CURRENT_WIRE_VERSION: u16 = 1; // allow 0 (legacy) and 1 (current)
version
)]
struct Args {
/// Optional path to a TOML config file (fields map to CLI flags).
#[arg(long, env = "QUICNPROTOCHAT_CONFIG")]
config: Option<PathBuf>,
/// QUIC listen address (host:port).
#[arg(long, default_value = "0.0.0.0:4201", env = "QUICNPROTOCHAT_LISTEN")]
#[arg(long, default_value = DEFAULT_LISTEN, env = "QUICNPROTOCHAT_LISTEN")]
listen: String,
/// Directory for persisted server data (KeyPackages + delivery queues).
#[arg(long, default_value = "data", env = "QUICNPROTOCHAT_DATA_DIR")]
#[arg(long, default_value = DEFAULT_DATA_DIR, env = "QUICNPROTOCHAT_DATA_DIR")]
data_dir: String,
/// TLS certificate path (generated automatically if missing).
#[arg(
long,
default_value = "data/server-cert.der",
env = "QUICNPROTOCHAT_TLS_CERT"
)]
#[arg(long, default_value = DEFAULT_TLS_CERT, env = "QUICNPROTOCHAT_TLS_CERT")]
tls_cert: PathBuf,
/// TLS private key path (generated automatically if missing).
#[arg(
long,
default_value = "data/server-key.der",
env = "QUICNPROTOCHAT_TLS_KEY"
)]
#[arg(long, default_value = DEFAULT_TLS_KEY, env = "QUICNPROTOCHAT_TLS_KEY")]
tls_key: PathBuf,
/// Required bearer token for auth.version=1 requests. If unset, any non-empty token is accepted.
#[arg(long, env = "QUICNPROTOCHAT_AUTH_TOKEN")]
auth_token: Option<String>,
/// Storage backend: "file" (bincode) or "sql" (SQLCipher-encrypted).
#[arg(long, default_value = DEFAULT_STORE_BACKEND, env = "QUICNPROTOCHAT_STORE_BACKEND")]
store_backend: String,
/// Path to the SQLCipher database file (only used when --store-backend=sql).
#[arg(long, default_value = DEFAULT_DB_PATH, env = "QUICNPROTOCHAT_DB_PATH")]
db_path: PathBuf,
/// SQLCipher encryption key. Empty string disables encryption.
#[arg(long, default_value = "", env = "QUICNPROTOCHAT_DB_KEY")]
db_key: String,
}
// ── Node service implementation ─────────────────────────────────────────────
/// Cap'n Proto RPC server implementation for `NodeService` (Auth + Delivery).
struct NodeServiceImpl {
store: Arc<FileBackedStore>,
store: Arc<dyn Store>,
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
auth_cfg: Arc<AuthConfig>,
}
impl NodeServiceImpl {
@@ -114,6 +255,9 @@ impl node_service::Server for NodeServiceImpl {
let (identity_key, package) = match params {
Ok(p) => {
if let Err(e) = validate_auth(&self.auth_cfg, p.get_auth()) {
return Promise::err(e);
}
let ik = match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
@@ -177,6 +321,14 @@ impl node_service::Server for NodeServiceImpl {
},
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
};
if let Err(e) = params
.get()
.ok()
.map(|p| validate_auth(&self.auth_cfg, p.get_auth()))
.transpose()
{
return Promise::err(e);
}
if identity_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
@@ -234,6 +386,9 @@ impl node_service::Server for NodeServiceImpl {
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
if let Err(e) = validate_auth(&self.auth_cfg, p.get_auth()) {
return Promise::err(e);
}
if recipient_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
@@ -252,9 +407,9 @@ impl node_service::Server for NodeServiceImpl {
MAX_PAYLOAD_BYTES
)));
}
if version != 0 && version != CURRENT_WIRE_VERSION {
if version != CURRENT_WIRE_VERSION {
return Promise::err(capnp::Error::failed(format!(
"unsupported wire version {} (expected 0 or {CURRENT_WIRE_VERSION})",
"unsupported wire version {} (expected {CURRENT_WIRE_VERSION})",
version
)));
}
@@ -300,7 +455,15 @@ impl node_service::Server for NodeServiceImpl {
.get()
.ok()
.map(|p| p.get_version())
.unwrap_or(0);
.unwrap_or(CURRENT_WIRE_VERSION);
if let Err(e) = params
.get()
.ok()
.map(|p| validate_auth(&self.auth_cfg, p.get_auth()))
.transpose()
{
return Promise::err(e);
}
if recipient_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
@@ -308,9 +471,9 @@ impl node_service::Server for NodeServiceImpl {
recipient_key.len()
)));
}
if version != 0 && version != CURRENT_WIRE_VERSION {
if version != CURRENT_WIRE_VERSION {
return Promise::err(capnp::Error::failed(format!(
"unsupported wire version {} (expected 0 or {CURRENT_WIRE_VERSION})",
"unsupported wire version {} (expected {CURRENT_WIRE_VERSION})",
version
)));
}
@@ -355,6 +518,9 @@ impl node_service::Server for NodeServiceImpl {
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let timeout_ms = p.get_timeout_ms();
if let Err(e) = validate_auth(&self.auth_cfg, p.get_auth()) {
return Promise::err(e);
}
if recipient_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
@@ -362,9 +528,9 @@ impl node_service::Server for NodeServiceImpl {
recipient_key.len()
)));
}
if version != 0 && version != CURRENT_WIRE_VERSION {
if version != CURRENT_WIRE_VERSION {
return Promise::err(capnp::Error::failed(format!(
"unsupported wire version {} (expected 0 or {CURRENT_WIRE_VERSION})",
"unsupported wire version {} (expected {CURRENT_WIRE_VERSION})",
version
)));
}
@@ -403,6 +569,103 @@ impl node_service::Server for NodeServiceImpl {
results.get().set_status("ok");
Promise::ok(())
}
/// Store a hybrid (X25519 + ML-KEM-768) public key for an identity.
fn upload_hybrid_key(
&mut self,
params: node_service::UploadHybridKeyParams,
_results: node_service::UploadHybridKeyResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
};
let identity_key = match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
};
let hybrid_pk = match p.get_hybrid_public_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
};
if identity_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
"identityKey must be exactly 32 bytes, got {}",
identity_key.len()
)));
}
if hybrid_pk.is_empty() {
return Promise::err(capnp::Error::failed(
"hybridPublicKey must not be empty".to_string(),
));
}
if let Err(e) = self
.store
.upload_hybrid_key(&identity_key, hybrid_pk)
.map_err(storage_err)
{
return Promise::err(e);
}
tracing::debug!(
identity = %fmt_hex(&identity_key[..4]),
"hybrid public key uploaded"
);
Promise::ok(())
}
/// Fetch a peer's hybrid public key.
fn fetch_hybrid_key(
&mut self,
params: node_service::FetchHybridKeyParams,
mut results: node_service::FetchHybridKeyResults,
) -> Promise<(), capnp::Error> {
let identity_key = match params.get() {
Ok(p) => match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
},
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
};
if identity_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
"identityKey must be exactly 32 bytes, got {}",
identity_key.len()
)));
}
let hybrid_pk = match self
.store
.fetch_hybrid_key(&identity_key)
.map_err(storage_err)
{
Ok(p) => p,
Err(e) => return Promise::err(e),
};
match hybrid_pk {
Some(pk) => {
tracing::debug!(
identity = %fmt_hex(&identity_key[..4]),
"hybrid key fetched"
);
results.get().set_hybrid_public_key(&pk);
}
None => {
tracing::debug!(
identity = %fmt_hex(&identity_key[..4]),
"no hybrid key for identity"
);
results.get().set_hybrid_public_key(&[]);
}
}
Promise::ok(())
}
}
fn fill_payloads_wait(results: &mut node_service::FetchWaitResults, messages: Vec<Vec<u8>>) {
@@ -416,6 +679,42 @@ fn storage_err(err: StorageError) -> capnp::Error {
capnp::Error::failed(format!("{err}"))
}
fn validate_auth(
cfg: &AuthConfig,
auth: Result<auth::Reader<'_>, capnp::Error>,
) -> Result<(), capnp::Error> {
let auth = auth?;
let version = auth.get_version();
if version != 1 {
return Err(capnp::Error::failed(format!(
"unsupported auth version {} (expected 1)",
version
)));
}
let token = auth
.get_access_token()
.map_err(|e| capnp::Error::failed(format!("auth.accessToken: {e}")))?
.to_vec();
if token.is_empty() {
return Err(capnp::Error::failed(
"auth.version=1 requires non-empty accessToken".to_string(),
));
}
if let Some(expected) = &cfg.required_token {
if &token != expected {
return Err(capnp::Error::failed("invalid accessToken".to_string()));
}
}
// Early-development stance: no legacy/no-auth path to avoid maintaining divergent behavior.
Ok(())
}
// ── Entry point ───────────────────────────────────────────────────────────────
#[tokio::main]
@@ -428,20 +727,42 @@ async fn main() -> anyhow::Result<()> {
.init();
let args = Args::parse();
let file_cfg = load_config(args.config.as_deref())?;
let effective = merge_config(&args, &file_cfg);
let listen: SocketAddr = args.listen.parse().context("--listen must be host:port")?;
let listen: SocketAddr = effective
.listen
.parse()
.context("--listen must be host:port")?;
let server_config = build_server_config(&args.tls_cert, &args.tls_key)
let server_config = build_server_config(&effective.tls_cert, &effective.tls_key)
.context("failed to build TLS/QUIC server config")?;
// Shared storage — persisted to disk for restart safety.
let store = Arc::new(FileBackedStore::open(&args.data_dir)?);
let store: Arc<dyn Store> = match effective.store_backend.as_str() {
"sql" => {
if let Some(parent) = effective.db_path.parent() {
std::fs::create_dir_all(parent).context("create db dir")?;
}
tracing::info!(
path = %effective.db_path.display(),
encrypted = !effective.db_key.is_empty(),
"opening SQLCipher store"
);
Arc::new(SqlStore::open(&effective.db_path, &effective.db_key)?)
}
"file" | _ => {
tracing::info!(dir = %effective.data_dir, "opening file-backed store");
Arc::new(FileBackedStore::open(&effective.data_dir)?)
}
};
let auth_cfg = Arc::new(AuthConfig::new(effective.auth_token.clone()));
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = Arc::new(DashMap::new());
let endpoint = Endpoint::server(server_config, listen)?;
tracing::info!(
addr = %args.listen,
addr = %effective.listen,
"accepting QUIC connections"
);
@@ -466,8 +787,9 @@ async fn main() -> anyhow::Result<()> {
let store = Arc::clone(&store);
let waiters = Arc::clone(&waiters);
let auth_cfg = Arc::clone(&auth_cfg);
tokio::task::spawn_local(async move {
if let Err(e) = handle_node_connection(connecting, store, waiters).await {
if let Err(e) = handle_node_connection(connecting, store, waiters, auth_cfg).await {
tracing::warn!(error = %e, "connection error");
}
});
@@ -483,8 +805,9 @@ async fn main() -> anyhow::Result<()> {
/// Handle one NodeService connection.
async fn handle_node_connection(
connecting: quinn::Connecting,
store: Arc<FileBackedStore>,
store: Arc<dyn Store>,
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
auth_cfg: Arc<AuthConfig>,
) -> Result<(), anyhow::Error> {
let connection = connecting.await?;
@@ -498,7 +821,11 @@ async fn handle_node_connection(
let network = twoparty::VatNetwork::new(reader, writer, Side::Server, Default::default());
let service: node_service::Client = capnp_rpc::new_client(NodeServiceImpl { store, waiters });
let service: node_service::Client = capnp_rpc::new_client(NodeServiceImpl {
store,
waiters,
auth_cfg,
});
RpcSystem::new(Box::new(network), Some(service.client))
.await

View File

@@ -0,0 +1,315 @@
//! SQLCipher-backed persistent storage.
//!
//! Uses `rusqlite` with `bundled-sqlcipher` for encrypted-at-rest storage.
//! Implements the same [`Store`] trait as [`FileBackedStore`] but with proper
//! ACID transactions and indexed queries.
use std::path::Path;
use std::sync::Mutex;
use rusqlite::{params, Connection};
use crate::storage::{StorageError, Store};
/// SQLCipher-encrypted storage backend.
///
/// All data is stored in a single encrypted SQLite database. The encryption
/// key is set via `PRAGMA key` at open time.
pub struct SqlStore {
conn: Mutex<Connection>,
}
impl SqlStore {
/// Open (or create) an encrypted database at `path`.
///
/// `key` is the passphrase used by SQLCipher. Pass an empty string for an
/// unencrypted database (useful for testing).
pub fn open(path: impl AsRef<Path>, key: &str) -> Result<Self, StorageError> {
let conn = Connection::open(path).map_err(|e| StorageError::Db(e.to_string()))?;
if !key.is_empty() {
conn.pragma_update(None, "key", key)
.map_err(|e| StorageError::Db(format!("PRAGMA key failed: {e}")))?;
}
// Performance pragmas — safe for a single-writer server.
conn.execute_batch(
"PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA foreign_keys = ON;",
)
.map_err(|e| StorageError::Db(e.to_string()))?;
let store = Self {
conn: Mutex::new(conn),
};
store.migrate()?;
Ok(store)
}
/// Create schema tables if they don't exist yet.
fn migrate(&self) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS key_packages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
identity_key BLOB NOT NULL,
package_data BLOB NOT NULL,
created_at INTEGER DEFAULT (strftime('%s','now'))
);
CREATE TABLE IF NOT EXISTS deliveries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
recipient_key BLOB NOT NULL,
channel_id BLOB NOT NULL DEFAULT X'',
payload BLOB NOT NULL,
created_at INTEGER DEFAULT (strftime('%s','now'))
);
CREATE TABLE IF NOT EXISTS hybrid_keys (
identity_key BLOB PRIMARY KEY,
hybrid_public_key BLOB NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_kp_identity
ON key_packages(identity_key);
CREATE INDEX IF NOT EXISTS idx_del_recipient_channel
ON deliveries(recipient_key, channel_id);",
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(())
}
}
impl Store for SqlStore {
fn upload_key_package(
&self,
identity_key: &[u8],
package: Vec<u8>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
conn.execute(
"INSERT INTO key_packages (identity_key, package_data) VALUES (?1, ?2)",
params![identity_key, package],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(())
}
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
// Find the oldest KeyPackage (FIFO) and delete it atomically.
let mut stmt = conn
.prepare(
"SELECT id, package_data FROM key_packages
WHERE identity_key = ?1
ORDER BY id ASC
LIMIT 1",
)
.map_err(|e| StorageError::Db(e.to_string()))?;
let row = stmt
.query_row(params![identity_key], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, Vec<u8>>(1)?))
})
.optional()
.map_err(|e| StorageError::Db(e.to_string()))?;
match row {
Some((id, package)) => {
conn.execute("DELETE FROM key_packages WHERE id = ?1", params![id])
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(Some(package))
}
None => Ok(None),
}
}
fn enqueue(
&self,
recipient_key: &[u8],
channel_id: &[u8],
payload: Vec<u8>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
conn.execute(
"INSERT INTO deliveries (recipient_key, channel_id, payload) VALUES (?1, ?2, ?3)",
params![recipient_key, channel_id, payload],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(())
}
fn fetch(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<Vec<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
let mut stmt = conn
.prepare(
"SELECT id, payload FROM deliveries
WHERE recipient_key = ?1 AND channel_id = ?2
ORDER BY id ASC",
)
.map_err(|e| StorageError::Db(e.to_string()))?;
let rows: Vec<(i64, Vec<u8>)> = stmt
.query_map(params![recipient_key, channel_id], |row| {
Ok((row.get(0)?, row.get(1)?))
})
.map_err(|e| StorageError::Db(e.to_string()))?
.collect::<Result<Vec<_>, _>>()
.map_err(|e| StorageError::Db(e.to_string()))?;
if !rows.is_empty() {
let ids: Vec<i64> = rows.iter().map(|(id, _)| *id).collect();
// Delete fetched rows in a single statement.
let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})");
let params: Vec<&dyn rusqlite::types::ToSql> =
ids.iter().map(|id| id as &dyn rusqlite::types::ToSql).collect();
conn.execute(&sql, params.as_slice())
.map_err(|e| StorageError::Db(e.to_string()))?;
}
Ok(rows.into_iter().map(|(_, payload)| payload).collect())
}
fn upload_hybrid_key(
&self,
identity_key: &[u8],
hybrid_pk: Vec<u8>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
conn.execute(
"INSERT OR REPLACE INTO hybrid_keys (identity_key, hybrid_public_key) VALUES (?1, ?2)",
params![identity_key, hybrid_pk],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(())
}
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
let mut stmt = conn
.prepare("SELECT hybrid_public_key FROM hybrid_keys WHERE identity_key = ?1")
.map_err(|e| StorageError::Db(e.to_string()))?;
stmt.query_row(params![identity_key], |row| row.get(0))
.optional()
.map_err(|e| StorageError::Db(e.to_string()))
}
}
/// Convenience extension for `rusqlite::OptionalExtension`.
trait OptionalExt<T> {
fn optional(self) -> Result<Option<T>, rusqlite::Error>;
}
impl<T> OptionalExt<T> for Result<T, rusqlite::Error> {
fn optional(self) -> Result<Option<T>, rusqlite::Error> {
match self {
Ok(v) => Ok(Some(v)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn open_in_memory() -> SqlStore {
SqlStore::open(":memory:", "").unwrap()
}
#[test]
fn key_package_fifo() {
let store = open_in_memory();
let ik = b"alice_identity_key__32bytes_long";
// Pad to 32 bytes to match real usage
let mut identity = [0u8; 32];
identity[..ik.len()].copy_from_slice(ik);
store
.upload_key_package(&identity, b"kp1".to_vec())
.unwrap();
store
.upload_key_package(&identity, b"kp2".to_vec())
.unwrap();
assert_eq!(
store.fetch_key_package(&identity).unwrap(),
Some(b"kp1".to_vec())
);
assert_eq!(
store.fetch_key_package(&identity).unwrap(),
Some(b"kp2".to_vec())
);
assert_eq!(store.fetch_key_package(&identity).unwrap(), None);
}
#[test]
fn delivery_round_trip() {
let store = open_in_memory();
let rk = [1u8; 32];
let ch = b"channel-1";
store.enqueue(&rk, ch, b"msg1".to_vec()).unwrap();
store.enqueue(&rk, ch, b"msg2".to_vec()).unwrap();
let msgs = store.fetch(&rk, ch).unwrap();
assert_eq!(msgs, vec![b"msg1".to_vec(), b"msg2".to_vec()]);
// Queue is drained.
assert!(store.fetch(&rk, ch).unwrap().is_empty());
}
#[test]
fn hybrid_key_round_trip() {
let store = open_in_memory();
let ik = [2u8; 32];
let pk = b"hybrid_public_key_data".to_vec();
store.upload_hybrid_key(&ik, pk.clone()).unwrap();
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(pk));
}
#[test]
fn hybrid_key_upsert() {
let store = open_in_memory();
let ik = [3u8; 32];
store
.upload_hybrid_key(&ik, b"v1".to_vec())
.unwrap();
store
.upload_hybrid_key(&ik, b"v2".to_vec())
.unwrap();
assert_eq!(
store.fetch_hybrid_key(&ik).unwrap(),
Some(b"v2".to_vec())
);
}
#[test]
fn separate_channels_isolated() {
let store = open_in_memory();
let rk = [4u8; 32];
store.enqueue(&rk, b"ch-a", b"a1".to_vec()).unwrap();
store.enqueue(&rk, b"ch-b", b"b1".to_vec()).unwrap();
let a_msgs = store.fetch(&rk, b"ch-a").unwrap();
assert_eq!(a_msgs, vec![b"a1".to_vec()]);
let b_msgs = store.fetch(&rk, b"ch-b").unwrap();
assert_eq!(b_msgs, vec![b"b1".to_vec()]);
}
}

View File

@@ -1,7 +1,7 @@
use std::{
collections::{HashMap, VecDeque},
fs,
hash::{Hash, Hasher},
hash::Hash,
path::{Path, PathBuf},
sync::Mutex,
};
@@ -14,13 +14,46 @@ pub enum StorageError {
Io(String),
#[error("serialization error")]
Serde,
#[error("database error: {0}")]
Db(String),
}
#[derive(Serialize, Deserialize, Default)]
struct QueueMapV1 {
map: HashMap<Vec<u8>, VecDeque<Vec<u8>>>,
// ── Store trait ──────────────────────────────────────────────────────────────
/// Abstraction over storage backends (file-backed, SQLCipher, etc.).
pub trait Store: Send + Sync {
fn upload_key_package(
&self,
identity_key: &[u8],
package: Vec<u8>,
) -> Result<(), StorageError>;
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
fn enqueue(
&self,
recipient_key: &[u8],
channel_id: &[u8],
payload: Vec<u8>,
) -> Result<(), StorageError>;
fn fetch(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<Vec<Vec<u8>>, StorageError>;
fn upload_hybrid_key(
&self,
identity_key: &[u8],
hybrid_pk: Vec<u8>,
) -> Result<(), StorageError>;
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
}
// ── ChannelKey ───────────────────────────────────────────────────────────────
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
pub struct ChannelKey {
pub channel_id: Vec<u8>,
@@ -28,12 +61,19 @@ pub struct ChannelKey {
}
impl Hash for ChannelKey {
fn hash<H: Hasher>(&self, state: &mut H) {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.channel_id.hash(state);
self.recipient_key.hash(state);
}
}
// ── FileBackedStore ──────────────────────────────────────────────────────────
#[derive(Serialize, Deserialize, Default)]
struct QueueMapV1 {
map: HashMap<Vec<u8>, VecDeque<Vec<u8>>>,
}
#[derive(Serialize, Deserialize, Default)]
struct QueueMapV2 {
map: HashMap<ChannelKey, VecDeque<Vec<u8>>>,
@@ -45,8 +85,10 @@ struct QueueMapV2 {
pub struct FileBackedStore {
kp_path: PathBuf,
ds_path: PathBuf,
hk_path: PathBuf,
key_packages: Mutex<HashMap<Vec<u8>, VecDeque<Vec<u8>>>>,
deliveries: Mutex<HashMap<ChannelKey, VecDeque<Vec<u8>>>>,
hybrid_keys: Mutex<HashMap<Vec<u8>, Vec<u8>>>,
}
impl FileBackedStore {
@@ -57,73 +99,23 @@ impl FileBackedStore {
}
let kp_path = dir.join("keypackages.bin");
let ds_path = dir.join("deliveries.bin");
let hk_path = dir.join("hybridkeys.bin");
let key_packages = Mutex::new(Self::load_map(&kp_path)?);
let deliveries = Mutex::new(Self::load_map(&ds_path)?);
let key_packages = Mutex::new(Self::load_kp_map(&kp_path)?);
let deliveries = Mutex::new(Self::load_delivery_map(&ds_path)?);
let hybrid_keys = Mutex::new(Self::load_hybrid_keys(&hk_path)?);
Ok(Self {
kp_path,
ds_path,
hk_path,
key_packages,
deliveries,
hybrid_keys,
})
}
pub fn upload_key_package(
&self,
identity_key: &[u8],
package: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.key_packages.lock().unwrap();
map.entry(identity_key.to_vec())
.or_default()
.push_back(package);
self.flush_map(&self.kp_path, &*map)
}
pub fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let mut map = self.key_packages.lock().unwrap();
let package = map.get_mut(identity_key).and_then(|q| q.pop_front());
self.flush_map(&self.kp_path, &*map)?;
Ok(package)
}
pub fn enqueue(
&self,
recipient_key: &[u8],
channel_id: &[u8],
payload: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.deliveries.lock().unwrap();
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
map.entry(key)
.or_default()
.push_back(payload);
self.flush_map(&self.ds_path, &*map)
}
pub fn fetch(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<Vec<Vec<u8>>, StorageError> {
let mut map = self.deliveries.lock().unwrap();
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
let messages = map
.get_mut(&key)
.map(|q| q.drain(..).collect())
.unwrap_or_default();
self.flush_map(&self.ds_path, &*map)?;
Ok(messages)
}
fn load_map(path: &Path) -> Result<HashMap<ChannelKey, VecDeque<Vec<u8>>>, StorageError> {
fn load_kp_map(path: &Path) -> Result<HashMap<Vec<u8>, VecDeque<Vec<u8>>>, StorageError> {
if !path.exists() {
return Ok(HashMap::new());
}
@@ -131,7 +123,32 @@ impl FileBackedStore {
if bytes.is_empty() {
return Ok(HashMap::new());
}
// Try v2 format (channel-aware). Fallback to legacy v1.
let map: QueueMapV1 = bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)?;
Ok(map.map)
}
fn flush_kp_map(
&self,
path: &Path,
map: &HashMap<Vec<u8>, VecDeque<Vec<u8>>>,
) -> Result<(), StorageError> {
let payload = QueueMapV1 { map: map.clone() };
let bytes = bincode::serialize(&payload).map_err(|_| StorageError::Serde)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
}
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
}
fn load_delivery_map(path: &Path) -> Result<HashMap<ChannelKey, VecDeque<Vec<u8>>>, StorageError> {
if !path.exists() {
return Ok(HashMap::new());
}
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
if bytes.is_empty() {
return Ok(HashMap::new());
}
// Try v2 format (channel-aware). Fallback to legacy v1 for upgrade.
if let Ok(map) = bincode::deserialize::<QueueMapV2>(&bytes) {
return Ok(map.map);
}
@@ -149,7 +166,7 @@ impl FileBackedStore {
Ok(upgraded)
}
fn flush_map(
fn flush_delivery_map(
&self,
path: &Path,
map: &HashMap<ChannelKey, VecDeque<Vec<u8>>>,
@@ -161,4 +178,98 @@ impl FileBackedStore {
}
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
}
fn load_hybrid_keys(path: &Path) -> Result<HashMap<Vec<u8>, Vec<u8>>, StorageError> {
if !path.exists() {
return Ok(HashMap::new());
}
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
if bytes.is_empty() {
return Ok(HashMap::new());
}
bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)
}
fn flush_hybrid_keys(
&self,
path: &Path,
map: &HashMap<Vec<u8>, Vec<u8>>,
) -> Result<(), StorageError> {
let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
}
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
}
}
impl Store for FileBackedStore {
fn upload_key_package(
&self,
identity_key: &[u8],
package: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.key_packages.lock().unwrap();
map.entry(identity_key.to_vec())
.or_default()
.push_back(package);
self.flush_kp_map(&self.kp_path, &*map)
}
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let mut map = self.key_packages.lock().unwrap();
let package = map.get_mut(identity_key).and_then(|q| q.pop_front());
self.flush_kp_map(&self.kp_path, &*map)?;
Ok(package)
}
fn enqueue(
&self,
recipient_key: &[u8],
channel_id: &[u8],
payload: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.deliveries.lock().unwrap();
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
map.entry(key)
.or_default()
.push_back(payload);
self.flush_delivery_map(&self.ds_path, &*map)
}
fn fetch(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<Vec<Vec<u8>>, StorageError> {
let mut map = self.deliveries.lock().unwrap();
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
let messages = map
.get_mut(&key)
.map(|q| q.drain(..).collect())
.unwrap_or_default();
self.flush_delivery_map(&self.ds_path, &*map)?;
Ok(messages)
}
fn upload_hybrid_key(
&self,
identity_key: &[u8],
hybrid_pk: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.hybrid_keys.lock().unwrap();
map.insert(identity_key.to_vec(), hybrid_pk);
self.flush_hybrid_keys(&self.hk_path, &*map)
}
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let map = self.hybrid_keys.lock().unwrap();
Ok(map.get(identity_key).cloned())
}
}