DM channels (createChannel), channel authz, security/docs, future improvements

- Add createChannel RPC (node.capnp @18): create 1:1 channel, returns 16-byte channelId
- Store: create_channel(member_a, member_b), get_channel_members(channel_id)
- FileBackedStore: channels.bin; SqlStore: migration 003_channels, schema v4
- channel_ops: handle_create_channel (auth + identity, peerKey 32 bytes)
- Delivery authz: when channel_id.len() == 16, require caller and recipient are channel members (E022/E023)
- Error codes E022 CHANNEL_ACCESS_DENIED, E023 CHANNEL_NOT_FOUND
- SUMMARY: link Certificate lifecycle; security audit, future improvements, multi-agent plan docs
- Certificate lifecycle doc, SECURITY-AUDIT, FUTURE-IMPROVEMENTS, MULTI-AGENT-WORK-PLAN
- Client/core/tls/auth/server main: assorted fixes and updates from review and audit

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
2026-02-23 22:54:28 +01:00
parent 6b8b61c6ae
commit 750b794342
40 changed files with 4715 additions and 152 deletions

View File

@@ -0,0 +1,13 @@
-- Migration 003: 1:1 DM channels.
-- channel_id is 16 bytes (UUID); member_a and member_b are identity keys in sorted order.
-- Unique on (member_a, member_b) prevents duplicate channels between the same pair.
CREATE TABLE IF NOT EXISTS channels (
channel_id BLOB PRIMARY KEY,
member_a BLOB NOT NULL,
member_b BLOB NOT NULL,
UNIQUE(member_a, member_b)
);
CREATE INDEX IF NOT EXISTS idx_channels_members
ON channels(member_a, member_b);

View File

