chore: rename project quicnprotochat -> quicproquo (binaries: qpq)
Rename the entire workspace:
- Crate packages: quicnprotochat-{core,proto,server,client,gui,p2p,mobile} -> quicproquo-*
- Binary names: quicnprotochat -> qpq, quicnprotochat-server -> qpq-server,
quicnprotochat-gui -> qpq-gui
- Default files: *-state.bin -> qpq-state.bin, *-server.toml -> qpq-server.toml,
*.db -> qpq.db
- Environment variable prefix: QUICNPROTOCHAT_* -> QPQ_*
- App identifier: chat.quicnproto.gui -> chat.quicproquo.gui
- Proto package: quicnprotochat.bench -> quicproquo.bench
- All documentation, Docker, CI, and script references updated
HKDF domain-separation strings and P2P ALPN remain unchanged for
backward compatibility with existing encrypted state and wire protocol.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
261
crates/quicproquo-server/src/auth.rs
Normal file
261
crates/quicproquo-server/src/auth.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use quicproquo_proto::node_capnp::auth;
|
||||
use sha2::Digest;
|
||||
use subtle::ConstantTimeEq;
|
||||
use tokio::sync::Notify;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use crate::error_codes::*;
|
||||
|
||||
pub const SESSION_TTL_SECS: u64 = 24 * 60 * 60; // 24 hours
|
||||
pub const PENDING_LOGIN_TTL_SECS: u64 = 300; // 5 minutes
|
||||
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||||
pub const RATE_LIMIT_MAX_ENQUEUES: u32 = 100;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthConfig {
|
||||
/// Server bearer token — zeroized on drop to prevent memory disclosure.
|
||||
pub required_token: Option<Zeroizing<Vec<u8>>>,
|
||||
/// When true, a valid bearer token (no session) is accepted and the request's identity/key is used (dev/e2e only).
|
||||
/// CLI flag: --allow-insecure-auth / QPQ_ALLOW_INSECURE_AUTH.
|
||||
pub allow_insecure_identity_from_request: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AuthConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AuthConfig")
|
||||
.field("required_token", &self.required_token.as_ref().map(|_| "[REDACTED]"))
|
||||
.field("allow_insecure_identity_from_request", &self.allow_insecure_identity_from_request)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthConfig {
|
||||
pub fn new(required_token: Option<String>, allow_insecure_identity_from_request: bool) -> Self {
|
||||
let required_token = required_token
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| Zeroizing::new(s.into_bytes()));
|
||||
Self {
|
||||
required_token,
|
||||
allow_insecure_identity_from_request,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionInfo {
|
||||
#[allow(dead_code)]
|
||||
pub username: String,
|
||||
pub identity_key: Vec<u8>,
|
||||
#[allow(dead_code)]
|
||||
pub created_at: u64,
|
||||
pub expires_at: u64,
|
||||
}
|
||||
|
||||
pub struct PendingLogin {
|
||||
pub state_bytes: Vec<u8>,
|
||||
pub created_at: u64,
|
||||
}
|
||||
|
||||
pub struct RateEntry {
|
||||
pub count: u32,
|
||||
pub window_start: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthContext {
|
||||
pub token: Vec<u8>,
|
||||
pub identity_key: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
pub fn current_timestamp() -> u64 {
|
||||
match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
|
||||
Ok(d) => d.as_secs(),
|
||||
Err(_) => {
|
||||
tracing::warn!("system time is before UNIX_EPOCH; using 0 for session/rate-limit timestamps");
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_rate_limit(
|
||||
rate_limits: &DashMap<Vec<u8>, RateEntry>,
|
||||
token: &[u8],
|
||||
) -> Result<(), capnp::Error> {
|
||||
let now = current_timestamp();
|
||||
let mut entry = rate_limits.entry(token.to_vec()).or_insert(RateEntry {
|
||||
count: 0,
|
||||
window_start: now,
|
||||
});
|
||||
|
||||
if now - entry.window_start >= RATE_LIMIT_WINDOW_SECS {
|
||||
entry.count = 1;
|
||||
entry.window_start = now;
|
||||
} else {
|
||||
entry.count += 1;
|
||||
if entry.count > RATE_LIMIT_MAX_ENQUEUES {
|
||||
return Err(crate::error_codes::coded_error(
|
||||
E014_RATE_LIMITED,
|
||||
format!(
|
||||
"rate limit exceeded: {} enqueues in {}s window",
|
||||
RATE_LIMIT_MAX_ENQUEUES, RATE_LIMIT_WINDOW_SECS
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_auth(
|
||||
cfg: &AuthConfig,
|
||||
sessions: &DashMap<Vec<u8>, SessionInfo>,
|
||||
auth: Result<auth::Reader<'_>, capnp::Error>,
|
||||
) -> Result<(), capnp::Error> {
|
||||
validate_auth_context(cfg, sessions, auth).map(|_| ())
|
||||
}
|
||||
|
||||
pub fn validate_auth_context(
|
||||
cfg: &AuthConfig,
|
||||
sessions: &DashMap<Vec<u8>, SessionInfo>,
|
||||
auth: Result<auth::Reader<'_>, capnp::Error>,
|
||||
) -> Result<AuthContext, capnp::Error> {
|
||||
let auth = auth?;
|
||||
let version = auth.get_version();
|
||||
|
||||
if version != 1 {
|
||||
return Err(crate::error_codes::coded_error(
|
||||
E001_BAD_AUTH_VERSION,
|
||||
format!("unsupported auth version {} (expected 1)", version),
|
||||
));
|
||||
}
|
||||
|
||||
let token = auth
|
||||
.get_access_token()
|
||||
.map_err(|e| crate::error_codes::coded_error(E020_BAD_PARAMS, format!("auth.accessToken: {e}")))?
|
||||
.to_vec();
|
||||
|
||||
if token.is_empty() {
|
||||
return Err(crate::error_codes::coded_error(
|
||||
E002_EMPTY_TOKEN,
|
||||
"auth.version=1 requires non-empty accessToken",
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(expected) = &cfg.required_token {
|
||||
if expected.len() == token.len() && bool::from(expected.as_slice().ct_eq(&token)) {
|
||||
return Ok(AuthContext {
|
||||
token,
|
||||
identity_key: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(session) = sessions.get(&token) {
|
||||
let now = current_timestamp();
|
||||
if session.expires_at > now {
|
||||
let identity = if session.identity_key.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(session.identity_key.clone())
|
||||
};
|
||||
|
||||
return Ok(AuthContext {
|
||||
token,
|
||||
identity_key: identity,
|
||||
});
|
||||
}
|
||||
drop(session);
|
||||
sessions.remove(&token);
|
||||
return Err(crate::error_codes::coded_error(
|
||||
E017_SESSION_EXPIRED,
|
||||
"session token has expired",
|
||||
));
|
||||
}
|
||||
|
||||
Err(crate::error_codes::coded_error(E003_INVALID_TOKEN, "invalid accessToken"))
|
||||
}
|
||||
|
||||
pub fn require_identity<'a>(auth_ctx: &'a AuthContext) -> Result<&'a [u8], capnp::Error> {
|
||||
match auth_ctx.identity_key.as_deref() {
|
||||
Some(ik) => Ok(ik),
|
||||
None => Err(crate::error_codes::coded_error(
|
||||
E003_INVALID_TOKEN,
|
||||
"access token is not identity-bound; login required",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn require_identity_match(auth_ctx: &AuthContext, expected: &[u8]) -> Result<(), capnp::Error> {
|
||||
let ik = require_identity(auth_ctx)?;
|
||||
if ik.len() != expected.len() || !bool::from(ik.ct_eq(expected)) {
|
||||
return Err(crate::error_codes::coded_error(
|
||||
E016_IDENTITY_MISMATCH,
|
||||
"access token is bound to a different identity",
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// When the token is a valid session, require it to match `request_identity`.
|
||||
/// When the token is a bearer token (no identity) and `allow_insecure_identity_from_request` is true, accept the request identity (dev/e2e).
|
||||
pub fn require_identity_or_request(
|
||||
auth_ctx: &AuthContext,
|
||||
request_identity: &[u8],
|
||||
allow_insecure: bool,
|
||||
) -> Result<(), capnp::Error> {
|
||||
match auth_ctx.identity_key.as_deref() {
|
||||
Some(_) => require_identity_match(auth_ctx, request_identity),
|
||||
None if allow_insecure => Ok(()),
|
||||
None => Err(crate::error_codes::coded_error(
|
||||
E003_INVALID_TOKEN,
|
||||
"access token is not identity-bound; login required",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fmt_hex(bytes: &[u8]) -> String {
|
||||
let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
|
||||
format!("{hex}…")
|
||||
}
|
||||
|
||||
pub fn waiter(waiters: &DashMap<Vec<u8>, Arc<Notify>>, recipient_key: &[u8]) -> Arc<Notify> {
|
||||
waiters
|
||||
.entry(recipient_key.to_vec())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub const CONN_RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||||
pub const CONN_RATE_LIMIT_MAX: u32 = 50;
|
||||
|
||||
/// Per-IP connection rate limiter. Returns `true` if the connection is allowed.
|
||||
pub fn check_conn_rate_limit(
|
||||
conn_rate_limits: &DashMap<IpAddr, RateEntry>,
|
||||
ip: IpAddr,
|
||||
) -> bool {
|
||||
let now = current_timestamp();
|
||||
let mut entry = conn_rate_limits.entry(ip).or_insert(RateEntry {
|
||||
count: 0,
|
||||
window_start: now,
|
||||
});
|
||||
|
||||
if now - entry.window_start >= CONN_RATE_LIMIT_WINDOW_SECS {
|
||||
entry.count = 1;
|
||||
entry.window_start = now;
|
||||
true
|
||||
} else {
|
||||
entry.count += 1;
|
||||
entry.count <= CONN_RATE_LIMIT_MAX
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fingerprint(data: &[u8]) -> Vec<u8> {
|
||||
sha2::Sha256::digest(data).to_vec()
|
||||
}
|
||||
|
||||
pub fn coded_error(code: &str, msg: impl std::fmt::Display) -> capnp::Error {
|
||||
crate::error_codes::coded_error(code, msg)
|
||||
}
|
||||
264
crates/quicproquo-server/src/config.rs
Normal file
264
crates/quicproquo-server/src/config.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DEFAULT_LISTEN: &str = "0.0.0.0:7000";
|
||||
pub const DEFAULT_DATA_DIR: &str = "data";
|
||||
pub const DEFAULT_TLS_CERT: &str = "data/server-cert.der";
|
||||
pub const DEFAULT_TLS_KEY: &str = "data/server-key.der";
|
||||
pub const DEFAULT_STORE_BACKEND: &str = "file";
|
||||
pub const DEFAULT_DB_PATH: &str = "data/qpq.db";
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
pub struct FileConfig {
|
||||
pub listen: Option<String>,
|
||||
pub data_dir: Option<String>,
|
||||
pub tls_cert: Option<PathBuf>,
|
||||
pub tls_key: Option<PathBuf>,
|
||||
pub auth_token: Option<String>,
|
||||
pub allow_insecure_auth: Option<bool>,
|
||||
/// When true, enqueue does not require an identity-bound session: only a valid token is required.
|
||||
/// The server does not associate the request with a specific sender (Sealed Sender).
|
||||
#[serde(default)]
|
||||
pub sealed_sender: Option<bool>,
|
||||
pub store_backend: Option<String>,
|
||||
pub db_path: Option<PathBuf>,
|
||||
pub db_key: Option<String>,
|
||||
/// Metrics HTTP listen address (e.g. "0.0.0.0:9090"). If set, /metrics is served there.
|
||||
pub metrics_listen: Option<String>,
|
||||
/// When true and metrics_listen is set, start the metrics server.
|
||||
#[serde(default)]
|
||||
pub metrics_enabled: Option<bool>,
|
||||
pub federation: Option<FederationFileConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EffectiveConfig {
|
||||
pub listen: String,
|
||||
pub data_dir: String,
|
||||
pub tls_cert: PathBuf,
|
||||
pub tls_key: PathBuf,
|
||||
pub auth_token: Option<String>,
|
||||
pub allow_insecure_auth: bool,
|
||||
/// When true, enqueue does not require identity; valid token only (Sealed Sender).
|
||||
pub sealed_sender: bool,
|
||||
pub store_backend: String,
|
||||
pub db_path: PathBuf,
|
||||
pub db_key: String,
|
||||
/// If Some(addr), metrics server listens here (e.g. "0.0.0.0:9090").
|
||||
pub metrics_listen: Option<String>,
|
||||
/// Start metrics server only when true and metrics_listen is set.
|
||||
pub metrics_enabled: bool,
|
||||
pub federation: Option<EffectiveFederationConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
pub struct FederationFileConfig {
|
||||
pub enabled: Option<bool>,
|
||||
pub domain: Option<String>,
|
||||
pub listen: Option<String>,
|
||||
pub federation_cert: Option<PathBuf>,
|
||||
pub federation_key: Option<PathBuf>,
|
||||
pub federation_ca: Option<PathBuf>,
|
||||
#[serde(default)]
|
||||
pub peers: Vec<FederationPeerConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FederationPeerConfig {
|
||||
pub domain: String,
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EffectiveFederationConfig {
|
||||
pub enabled: bool,
|
||||
pub domain: String,
|
||||
pub listen: String,
|
||||
pub federation_cert: PathBuf,
|
||||
pub federation_key: PathBuf,
|
||||
pub federation_ca: PathBuf,
|
||||
pub peers: Vec<FederationPeerConfig>,
|
||||
}
|
||||
|
||||
pub fn load_config(path: Option<&Path>) -> anyhow::Result<FileConfig> {
|
||||
let path = match path {
|
||||
Some(p) => PathBuf::from(p),
|
||||
None => PathBuf::from("qpq-server.toml"),
|
||||
};
|
||||
|
||||
if !path.exists() {
|
||||
return Ok(FileConfig::default());
|
||||
}
|
||||
|
||||
let contents =
|
||||
std::fs::read_to_string(&path).with_context(|| format!("read config file {path:?}"))?;
|
||||
let cfg: FileConfig =
|
||||
toml::from_str(&contents).with_context(|| format!("parse config file {path:?}"))?;
|
||||
Ok(cfg)
|
||||
}
|
||||
|
||||
pub fn merge_config(args: &crate::Args, file: &FileConfig) -> EffectiveConfig {
|
||||
let listen = if args.listen == DEFAULT_LISTEN {
|
||||
file.listen
|
||||
.clone()
|
||||
.unwrap_or_else(|| DEFAULT_LISTEN.to_string())
|
||||
} else {
|
||||
args.listen.clone()
|
||||
};
|
||||
|
||||
let data_dir = if args.data_dir == DEFAULT_DATA_DIR {
|
||||
file.data_dir
|
||||
.clone()
|
||||
.unwrap_or_else(|| DEFAULT_DATA_DIR.to_string())
|
||||
} else {
|
||||
args.data_dir.clone()
|
||||
};
|
||||
|
||||
let tls_cert = if args.tls_cert == PathBuf::from(DEFAULT_TLS_CERT) {
|
||||
file.tls_cert
|
||||
.clone()
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_CERT))
|
||||
} else {
|
||||
args.tls_cert.clone()
|
||||
};
|
||||
|
||||
let tls_key = if args.tls_key == PathBuf::from(DEFAULT_TLS_KEY) {
|
||||
file.tls_key
|
||||
.clone()
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_KEY))
|
||||
} else {
|
||||
args.tls_key.clone()
|
||||
};
|
||||
|
||||
let auth_token = if args.auth_token.is_some() {
|
||||
args.auth_token.clone()
|
||||
} else {
|
||||
file.auth_token.clone()
|
||||
};
|
||||
|
||||
let allow_insecure_auth = if args.allow_insecure_auth {
|
||||
true
|
||||
} else {
|
||||
file.allow_insecure_auth.unwrap_or(false)
|
||||
};
|
||||
|
||||
let sealed_sender = args.sealed_sender || file.sealed_sender.unwrap_or(false);
|
||||
|
||||
let store_backend = if args.store_backend == DEFAULT_STORE_BACKEND {
|
||||
file.store_backend
|
||||
.clone()
|
||||
.unwrap_or_else(|| DEFAULT_STORE_BACKEND.to_string())
|
||||
} else {
|
||||
args.store_backend.clone()
|
||||
};
|
||||
|
||||
let db_path = if args.db_path == PathBuf::from(DEFAULT_DB_PATH) {
|
||||
file.db_path
|
||||
.clone()
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_DB_PATH))
|
||||
} else {
|
||||
args.db_path.clone()
|
||||
};
|
||||
|
||||
let db_key = if args.db_key.is_empty() {
|
||||
file.db_key.clone().unwrap_or_else(|| args.db_key.clone())
|
||||
} else {
|
||||
args.db_key.clone()
|
||||
};
|
||||
|
||||
let metrics_listen = args
|
||||
.metrics_listen
|
||||
.clone()
|
||||
.or_else(|| file.metrics_listen.clone());
|
||||
let metrics_enabled = args
|
||||
.metrics_enabled
|
||||
.or(file.metrics_enabled)
|
||||
.unwrap_or(metrics_listen.is_some());
|
||||
|
||||
let federation = {
|
||||
let file_fed = file.federation.as_ref();
|
||||
let enabled = args.federation_enabled
|
||||
|| file_fed.and_then(|f| f.enabled).unwrap_or(false);
|
||||
|
||||
if enabled {
|
||||
let domain = args.federation_domain.clone()
|
||||
.or_else(|| file_fed.and_then(|f| f.domain.clone()))
|
||||
.unwrap_or_default();
|
||||
let listen_fed = args.federation_listen.clone()
|
||||
.or_else(|| file_fed.and_then(|f| f.listen.clone()))
|
||||
.unwrap_or_else(|| "0.0.0.0:7001".to_string());
|
||||
let federation_cert = file_fed.and_then(|f| f.federation_cert.clone())
|
||||
.unwrap_or_else(|| PathBuf::from("data/federation-cert.der"));
|
||||
let federation_key = file_fed.and_then(|f| f.federation_key.clone())
|
||||
.unwrap_or_else(|| PathBuf::from("data/federation-key.der"));
|
||||
let federation_ca = file_fed.and_then(|f| f.federation_ca.clone())
|
||||
.unwrap_or_else(|| PathBuf::from("data/federation-ca.der"));
|
||||
let peers = file_fed
|
||||
.map(|f| f.peers.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
Some(EffectiveFederationConfig {
|
||||
enabled,
|
||||
domain,
|
||||
listen: listen_fed,
|
||||
federation_cert,
|
||||
federation_key,
|
||||
federation_ca,
|
||||
peers,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
EffectiveConfig {
|
||||
listen,
|
||||
data_dir,
|
||||
tls_cert,
|
||||
tls_key,
|
||||
auth_token,
|
||||
allow_insecure_auth,
|
||||
sealed_sender,
|
||||
store_backend,
|
||||
db_path,
|
||||
db_key,
|
||||
metrics_listen,
|
||||
metrics_enabled,
|
||||
federation,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_production_config(effective: &EffectiveConfig) -> anyhow::Result<()> {
|
||||
if effective.allow_insecure_auth {
|
||||
anyhow::bail!("production forbids --allow-insecure-auth");
|
||||
}
|
||||
let token = effective
|
||||
.auth_token
|
||||
.as_deref()
|
||||
.filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("production requires QPQ_AUTH_TOKEN (non-empty)")
|
||||
})?;
|
||||
if token == "devtoken" {
|
||||
anyhow::bail!(
|
||||
"production forbids auth_token 'devtoken'; set a strong QPQ_AUTH_TOKEN"
|
||||
);
|
||||
}
|
||||
if effective.store_backend == "sql" && effective.db_key.is_empty() {
|
||||
anyhow::bail!("production with store_backend=sql requires non-empty QPQ_DB_KEY");
|
||||
}
|
||||
if effective.store_backend != "sql" {
|
||||
tracing::warn!(
|
||||
"production is using file-backed storage; \
|
||||
consider store_backend=sql with QPQ_DB_KEY for encryption at rest"
|
||||
);
|
||||
}
|
||||
if !effective.tls_cert.exists() || !effective.tls_key.exists() {
|
||||
anyhow::bail!(
|
||||
"production requires existing TLS cert and key (no auto-generation); provide QPQ_TLS_CERT and QPQ_TLS_KEY"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
33
crates/quicproquo-server/src/error_codes.rs
Normal file
33
crates/quicproquo-server/src/error_codes.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! Structured error codes for server RPC responses.
|
||||
//!
|
||||
//! Every `capnp::Error::failed()` message is prefixed with a stable code
|
||||
//! (E001–E020) so clients can match on the code without parsing free-text.
|
||||
|
||||
pub const E001_BAD_AUTH_VERSION: &str = "E001";
|
||||
pub const E002_EMPTY_TOKEN: &str = "E002";
|
||||
pub const E003_INVALID_TOKEN: &str = "E003";
|
||||
pub const E004_IDENTITY_KEY_LENGTH: &str = "E004";
|
||||
pub const E005_PAYLOAD_EMPTY: &str = "E005";
|
||||
pub const E006_PAYLOAD_TOO_LARGE: &str = "E006";
|
||||
pub const E007_PACKAGE_EMPTY: &str = "E007";
|
||||
pub const E008_PACKAGE_TOO_LARGE: &str = "E008";
|
||||
pub const E009_STORAGE_ERROR: &str = "E009";
|
||||
pub const E010_OPAQUE_ERROR: &str = "E010";
|
||||
pub const E011_USERNAME_EMPTY: &str = "E011";
|
||||
pub const E012_WIRE_VERSION: &str = "E012";
|
||||
pub const E013_HYBRID_KEY_EMPTY: &str = "E013";
|
||||
pub const E014_RATE_LIMITED: &str = "E014";
|
||||
pub const E015_QUEUE_FULL: &str = "E015";
|
||||
pub const E016_IDENTITY_MISMATCH: &str = "E016";
|
||||
pub const E017_SESSION_EXPIRED: &str = "E017";
|
||||
pub const E018_USER_EXISTS: &str = "E018";
|
||||
pub const E019_NO_PENDING_LOGIN: &str = "E019";
|
||||
pub const E020_BAD_PARAMS: &str = "E020";
|
||||
pub const E021_CIPHERSUITE_NOT_ALLOWED: &str = "E021";
|
||||
pub const E022_CHANNEL_ACCESS_DENIED: &str = "E022";
|
||||
pub const E023_CHANNEL_NOT_FOUND: &str = "E023";
|
||||
|
||||
/// Build a `capnp::Error::failed()` with the structured code prefix.
|
||||
pub fn coded_error(code: &str, msg: impl std::fmt::Display) -> capnp::Error {
|
||||
capnp::Error::failed(format!("{code}: {msg}"))
|
||||
}
|
||||
78
crates/quicproquo-server/src/federation/address.rs
Normal file
78
crates/quicproquo-server/src/federation/address.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
//! Parse `username@domain` federated addresses.
|
||||
//!
|
||||
//! A bare `username` (no `@`) is treated as local.
|
||||
|
||||
/// A parsed federated address.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct FederatedAddress {
|
||||
pub username: String,
|
||||
pub domain: Option<String>,
|
||||
}
|
||||
|
||||
impl FederatedAddress {
|
||||
/// Parse a `user@domain` string. Bare `user` → domain is `None`.
|
||||
pub fn parse(input: &str) -> Self {
|
||||
// Split on the *last* '@' so usernames can contain '@' in theory.
|
||||
match input.rsplit_once('@') {
|
||||
Some((user, domain)) if !domain.is_empty() && !user.is_empty() => Self {
|
||||
username: user.to_string(),
|
||||
domain: Some(domain.to_string()),
|
||||
},
|
||||
_ => Self {
|
||||
username: input.to_string(),
|
||||
domain: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this address refers to a local user (no domain or domain matches local).
|
||||
pub fn is_local(&self, local_domain: &str) -> bool {
|
||||
match &self.domain {
|
||||
None => true,
|
||||
Some(d) => d == local_domain,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn bare_username() {
|
||||
let addr = FederatedAddress::parse("alice");
|
||||
assert_eq!(addr.username, "alice");
|
||||
assert_eq!(addr.domain, None);
|
||||
assert!(addr.is_local("example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_at_domain() {
|
||||
let addr = FederatedAddress::parse("alice@remote.example.com");
|
||||
assert_eq!(addr.username, "alice");
|
||||
assert_eq!(addr.domain, Some("remote.example.com".into()));
|
||||
assert!(!addr.is_local("local.example.com"));
|
||||
assert!(addr.is_local("remote.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trailing_at_is_bare() {
|
||||
let addr = FederatedAddress::parse("alice@");
|
||||
assert_eq!(addr.username, "alice@");
|
||||
assert_eq!(addr.domain, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn leading_at_is_bare() {
|
||||
let addr = FederatedAddress::parse("@domain.com");
|
||||
assert_eq!(addr.username, "@domain.com");
|
||||
assert_eq!(addr.domain, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_at_uses_last() {
|
||||
let addr = FederatedAddress::parse("user@org@domain.com");
|
||||
assert_eq!(addr.username, "user@org");
|
||||
assert_eq!(addr.domain, Some("domain.com".into()));
|
||||
}
|
||||
}
|
||||
287
crates/quicproquo-server/src/federation/client.rs
Normal file
287
crates/quicproquo-server/src/federation/client.rs
Normal file
@@ -0,0 +1,287 @@
|
||||
//! Outbound federation client: connects to peer servers to relay messages.
|
||||
//!
|
||||
//! Uses a lazy connection pool (DashMap) to reuse QUIC connections to known peers.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use dashmap::DashMap;
|
||||
use quinn::Endpoint;
|
||||
|
||||
use crate::config::{EffectiveFederationConfig, FederationPeerConfig};
|
||||
|
||||
/// Outbound federation client for relaying to peer servers.
|
||||
pub struct FederationClient {
|
||||
/// Peer domain → address mapping from config.
|
||||
peer_addresses: HashMap<String, SocketAddr>,
|
||||
/// Lazy QUIC connection pool: domain → active Connection.
|
||||
connections: DashMap<String, quinn::Connection>,
|
||||
/// Local QUIC endpoint (shared for all outbound federation connections).
|
||||
endpoint: Endpoint,
|
||||
/// Local domain (for the FederationAuth.origin field).
|
||||
local_domain: String,
|
||||
}
|
||||
|
||||
impl FederationClient {
|
||||
/// Create a new federation client from config.
|
||||
///
|
||||
/// The `endpoint` should be configured with mTLS client credentials.
|
||||
pub fn new(
|
||||
config: &EffectiveFederationConfig,
|
||||
endpoint: Endpoint,
|
||||
) -> anyhow::Result<Self> {
|
||||
let mut peer_addresses = HashMap::new();
|
||||
for peer in &config.peers {
|
||||
let addr: SocketAddr = peer.address.parse().with_context(|| {
|
||||
format!("parse federation peer address '{}' for '{}'", peer.address, peer.domain)
|
||||
})?;
|
||||
peer_addresses.insert(peer.domain.clone(), addr);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
peer_addresses,
|
||||
connections: DashMap::new(),
|
||||
endpoint,
|
||||
local_domain: config.domain.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if we have a configured peer for the given domain.
|
||||
pub fn has_peer(&self, domain: &str) -> bool {
|
||||
self.peer_addresses.contains_key(domain)
|
||||
}
|
||||
|
||||
/// List all configured peer domains.
|
||||
pub fn peer_domains(&self) -> Vec<String> {
|
||||
self.peer_addresses.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Get the local domain.
|
||||
pub fn local_domain(&self) -> &str {
|
||||
&self.local_domain
|
||||
}
|
||||
|
||||
/// Relay a single enqueue to a remote peer. Returns the seq assigned by the remote server.
|
||||
pub async fn relay_enqueue(
|
||||
&self,
|
||||
domain: &str,
|
||||
recipient_key: &[u8],
|
||||
payload: &[u8],
|
||||
channel_id: &[u8],
|
||||
) -> anyhow::Result<u64> {
|
||||
let conn = self.get_or_connect(domain).await?;
|
||||
let (send, recv) = conn.open_bi().await.context("open bi stream to peer")?;
|
||||
|
||||
let (reader, writer) = (
|
||||
tokio_util::compat::TokioAsyncReadCompatExt::compat(recv),
|
||||
tokio_util::compat::TokioAsyncWriteCompatExt::compat_write(send),
|
||||
);
|
||||
|
||||
let rpc_network = capnp_rpc::twoparty::VatNetwork::new(
|
||||
reader,
|
||||
writer,
|
||||
capnp_rpc::rpc_twoparty_capnp::Side::Client,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let mut rpc_system = capnp_rpc::RpcSystem::new(Box::new(rpc_network), None);
|
||||
let client: quicproquo_proto::federation_capnp::federation_service::Client =
|
||||
rpc_system.bootstrap(capnp_rpc::rpc_twoparty_capnp::Side::Server);
|
||||
|
||||
tokio::task::spawn_local(rpc_system);
|
||||
|
||||
let mut req = client.relay_enqueue_request();
|
||||
{
|
||||
let mut builder = req.get();
|
||||
builder.set_recipient_key(recipient_key);
|
||||
builder.set_payload(payload);
|
||||
builder.set_channel_id(channel_id);
|
||||
builder.set_version(1);
|
||||
let mut auth = builder.init_auth();
|
||||
auth.set_origin(&self.local_domain);
|
||||
}
|
||||
|
||||
let response = req.send().promise.await
|
||||
.map_err(|e| anyhow::anyhow!("federation relay_enqueue failed: {e}"))?;
|
||||
let seq = response.get()
|
||||
.map_err(|e| anyhow::anyhow!("read relay_enqueue response: {e}"))?
|
||||
.get_seq();
|
||||
|
||||
Ok(seq)
|
||||
}
|
||||
|
||||
/// Proxy a key package fetch to a remote peer.
|
||||
pub async fn proxy_fetch_key_package(
|
||||
&self,
|
||||
domain: &str,
|
||||
identity_key: &[u8],
|
||||
) -> anyhow::Result<Option<Vec<u8>>> {
|
||||
let conn = self.get_or_connect(domain).await?;
|
||||
let (send, recv) = conn.open_bi().await.context("open bi stream to peer")?;
|
||||
|
||||
let (reader, writer) = (
|
||||
tokio_util::compat::TokioAsyncReadCompatExt::compat(recv),
|
||||
tokio_util::compat::TokioAsyncWriteCompatExt::compat_write(send),
|
||||
);
|
||||
|
||||
let rpc_network = capnp_rpc::twoparty::VatNetwork::new(
|
||||
reader,
|
||||
writer,
|
||||
capnp_rpc::rpc_twoparty_capnp::Side::Client,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let mut rpc_system = capnp_rpc::RpcSystem::new(Box::new(rpc_network), None);
|
||||
let client: quicproquo_proto::federation_capnp::federation_service::Client =
|
||||
rpc_system.bootstrap(capnp_rpc::rpc_twoparty_capnp::Side::Server);
|
||||
|
||||
tokio::task::spawn_local(rpc_system);
|
||||
|
||||
let mut req = client.proxy_fetch_key_package_request();
|
||||
{
|
||||
let mut builder = req.get();
|
||||
builder.set_identity_key(identity_key);
|
||||
let mut auth = builder.init_auth();
|
||||
auth.set_origin(&self.local_domain);
|
||||
}
|
||||
|
||||
let response = req.send().promise.await
|
||||
.map_err(|e| anyhow::anyhow!("federation proxy_fetch_key_package failed: {e}"))?;
|
||||
let pkg = response.get()
|
||||
.map_err(|e| anyhow::anyhow!("read proxy_fetch_key_package response: {e}"))?
|
||||
.get_package()
|
||||
.map_err(|e| anyhow::anyhow!("get package: {e}"))?;
|
||||
|
||||
if pkg.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(pkg.to_vec()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy a hybrid key fetch to a remote peer.
|
||||
pub async fn proxy_fetch_hybrid_key(
|
||||
&self,
|
||||
domain: &str,
|
||||
identity_key: &[u8],
|
||||
) -> anyhow::Result<Option<Vec<u8>>> {
|
||||
let conn = self.get_or_connect(domain).await?;
|
||||
let (send, recv) = conn.open_bi().await.context("open bi stream to peer")?;
|
||||
|
||||
let (reader, writer) = (
|
||||
tokio_util::compat::TokioAsyncReadCompatExt::compat(recv),
|
||||
tokio_util::compat::TokioAsyncWriteCompatExt::compat_write(send),
|
||||
);
|
||||
|
||||
let rpc_network = capnp_rpc::twoparty::VatNetwork::new(
|
||||
reader,
|
||||
writer,
|
||||
capnp_rpc::rpc_twoparty_capnp::Side::Client,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let mut rpc_system = capnp_rpc::RpcSystem::new(Box::new(rpc_network), None);
|
||||
let client: quicproquo_proto::federation_capnp::federation_service::Client =
|
||||
rpc_system.bootstrap(capnp_rpc::rpc_twoparty_capnp::Side::Server);
|
||||
|
||||
tokio::task::spawn_local(rpc_system);
|
||||
|
||||
let mut req = client.proxy_fetch_hybrid_key_request();
|
||||
{
|
||||
let mut builder = req.get();
|
||||
builder.set_identity_key(identity_key);
|
||||
let mut auth = builder.init_auth();
|
||||
auth.set_origin(&self.local_domain);
|
||||
}
|
||||
|
||||
let response = req.send().promise.await
|
||||
.map_err(|e| anyhow::anyhow!("federation proxy_fetch_hybrid_key failed: {e}"))?;
|
||||
let pk = response.get()
|
||||
.map_err(|e| anyhow::anyhow!("read proxy_fetch_hybrid_key response: {e}"))?
|
||||
.get_hybrid_public_key()
|
||||
.map_err(|e| anyhow::anyhow!("get hybrid_public_key: {e}"))?;
|
||||
|
||||
if pk.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(pk.to_vec()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy a user resolution to a remote peer.
|
||||
pub async fn proxy_resolve_user(
|
||||
&self,
|
||||
domain: &str,
|
||||
username: &str,
|
||||
) -> anyhow::Result<Option<Vec<u8>>> {
|
||||
let conn = self.get_or_connect(domain).await?;
|
||||
let (send, recv) = conn.open_bi().await.context("open bi stream to peer")?;
|
||||
|
||||
let (reader, writer) = (
|
||||
tokio_util::compat::TokioAsyncReadCompatExt::compat(recv),
|
||||
tokio_util::compat::TokioAsyncWriteCompatExt::compat_write(send),
|
||||
);
|
||||
|
||||
let rpc_network = capnp_rpc::twoparty::VatNetwork::new(
|
||||
reader,
|
||||
writer,
|
||||
capnp_rpc::rpc_twoparty_capnp::Side::Client,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let mut rpc_system = capnp_rpc::RpcSystem::new(Box::new(rpc_network), None);
|
||||
let client: quicproquo_proto::federation_capnp::federation_service::Client =
|
||||
rpc_system.bootstrap(capnp_rpc::rpc_twoparty_capnp::Side::Server);
|
||||
|
||||
tokio::task::spawn_local(rpc_system);
|
||||
|
||||
let mut req = client.proxy_resolve_user_request();
|
||||
{
|
||||
let mut builder = req.get();
|
||||
builder.set_username(username);
|
||||
let mut auth = builder.init_auth();
|
||||
auth.set_origin(&self.local_domain);
|
||||
}
|
||||
|
||||
let response = req.send().promise.await
|
||||
.map_err(|e| anyhow::anyhow!("federation proxy_resolve_user failed: {e}"))?;
|
||||
let key = response.get()
|
||||
.map_err(|e| anyhow::anyhow!("read proxy_resolve_user response: {e}"))?
|
||||
.get_identity_key()
|
||||
.map_err(|e| anyhow::anyhow!("get identity_key: {e}"))?;
|
||||
|
||||
if key.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(key.to_vec()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an existing connection or create a new one to a peer domain.
|
||||
async fn get_or_connect(&self, domain: &str) -> anyhow::Result<quinn::Connection> {
|
||||
// Check for cached connection that's still alive.
|
||||
if let Some(conn) = self.connections.get(domain) {
|
||||
if conn.close_reason().is_none() {
|
||||
return Ok(conn.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let addr = self.peer_addresses.get(domain).ok_or_else(|| {
|
||||
anyhow::anyhow!("no federation peer configured for domain '{domain}'")
|
||||
})?;
|
||||
|
||||
tracing::info!(domain = domain, addr = %addr, "connecting to federation peer");
|
||||
|
||||
let conn = self
|
||||
.endpoint
|
||||
.connect(*addr, domain)
|
||||
.map_err(|e| anyhow::anyhow!("federation connect to {domain}: {e}"))?
|
||||
.await
|
||||
.with_context(|| format!("federation QUIC handshake with {domain}"))?;
|
||||
|
||||
self.connections.insert(domain.to_string(), conn.clone());
|
||||
Ok(conn)
|
||||
}
|
||||
}
|
||||
16
crates/quicproquo-server/src/federation/mod.rs
Normal file
16
crates/quicproquo-server/src/federation/mod.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
//! Federation subsystem: server-to-server message relay over mutual TLS + QUIC.
|
||||
//!
|
||||
//! When federation is enabled, the server binds a second QUIC endpoint on a
|
||||
//! dedicated port (default 7001) that only accepts connections from known peers
|
||||
//! authenticated via mTLS. Inbound requests are handled by [`service::FederationServiceImpl`],
|
||||
//! which delegates to the local [`Store`]. Outbound relay uses [`client::FederationClient`].
|
||||
|
||||
pub mod address;
|
||||
pub mod client;
|
||||
pub mod routing;
|
||||
pub mod service;
|
||||
pub mod tls;
|
||||
|
||||
pub use address::FederatedAddress;
|
||||
pub use client::FederationClient;
|
||||
pub use routing::Destination;
|
||||
44
crates/quicproquo-server/src/federation/routing.rs
Normal file
44
crates/quicproquo-server/src/federation/routing.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
//! Federation routing: determine whether a recipient is local or remote.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::storage::Store;
|
||||
|
||||
/// Where a message should be delivered.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Destination {
|
||||
/// Recipient is on this server.
|
||||
Local,
|
||||
/// Recipient's home server is the given domain.
|
||||
Remote(String),
|
||||
}
|
||||
|
||||
/// Resolve a recipient identity key to a routing destination.
|
||||
///
|
||||
/// 1. Check the `identity_home_servers` table for an explicit mapping.
|
||||
/// 2. If no mapping exists, assume local (backwards compatible with single-server deployments).
|
||||
pub fn resolve_destination(
|
||||
store: &Arc<dyn Store>,
|
||||
recipient_key: &[u8],
|
||||
local_domain: &str,
|
||||
) -> Destination {
|
||||
match store.get_identity_home_server(recipient_key) {
|
||||
Ok(Some(domain)) if domain != local_domain => Destination::Remote(domain),
|
||||
_ => Destination::Local,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn unknown_identity_routes_local() {
|
||||
let store: Arc<dyn Store> =
|
||||
Arc::new(crate::storage::FileBackedStore::open(
|
||||
tempfile::tempdir().unwrap().path(),
|
||||
).unwrap());
|
||||
let dest = resolve_destination(&store, &[1u8; 32], "local.example.com");
|
||||
assert_eq!(dest, Destination::Local);
|
||||
}
|
||||
}
|
||||
201
crates/quicproquo-server/src/federation/service.rs
Normal file
201
crates/quicproquo-server/src/federation/service.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
//! Inbound federation handler: implements `FederationService` Cap'n Proto interface.
|
||||
//!
|
||||
//! Delegates all operations to the local [`Store`], acting as a trusted relay
|
||||
//! from authenticated peer servers.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use capnp::capability::Promise;
|
||||
use quicproquo_proto::federation_capnp::federation_service;
|
||||
use tokio::sync::Notify;
|
||||
use dashmap::DashMap;
|
||||
|
||||
use crate::storage::Store;
|
||||
|
||||
/// Inbound federation RPC handler.
|
||||
pub struct FederationServiceImpl {
|
||||
pub store: Arc<dyn Store>,
|
||||
pub waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
pub local_domain: String,
|
||||
}
|
||||
|
||||
impl federation_service::Server for FederationServiceImpl {
|
||||
fn relay_enqueue(
|
||||
&mut self,
|
||||
params: federation_service::RelayEnqueueParams,
|
||||
mut results: federation_service::RelayEnqueueResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad recipient_key: {e}"))),
|
||||
};
|
||||
let payload = match p.get_payload() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad payload: {e}"))),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
|
||||
if let Ok(a) = p.get_auth() {
|
||||
if let Ok(origin) = a.get_origin() {
|
||||
let origin = origin.to_str().unwrap_or("?");
|
||||
tracing::debug!(origin = origin, "federation relay_enqueue");
|
||||
}
|
||||
}
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(capnp::Error::failed("recipient_key must be 32 bytes".into()));
|
||||
}
|
||||
if payload.is_empty() {
|
||||
return Promise::err(capnp::Error::failed("payload must not be empty".into()));
|
||||
}
|
||||
|
||||
let seq = match self.store.enqueue(&recipient_key, &channel_id, payload) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("store error: {e}"))),
|
||||
};
|
||||
|
||||
results.get().set_seq(seq);
|
||||
|
||||
// Wake any waiting fetchWait clients.
|
||||
if let Some(waiter) = self.waiters.get(&recipient_key) {
|
||||
waiter.notify_waiters();
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
recipient_prefix = %hex::encode(&recipient_key[..4]),
|
||||
seq = seq,
|
||||
"federation: relayed enqueue"
|
||||
);
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
fn relay_batch_enqueue(
|
||||
&mut self,
|
||||
params: federation_service::RelayBatchEnqueueParams,
|
||||
mut results: federation_service::RelayBatchEnqueueResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
let recipient_keys = match p.get_recipient_keys() {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad recipient_keys: {e}"))),
|
||||
};
|
||||
let payload = match p.get_payload() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad payload: {e}"))),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
|
||||
let mut seqs = Vec::with_capacity(recipient_keys.len() as usize);
|
||||
for i in 0..recipient_keys.len() {
|
||||
let rk = match recipient_keys.get(i) {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad key[{i}]: {e}"))),
|
||||
};
|
||||
if rk.len() != 32 {
|
||||
return Promise::err(capnp::Error::failed(
|
||||
format!("recipient_key[{i}] must be 32 bytes"),
|
||||
));
|
||||
}
|
||||
let seq = match self.store.enqueue(&rk, &channel_id, payload.clone()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("store error: {e}"))),
|
||||
};
|
||||
seqs.push(seq);
|
||||
if let Some(waiter) = self.waiters.get(&rk) {
|
||||
waiter.notify_waiters();
|
||||
}
|
||||
}
|
||||
|
||||
let mut list = results.get().init_seqs(seqs.len() as u32);
|
||||
for (i, seq) in seqs.iter().enumerate() {
|
||||
list.set(i as u32, *seq);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
recipient_count = recipient_keys.len(),
|
||||
"federation: relayed batch_enqueue"
|
||||
);
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
fn proxy_fetch_key_package(
|
||||
&mut self,
|
||||
params: federation_service::ProxyFetchKeyPackageParams,
|
||||
mut results: federation_service::ProxyFetchKeyPackageResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let identity_key = match params.get().and_then(|p| p.get_identity_key()) {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
match self.store.fetch_key_package(&identity_key) {
|
||||
Ok(Some(pkg)) => results.get().set_package(&pkg),
|
||||
Ok(None) => results.get().set_package(&[]),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("store error: {e}"))),
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
fn proxy_fetch_hybrid_key(
|
||||
&mut self,
|
||||
params: federation_service::ProxyFetchHybridKeyParams,
|
||||
mut results: federation_service::ProxyFetchHybridKeyResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let identity_key = match params.get().and_then(|p| p.get_identity_key()) {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
match self.store.fetch_hybrid_key(&identity_key) {
|
||||
Ok(Some(pk)) => results.get().set_hybrid_public_key(&pk),
|
||||
Ok(None) => results.get().set_hybrid_public_key(&[]),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("store error: {e}"))),
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
fn proxy_resolve_user(
|
||||
&mut self,
|
||||
params: federation_service::ProxyResolveUserParams,
|
||||
mut results: federation_service::ProxyResolveUserResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let username = match params.get().and_then(|p| p.get_username()) {
|
||||
Ok(u) => match u.to_str() {
|
||||
Ok(s) => s.to_string(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad utf-8: {e}"))),
|
||||
},
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
match self.store.get_user_identity_key(&username) {
|
||||
Ok(Some(key)) => results.get().set_identity_key(&key),
|
||||
Ok(None) => results.get().set_identity_key(&[]),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("store error: {e}"))),
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
fn federation_health(
|
||||
&mut self,
|
||||
_params: federation_service::FederationHealthParams,
|
||||
mut results: federation_service::FederationHealthResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
results.get().set_status("ok");
|
||||
results.get().set_server_domain(&self.local_domain);
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
85
crates/quicproquo-server/src/federation/tls.rs
Normal file
85
crates/quicproquo-server/src/federation/tls.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
//! Build mTLS server/client configs for the federation endpoint.
|
||||
//!
|
||||
//! Federation uses a separate CA from the public-facing QUIC endpoint.
|
||||
//! Both server and client present certificates; the server verifies the client
|
||||
//! cert is signed by the federation CA.
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use quinn::ServerConfig;
|
||||
use quinn_proto::crypto::rustls::QuicServerConfig;
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use rustls::version::TLS13;
|
||||
|
||||
/// Build a QUIC server config for the federation listener with mutual TLS.
|
||||
///
|
||||
/// `cert`/`key`: this server's federation certificate and private key.
|
||||
/// `ca`: the federation CA certificate used to verify peer certificates.
|
||||
pub fn build_federation_server_config(
|
||||
cert_path: &Path,
|
||||
key_path: &Path,
|
||||
ca_path: &Path,
|
||||
) -> anyhow::Result<ServerConfig> {
|
||||
let cert_bytes = std::fs::read(cert_path)
|
||||
.with_context(|| format!("read federation cert: {:?}", cert_path))?;
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.with_context(|| format!("read federation key: {:?}", key_path))?;
|
||||
let ca_bytes = std::fs::read(ca_path)
|
||||
.with_context(|| format!("read federation CA: {:?}", ca_path))?;
|
||||
|
||||
let cert_chain = vec![CertificateDer::from(cert_bytes)];
|
||||
let key = PrivateKeyDer::try_from(key_bytes)
|
||||
.map_err(|_| anyhow::anyhow!("invalid federation private key"))?;
|
||||
|
||||
// Build a root cert store with the federation CA for client verification.
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
root_store
|
||||
.add(CertificateDer::from(ca_bytes))
|
||||
.context("add federation CA to root store")?;
|
||||
|
||||
let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
|
||||
.build()
|
||||
.context("build client cert verifier")?;
|
||||
|
||||
let mut tls = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13])
|
||||
.with_client_cert_verifier(client_verifier)
|
||||
.with_single_cert(cert_chain, key)?;
|
||||
tls.alpn_protocols = vec![b"qnpc-fed".to_vec()];
|
||||
|
||||
let crypto = QuicServerConfig::try_from(tls)
|
||||
.map_err(|e| anyhow::anyhow!("invalid federation server TLS config: {e}"))?;
|
||||
|
||||
Ok(ServerConfig::with_crypto(Arc::new(crypto)))
|
||||
}
|
||||
|
||||
/// Build a QUIC client config for connecting to a federation peer with mutual TLS.
|
||||
pub fn build_federation_client_config(
|
||||
cert_path: &Path,
|
||||
key_path: &Path,
|
||||
ca_path: &Path,
|
||||
) -> anyhow::Result<rustls::ClientConfig> {
|
||||
let cert_bytes = std::fs::read(cert_path)
|
||||
.with_context(|| format!("read federation cert: {:?}", cert_path))?;
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.with_context(|| format!("read federation key: {:?}", key_path))?;
|
||||
let ca_bytes = std::fs::read(ca_path)
|
||||
.with_context(|| format!("read federation CA: {:?}", ca_path))?;
|
||||
|
||||
let cert_chain = vec![CertificateDer::from(cert_bytes)];
|
||||
let key = PrivateKeyDer::try_from(key_bytes)
|
||||
.map_err(|_| anyhow::anyhow!("invalid federation client private key"))?;
|
||||
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
root_store
|
||||
.add(CertificateDer::from(ca_bytes))
|
||||
.context("add federation CA to root store")?;
|
||||
|
||||
let tls = rustls::ClientConfig::builder_with_protocol_versions(&[&TLS13])
|
||||
.with_root_certificates(root_store)
|
||||
.with_client_auth_cert(cert_chain, key)
|
||||
.context("set client auth cert")?;
|
||||
|
||||
Ok(tls)
|
||||
}
|
||||
505
crates/quicproquo-server/src/main.rs
Normal file
505
crates/quicproquo-server/src/main.rs
Normal file
@@ -0,0 +1,505 @@
|
||||
//! qpq-server — unified Authentication + Delivery service.
|
||||
//!
|
||||
//! The server hosts Authentication + Delivery services over QUIC + Cap'n Proto.
|
||||
|
||||
use std::{net::IpAddr, net::SocketAddr, path::PathBuf, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use dashmap::DashMap;
|
||||
use opaque_ke::ServerSetup;
|
||||
use quicproquo_core::opaque_auth::OpaqueSuite;
|
||||
use quinn::Endpoint;
|
||||
use rand::rngs::OsRng;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::task::LocalSet;
|
||||
|
||||
mod auth;
|
||||
mod config;
|
||||
mod error_codes;
|
||||
mod federation;
|
||||
mod metrics;
|
||||
mod node_service;
|
||||
mod sql_store;
|
||||
mod tls;
|
||||
mod storage;
|
||||
|
||||
use auth::{AuthConfig, PendingLogin, RateEntry, SessionInfo};
|
||||
use config::{
|
||||
load_config, merge_config, validate_production_config, DEFAULT_DATA_DIR, DEFAULT_DB_PATH,
|
||||
DEFAULT_LISTEN, DEFAULT_STORE_BACKEND, DEFAULT_TLS_CERT, DEFAULT_TLS_KEY,
|
||||
};
|
||||
use node_service::{handle_node_connection, spawn_cleanup_task};
|
||||
use sql_store::SqlStore;
|
||||
use storage::{FileBackedStore, Store};
|
||||
use tls::build_server_config;
|
||||
|
||||
// ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
#[command(
|
||||
name = "qpq-server",
|
||||
about = "quicproquo Delivery Service + Authentication Service",
|
||||
version
|
||||
)]
|
||||
struct Args {
|
||||
/// Optional path to a TOML config file (fields map to CLI flags).
|
||||
#[arg(long, env = "QPQ_CONFIG")]
|
||||
config: Option<PathBuf>,
|
||||
|
||||
/// QUIC listen address (host:port).
|
||||
#[arg(long, default_value = DEFAULT_LISTEN, env = "QPQ_LISTEN")]
|
||||
listen: String,
|
||||
|
||||
/// Directory for persisted server data (KeyPackages + delivery queues).
|
||||
#[arg(long, default_value = DEFAULT_DATA_DIR, env = "QPQ_DATA_DIR")]
|
||||
data_dir: String,
|
||||
|
||||
/// TLS certificate path (generated automatically if missing).
|
||||
#[arg(long, default_value = DEFAULT_TLS_CERT, env = "QPQ_TLS_CERT")]
|
||||
tls_cert: PathBuf,
|
||||
|
||||
/// TLS private key path (generated automatically if missing).
|
||||
#[arg(long, default_value = DEFAULT_TLS_KEY, env = "QPQ_TLS_KEY")]
|
||||
tls_key: PathBuf,
|
||||
|
||||
/// Required bearer token for auth.version=1 requests. Use --allow-insecure-auth to run without it (dev only).
|
||||
#[arg(long, env = "QPQ_AUTH_TOKEN")]
|
||||
auth_token: Option<String>,
|
||||
|
||||
/// Allow running without QPQ_AUTH_TOKEN (development only).
|
||||
#[arg(long, env = "QPQ_ALLOW_INSECURE_AUTH", default_value_t = false)]
|
||||
allow_insecure_auth: bool,
|
||||
|
||||
/// Enable Sealed Sender: enqueue does not require identity-bound session, only a valid token.
|
||||
#[arg(long, env = "QPQ_SEALED_SENDER", default_value_t = false)]
|
||||
sealed_sender: bool,
|
||||
|
||||
/// Storage backend: "file" (bincode) or "sql" (SQLCipher-encrypted).
|
||||
#[arg(long, default_value = DEFAULT_STORE_BACKEND, env = "QPQ_STORE_BACKEND")]
|
||||
store_backend: String,
|
||||
|
||||
/// Path to the SQLCipher database file (only used when --store-backend=sql).
|
||||
#[arg(long, default_value = DEFAULT_DB_PATH, env = "QPQ_DB_PATH")]
|
||||
db_path: PathBuf,
|
||||
|
||||
/// SQLCipher encryption key. Empty string disables encryption.
|
||||
#[arg(long, default_value = "", env = "QPQ_DB_KEY")]
|
||||
db_key: String,
|
||||
|
||||
/// Metrics HTTP listen address (e.g. 0.0.0.0:9090). If set and metrics enabled, /metrics is served.
|
||||
#[arg(long, env = "QPQ_METRICS_LISTEN")]
|
||||
metrics_listen: Option<String>,
|
||||
|
||||
/// Enable metrics server when metrics_listen is set.
|
||||
#[arg(long, env = "QPQ_METRICS_ENABLED")]
|
||||
metrics_enabled: Option<bool>,
|
||||
|
||||
/// Enable federation (server-to-server message relay).
|
||||
#[arg(long, env = "QPQ_FEDERATION_ENABLED", default_value_t = false)]
|
||||
federation_enabled: bool,
|
||||
|
||||
/// This server's domain for federation addressing (e.g. "chat.example.com").
|
||||
#[arg(long, env = "QPQ_FEDERATION_DOMAIN")]
|
||||
federation_domain: Option<String>,
|
||||
|
||||
/// Federation QUIC listen address (default: 0.0.0.0:7001).
|
||||
#[arg(long, env = "QPQ_FEDERATION_LISTEN")]
|
||||
federation_listen: Option<String>,
|
||||
}
|
||||
|
||||
// ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let args = Args::parse();
|
||||
let file_cfg = load_config(args.config.as_deref())?;
|
||||
let effective = merge_config(&args, &file_cfg);
|
||||
|
||||
let production = std::env::var("QPQ_PRODUCTION")
|
||||
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
|
||||
.unwrap_or(false);
|
||||
if production {
|
||||
validate_production_config(&effective)?;
|
||||
}
|
||||
|
||||
// Optional metrics server: only start when metrics_enabled and metrics_listen are set.
|
||||
if effective.metrics_enabled {
|
||||
if let Some(addr_str) = &effective.metrics_listen {
|
||||
let addr: std::net::SocketAddr = addr_str
|
||||
.parse()
|
||||
.context("metrics_listen must be host:port (e.g. 0.0.0.0:9090)")?;
|
||||
metrics_exporter_prometheus::PrometheusBuilder::new()
|
||||
.with_http_listener(addr)
|
||||
.install()
|
||||
.context("failed to install Prometheus metrics exporter")?;
|
||||
tracing::info!(addr = %addr_str, "metrics server listening on /metrics");
|
||||
}
|
||||
}
|
||||
|
||||
// In non-production, require an explicit opt-out before running without a static token.
|
||||
if !production
|
||||
&& effective
|
||||
.auth_token
|
||||
.as_deref()
|
||||
.map(|s| s.is_empty())
|
||||
.unwrap_or(true)
|
||||
&& !effective.allow_insecure_auth
|
||||
{
|
||||
anyhow::bail!(
|
||||
"missing QPQ_AUTH_TOKEN; set one or pass --allow-insecure-auth for development"
|
||||
);
|
||||
}
|
||||
|
||||
if effective.allow_insecure_auth
|
||||
&& effective
|
||||
.auth_token
|
||||
.as_deref()
|
||||
.map(|s| s.is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
tracing::warn!("running without QPQ_AUTH_TOKEN (allow-insecure-auth enabled); development only");
|
||||
}
|
||||
|
||||
let listen: SocketAddr = effective
|
||||
.listen
|
||||
.parse()
|
||||
.context("--listen must be host:port")?;
|
||||
|
||||
let mut server_config = build_server_config(&effective.tls_cert, &effective.tls_key, production)
|
||||
.context("failed to build TLS/QUIC server config")?;
|
||||
|
||||
// Harden QUIC transport: idle timeout, limit stream concurrency.
|
||||
let mut transport = quinn::TransportConfig::default();
|
||||
transport.max_idle_timeout(Some(
|
||||
std::time::Duration::from_secs(300)
|
||||
.try_into()
|
||||
.expect("300s is a valid IdleTimeout"),
|
||||
));
|
||||
transport.max_concurrent_bidi_streams(1u32.into());
|
||||
transport.max_concurrent_uni_streams(0u32.into());
|
||||
server_config.transport_config(Arc::new(transport));
|
||||
|
||||
// Shared storage — persisted to disk for restart safety.
|
||||
let store: Arc<dyn Store> = match effective.store_backend.as_str() {
|
||||
"sql" => {
|
||||
if let Some(parent) = effective.db_path.parent() {
|
||||
std::fs::create_dir_all(parent).context("create db dir")?;
|
||||
}
|
||||
tracing::info!(
|
||||
path = %effective.db_path.display(),
|
||||
encrypted = !effective.db_key.is_empty(),
|
||||
"opening SQLCipher store"
|
||||
);
|
||||
if effective.db_key.is_empty() {
|
||||
tracing::warn!("db_key is empty; SQL store will be plaintext (development only)");
|
||||
}
|
||||
Arc::new(SqlStore::open(&effective.db_path, &effective.db_key)?)
|
||||
}
|
||||
"file" | _ => {
|
||||
tracing::info!(dir = %effective.data_dir, "opening file-backed store");
|
||||
Arc::new(FileBackedStore::open(&effective.data_dir)?)
|
||||
}
|
||||
};
|
||||
|
||||
let auth_cfg = Arc::new(AuthConfig::new(
|
||||
effective.auth_token.clone(),
|
||||
effective.allow_insecure_auth,
|
||||
));
|
||||
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = Arc::new(DashMap::new());
|
||||
|
||||
// OPAQUE ServerSetup: load from storage or generate fresh.
|
||||
let opaque_setup: Arc<ServerSetup<OpaqueSuite>> = match store.get_server_setup() {
|
||||
Ok(Some(bytes)) => {
|
||||
let setup = ServerSetup::<OpaqueSuite>::deserialize(&bytes)
|
||||
.map_err(|e| anyhow::anyhow!("corrupt OPAQUE server setup: {e}"))?;
|
||||
tracing::info!("loaded persisted OPAQUE ServerSetup");
|
||||
Arc::new(setup)
|
||||
}
|
||||
Ok(None) => {
|
||||
let setup = ServerSetup::<OpaqueSuite>::new(&mut OsRng);
|
||||
let bytes = setup.serialize().to_vec();
|
||||
store
|
||||
.store_server_setup(bytes)
|
||||
.context("persist OPAQUE ServerSetup")?;
|
||||
tracing::info!("generated and persisted new OPAQUE ServerSetup");
|
||||
Arc::new(setup)
|
||||
}
|
||||
Err(e) => return Err(anyhow::anyhow!("load OPAQUE server setup: {e}")),
|
||||
};
|
||||
|
||||
let pending_logins: Arc<DashMap<String, PendingLogin>> = Arc::new(DashMap::new());
|
||||
let sessions: Arc<DashMap<Vec<u8>, SessionInfo>> = Arc::new(DashMap::new());
|
||||
let rate_limits: Arc<DashMap<Vec<u8>, RateEntry>> = Arc::new(DashMap::new());
|
||||
let conn_rate_limits: Arc<DashMap<IpAddr, RateEntry>> = Arc::new(DashMap::new());
|
||||
|
||||
// Background cleanup task (expire sessions, pending logins, rate limits, and stale messages).
|
||||
spawn_cleanup_task(
|
||||
Arc::clone(&sessions),
|
||||
Arc::clone(&pending_logins),
|
||||
Arc::clone(&rate_limits),
|
||||
Arc::clone(&conn_rate_limits),
|
||||
Arc::clone(&store),
|
||||
Arc::clone(&waiters),
|
||||
);
|
||||
|
||||
let endpoint = Endpoint::server(server_config, listen)?;
|
||||
|
||||
tracing::info!(
|
||||
addr = %effective.listen,
|
||||
"accepting QUIC connections"
|
||||
);
|
||||
|
||||
// ── Federation setup ─────────────────────────────────────────────────────
|
||||
let federation_client: Option<Arc<federation::FederationClient>> =
|
||||
if let Some(fed_cfg) = &effective.federation {
|
||||
tracing::info!(
|
||||
domain = %fed_cfg.domain,
|
||||
listen = %fed_cfg.listen,
|
||||
peers = fed_cfg.peers.len(),
|
||||
"federation enabled"
|
||||
);
|
||||
|
||||
// Build a client endpoint for outbound federation connections.
|
||||
// For now we create a simple endpoint; full mTLS is used when certs are provided.
|
||||
let client_config = if fed_cfg.federation_cert.exists()
|
||||
&& fed_cfg.federation_key.exists()
|
||||
&& fed_cfg.federation_ca.exists()
|
||||
{
|
||||
let tls_cfg = federation::tls::build_federation_client_config(
|
||||
&fed_cfg.federation_cert,
|
||||
&fed_cfg.federation_key,
|
||||
&fed_cfg.federation_ca,
|
||||
)
|
||||
.context("build federation client TLS config")?;
|
||||
|
||||
let crypto = quinn::crypto::rustls::QuicClientConfig::try_from(tls_cfg)
|
||||
.map_err(|e| anyhow::anyhow!("invalid federation client QUIC config: {e}"))?;
|
||||
let mut qc = quinn::ClientConfig::new(Arc::new(crypto));
|
||||
let mut transport = quinn::TransportConfig::default();
|
||||
transport.max_idle_timeout(Some(
|
||||
std::time::Duration::from_secs(120)
|
||||
.try_into()
|
||||
.expect("120s is valid"),
|
||||
));
|
||||
qc.transport_config(Arc::new(transport));
|
||||
Some(qc)
|
||||
} else {
|
||||
tracing::warn!("federation cert/key/CA not found; outbound federation connections will fail");
|
||||
None
|
||||
};
|
||||
|
||||
let fed_bind: SocketAddr = "0.0.0.0:0".parse().unwrap();
|
||||
let mut fed_endpoint = Endpoint::client(fed_bind)
|
||||
.context("create federation client endpoint")?;
|
||||
if let Some(cc) = client_config {
|
||||
fed_endpoint.set_default_client_config(cc);
|
||||
}
|
||||
|
||||
let client = federation::FederationClient::new(fed_cfg, fed_endpoint)
|
||||
.context("create federation client")?;
|
||||
|
||||
// Register configured peers in storage.
|
||||
for peer in &fed_cfg.peers {
|
||||
if let Err(e) = store.upsert_federation_peer(&peer.domain, true) {
|
||||
tracing::warn!(domain = %peer.domain, error = %e, "failed to register federation peer");
|
||||
}
|
||||
}
|
||||
|
||||
Some(Arc::new(client))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let local_domain: Option<String> = effective.federation.as_ref().map(|f| f.domain.clone());
|
||||
|
||||
// ── Federation listener ──────────────────────────────────────────────────
|
||||
let federation_endpoint: Option<Endpoint> =
|
||||
if let Some(fed_cfg) = &effective.federation {
|
||||
if fed_cfg.federation_cert.exists()
|
||||
&& fed_cfg.federation_key.exists()
|
||||
&& fed_cfg.federation_ca.exists()
|
||||
{
|
||||
let fed_server_config = federation::tls::build_federation_server_config(
|
||||
&fed_cfg.federation_cert,
|
||||
&fed_cfg.federation_key,
|
||||
&fed_cfg.federation_ca,
|
||||
)
|
||||
.context("build federation server TLS config")?;
|
||||
|
||||
let fed_listen: SocketAddr = fed_cfg
|
||||
.listen
|
||||
.parse()
|
||||
.context("federation listen must be host:port")?;
|
||||
|
||||
let ep = Endpoint::server(fed_server_config, fed_listen)
|
||||
.context("bind federation QUIC endpoint")?;
|
||||
|
||||
tracing::info!(addr = %fed_cfg.listen, "federation endpoint listening");
|
||||
Some(ep)
|
||||
} else {
|
||||
tracing::warn!("federation certs not found; federation listener not started");
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// capnp-rpc is !Send (Rc internals), so all RPC tasks must stay on a LocalSet.
|
||||
let local = LocalSet::new();
|
||||
local
|
||||
.run_until(async move {
|
||||
// Spawn federation acceptor if enabled.
|
||||
if let Some(fed_ep) = federation_endpoint {
|
||||
let fed_store = Arc::clone(&store);
|
||||
let fed_waiters = Arc::clone(&waiters);
|
||||
let fed_domain = local_domain.clone().unwrap_or_default();
|
||||
tokio::task::spawn_local(async move {
|
||||
loop {
|
||||
let incoming = match fed_ep.accept().await {
|
||||
Some(i) => i,
|
||||
None => break,
|
||||
};
|
||||
let connecting = match incoming.accept() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "federation: accept error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let store = Arc::clone(&fed_store);
|
||||
let waiters = Arc::clone(&fed_waiters);
|
||||
let domain = fed_domain.clone();
|
||||
tokio::task::spawn_local(async move {
|
||||
match connecting.await {
|
||||
Ok(conn) => {
|
||||
tracing::info!(
|
||||
peer = %conn.remote_address(),
|
||||
"federation: peer connected"
|
||||
);
|
||||
let (send, recv) = match conn.accept_bi().await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "federation: accept bi error");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let reader = tokio_util::compat::TokioAsyncReadCompatExt::compat(recv);
|
||||
let writer = tokio_util::compat::TokioAsyncWriteCompatExt::compat_write(send);
|
||||
|
||||
let network = capnp_rpc::twoparty::VatNetwork::new(
|
||||
reader,
|
||||
writer,
|
||||
capnp_rpc::rpc_twoparty_capnp::Side::Server,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let service_impl = federation::service::FederationServiceImpl {
|
||||
store,
|
||||
waiters,
|
||||
local_domain: domain,
|
||||
};
|
||||
let client: quicproquo_proto::federation_capnp::federation_service::Client =
|
||||
capnp_rpc::new_client(service_impl);
|
||||
|
||||
if let Err(e) = capnp_rpc::RpcSystem::new(
|
||||
Box::new(network),
|
||||
Some(client.client),
|
||||
).await {
|
||||
tracing::warn!(error = %e, "federation: RPC error");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "federation: connection error");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
incoming = endpoint.accept() => {
|
||||
let incoming = match incoming {
|
||||
Some(i) => i,
|
||||
None => break,
|
||||
};
|
||||
|
||||
// Per-IP connection rate limiting.
|
||||
let remote_ip = incoming.remote_address().ip();
|
||||
if !auth::check_conn_rate_limit(&conn_rate_limits, remote_ip) {
|
||||
tracing::warn!(ip = %remote_ip, "connection rate limit exceeded, dropping");
|
||||
incoming.refuse();
|
||||
continue;
|
||||
}
|
||||
|
||||
let connecting = match incoming.accept() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to accept incoming connection");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let store = Arc::clone(&store);
|
||||
let waiters = Arc::clone(&waiters);
|
||||
let auth_cfg = Arc::clone(&auth_cfg);
|
||||
let opaque_setup = Arc::clone(&opaque_setup);
|
||||
let pending_logins = Arc::clone(&pending_logins);
|
||||
let sessions = Arc::clone(&sessions);
|
||||
let rate_limits = Arc::clone(&rate_limits);
|
||||
let sealed_sender = effective.sealed_sender;
|
||||
let fed_client = federation_client.clone();
|
||||
let local_dom = local_domain.clone();
|
||||
|
||||
tokio::task::spawn_local(async move {
|
||||
if let Err(e) = handle_node_connection(
|
||||
connecting,
|
||||
store,
|
||||
waiters,
|
||||
auth_cfg,
|
||||
opaque_setup,
|
||||
pending_logins,
|
||||
sessions,
|
||||
rate_limits,
|
||||
sealed_sender,
|
||||
fed_client,
|
||||
local_dom,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(error = %e, "connection error");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
tracing::info!("shutdown signal received, draining QUIC connections");
|
||||
endpoint.close(0u32.into(), b"server shutdown");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Grace period: let in-flight RPC tasks on the LocalSet finish.
|
||||
tracing::info!("waiting up to 5s for in-flight RPCs to complete");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
|
||||
Ok::<(), anyhow::Error>(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
49
crates/quicproquo-server/src/metrics.rs
Normal file
49
crates/quicproquo-server/src/metrics.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
//! Prometheus metrics for the server.
|
||||
//!
|
||||
//! All counters/histograms/gauges use the `metrics` crate and are exported
|
||||
//! via metrics-exporter-prometheus on a configurable HTTP port (e.g. /metrics).
|
||||
|
||||
/// Record one enqueue (success). Call after a message is enqueued.
|
||||
pub fn record_enqueue_total() {
|
||||
metrics::counter!("enqueue_total").increment(1);
|
||||
}
|
||||
|
||||
/// Record enqueued payload size in bytes.
|
||||
pub fn record_enqueue_bytes(bytes: u64) {
|
||||
metrics::counter!("enqueue_bytes_total").increment(bytes);
|
||||
}
|
||||
|
||||
/// Record one fetch (success). Call when fetch returns.
|
||||
pub fn record_fetch_total() {
|
||||
metrics::counter!("fetch_total").increment(1);
|
||||
}
|
||||
|
||||
/// Record one fetch_wait (success). Call when fetch_wait returns.
|
||||
pub fn record_fetch_wait_total() {
|
||||
metrics::counter!("fetch_wait_total").increment(1);
|
||||
}
|
||||
|
||||
/// Set the delivery queue depth gauge (sample). Updated at enqueue/fetch time.
|
||||
pub fn record_delivery_queue_depth(depth: usize) {
|
||||
metrics::gauge!("delivery_queue_depth").set(depth as f64);
|
||||
}
|
||||
|
||||
/// Record one KeyPackage upload (success).
|
||||
pub fn record_key_package_upload_total() {
|
||||
metrics::counter!("key_package_upload_total").increment(1);
|
||||
}
|
||||
|
||||
/// Record successful auth login (session token issued).
|
||||
pub fn record_auth_login_success_total() {
|
||||
metrics::counter!("auth_login_success_total").increment(1);
|
||||
}
|
||||
|
||||
/// Record failed auth login attempt.
|
||||
pub fn record_auth_login_failure_total() {
|
||||
metrics::counter!("auth_login_failure_total").increment(1);
|
||||
}
|
||||
|
||||
/// Record rate limit hit (enqueue rejected).
|
||||
pub fn record_rate_limit_hit_total() {
|
||||
metrics::counter!("rate_limit_hit_total").increment(1);
|
||||
}
|
||||
373
crates/quicproquo-server/src/node_service/auth_ops.rs
Normal file
373
crates/quicproquo-server/src/node_service/auth_ops.rs
Normal file
@@ -0,0 +1,373 @@
|
||||
use capnp::capability::Promise;
|
||||
use opaque_ke::{
|
||||
CredentialFinalization, CredentialRequest, RegistrationRequest, RegistrationUpload,
|
||||
ServerLogin, ServerRegistration,
|
||||
};
|
||||
use quicproquo_core::opaque_auth::OpaqueSuite;
|
||||
use quicproquo_proto::node_capnp::node_service;
|
||||
|
||||
use crate::auth::{coded_error, current_timestamp, PendingLogin, SESSION_TTL_SECS};
|
||||
use crate::error_codes::*;
|
||||
use crate::metrics;
|
||||
use crate::storage::StorageError;
|
||||
|
||||
use super::NodeServiceImpl;
|
||||
|
||||
// Audit events in this module must never include secrets (no session tokens, passwords, or raw keys).
|
||||
|
||||
fn storage_err(err: StorageError) -> capnp::Error {
|
||||
coded_error(E009_STORAGE_ERROR, err)
|
||||
}
|
||||
|
||||
/// Parse username from Cap'n Proto reader; requires valid UTF-8.
|
||||
fn parse_username_param(
|
||||
result: Result<capnp::text::Reader<'_>, capnp::Error>,
|
||||
) -> Result<String, capnp::Error> {
|
||||
let reader = result.map_err(|e| coded_error(E020_BAD_PARAMS, e))?;
|
||||
reader
|
||||
.to_string()
|
||||
.map_err(|_| coded_error(E020_BAD_PARAMS, "username must be valid UTF-8"))
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
pub fn handle_opaque_login_start(
|
||||
&mut self,
|
||||
params: node_service::OpaqueLoginStartParams,
|
||||
mut results: node_service::OpaqueLoginStartResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let request_bytes = match p.get_request() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
|
||||
if username.is_empty() {
|
||||
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
|
||||
}
|
||||
|
||||
// Check for existing recent pending login before expensive OPAQUE/storage work (DoS mitigation).
|
||||
if let Some(existing) = self.pending_logins.get(&username) {
|
||||
let age = current_timestamp().saturating_sub(existing.created_at);
|
||||
if age < 60 {
|
||||
return Promise::err(coded_error(E010_OPAQUE_ERROR, "login already in progress"));
|
||||
}
|
||||
}
|
||||
|
||||
let credential_request = match CredentialRequest::<OpaqueSuite>::deserialize(&request_bytes) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("invalid credential request: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let password_file = match self.store.get_user_record(&username) {
|
||||
Ok(Some(bytes)) => match ServerRegistration::<OpaqueSuite>::deserialize(&bytes) {
|
||||
Ok(pf) => Some(pf),
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("corrupt user record: {e}"),
|
||||
))
|
||||
}
|
||||
},
|
||||
Ok(None) => None,
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
let result = match ServerLogin::<OpaqueSuite>::start(
|
||||
&mut rng,
|
||||
&self.opaque_setup,
|
||||
password_file,
|
||||
credential_request,
|
||||
username.as_bytes(),
|
||||
Default::default(),
|
||||
) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("OPAQUE login start failed: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let state_bytes = result.state.serialize().to_vec();
|
||||
self.pending_logins.insert(
|
||||
username.clone(),
|
||||
PendingLogin {
|
||||
state_bytes,
|
||||
created_at: current_timestamp(),
|
||||
},
|
||||
);
|
||||
|
||||
let response_bytes = result.message.serialize();
|
||||
results.get().set_response(&response_bytes);
|
||||
|
||||
tracing::info!(user = %username, "OPAQUE login started");
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_opaque_register_start(
|
||||
&mut self,
|
||||
params: node_service::OpaqueRegisterStartParams,
|
||||
mut results: node_service::OpaqueRegisterStartResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let request_bytes = match p.get_request() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
|
||||
if username.is_empty() {
|
||||
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
|
||||
}
|
||||
|
||||
if let Ok(true) = self.store.has_user_record(&username) {
|
||||
return Promise::err(coded_error(
|
||||
E018_USER_EXISTS,
|
||||
format!("user '{}' already registered", username),
|
||||
));
|
||||
}
|
||||
|
||||
let registration_request = match RegistrationRequest::<OpaqueSuite>::deserialize(&request_bytes) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("invalid registration request: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let result = match ServerRegistration::<OpaqueSuite>::start(
|
||||
&self.opaque_setup,
|
||||
registration_request,
|
||||
username.as_bytes(),
|
||||
) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("OPAQUE registration start failed: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let response_bytes = result.message.serialize();
|
||||
results.get().set_response(&response_bytes);
|
||||
|
||||
tracing::info!(user = %username, "OPAQUE registration started");
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_opaque_login_finish(
|
||||
&mut self,
|
||||
params: node_service::OpaqueLoginFinishParams,
|
||||
mut results: node_service::OpaqueLoginFinishResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let finalization_bytes = match p.get_finalization() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let identity_key = p.get_identity_key().unwrap_or_default().to_vec();
|
||||
|
||||
if username.is_empty() {
|
||||
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
|
||||
}
|
||||
|
||||
let pending = match self.pending_logins.remove(&username) {
|
||||
Some((_, pl)) => pl,
|
||||
None => {
|
||||
// Audit: login failure — do not log secrets (no token, no password).
|
||||
tracing::warn!(user = %username, "audit: auth login failure (no pending login)");
|
||||
metrics::record_auth_login_failure_total();
|
||||
return Promise::err(coded_error(E019_NO_PENDING_LOGIN, "no pending login for this username"))
|
||||
}
|
||||
};
|
||||
|
||||
let server_login = match ServerLogin::<OpaqueSuite>::deserialize(&pending.state_bytes) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("corrupt login state: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let finalization = match CredentialFinalization::<OpaqueSuite>::deserialize(&finalization_bytes) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("invalid credential finalization: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let _result = match server_login.finish(finalization, Default::default()) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!(user = %username, "audit: auth login failure (OPAQUE finish failed)");
|
||||
metrics::record_auth_login_failure_total();
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("OPAQUE login finish failed (bad password?): {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
if identity_key.is_empty() {
|
||||
metrics::record_auth_login_failure_total();
|
||||
return Promise::err(coded_error(
|
||||
E016_IDENTITY_MISMATCH,
|
||||
"identity key required to bind session token",
|
||||
));
|
||||
}
|
||||
|
||||
if let Ok(Some(stored_ik)) = self.store.get_user_identity_key(&username) {
|
||||
if stored_ik != identity_key {
|
||||
tracing::warn!(user = %username, "audit: auth login failure (identity mismatch)");
|
||||
metrics::record_auth_login_failure_total();
|
||||
return Promise::err(coded_error(
|
||||
E016_IDENTITY_MISMATCH,
|
||||
"identity key does not match registered key",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let mut token = [0u8; 32];
|
||||
rand::RngCore::fill_bytes(&mut rand::rngs::OsRng, &mut token);
|
||||
let token_vec = token.to_vec();
|
||||
|
||||
let now = current_timestamp();
|
||||
self.sessions.insert(
|
||||
token_vec.clone(),
|
||||
crate::auth::SessionInfo {
|
||||
username: username.clone(),
|
||||
identity_key,
|
||||
created_at: now,
|
||||
expires_at: now + SESSION_TTL_SECS,
|
||||
},
|
||||
);
|
||||
|
||||
results.get().set_session_token(&token_vec);
|
||||
|
||||
// Audit: login success — do not log session token or any secrets.
|
||||
metrics::record_auth_login_success_total();
|
||||
tracing::info!(user = %username, "audit: auth login success — session token issued");
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_opaque_register_finish(
|
||||
&mut self,
|
||||
params: node_service::OpaqueRegisterFinishParams,
|
||||
mut results: node_service::OpaqueRegisterFinishResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let upload_bytes = match p.get_upload() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let identity_key = p.get_identity_key().unwrap_or_default().to_vec();
|
||||
|
||||
if username.is_empty() {
|
||||
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
|
||||
}
|
||||
|
||||
let _request = match RegistrationRequest::<OpaqueSuite>::deserialize(&upload_bytes) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("invalid registration upload: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match self.store.has_user_record(&username) {
|
||||
Ok(true) => {
|
||||
return Promise::err(coded_error(
|
||||
E018_USER_EXISTS,
|
||||
format!("user '{}' already registered", username),
|
||||
))
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let upload = match RegistrationUpload::<OpaqueSuite>::deserialize(&upload_bytes) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("invalid registration upload: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let password_file = ServerRegistration::<OpaqueSuite>::finish(upload);
|
||||
let record_bytes = password_file.serialize().to_vec();
|
||||
|
||||
match self
|
||||
.store
|
||||
.store_user_record(&username, record_bytes)
|
||||
{
|
||||
Ok(()) => {}
|
||||
Err(crate::storage::StorageError::DuplicateUser(_)) => {
|
||||
return Promise::err(coded_error(
|
||||
E018_USER_EXISTS,
|
||||
format!("user '{}' already registered", username),
|
||||
))
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
}
|
||||
|
||||
if !identity_key.is_empty() {
|
||||
if let Err(e) = self
|
||||
.store
|
||||
.store_user_identity_key(&username, identity_key)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
return Promise::err(e);
|
||||
}
|
||||
}
|
||||
|
||||
results.get().set_success(true);
|
||||
tracing::info!(user = %username, "OPAQUE registration complete");
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
62
crates/quicproquo-server/src/node_service/channel_ops.rs
Normal file
62
crates/quicproquo-server/src/node_service/channel_ops.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
//! createChannel RPC: create or look up a 1:1 DM channel.
|
||||
|
||||
use capnp::capability::Promise;
|
||||
use quicproquo_proto::node_capnp::node_service;
|
||||
|
||||
use crate::auth::{coded_error, require_identity, validate_auth_context};
|
||||
use crate::error_codes::*;
|
||||
use crate::storage::StorageError;
|
||||
|
||||
use super::NodeServiceImpl;
|
||||
|
||||
fn storage_err(err: StorageError) -> capnp::Error {
|
||||
coded_error(E009_STORAGE_ERROR, err)
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
pub fn handle_create_channel(
|
||||
&mut self,
|
||||
params: node_service::CreateChannelParams,
|
||||
mut results: node_service::CreateChannelResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let peer_key = match p.get_peer_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
let identity = match require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if peer_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("peerKey must be exactly 32 bytes, got {}", peer_key.len()),
|
||||
));
|
||||
}
|
||||
|
||||
if identity == peer_key {
|
||||
return Promise::err(coded_error(
|
||||
E020_BAD_PARAMS,
|
||||
"peerKey must not equal caller identity",
|
||||
));
|
||||
}
|
||||
|
||||
let channel_id = match self.store.create_channel(&identity, &peer_key) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
|
||||
results.get().set_channel_id(&channel_id);
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
674
crates/quicproquo-server/src/node_service/delivery.rs
Normal file
674
crates/quicproquo-server/src/node_service/delivery.rs
Normal file
@@ -0,0 +1,674 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use capnp::capability::Promise;
|
||||
use dashmap::DashMap;
|
||||
use quicproquo_proto::node_capnp::node_service;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::auth::{
|
||||
check_rate_limit, coded_error, fmt_hex, require_identity_or_request, validate_auth_context,
|
||||
};
|
||||
use crate::error_codes::*;
|
||||
use crate::metrics;
|
||||
use crate::storage::{StorageError, Store};
|
||||
|
||||
use super::{NodeServiceImpl, CURRENT_WIRE_VERSION};
|
||||
|
||||
// Audit events here must not include secrets: no payload content, no full recipient/token bytes (prefix only).
|
||||
|
||||
const MAX_PAYLOAD_BYTES: usize = 5 * 1024 * 1024; // 5 MB cap per message
|
||||
const MAX_QUEUE_DEPTH: usize = 1000;
|
||||
|
||||
fn storage_err(err: StorageError) -> capnp::Error {
|
||||
coded_error(E009_STORAGE_ERROR, err)
|
||||
}
|
||||
|
||||
pub fn fill_payloads_wait(
|
||||
results: &mut node_service::FetchWaitResults,
|
||||
messages: Vec<(u64, Vec<u8>)>,
|
||||
) {
|
||||
let mut list = results.get().init_payloads(messages.len() as u32);
|
||||
for (i, (seq, data)) in messages.iter().enumerate() {
|
||||
let mut entry = list.reborrow().get(i as u32);
|
||||
entry.set_seq(*seq);
|
||||
entry.set_data(data);
|
||||
}
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
pub fn handle_enqueue(
|
||||
&mut self,
|
||||
params: node_service::EnqueueParams,
|
||||
mut results: node_service::EnqueueResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let payload = match p.get_payload() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
let version = p.get_version();
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if payload.is_empty() {
|
||||
return Promise::err(coded_error(E005_PAYLOAD_EMPTY, "payload must not be empty"));
|
||||
}
|
||||
if payload.len() > MAX_PAYLOAD_BYTES {
|
||||
return Promise::err(coded_error(
|
||||
E006_PAYLOAD_TOO_LARGE,
|
||||
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) {
|
||||
// Audit: rate limit hit — do not log token or identity.
|
||||
tracing::warn!("rate_limit_hit");
|
||||
metrics::record_rate_limit_hit_total();
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
// When sealed_sender is true, enqueue does not require identity; valid token only.
|
||||
// Otherwise, the sender must have an identity-bound session (but their identity
|
||||
// does NOT need to match the recipient — they're sending TO the recipient).
|
||||
if !self.sealed_sender {
|
||||
if let Err(e) = crate::auth::require_identity(&auth_ctx) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
}
|
||||
|
||||
// DM channel authz: channel_id.len() == 16 means a created channel; caller and recipient must be the two members.
|
||||
if channel_id.len() == 16 {
|
||||
let members = match self.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
let caller = match crate::auth::require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = caller == a.as_slice() || caller == b.as_slice();
|
||||
let recipient_other = (recipient_key == *a && caller == b.as_slice())
|
||||
|| (recipient_key == *b && caller == a.as_slice());
|
||||
if !caller_in || !recipient_other {
|
||||
return Promise::err(coded_error(
|
||||
E022_CHANNEL_ACCESS_DENIED,
|
||||
"caller or recipient not a member of this channel",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
match self.store.queue_depth(&recipient_key, &channel_id) {
|
||||
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
|
||||
return Promise::err(coded_error(
|
||||
E015_QUEUE_FULL,
|
||||
format!("queue depth {} exceeds limit {}", depth, MAX_QUEUE_DEPTH),
|
||||
));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let payload_len = payload.len();
|
||||
let seq = match self
|
||||
.store
|
||||
.enqueue(&recipient_key, &channel_id, payload)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(seq) => seq,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
results.get().set_seq(seq);
|
||||
|
||||
// Metrics and audit. Audit events must not include secrets (no payload, no full keys).
|
||||
metrics::record_enqueue_total();
|
||||
metrics::record_enqueue_bytes(payload_len as u64);
|
||||
if let Ok(depth) = self.store.queue_depth(&recipient_key, &channel_id) {
|
||||
metrics::record_delivery_queue_depth(depth);
|
||||
}
|
||||
tracing::info!(
|
||||
recipient_prefix = %fmt_hex(&recipient_key[..4]),
|
||||
payload_len = payload_len,
|
||||
seq = seq,
|
||||
"audit: enqueue"
|
||||
);
|
||||
|
||||
crate::auth::waiter(&self.waiters, &recipient_key).notify_waiters();
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_fetch(
|
||||
&mut self,
|
||||
params: node_service::FetchParams,
|
||||
mut results: node_service::FetchResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let recipient_key = match params.get() {
|
||||
Ok(p) => match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
},
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = params
|
||||
.get()
|
||||
.ok()
|
||||
.and_then(|p| p.get_channel_id().ok())
|
||||
.map(|c| c.to_vec())
|
||||
.unwrap_or_default();
|
||||
let version = params
|
||||
.get()
|
||||
.ok()
|
||||
.map(|p| p.get_version())
|
||||
.unwrap_or(CURRENT_WIRE_VERSION);
|
||||
let limit = params.get().ok().map(|p| p.get_limit()).unwrap_or(0);
|
||||
let auth_ctx = match params
|
||||
.get()
|
||||
.ok()
|
||||
.map(|p| validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()))
|
||||
.transpose()
|
||||
{
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
let auth_ctx = match auth_ctx {
|
||||
Some(ctx) => ctx,
|
||||
None => return Promise::err(coded_error(E003_INVALID_TOKEN, "auth required")),
|
||||
};
|
||||
|
||||
if let Err(e) = require_identity_or_request(
|
||||
&auth_ctx,
|
||||
&recipient_key,
|
||||
self.auth_cfg.allow_insecure_identity_from_request,
|
||||
) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if channel_id.len() == 16 {
|
||||
let members = match self.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
let caller = match crate::auth::require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = caller == a.as_slice() || caller == b.as_slice();
|
||||
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|
||||
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
|
||||
if !caller_in || !recipient_other {
|
||||
return Promise::err(coded_error(
|
||||
E022_CHANNEL_ACCESS_DENIED,
|
||||
"caller or recipient not a member of this channel",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let messages = if limit > 0 {
|
||||
match self
|
||||
.store
|
||||
.fetch_limited(&recipient_key, &channel_id, limit as usize)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(m) => m,
|
||||
Err(e) => return Promise::err(e),
|
||||
}
|
||||
} else {
|
||||
match self
|
||||
.store
|
||||
.fetch(&recipient_key, &channel_id)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(m) => m,
|
||||
Err(e) => return Promise::err(e),
|
||||
}
|
||||
};
|
||||
|
||||
// Audit: fetch — do not log payload or full keys.
|
||||
metrics::record_fetch_total();
|
||||
tracing::info!(
|
||||
recipient_prefix = %fmt_hex(&recipient_key[..4]),
|
||||
count = messages.len(),
|
||||
"audit: fetch"
|
||||
);
|
||||
|
||||
let mut list = results.get().init_payloads(messages.len() as u32);
|
||||
for (i, (seq, data)) in messages.iter().enumerate() {
|
||||
let mut entry = list.reborrow().get(i as u32);
|
||||
entry.set_seq(*seq);
|
||||
entry.set_data(data);
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_fetch_wait(
|
||||
&mut self,
|
||||
params: node_service::FetchWaitParams,
|
||||
mut results: node_service::FetchWaitResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
let version = p.get_version();
|
||||
let timeout_ms = p.get_timeout_ms();
|
||||
let limit = p.get_limit();
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = require_identity_or_request(
|
||||
&auth_ctx,
|
||||
&recipient_key,
|
||||
self.auth_cfg.allow_insecure_identity_from_request,
|
||||
) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if channel_id.len() == 16 {
|
||||
let members = match self.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
let caller = match crate::auth::require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = caller == a.as_slice() || caller == b.as_slice();
|
||||
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|
||||
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
|
||||
if !caller_in || !recipient_other {
|
||||
return Promise::err(coded_error(
|
||||
E022_CHANNEL_ACCESS_DENIED,
|
||||
"caller or recipient not a member of this channel",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let store = Arc::clone(&self.store);
|
||||
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = self.waiters.clone();
|
||||
|
||||
Promise::from_future(async move {
|
||||
let fetch_fn = |s: &Arc<dyn Store>, rk: &[u8], ch: &[u8], lim: u32| -> Result<Vec<(u64, Vec<u8>)>, capnp::Error> {
|
||||
if lim > 0 {
|
||||
s.fetch_limited(rk, ch, lim as usize).map_err(storage_err)
|
||||
} else {
|
||||
s.fetch(rk, ch).map_err(storage_err)
|
||||
}
|
||||
};
|
||||
|
||||
let messages = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
|
||||
|
||||
if messages.is_empty() && timeout_ms > 0 {
|
||||
let waiter = waiters
|
||||
.entry(recipient_key.clone())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone();
|
||||
let _ = timeout(Duration::from_millis(timeout_ms), waiter.notified()).await;
|
||||
let msgs = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
|
||||
fill_payloads_wait(&mut results, msgs);
|
||||
metrics::record_fetch_wait_total();
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fill_payloads_wait(&mut results, messages);
|
||||
metrics::record_fetch_wait_total();
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handle_peek(
|
||||
&mut self,
|
||||
params: node_service::PeekParams,
|
||||
mut results: node_service::PeekResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
let version = p.get_version();
|
||||
let limit = p.get_limit();
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = require_identity_or_request(
|
||||
&auth_ctx,
|
||||
&recipient_key,
|
||||
self.auth_cfg.allow_insecure_identity_from_request,
|
||||
) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let messages = match self
|
||||
.store
|
||||
.peek(&recipient_key, &channel_id, limit as usize)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(m) => m,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
recipient_prefix = %fmt_hex(&recipient_key[..4]),
|
||||
count = messages.len(),
|
||||
"audit: peek"
|
||||
);
|
||||
|
||||
let mut list = results.get().init_payloads(messages.len() as u32);
|
||||
for (i, (seq, data)) in messages.iter().enumerate() {
|
||||
let mut entry = list.reborrow().get(i as u32);
|
||||
entry.set_seq(*seq);
|
||||
entry.set_data(data);
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_ack(
|
||||
&mut self,
|
||||
params: node_service::AckParams,
|
||||
_results: node_service::AckResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
let version = p.get_version();
|
||||
let seq_up_to = p.get_seq_up_to();
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = require_identity_or_request(
|
||||
&auth_ctx,
|
||||
&recipient_key,
|
||||
self.auth_cfg.allow_insecure_identity_from_request,
|
||||
) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
match self
|
||||
.store
|
||||
.ack(&recipient_key, &channel_id, seq_up_to)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(removed) => {
|
||||
tracing::info!(
|
||||
recipient_prefix = %fmt_hex(&recipient_key[..4]),
|
||||
seq_up_to = seq_up_to,
|
||||
removed = removed,
|
||||
"audit: ack"
|
||||
);
|
||||
}
|
||||
Err(e) => return Promise::err(e),
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_batch_enqueue(
|
||||
&mut self,
|
||||
params: node_service::BatchEnqueueParams,
|
||||
mut results: node_service::BatchEnqueueResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let recipient_keys = match p.get_recipient_keys() {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let payload = match p.get_payload() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
let version = p.get_version();
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if payload.is_empty() {
|
||||
return Promise::err(coded_error(E005_PAYLOAD_EMPTY, "payload must not be empty"));
|
||||
}
|
||||
if payload.len() > MAX_PAYLOAD_BYTES {
|
||||
return Promise::err(coded_error(
|
||||
E006_PAYLOAD_TOO_LARGE,
|
||||
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) {
|
||||
tracing::warn!("rate_limit_hit");
|
||||
metrics::record_rate_limit_hit_total();
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
// When sealed_sender is false, require an identity-bound session.
|
||||
if !self.sealed_sender {
|
||||
if let Err(e) = crate::auth::require_identity(&auth_ctx) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
}
|
||||
|
||||
// DM channel authz: validate caller membership once before the loop.
|
||||
if channel_id.len() == 16 {
|
||||
let members = match self.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
let caller = match crate::auth::require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = caller == a.as_slice() || caller == b.as_slice();
|
||||
if !caller_in {
|
||||
return Promise::err(coded_error(
|
||||
E022_CHANNEL_ACCESS_DENIED,
|
||||
"caller is not a member of this channel",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let mut seqs = Vec::with_capacity(recipient_keys.len() as usize);
|
||||
for i in 0..recipient_keys.len() {
|
||||
let rk = match recipient_keys.get(i) {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
if rk.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey[{}] must be exactly 32 bytes, got {}", i, rk.len()),
|
||||
));
|
||||
}
|
||||
|
||||
// Per-recipient DM channel membership check.
|
||||
if channel_id.len() == 16 {
|
||||
let members = match self.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(
|
||||
E023_CHANNEL_NOT_FOUND,
|
||||
"channel not found",
|
||||
));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
let caller = match crate::auth::require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let recipient_other = (rk == *a && caller == b.as_slice())
|
||||
|| (rk == *b && caller == a.as_slice());
|
||||
if !recipient_other {
|
||||
return Promise::err(coded_error(
|
||||
E022_CHANNEL_ACCESS_DENIED,
|
||||
"recipient is not a member of this channel",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
match self.store.queue_depth(&rk, &channel_id) {
|
||||
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
|
||||
return Promise::err(coded_error(
|
||||
E015_QUEUE_FULL,
|
||||
format!("queue depth {} exceeds limit {}", depth, MAX_QUEUE_DEPTH),
|
||||
));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let seq = match self
|
||||
.store
|
||||
.enqueue(&rk, &channel_id, payload.clone())
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(seq) => seq,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
seqs.push(seq);
|
||||
|
||||
metrics::record_enqueue_total();
|
||||
metrics::record_enqueue_bytes(payload.len() as u64);
|
||||
|
||||
crate::auth::waiter(&self.waiters, &rk).notify_waiters();
|
||||
}
|
||||
|
||||
let mut list = results.get().init_seqs(seqs.len() as u32);
|
||||
for (i, seq) in seqs.iter().enumerate() {
|
||||
list.set(i as u32, *seq);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
recipient_count = recipient_keys.len(),
|
||||
payload_len = payload.len(),
|
||||
"audit: batch_enqueue"
|
||||
);
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
294
crates/quicproquo-server/src/node_service/key_ops.rs
Normal file
294
crates/quicproquo-server/src/node_service/key_ops.rs
Normal file
@@ -0,0 +1,294 @@
|
||||
use capnp::capability::Promise;
|
||||
use quicproquo_proto::node_capnp::node_service;
|
||||
|
||||
use crate::auth::{coded_error, fmt_hex, require_identity_or_request, validate_auth_context};
|
||||
use crate::error_codes::*;
|
||||
use crate::metrics;
|
||||
use crate::storage::StorageError;
|
||||
|
||||
use super::NodeServiceImpl;
|
||||
|
||||
fn storage_err(err: StorageError) -> capnp::Error {
|
||||
coded_error(E009_STORAGE_ERROR, err)
|
||||
}
|
||||
|
||||
const MAX_KEYPACKAGE_BYTES: usize = 1 * 1024 * 1024; // 1 MB cap per KeyPackage
|
||||
|
||||
impl NodeServiceImpl {
|
||||
pub fn handle_upload_key_package(
|
||||
&mut self,
|
||||
params: node_service::UploadKeyPackageParams,
|
||||
mut results: node_service::UploadKeyPackageResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let (auth_ctx, identity_key, package) = match params.get() {
|
||||
Ok(p) => {
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let ik = match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let pkg = match p.get_package() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
(auth_ctx, ik, pkg)
|
||||
}
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if identity_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
|
||||
));
|
||||
}
|
||||
if package.is_empty() {
|
||||
return Promise::err(coded_error(E007_PACKAGE_EMPTY, "package must not be empty"));
|
||||
}
|
||||
if package.len() > MAX_KEYPACKAGE_BYTES {
|
||||
return Promise::err(coded_error(
|
||||
E008_PACKAGE_TOO_LARGE,
|
||||
format!("package exceeds max size ({} bytes)", MAX_KEYPACKAGE_BYTES),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = require_identity_or_request(
|
||||
&auth_ctx,
|
||||
&identity_key,
|
||||
self.auth_cfg.allow_insecure_identity_from_request,
|
||||
) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if let Err(e) = quicproquo_core::validate_keypackage_ciphersuite(&package) {
|
||||
return Promise::err(coded_error(
|
||||
E021_CIPHERSUITE_NOT_ALLOWED,
|
||||
format!("KeyPackage ciphersuite not allowed: {e}"),
|
||||
));
|
||||
}
|
||||
|
||||
let fingerprint: Vec<u8> = crate::auth::fingerprint(&package);
|
||||
if let Err(e) = self
|
||||
.store
|
||||
.upload_key_package(&identity_key, package)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
results.get().set_fingerprint(&fingerprint);
|
||||
|
||||
metrics::record_key_package_upload_total();
|
||||
// Audit: KeyPackage upload — only fingerprint prefix, no secrets.
|
||||
tracing::info!(
|
||||
identity_prefix = %fmt_hex(&identity_key[..4]),
|
||||
fingerprint_prefix = %fmt_hex(&fingerprint[..4]),
|
||||
"audit: key_package_upload"
|
||||
);
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_fetch_key_package(
|
||||
&mut self,
|
||||
params: node_service::FetchKeyPackageParams,
|
||||
mut results: node_service::FetchKeyPackageResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let identity_key = match params.get() {
|
||||
Ok(p) => match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
},
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
if let Err(e) = params
|
||||
.get()
|
||||
.ok()
|
||||
.map(|p| crate::auth::validate_auth(&self.auth_cfg, &self.sessions, p.get_auth()))
|
||||
.transpose()
|
||||
{
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if identity_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
|
||||
));
|
||||
}
|
||||
|
||||
let package = match self
|
||||
.store
|
||||
.fetch_key_package(&identity_key)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
match package {
|
||||
Some(pkg) => {
|
||||
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "KeyPackage fetched");
|
||||
results.get().set_package(&pkg);
|
||||
}
|
||||
None => {
|
||||
tracing::debug!(
|
||||
identity = %fmt_hex(&identity_key[..4]),
|
||||
"no KeyPackage available for identity"
|
||||
);
|
||||
results.get().set_package(&[]);
|
||||
}
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_upload_hybrid_key(
|
||||
&mut self,
|
||||
params: node_service::UploadHybridKeyParams,
|
||||
_results: node_service::UploadHybridKeyResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let identity_key = match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let hybrid_pk = match p.get_hybrid_public_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if identity_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
|
||||
));
|
||||
}
|
||||
if hybrid_pk.is_empty() {
|
||||
return Promise::err(coded_error(E013_HYBRID_KEY_EMPTY, "hybridPublicKey must not be empty"));
|
||||
}
|
||||
|
||||
if let Err(e) = require_identity_or_request(
|
||||
&auth_ctx,
|
||||
&identity_key,
|
||||
self.auth_cfg.allow_insecure_identity_from_request,
|
||||
) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if let Err(e) = self
|
||||
.store
|
||||
.upload_hybrid_key(&identity_key, hybrid_pk)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "hybrid public key uploaded");
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_fetch_hybrid_key(
|
||||
&mut self,
|
||||
params: node_service::FetchHybridKeyParams,
|
||||
mut results: node_service::FetchHybridKeyResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let identity_key = match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
|
||||
// Auth check only — any authenticated user can fetch any peer's hybrid public key.
|
||||
if let Err(e) = validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if identity_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
|
||||
));
|
||||
}
|
||||
|
||||
let hybrid_pk = match self
|
||||
.store
|
||||
.fetch_hybrid_key(&identity_key)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
match hybrid_pk {
|
||||
Some(pk) => {
|
||||
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "hybrid key fetched");
|
||||
results.get().set_hybrid_public_key(&pk);
|
||||
}
|
||||
None => {
|
||||
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "no hybrid key for identity");
|
||||
results.get().set_hybrid_public_key(&[]);
|
||||
}
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_fetch_hybrid_keys(
|
||||
&mut self,
|
||||
params: node_service::FetchHybridKeysParams,
|
||||
mut results: node_service::FetchHybridKeysResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let identity_keys = match p.get_identity_keys() {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
|
||||
if let Err(e) = validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let count = identity_keys.len() as usize;
|
||||
let mut key_data: Vec<Vec<u8>> = Vec::with_capacity(count);
|
||||
for i in 0..identity_keys.len() {
|
||||
let ik = match identity_keys.get(i) {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let pk = match self.store.fetch_hybrid_key(&ik).map_err(storage_err) {
|
||||
Ok(Some(pk)) => pk,
|
||||
Ok(None) => vec![],
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
key_data.push(pk);
|
||||
}
|
||||
|
||||
let mut list = results.get().init_keys(key_data.len() as u32);
|
||||
for (i, pk) in key_data.iter().enumerate() {
|
||||
list.set(i as u32, pk);
|
||||
}
|
||||
|
||||
tracing::debug!(count = count, "batch hybrid key fetch");
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
349
crates/quicproquo-server/src/node_service/mod.rs
Normal file
349
crates/quicproquo-server/src/node_service/mod.rs
Normal file
@@ -0,0 +1,349 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use capnp_rpc::RpcSystem;
|
||||
use dashmap::DashMap;
|
||||
use opaque_ke::ServerSetup;
|
||||
use quicproquo_core::opaque_auth::OpaqueSuite;
|
||||
use quicproquo_proto::node_capnp::node_service;
|
||||
use tokio::sync::Notify;
|
||||
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
|
||||
|
||||
use crate::auth::{
|
||||
current_timestamp, AuthConfig, PendingLogin, RateEntry, SessionInfo,
|
||||
PENDING_LOGIN_TTL_SECS, RATE_LIMIT_WINDOW_SECS,
|
||||
};
|
||||
use crate::storage::Store;
|
||||
|
||||
/// Cap'n Proto traversal limit (words). 4 Mi words = 32 MiB; bounds DoS from deeply nested or large messages.
|
||||
const CAPNP_TRAVERSAL_LIMIT_WORDS: usize = 4 * 1024 * 1024;
|
||||
|
||||
mod auth_ops;
|
||||
mod channel_ops;
|
||||
mod delivery;
|
||||
mod key_ops;
|
||||
mod p2p_ops;
|
||||
mod user_ops;
|
||||
|
||||
impl node_service::Server for NodeServiceImpl {
|
||||
fn upload_key_package(
|
||||
&mut self,
|
||||
params: node_service::UploadKeyPackageParams,
|
||||
results: node_service::UploadKeyPackageResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_upload_key_package(params, results)
|
||||
}
|
||||
|
||||
fn fetch_key_package(
|
||||
&mut self,
|
||||
params: node_service::FetchKeyPackageParams,
|
||||
results: node_service::FetchKeyPackageResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_fetch_key_package(params, results)
|
||||
}
|
||||
|
||||
fn enqueue(
|
||||
&mut self,
|
||||
params: node_service::EnqueueParams,
|
||||
results: node_service::EnqueueResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_enqueue(params, results)
|
||||
}
|
||||
|
||||
fn fetch(
|
||||
&mut self,
|
||||
params: node_service::FetchParams,
|
||||
results: node_service::FetchResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_fetch(params, results)
|
||||
}
|
||||
|
||||
fn fetch_wait(
|
||||
&mut self,
|
||||
params: node_service::FetchWaitParams,
|
||||
results: node_service::FetchWaitResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_fetch_wait(params, results)
|
||||
}
|
||||
|
||||
fn health(
|
||||
&mut self,
|
||||
params: node_service::HealthParams,
|
||||
results: node_service::HealthResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_health(params, results)
|
||||
}
|
||||
|
||||
fn upload_hybrid_key(
|
||||
&mut self,
|
||||
params: node_service::UploadHybridKeyParams,
|
||||
results: node_service::UploadHybridKeyResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_upload_hybrid_key(params, results)
|
||||
}
|
||||
|
||||
fn fetch_hybrid_key(
|
||||
&mut self,
|
||||
params: node_service::FetchHybridKeyParams,
|
||||
results: node_service::FetchHybridKeyResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_fetch_hybrid_key(params, results)
|
||||
}
|
||||
|
||||
fn opaque_login_start(
|
||||
&mut self,
|
||||
params: node_service::OpaqueLoginStartParams,
|
||||
results: node_service::OpaqueLoginStartResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_opaque_login_start(params, results)
|
||||
}
|
||||
|
||||
fn opaque_register_start(
|
||||
&mut self,
|
||||
params: node_service::OpaqueRegisterStartParams,
|
||||
results: node_service::OpaqueRegisterStartResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_opaque_register_start(params, results)
|
||||
}
|
||||
|
||||
fn opaque_login_finish(
|
||||
&mut self,
|
||||
params: node_service::OpaqueLoginFinishParams,
|
||||
results: node_service::OpaqueLoginFinishResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_opaque_login_finish(params, results)
|
||||
}
|
||||
|
||||
fn opaque_register_finish(
|
||||
&mut self,
|
||||
params: node_service::OpaqueRegisterFinishParams,
|
||||
results: node_service::OpaqueRegisterFinishResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_opaque_register_finish(params, results)
|
||||
}
|
||||
|
||||
fn publish_endpoint(
|
||||
&mut self,
|
||||
params: node_service::PublishEndpointParams,
|
||||
results: node_service::PublishEndpointResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_publish_endpoint(params, results)
|
||||
}
|
||||
|
||||
fn resolve_endpoint(
|
||||
&mut self,
|
||||
params: node_service::ResolveEndpointParams,
|
||||
results: node_service::ResolveEndpointResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_resolve_endpoint(params, results)
|
||||
}
|
||||
|
||||
fn peek(
|
||||
&mut self,
|
||||
params: node_service::PeekParams,
|
||||
results: node_service::PeekResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_peek(params, results)
|
||||
}
|
||||
|
||||
fn ack(
|
||||
&mut self,
|
||||
params: node_service::AckParams,
|
||||
results: node_service::AckResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_ack(params, results)
|
||||
}
|
||||
|
||||
fn fetch_hybrid_keys(
|
||||
&mut self,
|
||||
params: node_service::FetchHybridKeysParams,
|
||||
results: node_service::FetchHybridKeysResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_fetch_hybrid_keys(params, results)
|
||||
}
|
||||
|
||||
fn batch_enqueue(
|
||||
&mut self,
|
||||
params: node_service::BatchEnqueueParams,
|
||||
results: node_service::BatchEnqueueResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_batch_enqueue(params, results)
|
||||
}
|
||||
|
||||
fn create_channel(
|
||||
&mut self,
|
||||
params: node_service::CreateChannelParams,
|
||||
results: node_service::CreateChannelResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_create_channel(params, results)
|
||||
}
|
||||
|
||||
fn resolve_user(
|
||||
&mut self,
|
||||
params: node_service::ResolveUserParams,
|
||||
results: node_service::ResolveUserResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_resolve_user(params, results)
|
||||
}
|
||||
|
||||
fn resolve_identity(
|
||||
&mut self,
|
||||
params: node_service::ResolveIdentityParams,
|
||||
results: node_service::ResolveIdentityResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_resolve_identity(params, results)
|
||||
}
|
||||
}
|
||||
|
||||
pub const CURRENT_WIRE_VERSION: u16 = 1;
|
||||
|
||||
pub struct NodeServiceImpl {
|
||||
pub store: Arc<dyn Store>,
|
||||
pub waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
pub auth_cfg: Arc<AuthConfig>,
|
||||
pub opaque_setup: Arc<ServerSetup<OpaqueSuite>>,
|
||||
pub pending_logins: Arc<DashMap<String, PendingLogin>>,
|
||||
pub sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
|
||||
pub rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
|
||||
/// When true, enqueue does not require identity-bound session (Sealed Sender).
|
||||
pub sealed_sender: bool,
|
||||
/// Outbound federation client for relaying to remote servers (None if federation disabled).
|
||||
pub federation_client: Option<Arc<crate::federation::FederationClient>>,
|
||||
/// This server's federation domain (empty if federation disabled).
|
||||
pub local_domain: Option<String>,
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
pub fn new(
|
||||
store: Arc<dyn Store>,
|
||||
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
auth_cfg: Arc<AuthConfig>,
|
||||
opaque_setup: Arc<ServerSetup<OpaqueSuite>>,
|
||||
pending_logins: Arc<DashMap<String, PendingLogin>>,
|
||||
sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
|
||||
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
|
||||
sealed_sender: bool,
|
||||
federation_client: Option<Arc<crate::federation::FederationClient>>,
|
||||
local_domain: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
store,
|
||||
waiters,
|
||||
auth_cfg,
|
||||
opaque_setup,
|
||||
pending_logins,
|
||||
sessions,
|
||||
rate_limits,
|
||||
sealed_sender,
|
||||
federation_client,
|
||||
local_domain,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_node_connection(
|
||||
connecting: quinn::Connecting,
|
||||
store: Arc<dyn Store>,
|
||||
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
auth_cfg: Arc<AuthConfig>,
|
||||
opaque_setup: Arc<ServerSetup<OpaqueSuite>>,
|
||||
pending_logins: Arc<DashMap<String, PendingLogin>>,
|
||||
sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
|
||||
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
|
||||
sealed_sender: bool,
|
||||
federation_client: Option<Arc<crate::federation::FederationClient>>,
|
||||
local_domain: Option<String>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let connection = connecting.await?;
|
||||
|
||||
tracing::info!(peer = %connection.remote_address(), "QUIC connected");
|
||||
|
||||
let (send, recv) = connection
|
||||
.accept_bi()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("failed to accept bi stream: {e}"))?;
|
||||
let (reader, writer) = (recv.compat(), send.compat_write());
|
||||
|
||||
let mut reader_opts = capnp::message::ReaderOptions::new();
|
||||
reader_opts.traversal_limit_in_words(Some(CAPNP_TRAVERSAL_LIMIT_WORDS));
|
||||
let network = capnp_rpc::twoparty::VatNetwork::new(
|
||||
reader,
|
||||
writer,
|
||||
capnp_rpc::rpc_twoparty_capnp::Side::Server,
|
||||
reader_opts,
|
||||
);
|
||||
|
||||
let service: node_service::Client = capnp_rpc::new_client(NodeServiceImpl::new(
|
||||
store,
|
||||
waiters,
|
||||
auth_cfg,
|
||||
opaque_setup,
|
||||
pending_logins,
|
||||
sessions,
|
||||
rate_limits,
|
||||
sealed_sender,
|
||||
federation_client,
|
||||
local_domain,
|
||||
));
|
||||
|
||||
RpcSystem::new(Box::new(network), Some(service.client))
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("NodeService RPC error: {e}"))
|
||||
}
|
||||
|
||||
const MESSAGE_TTL_SECS: u64 = 7 * 24 * 60 * 60; // 7 days
|
||||
|
||||
pub fn spawn_cleanup_task(
|
||||
sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
|
||||
pending_logins: Arc<DashMap<String, PendingLogin>>,
|
||||
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
|
||||
conn_rate_limits: Arc<DashMap<std::net::IpAddr, RateEntry>>,
|
||||
store: Arc<dyn Store>,
|
||||
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(60));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let now = current_timestamp();
|
||||
|
||||
sessions.retain(|_, info| info.expires_at > now);
|
||||
pending_logins.retain(|_, pl| now - pl.created_at < PENDING_LOGIN_TTL_SECS);
|
||||
rate_limits.retain(|_, entry| now - entry.window_start < RATE_LIMIT_WINDOW_SECS * 2);
|
||||
conn_rate_limits.retain(|_, entry| {
|
||||
now - entry.window_start < crate::auth::CONN_RATE_LIMIT_WINDOW_SECS * 2
|
||||
});
|
||||
|
||||
// Bound map sizes to prevent unbounded growth from malicious clients.
|
||||
const MAX_SESSIONS: usize = 100_000;
|
||||
const MAX_WAITERS: usize = 100_000;
|
||||
if sessions.len() > MAX_SESSIONS {
|
||||
let overflow = sessions.len() - MAX_SESSIONS;
|
||||
let mut entries: Vec<_> = sessions
|
||||
.iter()
|
||||
.map(|e| (e.key().clone(), e.expires_at))
|
||||
.collect();
|
||||
entries.sort_by_key(|(_, exp)| *exp);
|
||||
for (key, _) in entries.into_iter().take(overflow) {
|
||||
sessions.remove(&key);
|
||||
}
|
||||
}
|
||||
if waiters.len() > MAX_WAITERS {
|
||||
let overflow = waiters.len() - MAX_WAITERS;
|
||||
let keys: Vec<_> =
|
||||
waiters.iter().take(overflow).map(|e| e.key().clone()).collect();
|
||||
for key in keys {
|
||||
waiters.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
match store.gc_expired_messages(MESSAGE_TTL_SECS) {
|
||||
Ok(n) if n > 0 => {
|
||||
tracing::debug!(expired = n, "garbage collected expired messages")
|
||||
}
|
||||
Err(e) => tracing::warn!(error = %e, "message GC failed"),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
119
crates/quicproquo-server/src/node_service/p2p_ops.rs
Normal file
119
crates/quicproquo-server/src/node_service/p2p_ops.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use capnp::capability::Promise;
|
||||
use quicproquo_proto::node_capnp::node_service;
|
||||
|
||||
use crate::auth::{
|
||||
coded_error, fmt_hex, require_identity_or_request, validate_auth, validate_auth_context,
|
||||
};
|
||||
use crate::error_codes::*;
|
||||
use crate::storage::StorageError;
|
||||
|
||||
use super::NodeServiceImpl;
|
||||
|
||||
fn storage_err(err: StorageError) -> capnp::Error {
|
||||
coded_error(E009_STORAGE_ERROR, err)
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
/// Health check: unauthenticated by design for liveness probes and load balancers.
|
||||
pub fn handle_health(
|
||||
&mut self,
|
||||
_params: node_service::HealthParams,
|
||||
mut results: node_service::HealthResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
results.get().set_status("ok");
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_publish_endpoint(
|
||||
&mut self,
|
||||
params: node_service::PublishEndpointParams,
|
||||
_results: node_service::PublishEndpointResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let identity_key = match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let node_addr = match p.get_node_addr() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if identity_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = require_identity_or_request(
|
||||
&auth_ctx,
|
||||
&identity_key,
|
||||
self.auth_cfg.allow_insecure_identity_from_request,
|
||||
) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if let Err(e) = self
|
||||
.store
|
||||
.publish_endpoint(&identity_key, node_addr)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "endpoint published");
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_resolve_endpoint(
|
||||
&mut self,
|
||||
params: node_service::ResolveEndpointParams,
|
||||
mut results: node_service::ResolveEndpointResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let identity_key = match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
if let Err(e) = validate_auth(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if identity_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
|
||||
));
|
||||
}
|
||||
|
||||
let endpoint = match self
|
||||
.store
|
||||
.resolve_endpoint(&identity_key)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(e) => e,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if let Some(ep) = endpoint {
|
||||
tracing::debug!(identity = %fmt_hex(&identity_key[..4]), "endpoint resolved");
|
||||
results.get().set_node_addr(&ep);
|
||||
} else {
|
||||
results.get().set_node_addr(&[]);
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
131
crates/quicproquo-server/src/node_service/user_ops.rs
Normal file
131
crates/quicproquo-server/src/node_service/user_ops.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
//! resolveUser / resolveIdentity RPCs: bidirectional username ↔ identity key lookup.
|
||||
|
||||
use capnp::capability::Promise;
|
||||
use quicproquo_proto::node_capnp::node_service;
|
||||
|
||||
use crate::auth::{coded_error, validate_auth_context};
|
||||
use crate::error_codes::*;
|
||||
use crate::storage::StorageError;
|
||||
|
||||
use super::NodeServiceImpl;
|
||||
|
||||
fn storage_err(err: StorageError) -> capnp::Error {
|
||||
coded_error(E009_STORAGE_ERROR, err)
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
pub fn handle_resolve_user(
|
||||
&mut self,
|
||||
params: node_service::ResolveUserParams,
|
||||
mut results: node_service::ResolveUserResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match p.get_username() {
|
||||
Ok(u) => u,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let _auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
let username_str = match username.to_str() {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
|
||||
if username_str.is_empty() {
|
||||
return Promise::err(coded_error(E020_BAD_PARAMS, "username must not be empty"));
|
||||
}
|
||||
|
||||
// Federation: parse user@domain format.
|
||||
let addr = crate::federation::address::FederatedAddress::parse(username_str);
|
||||
let is_remote = match (&addr.domain, &self.local_domain) {
|
||||
(Some(d), Some(ld)) => d != ld,
|
||||
(Some(_), None) => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if is_remote {
|
||||
// Proxy to remote server via federation.
|
||||
if let (Some(ref fed_client), Some(ref domain)) = (&self.federation_client, &addr.domain) {
|
||||
if fed_client.has_peer(domain) {
|
||||
let fed = fed_client.clone();
|
||||
let user = addr.username.clone();
|
||||
let dom = domain.clone();
|
||||
return Promise::from_future(async move {
|
||||
match fed.proxy_resolve_user(&dom, &user).await {
|
||||
Ok(Some(key)) => {
|
||||
results.get().set_identity_key(&key);
|
||||
}
|
||||
Ok(None) => {
|
||||
// Not found on remote — return empty.
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "federation proxy_resolve_user failed");
|
||||
// Fall through — return empty (not found).
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
||||
// No federation client or unknown peer — return empty (not found).
|
||||
return Promise::ok(());
|
||||
}
|
||||
|
||||
// Local resolution.
|
||||
match self.store.get_user_identity_key(&addr.username) {
|
||||
Ok(Some(key)) => {
|
||||
results.get().set_identity_key(&key);
|
||||
}
|
||||
Ok(None) => {
|
||||
// Return empty Data — caller checks length to detect "not found".
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_resolve_identity(
|
||||
&mut self,
|
||||
params: node_service::ResolveIdentityParams,
|
||||
mut results: node_service::ResolveIdentityResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let identity_key = match p.get_identity_key() {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let _auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if identity_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("identityKey must be exactly 32 bytes, got {}", identity_key.len()),
|
||||
));
|
||||
}
|
||||
|
||||
match self.store.resolve_identity_key(identity_key) {
|
||||
Ok(Some(username)) => {
|
||||
results.get().set_username(&username);
|
||||
}
|
||||
Ok(None) => {
|
||||
// Return empty string — caller checks length to detect "not found".
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
724
crates/quicproquo-server/src/sql_store.rs
Normal file
724
crates/quicproquo-server/src/sql_store.rs
Normal file
@@ -0,0 +1,724 @@
|
||||
//! SQLCipher-backed persistent storage.
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use rand::RngCore;
|
||||
use rusqlite::{params, Connection};
|
||||
|
||||
use crate::storage::{StorageError, Store};
|
||||
|
||||
/// Schema version after introducing the migration runner (existing DBs had 1).
|
||||
const SCHEMA_VERSION: i32 = 5;
|
||||
|
||||
/// Migrations: (migration_number, SQL). Files named NNN_name.sql, applied in order when N > user_version.
|
||||
const MIGRATIONS: &[(i32, &str)] = &[
|
||||
(1, include_str!("../migrations/001_initial.sql")),
|
||||
(3, include_str!("../migrations/002_add_seq.sql")),
|
||||
(4, include_str!("../migrations/003_channels.sql")),
|
||||
(5, include_str!("../migrations/004_federation.sql")),
|
||||
];
|
||||
|
||||
/// Runs pending migrations on an open connection: applies any migration whose number is greater
|
||||
/// than the current PRAGMA user_version, then sets user_version to SCHEMA_VERSION.
|
||||
fn run_migrations(conn: &Connection) -> Result<(), StorageError> {
|
||||
let current_version: i32 = conn
|
||||
.pragma_query_value(None, "user_version", |row| row.get(0))
|
||||
.map_err(|e| StorageError::Db(format!("PRAGMA user_version failed: {e}")))?;
|
||||
|
||||
for (migration_num, sql) in MIGRATIONS {
|
||||
if *migration_num > current_version {
|
||||
conn.execute_batch(sql).map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
}
|
||||
}
|
||||
|
||||
conn.pragma_update(None, "user_version", SCHEMA_VERSION)
|
||||
.map_err(|e| StorageError::Db(format!("set user_version failed: {e}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// SQLCipher-encrypted storage backend.
|
||||
pub struct SqlStore {
|
||||
conn: Mutex<Connection>,
|
||||
}
|
||||
|
||||
impl SqlStore {
|
||||
fn lock_conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>, StorageError> {
|
||||
self.conn
|
||||
.lock()
|
||||
.map_err(|e| StorageError::Db(format!("lock poisoned: {e}")))
|
||||
}
|
||||
|
||||
pub fn open(path: impl AsRef<Path>, key: &str) -> Result<Self, StorageError> {
|
||||
let conn = Connection::open(path).map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
if !key.is_empty() {
|
||||
conn.pragma_update(None, "key", key)
|
||||
.map_err(|e| StorageError::Db(format!("PRAGMA key failed: {e}")))?;
|
||||
}
|
||||
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
PRAGMA foreign_keys = ON;",
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
let current_version: i32 = conn
|
||||
.pragma_query_value(None, "user_version", |row| row.get(0))
|
||||
.map_err(|e| StorageError::Db(format!("PRAGMA user_version failed: {e}")))?;
|
||||
|
||||
if current_version > SCHEMA_VERSION {
|
||||
return Err(StorageError::Db(format!(
|
||||
"database schema version {current_version} is newer than supported {SCHEMA_VERSION}"
|
||||
)));
|
||||
}
|
||||
|
||||
run_migrations(&conn)?;
|
||||
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Store for SqlStore {
|
||||
fn upload_key_package(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
package: Vec<u8>,
|
||||
) -> Result<(), StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
conn.execute(
|
||||
"INSERT INTO key_packages (identity_key, package_data) VALUES (?1, ?2)",
|
||||
params![identity_key, package],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
|
||||
let mut stmt = conn
|
||||
.prepare(
|
||||
"SELECT id, package_data FROM key_packages
|
||||
WHERE identity_key = ?1
|
||||
ORDER BY id ASC
|
||||
LIMIT 1",
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
let row = stmt
|
||||
.query_row(params![identity_key], |row| {
|
||||
Ok((row.get::<_, i64>(0)?, row.get::<_, Vec<u8>>(1)?))
|
||||
})
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
match row {
|
||||
Some((id, package)) => {
|
||||
conn.execute("DELETE FROM key_packages WHERE id = ?1", params![id])
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(Some(package))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn enqueue(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
payload: Vec<u8>,
|
||||
) -> Result<u64, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
// Atomically get-and-increment the per-inbox sequence counter.
|
||||
// RETURNING gives us the post-update next_seq; the assigned seq is next_seq - 1.
|
||||
let seq: i64 = conn
|
||||
.query_row(
|
||||
"INSERT INTO delivery_seq_counters (recipient_key, channel_id, next_seq)
|
||||
VALUES (?1, ?2, 1)
|
||||
ON CONFLICT(recipient_key, channel_id) DO UPDATE SET next_seq = next_seq + 1
|
||||
RETURNING next_seq - 1",
|
||||
params![recipient_key, channel_id],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
conn.execute(
|
||||
"INSERT INTO deliveries (recipient_key, channel_id, seq, payload) VALUES (?1, ?2, ?3, ?4)",
|
||||
params![recipient_key, channel_id, seq, payload],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(seq as u64)
|
||||
}
|
||||
|
||||
fn fetch(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
|
||||
let mut stmt = conn
|
||||
.prepare(
|
||||
"SELECT id, seq, payload FROM deliveries
|
||||
WHERE recipient_key = ?1 AND channel_id = ?2
|
||||
ORDER BY seq ASC",
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
let rows: Vec<(i64, i64, Vec<u8>)> = stmt
|
||||
.query_map(params![recipient_key, channel_id], |row| {
|
||||
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
|
||||
})
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
if !rows.is_empty() {
|
||||
let ids: Vec<i64> = rows.iter().map(|(id, _, _)| *id).collect();
|
||||
let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
|
||||
let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})");
|
||||
let params: Vec<&dyn rusqlite::types::ToSql> = ids
|
||||
.iter()
|
||||
.map(|id| id as &dyn rusqlite::types::ToSql)
|
||||
.collect();
|
||||
conn.execute(&sql, params.as_slice())
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
}
|
||||
|
||||
Ok(rows.into_iter().map(|(_, seq, payload)| (seq as u64, payload)).collect())
|
||||
}
|
||||
|
||||
fn fetch_limited(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
|
||||
let mut stmt = conn
|
||||
.prepare(
|
||||
"SELECT id, seq, payload FROM deliveries
|
||||
WHERE recipient_key = ?1 AND channel_id = ?2
|
||||
ORDER BY seq ASC
|
||||
LIMIT ?3",
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
let rows: Vec<(i64, i64, Vec<u8>)> = stmt
|
||||
.query_map(params![recipient_key, channel_id, limit as i64], |row| {
|
||||
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
|
||||
})
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
if !rows.is_empty() {
|
||||
let ids: Vec<i64> = rows.iter().map(|(id, _, _)| *id).collect();
|
||||
let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
|
||||
let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})");
|
||||
let params: Vec<&dyn rusqlite::types::ToSql> = ids
|
||||
.iter()
|
||||
.map(|id| id as &dyn rusqlite::types::ToSql)
|
||||
.collect();
|
||||
conn.execute(&sql, params.as_slice())
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
}
|
||||
|
||||
Ok(rows.into_iter().map(|(_, seq, payload)| (seq as u64, payload)).collect())
|
||||
}
|
||||
|
||||
fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result<usize, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let count: i64 = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2",
|
||||
params![recipient_key, channel_id],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(count as usize)
|
||||
}
|
||||
|
||||
fn gc_expired_messages(&self, max_age_secs: u64) -> Result<usize, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let cutoff = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
.saturating_sub(max_age_secs);
|
||||
let deleted = conn
|
||||
.execute(
|
||||
"DELETE FROM deliveries WHERE created_at < ?1",
|
||||
params![cutoff as i64],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
fn upload_hybrid_key(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
hybrid_pk: Vec<u8>,
|
||||
) -> Result<(), StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO hybrid_keys (identity_key, hybrid_public_key) VALUES (?1, ?2)",
|
||||
params![identity_key, hybrid_pk],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT hybrid_public_key FROM hybrid_keys WHERE identity_key = ?1")
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
stmt.query_row(params![identity_key], |row| row.get(0))
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn store_server_setup(&self, setup: Vec<u8>) -> Result<(), StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO server_setup (id, setup_data) VALUES (1, ?1)",
|
||||
params![setup],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_server_setup(&self) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT setup_data FROM server_setup WHERE id = 1")
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
stmt.query_row([], |row| row.get(0))
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
conn.execute(
|
||||
"INSERT INTO users (username, opaque_record) VALUES (?1, ?2)",
|
||||
params![username, record],
|
||||
)
|
||||
.map_err(|e| {
|
||||
if let rusqlite::Error::SqliteFailure(ref err, _) = &e {
|
||||
if err.code == rusqlite::ErrorCode::ConstraintViolation {
|
||||
return StorageError::DuplicateUser(username.to_string());
|
||||
}
|
||||
}
|
||||
StorageError::Db(e.to_string())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_user_record(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT opaque_record FROM users WHERE username = ?1")
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
stmt.query_row(params![username], |row| row.get(0))
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn has_user_record(&self, username: &str) -> Result<bool, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let exists: bool = conn
|
||||
.query_row(
|
||||
"SELECT EXISTS(SELECT 1 FROM users WHERE username = ?1)",
|
||||
params![username],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(exists)
|
||||
}
|
||||
|
||||
fn store_user_identity_key(
|
||||
&self,
|
||||
username: &str,
|
||||
identity_key: Vec<u8>,
|
||||
) -> Result<(), StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO user_identity_keys (username, identity_key) VALUES (?1, ?2)",
|
||||
params![username, identity_key],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_user_identity_key(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT identity_key FROM user_identity_keys WHERE username = ?1")
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
stmt.query_row(params![username], |row| row.get(0))
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn resolve_identity_key(&self, identity_key: &[u8]) -> Result<Option<String>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT username FROM user_identity_keys WHERE identity_key = ?1")
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
stmt.query_row(params![identity_key], |row| row.get(0))
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn peek(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
|
||||
let sql = if limit == 0 {
|
||||
"SELECT seq, payload FROM deliveries
|
||||
WHERE recipient_key = ?1 AND channel_id = ?2
|
||||
ORDER BY seq ASC".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"SELECT seq, payload FROM deliveries
|
||||
WHERE recipient_key = ?1 AND channel_id = ?2
|
||||
ORDER BY seq ASC
|
||||
LIMIT {}",
|
||||
limit
|
||||
)
|
||||
};
|
||||
|
||||
let mut stmt = conn.prepare(&sql).map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
let rows: Vec<(i64, Vec<u8>)> = stmt
|
||||
.query_map(params![recipient_key, channel_id], |row| {
|
||||
Ok((row.get(0)?, row.get(1)?))
|
||||
})
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
Ok(rows.into_iter().map(|(seq, payload)| (seq as u64, payload)).collect())
|
||||
}
|
||||
|
||||
fn ack(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
seq_up_to: u64,
|
||||
) -> Result<usize, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let deleted = conn
|
||||
.execute(
|
||||
"DELETE FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND seq <= ?3",
|
||||
params![recipient_key, channel_id, seq_up_to as i64],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
fn publish_endpoint(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
node_addr: Vec<u8>,
|
||||
) -> Result<(), StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO endpoints (identity_key, node_addr) VALUES (?1, ?2)",
|
||||
params![identity_key, node_addr],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn resolve_endpoint(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT node_addr FROM endpoints WHERE identity_key = ?1")
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
stmt.query_row(params![identity_key], |row| row.get(0))
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError> {
|
||||
let (a, b) = if member_a < member_b {
|
||||
(member_a.to_vec(), member_b.to_vec())
|
||||
} else {
|
||||
(member_b.to_vec(), member_a.to_vec())
|
||||
};
|
||||
let conn = self.lock_conn()?;
|
||||
let existing: Option<Vec<u8>> = conn
|
||||
.query_row(
|
||||
"SELECT channel_id FROM channels WHERE member_a = ?1 AND member_b = ?2",
|
||||
params![a, b],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
if let Some(id) = existing {
|
||||
return Ok(id);
|
||||
}
|
||||
let mut channel_id = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut channel_id);
|
||||
conn.execute(
|
||||
"INSERT INTO channels (channel_id, member_a, member_b) VALUES (?1, ?2, ?3)",
|
||||
params![channel_id.as_slice(), a, b],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(channel_id.to_vec())
|
||||
}
|
||||
|
||||
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
conn.query_row(
|
||||
"SELECT member_a, member_b FROM channels WHERE channel_id = ?1",
|
||||
params![channel_id],
|
||||
|row| Ok((row.get::<_, Vec<u8>>(0)?, row.get::<_, Vec<u8>>(1)?)),
|
||||
)
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn store_identity_home_server(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
home_server: &str,
|
||||
) -> Result<(), StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs() as i64;
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO identity_home_servers (identity_key, home_server, updated_at) VALUES (?1, ?2, ?3)",
|
||||
params![identity_key, home_server, now],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_identity_home_server(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
) -> Result<Option<String>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT home_server FROM identity_home_servers WHERE identity_key = ?1")
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
stmt.query_row(params![identity_key], |row| row.get(0))
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn upsert_federation_peer(
|
||||
&self,
|
||||
domain: &str,
|
||||
is_active: bool,
|
||||
) -> Result<(), StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs() as i64;
|
||||
conn.execute(
|
||||
"INSERT INTO federation_peers (domain, last_seen, is_active) VALUES (?1, ?2, ?3)
|
||||
ON CONFLICT(domain) DO UPDATE SET last_seen = ?2, is_active = ?3",
|
||||
params![domain, now, is_active as i32],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn list_federation_peers(&self) -> Result<Vec<(String, bool)>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let mut stmt = conn
|
||||
.prepare("SELECT domain, is_active FROM federation_peers WHERE is_active = 1")
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
let rows = stmt
|
||||
.query_map([], |row| {
|
||||
Ok((row.get::<_, String>(0)?, row.get::<_, i32>(1)? != 0))
|
||||
})
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(rows)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience extension for `rusqlite::OptionalExtension`.
|
||||
trait OptionalExt<T> {
|
||||
fn optional(self) -> Result<Option<T>, rusqlite::Error>;
|
||||
}
|
||||
|
||||
impl<T> OptionalExt<T> for Result<T, rusqlite::Error> {
|
||||
fn optional(self) -> Result<Option<T>, rusqlite::Error> {
|
||||
match self {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn open_in_memory() -> SqlStore {
|
||||
SqlStore::open(":memory:", "").unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sets_user_version_after_migrate() {
|
||||
let dir = tempfile::tempdir().expect("tempdir");
|
||||
let db_path: PathBuf = dir.path().join("store.db");
|
||||
|
||||
{
|
||||
let store = SqlStore::open(&db_path, "").expect("open store");
|
||||
let _guard = store.lock_conn().unwrap();
|
||||
}
|
||||
|
||||
let conn = rusqlite::Connection::open(&db_path).expect("reopen db");
|
||||
let version: i32 = conn
|
||||
.pragma_query_value(None, "user_version", |row| row.get(0))
|
||||
.expect("read user_version");
|
||||
|
||||
assert_eq!(version, SCHEMA_VERSION);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_package_fifo() {
|
||||
let store = open_in_memory();
|
||||
let identity = [1u8; 32];
|
||||
|
||||
store
|
||||
.upload_key_package(&identity, b"kp1".to_vec())
|
||||
.unwrap();
|
||||
store
|
||||
.upload_key_package(&identity, b"kp2".to_vec())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
store.fetch_key_package(&identity).unwrap(),
|
||||
Some(b"kp1".to_vec())
|
||||
);
|
||||
assert_eq!(
|
||||
store.fetch_key_package(&identity).unwrap(),
|
||||
Some(b"kp2".to_vec())
|
||||
);
|
||||
assert_eq!(store.fetch_key_package(&identity).unwrap(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delivery_round_trip() {
|
||||
let store = open_in_memory();
|
||||
let rk = [1u8; 32];
|
||||
let ch = b"channel-1";
|
||||
|
||||
let seq0 = store.enqueue(&rk, ch, b"msg1".to_vec()).unwrap();
|
||||
let seq1 = store.enqueue(&rk, ch, b"msg2".to_vec()).unwrap();
|
||||
assert_eq!(seq0, 0);
|
||||
assert_eq!(seq1, 1);
|
||||
|
||||
let msgs = store.fetch(&rk, ch).unwrap();
|
||||
assert_eq!(msgs, vec![(0u64, b"msg1".to_vec()), (1u64, b"msg2".to_vec())]);
|
||||
|
||||
assert!(store.fetch(&rk, ch).unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fetch_limited_partial_drain() {
|
||||
let store = open_in_memory();
|
||||
let rk = [5u8; 32];
|
||||
let ch = b"ch";
|
||||
|
||||
store.enqueue(&rk, ch, b"a".to_vec()).unwrap();
|
||||
store.enqueue(&rk, ch, b"b".to_vec()).unwrap();
|
||||
store.enqueue(&rk, ch, b"c".to_vec()).unwrap();
|
||||
|
||||
let msgs = store.fetch_limited(&rk, ch, 2).unwrap();
|
||||
assert_eq!(msgs, vec![(0u64, b"a".to_vec()), (1u64, b"b".to_vec())]);
|
||||
|
||||
let remaining = store.fetch(&rk, ch).unwrap();
|
||||
assert_eq!(remaining, vec![(2u64, b"c".to_vec())]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn queue_depth_count() {
|
||||
let store = open_in_memory();
|
||||
let rk = [6u8; 32];
|
||||
let ch = b"ch";
|
||||
|
||||
assert_eq!(store.queue_depth(&rk, ch).unwrap(), 0);
|
||||
store.enqueue(&rk, ch, b"x".to_vec()).unwrap();
|
||||
store.enqueue(&rk, ch, b"y".to_vec()).unwrap();
|
||||
assert_eq!(store.queue_depth(&rk, ch).unwrap(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn has_user_record_check() {
|
||||
let store = open_in_memory();
|
||||
assert!(!store.has_user_record("user1").unwrap());
|
||||
store
|
||||
.store_user_record("user1", b"record".to_vec())
|
||||
.unwrap();
|
||||
assert!(store.has_user_record("user1").unwrap());
|
||||
assert!(!store.has_user_record("user2").unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_identity_key_round_trip() {
|
||||
let store = open_in_memory();
|
||||
assert!(store.get_user_identity_key("user1").unwrap().is_none());
|
||||
store
|
||||
.store_user_identity_key("user1", vec![1u8; 32])
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
store.get_user_identity_key("user1").unwrap(),
|
||||
Some(vec![1u8; 32])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hybrid_key_round_trip() {
|
||||
let store = open_in_memory();
|
||||
let ik = [2u8; 32];
|
||||
let pk = b"hybrid_public_key_data".to_vec();
|
||||
|
||||
store.upload_hybrid_key(&ik, pk.clone()).unwrap();
|
||||
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(pk));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn separate_channels_isolated() {
|
||||
let store = open_in_memory();
|
||||
let rk = [4u8; 32];
|
||||
|
||||
store.enqueue(&rk, b"ch-a", b"a1".to_vec()).unwrap();
|
||||
store.enqueue(&rk, b"ch-b", b"b1".to_vec()).unwrap();
|
||||
|
||||
let a_msgs = store.fetch(&rk, b"ch-a").unwrap();
|
||||
assert_eq!(a_msgs, vec![(0u64, b"a1".to_vec())]);
|
||||
|
||||
let b_msgs = store.fetch(&rk, b"ch-b").unwrap();
|
||||
assert_eq!(b_msgs, vec![(0u64, b"b1".to_vec())]);
|
||||
}
|
||||
}
|
||||
823
crates/quicproquo-server/src/storage.rs
Normal file
823
crates/quicproquo-server/src/storage.rs
Normal file
@@ -0,0 +1,823 @@
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fs,
|
||||
hash::Hash,
|
||||
path::{Path, PathBuf},
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum StorageError {
|
||||
#[error("io error: {0}")]
|
||||
Io(String),
|
||||
#[error("serialization error")]
|
||||
Serde,
|
||||
#[error("database error: {0}")]
|
||||
Db(String),
|
||||
/// Unique constraint violation (e.g. user already exists).
|
||||
#[error("duplicate user: {0}")]
|
||||
DuplicateUser(String),
|
||||
}
|
||||
|
||||
fn lock<T>(m: &Mutex<T>) -> Result<std::sync::MutexGuard<'_, T>, StorageError> {
|
||||
m.lock()
|
||||
.map_err(|e| StorageError::Io(format!("lock poisoned: {e}")))
|
||||
}
|
||||
|
||||
// ── Store trait ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Abstraction over storage backends (file-backed, SQLCipher, etc.).
|
||||
pub trait Store: Send + Sync {
|
||||
fn upload_key_package(&self, identity_key: &[u8], package: Vec<u8>)
|
||||
-> Result<(), StorageError>;
|
||||
|
||||
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
|
||||
|
||||
/// Enqueue a payload and return the monotonically increasing per-inbox sequence number
|
||||
/// assigned to this message. Clients sort by seq before MLS processing.
|
||||
fn enqueue(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
payload: Vec<u8>,
|
||||
) -> Result<u64, StorageError>;
|
||||
|
||||
/// Fetch and drain all queued messages, returning `(seq, payload)` pairs ordered by seq.
|
||||
fn fetch(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError>;
|
||||
|
||||
/// Fetch up to `limit` messages without draining the entire queue (Fix 8).
|
||||
/// Returns `(seq, payload)` pairs ordered by seq.
|
||||
fn fetch_limited(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError>;
|
||||
|
||||
/// Return the number of queued messages for (recipient, channel) (Fix 7).
|
||||
fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result<usize, StorageError>;
|
||||
|
||||
/// Delete messages older than `max_age_secs`. Returns count deleted (Fix 7).
|
||||
fn gc_expired_messages(&self, max_age_secs: u64) -> Result<usize, StorageError>;
|
||||
|
||||
fn upload_hybrid_key(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
hybrid_pk: Vec<u8>,
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
|
||||
|
||||
/// Store the OPAQUE `ServerSetup` (generated once, loaded on restart).
|
||||
fn store_server_setup(&self, setup: Vec<u8>) -> Result<(), StorageError>;
|
||||
|
||||
/// Load the persisted `ServerSetup`, if any.
|
||||
fn get_server_setup(&self) -> Result<Option<Vec<u8>>, StorageError>;
|
||||
|
||||
/// Store an OPAQUE user record (serialized `ServerRegistration`).
|
||||
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError>;
|
||||
|
||||
/// Retrieve an OPAQUE user record by username.
|
||||
fn get_user_record(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError>;
|
||||
|
||||
/// Check if a user record already exists (Fix 5).
|
||||
fn has_user_record(&self, username: &str) -> Result<bool, StorageError>;
|
||||
|
||||
/// Store identity key for a user (Fix 2).
|
||||
fn store_user_identity_key(
|
||||
&self,
|
||||
username: &str,
|
||||
identity_key: Vec<u8>,
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
/// Retrieve identity key for a user (Fix 2).
|
||||
fn get_user_identity_key(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError>;
|
||||
|
||||
/// Reverse lookup: resolve an identity key to the registered username.
|
||||
fn resolve_identity_key(&self, identity_key: &[u8]) -> Result<Option<String>, StorageError>;
|
||||
|
||||
/// Peek at queued messages without removing them (non-destructive).
|
||||
/// Returns `(seq, payload)` pairs ordered by seq.
|
||||
fn peek(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError>;
|
||||
|
||||
/// Acknowledge (remove) all messages with seq <= seq_up_to.
|
||||
fn ack(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
seq_up_to: u64,
|
||||
) -> Result<usize, StorageError>;
|
||||
|
||||
/// Publish a P2P endpoint address for an identity key.
|
||||
fn publish_endpoint(&self, identity_key: &[u8], node_addr: Vec<u8>)
|
||||
-> Result<(), StorageError>;
|
||||
|
||||
/// Resolve a peer's P2P endpoint address.
|
||||
fn resolve_endpoint(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
|
||||
|
||||
/// Create a 1:1 channel between two members. Returns 16-byte channel_id (UUID).
|
||||
/// Members are stored in sorted order for deterministic lookup.
|
||||
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError>;
|
||||
|
||||
/// Get the two members of a channel by channel_id (16 bytes). Returns (member_a, member_b) in sorted order.
|
||||
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError>;
|
||||
|
||||
// ── Federation ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Store the home server domain for an identity key.
|
||||
fn store_identity_home_server(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
home_server: &str,
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
/// Get the home server domain for an identity key.
|
||||
fn get_identity_home_server(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
) -> Result<Option<String>, StorageError>;
|
||||
|
||||
/// Insert or update a federation peer.
|
||||
fn upsert_federation_peer(
|
||||
&self,
|
||||
domain: &str,
|
||||
is_active: bool,
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
/// List all active federation peers.
|
||||
fn list_federation_peers(&self) -> Result<Vec<(String, bool)>, StorageError>;
|
||||
}
|
||||
|
||||
// ── ChannelKey ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
|
||||
pub struct ChannelKey {
|
||||
pub channel_id: Vec<u8>,
|
||||
pub recipient_key: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Hash for ChannelKey {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.channel_id.hash(state);
|
||||
self.recipient_key.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
// ── FileBackedStore ──────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
struct QueueMapV1 {
|
||||
map: HashMap<Vec<u8>, VecDeque<Vec<u8>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
struct QueueMapV2 {
|
||||
map: HashMap<ChannelKey, VecDeque<Vec<u8>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
struct SeqEntry {
|
||||
seq: u64,
|
||||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
/// V3 delivery store: each queue entry carries a monotonic per-inbox sequence number.
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
struct QueueMapV3 {
|
||||
map: HashMap<ChannelKey, VecDeque<SeqEntry>>,
|
||||
next_seq: HashMap<ChannelKey, u64>,
|
||||
}
|
||||
|
||||
/// File-backed storage for KeyPackages and delivery queues.
|
||||
///
|
||||
/// Each mutation flushes the entire map to disk. Suitable for MVP-scale loads.
|
||||
pub struct FileBackedStore {
|
||||
kp_path: PathBuf,
|
||||
ds_path: PathBuf,
|
||||
hk_path: PathBuf,
|
||||
setup_path: PathBuf,
|
||||
users_path: PathBuf,
|
||||
identity_keys_path: PathBuf,
|
||||
channels_path: PathBuf,
|
||||
key_packages: Mutex<HashMap<Vec<u8>, VecDeque<Vec<u8>>>>,
|
||||
deliveries: Mutex<QueueMapV3>,
|
||||
channels: Mutex<HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>>,
|
||||
hybrid_keys: Mutex<HashMap<Vec<u8>, Vec<u8>>>,
|
||||
users: Mutex<HashMap<String, Vec<u8>>>,
|
||||
identity_keys: Mutex<HashMap<String, Vec<u8>>>,
|
||||
endpoints: Mutex<HashMap<Vec<u8>, Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl FileBackedStore {
|
||||
pub fn open(dir: impl AsRef<Path>) -> Result<Self, StorageError> {
|
||||
let dir = dir.as_ref();
|
||||
if !dir.exists() {
|
||||
fs::create_dir_all(dir).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
let kp_path = dir.join("keypackages.bin");
|
||||
let ds_path = dir.join("deliveries.bin");
|
||||
let hk_path = dir.join("hybridkeys.bin");
|
||||
let setup_path = dir.join("server_setup.bin");
|
||||
let users_path = dir.join("users.bin");
|
||||
let identity_keys_path = dir.join("identity_keys.bin");
|
||||
let channels_path = dir.join("channels.bin");
|
||||
|
||||
let key_packages = Mutex::new(Self::load_kp_map(&kp_path)?);
|
||||
let deliveries = Mutex::new(Self::load_delivery_map_v3(&ds_path)?);
|
||||
let hybrid_keys = Mutex::new(Self::load_hybrid_keys(&hk_path)?);
|
||||
let users = Mutex::new(Self::load_users(&users_path)?);
|
||||
let identity_keys = Mutex::new(Self::load_map_string_bytes(&identity_keys_path)?);
|
||||
let channels = Mutex::new(Self::load_channels(&channels_path)?);
|
||||
|
||||
Ok(Self {
|
||||
kp_path,
|
||||
ds_path,
|
||||
hk_path,
|
||||
setup_path,
|
||||
users_path,
|
||||
identity_keys_path,
|
||||
channels_path,
|
||||
key_packages,
|
||||
deliveries,
|
||||
channels,
|
||||
hybrid_keys,
|
||||
users,
|
||||
identity_keys,
|
||||
endpoints: Mutex::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
fn load_channels(
|
||||
path: &Path,
|
||||
) -> Result<HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>, StorageError> {
|
||||
if !path.exists() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
if bytes.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)
|
||||
}
|
||||
|
||||
fn flush_channels(
|
||||
&self,
|
||||
path: &Path,
|
||||
map: &HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>,
|
||||
) -> Result<(), StorageError> {
|
||||
let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
fn load_kp_map(path: &Path) -> Result<HashMap<Vec<u8>, VecDeque<Vec<u8>>>, StorageError> {
|
||||
if !path.exists() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
if bytes.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
let map: QueueMapV1 = bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)?;
|
||||
Ok(map.map)
|
||||
}
|
||||
|
||||
fn flush_kp_map(
|
||||
&self,
|
||||
path: &Path,
|
||||
map: &HashMap<Vec<u8>, VecDeque<Vec<u8>>>,
|
||||
) -> Result<(), StorageError> {
|
||||
let payload = QueueMapV1 { map: map.clone() };
|
||||
let bytes = bincode::serialize(&payload).map_err(|_| StorageError::Serde)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
/// Load deliveries as V3. Falls back to V2 format (assigns seqs starting at 0).
|
||||
fn load_delivery_map_v3(path: &Path) -> Result<QueueMapV3, StorageError> {
|
||||
if !path.exists() {
|
||||
return Ok(QueueMapV3::default());
|
||||
}
|
||||
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
if bytes.is_empty() {
|
||||
return Ok(QueueMapV3::default());
|
||||
}
|
||||
// Try V3 first.
|
||||
if let Ok(v3) = bincode::deserialize::<QueueMapV3>(&bytes) {
|
||||
return Ok(v3);
|
||||
}
|
||||
// Fall back to V2: assign ascending seqs starting at 0 per channel.
|
||||
let v2 = bincode::deserialize::<QueueMapV2>(&bytes)
|
||||
.map_err(|_| StorageError::Io("deliveries file: unrecognised format".into()))?;
|
||||
let mut v3 = QueueMapV3::default();
|
||||
for (key, queue) in v2.map {
|
||||
let entries: VecDeque<SeqEntry> = queue
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, data)| SeqEntry { seq: i as u64, data })
|
||||
.collect();
|
||||
let next = entries.len() as u64;
|
||||
v3.next_seq.insert(key.clone(), next);
|
||||
v3.map.insert(key, entries);
|
||||
}
|
||||
Ok(v3)
|
||||
}
|
||||
|
||||
fn flush_delivery_map(&self, path: &Path, map: &QueueMapV3) -> Result<(), StorageError> {
|
||||
let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
fn load_hybrid_keys(path: &Path) -> Result<HashMap<Vec<u8>, Vec<u8>>, StorageError> {
|
||||
if !path.exists() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
if bytes.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)
|
||||
}
|
||||
|
||||
fn flush_hybrid_keys(
|
||||
&self,
|
||||
path: &Path,
|
||||
map: &HashMap<Vec<u8>, Vec<u8>>,
|
||||
) -> Result<(), StorageError> {
|
||||
let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
fn load_users(path: &Path) -> Result<HashMap<String, Vec<u8>>, StorageError> {
|
||||
if !path.exists() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
if bytes.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)
|
||||
}
|
||||
|
||||
fn flush_users(&self, path: &Path, map: &HashMap<String, Vec<u8>>) -> Result<(), StorageError> {
|
||||
let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
fn load_map_string_bytes(path: &Path) -> Result<HashMap<String, Vec<u8>>, StorageError> {
|
||||
Self::load_users(path)
|
||||
}
|
||||
|
||||
fn flush_map_string_bytes(
|
||||
&self,
|
||||
path: &Path,
|
||||
map: &HashMap<String, Vec<u8>>,
|
||||
) -> Result<(), StorageError> {
|
||||
self.flush_users(path, map)
|
||||
}
|
||||
}
|
||||
|
||||
impl Store for FileBackedStore {
|
||||
fn upload_key_package(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
package: Vec<u8>,
|
||||
) -> Result<(), StorageError> {
|
||||
let mut map = lock(&self.key_packages)?;
|
||||
map.entry(identity_key.to_vec())
|
||||
.or_default()
|
||||
.push_back(package);
|
||||
self.flush_kp_map(&self.kp_path, &*map)
|
||||
}
|
||||
|
||||
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let mut map = lock(&self.key_packages)?;
|
||||
let package = map.get_mut(identity_key).and_then(|q| q.pop_front());
|
||||
self.flush_kp_map(&self.kp_path, &*map)?;
|
||||
Ok(package)
|
||||
}
|
||||
|
||||
fn enqueue(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
payload: Vec<u8>,
|
||||
) -> Result<u64, StorageError> {
|
||||
let mut inner = lock(&self.deliveries)?;
|
||||
let key = ChannelKey {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
let entry = inner.next_seq.entry(key.clone()).or_insert(0);
|
||||
let seq = *entry;
|
||||
*entry = seq + 1;
|
||||
inner.map.entry(key).or_default().push_back(SeqEntry { seq, data: payload });
|
||||
self.flush_delivery_map(&self.ds_path, &*inner)?;
|
||||
Ok(seq)
|
||||
}
|
||||
|
||||
fn fetch(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
||||
let mut inner = lock(&self.deliveries)?;
|
||||
let key = ChannelKey {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
let messages: Vec<(u64, Vec<u8>)> = inner
|
||||
.map
|
||||
.get_mut(&key)
|
||||
.map(|q| q.drain(..).map(|e| (e.seq, e.data)).collect())
|
||||
.unwrap_or_default();
|
||||
self.flush_delivery_map(&self.ds_path, &*inner)?;
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
fn fetch_limited(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
||||
let mut inner = lock(&self.deliveries)?;
|
||||
let key = ChannelKey {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
let messages: Vec<(u64, Vec<u8>)> = inner
|
||||
.map
|
||||
.get_mut(&key)
|
||||
.map(|q| {
|
||||
let count = limit.min(q.len());
|
||||
q.drain(..count).map(|e| (e.seq, e.data)).collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
self.flush_delivery_map(&self.ds_path, &*inner)?;
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
fn queue_depth(&self, recipient_key: &[u8], channel_id: &[u8]) -> Result<usize, StorageError> {
|
||||
let inner = lock(&self.deliveries)?;
|
||||
let key = ChannelKey {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
Ok(inner.map.get(&key).map(|q| q.len()).unwrap_or(0))
|
||||
}
|
||||
|
||||
fn gc_expired_messages(&self, _max_age_secs: u64) -> Result<usize, StorageError> {
|
||||
// FileBackedStore does not track timestamps per message — no-op.
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn upload_hybrid_key(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
hybrid_pk: Vec<u8>,
|
||||
) -> Result<(), StorageError> {
|
||||
let mut map = lock(&self.hybrid_keys)?;
|
||||
map.insert(identity_key.to_vec(), hybrid_pk);
|
||||
self.flush_hybrid_keys(&self.hk_path, &*map)
|
||||
}
|
||||
|
||||
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let map = lock(&self.hybrid_keys)?;
|
||||
Ok(map.get(identity_key).cloned())
|
||||
}
|
||||
|
||||
fn store_server_setup(&self, setup: Vec<u8>) -> Result<(), StorageError> {
|
||||
if let Some(parent) = self.setup_path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(&self.setup_path, setup).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(&self.setup_path, std::fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_server_setup(&self) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
if !self.setup_path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
let bytes = fs::read(&self.setup_path).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
if bytes.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(bytes))
|
||||
}
|
||||
|
||||
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
|
||||
let mut map = lock(&self.users)?;
|
||||
match map.entry(username.to_string()) {
|
||||
std::collections::hash_map::Entry::Occupied(_) => {
|
||||
return Err(StorageError::DuplicateUser(username.to_string()))
|
||||
}
|
||||
std::collections::hash_map::Entry::Vacant(v) => {
|
||||
v.insert(record);
|
||||
}
|
||||
}
|
||||
self.flush_users(&self.users_path, &*map)
|
||||
}
|
||||
|
||||
fn get_user_record(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let map = lock(&self.users)?;
|
||||
Ok(map.get(username).cloned())
|
||||
}
|
||||
|
||||
fn has_user_record(&self, username: &str) -> Result<bool, StorageError> {
|
||||
let map = lock(&self.users)?;
|
||||
Ok(map.contains_key(username))
|
||||
}
|
||||
|
||||
fn store_user_identity_key(
|
||||
&self,
|
||||
username: &str,
|
||||
identity_key: Vec<u8>,
|
||||
) -> Result<(), StorageError> {
|
||||
let mut map = lock(&self.identity_keys)?;
|
||||
map.insert(username.to_string(), identity_key);
|
||||
self.flush_map_string_bytes(&self.identity_keys_path, &*map)
|
||||
}
|
||||
|
||||
fn get_user_identity_key(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let map = lock(&self.identity_keys)?;
|
||||
Ok(map.get(username).cloned())
|
||||
}
|
||||
|
||||
fn resolve_identity_key(&self, identity_key: &[u8]) -> Result<Option<String>, StorageError> {
|
||||
let map = lock(&self.identity_keys)?;
|
||||
for (username, ik) in map.iter() {
|
||||
if ik.as_slice() == identity_key {
|
||||
return Ok(Some(username.clone()));
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn peek(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
||||
let inner = lock(&self.deliveries)?;
|
||||
let key = ChannelKey {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
let messages: Vec<(u64, Vec<u8>)> = inner
|
||||
.map
|
||||
.get(&key)
|
||||
.map(|q| {
|
||||
let count = if limit == 0 { q.len() } else { limit.min(q.len()) };
|
||||
q.iter()
|
||||
.take(count)
|
||||
.map(|e| (e.seq, e.data.clone()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
// Non-destructive: do NOT flush.
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
fn ack(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
seq_up_to: u64,
|
||||
) -> Result<usize, StorageError> {
|
||||
let mut inner = lock(&self.deliveries)?;
|
||||
let key = ChannelKey {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
let removed = if let Some(q) = inner.map.get_mut(&key) {
|
||||
let before = q.len();
|
||||
q.retain(|e| e.seq > seq_up_to);
|
||||
before - q.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
self.flush_delivery_map(&self.ds_path, &*inner)?;
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
fn publish_endpoint(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
node_addr: Vec<u8>,
|
||||
) -> Result<(), StorageError> {
|
||||
let mut map = lock(&self.endpoints)?;
|
||||
map.insert(identity_key.to_vec(), node_addr);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn resolve_endpoint(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let map = lock(&self.endpoints)?;
|
||||
Ok(map.get(identity_key).cloned())
|
||||
}
|
||||
|
||||
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError> {
|
||||
let (a, b) = if member_a < member_b {
|
||||
(member_a.to_vec(), member_b.to_vec())
|
||||
} else {
|
||||
(member_b.to_vec(), member_a.to_vec())
|
||||
};
|
||||
let mut map = lock(&self.channels)?;
|
||||
if let Some((channel_id, _)) = map.iter().find(|(_, (ma, mb))| ma == &a && mb == &b) {
|
||||
return Ok(channel_id.clone());
|
||||
}
|
||||
let mut channel_id = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut channel_id);
|
||||
let channel_id = channel_id.to_vec();
|
||||
map.insert(channel_id.clone(), (a, b));
|
||||
self.flush_channels(&self.channels_path, &*map)?;
|
||||
Ok(channel_id)
|
||||
}
|
||||
|
||||
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError> {
|
||||
let map = lock(&self.channels)?;
|
||||
Ok(map.get(channel_id).cloned())
|
||||
}
|
||||
|
||||
fn store_identity_home_server(
|
||||
&self,
|
||||
_identity_key: &[u8],
|
||||
_home_server: &str,
|
||||
) -> Result<(), StorageError> {
|
||||
// File-backed store: federation mappings are ephemeral (in-memory only).
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_identity_home_server(
|
||||
&self,
|
||||
_identity_key: &[u8],
|
||||
) -> Result<Option<String>, StorageError> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn upsert_federation_peer(
|
||||
&self,
|
||||
_domain: &str,
|
||||
_is_active: bool,
|
||||
) -> Result<(), StorageError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn list_federation_peers(&self) -> Result<Vec<(String, bool)>, StorageError> {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn temp_store() -> (TempDir, FileBackedStore) {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let store = FileBackedStore::open(dir.path()).unwrap();
|
||||
(dir, store)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_package_upload_fetch() {
|
||||
let (_dir, store) = temp_store();
|
||||
let ik = vec![1u8; 32];
|
||||
store.upload_key_package(&ik, vec![10, 20, 30]).unwrap();
|
||||
let pkg = store.fetch_key_package(&ik).unwrap();
|
||||
assert_eq!(pkg, Some(vec![10, 20, 30]));
|
||||
// Second fetch should return None (consumed)
|
||||
let pkg2 = store.fetch_key_package(&ik).unwrap();
|
||||
assert_eq!(pkg2, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enqueue_fetch_with_seq() {
|
||||
let (_dir, store) = temp_store();
|
||||
let rk = vec![2u8; 32];
|
||||
let ch = vec![];
|
||||
let seq0 = store.enqueue(&rk, &ch, vec![1]).unwrap();
|
||||
let seq1 = store.enqueue(&rk, &ch, vec![2]).unwrap();
|
||||
assert_eq!(seq0, 0);
|
||||
assert_eq!(seq1, 1);
|
||||
let msgs = store.fetch(&rk, &ch).unwrap();
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0], (0, vec![1]));
|
||||
assert_eq!(msgs[1], (1, vec![2]));
|
||||
// After fetch, queue should be empty
|
||||
let msgs2 = store.fetch(&rk, &ch).unwrap();
|
||||
assert!(msgs2.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fetch_limited_respects_limit() {
|
||||
let (_dir, store) = temp_store();
|
||||
let rk = vec![3u8; 32];
|
||||
let ch = vec![];
|
||||
for i in 0..5 {
|
||||
store.enqueue(&rk, &ch, vec![i]).unwrap();
|
||||
}
|
||||
let msgs = store.fetch_limited(&rk, &ch, 2).unwrap();
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].1, vec![0]);
|
||||
assert_eq!(msgs[1].1, vec![1]);
|
||||
// Remaining 3 should still be there
|
||||
let depth = store.queue_depth(&rk, &ch).unwrap();
|
||||
assert_eq!(depth, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn queue_depth_tracking() {
|
||||
let (_dir, store) = temp_store();
|
||||
let rk = vec![4u8; 32];
|
||||
let ch = vec![];
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0);
|
||||
store.enqueue(&rk, &ch, vec![1]).unwrap();
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 1);
|
||||
store.enqueue(&rk, &ch, vec![2]).unwrap();
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 2);
|
||||
store.fetch(&rk, &ch).unwrap();
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hybrid_key_upload_fetch() {
|
||||
let (_dir, store) = temp_store();
|
||||
let ik = vec![5u8; 32];
|
||||
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), None);
|
||||
store.upload_hybrid_key(&ik, vec![99; 100]).unwrap();
|
||||
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(vec![99; 100]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_record_crud() {
|
||||
let (_dir, store) = temp_store();
|
||||
assert!(!store.has_user_record("alice").unwrap());
|
||||
store.store_user_record("alice", vec![1, 2, 3]).unwrap();
|
||||
assert!(store.has_user_record("alice").unwrap());
|
||||
assert_eq!(store.get_user_record("alice").unwrap(), Some(vec![1, 2, 3]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_identity_key_crud() {
|
||||
let (_dir, store) = temp_store();
|
||||
assert_eq!(store.get_user_identity_key("bob").unwrap(), None);
|
||||
store.store_user_identity_key("bob", vec![7u8; 32]).unwrap();
|
||||
assert_eq!(store.get_user_identity_key("bob").unwrap(), Some(vec![7u8; 32]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn endpoint_publish_resolve() {
|
||||
let (_dir, store) = temp_store();
|
||||
let ik = vec![8u8; 32];
|
||||
assert_eq!(store.resolve_endpoint(&ik).unwrap(), None);
|
||||
store.publish_endpoint(&ik, vec![10, 20]).unwrap();
|
||||
assert_eq!(store.resolve_endpoint(&ik).unwrap(), Some(vec![10, 20]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_channel_and_members() {
|
||||
let (_dir, store) = temp_store();
|
||||
let a = vec![1u8; 32];
|
||||
let b = vec![2u8; 32];
|
||||
assert_eq!(store.get_channel_members(&[0u8; 16]).unwrap(), None);
|
||||
let id1 = store.create_channel(&a, &b).unwrap();
|
||||
assert_eq!(id1.len(), 16);
|
||||
let members = store.get_channel_members(&id1).unwrap().unwrap();
|
||||
assert_eq!(members.0, a);
|
||||
assert_eq!(members.1, b);
|
||||
let id2 = store.create_channel(&b, &a).unwrap();
|
||||
assert_eq!(id1, id2);
|
||||
}
|
||||
}
|
||||
78
crates/quicproquo-server/src/tls.rs
Normal file
78
crates/quicproquo-server/src/tls.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Context;
|
||||
use quinn::ServerConfig;
|
||||
use quinn_proto::crypto::rustls::QuicServerConfig;
|
||||
use rcgen::generate_simple_self_signed;
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use rustls::version::TLS13;
|
||||
|
||||
/// Ensure a self-signed certificate exists on disk and return a QUIC server config.
|
||||
/// When `production` is true, cert and key must already exist (no auto-generation).
|
||||
pub fn build_server_config(
|
||||
cert_path: &PathBuf,
|
||||
key_path: &PathBuf,
|
||||
production: bool,
|
||||
) -> anyhow::Result<ServerConfig> {
|
||||
if !cert_path.exists() || !key_path.exists() {
|
||||
if production {
|
||||
anyhow::bail!(
|
||||
"TLS cert or key missing at {:?} / {:?}; production mode forbids auto-generation",
|
||||
cert_path,
|
||||
key_path
|
||||
);
|
||||
}
|
||||
generate_self_signed_cert(cert_path, key_path)?;
|
||||
}
|
||||
|
||||
let cert_bytes = std::fs::read(cert_path).context("read cert")?;
|
||||
let key_bytes = std::fs::read(key_path).context("read key")?;
|
||||
|
||||
let cert_chain = vec![CertificateDer::from(cert_bytes)];
|
||||
let key = PrivateKeyDer::try_from(key_bytes).map_err(|_| anyhow::anyhow!("invalid key"))?;
|
||||
|
||||
let mut tls = rustls::ServerConfig::builder_with_protocol_versions(&[&TLS13])
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, key)?;
|
||||
tls.alpn_protocols = vec![b"capnp".to_vec()];
|
||||
|
||||
let crypto = QuicServerConfig::try_from(tls)
|
||||
.map_err(|e| anyhow::anyhow!("invalid server TLS config: {e}"))?;
|
||||
|
||||
Ok(ServerConfig::with_crypto(std::sync::Arc::new(crypto)))
|
||||
}
|
||||
|
||||
fn generate_self_signed_cert(cert_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result<()> {
|
||||
if let Some(parent) = cert_path.parent() {
|
||||
std::fs::create_dir_all(parent).context("create cert dir")?;
|
||||
}
|
||||
if let Some(parent) = key_path.parent() {
|
||||
std::fs::create_dir_all(parent).context("create key dir")?;
|
||||
}
|
||||
|
||||
let subject_alt_names = vec![
|
||||
"localhost".to_string(),
|
||||
"127.0.0.1".to_string(),
|
||||
"::1".to_string(),
|
||||
];
|
||||
|
||||
let issued = generate_simple_self_signed(subject_alt_names)?;
|
||||
let key_der = issued.key_pair.serialize_der();
|
||||
|
||||
std::fs::write(cert_path, issued.cert.der()).context("write cert")?;
|
||||
std::fs::write(key_path, &key_der).context("write key")?;
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let perms = std::fs::Permissions::from_mode(0o600);
|
||||
std::fs::set_permissions(key_path, perms).context("set key permissions")?;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
cert = %cert_path.display(),
|
||||
key = %key_path.display(),
|
||||
"generated self-signed TLS certificate"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user