fix: security hardening — 40 findings from full codebase review
Full codebase review by 4 independent agents (security, architecture,
code quality, correctness) identified ~80 findings. This commit fixes 40
of them across all workspace crates.
Critical fixes:
- Federation service: validate origin against mTLS cert CN/SAN (C1)
- WS bridge: add DM channel auth, size limits, rate limiting (C2)
- hpke_seal: panic on error instead of silent empty ciphertext (C3)
- hpke_setup_sender_and_export: error on parse fail, no PQ downgrade (C7)
Security fixes:
- Zeroize: seed_bytes() returns Zeroizing<[u8;32]>, private_to_bytes()
returns Zeroizing<Vec<u8>>, ClientAuth.access_token, SessionState.password,
conversation hex_key all wrapped in Zeroizing
- Keystore: 0o600 file permissions on Unix
- MeshIdentity: 0o600 file permissions on Unix
- Timing floors: resolveIdentity + WS bridge resolve_user get 5ms floor
- Mobile: TLS verification gated behind insecure-dev feature flag
- Proto: from_bytes default limit tightened from 64 MiB to 8 MiB
Correctness fixes:
- fetch_wait: register waiter before fetch to close TOCTOU window
- MeshEnvelope: exclude hop_count from signature (forwarding no longer
invalidates sender signature)
- BroadcastChannel: encrypt returns Result instead of panicking
- transcript: rename verify_transcript_chain → validate_transcript_structure
- group.rs: extract shared process_incoming() for receive_message variants
- auth_ops: remove spurious RegistrationRequest deserialization
- MeshStore.seen: bounded to 100K with FIFO eviction
Quality fixes:
- FFI error classification: typed downcast instead of string matching
- Plugin HookVTable: SAFETY documentation for unsafe Send+Sync
- clippy::unwrap_used: warn → deny workspace-wide
- Various .unwrap_or("") → proper error returns
Review report: docs/REVIEW-2026-03-04.md
152 tests passing (72 core + 35 server + 14 E2E + 1 doctest + 30 P2P)
This commit is contained in:
@@ -178,6 +178,49 @@ pub fn validate_auth_context(
|
||||
Err(crate::error_codes::coded_error(E003_INVALID_TOKEN, "invalid accessToken"))
|
||||
}
|
||||
|
||||
/// Validate a raw bearer token (no Cap'n Proto dependency).
|
||||
/// Used by the WebSocket JSON-RPC bridge.
|
||||
pub fn validate_token_raw(
|
||||
cfg: &AuthConfig,
|
||||
sessions: &DashMap<Vec<u8>, SessionInfo>,
|
||||
token: &[u8],
|
||||
) -> Result<AuthContext, String> {
|
||||
if token.is_empty() {
|
||||
return Err("empty access token".to_string());
|
||||
}
|
||||
|
||||
// Check static bearer token.
|
||||
if let Some(expected) = &cfg.required_token {
|
||||
if expected.len() == token.len() && bool::from(expected.as_slice().ct_eq(token)) {
|
||||
return Ok(AuthContext {
|
||||
token: token.to_vec(),
|
||||
identity_key: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check session tokens.
|
||||
if let Some(session) = sessions.get(token) {
|
||||
let now = current_timestamp();
|
||||
if session.expires_at > now {
|
||||
let identity = if session.identity_key.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(session.identity_key.clone())
|
||||
};
|
||||
return Ok(AuthContext {
|
||||
token: token.to_vec(),
|
||||
identity_key: identity,
|
||||
});
|
||||
}
|
||||
drop(session);
|
||||
sessions.remove(token);
|
||||
return Err("session token has expired".to_string());
|
||||
}
|
||||
|
||||
Err("invalid access token".to_string())
|
||||
}
|
||||
|
||||
pub fn require_identity(auth_ctx: &AuthContext) -> Result<&[u8], capnp::Error> {
|
||||
match auth_ctx.identity_key.as_deref() {
|
||||
Some(ik) => Ok(ik),
|
||||
|
||||
@@ -36,6 +36,8 @@ pub struct FileConfig {
|
||||
/// When true, audit logs hash identity key prefixes and omit payload sizes.
|
||||
#[serde(default)]
|
||||
pub redact_logs: Option<bool>,
|
||||
/// WebSocket JSON-RPC bridge listen address (e.g. "0.0.0.0:9000").
|
||||
pub ws_listen: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -60,6 +62,8 @@ pub struct EffectiveConfig {
|
||||
pub plugin_dir: Option<PathBuf>,
|
||||
/// When true, audit logs hash identity key prefixes and omit payload sizes.
|
||||
pub redact_logs: bool,
|
||||
/// WebSocket JSON-RPC bridge listen address. If set, the bridge is started.
|
||||
pub ws_listen: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
@@ -225,6 +229,10 @@ pub fn merge_config(args: &crate::Args, file: &FileConfig) -> EffectiveConfig {
|
||||
|
||||
let plugin_dir = args.plugin_dir.clone().or_else(|| file.plugin_dir.clone());
|
||||
let redact_logs = args.redact_logs || file.redact_logs.unwrap_or(false);
|
||||
let ws_listen = args
|
||||
.ws_listen
|
||||
.clone()
|
||||
.or_else(|| file.ws_listen.clone());
|
||||
|
||||
EffectiveConfig {
|
||||
listen,
|
||||
@@ -242,6 +250,7 @@ pub fn merge_config(args: &crate::Args, file: &FileConfig) -> EffectiveConfig {
|
||||
federation,
|
||||
plugin_dir,
|
||||
redact_logs,
|
||||
ws_listen,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,21 +2,99 @@
|
||||
//!
|
||||
//! Delegates all operations to the local [`Store`], acting as a trusted relay
|
||||
//! from authenticated peer servers.
|
||||
//!
|
||||
//! **Security:** Each handler validates the request's `origin` field against
|
||||
//! the `verified_peer_domain` extracted from the mTLS client certificate at
|
||||
//! connection time. Per-peer rate limits prevent abuse.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use capnp::capability::Promise;
|
||||
use dashmap::DashMap;
|
||||
use quicproquo_proto::federation_capnp::federation_service;
|
||||
use tokio::sync::Notify;
|
||||
use dashmap::DashMap;
|
||||
|
||||
use crate::auth::RateEntry;
|
||||
use crate::storage::Store;
|
||||
|
||||
/// Per-peer federation rate limit: max requests within a 60-second window.
|
||||
const FED_RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||||
const FED_RATE_LIMIT_MAX: u32 = 200;
|
||||
|
||||
/// Inbound federation RPC handler.
|
||||
pub struct FederationServiceImpl {
|
||||
pub store: Arc<dyn Store>,
|
||||
pub waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
pub local_domain: String,
|
||||
/// The peer domain extracted from the mTLS client certificate's CN/SAN
|
||||
/// at connection time. All requests must declare an `origin` matching this.
|
||||
pub verified_peer_domain: Option<String>,
|
||||
/// Per-peer rate limiter (keyed by peer domain).
|
||||
pub rate_limits: Arc<DashMap<String, RateEntry>>,
|
||||
}
|
||||
|
||||
/// Validate that the request's `origin` matches the mTLS-verified peer domain.
|
||||
fn validate_origin(
|
||||
verified: &Option<String>,
|
||||
declared: &str,
|
||||
) -> Result<(), capnp::Error> {
|
||||
match verified {
|
||||
Some(ref expected) if expected == declared => Ok(()),
|
||||
Some(ref expected) => Err(capnp::Error::failed(format!(
|
||||
"federation auth: origin '{}' does not match mTLS cert '{}'",
|
||||
declared, expected
|
||||
))),
|
||||
None => Err(capnp::Error::failed(
|
||||
"federation auth: no verified peer domain (mTLS required)".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract and validate the origin string from the request's auth field.
|
||||
fn extract_and_validate_origin(
|
||||
service: &FederationServiceImpl,
|
||||
get_auth: Result<quicproquo_proto::federation_capnp::federation_auth::Reader<'_>, capnp::Error>,
|
||||
) -> Result<String, capnp::Error> {
|
||||
let auth = get_auth
|
||||
.map_err(|_| capnp::Error::failed("federation auth: missing auth field".into()))?;
|
||||
let origin_reader = auth.get_origin()
|
||||
.map_err(|_| capnp::Error::failed("federation auth: missing origin".into()))?;
|
||||
let origin = origin_reader.to_str()
|
||||
.map_err(|_| capnp::Error::failed("federation auth: origin is not valid UTF-8".into()))?;
|
||||
|
||||
if origin.is_empty() {
|
||||
return Err(capnp::Error::failed("federation auth: origin must not be empty".into()));
|
||||
}
|
||||
|
||||
validate_origin(&service.verified_peer_domain, origin)?;
|
||||
check_federation_rate_limit(&service.rate_limits, origin)?;
|
||||
|
||||
Ok(origin.to_string())
|
||||
}
|
||||
|
||||
/// Per-peer federation rate limiter.
|
||||
fn check_federation_rate_limit(
|
||||
rate_limits: &DashMap<String, RateEntry>,
|
||||
peer_domain: &str,
|
||||
) -> Result<(), capnp::Error> {
|
||||
let now = crate::auth::current_timestamp();
|
||||
let mut entry = rate_limits.entry(peer_domain.to_string()).or_insert(RateEntry {
|
||||
count: 0,
|
||||
window_start: now,
|
||||
});
|
||||
|
||||
if now - entry.window_start >= FED_RATE_LIMIT_WINDOW_SECS {
|
||||
entry.count = 1;
|
||||
entry.window_start = now;
|
||||
} else {
|
||||
entry.count += 1;
|
||||
if entry.count > FED_RATE_LIMIT_MAX {
|
||||
return Err(capnp::Error::failed(format!(
|
||||
"federation rate limit exceeded for peer '{peer_domain}'"
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl federation_service::Server for FederationServiceImpl {
|
||||
@@ -30,6 +108,12 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
// Validate origin against mTLS cert and apply rate limit.
|
||||
let origin = match extract_and_validate_origin(self, p.get_auth()) {
|
||||
Ok(o) => o,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad recipient_key: {e}"))),
|
||||
@@ -40,13 +124,6 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
|
||||
if let Ok(a) = p.get_auth() {
|
||||
if let Ok(origin) = a.get_origin() {
|
||||
let origin = origin.to_str().unwrap_or("?");
|
||||
tracing::debug!(origin = origin, "federation relay_enqueue");
|
||||
}
|
||||
}
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(capnp::Error::failed("recipient_key must be 32 bytes".into()));
|
||||
}
|
||||
@@ -67,6 +144,7 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
origin = %origin,
|
||||
recipient_prefix = %hex::encode(&recipient_key[..4]),
|
||||
seq = seq,
|
||||
"federation: relayed enqueue"
|
||||
@@ -85,6 +163,12 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
// Validate origin against mTLS cert and apply rate limit.
|
||||
let _origin = match extract_and_validate_origin(self, p.get_auth()) {
|
||||
Ok(o) => o,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
let recipient_keys = match p.get_recipient_keys() {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad recipient_keys: {e}"))),
|
||||
@@ -134,11 +218,21 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
params: federation_service::ProxyFetchKeyPackageParams,
|
||||
mut results: federation_service::ProxyFetchKeyPackageResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let identity_key = match params.get().and_then(|p| p.get_identity_key()) {
|
||||
Ok(v) => v.to_vec(),
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
// Validate origin against mTLS cert and apply rate limit.
|
||||
if let Err(e) = extract_and_validate_origin(self, p.get_auth()) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let identity_key = match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad identity_key: {e}"))),
|
||||
};
|
||||
|
||||
match self.store.fetch_key_package(&identity_key) {
|
||||
Ok(Some(pkg)) => results.get().set_package(&pkg),
|
||||
Ok(None) => results.get().set_package(&[]),
|
||||
@@ -153,11 +247,21 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
params: federation_service::ProxyFetchHybridKeyParams,
|
||||
mut results: federation_service::ProxyFetchHybridKeyResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let identity_key = match params.get().and_then(|p| p.get_identity_key()) {
|
||||
Ok(v) => v.to_vec(),
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
// Validate origin against mTLS cert and apply rate limit.
|
||||
if let Err(e) = extract_and_validate_origin(self, p.get_auth()) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let identity_key = match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad identity_key: {e}"))),
|
||||
};
|
||||
|
||||
match self.store.fetch_hybrid_key(&identity_key) {
|
||||
Ok(Some(pk)) => results.get().set_hybrid_public_key(&pk),
|
||||
Ok(None) => results.get().set_hybrid_public_key(&[]),
|
||||
@@ -172,7 +276,17 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
params: federation_service::ProxyResolveUserParams,
|
||||
mut results: federation_service::ProxyResolveUserResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let username = match params.get().and_then(|p| p.get_username()) {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
// Validate origin against mTLS cert and apply rate limit.
|
||||
if let Err(e) = extract_and_validate_origin(self, p.get_auth()) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let username = match p.get_username() {
|
||||
Ok(u) => match u.to_str() {
|
||||
Ok(s) => s.to_string(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad utf-8: {e}"))),
|
||||
@@ -194,8 +308,42 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
_params: federation_service::FederationHealthParams,
|
||||
mut results: federation_service::FederationHealthResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
// Health check does not require origin validation (diagnostic endpoint).
|
||||
results.get().set_status("ok");
|
||||
results.get().set_server_domain(&self.local_domain);
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the peer domain from the mTLS client certificate's first SAN (DNS name)
|
||||
/// or CN, given the QUIC connection's peer identity (a certificate chain).
|
||||
pub fn extract_peer_domain(conn: &quinn::Connection) -> Option<String> {
|
||||
let identity = conn.peer_identity()?;
|
||||
let certs = identity.downcast::<Vec<rustls::pki_types::CertificateDer<'static>>>().ok()?;
|
||||
let first_cert = certs.first()?;
|
||||
|
||||
// Parse the DER certificate to extract SAN DNS names or CN.
|
||||
let (_, parsed) = x509_parser::parse_x509_certificate(first_cert.as_ref()).ok()?;
|
||||
|
||||
// Prefer SAN DNS names.
|
||||
if let Ok(Some(san)) = parsed.subject_alternative_name() {
|
||||
for name in &san.value.general_names {
|
||||
if let x509_parser::extensions::GeneralName::DNSName(dns) = name {
|
||||
return Some(dns.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to CN.
|
||||
for rdn in parsed.subject().iter() {
|
||||
for attr in rdn.iter() {
|
||||
if attr.attr_type() == &x509_parser::oid_registry::OID_X509_COMMON_NAME {
|
||||
if let Ok(cn) = attr.as_str() {
|
||||
return Some(cn.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ mod plugin_loader;
|
||||
mod sql_store;
|
||||
mod tls;
|
||||
mod storage;
|
||||
mod ws_bridge;
|
||||
|
||||
use auth::{AuthConfig, PendingLogin, RateEntry, SessionInfo};
|
||||
use config::{
|
||||
@@ -119,6 +120,10 @@ struct Args {
|
||||
/// Redact identity key prefixes and payload sizes in audit logs for metadata minimization.
|
||||
#[arg(long, env = "QPQ_REDACT_LOGS", default_value_t = false)]
|
||||
redact_logs: bool,
|
||||
|
||||
/// WebSocket JSON-RPC bridge listen address (e.g. 0.0.0.0:9000). Enables browser connectivity.
|
||||
#[arg(long, env = "QPQ_WS_LISTEN")]
|
||||
ws_listen: Option<String>,
|
||||
}
|
||||
|
||||
// ── Entry point ───────────────────────────────────────────────────────────────
|
||||
@@ -329,6 +334,23 @@ async fn main() -> anyhow::Result<()> {
|
||||
Arc::clone(&waiters),
|
||||
);
|
||||
|
||||
// ── WebSocket JSON-RPC bridge ──────────────────────────────────────────
|
||||
if let Some(ws_addr_str) = &effective.ws_listen {
|
||||
let ws_addr: SocketAddr = ws_addr_str
|
||||
.parse()
|
||||
.context("--ws-listen must be host:port (e.g. 0.0.0.0:9000)")?;
|
||||
let ws_state = Arc::new(ws_bridge::WsBridgeState {
|
||||
store: Arc::clone(&store),
|
||||
waiters: Arc::clone(&waiters),
|
||||
auth_cfg: Arc::clone(&auth_cfg),
|
||||
sessions: Arc::clone(&sessions),
|
||||
rate_limits: Arc::clone(&rate_limits),
|
||||
sealed_sender: effective.sealed_sender,
|
||||
allow_insecure_auth: effective.allow_insecure_auth,
|
||||
});
|
||||
ws_bridge::spawn_ws_bridge(ws_addr, ws_state);
|
||||
}
|
||||
|
||||
let endpoint = Endpoint::server(server_config, listen)?;
|
||||
|
||||
tracing::info!(
|
||||
@@ -539,10 +561,20 @@ async fn main() -> anyhow::Result<()> {
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let verified_peer_domain =
|
||||
federation::service::extract_peer_domain(&conn);
|
||||
if let Some(ref peer) = verified_peer_domain {
|
||||
tracing::info!(peer_domain = %peer, "federation: mTLS peer authenticated");
|
||||
} else {
|
||||
tracing::warn!(peer = %conn.remote_address(), "federation: could not extract peer domain from mTLS cert");
|
||||
}
|
||||
|
||||
let service_impl = federation::service::FederationServiceImpl {
|
||||
store,
|
||||
waiters,
|
||||
local_domain: domain,
|
||||
verified_peer_domain,
|
||||
rate_limits: Arc::new(dashmap::DashMap::new()),
|
||||
};
|
||||
let client: quicproquo_proto::federation_capnp::federation_service::Client =
|
||||
capnp_rpc::new_client(service_impl);
|
||||
|
||||
@@ -332,16 +332,6 @@ impl NodeServiceImpl {
|
||||
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
|
||||
}
|
||||
|
||||
let _request = match RegistrationRequest::<OpaqueSuite>::deserialize(&upload_bytes) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("invalid registration upload: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match self.store.has_user_record(&username) {
|
||||
Ok(true) => {
|
||||
return Promise::err(coded_error(
|
||||
|
||||
@@ -113,94 +113,106 @@ impl NodeServiceImpl {
|
||||
let final_path = dir.join(&blob_hex);
|
||||
let meta_path = dir.join(format!("{blob_hex}.meta"));
|
||||
|
||||
// If the blob already exists (fully uploaded), return immediately.
|
||||
if final_path.exists() {
|
||||
results.get().set_blob_id(&blob_hash);
|
||||
return Promise::ok(());
|
||||
}
|
||||
// All file I/O is delegated to spawn_blocking to avoid stalling the Tokio event loop.
|
||||
let uploader_prefix = auth_ctx
|
||||
.identity_key
|
||||
.as_deref()
|
||||
.filter(|k| k.len() >= 4)
|
||||
.map(|k| hex::encode(&k[..4]))
|
||||
.unwrap_or_default();
|
||||
|
||||
// Write chunk at the given offset.
|
||||
let write_result = (|| -> Result<(), String> {
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(false)
|
||||
.open(&part_path)
|
||||
.map_err(|e| format!("open .part file: {e}"))?;
|
||||
file.seek(SeekFrom::Start(offset))
|
||||
.map_err(|e| format!("seek: {e}"))?;
|
||||
file.write_all(&chunk)
|
||||
.map_err(|e| format!("write chunk: {e}"))?;
|
||||
file.sync_all()
|
||||
.map_err(|e| format!("sync: {e}"))?;
|
||||
Ok(())
|
||||
})();
|
||||
Promise::from_future(async move {
|
||||
let blob_hash_clone = blob_hash.clone();
|
||||
let part_path_clone = part_path.clone();
|
||||
let final_path_clone = final_path.clone();
|
||||
let meta_path_clone = meta_path.clone();
|
||||
let chunk_clone = chunk;
|
||||
let mime_clone = mime_type.clone();
|
||||
let prefix_clone = uploader_prefix.clone();
|
||||
|
||||
if let Err(e) = write_result {
|
||||
return Promise::err(coded_error(E009_STORAGE_ERROR, e));
|
||||
}
|
||||
|
||||
// Check if the blob is complete.
|
||||
let end = offset + chunk.len() as u64;
|
||||
if end == total_size {
|
||||
// Verify SHA-256 of the complete file.
|
||||
let verify_result = (|| -> Result<bool, String> {
|
||||
let mut file = std::fs::File::open(&part_path)
|
||||
.map_err(|e| format!("open for verify: {e}"))?;
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buf = [0u8; 64 * 1024];
|
||||
loop {
|
||||
let n = file.read(&mut buf).map_err(|e| format!("read: {e}"))?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
let io_result = tokio::task::spawn_blocking(move || -> Result<Option<bool>, String> {
|
||||
// If the blob already exists (fully uploaded), return immediately.
|
||||
if final_path_clone.exists() {
|
||||
return Ok(None); // signals "already done"
|
||||
}
|
||||
let computed: [u8; 32] = hasher.finalize().into();
|
||||
Ok(computed == blob_hash.as_slice())
|
||||
})();
|
||||
|
||||
match verify_result {
|
||||
Ok(true) => {
|
||||
// Hash matches — finalize the blob.
|
||||
if let Err(e) = std::fs::rename(&part_path, &final_path) {
|
||||
return Promise::err(coded_error(
|
||||
E009_STORAGE_ERROR,
|
||||
format!("rename .part to final: {e}"),
|
||||
));
|
||||
// Write chunk at the given offset.
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(false)
|
||||
.open(&part_path_clone)
|
||||
.map_err(|e| format!("open .part file: {e}"))?;
|
||||
file.seek(SeekFrom::Start(offset))
|
||||
.map_err(|e| format!("seek: {e}"))?;
|
||||
file.write_all(&chunk_clone)
|
||||
.map_err(|e| format!("write chunk: {e}"))?;
|
||||
file.sync_all()
|
||||
.map_err(|e| format!("sync: {e}"))?;
|
||||
|
||||
// Check if the blob is complete.
|
||||
let end = offset + chunk_clone.len() as u64;
|
||||
if end == total_size {
|
||||
// Verify SHA-256 of the complete file.
|
||||
let mut vfile = std::fs::File::open(&part_path_clone)
|
||||
.map_err(|e| format!("open for verify: {e}"))?;
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buf = [0u8; 64 * 1024];
|
||||
loop {
|
||||
let n = vfile.read(&mut buf).map_err(|e| format!("read: {e}"))?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
}
|
||||
let computed: [u8; 32] = hasher.finalize().into();
|
||||
if computed != blob_hash_clone.as_slice() {
|
||||
let _ = std::fs::remove_file(&part_path_clone);
|
||||
return Ok(Some(false)); // hash mismatch
|
||||
}
|
||||
|
||||
// Hash matches — finalize the blob.
|
||||
std::fs::rename(&part_path_clone, &final_path_clone)
|
||||
.map_err(|e| format!("rename .part to final: {e}"))?;
|
||||
|
||||
// Write metadata file.
|
||||
let uploader_prefix = auth_ctx
|
||||
.identity_key
|
||||
.as_deref()
|
||||
.filter(|k| k.len() >= 4)
|
||||
.map(|k| hex::encode(&k[..4]))
|
||||
.unwrap_or_default();
|
||||
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let meta = BlobMeta {
|
||||
mime_type: mime_type.clone(),
|
||||
mime_type: mime_clone,
|
||||
total_size,
|
||||
uploaded_at: now,
|
||||
uploader_key_prefix: uploader_prefix.clone(),
|
||||
uploader_key_prefix: prefix_clone,
|
||||
};
|
||||
|
||||
if let Err(e) = (|| -> Result<(), String> {
|
||||
let json = serde_json::to_string_pretty(&meta)
|
||||
.map_err(|e| format!("serialize meta: {e}"))?;
|
||||
std::fs::write(&meta_path, json.as_bytes())
|
||||
std::fs::write(&meta_path_clone, json.as_bytes())
|
||||
.map_err(|e| format!("write meta: {e}"))?;
|
||||
Ok(())
|
||||
})() {
|
||||
// Non-fatal: the blob is already stored; log and continue.
|
||||
tracing::warn!(error = %e, "failed to write blob metadata");
|
||||
}
|
||||
|
||||
return Ok(Some(true)); // complete + verified
|
||||
}
|
||||
|
||||
Ok(None) // chunk written, not yet complete
|
||||
})
|
||||
.await
|
||||
.map_err(|e| capnp::Error::failed(format!("spawn_blocking join: {e}")))?;
|
||||
|
||||
match io_result {
|
||||
Ok(None) => {
|
||||
// Already existed or chunk written (not yet complete).
|
||||
results.get().set_blob_id(&blob_hash);
|
||||
}
|
||||
Ok(Some(true)) => {
|
||||
// Complete and verified.
|
||||
tracing::info!(
|
||||
blob_hash_prefix = %fmt_hex(&blob_hash[..4]),
|
||||
total_size = total_size,
|
||||
@@ -208,24 +220,21 @@ impl NodeServiceImpl {
|
||||
uploader_prefix = %uploader_prefix,
|
||||
"audit: blob_upload_complete"
|
||||
);
|
||||
results.get().set_blob_id(&blob_hash);
|
||||
}
|
||||
Ok(false) => {
|
||||
// Hash mismatch — delete the .part file.
|
||||
let _ = std::fs::remove_file(&part_path);
|
||||
return Promise::err(coded_error(
|
||||
Ok(Some(false)) => {
|
||||
return Err(coded_error(
|
||||
E026_BLOB_HASH_MISMATCH,
|
||||
"SHA-256 of uploaded data does not match blobHash",
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = std::fs::remove_file(&part_path);
|
||||
return Promise::err(coded_error(E009_STORAGE_ERROR, e));
|
||||
return Err(coded_error(E009_STORAGE_ERROR, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results.get().set_blob_id(&blob_hash);
|
||||
Promise::ok(())
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handle_download_blob(
|
||||
@@ -261,65 +270,57 @@ impl NodeServiceImpl {
|
||||
let blob_path = dir.join(&blob_hex);
|
||||
let meta_path = dir.join(format!("{blob_hex}.meta"));
|
||||
|
||||
// Check that the blob exists.
|
||||
if !blob_path.exists() {
|
||||
return Promise::err(coded_error(E027_BLOB_NOT_FOUND, "blob not found"));
|
||||
}
|
||||
|
||||
// Read metadata.
|
||||
let meta: BlobMeta = match std::fs::read_to_string(&meta_path) {
|
||||
Ok(json) => match serde_json::from_str(&json) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E009_STORAGE_ERROR,
|
||||
format!("corrupt blob metadata: {e}"),
|
||||
));
|
||||
// Delegate all file I/O to spawn_blocking to avoid stalling the event loop.
|
||||
Promise::from_future(async move {
|
||||
let io_result = tokio::task::spawn_blocking(move || -> Result<(Vec<u8>, BlobMeta), capnp::Error> {
|
||||
// Check that the blob exists.
|
||||
if !blob_path.exists() {
|
||||
return Err(coded_error(E027_BLOB_NOT_FOUND, "blob not found"));
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E009_STORAGE_ERROR,
|
||||
format!("read blob metadata: {e}"),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Read the requested chunk.
|
||||
let read_result = (|| -> Result<Vec<u8>, String> {
|
||||
let mut file = std::fs::File::open(&blob_path)
|
||||
.map_err(|e| format!("open blob: {e}"))?;
|
||||
let file_len = file
|
||||
.metadata()
|
||||
.map_err(|e| format!("file metadata: {e}"))?
|
||||
.len();
|
||||
// Read metadata.
|
||||
let meta: BlobMeta = match std::fs::read_to_string(&meta_path) {
|
||||
Ok(json) => serde_json::from_str(&json).map_err(|e| {
|
||||
coded_error(E009_STORAGE_ERROR, format!("corrupt blob metadata: {e}"))
|
||||
})?,
|
||||
Err(e) => {
|
||||
return Err(coded_error(
|
||||
E009_STORAGE_ERROR,
|
||||
format!("read blob metadata: {e}"),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if offset >= file_len {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
// Read the requested chunk.
|
||||
let mut file = std::fs::File::open(&blob_path)
|
||||
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("open blob: {e}")))?;
|
||||
let file_len = file
|
||||
.metadata()
|
||||
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("file metadata: {e}")))?
|
||||
.len();
|
||||
|
||||
file.seek(SeekFrom::Start(offset))
|
||||
.map_err(|e| format!("seek: {e}"))?;
|
||||
let remaining = (file_len - offset) as usize;
|
||||
let to_read = remaining.min(length as usize);
|
||||
let mut buf = vec![0u8; to_read];
|
||||
file.read_exact(&mut buf)
|
||||
.map_err(|e| format!("read chunk: {e}"))?;
|
||||
Ok(buf)
|
||||
})();
|
||||
if offset >= file_len {
|
||||
return Ok((vec![], meta));
|
||||
}
|
||||
|
||||
match read_result {
|
||||
Ok(chunk) => {
|
||||
let mut r = results.get();
|
||||
r.set_chunk(&chunk);
|
||||
r.set_total_size(meta.total_size);
|
||||
r.set_mime_type(&meta.mime_type);
|
||||
}
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(E009_STORAGE_ERROR, e));
|
||||
}
|
||||
}
|
||||
file.seek(SeekFrom::Start(offset))
|
||||
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("seek: {e}")))?;
|
||||
let remaining = (file_len - offset) as usize;
|
||||
let to_read = remaining.min(length as usize);
|
||||
let mut buf = vec![0u8; to_read];
|
||||
file.read_exact(&mut buf)
|
||||
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("read chunk: {e}")))?;
|
||||
Ok((buf, meta))
|
||||
})
|
||||
.await
|
||||
.map_err(|e| capnp::Error::failed(format!("spawn_blocking join: {e}")))?;
|
||||
|
||||
Promise::ok(())
|
||||
let (chunk, meta) = io_result?;
|
||||
let mut r = results.get();
|
||||
r.set_chunk(&chunk);
|
||||
r.set_total_size(meta.total_size);
|
||||
r.set_mime_type(&meta.mime_type);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -502,18 +502,30 @@ impl NodeServiceImpl {
|
||||
}
|
||||
};
|
||||
|
||||
// Register waiter BEFORE the initial fetch to close the TOCTOU window:
|
||||
// an enqueue between fetch and registration would fire notify before
|
||||
// the waiter exists, causing a missed wakeup.
|
||||
let waiter = if timeout_ms > 0 {
|
||||
Some(
|
||||
waiters
|
||||
.entry(recipient_key.clone())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let messages = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
|
||||
|
||||
if messages.is_empty() && timeout_ms > 0 {
|
||||
let waiter = waiters
|
||||
.entry(recipient_key.clone())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone();
|
||||
let _ = timeout(Duration::from_millis(timeout_ms), waiter.notified()).await;
|
||||
let msgs = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
|
||||
fill_payloads_wait(&mut results, msgs);
|
||||
metrics::record_fetch_wait_total();
|
||||
return Ok(());
|
||||
if messages.is_empty() {
|
||||
if let Some(waiter) = waiter {
|
||||
let _ = timeout(Duration::from_millis(timeout_ms), waiter.notified()).await;
|
||||
let msgs = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
|
||||
fill_payloads_wait(&mut results, msgs);
|
||||
metrics::record_fetch_wait_total();
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
fill_payloads_wait(&mut results, messages);
|
||||
|
||||
@@ -46,7 +46,10 @@ impl NodeServiceImpl {
|
||||
}
|
||||
|
||||
let device_name = match p.get_device_name() {
|
||||
Ok(n) => n.to_str().unwrap_or("").to_string(),
|
||||
Ok(n) => match n.to_str() {
|
||||
Ok(s) => s.to_string(),
|
||||
Err(_) => return Promise::err(coded_error(E020_BAD_PARAMS, "deviceName must be valid UTF-8")),
|
||||
},
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
|
||||
|
||||
@@ -164,6 +164,9 @@ impl NodeServiceImpl {
|
||||
));
|
||||
}
|
||||
|
||||
// Timing floor: mask DB-lookup timing differences (same as resolveUser).
|
||||
let deadline = Instant::now() + RESOLVE_TIMING_FLOOR;
|
||||
|
||||
match self.store.resolve_identity_key(identity_key) {
|
||||
Ok(Some(username)) => {
|
||||
results.get().set_username(&username);
|
||||
@@ -174,6 +177,10 @@ impl NodeServiceImpl {
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
// Pad to timing floor before responding.
|
||||
Promise::from_future(async move {
|
||||
tokio::time::sleep_until(deadline).await;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -542,7 +542,7 @@ impl Store for SqlStore {
|
||||
return Ok((id, false));
|
||||
}
|
||||
let mut channel_id = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut channel_id);
|
||||
rand::rngs::OsRng.fill_bytes(&mut channel_id);
|
||||
conn.execute(
|
||||
"INSERT INTO channels (channel_id, member_a, member_b) VALUES (?1, ?2, ?3)",
|
||||
params![channel_id.as_slice(), a, b],
|
||||
|
||||
@@ -757,7 +757,7 @@ impl Store for FileBackedStore {
|
||||
return Ok((channel_id.clone(), false));
|
||||
}
|
||||
let mut channel_id = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut channel_id);
|
||||
rand::rngs::OsRng.fill_bytes(&mut channel_id);
|
||||
let channel_id = channel_id.to_vec();
|
||||
map.insert(channel_id.clone(), (a, b));
|
||||
self.flush_channels(&self.channels_path, &map)?;
|
||||
|
||||
561
crates/quicproquo-server/src/ws_bridge.rs
Normal file
561
crates/quicproquo-server/src/ws_bridge.rs
Normal file
@@ -0,0 +1,561 @@
|
||||
//! WebSocket JSON-RPC bridge for browser clients.
|
||||
//!
|
||||
//! Provides a lightweight JSON-RPC interface over WebSocket so that browsers
|
||||
//! can interact with the server without a Cap'n Proto / QUIC stack.
|
||||
//!
|
||||
//! Security parity with the Cap'n Proto path:
|
||||
//! - Rate limiting via `check_rate_limit()` on all mutating handlers
|
||||
//! - DM channel membership verification on `send`
|
||||
//! - Payload size limits (5 MB)
|
||||
//! - Timing floor on `resolveUser` to mask lookup timing
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use base64::Engine;
|
||||
use dashmap::DashMap;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::SinkExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::Instant;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
use crate::auth::{check_rate_limit, validate_token_raw, AuthConfig, AuthContext, RateEntry, SessionInfo};
|
||||
use crate::storage::Store;
|
||||
|
||||
const B64: base64::engine::general_purpose::GeneralPurpose =
|
||||
base64::engine::general_purpose::STANDARD;
|
||||
|
||||
/// Maximum payload size for WS bridge (same as Cap'n Proto path).
|
||||
const MAX_PAYLOAD_BYTES: usize = 5 * 1024 * 1024;
|
||||
|
||||
/// Minimum response time for resolveUser to mask DB lookup timing differences.
|
||||
const RESOLVE_TIMING_FLOOR: Duration = Duration::from_millis(5);
|
||||
|
||||
// ── Shared state ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Subset of server state needed by the WS bridge (all `Send + Sync`).
|
||||
#[allow(dead_code)] // sealed_sender plumbed for future use
|
||||
pub struct WsBridgeState {
|
||||
pub store: Arc<dyn Store>,
|
||||
pub waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
pub auth_cfg: Arc<AuthConfig>,
|
||||
pub sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
|
||||
pub rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
|
||||
pub sealed_sender: bool,
|
||||
pub allow_insecure_auth: bool,
|
||||
}
|
||||
|
||||
// ── JSON-RPC types ──────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RpcRequest {
|
||||
id: serde_json::Value,
|
||||
method: String,
|
||||
#[serde(default)]
|
||||
params: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RpcResponse {
|
||||
id: serde_json::Value,
|
||||
ok: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
impl RpcResponse {
|
||||
fn success(id: serde_json::Value, result: serde_json::Value) -> Self {
|
||||
Self {
|
||||
id,
|
||||
ok: true,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn error(id: serde_json::Value, msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
ok: false,
|
||||
result: None,
|
||||
error: Some(msg.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Auth helper ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Extract and validate the "token" field from params. In insecure-auth mode
|
||||
/// with no token configured, an empty token is accepted as the bearer.
|
||||
fn extract_auth(
|
||||
state: &WsBridgeState,
|
||||
params: &serde_json::Value,
|
||||
) -> Result<AuthContext, String> {
|
||||
let token_str = params
|
||||
.get("token")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let token_bytes = token_str.as_bytes().to_vec();
|
||||
|
||||
// In insecure-auth mode with no configured token, accept any (including empty)
|
||||
// token as the bearer token. This mirrors the Cap'n Proto path behaviour.
|
||||
if state.allow_insecure_auth && state.auth_cfg.required_token.is_none() {
|
||||
// Treat the request identity from params as the identity.
|
||||
return Ok(AuthContext {
|
||||
token: token_bytes,
|
||||
identity_key: None,
|
||||
});
|
||||
}
|
||||
|
||||
validate_token_raw(&state.auth_cfg, &state.sessions, &token_bytes)
|
||||
}
|
||||
|
||||
/// Resolve identity key: either from auth context (session-bound) or from
|
||||
/// request params (insecure-auth mode). Returns the 32-byte identity key.
|
||||
fn resolve_identity(
|
||||
state: &WsBridgeState,
|
||||
auth_ctx: &AuthContext,
|
||||
params: &serde_json::Value,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
// If auth context has an identity-bound session, use that.
|
||||
if let Some(ref ik) = auth_ctx.identity_key {
|
||||
return Ok(ik.clone());
|
||||
}
|
||||
|
||||
// In insecure-auth mode, accept identity from params.
|
||||
if state.allow_insecure_auth {
|
||||
// Try base64-encoded identityKey first.
|
||||
if let Some(b64) = params.get("identityKey").and_then(|v| v.as_str()) {
|
||||
return B64
|
||||
.decode(b64)
|
||||
.map_err(|e| format!("bad base64 identityKey: {e}"));
|
||||
}
|
||||
// Try username lookup.
|
||||
if let Some(username) = params.get("username").and_then(|v| v.as_str()) {
|
||||
if let Ok(Some(ik)) = state.store.get_user_identity_key(username) {
|
||||
return Ok(ik);
|
||||
}
|
||||
return Err(format!("user not found: {username}"));
|
||||
}
|
||||
}
|
||||
|
||||
Err("no identity: login required or pass identityKey/username in insecure mode".to_string())
|
||||
}
|
||||
|
||||
/// Apply rate limiting using the auth token. Returns an error string on limit exceeded.
|
||||
fn ws_check_rate_limit(state: &WsBridgeState, auth_ctx: &AuthContext) -> Result<(), String> {
|
||||
check_rate_limit(&state.rate_limits, &auth_ctx.token)
|
||||
.map_err(|e| format!("rate limit exceeded: {e}"))
|
||||
}
|
||||
|
||||
// ── Dispatch ────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn dispatch(state: &WsBridgeState, req: RpcRequest) -> RpcResponse {
|
||||
match req.method.as_str() {
|
||||
"health" => handle_health(req.id),
|
||||
"resolveUser" => handle_resolve_user(state, req.id, &req.params).await,
|
||||
"createChannel" => handle_create_channel(state, req.id, &req.params),
|
||||
"send" => handle_send(state, req.id, &req.params),
|
||||
"receive" => handle_receive(state, req.id, &req.params),
|
||||
"deleteAccount" => handle_delete_account(state, req.id, &req.params),
|
||||
_ => RpcResponse::error(req.id, format!("unknown method: {}", req.method)),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Handlers ────────────────────────────────────────────────────────────────
|
||||
|
||||
fn handle_health(id: serde_json::Value) -> RpcResponse {
|
||||
RpcResponse::success(id, serde_json::json!("ok"))
|
||||
}
|
||||
|
||||
async fn handle_resolve_user(
|
||||
state: &WsBridgeState,
|
||||
id: serde_json::Value,
|
||||
params: &serde_json::Value,
|
||||
) -> RpcResponse {
|
||||
let auth_ctx = match extract_auth(state, params) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Rate limit resolve requests to prevent bulk enumeration.
|
||||
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
|
||||
return RpcResponse::error(id, e);
|
||||
}
|
||||
|
||||
let username = match params.get("username").and_then(|v| v.as_str()) {
|
||||
Some(u) if !u.is_empty() => u,
|
||||
_ => return RpcResponse::error(id, "missing or empty 'username' param"),
|
||||
};
|
||||
|
||||
// Timing floor: mask DB-lookup timing differences between existing and
|
||||
// non-existing usernames (same as Cap'n Proto resolveUser handler).
|
||||
let deadline = Instant::now() + RESOLVE_TIMING_FLOOR;
|
||||
|
||||
let response = match state.store.get_user_identity_key(username) {
|
||||
Ok(Some(key)) => {
|
||||
RpcResponse::success(id, serde_json::json!({ "identityKey": B64.encode(&key) }))
|
||||
}
|
||||
Ok(None) => RpcResponse::success(id, serde_json::json!({ "identityKey": null })),
|
||||
Err(e) => RpcResponse::error(id, format!("storage error: {e}")),
|
||||
};
|
||||
|
||||
// Pad to timing floor before responding.
|
||||
tokio::time::sleep_until(deadline).await;
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
fn handle_create_channel(
|
||||
state: &WsBridgeState,
|
||||
id: serde_json::Value,
|
||||
params: &serde_json::Value,
|
||||
) -> RpcResponse {
|
||||
let auth_ctx = match extract_auth(state, params) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Rate limit.
|
||||
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
|
||||
return RpcResponse::error(id, e);
|
||||
}
|
||||
|
||||
let my_key = match resolve_identity(state, &auth_ctx, params) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Accept peer key as base64 or resolve from username.
|
||||
let peer_key = if let Some(b64) = params.get("peerKey").and_then(|v| v.as_str()) {
|
||||
match B64.decode(b64) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, format!("bad base64 peerKey: {e}")),
|
||||
}
|
||||
} else if let Some(username) = params.get("peerUsername").and_then(|v| v.as_str()) {
|
||||
match state.store.get_user_identity_key(username) {
|
||||
Ok(Some(k)) => k,
|
||||
Ok(None) => return RpcResponse::error(id, format!("peer user not found: {username}")),
|
||||
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
|
||||
}
|
||||
} else {
|
||||
return RpcResponse::error(id, "missing 'peerKey' (base64) or 'peerUsername'");
|
||||
};
|
||||
|
||||
if peer_key.len() != 32 {
|
||||
return RpcResponse::error(id, "peerKey must be 32 bytes");
|
||||
}
|
||||
|
||||
match state.store.create_channel(&my_key, &peer_key) {
|
||||
Ok((channel_id, was_new)) => RpcResponse::success(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"channelId": B64.encode(&channel_id),
|
||||
"wasNew": was_new,
|
||||
}),
|
||||
),
|
||||
Err(e) => RpcResponse::error(id, format!("storage error: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_send(
|
||||
state: &WsBridgeState,
|
||||
id: serde_json::Value,
|
||||
params: &serde_json::Value,
|
||||
) -> RpcResponse {
|
||||
let auth_ctx = match extract_auth(state, params) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Rate limit (parity with Cap'n Proto enqueue path).
|
||||
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
|
||||
return RpcResponse::error(id, e);
|
||||
}
|
||||
|
||||
let sender_key = match resolve_identity(state, &auth_ctx, params) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Resolve recipient: base64 key or username.
|
||||
let recipient_key =
|
||||
if let Some(b64) = params.get("recipientKey").and_then(|v| v.as_str()) {
|
||||
match B64.decode(b64) {
|
||||
Ok(k) => k,
|
||||
Err(e) => {
|
||||
return RpcResponse::error(id, format!("bad base64 recipientKey: {e}"))
|
||||
}
|
||||
}
|
||||
} else if let Some(username) = params.get("recipient").and_then(|v| v.as_str()) {
|
||||
match state.store.get_user_identity_key(username) {
|
||||
Ok(Some(k)) => k,
|
||||
Ok(None) => {
|
||||
return RpcResponse::error(id, format!("recipient not found: {username}"))
|
||||
}
|
||||
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
|
||||
}
|
||||
} else {
|
||||
return RpcResponse::error(id, "missing 'recipientKey' (base64) or 'recipient' (username)");
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return RpcResponse::error(id, "recipientKey must be 32 bytes");
|
||||
}
|
||||
|
||||
// Payload: base64-encoded binary or plain text message.
|
||||
let payload = if let Some(b64) = params.get("payload").and_then(|v| v.as_str()) {
|
||||
match B64.decode(b64) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return RpcResponse::error(id, format!("bad base64 payload: {e}")),
|
||||
}
|
||||
} else if let Some(msg) = params.get("message").and_then(|v| v.as_str()) {
|
||||
msg.as_bytes().to_vec()
|
||||
} else {
|
||||
return RpcResponse::error(id, "missing 'payload' (base64) or 'message' (text)");
|
||||
};
|
||||
|
||||
if payload.is_empty() {
|
||||
return RpcResponse::error(id, "payload must not be empty");
|
||||
}
|
||||
|
||||
// Payload size limit (same as Cap'n Proto path: 5 MB).
|
||||
if payload.len() > MAX_PAYLOAD_BYTES {
|
||||
return RpcResponse::error(
|
||||
id,
|
||||
format!("payload exceeds max size ({MAX_PAYLOAD_BYTES} bytes)"),
|
||||
);
|
||||
}
|
||||
|
||||
// Create or look up the DM channel between sender and recipient.
|
||||
let channel_id = match state.store.create_channel(&sender_key, &recipient_key) {
|
||||
Ok((ch, _)) => ch,
|
||||
Err(e) => return RpcResponse::error(id, format!("channel error: {e}")),
|
||||
};
|
||||
|
||||
// DM channel membership verification (parity with Cap'n Proto enqueue path).
|
||||
if channel_id.len() == 16 {
|
||||
let members = match state.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => return RpcResponse::error(id, "channel not found"),
|
||||
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = sender_key == *a || sender_key == *b;
|
||||
let recipient_other = (recipient_key == *a && sender_key == *b)
|
||||
|| (recipient_key == *b && sender_key == *a);
|
||||
if !caller_in || !recipient_other {
|
||||
return RpcResponse::error(id, "caller or recipient not a member of this channel");
|
||||
}
|
||||
}
|
||||
|
||||
match state
|
||||
.store
|
||||
.enqueue(&recipient_key, &channel_id, payload, None)
|
||||
{
|
||||
Ok(seq) => {
|
||||
// Notify any waiting long-poll fetchers.
|
||||
if let Some(notify) = state.waiters.get(&recipient_key) {
|
||||
notify.notify_waiters();
|
||||
}
|
||||
|
||||
// Audit logging (no secrets: no payload, no full keys).
|
||||
tracing::info!(
|
||||
recipient_prefix = %hex::encode(&recipient_key[..std::cmp::min(4, recipient_key.len())]),
|
||||
seq = seq,
|
||||
"audit: ws_bridge enqueue"
|
||||
);
|
||||
|
||||
RpcResponse::success(id, serde_json::json!({ "seq": seq }))
|
||||
}
|
||||
Err(e) => RpcResponse::error(id, format!("enqueue error: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_receive(
|
||||
state: &WsBridgeState,
|
||||
id: serde_json::Value,
|
||||
params: &serde_json::Value,
|
||||
) -> RpcResponse {
|
||||
let auth_ctx = match extract_auth(state, params) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Rate limit.
|
||||
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
|
||||
return RpcResponse::error(id, e);
|
||||
}
|
||||
|
||||
let my_key = match resolve_identity(state, &auth_ctx, params) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Resolve sender/peer: base64 key or username (needed to find the channel).
|
||||
let peer_key = if let Some(b64) = params.get("recipientKey").and_then(|v| v.as_str()) {
|
||||
match B64.decode(b64) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, format!("bad base64 recipientKey: {e}")),
|
||||
}
|
||||
} else if let Some(username) = params.get("recipient").and_then(|v| v.as_str()) {
|
||||
match state.store.get_user_identity_key(username) {
|
||||
Ok(Some(k)) => k,
|
||||
Ok(None) => return RpcResponse::error(id, format!("user not found: {username}")),
|
||||
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
|
||||
}
|
||||
} else {
|
||||
return RpcResponse::error(id, "missing 'recipientKey' (base64) or 'recipient' (username)");
|
||||
};
|
||||
|
||||
// Find the channel between me and the peer.
|
||||
let channel_id = match state.store.create_channel(&my_key, &peer_key) {
|
||||
Ok((ch, _)) => ch,
|
||||
Err(e) => return RpcResponse::error(id, format!("channel error: {e}")),
|
||||
};
|
||||
|
||||
// Fetch (drain) all messages for me in this channel.
|
||||
match state.store.fetch(&my_key, &channel_id) {
|
||||
Ok(messages) => {
|
||||
let items: Vec<serde_json::Value> = messages
|
||||
.into_iter()
|
||||
.map(|(seq, data)| {
|
||||
// Try to decode as UTF-8 text; fall back to base64.
|
||||
let text = String::from_utf8(data.clone()).ok();
|
||||
serde_json::json!({
|
||||
"seq": seq,
|
||||
"data": B64.encode(&data),
|
||||
"text": text,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
RpcResponse::success(id, serde_json::json!(items))
|
||||
}
|
||||
Err(e) => RpcResponse::error(id, format!("fetch error: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_delete_account(
|
||||
state: &WsBridgeState,
|
||||
id: serde_json::Value,
|
||||
params: &serde_json::Value,
|
||||
) -> RpcResponse {
|
||||
let auth_ctx = match extract_auth(state, params) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Rate limit.
|
||||
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
|
||||
return RpcResponse::error(id, e);
|
||||
}
|
||||
|
||||
let identity_key = match resolve_identity(state, &auth_ctx, params) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
match state.store.delete_account(&identity_key) {
|
||||
Ok(()) => {
|
||||
// Invalidate sessions for this identity.
|
||||
let tokens_to_remove: Vec<Vec<u8>> = state
|
||||
.sessions
|
||||
.iter()
|
||||
.filter(|entry| entry.value().identity_key == identity_key)
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect();
|
||||
for token in &tokens_to_remove {
|
||||
state.sessions.remove(token);
|
||||
}
|
||||
RpcResponse::success(id, serde_json::json!({ "deleted": true }))
|
||||
}
|
||||
Err(e) => RpcResponse::error(id, format!("delete failed: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
// ── WebSocket listener ──────────────────────────────────────────────────────
|
||||
|
||||
/// Spawn the WebSocket JSON-RPC bridge as a background tokio task.
|
||||
pub fn spawn_ws_bridge(addr: SocketAddr, state: Arc<WsBridgeState>) {
|
||||
tokio::spawn(async move {
|
||||
let listener = match TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::error!(addr = %addr, error = %e, "ws_bridge: failed to bind");
|
||||
return;
|
||||
}
|
||||
};
|
||||
tracing::info!(addr = %addr, "ws_bridge: accepting WebSocket connections");
|
||||
|
||||
loop {
|
||||
let (stream, peer) = match listener.accept().await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "ws_bridge: accept error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let state = Arc::clone(&state);
|
||||
tokio::spawn(async move {
|
||||
let ws = match tokio_tungstenite::accept_async(stream).await {
|
||||
Ok(ws) => ws,
|
||||
Err(e) => {
|
||||
tracing::debug!(peer = %peer, error = %e, "ws_bridge: handshake failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
tracing::debug!(peer = %peer, "ws_bridge: client connected");
|
||||
|
||||
let (mut sink, mut stream) = ws.split();
|
||||
|
||||
while let Some(msg) = stream.next().await {
|
||||
let msg = match msg {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
tracing::debug!(peer = %peer, error = %e, "ws_bridge: read error");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match msg {
|
||||
Message::Text(t) => t,
|
||||
Message::Close(_) => break,
|
||||
Message::Ping(_) | Message::Pong(_) => continue,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let req: RpcRequest = match serde_json::from_str(&text) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let resp = RpcResponse::error(
|
||||
serde_json::Value::Null,
|
||||
format!("invalid JSON: {e}"),
|
||||
);
|
||||
let json = serde_json::to_string(&resp).unwrap_or_default();
|
||||
if sink.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let resp = dispatch(&state, req).await;
|
||||
let json = serde_json::to_string(&resp).unwrap_or_default();
|
||||
if sink.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::debug!(peer = %peer, "ws_bridge: client disconnected");
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
Reference in New Issue
Block a user