@@ -17,6 +17,7 @@ pub const RATE_LIMIT_MAX_ENQUEUES: u32 = 100;
pub struct AuthConfig {
pub required_token: Option<Vec<u8>>,
/// When true, a valid bearer token (no session) is accepted and the request's identity/key is used (dev/e2e only).
/// CLI flag: --allow-insecure-auth / QUICNPROTOCHAT_ALLOW_INSECURE_AUTH.
pub allow_insecure_identity_from_request: bool,
}
@@ -59,10 +60,13 @@ pub struct AuthContext {
}
pub fn current_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
Ok(d) => d.as_secs(),
Err(_) => {
tracing::warn!("system time is before UNIX_EPOCH; using 0 for session/rate-limit timestamps");
0
}
}
}
pub fn check_rate_limit(
@@ -174,7 +178,7 @@ pub fn require_identity<'a>(auth_ctx: &'a AuthContext) -> Result<&'a [u8], capnp
pub fn require_identity_match(auth_ctx: &AuthContext, expected: &[u8]) -> Result<(), capnp::Error> {
let ik = require_identity(auth_ctx)?;
if ik != expected {
if ik.len() != expected.len() || !bool::from(ik.ct_eq(expected)) {
return Err(crate::error_codes::coded_error(
E016_IDENTITY_MISMATCH,
"access token is bound to a different identity",

View File

@@ -24,6 +24,8 @@ pub const E018_USER_EXISTS: &str = "E018";
pub const E019_NO_PENDING_LOGIN: &str = "E019";
pub const E020_BAD_PARAMS: &str = "E020";
pub const E021_CIPHERSUITE_NOT_ALLOWED: &str = "E021";
pub const E022_CHANNEL_ACCESS_DENIED: &str = "E022";
pub const E023_CHANNEL_NOT_FOUND: &str = "E023";
/// Build a `capnp::Error::failed()` with the structured code prefix.
pub fn coded_error(code: &str, msg: impl std::fmt::Display) -> capnp::Error {

View File

@@ -162,9 +162,16 @@ async fn main() -> anyhow::Result<()> {
.parse()
.context("--listen must be host:port")?;
let server_config = build_server_config(&effective.tls_cert, &effective.tls_key, production)
let mut server_config = build_server_config(&effective.tls_cert, &effective.tls_key, production)
.context("failed to build TLS/QUIC server config")?;
// Harden QUIC transport: idle timeout, limit stream concurrency.
let mut transport = quinn::TransportConfig::default();
transport.max_idle_timeout(Some(std::time::Duration::from_secs(300).try_into().unwrap()));
transport.max_concurrent_bidi_streams(1u32.into());
transport.max_concurrent_uni_streams(0u32.into());
server_config.transport_config(Arc::new(transport));
// Shared storage — persisted to disk for restart safety.
let store: Arc<dyn Store> = match effective.store_backend.as_str() {
"sql" => {
@@ -223,6 +230,7 @@ async fn main() -> anyhow::Result<()> {
Arc::clone(&pending_logins),
Arc::clone(&rate_limits),
Arc::clone(&store),
Arc::clone(&waiters),
);
let endpoint = Endpoint::server(server_config, listen)?;

View File

@@ -19,6 +19,16 @@ fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
/// Parse username from Cap'n Proto reader; requires valid UTF-8.
fn parse_username_param(
result: Result<capnp::text::Reader<'_>, capnp::Error>,
) -> Result<String, capnp::Error> {
let reader = result.map_err(|e| coded_error(E020_BAD_PARAMS, e))?;
reader
.to_string()
.map_err(|_| coded_error(E020_BAD_PARAMS, "username must be valid UTF-8"))
}
impl NodeServiceImpl {
pub fn handle_opaque_login_start(
&mut self,
@@ -29,9 +39,9 @@ impl NodeServiceImpl {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match p.get_username() {
Ok(v) => v.to_string().unwrap_or_default().to_string(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let request_bytes = match p.get_request() {
Ok(v) => v.to_vec(),
@@ -42,6 +52,14 @@ impl NodeServiceImpl {
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
}
// Check for existing recent pending login before expensive OPAQUE/storage work (DoS mitigation).
if let Some(existing) = self.pending_logins.get(&username) {
let age = current_timestamp().saturating_sub(existing.created_at);
if age < 60 {
return Promise::err(coded_error(E010_OPAQUE_ERROR, "login already in progress"));
}
}
let credential_request = match CredentialRequest::<OpaqueSuite>::deserialize(&request_bytes) {
Ok(r) => r,
Err(e) => {
@@ -62,9 +80,7 @@ impl NodeServiceImpl {
))
}
},
Ok(None) => {
return Promise::err(coded_error(E010_OPAQUE_ERROR, "user not registered"))
}
Ok(None) => None,
Err(e) => return Promise::err(storage_err(e)),
};
@@ -111,9 +127,9 @@ impl NodeServiceImpl {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match p.get_username() {
Ok(v) => v.to_string().unwrap_or_default().to_string(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let request_bytes = match p.get_request() {
Ok(v) => v.to_vec(),
@@ -171,9 +187,9 @@ impl NodeServiceImpl {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match p.get_username() {
Ok(v) => v.to_string().unwrap_or_default().to_string(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let finalization_bytes = match p.get_finalization() {
Ok(v) => v.to_vec(),
@@ -278,9 +294,9 @@ impl NodeServiceImpl {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match p.get_username() {
Ok(v) => v.to_string().unwrap_or_default().to_string(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let upload_bytes = match p.get_upload() {
Ok(v) => v.to_vec(),
@@ -326,12 +342,18 @@ impl NodeServiceImpl {
let password_file = ServerRegistration::<OpaqueSuite>::finish(upload);
let record_bytes = password_file.serialize().to_vec();
if let Err(e) = self
match self
.store
.store_user_record(&username, record_bytes)
.map_err(storage_err)
{
return Promise::err(e);
Ok(()) => {}
Err(crate::storage::StorageError::DuplicateUser(_)) => {
return Promise::err(coded_error(
E018_USER_EXISTS,
format!("user '{}' already registered", username),
))
}
Err(e) => return Promise::err(storage_err(e)),
}
if !identity_key.is_empty() {

View File

@@ -0,0 +1,62 @@
//! createChannel RPC: create or look up a 1:1 DM channel.
use capnp::capability::Promise;
use quicnprotochat_proto::node_capnp::node_service;
use crate::auth::{coded_error, require_identity, validate_auth_context};
use crate::error_codes::*;
use crate::storage::StorageError;
use super::NodeServiceImpl;
fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
impl NodeServiceImpl {
pub fn handle_create_channel(
&mut self,
params: node_service::CreateChannelParams,
mut results: node_service::CreateChannelResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let peer_key = match p.get_peer_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
let identity = match require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
if peer_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("peerKey must be exactly 32 bytes, got {}", peer_key.len()),
));
}
if identity == peer_key {
return Promise::err(coded_error(
E020_BAD_PARAMS,
"peerKey must not equal caller identity",
));
}
let channel_id = match self.store.create_channel(&identity, &peer_key) {
Ok(id) => id,
Err(e) => return Promise::err(storage_err(e)),
};
results.get().set_channel_id(&channel_id);
Promise::ok(())
}
}

View File

@@ -77,10 +77,10 @@ impl NodeServiceImpl {
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
));
}
if version != CURRENT_WIRE_VERSION {
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
@@ -102,6 +102,31 @@ impl NodeServiceImpl {
}
}
// DM channel authz: channel_id.len() == 16 means a created channel; caller and recipient must be the two members.
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key == *a && caller == b.as_slice())
|| (recipient_key == *b && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
match self.store.queue_depth(&recipient_key, &channel_id) {
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
return Promise::err(coded_error(
@@ -183,10 +208,10 @@ impl NodeServiceImpl {
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version != CURRENT_WIRE_VERSION {
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
@@ -203,6 +228,30 @@ impl NodeServiceImpl {
return Promise::err(e);
}
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
let messages = if limit > 0 {
match self
.store
@@ -269,10 +318,10 @@ impl NodeServiceImpl {
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version != CURRENT_WIRE_VERSION {
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
@@ -284,6 +333,30 @@ impl NodeServiceImpl {
return Promise::err(e);
}
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
let store = Arc::clone(&self.store);
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = self.waiters.clone();
@@ -315,4 +388,232 @@ impl NodeServiceImpl {
Ok(())
})
}
pub fn handle_peek(
&mut self,
params: node_service::PeekParams,
mut results: node_service::PeekResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let limit = p.get_limit();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
let messages = match self
.store
.peek(&recipient_key, &channel_id, limit as usize)
.map_err(storage_err)
{
Ok(m) => m,
Err(e) => return Promise::err(e),
};
tracing::info!(
recipient_prefix = %fmt_hex(&recipient_key[..4]),
count = messages.len(),
"audit: peek"
);
let mut list = results.get().init_payloads(messages.len() as u32);
for (i, (seq, data)) in messages.iter().enumerate() {
let mut entry = list.reborrow().get(i as u32);
entry.set_seq(*seq);
entry.set_data(data);
}
Promise::ok(())
}
pub fn handle_ack(
&mut self,
params: node_service::AckParams,
_results: node_service::AckResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let seq_up_to = p.get_seq_up_to();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
match self
.store
.ack(&recipient_key, &channel_id, seq_up_to)
.map_err(storage_err)
{
Ok(removed) => {
tracing::info!(
recipient_prefix = %fmt_hex(&recipient_key[..4]),
seq_up_to = seq_up_to,
removed = removed,
"audit: ack"
);
}
Err(e) => return Promise::err(e),
}
Promise::ok(())
}
pub fn handle_batch_enqueue(
&mut self,
params: node_service::BatchEnqueueParams,
mut results: node_service::BatchEnqueueResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_keys = match p.get_recipient_keys() {
Ok(v) => v,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let payload = match p.get_payload() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if payload.is_empty() {
return Promise::err(coded_error(E005_PAYLOAD_EMPTY, "payload must not be empty"));
}
if payload.len() > MAX_PAYLOAD_BYTES {
return Promise::err(coded_error(
E006_PAYLOAD_TOO_LARGE,
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) {
tracing::warn!("rate_limit_hit");
metrics::record_rate_limit_hit_total();
return Promise::err(e);
}
let mut seqs = Vec::with_capacity(recipient_keys.len() as usize);
for i in 0..recipient_keys.len() {
let rk = match recipient_keys.get(i) {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if rk.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey[{}] must be exactly 32 bytes, got {}", i, rk.len()),
));
}
match self.store.queue_depth(&rk, &channel_id) {
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
return Promise::err(coded_error(
E015_QUEUE_FULL,
format!("queue depth {} exceeds limit {}", depth, MAX_QUEUE_DEPTH),
));
}
Err(e) => return Promise::err(storage_err(e)),
_ => {}
}
let seq = match self
.store
.enqueue(&rk, &channel_id, payload.clone())
.map_err(storage_err)
{
Ok(seq) => seq,
Err(e) => return Promise::err(e),
};
seqs.push(seq);
metrics::record_enqueue_total();
metrics::record_enqueue_bytes(payload.len() as u64);
crate::auth::waiter(&self.waiters, &rk).notify_waiters();
}
let mut list = results.get().init_seqs(seqs.len() as u32);
for (i, seq) in seqs.iter().enumerate() {
list.set(i as u32, *seq);
}
tracing::info!(
recipient_count = recipient_keys.len(),
payload_len = payload.len(),
"audit: batch_enqueue"
);
Promise::ok(())
}
}

View File

@@ -256,4 +256,47 @@ impl NodeServiceImpl {
Promise::ok(())
}
pub fn handle_fetch_hybrid_keys(
&mut self,
params: node_service::FetchHybridKeysParams,
mut results: node_service::FetchHybridKeysResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let identity_keys = match p.get_identity_keys() {
Ok(v) => v,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if let Err(e) = validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
return Promise::err(e);
}
let count = identity_keys.len() as usize;
let mut key_data: Vec<Vec<u8>> = Vec::with_capacity(count);
for i in 0..identity_keys.len() {
let ik = match identity_keys.get(i) {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let pk = match self.store.fetch_hybrid_key(&ik).map_err(storage_err) {
Ok(Some(pk)) => pk,
Ok(None) => vec![],
Err(e) => return Promise::err(e),
};
key_data.push(pk);
}
let mut list = results.get().init_keys(key_data.len() as u32);
for (i, pk) in key_data.iter().enumerate() {
list.set(i as u32, pk);
}
tracing::debug!(count = count, "batch hybrid key fetch");
Promise::ok(())
}
}

View File

@@ -15,7 +15,11 @@ use crate::auth::{
};
use crate::storage::Store;
/// Cap'n Proto traversal limit (words). 4 Mi words = 32 MiB; bounds DoS from deeply nested or large messages.
const CAPNP_TRAVERSAL_LIMIT_WORDS: usize = 4 * 1024 * 1024;
mod auth_ops;
mod channel_ops;
mod delivery;
mod key_ops;
mod p2p_ops;
@@ -132,6 +136,46 @@ impl node_service::Server for NodeServiceImpl {
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_resolve_endpoint(params, results)
}
fn peek(
&mut self,
params: node_service::PeekParams,
results: node_service::PeekResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_peek(params, results)
}
fn ack(
&mut self,
params: node_service::AckParams,
results: node_service::AckResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_ack(params, results)
}
fn fetch_hybrid_keys(
&mut self,
params: node_service::FetchHybridKeysParams,
results: node_service::FetchHybridKeysResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_fetch_hybrid_keys(params, results)
}
fn batch_enqueue(
&mut self,
params: node_service::BatchEnqueueParams,
results: node_service::BatchEnqueueResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_batch_enqueue(params, results)
}
fn create_channel(
&mut self,
params: node_service::CreateChannelParams,
results: node_service::CreateChannelResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_create_channel(params, results)
}
}
pub const CURRENT_WIRE_VERSION: u16 = 1;
@@ -193,11 +237,13 @@ pub async fn handle_node_connection(
.map_err(|e| anyhow::anyhow!("failed to accept bi stream: {e}"))?;
let (reader, writer) = (recv.compat(), send.compat_write());
let mut reader_opts = capnp::message::ReaderOptions::new();
reader_opts.traversal_limit_in_words(Some(CAPNP_TRAVERSAL_LIMIT_WORDS));
let network = capnp_rpc::twoparty::VatNetwork::new(
reader,
writer,
capnp_rpc::rpc_twoparty_capnp::Side::Server,
Default::default(),
reader_opts,
);
let service: node_service::Client = capnp_rpc::new_client(NodeServiceImpl::new(
@@ -223,6 +269,7 @@ pub fn spawn_cleanup_task(
pending_logins: Arc<DashMap<String, PendingLogin>>,
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
store: Arc<dyn Store>,
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
@@ -234,6 +281,29 @@ pub fn spawn_cleanup_task(
pending_logins.retain(|_, pl| now - pl.created_at < PENDING_LOGIN_TTL_SECS);
rate_limits.retain(|_, entry| now - entry.window_start < RATE_LIMIT_WINDOW_SECS * 2);
// Bound map sizes to prevent unbounded growth from malicious clients.
const MAX_SESSIONS: usize = 100_000;
const MAX_WAITERS: usize = 100_000;
if sessions.len() > MAX_SESSIONS {
let overflow = sessions.len() - MAX_SESSIONS;
let mut entries: Vec<_> = sessions
.iter()
.map(|e| (e.key().clone(), e.expires_at))
.collect();
entries.sort_by_key(|(_, exp)| *exp);
for (key, _) in entries.into_iter().take(overflow) {
sessions.remove(&key);
}
}
if waiters.len() > MAX_WAITERS {
let overflow = waiters.len() - MAX_WAITERS;
let keys: Vec<_> =
waiters.iter().take(overflow).map(|e| e.key().clone()).collect();
for key in keys {
waiters.remove(&key);
}
}
match store.gc_expired_messages(MESSAGE_TTL_SECS) {
Ok(n) if n > 0 => {
tracing::debug!(expired = n, "garbage collected expired messages")

View File

@@ -14,6 +14,7 @@ fn storage_err(err: StorageError) -> capnp::Error {
}
impl NodeServiceImpl {
/// Health check: unauthenticated by design for liveness probes and load balancers.
pub fn handle_health(
&mut self,
_params: node_service::HealthParams,

View File

@@ -3,17 +3,19 @@
use std::path::Path;
use std::sync::Mutex;
use rand::RngCore;
use rusqlite::{params, Connection};
use crate::storage::{StorageError, Store};
/// Schema version after introducing the migration runner (existing DBs had 1).
const SCHEMA_VERSION: i32 = 3;
const SCHEMA_VERSION: i32 = 4;
/// Migrations: (migration_number, SQL). Files named NNN_name.sql, applied in order when N > user_version.
const MIGRATIONS: &[(i32, &str)] = &[
(1, include_str!("../migrations/001_initial.sql")),
(3, include_str!("../migrations/002_add_seq.sql")),
(4, include_str!("../migrations/003_channels.sql")),
];
/// Runs pending migrations on an open connection: applies any migration whose number is greater
@@ -305,10 +307,17 @@ impl Store for SqlStore {
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
let conn = self.lock_conn()?;
conn.execute(
"INSERT OR REPLACE INTO users (username, opaque_record) VALUES (?1, ?2)",
"INSERT INTO users (username, opaque_record) VALUES (?1, ?2)",
params![username, record],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
.map_err(|e| {
if let rusqlite::Error::SqliteFailure(ref err, _) = &e {
if err.code == rusqlite::ErrorCode::ConstraintViolation {
return StorageError::DuplicateUser(username.to_string());
}
}
StorageError::Db(e.to_string())
})?;
Ok(())
}
@@ -360,6 +369,57 @@ impl Store for SqlStore {
.map_err(|e| StorageError::Db(e.to_string()))
}
fn peek(
&self,
recipient_key: &[u8],
channel_id: &[u8],
limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
let conn = self.lock_conn()?;
let sql = if limit == 0 {
"SELECT seq, payload FROM deliveries
WHERE recipient_key = ?1 AND channel_id = ?2
ORDER BY seq ASC".to_string()
} else {
format!(
"SELECT seq, payload FROM deliveries
WHERE recipient_key = ?1 AND channel_id = ?2
ORDER BY seq ASC
LIMIT {}",
limit
)
};
let mut stmt = conn.prepare(&sql).map_err(|e| StorageError::Db(e.to_string()))?;
let rows: Vec<(i64, Vec<u8>)> = stmt
.query_map(params![recipient_key, channel_id], |row| {
Ok((row.get(0)?, row.get(1)?))
})
.map_err(|e| StorageError::Db(e.to_string()))?
.collect::<Result<Vec<_>, _>>()
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(rows.into_iter().map(|(seq, payload)| (seq as u64, payload)).collect())
}
fn ack(
&self,
recipient_key: &[u8],
channel_id: &[u8],
seq_up_to: u64,
) -> Result<usize, StorageError> {
let conn = self.lock_conn()?;
let deleted = conn
.execute(
"DELETE FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND seq <= ?3",
params![recipient_key, channel_id, seq_up_to as i64],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(deleted)
}
fn publish_endpoint(
&self,
identity_key: &[u8],
@@ -384,6 +444,45 @@ impl Store for SqlStore {
.optional()
.map_err(|e| StorageError::Db(e.to_string()))
}
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError> {
let (a, b) = if member_a < member_b {
(member_a.to_vec(), member_b.to_vec())
} else {
(member_b.to_vec(), member_a.to_vec())
};
let conn = self.lock_conn()?;
let existing: Option<Vec<u8>> = conn
.query_row(
"SELECT channel_id FROM channels WHERE member_a = ?1 AND member_b = ?2",
params![a, b],
|row| row.get(0),
)
.optional()
.map_err(|e| StorageError::Db(e.to_string()))?;
if let Some(id) = existing {
return Ok(id);
}
let mut channel_id = [0u8; 16];
rand::thread_rng().fill_bytes(&mut channel_id);
conn.execute(
"INSERT INTO channels (channel_id, member_a, member_b) VALUES (?1, ?2, ?3)",
params![channel_id.as_slice(), a, b],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(channel_id.to_vec())
}
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError> {
let conn = self.lock_conn()?;
conn.query_row(
"SELECT member_a, member_b FROM channels WHERE channel_id = ?1",
params![channel_id],
|row| Ok((row.get::<_, Vec<u8>>(0)?, row.get::<_, Vec<u8>>(1)?)),
)
.optional()
.map_err(|e| StorageError::Db(e.to_string()))
}
}
/// Convenience extension for `rusqlite::OptionalExtension`.

View File

@@ -6,6 +6,7 @@ use std::{
sync::Mutex,
};
use rand::RngCore;
use serde::{Deserialize, Serialize};
#[derive(thiserror::Error, Debug)]
@@ -16,6 +17,9 @@ pub enum StorageError {
Serde,
#[error("database error: {0}")]
Db(String),
/// Unique constraint violation (e.g. user already exists).
#[error("duplicate user: {0}")]
DuplicateUser(String),
}
fn lock<T>(m: &Mutex<T>) -> Result<std::sync::MutexGuard<'_, T>, StorageError> {
@@ -96,12 +100,36 @@ pub trait Store: Send + Sync {
/// Retrieve identity key for a user (Fix 2).
fn get_user_identity_key(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError>;
/// Peek at queued messages without removing them (non-destructive).
/// Returns `(seq, payload)` pairs ordered by seq.
fn peek(
&self,
recipient_key: &[u8],
channel_id: &[u8],
limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, StorageError>;
/// Acknowledge (remove) all messages with seq <= seq_up_to.
fn ack(
&self,
recipient_key: &[u8],
channel_id: &[u8],
seq_up_to: u64,
) -> Result<usize, StorageError>;
/// Publish a P2P endpoint address for an identity key.
fn publish_endpoint(&self, identity_key: &[u8], node_addr: Vec<u8>)
-> Result<(), StorageError>;
/// Resolve a peer's P2P endpoint address.
fn resolve_endpoint(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
/// Create a 1:1 channel between two members. Returns 16-byte channel_id (UUID).
/// Members are stored in sorted order for deterministic lookup.
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError>;
/// Get the two members of a channel by channel_id (16 bytes). Returns (member_a, member_b) in sorted order.
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError>;
}
// ── ChannelKey ───────────────────────────────────────────────────────────────
@@ -154,8 +182,10 @@ pub struct FileBackedStore {
setup_path: PathBuf,
users_path: PathBuf,
identity_keys_path: PathBuf,
channels_path: PathBuf,
key_packages: Mutex<HashMap<Vec<u8>, VecDeque<Vec<u8>>>>,
deliveries: Mutex<QueueMapV3>,
channels: Mutex<HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>>,
hybrid_keys: Mutex<HashMap<Vec<u8>, Vec<u8>>>,
users: Mutex<HashMap<String, Vec<u8>>>,
identity_keys: Mutex<HashMap<String, Vec<u8>>>,
@@ -174,12 +204,14 @@ impl FileBackedStore {
let setup_path = dir.join("server_setup.bin");
let users_path = dir.join("users.bin");
let identity_keys_path = dir.join("identity_keys.bin");
let channels_path = dir.join("channels.bin");
let key_packages = Mutex::new(Self::load_kp_map(&kp_path)?);
let deliveries = Mutex::new(Self::load_delivery_map_v3(&ds_path)?);
let hybrid_keys = Mutex::new(Self::load_hybrid_keys(&hk_path)?);
let users = Mutex::new(Self::load_users(&users_path)?);
let identity_keys = Mutex::new(Self::load_map_string_bytes(&identity_keys_path)?);
let channels = Mutex::new(Self::load_channels(&channels_path)?);
Ok(Self {
kp_path,
@@ -188,8 +220,10 @@ impl FileBackedStore {
setup_path,
users_path,
identity_keys_path,
channels_path,
key_packages,
deliveries,
channels,
hybrid_keys,
users,
identity_keys,
@@ -197,6 +231,31 @@ impl FileBackedStore {
})
}
fn load_channels(
path: &Path,
) -> Result<HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>, StorageError> {
if !path.exists() {
return Ok(HashMap::new());
}
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
if bytes.is_empty() {
return Ok(HashMap::new());
}
bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)
}
fn flush_channels(
&self,
path: &Path,
map: &HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>,
) -> Result<(), StorageError> {
let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
}
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
}
fn load_kp_map(path: &Path) -> Result<HashMap<Vec<u8>, VecDeque<Vec<u8>>>, StorageError> {
if !path.exists() {
return Ok(HashMap::new());
@@ -346,8 +405,9 @@ impl Store for FileBackedStore {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
let seq = *inner.next_seq.entry(key.clone()).or_insert(0);
*inner.next_seq.get_mut(&key).unwrap() = seq + 1;
let entry = inner.next_seq.entry(key.clone()).or_insert(0);
let seq = *entry;
*entry = seq + 1;
inner.map.entry(key).or_default().push_back(SeqEntry { seq, data: payload });
self.flush_delivery_map(&self.ds_path, &*inner)?;
Ok(seq)
@@ -428,7 +488,13 @@ impl Store for FileBackedStore {
if let Some(parent) = self.setup_path.parent() {
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
}
fs::write(&self.setup_path, setup).map_err(|e| StorageError::Io(e.to_string()))
fs::write(&self.setup_path, setup).map_err(|e| StorageError::Io(e.to_string()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(&self.setup_path, std::fs::Permissions::from_mode(0o600));
}
Ok(())
}
fn get_server_setup(&self) -> Result<Option<Vec<u8>>, StorageError> {
@@ -444,7 +510,14 @@ impl Store for FileBackedStore {
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
let mut map = lock(&self.users)?;
map.insert(username.to_string(), record);
match map.entry(username.to_string()) {
std::collections::hash_map::Entry::Occupied(_) => {
return Err(StorageError::DuplicateUser(username.to_string()))
}
std::collections::hash_map::Entry::Vacant(v) => {
v.insert(record);
}
}
self.flush_users(&self.users_path, &*map)
}
@@ -473,6 +546,54 @@ impl Store for FileBackedStore {
Ok(map.get(username).cloned())
}
fn peek(
&self,
recipient_key: &[u8],
channel_id: &[u8],
limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
let inner = lock(&self.deliveries)?;
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
let messages: Vec<(u64, Vec<u8>)> = inner
.map
.get(&key)
.map(|q| {
let count = if limit == 0 { q.len() } else { limit.min(q.len()) };
q.iter()
.take(count)
.map(|e| (e.seq, e.data.clone()))
.collect()
})
.unwrap_or_default();
// Non-destructive: do NOT flush.
Ok(messages)
}
fn ack(
&self,
recipient_key: &[u8],
channel_id: &[u8],
seq_up_to: u64,
) -> Result<usize, StorageError> {
let mut inner = lock(&self.deliveries)?;
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
let removed = if let Some(q) = inner.map.get_mut(&key) {
let before = q.len();
q.retain(|e| e.seq > seq_up_to);
before - q.len()
} else {
0
};
self.flush_delivery_map(&self.ds_path, &*inner)?;
Ok(removed)
}
fn publish_endpoint(
&self,
identity_key: &[u8],
@@ -487,4 +608,150 @@ impl Store for FileBackedStore {
let map = lock(&self.endpoints)?;
Ok(map.get(identity_key).cloned())
}
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError> {
let (a, b) = if member_a < member_b {
(member_a.to_vec(), member_b.to_vec())
} else {
(member_b.to_vec(), member_a.to_vec())
};
let mut map = lock(&self.channels)?;
if let Some((channel_id, _)) = map.iter().find(|(_, (ma, mb))| ma == &a && mb == &b) {
return Ok(channel_id.clone());
}
let mut channel_id = [0u8; 16];
rand::thread_rng().fill_bytes(&mut channel_id);
let channel_id = channel_id.to_vec();
map.insert(channel_id.clone(), (a, b));
self.flush_channels(&self.channels_path, &*map)?;
Ok(channel_id)
}
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError> {
let map = lock(&self.channels)?;
Ok(map.get(channel_id).cloned())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn temp_store() -> (TempDir, FileBackedStore) {
let dir = TempDir::new().unwrap();
let store = FileBackedStore::open(dir.path()).unwrap();
(dir, store)
}
#[test]
fn key_package_upload_fetch() {
let (_dir, store) = temp_store();
let ik = vec![1u8; 32];
store.upload_key_package(&ik, vec![10, 20, 30]).unwrap();
let pkg = store.fetch_key_package(&ik).unwrap();
assert_eq!(pkg, Some(vec![10, 20, 30]));
// Second fetch should return None (consumed)
let pkg2 = store.fetch_key_package(&ik).unwrap();
assert_eq!(pkg2, None);
}
#[test]
fn enqueue_fetch_with_seq() {
let (_dir, store) = temp_store();
let rk = vec![2u8; 32];
let ch = vec![];
let seq0 = store.enqueue(&rk, &ch, vec![1]).unwrap();
let seq1 = store.enqueue(&rk, &ch, vec![2]).unwrap();
assert_eq!(seq0, 0);
assert_eq!(seq1, 1);
let msgs = store.fetch(&rk, &ch).unwrap();
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0], (0, vec![1]));
assert_eq!(msgs[1], (1, vec![2]));
// After fetch, queue should be empty
let msgs2 = store.fetch(&rk, &ch).unwrap();
assert!(msgs2.is_empty());
}
#[test]
fn fetch_limited_respects_limit() {
let (_dir, store) = temp_store();
let rk = vec![3u8; 32];
let ch = vec![];
for i in 0..5 {
store.enqueue(&rk, &ch, vec![i]).unwrap();
}
let msgs = store.fetch_limited(&rk, &ch, 2).unwrap();
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0].1, vec![0]);
assert_eq!(msgs[1].1, vec![1]);
// Remaining 3 should still be there
let depth = store.queue_depth(&rk, &ch).unwrap();
assert_eq!(depth, 3);
}
#[test]
fn queue_depth_tracking() {
let (_dir, store) = temp_store();
let rk = vec![4u8; 32];
let ch = vec![];
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0);
store.enqueue(&rk, &ch, vec![1]).unwrap();
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 1);
store.enqueue(&rk, &ch, vec![2]).unwrap();
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 2);
store.fetch(&rk, &ch).unwrap();
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0);
}
#[test]
fn hybrid_key_upload_fetch() {
let (_dir, store) = temp_store();
let ik = vec![5u8; 32];
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), None);
store.upload_hybrid_key(&ik, vec![99; 100]).unwrap();
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(vec![99; 100]));
}
#[test]
fn user_record_crud() {
let (_dir, store) = temp_store();
assert!(!store.has_user_record("alice").unwrap());
store.store_user_record("alice", vec![1, 2, 3]).unwrap();
assert!(store.has_user_record("alice").unwrap());
assert_eq!(store.get_user_record("alice").unwrap(), Some(vec![1, 2, 3]));
}
#[test]
fn user_identity_key_crud() {
let (_dir, store) = temp_store();
assert_eq!(store.get_user_identity_key("bob").unwrap(), None);
store.store_user_identity_key("bob", vec![7u8; 32]).unwrap();
assert_eq!(store.get_user_identity_key("bob").unwrap(), Some(vec![7u8; 32]));
}
#[test]
fn endpoint_publish_resolve() {
let (_dir, store) = temp_store();
let ik = vec![8u8; 32];
assert_eq!(store.resolve_endpoint(&ik).unwrap(), None);
store.publish_endpoint(&ik, vec![10, 20]).unwrap();
assert_eq!(store.resolve_endpoint(&ik).unwrap(), Some(vec![10, 20]));
}
#[test]
fn create_channel_and_members() {
let (_dir, store) = temp_store();
let a = vec![1u8; 32];
let b = vec![2u8; 32];
assert_eq!(store.get_channel_members(&[0u8; 16]).unwrap(), None);
let id1 = store.create_channel(&a, &b).unwrap();
assert_eq!(id1.len(), 16);
let members = store.get_channel_members(&id1).unwrap().unwrap();
assert_eq!(members.0, a);
assert_eq!(members.1, b);
let id2 = store.create_channel(&b, &a).unwrap();
assert_eq!(id1, id2);
}
}

View File

@@ -61,6 +61,12 @@ fn generate_self_signed_cert(cert_path: &PathBuf, key_path: &PathBuf) -> anyhow:
std::fs::write(cert_path, issued.cert.der()).context("write cert")?;
std::fs::write(key_path, &key_der).context("write key")?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(0o600);
std::fs::set_permissions(key_path, perms).context("set key permissions")?;
}
tracing::info!(
cert = %cert_path.display(),