319 lines
11 KiB
Rust
319 lines
11 KiB
Rust
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use capnp::capability::Promise;
|
|
use dashmap::DashMap;
|
|
use quicnprotochat_proto::node_capnp::node_service;
|
|
use tokio::sync::Notify;
|
|
use tokio::time::timeout;
|
|
|
|
use crate::auth::{
|
|
check_rate_limit, coded_error, fmt_hex, require_identity_or_request, validate_auth_context,
|
|
};
|
|
use crate::error_codes::*;
|
|
use crate::metrics;
|
|
use crate::storage::{StorageError, Store};
|
|
|
|
use super::{NodeServiceImpl, CURRENT_WIRE_VERSION};
|
|
|
|
// Audit events here must not include secrets: no payload content, no full recipient/token bytes (prefix only).
|
|
|
|
const MAX_PAYLOAD_BYTES: usize = 5 * 1024 * 1024; // 5 MB cap per message
|
|
const MAX_QUEUE_DEPTH: usize = 1000;
|
|
|
|
fn storage_err(err: StorageError) -> capnp::Error {
|
|
coded_error(E009_STORAGE_ERROR, err)
|
|
}
|
|
|
|
pub fn fill_payloads_wait(
|
|
results: &mut node_service::FetchWaitResults,
|
|
messages: Vec<(u64, Vec<u8>)>,
|
|
) {
|
|
let mut list = results.get().init_payloads(messages.len() as u32);
|
|
for (i, (seq, data)) in messages.iter().enumerate() {
|
|
let mut entry = list.reborrow().get(i as u32);
|
|
entry.set_seq(*seq);
|
|
entry.set_data(data);
|
|
}
|
|
}
|
|
|
|
impl NodeServiceImpl {
|
|
pub fn handle_enqueue(
|
|
&mut self,
|
|
params: node_service::EnqueueParams,
|
|
mut results: node_service::EnqueueResults,
|
|
) -> Promise<(), capnp::Error> {
|
|
let p = match params.get() {
|
|
Ok(p) => p,
|
|
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
|
};
|
|
let recipient_key = match p.get_recipient_key() {
|
|
Ok(v) => v.to_vec(),
|
|
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
|
};
|
|
let payload = match p.get_payload() {
|
|
Ok(v) => v.to_vec(),
|
|
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
|
};
|
|
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
|
let version = p.get_version();
|
|
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
|
Ok(ctx) => ctx,
|
|
Err(e) => return Promise::err(e),
|
|
};
|
|
|
|
if recipient_key.len() != 32 {
|
|
return Promise::err(coded_error(
|
|
E004_IDENTITY_KEY_LENGTH,
|
|
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
|
));
|
|
}
|
|
if payload.is_empty() {
|
|
return Promise::err(coded_error(E005_PAYLOAD_EMPTY, "payload must not be empty"));
|
|
}
|
|
if payload.len() > MAX_PAYLOAD_BYTES {
|
|
return Promise::err(coded_error(
|
|
E006_PAYLOAD_TOO_LARGE,
|
|
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
|
|
));
|
|
}
|
|
if version != CURRENT_WIRE_VERSION {
|
|
return Promise::err(coded_error(
|
|
E012_WIRE_VERSION,
|
|
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
|
|
));
|
|
}
|
|
|
|
if let Err(e) = check_rate_limit(&self.rate_limiter, &auth_ctx.token) {
|
|
// Audit: rate limit hit — do not log token or identity.
|
|
tracing::warn!("rate_limit_hit");
|
|
metrics::record_rate_limit_hit_total();
|
|
return Promise::err(e);
|
|
}
|
|
|
|
// When sealed_sender is true, enqueue does not require identity; valid token only.
|
|
if !self.sealed_sender {
|
|
if let Err(e) = require_identity_or_request(
|
|
&auth_ctx,
|
|
&recipient_key,
|
|
self.auth_cfg.allow_insecure_identity_from_request,
|
|
) {
|
|
return Promise::err(e);
|
|
}
|
|
}
|
|
|
|
match self.store.queue_depth(&recipient_key, &channel_id) {
|
|
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
|
|
return Promise::err(coded_error(
|
|
E015_QUEUE_FULL,
|
|
format!("queue depth {} exceeds limit {}", depth, MAX_QUEUE_DEPTH),
|
|
));
|
|
}
|
|
Err(e) => return Promise::err(storage_err(e)),
|
|
_ => {}
|
|
}
|
|
|
|
let payload_len = payload.len();
|
|
let seq = match self
|
|
.store
|
|
.enqueue(&recipient_key, &channel_id, payload)
|
|
.map_err(storage_err)
|
|
{
|
|
Ok(seq) => seq,
|
|
Err(e) => return Promise::err(e),
|
|
};
|
|
|
|
results.get().set_seq(seq);
|
|
|
|
// Metrics and audit. Audit events must not include secrets (no payload, no full keys).
|
|
metrics::record_enqueue_total();
|
|
metrics::record_enqueue_bytes(payload_len as u64);
|
|
if let Ok(depth) = self.store.queue_depth(&recipient_key, &channel_id) {
|
|
metrics::record_delivery_queue_depth(depth);
|
|
}
|
|
tracing::info!(
|
|
recipient_prefix = %fmt_hex(&recipient_key[..4]),
|
|
payload_len = payload_len,
|
|
seq = seq,
|
|
"audit: enqueue"
|
|
);
|
|
|
|
crate::auth::waiter(&self.waiters, &recipient_key).notify_waiters();
|
|
|
|
Promise::ok(())
|
|
}
|
|
|
|
pub fn handle_fetch(
|
|
&mut self,
|
|
params: node_service::FetchParams,
|
|
mut results: node_service::FetchResults,
|
|
) -> Promise<(), capnp::Error> {
|
|
let recipient_key = match params.get() {
|
|
Ok(p) => match p.get_recipient_key() {
|
|
Ok(v) => v.to_vec(),
|
|
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
|
},
|
|
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
|
};
|
|
let channel_id = params
|
|
.get()
|
|
.ok()
|
|
.and_then(|p| p.get_channel_id().ok())
|
|
.map(|c| c.to_vec())
|
|
.unwrap_or_default();
|
|
let version = params
|
|
.get()
|
|
.ok()
|
|
.map(|p| p.get_version())
|
|
.unwrap_or(CURRENT_WIRE_VERSION);
|
|
let limit = params.get().ok().map(|p| p.get_limit()).unwrap_or(0);
|
|
let auth_ctx = match params
|
|
.get()
|
|
.ok()
|
|
.map(|p| validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()))
|
|
.transpose()
|
|
{
|
|
Ok(ctx) => ctx,
|
|
Err(e) => return Promise::err(e),
|
|
};
|
|
|
|
if recipient_key.len() != 32 {
|
|
return Promise::err(coded_error(
|
|
E004_IDENTITY_KEY_LENGTH,
|
|
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
|
));
|
|
}
|
|
if version != CURRENT_WIRE_VERSION {
|
|
return Promise::err(coded_error(
|
|
E012_WIRE_VERSION,
|
|
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
|
|
));
|
|
}
|
|
|
|
let auth_ctx = match auth_ctx {
|
|
Some(ctx) => ctx,
|
|
None => return Promise::err(coded_error(E003_INVALID_TOKEN, "auth required")),
|
|
};
|
|
|
|
if let Err(e) = require_identity_or_request(
|
|
&auth_ctx,
|
|
&recipient_key,
|
|
self.auth_cfg.allow_insecure_identity_from_request,
|
|
) {
|
|
return Promise::err(e);
|
|
}
|
|
|
|
let messages = if limit > 0 {
|
|
match self
|
|
.store
|
|
.fetch_limited(&recipient_key, &channel_id, limit as usize)
|
|
.map_err(storage_err)
|
|
{
|
|
Ok(m) => m,
|
|
Err(e) => return Promise::err(e),
|
|
}
|
|
} else {
|
|
match self
|
|
.store
|
|
.fetch(&recipient_key, &channel_id)
|
|
.map_err(storage_err)
|
|
{
|
|
Ok(m) => m,
|
|
Err(e) => return Promise::err(e),
|
|
}
|
|
};
|
|
|
|
// Audit: fetch — do not log payload or full keys.
|
|
metrics::record_fetch_total();
|
|
tracing::info!(
|
|
recipient_prefix = %fmt_hex(&recipient_key[..4]),
|
|
count = messages.len(),
|
|
"audit: fetch"
|
|
);
|
|
|
|
let mut list = results.get().init_payloads(messages.len() as u32);
|
|
for (i, (seq, data)) in messages.iter().enumerate() {
|
|
let mut entry = list.reborrow().get(i as u32);
|
|
entry.set_seq(*seq);
|
|
entry.set_data(data);
|
|
}
|
|
|
|
Promise::ok(())
|
|
}
|
|
|
|
pub fn handle_fetch_wait(
|
|
&mut self,
|
|
params: node_service::FetchWaitParams,
|
|
mut results: node_service::FetchWaitResults,
|
|
) -> Promise<(), capnp::Error> {
|
|
let p = match params.get() {
|
|
Ok(p) => p,
|
|
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
|
};
|
|
let recipient_key = match p.get_recipient_key() {
|
|
Ok(v) => v.to_vec(),
|
|
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
|
};
|
|
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
|
let version = p.get_version();
|
|
let timeout_ms = p.get_timeout_ms();
|
|
let limit = p.get_limit();
|
|
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
|
Ok(ctx) => ctx,
|
|
Err(e) => return Promise::err(e),
|
|
};
|
|
|
|
if recipient_key.len() != 32 {
|
|
return Promise::err(coded_error(
|
|
E004_IDENTITY_KEY_LENGTH,
|
|
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
|
));
|
|
}
|
|
if version != CURRENT_WIRE_VERSION {
|
|
return Promise::err(coded_error(
|
|
E012_WIRE_VERSION,
|
|
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
|
|
));
|
|
}
|
|
|
|
if let Err(e) = require_identity_or_request(
|
|
&auth_ctx,
|
|
&recipient_key,
|
|
self.auth_cfg.allow_insecure_identity_from_request,
|
|
) {
|
|
return Promise::err(e);
|
|
}
|
|
|
|
let store = Arc::clone(&self.store);
|
|
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = self.waiters.clone();
|
|
|
|
Promise::from_future(async move {
|
|
let fetch_fn = |s: &Arc<dyn Store>, rk: &[u8], ch: &[u8], lim: u32| -> Result<Vec<(u64, Vec<u8>)>, capnp::Error> {
|
|
if lim > 0 {
|
|
s.fetch_limited(rk, ch, lim as usize).map_err(storage_err)
|
|
} else {
|
|
s.fetch(rk, ch).map_err(storage_err)
|
|
}
|
|
};
|
|
|
|
let messages = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
|
|
|
|
if messages.is_empty() && timeout_ms > 0 {
|
|
let waiter = waiters
|
|
.entry(recipient_key.clone())
|
|
.or_insert_with(|| Arc::new(Notify::new()))
|
|
.clone();
|
|
let _ = timeout(Duration::from_millis(timeout_ms), waiter.notified()).await;
|
|
let msgs = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
|
|
fill_payloads_wait(&mut results, msgs);
|
|
metrics::record_fetch_wait_total();
|
|
return Ok(());
|
|
}
|
|
|
|
fill_payloads_wait(&mut results, messages);
|
|
metrics::record_fetch_wait_total();
|
|
Ok(())
|
|
})
|
|
}
|
|
}
|