Files
quicproquo/crates/quicproquo-server/src/node_service/delivery.rs
Chris Nennemann 9ab306d891 feat: Sprint 2 — security hardening, MLS key rotation, E2E tests
- DS sender identity binding (Phase 4.3): explicit audit logging of
  sender_prefix in enqueue/batch_enqueue, documenting that sender
  identity is always derived from authenticated session
- Username enumeration mitigation (Phase 4.5): 5ms timing floor on
  resolveUser responses + rate limiting to prevent bulk enumeration
- Add /update-key REPL command for MLS leaf key rotation via
  propose_self_update + auto-commit + fan-out to group members
- Add 4 new E2E tests: message delivery round-trip, key rotation
  update path, oversized payload rejection, multi-party group (12 total)
2026-03-03 23:37:24 +01:00

849 lines
32 KiB
Rust

use std::sync::Arc;
use std::time::Duration;
use capnp::capability::Promise;
use dashmap::DashMap;
use quicproquo_proto::node_capnp::node_service;
use tokio::sync::Notify;
use tokio::time::timeout;
use sha2::{Digest, Sha256};
use crate::auth::{
check_rate_limit, coded_error, fmt_hex, require_identity_or_request, validate_auth_context,
};
use crate::error_codes::*;
use crate::metrics;
use crate::storage::{StorageError, Store};
use super::{NodeServiceImpl, CURRENT_WIRE_VERSION};
use crate::hooks::{HookAction, MessageEvent, FetchEvent};
// Audit events here must not include secrets: no payload content, no full recipient/token bytes (prefix only).
const MAX_PAYLOAD_BYTES: usize = 5 * 1024 * 1024; // 5 MB cap per message
const MAX_QUEUE_DEPTH: usize = 1000;
/// Build a 96-byte delivery proof: SHA-256(seq || recipient_key || timestamp_ms) || Ed25519 sig.
///
/// Layout:
/// bytes 0..32 — SHA-256 preimage hash
/// bytes 32..96 — Ed25519 signature over those 32 bytes
fn build_delivery_proof(
signing_key: &quicproquo_core::IdentityKeypair,
seq: u64,
recipient_key: &[u8],
timestamp_ms: u64,
) -> [u8; 96] {
let mut hasher = Sha256::new();
hasher.update(seq.to_le_bytes());
hasher.update(recipient_key);
hasher.update(timestamp_ms.to_le_bytes());
let hash: [u8; 32] = hasher.finalize().into();
let sig = signing_key.sign_raw(&hash);
let mut proof = [0u8; 96];
proof[..32].copy_from_slice(&hash);
proof[32..].copy_from_slice(&sig);
proof
}
fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
pub fn fill_payloads_wait(
results: &mut node_service::FetchWaitResults,
messages: Vec<(u64, Vec<u8>)>,
) {
let mut list = results.get().init_payloads(messages.len() as u32);
for (i, (seq, data)) in messages.iter().enumerate() {
let mut entry = list.reborrow().get(i as u32);
entry.set_seq(*seq);
entry.set_data(data);
}
}
impl NodeServiceImpl {
pub fn handle_enqueue(
&mut self,
params: node_service::EnqueueParams,
mut results: node_service::EnqueueResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let payload = match p.get_payload() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if payload.is_empty() {
return Promise::err(coded_error(E005_PAYLOAD_EMPTY, "payload must not be empty"));
}
if payload.len() > MAX_PAYLOAD_BYTES {
return Promise::err(coded_error(
E006_PAYLOAD_TOO_LARGE,
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) {
// Audit: rate limit hit — do not log token or identity.
tracing::warn!("rate_limit_hit");
metrics::record_rate_limit_hit_total();
return Promise::err(e);
}
// Phase 4.3 — DS sender identity binding.
// When sealed_sender is false, the sender MUST have an identity-bound session.
// The sender_identity used for audit/hooks is ALWAYS derived from
// auth_ctx.identity_key (populated by OPAQUE session lookup in validate_auth_context),
// never from any client-supplied field. This guarantees that the server only
// attributes messages to the cryptographically authenticated identity.
if !self.sealed_sender {
if let Err(e) = crate::auth::require_identity(&auth_ctx) {
return Promise::err(e);
}
}
// Federation routing: if the recipient's home server differs from ours, relay the
// message to the remote server instead of enqueueing locally. This enables
// cross-node delivery in a Freifunk / community mesh deployment.
if let (Some(fed_client), Some(local_domain)) =
(&self.federation_client, &self.local_domain)
{
let dest = crate::federation::routing::resolve_destination(
&self.store,
&recipient_key,
local_domain,
);
if let crate::federation::routing::Destination::Remote(remote_domain) = dest {
let fed = Arc::clone(fed_client);
let rk = recipient_key;
let pl = payload;
let ch = channel_id;
tracing::info!(
recipient_prefix = %fmt_hex(&rk[..4]),
domain = %remote_domain,
"federation: routing enqueue to remote server"
);
return Promise::from_future(async move {
let seq = fed
.relay_enqueue(&remote_domain, &rk, &pl, &ch)
.await
.map_err(|e| {
capnp::Error::failed(format!("federation relay failed: {e}"))
})?;
results.get().set_seq(seq);
metrics::record_enqueue_total();
metrics::record_enqueue_bytes(pl.len() as u64);
Ok(())
});
}
}
// DM channel authz: channel_id.len() == 16 means a created channel; caller and recipient must be the two members.
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key == *a && caller == b.as_slice())
|| (recipient_key == *b && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
match self.store.queue_depth(&recipient_key, &channel_id) {
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
return Promise::err(coded_error(
E015_QUEUE_FULL,
format!("queue depth {} exceeds limit {}", depth, MAX_QUEUE_DEPTH),
));
}
Err(e) => return Promise::err(storage_err(e)),
_ => {}
}
let payload_len = payload.len();
// sender_identity is derived solely from auth_ctx (server-side session state).
let sender_identity = if self.sealed_sender {
None
} else {
crate::auth::require_identity(&auth_ctx).ok().map(|v| v.to_vec())
};
let sender_prefix = sender_identity
.as_deref()
.filter(|id| id.len() >= 4)
.map(|id| fmt_hex(&id[..4]));
// Hook: on_message_enqueue — fires after validation, before storage.
let hook_event = MessageEvent {
sender_identity,
recipient_key: recipient_key.clone(),
channel_id: channel_id.clone(),
payload_len,
seq: 0, // not yet assigned
};
if let HookAction::Reject(reason) = self.hooks.on_message_enqueue(&hook_event) {
return Promise::err(capnp::Error::failed(format!("hook rejected enqueue: {reason}")));
}
let seq = match self
.store
.enqueue(&recipient_key, &channel_id, payload)
.map_err(storage_err)
{
Ok(seq) => seq,
Err(e) => return Promise::err(e),
};
let timestamp_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let proof = build_delivery_proof(&self.signing_key, seq, &recipient_key, timestamp_ms);
let mut r = results.get();
r.set_seq(seq);
r.set_delivery_proof(&proof);
// Metrics and audit. Audit events must not include secrets (no payload, no full keys).
metrics::record_enqueue_total();
metrics::record_enqueue_bytes(payload_len as u64);
if let Ok(depth) = self.store.queue_depth(&recipient_key, &channel_id) {
metrics::record_delivery_queue_depth(depth);
}
tracing::info!(
sender_prefix = sender_prefix.as_deref().unwrap_or("sealed"),
recipient_prefix = %fmt_hex(&recipient_key[..4]),
payload_len = payload_len,
seq = seq,
"audit: enqueue"
);
crate::auth::waiter(&self.waiters, &recipient_key).notify_waiters();
Promise::ok(())
}
pub fn handle_fetch(
&mut self,
params: node_service::FetchParams,
mut results: node_service::FetchResults,
) -> Promise<(), capnp::Error> {
let recipient_key = match params.get() {
Ok(p) => match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
},
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = params
.get()
.ok()
.and_then(|p| p.get_channel_id().ok())
.map(|c| c.to_vec())
.unwrap_or_default();
let version = params
.get()
.ok()
.map(|p| p.get_version())
.unwrap_or(CURRENT_WIRE_VERSION);
let limit = params.get().ok().map(|p| p.get_limit()).unwrap_or(0);
let auth_ctx = match params
.get()
.ok()
.map(|p| validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()))
.transpose()
{
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
let auth_ctx = match auth_ctx {
Some(ctx) => ctx,
None => return Promise::err(coded_error(E003_INVALID_TOKEN, "auth required")),
};
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
let messages = if limit > 0 {
match self
.store
.fetch_limited(&recipient_key, &channel_id, limit as usize)
.map_err(storage_err)
{
Ok(m) => m,
Err(e) => return Promise::err(e),
}
} else {
match self
.store
.fetch(&recipient_key, &channel_id)
.map_err(storage_err)
{
Ok(m) => m,
Err(e) => return Promise::err(e),
}
};
// Hook: on_fetch — fires after messages are retrieved.
self.hooks.on_fetch(&FetchEvent {
recipient_key: recipient_key.clone(),
channel_id: channel_id.clone(),
message_count: messages.len(),
});
// Audit: fetch — do not log payload or full keys.
metrics::record_fetch_total();
tracing::info!(
recipient_prefix = %fmt_hex(&recipient_key[..4]),
count = messages.len(),
"audit: fetch"
);
let mut list = results.get().init_payloads(messages.len() as u32);
for (i, (seq, data)) in messages.iter().enumerate() {
let mut entry = list.reborrow().get(i as u32);
entry.set_seq(*seq);
entry.set_data(data);
}
Promise::ok(())
}
pub fn handle_fetch_wait(
&mut self,
params: node_service::FetchWaitParams,
mut results: node_service::FetchWaitResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let timeout_ms = p.get_timeout_ms();
let limit = p.get_limit();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
let store = Arc::clone(&self.store);
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = self.waiters.clone();
Promise::from_future(async move {
let fetch_fn = |s: &Arc<dyn Store>, rk: &[u8], ch: &[u8], lim: u32| -> Result<Vec<(u64, Vec<u8>)>, capnp::Error> {
if lim > 0 {
s.fetch_limited(rk, ch, lim as usize).map_err(storage_err)
} else {
s.fetch(rk, ch).map_err(storage_err)
}
};
let messages = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
if messages.is_empty() && timeout_ms > 0 {
let waiter = waiters
.entry(recipient_key.clone())
.or_insert_with(|| Arc::new(Notify::new()))
.clone();
let _ = timeout(Duration::from_millis(timeout_ms), waiter.notified()).await;
let msgs = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
fill_payloads_wait(&mut results, msgs);
metrics::record_fetch_wait_total();
return Ok(());
}
fill_payloads_wait(&mut results, messages);
metrics::record_fetch_wait_total();
Ok(())
})
}
pub fn handle_peek(
&mut self,
params: node_service::PeekParams,
mut results: node_service::PeekResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let limit = p.get_limit();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
let messages = match self
.store
.peek(&recipient_key, &channel_id, limit as usize)
.map_err(storage_err)
{
Ok(m) => m,
Err(e) => return Promise::err(e),
};
tracing::info!(
recipient_prefix = %fmt_hex(&recipient_key[..4]),
count = messages.len(),
"audit: peek"
);
let mut list = results.get().init_payloads(messages.len() as u32);
for (i, (seq, data)) in messages.iter().enumerate() {
let mut entry = list.reborrow().get(i as u32);
entry.set_seq(*seq);
entry.set_data(data);
}
Promise::ok(())
}
pub fn handle_ack(
&mut self,
params: node_service::AckParams,
_results: node_service::AckResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let seq_up_to = p.get_seq_up_to();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
match self
.store
.ack(&recipient_key, &channel_id, seq_up_to)
.map_err(storage_err)
{
Ok(removed) => {
tracing::info!(
recipient_prefix = %fmt_hex(&recipient_key[..4]),
seq_up_to = seq_up_to,
removed = removed,
"audit: ack"
);
}
Err(e) => return Promise::err(e),
}
Promise::ok(())
}
pub fn handle_batch_enqueue(
&mut self,
params: node_service::BatchEnqueueParams,
mut results: node_service::BatchEnqueueResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_keys = match p.get_recipient_keys() {
Ok(v) => v,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let payload = match p.get_payload() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if payload.is_empty() {
return Promise::err(coded_error(E005_PAYLOAD_EMPTY, "payload must not be empty"));
}
if payload.len() > MAX_PAYLOAD_BYTES {
return Promise::err(coded_error(
E006_PAYLOAD_TOO_LARGE,
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) {
tracing::warn!("rate_limit_hit");
metrics::record_rate_limit_hit_total();
return Promise::err(e);
}
// Phase 4.3 — DS sender identity binding (same guarantee as handle_enqueue).
// sender_identity is derived solely from auth_ctx.identity_key, never client data.
if !self.sealed_sender {
if let Err(e) = crate::auth::require_identity(&auth_ctx) {
return Promise::err(e);
}
}
// DM channel authz: validate caller membership once before the loop.
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
if !caller_in {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller is not a member of this channel",
));
}
}
// Eagerly collect recipient keys so params can be dropped before any async work.
let mut recipient_key_vecs: Vec<Vec<u8>> = Vec::with_capacity(recipient_keys.len() as usize);
for i in 0..recipient_keys.len() {
let rk = match recipient_keys.get(i) {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if rk.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey[{}] must be exactly 32 bytes, got {}", i, rk.len()),
));
}
// Per-recipient DM channel membership check (only when channel_id is a 16-byte UUID).
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(
E023_CHANNEL_NOT_FOUND,
"channel not found",
));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let recipient_other = (rk == *a && caller == b.as_slice())
|| (rk == *b && caller == a.as_slice());
if !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"recipient is not a member of this channel",
));
}
}
recipient_key_vecs.push(rk);
}
// Hook: on_message_enqueue for each recipient — fires before storage.
// sender_identity is derived solely from auth_ctx (server-side session state).
let sender_identity = if self.sealed_sender {
None
} else {
crate::auth::require_identity(&auth_ctx).ok().map(|v| v.to_vec())
};
let sender_prefix = sender_identity
.as_deref()
.filter(|id| id.len() >= 4)
.map(|id| fmt_hex(&id[..4]));
let mut hook_events = Vec::with_capacity(recipient_key_vecs.len());
for rk in &recipient_key_vecs {
let event = MessageEvent {
sender_identity: sender_identity.clone(),
recipient_key: rk.clone(),
channel_id: channel_id.clone(),
payload_len: payload.len(),
seq: 0,
};
if let HookAction::Reject(reason) = self.hooks.on_message_enqueue(&event) {
return Promise::err(capnp::Error::failed(format!("hook rejected enqueue: {reason}")));
}
hook_events.push(event);
}
let n = recipient_key_vecs.len();
let store = Arc::clone(&self.store);
let waiters = Arc::clone(&self.waiters);
let fed_client = self.federation_client.clone();
let local_domain = self.local_domain.clone();
let hooks = Arc::clone(&self.hooks);
// Use an async future to support federation relay alongside local enqueue.
// All storage operations are synchronous; only federation relay calls are await-ed.
Promise::from_future(async move {
let mut seqs = Vec::with_capacity(n);
for rk in &recipient_key_vecs {
// Federation routing: relay to the recipient's home server when remote.
let dest = if let (Some(ref _fed), Some(ref domain)) = (&fed_client, &local_domain) {
crate::federation::routing::resolve_destination(&store, rk, domain)
} else {
crate::federation::routing::Destination::Local
};
let seq = match dest {
crate::federation::routing::Destination::Remote(ref remote_domain) => {
let fed = fed_client.as_deref().ok_or_else(|| {
capnp::Error::failed("federation client unavailable for remote routing".into())
})?;
tracing::info!(
recipient_prefix = %fmt_hex(&rk[..4]),
domain = %remote_domain,
"federation: routing batch enqueue to remote server"
);
fed.relay_enqueue(remote_domain, rk, &payload, &channel_id)
.await
.map_err(|e| {
capnp::Error::failed(format!("federation relay failed: {e}"))
})?
}
crate::federation::routing::Destination::Local => {
match store.queue_depth(rk, &channel_id) {
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
return Err(coded_error(
E015_QUEUE_FULL,
format!("queue depth {} exceeds limit {MAX_QUEUE_DEPTH}", depth),
));
}
Err(e) => return Err(storage_err(e)),
_ => {}
}
store
.enqueue(rk, &channel_id, payload.clone())
.map_err(storage_err)?
}
};
seqs.push(seq);
metrics::record_enqueue_total();
metrics::record_enqueue_bytes(payload.len() as u64);
crate::auth::waiter(&waiters, rk).notify_waiters();
}
let mut list = results.get().init_seqs(seqs.len() as u32);
for (i, seq) in seqs.iter().enumerate() {
list.set(i as u32, *seq);
}
// Hook: on_batch_enqueue — fires after all messages are stored.
hooks.on_batch_enqueue(&hook_events);
tracing::info!(
sender_prefix = sender_prefix.as_deref().unwrap_or("sealed"),
recipient_count = n,
payload_len = payload.len(),
"audit: batch_enqueue"
);
Ok(())
})
}
}