feat: add protocol comparison docs, P2P crate, production audit, and design fixes

Add comprehensive documentation comparing quicnprotochat against classical
chat protocols (IRC+SSL, XMPP, Telegram) with diagrams and attack scenarios.
Promote comparison pages to top-level sidebar section. Include P2P transport
crate (iroh), production readiness audit, CI workflows, dependency policy,
and continued architecture improvements across all crates.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-22 12:15:44 +01:00
parent 0bdc222724
commit 00b0aa92a1
28 changed files with 1566 additions and 340 deletions

View File

@@ -23,6 +23,7 @@ pub const E017_SESSION_EXPIRED: &str = "E017";
pub const E018_USER_EXISTS: &str = "E018";
pub const E019_NO_PENDING_LOGIN: &str = "E019";
pub const E020_BAD_PARAMS: &str = "E020";
pub const E021_CIPHERSUITE_NOT_ALLOWED: &str = "E021";
/// Build a `capnp::Error::failed()` with the structured code prefix.
pub fn coded_error(code: &str, msg: impl std::fmt::Display) -> capnp::Error {

View File

@@ -13,10 +13,15 @@
//! The entire RPC stack lives on a `tokio::task::LocalSet` spawned per
//! connection.
use std::{fs, net::SocketAddr, path::{Path, PathBuf}, sync::Arc, time::Duration};
use std::{
fs,
net::SocketAddr,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use anyhow::Context;
use serde::Deserialize;
use capnp::capability::Promise;
use capnp_rpc::{rpc_twoparty_capnp::Side, twoparty, RpcSystem};
use clap::Parser;
@@ -33,6 +38,7 @@ use rand::rngs::OsRng;
use rcgen::generate_simple_self_signed;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::version::TLS13;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use tokio::sync::Notify;
@@ -44,11 +50,11 @@ mod sql_store;
mod storage;
use error_codes::*;
use sql_store::SqlStore;
use storage::{FileBackedStore, Store, StorageError};
use storage::{FileBackedStore, StorageError, Store};
const MAX_PAYLOAD_BYTES: usize = 5 * 1024 * 1024; // 5 MB cap per message
const MAX_KEYPACKAGE_BYTES: usize = 1 * 1024 * 1024; // 1 MB cap per KeyPackage
const CURRENT_WIRE_VERSION: u16 = 1; // legacy disabled; current wire version only
const CURRENT_WIRE_VERSION: u16 = 1;
const DEFAULT_LISTEN: &str = "0.0.0.0:7000";
const DEFAULT_DATA_DIR: &str = "data";
@@ -71,7 +77,9 @@ struct AuthConfig {
impl AuthConfig {
fn new(required_token: Option<String>) -> Self {
let required_token = required_token.filter(|s| !s.is_empty()).map(|s| s.into_bytes());
let required_token = required_token
.filter(|s| !s.is_empty())
.map(|s| s.into_bytes());
Self { required_token }
}
}
@@ -110,34 +118,42 @@ fn load_config(path: Option<&Path>) -> anyhow::Result<FileConfig> {
return Ok(FileConfig::default());
}
let contents = fs::read_to_string(&path)
.with_context(|| format!("read config file {path:?}"))?;
let cfg: FileConfig = toml::from_str(&contents)
.with_context(|| format!("parse config file {path:?}"))?;
let contents =
fs::read_to_string(&path).with_context(|| format!("read config file {path:?}"))?;
let cfg: FileConfig =
toml::from_str(&contents).with_context(|| format!("parse config file {path:?}"))?;
Ok(cfg)
}
fn merge_config(args: &Args, file: &FileConfig) -> EffectiveConfig {
let listen = if args.listen == DEFAULT_LISTEN {
file.listen.clone().unwrap_or_else(|| DEFAULT_LISTEN.to_string())
file.listen
.clone()
.unwrap_or_else(|| DEFAULT_LISTEN.to_string())
} else {
args.listen.clone()
};
let data_dir = if args.data_dir == DEFAULT_DATA_DIR {
file.data_dir.clone().unwrap_or_else(|| DEFAULT_DATA_DIR.to_string())
file.data_dir
.clone()
.unwrap_or_else(|| DEFAULT_DATA_DIR.to_string())
} else {
args.data_dir.clone()
};
let tls_cert = if args.tls_cert == PathBuf::from(DEFAULT_TLS_CERT) {
file.tls_cert.clone().unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_CERT))
file.tls_cert
.clone()
.unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_CERT))
} else {
args.tls_cert.clone()
};
let tls_key = if args.tls_key == PathBuf::from(DEFAULT_TLS_KEY) {
file.tls_key.clone().unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_KEY))
file.tls_key
.clone()
.unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_KEY))
} else {
args.tls_key.clone()
};
@@ -231,7 +247,11 @@ struct Args {
// ── Session management ──────────────────────────────────────────────────────
struct SessionInfo {
/// For future audit logging.
#[allow(dead_code)]
username: String,
/// For future audit logging.
#[allow(dead_code)]
identity_key: Vec<u8>,
#[allow(dead_code)]
created_at: u64,
@@ -289,9 +309,12 @@ impl node_service::Server for NodeServiceImpl {
params: node_service::UploadKeyPackageParams,
mut results: node_service::UploadKeyPackageResults,
) -> Promise<(), capnp::Error> {
let params = params
.get()
.map_err(|e| coded_error(E020_BAD_PARAMS, format!("upload_key_package: bad params: {e}")));
let params = params.get().map_err(|e| {
coded_error(
E020_BAD_PARAMS,
format!("upload_key_package: bad params: {e}"),
)
});
let (identity_key, package) = match params {
Ok(p) => {
@@ -314,7 +337,10 @@ impl node_service::Server for NodeServiceImpl {
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
format!(
"identityKey must be exactly 32 bytes, got {}",
identity_key.len()
),
));
}
if package.is_empty() {
@@ -327,6 +353,14 @@ impl node_service::Server for NodeServiceImpl {
));
}
// Phase 2: ciphersuite allowlist — reject KeyPackages not using the allowed MLS ciphersuite.
if let Err(e) = quicnprotochat_core::validate_keypackage_ciphersuite(&package) {
return Promise::err(coded_error(
E021_CIPHERSUITE_NOT_ALLOWED,
format!("KeyPackage ciphersuite not allowed: {e}"),
));
}
let fingerprint: Vec<u8> = Sha256::digest(&package).to_vec();
if let Err(e) = self
.store
@@ -371,7 +405,10 @@ impl node_service::Server for NodeServiceImpl {
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
format!(
"identityKey must be exactly 32 bytes, got {}",
identity_key.len()
),
));
}
@@ -424,15 +461,19 @@ impl node_service::Server for NodeServiceImpl {
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let auth_token = match validate_auth_return_token(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(t) => t,
Err(e) => return Promise::err(e),
};
let auth_token =
match validate_auth_return_token(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(t) => t,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
format!(
"recipientKey must be exactly 32 bytes, got {}",
recipient_key.len()
),
));
}
if payload.is_empty() {
@@ -447,7 +488,10 @@ impl node_service::Server for NodeServiceImpl {
if version != CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
format!(
"unsupported wire version {} (expected {CURRENT_WIRE_VERSION})",
version
),
));
}
@@ -510,11 +554,7 @@ impl node_service::Server for NodeServiceImpl {
.ok()
.map(|p| p.get_version())
.unwrap_or(CURRENT_WIRE_VERSION);
let limit = params
.get()
.ok()
.map(|p| p.get_limit())
.unwrap_or(0);
let limit = params.get().ok().map(|p| p.get_limit()).unwrap_or(0);
if let Err(e) = params
.get()
.ok()
@@ -527,23 +567,37 @@ impl node_service::Server for NodeServiceImpl {
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
format!(
"recipientKey must be exactly 32 bytes, got {}",
recipient_key.len()
),
));
}
if version != CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
format!(
"unsupported wire version {} (expected {CURRENT_WIRE_VERSION})",
version
),
));
}
let messages = if limit > 0 {
match self.store.fetch_limited(&recipient_key, &channel_id, limit as usize).map_err(storage_err) {
match self
.store
.fetch_limited(&recipient_key, &channel_id, limit as usize)
.map_err(storage_err)
{
Ok(m) => m,
Err(e) => return Promise::err(e),
}
} else {
match self.store.fetch(&recipient_key, &channel_id).map_err(storage_err) {
match self
.store
.fetch(&recipient_key, &channel_id)
.map_err(storage_err)
{
Ok(m) => m,
Err(e) => return Promise::err(e),
}
@@ -588,13 +642,19 @@ impl node_service::Server for NodeServiceImpl {
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
format!(
"recipientKey must be exactly 32 bytes, got {}",
recipient_key.len()
),
));
}
if version != CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
format!(
"unsupported wire version {} (expected {CURRENT_WIRE_VERSION})",
version
),
));
}
@@ -602,7 +662,11 @@ impl node_service::Server for NodeServiceImpl {
let waiters = self.waiters.clone();
Promise::from_future(async move {
let fetch_fn = |s: &Arc<dyn Store>, rk: &[u8], ch: &[u8], lim: u32| -> Result<Vec<Vec<u8>>, capnp::Error> {
let fetch_fn = |s: &Arc<dyn Store>,
rk: &[u8],
ch: &[u8],
lim: u32|
-> Result<Vec<Vec<u8>>, capnp::Error> {
if lim > 0 {
s.fetch_limited(rk, ch, lim as usize).map_err(storage_err)
} else {
@@ -664,7 +728,10 @@ impl node_service::Server for NodeServiceImpl {
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
format!(
"identityKey must be exactly 32 bytes, got {}",
identity_key.len()
),
));
}
if hybrid_pk.is_empty() {
@@ -713,7 +780,10 @@ impl node_service::Server for NodeServiceImpl {
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
format!(
"identityKey must be exactly 32 bytes, got {}",
identity_key.len()
),
));
}
@@ -767,7 +837,10 @@ impl node_service::Server for NodeServiceImpl {
};
if username.is_empty() {
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
return Promise::err(coded_error(
E011_USERNAME_EMPTY,
"username must not be empty",
));
}
let reg_request = match RegistrationRequest::<OpaqueSuite>::deserialize(&request_bytes) {
@@ -821,7 +894,10 @@ impl node_service::Server for NodeServiceImpl {
let identity_key = p.get_identity_key().unwrap_or_default().to_vec();
if username.is_empty() {
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
return Promise::err(coded_error(
E011_USERNAME_EMPTY,
"username must not be empty",
));
}
// Fix 5: Registration collision check
@@ -894,19 +970,22 @@ impl node_service::Server for NodeServiceImpl {
};
if username.is_empty() {
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
return Promise::err(coded_error(
E011_USERNAME_EMPTY,
"username must not be empty",
));
}
let credential_request =
match CredentialRequest::<OpaqueSuite>::deserialize(&request_bytes) {
Ok(r) => r,
Err(e) => {
return Promise::err(coded_error(
E010_OPAQUE_ERROR,
format!("invalid credential request: {e}"),
))
}
};
let credential_request = match CredentialRequest::<OpaqueSuite>::deserialize(&request_bytes)
{
Ok(r) => r,
Err(e) => {
return Promise::err(coded_error(
E010_OPAQUE_ERROR,
format!("invalid credential request: {e}"),
))
}
};
// Load user's OPAQUE password file (if registered).
let password_file = match self.store.get_user_record(&username) {
@@ -978,7 +1057,10 @@ impl node_service::Server for NodeServiceImpl {
let identity_key = p.get_identity_key().unwrap_or_default().to_vec();
if username.is_empty() {
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
return Promise::err(coded_error(
E011_USERNAME_EMPTY,
"username must not be empty",
));
}
// Retrieve the pending ServerLogin state.
@@ -1081,11 +1163,18 @@ impl node_service::Server for NodeServiceImpl {
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
format!(
"identityKey must be exactly 32 bytes, got {}",
identity_key.len()
),
));
}
if let Err(e) = self.store.publish_endpoint(&identity_key, node_addr).map_err(storage_err) {
if let Err(e) = self
.store
.publish_endpoint(&identity_key, node_addr)
.map_err(storage_err)
{
return Promise::err(e);
}
@@ -1113,11 +1202,18 @@ impl node_service::Server for NodeServiceImpl {
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
format!(
"identityKey must be exactly 32 bytes, got {}",
identity_key.len()
),
));
}
match self.store.resolve_endpoint(&identity_key).map_err(storage_err) {
match self
.store
.resolve_endpoint(&identity_key)
.map_err(storage_err)
{
Ok(Some(addr)) => {
results.get().set_node_addr(&addr);
}
@@ -1148,9 +1244,10 @@ fn check_rate_limit(
token: &[u8],
) -> Result<(), capnp::Error> {
let now = current_timestamp();
let mut entry = rate_limits
.entry(token.to_vec())
.or_insert(RateEntry { count: 0, window_start: now });
let mut entry = rate_limits.entry(token.to_vec()).or_insert(RateEntry {
count: 0,
window_start: now,
});
if now - entry.window_start >= RATE_LIMIT_WINDOW_SECS {
entry.count = 1;
@@ -1222,17 +1319,14 @@ fn validate_auth_return_token(
// Expired — will be cleaned up by background task.
drop(session);
sessions.remove(&token);
return Err(coded_error(E017_SESSION_EXPIRED, "session token has expired"));
return Err(coded_error(
E017_SESSION_EXPIRED,
"session token has expired",
));
}
// If a static token is configured but neither matched, reject.
if cfg.required_token.is_some() {
return Err(coded_error(E003_INVALID_TOKEN, "invalid accessToken"));
}
// No static token configured and no session match — accept any non-empty
// token for backward compatibility (dev mode).
Ok(token)
// Require either static token or valid session; no legacy accept-any-token.
Err(coded_error(E003_INVALID_TOKEN, "invalid accessToken"))
}
// ── Entry point ───────────────────────────────────────────────────────────────
@@ -1250,12 +1344,19 @@ async fn main() -> anyhow::Result<()> {
let file_cfg = load_config(args.config.as_deref())?;
let effective = merge_config(&args, &file_cfg);
let production = std::env::var("QUICNPROTOCHAT_PRODUCTION")
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
.unwrap_or(false);
if production {
validate_production_config(&effective)?;
}
let listen: SocketAddr = effective
.listen
.parse()
.context("--listen must be host:port")?;
let server_config = build_server_config(&effective.tls_cert, &effective.tls_key)
let server_config = build_server_config(&effective.tls_cert, &effective.tls_key, production)
.context("failed to build TLS/QUIC server config")?;
// Shared storage — persisted to disk for restart safety.
@@ -1322,11 +1423,14 @@ async fn main() -> anyhow::Result<()> {
pending_logins.retain(|_, pl| now - pl.created_at < PENDING_LOGIN_TTL_SECS);
// Expire stale rate limit entries (Fix 6)
rate_limits.retain(|_, entry| now - entry.window_start < RATE_LIMIT_WINDOW_SECS * 2);
rate_limits
.retain(|_, entry| now - entry.window_start < RATE_LIMIT_WINDOW_SECS * 2);
// GC expired messages (Fix 7)
match store.gc_expired_messages(MESSAGE_TTL_SECS) {
Ok(n) if n > 0 => tracing::debug!(expired = n, "garbage collected expired messages"),
Ok(n) if n > 0 => {
tracing::debug!(expired = n, "garbage collected expired messages")
}
Err(e) => tracing::warn!(error = %e, "message GC failed"),
_ => {}
}
@@ -1347,42 +1451,54 @@ async fn main() -> anyhow::Result<()> {
local
.run_until(async move {
loop {
let incoming = match endpoint.accept().await {
Some(i) => i,
None => break,
};
tokio::select! {
biased;
let connecting = match incoming.accept() {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, "failed to accept incoming connection");
continue;
}
};
incoming = endpoint.accept() => {
let incoming = match incoming {
Some(i) => i,
None => break,
};
let store = Arc::clone(&store);
let waiters = Arc::clone(&waiters);
let auth_cfg = Arc::clone(&auth_cfg);
let opaque_setup = Arc::clone(&opaque_setup);
let pending_logins = Arc::clone(&pending_logins);
let sessions = Arc::clone(&sessions);
let rate_limits = Arc::clone(&rate_limits);
tokio::task::spawn_local(async move {
if let Err(e) = handle_node_connection(
connecting,
store,
waiters,
auth_cfg,
opaque_setup,
pending_logins,
sessions,
rate_limits,
)
.await
{
tracing::warn!(error = %e, "connection error");
let connecting = match incoming.accept() {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, "failed to accept incoming connection");
continue;
}
};
let store = Arc::clone(&store);
let waiters = Arc::clone(&waiters);
let auth_cfg = Arc::clone(&auth_cfg);
let opaque_setup = Arc::clone(&opaque_setup);
let pending_logins = Arc::clone(&pending_logins);
let sessions = Arc::clone(&sessions);
let rate_limits = Arc::clone(&rate_limits);
tokio::task::spawn_local(async move {
if let Err(e) = handle_node_connection(
connecting,
store,
waiters,
auth_cfg,
opaque_setup,
pending_logins,
sessions,
rate_limits,
)
.await
{
tracing::warn!(error = %e, "connection error");
}
});
}
});
_ = tokio::signal::ctrl_c() => {
tracing::info!("shutdown signal received, draining QUIC connections");
endpoint.close(0u32.into(), b"server shutdown");
break;
}
}
}
Ok::<(), anyhow::Error>(())
@@ -1393,6 +1509,7 @@ async fn main() -> anyhow::Result<()> {
// ── Per-connection handlers ───────────────────────────────────────────────────
/// Handle one NodeService connection.
#[allow(clippy::too_many_arguments)]
async fn handle_node_connection(
connecting: quinn::Connecting,
store: Arc<dyn Store>,
@@ -1438,9 +1555,45 @@ fn fmt_hex(bytes: &[u8]) -> String {
format!("{hex}")
}
fn validate_production_config(effective: &EffectiveConfig) -> anyhow::Result<()> {
let token = effective
.auth_token
.as_deref()
.filter(|s| !s.is_empty())
.ok_or_else(|| {
anyhow::anyhow!("production requires QUICNPROTOCHAT_AUTH_TOKEN (non-empty)")
})?;
if token == "devtoken" {
anyhow::bail!(
"production forbids auth_token 'devtoken'; set a strong QUICNPROTOCHAT_AUTH_TOKEN"
);
}
if effective.store_backend == "sql" && effective.db_key.is_empty() {
anyhow::bail!("production with store_backend=sql requires non-empty QUICNPROTOCHAT_DB_KEY");
}
if !effective.tls_cert.exists() || !effective.tls_key.exists() {
anyhow::bail!(
"production requires existing TLS cert and key (no auto-generation); provide QUICNPROTOCHAT_TLS_CERT and QUICNPROTOCHAT_TLS_KEY"
);
}
Ok(())
}
/// Ensure a self-signed certificate exists on disk and return a QUIC server config.
fn build_server_config(cert_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result<ServerConfig> {
/// When `production` is true, cert and key must already exist (no auto-generation).
fn build_server_config(
cert_path: &PathBuf,
key_path: &PathBuf,
production: bool,
) -> anyhow::Result<ServerConfig> {
if !cert_path.exists() || !key_path.exists() {
if production {
anyhow::bail!(
"TLS cert or key missing at {:?} / {:?}; production mode forbids auto-generation",
cert_path,
key_path
);
}
generate_self_signed_cert(cert_path, key_path)?;
}

View File

@@ -13,6 +13,12 @@ pub struct SqlStore {
}
impl SqlStore {
fn lock_conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>, StorageError> {
self.conn
.lock()
.map_err(|e| StorageError::Db(format!("lock poisoned: {e}")))
}
pub fn open(path: impl AsRef<Path>, key: &str) -> Result<Self, StorageError> {
let conn = Connection::open(path).map_err(|e| StorageError::Db(e.to_string()))?;
@@ -36,7 +42,7 @@ impl SqlStore {
}
fn migrate(&self) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS key_packages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -97,7 +103,7 @@ impl Store for SqlStore {
identity_key: &[u8],
package: Vec<u8>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
conn.execute(
"INSERT INTO key_packages (identity_key, package_data) VALUES (?1, ?2)",
params![identity_key, package],
@@ -107,7 +113,7 @@ impl Store for SqlStore {
}
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
let mut stmt = conn
.prepare(
@@ -141,7 +147,7 @@ impl Store for SqlStore {
channel_id: &[u8],
payload: Vec<u8>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
conn.execute(
"INSERT INTO deliveries (recipient_key, channel_id, payload) VALUES (?1, ?2, ?3)",
params![recipient_key, channel_id, payload],
@@ -150,12 +156,8 @@ impl Store for SqlStore {
Ok(())
}
fn fetch(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<Vec<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
fn fetch(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result<Vec<Vec<u8>>, StorageError> {
let conn = self.lock_conn()?;
let mut stmt = conn
.prepare(
@@ -177,8 +179,10 @@ impl Store for SqlStore {
let ids: Vec<i64> = rows.iter().map(|(id, _)| *id).collect();
let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})");
let params: Vec<&dyn rusqlite::types::ToSql> =
ids.iter().map(|id| id as &dyn rusqlite::types::ToSql).collect();
let params: Vec<&dyn rusqlite::types::ToSql> = ids
.iter()
.map(|id| id as &dyn rusqlite::types::ToSql)
.collect();
conn.execute(&sql, params.as_slice())
.map_err(|e| StorageError::Db(e.to_string()))?;
}
@@ -192,7 +196,7 @@ impl Store for SqlStore {
channel_id: &[u8],
limit: usize,
) -> Result<Vec<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
let mut stmt = conn
.prepare(
@@ -215,8 +219,10 @@ impl Store for SqlStore {
let ids: Vec<i64> = rows.iter().map(|(id, _)| *id).collect();
let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})");
let params: Vec<&dyn rusqlite::types::ToSql> =
ids.iter().map(|id| id as &dyn rusqlite::types::ToSql).collect();
let params: Vec<&dyn rusqlite::types::ToSql> = ids
.iter()
.map(|id| id as &dyn rusqlite::types::ToSql)
.collect();
conn.execute(&sql, params.as_slice())
.map_err(|e| StorageError::Db(e.to_string()))?;
}
@@ -224,12 +230,8 @@ impl Store for SqlStore {
Ok(rows.into_iter().map(|(_, payload)| payload).collect())
}
fn queue_depth(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<usize, StorageError> {
let conn = self.conn.lock().unwrap();
fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result<usize, StorageError> {
let conn = self.lock_conn()?;
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2",
@@ -241,7 +243,7 @@ impl Store for SqlStore {
}
fn gc_expired_messages(&self, max_age_secs: u64) -> Result<usize, StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
let cutoff = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
@@ -261,7 +263,7 @@ impl Store for SqlStore {
identity_key: &[u8],
hybrid_pk: Vec<u8>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
conn.execute(
"INSERT OR REPLACE INTO hybrid_keys (identity_key, hybrid_public_key) VALUES (?1, ?2)",
params![identity_key, hybrid_pk],
@@ -271,7 +273,7 @@ impl Store for SqlStore {
}
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
let mut stmt = conn
.prepare("SELECT hybrid_public_key FROM hybrid_keys WHERE identity_key = ?1")
.map_err(|e| StorageError::Db(e.to_string()))?;
@@ -282,7 +284,7 @@ impl Store for SqlStore {
}
fn store_server_setup(&self, setup: Vec<u8>) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
conn.execute(
"INSERT OR REPLACE INTO server_setup (id, setup_data) VALUES (1, ?1)",
params![setup],
@@ -292,7 +294,7 @@ impl Store for SqlStore {
}
fn get_server_setup(&self) -> Result<Option<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
let mut stmt = conn
.prepare("SELECT setup_data FROM server_setup WHERE id = 1")
.map_err(|e| StorageError::Db(e.to_string()))?;
@@ -303,7 +305,7 @@ impl Store for SqlStore {
}
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
conn.execute(
"INSERT OR REPLACE INTO users (username, opaque_record) VALUES (?1, ?2)",
params![username, record],
@@ -313,7 +315,7 @@ impl Store for SqlStore {
}
fn get_user_record(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
let mut stmt = conn
.prepare("SELECT opaque_record FROM users WHERE username = ?1")
.map_err(|e| StorageError::Db(e.to_string()))?;
@@ -324,7 +326,7 @@ impl Store for SqlStore {
}
fn has_user_record(&self, username: &str) -> Result<bool, StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
let exists: bool = conn
.query_row(
"SELECT EXISTS(SELECT 1 FROM users WHERE username = ?1)",
@@ -340,7 +342,7 @@ impl Store for SqlStore {
username: &str,
identity_key: Vec<u8>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
conn.execute(
"INSERT OR REPLACE INTO user_identity_keys (username, identity_key) VALUES (?1, ?2)",
params![username, identity_key],
@@ -350,7 +352,7 @@ impl Store for SqlStore {
}
fn get_user_identity_key(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
let mut stmt = conn
.prepare("SELECT identity_key FROM user_identity_keys WHERE username = ?1")
.map_err(|e| StorageError::Db(e.to_string()))?;
@@ -365,7 +367,7 @@ impl Store for SqlStore {
identity_key: &[u8],
node_addr: Vec<u8>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
conn.execute(
"INSERT OR REPLACE INTO endpoints (identity_key, node_addr) VALUES (?1, ?2)",
params![identity_key, node_addr],
@@ -375,7 +377,7 @@ impl Store for SqlStore {
}
fn resolve_endpoint(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
let conn = self.lock_conn()?;
let mut stmt = conn
.prepare("SELECT node_addr FROM endpoints WHERE identity_key = ?1")
.map_err(|e| StorageError::Db(e.to_string()))?;
@@ -481,7 +483,9 @@ mod tests {
fn has_user_record_check() {
let store = open_in_memory();
assert!(!store.has_user_record("alice").unwrap());
store.store_user_record("alice", b"record".to_vec()).unwrap();
store
.store_user_record("alice", b"record".to_vec())
.unwrap();
assert!(store.has_user_record("alice").unwrap());
assert!(!store.has_user_record("bob").unwrap());
}
@@ -490,8 +494,13 @@ mod tests {
fn user_identity_key_round_trip() {
let store = open_in_memory();
assert!(store.get_user_identity_key("alice").unwrap().is_none());
store.store_user_identity_key("alice", vec![1u8; 32]).unwrap();
assert_eq!(store.get_user_identity_key("alice").unwrap(), Some(vec![1u8; 32]));
store
.store_user_identity_key("alice", vec![1u8; 32])
.unwrap();
assert_eq!(
store.get_user_identity_key("alice").unwrap(),
Some(vec![1u8; 32])
);
}
#[test]

View File

@@ -18,15 +18,17 @@ pub enum StorageError {
Db(String),
}
fn lock<T>(m: &Mutex<T>) -> Result<std::sync::MutexGuard<'_, T>, StorageError> {
m.lock()
.map_err(|e| StorageError::Io(format!("lock poisoned: {e}")))
}
// ── Store trait ──────────────────────────────────────────────────────────────
/// Abstraction over storage backends (file-backed, SQLCipher, etc.).
pub trait Store: Send + Sync {
fn upload_key_package(
&self,
identity_key: &[u8],
package: Vec<u8>,
) -> Result<(), StorageError>;
fn upload_key_package(&self, identity_key: &[u8], package: Vec<u8>)
-> Result<(), StorageError>;
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
@@ -37,11 +39,7 @@ pub trait Store: Send + Sync {
payload: Vec<u8>,
) -> Result<(), StorageError>;
fn fetch(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<Vec<Vec<u8>>, StorageError>;
fn fetch(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result<Vec<Vec<u8>>, StorageError>;
/// Fetch up to `limit` messages without draining the entire queue (Fix 8).
fn fetch_limited(
@@ -52,11 +50,7 @@ pub trait Store: Send + Sync {
) -> Result<Vec<Vec<u8>>, StorageError>;
/// Return the number of queued messages for (recipient, channel) (Fix 7).
fn queue_depth(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<usize, StorageError>;
fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result<usize, StorageError>;
/// Delete messages older than `max_age_secs`. Returns count deleted (Fix 7).
fn gc_expired_messages(&self, max_age_secs: u64) -> Result<usize, StorageError>;
@@ -95,11 +89,8 @@ pub trait Store: Send + Sync {
fn get_user_identity_key(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError>;
/// Publish a P2P endpoint address for an identity key.
fn publish_endpoint(
&self,
identity_key: &[u8],
node_addr: Vec<u8>,
) -> Result<(), StorageError>;
fn publish_endpoint(&self, identity_key: &[u8], node_addr: Vec<u8>)
-> Result<(), StorageError>;
/// Resolve a peer's P2P endpoint address.
fn resolve_endpoint(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
@@ -210,7 +201,9 @@ impl FileBackedStore {
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
}
fn load_delivery_map(path: &Path) -> Result<HashMap<ChannelKey, VecDeque<Vec<u8>>>, StorageError> {
fn load_delivery_map(
path: &Path,
) -> Result<HashMap<ChannelKey, VecDeque<Vec<u8>>>, StorageError> {
if !path.exists() {
return Ok(HashMap::new());
}
@@ -218,22 +211,9 @@ impl FileBackedStore {
if bytes.is_empty() {
return Ok(HashMap::new());
}
// Try v2 format (channel-aware). Fallback to legacy v1 for upgrade.
if let Ok(map) = bincode::deserialize::<QueueMapV2>(&bytes) {
return Ok(map.map);
}
let legacy: QueueMapV1 = bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)?;
let mut upgraded = HashMap::new();
for (recipient_key, queue) in legacy.map.into_iter() {
upgraded.insert(
ChannelKey {
channel_id: Vec::new(),
recipient_key,
},
queue,
);
}
Ok(upgraded)
bincode::deserialize::<QueueMapV2>(&bytes)
.map(|v| v.map)
.map_err(|_| StorageError::Io("deliveries file: v1 format no longer supported; delete or migrate".into()))
}
fn flush_delivery_map(
@@ -283,11 +263,7 @@ impl FileBackedStore {
bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)
}
fn flush_users(
&self,
path: &Path,
map: &HashMap<String, Vec<u8>>,
) -> Result<(), StorageError> {
fn flush_users(&self, path: &Path, map: &HashMap<String, Vec<u8>>) -> Result<(), StorageError> {
let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
@@ -314,7 +290,7 @@ impl Store for FileBackedStore {
identity_key: &[u8],
package: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.key_packages.lock().unwrap();
let mut map = lock(&self.key_packages)?;
map.entry(identity_key.to_vec())
.or_default()
.push_back(package);
@@ -322,7 +298,7 @@ impl Store for FileBackedStore {
}
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let mut map = self.key_packages.lock().unwrap();
let mut map = lock(&self.key_packages)?;
let package = map.get_mut(identity_key).and_then(|q| q.pop_front());
self.flush_kp_map(&self.kp_path, &*map)?;
Ok(package)
@@ -334,23 +310,17 @@ impl Store for FileBackedStore {
channel_id: &[u8],
payload: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.deliveries.lock().unwrap();
let mut map = lock(&self.deliveries)?;
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
map.entry(key)
.or_default()
.push_back(payload);
map.entry(key).or_default().push_back(payload);
self.flush_delivery_map(&self.ds_path, &*map)
}
fn fetch(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<Vec<Vec<u8>>, StorageError> {
let mut map = self.deliveries.lock().unwrap();
fn fetch(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result<Vec<Vec<u8>>, StorageError> {
let mut map = lock(&self.deliveries)?;
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
@@ -369,7 +339,7 @@ impl Store for FileBackedStore {
channel_id: &[u8],
limit: usize,
) -> Result<Vec<Vec<u8>>, StorageError> {
let mut map = self.deliveries.lock().unwrap();
let mut map = lock(&self.deliveries)?;
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
@@ -385,12 +355,8 @@ impl Store for FileBackedStore {
Ok(messages)
}
fn queue_depth(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<usize, StorageError> {
let map = self.deliveries.lock().unwrap();
fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result<usize, StorageError> {
let map = lock(&self.deliveries)?;
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
@@ -408,13 +374,13 @@ impl Store for FileBackedStore {
identity_key: &[u8],
hybrid_pk: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.hybrid_keys.lock().unwrap();
let mut map = lock(&self.hybrid_keys)?;
map.insert(identity_key.to_vec(), hybrid_pk);
self.flush_hybrid_keys(&self.hk_path, &*map)
}
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let map = self.hybrid_keys.lock().unwrap();
let map = lock(&self.hybrid_keys)?;
Ok(map.get(identity_key).cloned())
}
@@ -437,18 +403,18 @@ impl Store for FileBackedStore {
}
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
let mut map = self.users.lock().unwrap();
let mut map = lock(&self.users)?;
map.insert(username.to_string(), record);
self.flush_users(&self.users_path, &*map)
}
fn get_user_record(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError> {
let map = self.users.lock().unwrap();
let map = lock(&self.users)?;
Ok(map.get(username).cloned())
}
fn has_user_record(&self, username: &str) -> Result<bool, StorageError> {
let map = self.users.lock().unwrap();
let map = lock(&self.users)?;
Ok(map.contains_key(username))
}
@@ -457,13 +423,13 @@ impl Store for FileBackedStore {
username: &str,
identity_key: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.identity_keys.lock().unwrap();
let mut map = lock(&self.identity_keys)?;
map.insert(username.to_string(), identity_key);
self.flush_map_string_bytes(&self.identity_keys_path, &*map)
}
fn get_user_identity_key(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError> {
let map = self.identity_keys.lock().unwrap();
let map = lock(&self.identity_keys)?;
Ok(map.get(username).cloned())
}
@@ -472,13 +438,13 @@ impl Store for FileBackedStore {
identity_key: &[u8],
node_addr: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.endpoints.lock().unwrap();
let mut map = lock(&self.endpoints)?;
map.insert(identity_key.to_vec(), node_addr);
Ok(())
}
fn resolve_endpoint(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let map = self.endpoints.lock().unwrap();
let map = lock(&self.endpoints)?;
Ok(map.get(identity_key).cloned())
}
}