chore: rename quicproquo → quicprochat in Rust workspace

Rename all crate directories, package names, binary names, proto
package/module paths, ALPN strings, env var prefixes, config filenames,
mDNS service names, and plugin ABI symbols from quicproquo/qpq to
quicprochat/qpc.
This commit is contained in:
2026-03-07 18:24:52 +01:00
parent d8c1392587
commit a710037dde
212 changed files with 609 additions and 609 deletions

View File

@@ -0,0 +1,232 @@
//! Structured audit log — persistent, machine-readable event journal.
//!
//! Events are serialized as JSON lines and appended to a file or SQL table.
//! Each event carries a correlation `trace_id` for cross-referencing with
//! RPC request traces.
use std::fs::OpenOptions;
use std::io::Write as IoWrite;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use serde::Serialize;
// ── Audit event types ─────────────────────────────────────────────────────
/// Action categories for the audit log.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum AuditAction {
AuthRegister,
AuthLoginSuccess,
AuthLoginFailure,
Enqueue,
BatchEnqueue,
Fetch,
FetchWait,
KeyUpload,
HybridKeyUpload,
BanUser,
UnbanUser,
ReportMessage,
AccountDelete,
DeviceRegister,
DeviceRevoke,
BlobUpload,
RecoveryStore,
RecoveryFetch,
RecoveryDelete,
}
/// Outcome of an audited action.
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum AuditOutcome {
Success,
Denied,
Error,
RateLimited,
}
/// A single audit event record.
#[derive(Debug, Clone, Serialize)]
pub struct AuditEvent {
/// ISO-8601 timestamp.
pub timestamp: String,
/// RPC correlation ID.
pub trace_id: String,
/// Hex-encoded actor identity key (truncated for privacy when redact=true).
pub actor: String,
/// The action performed.
pub action: AuditAction,
/// Target identifier (recipient key, username, etc.).
#[serde(skip_serializing_if = "Option::is_none")]
pub target: Option<String>,
/// Outcome of the action.
pub outcome: AuditOutcome,
/// Free-form details.
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<String>,
}
// ── Audit logger trait ────────────────────────────────────────────────────
/// Trait for audit log backends.
pub trait AuditLogger: Send + Sync {
fn log(&self, event: AuditEvent);
}
// ── File-backed implementation ───────────────────────────────────────────
/// Appends JSON-line events to a file.
pub struct FileAuditLogger {
path: PathBuf,
file: Mutex<std::fs::File>,
}
impl FileAuditLogger {
/// Open (or create) the audit log file at `path`.
pub fn open(path: &Path) -> Result<Self, std::io::Error> {
let file = OpenOptions::new()
.create(true)
.append(true)
.open(path)?;
Ok(Self {
path: path.to_path_buf(),
file: Mutex::new(file),
})
}
/// Return the path to the audit log file.
pub fn path(&self) -> &Path {
&self.path
}
}
impl AuditLogger for FileAuditLogger {
fn log(&self, event: AuditEvent) {
let Ok(mut line) = serde_json::to_string(&event) else {
tracing::warn!("audit: failed to serialize event");
return;
};
line.push('\n');
let Ok(mut f) = self.file.lock() else {
tracing::warn!("audit: log file lock poisoned");
return;
};
if let Err(e) = f.write_all(line.as_bytes()) {
tracing::warn!(error = %e, "audit: failed to write event");
}
}
}
// ── No-op implementation ─────────────────────────────────────────────────
/// Does nothing. Used when audit logging is disabled.
pub struct NoopAuditLogger;
impl AuditLogger for NoopAuditLogger {
fn log(&self, _event: AuditEvent) {}
}
// ── Helpers ──────────────────────────────────────────────────────────────
/// Format identity key bytes as hex, optionally truncated for privacy.
pub fn format_actor(identity_key: &[u8], redact: bool) -> String {
let full = hex::encode(identity_key);
if redact && full.len() > 12 {
format!("{}...", &full[..12])
} else {
full
}
}
/// Current ISO-8601 UTC timestamp.
pub fn now_iso8601() -> String {
// Use SystemTime to avoid pulling in chrono.
let d = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
let secs = d.as_secs();
// Simple UTC formatting: enough for audit logs.
format!("{secs}")
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Read;
#[test]
fn file_audit_logger_writes_json_lines() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("audit.jsonl");
let logger = FileAuditLogger::open(&path).expect("open");
logger.log(AuditEvent {
timestamp: "1709500000".to_string(),
trace_id: "test-trace-001".to_string(),
actor: "abcdef123456".to_string(),
action: AuditAction::Enqueue,
target: Some("recipient-hex".to_string()),
outcome: AuditOutcome::Success,
details: None,
});
logger.log(AuditEvent {
timestamp: "1709500001".to_string(),
trace_id: "test-trace-002".to_string(),
actor: "abcdef123456".to_string(),
action: AuditAction::AuthLoginFailure,
target: None,
outcome: AuditOutcome::Denied,
details: Some("bad password".to_string()),
});
drop(logger);
let mut content = String::new();
std::fs::File::open(&path)
.expect("open for read")
.read_to_string(&mut content)
.expect("read");
let lines: Vec<&str> = content.trim().split('\n').collect();
assert_eq!(lines.len(), 2);
// Verify JSON parses.
let v: serde_json::Value = serde_json::from_str(lines[0]).expect("parse line 0");
assert_eq!(v["action"], "enqueue");
assert_eq!(v["outcome"], "success");
assert_eq!(v["trace_id"], "test-trace-001");
let v: serde_json::Value = serde_json::from_str(lines[1]).expect("parse line 1");
assert_eq!(v["action"], "auth_login_failure");
assert_eq!(v["details"], "bad password");
}
#[test]
fn format_actor_truncates_when_redacted() {
let key = vec![0xAA; 32];
let full = format_actor(&key, false);
assert_eq!(full.len(), 64);
let redacted = format_actor(&key, true);
assert!(redacted.ends_with("..."));
assert_eq!(redacted.len(), 15); // 12 hex chars + "..."
}
#[test]
fn noop_logger_does_not_panic() {
let logger = NoopAuditLogger;
logger.log(AuditEvent {
timestamp: "0".to_string(),
trace_id: "noop".to_string(),
actor: "none".to_string(),
action: AuditAction::Fetch,
target: None,
outcome: AuditOutcome::Success,
details: None,
});
}
}

View File

@@ -0,0 +1,304 @@
use std::net::IpAddr;
use std::sync::Arc;
use dashmap::DashMap;
use quicprochat_proto::node_capnp::auth;
use sha2::Digest;
use subtle::ConstantTimeEq;
use tokio::sync::Notify;
use zeroize::Zeroizing;
use crate::error_codes::*;
pub const SESSION_TTL_SECS: u64 = 24 * 60 * 60; // 24 hours
pub const PENDING_LOGIN_TTL_SECS: u64 = 300; // 5 minutes
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
pub const RATE_LIMIT_MAX_ENQUEUES: u32 = 100;
#[derive(Clone)]
pub struct AuthConfig {
/// Server bearer token — zeroized on drop to prevent memory disclosure.
pub required_token: Option<Zeroizing<Vec<u8>>>,
/// When true, a valid bearer token (no session) is accepted and the request's identity/key is used (dev/e2e only).
/// CLI flag: --allow-insecure-auth / QPQ_ALLOW_INSECURE_AUTH.
pub allow_insecure_identity_from_request: bool,
}
impl std::fmt::Debug for AuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthConfig")
.field("required_token", &self.required_token.as_ref().map(|_| "[REDACTED]"))
.field("allow_insecure_identity_from_request", &self.allow_insecure_identity_from_request)
.finish()
}
}
impl AuthConfig {
pub fn new(required_token: Option<String>, allow_insecure_identity_from_request: bool) -> Self {
let required_token = required_token
.filter(|s| !s.is_empty())
.map(|s| Zeroizing::new(s.into_bytes()));
Self {
required_token,
allow_insecure_identity_from_request,
}
}
}
#[derive(Clone)]
pub struct SessionInfo {
#[allow(dead_code)]
pub username: String,
pub identity_key: Vec<u8>,
#[allow(dead_code)]
pub created_at: u64,
pub expires_at: u64,
}
pub struct PendingLogin {
pub state_bytes: Vec<u8>,
pub created_at: u64,
}
pub struct RateEntry {
pub count: u32,
pub window_start: u64,
}
#[derive(Clone)]
pub struct AuthContext {
pub token: Vec<u8>,
pub identity_key: Option<Vec<u8>>,
}
pub fn current_timestamp() -> u64 {
match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
Ok(d) => d.as_secs(),
Err(_) => {
tracing::warn!("system time is before UNIX_EPOCH; using 0 for session/rate-limit timestamps");
0
}
}
}
pub fn check_rate_limit(
rate_limits: &DashMap<Vec<u8>, RateEntry>,
token: &[u8],
) -> Result<(), capnp::Error> {
let now = current_timestamp();
let mut entry = rate_limits.entry(token.to_vec()).or_insert(RateEntry {
count: 0,
window_start: now,
});
if now - entry.window_start >= RATE_LIMIT_WINDOW_SECS {
entry.count = 1;
entry.window_start = now;
} else {
entry.count += 1;
if entry.count > RATE_LIMIT_MAX_ENQUEUES {
return Err(crate::error_codes::coded_error(
E014_RATE_LIMITED,
format!(
"rate limit exceeded: {} enqueues in {}s window",
RATE_LIMIT_MAX_ENQUEUES, RATE_LIMIT_WINDOW_SECS
),
));
}
}
Ok(())
}
pub fn validate_auth(
cfg: &AuthConfig,
sessions: &DashMap<Vec<u8>, SessionInfo>,
auth: Result<auth::Reader<'_>, capnp::Error>,
) -> Result<(), capnp::Error> {
validate_auth_context(cfg, sessions, auth).map(|_| ())
}
pub fn validate_auth_context(
cfg: &AuthConfig,
sessions: &DashMap<Vec<u8>, SessionInfo>,
auth: Result<auth::Reader<'_>, capnp::Error>,
) -> Result<AuthContext, capnp::Error> {
let auth = auth?;
let version = auth.get_version();
if version != 1 {
return Err(crate::error_codes::coded_error(
E001_BAD_AUTH_VERSION,
format!("unsupported auth version {} (expected 1)", version),
));
}
let token = auth
.get_access_token()
.map_err(|e| crate::error_codes::coded_error(E020_BAD_PARAMS, format!("auth.accessToken: {e}")))?
.to_vec();
if token.is_empty() {
return Err(crate::error_codes::coded_error(
E002_EMPTY_TOKEN,
"auth.version=1 requires non-empty accessToken",
));
}
if let Some(expected) = &cfg.required_token {
if expected.len() == token.len() && bool::from(expected.as_slice().ct_eq(&token)) {
return Ok(AuthContext {
token,
identity_key: None,
});
}
}
if let Some(session) = sessions.get(&token) {
let now = current_timestamp();
if session.expires_at > now {
let identity = if session.identity_key.is_empty() {
None
} else {
Some(session.identity_key.clone())
};
return Ok(AuthContext {
token,
identity_key: identity,
});
}
drop(session);
sessions.remove(&token);
return Err(crate::error_codes::coded_error(
E017_SESSION_EXPIRED,
"session token has expired",
));
}
Err(crate::error_codes::coded_error(E003_INVALID_TOKEN, "invalid accessToken"))
}
/// Validate a raw bearer token (no Cap'n Proto dependency).
/// Used by the WebSocket JSON-RPC bridge.
pub fn validate_token_raw(
cfg: &AuthConfig,
sessions: &DashMap<Vec<u8>, SessionInfo>,
token: &[u8],
) -> Result<AuthContext, String> {
if token.is_empty() {
return Err("empty access token".to_string());
}
// Check static bearer token.
if let Some(expected) = &cfg.required_token {
if expected.len() == token.len() && bool::from(expected.as_slice().ct_eq(token)) {
return Ok(AuthContext {
token: token.to_vec(),
identity_key: None,
});
}
}
// Check session tokens.
if let Some(session) = sessions.get(token) {
let now = current_timestamp();
if session.expires_at > now {
let identity = if session.identity_key.is_empty() {
None
} else {
Some(session.identity_key.clone())
};
return Ok(AuthContext {
token: token.to_vec(),
identity_key: identity,
});
}
drop(session);
sessions.remove(token);
return Err("session token has expired".to_string());
}
Err("invalid access token".to_string())
}
pub fn require_identity(auth_ctx: &AuthContext) -> Result<&[u8], capnp::Error> {
match auth_ctx.identity_key.as_deref() {
Some(ik) => Ok(ik),
None => Err(crate::error_codes::coded_error(
E003_INVALID_TOKEN,
"access token is not identity-bound; login required",
)),
}
}
pub fn require_identity_match(auth_ctx: &AuthContext, expected: &[u8]) -> Result<(), capnp::Error> {
let ik = require_identity(auth_ctx)?;
if ik.len() != expected.len() || !bool::from(ik.ct_eq(expected)) {
return Err(crate::error_codes::coded_error(
E016_IDENTITY_MISMATCH,
"access token is bound to a different identity",
));
}
Ok(())
}
/// When the token is a valid session, require it to match `request_identity`.
/// When the token is a bearer token (no identity) and `allow_insecure_identity_from_request` is true, accept the request identity (dev/e2e).
pub fn require_identity_or_request(
auth_ctx: &AuthContext,
request_identity: &[u8],
allow_insecure: bool,
) -> Result<(), capnp::Error> {
match auth_ctx.identity_key.as_deref() {
Some(_) => require_identity_match(auth_ctx, request_identity),
None if allow_insecure => Ok(()),
None => Err(crate::error_codes::coded_error(
E003_INVALID_TOKEN,
"access token is not identity-bound; login required",
)),
}
}
pub fn fmt_hex(bytes: &[u8]) -> String {
let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
format!("{hex}")
}
pub fn waiter(waiters: &DashMap<Vec<u8>, Arc<Notify>>, recipient_key: &[u8]) -> Arc<Notify> {
waiters
.entry(recipient_key.to_vec())
.or_insert_with(|| Arc::new(Notify::new()))
.clone()
}
pub const CONN_RATE_LIMIT_WINDOW_SECS: u64 = 60;
pub const CONN_RATE_LIMIT_MAX: u32 = 50;
/// Per-IP connection rate limiter. Returns `true` if the connection is allowed.
pub fn check_conn_rate_limit(
conn_rate_limits: &DashMap<IpAddr, RateEntry>,
ip: IpAddr,
) -> bool {
let now = current_timestamp();
let mut entry = conn_rate_limits.entry(ip).or_insert(RateEntry {
count: 0,
window_start: now,
});
if now - entry.window_start >= CONN_RATE_LIMIT_WINDOW_SECS {
entry.count = 1;
entry.window_start = now;
true
} else {
entry.count += 1;
entry.count <= CONN_RATE_LIMIT_MAX
}
}
pub fn fingerprint(data: &[u8]) -> Vec<u8> {
sha2::Sha256::digest(data).to_vec()
}
pub fn coded_error(code: &str, msg: impl std::fmt::Display) -> capnp::Error {
crate::error_codes::coded_error(code, msg)
}

View File

@@ -0,0 +1,350 @@
use std::path::{Path, PathBuf};
use anyhow::Context;
use serde::Deserialize;
pub const DEFAULT_LISTEN: &str = "0.0.0.0:7000";
pub const DEFAULT_DATA_DIR: &str = "data";
pub const DEFAULT_TLS_CERT: &str = "data/server-cert.der";
pub const DEFAULT_TLS_KEY: &str = "data/server-key.der";
pub const DEFAULT_STORE_BACKEND: &str = "file";
pub const DEFAULT_DB_PATH: &str = "data/qpc.db";
#[derive(Debug, Default, Deserialize)]
pub struct FileConfig {
pub listen: Option<String>,
pub data_dir: Option<String>,
pub tls_cert: Option<PathBuf>,
pub tls_key: Option<PathBuf>,
pub auth_token: Option<String>,
pub allow_insecure_auth: Option<bool>,
/// When true, enqueue does not require an identity-bound session: only a valid token is required.
/// The server does not associate the request with a specific sender (Sealed Sender).
#[serde(default)]
pub sealed_sender: Option<bool>,
pub store_backend: Option<String>,
pub db_path: Option<PathBuf>,
pub db_key: Option<String>,
/// Metrics HTTP listen address (e.g. "0.0.0.0:9090"). If set, /metrics is served there.
pub metrics_listen: Option<String>,
/// When true and metrics_listen is set, start the metrics server.
#[serde(default)]
pub metrics_enabled: Option<bool>,
pub federation: Option<FederationFileConfig>,
/// Directory containing plugin `.so` / `.dylib` files to load at startup.
pub plugin_dir: Option<PathBuf>,
/// When true, audit logs hash identity key prefixes and omit payload sizes.
#[serde(default)]
pub redact_logs: Option<bool>,
/// WebSocket JSON-RPC bridge listen address (e.g. "0.0.0.0:9000").
pub ws_listen: Option<String>,
/// WebTransport (HTTP/3) listen address for browser clients (e.g. "0.0.0.0:7443").
pub webtransport_listen: Option<String>,
/// Graceful shutdown drain timeout in seconds.
pub drain_timeout_secs: Option<u64>,
/// Default per-RPC timeout in seconds.
pub rpc_timeout_secs: Option<u64>,
/// Storage/database operation timeout in seconds.
pub storage_timeout_secs: Option<u64>,
}
#[derive(Debug)]
pub struct EffectiveConfig {
pub listen: String,
pub data_dir: String,
pub tls_cert: PathBuf,
pub tls_key: PathBuf,
pub auth_token: Option<String>,
pub allow_insecure_auth: bool,
/// When true, enqueue does not require identity; valid token only (Sealed Sender).
pub sealed_sender: bool,
pub store_backend: String,
pub db_path: PathBuf,
pub db_key: String,
/// If Some(addr), metrics server listens here (e.g. "0.0.0.0:9090").
pub metrics_listen: Option<String>,
/// Start metrics server only when true and metrics_listen is set.
pub metrics_enabled: bool,
pub federation: Option<EffectiveFederationConfig>,
/// Directory to scan for plugin `.so` / `.dylib` files at startup. None = no plugins.
pub plugin_dir: Option<PathBuf>,
/// When true, audit logs hash identity key prefixes and omit payload sizes.
pub redact_logs: bool,
/// WebSocket JSON-RPC bridge listen address. If set, the bridge is started.
pub ws_listen: Option<String>,
/// WebTransport (HTTP/3) listen address. If set, the WebTransport endpoint is started.
pub webtransport_listen: Option<String>,
/// Graceful shutdown drain timeout in seconds.
pub drain_timeout_secs: u64,
/// Default per-RPC timeout in seconds.
pub rpc_timeout_secs: u64,
/// Storage/database operation timeout in seconds.
pub storage_timeout_secs: u64,
}
pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_RPC_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_STORAGE_TIMEOUT_SECS: u64 = 10;
#[derive(Debug, Default, Deserialize)]
pub struct FederationFileConfig {
pub enabled: Option<bool>,
pub domain: Option<String>,
pub listen: Option<String>,
pub federation_cert: Option<PathBuf>,
pub federation_key: Option<PathBuf>,
pub federation_ca: Option<PathBuf>,
#[serde(default)]
pub peers: Vec<FederationPeerConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct FederationPeerConfig {
pub domain: String,
pub address: String,
}
#[derive(Debug)]
#[allow(dead_code)] // federation not yet wired up
pub struct EffectiveFederationConfig {
pub enabled: bool,
pub domain: String,
pub listen: String,
pub federation_cert: PathBuf,
pub federation_key: PathBuf,
pub federation_ca: PathBuf,
pub peers: Vec<FederationPeerConfig>,
}
pub fn load_config(path: Option<&Path>) -> anyhow::Result<FileConfig> {
let path = match path {
Some(p) => PathBuf::from(p),
None => PathBuf::from("qpc-server.toml"),
};
if !path.exists() {
return Ok(FileConfig::default());
}
let contents =
std::fs::read_to_string(&path).with_context(|| format!("read config file {path:?}"))?;
let cfg: FileConfig =
toml::from_str(&contents).with_context(|| format!("parse config file {path:?}"))?;
Ok(cfg)
}
pub fn merge_config(args: &crate::Args, file: &FileConfig) -> EffectiveConfig {
let listen = if args.listen == DEFAULT_LISTEN {
file.listen
.clone()
.unwrap_or_else(|| DEFAULT_LISTEN.to_string())
} else {
args.listen.clone()
};
let data_dir = if args.data_dir == DEFAULT_DATA_DIR {
file.data_dir
.clone()
.unwrap_or_else(|| DEFAULT_DATA_DIR.to_string())
} else {
args.data_dir.clone()
};
let tls_cert = if args.tls_cert == Path::new(DEFAULT_TLS_CERT) {
file.tls_cert
.clone()
.unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_CERT))
} else {
args.tls_cert.clone()
};
let tls_key = if args.tls_key == Path::new(DEFAULT_TLS_KEY) {
file.tls_key
.clone()
.unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_KEY))
} else {
args.tls_key.clone()
};
let auth_token = if args.auth_token.is_some() {
args.auth_token.clone()
} else {
file.auth_token.clone()
};
let allow_insecure_auth = if args.allow_insecure_auth {
true
} else {
file.allow_insecure_auth.unwrap_or(false)
};
let sealed_sender = args.sealed_sender || file.sealed_sender.unwrap_or(false);
let store_backend = if args.store_backend == DEFAULT_STORE_BACKEND {
file.store_backend
.clone()
.unwrap_or_else(|| DEFAULT_STORE_BACKEND.to_string())
} else {
args.store_backend.clone()
};
let db_path = if args.db_path == Path::new(DEFAULT_DB_PATH) {
file.db_path
.clone()
.unwrap_or_else(|| PathBuf::from(DEFAULT_DB_PATH))
} else {
args.db_path.clone()
};
let db_key = if args.db_key.is_empty() {
file.db_key.clone().unwrap_or_else(|| args.db_key.clone())
} else {
args.db_key.clone()
};
let metrics_listen = args
.metrics_listen
.clone()
.or_else(|| file.metrics_listen.clone());
let metrics_enabled = args
.metrics_enabled
.or(file.metrics_enabled)
.unwrap_or(metrics_listen.is_some());
let federation = {
let file_fed = file.federation.as_ref();
let enabled = args.federation_enabled
|| file_fed.and_then(|f| f.enabled).unwrap_or(false);
if enabled {
let domain = args.federation_domain.clone()
.or_else(|| file_fed.and_then(|f| f.domain.clone()))
.unwrap_or_default();
let listen_fed = args.federation_listen.clone()
.or_else(|| file_fed.and_then(|f| f.listen.clone()))
.unwrap_or_else(|| "0.0.0.0:7001".to_string());
let federation_cert = file_fed.and_then(|f| f.federation_cert.clone())
.unwrap_or_else(|| PathBuf::from("data/federation-cert.der"));
let federation_key = file_fed.and_then(|f| f.federation_key.clone())
.unwrap_or_else(|| PathBuf::from("data/federation-key.der"));
let federation_ca = file_fed.and_then(|f| f.federation_ca.clone())
.unwrap_or_else(|| PathBuf::from("data/federation-ca.der"));
let peers = file_fed
.map(|f| f.peers.clone())
.unwrap_or_default();
Some(EffectiveFederationConfig {
enabled,
domain,
listen: listen_fed,
federation_cert,
federation_key,
federation_ca,
peers,
})
} else {
None
}
};
let plugin_dir = args.plugin_dir.clone().or_else(|| file.plugin_dir.clone());
let redact_logs = args.redact_logs || file.redact_logs.unwrap_or(false);
let ws_listen = args
.ws_listen
.clone()
.or_else(|| file.ws_listen.clone());
let webtransport_listen = args
.webtransport_listen
.clone()
.or_else(|| file.webtransport_listen.clone());
let drain_timeout_secs = if args.drain_timeout == DEFAULT_DRAIN_TIMEOUT_SECS {
file.drain_timeout_secs.unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS)
} else {
args.drain_timeout
};
let rpc_timeout_secs = if args.rpc_timeout == DEFAULT_RPC_TIMEOUT_SECS {
file.rpc_timeout_secs.unwrap_or(DEFAULT_RPC_TIMEOUT_SECS)
} else {
args.rpc_timeout
};
let storage_timeout_secs = if args.storage_timeout == DEFAULT_STORAGE_TIMEOUT_SECS {
file.storage_timeout_secs.unwrap_or(DEFAULT_STORAGE_TIMEOUT_SECS)
} else {
args.storage_timeout
};
EffectiveConfig {
listen,
data_dir,
tls_cert,
tls_key,
auth_token,
allow_insecure_auth,
sealed_sender,
store_backend,
db_path,
db_key,
metrics_listen,
metrics_enabled,
federation,
plugin_dir,
redact_logs,
ws_listen,
webtransport_listen,
drain_timeout_secs,
rpc_timeout_secs,
storage_timeout_secs,
}
}
pub fn validate_production_config(effective: &EffectiveConfig) -> anyhow::Result<()> {
if effective.allow_insecure_auth {
anyhow::bail!("production forbids --allow-insecure-auth");
}
let token = effective
.auth_token
.as_deref()
.filter(|s| !s.is_empty())
.ok_or_else(|| {
anyhow::anyhow!("production requires QPQ_AUTH_TOKEN (non-empty)")
})?;
if token == "devtoken" {
anyhow::bail!(
"production forbids auth_token 'devtoken'; set a strong QPQ_AUTH_TOKEN"
);
}
if token.len() < 16 {
anyhow::bail!(
"production requires QPQ_AUTH_TOKEN of at least 16 characters (got {})",
token.len()
);
}
if effective.store_backend == "sql" && effective.db_key.is_empty() {
anyhow::bail!("production with store_backend=sql requires non-empty QPQ_DB_KEY");
}
if effective.store_backend == "sql" {
let db_dir = effective
.db_path
.parent()
.unwrap_or_else(|| Path::new("."));
// Verify the directory exists and is writable by creating+removing a probe file.
let probe = db_dir.join(".qpc-write-probe");
std::fs::write(&probe, b"probe")
.with_context(|| format!("DB path parent {:?} is not writable", db_dir))?;
let _ = std::fs::remove_file(&probe);
}
if effective.store_backend != "sql" {
tracing::warn!(
"production is using file-backed storage; \
consider store_backend=sql with QPQ_DB_KEY for encryption at rest"
);
}
if !effective.tls_cert.exists() || !effective.tls_key.exists() {
anyhow::bail!(
"production requires existing TLS cert and key (no auto-generation); provide QPQ_TLS_CERT and QPQ_TLS_KEY"
);
}
Ok(())
}

View File

@@ -0,0 +1,28 @@
//! Account domain logic — account deletion with KT tombstone.
use std::sync::{Arc, Mutex};
use quicprochat_kt::MerkleLog;
use crate::storage::Store;
use super::types::DomainError;
/// Domain service for account lifecycle operations.
pub struct AccountService {
pub store: Arc<dyn Store>,
pub kt_log: Arc<Mutex<MerkleLog>>,
}
impl AccountService {
pub fn delete_account(&self, caller_identity_key: &[u8]) -> Result<(), DomainError> {
self.store.delete_account(caller_identity_key)?;
// Append a KT tombstone entry so the deletion is auditable.
if let Ok(mut log) = self.kt_log.lock() {
log.append("__tombstone__", caller_identity_key);
}
Ok(())
}
}

View File

@@ -0,0 +1,72 @@
//! Authentication domain logic — OPAQUE registration and login.
//!
//! This module contains the pure business logic for OPAQUE auth,
//! extracted from `node_service/auth_ops.rs`. It operates on domain
//! types and the `Store` trait, with no dependency on Cap'n Proto or Protobuf.
use std::sync::Arc;
use dashmap::DashMap;
use opaque_ke::ServerSetup;
use quicprochat_core::opaque_auth::OpaqueSuite;
use crate::auth::{AuthConfig, PendingLogin, SessionInfo};
use crate::storage::{Store, StorageError};
use super::types::*;
/// Shared state needed by auth operations.
pub struct AuthService {
pub store: Arc<dyn Store>,
pub opaque_setup: Arc<ServerSetup<OpaqueSuite>>,
pub pending_logins: Arc<DashMap<String, PendingLogin>>,
pub sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
pub auth_cfg: Arc<AuthConfig>,
}
impl AuthService {
/// Validate a session token and return the caller's auth context.
pub fn validate_session(&self, token: &[u8]) -> Option<CallerAuth> {
let info = self.sessions.get(token)?;
if info.expires_at <= crate::auth::current_timestamp() {
self.sessions.remove(token);
return None;
}
Some(CallerAuth {
identity_key: info.identity_key.clone(),
token: token.to_vec(),
device_id: None,
})
}
/// Start OPAQUE registration.
pub fn register_start(&self, req: RegisterStartReq) -> Result<RegisterStartResp, StorageError> {
use opaque_ke::ServerRegistration;
let result = ServerRegistration::<OpaqueSuite>::start(
&self.opaque_setup,
opaque_ke::RegistrationRequest::deserialize(&req.request_bytes)
.map_err(|e| StorageError::Io(format!("bad registration request: {e}")))?,
req.username.as_bytes(),
)
.map_err(|e| StorageError::Io(format!("OPAQUE register start: {e}")))?;
let response_bytes = result.message.serialize().to_vec();
Ok(RegisterStartResp { response_bytes })
}
/// Finish OPAQUE registration — persist user record and identity key.
pub fn register_finish(&self, req: RegisterFinishReq) -> Result<RegisterFinishResp, StorageError> {
let upload = opaque_ke::RegistrationUpload::<OpaqueSuite>::deserialize(&req.upload_bytes)
.map_err(|e| StorageError::Io(format!("bad registration upload: {e}")))?;
let record = opaque_ke::ServerRegistration::<OpaqueSuite>::finish(upload);
let serialized = record.serialize().to_vec();
self.store.store_user_record(&req.username, serialized)?;
self.store
.store_user_identity_key(&req.username, req.identity_key)?;
Ok(RegisterFinishResp { success: true })
}
}

View File

@@ -0,0 +1,193 @@
//! Blob domain logic — chunked file upload/download with SHA-256 verification.
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::PathBuf;
use sha2::{Digest, Sha256};
use super::types::*;
/// Maximum blob size: 100 MB.
const MAX_BLOB_SIZE: u64 = 100 * 1024 * 1024;
/// Maximum download chunk size: 256 KB.
const MAX_DOWNLOAD_CHUNK: u32 = 256 * 1024;
/// Metadata stored alongside each completed blob.
#[derive(serde::Serialize, serde::Deserialize)]
struct BlobMeta {
mime_type: String,
total_size: u64,
uploaded_at: u64,
}
/// Domain service for blob (file attachment) storage.
pub struct BlobService {
pub data_dir: PathBuf,
}
impl BlobService {
fn blobs_dir(&self) -> PathBuf {
self.data_dir.join("blobs")
}
pub fn upload_blob(
&self,
req: UploadBlobReq,
_auth: &CallerAuth,
) -> Result<UploadBlobResp, DomainError> {
if req.blob_hash.len() != 32 {
return Err(DomainError::BlobHashLength(req.blob_hash.len()));
}
if req.total_size > MAX_BLOB_SIZE {
return Err(DomainError::BlobTooLarge(req.total_size));
}
if req.total_size == 0 {
return Err(DomainError::BadParams("total_size must be > 0".into()));
}
if req
.offset
.checked_add(req.chunk.len() as u64)
.is_none_or(|end| end > req.total_size)
{
return Err(DomainError::BadParams(format!(
"chunk out of bounds: offset={} + chunk_len={} > total_size={}",
req.offset,
req.chunk.len(),
req.total_size
)));
}
let blob_hex = hex::encode(&req.blob_hash);
let dir = self.blobs_dir();
std::fs::create_dir_all(&dir)
.map_err(|e| DomainError::Io(format!("create blobs directory: {e}")))?;
let part_path = dir.join(format!("{blob_hex}.part"));
let final_path = dir.join(&blob_hex);
let meta_path = dir.join(format!("{blob_hex}.meta"));
// Already fully uploaded.
if final_path.exists() {
return Ok(UploadBlobResp {
blob_id: req.blob_hash,
});
}
// Write chunk at offset.
let mut file = std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(false)
.open(&part_path)
.map_err(|e| DomainError::Io(format!("open .part file: {e}")))?;
file.seek(SeekFrom::Start(req.offset))
.map_err(|e| DomainError::Io(format!("seek: {e}")))?;
file.write_all(&req.chunk)
.map_err(|e| DomainError::Io(format!("write chunk: {e}")))?;
file.sync_all()
.map_err(|e| DomainError::Io(format!("sync: {e}")))?;
// Check if upload is complete.
let end = req.offset + req.chunk.len() as u64;
if end == req.total_size {
// Verify SHA-256.
let mut vfile = std::fs::File::open(&part_path)
.map_err(|e| DomainError::Io(format!("open for verify: {e}")))?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = vfile
.read(&mut buf)
.map_err(|e| DomainError::Io(format!("read: {e}")))?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
let computed: [u8; 32] = hasher.finalize().into();
if computed[..] != req.blob_hash[..] {
let _ = std::fs::remove_file(&part_path);
return Err(DomainError::BlobHashMismatch);
}
// Finalize.
std::fs::rename(&part_path, &final_path)
.map_err(|e| DomainError::Io(format!("rename .part: {e}")))?;
// Write metadata.
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let meta = BlobMeta {
mime_type: req.mime_type,
total_size: req.total_size,
uploaded_at: now,
};
if let Ok(json) = serde_json::to_string_pretty(&meta) {
let _ = std::fs::write(&meta_path, json.as_bytes());
}
}
Ok(UploadBlobResp {
blob_id: req.blob_hash,
})
}
pub fn download_blob(
&self,
req: DownloadBlobReq,
_auth: &CallerAuth,
) -> Result<DownloadBlobResp, DomainError> {
if req.blob_id.len() != 32 {
return Err(DomainError::BlobHashLength(req.blob_id.len()));
}
let blob_hex = hex::encode(&req.blob_id);
let dir = self.blobs_dir();
let blob_path = dir.join(&blob_hex);
let meta_path = dir.join(format!("{blob_hex}.meta"));
if !blob_path.exists() {
return Err(DomainError::BlobNotFound);
}
// Read metadata.
let meta_json = std::fs::read_to_string(&meta_path)
.map_err(|e| DomainError::Io(format!("read blob metadata: {e}")))?;
let meta: BlobMeta = serde_json::from_str(&meta_json)
.map_err(|e| DomainError::Io(format!("corrupt blob metadata: {e}")))?;
// Read chunk.
let mut file = std::fs::File::open(&blob_path)
.map_err(|e| DomainError::Io(format!("open blob: {e}")))?;
let file_len = file
.metadata()
.map_err(|e| DomainError::Io(format!("file metadata: {e}")))?
.len();
if req.offset >= file_len {
return Ok(DownloadBlobResp {
chunk: vec![],
total_size: meta.total_size,
mime_type: meta.mime_type,
});
}
file.seek(SeekFrom::Start(req.offset))
.map_err(|e| DomainError::Io(format!("seek: {e}")))?;
let remaining = (file_len - req.offset) as usize;
let to_read = remaining.min(req.length.min(MAX_DOWNLOAD_CHUNK) as usize);
let mut chunk = vec![0u8; to_read];
file.read_exact(&mut chunk)
.map_err(|e| DomainError::Io(format!("read chunk: {e}")))?;
Ok(DownloadBlobResp {
chunk,
total_size: meta.total_size,
mime_type: meta.mime_type,
})
}
}

View File

@@ -0,0 +1,38 @@
//! Channel domain logic — 1:1 DM channel creation.
use std::sync::Arc;
use crate::storage::Store;
use super::types::*;
/// Domain service for 1:1 channel management.
pub struct ChannelService {
pub store: Arc<dyn Store>,
}
impl ChannelService {
pub fn create_channel(
&self,
req: CreateChannelReq,
caller_identity_key: &[u8],
) -> Result<CreateChannelResp, DomainError> {
if req.peer_key.len() != 32 {
return Err(DomainError::InvalidIdentityKey(req.peer_key.len()));
}
if caller_identity_key == req.peer_key.as_slice() {
return Err(DomainError::BadParams(
"peer_key must not equal caller identity".into(),
));
}
let (channel_id, was_new) = self
.store
.create_channel(caller_identity_key, &req.peer_key)?;
Ok(CreateChannelResp {
channel_id,
was_new,
})
}
}

View File

@@ -0,0 +1,352 @@
//! Delivery domain logic — enqueue, fetch, peek, ack.
//!
//! Pure business logic operating on `Store` trait and domain types.
//!
//! ## Multi-device delivery
//!
//! When a message is enqueued for a recipient identity, the service resolves
//! all registered device IDs for that identity and enqueues a copy of the
//! payload to each device-scoped queue. The queue key is a composite of
//! `identity_key + device_id`, so each device maintains its own sequence
//! counter and ack state.
//!
//! If the recipient has no registered devices, the message is delivered to
//! the bare `identity_key` queue (backwards compatible with single-device
//! clients).
use std::sync::Arc;
use dashmap::DashMap;
use tokio::sync::Notify;
use crate::storage::Store;
use super::types::*;
/// Build a device-scoped recipient key: `identity_key || device_id`.
/// When `device_id` is empty, returns a clone of `identity_key` (single-device compat).
fn device_recipient_key(identity_key: &[u8], device_id: &[u8]) -> Vec<u8> {
if device_id.is_empty() {
return identity_key.to_vec();
}
let mut key = Vec::with_capacity(identity_key.len() + device_id.len());
key.extend_from_slice(identity_key);
key.extend_from_slice(device_id);
key
}
/// Shared state needed by delivery operations.
pub struct DeliveryService {
pub store: Arc<dyn Store>,
pub waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
}
impl DeliveryService {
/// Resolve the device-scoped recipient keys for an identity.
/// Returns a list of composite keys (identity_key + device_id) for each
/// registered device. If no devices are registered, returns a single-element
/// list with the bare identity_key for backwards compatibility.
fn resolve_device_keys(&self, identity_key: &[u8]) -> Vec<Vec<u8>> {
let devices = self.store.list_devices(identity_key).unwrap_or_default();
if devices.is_empty() {
vec![identity_key.to_vec()]
} else {
devices
.into_iter()
.map(|(device_id, _, _)| device_recipient_key(identity_key, &device_id))
.collect()
}
}
/// Wake any long-polling waiter for the given recipient key.
fn wake_waiter(&self, recipient_key: &[u8]) {
if let Some(notify) = self.waiters.get(recipient_key) {
notify.notify_one();
}
}
/// Enqueue a payload for delivery to all devices of the recipient.
///
/// Returns the sequence number from the *first* device queue (for backwards
/// compatibility with single-device callers).
pub fn enqueue(&self, req: EnqueueReq) -> Result<EnqueueResp, crate::storage::StorageError> {
let ttl = if req.ttl_secs > 0 {
Some(req.ttl_secs)
} else {
None
};
let device_keys = self.resolve_device_keys(&req.recipient_key);
let mut first_seq = 0;
for (i, dk) in device_keys.iter().enumerate() {
let seq = self.store.enqueue(
dk,
&req.channel_id,
req.payload.clone(),
ttl,
)?;
if i == 0 {
first_seq = seq;
}
self.wake_waiter(dk);
}
// Also wake the bare identity_key waiter (legacy clients).
self.wake_waiter(&req.recipient_key);
Ok(EnqueueResp {
seq: first_seq,
delivery_proof: Vec::new(), // Proof generated at RPC handler layer (see v2_handlers/delivery.rs)
})
}
/// Fetch and drain queued messages for a specific device.
///
/// The `recipient_key` should be the device-scoped composite key
/// (`identity_key + device_id`) or bare `identity_key` for single-device.
pub fn fetch(&self, req: FetchReq) -> Result<FetchResp, crate::storage::StorageError> {
let messages = if req.limit > 0 {
self.store
.fetch_limited(&req.recipient_key, &req.channel_id, req.limit as usize)?
} else {
self.store.fetch(&req.recipient_key, &req.channel_id)?
};
Ok(FetchResp {
payloads: messages
.into_iter()
.map(|(seq, data)| Envelope { seq, data })
.collect(),
})
}
/// Peek at messages without removing them.
pub fn peek(&self, req: PeekReq) -> Result<PeekResp, crate::storage::StorageError> {
let messages = self.store.peek(
&req.recipient_key,
&req.channel_id,
if req.limit > 0 { req.limit as usize } else { 0 },
)?;
Ok(PeekResp {
payloads: messages
.into_iter()
.map(|(seq, data)| Envelope { seq, data })
.collect(),
})
}
/// Acknowledge messages up to a sequence number.
pub fn ack(&self, req: AckReq) -> Result<(), crate::storage::StorageError> {
self.store
.ack(&req.recipient_key, &req.channel_id, req.seq_up_to)?;
Ok(())
}
/// Batch enqueue to multiple recipients (with multi-device fan-out for each).
///
/// Returns one sequence number per recipient identity (from the first device queue).
pub fn batch_enqueue(
&self,
req: BatchEnqueueReq,
) -> Result<BatchEnqueueResp, crate::storage::StorageError> {
let ttl = if req.ttl_secs > 0 {
Some(req.ttl_secs)
} else {
None
};
let mut seqs = Vec::with_capacity(req.recipient_keys.len());
for rk in &req.recipient_keys {
let device_keys = self.resolve_device_keys(rk);
let mut first_seq = 0;
for (i, dk) in device_keys.iter().enumerate() {
let seq = self.store.enqueue(dk, &req.channel_id, req.payload.clone(), ttl)?;
if i == 0 {
first_seq = seq;
}
self.wake_waiter(dk);
}
self.wake_waiter(rk);
seqs.push(first_seq);
}
Ok(BatchEnqueueResp { seqs })
}
/// Build a device-scoped recipient key from identity_key and device_id.
/// Public helper for RPC handlers to build the correct fetch/ack key.
pub fn device_recipient_key(identity_key: &[u8], device_id: &[u8]) -> Vec<u8> {
device_recipient_key(identity_key, device_id)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::storage::FileBackedStore;
fn test_service() -> (tempfile::TempDir, DeliveryService) {
let dir = tempfile::tempdir().unwrap();
let store = Arc::new(FileBackedStore::open(dir.path()).unwrap());
let svc = DeliveryService {
store,
waiters: Arc::new(DashMap::new()),
};
(dir, svc)
}
#[test]
fn enqueue_single_device_backwards_compat() {
let (_dir, svc) = test_service();
let ik = vec![1u8; 32];
let ch = vec![0u8; 16];
// No devices registered — should enqueue to bare identity_key.
let resp = svc
.enqueue(EnqueueReq {
recipient_key: ik.clone(),
payload: b"hello".to_vec(),
channel_id: ch.clone(),
ttl_secs: 0,
})
.unwrap();
assert_eq!(resp.seq, 0);
// Fetch from bare identity_key.
let fetched = svc
.fetch(FetchReq {
recipient_key: ik,
channel_id: ch,
limit: 10,
})
.unwrap();
assert_eq!(fetched.payloads.len(), 1);
assert_eq!(fetched.payloads[0].data, b"hello");
}
#[test]
fn enqueue_multi_device_fanout() {
let (_dir, svc) = test_service();
let ik = vec![2u8; 32];
let ch = vec![0u8; 16];
let dev_a = b"device-a".to_vec();
let dev_b = b"device-b".to_vec();
// Register two devices.
svc.store
.register_device(&ik, &dev_a, "Phone")
.unwrap();
svc.store
.register_device(&ik, &dev_b, "Laptop")
.unwrap();
// Enqueue a message.
svc.enqueue(EnqueueReq {
recipient_key: ik.clone(),
payload: b"fanout-msg".to_vec(),
channel_id: ch.clone(),
ttl_secs: 0,
})
.unwrap();
// Each device should receive the message on its own queue.
let key_a = device_recipient_key(&ik, &dev_a);
let key_b = device_recipient_key(&ik, &dev_b);
let msgs_a = svc
.fetch(FetchReq {
recipient_key: key_a,
channel_id: ch.clone(),
limit: 10,
})
.unwrap();
assert_eq!(msgs_a.payloads.len(), 1);
assert_eq!(msgs_a.payloads[0].data, b"fanout-msg");
let msgs_b = svc
.fetch(FetchReq {
recipient_key: key_b,
channel_id: ch.clone(),
limit: 10,
})
.unwrap();
assert_eq!(msgs_b.payloads.len(), 1);
assert_eq!(msgs_b.payloads[0].data, b"fanout-msg");
// Bare identity_key queue should be empty (not used when devices exist).
let msgs_bare = svc
.fetch(FetchReq {
recipient_key: ik,
channel_id: ch,
limit: 10,
})
.unwrap();
assert!(msgs_bare.payloads.is_empty());
}
#[test]
fn batch_enqueue_multi_device() {
let (_dir, svc) = test_service();
let ik1 = vec![3u8; 32];
let ik2 = vec![4u8; 32];
let ch = vec![0u8; 16];
let dev = b"dev1".to_vec();
// ik1 has a device, ik2 has none.
svc.store
.register_device(&ik1, &dev, "Phone")
.unwrap();
let resp = svc
.batch_enqueue(BatchEnqueueReq {
recipient_keys: vec![ik1.clone(), ik2.clone()],
payload: b"batch-msg".to_vec(),
channel_id: ch.clone(),
ttl_secs: 0,
})
.unwrap();
assert_eq!(resp.seqs.len(), 2);
// ik1 device should have the message.
let key_1 = device_recipient_key(&ik1, &dev);
let msgs_1 = svc
.fetch(FetchReq {
recipient_key: key_1,
channel_id: ch.clone(),
limit: 10,
})
.unwrap();
assert_eq!(msgs_1.payloads.len(), 1);
// ik2 (no devices) should have it on bare key.
let msgs_2 = svc
.fetch(FetchReq {
recipient_key: ik2,
channel_id: ch,
limit: 10,
})
.unwrap();
assert_eq!(msgs_2.payloads.len(), 1);
}
#[test]
fn device_recipient_key_construction() {
let ik = vec![1u8; 32];
let dev = b"my-device".to_vec();
// With device_id.
let key = device_recipient_key(&ik, &dev);
assert_eq!(key.len(), 32 + dev.len());
assert_eq!(&key[..32], &ik[..]);
assert_eq!(&key[32..], dev.as_slice());
// Empty device_id returns bare identity_key.
let bare = device_recipient_key(&ik, &[]);
assert_eq!(bare, ik);
}
}

View File

@@ -0,0 +1,76 @@
//! Device registry domain logic — register, list, revoke devices.
use std::sync::Arc;
use crate::storage::Store;
use super::types::*;
const MAX_DEVICES_PER_IDENTITY: usize = 5;
/// Domain service for multi-device management.
pub struct DeviceService {
pub store: Arc<dyn Store>,
}
impl DeviceService {
pub fn register_device(
&self,
req: RegisterDeviceReq,
caller_identity_key: &[u8],
) -> Result<RegisterDeviceResp, DomainError> {
if req.device_id.is_empty() {
return Err(DomainError::BadParams(
"device_id must not be empty".into(),
));
}
let count = self.store.device_count(caller_identity_key)?;
if count >= MAX_DEVICES_PER_IDENTITY {
return Err(DomainError::DeviceLimit(MAX_DEVICES_PER_IDENTITY));
}
let success =
self.store
.register_device(caller_identity_key, &req.device_id, &req.device_name)?;
Ok(RegisterDeviceResp { success })
}
pub fn list_devices(
&self,
caller_identity_key: &[u8],
) -> Result<ListDevicesResp, DomainError> {
let raw = self.store.list_devices(caller_identity_key)?;
let devices = raw
.into_iter()
.map(|(device_id, device_name, registered_at)| DeviceInfo {
device_id,
device_name,
registered_at,
})
.collect();
Ok(ListDevicesResp { devices })
}
pub fn revoke_device(
&self,
req: RevokeDeviceReq,
caller_identity_key: &[u8],
) -> Result<RevokeDeviceResp, DomainError> {
if req.device_id.is_empty() {
return Err(DomainError::BadParams(
"device_id must not be empty".into(),
));
}
let success = self
.store
.revoke_device(caller_identity_key, &req.device_id)?;
if !success {
return Err(DomainError::DeviceNotFound);
}
Ok(RevokeDeviceResp { success })
}
}

View File

@@ -0,0 +1,208 @@
//! Group management domain logic — metadata, membership tracking.
use std::sync::Arc;
use crate::storage::Store;
use super::types::*;
/// Domain service for group metadata and membership.
pub struct GroupService {
pub store: Arc<dyn Store>,
}
impl GroupService {
/// Update group metadata (name, description, avatar_hash).
pub fn update_metadata(
&self,
req: UpdateGroupMetadataReq,
caller_identity_key: &[u8],
) -> Result<(), DomainError> {
if req.group_id.is_empty() {
return Err(DomainError::BadParams("group_id must not be empty".into()));
}
self.store.store_group_metadata(
&req.group_id,
&req.name,
&req.description,
&req.avatar_hash,
caller_identity_key,
)?;
Ok(())
}
/// List group members with resolved usernames.
pub fn list_members(
&self,
req: ListGroupMembersReq,
) -> Result<ListGroupMembersResp, DomainError> {
if req.group_id.is_empty() {
return Err(DomainError::BadParams("group_id must not be empty".into()));
}
let raw = self.store.list_group_members(&req.group_id)?;
let members = raw
.into_iter()
.map(|(identity_key, joined_at)| {
let username = self
.store
.resolve_identity_key(&identity_key)
.ok()
.flatten()
.unwrap_or_default();
GroupMemberInfo {
identity_key,
username,
joined_at,
}
})
.collect();
Ok(ListGroupMembersResp { members })
}
/// Track a member addition in the server-side membership table.
pub fn add_member(
&self,
group_id: &[u8],
identity_key: &[u8],
) -> Result<(), DomainError> {
self.store.add_group_member(group_id, identity_key)?;
Ok(())
}
/// Track a member removal in the server-side membership table.
pub fn remove_member(
&self,
group_id: &[u8],
identity_key: &[u8],
) -> Result<bool, DomainError> {
let removed = self.store.remove_group_member(group_id, identity_key)?;
Ok(removed)
}
/// Get group metadata.
pub fn get_metadata(
&self,
group_id: &[u8],
) -> Result<Option<GroupMetadata>, DomainError> {
if group_id.is_empty() {
return Err(DomainError::BadParams("group_id must not be empty".into()));
}
match self.store.get_group_metadata(group_id)? {
Some((name, description, avatar_hash, creator_key, created_at)) => {
Ok(Some(GroupMetadata {
group_id: group_id.to_vec(),
name,
description,
avatar_hash,
creator_key,
created_at,
}))
}
None => Ok(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::FileBackedStore;
fn make_service() -> (GroupService, tempfile::TempDir) {
let dir = tempfile::tempdir().expect("tempdir");
let store = FileBackedStore::open(dir.path()).expect("open store");
let svc = GroupService {
store: Arc::new(store),
};
(svc, dir)
}
#[test]
fn update_and_get_metadata() {
let (svc, _dir) = make_service();
let group_id = b"test-group-00001".to_vec();
let caller = b"caller-key".to_vec();
svc.update_metadata(
UpdateGroupMetadataReq {
group_id: group_id.clone(),
name: "Test Group".into(),
description: "A test group".into(),
avatar_hash: vec![0xAB],
},
&caller,
)
.expect("update_metadata should succeed");
let meta = svc.get_metadata(&group_id).expect("get_metadata should succeed");
let meta = meta.expect("metadata should exist");
assert_eq!(meta.name, "Test Group");
assert_eq!(meta.description, "A test group");
assert_eq!(meta.avatar_hash, vec![0xAB]);
assert_eq!(meta.creator_key, caller);
}
#[test]
fn get_metadata_nonexistent_returns_none() {
let (svc, _dir) = make_service();
let result = svc.get_metadata(b"no-such-group").expect("should not error");
assert!(result.is_none());
}
#[test]
fn empty_group_id_rejected() {
let (svc, _dir) = make_service();
let err = svc.update_metadata(
UpdateGroupMetadataReq {
group_id: Vec::new(),
name: "X".into(),
description: String::new(),
avatar_hash: Vec::new(),
},
b"caller",
);
assert!(err.is_err());
}
#[test]
fn add_list_remove_members() {
let (svc, _dir) = make_service();
let group_id = b"membership-group".to_vec();
let member_a = b"member-a-key".to_vec();
let member_b = b"member-b-key".to_vec();
svc.add_member(&group_id, &member_a).expect("add a");
svc.add_member(&group_id, &member_b).expect("add b");
let resp = svc
.list_members(ListGroupMembersReq {
group_id: group_id.clone(),
})
.expect("list members");
assert_eq!(resp.members.len(), 2);
let removed = svc.remove_member(&group_id, &member_a).expect("remove a");
assert!(removed);
let resp = svc
.list_members(ListGroupMembersReq {
group_id: group_id.clone(),
})
.expect("list members after removal");
assert_eq!(resp.members.len(), 1);
assert_eq!(resp.members[0].identity_key, member_b);
}
#[test]
fn remove_nonexistent_member_returns_false() {
let (svc, _dir) = make_service();
let removed = svc
.remove_member(b"group", b"nobody")
.expect("remove nonexistent");
assert!(!removed);
}
}

View File

@@ -0,0 +1,93 @@
//! Key management domain logic — KeyPackage and hybrid key operations.
use std::sync::Arc;
use sha2::{Digest, Sha256};
use crate::storage::Store;
use super::types::*;
const MAX_KEYPACKAGE_BYTES: usize = 1024 * 1024; // 1 MB
/// Domain service for MLS KeyPackage and hybrid (PQ) key management.
pub struct KeyService {
pub store: Arc<dyn Store>,
}
impl KeyService {
pub fn upload_key_package(
&self,
req: UploadKeyPackageReq,
_auth: &CallerAuth,
) -> Result<UploadKeyPackageResp, DomainError> {
if req.identity_key.len() != 32 {
return Err(DomainError::InvalidIdentityKey(req.identity_key.len()));
}
if req.package.is_empty() {
return Err(DomainError::EmptyPackage);
}
if req.package.len() > MAX_KEYPACKAGE_BYTES {
return Err(DomainError::PackageTooLarge(req.package.len()));
}
let fingerprint: Vec<u8> = Sha256::digest(&req.package).to_vec();
self.store
.upload_key_package(&req.identity_key, req.package)?;
Ok(UploadKeyPackageResp { fingerprint })
}
pub fn fetch_key_package(
&self,
req: FetchKeyPackageReq,
_auth: &CallerAuth,
) -> Result<FetchKeyPackageResp, DomainError> {
let package = self.store.fetch_key_package(&req.identity_key)?;
Ok(FetchKeyPackageResp {
package: package.unwrap_or_default(),
})
}
pub fn upload_hybrid_key(
&self,
req: UploadHybridKeyReq,
_auth: &CallerAuth,
) -> Result<(), DomainError> {
if req.identity_key.len() != 32 {
return Err(DomainError::InvalidIdentityKey(req.identity_key.len()));
}
if req.hybrid_public_key.is_empty() {
return Err(DomainError::EmptyHybridKey);
}
self.store
.upload_hybrid_key(&req.identity_key, req.hybrid_public_key)?;
Ok(())
}
pub fn fetch_hybrid_key(
&self,
req: FetchHybridKeyReq,
_auth: &CallerAuth,
) -> Result<FetchHybridKeyResp, DomainError> {
let hybrid_public_key = self
.store
.fetch_hybrid_key(&req.identity_key)?
.unwrap_or_default();
Ok(FetchHybridKeyResp { hybrid_public_key })
}
pub fn fetch_hybrid_keys(
&self,
req: FetchHybridKeysReq,
_auth: &CallerAuth,
) -> Result<FetchHybridKeysResp, DomainError> {
let mut keys = Vec::with_capacity(req.identity_keys.len());
for ik in &req.identity_keys {
let pk = self.store.fetch_hybrid_key(ik)?.unwrap_or_default();
keys.push(pk);
}
Ok(FetchHybridKeysResp { keys })
}
}

View File

@@ -0,0 +1,24 @@
//! Domain types and service logic — protocol-agnostic.
//!
//! These types define the server's business logic independently of any
//! serialization format (Cap'n Proto, Protobuf). RPC handlers translate
//! wire-format messages into these types, call service functions, and
//! translate the results back.
pub mod types;
pub mod auth;
pub mod delivery;
pub mod keys;
pub mod channels;
pub mod users;
pub mod blobs;
pub mod devices;
pub mod groups;
pub mod p2p;
pub mod account;
pub mod moderation;
pub mod notification;
pub mod rate_limit;
pub mod recovery;
#[cfg(feature = "traffic-resistance")]
pub mod traffic_resistance;

View File

@@ -0,0 +1,304 @@
//! Moderation domain logic — report, ban, unban, list.
//!
//! Pure business logic operating on `Store` trait and domain types.
use std::sync::Arc;
use crate::storage::Store;
use super::types::*;
/// Shared state needed by moderation operations.
pub struct ModerationService {
pub store: Arc<dyn Store>,
}
impl ModerationService {
/// Submit an encrypted report for a message.
pub fn report_message(
&self,
req: ReportMessageReq,
) -> Result<ReportMessageResp, DomainError> {
if req.encrypted_report.is_empty() {
return Err(DomainError::BadParams(
"encrypted report must not be empty".into(),
));
}
self.store
.store_report(
&req.encrypted_report,
&req.conversation_id,
&req.reporter_identity,
)
.map_err(DomainError::Storage)?;
tracing::info!(
reporter_prefix = %hex_prefix(&req.reporter_identity),
"audit: message reported"
);
Ok(ReportMessageResp { accepted: true })
}
/// Ban a user by identity key.
pub fn ban_user(&self, req: BanUserReq) -> Result<BanUserResp, DomainError> {
if req.identity_key.len() != 32 {
return Err(DomainError::InvalidIdentityKey(req.identity_key.len()));
}
let expires_at = if req.duration_secs == 0 {
0 // permanent
} else {
now_secs() + req.duration_secs
};
self.store
.ban_user(&req.identity_key, &req.reason, expires_at)
.map_err(DomainError::Storage)?;
tracing::info!(
identity_prefix = %hex_prefix(&req.identity_key),
reason = %req.reason,
expires_at,
"audit: user banned"
);
Ok(BanUserResp { success: true })
}
/// Unban a user by identity key.
pub fn unban_user(&self, req: UnbanUserReq) -> Result<UnbanUserResp, DomainError> {
if req.identity_key.len() != 32 {
return Err(DomainError::InvalidIdentityKey(req.identity_key.len()));
}
let removed = self
.store
.unban_user(&req.identity_key)
.map_err(DomainError::Storage)?;
if removed {
tracing::info!(
identity_prefix = %hex_prefix(&req.identity_key),
"audit: user unbanned"
);
}
Ok(UnbanUserResp { success: removed })
}
/// Check if a user is currently banned.
pub fn check_ban(&self, identity_key: &[u8]) -> Result<Option<String>, DomainError> {
self.store
.is_banned(identity_key)
.map_err(DomainError::Storage)
}
/// List reports with pagination.
pub fn list_reports(&self, req: ListReportsReq) -> Result<ListReportsResp, DomainError> {
let raw = self
.store
.list_reports(req.limit, req.offset)
.map_err(DomainError::Storage)?;
let reports = raw
.into_iter()
.map(
|(id, encrypted_report, conversation_id, reporter_identity, timestamp)| {
ReportEntry {
id,
encrypted_report,
conversation_id,
reporter_identity,
timestamp,
}
},
)
.collect();
Ok(ListReportsResp { reports })
}
/// List all currently banned users.
pub fn list_banned(&self) -> Result<ListBannedResp, DomainError> {
let raw = self.store.list_banned().map_err(DomainError::Storage)?;
let users = raw
.into_iter()
.map(
|(identity_key, reason, banned_at, expires_at)| BannedUserEntry {
identity_key,
reason,
banned_at,
expires_at,
},
)
.collect();
Ok(ListBannedResp { users })
}
}
fn now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn hex_prefix(bytes: &[u8]) -> String {
let len = bytes.len().min(4);
let hex: String = bytes[..len].iter().map(|b| format!("{b:02x}")).collect();
format!("{hex}...")
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::storage::FileBackedStore;
fn test_service() -> (tempfile::TempDir, ModerationService) {
let dir = tempfile::tempdir().unwrap();
let store = Arc::new(FileBackedStore::open(dir.path()).unwrap());
let svc = ModerationService { store };
(dir, svc)
}
#[test]
fn report_store_and_list() {
let (_dir, svc) = test_service();
let resp = svc
.report_message(ReportMessageReq {
encrypted_report: vec![1, 2, 3],
conversation_id: vec![10; 16],
reporter_identity: vec![20; 32],
})
.unwrap();
assert!(resp.accepted);
let reports = svc
.list_reports(ListReportsReq {
limit: 10,
offset: 0,
})
.unwrap();
assert_eq!(reports.reports.len(), 1);
assert_eq!(reports.reports[0].encrypted_report, vec![1, 2, 3]);
assert_eq!(reports.reports[0].conversation_id, vec![10; 16]);
assert_eq!(reports.reports[0].reporter_identity, vec![20; 32]);
}
#[test]
fn report_empty_rejected() {
let (_dir, svc) = test_service();
let result = svc.report_message(ReportMessageReq {
encrypted_report: vec![],
conversation_id: vec![10; 16],
reporter_identity: vec![20; 32],
});
assert!(result.is_err());
}
#[test]
fn ban_unban_lifecycle() {
let (_dir, svc) = test_service();
let ik = vec![1u8; 32];
// Not banned initially.
assert!(svc.check_ban(&ik).unwrap().is_none());
// Ban permanently.
let resp = svc
.ban_user(BanUserReq {
identity_key: ik.clone(),
reason: "spam".into(),
duration_secs: 0,
})
.unwrap();
assert!(resp.success);
// Now banned.
let reason = svc.check_ban(&ik).unwrap();
assert_eq!(reason, Some("spam".to_string()));
// Listed in banned users.
let banned = svc.list_banned().unwrap();
assert_eq!(banned.users.len(), 1);
assert_eq!(banned.users[0].identity_key, ik);
assert_eq!(banned.users[0].reason, "spam");
assert_eq!(banned.users[0].expires_at, 0); // permanent
// Unban.
let resp = svc.unban_user(UnbanUserReq { identity_key: ik.clone() }).unwrap();
assert!(resp.success);
// No longer banned.
assert!(svc.check_ban(&ik).unwrap().is_none());
assert!(svc.list_banned().unwrap().users.is_empty());
}
#[test]
fn ban_invalid_identity_key() {
let (_dir, svc) = test_service();
let result = svc.ban_user(BanUserReq {
identity_key: vec![1u8; 16], // wrong length
reason: "test".into(),
duration_secs: 0,
});
assert!(result.is_err());
}
#[test]
fn list_reports_pagination() {
let (_dir, svc) = test_service();
for i in 0..5u8 {
svc.report_message(ReportMessageReq {
encrypted_report: vec![i],
conversation_id: vec![10; 16],
reporter_identity: vec![20; 32],
})
.unwrap();
}
let page1 = svc
.list_reports(ListReportsReq {
limit: 2,
offset: 0,
})
.unwrap();
assert_eq!(page1.reports.len(), 2);
assert_eq!(page1.reports[0].encrypted_report, vec![0]);
let page2 = svc
.list_reports(ListReportsReq {
limit: 2,
offset: 2,
})
.unwrap();
assert_eq!(page2.reports.len(), 2);
assert_eq!(page2.reports[0].encrypted_report, vec![2]);
let page3 = svc
.list_reports(ListReportsReq {
limit: 2,
offset: 4,
})
.unwrap();
assert_eq!(page3.reports.len(), 1);
}
#[test]
fn unban_nonexistent_returns_false() {
let (_dir, svc) = test_service();
let resp = svc
.unban_user(UnbanUserReq {
identity_key: vec![99u8; 32],
})
.unwrap();
assert!(!resp.success);
}
}

View File

@@ -0,0 +1,131 @@
//! Cross-node notification bus for message delivery fan-out.
//!
//! When a message is enqueued, the bus publishes a notification so that
//! any node running a `fetch_wait` long-poll for that recipient can
//! wake up — even if the enqueue happened on a different node.
//!
//! Two backends:
//! - `InMemoryNotificationBus`: single-node, tokio::sync::Notify (default)
//! - Redis pub/sub (feature-gated `redis-pubsub`, implemented externally)
use std::sync::Arc;
use dashmap::DashMap;
use tokio::sync::Notify;
// ── Trait ────────────────────────────────────────────────────────────────────
/// Cross-node notification bus.
///
/// Publishers call `publish` when a message is enqueued.
/// Subscribers call `subscribe` to get a future that resolves when
/// a notification arrives for the given topic.
pub trait NotificationBus: Send + Sync {
/// Notify all waiters for `topic` that new data is available.
fn publish(&self, topic: &[u8]);
/// Return a future that completes when `topic` receives a notification.
/// The returned `Notify` can be `.notified().await`'d.
fn subscribe(&self, topic: &[u8]) -> Arc<Notify>;
}
// ── In-memory implementation ────────────────────────────────────────────────
/// Single-node notification bus backed by `tokio::sync::Notify`.
///
/// This is the default for single-node deployments. For multi-node,
/// replace with a Redis pub/sub or NATS implementation.
pub struct InMemoryNotificationBus {
waiters: DashMap<Vec<u8>, Arc<Notify>>,
}
impl InMemoryNotificationBus {
pub fn new() -> Self {
Self {
waiters: DashMap::new(),
}
}
}
impl Default for InMemoryNotificationBus {
fn default() -> Self {
Self::new()
}
}
impl NotificationBus for InMemoryNotificationBus {
fn publish(&self, topic: &[u8]) {
if let Some(notify) = self.waiters.get(topic) {
notify.notify_waiters();
}
}
fn subscribe(&self, topic: &[u8]) -> Arc<Notify> {
self.waiters
.entry(topic.to_vec())
.or_insert_with(|| Arc::new(Notify::new()))
.clone()
}
}
/// Create the default notification bus (in-memory, single-node).
pub fn default_notification_bus() -> Arc<dyn NotificationBus> {
Arc::new(InMemoryNotificationBus::new())
}
// ── Tests ───────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn publish_wakes_subscriber() {
let bus = InMemoryNotificationBus::new();
let topic = b"user:alice";
let notify = bus.subscribe(topic);
let notified = notify.notified();
// Publish from another "node" (same process in this case).
bus.publish(topic);
// Should resolve immediately since we published.
tokio::time::timeout(Duration::from_millis(100), notified)
.await
.expect("notification should arrive");
}
#[tokio::test]
async fn no_publish_times_out() {
let bus = InMemoryNotificationBus::new();
let topic = b"user:bob";
let notify = bus.subscribe(topic);
let notified = notify.notified();
let result = tokio::time::timeout(Duration::from_millis(50), notified).await;
assert!(result.is_err(), "should time out without publish");
}
#[tokio::test]
async fn independent_topics() {
let bus = InMemoryNotificationBus::new();
let notify_a = bus.subscribe(b"topic-a");
let notified_a = notify_a.notified();
let notify_b = bus.subscribe(b"topic-b");
let notified_b = notify_b.notified();
// Only publish to topic-a.
bus.publish(b"topic-a");
tokio::time::timeout(Duration::from_millis(100), notified_a)
.await
.expect("topic-a should wake");
let result = tokio::time::timeout(Duration::from_millis(50), notified_b).await;
assert!(result.is_err(), "topic-b should not wake");
}
}

View File

@@ -0,0 +1,50 @@
//! P2P endpoint domain logic — publish, resolve, health.
use std::sync::Arc;
use crate::storage::Store;
use super::types::*;
/// Domain service for P2P endpoint management and health checks.
pub struct P2pService {
pub store: Arc<dyn Store>,
}
impl P2pService {
pub fn publish_endpoint(
&self,
req: PublishEndpointReq,
_auth: &CallerAuth,
) -> Result<(), DomainError> {
if req.identity_key.len() != 32 {
return Err(DomainError::InvalidIdentityKey(req.identity_key.len()));
}
self.store
.publish_endpoint(&req.identity_key, req.node_addr)?;
Ok(())
}
pub fn resolve_endpoint(
&self,
req: ResolveEndpointReq,
_auth: &CallerAuth,
) -> Result<ResolveEndpointResp, DomainError> {
if req.identity_key.len() != 32 {
return Err(DomainError::InvalidIdentityKey(req.identity_key.len()));
}
let node_addr = self
.store
.resolve_endpoint(&req.identity_key)?
.unwrap_or_default();
Ok(ResolveEndpointResp { node_addr })
}
pub fn health() -> HealthResp {
HealthResp {
status: "ok".into(),
}
}
}

View File

@@ -0,0 +1,257 @@
//! Distributed rate limiting — sliding window algorithm.
//!
//! Two backends:
//! - `InMemoryRateLimiter`: single-process, DashMap-based (default)
//! - `RedisRateLimiter`: shared across nodes via Redis (feature-gated `redis-ratelimit`)
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
// ── Public types ────────────────────────────────────────────────────────────
/// Result of a rate-limit check.
#[derive(Debug, Clone)]
pub struct RateResult {
/// Whether the request is allowed.
pub allowed: bool,
/// Remaining requests in the current window.
pub remaining: u32,
/// When the window resets (seconds from now).
pub retry_after_secs: u32,
}
/// Configuration for a specific rate-limit bucket.
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
/// Maximum number of requests in the window.
pub max_requests: u32,
/// Length of the sliding window.
pub window: Duration,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window: Duration::from_secs(60),
}
}
}
// ── Trait ────────────────────────────────────────────────────────────────────
/// Abstraction over rate-limit backends.
pub trait RateLimiter: Send + Sync {
/// Check whether `key` is within its rate limit. If allowed, the counter
/// is incremented atomically.
fn check_rate(&self, key: &str, config: &RateLimitConfig) -> RateResult;
}
// ── In-memory sliding window ────────────────────────────────────────────────
/// Per-key state for the sliding window algorithm.
struct SlidingWindow {
/// Timestamps of recent requests within the window.
timestamps: Vec<u64>,
}
/// In-memory rate limiter using a sliding window log.
pub struct InMemoryRateLimiter {
buckets: DashMap<String, SlidingWindow>,
/// Last time we ran GC on expired entries.
last_gc: std::sync::Mutex<Instant>,
}
impl InMemoryRateLimiter {
pub fn new() -> Self {
Self {
buckets: DashMap::new(),
last_gc: std::sync::Mutex::new(Instant::now()),
}
}
fn now_millis() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
/// Remove entries whose entire window has expired. Called lazily.
fn gc_if_needed(&self, window: Duration) {
let should_gc = {
let Ok(last) = self.last_gc.lock() else {
return;
};
last.elapsed() > Duration::from_secs(60)
};
if !should_gc {
return;
}
if let Ok(mut last) = self.last_gc.lock() {
*last = Instant::now();
}
let now_ms = Self::now_millis();
let window_ms = window.as_millis() as u64;
self.buckets.retain(|_key, window_state| {
// Keep if any timestamp is within the window.
window_state
.timestamps
.iter()
.any(|&ts| now_ms.saturating_sub(ts) < window_ms)
});
}
}
impl Default for InMemoryRateLimiter {
fn default() -> Self {
Self::new()
}
}
impl RateLimiter for InMemoryRateLimiter {
fn check_rate(&self, key: &str, config: &RateLimitConfig) -> RateResult {
let now_ms = Self::now_millis();
let window_ms = config.window.as_millis() as u64;
self.gc_if_needed(config.window);
let mut entry = self.buckets.entry(key.to_string()).or_insert(SlidingWindow {
timestamps: Vec::new(),
});
// Evict timestamps outside the sliding window.
let cutoff = now_ms.saturating_sub(window_ms);
entry.timestamps.retain(|&ts| ts > cutoff);
let count = entry.timestamps.len() as u32;
if count >= config.max_requests {
// Find earliest timestamp to compute retry-after.
let earliest = entry.timestamps.iter().copied().min().unwrap_or(now_ms);
let retry_after_ms = (earliest + window_ms).saturating_sub(now_ms);
return RateResult {
allowed: false,
remaining: 0,
retry_after_secs: (retry_after_ms / 1000).max(1) as u32,
};
}
entry.timestamps.push(now_ms);
let remaining = config.max_requests.saturating_sub(count + 1);
RateResult {
allowed: true,
remaining,
retry_after_secs: 0,
}
}
}
/// Create the default rate limiter (in-memory).
pub fn default_rate_limiter() -> Arc<dyn RateLimiter> {
Arc::new(InMemoryRateLimiter::new())
}
// ── Tests ───────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_within_limit() {
let limiter = InMemoryRateLimiter::new();
let config = RateLimitConfig {
max_requests: 3,
window: Duration::from_secs(60),
};
for _ in 0..3 {
let result = limiter.check_rate("user1", &config);
assert!(result.allowed);
}
}
#[test]
fn blocks_over_limit() {
let limiter = InMemoryRateLimiter::new();
let config = RateLimitConfig {
max_requests: 2,
window: Duration::from_secs(60),
};
assert!(limiter.check_rate("user1", &config).allowed);
assert!(limiter.check_rate("user1", &config).allowed);
let result = limiter.check_rate("user1", &config);
assert!(!result.allowed);
assert_eq!(result.remaining, 0);
assert!(result.retry_after_secs > 0);
}
#[test]
fn independent_keys() {
let limiter = InMemoryRateLimiter::new();
let config = RateLimitConfig {
max_requests: 1,
window: Duration::from_secs(60),
};
assert!(limiter.check_rate("user1", &config).allowed);
assert!(!limiter.check_rate("user1", &config).allowed);
// Different key should still be allowed.
assert!(limiter.check_rate("user2", &config).allowed);
}
#[test]
fn remaining_decreases() {
let limiter = InMemoryRateLimiter::new();
let config = RateLimitConfig {
max_requests: 5,
window: Duration::from_secs(60),
};
let r1 = limiter.check_rate("user1", &config);
assert_eq!(r1.remaining, 4);
let r2 = limiter.check_rate("user1", &config);
assert_eq!(r2.remaining, 3);
}
#[test]
fn concurrent_access_is_safe() {
use std::sync::Arc;
use std::thread;
let limiter = Arc::new(InMemoryRateLimiter::new());
let config = RateLimitConfig {
max_requests: 1000,
window: Duration::from_secs(60),
};
let handles: Vec<_> = (0..10)
.map(|_| {
let l = Arc::clone(&limiter);
let c = config.clone();
thread::spawn(move || {
for _ in 0..100 {
l.check_rate("shared_key", &c);
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
// After 1000 requests exactly, next should be blocked.
let result = limiter.check_rate("shared_key", &config);
assert!(!result.allowed);
}
}

View File

@@ -0,0 +1,76 @@
//! Recovery domain logic — encrypted recovery bundle CRUD.
use std::sync::Arc;
use crate::storage::Store;
use super::types::DomainError;
/// Maximum recovery bundle size: 64 KiB.
const MAX_BUNDLE_SIZE: usize = 64 * 1024;
/// Default TTL for recovery bundles: 90 days.
pub const DEFAULT_TTL_SECS: u64 = 90 * 24 * 60 * 60;
/// Domain service for recovery bundle operations.
pub struct RecoveryService {
pub store: Arc<dyn Store>,
}
impl RecoveryService {
/// Store an encrypted recovery bundle.
///
/// `token_hash` is the SHA-256 of a recovery token derived from the code.
/// `bundle` is the encrypted blob (opaque to server).
/// `ttl_secs` is the time-to-live; 0 uses the default (90 days).
pub fn store_bundle(
&self,
token_hash: &[u8],
bundle: Vec<u8>,
ttl_secs: u64,
) -> Result<(), DomainError> {
if token_hash.len() != 32 {
return Err(DomainError::BadParams(format!(
"token_hash must be 32 bytes, got {}",
token_hash.len()
)));
}
if bundle.is_empty() {
return Err(DomainError::BadParams("recovery bundle must not be empty".into()));
}
if bundle.len() > MAX_BUNDLE_SIZE {
return Err(DomainError::BadParams(format!(
"recovery bundle exceeds max size ({} > {MAX_BUNDLE_SIZE})",
bundle.len()
)));
}
let ttl = if ttl_secs == 0 { DEFAULT_TTL_SECS } else { ttl_secs };
self.store.store_recovery_bundle(token_hash, bundle, ttl)?;
Ok(())
}
/// Fetch an encrypted recovery bundle by token_hash.
pub fn fetch_bundle(&self, token_hash: &[u8]) -> Result<Option<Vec<u8>>, DomainError> {
if token_hash.len() != 32 {
return Err(DomainError::BadParams(format!(
"token_hash must be 32 bytes, got {}",
token_hash.len()
)));
}
let bundle = self.store.get_recovery_bundle(token_hash)?;
Ok(bundle)
}
/// Delete an encrypted recovery bundle by token_hash.
pub fn delete_bundle(&self, token_hash: &[u8]) -> Result<bool, DomainError> {
if token_hash.len() != 32 {
return Err(DomainError::BadParams(format!(
"token_hash must be 32 bytes, got {}",
token_hash.len()
)));
}
let deleted = self.store.delete_recovery_bundle(token_hash)?;
Ok(deleted)
}
}

View File

@@ -0,0 +1,249 @@
//! Traffic analysis resistance — decoy traffic generation and timing jitter.
//!
//! When enabled (via the `traffic-resistance` feature), the server:
//!
//! 1. Pads all enqueued payloads to a uniform boundary using [`quicprochat_core::padding::pad_uniform`].
//! 2. Injects random jitter delays before enqueue responses to mask timing patterns.
//! 3. Runs a background decoy traffic generator that enqueues fake encrypted messages
//! at a configurable rate to connected recipients.
//!
//! Decoy messages are indistinguishable from real padded messages on the wire.
//! Recipients detect and discard them by unpadding to an empty payload.
use std::sync::Arc;
use rand::Rng;
use tokio::sync::Notify;
use super::delivery::DeliveryService;
use super::types::EnqueueReq;
/// Configuration for traffic analysis resistance.
#[derive(Clone, Debug)]
pub struct TrafficResistanceConfig {
/// Padding boundary in bytes (default 256). All enqueued payloads are
/// padded to the nearest multiple of this value.
pub padding_boundary: usize,
/// Mean interval in milliseconds between decoy messages per recipient.
/// Set to 0 to disable decoy traffic.
pub decoy_interval_ms: u64,
/// Maximum random jitter in milliseconds added before enqueue responses.
/// Set to 0 to disable jitter.
pub jitter_max_ms: u64,
}
impl Default for TrafficResistanceConfig {
fn default() -> Self {
Self {
padding_boundary: quicprochat_core::padding::DEFAULT_PADDING_BOUNDARY,
decoy_interval_ms: 5000,
jitter_max_ms: 50,
}
}
}
/// Pad a payload to the configured uniform boundary.
pub fn pad_payload(payload: &[u8], config: &TrafficResistanceConfig) -> Vec<u8> {
quicprochat_core::padding::pad_uniform(payload, config.padding_boundary)
}
/// Apply random jitter delay to mask timing patterns.
///
/// Sleeps for a random duration in `[0, config.jitter_max_ms)` milliseconds.
/// Does nothing if `jitter_max_ms` is 0.
pub async fn apply_jitter(config: &TrafficResistanceConfig) {
if config.jitter_max_ms == 0 {
return;
}
let jitter_ms = rand::thread_rng().gen_range(0..config.jitter_max_ms);
if jitter_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(jitter_ms)).await;
}
}
/// Spawn a background task that generates decoy traffic.
///
/// Sends decoy messages to the provided `recipient_keys` at random intervals
/// around `config.decoy_interval_ms`. The task runs until `shutdown` is notified.
///
/// Returns a `JoinHandle` for the spawned task.
pub fn spawn_decoy_generator(
delivery: Arc<DeliveryService>,
recipient_keys: Vec<Vec<u8>>,
channel_id: Vec<u8>,
config: TrafficResistanceConfig,
shutdown: Arc<Notify>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
if config.decoy_interval_ms == 0 || recipient_keys.is_empty() {
// Decoy traffic disabled or no recipients — wait for shutdown.
shutdown.notified().await;
return;
}
let base_interval = std::time::Duration::from_millis(config.decoy_interval_ms);
loop {
// Randomize interval: 50%150% of base to avoid periodic patterns.
let jitter_factor: f64 = rand::thread_rng().gen_range(0.5..1.5);
let interval = base_interval.mul_f64(jitter_factor);
tokio::select! {
() = tokio::time::sleep(interval) => {}
() = shutdown.notified() => {
tracing::debug!("decoy traffic generator shutting down");
return;
}
}
// Pick a random recipient.
let idx = rand::thread_rng().gen_range(0..recipient_keys.len());
let recipient_key = &recipient_keys[idx];
// Generate a decoy payload that is indistinguishable from a real padded message.
let decoy = quicprochat_core::padding::generate_decoy(config.padding_boundary);
let req = EnqueueReq {
recipient_key: recipient_key.clone(),
payload: decoy,
channel_id: channel_id.clone(),
ttl_secs: 60, // Short TTL for decoys.
};
match delivery.enqueue(req) {
Ok(_) => {
tracing::trace!("decoy message injected");
}
Err(e) => {
tracing::warn!(error = %e, "failed to inject decoy message");
}
}
}
})
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::storage::FileBackedStore;
use dashmap::DashMap;
fn test_delivery() -> (tempfile::TempDir, Arc<DeliveryService>) {
let dir = tempfile::tempdir().unwrap();
let store = Arc::new(FileBackedStore::open(dir.path()).unwrap());
let svc = Arc::new(DeliveryService {
store,
waiters: Arc::new(DashMap::new()),
});
(dir, svc)
}
#[test]
fn pad_payload_is_boundary_aligned() {
let config = TrafficResistanceConfig {
padding_boundary: 256,
..Default::default()
};
let payload = b"test message";
let padded = pad_payload(payload, &config);
assert_eq!(padded.len() % 256, 0);
// Unpad should recover original.
let unpadded = quicprochat_core::padding::unpad_uniform(&padded).unwrap();
assert_eq!(unpadded, payload);
}
#[test]
fn pad_payload_custom_boundary() {
let config = TrafficResistanceConfig {
padding_boundary: 512,
..Default::default()
};
let payload = vec![0xAA; 300];
let padded = pad_payload(&payload, &config);
assert_eq!(padded.len() % 512, 0);
assert_eq!(padded.len(), 512);
}
#[tokio::test]
async fn jitter_zero_is_noop() {
let config = TrafficResistanceConfig {
jitter_max_ms: 0,
..Default::default()
};
let start = std::time::Instant::now();
apply_jitter(&config).await;
// Should return almost immediately.
assert!(start.elapsed() < std::time::Duration::from_millis(5));
}
#[tokio::test]
async fn decoy_generator_produces_messages() {
let (_dir, delivery) = test_delivery();
let recipient = vec![0xFFu8; 32];
let channel = vec![0u8; 16];
let shutdown = Arc::new(Notify::new());
let config = TrafficResistanceConfig {
padding_boundary: 256,
decoy_interval_ms: 50, // Fast interval for testing.
jitter_max_ms: 0,
};
let handle = spawn_decoy_generator(
Arc::clone(&delivery),
vec![recipient.clone()],
channel.clone(),
config,
Arc::clone(&shutdown),
);
// Wait enough time for at least one decoy.
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
shutdown.notify_one();
handle.await.unwrap();
// Check that decoy messages were enqueued.
let fetched = delivery
.fetch(super::super::types::FetchReq {
recipient_key: recipient,
channel_id: channel,
limit: 100,
})
.unwrap();
assert!(!fetched.payloads.is_empty(), "decoy generator should have enqueued at least one message");
// Every decoy should unpad to an empty payload.
for env in &fetched.payloads {
let unpadded = quicprochat_core::padding::unpad_uniform(&env.data).unwrap();
assert!(unpadded.is_empty(), "decoy payload should unpad to empty");
}
}
#[tokio::test]
async fn decoy_generator_disabled_when_zero_interval() {
let (_dir, delivery) = test_delivery();
let shutdown = Arc::new(Notify::new());
let config = TrafficResistanceConfig {
decoy_interval_ms: 0,
..Default::default()
};
let handle = spawn_decoy_generator(
delivery,
vec![vec![1u8; 32]],
vec![0u8; 16],
config,
Arc::clone(&shutdown),
);
// Signal shutdown immediately — should return without having sent anything.
shutdown.notify_one();
handle.await.unwrap();
}
}

View File

@@ -0,0 +1,441 @@
//! Plain Rust request/response types for server domain logic.
//!
//! No proto, no capnp — just Rust structs.
use crate::storage::StorageError;
// ── Domain Error ────────────────────────────────────────────────────────────
/// Errors returned by domain service methods.
#[derive(thiserror::Error, Debug)]
pub enum DomainError {
#[error("identity key must be exactly 32 bytes, got {0}")]
InvalidIdentityKey(usize),
#[error("key package must not be empty")]
EmptyPackage,
#[error("key package exceeds max size ({0} bytes)")]
PackageTooLarge(usize),
#[error("hybrid public key must not be empty")]
EmptyHybridKey,
#[error("username must not be empty")]
EmptyUsername,
#[error("blob hash must be exactly 32 bytes, got {0}")]
BlobHashLength(usize),
#[error("blob exceeds max size ({0} bytes)")]
BlobTooLarge(u64),
#[error("SHA-256 of uploaded data does not match blob hash")]
BlobHashMismatch,
#[error("blob not found")]
BlobNotFound,
#[error("maximum {0} devices per identity")]
DeviceLimit(usize),
#[error("device not found")]
DeviceNotFound,
#[error("group not found")]
GroupNotFound,
#[error("bad parameters: {0}")]
BadParams(String),
#[error("I/O error: {0}")]
Io(String),
#[error("storage error: {0}")]
Storage(#[from] StorageError),
}
// ── Auth ─────────────────────────────────────────────────────────────────────
/// Caller authentication context (resolved from session token).
#[derive(Debug, Clone)]
pub struct CallerAuth {
/// Ed25519 identity key of the authenticated caller (32 bytes).
pub identity_key: Vec<u8>,
/// Session token bytes.
pub token: Vec<u8>,
/// Device ID (optional, for auditing).
pub device_id: Option<Vec<u8>>,
}
/// OPAQUE registration start.
pub struct RegisterStartReq {
pub username: String,
pub request_bytes: Vec<u8>,
}
pub struct RegisterStartResp {
pub response_bytes: Vec<u8>,
}
/// OPAQUE registration finish.
pub struct RegisterFinishReq {
pub username: String,
pub upload_bytes: Vec<u8>,
pub identity_key: Vec<u8>,
}
pub struct RegisterFinishResp {
pub success: bool,
}
/// OPAQUE login start.
pub struct LoginStartReq {
pub username: String,
pub request_bytes: Vec<u8>,
}
pub struct LoginStartResp {
pub response_bytes: Vec<u8>,
}
/// OPAQUE login finish.
pub struct LoginFinishReq {
pub username: String,
pub finalization_bytes: Vec<u8>,
pub identity_key: Vec<u8>,
}
pub struct LoginFinishResp {
pub session_token: Vec<u8>,
}
// ── Delivery ─────────────────────────────────────────────────────────────────
/// An envelope pairing a sequence number with an opaque payload.
#[derive(Debug, Clone)]
pub struct Envelope {
pub seq: u64,
pub data: Vec<u8>,
}
pub struct EnqueueReq {
pub recipient_key: Vec<u8>,
pub payload: Vec<u8>,
pub channel_id: Vec<u8>,
pub ttl_secs: u32,
}
pub struct EnqueueResp {
pub seq: u64,
pub delivery_proof: Vec<u8>,
}
pub struct FetchReq {
pub recipient_key: Vec<u8>,
pub channel_id: Vec<u8>,
pub limit: u32,
}
pub struct FetchResp {
pub payloads: Vec<Envelope>,
}
pub struct PeekReq {
pub recipient_key: Vec<u8>,
pub channel_id: Vec<u8>,
pub limit: u32,
}
pub struct PeekResp {
pub payloads: Vec<Envelope>,
}
pub struct AckReq {
pub recipient_key: Vec<u8>,
pub channel_id: Vec<u8>,
pub seq_up_to: u64,
}
pub struct BatchEnqueueReq {
pub recipient_keys: Vec<Vec<u8>>,
pub payload: Vec<u8>,
pub channel_id: Vec<u8>,
pub ttl_secs: u32,
}
pub struct BatchEnqueueResp {
pub seqs: Vec<u64>,
}
// ── Keys ─────────────────────────────────────────────────────────────────────
pub struct UploadKeyPackageReq {
pub identity_key: Vec<u8>,
pub package: Vec<u8>,
}
pub struct UploadKeyPackageResp {
pub fingerprint: Vec<u8>,
}
pub struct FetchKeyPackageReq {
pub identity_key: Vec<u8>,
}
pub struct FetchKeyPackageResp {
pub package: Vec<u8>,
}
pub struct UploadHybridKeyReq {
pub identity_key: Vec<u8>,
pub hybrid_public_key: Vec<u8>,
}
pub struct FetchHybridKeyReq {
pub identity_key: Vec<u8>,
}
pub struct FetchHybridKeyResp {
pub hybrid_public_key: Vec<u8>,
}
pub struct FetchHybridKeysReq {
pub identity_keys: Vec<Vec<u8>>,
}
pub struct FetchHybridKeysResp {
pub keys: Vec<Vec<u8>>,
}
// ── Key Transparency / Revocation ────────────────────────────────────
pub struct RevokeKeyReq {
pub identity_key: Vec<u8>,
pub reason: String,
}
pub struct RevokeKeyResp {
pub success: bool,
pub leaf_index: u64,
}
pub struct CheckRevocationReq {
pub identity_key: Vec<u8>,
}
pub struct CheckRevocationResp {
pub revoked: bool,
pub reason: String,
pub timestamp_ms: u64,
}
pub struct AuditKeyTransparencyReq {
pub start: u64,
pub end: u64,
}
pub struct AuditLogEntry {
pub index: u64,
pub leaf_hash: Vec<u8>,
}
pub struct AuditKeyTransparencyResp {
pub entries: Vec<AuditLogEntry>,
pub tree_size: u64,
pub root: Vec<u8>,
}
// ── Channel ──────────────────────────────────────────────────────────────────
pub struct CreateChannelReq {
pub peer_key: Vec<u8>,
}
pub struct CreateChannelResp {
pub channel_id: Vec<u8>,
pub was_new: bool,
}
// ── User ─────────────────────────────────────────────────────────────────────
pub struct ResolveUserReq {
pub username: String,
}
pub struct ResolveUserResp {
pub identity_key: Vec<u8>,
pub inclusion_proof: Vec<u8>,
}
pub struct ResolveIdentityReq {
pub identity_key: Vec<u8>,
}
pub struct ResolveIdentityResp {
pub username: String,
}
// ── Blob ─────────────────────────────────────────────────────────────────────
pub struct UploadBlobReq {
pub blob_hash: Vec<u8>,
pub chunk: Vec<u8>,
pub offset: u64,
pub total_size: u64,
pub mime_type: String,
}
pub struct UploadBlobResp {
pub blob_id: Vec<u8>,
}
pub struct DownloadBlobReq {
pub blob_id: Vec<u8>,
pub offset: u64,
pub length: u32,
}
pub struct DownloadBlobResp {
pub chunk: Vec<u8>,
pub total_size: u64,
pub mime_type: String,
}
// ── Device ───────────────────────────────────────────────────────────────────
pub struct RegisterDeviceReq {
pub device_id: Vec<u8>,
pub device_name: String,
}
pub struct RegisterDeviceResp {
pub success: bool,
}
pub struct DeviceInfo {
pub device_id: Vec<u8>,
pub device_name: String,
pub registered_at: u64,
}
pub struct ListDevicesResp {
pub devices: Vec<DeviceInfo>,
}
pub struct RevokeDeviceReq {
pub device_id: Vec<u8>,
}
pub struct RevokeDeviceResp {
pub success: bool,
}
// ── Group metadata ───────────────────────────────────────────────────
pub struct GroupMetadata {
pub group_id: Vec<u8>,
pub name: String,
pub description: String,
pub avatar_hash: Vec<u8>,
pub creator_key: Vec<u8>,
pub created_at: u64,
}
pub struct UpdateGroupMetadataReq {
pub group_id: Vec<u8>,
pub name: String,
pub description: String,
pub avatar_hash: Vec<u8>,
}
pub struct ListGroupMembersReq {
pub group_id: Vec<u8>,
}
pub struct GroupMemberInfo {
pub identity_key: Vec<u8>,
pub username: String,
pub joined_at: u64,
}
pub struct ListGroupMembersResp {
pub members: Vec<GroupMemberInfo>,
}
// ── Moderation ───────────────────────────────────────────────────────────────
pub struct ReportMessageReq {
pub encrypted_report: Vec<u8>,
pub conversation_id: Vec<u8>,
pub reporter_identity: Vec<u8>,
}
pub struct ReportMessageResp {
pub accepted: bool,
}
pub struct BanUserReq {
pub identity_key: Vec<u8>,
pub reason: String,
pub duration_secs: u64,
}
pub struct BanUserResp {
pub success: bool,
}
pub struct UnbanUserReq {
pub identity_key: Vec<u8>,
}
pub struct UnbanUserResp {
pub success: bool,
}
pub struct ListReportsReq {
pub limit: u32,
pub offset: u32,
}
pub struct ReportEntry {
pub id: u64,
pub encrypted_report: Vec<u8>,
pub conversation_id: Vec<u8>,
pub reporter_identity: Vec<u8>,
pub timestamp: u64,
}
pub struct ListReportsResp {
pub reports: Vec<ReportEntry>,
}
pub struct BannedUserEntry {
pub identity_key: Vec<u8>,
pub reason: String,
pub banned_at: u64,
pub expires_at: u64,
}
pub struct ListBannedResp {
pub users: Vec<BannedUserEntry>,
}
// ── P2P ──────────────────────────────────────────────────────────────────────
pub struct PublishEndpointReq {
pub identity_key: Vec<u8>,
pub node_addr: Vec<u8>,
}
pub struct ResolveEndpointReq {
pub identity_key: Vec<u8>,
}
pub struct ResolveEndpointResp {
pub node_addr: Vec<u8>,
}
pub struct HealthResp {
pub status: String,
}

View File

@@ -0,0 +1,146 @@
//! User resolution domain logic — username <-> identity key lookups.
use std::sync::{Arc, Mutex};
use quicprochat_kt::{MerkleLog, RevocationLog, RevocationReason};
use crate::storage::Store;
use super::types::*;
/// Domain service for user/identity resolution.
pub struct UserService {
pub store: Arc<dyn Store>,
pub kt_log: Arc<Mutex<MerkleLog>>,
pub revocation_log: Arc<Mutex<RevocationLog>>,
}
impl UserService {
pub fn resolve_user(&self, req: ResolveUserReq) -> Result<ResolveUserResp, DomainError> {
if req.username.is_empty() {
return Err(DomainError::EmptyUsername);
}
let identity_key = self
.store
.get_user_identity_key(&req.username)?
.unwrap_or_default();
let mut inclusion_proof = Vec::new();
if !identity_key.is_empty() {
if let Ok(log) = self.kt_log.lock() {
if let Some(leaf_idx) = log.find(&req.username, &identity_key) {
if let Ok(proof) = log.inclusion_proof(leaf_idx) {
if let Ok(bytes) = proof.to_bytes() {
inclusion_proof = bytes;
}
}
}
}
}
Ok(ResolveUserResp {
identity_key,
inclusion_proof,
})
}
pub fn resolve_identity(
&self,
req: ResolveIdentityReq,
) -> Result<ResolveIdentityResp, DomainError> {
if req.identity_key.len() != 32 {
return Err(DomainError::InvalidIdentityKey(req.identity_key.len()));
}
let username = self
.store
.resolve_identity_key(&req.identity_key)?
.unwrap_or_default();
Ok(ResolveIdentityResp { username })
}
/// Revoke an identity key in the Key Transparency log.
pub fn revoke_key(&self, req: RevokeKeyReq) -> Result<RevokeKeyResp, DomainError> {
if req.identity_key.len() != 32 {
return Err(DomainError::InvalidIdentityKey(req.identity_key.len()));
}
let reason = RevocationReason::from_tag(&req.reason)
.ok_or_else(|| DomainError::BadParams(format!("invalid revocation reason: {}", req.reason)))?;
let timestamp_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let mut kt = self.kt_log.lock().map_err(|e| DomainError::Io(e.to_string()))?;
let mut revlog = self.revocation_log.lock().map_err(|e| DomainError::Io(e.to_string()))?;
let leaf_index = revlog
.revoke(&mut kt, &req.identity_key, reason, timestamp_ms)
.map_err(|e| DomainError::BadParams(e.to_string()))?;
// Persist updated logs.
if let Ok(bytes) = kt.to_bytes() {
let _ = self.store.save_kt_log(bytes);
}
if let Ok(bytes) = revlog.to_bytes() {
let _ = self.store.save_revocation_log(bytes);
}
Ok(RevokeKeyResp {
success: true,
leaf_index,
})
}
/// Check if an identity key has been revoked.
pub fn check_revocation(&self, req: CheckRevocationReq) -> Result<CheckRevocationResp, DomainError> {
let revlog = self.revocation_log.lock().map_err(|e| DomainError::Io(e.to_string()))?;
if let Some(entry) = revlog.get(&req.identity_key) {
Ok(CheckRevocationResp {
revoked: true,
reason: entry.reason.as_tag().to_string(),
timestamp_ms: entry.timestamp_ms,
})
} else {
Ok(CheckRevocationResp {
revoked: false,
reason: String::new(),
timestamp_ms: 0,
})
}
}
/// Return a range of KT log entries for client-side audit.
pub fn audit_key_transparency(
&self,
req: AuditKeyTransparencyReq,
) -> Result<AuditKeyTransparencyResp, DomainError> {
let kt = self.kt_log.lock().map_err(|e| DomainError::Io(e.to_string()))?;
let end = if req.end == 0 { kt.len() } else { req.end };
let log_entries = kt.audit_log(req.start, end);
let entries: Vec<AuditLogEntry> = log_entries
.into_iter()
.map(|(index, hash)| AuditLogEntry {
index,
leaf_hash: hash.to_vec(),
})
.collect();
let tree_size = kt.len();
let root = kt.root().map(|r| r.to_vec()).unwrap_or_default();
Ok(AuditKeyTransparencyResp {
entries,
tree_size,
root,
})
}
}

View File

@@ -0,0 +1,46 @@
//! Structured error codes for server RPC responses.
//!
//! Every `capnp::Error::failed()` message is prefixed with a stable code
//! (E001E020) so clients can match on the code without parsing free-text.
pub const E001_BAD_AUTH_VERSION: &str = "E001";
pub const E002_EMPTY_TOKEN: &str = "E002";
pub const E003_INVALID_TOKEN: &str = "E003";
pub const E004_IDENTITY_KEY_LENGTH: &str = "E004";
pub const E005_PAYLOAD_EMPTY: &str = "E005";
pub const E006_PAYLOAD_TOO_LARGE: &str = "E006";
pub const E007_PACKAGE_EMPTY: &str = "E007";
pub const E008_PACKAGE_TOO_LARGE: &str = "E008";
pub const E009_STORAGE_ERROR: &str = "E009";
pub const E010_OPAQUE_ERROR: &str = "E010";
pub const E011_USERNAME_EMPTY: &str = "E011";
pub const E012_WIRE_VERSION: &str = "E012";
pub const E013_HYBRID_KEY_EMPTY: &str = "E013";
pub const E014_RATE_LIMITED: &str = "E014";
pub const E015_QUEUE_FULL: &str = "E015";
pub const E016_IDENTITY_MISMATCH: &str = "E016";
pub const E017_SESSION_EXPIRED: &str = "E017";
pub const E018_USER_EXISTS: &str = "E018";
pub const E019_NO_PENDING_LOGIN: &str = "E019";
pub const E020_BAD_PARAMS: &str = "E020";
pub const E021_CIPHERSUITE_NOT_ALLOWED: &str = "E021";
pub const E022_CHANNEL_ACCESS_DENIED: &str = "E022";
pub const E023_CHANNEL_NOT_FOUND: &str = "E023";
pub const E024_BLOB_TOO_LARGE: &str = "E024";
pub const E025_BLOB_HASH_LENGTH: &str = "E025";
pub const E026_BLOB_HASH_MISMATCH: &str = "E026";
pub const E027_BLOB_NOT_FOUND: &str = "E027";
pub const E028_ACCOUNT_DELETION_FAILED: &str = "E028";
pub const E029_DEVICE_LIMIT: &str = "E029";
pub const E030_DEVICE_NOT_FOUND: &str = "E030";
#[allow(dead_code)] // used by v2 RPC moderation handlers
pub const E031_USER_BANNED: &str = "E031";
#[allow(dead_code)] // used by v2 RPC moderation handlers
pub const E032_REPORT_EMPTY: &str = "E032";
#[allow(dead_code)] // used by v2 RPC moderation handlers
pub const E033_ADMIN_REQUIRED: &str = "E033";
/// Build a `capnp::Error::failed()` with the structured code prefix.
pub fn coded_error(code: &str, msg: impl std::fmt::Display) -> capnp::Error {
capnp::Error::failed(format!("{code}: {msg}"))
}

View File

@@ -0,0 +1,80 @@
//! Parse `username@domain` federated addresses.
//!
//! A bare `username` (no `@`) is treated as local.
#![allow(dead_code)] // federation not yet wired up
/// A parsed federated address.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FederatedAddress {
pub username: String,
pub domain: Option<String>,
}
impl FederatedAddress {
/// Parse a `user@domain` string. Bare `user` → domain is `None`.
pub fn parse(input: &str) -> Self {
// Split on the *last* '@' so usernames can contain '@' in theory.
match input.rsplit_once('@') {
Some((user, domain)) if !domain.is_empty() && !user.is_empty() => Self {
username: user.to_string(),
domain: Some(domain.to_string()),
},
_ => Self {
username: input.to_string(),
domain: None,
},
}
}
/// Returns true if this address refers to a local user (no domain or domain matches local).
pub fn is_local(&self, local_domain: &str) -> bool {
match &self.domain {
None => true,
Some(d) => d == local_domain,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bare_username() {
let addr = FederatedAddress::parse("alice");
assert_eq!(addr.username, "alice");
assert_eq!(addr.domain, None);
assert!(addr.is_local("example.com"));
}
#[test]
fn user_at_domain() {
let addr = FederatedAddress::parse("alice@remote.example.com");
assert_eq!(addr.username, "alice");
assert_eq!(addr.domain, Some("remote.example.com".into()));
assert!(!addr.is_local("local.example.com"));
assert!(addr.is_local("remote.example.com"));
}
#[test]
fn trailing_at_is_bare() {
let addr = FederatedAddress::parse("alice@");
assert_eq!(addr.username, "alice@");
assert_eq!(addr.domain, None);
}
#[test]
fn leading_at_is_bare() {
let addr = FederatedAddress::parse("@domain.com");
assert_eq!(addr.username, "@domain.com");
assert_eq!(addr.domain, None);
}
#[test]
fn multiple_at_uses_last() {
let addr = FederatedAddress::parse("user@org@domain.com");
assert_eq!(addr.username, "user@org");
assert_eq!(addr.domain, Some("domain.com".into()));
}
}

View File

@@ -0,0 +1,288 @@
//! Outbound federation client: connects to peer servers to relay messages.
//!
//! Uses a lazy connection pool (DashMap) to reuse QUIC connections to known peers.
#![allow(dead_code)] // federation not yet wired up
use std::collections::HashMap;
use std::net::SocketAddr;
use anyhow::Context;
use dashmap::DashMap;
use quinn::Endpoint;
use crate::config::EffectiveFederationConfig;
/// Outbound federation client for relaying to peer servers.
pub struct FederationClient {
/// Peer domain → address mapping from config.
peer_addresses: HashMap<String, SocketAddr>,
/// Lazy QUIC connection pool: domain → active Connection.
connections: DashMap<String, quinn::Connection>,
/// Local QUIC endpoint (shared for all outbound federation connections).
endpoint: Endpoint,
/// Local domain (for the FederationAuth.origin field).
local_domain: String,
}
impl FederationClient {
/// Create a new federation client from config.
///
/// The `endpoint` should be configured with mTLS client credentials.
pub fn new(
config: &EffectiveFederationConfig,
endpoint: Endpoint,
) -> anyhow::Result<Self> {
let mut peer_addresses = HashMap::new();
for peer in &config.peers {
let addr: SocketAddr = peer.address.parse().with_context(|| {
format!("parse federation peer address '{}' for '{}'", peer.address, peer.domain)
})?;
peer_addresses.insert(peer.domain.clone(), addr);
}
Ok(Self {
peer_addresses,
connections: DashMap::new(),
endpoint,
local_domain: config.domain.clone(),
})
}
/// Check if we have a configured peer for the given domain.
pub fn has_peer(&self, domain: &str) -> bool {
self.peer_addresses.contains_key(domain)
}
/// List all configured peer domains.
pub fn peer_domains(&self) -> Vec<String> {
self.peer_addresses.keys().cloned().collect()
}
/// Get the local domain.
pub fn local_domain(&self) -> &str {
&self.local_domain
}
/// Relay a single enqueue to a remote peer. Returns the seq assigned by the remote server.
pub async fn relay_enqueue(
&self,
domain: &str,
recipient_key: &[u8],
payload: &[u8],
channel_id: &[u8],
) -> anyhow::Result<u64> {
let conn = self.get_or_connect(domain).await?;
let (send, recv) = conn.open_bi().await.context("open bi stream to peer")?;
let (reader, writer) = (
tokio_util::compat::TokioAsyncReadCompatExt::compat(recv),
tokio_util::compat::TokioAsyncWriteCompatExt::compat_write(send),
);
let rpc_network = capnp_rpc::twoparty::VatNetwork::new(
reader,
writer,
capnp_rpc::rpc_twoparty_capnp::Side::Client,
Default::default(),
);
let mut rpc_system = capnp_rpc::RpcSystem::new(Box::new(rpc_network), None);
let client: quicprochat_proto::federation_capnp::federation_service::Client =
rpc_system.bootstrap(capnp_rpc::rpc_twoparty_capnp::Side::Server);
tokio::task::spawn_local(rpc_system);
let mut req = client.relay_enqueue_request();
{
let mut builder = req.get();
builder.set_recipient_key(recipient_key);
builder.set_payload(payload);
builder.set_channel_id(channel_id);
builder.set_version(1);
let mut auth = builder.init_auth();
auth.set_origin(&self.local_domain);
}
let response = req.send().promise.await
.map_err(|e| anyhow::anyhow!("federation relay_enqueue failed: {e}"))?;
let seq = response.get()
.map_err(|e| anyhow::anyhow!("read relay_enqueue response: {e}"))?
.get_seq();
Ok(seq)
}
/// Proxy a key package fetch to a remote peer.
pub async fn proxy_fetch_key_package(
&self,
domain: &str,
identity_key: &[u8],
) -> anyhow::Result<Option<Vec<u8>>> {
let conn = self.get_or_connect(domain).await?;
let (send, recv) = conn.open_bi().await.context("open bi stream to peer")?;
let (reader, writer) = (
tokio_util::compat::TokioAsyncReadCompatExt::compat(recv),
tokio_util::compat::TokioAsyncWriteCompatExt::compat_write(send),
);
let rpc_network = capnp_rpc::twoparty::VatNetwork::new(
reader,
writer,
capnp_rpc::rpc_twoparty_capnp::Side::Client,
Default::default(),
);
let mut rpc_system = capnp_rpc::RpcSystem::new(Box::new(rpc_network), None);
let client: quicprochat_proto::federation_capnp::federation_service::Client =
rpc_system.bootstrap(capnp_rpc::rpc_twoparty_capnp::Side::Server);
tokio::task::spawn_local(rpc_system);
let mut req = client.proxy_fetch_key_package_request();
{
let mut builder = req.get();
builder.set_identity_key(identity_key);
let mut auth = builder.init_auth();
auth.set_origin(&self.local_domain);
}
let response = req.send().promise.await
.map_err(|e| anyhow::anyhow!("federation proxy_fetch_key_package failed: {e}"))?;
let pkg = response.get()
.map_err(|e| anyhow::anyhow!("read proxy_fetch_key_package response: {e}"))?
.get_package()
.map_err(|e| anyhow::anyhow!("get package: {e}"))?;
if pkg.is_empty() {
Ok(None)
} else {
Ok(Some(pkg.to_vec()))
}
}
/// Proxy a hybrid key fetch to a remote peer.
pub async fn proxy_fetch_hybrid_key(
&self,
domain: &str,
identity_key: &[u8],
) -> anyhow::Result<Option<Vec<u8>>> {
let conn = self.get_or_connect(domain).await?;
let (send, recv) = conn.open_bi().await.context("open bi stream to peer")?;
let (reader, writer) = (
tokio_util::compat::TokioAsyncReadCompatExt::compat(recv),
tokio_util::compat::TokioAsyncWriteCompatExt::compat_write(send),
);
let rpc_network = capnp_rpc::twoparty::VatNetwork::new(
reader,
writer,
capnp_rpc::rpc_twoparty_capnp::Side::Client,
Default::default(),
);
let mut rpc_system = capnp_rpc::RpcSystem::new(Box::new(rpc_network), None);
let client: quicprochat_proto::federation_capnp::federation_service::Client =
rpc_system.bootstrap(capnp_rpc::rpc_twoparty_capnp::Side::Server);
tokio::task::spawn_local(rpc_system);
let mut req = client.proxy_fetch_hybrid_key_request();
{
let mut builder = req.get();
builder.set_identity_key(identity_key);
let mut auth = builder.init_auth();
auth.set_origin(&self.local_domain);
}
let response = req.send().promise.await
.map_err(|e| anyhow::anyhow!("federation proxy_fetch_hybrid_key failed: {e}"))?;
let pk = response.get()
.map_err(|e| anyhow::anyhow!("read proxy_fetch_hybrid_key response: {e}"))?
.get_hybrid_public_key()
.map_err(|e| anyhow::anyhow!("get hybrid_public_key: {e}"))?;
if pk.is_empty() {
Ok(None)
} else {
Ok(Some(pk.to_vec()))
}
}
/// Proxy a user resolution to a remote peer.
pub async fn proxy_resolve_user(
&self,
domain: &str,
username: &str,
) -> anyhow::Result<Option<Vec<u8>>> {
let conn = self.get_or_connect(domain).await?;
let (send, recv) = conn.open_bi().await.context("open bi stream to peer")?;
let (reader, writer) = (
tokio_util::compat::TokioAsyncReadCompatExt::compat(recv),
tokio_util::compat::TokioAsyncWriteCompatExt::compat_write(send),
);
let rpc_network = capnp_rpc::twoparty::VatNetwork::new(
reader,
writer,
capnp_rpc::rpc_twoparty_capnp::Side::Client,
Default::default(),
);
let mut rpc_system = capnp_rpc::RpcSystem::new(Box::new(rpc_network), None);
let client: quicprochat_proto::federation_capnp::federation_service::Client =
rpc_system.bootstrap(capnp_rpc::rpc_twoparty_capnp::Side::Server);
tokio::task::spawn_local(rpc_system);
let mut req = client.proxy_resolve_user_request();
{
let mut builder = req.get();
builder.set_username(username);
let mut auth = builder.init_auth();
auth.set_origin(&self.local_domain);
}
let response = req.send().promise.await
.map_err(|e| anyhow::anyhow!("federation proxy_resolve_user failed: {e}"))?;
let key = response.get()
.map_err(|e| anyhow::anyhow!("read proxy_resolve_user response: {e}"))?
.get_identity_key()
.map_err(|e| anyhow::anyhow!("get identity_key: {e}"))?;
if key.is_empty() {
Ok(None)
} else {
Ok(Some(key.to_vec()))
}
}
/// Get an existing connection or create a new one to a peer domain.
async fn get_or_connect(&self, domain: &str) -> anyhow::Result<quinn::Connection> {
// Check for cached connection that's still alive.
if let Some(conn) = self.connections.get(domain) {
if conn.close_reason().is_none() {
return Ok(conn.clone());
}
}
let addr = self.peer_addresses.get(domain).ok_or_else(|| {
anyhow::anyhow!("no federation peer configured for domain '{domain}'")
})?;
tracing::info!(domain = domain, addr = %addr, "connecting to federation peer");
let conn = self
.endpoint
.connect(*addr, domain)
.map_err(|e| anyhow::anyhow!("federation connect to {domain}: {e}"))?
.await
.with_context(|| format!("federation QUIC handshake with {domain}"))?;
self.connections.insert(domain.to_string(), conn.clone());
Ok(conn)
}
}

View File

@@ -0,0 +1,14 @@
//! Federation subsystem: server-to-server message relay over mutual TLS + QUIC.
//!
//! When federation is enabled, the server binds a second QUIC endpoint on a
//! dedicated port (default 7001) that only accepts connections from known peers
//! authenticated via mTLS. Inbound requests are handled by [`service::FederationServiceImpl`],
//! which delegates to the local [`Store`]. Outbound relay uses [`client::FederationClient`].
pub mod address;
pub mod client;
pub mod routing;
pub mod service;
pub mod tls;
pub use client::FederationClient;

View File

@@ -0,0 +1,45 @@
//! Federation routing: determine whether a recipient is local or remote.
use std::sync::Arc;
use crate::storage::Store;
/// Where a message should be delivered.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Destination {
/// Recipient is on this server.
Local,
/// Recipient's home server is the given domain.
Remote(String),
}
/// Resolve a recipient identity key to a routing destination.
///
/// 1. Check the `identity_home_servers` table for an explicit mapping.
/// 2. If no mapping exists, assume local (backwards compatible with single-server deployments).
pub fn resolve_destination(
store: &Arc<dyn Store>,
recipient_key: &[u8],
local_domain: &str,
) -> Destination {
match store.get_identity_home_server(recipient_key) {
Ok(Some(domain)) if domain != local_domain => Destination::Remote(domain),
_ => Destination::Local,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn unknown_identity_routes_local() {
let store: Arc<dyn Store> =
Arc::new(crate::storage::FileBackedStore::open(
tempfile::tempdir().unwrap().path(),
).unwrap());
let dest = resolve_destination(&store, &[1u8; 32], "local.example.com");
assert_eq!(dest, Destination::Local);
}
}

View File

@@ -0,0 +1,349 @@
//! Inbound federation handler: implements `FederationService` Cap'n Proto interface.
//!
//! Delegates all operations to the local [`Store`], acting as a trusted relay
//! from authenticated peer servers.
//!
//! **Security:** Each handler validates the request's `origin` field against
//! the `verified_peer_domain` extracted from the mTLS client certificate at
//! connection time. Per-peer rate limits prevent abuse.
use std::sync::Arc;
use capnp::capability::Promise;
use dashmap::DashMap;
use quicprochat_proto::federation_capnp::federation_service;
use tokio::sync::Notify;
use crate::auth::RateEntry;
use crate::storage::Store;
/// Per-peer federation rate limit: max requests within a 60-second window.
const FED_RATE_LIMIT_WINDOW_SECS: u64 = 60;
const FED_RATE_LIMIT_MAX: u32 = 200;
/// Inbound federation RPC handler.
pub struct FederationServiceImpl {
pub store: Arc<dyn Store>,
pub waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
pub local_domain: String,
/// The peer domain extracted from the mTLS client certificate's CN/SAN
/// at connection time. All requests must declare an `origin` matching this.
pub verified_peer_domain: Option<String>,
/// Per-peer rate limiter (keyed by peer domain).
pub rate_limits: Arc<DashMap<String, RateEntry>>,
}
/// Validate that the request's `origin` matches the mTLS-verified peer domain.
fn validate_origin(
verified: &Option<String>,
declared: &str,
) -> Result<(), capnp::Error> {
match verified {
Some(ref expected) if expected == declared => Ok(()),
Some(ref expected) => Err(capnp::Error::failed(format!(
"federation auth: origin '{}' does not match mTLS cert '{}'",
declared, expected
))),
None => Err(capnp::Error::failed(
"federation auth: no verified peer domain (mTLS required)".into(),
)),
}
}
/// Extract and validate the origin string from the request's auth field.
fn extract_and_validate_origin(
service: &FederationServiceImpl,
get_auth: Result<quicprochat_proto::federation_capnp::federation_auth::Reader<'_>, capnp::Error>,
) -> Result<String, capnp::Error> {
let auth = get_auth
.map_err(|_| capnp::Error::failed("federation auth: missing auth field".into()))?;
let origin_reader = auth.get_origin()
.map_err(|_| capnp::Error::failed("federation auth: missing origin".into()))?;
let origin = origin_reader.to_str()
.map_err(|_| capnp::Error::failed("federation auth: origin is not valid UTF-8".into()))?;
if origin.is_empty() {
return Err(capnp::Error::failed("federation auth: origin must not be empty".into()));
}
validate_origin(&service.verified_peer_domain, origin)?;
check_federation_rate_limit(&service.rate_limits, origin)?;
Ok(origin.to_string())
}
/// Per-peer federation rate limiter.
fn check_federation_rate_limit(
rate_limits: &DashMap<String, RateEntry>,
peer_domain: &str,
) -> Result<(), capnp::Error> {
let now = crate::auth::current_timestamp();
let mut entry = rate_limits.entry(peer_domain.to_string()).or_insert(RateEntry {
count: 0,
window_start: now,
});
if now - entry.window_start >= FED_RATE_LIMIT_WINDOW_SECS {
entry.count = 1;
entry.window_start = now;
} else {
entry.count += 1;
if entry.count > FED_RATE_LIMIT_MAX {
return Err(capnp::Error::failed(format!(
"federation rate limit exceeded for peer '{peer_domain}'"
)));
}
}
Ok(())
}
impl federation_service::Server for FederationServiceImpl {
fn relay_enqueue(
&mut self,
params: federation_service::RelayEnqueueParams,
mut results: federation_service::RelayEnqueueResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
};
// Validate origin against mTLS cert and apply rate limit.
let origin = match extract_and_validate_origin(self, p.get_auth()) {
Ok(o) => o,
Err(e) => return Promise::err(e),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("bad recipient_key: {e}"))),
};
let payload = match p.get_payload() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("bad payload: {e}"))),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
if recipient_key.len() != 32 {
return Promise::err(capnp::Error::failed("recipient_key must be 32 bytes".into()));
}
if payload.is_empty() {
return Promise::err(capnp::Error::failed("payload must not be empty".into()));
}
let seq = match self.store.enqueue(&recipient_key, &channel_id, payload, None) {
Ok(s) => s,
Err(e) => return Promise::err(capnp::Error::failed(format!("store error: {e}"))),
};
results.get().set_seq(seq);
// Wake any waiting fetchWait clients.
if let Some(waiter) = self.waiters.get(&recipient_key) {
waiter.notify_waiters();
}
tracing::info!(
origin = %origin,
recipient_prefix = %hex::encode(&recipient_key[..4]),
seq = seq,
"federation: relayed enqueue"
);
Promise::ok(())
}
fn relay_batch_enqueue(
&mut self,
params: federation_service::RelayBatchEnqueueParams,
mut results: federation_service::RelayBatchEnqueueResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
};
// Validate origin against mTLS cert and apply rate limit.
let _origin = match extract_and_validate_origin(self, p.get_auth()) {
Ok(o) => o,
Err(e) => return Promise::err(e),
};
let recipient_keys = match p.get_recipient_keys() {
Ok(v) => v,
Err(e) => return Promise::err(capnp::Error::failed(format!("bad recipient_keys: {e}"))),
};
let payload = match p.get_payload() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("bad payload: {e}"))),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let mut seqs = Vec::with_capacity(recipient_keys.len() as usize);
for i in 0..recipient_keys.len() {
let rk = match recipient_keys.get(i) {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("bad key[{i}]: {e}"))),
};
if rk.len() != 32 {
return Promise::err(capnp::Error::failed(
format!("recipient_key[{i}] must be 32 bytes"),
));
}
let seq = match self.store.enqueue(&rk, &channel_id, payload.clone(), None) {
Ok(s) => s,
Err(e) => return Promise::err(capnp::Error::failed(format!("store error: {e}"))),
};
seqs.push(seq);
if let Some(waiter) = self.waiters.get(&rk) {
waiter.notify_waiters();
}
}
let mut list = results.get().init_seqs(seqs.len() as u32);
for (i, seq) in seqs.iter().enumerate() {
list.set(i as u32, *seq);
}
tracing::info!(
recipient_count = recipient_keys.len(),
"federation: relayed batch_enqueue"
);
Promise::ok(())
}
fn proxy_fetch_key_package(
&mut self,
params: federation_service::ProxyFetchKeyPackageParams,
mut results: federation_service::ProxyFetchKeyPackageResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
};
// Validate origin against mTLS cert and apply rate limit.
if let Err(e) = extract_and_validate_origin(self, p.get_auth()) {
return Promise::err(e);
}
let identity_key = match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("bad identity_key: {e}"))),
};
match self.store.fetch_key_package(&identity_key) {
Ok(Some(pkg)) => results.get().set_package(&pkg),
Ok(None) => results.get().set_package(&[]),
Err(e) => return Promise::err(capnp::Error::failed(format!("store error: {e}"))),
}
Promise::ok(())
}
fn proxy_fetch_hybrid_key(
&mut self,
params: federation_service::ProxyFetchHybridKeyParams,
mut results: federation_service::ProxyFetchHybridKeyResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
};
// Validate origin against mTLS cert and apply rate limit.
if let Err(e) = extract_and_validate_origin(self, p.get_auth()) {
return Promise::err(e);
}
let identity_key = match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("bad identity_key: {e}"))),
};
match self.store.fetch_hybrid_key(&identity_key) {
Ok(Some(pk)) => results.get().set_hybrid_public_key(&pk),
Ok(None) => results.get().set_hybrid_public_key(&[]),
Err(e) => return Promise::err(capnp::Error::failed(format!("store error: {e}"))),
}
Promise::ok(())
}
fn proxy_resolve_user(
&mut self,
params: federation_service::ProxyResolveUserParams,
mut results: federation_service::ProxyResolveUserResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
};
// Validate origin against mTLS cert and apply rate limit.
if let Err(e) = extract_and_validate_origin(self, p.get_auth()) {
return Promise::err(e);
}
let username = match p.get_username() {
Ok(u) => match u.to_str() {
Ok(s) => s.to_string(),
Err(e) => return Promise::err(capnp::Error::failed(format!("bad utf-8: {e}"))),
},
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
};
match self.store.get_user_identity_key(&username) {
Ok(Some(key)) => results.get().set_identity_key(&key),
Ok(None) => results.get().set_identity_key(&[]),
Err(e) => return Promise::err(capnp::Error::failed(format!("store error: {e}"))),
}
Promise::ok(())
}
fn federation_health(
&mut self,
_params: federation_service::FederationHealthParams,
mut results: federation_service::FederationHealthResults,
) -> Promise<(), capnp::Error> {
// Health check does not require origin validation (diagnostic endpoint).
results.get().set_status("ok");
results.get().set_server_domain(&self.local_domain);
Promise::ok(())
}
}
/// Extract the peer domain from the mTLS client certificate's first SAN (DNS name)
/// or CN, given the QUIC connection's peer identity (a certificate chain).
pub fn extract_peer_domain(conn: &quinn::Connection) -> Option<String> {
let identity = conn.peer_identity()?;
let certs = identity.downcast::<Vec<rustls::pki_types::CertificateDer<'static>>>().ok()?;
let first_cert = certs.first()?;
// Parse the DER certificate to extract SAN DNS names or CN.
let (_, parsed) = x509_parser::parse_x509_certificate(first_cert.as_ref()).ok()?;
// Prefer SAN DNS names.
if let Ok(Some(san)) = parsed.subject_alternative_name() {
for name in &san.value.general_names {
if let x509_parser::extensions::GeneralName::DNSName(dns) = name {
return Some(dns.to_string());
}
}
}
// Fall back to CN.
for rdn in parsed.subject().iter() {
for attr in rdn.iter() {
if attr.attr_type() == &x509_parser::oid_registry::OID_X509_COMMON_NAME {
if let Ok(cn) = attr.as_str() {
return Some(cn.to_string());
}
}
}
}
None
}

View File

@@ -0,0 +1,85 @@
//! Build mTLS server/client configs for the federation endpoint.
//!
//! Federation uses a separate CA from the public-facing QUIC endpoint.
//! Both server and client present certificates; the server verifies the client
//! cert is signed by the federation CA.
use std::path::Path;
use std::sync::Arc;
use anyhow::Context;
use quinn::ServerConfig;
use quinn_proto::crypto::rustls::QuicServerConfig;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::version::TLS13;
/// Build a QUIC server config for the federation listener with mutual TLS.
///
/// `cert`/`key`: this server's federation certificate and private key.
/// `ca`: the federation CA certificate used to verify peer certificates.
pub fn build_federation_server_config(
cert_path: &Path,
key_path: &Path,
ca_path: &Path,
) -> anyhow::Result<ServerConfig> {
let cert_bytes = std::fs::read(cert_path)
.with_context(|| format!("read federation cert: {:?}", cert_path))?;
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("read federation key: {:?}", key_path))?;
let ca_bytes = std::fs::read(ca_path)
.with_context(|| format!("read federation CA: {:?}", ca_path))?;
let cert_chain = vec![CertificateDer::from(cert_bytes)];
let key = PrivateKeyDer::try_from(key_bytes)
.map_err(|_| anyhow::anyhow!("invalid federation private key"))?;
// Build a root cert store with the federation CA for client verification.
let mut root_store = rustls::RootCertStore::empty();
root_store
.add(CertificateDer::from(ca_bytes))
.context("add federation CA to root store")?;
let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
.build()
.context("build client cert verifier")?;
let mut tls = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13])
.with_client_cert_verifier(client_verifier)
.with_single_cert(cert_chain, key)?;
tls.alpn_protocols = vec![b"quicprochat/federation/1".to_vec()];
let crypto = QuicServerConfig::try_from(tls)
.map_err(|e| anyhow::anyhow!("invalid federation server TLS config: {e}"))?;
Ok(ServerConfig::with_crypto(Arc::new(crypto)))
}
/// Build a QUIC client config for connecting to a federation peer with mutual TLS.
pub fn build_federation_client_config(
cert_path: &Path,
key_path: &Path,
ca_path: &Path,
) -> anyhow::Result<rustls::ClientConfig> {
let cert_bytes = std::fs::read(cert_path)
.with_context(|| format!("read federation cert: {:?}", cert_path))?;
let key_bytes = std::fs::read(key_path)
.with_context(|| format!("read federation key: {:?}", key_path))?;
let ca_bytes = std::fs::read(ca_path)
.with_context(|| format!("read federation CA: {:?}", ca_path))?;
let cert_chain = vec![CertificateDer::from(cert_bytes)];
let key = PrivateKeyDer::try_from(key_bytes)
.map_err(|_| anyhow::anyhow!("invalid federation client private key"))?;
let mut root_store = rustls::RootCertStore::empty();
root_store
.add(CertificateDer::from(ca_bytes))
.context("add federation CA to root store")?;
let tls = rustls::ClientConfig::builder_with_protocol_versions(&[&TLS13])
.with_root_certificates(root_store)
.with_client_auth_cert(cert_chain, key)
.context("set client auth cert")?;
Ok(tls)
}

View File

@@ -0,0 +1,198 @@
//! Server-side plugin hooks for extending quicprochat.
//!
//! Implement the [`ServerHooks`] trait to intercept server events — message delivery,
//! authentication, channel creation, and more. Hooks fire after validation but before
//! storage, so they can inspect, log, or reject operations.
//!
//! # Built-in implementations
//!
//! - [`NoopHooks`] — does nothing (default when no hooks are configured)
//! - [`TracingHooks`] — logs all events via `tracing` at info/debug level
//!
//! # Writing a custom hook
//!
//! ```rust,ignore
//! use quicprochat_server::hooks::{ServerHooks, HookAction, MessageEvent};
//!
//! struct ModeratorHook {
//! banned_words: Vec<String>,
//! }
//!
//! impl ServerHooks for ModeratorHook {
//! fn on_message_enqueue(&self, event: &MessageEvent) -> HookAction {
//! // Can't inspect encrypted content (E2E), but can enforce rate limits,
//! // payload size limits, or sender restrictions.
//! if event.payload_len > 1_000_000 {
//! return HookAction::Reject("payload too large".into());
//! }
//! HookAction::Continue
//! }
//! }
//! ```
/// The result of a hook invocation.
#[derive(Clone, Debug)]
pub enum HookAction {
/// Allow the operation to proceed.
Continue,
/// Reject the operation with a reason (returned to the client as an error).
Reject(String),
}
/// Event data for message enqueue operations.
#[derive(Clone, Debug)]
pub struct MessageEvent {
/// Sender's identity key (32 bytes), if known (None in sealed sender mode).
pub sender_identity: Option<Vec<u8>>,
/// Recipient's identity key (32 bytes).
pub recipient_key: Vec<u8>,
/// Channel ID (16 bytes) if this is a DM channel message.
pub channel_id: Vec<u8>,
/// Length of the encrypted payload in bytes.
pub payload_len: usize,
/// Server-assigned sequence number.
pub seq: u64,
}
/// Event data for authentication operations.
#[derive(Clone, Debug)]
pub struct AuthEvent {
/// The username attempting to authenticate.
pub username: String,
/// Whether the authentication succeeded.
pub success: bool,
/// Failure reason (empty on success).
pub failure_reason: String,
}
/// Event data for channel creation operations.
#[derive(Clone, Debug)]
pub struct ChannelEvent {
/// The channel's unique ID (16 bytes).
pub channel_id: Vec<u8>,
/// Identity key of the initiator.
pub initiator_key: Vec<u8>,
/// Identity key of the peer.
pub peer_key: Vec<u8>,
/// True if this is a newly created channel (initiator creates the MLS group).
pub was_new: bool,
}
/// Event data for message fetch operations.
#[derive(Clone, Debug)]
pub struct FetchEvent {
/// Identity key of the fetcher.
pub recipient_key: Vec<u8>,
/// Channel ID being fetched from.
pub channel_id: Vec<u8>,
/// Number of messages returned.
pub message_count: usize,
}
/// Trait for server-side plugin hooks.
///
/// All methods have default implementations that return [`HookAction::Continue`],
/// so you only need to override the events you care about.
///
/// Hooks are called synchronously in the RPC handler path. Keep them fast —
/// offload heavy work (HTTP calls, disk I/O) to background tasks.
pub trait ServerHooks: Send + Sync {
/// Called after validation, before a message is stored in the delivery queue.
///
/// Return `HookAction::Reject` to prevent delivery.
fn on_message_enqueue(&self, _event: &MessageEvent) -> HookAction {
HookAction::Continue
}
/// Called after a batch of messages is enqueued.
fn on_batch_enqueue(&self, _events: &[MessageEvent]) {
// Default: no-op
}
/// Called after a successful or failed login attempt.
fn on_auth(&self, _event: &AuthEvent) {
// Default: no-op
}
/// Called after a channel is created or looked up.
fn on_channel_created(&self, _event: &ChannelEvent) {
// Default: no-op
}
/// Called after messages are fetched from the delivery queue.
fn on_fetch(&self, _event: &FetchEvent) {
// Default: no-op
}
/// Called when a user registers (OPAQUE registration complete).
fn on_user_registered(&self, _username: &str, _identity_key: &[u8]) {
// Default: no-op
}
}
/// No-op hook implementation (default).
pub struct NoopHooks;
impl ServerHooks for NoopHooks {}
/// Hook implementation that logs all events via `tracing`.
pub struct TracingHooks;
impl ServerHooks for TracingHooks {
fn on_message_enqueue(&self, event: &MessageEvent) -> HookAction {
tracing::info!(
recipient_prefix = %hex_prefix(&event.recipient_key),
payload_len = event.payload_len,
seq = event.seq,
has_sender = event.sender_identity.is_some(),
"hook: message enqueued"
);
HookAction::Continue
}
fn on_batch_enqueue(&self, events: &[MessageEvent]) {
tracing::info!(
count = events.len(),
"hook: batch enqueue"
);
}
fn on_auth(&self, event: &AuthEvent) {
if event.success {
tracing::info!(username = %event.username, "hook: login success");
} else {
tracing::warn!(
username = %event.username,
reason = %event.failure_reason,
"hook: login failure"
);
}
}
fn on_channel_created(&self, event: &ChannelEvent) {
tracing::info!(
channel_id = %hex_prefix(&event.channel_id),
was_new = event.was_new,
"hook: channel created"
);
}
fn on_fetch(&self, event: &FetchEvent) {
if event.message_count > 0 {
tracing::debug!(
recipient_prefix = %hex_prefix(&event.recipient_key),
count = event.message_count,
"hook: messages fetched"
);
}
}
fn on_user_registered(&self, username: &str, _identity_key: &[u8]) {
tracing::info!(username = %username, "hook: user registered");
}
}
fn hex_prefix(bytes: &[u8]) -> String {
let n = bytes.len().min(4);
hex::encode(&bytes[..n])
}

View File

@@ -0,0 +1,783 @@
//! qpc-server — unified Authentication + Delivery service.
//!
//! The server hosts Authentication + Delivery services over QUIC + Cap'n Proto.
use std::{net::IpAddr, net::SocketAddr, path::PathBuf, sync::Arc};
use anyhow::Context;
use clap::Parser;
use dashmap::DashMap;
use opaque_ke::ServerSetup;
use quicprochat_core::opaque_auth::OpaqueSuite;
use quicprochat_kt::MerkleLog;
use quinn::Endpoint;
use rand::rngs::OsRng;
use tokio::sync::Notify;
use tokio::task::LocalSet;
pub mod audit;
mod auth;
mod config;
pub mod domain;
mod error_codes;
mod federation;
pub mod hooks;
mod metrics;
mod node_service;
#[allow(unsafe_code)] // FFI: C-ABI plugin interaction requires unsafe blocks
mod plugin_loader;
mod sql_store;
mod tls;
mod storage;
pub mod v2_handlers;
mod ws_bridge;
#[cfg(feature = "webtransport")]
mod webtransport;
use auth::{AuthConfig, PendingLogin, RateEntry, SessionInfo};
use config::{
load_config, merge_config, validate_production_config, DEFAULT_DATA_DIR, DEFAULT_DB_PATH,
DEFAULT_LISTEN, DEFAULT_STORE_BACKEND, DEFAULT_TLS_CERT, DEFAULT_TLS_KEY,
};
use node_service::{handle_node_connection, spawn_cleanup_task};
use sql_store::SqlStore;
use storage::{FileBackedStore, Store};
use tls::build_server_config;
// ── CLI ───────────────────────────────────────────────────────────────────────
#[derive(Debug, Parser)]
#[command(
name = "qpc-server",
about = "quicprochat Delivery Service + Authentication Service",
version
)]
struct Args {
/// Optional path to a TOML config file (fields map to CLI flags).
#[arg(long, env = "QPQ_CONFIG")]
config: Option<PathBuf>,
/// QUIC listen address (host:port).
#[arg(long, default_value = DEFAULT_LISTEN, env = "QPQ_LISTEN")]
listen: String,
/// Directory for persisted server data (KeyPackages + delivery queues).
#[arg(long, default_value = DEFAULT_DATA_DIR, env = "QPQ_DATA_DIR")]
data_dir: String,
/// TLS certificate path (generated automatically if missing).
#[arg(long, default_value = DEFAULT_TLS_CERT, env = "QPQ_TLS_CERT")]
tls_cert: PathBuf,
/// TLS private key path (generated automatically if missing).
#[arg(long, default_value = DEFAULT_TLS_KEY, env = "QPQ_TLS_KEY")]
tls_key: PathBuf,
/// Required bearer token for auth.version=1 requests. Use --allow-insecure-auth to run without it (dev only).
#[arg(long, env = "QPQ_AUTH_TOKEN")]
auth_token: Option<String>,
/// Allow running without QPQ_AUTH_TOKEN (development only).
#[arg(long, env = "QPQ_ALLOW_INSECURE_AUTH", default_value_t = false)]
allow_insecure_auth: bool,
/// Enable Sealed Sender: enqueue does not require identity-bound session, only a valid token.
#[arg(long, env = "QPQ_SEALED_SENDER", default_value_t = false)]
sealed_sender: bool,
/// Storage backend: "file" (bincode) or "sql" (SQLCipher-encrypted).
#[arg(long, default_value = DEFAULT_STORE_BACKEND, env = "QPQ_STORE_BACKEND")]
store_backend: String,
/// Path to the SQLCipher database file (only used when --store-backend=sql).
#[arg(long, default_value = DEFAULT_DB_PATH, env = "QPQ_DB_PATH")]
db_path: PathBuf,
/// SQLCipher encryption key. Empty string disables encryption.
#[arg(long, default_value = "", env = "QPQ_DB_KEY")]
db_key: String,
/// Metrics HTTP listen address (e.g. 0.0.0.0:9090). If set and metrics enabled, /metrics is served.
#[arg(long, env = "QPQ_METRICS_LISTEN")]
metrics_listen: Option<String>,
/// Enable metrics server when metrics_listen is set.
#[arg(long, env = "QPQ_METRICS_ENABLED")]
metrics_enabled: Option<bool>,
/// Enable federation (server-to-server message relay).
#[arg(long, env = "QPQ_FEDERATION_ENABLED", default_value_t = false)]
federation_enabled: bool,
/// This server's domain for federation addressing (e.g. "chat.example.com").
#[arg(long, env = "QPQ_FEDERATION_DOMAIN")]
federation_domain: Option<String>,
/// Federation QUIC listen address (default: 0.0.0.0:7001).
#[arg(long, env = "QPQ_FEDERATION_LISTEN")]
federation_listen: Option<String>,
/// Directory containing plugin `.so` / `.dylib` files to load at startup.
/// Each library must export `extern "C" fn qpc_plugin_init(vtable: *mut HookVTable) -> i32`.
#[arg(long, env = "QPQ_PLUGIN_DIR")]
plugin_dir: Option<PathBuf>,
/// Redact identity key prefixes and payload sizes in audit logs for metadata minimization.
#[arg(long, env = "QPQ_REDACT_LOGS", default_value_t = false)]
redact_logs: bool,
/// WebSocket JSON-RPC bridge listen address (e.g. 0.0.0.0:9000). Enables browser connectivity.
#[arg(long, env = "QPQ_WS_LISTEN")]
ws_listen: Option<String>,
/// WebTransport (HTTP/3) listen address for browser clients (e.g. 0.0.0.0:7443).
/// Requires --features webtransport.
#[arg(long, env = "QPQ_WEBTRANSPORT_LISTEN")]
webtransport_listen: Option<String>,
/// Graceful shutdown drain timeout in seconds (default: 30). In-flight RPCs get this
/// long to finish after a shutdown signal before connections are forcefully closed.
#[arg(long, env = "QPQ_DRAIN_TIMEOUT", default_value_t = config::DEFAULT_DRAIN_TIMEOUT_SECS)]
drain_timeout: u64,
/// Default per-RPC timeout in seconds (default: 30). Individual methods may override.
#[arg(long, env = "QPQ_RPC_TIMEOUT", default_value_t = config::DEFAULT_RPC_TIMEOUT_SECS)]
rpc_timeout: u64,
/// Storage/database operation timeout in seconds (default: 10).
#[arg(long, env = "QPQ_STORAGE_TIMEOUT", default_value_t = config::DEFAULT_STORAGE_TIMEOUT_SECS)]
storage_timeout: u64,
}
// ── Entry point ───────────────────────────────────────────────────────────────
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let _ = rustls::crypto::ring::default_provider().install_default();
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
let args = Args::parse();
let file_cfg = load_config(args.config.as_deref())?;
let effective = merge_config(&args, &file_cfg);
let production = std::env::var("QPQ_PRODUCTION")
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
.unwrap_or(false);
if production {
validate_production_config(&effective)?;
}
// Optional metrics server: only start when metrics_enabled and metrics_listen are set.
if effective.metrics_enabled {
if let Some(addr_str) = &effective.metrics_listen {
let addr: std::net::SocketAddr = addr_str
.parse()
.context("metrics_listen must be host:port (e.g. 0.0.0.0:9090)")?;
metrics_exporter_prometheus::PrometheusBuilder::new()
.with_http_listener(addr)
.install()
.context("failed to install Prometheus metrics exporter")?;
tracing::info!(addr = %addr_str, "metrics server listening on /metrics");
}
}
// In non-production, require an explicit opt-out before running without a static token.
if !production
&& effective
.auth_token
.as_deref()
.map(|s| s.is_empty())
.unwrap_or(true)
&& !effective.allow_insecure_auth
{
anyhow::bail!(
"missing QPQ_AUTH_TOKEN; set one or pass --allow-insecure-auth for development"
);
}
if effective.allow_insecure_auth
&& effective
.auth_token
.as_deref()
.map(|s| s.is_empty())
.unwrap_or(true)
{
tracing::warn!("running without QPQ_AUTH_TOKEN (allow-insecure-auth enabled); development only");
}
let listen: SocketAddr = effective
.listen
.parse()
.context("--listen must be host:port")?;
let mut server_config = build_server_config(&effective.tls_cert, &effective.tls_key, production)
.context("failed to build TLS/QUIC server config")?;
// Harden QUIC transport: idle timeout, limit stream concurrency.
let mut transport = quinn::TransportConfig::default();
transport.max_idle_timeout(Some(
std::time::Duration::from_secs(300)
.try_into()
.map_err(|e| anyhow::anyhow!("idle timeout: {e}"))?,
));
transport.max_concurrent_bidi_streams(1u32.into());
transport.max_concurrent_uni_streams(0u32.into());
server_config.transport_config(Arc::new(transport));
// Shared storage — persisted to disk for restart safety.
let store: Arc<dyn Store> = match effective.store_backend.as_str() {
"sql" => {
if let Some(parent) = effective.db_path.parent() {
std::fs::create_dir_all(parent).context("create db dir")?;
}
tracing::info!(
path = %effective.db_path.display(),
encrypted = !effective.db_key.is_empty(),
"opening SQLCipher store"
);
if effective.db_key.is_empty() {
tracing::warn!("db_key is empty; SQL store will be plaintext (development only)");
}
Arc::new(SqlStore::open(&effective.db_path, &effective.db_key)?)
}
_ => {
tracing::info!(dir = %effective.data_dir, "opening file-backed store");
Arc::new(FileBackedStore::open(&effective.data_dir)?)
}
};
// Ensure blobs directory exists for file transfer support.
std::fs::create_dir_all(PathBuf::from(&effective.data_dir).join("blobs"))
.context("create blobs directory")?;
let auth_cfg = Arc::new(AuthConfig::new(
effective.auth_token.clone(),
effective.allow_insecure_auth,
));
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = Arc::new(DashMap::new());
// OPAQUE ServerSetup: load from storage or generate fresh.
let opaque_setup: Arc<ServerSetup<OpaqueSuite>> = match store.get_server_setup() {
Ok(Some(bytes)) => {
let setup = ServerSetup::<OpaqueSuite>::deserialize(&bytes)
.map_err(|e| anyhow::anyhow!("corrupt OPAQUE server setup: {e}"))?;
tracing::info!("loaded persisted OPAQUE ServerSetup");
Arc::new(setup)
}
Ok(None) => {
let setup = ServerSetup::<OpaqueSuite>::new(&mut OsRng);
let bytes = setup.serialize().to_vec();
store
.store_server_setup(bytes)
.context("persist OPAQUE ServerSetup")?;
tracing::info!("generated and persisted new OPAQUE ServerSetup");
Arc::new(setup)
}
Err(e) => return Err(anyhow::anyhow!("load OPAQUE server setup: {e}")),
};
// Server Ed25519 signing key for delivery proofs: load from storage or generate fresh.
let signing_key: Arc<quicprochat_core::IdentityKeypair> = match store.get_signing_key_seed() {
Ok(Some(seed_bytes)) => {
let seed: [u8; 32] = seed_bytes
.as_slice()
.try_into()
.context("signing key seed must be 32 bytes")?;
tracing::info!("loaded persisted server signing key");
Arc::new(quicprochat_core::IdentityKeypair::from_seed(seed))
}
Ok(None) => {
let kp = quicprochat_core::IdentityKeypair::generate();
store
.store_signing_key_seed(kp.seed_bytes().to_vec())
.context("persist server signing key")?;
tracing::info!("generated and persisted new server signing key");
Arc::new(kp)
}
Err(e) => return Err(anyhow::anyhow!("load server signing key: {e}")),
};
// Key Transparency Merkle log: load from storage or start fresh.
let kt_log: Arc<std::sync::Mutex<MerkleLog>> = match store.load_kt_log() {
Ok(Some(bytes)) => {
match MerkleLog::from_bytes(&bytes) {
Ok(log) => {
tracing::info!(entries = log.len(), "loaded persisted KT Merkle log");
Arc::new(std::sync::Mutex::new(log))
}
Err(e) => {
tracing::warn!(error = %e, "KT log deserialise failed; starting fresh");
Arc::new(std::sync::Mutex::new(MerkleLog::new()))
}
}
}
Ok(None) => {
tracing::info!("no KT log found; starting fresh");
Arc::new(std::sync::Mutex::new(MerkleLog::new()))
}
Err(e) => return Err(anyhow::anyhow!("load KT log: {e}")),
};
// ── Plugin hooks ──────────────────────────────────────────────────────────
let hooks: Arc<dyn hooks::ServerHooks> = if let Some(dir) = &effective.plugin_dir {
let plugins = plugin_loader::load_plugins_from_dir(dir);
if plugins.is_empty() {
tracing::info!(dir = %dir.display(), "plugin_dir set but no plugins loaded");
Arc::new(hooks::NoopHooks)
} else {
tracing::info!(count = plugins.len(), "plugins loaded");
let boxed: Vec<Box<dyn hooks::ServerHooks>> = plugins
.into_iter()
.map(|p| Box::new(p) as Box<dyn hooks::ServerHooks>)
.collect();
Arc::new(plugin_loader::ChainedHooks::new(boxed))
}
} else {
Arc::new(hooks::NoopHooks)
};
let pending_logins: Arc<DashMap<String, PendingLogin>> = Arc::new(DashMap::new());
let sessions: Arc<DashMap<Vec<u8>, SessionInfo>> = Arc::new(DashMap::new());
let rate_limits: Arc<DashMap<Vec<u8>, RateEntry>> = Arc::new(DashMap::new());
let conn_rate_limits: Arc<DashMap<IpAddr, RateEntry>> = Arc::new(DashMap::new());
// Background cleanup task (expire sessions, pending logins, rate limits, and stale messages).
spawn_cleanup_task(
Arc::clone(&sessions),
Arc::clone(&pending_logins),
Arc::clone(&rate_limits),
Arc::clone(&conn_rate_limits),
Arc::clone(&store),
Arc::clone(&waiters),
);
// ── WebSocket JSON-RPC bridge ──────────────────────────────────────────
if let Some(ws_addr_str) = &effective.ws_listen {
let ws_addr: SocketAddr = ws_addr_str
.parse()
.context("--ws-listen must be host:port (e.g. 0.0.0.0:9000)")?;
let ws_state = Arc::new(ws_bridge::WsBridgeState {
store: Arc::clone(&store),
waiters: Arc::clone(&waiters),
auth_cfg: Arc::clone(&auth_cfg),
sessions: Arc::clone(&sessions),
rate_limits: Arc::clone(&rate_limits),
sealed_sender: effective.sealed_sender,
allow_insecure_auth: effective.allow_insecure_auth,
});
ws_bridge::spawn_ws_bridge(ws_addr, ws_state);
}
// ── WebTransport (HTTP/3) endpoint ─────────────────────────────────────
#[cfg(feature = "webtransport")]
if let Some(wt_addr_str) = &effective.webtransport_listen {
let wt_addr: SocketAddr = wt_addr_str
.parse()
.context("--webtransport-listen must be host:port (e.g. 0.0.0.0:7443)")?;
let wt_server_config = webtransport::build_webtransport_server_config(
&effective.tls_cert,
&effective.tls_key,
)
.context("build WebTransport server config")?;
let wt_state = Arc::new(v2_handlers::ServerState {
store: Arc::clone(&store),
waiters: Arc::clone(&waiters),
auth_cfg: Arc::clone(&auth_cfg),
opaque_setup: Arc::clone(&opaque_setup),
pending_logins: Arc::clone(&pending_logins),
sessions: Arc::clone(&sessions),
rate_limits: Arc::clone(&rate_limits),
sealed_sender: effective.sealed_sender,
hooks: Arc::clone(&hooks),
signing_key: Arc::clone(&signing_key),
kt_log: Arc::clone(&kt_log),
revocation_log: Arc::new(std::sync::Mutex::new(
quicprochat_kt::RevocationLog::new(),
)),
data_dir: PathBuf::from(&effective.data_dir),
redact_logs: effective.redact_logs,
audit_logger: Arc::new(audit::NoopAuditLogger),
draining: Arc::new(std::sync::atomic::AtomicBool::new(false)),
seen_message_ids: Arc::new(DashMap::new()),
banned_users: Arc::new(DashMap::new()),
moderation_reports: Arc::new(std::sync::Mutex::new(Vec::new())),
node_id: format!("wt-{}", hex::encode(&signing_key.public_key_bytes()[..4])),
start_time: std::time::Instant::now(),
storage_backend: effective.store_backend.clone(),
federation_client: None,
local_domain: effective.federation.as_ref().map(|f| f.domain.clone()).unwrap_or_default(),
});
let wt_registry = Arc::new(v2_handlers::build_registry(
std::time::Duration::from_secs(effective.rpc_timeout_secs),
));
webtransport::spawn_webtransport_listener(wt_addr, wt_server_config, wt_state, wt_registry)
.context("spawn WebTransport listener")?;
}
let endpoint = Endpoint::server(server_config, listen)?;
tracing::info!(
addr = %effective.listen,
"accepting QUIC connections"
);
// ── Federation setup ─────────────────────────────────────────────────────
let federation_client: Option<Arc<federation::FederationClient>> =
if let Some(fed_cfg) = &effective.federation {
tracing::info!(
domain = %fed_cfg.domain,
listen = %fed_cfg.listen,
peers = fed_cfg.peers.len(),
"federation enabled"
);
// Build a client endpoint for outbound federation connections.
// For now we create a simple endpoint; full mTLS is used when certs are provided.
let client_config = if fed_cfg.federation_cert.exists()
&& fed_cfg.federation_key.exists()
&& fed_cfg.federation_ca.exists()
{
let tls_cfg = federation::tls::build_federation_client_config(
&fed_cfg.federation_cert,
&fed_cfg.federation_key,
&fed_cfg.federation_ca,
)
.context("build federation client TLS config")?;
let crypto = quinn::crypto::rustls::QuicClientConfig::try_from(tls_cfg)
.map_err(|e| anyhow::anyhow!("invalid federation client QUIC config: {e}"))?;
let mut qc = quinn::ClientConfig::new(Arc::new(crypto));
let mut transport = quinn::TransportConfig::default();
transport.max_idle_timeout(Some(
std::time::Duration::from_secs(120)
.try_into()
.map_err(|e| anyhow::anyhow!("idle timeout: {e}"))?,
));
qc.transport_config(Arc::new(transport));
Some(qc)
} else {
tracing::warn!("federation cert/key/CA not found; outbound federation connections will fail");
None
};
let fed_bind: SocketAddr = SocketAddr::from(([0, 0, 0, 0], 0));
let mut fed_endpoint = Endpoint::client(fed_bind)
.context("create federation client endpoint")?;
if let Some(cc) = client_config {
fed_endpoint.set_default_client_config(cc);
}
let client = federation::FederationClient::new(fed_cfg, fed_endpoint)
.context("create federation client")?;
// Register configured peers in storage.
for peer in &fed_cfg.peers {
if let Err(e) = store.upsert_federation_peer(&peer.domain, true) {
tracing::warn!(domain = %peer.domain, error = %e, "failed to register federation peer");
}
}
Some(Arc::new(client))
} else {
None
};
let local_domain: Option<String> = effective.federation.as_ref().map(|f| f.domain.clone());
// ── Federation listener ──────────────────────────────────────────────────
let federation_endpoint: Option<Endpoint> =
if let Some(fed_cfg) = &effective.federation {
if fed_cfg.federation_cert.exists()
&& fed_cfg.federation_key.exists()
&& fed_cfg.federation_ca.exists()
{
let fed_server_config = federation::tls::build_federation_server_config(
&fed_cfg.federation_cert,
&fed_cfg.federation_key,
&fed_cfg.federation_ca,
)
.context("build federation server TLS config")?;
let fed_listen: SocketAddr = fed_cfg
.listen
.parse()
.context("federation listen must be host:port")?;
let ep = Endpoint::server(fed_server_config, fed_listen)
.context("bind federation QUIC endpoint")?;
tracing::info!(addr = %fed_cfg.listen, "federation endpoint listening");
Some(ep)
} else {
tracing::warn!("federation certs not found; federation listener not started");
None
}
} else {
None
};
// ── mDNS local mesh discovery ─────────────────────────────────────────────
// Announce this server on the local network so mesh-mode clients (and other
// Freifunk nodes) can discover it automatically without manual configuration.
// Non-critical: failures are logged as warnings; the server starts regardless.
let _mdns_daemon = {
let listen_port: u16 = listen.port();
// Use the federation domain as the mDNS instance name when available.
let mdns_instance = effective
.federation
.as_ref()
.map(|f| f.domain.clone())
.unwrap_or_else(|| "qpc-server".to_string());
// mDNS host names must end with a dot.
let mdns_host = if mdns_instance.ends_with('.') {
mdns_instance.clone()
} else {
format!("{mdns_instance}.local.")
};
match mdns_sd::ServiceDaemon::new() {
Ok(daemon) => {
let mut props = std::collections::HashMap::new();
props.insert("ver".to_string(), "1".to_string());
props.insert("server".to_string(), effective.listen.clone());
props.insert("domain".to_string(), mdns_instance.clone());
match mdns_sd::ServiceInfo::new(
"_quicprochat._udp.local.",
&mdns_instance,
&mdns_host,
&[] as &[std::net::IpAddr],
listen_port,
Some(props),
) {
Ok(info) => match daemon.register(info) {
Ok(()) => {
tracing::info!(
instance = %mdns_instance,
port = listen_port,
"mDNS: announced qpc server on local network (_quicprochat._udp.local.)"
);
}
Err(e) => {
tracing::warn!(error = %e, "mDNS: service registration failed; mesh discovery disabled");
}
},
Err(e) => {
tracing::warn!(error = %e, "mDNS: failed to build service info; mesh discovery disabled");
}
}
Some(daemon)
}
Err(e) => {
tracing::warn!(error = %e, "mDNS: daemon start failed; mesh discovery disabled");
None
}
}
};
// capnp-rpc is !Send (Rc internals), so all RPC tasks must stay on a LocalSet.
let local = LocalSet::new();
local
.run_until(async move {
// Spawn federation acceptor if enabled.
if let Some(fed_ep) = federation_endpoint {
let fed_store = Arc::clone(&store);
let fed_waiters = Arc::clone(&waiters);
let fed_domain = local_domain.clone().unwrap_or_default();
tokio::task::spawn_local(async move {
loop {
let incoming = match fed_ep.accept().await {
Some(i) => i,
None => break,
};
let connecting = match incoming.accept() {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, "federation: accept error");
continue;
}
};
let store = Arc::clone(&fed_store);
let waiters = Arc::clone(&fed_waiters);
let domain = fed_domain.clone();
tokio::task::spawn_local(async move {
match connecting.await {
Ok(conn) => {
tracing::info!(
peer = %conn.remote_address(),
"federation: peer connected"
);
let (send, recv) = match conn.accept_bi().await {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "federation: accept bi error");
return;
}
};
let reader = tokio_util::compat::TokioAsyncReadCompatExt::compat(recv);
let writer = tokio_util::compat::TokioAsyncWriteCompatExt::compat_write(send);
let network = capnp_rpc::twoparty::VatNetwork::new(
reader,
writer,
capnp_rpc::rpc_twoparty_capnp::Side::Server,
Default::default(),
);
let verified_peer_domain =
federation::service::extract_peer_domain(&conn);
if let Some(ref peer) = verified_peer_domain {
tracing::info!(peer_domain = %peer, "federation: mTLS peer authenticated");
} else {
tracing::warn!(peer = %conn.remote_address(), "federation: could not extract peer domain from mTLS cert");
}
let service_impl = federation::service::FederationServiceImpl {
store,
waiters,
local_domain: domain,
verified_peer_domain,
rate_limits: Arc::new(dashmap::DashMap::new()),
};
let client: quicprochat_proto::federation_capnp::federation_service::Client =
capnp_rpc::new_client(service_impl);
if let Err(e) = capnp_rpc::RpcSystem::new(
Box::new(network),
Some(client.client),
).await {
tracing::warn!(error = %e, "federation: RPC error");
}
}
Err(e) => {
tracing::warn!(error = %e, "federation: connection error");
}
}
});
}
});
}
loop {
tokio::select! {
biased;
incoming = endpoint.accept() => {
let incoming = match incoming {
Some(i) => i,
None => break,
};
// Per-IP connection rate limiting.
let remote_ip = incoming.remote_address().ip();
if !auth::check_conn_rate_limit(&conn_rate_limits, remote_ip) {
tracing::warn!(ip = %remote_ip, "connection rate limit exceeded, dropping");
incoming.refuse();
continue;
}
let connecting = match incoming.accept() {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, "failed to accept incoming connection");
continue;
}
};
let store = Arc::clone(&store);
let waiters = Arc::clone(&waiters);
let auth_cfg = Arc::clone(&auth_cfg);
let opaque_setup = Arc::clone(&opaque_setup);
let pending_logins = Arc::clone(&pending_logins);
let sessions = Arc::clone(&sessions);
let rate_limits = Arc::clone(&rate_limits);
let sealed_sender = effective.sealed_sender;
let fed_client = federation_client.clone();
let local_dom = local_domain.clone();
let sk = Arc::clone(&signing_key);
let conn_hooks = Arc::clone(&hooks);
let conn_kt_log = Arc::clone(&kt_log);
let conn_data_dir = PathBuf::from(&effective.data_dir);
let conn_redact_logs = effective.redact_logs;
tokio::task::spawn_local(async move {
if let Err(e) = handle_node_connection(
connecting,
store,
waiters,
auth_cfg,
opaque_setup,
pending_logins,
sessions,
rate_limits,
sealed_sender,
fed_client,
local_dom,
sk,
conn_hooks,
conn_kt_log,
conn_data_dir,
conn_redact_logs,
)
.await
{
tracing::warn!(error = %e, "connection error");
}
});
}
_ = shutdown_signal() => {
tracing::info!("shutdown signal received, draining QUIC connections");
// Stop accepting new connections immediately.
endpoint.close(0u32.into(), b"server shutdown");
break;
}
}
}
// Grace period: let in-flight RPC tasks on the LocalSet finish.
let drain_secs = effective.drain_timeout_secs;
tracing::info!(drain_timeout_secs = drain_secs, "waiting for in-flight RPCs to complete");
tokio::time::sleep(std::time::Duration::from_secs(drain_secs)).await;
Ok::<(), anyhow::Error>(())
})
.await?;
Ok(())
}
/// Wait for either SIGINT (Ctrl-C) or SIGTERM (Unix only).
///
/// Load balancers typically send SIGTERM during rolling deploys. The server
/// should stop accepting new connections, return "draining" from the health
/// endpoint, and wait for in-flight RPCs to finish (up to the drain timeout).
async fn shutdown_signal() {
let ctrl_c = tokio::signal::ctrl_c();
#[cfg(unix)]
{
let mut sigterm =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler");
tokio::select! {
_ = ctrl_c => {},
_ = sigterm.recv() => {},
}
}
#[cfg(not(unix))]
{
ctrl_c.await.ok();
}
}

View File

@@ -0,0 +1,64 @@
//! Prometheus metrics for the server.
//!
//! All counters/histograms/gauges use the `metrics` crate and are exported
//! via metrics-exporter-prometheus on a configurable HTTP port (e.g. /metrics).
/// Record one enqueue (success). Call after a message is enqueued.
pub fn record_enqueue_total() {
metrics::counter!("enqueue_total").increment(1);
}
/// Record enqueued payload size in bytes.
pub fn record_enqueue_bytes(bytes: u64) {
metrics::counter!("enqueue_bytes_total").increment(bytes);
}
/// Record one fetch (success). Call when fetch returns.
pub fn record_fetch_total() {
metrics::counter!("fetch_total").increment(1);
}
/// Record one fetch_wait (success). Call when fetch_wait returns.
pub fn record_fetch_wait_total() {
metrics::counter!("fetch_wait_total").increment(1);
}
/// Set the delivery queue depth gauge (sample). Updated at enqueue/fetch time.
pub fn record_delivery_queue_depth(depth: usize) {
metrics::gauge!("delivery_queue_depth").set(depth as f64);
}
/// Record one KeyPackage upload (success).
pub fn record_key_package_upload_total() {
metrics::counter!("key_package_upload_total").increment(1);
}
/// Record successful auth login (session token issued).
pub fn record_auth_login_success_total() {
metrics::counter!("auth_login_success_total").increment(1);
}
/// Record failed auth login attempt.
pub fn record_auth_login_failure_total() {
metrics::counter!("auth_login_failure_total").increment(1);
}
/// Record rate limit hit (enqueue rejected).
pub fn record_rate_limit_hit_total() {
metrics::counter!("rate_limit_hit_total").increment(1);
}
// ── Storage operation latency ───────────────────────────────────────────────
/// Record storage operation latency. Called by instrumented Store wrappers.
pub fn record_storage_latency(operation: &'static str, duration: std::time::Duration) {
metrics::histogram!("storage_operation_duration_seconds", "op" => operation)
.record(duration.as_secs_f64());
}
// ── Server info ────────────────────────────────────────────────────────────
/// Record the server uptime in seconds (set periodically).
pub fn record_uptime_seconds(secs: f64) {
metrics::gauge!("server_uptime_seconds").set(secs);
}

View File

@@ -0,0 +1,63 @@
use capnp::capability::Promise;
use quicprochat_proto::node_capnp::node_service;
use crate::auth::{coded_error, require_identity, validate_auth_context};
use crate::error_codes::*;
use super::NodeServiceImpl;
impl NodeServiceImpl {
pub fn handle_delete_account(
&mut self,
params: node_service::DeleteAccountParams,
mut results: node_service::DeleteAccountResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
// Validate auth and require an identity-bound session.
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
let identity_key = match require_identity(&auth_ctx) {
Ok(ik) => ik.to_vec(),
Err(e) => return Promise::err(e),
};
let identity_prefix = crate::auth::fmt_hex(&identity_key[..8.min(identity_key.len())]);
// Delete account data from the store.
if let Err(e) = self.store.delete_account(&identity_key) {
tracing::error!(identity = %identity_prefix, error = %e, "account deletion failed");
return Promise::err(coded_error(
E028_ACCOUNT_DELETION_FAILED,
format!("account deletion failed: {e}"),
));
}
// Invalidate all sessions for this identity.
let tokens_to_remove: Vec<Vec<u8>> = self
.sessions
.iter()
.filter(|entry| entry.value().identity_key == identity_key)
.map(|entry| entry.key().clone())
.collect();
for token in &tokens_to_remove {
self.sessions.remove(token);
}
tracing::info!(
identity = %identity_prefix,
sessions_invalidated = tokens_to_remove.len(),
"audit: account deleted"
);
results.get().set_success(true);
Promise::ok(())
}
}

View File

@@ -0,0 +1,412 @@
use capnp::capability::Promise;
use opaque_ke::{
CredentialFinalization, CredentialRequest, RegistrationRequest, RegistrationUpload,
ServerLogin, ServerRegistration,
};
use quicprochat_core::opaque_auth::OpaqueSuite;
use quicprochat_proto::node_capnp::node_service;
use crate::auth::{coded_error, current_timestamp, PendingLogin, SESSION_TTL_SECS};
use crate::error_codes::*;
use crate::metrics;
use crate::storage::StorageError;
use crate::hooks::AuthEvent;
use super::NodeServiceImpl;
// Audit events in this module must never include secrets (no session tokens, passwords, or raw keys).
fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
/// Parse username from Cap'n Proto reader; requires valid UTF-8.
fn parse_username_param(
result: Result<capnp::text::Reader<'_>, capnp::Error>,
) -> Result<String, capnp::Error> {
let reader = result.map_err(|e| coded_error(E020_BAD_PARAMS, e))?;
reader
.to_string()
.map_err(|_| coded_error(E020_BAD_PARAMS, "username must be valid UTF-8"))
}
impl NodeServiceImpl {
pub fn handle_opaque_login_start(
&mut self,
params: node_service::OpaqueLoginStartParams,
mut results: node_service::OpaqueLoginStartResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let request_bytes = match p.get_request() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if username.is_empty() {
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
}
// Check for existing recent pending login before expensive OPAQUE/storage work (DoS mitigation).
if let Some(existing) = self.pending_logins.get(&username) {
let age = current_timestamp().saturating_sub(existing.created_at);
if age < 60 {
return Promise::err(coded_error(E010_OPAQUE_ERROR, "login already in progress"));
}
}
let credential_request = match CredentialRequest::<OpaqueSuite>::deserialize(&request_bytes) {
Ok(r) => r,
Err(e) => {
return Promise::err(coded_error(
E010_OPAQUE_ERROR,
format!("invalid credential request: {e}"),
))
}
};
let password_file = match self.store.get_user_record(&username) {
Ok(Some(bytes)) => match ServerRegistration::<OpaqueSuite>::deserialize(&bytes) {
Ok(pf) => Some(pf),
Err(e) => {
return Promise::err(coded_error(
E010_OPAQUE_ERROR,
format!("corrupt user record: {e}"),
))
}
},
Ok(None) => None,
Err(e) => return Promise::err(storage_err(e)),
};
let mut rng = rand::rngs::OsRng;
let result = match ServerLogin::<OpaqueSuite>::start(
&mut rng,
&self.opaque_setup,
password_file,
credential_request,
username.as_bytes(),
Default::default(),
) {
Ok(r) => r,
Err(e) => {
return Promise::err(coded_error(
E010_OPAQUE_ERROR,
format!("OPAQUE login start failed: {e}"),
))
}
};
let state_bytes = result.state.serialize().to_vec();
self.pending_logins.insert(
username.clone(),
PendingLogin {
state_bytes,
created_at: current_timestamp(),
},
);
let response_bytes = result.message.serialize();
results.get().set_response(&response_bytes);
tracing::info!(user = %username, "OPAQUE login started");
Promise::ok(())
}
pub fn handle_opaque_register_start(
&mut self,
params: node_service::OpaqueRegisterStartParams,
mut results: node_service::OpaqueRegisterStartResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let request_bytes = match p.get_request() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if username.is_empty() {
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
}
if let Ok(true) = self.store.has_user_record(&username) {
return Promise::err(coded_error(
E018_USER_EXISTS,
format!("user '{}' already registered", username),
));
}
let registration_request = match RegistrationRequest::<OpaqueSuite>::deserialize(&request_bytes) {
Ok(r) => r,
Err(e) => {
return Promise::err(coded_error(
E010_OPAQUE_ERROR,
format!("invalid registration request: {e}"),
))
}
};
let result = match ServerRegistration::<OpaqueSuite>::start(
&self.opaque_setup,
registration_request,
username.as_bytes(),
) {
Ok(r) => r,
Err(e) => {
return Promise::err(coded_error(
E010_OPAQUE_ERROR,
format!("OPAQUE registration start failed: {e}"),
))
}
};
let response_bytes = result.message.serialize();
results.get().set_response(&response_bytes);
tracing::info!(user = %username, "OPAQUE registration started");
Promise::ok(())
}
pub fn handle_opaque_login_finish(
&mut self,
params: node_service::OpaqueLoginFinishParams,
mut results: node_service::OpaqueLoginFinishResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let finalization_bytes = match p.get_finalization() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let identity_key = p.get_identity_key().unwrap_or_default().to_vec();
if username.is_empty() {
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
}
let pending = match self.pending_logins.remove(&username) {
Some((_, pl)) => pl,
None => {
// Audit: login failure — do not log secrets (no token, no password).
tracing::warn!(user = %username, "audit: auth login failure (no pending login)");
metrics::record_auth_login_failure_total();
self.hooks.on_auth(&AuthEvent {
username: username.clone(),
success: false,
failure_reason: "no pending login".to_string(),
});
return Promise::err(coded_error(E019_NO_PENDING_LOGIN, "no pending login for this username"))
}
};
let server_login = match ServerLogin::<OpaqueSuite>::deserialize(&pending.state_bytes) {
Ok(s) => s,
Err(e) => {
return Promise::err(coded_error(
E010_OPAQUE_ERROR,
format!("corrupt login state: {e}"),
))
}
};
let finalization = match CredentialFinalization::<OpaqueSuite>::deserialize(&finalization_bytes) {
Ok(f) => f,
Err(e) => {
return Promise::err(coded_error(
E010_OPAQUE_ERROR,
format!("invalid credential finalization: {e}"),
))
}
};
let _result = match server_login.finish(finalization, Default::default()) {
Ok(r) => r,
Err(e) => {
tracing::warn!(user = %username, "audit: auth login failure (OPAQUE finish failed)");
metrics::record_auth_login_failure_total();
self.hooks.on_auth(&AuthEvent {
username: username.clone(),
success: false,
failure_reason: format!("OPAQUE finish failed: {e}"),
});
return Promise::err(coded_error(
E010_OPAQUE_ERROR,
format!("OPAQUE login finish failed (bad password?): {e}"),
))
}
};
if identity_key.is_empty() {
metrics::record_auth_login_failure_total();
return Promise::err(coded_error(
E016_IDENTITY_MISMATCH,
"identity key required to bind session token",
));
}
if let Ok(Some(stored_ik)) = self.store.get_user_identity_key(&username) {
if stored_ik != identity_key {
tracing::warn!(user = %username, "audit: auth login failure (identity mismatch)");
metrics::record_auth_login_failure_total();
self.hooks.on_auth(&AuthEvent {
username: username.clone(),
success: false,
failure_reason: "identity key mismatch".to_string(),
});
return Promise::err(coded_error(
E016_IDENTITY_MISMATCH,
"identity key does not match registered key",
));
}
}
let mut token = [0u8; 32];
rand::RngCore::fill_bytes(&mut rand::rngs::OsRng, &mut token);
let token_vec = token.to_vec();
let now = current_timestamp();
self.sessions.insert(
token_vec.clone(),
crate::auth::SessionInfo {
username: username.clone(),
identity_key,
created_at: now,
expires_at: now + SESSION_TTL_SECS,
},
);
results.get().set_session_token(&token_vec);
// Hook: on_auth — fires after successful login.
self.hooks.on_auth(&AuthEvent {
username: username.clone(),
success: true,
failure_reason: String::new(),
});
// Audit: login success — do not log session token or any secrets.
metrics::record_auth_login_success_total();
tracing::info!(user = %username, "audit: auth login success — session token issued");
Promise::ok(())
}
pub fn handle_opaque_register_finish(
&mut self,
params: node_service::OpaqueRegisterFinishParams,
mut results: node_service::OpaqueRegisterFinishResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let upload_bytes = match p.get_upload() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let identity_key = p.get_identity_key().unwrap_or_default().to_vec();
if username.is_empty() {
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
}
match self.store.has_user_record(&username) {
Ok(true) => {
return Promise::err(coded_error(
E018_USER_EXISTS,
format!("user '{}' already registered", username),
))
}
Err(e) => return Promise::err(storage_err(e)),
_ => {}
}
let upload = match RegistrationUpload::<OpaqueSuite>::deserialize(&upload_bytes) {
Ok(u) => u,
Err(e) => {
return Promise::err(coded_error(
E010_OPAQUE_ERROR,
format!("invalid registration upload: {e}"),
))
}
};
let password_file = ServerRegistration::<OpaqueSuite>::finish(upload);
let record_bytes = password_file.serialize().to_vec();
match self
.store
.store_user_record(&username, record_bytes)
{
Ok(()) => {}
Err(crate::storage::StorageError::DuplicateUser(_)) => {
return Promise::err(coded_error(
E018_USER_EXISTS,
format!("user '{}' already registered", username),
))
}
Err(e) => return Promise::err(storage_err(e)),
}
// Hook: on_user_registered — fires after successful registration.
self.hooks.on_user_registered(&username, &identity_key);
if !identity_key.is_empty() {
if let Err(e) = self
.store
.store_user_identity_key(&username, identity_key.clone())
.map_err(storage_err)
{
return Promise::err(e);
}
// Append (username, identity_key) to the Key Transparency Merkle log.
match self.kt_log.lock() {
Ok(mut log) => {
log.append(&username, &identity_key);
// Persist after each append (small extra cost, but ensures durability).
match log.to_bytes() {
Ok(bytes) => {
if let Err(e) = self.store.save_kt_log(bytes) {
tracing::warn!(user = %username, error = %e, "KT log persist failed");
}
}
Err(e) => {
tracing::warn!(user = %username, error = %e, "KT log serialise failed");
}
}
tracing::info!(user = %username, tree_size = log.len(), "KT: appended identity binding");
}
Err(e) => {
tracing::warn!(user = %username, error = %e, "KT log lock poisoned; skipping append");
}
}
}
results.get().set_success(true);
tracing::info!(user = %username, "OPAQUE registration complete");
Promise::ok(())
}
}

View File

@@ -0,0 +1,326 @@
//! uploadBlob / downloadBlob RPCs: chunked file transfer with SHA-256 integrity verification.
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::PathBuf;
use capnp::capability::Promise;
use quicprochat_proto::node_capnp::node_service;
use sha2::{Digest, Sha256};
use crate::auth::{coded_error, fmt_hex, validate_auth_context};
use crate::error_codes::*;
use super::NodeServiceImpl;
/// Maximum blob size: 50 MB.
const MAX_BLOB_SIZE: u64 = 50 * 1024 * 1024;
/// Maximum download chunk size: 256 KB.
const MAX_DOWNLOAD_CHUNK: u32 = 256 * 1024;
/// Metadata stored alongside each completed blob.
#[derive(serde::Serialize, serde::Deserialize)]
struct BlobMeta {
mime_type: String,
total_size: u64,
uploaded_at: u64,
uploader_key_prefix: String,
}
/// Resolve the blobs directory from the server's data_dir.
fn blobs_dir(data_dir: &std::path::Path) -> PathBuf {
data_dir.join("blobs")
}
impl NodeServiceImpl {
pub fn handle_upload_blob(
&mut self,
params: node_service::UploadBlobParams,
mut results: node_service::UploadBlobResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
let blob_hash = match p.get_blob_hash() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let chunk = match p.get_chunk() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let offset = p.get_offset();
let total_size = p.get_total_size();
let mime_type = match p.get_mime_type() {
Ok(v) => match v.to_str() {
Ok(s) => s.to_string(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
},
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
// Validate blobHash length.
if blob_hash.len() != 32 {
return Promise::err(coded_error(
E025_BLOB_HASH_LENGTH,
format!("blobHash must be exactly 32 bytes, got {}", blob_hash.len()),
));
}
// Validate totalSize.
if total_size > MAX_BLOB_SIZE {
return Promise::err(coded_error(
E024_BLOB_TOO_LARGE,
format!("totalSize {} exceeds max blob size ({} bytes)", total_size, MAX_BLOB_SIZE),
));
}
if total_size == 0 {
return Promise::err(coded_error(E020_BAD_PARAMS, "totalSize must be > 0"));
}
// Validate chunk bounds.
if offset.checked_add(chunk.len() as u64).is_none_or(|end| end > total_size) {
return Promise::err(coded_error(
E020_BAD_PARAMS,
format!(
"chunk out of bounds: offset={} + chunk_len={} > totalSize={}",
offset,
chunk.len(),
total_size
),
));
}
let blob_hex = hex::encode(&blob_hash);
let dir = blobs_dir(&self.data_dir);
// Ensure blobs directory exists.
if let Err(e) = std::fs::create_dir_all(&dir) {
return Promise::err(coded_error(
E009_STORAGE_ERROR,
format!("failed to create blobs directory: {e}"),
));
}
let part_path = dir.join(format!("{blob_hex}.part"));
let final_path = dir.join(&blob_hex);
let meta_path = dir.join(format!("{blob_hex}.meta"));
// All file I/O is delegated to spawn_blocking to avoid stalling the Tokio event loop.
let uploader_prefix = auth_ctx
.identity_key
.as_deref()
.filter(|k| k.len() >= 4)
.map(|k| hex::encode(&k[..4]))
.unwrap_or_default();
Promise::from_future(async move {
let blob_hash_clone = blob_hash.clone();
let part_path_clone = part_path.clone();
let final_path_clone = final_path.clone();
let meta_path_clone = meta_path.clone();
let chunk_clone = chunk;
let mime_clone = mime_type.clone();
let prefix_clone = uploader_prefix.clone();
let io_result = tokio::task::spawn_blocking(move || -> Result<Option<bool>, String> {
// If the blob already exists (fully uploaded), return immediately.
if final_path_clone.exists() {
return Ok(None); // signals "already done"
}
// Write chunk at the given offset.
let mut file = std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(false)
.open(&part_path_clone)
.map_err(|e| format!("open .part file: {e}"))?;
file.seek(SeekFrom::Start(offset))
.map_err(|e| format!("seek: {e}"))?;
file.write_all(&chunk_clone)
.map_err(|e| format!("write chunk: {e}"))?;
file.sync_all()
.map_err(|e| format!("sync: {e}"))?;
// Check if the blob is complete.
let end = offset + chunk_clone.len() as u64;
if end == total_size {
// Verify SHA-256 of the complete file.
let mut vfile = std::fs::File::open(&part_path_clone)
.map_err(|e| format!("open for verify: {e}"))?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = vfile.read(&mut buf).map_err(|e| format!("read: {e}"))?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
let computed: [u8; 32] = hasher.finalize().into();
if computed != blob_hash_clone.as_slice() {
let _ = std::fs::remove_file(&part_path_clone);
return Ok(Some(false)); // hash mismatch
}
// Hash matches — finalize the blob.
std::fs::rename(&part_path_clone, &final_path_clone)
.map_err(|e| format!("rename .part to final: {e}"))?;
// Write metadata file.
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let meta = BlobMeta {
mime_type: mime_clone,
total_size,
uploaded_at: now,
uploader_key_prefix: prefix_clone,
};
if let Err(e) = (|| -> Result<(), String> {
let json = serde_json::to_string_pretty(&meta)
.map_err(|e| format!("serialize meta: {e}"))?;
std::fs::write(&meta_path_clone, json.as_bytes())
.map_err(|e| format!("write meta: {e}"))?;
Ok(())
})() {
tracing::warn!(error = %e, "failed to write blob metadata");
}
return Ok(Some(true)); // complete + verified
}
Ok(None) // chunk written, not yet complete
})
.await
.map_err(|e| capnp::Error::failed(format!("spawn_blocking join: {e}")))?;
match io_result {
Ok(None) => {
// Already existed or chunk written (not yet complete).
results.get().set_blob_id(&blob_hash);
}
Ok(Some(true)) => {
// Complete and verified.
tracing::info!(
blob_hash_prefix = %fmt_hex(&blob_hash[..4]),
total_size = total_size,
mime_type = %mime_type,
uploader_prefix = %uploader_prefix,
"audit: blob_upload_complete"
);
results.get().set_blob_id(&blob_hash);
}
Ok(Some(false)) => {
return Err(coded_error(
E026_BLOB_HASH_MISMATCH,
"SHA-256 of uploaded data does not match blobHash",
));
}
Err(e) => {
return Err(coded_error(E009_STORAGE_ERROR, e));
}
}
Ok(())
})
}
pub fn handle_download_blob(
&mut self,
params: node_service::DownloadBlobParams,
mut results: node_service::DownloadBlobResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if let Err(e) = validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
return Promise::err(e);
}
let blob_id = match p.get_blob_id() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let offset = p.get_offset();
let length = p.get_length().min(MAX_DOWNLOAD_CHUNK);
if blob_id.len() != 32 {
return Promise::err(coded_error(
E025_BLOB_HASH_LENGTH,
format!("blobId must be exactly 32 bytes, got {}", blob_id.len()),
));
}
let blob_hex = hex::encode(&blob_id);
let dir = blobs_dir(&self.data_dir);
let blob_path = dir.join(&blob_hex);
let meta_path = dir.join(format!("{blob_hex}.meta"));
// Delegate all file I/O to spawn_blocking to avoid stalling the event loop.
Promise::from_future(async move {
let io_result = tokio::task::spawn_blocking(move || -> Result<(Vec<u8>, BlobMeta), capnp::Error> {
// Check that the blob exists.
if !blob_path.exists() {
return Err(coded_error(E027_BLOB_NOT_FOUND, "blob not found"));
}
// Read metadata.
let meta: BlobMeta = match std::fs::read_to_string(&meta_path) {
Ok(json) => serde_json::from_str(&json).map_err(|e| {
coded_error(E009_STORAGE_ERROR, format!("corrupt blob metadata: {e}"))
})?,
Err(e) => {
return Err(coded_error(
E009_STORAGE_ERROR,
format!("read blob metadata: {e}"),
));
}
};
// Read the requested chunk.
let mut file = std::fs::File::open(&blob_path)
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("open blob: {e}")))?;
let file_len = file
.metadata()
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("file metadata: {e}")))?
.len();
if offset >= file_len {
return Ok((vec![], meta));
}
file.seek(SeekFrom::Start(offset))
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("seek: {e}")))?;
let remaining = (file_len - offset) as usize;
let to_read = remaining.min(length as usize);
let mut buf = vec![0u8; to_read];
file.read_exact(&mut buf)
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("read chunk: {e}")))?;
Ok((buf, meta))
})
.await
.map_err(|e| capnp::Error::failed(format!("spawn_blocking join: {e}")))?;
let (chunk, meta) = io_result?;
let mut r = results.get();
r.set_chunk(&chunk);
r.set_total_size(meta.total_size);
r.set_mime_type(&meta.mime_type);
Ok(())
})
}
}

View File

@@ -0,0 +1,74 @@
//! createChannel RPC: create or look up a 1:1 DM channel.
use capnp::capability::Promise;
use quicprochat_proto::node_capnp::node_service;
use crate::auth::{coded_error, require_identity, validate_auth_context};
use crate::error_codes::*;
use crate::storage::StorageError;
use crate::hooks::ChannelEvent;
use super::NodeServiceImpl;
fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
impl NodeServiceImpl {
pub fn handle_create_channel(
&mut self,
params: node_service::CreateChannelParams,
mut results: node_service::CreateChannelResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let peer_key = match p.get_peer_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
let identity = match require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
if peer_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("peerKey must be exactly 32 bytes, got {}", peer_key.len()),
));
}
if identity == peer_key {
return Promise::err(coded_error(
E020_BAD_PARAMS,
"peerKey must not equal caller identity",
));
}
let (channel_id, was_new) = match self.store.create_channel(identity, &peer_key) {
Ok(pair) => pair,
Err(e) => return Promise::err(storage_err(e)),
};
// Hook: on_channel_created — fires after channel is created or looked up.
self.hooks.on_channel_created(&ChannelEvent {
channel_id: channel_id.clone(),
initiator_key: identity.to_vec(),
peer_key: peer_key.clone(),
was_new,
});
let mut r = results.get();
r.set_channel_id(&channel_id);
r.set_was_new(was_new);
Promise::ok(())
}
}

View File

@@ -0,0 +1,916 @@
use std::sync::Arc;
use std::time::Duration;
use capnp::capability::Promise;
use dashmap::DashMap;
use quicprochat_proto::node_capnp::node_service;
use tokio::sync::Notify;
use tokio::time::timeout;
use sha2::{Digest, Sha256};
use crate::auth::{
check_rate_limit, coded_error, fmt_hex, require_identity_or_request, validate_auth_context,
};
use crate::error_codes::*;
use crate::metrics;
use crate::storage::{StorageError, Store};
use super::{NodeServiceImpl, CURRENT_WIRE_VERSION};
use crate::hooks::{HookAction, MessageEvent, FetchEvent};
// Audit events here must not include secrets: no payload content, no full recipient/token bytes (prefix only).
/// Hash first 4 bytes of the key's SHA-256 as a hex string (for redacted audit logs).
fn redacted_prefix(key: &[u8]) -> String {
let hash = Sha256::digest(key);
fmt_hex(&hash[..4])
}
const MAX_PAYLOAD_BYTES: usize = 5 * 1024 * 1024; // 5 MB cap per message
const MAX_QUEUE_DEPTH: usize = 1000;
/// Build a 96-byte delivery proof: SHA-256(seq || recipient_key || timestamp_ms) || Ed25519 sig.
///
/// Layout:
/// bytes 0..32 — SHA-256 preimage hash
/// bytes 32..96 — Ed25519 signature over those 32 bytes
fn build_delivery_proof(
signing_key: &quicprochat_core::IdentityKeypair,
seq: u64,
recipient_key: &[u8],
timestamp_ms: u64,
) -> [u8; 96] {
let mut hasher = Sha256::new();
hasher.update(seq.to_le_bytes());
hasher.update(recipient_key);
hasher.update(timestamp_ms.to_le_bytes());
let hash: [u8; 32] = hasher.finalize().into();
let sig = signing_key.sign_raw(&hash);
let mut proof = [0u8; 96];
proof[..32].copy_from_slice(&hash);
proof[32..].copy_from_slice(&sig);
proof
}
fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
pub fn fill_payloads_wait(
results: &mut node_service::FetchWaitResults,
messages: Vec<(u64, Vec<u8>)>,
) {
let mut list = results.get().init_payloads(messages.len() as u32);
for (i, (seq, data)) in messages.iter().enumerate() {
let mut entry = list.reborrow().get(i as u32);
entry.set_seq(*seq);
entry.set_data(data);
}
}
impl NodeServiceImpl {
pub fn handle_enqueue(
&mut self,
params: node_service::EnqueueParams,
mut results: node_service::EnqueueResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let payload = match p.get_payload() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let ttl_secs_raw = p.get_ttl_secs();
let ttl_secs = if ttl_secs_raw > 0 { Some(ttl_secs_raw) } else { None };
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if payload.is_empty() {
return Promise::err(coded_error(E005_PAYLOAD_EMPTY, "payload must not be empty"));
}
if payload.len() > MAX_PAYLOAD_BYTES {
return Promise::err(coded_error(
E006_PAYLOAD_TOO_LARGE,
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) {
// Audit: rate limit hit — do not log token or identity.
tracing::warn!("rate_limit_hit");
metrics::record_rate_limit_hit_total();
return Promise::err(e);
}
// Phase 4.3 — DS sender identity binding.
// When sealed_sender is false, the sender MUST have an identity-bound session.
// The sender_identity used for audit/hooks is ALWAYS derived from
// auth_ctx.identity_key (populated by OPAQUE session lookup in validate_auth_context),
// never from any client-supplied field. This guarantees that the server only
// attributes messages to the cryptographically authenticated identity.
if !self.sealed_sender {
if let Err(e) = crate::auth::require_identity(&auth_ctx) {
return Promise::err(e);
}
}
// Federation routing: if the recipient's home server differs from ours, relay the
// message to the remote server instead of enqueueing locally. This enables
// cross-node delivery in a Freifunk / community mesh deployment.
if let (Some(fed_client), Some(local_domain)) =
(&self.federation_client, &self.local_domain)
{
let dest = crate::federation::routing::resolve_destination(
&self.store,
&recipient_key,
local_domain,
);
if let crate::federation::routing::Destination::Remote(remote_domain) = dest {
let fed = Arc::clone(fed_client);
let rk = recipient_key;
let pl = payload;
let ch = channel_id;
tracing::info!(
recipient_prefix = %fmt_hex(&rk[..4]),
domain = %remote_domain,
"federation: routing enqueue to remote server"
);
return Promise::from_future(async move {
let seq = fed
.relay_enqueue(&remote_domain, &rk, &pl, &ch)
.await
.map_err(|e| {
capnp::Error::failed(format!("federation relay failed: {e}"))
})?;
results.get().set_seq(seq);
metrics::record_enqueue_total();
metrics::record_enqueue_bytes(pl.len() as u64);
Ok(())
});
}
}
// DM channel authz: channel_id.len() == 16 means a created channel; caller and recipient must be the two members.
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key == *a && caller == b.as_slice())
|| (recipient_key == *b && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
match self.store.queue_depth(&recipient_key, &channel_id) {
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
return Promise::err(coded_error(
E015_QUEUE_FULL,
format!("queue depth {} exceeds limit {}", depth, MAX_QUEUE_DEPTH),
));
}
Err(e) => return Promise::err(storage_err(e)),
_ => {}
}
let payload_len = payload.len();
// sender_identity is derived solely from auth_ctx (server-side session state).
let sender_identity = if self.sealed_sender {
None
} else {
crate::auth::require_identity(&auth_ctx).ok().map(|v| v.to_vec())
};
let sender_prefix = sender_identity
.as_deref()
.filter(|id| id.len() >= 4)
.map(|id| fmt_hex(&id[..4]));
// Hook: on_message_enqueue — fires after validation, before storage.
let hook_event = MessageEvent {
sender_identity: sender_identity.clone(),
recipient_key: recipient_key.clone(),
channel_id: channel_id.clone(),
payload_len,
seq: 0, // not yet assigned
};
if let HookAction::Reject(reason) = self.hooks.on_message_enqueue(&hook_event) {
return Promise::err(capnp::Error::failed(format!("hook rejected enqueue: {reason}")));
}
let seq = match self
.store
.enqueue(&recipient_key, &channel_id, payload, ttl_secs)
.map_err(storage_err)
{
Ok(seq) => seq,
Err(e) => return Promise::err(e),
};
let timestamp_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let proof = build_delivery_proof(&self.signing_key, seq, &recipient_key, timestamp_ms);
let mut r = results.get();
r.set_seq(seq);
r.set_delivery_proof(&proof);
// Metrics and audit. Audit events must not include secrets (no payload, no full keys).
metrics::record_enqueue_total();
metrics::record_enqueue_bytes(payload_len as u64);
if let Ok(depth) = self.store.queue_depth(&recipient_key, &channel_id) {
metrics::record_delivery_queue_depth(depth);
}
if self.redact_logs {
let redacted_sender = sender_identity
.as_deref()
.map(redacted_prefix)
.unwrap_or_else(|| "sealed".to_string());
tracing::info!(
sender_prefix = %redacted_sender,
recipient_prefix = %redacted_prefix(&recipient_key),
seq = seq,
"audit: enqueue"
);
} else {
tracing::info!(
sender_prefix = sender_prefix.as_deref().unwrap_or("sealed"),
recipient_prefix = %fmt_hex(&recipient_key[..4]),
payload_len = payload_len,
seq = seq,
"audit: enqueue"
);
}
crate::auth::waiter(&self.waiters, &recipient_key).notify_waiters();
Promise::ok(())
}
pub fn handle_fetch(
&mut self,
params: node_service::FetchParams,
mut results: node_service::FetchResults,
) -> Promise<(), capnp::Error> {
let recipient_key = match params.get() {
Ok(p) => match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
},
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = params
.get()
.ok()
.and_then(|p| p.get_channel_id().ok())
.map(|c| c.to_vec())
.unwrap_or_default();
let version = params
.get()
.ok()
.map(|p| p.get_version())
.unwrap_or(CURRENT_WIRE_VERSION);
let limit = params.get().ok().map(|p| p.get_limit()).unwrap_or(0);
let auth_ctx = match params
.get()
.ok()
.map(|p| validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()))
.transpose()
{
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
let auth_ctx = match auth_ctx {
Some(ctx) => ctx,
None => return Promise::err(coded_error(E003_INVALID_TOKEN, "auth required")),
};
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
let messages = if limit > 0 {
match self
.store
.fetch_limited(&recipient_key, &channel_id, limit as usize)
.map_err(storage_err)
{
Ok(m) => m,
Err(e) => return Promise::err(e),
}
} else {
match self
.store
.fetch(&recipient_key, &channel_id)
.map_err(storage_err)
{
Ok(m) => m,
Err(e) => return Promise::err(e),
}
};
// Hook: on_fetch — fires after messages are retrieved.
self.hooks.on_fetch(&FetchEvent {
recipient_key: recipient_key.clone(),
channel_id: channel_id.clone(),
message_count: messages.len(),
});
// Audit: fetch — do not log payload or full keys.
metrics::record_fetch_total();
if self.redact_logs {
tracing::info!(
recipient_prefix = %redacted_prefix(&recipient_key),
count = messages.len(),
"audit: fetch"
);
} else {
tracing::info!(
recipient_prefix = %fmt_hex(&recipient_key[..4]),
count = messages.len(),
"audit: fetch"
);
}
let mut list = results.get().init_payloads(messages.len() as u32);
for (i, (seq, data)) in messages.iter().enumerate() {
let mut entry = list.reborrow().get(i as u32);
entry.set_seq(*seq);
entry.set_data(data);
}
Promise::ok(())
}
pub fn handle_fetch_wait(
&mut self,
params: node_service::FetchWaitParams,
mut results: node_service::FetchWaitResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let timeout_ms = p.get_timeout_ms();
let limit = p.get_limit();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
let store = Arc::clone(&self.store);
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = self.waiters.clone();
Promise::from_future(async move {
let fetch_fn = |s: &Arc<dyn Store>, rk: &[u8], ch: &[u8], lim: u32| -> Result<Vec<(u64, Vec<u8>)>, capnp::Error> {
if lim > 0 {
s.fetch_limited(rk, ch, lim as usize).map_err(storage_err)
} else {
s.fetch(rk, ch).map_err(storage_err)
}
};
// Register waiter BEFORE the initial fetch to close the TOCTOU window:
// an enqueue between fetch and registration would fire notify before
// the waiter exists, causing a missed wakeup.
let waiter = if timeout_ms > 0 {
Some(
waiters
.entry(recipient_key.clone())
.or_insert_with(|| Arc::new(Notify::new()))
.clone(),
)
} else {
None
};
let messages = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
if messages.is_empty() {
if let Some(waiter) = waiter {
let _ = timeout(Duration::from_millis(timeout_ms), waiter.notified()).await;
let msgs = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
fill_payloads_wait(&mut results, msgs);
metrics::record_fetch_wait_total();
return Ok(());
}
}
fill_payloads_wait(&mut results, messages);
metrics::record_fetch_wait_total();
Ok(())
})
}
pub fn handle_peek(
&mut self,
params: node_service::PeekParams,
mut results: node_service::PeekResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let limit = p.get_limit();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
let messages = match self
.store
.peek(&recipient_key, &channel_id, limit as usize)
.map_err(storage_err)
{
Ok(m) => m,
Err(e) => return Promise::err(e),
};
if self.redact_logs {
tracing::info!(
recipient_prefix = %redacted_prefix(&recipient_key),
count = messages.len(),
"audit: peek"
);
} else {
tracing::info!(
recipient_prefix = %fmt_hex(&recipient_key[..4]),
count = messages.len(),
"audit: peek"
);
}
let mut list = results.get().init_payloads(messages.len() as u32);
for (i, (seq, data)) in messages.iter().enumerate() {
let mut entry = list.reborrow().get(i as u32);
entry.set_seq(*seq);
entry.set_data(data);
}
Promise::ok(())
}
pub fn handle_ack(
&mut self,
params: node_service::AckParams,
_results: node_service::AckResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let seq_up_to = p.get_seq_up_to();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
match self
.store
.ack(&recipient_key, &channel_id, seq_up_to)
.map_err(storage_err)
{
Ok(removed) => {
if self.redact_logs {
tracing::info!(
recipient_prefix = %redacted_prefix(&recipient_key),
seq_up_to = seq_up_to,
removed = removed,
"audit: ack"
);
} else {
tracing::info!(
recipient_prefix = %fmt_hex(&recipient_key[..4]),
seq_up_to = seq_up_to,
removed = removed,
"audit: ack"
);
}
}
Err(e) => return Promise::err(e),
}
Promise::ok(())
}
pub fn handle_batch_enqueue(
&mut self,
params: node_service::BatchEnqueueParams,
mut results: node_service::BatchEnqueueResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_keys = match p.get_recipient_keys() {
Ok(v) => v,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let payload = match p.get_payload() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let ttl_secs_raw = p.get_ttl_secs();
let ttl_secs = if ttl_secs_raw > 0 { Some(ttl_secs_raw) } else { None };
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if payload.is_empty() {
return Promise::err(coded_error(E005_PAYLOAD_EMPTY, "payload must not be empty"));
}
if payload.len() > MAX_PAYLOAD_BYTES {
return Promise::err(coded_error(
E006_PAYLOAD_TOO_LARGE,
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) {
tracing::warn!("rate_limit_hit");
metrics::record_rate_limit_hit_total();
return Promise::err(e);
}
// Phase 4.3 — DS sender identity binding (same guarantee as handle_enqueue).
// sender_identity is derived solely from auth_ctx.identity_key, never client data.
if !self.sealed_sender {
if let Err(e) = crate::auth::require_identity(&auth_ctx) {
return Promise::err(e);
}
}
// DM channel authz: validate caller membership once before the loop.
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
if !caller_in {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller is not a member of this channel",
));
}
}
// Eagerly collect recipient keys so params can be dropped before any async work.
let mut recipient_key_vecs: Vec<Vec<u8>> = Vec::with_capacity(recipient_keys.len() as usize);
for i in 0..recipient_keys.len() {
let rk = match recipient_keys.get(i) {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if rk.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey[{}] must be exactly 32 bytes, got {}", i, rk.len()),
));
}
// Per-recipient DM channel membership check (only when channel_id is a 16-byte UUID).
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(
E023_CHANNEL_NOT_FOUND,
"channel not found",
));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let recipient_other = (rk == *a && caller == b.as_slice())
|| (rk == *b && caller == a.as_slice());
if !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"recipient is not a member of this channel",
));
}
}
recipient_key_vecs.push(rk);
}
// Hook: on_message_enqueue for each recipient — fires before storage.
// sender_identity is derived solely from auth_ctx (server-side session state).
let sender_identity = if self.sealed_sender {
None
} else {
crate::auth::require_identity(&auth_ctx).ok().map(|v| v.to_vec())
};
let sender_prefix = sender_identity
.as_deref()
.filter(|id| id.len() >= 4)
.map(|id| fmt_hex(&id[..4]));
let mut hook_events = Vec::with_capacity(recipient_key_vecs.len());
for rk in &recipient_key_vecs {
let event = MessageEvent {
sender_identity: sender_identity.clone(),
recipient_key: rk.clone(),
channel_id: channel_id.clone(),
payload_len: payload.len(),
seq: 0,
};
if let HookAction::Reject(reason) = self.hooks.on_message_enqueue(&event) {
return Promise::err(capnp::Error::failed(format!("hook rejected enqueue: {reason}")));
}
hook_events.push(event);
}
let n = recipient_key_vecs.len();
let store = Arc::clone(&self.store);
let waiters = Arc::clone(&self.waiters);
let fed_client = self.federation_client.clone();
let local_domain = self.local_domain.clone();
let hooks = Arc::clone(&self.hooks);
let redact_logs = self.redact_logs;
// Use an async future to support federation relay alongside local enqueue.
// All storage operations are synchronous; only federation relay calls are await-ed.
Promise::from_future(async move {
let mut seqs = Vec::with_capacity(n);
for rk in &recipient_key_vecs {
// Federation routing: relay to the recipient's home server when remote.
let dest = if let (Some(ref _fed), Some(ref domain)) = (&fed_client, &local_domain) {
crate::federation::routing::resolve_destination(&store, rk, domain)
} else {
crate::federation::routing::Destination::Local
};
let seq = match dest {
crate::federation::routing::Destination::Remote(ref remote_domain) => {
let fed = fed_client.as_deref().ok_or_else(|| {
capnp::Error::failed("federation client unavailable for remote routing".into())
})?;
tracing::info!(
recipient_prefix = %fmt_hex(&rk[..4]),
domain = %remote_domain,
"federation: routing batch enqueue to remote server"
);
fed.relay_enqueue(remote_domain, rk, &payload, &channel_id)
.await
.map_err(|e| {
capnp::Error::failed(format!("federation relay failed: {e}"))
})?
}
crate::federation::routing::Destination::Local => {
match store.queue_depth(rk, &channel_id) {
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
return Err(coded_error(
E015_QUEUE_FULL,
format!("queue depth {} exceeds limit {MAX_QUEUE_DEPTH}", depth),
));
}
Err(e) => return Err(storage_err(e)),
_ => {}
}
store
.enqueue(rk, &channel_id, payload.clone(), ttl_secs)
.map_err(storage_err)?
}
};
seqs.push(seq);
metrics::record_enqueue_total();
metrics::record_enqueue_bytes(payload.len() as u64);
crate::auth::waiter(&waiters, rk).notify_waiters();
}
let mut list = results.get().init_seqs(seqs.len() as u32);
for (i, seq) in seqs.iter().enumerate() {
list.set(i as u32, *seq);
}
// Hook: on_batch_enqueue — fires after all messages are stored.
hooks.on_batch_enqueue(&hook_events);
if redact_logs {
tracing::info!(
recipient_count = n,
"audit: batch_enqueue"
);
} else {
tracing::info!(
sender_prefix = sender_prefix.as_deref().unwrap_or("sealed"),
recipient_count = n,
payload_len = payload.len(),
"audit: batch_enqueue"
);
}
Ok(())
})
}
}

View File

@@ -0,0 +1,154 @@
//! Device registry RPC handlers: registerDevice, listDevices, revokeDevice.
use capnp::capability::Promise;
use quicprochat_proto::node_capnp::node_service;
use crate::auth::{coded_error, require_identity, validate_auth_context};
use crate::error_codes::*;
use crate::storage::StorageError;
use super::NodeServiceImpl;
const MAX_DEVICES_PER_IDENTITY: usize = 5;
fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
impl NodeServiceImpl {
pub fn handle_register_device(
&mut self,
params: node_service::RegisterDeviceParams,
mut results: node_service::RegisterDeviceResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
let identity_key = match require_identity(&auth_ctx) {
Ok(ik) => ik.to_vec(),
Err(e) => return Promise::err(e),
};
let device_id = match p.get_device_id() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if device_id.is_empty() {
return Promise::err(coded_error(E020_BAD_PARAMS, "deviceId must not be empty"));
}
let device_name = match p.get_device_name() {
Ok(n) => match n.to_str() {
Ok(s) => s.to_string(),
Err(_) => return Promise::err(coded_error(E020_BAD_PARAMS, "deviceName must be valid UTF-8")),
},
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
// Check device limit.
match self.store.device_count(&identity_key) {
Ok(count) if count >= MAX_DEVICES_PER_IDENTITY => {
return Promise::err(coded_error(
E029_DEVICE_LIMIT,
format!("maximum {MAX_DEVICES_PER_IDENTITY} devices per identity"),
));
}
Err(e) => return Promise::err(storage_err(e)),
_ => {}
}
match self.store.register_device(&identity_key, &device_id, &device_name) {
Ok(success) => {
results.get().set_success(success);
Promise::ok(())
}
Err(e) => Promise::err(storage_err(e)),
}
}
pub fn handle_list_devices(
&mut self,
params: node_service::ListDevicesParams,
mut results: node_service::ListDevicesResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
let identity_key = match require_identity(&auth_ctx) {
Ok(ik) => ik.to_vec(),
Err(e) => return Promise::err(e),
};
let devices = match self.store.list_devices(&identity_key) {
Ok(d) => d,
Err(e) => return Promise::err(storage_err(e)),
};
let r = results.get();
let mut list = r.init_devices(devices.len() as u32);
for (i, (device_id, name, registered_at)) in devices.iter().enumerate() {
let mut entry = list.reborrow().get(i as u32);
entry.set_device_id(device_id);
entry.set_device_name(name);
entry.set_registered_at(*registered_at);
}
Promise::ok(())
}
pub fn handle_revoke_device(
&mut self,
params: node_service::RevokeDeviceParams,
mut results: node_service::RevokeDeviceResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
let identity_key = match require_identity(&auth_ctx) {
Ok(ik) => ik.to_vec(),
Err(e) => return Promise::err(e),
};
let device_id = match p.get_device_id() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if device_id.is_empty() {
return Promise::err(coded_error(E020_BAD_PARAMS, "deviceId must not be empty"));
}
match self.store.revoke_device(&identity_key, &device_id) {
Ok(true) => {
results.get().set_success(true);
Promise::ok(())
}
Ok(false) => {
Promise::err(coded_error(E030_DEVICE_NOT_FOUND, "device not found"))
}
Err(e) => Promise::err(storage_err(e)),
}
}
}

View File

@@ -0,0 +1,294 @@
use capnp::capability::Promise;
use quicprochat_proto::node_capnp::node_service;
use crate::auth::{coded_error, fmt_hex, require_identity_or_request, validate_auth_context};
use crate::error_codes::*;
use crate::metrics;
use crate::storage::StorageError;
use super::NodeServiceImpl;
fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
const MAX_KEYPACKAGE_BYTES: usize = 1024 * 1024; // 1 MB cap per KeyPackage
impl NodeServiceImpl {
pub fn handle_upload_key_package(
&mut self,
params: node_service::UploadKeyPackageParams,
mut results: node_service::UploadKeyPackageResults,
) -> Promise<(), capnp::Error> {
let (auth_ctx, identity_key, package) = match params.get() {
Ok(p) => {
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
let ik = match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let pkg = match p.get_package() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
(auth_ctx, ik, pkg)
}
Err(e) => return Promise::err(e),
};
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
));
}
if package.is_empty() {
return Promise::err(coded_error(E007_PACKAGE_EMPTY, "package must not be empty"));
}
if package.len() > MAX_KEYPACKAGE_BYTES {
return Promise::err(coded_error(
E008_PACKAGE_TOO_LARGE,
format!("package exceeds max size ({} bytes)", MAX_KEYPACKAGE_BYTES),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&identity_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
if let Err(e) = quicprochat_core::validate_keypackage_ciphersuite(&package) {
return Promise::err(coded_error(
E021_CIPHERSUITE_NOT_ALLOWED,
format!("KeyPackage ciphersuite not allowed: {e}"),
));
}
let fingerprint: Vec<u8> = crate::auth::fingerprint(&package);
if let Err(e) = self
.store
.upload_key_package(&identity_key, package)
.map_err(storage_err)
{
return Promise::err(e);
}
results.get().set_fingerprint(&fingerprint);
metrics::record_key_package_upload_total();
// Audit: KeyPackage upload — only fingerprint prefix, no secrets.
tracing::info!(
identity_prefix = %fmt_hex(&identity_key[..4]),
fingerprint_prefix = %fmt_hex(&fingerprint[..4]),
"audit: key_package_upload"
);
Promise::ok(())
}
pub fn handle_fetch_key_package(
&mut self,
params: node_service::FetchKeyPackageParams,
mut results: node_service::FetchKeyPackageResults,
) -> Promise<(), capnp::Error> {
let identity_key = match params.get() {
Ok(p) => match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
},
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if let Err(e) = params
.get()
.ok()
.map(|p| crate::auth::validate_auth(&self.auth_cfg, &self.sessions, p.get_auth()))
.transpose()
{
return Promise::err(e);
}
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
));
}
let package = match self
.store
.fetch_key_package(&identity_key)
.map_err(storage_err)
{
Ok(p) => p,
Err(e) => return Promise::err(e),
};
match package {
Some(pkg) => {
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "KeyPackage fetched");
results.get().set_package(&pkg);
}
None => {
tracing::debug!(
identity = %fmt_hex(&identity_key[..4]),
"no KeyPackage available for identity"
);
results.get().set_package(&[]);
}
}
Promise::ok(())
}
pub fn handle_upload_hybrid_key(
&mut self,
params: node_service::UploadHybridKeyParams,
_results: node_service::UploadHybridKeyResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let identity_key = match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let hybrid_pk = match p.get_hybrid_public_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
));
}
if hybrid_pk.is_empty() {
return Promise::err(coded_error(E013_HYBRID_KEY_EMPTY, "hybridPublicKey must not be empty"));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&identity_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
if let Err(e) = self
.store
.upload_hybrid_key(&identity_key, hybrid_pk)
.map_err(storage_err)
{
return Promise::err(e);
}
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "hybrid public key uploaded");
Promise::ok(())
}
pub fn handle_fetch_hybrid_key(
&mut self,
params: node_service::FetchHybridKeyParams,
mut results: node_service::FetchHybridKeyResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let identity_key = match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
// Auth check only — any authenticated user can fetch any peer's hybrid public key.
if let Err(e) = validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
return Promise::err(e);
}
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
));
}
let hybrid_pk = match self
.store
.fetch_hybrid_key(&identity_key)
.map_err(storage_err)
{
Ok(p) => p,
Err(e) => return Promise::err(e),
};
match hybrid_pk {
Some(pk) => {
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "hybrid key fetched");
results.get().set_hybrid_public_key(&pk);
}
None => {
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "no hybrid key for identity");
results.get().set_hybrid_public_key(&[]);
}
}
Promise::ok(())
}
pub fn handle_fetch_hybrid_keys(
&mut self,
params: node_service::FetchHybridKeysParams,
mut results: node_service::FetchHybridKeysResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let identity_keys = match p.get_identity_keys() {
Ok(v) => v,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if let Err(e) = validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
return Promise::err(e);
}
let count = identity_keys.len() as usize;
let mut key_data: Vec<Vec<u8>> = Vec::with_capacity(count);
for i in 0..identity_keys.len() {
let ik = match identity_keys.get(i) {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let pk = match self.store.fetch_hybrid_key(&ik).map_err(storage_err) {
Ok(Some(pk)) => pk,
Ok(None) => vec![],
Err(e) => return Promise::err(e),
};
key_data.push(pk);
}
let mut list = results.get().init_keys(key_data.len() as u32);
for (i, pk) in key_data.iter().enumerate() {
list.set(i as u32, pk);
}
tracing::debug!(count = count, "batch hybrid key fetch");
Promise::ok(())
}
}

View File

@@ -0,0 +1,434 @@
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use capnp_rpc::RpcSystem;
use dashmap::DashMap;
use opaque_ke::ServerSetup;
use quicprochat_core::opaque_auth::OpaqueSuite;
use quicprochat_kt::MerkleLog;
use quicprochat_proto::node_capnp::node_service;
use tokio::sync::Notify;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
use crate::auth::{
current_timestamp, AuthConfig, PendingLogin, RateEntry, SessionInfo,
PENDING_LOGIN_TTL_SECS, RATE_LIMIT_WINDOW_SECS,
};
use crate::storage::Store;
/// Cap'n Proto traversal limit (words). 4 Mi words = 32 MiB; bounds DoS from deeply nested or large messages.
const CAPNP_TRAVERSAL_LIMIT_WORDS: usize = 4 * 1024 * 1024;
mod account_ops;
mod auth_ops;
mod blob_ops;
mod channel_ops;
mod delivery;
mod device_ops;
mod key_ops;
mod p2p_ops;
mod user_ops;
impl node_service::Server for NodeServiceImpl {
fn upload_key_package(
&mut self,
params: node_service::UploadKeyPackageParams,
results: node_service::UploadKeyPackageResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_upload_key_package(params, results)
}
fn fetch_key_package(
&mut self,
params: node_service::FetchKeyPackageParams,
results: node_service::FetchKeyPackageResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_fetch_key_package(params, results)
}
fn enqueue(
&mut self,
params: node_service::EnqueueParams,
results: node_service::EnqueueResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_enqueue(params, results)
}
fn fetch(
&mut self,
params: node_service::FetchParams,
results: node_service::FetchResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_fetch(params, results)
}
fn fetch_wait(
&mut self,
params: node_service::FetchWaitParams,
results: node_service::FetchWaitResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_fetch_wait(params, results)
}
fn health(
&mut self,
params: node_service::HealthParams,
results: node_service::HealthResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_health(params, results)
}
fn upload_hybrid_key(
&mut self,
params: node_service::UploadHybridKeyParams,
results: node_service::UploadHybridKeyResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_upload_hybrid_key(params, results)
}
fn fetch_hybrid_key(
&mut self,
params: node_service::FetchHybridKeyParams,
results: node_service::FetchHybridKeyResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_fetch_hybrid_key(params, results)
}
fn opaque_login_start(
&mut self,
params: node_service::OpaqueLoginStartParams,
results: node_service::OpaqueLoginStartResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_opaque_login_start(params, results)
}
fn opaque_register_start(
&mut self,
params: node_service::OpaqueRegisterStartParams,
results: node_service::OpaqueRegisterStartResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_opaque_register_start(params, results)
}
fn opaque_login_finish(
&mut self,
params: node_service::OpaqueLoginFinishParams,
results: node_service::OpaqueLoginFinishResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_opaque_login_finish(params, results)
}
fn opaque_register_finish(
&mut self,
params: node_service::OpaqueRegisterFinishParams,
results: node_service::OpaqueRegisterFinishResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_opaque_register_finish(params, results)
}
fn publish_endpoint(
&mut self,
params: node_service::PublishEndpointParams,
results: node_service::PublishEndpointResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_publish_endpoint(params, results)
}
fn resolve_endpoint(
&mut self,
params: node_service::ResolveEndpointParams,
results: node_service::ResolveEndpointResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_resolve_endpoint(params, results)
}
fn peek(
&mut self,
params: node_service::PeekParams,
results: node_service::PeekResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_peek(params, results)
}
fn ack(
&mut self,
params: node_service::AckParams,
results: node_service::AckResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_ack(params, results)
}
fn fetch_hybrid_keys(
&mut self,
params: node_service::FetchHybridKeysParams,
results: node_service::FetchHybridKeysResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_fetch_hybrid_keys(params, results)
}
fn batch_enqueue(
&mut self,
params: node_service::BatchEnqueueParams,
results: node_service::BatchEnqueueResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_batch_enqueue(params, results)
}
fn create_channel(
&mut self,
params: node_service::CreateChannelParams,
results: node_service::CreateChannelResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_create_channel(params, results)
}
fn resolve_user(
&mut self,
params: node_service::ResolveUserParams,
results: node_service::ResolveUserResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_resolve_user(params, results)
}
fn resolve_identity(
&mut self,
params: node_service::ResolveIdentityParams,
results: node_service::ResolveIdentityResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_resolve_identity(params, results)
}
fn upload_blob(
&mut self,
params: node_service::UploadBlobParams,
results: node_service::UploadBlobResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_upload_blob(params, results)
}
fn download_blob(
&mut self,
params: node_service::DownloadBlobParams,
results: node_service::DownloadBlobResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_download_blob(params, results)
}
fn delete_account(
&mut self,
params: node_service::DeleteAccountParams,
results: node_service::DeleteAccountResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_delete_account(params, results)
}
fn register_device(
&mut self,
params: node_service::RegisterDeviceParams,
results: node_service::RegisterDeviceResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_register_device(params, results)
}
fn list_devices(
&mut self,
params: node_service::ListDevicesParams,
results: node_service::ListDevicesResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_list_devices(params, results)
}
fn revoke_device(
&mut self,
params: node_service::RevokeDeviceParams,
results: node_service::RevokeDeviceResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_revoke_device(params, results)
}
}
pub const CURRENT_WIRE_VERSION: u16 = 1;
pub struct NodeServiceImpl {
pub store: Arc<dyn Store>,
pub waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
pub auth_cfg: Arc<AuthConfig>,
pub opaque_setup: Arc<ServerSetup<OpaqueSuite>>,
pub pending_logins: Arc<DashMap<String, PendingLogin>>,
pub sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
pub rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
/// When true, enqueue does not require identity-bound session (Sealed Sender).
pub sealed_sender: bool,
/// Outbound federation client for relaying to remote servers (None if federation disabled).
pub federation_client: Option<Arc<crate::federation::FederationClient>>,
/// This server's federation domain (empty if federation disabled).
pub local_domain: Option<String>,
/// Server-side plugin hooks for extensibility.
pub hooks: Arc<dyn crate::hooks::ServerHooks>,
/// Server Ed25519 signing key for delivery proofs.
pub signing_key: Arc<quicprochat_core::IdentityKeypair>,
/// Key Transparency Merkle log (shared across connections).
pub kt_log: Arc<std::sync::Mutex<MerkleLog>>,
/// Server data directory (used for blob storage).
pub data_dir: PathBuf,
/// When true, hash identity key prefixes and omit payload sizes in audit logs.
pub redact_logs: bool,
}
impl NodeServiceImpl {
#[allow(clippy::too_many_arguments)]
pub fn new(
store: Arc<dyn Store>,
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
auth_cfg: Arc<AuthConfig>,
opaque_setup: Arc<ServerSetup<OpaqueSuite>>,
pending_logins: Arc<DashMap<String, PendingLogin>>,
sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
sealed_sender: bool,
federation_client: Option<Arc<crate::federation::FederationClient>>,
local_domain: Option<String>,
signing_key: Arc<quicprochat_core::IdentityKeypair>,
hooks: Arc<dyn crate::hooks::ServerHooks>,
kt_log: Arc<std::sync::Mutex<MerkleLog>>,
data_dir: PathBuf,
redact_logs: bool,
) -> Self {
Self {
store,
waiters,
auth_cfg,
opaque_setup,
pending_logins,
sessions,
rate_limits,
sealed_sender,
federation_client,
local_domain,
hooks,
signing_key,
kt_log,
data_dir,
redact_logs,
}
}
}
#[allow(clippy::too_many_arguments)]
pub async fn handle_node_connection(
connecting: quinn::Connecting,
store: Arc<dyn Store>,
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
auth_cfg: Arc<AuthConfig>,
opaque_setup: Arc<ServerSetup<OpaqueSuite>>,
pending_logins: Arc<DashMap<String, PendingLogin>>,
sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
sealed_sender: bool,
federation_client: Option<Arc<crate::federation::FederationClient>>,
local_domain: Option<String>,
signing_key: Arc<quicprochat_core::IdentityKeypair>,
hooks: Arc<dyn crate::hooks::ServerHooks>,
kt_log: Arc<std::sync::Mutex<MerkleLog>>,
data_dir: PathBuf,
redact_logs: bool,
) -> Result<(), anyhow::Error> {
let connection = connecting.await?;
tracing::info!(peer = %connection.remote_address(), "QUIC connected");
let (send, recv) = connection
.accept_bi()
.await
.map_err(|e| anyhow::anyhow!("failed to accept bi stream: {e}"))?;
let (reader, writer) = (recv.compat(), send.compat_write());
let mut reader_opts = capnp::message::ReaderOptions::new();
reader_opts.traversal_limit_in_words(Some(CAPNP_TRAVERSAL_LIMIT_WORDS));
let network = capnp_rpc::twoparty::VatNetwork::new(
reader,
writer,
capnp_rpc::rpc_twoparty_capnp::Side::Server,
reader_opts,
);
let service: node_service::Client = capnp_rpc::new_client(NodeServiceImpl::new(
store,
waiters,
auth_cfg,
opaque_setup,
pending_logins,
sessions,
rate_limits,
sealed_sender,
federation_client,
local_domain,
signing_key,
hooks,
kt_log,
data_dir,
redact_logs,
));
RpcSystem::new(Box::new(network), Some(service.client))
.await
.map_err(|e| anyhow::anyhow!("NodeService RPC error: {e}"))
}
const MESSAGE_TTL_SECS: u64 = 7 * 24 * 60 * 60; // 7 days
pub fn spawn_cleanup_task(
sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
pending_logins: Arc<DashMap<String, PendingLogin>>,
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
conn_rate_limits: Arc<DashMap<std::net::IpAddr, RateEntry>>,
store: Arc<dyn Store>,
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
let now = current_timestamp();
sessions.retain(|_, info| info.expires_at > now);
pending_logins.retain(|_, pl| now - pl.created_at < PENDING_LOGIN_TTL_SECS);
rate_limits.retain(|_, entry| now - entry.window_start < RATE_LIMIT_WINDOW_SECS * 2);
conn_rate_limits.retain(|_, entry| {
now - entry.window_start < crate::auth::CONN_RATE_LIMIT_WINDOW_SECS * 2
});
// Bound map sizes to prevent unbounded growth from malicious clients.
const MAX_SESSIONS: usize = 100_000;
const MAX_WAITERS: usize = 100_000;
if sessions.len() > MAX_SESSIONS {
let overflow = sessions.len() - MAX_SESSIONS;
let mut entries: Vec<_> = sessions
.iter()
.map(|e| (e.key().clone(), e.expires_at))
.collect();
entries.sort_by_key(|(_, exp)| *exp);
for (key, _) in entries.into_iter().take(overflow) {
sessions.remove(&key);
}
}
if waiters.len() > MAX_WAITERS {
let overflow = waiters.len() - MAX_WAITERS;
let keys: Vec<_> =
waiters.iter().take(overflow).map(|e| e.key().clone()).collect();
for key in keys {
waiters.remove(&key);
}
}
match store.gc_expired_messages(MESSAGE_TTL_SECS) {
Ok(n) if n > 0 => {
tracing::debug!(expired = n, "garbage collected expired messages")
}
Err(e) => tracing::warn!(error = %e, "message GC failed"),
_ => {}
}
}
});
}

View File

@@ -0,0 +1,119 @@
use capnp::capability::Promise;
use quicprochat_proto::node_capnp::node_service;
use crate::auth::{
coded_error, fmt_hex, require_identity_or_request, validate_auth, validate_auth_context,
};
use crate::error_codes::*;
use crate::storage::StorageError;
use super::NodeServiceImpl;
fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
impl NodeServiceImpl {
/// Health check: unauthenticated by design for liveness probes and load balancers.
pub fn handle_health(
&mut self,
_params: node_service::HealthParams,
mut results: node_service::HealthResults,
) -> Promise<(), capnp::Error> {
results.get().set_status("ok");
Promise::ok(())
}
pub fn handle_publish_endpoint(
&mut self,
params: node_service::PublishEndpointParams,
_results: node_service::PublishEndpointResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let identity_key = match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let node_addr = match p.get_node_addr() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&identity_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
if let Err(e) = self
.store
.publish_endpoint(&identity_key, node_addr)
.map_err(storage_err)
{
return Promise::err(e);
}
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "endpoint published");
Promise::ok(())
}
pub fn handle_resolve_endpoint(
&mut self,
params: node_service::ResolveEndpointParams,
mut results: node_service::ResolveEndpointResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let identity_key = match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if let Err(e) = validate_auth(&self.auth_cfg, &self.sessions, p.get_auth()) {
return Promise::err(e);
}
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
));
}
let endpoint = match self
.store
.resolve_endpoint(&identity_key)
.map_err(storage_err)
{
Ok(e) => e,
Err(e) => return Promise::err(e),
};
if let Some(ep) = endpoint {
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "endpoint resolved");
results.get().set_node_addr(&ep);
} else {
results.get().set_node_addr(&[]);
}
Promise::ok(())
}
}

View File

@@ -0,0 +1,186 @@
//! resolveUser / resolveIdentity RPCs: bidirectional username ↔ identity key lookup.
use capnp::capability::Promise;
use quicprochat_proto::node_capnp::node_service;
use std::time::Duration;
use tokio::time::Instant;
use crate::auth::{check_rate_limit, coded_error, validate_auth_context};
use crate::error_codes::*;
use crate::metrics;
use crate::storage::StorageError;
use super::NodeServiceImpl;
/// Minimum response time for resolveUser to mask DB lookup timing differences.
const RESOLVE_TIMING_FLOOR: Duration = Duration::from_millis(5);
fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
impl NodeServiceImpl {
pub fn handle_resolve_user(
&mut self,
params: node_service::ResolveUserParams,
mut results: node_service::ResolveUserResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match p.get_username() {
Ok(u) => u,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
let username_str = match username.to_str() {
Ok(s) => s,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if username_str.is_empty() {
return Promise::err(coded_error(E020_BAD_PARAMS, "username must not be empty"));
}
// Federation: parse user@domain format.
let addr = crate::federation::address::FederatedAddress::parse(username_str);
let is_remote = match (&addr.domain, &self.local_domain) {
(Some(d), Some(ld)) => d != ld,
(Some(_), None) => true,
_ => false,
};
if is_remote {
// Proxy to remote server via federation.
if let (Some(ref fed_client), Some(ref domain)) = (&self.federation_client, &addr.domain) {
if fed_client.has_peer(domain) {
let fed = fed_client.clone();
let user = addr.username.clone();
let dom = domain.clone();
return Promise::from_future(async move {
match fed.proxy_resolve_user(&dom, &user).await {
Ok(Some(key)) => {
results.get().set_identity_key(&key);
}
Ok(None) => {
// Not found on remote — return empty.
}
Err(e) => {
tracing::warn!(error = %e, "federation proxy_resolve_user failed");
// Fall through — return empty (not found).
}
}
Ok(())
});
}
}
// No federation client or unknown peer — return empty (not found).
return Promise::ok(());
}
// Rate-limit resolve requests to prevent bulk enumeration.
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) {
tracing::warn!("rate_limit_hit");
metrics::record_rate_limit_hit_total();
return Promise::err(e);
}
// Timing floor: record the start time so we can pad the response to a
// fixed minimum duration, masking DB-lookup timing differences between
// existing and non-existing usernames.
let deadline = Instant::now() + RESOLVE_TIMING_FLOOR;
// Local resolution.
let identity_key = match self.store.get_user_identity_key(&addr.username) {
Ok(Some(key)) => key,
Ok(None) => {
// Return empty Data — caller checks length to detect "not found".
// Pad to timing floor before responding.
return Promise::from_future(async move {
tokio::time::sleep_until(deadline).await;
Ok(())
});
}
Err(e) => return Promise::err(storage_err(e)),
};
let mut r = results.get();
r.set_identity_key(&identity_key);
// Attempt to include a KT Merkle inclusion proof.
// Non-fatal: if the log is unavailable or has no entry, return just the key.
if let Ok(log) = self.kt_log.lock() {
if let Some(leaf_idx) = log.find(&addr.username, &identity_key) {
match log.inclusion_proof(leaf_idx) {
Ok(proof) => match proof.to_bytes() {
Ok(bytes) => {
r.set_inclusion_proof(&bytes);
}
Err(e) => {
tracing::warn!(error = %e, "KT proof serialise failed");
}
},
Err(e) => {
tracing::warn!(error = %e, "KT inclusion_proof failed");
}
}
}
}
// Pad to timing floor before responding.
Promise::from_future(async move {
tokio::time::sleep_until(deadline).await;
Ok(())
})
}
pub fn handle_resolve_identity(
&mut self,
params: node_service::ResolveIdentityParams,
mut results: node_service::ResolveIdentityResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let identity_key = match p.get_identity_key() {
Ok(v) => v,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let _auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if identity_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
));
}
// Timing floor: mask DB-lookup timing differences (same as resolveUser).
let deadline = Instant::now() + RESOLVE_TIMING_FLOOR;
match self.store.resolve_identity_key(identity_key) {
Ok(Some(username)) => {
results.get().set_username(&username);
}
Ok(None) => {
// Return empty string — caller checks length to detect "not found".
}
Err(e) => return Promise::err(storage_err(e)),
}
// Pad to timing floor before responding.
Promise::from_future(async move {
tokio::time::sleep_until(deadline).await;
Ok(())
})
}
}

View File

@@ -0,0 +1,342 @@
//! Dynamic plugin loader for server-side hook extensions.
//!
//! Loads shared libraries (`*.so` / `*.dylib`) from a directory at server
//! startup. Each library must export:
//!
//! ```c
//! extern "C" int32_t qpc_plugin_init(HookVTable *vtable);
//! ```
//!
//! The server creates a zeroed [`HookVTable`], passes it to `qpc_plugin_init`,
//! and wraps the resulting vtable in a [`PluginHooks`] that implements
//! [`ServerHooks`]. Multiple plugins are chained via [`ChainedHooks`].
//!
//! # Safety model
//!
//! Dynamic loading is inherently unsafe. The plugin binary MUST:
//! - be compiled against the same `quicprochat-plugin-api` version
//! - not store the event-struct pointers beyond the callback duration
//! - be `Send + Sync` (the wrapper is put behind an `Arc`)
//!
//! The server operator is responsible for only loading trusted plugin binaries.
use std::path::Path;
use libloading::{Library, Symbol};
use quicprochat_plugin_api::{
CAuthEvent, CChannelEvent, CFetchEvent, CMessageEvent, HookVTable, HOOK_CONTINUE, PLUGIN_OK,
};
use crate::hooks::{AuthEvent, ChannelEvent, FetchEvent, HookAction, MessageEvent, ServerHooks};
// ── PluginHooks ───────────────────────────────────────────────────────────────
/// A [`ServerHooks`] implementation backed by a dynamically loaded plugin vtable.
///
/// Holds the [`Library`] alive alongside the vtable so that the loaded code
/// is not unmapped while the vtable function pointers are still reachable.
pub struct PluginHooks {
/// The vtable filled by `qpc_plugin_init`.
vtable: HookVTable,
/// Keeps the shared library mapped. Must be dropped after `vtable`.
_lib: Library,
/// Name of the plugin file, for diagnostics.
name: String,
}
impl PluginHooks {
/// Load a plugin from `path` and call `qpc_plugin_init`.
///
/// Returns `Err` if the library cannot be opened, the symbol is missing,
/// or `qpc_plugin_init` returns a non-zero error code.
pub fn load(path: &Path) -> anyhow::Result<Self> {
let name = path
.file_name()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_else(|| path.display().to_string());
// Safety: loading arbitrary shared libraries is inherently unsafe.
// The server operator is responsible for only loading trusted plugins.
let lib = unsafe { Library::new(path) }
.map_err(|e| anyhow::anyhow!("plugin '{}': load failed: {}", name, e))?;
// Zero-initialise the vtable so unused slots are null.
let mut vtable = HookVTable {
user_data: core::ptr::null_mut(),
on_message_enqueue: None,
on_batch_enqueue: None,
on_auth: None,
on_channel_created: None,
on_fetch: None,
on_user_registered: None,
error_message: None,
destroy: None,
};
// Safety: the symbol must have the exact signature declared in the API crate.
let init: Symbol<unsafe extern "C" fn(*mut HookVTable) -> i32> =
unsafe { lib.get(b"qpc_plugin_init\0") }.map_err(|e| {
anyhow::anyhow!("plugin '{}': missing qpc_plugin_init: {}", name, e)
})?;
let rc = unsafe { init(&mut vtable) };
if rc != PLUGIN_OK {
anyhow::bail!("plugin '{}': qpc_plugin_init returned error {}", name, rc);
}
tracing::info!(plugin = %name, "loaded plugin");
Ok(Self { vtable, _lib: lib, name })
}
/// Human-readable plugin name (filename).
pub fn name(&self) -> &str {
&self.name
}
/// Retrieve the rejection reason from the plugin, falling back to a generic string.
fn rejection_reason(&self) -> String {
if let Some(f) = self.vtable.error_message {
let ptr = unsafe { f(self.vtable.user_data) };
if !ptr.is_null() {
// Safety: plugin must return a valid null-terminated UTF-8 (or ASCII) string.
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr as *const core::ffi::c_char) };
return cstr.to_string_lossy().into_owned();
}
}
"rejected by plugin".to_string()
}
}
impl Drop for PluginHooks {
fn drop(&mut self) {
if let Some(destroy) = self.vtable.destroy {
// Safety: destroy must be safe to call at any time after init.
unsafe { destroy(self.vtable.user_data) };
}
}
}
impl ServerHooks for PluginHooks {
fn on_message_enqueue(&self, event: &MessageEvent) -> HookAction {
let f = match self.vtable.on_message_enqueue {
Some(f) => f,
None => return HookAction::Continue,
};
let sender_ptr = event
.sender_identity
.as_deref()
.map(|s| s.as_ptr())
.unwrap_or(core::ptr::null());
let sender_len = event.sender_identity.as_deref().map_or(0, |s| s.len());
let c_event = CMessageEvent {
sender_identity: sender_ptr,
sender_identity_len: sender_len,
recipient_key: event.recipient_key.as_ptr(),
recipient_key_len: event.recipient_key.len(),
channel_id: event.channel_id.as_ptr(),
channel_id_len: event.channel_id.len(),
payload_len: event.payload_len,
seq: event.seq,
};
let rc = unsafe { f(self.vtable.user_data, &c_event) };
if rc == HOOK_CONTINUE {
HookAction::Continue
} else {
HookAction::Reject(self.rejection_reason())
}
}
fn on_batch_enqueue(&self, events: &[MessageEvent]) {
let f = match self.vtable.on_batch_enqueue {
Some(f) => f,
None => return,
};
let c_events: Vec<CMessageEvent> = events
.iter()
.map(|e| {
let sender_ptr = e
.sender_identity
.as_deref()
.map(|s| s.as_ptr())
.unwrap_or(core::ptr::null());
let sender_len = e.sender_identity.as_deref().map_or(0, |s| s.len());
CMessageEvent {
sender_identity: sender_ptr,
sender_identity_len: sender_len,
recipient_key: e.recipient_key.as_ptr(),
recipient_key_len: e.recipient_key.len(),
channel_id: e.channel_id.as_ptr(),
channel_id_len: e.channel_id.len(),
payload_len: e.payload_len,
seq: e.seq,
}
})
.collect();
unsafe { f(self.vtable.user_data, c_events.as_ptr(), c_events.len()) };
}
fn on_auth(&self, event: &AuthEvent) {
let f = match self.vtable.on_auth {
Some(f) => f,
None => return,
};
let c_event = CAuthEvent {
username: event.username.as_ptr(),
username_len: event.username.len(),
success: if event.success { 1 } else { 0 },
failure_reason: event.failure_reason.as_ptr(),
failure_reason_len: event.failure_reason.len(),
};
unsafe { f(self.vtable.user_data, &c_event) };
}
fn on_channel_created(&self, event: &ChannelEvent) {
let f = match self.vtable.on_channel_created {
Some(f) => f,
None => return,
};
let c_event = CChannelEvent {
channel_id: event.channel_id.as_ptr(),
channel_id_len: event.channel_id.len(),
initiator_key: event.initiator_key.as_ptr(),
initiator_key_len: event.initiator_key.len(),
peer_key: event.peer_key.as_ptr(),
peer_key_len: event.peer_key.len(),
was_new: if event.was_new { 1 } else { 0 },
};
unsafe { f(self.vtable.user_data, &c_event) };
}
fn on_fetch(&self, event: &FetchEvent) {
let f = match self.vtable.on_fetch {
Some(f) => f,
None => return,
};
let c_event = CFetchEvent {
recipient_key: event.recipient_key.as_ptr(),
recipient_key_len: event.recipient_key.len(),
channel_id: event.channel_id.as_ptr(),
channel_id_len: event.channel_id.len(),
message_count: event.message_count,
};
unsafe { f(self.vtable.user_data, &c_event) };
}
fn on_user_registered(&self, username: &str, identity_key: &[u8]) {
let f = match self.vtable.on_user_registered {
Some(f) => f,
None => return,
};
unsafe {
f(
self.vtable.user_data,
username.as_ptr(),
username.len(),
identity_key.as_ptr(),
identity_key.len(),
)
};
}
}
// ── ChainedHooks ─────────────────────────────────────────────────────────────
/// Composes multiple [`ServerHooks`] implementations into one.
///
/// For filtering hooks (`on_message_enqueue`), the first rejection short-circuits
/// the chain. For fire-and-forget hooks, all plugins are called in order.
pub struct ChainedHooks {
hooks: Vec<Box<dyn ServerHooks>>,
}
impl ChainedHooks {
pub fn new(hooks: Vec<Box<dyn ServerHooks>>) -> Self {
Self { hooks }
}
}
impl ServerHooks for ChainedHooks {
fn on_message_enqueue(&self, event: &MessageEvent) -> HookAction {
for h in &self.hooks {
match h.on_message_enqueue(event) {
HookAction::Continue => {}
reject => return reject,
}
}
HookAction::Continue
}
fn on_batch_enqueue(&self, events: &[MessageEvent]) {
for h in &self.hooks {
h.on_batch_enqueue(events);
}
}
fn on_auth(&self, event: &AuthEvent) {
for h in &self.hooks {
h.on_auth(event);
}
}
fn on_channel_created(&self, event: &ChannelEvent) {
for h in &self.hooks {
h.on_channel_created(event);
}
}
fn on_fetch(&self, event: &FetchEvent) {
for h in &self.hooks {
h.on_fetch(event);
}
}
fn on_user_registered(&self, username: &str, identity_key: &[u8]) {
for h in &self.hooks {
h.on_user_registered(username, identity_key);
}
}
}
// ── load_plugins_from_dir ─────────────────────────────────────────────────────
/// Load all `*.so` / `*.dylib` files from `dir` as plugins.
///
/// Non-fatal errors (unreadable files, init failures) are logged as warnings
/// and skipped; the server continues with the plugins that did load.
/// Returns the full list of successfully loaded plugins.
pub fn load_plugins_from_dir(dir: &Path) -> Vec<PluginHooks> {
let mut plugins = Vec::new();
let entries = match std::fs::read_dir(dir) {
Ok(e) => e,
Err(e) => {
tracing::warn!(dir = %dir.display(), error = %e, "plugin_dir unreadable; no plugins loaded");
return plugins;
}
};
for entry in entries.flatten() {
let path = entry.path();
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
if ext != "so" && ext != "dylib" {
continue;
}
match PluginHooks::load(&path) {
Ok(p) => {
tracing::info!(plugin = %p.name(), "plugin loaded successfully");
plugins.push(p);
}
Err(e) => {
tracing::warn!(path = %path.display(), error = %e, "failed to load plugin; skipping");
}
}
}
plugins
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,117 @@
use std::path::PathBuf;
use anyhow::Context;
use quinn::ServerConfig;
use quinn_proto::crypto::rustls::QuicServerConfig;
use rcgen::generate_simple_self_signed;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::version::TLS13;
/// Ensure a self-signed certificate exists on disk and return a QUIC server config.
/// When `production` is true, cert and key must already exist (no auto-generation).
pub fn build_server_config(
cert_path: &PathBuf,
key_path: &PathBuf,
production: bool,
) -> anyhow::Result<ServerConfig> {
if !cert_path.exists() || !key_path.exists() {
if production {
anyhow::bail!(
"TLS cert or key missing at {:?} / {:?}; production mode forbids auto-generation",
cert_path,
key_path
);
}
generate_self_signed_cert(cert_path, key_path)?;
}
let cert_bytes = std::fs::read(cert_path).context("read cert")?;
let key_bytes = std::fs::read(key_path).context("read key")?;
// Validate certificate expiry and warn about self-signed certs.
validate_certificate(&cert_bytes)?;
let cert_chain = vec![CertificateDer::from(cert_bytes)];
let key = PrivateKeyDer::try_from(key_bytes).map_err(|_| anyhow::anyhow!("invalid key"))?;
let mut tls = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13])
.with_no_client_auth()
.with_single_cert(cert_chain, key)?;
tls.alpn_protocols = vec![b"capnp".to_vec()];
let crypto = QuicServerConfig::try_from(tls)
.map_err(|e| anyhow::anyhow!("invalid server TLS config: {e}"))?;
Ok(ServerConfig::with_crypto(std::sync::Arc::new(crypto)))
}
fn generate_self_signed_cert(cert_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result<()> {
if let Some(parent) = cert_path.parent() {
std::fs::create_dir_all(parent).context("create cert dir")?;
}
if let Some(parent) = key_path.parent() {
std::fs::create_dir_all(parent).context("create key dir")?;
}
let subject_alt_names = vec![
"localhost".to_string(),
"127.0.0.1".to_string(),
"::1".to_string(),
];
let issued = generate_simple_self_signed(subject_alt_names)?;
let key_der = issued.key_pair.serialize_der();
std::fs::write(cert_path, issued.cert.der()).context("write cert")?;
std::fs::write(key_path, &key_der).context("write key")?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(0o600);
std::fs::set_permissions(key_path, perms).context("set key permissions")?;
}
tracing::info!(
cert = %cert_path.display(),
key = %key_path.display(),
"generated self-signed TLS certificate"
);
Ok(())
}
/// Validate a DER-encoded X.509 certificate: bail if expired, warn if expiring
/// soon or self-signed.
fn validate_certificate(der_bytes: &[u8]) -> anyhow::Result<()> {
use x509_parser::prelude::*;
let (_, cert) = X509Certificate::from_der(der_bytes)
.map_err(|e| anyhow::anyhow!("failed to parse X.509 certificate: {e}"))?;
let validity = cert.validity();
let now = ASN1Time::now();
if !validity.is_valid_at(now) {
anyhow::bail!(
"TLS certificate expired (not_after: {})",
validity.not_after
);
}
// Warn if expiring within 30 days.
let thirty_days = std::time::Duration::from_secs(30 * 24 * 60 * 60);
let cutoff = now.timestamp() + thirty_days.as_secs() as i64;
if validity.not_after.timestamp() < cutoff {
tracing::warn!(
not_after = %validity.not_after,
"TLS certificate expires within 30 days"
);
}
// Warn if self-signed (issuer == subject).
if cert.issuer() == cert.subject() {
tracing::warn!("TLS certificate is self-signed (issuer == subject)");
}
Ok(())
}

View File

@@ -0,0 +1,51 @@
//! Account handler — account deletion.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use crate::domain::account::AccountService;
use super::{domain_err, require_auth, ServerState};
pub async fn handle_delete_account(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
// DeleteAccountRequest is empty but decode for protocol correctness.
let _req = match v1::DeleteAccountRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = AccountService {
store: Arc::clone(&state.store),
kt_log: Arc::clone(&state.kt_log),
};
match svc.delete_account(&identity_key) {
Ok(()) => {
// Remove session for the deleted account.
if let Some(token) = ctx.session_token.as_deref() {
state.sessions.remove(token);
}
let proto = v1::DeleteAccountResponse { success: true };
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}

View File

@@ -0,0 +1,255 @@
//! OPAQUE auth handlers — registration and login.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::error::RpcStatus;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use crate::auth::{PendingLogin, SessionInfo, SESSION_TTL_SECS};
use crate::domain::auth::AuthService;
use crate::domain::types::{RegisterFinishReq, RegisterStartReq};
use super::ServerState;
pub async fn handle_opaque_register_start(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let req = match v1::OpaqueRegisterStartRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
if req.username.is_empty() {
return HandlerResult::err(RpcStatus::BadRequest, "username must not be empty");
}
let svc = AuthService {
store: Arc::clone(&state.store),
opaque_setup: Arc::clone(&state.opaque_setup),
pending_logins: Arc::clone(&state.pending_logins),
sessions: Arc::clone(&state.sessions),
auth_cfg: Arc::clone(&state.auth_cfg),
};
let domain_req = RegisterStartReq {
username: req.username,
request_bytes: req.request,
};
match svc.register_start(domain_req) {
Ok(resp) => {
let proto = v1::OpaqueRegisterStartResponse {
response: resp.response_bytes,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => HandlerResult::err(RpcStatus::Internal, &format!("register_start: {e}")),
}
}
pub async fn handle_opaque_register_finish(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let req = match v1::OpaqueRegisterFinishRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
if req.username.is_empty() {
return HandlerResult::err(RpcStatus::BadRequest, "username must not be empty");
}
let svc = AuthService {
store: Arc::clone(&state.store),
opaque_setup: Arc::clone(&state.opaque_setup),
pending_logins: Arc::clone(&state.pending_logins),
sessions: Arc::clone(&state.sessions),
auth_cfg: Arc::clone(&state.auth_cfg),
};
let domain_req = RegisterFinishReq {
username: req.username.clone(),
upload_bytes: req.upload,
identity_key: req.identity_key.clone(),
};
match svc.register_finish(domain_req) {
Ok(resp) => {
state
.hooks
.on_user_registered(&req.username, &req.identity_key);
let proto = v1::OpaqueRegisterFinishResponse {
success: resp.success,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => HandlerResult::err(RpcStatus::Internal, &format!("register_finish: {e}")),
}
}
pub async fn handle_opaque_login_start(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let req = match v1::OpaqueLoginStartRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
if req.username.is_empty() {
return HandlerResult::err(RpcStatus::BadRequest, "username must not be empty");
}
// Look up user record.
let user_record = match state.store.get_user_record(&req.username) {
Ok(Some(r)) => r,
Ok(None) => {
return HandlerResult::err(RpcStatus::NotFound, "user not found");
}
Err(e) => return HandlerResult::err(RpcStatus::Internal, &format!("store: {e}")),
};
// Deserialise stored registration.
let registration =
match opaque_ke::ServerRegistration::<quicprochat_core::opaque_auth::OpaqueSuite>::deserialize(&user_record) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
RpcStatus::Internal,
&format!("corrupt user record: {e}"),
)
}
};
// Start login.
let credential_request =
match opaque_ke::CredentialRequest::<quicprochat_core::opaque_auth::OpaqueSuite>::deserialize(&req.request)
{
Ok(r) => r,
Err(e) => {
return HandlerResult::err(RpcStatus::BadRequest, &format!("bad login request: {e}"))
}
};
let login_start = match opaque_ke::ServerLogin::<
quicprochat_core::opaque_auth::OpaqueSuite,
>::start(
&mut rand::rngs::OsRng,
&state.opaque_setup,
Some(registration),
credential_request,
req.username.as_bytes(),
Default::default(),
) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(RpcStatus::Internal, &format!("login start: {e}"));
}
};
let response_bytes = login_start.message.serialize().to_vec();
// Store pending login state.
let now = crate::auth::current_timestamp();
state.pending_logins.insert(
req.username.clone(),
PendingLogin {
state_bytes: login_start.state.serialize().to_vec(),
created_at: now,
},
);
let proto = v1::OpaqueLoginStartResponse {
response: response_bytes,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
pub async fn handle_opaque_login_finish(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let req = match v1::OpaqueLoginFinishRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
if req.username.is_empty() {
return HandlerResult::err(RpcStatus::BadRequest, "username must not be empty");
}
// Retrieve pending login state.
let pending = match state.pending_logins.remove(&req.username) {
Some((_, p)) => p,
None => {
return HandlerResult::err(
RpcStatus::BadRequest,
"no pending login for this username",
);
}
};
let login_state = match opaque_ke::ServerLogin::<
quicprochat_core::opaque_auth::OpaqueSuite,
>::deserialize(&pending.state_bytes)
{
Ok(s) => s,
Err(e) => {
return HandlerResult::err(
RpcStatus::Internal,
&format!("corrupt pending login: {e}"),
)
}
};
let finalization = match opaque_ke::CredentialFinalization::<
quicprochat_core::opaque_auth::OpaqueSuite,
>::deserialize(&req.finalization)
{
Ok(f) => f,
Err(e) => {
return HandlerResult::err(RpcStatus::BadRequest, &format!("bad finalization: {e}"));
}
};
if let Err(e) = login_state.finish(finalization, Default::default()) {
state.hooks.on_auth(&crate::hooks::AuthEvent {
username: req.username.clone(),
success: false,
failure_reason: format!("{e}"),
});
return HandlerResult::err(RpcStatus::Unauthorized, &format!("login failed: {e}"));
}
// Generate session token.
let mut token = vec![0u8; 32];
rand::RngCore::fill_bytes(&mut rand::rngs::OsRng, &mut token);
let now = crate::auth::current_timestamp();
state.sessions.insert(
token.clone(),
SessionInfo {
username: req.username.clone(),
identity_key: req.identity_key.clone(),
created_at: now,
expires_at: now + SESSION_TTL_SECS,
},
);
state.hooks.on_auth(&crate::hooks::AuthEvent {
username: req.username,
success: true,
failure_reason: String::new(),
});
let proto = v1::OpaqueLoginFinishResponse {
session_token: token,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}

View File

@@ -0,0 +1,101 @@
//! Blob handlers — chunked file upload/download.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use crate::domain::blobs::BlobService;
use crate::domain::types::{CallerAuth, DownloadBlobReq, UploadBlobReq};
use super::{domain_err, require_auth, ServerState};
fn caller_auth(identity_key: Vec<u8>) -> CallerAuth {
CallerAuth {
identity_key,
token: Vec::new(),
device_id: None,
}
}
pub async fn handle_upload_blob(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::UploadBlobRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = BlobService {
data_dir: state.data_dir.clone(),
};
let auth = caller_auth(identity_key);
let domain_req = UploadBlobReq {
blob_hash: req.blob_hash,
chunk: req.chunk,
offset: req.offset,
total_size: req.total_size,
mime_type: req.mime_type,
};
match svc.upload_blob(domain_req, &auth) {
Ok(resp) => {
let proto = v1::UploadBlobResponse {
blob_id: resp.blob_id,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_download_blob(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::DownloadBlobRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = BlobService {
data_dir: state.data_dir.clone(),
};
let auth = caller_auth(identity_key);
let domain_req = DownloadBlobReq {
blob_id: req.blob_id,
offset: req.offset,
length: req.length,
};
match svc.download_blob(domain_req, &auth) {
Ok(resp) => {
let proto = v1::DownloadBlobResponse {
chunk: resp.chunk,
total_size: resp.total_size,
mime_type: resp.mime_type,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}

View File

@@ -0,0 +1,60 @@
//! Channel handler — 1:1 DM channel creation.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use crate::domain::channels::ChannelService;
use crate::domain::types::CreateChannelReq;
use crate::hooks::ChannelEvent;
use super::{domain_err, require_auth, ServerState};
pub async fn handle_create_channel(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::CreateChannelRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = ChannelService {
store: Arc::clone(&state.store),
};
let domain_req = CreateChannelReq {
peer_key: req.peer_key.clone(),
};
match svc.create_channel(domain_req, &identity_key) {
Ok(resp) => {
state.hooks.on_channel_created(&ChannelEvent {
channel_id: resp.channel_id.clone(),
initiator_key: identity_key,
peer_key: req.peer_key,
was_new: resp.was_new,
});
let proto = v1::CreateChannelResponse {
channel_id: resp.channel_id,
was_new: resp.was_new,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}

View File

@@ -0,0 +1,408 @@
//! Delivery handlers — enqueue, fetch, fetch_wait, peek, ack, batch_enqueue.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::error::RpcStatus;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use sha2::{Digest, Sha256};
use tokio::sync::Notify;
use crate::domain::delivery::DeliveryService;
use crate::domain::types::{AckReq, BatchEnqueueReq, EnqueueReq, FetchReq, PeekReq};
use crate::hooks::{HookAction, MessageEvent};
use super::{require_auth, ServerState};
/// Build a 96-byte delivery proof: `SHA-256(seq || recipient_key || timestamp_ms) || Ed25519(hash)`.
///
/// The sender stores this as cryptographic evidence that the server enqueued the message.
fn build_delivery_proof(
signing_key: &quicprochat_core::IdentityKeypair,
seq: u64,
recipient_key: &[u8],
timestamp_ms: u64,
) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(seq.to_le_bytes());
hasher.update(recipient_key);
hasher.update(timestamp_ms.to_le_bytes());
let hash: [u8; 32] = hasher.finalize().into();
let sig = signing_key.sign_raw(&hash);
let mut proof = vec![0u8; 96];
proof[..32].copy_from_slice(&hash);
proof[32..].copy_from_slice(&sig);
proof
}
pub async fn handle_enqueue(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::EnqueueRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
if req.recipient_key.is_empty() || req.payload.is_empty() {
return HandlerResult::err(RpcStatus::BadRequest, "recipient_key and payload required");
}
// Rate limiting.
if let Err(_e) = crate::auth::check_rate_limit(&state.rate_limits, &identity_key) {
return HandlerResult::err(RpcStatus::RateLimited, "rate limit exceeded");
}
// Idempotency dedup: if message_id is provided and already seen, return the cached seq.
if !req.message_id.is_empty() {
if let Some(entry) = state.seen_message_ids.get(&req.message_id) {
let (cached_seq, _ts) = *entry;
let proto = v1::EnqueueResponse {
seq: cached_seq,
delivery_proof: Vec::new(),
duplicate: true,
};
return HandlerResult::ok(Bytes::from(proto.encode_to_vec()));
}
}
let svc = DeliveryService {
store: Arc::clone(&state.store),
waiters: Arc::clone(&state.waiters),
};
let domain_req = EnqueueReq {
recipient_key: req.recipient_key.clone(),
payload: req.payload.clone(),
channel_id: req.channel_id.clone(),
ttl_secs: req.ttl_secs,
};
match svc.enqueue(domain_req) {
Ok(resp) => {
// Record message_id for dedup.
if !req.message_id.is_empty() {
let now = crate::auth::current_timestamp();
state.seen_message_ids.insert(req.message_id, (resp.seq, now));
}
// Fire hook.
let action = state.hooks.on_message_enqueue(&MessageEvent {
sender_identity: Some(identity_key),
recipient_key: req.recipient_key.clone(),
channel_id: req.channel_id,
payload_len: req.payload.len(),
seq: resp.seq,
});
if let HookAction::Reject(reason) = action {
return HandlerResult::err(RpcStatus::Forbidden, &reason);
}
// Build server-signed delivery proof.
let timestamp_ms = crate::auth::current_timestamp();
let delivery_proof = build_delivery_proof(
&state.signing_key,
resp.seq,
&req.recipient_key,
timestamp_ms,
);
let proto = v1::EnqueueResponse {
seq: resp.seq,
delivery_proof,
duplicate: false,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => HandlerResult::err(RpcStatus::Internal, &format!("enqueue: {e}")),
}
}
pub async fn handle_fetch(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::FetchRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let svc = DeliveryService {
store: Arc::clone(&state.store),
waiters: Arc::clone(&state.waiters),
};
let base_key = if req.recipient_key.is_empty() {
identity_key
} else {
req.recipient_key
};
let recipient_key = if req.device_id.is_empty() {
base_key
} else {
DeliveryService::device_recipient_key(&base_key, &req.device_id)
};
let domain_req = FetchReq {
recipient_key,
channel_id: req.channel_id,
limit: req.limit,
};
match svc.fetch(domain_req) {
Ok(resp) => {
let proto = v1::FetchResponse {
payloads: resp
.payloads
.into_iter()
.map(|e| v1::Envelope {
seq: e.seq,
data: e.data,
})
.collect(),
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => HandlerResult::err(RpcStatus::Internal, &format!("fetch: {e}")),
}
}
pub async fn handle_fetch_wait(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::FetchWaitRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let base_key = if req.recipient_key.is_empty() {
identity_key
} else {
req.recipient_key
};
let recipient_key = if req.device_id.is_empty() {
base_key
} else {
DeliveryService::device_recipient_key(&base_key, &req.device_id)
};
let timeout_ms = if req.timeout_ms == 0 {
30_000
} else {
req.timeout_ms.min(60_000)
};
let svc = DeliveryService {
store: Arc::clone(&state.store),
waiters: Arc::clone(&state.waiters),
};
// Try immediate fetch first.
let fetch_req = FetchReq {
recipient_key: recipient_key.clone(),
channel_id: req.channel_id.clone(),
limit: req.limit,
};
match svc.fetch(fetch_req) {
Ok(resp) if !resp.payloads.is_empty() => {
let proto = v1::FetchWaitResponse {
payloads: resp
.payloads
.into_iter()
.map(|e| v1::Envelope {
seq: e.seq,
data: e.data,
})
.collect(),
};
return HandlerResult::ok(Bytes::from(proto.encode_to_vec()));
}
Err(e) => {
return HandlerResult::err(RpcStatus::Internal, &format!("fetch: {e}"));
}
_ => {}
}
// Long-poll: wait for notification or timeout.
let notify = state
.waiters
.entry(recipient_key.clone())
.or_insert_with(|| Arc::new(Notify::new()))
.clone();
let timeout = tokio::time::Duration::from_millis(timeout_ms);
let _ = tokio::time::timeout(timeout, notify.notified()).await;
// Re-fetch after wake or timeout.
let fetch_req = FetchReq {
recipient_key,
channel_id: req.channel_id,
limit: req.limit,
};
match svc.fetch(fetch_req) {
Ok(resp) => {
let proto = v1::FetchWaitResponse {
payloads: resp
.payloads
.into_iter()
.map(|e| v1::Envelope {
seq: e.seq,
data: e.data,
})
.collect(),
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => HandlerResult::err(RpcStatus::Internal, &format!("fetch: {e}")),
}
}
pub async fn handle_peek(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::PeekRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let svc = DeliveryService {
store: Arc::clone(&state.store),
waiters: Arc::clone(&state.waiters),
};
let base_key = if req.recipient_key.is_empty() {
identity_key
} else {
req.recipient_key
};
let recipient_key = if req.device_id.is_empty() {
base_key
} else {
DeliveryService::device_recipient_key(&base_key, &req.device_id)
};
let domain_req = PeekReq {
recipient_key,
channel_id: req.channel_id,
limit: req.limit,
};
match svc.peek(domain_req) {
Ok(resp) => {
let proto = v1::PeekResponse {
payloads: resp
.payloads
.into_iter()
.map(|e| v1::Envelope {
seq: e.seq,
data: e.data,
})
.collect(),
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => HandlerResult::err(RpcStatus::Internal, &format!("peek: {e}")),
}
}
pub async fn handle_ack(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::AckRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let svc = DeliveryService {
store: Arc::clone(&state.store),
waiters: Arc::clone(&state.waiters),
};
let base_key = if req.recipient_key.is_empty() {
identity_key
} else {
req.recipient_key
};
let recipient_key = if req.device_id.is_empty() {
base_key
} else {
DeliveryService::device_recipient_key(&base_key, &req.device_id)
};
let domain_req = AckReq {
recipient_key,
channel_id: req.channel_id,
seq_up_to: req.seq_up_to,
};
match svc.ack(domain_req) {
Ok(()) => {
let proto = v1::AckResponse {};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => HandlerResult::err(RpcStatus::Internal, &format!("ack: {e}")),
}
}
pub async fn handle_batch_enqueue(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::BatchEnqueueRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
if req.recipient_keys.is_empty() || req.payload.is_empty() {
return HandlerResult::err(
RpcStatus::BadRequest,
"recipient_keys and payload required",
);
}
// Rate limiting.
if let Err(_e) = crate::auth::check_rate_limit(&state.rate_limits, &identity_key) {
return HandlerResult::err(RpcStatus::RateLimited, "rate limit exceeded");
}
let svc = DeliveryService {
store: Arc::clone(&state.store),
waiters: Arc::clone(&state.waiters),
};
let domain_req = BatchEnqueueReq {
recipient_keys: req.recipient_keys,
payload: req.payload,
channel_id: req.channel_id,
ttl_secs: req.ttl_secs,
};
match svc.batch_enqueue(domain_req) {
Ok(resp) => {
let proto = v1::BatchEnqueueResponse { seqs: resp.seqs };
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => HandlerResult::err(RpcStatus::Internal, &format!("batch_enqueue: {e}")),
}
}

View File

@@ -0,0 +1,127 @@
//! Device handlers — register, list, revoke devices.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use crate::domain::devices::DeviceService;
use crate::domain::types::{RegisterDeviceReq, RevokeDeviceReq};
use super::{domain_err, require_auth, ServerState};
pub async fn handle_register_device(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::RegisterDeviceRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = DeviceService {
store: Arc::clone(&state.store),
};
let domain_req = RegisterDeviceReq {
device_id: req.device_id,
device_name: req.device_name,
};
match svc.register_device(domain_req, &identity_key) {
Ok(resp) => {
let proto = v1::RegisterDeviceResponse {
success: resp.success,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_list_devices(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
// ListDevicesRequest is empty but we still decode for protocol correctness.
let _req = match v1::ListDevicesRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = DeviceService {
store: Arc::clone(&state.store),
};
match svc.list_devices(&identity_key) {
Ok(resp) => {
let proto = v1::ListDevicesResponse {
devices: resp
.devices
.into_iter()
.map(|d| v1::Device {
device_id: d.device_id,
device_name: d.device_name,
registered_at: d.registered_at,
})
.collect(),
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_revoke_device(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::RevokeDeviceRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = DeviceService {
store: Arc::clone(&state.store),
};
let domain_req = RevokeDeviceReq {
device_id: req.device_id,
};
match svc.revoke_device(domain_req, &identity_key) {
Ok(resp) => {
let proto = v1::RevokeDeviceResponse {
success: resp.success,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}

View File

@@ -0,0 +1,227 @@
//! Federation v2 RPC handlers — relay, proxy, and health.
//!
//! Implements the inbound side of server-to-server federation: accepts relay
//! and proxy requests from peer servers and delegates to local storage.
//! Outbound relay to remote peers is handled by the capnp-based
//! `FederationClient` on the main connection path.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::error::RpcStatus;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use crate::federation::address::FederatedAddress;
use super::ServerState;
/// Validate that the request carries a valid federation auth origin.
fn validate_federation_auth(auth: &Option<v1::FederationAuth>) -> Result<String, HandlerResult> {
let a = auth.as_ref().ok_or_else(|| {
HandlerResult::err(RpcStatus::Unauthorized, "missing federation auth")
})?;
if a.origin.is_empty() {
return Err(HandlerResult::err(
RpcStatus::Unauthorized,
"federation auth origin must not be empty",
));
}
Ok(a.origin.clone())
}
/// Relay a single message to a local recipient.
///
/// This handler is called by peer servers to deliver messages to users
/// homed on this server. If the recipient is not local, returns NotFound
/// (the originating server should route directly to the correct home server).
pub async fn handle_relay_enqueue(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let req = match v1::RelayEnqueueRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let origin = match validate_federation_auth(&req.auth) {
Ok(o) => o,
Err(e) => return e,
};
if req.recipient_key.len() != 32 {
return HandlerResult::err(RpcStatus::BadRequest, "recipient_key must be 32 bytes");
}
if req.payload.is_empty() {
return HandlerResult::err(RpcStatus::BadRequest, "payload must not be empty");
}
match state
.store
.enqueue(&req.recipient_key, &req.channel_id, req.payload, None)
{
Ok(seq) => {
if let Some(waiter) = state.waiters.get(&req.recipient_key) {
waiter.notify_waiters();
}
tracing::info!(
origin = %origin,
recipient_prefix = %hex::encode(&req.recipient_key[..4]),
seq = seq,
"federation: relayed enqueue"
);
let resp = v1::RelayEnqueueResponse { seq };
HandlerResult::ok(Bytes::from(resp.encode_to_vec()))
}
Err(e) => HandlerResult::err(RpcStatus::Internal, &format!("store error: {e}")),
}
}
/// Relay a batch of messages to local recipients.
pub async fn handle_relay_batch_enqueue(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let req = match v1::RelayBatchEnqueueRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let _origin = match validate_federation_auth(&req.auth) {
Ok(o) => o,
Err(e) => return e,
};
if req.payload.is_empty() {
return HandlerResult::err(RpcStatus::BadRequest, "payload must not be empty");
}
let mut seqs = Vec::with_capacity(req.recipient_keys.len());
for rk in &req.recipient_keys {
if rk.len() != 32 {
return HandlerResult::err(
RpcStatus::BadRequest,
"each recipient_key must be 32 bytes",
);
}
match state
.store
.enqueue(rk, &req.channel_id, req.payload.clone(), None)
{
Ok(seq) => {
if let Some(waiter) = state.waiters.get(rk.as_slice()) {
waiter.notify_waiters();
}
seqs.push(seq);
}
Err(e) => {
return HandlerResult::err(RpcStatus::Internal, &format!("store error: {e}"))
}
}
}
tracing::info!(
recipient_count = req.recipient_keys.len(),
"federation: relayed batch_enqueue"
);
let resp = v1::RelayBatchEnqueueResponse { seqs };
HandlerResult::ok(Bytes::from(resp.encode_to_vec()))
}
/// Proxy a key package fetch from local storage.
pub async fn handle_proxy_fetch_key_package(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let req = match v1::ProxyFetchKeyPackageRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let _origin = match validate_federation_auth(&req.auth) {
Ok(o) => o,
Err(e) => return e,
};
let package = match state.store.fetch_key_package(&req.identity_key) {
Ok(pkg) => pkg.unwrap_or_default(),
Err(e) => return HandlerResult::err(RpcStatus::Internal, &format!("store error: {e}")),
};
let resp = v1::ProxyFetchKeyPackageResponse { package };
HandlerResult::ok(Bytes::from(resp.encode_to_vec()))
}
/// Proxy a hybrid key fetch from local storage.
pub async fn handle_proxy_fetch_hybrid_key(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let req = match v1::ProxyFetchHybridKeyRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let _origin = match validate_federation_auth(&req.auth) {
Ok(o) => o,
Err(e) => return e,
};
let hybrid_public_key = match state.store.fetch_hybrid_key(&req.identity_key) {
Ok(pk) => pk.unwrap_or_default(),
Err(e) => return HandlerResult::err(RpcStatus::Internal, &format!("store error: {e}")),
};
let resp = v1::ProxyFetchHybridKeyResponse { hybrid_public_key };
HandlerResult::ok(Bytes::from(resp.encode_to_vec()))
}
/// Proxy a user resolution from local storage.
///
/// Supports federated `user@domain` addresses: if the domain matches the
/// local server, the local user is resolved; otherwise returns empty.
pub async fn handle_proxy_resolve_user(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let req = match v1::ProxyResolveUserRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let _origin = match validate_federation_auth(&req.auth) {
Ok(o) => o,
Err(e) => return e,
};
let addr = FederatedAddress::parse(&req.username);
let is_local = addr.is_local(&state.local_domain);
let identity_key = if is_local {
match state.store.get_user_identity_key(&addr.username) {
Ok(key) => key.unwrap_or_default(),
Err(e) => {
return HandlerResult::err(RpcStatus::Internal, &format!("store error: {e}"))
}
}
} else {
// Remote user: not on this server. Return empty.
Vec::new()
};
let resp = v1::ProxyResolveUserResponse { identity_key };
HandlerResult::ok(Bytes::from(resp.encode_to_vec()))
}
/// Federation health check — returns ok status and this server's domain.
pub async fn handle_federation_health(
state: Arc<ServerState>,
_ctx: RequestContext,
) -> HandlerResult {
let resp = v1::FederationHealthResponse {
status: "ok".into(),
server_domain: state.local_domain.clone(),
};
HandlerResult::ok(Bytes::from(resp.encode_to_vec()))
}

View File

@@ -0,0 +1,162 @@
//! Group management handlers — remove member, update metadata, list members, rotate keys.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::error::RpcStatus;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use crate::domain::groups::GroupService;
use crate::domain::types::{ListGroupMembersReq, UpdateGroupMetadataReq};
use super::{domain_err, require_auth, ServerState};
/// Handle RemoveMember (410): track member removal server-side.
///
/// Note: actual MLS removal (Remove proposal + Commit) is done client-side
/// via the SDK. This handler records the membership change on the server.
pub async fn handle_remove_member(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::RemoveMemberRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
if req.group_id.is_empty() || req.member_identity_key.is_empty() {
return HandlerResult::err(
RpcStatus::BadRequest,
"group_id and member_identity_key required",
);
}
let svc = GroupService {
store: Arc::clone(&state.store),
};
match svc.remove_member(&req.group_id, &req.member_identity_key) {
Ok(_) => {
let _ = identity_key; // caller is authorized; removal tracked
let proto = v1::RemoveMemberResponse {
commit: Vec::new(), // commit is generated client-side
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
/// Handle UpdateGroupMetadata (411): store group name, description, avatar.
pub async fn handle_update_group_metadata(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::UpdateGroupMetadataRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let svc = GroupService {
store: Arc::clone(&state.store),
};
let domain_req = UpdateGroupMetadataReq {
group_id: req.group_id,
name: req.name,
description: req.description,
avatar_hash: req.avatar_hash,
};
match svc.update_metadata(domain_req, &identity_key) {
Ok(()) => {
let proto = v1::UpdateGroupMetadataResponse { success: true };
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
/// Handle ListGroupMembers (412): return member list with resolved usernames.
pub async fn handle_list_group_members(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let _identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::ListGroupMembersRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let svc = GroupService {
store: Arc::clone(&state.store),
};
let domain_req = ListGroupMembersReq {
group_id: req.group_id,
};
match svc.list_members(domain_req) {
Ok(resp) => {
let proto = v1::ListGroupMembersResponse {
members: resp
.members
.into_iter()
.map(|m| v1::GroupMemberInfo {
identity_key: m.identity_key,
username: m.username,
joined_at: m.joined_at,
})
.collect(),
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
/// Handle RotateKeys (413): acknowledge key rotation.
///
/// Actual MLS key rotation (Update proposal + Commit) is done client-side.
/// This handler exists for server-side tracking and future rate limiting.
pub async fn handle_rotate_keys(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let _identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::RotateKeysRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
if req.group_id.is_empty() {
return HandlerResult::err(RpcStatus::BadRequest, "group_id required");
}
// Key rotation is handled entirely client-side in MLS.
// This endpoint is for server-side auditing and future rate limiting.
let proto = v1::RotateKeysResponse {
commit: Vec::new(), // commit is generated client-side
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}

View File

@@ -0,0 +1,217 @@
//! Key management handlers — KeyPackage and hybrid key operations.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use crate::domain::keys::KeyService;
use crate::domain::types::{
CallerAuth, FetchHybridKeyReq, FetchHybridKeysReq, FetchKeyPackageReq, UploadHybridKeyReq,
UploadKeyPackageReq,
};
use super::{domain_err, require_auth, ServerState};
fn caller_auth(identity_key: Vec<u8>) -> CallerAuth {
CallerAuth {
identity_key,
token: Vec::new(),
device_id: None,
}
}
pub async fn handle_upload_key_package(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::UploadKeyPackageRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = KeyService {
store: Arc::clone(&state.store),
};
let auth = caller_auth(identity_key);
let domain_req = UploadKeyPackageReq {
identity_key: req.identity_key,
package: req.package,
};
match svc.upload_key_package(domain_req, &auth) {
Ok(resp) => {
let proto = v1::UploadKeyPackageResponse {
fingerprint: resp.fingerprint,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_fetch_key_package(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::FetchKeyPackageRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = KeyService {
store: Arc::clone(&state.store),
};
let auth = caller_auth(identity_key);
let domain_req = FetchKeyPackageReq {
identity_key: req.identity_key,
};
match svc.fetch_key_package(domain_req, &auth) {
Ok(resp) => {
let proto = v1::FetchKeyPackageResponse {
package: resp.package,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_upload_hybrid_key(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::UploadHybridKeyRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = KeyService {
store: Arc::clone(&state.store),
};
let auth = caller_auth(identity_key);
let domain_req = UploadHybridKeyReq {
identity_key: req.identity_key,
hybrid_public_key: req.hybrid_public_key,
};
match svc.upload_hybrid_key(domain_req, &auth) {
Ok(()) => {
let proto = v1::UploadHybridKeyResponse {};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_fetch_hybrid_key(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::FetchHybridKeyRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = KeyService {
store: Arc::clone(&state.store),
};
let auth = caller_auth(identity_key);
let domain_req = FetchHybridKeyReq {
identity_key: req.identity_key,
};
match svc.fetch_hybrid_key(domain_req, &auth) {
Ok(resp) => {
let proto = v1::FetchHybridKeyResponse {
hybrid_public_key: resp.hybrid_public_key,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_fetch_hybrid_keys(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::FetchHybridKeysRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = KeyService {
store: Arc::clone(&state.store),
};
let auth = caller_auth(identity_key);
let domain_req = FetchHybridKeysReq {
identity_keys: req.identity_keys,
};
match svc.fetch_hybrid_keys(domain_req, &auth) {
Ok(resp) => {
let proto = v1::FetchHybridKeysResponse { keys: resp.keys };
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}

View File

@@ -0,0 +1,439 @@
//! v2 RPC handler dispatch — protobuf in, domain logic, protobuf out.
use std::path::PathBuf;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use dashmap::DashMap;
use opaque_ke::ServerSetup;
use quicprochat_core::opaque_auth::OpaqueSuite;
use quicprochat_proto::method_ids;
use quicprochat_rpc::error::RpcStatus;
use quicprochat_rpc::method::{HandlerResult, MethodRegistry, RequestContext};
use tokio::sync::Notify;
use crate::audit::AuditLogger;
use crate::auth::{AuthConfig, PendingLogin, RateEntry, SessionInfo};
use crate::hooks::ServerHooks;
use crate::storage::Store;
pub mod account;
pub mod auth;
pub mod blob;
pub mod channel;
pub mod delivery;
pub mod device;
pub mod federation;
pub mod group;
pub mod keys;
pub mod moderation;
pub mod p2p;
pub mod recovery;
pub mod user;
/// Shared server state accessible by all v2 RPC handlers.
pub struct ServerState {
pub store: Arc<dyn Store>,
pub waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
pub auth_cfg: Arc<AuthConfig>,
pub opaque_setup: Arc<ServerSetup<OpaqueSuite>>,
pub pending_logins: Arc<DashMap<String, PendingLogin>>,
pub sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
pub rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
pub sealed_sender: bool,
pub hooks: Arc<dyn ServerHooks>,
pub signing_key: Arc<quicprochat_core::IdentityKeypair>,
pub kt_log: Arc<std::sync::Mutex<quicprochat_kt::MerkleLog>>,
pub revocation_log: Arc<std::sync::Mutex<quicprochat_kt::RevocationLog>>,
pub data_dir: PathBuf,
pub redact_logs: bool,
/// Structured audit logger for security-relevant events.
pub audit_logger: Arc<dyn AuditLogger>,
/// When true, the server is draining and will reject new work.
/// Health endpoint returns "draining" status so load balancers stop routing.
pub draining: Arc<AtomicBool>,
/// Idempotency dedup: message_id -> (seq, timestamp). TTL-cleaned by cleanup task.
pub seen_message_ids: Arc<DashMap<Vec<u8>, (u64, u64)>>,
/// Banned users: identity_key -> BanRecord.
pub banned_users: Arc<DashMap<Vec<u8>, BanRecord>>,
/// Moderation reports (append-only).
pub moderation_reports: Arc<std::sync::Mutex<Vec<ModerationReport>>>,
/// Unique node identifier for multi-node health reporting.
pub node_id: String,
/// Process start time for uptime calculation.
pub start_time: std::time::Instant,
/// Storage backend name (e.g. "sql", "file").
pub storage_backend: String,
/// Federation client for outbound server-to-server relay. None when federation is disabled.
pub federation_client: Option<Arc<crate::federation::FederationClient>>,
/// This server's domain for federation addressing. Empty when federation is disabled.
pub local_domain: String,
}
/// A ban record for a user.
#[derive(Debug, Clone)]
pub struct BanRecord {
pub reason: String,
pub banned_at: u64,
/// 0 = permanent.
pub expires_at: u64,
}
/// A stored moderation report.
#[derive(Debug, Clone)]
pub struct ModerationReport {
pub id: u64,
pub encrypted_report: Vec<u8>,
pub conversation_id: Vec<u8>,
pub reporter_identity: Vec<u8>,
pub timestamp: u64,
}
/// Validate the session token from the request context and return the
/// authenticated caller's identity key. Returns an Unauthorized HandlerResult
/// on failure.
pub fn require_auth(state: &ServerState, ctx: &RequestContext) -> Result<Vec<u8>, HandlerResult> {
let token = ctx
.session_token
.as_deref()
.or(ctx.identity_key.as_deref())
.unwrap_or(&[]);
if token.is_empty() {
return Err(HandlerResult::err(
RpcStatus::Unauthorized,
"missing session token",
));
}
// Check session store.
if let Some(session) = state.sessions.get(token) {
let now = crate::auth::current_timestamp();
if session.expires_at > now && !session.identity_key.is_empty() {
// Check ban status.
if let Some(ban) = state.banned_users.get(&session.identity_key) {
if ban.expires_at == 0 || ban.expires_at > now {
return Err(HandlerResult::err(
RpcStatus::Forbidden,
"account banned",
));
}
// Ban expired — remove it.
drop(ban);
state.banned_users.remove(&session.identity_key);
}
return Ok(session.identity_key.clone());
}
}
// Fall back to static bearer token (dev mode).
if state.auth_cfg.allow_insecure_identity_from_request {
if let Some(ik) = ctx.identity_key.as_deref() {
if !ik.is_empty() {
return Ok(ik.to_vec());
}
}
}
Err(HandlerResult::err(
RpcStatus::Unauthorized,
"invalid or expired session token",
))
}
/// Map a domain error to an RPC HandlerResult error.
pub fn domain_err(e: crate::domain::types::DomainError) -> HandlerResult {
use crate::domain::types::DomainError;
match &e {
DomainError::InvalidIdentityKey(_)
| DomainError::EmptyPackage
| DomainError::EmptyHybridKey
| DomainError::EmptyUsername
| DomainError::BlobHashLength(_)
| DomainError::BadParams(_) => HandlerResult::err(RpcStatus::BadRequest, &e.to_string()),
DomainError::BlobNotFound | DomainError::DeviceNotFound | DomainError::GroupNotFound => {
HandlerResult::err(RpcStatus::NotFound, &e.to_string())
}
DomainError::PackageTooLarge(_) | DomainError::BlobTooLarge(_) => {
HandlerResult::err(RpcStatus::BadRequest, &e.to_string())
}
DomainError::BlobHashMismatch => {
HandlerResult::err(RpcStatus::BadRequest, &e.to_string())
}
DomainError::DeviceLimit(_) => HandlerResult::err(RpcStatus::Forbidden, &e.to_string()),
DomainError::Io(_) | DomainError::Storage(_) => {
HandlerResult::err(RpcStatus::Internal, &e.to_string())
}
}
}
/// Build the v2 method registry with all handlers registered.
///
/// `default_rpc_timeout` sets the server-wide per-RPC timeout. Individual methods
/// (e.g. blob upload, health) may override this with shorter or longer values.
pub fn build_registry(default_rpc_timeout: std::time::Duration) -> MethodRegistry<ServerState> {
let mut reg = MethodRegistry::new();
reg.set_default_timeout(default_rpc_timeout);
// Auth (100-103)
reg.register(
method_ids::OPAQUE_REGISTER_START,
"OpaqueRegisterStart",
auth::handle_opaque_register_start,
);
reg.register(
method_ids::OPAQUE_REGISTER_FINISH,
"OpaqueRegisterFinish",
auth::handle_opaque_register_finish,
);
reg.register(
method_ids::OPAQUE_LOGIN_START,
"OpaqueLoginStart",
auth::handle_opaque_login_start,
);
reg.register(
method_ids::OPAQUE_LOGIN_FINISH,
"OpaqueLoginFinish",
auth::handle_opaque_login_finish,
);
// Delivery (200-205)
reg.register(method_ids::ENQUEUE, "Enqueue", delivery::handle_enqueue);
reg.register(method_ids::FETCH, "Fetch", delivery::handle_fetch);
reg.register(
method_ids::FETCH_WAIT,
"FetchWait",
delivery::handle_fetch_wait,
);
reg.register(method_ids::PEEK, "Peek", delivery::handle_peek);
reg.register(method_ids::ACK, "Ack", delivery::handle_ack);
reg.register(
method_ids::BATCH_ENQUEUE,
"BatchEnqueue",
delivery::handle_batch_enqueue,
);
// Keys (300-304)
reg.register(
method_ids::UPLOAD_KEY_PACKAGE,
"UploadKeyPackage",
keys::handle_upload_key_package,
);
reg.register(
method_ids::FETCH_KEY_PACKAGE,
"FetchKeyPackage",
keys::handle_fetch_key_package,
);
reg.register(
method_ids::UPLOAD_HYBRID_KEY,
"UploadHybridKey",
keys::handle_upload_hybrid_key,
);
reg.register(
method_ids::FETCH_HYBRID_KEY,
"FetchHybridKey",
keys::handle_fetch_hybrid_key,
);
reg.register(
method_ids::FETCH_HYBRID_KEYS,
"FetchHybridKeys",
keys::handle_fetch_hybrid_keys,
);
// Channel (400)
reg.register(
method_ids::CREATE_CHANNEL,
"CreateChannel",
channel::handle_create_channel,
);
// Group management (410-413)
reg.register(
method_ids::REMOVE_MEMBER,
"RemoveMember",
group::handle_remove_member,
);
reg.register(
method_ids::UPDATE_GROUP_METADATA,
"UpdateGroupMetadata",
group::handle_update_group_metadata,
);
reg.register(
method_ids::LIST_GROUP_MEMBERS,
"ListGroupMembers",
group::handle_list_group_members,
);
reg.register(
method_ids::ROTATE_KEYS,
"RotateKeys",
group::handle_rotate_keys,
);
// User (500-501)
reg.register(
method_ids::RESOLVE_USER,
"ResolveUser",
user::handle_resolve_user,
);
reg.register(
method_ids::RESOLVE_IDENTITY,
"ResolveIdentity",
user::handle_resolve_identity,
);
// Key Transparency (510-520)
reg.register(
method_ids::REVOKE_KEY,
"RevokeKey",
user::handle_revoke_key,
);
reg.register(
method_ids::CHECK_REVOCATION,
"CheckRevocation",
user::handle_check_revocation,
);
reg.register(
method_ids::AUDIT_KEY_TRANSPARENCY,
"AuditKeyTransparency",
user::handle_audit_key_transparency,
);
// Blob (600-601) — longer timeout for file transfers.
reg.register_with_timeout(
method_ids::UPLOAD_BLOB,
"UploadBlob",
std::time::Duration::from_secs(120),
blob::handle_upload_blob,
);
reg.register_with_timeout(
method_ids::DOWNLOAD_BLOB,
"DownloadBlob",
std::time::Duration::from_secs(120),
blob::handle_download_blob,
);
// Device (700-702)
reg.register(
method_ids::REGISTER_DEVICE,
"RegisterDevice",
device::handle_register_device,
);
reg.register(
method_ids::LIST_DEVICES,
"ListDevices",
device::handle_list_devices,
);
reg.register(
method_ids::REVOKE_DEVICE,
"RevokeDevice",
device::handle_revoke_device,
);
// P2P (800-802)
reg.register(
method_ids::PUBLISH_ENDPOINT,
"PublishEndpoint",
p2p::handle_publish_endpoint,
);
reg.register(
method_ids::RESOLVE_ENDPOINT,
"ResolveEndpoint",
p2p::handle_resolve_endpoint,
);
reg.register_with_timeout(
method_ids::HEALTH,
"Health",
std::time::Duration::from_secs(5),
p2p::handle_health,
);
// Federation (900-905)
reg.register(
method_ids::RELAY_ENQUEUE,
"RelayEnqueue",
federation::handle_relay_enqueue,
);
reg.register(
method_ids::RELAY_BATCH_ENQUEUE,
"RelayBatchEnqueue",
federation::handle_relay_batch_enqueue,
);
reg.register(
method_ids::PROXY_FETCH_KEY_PACKAGE,
"ProxyFetchKeyPackage",
federation::handle_proxy_fetch_key_package,
);
reg.register(
method_ids::PROXY_FETCH_HYBRID_KEY,
"ProxyFetchHybridKey",
federation::handle_proxy_fetch_hybrid_key,
);
reg.register(
method_ids::PROXY_RESOLVE_USER,
"ProxyResolveUser",
federation::handle_proxy_resolve_user,
);
reg.register(
method_ids::FEDERATION_HEALTH,
"FederationHealth",
federation::handle_federation_health,
);
// Moderation (420-424)
reg.register(
method_ids::REPORT_MESSAGE,
"ReportMessage",
moderation::handle_report_message,
);
reg.register(
method_ids::BAN_USER,
"BanUser",
moderation::handle_ban_user,
);
reg.register(
method_ids::UNBAN_USER,
"UnbanUser",
moderation::handle_unban_user,
);
reg.register(
method_ids::LIST_REPORTS,
"ListReports",
moderation::handle_list_reports,
);
reg.register(
method_ids::LIST_BANNED,
"ListBanned",
moderation::handle_list_banned,
);
// Recovery (750-752)
reg.register(
method_ids::STORE_RECOVERY_BUNDLE,
"StoreRecoveryBundle",
recovery::handle_store_recovery_bundle,
);
reg.register(
method_ids::FETCH_RECOVERY_BUNDLE,
"FetchRecoveryBundle",
recovery::handle_fetch_recovery_bundle,
);
reg.register(
method_ids::DELETE_RECOVERY_BUNDLE,
"DeleteRecoveryBundle",
recovery::handle_delete_recovery_bundle,
);
// Account (950)
reg.register(
method_ids::DELETE_ACCOUNT,
"DeleteAccount",
account::handle_delete_account,
);
reg
}

View File

@@ -0,0 +1,199 @@
//! Moderation handlers — report, ban, unban, list reports, list banned.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::error::RpcStatus;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use tracing::{info, warn};
use super::{require_auth, BanRecord, ModerationReport, ServerState};
/// Submit an encrypted report. Any authenticated user can report.
pub async fn handle_report_message(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::ReportMessageRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
if req.encrypted_report.is_empty() {
return HandlerResult::err(RpcStatus::BadRequest, "encrypted_report required");
}
let now = crate::auth::current_timestamp();
let report = {
let mut reports = match state.moderation_reports.lock() {
Ok(r) => r,
Err(e) => {
warn!("moderation_reports lock poisoned: {e}");
return HandlerResult::err(RpcStatus::Internal, "internal error");
}
};
let id = reports.len() as u64;
let report = ModerationReport {
id,
encrypted_report: req.encrypted_report,
conversation_id: req.conversation_id,
reporter_identity: identity_key.clone(),
timestamp: now,
};
reports.push(report.clone());
report
};
info!(
report_id = report.id,
reporter = hex::encode(&identity_key[..4.min(identity_key.len())]),
"moderation report submitted"
);
let proto = v1::ReportMessageResponse { accepted: true };
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
/// Ban a user. Requires admin role (currently: any authenticated user for MVP).
pub async fn handle_ban_user(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let admin_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::BanUserRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
if req.identity_key.is_empty() || req.identity_key.len() != 32 {
return HandlerResult::err(RpcStatus::BadRequest, "identity_key must be 32 bytes");
}
let now = crate::auth::current_timestamp();
let expires_at = if req.duration_secs == 0 {
0 // permanent
} else {
now + req.duration_secs
};
let record = BanRecord {
reason: req.reason.clone(),
banned_at: now,
expires_at,
};
state.banned_users.insert(req.identity_key.clone(), record);
info!(
target_key = hex::encode(&req.identity_key[..4]),
admin_key = hex::encode(&admin_key[..4.min(admin_key.len())]),
reason = %req.reason,
duration_secs = req.duration_secs,
"user banned"
);
let proto = v1::BanUserResponse { success: true };
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
/// Unban a user. Requires admin role.
pub async fn handle_unban_user(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let admin_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::UnbanUserRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
if req.identity_key.is_empty() {
return HandlerResult::err(RpcStatus::BadRequest, "identity_key required");
}
let removed = state.banned_users.remove(&req.identity_key).is_some();
info!(
target_key = hex::encode(&req.identity_key[..4.min(req.identity_key.len())]),
admin_key = hex::encode(&admin_key[..4.min(admin_key.len())]),
removed,
"user unbanned"
);
let proto = v1::UnbanUserResponse { success: removed };
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
/// List moderation reports. Requires admin role.
pub async fn handle_list_reports(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let _admin_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::ListReportsRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let reports = match state.moderation_reports.lock() {
Ok(r) => r,
Err(e) => {
warn!("moderation_reports lock poisoned: {e}");
return HandlerResult::err(RpcStatus::Internal, "internal error");
}
};
let offset = req.offset as usize;
let limit = if req.limit == 0 { 50 } else { req.limit as usize };
let entries: Vec<v1::ReportEntry> = reports
.iter()
.skip(offset)
.take(limit)
.map(|r| v1::ReportEntry {
id: r.id,
encrypted_report: r.encrypted_report.clone(),
conversation_id: r.conversation_id.clone(),
reporter_identity: r.reporter_identity.clone(),
timestamp: r.timestamp,
})
.collect();
let proto = v1::ListReportsResponse { reports: entries };
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
/// List banned users.
pub async fn handle_list_banned(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let _admin_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let _req = match v1::ListBannedRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")),
};
let now = crate::auth::current_timestamp();
let entries: Vec<v1::BannedUserEntry> = state
.banned_users
.iter()
.filter(|entry| entry.expires_at == 0 || entry.expires_at > now)
.map(|entry| v1::BannedUserEntry {
identity_key: entry.key().clone(),
reason: entry.reason.clone(),
banned_at: entry.banned_at,
expires_at: entry.expires_at,
})
.collect();
let proto = v1::ListBannedResponse { users: entries };
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}

View File

@@ -0,0 +1,118 @@
//! P2P handlers — publish/resolve endpoints and health check.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use crate::domain::p2p::P2pService;
use crate::domain::types::{CallerAuth, PublishEndpointReq, ResolveEndpointReq};
use super::{domain_err, require_auth, ServerState};
fn caller_auth(identity_key: Vec<u8>) -> CallerAuth {
CallerAuth {
identity_key,
token: Vec::new(),
device_id: None,
}
}
pub async fn handle_publish_endpoint(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::PublishEndpointRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = P2pService {
store: Arc::clone(&state.store),
};
let auth = caller_auth(identity_key);
let domain_req = PublishEndpointReq {
identity_key: req.identity_key,
node_addr: req.node_addr,
};
match svc.publish_endpoint(domain_req, &auth) {
Ok(()) => {
let proto = v1::PublishEndpointResponse {};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_resolve_endpoint(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::ResolveEndpointRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = P2pService {
store: Arc::clone(&state.store),
};
let auth = caller_auth(identity_key);
let domain_req = ResolveEndpointReq {
identity_key: req.identity_key,
};
match svc.resolve_endpoint(domain_req, &auth) {
Ok(resp) => {
let proto = v1::ResolveEndpointResponse {
node_addr: resp.node_addr,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_health(
state: Arc<ServerState>,
_ctx: RequestContext,
) -> HandlerResult {
let status = if state.draining.load(std::sync::atomic::Ordering::Relaxed) {
"draining"
} else {
"ok"
};
let uptime = state.start_time.elapsed().as_secs();
let resp = v1::HealthResponse {
status: status.into(),
node_id: state.node_id.clone(),
version: env!("CARGO_PKG_VERSION").to_string(),
uptime_secs: uptime,
storage_backend: state.storage_backend.clone(),
};
HandlerResult::ok(Bytes::from(resp.encode_to_vec()))
}

View File

@@ -0,0 +1,99 @@
//! Recovery handlers — store/fetch/delete encrypted recovery bundles.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use crate::domain::recovery::RecoveryService;
use super::{domain_err, ServerState};
/// Store an encrypted recovery bundle (no auth required — recovery is pre-login).
pub async fn handle_store_recovery_bundle(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let req = match v1::StoreRecoveryBundleRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = RecoveryService {
store: Arc::clone(&state.store),
};
match svc.store_bundle(&req.token_hash, req.bundle, req.ttl_secs) {
Ok(()) => {
let proto = v1::StoreRecoveryBundleResponse { success: true };
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
/// Fetch an encrypted recovery bundle (no auth required — recovery is pre-login).
pub async fn handle_fetch_recovery_bundle(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let req = match v1::FetchRecoveryBundleRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = RecoveryService {
store: Arc::clone(&state.store),
};
match svc.fetch_bundle(&req.token_hash) {
Ok(bundle_opt) => {
let proto = v1::FetchRecoveryBundleResponse {
bundle: bundle_opt.unwrap_or_default(),
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
/// Delete an encrypted recovery bundle (no auth required — caller proves
/// knowledge of the token_hash).
pub async fn handle_delete_recovery_bundle(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let req = match v1::DeleteRecoveryBundleRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = RecoveryService {
store: Arc::clone(&state.store),
};
match svc.delete_bundle(&req.token_hash) {
Ok(deleted) => {
let proto = v1::DeleteRecoveryBundleResponse { success: deleted };
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}

View File

@@ -0,0 +1,213 @@
//! User resolution handlers — username <-> identity key lookups,
//! key revocation, and KT audit.
use std::sync::Arc;
use bytes::Bytes;
use prost::Message;
use quicprochat_proto::qpc::v1;
use quicprochat_rpc::method::{HandlerResult, RequestContext};
use crate::domain::types::{
AuditKeyTransparencyReq, CheckRevocationReq, ResolveIdentityReq, ResolveUserReq, RevokeKeyReq,
};
use crate::domain::users::UserService;
use super::{domain_err, require_auth, ServerState};
fn user_svc(state: &Arc<ServerState>) -> UserService {
UserService {
store: Arc::clone(&state.store),
kt_log: Arc::clone(&state.kt_log),
revocation_log: Arc::clone(&state.revocation_log),
}
}
pub async fn handle_resolve_user(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let _identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::ResolveUserRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = user_svc(&state);
let domain_req = ResolveUserReq {
username: req.username,
};
match svc.resolve_user(domain_req) {
Ok(resp) => {
let proto = v1::ResolveUserResponse {
identity_key: resp.identity_key,
inclusion_proof: resp.inclusion_proof,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_resolve_identity(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let _identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::ResolveIdentityRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = user_svc(&state);
let domain_req = ResolveIdentityReq {
identity_key: req.identity_key,
};
match svc.resolve_identity(domain_req) {
Ok(resp) => {
let proto = v1::ResolveIdentityResponse {
username: resp.username,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_revoke_key(state: Arc<ServerState>, ctx: RequestContext) -> HandlerResult {
let _identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::RevokeKeyRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = user_svc(&state);
let domain_req = RevokeKeyReq {
identity_key: req.identity_key,
reason: req.reason,
};
match svc.revoke_key(domain_req) {
Ok(resp) => {
let proto = v1::RevokeKeyResponse {
success: resp.success,
leaf_index: resp.leaf_index,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_check_revocation(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let _identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::CheckRevocationRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = user_svc(&state);
let domain_req = CheckRevocationReq {
identity_key: req.identity_key,
};
match svc.check_revocation(domain_req) {
Ok(resp) => {
let proto = v1::CheckRevocationResponse {
revoked: resp.revoked,
reason: resp.reason,
timestamp_ms: resp.timestamp_ms,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}
pub async fn handle_audit_key_transparency(
state: Arc<ServerState>,
ctx: RequestContext,
) -> HandlerResult {
let _identity_key = match require_auth(&state, &ctx) {
Ok(ik) => ik,
Err(e) => return e,
};
let req = match v1::AuditKeyTransparencyRequest::decode(ctx.payload) {
Ok(r) => r,
Err(e) => {
return HandlerResult::err(
quicprochat_rpc::error::RpcStatus::BadRequest,
&format!("decode: {e}"),
)
}
};
let svc = user_svc(&state);
let domain_req = AuditKeyTransparencyReq {
start: req.start,
end: req.end,
};
match svc.audit_key_transparency(domain_req) {
Ok(resp) => {
let proto = v1::AuditKeyTransparencyResponse {
entries: resp
.entries
.into_iter()
.map(|e| v1::LogEntry {
index: e.index,
leaf_hash: e.leaf_hash,
})
.collect(),
tree_size: resp.tree_size,
root: resp.root,
};
HandlerResult::ok(Bytes::from(proto.encode_to_vec()))
}
Err(e) => domain_err(e),
}
}

View File

@@ -0,0 +1,420 @@
//! WebTransport server endpoint for browser clients.
//!
//! Accepts HTTP/3 WebTransport sessions and dispatches RPC requests through the
//! same v2 handler registry as the native QUIC endpoint. Browsers connect via:
//!
//! ```js
//! const wt = new WebTransport("https://server:7443");
//! ```
//!
//! Each WebTransport bidirectional stream carries a single RPC request/response
//! using the same wire format as the native QUIC transport:
//!
//! ```text
//! [method_id: u16][request_id: u32][payload_len: u32][protobuf bytes]
//! ```
use std::sync::Arc;
use bytes::BytesMut;
use h3_quinn::quinn;
use h3_webtransport::server::AcceptedBi;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::{debug, info, warn};
use crate::v2_handlers::ServerState;
use quicprochat_rpc::error::RpcStatus;
use quicprochat_rpc::framing::{RequestFrame, ResponseFrame};
use quicprochat_rpc::method::{HandlerResult, MethodRegistry, RequestContext};
/// Concrete H3 connection type.
type H3Conn = h3::server::Connection<h3_quinn::Connection, bytes::Bytes>;
/// Concrete request resolver type.
type H3Resolver = h3::server::RequestResolver<h3_quinn::Connection, bytes::Bytes>;
/// Type alias for the concrete WebTransport session type used by this server.
type WtSession =
h3_webtransport::server::WebTransportSession<h3_quinn::Connection, bytes::Bytes>;
/// Type alias for the concrete WebTransport bidi stream.
type WtBidiStream = h3_webtransport::stream::BidiStream<
h3_quinn::BidiStream<bytes::Bytes>,
bytes::Bytes,
>;
/// Start the WebTransport listener in a background task.
///
/// The endpoint uses the provided quinn `ServerConfig` (with "h3" ALPN) and
/// binds to `listen_addr`. Incoming HTTP/3 CONNECT requests are upgraded to
/// WebTransport sessions.
pub fn spawn_webtransport_listener(
listen_addr: std::net::SocketAddr,
server_config: quinn::ServerConfig,
state: Arc<ServerState>,
registry: Arc<MethodRegistry<ServerState>>,
) -> anyhow::Result<()> {
let endpoint = quinn::Endpoint::server(server_config, listen_addr)
.map_err(|e| anyhow::anyhow!("bind WebTransport endpoint {listen_addr}: {e}"))?;
info!(addr = %listen_addr, "WebTransport endpoint listening");
tokio::spawn(async move {
accept_loop(endpoint, state, registry).await;
});
Ok(())
}
/// Accept QUIC connections and upgrade them to HTTP/3 + WebTransport.
async fn accept_loop(
endpoint: quinn::Endpoint,
state: Arc<ServerState>,
registry: Arc<MethodRegistry<ServerState>>,
) {
while let Some(incoming) = endpoint.accept().await {
let state = Arc::clone(&state);
let registry = Arc::clone(&registry);
tokio::spawn(async move {
let connection = match incoming.await {
Ok(c) => c,
Err(e) => {
warn!(error = %e, "WebTransport: QUIC accept failed");
return;
}
};
let remote = connection.remote_address();
debug!(remote = %remote, "WebTransport: new QUIC connection");
metrics::counter!("webtransport_connections_total").increment(1);
metrics::gauge!("webtransport_active_connections").increment(1.0);
if let Err(e) = handle_h3_connection(connection, state, registry).await {
debug!(remote = %remote, error = %e, "WebTransport: session error");
}
metrics::gauge!("webtransport_active_connections").decrement(1.0);
});
}
}
/// Handle an HTTP/3 connection: accept the WebTransport CONNECT request and
/// process bidirectional streams as RPC calls.
async fn handle_h3_connection(
connection: quinn::Connection,
state: Arc<ServerState>,
registry: Arc<MethodRegistry<ServerState>>,
) -> anyhow::Result<()> {
let h3_quinn_conn = h3_quinn::Connection::new(connection);
let mut h3_conn: H3Conn = h3::server::builder()
.enable_webtransport(true)
.enable_extended_connect(true)
.enable_datagram(true)
.build::<h3_quinn::Connection, bytes::Bytes>(h3_quinn_conn)
.await
.map_err(|e| anyhow::anyhow!("H3 connection setup: {e}"))?;
// Accept HTTP/3 requests until we get a CONNECT for WebTransport.
loop {
let resolver: H3Resolver = match h3_conn.accept().await {
Ok(Some(r)) => r,
Ok(None) => {
debug!("WebTransport: H3 connection closed");
return Ok(());
}
Err(e) => {
return Err(anyhow::anyhow!("WebTransport: H3 accept error: {e}"));
}
};
let (request, stream) = resolver
.resolve_request()
.await
.map_err(|e| anyhow::anyhow!("resolve request: {e}"))?;
let method = request.method().clone();
let uri = request.uri().clone();
if method == http::Method::CONNECT {
debug!(uri = %uri, "WebTransport: CONNECT request");
let wt_session = h3_webtransport::server::WebTransportSession::accept(
request, stream, h3_conn,
)
.await
.map_err(|e| anyhow::anyhow!("WebTransport session accept: {e}"))?;
info!("WebTransport: session established");
metrics::counter!("webtransport_sessions_total").increment(1);
serve_wt_streams(wt_session, state, registry).await;
return Ok(());
}
debug!(method = %method, uri = %uri, "WebTransport: non-CONNECT request ignored");
}
}
/// Per-connection state from the WebTransport auth handshake.
#[derive(Debug, Clone, Default)]
struct ConnectionState {
session_token: Option<Vec<u8>>,
identity_key: Option<Vec<u8>>,
}
/// Accept bidirectional streams from a WebTransport session and dispatch
/// each as an RPC request.
async fn serve_wt_streams(
session: WtSession,
state: Arc<ServerState>,
registry: Arc<MethodRegistry<ServerState>>,
) {
// Auth handshake: the first bidi stream carries the session token.
let conn_state: Arc<ConnectionState> = match accept_auth_stream(&session).await {
Ok(cs) => Arc::new(cs),
Err(e) => {
warn!(error = %e, "WebTransport: auth handshake failed");
return;
}
};
loop {
match session.accept_bi().await {
Ok(Some(AcceptedBi::BidiStream(_session_id, stream))) => {
let state = Arc::clone(&state);
let registry = Arc::clone(&registry);
let conn_state = Arc::clone(&conn_state);
tokio::spawn(async move {
if let Err(e) =
handle_wt_bidi_stream(stream, state, registry, &conn_state).await
{
debug!(error = %e, "WebTransport: stream error");
}
});
}
Ok(Some(AcceptedBi::Request(_req, _stream))) => {
debug!("WebTransport: ignoring nested HTTP/3 request");
}
Ok(None) => {
debug!("WebTransport: no more bidi streams");
break;
}
Err(e) => {
debug!(error = %e, "WebTransport: accept_bi error");
break;
}
}
}
}
/// Accept the first bidirectional stream as an auth init handshake.
///
/// The client sends a raw session token (length-prefixed: `u32 BE + token bytes`).
/// The server reads it and sends a 1-byte ack (0x00).
async fn accept_auth_stream(session: &WtSession) -> anyhow::Result<ConnectionState> {
let accepted = session
.accept_bi()
.await
.map_err(|e| anyhow::anyhow!("auth stream accept: {e}"))?
.ok_or_else(|| anyhow::anyhow!("session closed before auth handshake"))?;
let mut stream: WtBidiStream = match accepted {
AcceptedBi::BidiStream(_session_id, stream) => stream,
AcceptedBi::Request(_, _) => {
anyhow::bail!("expected bidi stream for auth, got HTTP/3 request")
}
};
// Read the token: [len: u32 BE][token bytes]
let mut header = [0u8; 4];
AsyncReadExt::read_exact(&mut stream, &mut header)
.await
.map_err(|e| anyhow::anyhow!("auth read header: {e}"))?;
let len = u32::from_be_bytes(header) as usize;
if len > 4096 {
anyhow::bail!("auth token too large: {len} bytes");
}
let mut token = vec![0u8; len];
if len > 0 {
AsyncReadExt::read_exact(&mut stream, &mut token)
.await
.map_err(|e| anyhow::anyhow!("auth read token: {e}"))?;
}
// Send ack: single zero byte.
AsyncWriteExt::write_all(&mut stream, &[0u8])
.await
.map_err(|e| anyhow::anyhow!("auth ack send: {e}"))?;
debug!(token_len = token.len(), "WebTransport: auth init received");
Ok(ConnectionState {
session_token: Some(token),
identity_key: None,
})
}
/// Handle a single WebTransport bidirectional stream: read request, dispatch,
/// write response. Uses the same framing as native QUIC.
async fn handle_wt_bidi_stream(
mut stream: WtBidiStream,
state: Arc<ServerState>,
registry: Arc<MethodRegistry<ServerState>>,
conn_state: &ConnectionState,
) -> anyhow::Result<()> {
// Read the complete request from the stream.
let max_size =
quicprochat_rpc::framing::MAX_PAYLOAD_SIZE + quicprochat_rpc::framing::REQUEST_HEADER_SIZE;
let mut buf = Vec::with_capacity(1024);
let mut tmp = [0u8; 8192];
loop {
let n = AsyncReadExt::read(&mut stream, &mut tmp)
.await
.map_err(|e| anyhow::anyhow!("recv: {e}"))?;
if n == 0 {
break;
}
buf.extend_from_slice(&tmp[..n]);
if buf.len() > max_size {
anyhow::bail!("payload too large");
}
}
let mut bytes = BytesMut::from(buf.as_slice());
let frame = match RequestFrame::decode(&mut bytes)
.map_err(|e| anyhow::anyhow!("decode: {e}"))?
{
Some(f) => f,
None => anyhow::bail!("incomplete request frame"),
};
let trace_id = format!(
"wt-{:016x}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64
);
let result = match registry.get(frame.method_id) {
Some((handler, name, timeout)) => {
let span = tracing::info_span!(
"wt_rpc",
trace_id = %trace_id,
method_id = frame.method_id,
method = name,
req_id = frame.request_id,
);
let _guard = span.enter();
debug!("dispatching");
let deadline = timeout.map(|d| tokio::time::Instant::now() + d);
let start = std::time::Instant::now();
let ctx = RequestContext {
identity_key: conn_state.identity_key.clone(),
session_token: conn_state.session_token.clone(),
payload: frame.payload,
trace_id: trace_id.clone(),
deadline,
};
let result = if let Some(dur) = timeout {
match tokio::time::timeout(dur, handler(Arc::clone(&state), ctx)).await {
Ok(r) => r,
Err(_) => {
warn!(method = name, "WebTransport: request deadline exceeded");
HandlerResult::err(RpcStatus::DeadlineExceeded, "request deadline exceeded")
}
}
} else {
handler(Arc::clone(&state), ctx).await
};
let elapsed = start.elapsed();
metrics::histogram!("webtransport_request_duration_seconds", "method" => name)
.record(elapsed.as_secs_f64());
metrics::counter!("webtransport_requests_total", "method" => name).increment(1);
result
}
None => {
warn!(method_id = frame.method_id, "WebTransport: unknown method");
HandlerResult::err(RpcStatus::UnknownMethod, "unknown method")
}
};
let response = ResponseFrame {
status: result.status as u8,
request_id: frame.request_id,
payload: result.payload,
};
let encoded = response.encode();
AsyncWriteExt::write_all(&mut stream, &encoded)
.await
.map_err(|e| anyhow::anyhow!("send response: {e}"))?;
AsyncWriteExt::shutdown(&mut stream)
.await
.map_err(|e| anyhow::anyhow!("shutdown: {e}"))?;
Ok(())
}
/// Build a quinn `ServerConfig` for the WebTransport endpoint.
///
/// Uses the same TLS cert/key as the main server but with "h3" ALPN.
pub fn build_webtransport_server_config(
cert_path: &std::path::Path,
key_path: &std::path::Path,
) -> anyhow::Result<quinn::ServerConfig> {
use anyhow::Context;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::version::TLS13;
let cert_bytes = std::fs::read(cert_path).context("read WebTransport cert")?;
let key_bytes = std::fs::read(key_path).context("read WebTransport key")?;
let cert_chain = vec![CertificateDer::from(cert_bytes)];
let key = PrivateKeyDer::try_from(key_bytes)
.map_err(|_| anyhow::anyhow!("invalid WebTransport private key"))?;
let mut tls = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13])
.with_no_client_auth()
.with_single_cert(cert_chain, key)?;
tls.alpn_protocols = vec![b"h3".to_vec()];
let crypto = quinn_proto::crypto::rustls::QuicServerConfig::try_from(tls)
.map_err(|e| anyhow::anyhow!("invalid WebTransport TLS config: {e}"))?;
let mut transport = quinn::TransportConfig::default();
transport.max_idle_timeout(Some(
std::time::Duration::from_secs(300)
.try_into()
.map_err(|e| anyhow::anyhow!("idle timeout: {e}"))?,
));
// WebTransport sessions may have multiple simultaneous streams.
transport.max_concurrent_bidi_streams(64u32.into());
transport.max_concurrent_uni_streams(16u32.into());
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(crypto));
server_config.transport_config(Arc::new(transport));
Ok(server_config)
}
#[cfg(test)]
mod tests {
#[test]
fn webtransport_module_compiles() {
assert!(true);
}
}

View File

@@ -0,0 +1,645 @@
//! WebSocket JSON-RPC bridge for browser clients.
//!
//! Provides a lightweight JSON-RPC interface over WebSocket so that browsers
//! can interact with the server without a Cap'n Proto / QUIC stack.
//!
//! Security parity with the Cap'n Proto path:
//! - Rate limiting via `check_rate_limit()` on all mutating handlers
//! - DM channel membership verification on `send`
//! - Payload size limits (5 MB)
//! - Timing floor on `resolveUser` to mask lookup timing
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use base64::Engine;
use dashmap::DashMap;
use futures::stream::StreamExt;
use futures::SinkExt;
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::time::Instant;
use tokio_tungstenite::tungstenite::Message;
use crate::auth::{check_rate_limit, validate_token_raw, AuthConfig, AuthContext, RateEntry, SessionInfo};
use crate::storage::Store;
const B64: base64::engine::general_purpose::GeneralPurpose =
base64::engine::general_purpose::STANDARD;
/// Maximum payload size for WS bridge (same as Cap'n Proto path).
const MAX_PAYLOAD_BYTES: usize = 5 * 1024 * 1024;
/// Minimum response time for resolveUser to mask DB lookup timing differences.
const RESOLVE_TIMING_FLOOR: Duration = Duration::from_millis(5);
// ── Shared state ────────────────────────────────────────────────────────────
/// Subset of server state needed by the WS bridge (all `Send + Sync`).
#[allow(dead_code)] // sealed_sender plumbed for future use
pub struct WsBridgeState {
pub store: Arc<dyn Store>,
pub waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
pub auth_cfg: Arc<AuthConfig>,
pub sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
pub rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
pub sealed_sender: bool,
pub allow_insecure_auth: bool,
}
// ── JSON-RPC types ──────────────────────────────────────────────────────────
#[derive(Deserialize)]
struct RpcRequest {
id: serde_json::Value,
method: String,
#[serde(default)]
params: serde_json::Value,
}
#[derive(Serialize)]
struct RpcResponse {
id: serde_json::Value,
ok: bool,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
}
impl RpcResponse {
fn success(id: serde_json::Value, result: serde_json::Value) -> Self {
Self {
id,
ok: true,
result: Some(result),
error: None,
}
}
fn error(id: serde_json::Value, msg: impl Into<String>) -> Self {
Self {
id,
ok: false,
result: None,
error: Some(msg.into()),
}
}
}
// ── Auth helper ─────────────────────────────────────────────────────────────
/// Extract and validate the "token" field from params. In insecure-auth mode
/// with no token configured, an empty token is accepted as the bearer.
fn extract_auth(
state: &WsBridgeState,
params: &serde_json::Value,
) -> Result<AuthContext, String> {
let token_str = params
.get("token")
.and_then(|v| v.as_str())
.unwrap_or("");
let token_bytes = token_str.as_bytes().to_vec();
// In insecure-auth mode with no configured token, accept any (including empty)
// token as the bearer token. This mirrors the Cap'n Proto path behaviour.
if state.allow_insecure_auth && state.auth_cfg.required_token.is_none() {
// Treat the request identity from params as the identity.
return Ok(AuthContext {
token: token_bytes,
identity_key: None,
});
}
validate_token_raw(&state.auth_cfg, &state.sessions, &token_bytes)
}
/// Resolve identity key: either from auth context (session-bound) or from
/// request params (insecure-auth mode). Returns the 32-byte identity key.
fn resolve_identity(
state: &WsBridgeState,
auth_ctx: &AuthContext,
params: &serde_json::Value,
) -> Result<Vec<u8>, String> {
// If auth context has an identity-bound session, use that.
if let Some(ref ik) = auth_ctx.identity_key {
return Ok(ik.clone());
}
// In insecure-auth mode, accept identity from params.
if state.allow_insecure_auth {
// Try base64-encoded identityKey first.
if let Some(b64) = params.get("identityKey").and_then(|v| v.as_str()) {
return B64
.decode(b64)
.map_err(|e| format!("bad base64 identityKey: {e}"));
}
// Try username lookup.
if let Some(username) = params.get("username").and_then(|v| v.as_str()) {
if let Ok(Some(ik)) = state.store.get_user_identity_key(username) {
return Ok(ik);
}
return Err(format!("user not found: {username}"));
}
}
Err("no identity: login required or pass identityKey/username in insecure mode".to_string())
}
/// Apply rate limiting using the auth token. Returns an error string on limit exceeded.
fn ws_check_rate_limit(state: &WsBridgeState, auth_ctx: &AuthContext) -> Result<(), String> {
check_rate_limit(&state.rate_limits, &auth_ctx.token)
.map_err(|e| format!("rate limit exceeded: {e}"))
}
// ── Dispatch ────────────────────────────────────────────────────────────────
async fn dispatch(state: &WsBridgeState, req: RpcRequest) -> RpcResponse {
match req.method.as_str() {
"health" => handle_health(req.id),
"resolveUser" => handle_resolve_user(state, req.id, &req.params).await,
"createChannel" => handle_create_channel(state, req.id, &req.params),
"send" => handle_send(state, req.id, &req.params),
"receive" => handle_receive(state, req.id, &req.params),
"deleteAccount" => handle_delete_account(state, req.id, &req.params),
"register" => handle_register(state, req.id, &req.params),
_ => RpcResponse::error(req.id, format!("unknown method: {}", req.method)),
}
}
// ── Handlers ────────────────────────────────────────────────────────────────
fn handle_health(id: serde_json::Value) -> RpcResponse {
RpcResponse::success(id, serde_json::json!("ok"))
}
fn handle_register(
state: &WsBridgeState,
id: serde_json::Value,
params: &serde_json::Value,
) -> RpcResponse {
// Only allow in insecure-auth mode (development/demo).
if !state.allow_insecure_auth {
return RpcResponse::error(id, "register is only available in --allow-insecure-auth mode");
}
// Rate limit.
let auth_ctx = match extract_auth(state, params) {
Ok(ctx) => ctx,
Err(e) => return RpcResponse::error(id, e),
};
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
return RpcResponse::error(id, e);
}
// Validate username.
let username = match params.get("username").and_then(|v| v.as_str()) {
Some(u) if !u.is_empty() => u,
_ => return RpcResponse::error(id, "missing or empty 'username' param"),
};
if username.len() > 32 {
return RpcResponse::error(id, "username must be at most 32 characters");
}
if !username.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
return RpcResponse::error(id, "username must be alphanumeric or underscore only");
}
// Validate identity key.
let ik_b64 = match params.get("identityKey").and_then(|v| v.as_str()) {
Some(s) if !s.is_empty() => s,
_ => return RpcResponse::error(id, "missing or empty 'identityKey' param"),
};
let identity_key = match B64.decode(ik_b64) {
Ok(k) => k,
Err(e) => return RpcResponse::error(id, format!("bad base64 identityKey: {e}")),
};
if identity_key.len() != 32 {
return RpcResponse::error(id, "identityKey must be 32 bytes");
}
// Check if username is already taken by a different key.
match state.store.get_user_identity_key(username) {
Ok(Some(existing)) if existing == identity_key => {
// Idempotent: same key, return success.
return RpcResponse::success(
id,
serde_json::json!({
"username": username,
"identityKey": B64.encode(&identity_key),
}),
);
}
Ok(Some(_)) => {
return RpcResponse::error(id, "username already taken");
}
Ok(None) => {} // Available, proceed.
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
}
// Store the mapping.
if let Err(e) = state.store.store_user_identity_key(username, identity_key.clone()) {
return RpcResponse::error(id, format!("storage error: {e}"));
}
tracing::info!(
username = %username,
key_prefix = %hex::encode(&identity_key[..4]),
"audit: ws_bridge register"
);
RpcResponse::success(
id,
serde_json::json!({
"username": username,
"identityKey": B64.encode(&identity_key),
}),
)
}
async fn handle_resolve_user(
state: &WsBridgeState,
id: serde_json::Value,
params: &serde_json::Value,
) -> RpcResponse {
let auth_ctx = match extract_auth(state, params) {
Ok(ctx) => ctx,
Err(e) => return RpcResponse::error(id, e),
};
// Rate limit resolve requests to prevent bulk enumeration.
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
return RpcResponse::error(id, e);
}
let username = match params.get("username").and_then(|v| v.as_str()) {
Some(u) if !u.is_empty() => u,
_ => return RpcResponse::error(id, "missing or empty 'username' param"),
};
// Timing floor: mask DB-lookup timing differences between existing and
// non-existing usernames (same as Cap'n Proto resolveUser handler).
let deadline = Instant::now() + RESOLVE_TIMING_FLOOR;
let response = match state.store.get_user_identity_key(username) {
Ok(Some(key)) => {
RpcResponse::success(id, serde_json::json!({ "identityKey": B64.encode(&key) }))
}
Ok(None) => RpcResponse::success(id, serde_json::json!({ "identityKey": null })),
Err(e) => RpcResponse::error(id, format!("storage error: {e}")),
};
// Pad to timing floor before responding.
tokio::time::sleep_until(deadline).await;
response
}
fn handle_create_channel(
state: &WsBridgeState,
id: serde_json::Value,
params: &serde_json::Value,
) -> RpcResponse {
let auth_ctx = match extract_auth(state, params) {
Ok(ctx) => ctx,
Err(e) => return RpcResponse::error(id, e),
};
// Rate limit.
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
return RpcResponse::error(id, e);
}
let my_key = match resolve_identity(state, &auth_ctx, params) {
Ok(k) => k,
Err(e) => return RpcResponse::error(id, e),
};
// Accept peer key as base64 or resolve from username.
let peer_key = if let Some(b64) = params.get("peerKey").and_then(|v| v.as_str()) {
match B64.decode(b64) {
Ok(k) => k,
Err(e) => return RpcResponse::error(id, format!("bad base64 peerKey: {e}")),
}
} else if let Some(username) = params.get("peerUsername").and_then(|v| v.as_str()) {
match state.store.get_user_identity_key(username) {
Ok(Some(k)) => k,
Ok(None) => return RpcResponse::error(id, format!("peer user not found: {username}")),
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
}
} else {
return RpcResponse::error(id, "missing 'peerKey' (base64) or 'peerUsername'");
};
if peer_key.len() != 32 {
return RpcResponse::error(id, "peerKey must be 32 bytes");
}
match state.store.create_channel(&my_key, &peer_key) {
Ok((channel_id, was_new)) => RpcResponse::success(
id,
serde_json::json!({
"channelId": B64.encode(&channel_id),
"wasNew": was_new,
}),
),
Err(e) => RpcResponse::error(id, format!("storage error: {e}")),
}
}
fn handle_send(
state: &WsBridgeState,
id: serde_json::Value,
params: &serde_json::Value,
) -> RpcResponse {
let auth_ctx = match extract_auth(state, params) {
Ok(ctx) => ctx,
Err(e) => return RpcResponse::error(id, e),
};
// Rate limit (parity with Cap'n Proto enqueue path).
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
return RpcResponse::error(id, e);
}
let sender_key = match resolve_identity(state, &auth_ctx, params) {
Ok(k) => k,
Err(e) => return RpcResponse::error(id, e),
};
// Resolve recipient: base64 key or username.
let recipient_key =
if let Some(b64) = params.get("recipientKey").and_then(|v| v.as_str()) {
match B64.decode(b64) {
Ok(k) => k,
Err(e) => {
return RpcResponse::error(id, format!("bad base64 recipientKey: {e}"))
}
}
} else if let Some(username) = params.get("recipient").and_then(|v| v.as_str()) {
match state.store.get_user_identity_key(username) {
Ok(Some(k)) => k,
Ok(None) => {
return RpcResponse::error(id, format!("recipient not found: {username}"))
}
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
}
} else {
return RpcResponse::error(id, "missing 'recipientKey' (base64) or 'recipient' (username)");
};
if recipient_key.len() != 32 {
return RpcResponse::error(id, "recipientKey must be 32 bytes");
}
// Payload: base64-encoded binary or plain text message.
let payload = if let Some(b64) = params.get("payload").and_then(|v| v.as_str()) {
match B64.decode(b64) {
Ok(p) => p,
Err(e) => return RpcResponse::error(id, format!("bad base64 payload: {e}")),
}
} else if let Some(msg) = params.get("message").and_then(|v| v.as_str()) {
msg.as_bytes().to_vec()
} else {
return RpcResponse::error(id, "missing 'payload' (base64) or 'message' (text)");
};
if payload.is_empty() {
return RpcResponse::error(id, "payload must not be empty");
}
// Payload size limit (same as Cap'n Proto path: 5 MB).
if payload.len() > MAX_PAYLOAD_BYTES {
return RpcResponse::error(
id,
format!("payload exceeds max size ({MAX_PAYLOAD_BYTES} bytes)"),
);
}
// Create or look up the DM channel between sender and recipient.
let channel_id = match state.store.create_channel(&sender_key, &recipient_key) {
Ok((ch, _)) => ch,
Err(e) => return RpcResponse::error(id, format!("channel error: {e}")),
};
// DM channel membership verification (parity with Cap'n Proto enqueue path).
if channel_id.len() == 16 {
let members = match state.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => return RpcResponse::error(id, "channel not found"),
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
};
let (a, b) = &members;
let caller_in = sender_key == *a || sender_key == *b;
let recipient_other = (recipient_key == *a && sender_key == *b)
|| (recipient_key == *b && sender_key == *a);
if !caller_in || !recipient_other {
return RpcResponse::error(id, "caller or recipient not a member of this channel");
}
}
match state
.store
.enqueue(&recipient_key, &channel_id, payload, None)
{
Ok(seq) => {
// Notify any waiting long-poll fetchers.
if let Some(notify) = state.waiters.get(&recipient_key) {
notify.notify_waiters();
}
// Audit logging (no secrets: no payload, no full keys).
tracing::info!(
recipient_prefix = %hex::encode(&recipient_key[..std::cmp::min(4, recipient_key.len())]),
seq = seq,
"audit: ws_bridge enqueue"
);
RpcResponse::success(id, serde_json::json!({ "seq": seq }))
}
Err(e) => RpcResponse::error(id, format!("enqueue error: {e}")),
}
}
fn handle_receive(
state: &WsBridgeState,
id: serde_json::Value,
params: &serde_json::Value,
) -> RpcResponse {
let auth_ctx = match extract_auth(state, params) {
Ok(ctx) => ctx,
Err(e) => return RpcResponse::error(id, e),
};
// Rate limit.
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
return RpcResponse::error(id, e);
}
let my_key = match resolve_identity(state, &auth_ctx, params) {
Ok(k) => k,
Err(e) => return RpcResponse::error(id, e),
};
// Resolve sender/peer: base64 key or username (needed to find the channel).
let peer_key = if let Some(b64) = params.get("recipientKey").and_then(|v| v.as_str()) {
match B64.decode(b64) {
Ok(k) => k,
Err(e) => return RpcResponse::error(id, format!("bad base64 recipientKey: {e}")),
}
} else if let Some(username) = params.get("recipient").and_then(|v| v.as_str()) {
match state.store.get_user_identity_key(username) {
Ok(Some(k)) => k,
Ok(None) => return RpcResponse::error(id, format!("user not found: {username}")),
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
}
} else {
return RpcResponse::error(id, "missing 'recipientKey' (base64) or 'recipient' (username)");
};
// Find the channel between me and the peer.
let channel_id = match state.store.create_channel(&my_key, &peer_key) {
Ok((ch, _)) => ch,
Err(e) => return RpcResponse::error(id, format!("channel error: {e}")),
};
// Fetch (drain) all messages for me in this channel.
match state.store.fetch(&my_key, &channel_id) {
Ok(messages) => {
let items: Vec<serde_json::Value> = messages
.into_iter()
.map(|(seq, data)| {
// Try to decode as UTF-8 text; fall back to base64.
let text = String::from_utf8(data.clone()).ok();
serde_json::json!({
"seq": seq,
"data": B64.encode(&data),
"text": text,
})
})
.collect();
RpcResponse::success(id, serde_json::json!(items))
}
Err(e) => RpcResponse::error(id, format!("fetch error: {e}")),
}
}
fn handle_delete_account(
state: &WsBridgeState,
id: serde_json::Value,
params: &serde_json::Value,
) -> RpcResponse {
let auth_ctx = match extract_auth(state, params) {
Ok(ctx) => ctx,
Err(e) => return RpcResponse::error(id, e),
};
// Rate limit.
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
return RpcResponse::error(id, e);
}
let identity_key = match resolve_identity(state, &auth_ctx, params) {
Ok(k) => k,
Err(e) => return RpcResponse::error(id, e),
};
match state.store.delete_account(&identity_key) {
Ok(()) => {
// Invalidate sessions for this identity.
let tokens_to_remove: Vec<Vec<u8>> = state
.sessions
.iter()
.filter(|entry| entry.value().identity_key == identity_key)
.map(|entry| entry.key().clone())
.collect();
for token in &tokens_to_remove {
state.sessions.remove(token);
}
RpcResponse::success(id, serde_json::json!({ "deleted": true }))
}
Err(e) => RpcResponse::error(id, format!("delete failed: {e}")),
}
}
// ── WebSocket listener ──────────────────────────────────────────────────────
/// Spawn the WebSocket JSON-RPC bridge as a background tokio task.
pub fn spawn_ws_bridge(addr: SocketAddr, state: Arc<WsBridgeState>) {
tokio::spawn(async move {
let listener = match TcpListener::bind(addr).await {
Ok(l) => l,
Err(e) => {
tracing::error!(addr = %addr, error = %e, "ws_bridge: failed to bind");
return;
}
};
tracing::info!(addr = %addr, "ws_bridge: accepting WebSocket connections");
loop {
let (stream, peer) = match listener.accept().await {
Ok(pair) => pair,
Err(e) => {
tracing::warn!(error = %e, "ws_bridge: accept error");
continue;
}
};
let state = Arc::clone(&state);
tokio::spawn(async move {
let ws = match tokio_tungstenite::accept_async(stream).await {
Ok(ws) => ws,
Err(e) => {
tracing::debug!(peer = %peer, error = %e, "ws_bridge: handshake failed");
return;
}
};
tracing::debug!(peer = %peer, "ws_bridge: client connected");
let (mut sink, mut stream) = ws.split();
while let Some(msg) = stream.next().await {
let msg = match msg {
Ok(m) => m,
Err(e) => {
tracing::debug!(peer = %peer, error = %e, "ws_bridge: read error");
break;
}
};
let text = match msg {
Message::Text(t) => t,
Message::Close(_) => break,
Message::Ping(_) | Message::Pong(_) => continue,
_ => continue,
};
let req: RpcRequest = match serde_json::from_str(&text) {
Ok(r) => r,
Err(e) => {
let resp = RpcResponse::error(
serde_json::Value::Null,
format!("invalid JSON: {e}"),
);
let json = serde_json::to_string(&resp).unwrap_or_default();
if sink.send(Message::Text(json.into())).await.is_err() {
break;
}
continue;
}
};
let resp = dispatch(&state, req).await;
let json = serde_json::to_string(&resp).unwrap_or_default();
if sink.send(Message::Text(json.into())).await.is_err() {
break;
}
}
tracing::debug!(peer = %peer, "ws_bridge: client disconnected");
});
}
});
}