DM channels (createChannel), channel authz, security/docs, future improvements

- Add createChannel RPC (node.capnp @18): create 1:1 channel, returns 16-byte channelId
- Store: create_channel(member_a, member_b), get_channel_members(channel_id)
- FileBackedStore: channels.bin; SqlStore: migration 003_channels, schema v4
- channel_ops: handle_create_channel (auth + identity, peerKey 32 bytes)
- Delivery authz: when channel_id.len() == 16, require caller and recipient are channel members (E022/E023)
- Error codes E022 CHANNEL_ACCESS_DENIED, E023 CHANNEL_NOT_FOUND
- SUMMARY: link Certificate lifecycle; security audit, future improvements, multi-agent plan docs
- Certificate lifecycle doc, SECURITY-AUDIT, FUTURE-IMPROVEMENTS, MULTI-AGENT-WORK-PLAN
- Client/core/tls/auth/server main: assorted fixes and updates from review and audit

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
2026-02-23 22:54:28 +01:00
parent 6b8b61c6ae
commit 750b794342
40 changed files with 4715 additions and 152 deletions

View File

@@ -38,6 +38,7 @@ thiserror = { workspace = true }
sha2 = { workspace = true }
argon2 = { workspace = true }
chacha20poly1305 = { workspace = true }
zeroize = { workspace = true }
quinn = { workspace = true }
quinn-proto = { workspace = true }
rustls = { workspace = true }
@@ -49,10 +50,12 @@ tracing-subscriber = { workspace = true }
# CLI
clap = { workspace = true }
# Hex encoding/decoding
hex = "0.4"
[dev-dependencies]
dashmap = { workspace = true }
assert_cmd = "2"
tempfile = "3"
portpicker = "0.1"
rand = "0.8"
hex = "0.4"

View File

@@ -574,7 +574,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
.await?
.context("joiner hybrid key not found")?;
let wrapped_welcome =
hybrid_encrypt(&joiner_hybrid_pk, &welcome).context("hybrid encrypt welcome")?;
hybrid_encrypt(&joiner_hybrid_pk, &welcome, b"", b"").context("hybrid encrypt welcome")?;
enqueue(&creator_ds, &joiner_identity, &wrapped_welcome).await?;
let welcome_payloads = fetch_all(&joiner_ds, &joiner_identity).await?;
@@ -584,7 +584,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
.context("Welcome was not delivered to joiner via DS")?;
let welcome_bytes =
hybrid_decrypt(&joiner_hybrid, &raw_welcome).context("hybrid decrypt welcome failed")?;
hybrid_decrypt(&joiner_hybrid, &raw_welcome, b"", b"").context("hybrid decrypt welcome failed")?;
joiner
.join_group(&welcome_bytes)
.context("join_group failed")?;
@@ -593,7 +593,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
.send_message(b"hello")
.context("send_message failed")?;
let wrapped_creator_joiner =
hybrid_encrypt(&joiner_hybrid_pk, &ct_creator_to_joiner).context("hybrid encrypt failed")?;
hybrid_encrypt(&joiner_hybrid_pk, &ct_creator_to_joiner, b"", b"").context("hybrid encrypt failed")?;
enqueue(&creator_ds, &joiner_identity, &wrapped_creator_joiner).await?;
let joiner_msgs = fetch_all(&joiner_ds, &joiner_identity).await?;
@@ -601,7 +601,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
.first()
.context("joiner: missing ciphertext from DS")?;
let inner_creator_joiner =
hybrid_decrypt(&joiner_hybrid, raw_creator_joiner).context("hybrid decrypt failed")?;
hybrid_decrypt(&joiner_hybrid, raw_creator_joiner, b"", b"").context("hybrid decrypt failed")?;
let plaintext_creator_joiner = joiner
.receive_message(&inner_creator_joiner)?
.context("expected application message")?;
@@ -617,7 +617,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
.send_message(b"hello back")
.context("send_message failed")?;
let wrapped_joiner_creator =
hybrid_encrypt(&creator_hybrid_pk, &ct_joiner_to_creator).context("hybrid encrypt failed")?;
hybrid_encrypt(&creator_hybrid_pk, &ct_joiner_to_creator, b"", b"").context("hybrid encrypt failed")?;
enqueue(&joiner_ds, &creator_identity, &wrapped_joiner_creator).await?;
let creator_msgs = fetch_all(&creator_ds, &creator_identity).await?;
@@ -625,7 +625,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
.first()
.context("creator: missing ciphertext from DS")?;
let inner_joiner_creator =
hybrid_decrypt(&creator_hybrid, raw_joiner_creator).context("hybrid decrypt failed")?;
hybrid_decrypt(&creator_hybrid, raw_joiner_creator, b"", b"").context("hybrid decrypt failed")?;
let plaintext_joiner_creator = creator
.receive_message(&inner_joiner_creator)?
.context("expected application message")?;
@@ -701,7 +701,7 @@ pub async fn cmd_invite(
}
let peer_hpk = fetch_hybrid_key(&node_client, mk).await?;
let commit_payload = if let Some(ref pk) = peer_hpk {
hybrid_encrypt(pk, &commit).context("hybrid encrypt commit")?
hybrid_encrypt(pk, &commit, b"", b"").context("hybrid encrypt commit")?
} else {
commit.clone()
};
@@ -710,7 +710,7 @@ pub async fn cmd_invite(
let peer_hybrid_pk = fetch_hybrid_key(&node_client, &peer_key).await?;
let payload = if let Some(ref pk) = peer_hybrid_pk {
hybrid_encrypt(pk, &welcome).context("hybrid encrypt welcome failed")?
hybrid_encrypt(pk, &welcome, b"", b"").context("hybrid encrypt welcome failed")?
} else {
welcome
};
@@ -774,6 +774,15 @@ pub async fn cmd_join(
let _ = member.receive_message(&mls_payload);
}
// Auto-replenish KeyPackage after join consumed the original one.
let tls_bytes = member
.generate_key_package()
.context("KeyPackage replenishment failed")?;
upload_key_package(&node_client, &member.identity().public_key_bytes(), &tls_bytes)
.await
.context("KeyPackage replenishment upload failed")?;
println!("KeyPackage auto-replenished after join");
save_state(state_path, &member, hybrid_kp.as_ref(), password)?;
println!("joined group successfully");
Ok(())
@@ -820,7 +829,7 @@ pub async fn cmd_send(
for recipient in &recipients {
let peer_hybrid_pk = fetch_hybrid_key(&node_client, recipient).await?;
let payload = if let Some(ref pk) = peer_hybrid_pk {
hybrid_encrypt(pk, &ct).context("hybrid encrypt failed")?
hybrid_encrypt(pk, &ct, b"", b"").context("hybrid encrypt failed")?
} else {
ct.clone()
};
@@ -871,7 +880,7 @@ pub async fn cmd_recv(
// application messages that depend on the resulting epoch.
payloads.sort_by_key(|(seq, _)| *seq);
let mut retry_mls: Vec<Vec<u8>> = Vec::new();
let mut pending: Vec<(usize, Vec<u8>)> = Vec::new();
for (idx, (_, payload)) in payloads.iter().enumerate() {
let mls_payload = match try_hybrid_decrypt(hybrid_kp.as_ref(), payload) {
Ok(b) => b,
@@ -883,18 +892,32 @@ pub async fn cmd_recv(
match member.receive_message(&mls_payload) {
Ok(Some(pt)) => println!("[{idx}] plaintext: {}", String::from_utf8_lossy(&pt)),
Ok(None) => println!("[{idx}] commit applied"),
Err(_) => retry_mls.push(mls_payload),
Err(_) => pending.push((idx, mls_payload)),
}
}
// Retry messages that failed on the first pass (e.g. app messages whose
// epoch was not yet advanced until a commit earlier in the batch was applied).
for mls_payload in &retry_mls {
match member.receive_message(mls_payload) {
Ok(Some(pt)) => println!("[retry] plaintext: {}", String::from_utf8_lossy(&pt)),
Ok(None) => {}
Err(e) => println!("[retry] error: {e}"),
// Retry until no more progress (handles multi-epoch batches).
loop {
let before = pending.len();
pending.retain(|(idx, mls_payload)| {
match member.receive_message(mls_payload) {
Ok(Some(pt)) => {
println!("[{idx}/retry] plaintext: {}", String::from_utf8_lossy(&pt));
false
}
Ok(None) => {
println!("[{idx}/retry] commit applied");
false
}
Err(_) => true,
}
});
if pending.len() == before {
break; // No progress — remaining messages are unprocessable
}
}
for (idx, _) in &pending {
println!("[{idx}] error: unprocessable after all retries");
}
save_state(state_path, &member, hybrid_kp.as_ref(), password)?;
@@ -906,8 +929,8 @@ pub async fn cmd_recv(
/// Fetch pending payloads, process in order (merge commits, collect plaintexts), save state.
/// Returns only application-message plaintexts. Used by E2E tests and callers that need returned messages.
/// Uses two passes so that if the server delivers an application message before a Commit, the second pass
/// processes it after commits are merged.
/// Retries in a loop until no more progress, handling multi-epoch batches where commits must be
/// applied before later application messages can be decrypted.
pub async fn receive_pending_plaintexts(
state_path: &Path,
server: &str,
@@ -925,7 +948,7 @@ pub async fn receive_pending_plaintexts(
payloads.sort_by_key(|(seq, _)| *seq);
let mut plaintexts = Vec::new();
let mut retry_mls: Vec<Vec<u8>> = Vec::new();
let mut pending: Vec<Vec<u8>> = Vec::new();
for (_, payload) in &payloads {
let mls_payload = match try_hybrid_decrypt(hybrid_kp.as_ref(), payload) {
Ok(b) => b,
@@ -934,12 +957,24 @@ pub async fn receive_pending_plaintexts(
match member.receive_message(&mls_payload) {
Ok(Some(pt)) => plaintexts.push(pt),
Ok(None) => {}
Err(_) => retry_mls.push(mls_payload),
Err(_) => pending.push(mls_payload),
}
}
for mls_payload in &retry_mls {
if let Ok(Some(pt)) = member.receive_message(mls_payload) {
plaintexts.push(pt);
// Retry until no more progress (handles multi-epoch batches).
loop {
let before = pending.len();
pending.retain(|mls_payload| {
match member.receive_message(mls_payload) {
Ok(Some(pt)) => {
plaintexts.push(pt);
false
}
Ok(None) => false,
Err(_) => true,
}
});
if pending.len() == before {
break;
}
}
@@ -1069,7 +1104,7 @@ pub async fn cmd_chat(
.context("send_message failed")?;
let peer_hybrid_pk = fetch_hybrid_key(&client, &peer_key).await?;
let payload = if let Some(ref pk) = peer_hybrid_pk {
hybrid_encrypt(pk, &ct).context("hybrid encrypt failed")?
hybrid_encrypt(pk, &ct, b"", b"").context("hybrid encrypt failed")?
} else {
ct
};
@@ -1085,6 +1120,7 @@ pub async fn cmd_chat(
_ = poll.tick() => {
let mut payloads = fetch_wait(&client, &identity_bytes, 0).await?;
payloads.sort_by_key(|(seq, _)| *seq);
let mut retry_payloads: Vec<Vec<u8>> = Vec::new();
for (_, payload) in &payloads {
let mls_payload = match try_hybrid_decrypt(hybrid_kp.as_ref(), payload) {
Ok(b) => b,
@@ -1097,9 +1133,26 @@ pub async fn cmd_chat(
std::io::stdout().flush().context("flush stdout")?;
}
Ok(None) => {}
Err(_) => {}
Err(_) => retry_payloads.push(mls_payload),
}
}
// Retry failed messages (epoch may have advanced from commits in this batch)
loop {
let before = retry_payloads.len();
retry_payloads.retain(|mls_payload| {
match member.receive_message(mls_payload) {
Ok(Some(pt)) => {
let s = String::from_utf8_lossy(&pt);
println!("\r\n[peer] {s}\n> ");
let _ = std::io::stdout().flush();
false
}
Ok(None) => false,
Err(_) => true,
}
});
if retry_payloads.len() == before { break; }
}
if !payloads.is_empty() {
save_state(state_path, &member, hybrid_kp.as_ref(), password)?;
}

View File

@@ -1,13 +1,7 @@
pub fn encode(bytes: impl AsRef<[u8]>) -> String {
bytes.as_ref().iter().map(|b| format!("{b:02x}")).collect()
hex::encode(bytes)
}
pub fn decode(s: &str) -> Result<Vec<u8>, &'static str> {
if s.len() % 2 != 0 {
return Err("odd-length hex string");
}
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|_| "invalid hex character"))
.collect()
hex::decode(s).map_err(|_| "invalid hex string")
}

View File

@@ -48,7 +48,12 @@ where
}
}
}
Err(last_err.expect("retry_async: last_err set when we break after Err"))
match last_err {
Some(e) => Err(e),
None => unreachable!(
"retry_async: last_err is always Some when loop exits after an Err"
),
}
}
/// Classifies `anyhow::Error` for retry: returns `false` for auth or invalid-param

View File

@@ -17,6 +17,9 @@ use crate::AUTH_CONTEXT;
use super::retry::{anyhow_is_retriable, retry_async, DEFAULT_BASE_DELAY_MS, DEFAULT_MAX_RETRIES};
/// Cap'n Proto traversal limit (words). 4 Mi words = 32 MiB; bounds DoS from deeply nested or large messages.
const CAPNP_TRAVERSAL_LIMIT_WORDS: usize = 4 * 1024 * 1024;
/// Establish a QUIC/TLS connection and return a `NodeService` client.
///
/// Must be called from within a `LocalSet` because capnp-rpc is `!Send`.
@@ -55,11 +58,13 @@ pub async fn connect_node(
let (send, recv) = connection.open_bi().await.context("open bi stream")?;
let mut reader_opts = capnp::message::ReaderOptions::new();
reader_opts.traversal_limit_in_words(Some(CAPNP_TRAVERSAL_LIMIT_WORDS));
let network = twoparty::VatNetwork::new(
recv.compat(),
send.compat_write(),
Side::Client,
Default::default(),
reader_opts,
);
let mut rpc_system = RpcSystem::new(Box::new(network), None);
@@ -72,7 +77,9 @@ pub async fn connect_node(
pub fn set_auth(auth: &mut auth::Builder<'_>) -> anyhow::Result<()> {
let ctx = AUTH_CONTEXT.get().ok_or_else(|| {
anyhow::anyhow!("init_auth must be called with a non-empty token before RPCs")
anyhow::anyhow!(
"init_auth must be called before RPCs (use a bearer or session token for authenticated commands)"
)
})?;
auth.set_version(ctx.version);
auth.set_access_token(&ctx.access_token);
@@ -355,7 +362,216 @@ pub fn try_hybrid_decrypt(
payload: &[u8],
) -> anyhow::Result<Vec<u8>> {
let kp = hybrid_kp.ok_or_else(|| anyhow::anyhow!("hybrid key required for decryption"))?;
quicnprotochat_core::hybrid_decrypt(kp, payload).map_err(|e| anyhow::anyhow!("{e}"))
quicnprotochat_core::hybrid_decrypt(kp, payload, b"", b"").map_err(|e| anyhow::anyhow!("{e}"))
}
/// Peek at queued payloads without removing them.
/// Returns `(seq, payload)` pairs sorted by seq.
/// Retries on transient failures with exponential backoff.
pub async fn peek(
client: &node_service::Client,
recipient_key: &[u8],
) -> anyhow::Result<Vec<(u64, Vec<u8>)>> {
let client = client.clone();
let recipient_key = recipient_key.to_vec();
retry_async(
|| {
let client = client.clone();
let recipient_key = recipient_key.clone();
async move {
let mut req = client.peek_request();
{
let mut p = req.get();
p.set_recipient_key(&recipient_key);
p.set_channel_id(&[]);
p.set_version(1);
p.set_limit(0); // peek all
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth)?;
}
let resp = req.send().promise.await.context("peek RPC failed")?;
let list = resp
.get()
.context("peek: bad response")?
.get_payloads()
.context("peek: missing payloads")?;
let mut payloads = Vec::with_capacity(list.len() as usize);
for i in 0..list.len() {
let entry = list.get(i);
let seq = entry.get_seq();
let data = entry
.get_data()
.context("peek: envelope data read failed")?
.to_vec();
payloads.push((seq, data));
}
Ok(payloads)
}
},
DEFAULT_MAX_RETRIES,
DEFAULT_BASE_DELAY_MS,
anyhow_is_retriable,
)
.await
}
/// Acknowledge all messages up to and including `seq_up_to`.
/// Retries on transient failures with exponential backoff.
pub async fn ack(
client: &node_service::Client,
recipient_key: &[u8],
seq_up_to: u64,
) -> anyhow::Result<()> {
let client = client.clone();
let recipient_key = recipient_key.to_vec();
retry_async(
|| {
let client = client.clone();
let recipient_key = recipient_key.clone();
async move {
let mut req = client.ack_request();
{
let mut p = req.get();
p.set_recipient_key(&recipient_key);
p.set_channel_id(&[]);
p.set_version(1);
p.set_seq_up_to(seq_up_to);
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth)?;
}
req.send().promise.await.context("ack RPC failed")?;
Ok(())
}
},
DEFAULT_MAX_RETRIES,
DEFAULT_BASE_DELAY_MS,
anyhow_is_retriable,
)
.await
}
/// Fetch multiple peers' hybrid keys in a single round-trip.
/// Returns `None` for peers who have not uploaded a hybrid key.
/// Retries on transient failures with exponential backoff.
pub async fn fetch_hybrid_keys(
client: &node_service::Client,
identity_keys: &[&[u8]],
) -> anyhow::Result<Vec<Option<HybridPublicKey>>> {
let client = client.clone();
let identity_keys: Vec<Vec<u8>> = identity_keys.iter().map(|k| k.to_vec()).collect();
retry_async(
|| {
let client = client.clone();
let identity_keys = identity_keys.clone();
async move {
let mut req = client.fetch_hybrid_keys_request();
{
let mut p = req.get();
let mut list = p.reborrow().init_identity_keys(identity_keys.len() as u32);
for (i, ik) in identity_keys.iter().enumerate() {
list.set(i as u32, ik);
}
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth)?;
}
let resp = req
.send()
.promise
.await
.context("fetch_hybrid_keys RPC failed")?;
let keys = resp
.get()
.context("fetch_hybrid_keys: bad response")?
.get_keys()
.context("fetch_hybrid_keys: missing keys")?;
let mut result = Vec::with_capacity(keys.len() as usize);
for i in 0..keys.len() {
let pk_bytes = keys
.get(i)
.context("fetch_hybrid_keys: key read failed")?
.to_vec();
if pk_bytes.is_empty() {
result.push(None);
} else {
let pk = HybridPublicKey::from_bytes(&pk_bytes)
.context("invalid hybrid public key")?;
result.push(Some(pk));
}
}
Ok(result)
}
},
DEFAULT_MAX_RETRIES,
DEFAULT_BASE_DELAY_MS,
anyhow_is_retriable,
)
.await
}
/// Enqueue the same payload to multiple recipients in a single round-trip.
/// Returns per-recipient sequence numbers.
/// Retries on transient failures with exponential backoff.
pub async fn batch_enqueue(
client: &node_service::Client,
recipient_keys: &[&[u8]],
payload: &[u8],
) -> anyhow::Result<Vec<u64>> {
let client = client.clone();
let recipient_keys: Vec<Vec<u8>> = recipient_keys.iter().map(|k| k.to_vec()).collect();
let payload = payload.to_vec();
retry_async(
|| {
let client = client.clone();
let recipient_keys = recipient_keys.clone();
let payload = payload.clone();
async move {
let mut req = client.batch_enqueue_request();
{
let mut p = req.get();
let mut list = p.reborrow().init_recipient_keys(recipient_keys.len() as u32);
for (i, rk) in recipient_keys.iter().enumerate() {
list.set(i as u32, rk);
}
p.set_payload(&payload);
p.set_channel_id(&[]);
p.set_version(1);
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth)?;
}
let resp = req
.send()
.promise
.await
.context("batch_enqueue RPC failed")?;
let seqs = resp
.get()
.context("batch_enqueue: bad response")?
.get_seqs()
.context("batch_enqueue: missing seqs")?;
let mut result = Vec::with_capacity(seqs.len() as usize);
for i in 0..seqs.len() {
result.push(seqs.get(i));
}
Ok(result)
}
},
DEFAULT_MAX_RETRIES,
DEFAULT_BASE_DELAY_MS,
anyhow_is_retriable,
)
.await
}
/// Return the current Unix timestamp in milliseconds.

View File

@@ -2,7 +2,7 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::Context;
use argon2::Argon2;
use argon2::{Algorithm, Argon2, Params, Version};
use chacha20poly1305::{
aead::{Aead, KeyInit},
ChaCha20Poly1305, Key, Nonce,
@@ -62,10 +62,21 @@ impl StoredState {
}
}
/// Derive a 32-byte key from a password and salt using Argon2id.
/// Argon2id parameters for client state key derivation (auditable; matches argon2 crate defaults).
/// - Memory: 19 MiB (m_cost = 19*1024 KiB)
/// - Time: 2 iterations
/// - Parallelism: 1 lane
const ARGON2_STATE_M_COST: u32 = 19 * 1024;
const ARGON2_STATE_T_COST: u32 = 2;
const ARGON2_STATE_P_COST: u32 = 1;
/// Derive a 32-byte key from a password and salt using Argon2id with explicit parameters.
fn derive_state_key(password: &str, salt: &[u8]) -> anyhow::Result<[u8; 32]> {
let params = Params::new(ARGON2_STATE_M_COST, ARGON2_STATE_T_COST, ARGON2_STATE_P_COST, Some(32))
.map_err(|e| anyhow::anyhow!("argon2 params: {e}"))?;
let argon2 = Argon2::new(Algorithm::Argon2id, Version::default(), params);
let mut key = [0u8; 32];
Argon2::default()
argon2
.hash_password_into(password.as_bytes(), salt, &mut key)
.map_err(|e| anyhow::anyhow!("argon2 key derivation failed: {e}"))?;
Ok(key)
@@ -79,8 +90,8 @@ pub fn encrypt_state(password: &str, plaintext: &[u8]) -> anyhow::Result<Vec<u8>
let mut nonce_bytes = [0u8; STATE_NONCE_LEN];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let key = derive_state_key(password, &salt)?;
let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
let key = zeroize::Zeroizing::new(derive_state_key(password, &salt)?);
let cipher = ChaCha20Poly1305::new(Key::from_slice(&*key));
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
@@ -108,8 +119,8 @@ pub fn decrypt_state(password: &str, data: &[u8]) -> anyhow::Result<Vec<u8>> {
let nonce_bytes = &data[4 + STATE_SALT_LEN..header_len];
let ciphertext = &data[header_len..];
let key = derive_state_key(password, salt)?;
let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
let key = zeroize::Zeroizing::new(derive_state_key(password, salt)?);
let cipher = ChaCha20Poly1305::new(Key::from_slice(&*key));
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = cipher
@@ -179,7 +190,9 @@ pub fn write_state(path: &Path, state: &StoredState, password: Option<&str>) ->
plaintext
};
std::fs::write(path, bytes).with_context(|| format!("write state {path:?}"))?;
let tmp = path.with_extension("tmp");
std::fs::write(&tmp, bytes).with_context(|| format!("write state temp {tmp:?}"))?;
std::fs::rename(&tmp, path).with_context(|| format!("rename state {tmp:?} -> {path:?}"))?;
Ok(())
}
@@ -222,4 +235,57 @@ mod tests {
let encrypted = encrypt_state("correct", plaintext).unwrap();
assert!(decrypt_state("wrong", &encrypted).is_err());
}
#[test]
fn state_encrypt_decrypt_round_trip() {
let state = StoredState {
identity_seed: [42u8; 32],
hybrid_key: None,
group: None,
member_keys: Vec::new(),
};
let password = "test-password";
let plaintext = bincode::serialize(&state).unwrap();
let encrypted = encrypt_state(password, &plaintext).unwrap();
let decrypted = decrypt_state(password, &encrypted).unwrap();
let recovered: StoredState = bincode::deserialize(&decrypted).unwrap();
assert_eq!(recovered.identity_seed, state.identity_seed);
assert!(recovered.hybrid_key.is_none());
assert!(recovered.group.is_none());
}
#[test]
fn state_encrypt_decrypt_with_hybrid_key() {
use zeroize::Zeroizing;
let state = StoredState {
identity_seed: [7u8; 32],
hybrid_key: Some(HybridKeypairBytes {
x25519_sk: Zeroizing::new([1u8; 32]),
mlkem_dk: Zeroizing::new(vec![3u8; 2400]),
mlkem_ek: vec![4u8; 1184],
}),
group: None,
member_keys: Vec::new(),
};
let password = "another-password";
let plaintext = bincode::serialize(&state).unwrap();
let encrypted = encrypt_state(password, &plaintext).unwrap();
let decrypted = decrypt_state(password, &encrypted).unwrap();
let recovered: StoredState = bincode::deserialize(&decrypted).unwrap();
assert_eq!(recovered.identity_seed, state.identity_seed);
assert!(recovered.hybrid_key.is_some());
}
#[test]
fn state_wrong_password_fails() {
let state = StoredState {
identity_seed: [99u8; 32],
hybrid_key: None,
group: None,
member_keys: Vec::new(),
};
let plaintext = bincode::serialize(&state).unwrap();
let encrypted = encrypt_state("correct", &plaintext).unwrap();
assert!(decrypt_state("wrong", &encrypted).is_err());
}
}

View File

@@ -91,22 +91,28 @@ pub fn serialize(msg_type: MessageType, payload: &[u8]) -> Vec<u8> {
}
/// Serialize a Chat message (generates message_id internally; pass None to generate, or Some(id) when replying with a known id).
pub fn serialize_chat(body: &[u8], message_id: Option<[u8; 16]>) -> Vec<u8> {
pub fn serialize_chat(body: &[u8], message_id: Option<[u8; 16]>) -> Result<Vec<u8>, CoreError> {
if body.len() > u16::MAX as usize {
return Err(CoreError::AppMessage("chat body exceeds maximum length (65535 bytes)".into()));
}
let id = message_id.unwrap_or_else(generate_message_id);
let mut payload = Vec::with_capacity(16 + 2 + body.len());
payload.extend_from_slice(&id);
payload.extend_from_slice(&(body.len() as u16).to_be_bytes());
payload.extend_from_slice(body);
serialize(MessageType::Chat, &payload)
Ok(serialize(MessageType::Chat, &payload))
}
/// Serialize a Reply message.
pub fn serialize_reply(ref_msg_id: [u8; 16], body: &[u8]) -> Vec<u8> {
pub fn serialize_reply(ref_msg_id: [u8; 16], body: &[u8]) -> Result<Vec<u8>, CoreError> {
if body.len() > u16::MAX as usize {
return Err(CoreError::AppMessage("reply body exceeds maximum length (65535 bytes)".into()));
}
let mut payload = Vec::with_capacity(16 + 2 + body.len());
payload.extend_from_slice(&ref_msg_id);
payload.extend_from_slice(&(body.len() as u16).to_be_bytes());
payload.extend_from_slice(body);
serialize(MessageType::Reply, &payload)
Ok(serialize(MessageType::Reply, &payload))
}
/// Serialize a Reaction message.
@@ -220,7 +226,7 @@ mod tests {
#[test]
fn roundtrip_chat() {
let body = b"hello";
let encoded = serialize_chat(body, None);
let encoded = serialize_chat(body, None).unwrap();
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::Chat);
match &msg {
@@ -233,7 +239,7 @@ mod tests {
fn roundtrip_reply() {
let ref_id = [1u8; 16];
let body = b"reply text";
let encoded = serialize_reply(ref_id, body);
let encoded = serialize_reply(ref_id, body).unwrap();
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::Reply);
match &msg {
@@ -255,4 +261,67 @@ mod tests {
_ => panic!("expected Typing"),
}
}
#[test]
fn roundtrip_reaction() {
let ref_id = [2u8; 16];
let emoji = "\u{1f44d}".as_bytes();
let encoded = serialize_reaction(ref_id, emoji).unwrap();
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::Reaction);
match &msg {
AppMessage::Reaction { ref_msg_id, emoji: e } => {
assert_eq!(ref_msg_id, &ref_id);
assert_eq!(e.as_slice(), emoji);
}
_ => panic!("expected Reaction"),
}
}
#[test]
fn roundtrip_read_receipt() {
let msg_id = [3u8; 16];
let encoded = serialize_read_receipt(msg_id);
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::ReadReceipt);
match &msg {
AppMessage::ReadReceipt { msg_id: id } => assert_eq!(id, &msg_id),
_ => panic!("expected ReadReceipt"),
}
}
#[test]
fn parse_empty_fails() {
assert!(parse(&[]).is_err());
}
#[test]
fn parse_bad_version_fails() {
assert!(parse(&[99, 0x01]).is_err());
}
#[test]
fn parse_bad_type_fails() {
assert!(parse(&[1, 0xFF]).is_err());
}
#[test]
fn chat_body_too_long() {
let body = vec![0u8; 65536]; // exceeds u16::MAX
assert!(serialize_chat(&body, None).is_err());
}
#[test]
fn reaction_emoji_too_long() {
let emoji = vec![0u8; 256];
assert!(serialize_reaction([0; 16], &emoji).is_err());
}
#[test]
fn parse_truncated_chat_payload() {
// Version + type + only 10 bytes of payload (needs 18 minimum for chat)
let mut data = vec![1, 0x01];
data.extend_from_slice(&[0u8; 10]);
assert!(parse(&data).is_err());
}
}

View File

@@ -161,9 +161,15 @@ impl OpenMlsCrypto for HybridCrypto {
if Self::is_hybrid_public_key(pk_r) {
let recipient_pk = match HybridPublicKey::from_bytes(pk_r) {
Ok(pk) => pk,
Err(_) => return self.rust_crypto.hpke_seal(config, pk_r, info, aad, ptxt),
// Key parsed as hybrid length but failed to deserialize — this is
// a real error, not a reason to silently fall back to classical HPKE.
Err(_) => return HpkeCiphertext {
kem_output: Vec::new().into(),
ciphertext: Vec::new().into(),
},
};
match hybrid_encrypt(&recipient_pk, ptxt) {
// Pass HPKE info and aad through for proper context binding (RFC 9180).
match hybrid_encrypt(&recipient_pk, ptxt, info, aad) {
Ok(envelope) => {
let kem_output = envelope[..HYBRID_KEM_OUTPUT_LEN].to_vec();
let ciphertext = envelope[HYBRID_KEM_OUTPUT_LEN..].to_vec();
@@ -172,7 +178,13 @@ impl OpenMlsCrypto for HybridCrypto {
ciphertext: ciphertext.into(),
}
}
Err(_) => self.rust_crypto.hpke_seal(config, pk_r, info, aad, ptxt),
// Encryption failed with a hybrid key — return empty ciphertext
// rather than silently falling back to classical HPKE with an
// incompatible key.
Err(_) => HpkeCiphertext {
kem_output: Vec::new().into(),
ciphertext: Vec::new().into(),
},
}
} else {
self.rust_crypto.hpke_seal(config, pk_r, info, aad, ptxt)
@@ -188,17 +200,17 @@ impl OpenMlsCrypto for HybridCrypto {
aad: &[u8],
) -> Result<Vec<u8>, CryptoError> {
if Self::is_hybrid_private_key(sk_r) {
let keypair = match HybridKeypair::from_private_bytes(sk_r) {
Ok(kp) => kp,
Err(_) => return self.rust_crypto.hpke_open(config, input, sk_r, info, aad),
};
let keypair = HybridKeypair::from_private_bytes(sk_r)
.map_err(|_| CryptoError::HpkeDecryptionError)?;
let envelope: Vec<u8> = input
.kem_output.as_slice()
.iter()
.chain(input.ciphertext.as_slice())
.copied()
.collect();
hybrid_decrypt(&keypair, &envelope).map_err(|_| CryptoError::HpkeDecryptionError)
// Pass HPKE info and aad through for proper context binding (RFC 9180).
hybrid_decrypt(&keypair, &envelope, info, aad)
.map_err(|_| CryptoError::HpkeDecryptionError)
} else {
self.rust_crypto.hpke_open(config, input, sk_r, info, aad)
}

View File

@@ -43,6 +43,9 @@ const HYBRID_VERSION: u8 = 0x01;
/// HKDF info string for domain separation.
const HKDF_INFO: &[u8] = b"quicnprotochat-hybrid-v1";
/// HKDF salt for domain separation (defence-in-depth; IKM already has 64 bytes of entropy).
const HKDF_SALT: &[u8] = b"quicnprotochat-hybrid-v1-salt";
/// ML-KEM-768 ciphertext size in bytes.
const MLKEM_CT_LEN: usize = 1088;
@@ -164,7 +167,8 @@ impl HybridKeypair {
if bytes.len() != HYBRID_PRIVATE_KEY_LEN {
return Err(HybridKemError::TooShort(bytes.len()));
}
let x25519_sk = StaticSecret::from(<[u8; 32]>::try_from(&bytes[0..32]).unwrap());
let x25519_sk = StaticSecret::from(<[u8; 32]>::try_from(&bytes[0..32])
.expect("slice is exactly 32 bytes (guaranteed by HYBRID_PRIVATE_KEY_LEN check)"));
let x25519_pk = X25519Public::from(&x25519_sk);
let mlkem_dk_arr = Array::try_from(&bytes[32..32 + MLKEM_DK_LEN])
@@ -247,10 +251,15 @@ impl HybridPublicKey {
/// Encrypt `plaintext` to `recipient_pk` using X25519 + ML-KEM-768 hybrid KEM.
///
/// `info` is optional HPKE context info incorporated into key derivation.
/// `aad` is optional additional authenticated data bound to the AEAD ciphertext.
///
/// Returns the complete hybrid envelope as a byte vector.
pub fn hybrid_encrypt(
recipient_pk: &HybridPublicKey,
plaintext: &[u8],
info: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, HybridKemError> {
// 1. Ephemeral X25519 DH
let eph_secret = EphemeralSecret::random_from_rng(OsRng);
@@ -266,18 +275,19 @@ pub fn hybrid_encrypt(
.encapsulate(&mut OsRng)
.map_err(|_| HybridKemError::EncryptionFailed)?;
// 3. Derive AEAD key from combined shared secrets
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice());
// 3. Derive AEAD key from combined shared secrets (with caller info for context binding)
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), info);
// Generate a random 12-byte nonce (not derived from HKDF).
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
let aead_nonce = *Nonce::from_slice(&nonce_bytes);
// 4. AEAD encrypt
// 4. AEAD encrypt with caller-supplied AAD
let cipher = ChaCha20Poly1305::new(&aead_key);
let aead_payload = chacha20poly1305::aead::Payload { msg: plaintext, aad };
let ct = cipher
.encrypt(&aead_nonce, plaintext)
.encrypt(&aead_nonce, aead_payload)
.map_err(|_| HybridKemError::EncryptionFailed)?;
// 5. Assemble envelope: version || x25519_eph_pk || mlkem_ct || nonce || aead_ct
@@ -292,7 +302,14 @@ pub fn hybrid_encrypt(
}
/// Decrypt a hybrid envelope using the recipient's private key.
pub fn hybrid_decrypt(keypair: &HybridKeypair, envelope: &[u8]) -> Result<Vec<u8>, HybridKemError> {
///
/// `info` and `aad` must match what was passed to `hybrid_encrypt`.
pub fn hybrid_decrypt(
keypair: &HybridKeypair,
envelope: &[u8],
info: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, HybridKemError> {
if envelope.len() < HEADER_LEN + 16 {
// 16 = minimum AEAD tag
return Err(HybridKemError::TooShort(envelope.len()));
@@ -334,13 +351,14 @@ pub fn hybrid_decrypt(keypair: &HybridKeypair, envelope: &[u8]) -> Result<Vec<u8
.decapsulate(&mlkem_ct_arr)
.map_err(|_| HybridKemError::MlKemDecapsFailed)?;
// 3. Derive AEAD key
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice());
// 3. Derive AEAD key (with caller info for context binding)
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), info);
// 4. Decrypt
// 4. Decrypt with caller-supplied AAD
let cipher = ChaCha20Poly1305::new(&aead_key);
let aead_payload = chacha20poly1305::aead::Payload { msg: aead_ct, aad };
let plaintext = cipher
.decrypt(nonce, aead_ct)
.decrypt(nonce, aead_payload)
.map_err(|_| HybridKemError::DecryptionFailed)?;
Ok(plaintext)
@@ -366,8 +384,9 @@ pub fn hybrid_encapsulate_only(
.encapsulate(&mut OsRng)
.map_err(|_| HybridKemError::EncryptionFailed)?;
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice());
let shared_secret = aead_key.as_slice().try_into().unwrap();
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), b"");
let shared_secret: [u8; 32] = aead_key.as_slice().try_into()
.expect("AEAD key is always exactly 32 bytes");
let mut kem_output = Vec::with_capacity(HYBRID_KEM_OUTPUT_LEN);
kem_output.push(HYBRID_VERSION);
@@ -390,7 +409,8 @@ pub fn hybrid_decapsulate_only(
return Err(HybridKemError::UnsupportedVersion(kem_output[0]));
}
let eph_pk_bytes: [u8; 32] = kem_output[1..33].try_into().unwrap();
let eph_pk_bytes: [u8; 32] = kem_output[1..33].try_into()
.expect("slice is exactly 32 bytes (guaranteed by HYBRID_KEM_OUTPUT_LEN check)");
let eph_pk = X25519Public::from(eph_pk_bytes);
let x25519_ss = keypair.x25519_sk.diffie_hellman(&eph_pk);
@@ -401,8 +421,9 @@ pub fn hybrid_decapsulate_only(
.decapsulate(&mlkem_ct_arr)
.map_err(|_| HybridKemError::MlKemDecapsFailed)?;
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice());
Ok(aead_key.as_slice().try_into().unwrap())
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), b"");
Ok(aead_key.as_slice().try_into()
.expect("AEAD key is always exactly 32 bytes"))
}
/// Export a secret from shared secret (MLS HPKE exporter compatibility).
@@ -412,7 +433,7 @@ pub fn hybrid_export(
exporter_context: &[u8],
length: usize,
) -> Vec<u8> {
let hk = Hkdf::<Sha256>::new(None, shared_secret);
let hk = Hkdf::<Sha256>::new(Some(HKDF_SALT), shared_secret);
let mut out = vec![0u8; length];
hk.expand(exporter_context, &mut out).expect("valid length");
out
@@ -420,18 +441,26 @@ pub fn hybrid_export(
/// Derive AEAD key from the combined X25519 + ML-KEM shared secrets.
///
/// `extra_info` is optional caller-supplied context (e.g. HPKE `info`) that is
/// appended to the domain-separation label for additional binding.
///
/// The nonce is generated randomly per-encryption rather than derived from
/// HKDF, preventing nonce reuse when the same shared secret is (accidentally)
/// used more than once.
fn derive_aead_key(x25519_ss: &[u8], mlkem_ss: &[u8]) -> Key {
fn derive_aead_key(x25519_ss: &[u8], mlkem_ss: &[u8], extra_info: &[u8]) -> Key {
let mut ikm = Zeroizing::new(vec![0u8; x25519_ss.len() + mlkem_ss.len()]);
ikm[..x25519_ss.len()].copy_from_slice(x25519_ss);
ikm[x25519_ss.len()..].copy_from_slice(mlkem_ss);
let hk = Hkdf::<Sha256>::new(None, &ikm);
let hk = Hkdf::<Sha256>::new(Some(HKDF_SALT), &ikm);
// Combine domain-separation label with caller-supplied context.
let mut info = Vec::with_capacity(HKDF_INFO.len() + extra_info.len());
info.extend_from_slice(HKDF_INFO);
info.extend_from_slice(extra_info);
let mut key_bytes = Zeroizing::new([0u8; 32]);
hk.expand(HKDF_INFO, &mut *key_bytes)
hk.expand(&info, &mut *key_bytes)
.expect("32 bytes is valid HKDF-SHA256 output length");
*Key::from_slice(&*key_bytes)
@@ -457,21 +486,39 @@ mod tests {
let pk = kp.public_key();
let plaintext = b"hello post-quantum world!";
let envelope = hybrid_encrypt(&pk, plaintext).unwrap();
let recovered = hybrid_decrypt(&kp, &envelope).unwrap();
let envelope = hybrid_encrypt(&pk, plaintext, b"", b"").unwrap();
let recovered = hybrid_decrypt(&kp, &envelope, b"", b"").unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn encrypt_decrypt_with_info_aad() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let plaintext = b"context-bound payload";
let info = b"mls epoch 42";
let aad = b"group-id-abc";
let envelope = hybrid_encrypt(&pk, plaintext, info, aad).unwrap();
let recovered = hybrid_decrypt(&kp, &envelope, info, aad).unwrap();
assert_eq!(recovered, plaintext);
// Mismatched info must fail
assert!(hybrid_decrypt(&kp, &envelope, b"wrong info", aad).is_err());
// Mismatched aad must fail
assert!(hybrid_decrypt(&kp, &envelope, info, b"wrong aad").is_err());
}
#[test]
fn wrong_key_decryption_fails() {
let kp_sender_target = HybridKeypair::generate();
let kp_wrong = HybridKeypair::generate();
let pk = kp_sender_target.public_key();
let envelope = hybrid_encrypt(&pk, b"secret").unwrap();
let envelope = hybrid_encrypt(&pk, b"secret", b"", b"").unwrap();
let result = hybrid_decrypt(&kp_wrong, &envelope);
let result = hybrid_decrypt(&kp_wrong, &envelope, b"", b"");
assert!(result.is_err());
}
@@ -480,12 +527,12 @@ mod tests {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
let last = envelope.len() - 1;
envelope[last] ^= 0x01;
assert!(matches!(
hybrid_decrypt(&kp, &envelope),
hybrid_decrypt(&kp, &envelope, b"", b""),
Err(HybridKemError::DecryptionFailed)
));
}
@@ -495,11 +542,11 @@ mod tests {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
// Flip a byte in the ML-KEM ciphertext region (starts at offset 33)
envelope[40] ^= 0xFF;
assert!(hybrid_decrypt(&kp, &envelope).is_err());
assert!(hybrid_decrypt(&kp, &envelope, b"", b"").is_err());
}
#[test]
@@ -507,11 +554,11 @@ mod tests {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
// Flip a byte in the X25519 ephemeral pk region (offset 1..33)
envelope[5] ^= 0xFF;
assert!(hybrid_decrypt(&kp, &envelope).is_err());
assert!(hybrid_decrypt(&kp, &envelope, b"", b"").is_err());
}
#[test]
@@ -519,11 +566,11 @@ mod tests {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
envelope[0] = 0xFF;
assert!(matches!(
hybrid_decrypt(&kp, &envelope),
hybrid_decrypt(&kp, &envelope, b"", b""),
Err(HybridKemError::UnsupportedVersion(0xFF))
));
}
@@ -532,7 +579,7 @@ mod tests {
fn envelope_too_short_rejected() {
let kp = HybridKeypair::generate();
assert!(matches!(
hybrid_decrypt(&kp, &[0x01; 10]),
hybrid_decrypt(&kp, &[0x01; 10], b"", b""),
Err(HybridKemError::TooShort(10))
));
}
@@ -548,8 +595,8 @@ mod tests {
// Verify restored keypair can decrypt
let pk = kp.public_key();
let ct = hybrid_encrypt(&pk, b"test").unwrap();
let pt = hybrid_decrypt(&restored, &ct).unwrap();
let ct = hybrid_encrypt(&pk, b"test", b"", b"").unwrap();
let pt = hybrid_decrypt(&restored, &ct, b"", b"").unwrap();
assert_eq!(pt, b"test");
}
@@ -570,8 +617,8 @@ mod tests {
let pk = kp.public_key();
let plaintext = vec![0xAB; 50_000]; // 50 KB
let envelope = hybrid_encrypt(&pk, &plaintext).unwrap();
let recovered = hybrid_decrypt(&kp, &envelope).unwrap();
let envelope = hybrid_encrypt(&pk, &plaintext, b"", b"").unwrap();
let recovered = hybrid_decrypt(&kp, &envelope, b"", b"").unwrap();
assert_eq!(recovered, plaintext);
}

View File

@@ -62,7 +62,7 @@ impl DiskKeyStore {
let Some(path) = &self.path else {
return Ok(());
};
let values = self.values.read().unwrap();
let values = self.values.read().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?;
let bytes = bincode::serialize(&*values).map_err(|_| DiskKeyStoreError::Serialization)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?;
@@ -82,21 +82,24 @@ impl OpenMlsKeyStore for DiskKeyStore {
fn store<V: MlsEntity>(&self, k: &[u8], v: &V) -> Result<(), Self::Error> {
let value = serde_json::to_vec(v).map_err(|_| DiskKeyStoreError::Serialization)?;
let mut values = self.values.write().unwrap();
let mut values = self.values.write().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?;
values.insert(k.to_vec(), value);
drop(values);
self.flush()
}
fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V> {
let values = self.values.read().unwrap();
let values = match self.values.read() {
Ok(v) => v,
Err(_) => return None,
};
values
.get(k)
.and_then(|bytes| serde_json::from_slice(bytes).ok())
}
fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
let mut values = self.values.write().unwrap();
let mut values = self.values.write().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?;
values.remove(k);
drop(values);
self.flush()

View File

@@ -16,8 +16,8 @@
mod app_message;
mod error;
mod group;
pub mod hybrid_crypto;
pub mod hybrid_kem;
mod hybrid_crypto;
mod hybrid_kem;
mod identity;
mod keypackage;
mod keystore;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,13 @@
-- Migration 003: 1:1 DM channels.
-- channel_id is 16 bytes (UUID); member_a and member_b are identity keys in sorted order.
-- Unique on (member_a, member_b) prevents duplicate channels between the same pair.
CREATE TABLE IF NOT EXISTS channels (
channel_id BLOB PRIMARY KEY,
member_a BLOB NOT NULL,
member_b BLOB NOT NULL,
UNIQUE(member_a, member_b)
);
CREATE INDEX IF NOT EXISTS idx_channels_members
ON channels(member_a, member_b);

View File

@@ -17,6 +17,7 @@ pub const RATE_LIMIT_MAX_ENQUEUES: u32 = 100;
pub struct AuthConfig {
pub required_token: Option<Vec<u8>>,
/// When true, a valid bearer token (no session) is accepted and the request's identity/key is used (dev/e2e only).
/// CLI flag: --allow-insecure-auth / QUICNPROTOCHAT_ALLOW_INSECURE_AUTH.
pub allow_insecure_identity_from_request: bool,
}
@@ -59,10 +60,13 @@ pub struct AuthContext {
}
pub fn current_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
Ok(d) => d.as_secs(),
Err(_) => {
tracing::warn!("system time is before UNIX_EPOCH; using 0 for session/rate-limit timestamps");
0
}
}
}
pub fn check_rate_limit(
@@ -174,7 +178,7 @@ pub fn require_identity<'a>(auth_ctx: &'a AuthContext) -> Result<&'a [u8], capnp
pub fn require_identity_match(auth_ctx: &AuthContext, expected: &[u8]) -> Result<(), capnp::Error> {
let ik = require_identity(auth_ctx)?;
if ik != expected {
if ik.len() != expected.len() || !bool::from(ik.ct_eq(expected)) {
return Err(crate::error_codes::coded_error(
E016_IDENTITY_MISMATCH,
"access token is bound to a different identity",

View File

@@ -24,6 +24,8 @@ pub const E018_USER_EXISTS: &str = "E018";
pub const E019_NO_PENDING_LOGIN: &str = "E019";
pub const E020_BAD_PARAMS: &str = "E020";
pub const E021_CIPHERSUITE_NOT_ALLOWED: &str = "E021";
pub const E022_CHANNEL_ACCESS_DENIED: &str = "E022";
pub const E023_CHANNEL_NOT_FOUND: &str = "E023";
/// Build a `capnp::Error::failed()` with the structured code prefix.
pub fn coded_error(code: &str, msg: impl std::fmt::Display) -> capnp::Error {

View File

@@ -162,9 +162,16 @@ async fn main() -> anyhow::Result<()> {
.parse()
.context("--listen must be host:port")?;
let server_config = build_server_config(&effective.tls_cert, &effective.tls_key, production)
let mut server_config = build_server_config(&effective.tls_cert, &effective.tls_key, production)
.context("failed to build TLS/QUIC server config")?;
// Harden QUIC transport: idle timeout, limit stream concurrency.
let mut transport = quinn::TransportConfig::default();
transport.max_idle_timeout(Some(std::time::Duration::from_secs(300).try_into().unwrap()));
transport.max_concurrent_bidi_streams(1u32.into());
transport.max_concurrent_uni_streams(0u32.into());
server_config.transport_config(Arc::new(transport));
// Shared storage — persisted to disk for restart safety.
let store: Arc<dyn Store> = match effective.store_backend.as_str() {
"sql" => {
@@ -223,6 +230,7 @@ async fn main() -> anyhow::Result<()> {
Arc::clone(&pending_logins),
Arc::clone(&rate_limits),
Arc::clone(&store),
Arc::clone(&waiters),
);
let endpoint = Endpoint::server(server_config, listen)?;

View File

@@ -19,6 +19,16 @@ fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
/// Parse username from Cap'n Proto reader; requires valid UTF-8.
fn parse_username_param(
result: Result<capnp::text::Reader<'_>, capnp::Error>,
) -> Result<String, capnp::Error> {
let reader = result.map_err(|e| coded_error(E020_BAD_PARAMS, e))?;
reader
.to_string()
.map_err(|_| coded_error(E020_BAD_PARAMS, "username must be valid UTF-8"))
}
impl NodeServiceImpl {
pub fn handle_opaque_login_start(
&mut self,
@@ -29,9 +39,9 @@ impl NodeServiceImpl {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match p.get_username() {
Ok(v) => v.to_string().unwrap_or_default().to_string(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let request_bytes = match p.get_request() {
Ok(v) => v.to_vec(),
@@ -42,6 +52,14 @@ impl NodeServiceImpl {
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
}
// Check for existing recent pending login before expensive OPAQUE/storage work (DoS mitigation).
if let Some(existing) = self.pending_logins.get(&username) {
let age = current_timestamp().saturating_sub(existing.created_at);
if age < 60 {
return Promise::err(coded_error(E010_OPAQUE_ERROR, "login already in progress"));
}
}
let credential_request = match CredentialRequest::<OpaqueSuite>::deserialize(&request_bytes) {
Ok(r) => r,
Err(e) => {
@@ -62,9 +80,7 @@ impl NodeServiceImpl {
))
}
},
Ok(None) => {
return Promise::err(coded_error(E010_OPAQUE_ERROR, "user not registered"))
}
Ok(None) => None,
Err(e) => return Promise::err(storage_err(e)),
};
@@ -111,9 +127,9 @@ impl NodeServiceImpl {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match p.get_username() {
Ok(v) => v.to_string().unwrap_or_default().to_string(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let request_bytes = match p.get_request() {
Ok(v) => v.to_vec(),
@@ -171,9 +187,9 @@ impl NodeServiceImpl {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match p.get_username() {
Ok(v) => v.to_string().unwrap_or_default().to_string(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let finalization_bytes = match p.get_finalization() {
Ok(v) => v.to_vec(),
@@ -278,9 +294,9 @@ impl NodeServiceImpl {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let username = match p.get_username() {
Ok(v) => v.to_string().unwrap_or_default().to_string(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
let username = match parse_username_param(p.get_username()) {
Ok(s) => s,
Err(e) => return Promise::err(e),
};
let upload_bytes = match p.get_upload() {
Ok(v) => v.to_vec(),
@@ -326,12 +342,18 @@ impl NodeServiceImpl {
let password_file = ServerRegistration::<OpaqueSuite>::finish(upload);
let record_bytes = password_file.serialize().to_vec();
if let Err(e) = self
match self
.store
.store_user_record(&username, record_bytes)
.map_err(storage_err)
{
return Promise::err(e);
Ok(()) => {}
Err(crate::storage::StorageError::DuplicateUser(_)) => {
return Promise::err(coded_error(
E018_USER_EXISTS,
format!("user '{}' already registered", username),
))
}
Err(e) => return Promise::err(storage_err(e)),
}
if !identity_key.is_empty() {

View File

@@ -0,0 +1,62 @@
//! createChannel RPC: create or look up a 1:1 DM channel.
use capnp::capability::Promise;
use quicnprotochat_proto::node_capnp::node_service;
use crate::auth::{coded_error, require_identity, validate_auth_context};
use crate::error_codes::*;
use crate::storage::StorageError;
use super::NodeServiceImpl;
fn storage_err(err: StorageError) -> capnp::Error {
coded_error(E009_STORAGE_ERROR, err)
}
impl NodeServiceImpl {
pub fn handle_create_channel(
&mut self,
params: node_service::CreateChannelParams,
mut results: node_service::CreateChannelResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let peer_key = match p.get_peer_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
let identity = match require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
if peer_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("peerKey must be exactly 32 bytes, got {}", peer_key.len()),
));
}
if identity == peer_key {
return Promise::err(coded_error(
E020_BAD_PARAMS,
"peerKey must not equal caller identity",
));
}
let channel_id = match self.store.create_channel(&identity, &peer_key) {
Ok(id) => id,
Err(e) => return Promise::err(storage_err(e)),
};
results.get().set_channel_id(&channel_id);
Promise::ok(())
}
}

View File

@@ -77,10 +77,10 @@ impl NodeServiceImpl {
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
));
}
if version != CURRENT_WIRE_VERSION {
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
@@ -102,6 +102,31 @@ impl NodeServiceImpl {
}
}
// DM channel authz: channel_id.len() == 16 means a created channel; caller and recipient must be the two members.
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key == *a && caller == b.as_slice())
|| (recipient_key == *b && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
match self.store.queue_depth(&recipient_key, &channel_id) {
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
return Promise::err(coded_error(
@@ -183,10 +208,10 @@ impl NodeServiceImpl {
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version != CURRENT_WIRE_VERSION {
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
@@ -203,6 +228,30 @@ impl NodeServiceImpl {
return Promise::err(e);
}
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
let messages = if limit > 0 {
match self
.store
@@ -269,10 +318,10 @@ impl NodeServiceImpl {
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version != CURRENT_WIRE_VERSION {
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
@@ -284,6 +333,30 @@ impl NodeServiceImpl {
return Promise::err(e);
}
if channel_id.len() == 16 {
let members = match self.store.get_channel_members(&channel_id) {
Ok(Some(m)) => m,
Ok(None) => {
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
}
Err(e) => return Promise::err(storage_err(e)),
};
let caller = match crate::auth::require_identity(&auth_ctx) {
Ok(id) => id,
Err(e) => return Promise::err(e),
};
let (a, b) = &members;
let caller_in = caller == a.as_slice() || caller == b.as_slice();
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
if !caller_in || !recipient_other {
return Promise::err(coded_error(
E022_CHANNEL_ACCESS_DENIED,
"caller or recipient not a member of this channel",
));
}
}
let store = Arc::clone(&self.store);
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = self.waiters.clone();
@@ -315,4 +388,232 @@ impl NodeServiceImpl {
Ok(())
})
}
pub fn handle_peek(
&mut self,
params: node_service::PeekParams,
mut results: node_service::PeekResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let limit = p.get_limit();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
let messages = match self
.store
.peek(&recipient_key, &channel_id, limit as usize)
.map_err(storage_err)
{
Ok(m) => m,
Err(e) => return Promise::err(e),
};
tracing::info!(
recipient_prefix = %fmt_hex(&recipient_key[..4]),
count = messages.len(),
"audit: peek"
);
let mut list = results.get().init_payloads(messages.len() as u32);
for (i, (seq, data)) in messages.iter().enumerate() {
let mut entry = list.reborrow().get(i as u32);
entry.set_seq(*seq);
entry.set_data(data);
}
Promise::ok(())
}
pub fn handle_ack(
&mut self,
params: node_service::AckParams,
_results: node_service::AckResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_key = match p.get_recipient_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let seq_up_to = p.get_seq_up_to();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if recipient_key.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = require_identity_or_request(
&auth_ctx,
&recipient_key,
self.auth_cfg.allow_insecure_identity_from_request,
) {
return Promise::err(e);
}
match self
.store
.ack(&recipient_key, &channel_id, seq_up_to)
.map_err(storage_err)
{
Ok(removed) => {
tracing::info!(
recipient_prefix = %fmt_hex(&recipient_key[..4]),
seq_up_to = seq_up_to,
removed = removed,
"audit: ack"
);
}
Err(e) => return Promise::err(e),
}
Promise::ok(())
}
pub fn handle_batch_enqueue(
&mut self,
params: node_service::BatchEnqueueParams,
mut results: node_service::BatchEnqueueResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let recipient_keys = match p.get_recipient_keys() {
Ok(v) => v,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let payload = match p.get_payload() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
Ok(ctx) => ctx,
Err(e) => return Promise::err(e),
};
if payload.is_empty() {
return Promise::err(coded_error(E005_PAYLOAD_EMPTY, "payload must not be empty"));
}
if payload.len() > MAX_PAYLOAD_BYTES {
return Promise::err(coded_error(
E006_PAYLOAD_TOO_LARGE,
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
));
}
if version > CURRENT_WIRE_VERSION {
return Promise::err(coded_error(
E012_WIRE_VERSION,
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
));
}
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) {
tracing::warn!("rate_limit_hit");
metrics::record_rate_limit_hit_total();
return Promise::err(e);
}
let mut seqs = Vec::with_capacity(recipient_keys.len() as usize);
for i in 0..recipient_keys.len() {
let rk = match recipient_keys.get(i) {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if rk.len() != 32 {
return Promise::err(coded_error(
E004_IDENTITY_KEY_LENGTH,
format!("recipientKey[{}] must be exactly 32 bytes, got {}", i, rk.len()),
));
}
match self.store.queue_depth(&rk, &channel_id) {
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
return Promise::err(coded_error(
E015_QUEUE_FULL,
format!("queue depth {} exceeds limit {}", depth, MAX_QUEUE_DEPTH),
));
}
Err(e) => return Promise::err(storage_err(e)),
_ => {}
}
let seq = match self
.store
.enqueue(&rk, &channel_id, payload.clone())
.map_err(storage_err)
{
Ok(seq) => seq,
Err(e) => return Promise::err(e),
};
seqs.push(seq);
metrics::record_enqueue_total();
metrics::record_enqueue_bytes(payload.len() as u64);
crate::auth::waiter(&self.waiters, &rk).notify_waiters();
}
let mut list = results.get().init_seqs(seqs.len() as u32);
for (i, seq) in seqs.iter().enumerate() {
list.set(i as u32, *seq);
}
tracing::info!(
recipient_count = recipient_keys.len(),
payload_len = payload.len(),
"audit: batch_enqueue"
);
Promise::ok(())
}
}

View File

@@ -256,4 +256,47 @@ impl NodeServiceImpl {
Promise::ok(())
}
pub fn handle_fetch_hybrid_keys(
&mut self,
params: node_service::FetchHybridKeysParams,
mut results: node_service::FetchHybridKeysResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let identity_keys = match p.get_identity_keys() {
Ok(v) => v,
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
if let Err(e) = validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
return Promise::err(e);
}
let count = identity_keys.len() as usize;
let mut key_data: Vec<Vec<u8>> = Vec::with_capacity(count);
for i in 0..identity_keys.len() {
let ik = match identity_keys.get(i) {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
};
let pk = match self.store.fetch_hybrid_key(&ik).map_err(storage_err) {
Ok(Some(pk)) => pk,
Ok(None) => vec![],
Err(e) => return Promise::err(e),
};
key_data.push(pk);
}
let mut list = results.get().init_keys(key_data.len() as u32);
for (i, pk) in key_data.iter().enumerate() {
list.set(i as u32, pk);
}
tracing::debug!(count = count, "batch hybrid key fetch");
Promise::ok(())
}
}

View File

@@ -15,7 +15,11 @@ use crate::auth::{
};
use crate::storage::Store;
/// Cap'n Proto traversal limit (words). 4 Mi words = 32 MiB; bounds DoS from deeply nested or large messages.
const CAPNP_TRAVERSAL_LIMIT_WORDS: usize = 4 * 1024 * 1024;
mod auth_ops;
mod channel_ops;
mod delivery;
mod key_ops;
mod p2p_ops;
@@ -132,6 +136,46 @@ impl node_service::Server for NodeServiceImpl {
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_resolve_endpoint(params, results)
}
fn peek(
&mut self,
params: node_service::PeekParams,
results: node_service::PeekResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_peek(params, results)
}
fn ack(
&mut self,
params: node_service::AckParams,
results: node_service::AckResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_ack(params, results)
}
fn fetch_hybrid_keys(
&mut self,
params: node_service::FetchHybridKeysParams,
results: node_service::FetchHybridKeysResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_fetch_hybrid_keys(params, results)
}
fn batch_enqueue(
&mut self,
params: node_service::BatchEnqueueParams,
results: node_service::BatchEnqueueResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_batch_enqueue(params, results)
}
fn create_channel(
&mut self,
params: node_service::CreateChannelParams,
results: node_service::CreateChannelResults,
) -> capnp::capability::Promise<(), capnp::Error> {
self.handle_create_channel(params, results)
}
}
pub const CURRENT_WIRE_VERSION: u16 = 1;
@@ -193,11 +237,13 @@ pub async fn handle_node_connection(
.map_err(|e| anyhow::anyhow!("failed to accept bi stream: {e}"))?;
let (reader, writer) = (recv.compat(), send.compat_write());
let mut reader_opts = capnp::message::ReaderOptions::new();
reader_opts.traversal_limit_in_words(Some(CAPNP_TRAVERSAL_LIMIT_WORDS));
let network = capnp_rpc::twoparty::VatNetwork::new(
reader,
writer,
capnp_rpc::rpc_twoparty_capnp::Side::Server,
Default::default(),
reader_opts,
);
let service: node_service::Client = capnp_rpc::new_client(NodeServiceImpl::new(
@@ -223,6 +269,7 @@ pub fn spawn_cleanup_task(
pending_logins: Arc<DashMap<String, PendingLogin>>,
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
store: Arc<dyn Store>,
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
@@ -234,6 +281,29 @@ pub fn spawn_cleanup_task(
pending_logins.retain(|_, pl| now - pl.created_at < PENDING_LOGIN_TTL_SECS);
rate_limits.retain(|_, entry| now - entry.window_start < RATE_LIMIT_WINDOW_SECS * 2);
// Bound map sizes to prevent unbounded growth from malicious clients.
const MAX_SESSIONS: usize = 100_000;
const MAX_WAITERS: usize = 100_000;
if sessions.len() > MAX_SESSIONS {
let overflow = sessions.len() - MAX_SESSIONS;
let mut entries: Vec<_> = sessions
.iter()
.map(|e| (e.key().clone(), e.expires_at))
.collect();
entries.sort_by_key(|(_, exp)| *exp);
for (key, _) in entries.into_iter().take(overflow) {
sessions.remove(&key);
}
}
if waiters.len() > MAX_WAITERS {
let overflow = waiters.len() - MAX_WAITERS;
let keys: Vec<_> =
waiters.iter().take(overflow).map(|e| e.key().clone()).collect();
for key in keys {
waiters.remove(&key);
}
}
match store.gc_expired_messages(MESSAGE_TTL_SECS) {
Ok(n) if n > 0 => {
tracing::debug!(expired = n, "garbage collected expired messages")

View File

@@ -14,6 +14,7 @@ fn storage_err(err: StorageError) -> capnp::Error {
}
impl NodeServiceImpl {
/// Health check: unauthenticated by design for liveness probes and load balancers.
pub fn handle_health(
&mut self,
_params: node_service::HealthParams,

View File

@@ -3,17 +3,19 @@
use std::path::Path;
use std::sync::Mutex;
use rand::RngCore;
use rusqlite::{params, Connection};
use crate::storage::{StorageError, Store};
/// Schema version after introducing the migration runner (existing DBs had 1).
const SCHEMA_VERSION: i32 = 3;
const SCHEMA_VERSION: i32 = 4;
/// Migrations: (migration_number, SQL). Files named NNN_name.sql, applied in order when N > user_version.
const MIGRATIONS: &[(i32, &str)] = &[
(1, include_str!("../migrations/001_initial.sql")),
(3, include_str!("../migrations/002_add_seq.sql")),
(4, include_str!("../migrations/003_channels.sql")),
];
/// Runs pending migrations on an open connection: applies any migration whose number is greater
@@ -305,10 +307,17 @@ impl Store for SqlStore {
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
let conn = self.lock_conn()?;
conn.execute(
"INSERT OR REPLACE INTO users (username, opaque_record) VALUES (?1, ?2)",
"INSERT INTO users (username, opaque_record) VALUES (?1, ?2)",
params![username, record],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
.map_err(|e| {
if let rusqlite::Error::SqliteFailure(ref err, _) = &e {
if err.code == rusqlite::ErrorCode::ConstraintViolation {
return StorageError::DuplicateUser(username.to_string());
}
}
StorageError::Db(e.to_string())
})?;
Ok(())
}
@@ -360,6 +369,57 @@ impl Store for SqlStore {
.map_err(|e| StorageError::Db(e.to_string()))
}
fn peek(
&self,
recipient_key: &[u8],
channel_id: &[u8],
limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
let conn = self.lock_conn()?;
let sql = if limit == 0 {
"SELECT seq, payload FROM deliveries
WHERE recipient_key = ?1 AND channel_id = ?2
ORDER BY seq ASC".to_string()
} else {
format!(
"SELECT seq, payload FROM deliveries
WHERE recipient_key = ?1 AND channel_id = ?2
ORDER BY seq ASC
LIMIT {}",
limit
)
};
let mut stmt = conn.prepare(&sql).map_err(|e| StorageError::Db(e.to_string()))?;
let rows: Vec<(i64, Vec<u8>)> = stmt
.query_map(params![recipient_key, channel_id], |row| {
Ok((row.get(0)?, row.get(1)?))
})
.map_err(|e| StorageError::Db(e.to_string()))?
.collect::<Result<Vec<_>, _>>()
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(rows.into_iter().map(|(seq, payload)| (seq as u64, payload)).collect())
}
fn ack(
&self,
recipient_key: &[u8],
channel_id: &[u8],
seq_up_to: u64,
) -> Result<usize, StorageError> {
let conn = self.lock_conn()?;
let deleted = conn
.execute(
"DELETE FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND seq <= ?3",
params![recipient_key, channel_id, seq_up_to as i64],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(deleted)
}
fn publish_endpoint(
&self,
identity_key: &[u8],
@@ -384,6 +444,45 @@ impl Store for SqlStore {
.optional()
.map_err(|e| StorageError::Db(e.to_string()))
}
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError> {
let (a, b) = if member_a < member_b {
(member_a.to_vec(), member_b.to_vec())
} else {
(member_b.to_vec(), member_a.to_vec())
};
let conn = self.lock_conn()?;
let existing: Option<Vec<u8>> = conn
.query_row(
"SELECT channel_id FROM channels WHERE member_a = ?1 AND member_b = ?2",
params![a, b],
|row| row.get(0),
)
.optional()
.map_err(|e| StorageError::Db(e.to_string()))?;
if let Some(id) = existing {
return Ok(id);
}
let mut channel_id = [0u8; 16];
rand::thread_rng().fill_bytes(&mut channel_id);
conn.execute(
"INSERT INTO channels (channel_id, member_a, member_b) VALUES (?1, ?2, ?3)",
params![channel_id.as_slice(), a, b],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(channel_id.to_vec())
}
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError> {
let conn = self.lock_conn()?;
conn.query_row(
"SELECT member_a, member_b FROM channels WHERE channel_id = ?1",
params![channel_id],
|row| Ok((row.get::<_, Vec<u8>>(0)?, row.get::<_, Vec<u8>>(1)?)),
)
.optional()
.map_err(|e| StorageError::Db(e.to_string()))
}
}
/// Convenience extension for `rusqlite::OptionalExtension`.

View File

@@ -6,6 +6,7 @@ use std::{
sync::Mutex,
};
use rand::RngCore;
use serde::{Deserialize, Serialize};
#[derive(thiserror::Error, Debug)]
@@ -16,6 +17,9 @@ pub enum StorageError {
Serde,
#[error("database error: {0}")]
Db(String),
/// Unique constraint violation (e.g. user already exists).
#[error("duplicate user: {0}")]
DuplicateUser(String),
}
fn lock<T>(m: &Mutex<T>) -> Result<std::sync::MutexGuard<'_, T>, StorageError> {
@@ -96,12 +100,36 @@ pub trait Store: Send + Sync {
/// Retrieve identity key for a user (Fix 2).
fn get_user_identity_key(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError>;
/// Peek at queued messages without removing them (non-destructive).
/// Returns `(seq, payload)` pairs ordered by seq.
fn peek(
&self,
recipient_key: &[u8],
channel_id: &[u8],
limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, StorageError>;
/// Acknowledge (remove) all messages with seq <= seq_up_to.
fn ack(
&self,
recipient_key: &[u8],
channel_id: &[u8],
seq_up_to: u64,
) -> Result<usize, StorageError>;
/// Publish a P2P endpoint address for an identity key.
fn publish_endpoint(&self, identity_key: &[u8], node_addr: Vec<u8>)
-> Result<(), StorageError>;
/// Resolve a peer's P2P endpoint address.
fn resolve_endpoint(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
/// Create a 1:1 channel between two members. Returns 16-byte channel_id (UUID).
/// Members are stored in sorted order for deterministic lookup.
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError>;
/// Get the two members of a channel by channel_id (16 bytes). Returns (member_a, member_b) in sorted order.
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError>;
}
// ── ChannelKey ───────────────────────────────────────────────────────────────
@@ -154,8 +182,10 @@ pub struct FileBackedStore {
setup_path: PathBuf,
users_path: PathBuf,
identity_keys_path: PathBuf,
channels_path: PathBuf,
key_packages: Mutex<HashMap<Vec<u8>, VecDeque<Vec<u8>>>>,
deliveries: Mutex<QueueMapV3>,
channels: Mutex<HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>>,
hybrid_keys: Mutex<HashMap<Vec<u8>, Vec<u8>>>,
users: Mutex<HashMap<String, Vec<u8>>>,
identity_keys: Mutex<HashMap<String, Vec<u8>>>,
@@ -174,12 +204,14 @@ impl FileBackedStore {
let setup_path = dir.join("server_setup.bin");
let users_path = dir.join("users.bin");
let identity_keys_path = dir.join("identity_keys.bin");
let channels_path = dir.join("channels.bin");
let key_packages = Mutex::new(Self::load_kp_map(&kp_path)?);
let deliveries = Mutex::new(Self::load_delivery_map_v3(&ds_path)?);
let hybrid_keys = Mutex::new(Self::load_hybrid_keys(&hk_path)?);
let users = Mutex::new(Self::load_users(&users_path)?);
let identity_keys = Mutex::new(Self::load_map_string_bytes(&identity_keys_path)?);
let channels = Mutex::new(Self::load_channels(&channels_path)?);
Ok(Self {
kp_path,
@@ -188,8 +220,10 @@ impl FileBackedStore {
setup_path,
users_path,
identity_keys_path,
channels_path,
key_packages,
deliveries,
channels,
hybrid_keys,
users,
identity_keys,
@@ -197,6 +231,31 @@ impl FileBackedStore {
})
}
fn load_channels(
path: &Path,
) -> Result<HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>, StorageError> {
if !path.exists() {
return Ok(HashMap::new());
}
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
if bytes.is_empty() {
return Ok(HashMap::new());
}
bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)
}
fn flush_channels(
&self,
path: &Path,
map: &HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>,
) -> Result<(), StorageError> {
let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
}
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
}
fn load_kp_map(path: &Path) -> Result<HashMap<Vec<u8>, VecDeque<Vec<u8>>>, StorageError> {
if !path.exists() {
return Ok(HashMap::new());
@@ -346,8 +405,9 @@ impl Store for FileBackedStore {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
let seq = *inner.next_seq.entry(key.clone()).or_insert(0);
*inner.next_seq.get_mut(&key).unwrap() = seq + 1;
let entry = inner.next_seq.entry(key.clone()).or_insert(0);
let seq = *entry;
*entry = seq + 1;
inner.map.entry(key).or_default().push_back(SeqEntry { seq, data: payload });
self.flush_delivery_map(&self.ds_path, &*inner)?;
Ok(seq)
@@ -428,7 +488,13 @@ impl Store for FileBackedStore {
if let Some(parent) = self.setup_path.parent() {
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
}
fs::write(&self.setup_path, setup).map_err(|e| StorageError::Io(e.to_string()))
fs::write(&self.setup_path, setup).map_err(|e| StorageError::Io(e.to_string()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(&self.setup_path, std::fs::Permissions::from_mode(0o600));
}
Ok(())
}
fn get_server_setup(&self) -> Result<Option<Vec<u8>>, StorageError> {
@@ -444,7 +510,14 @@ impl Store for FileBackedStore {
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
let mut map = lock(&self.users)?;
map.insert(username.to_string(), record);
match map.entry(username.to_string()) {
std::collections::hash_map::Entry::Occupied(_) => {
return Err(StorageError::DuplicateUser(username.to_string()))
}
std::collections::hash_map::Entry::Vacant(v) => {
v.insert(record);
}
}
self.flush_users(&self.users_path, &*map)
}
@@ -473,6 +546,54 @@ impl Store for FileBackedStore {
Ok(map.get(username).cloned())
}
fn peek(
&self,
recipient_key: &[u8],
channel_id: &[u8],
limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
let inner = lock(&self.deliveries)?;
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
let messages: Vec<(u64, Vec<u8>)> = inner
.map
.get(&key)
.map(|q| {
let count = if limit == 0 { q.len() } else { limit.min(q.len()) };
q.iter()
.take(count)
.map(|e| (e.seq, e.data.clone()))
.collect()
})
.unwrap_or_default();
// Non-destructive: do NOT flush.
Ok(messages)
}
fn ack(
&self,
recipient_key: &[u8],
channel_id: &[u8],
seq_up_to: u64,
) -> Result<usize, StorageError> {
let mut inner = lock(&self.deliveries)?;
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
let removed = if let Some(q) = inner.map.get_mut(&key) {
let before = q.len();
q.retain(|e| e.seq > seq_up_to);
before - q.len()
} else {
0
};
self.flush_delivery_map(&self.ds_path, &*inner)?;
Ok(removed)
}
fn publish_endpoint(
&self,
identity_key: &[u8],
@@ -487,4 +608,150 @@ impl Store for FileBackedStore {
let map = lock(&self.endpoints)?;
Ok(map.get(identity_key).cloned())
}
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError> {
let (a, b) = if member_a < member_b {
(member_a.to_vec(), member_b.to_vec())
} else {
(member_b.to_vec(), member_a.to_vec())
};
let mut map = lock(&self.channels)?;
if let Some((channel_id, _)) = map.iter().find(|(_, (ma, mb))| ma == &a && mb == &b) {
return Ok(channel_id.clone());
}
let mut channel_id = [0u8; 16];
rand::thread_rng().fill_bytes(&mut channel_id);
let channel_id = channel_id.to_vec();
map.insert(channel_id.clone(), (a, b));
self.flush_channels(&self.channels_path, &*map)?;
Ok(channel_id)
}
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError> {
let map = lock(&self.channels)?;
Ok(map.get(channel_id).cloned())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn temp_store() -> (TempDir, FileBackedStore) {
let dir = TempDir::new().unwrap();
let store = FileBackedStore::open(dir.path()).unwrap();
(dir, store)
}
#[test]
fn key_package_upload_fetch() {
let (_dir, store) = temp_store();
let ik = vec![1u8; 32];
store.upload_key_package(&ik, vec![10, 20, 30]).unwrap();
let pkg = store.fetch_key_package(&ik).unwrap();
assert_eq!(pkg, Some(vec![10, 20, 30]));
// Second fetch should return None (consumed)
let pkg2 = store.fetch_key_package(&ik).unwrap();
assert_eq!(pkg2, None);
}
#[test]
fn enqueue_fetch_with_seq() {
let (_dir, store) = temp_store();
let rk = vec![2u8; 32];
let ch = vec![];
let seq0 = store.enqueue(&rk, &ch, vec![1]).unwrap();
let seq1 = store.enqueue(&rk, &ch, vec![2]).unwrap();
assert_eq!(seq0, 0);
assert_eq!(seq1, 1);
let msgs = store.fetch(&rk, &ch).unwrap();
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0], (0, vec![1]));
assert_eq!(msgs[1], (1, vec![2]));
// After fetch, queue should be empty
let msgs2 = store.fetch(&rk, &ch).unwrap();
assert!(msgs2.is_empty());
}
#[test]
fn fetch_limited_respects_limit() {
let (_dir, store) = temp_store();
let rk = vec![3u8; 32];
let ch = vec![];
for i in 0..5 {
store.enqueue(&rk, &ch, vec![i]).unwrap();
}
let msgs = store.fetch_limited(&rk, &ch, 2).unwrap();
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0].1, vec![0]);
assert_eq!(msgs[1].1, vec![1]);
// Remaining 3 should still be there
let depth = store.queue_depth(&rk, &ch).unwrap();
assert_eq!(depth, 3);
}
#[test]
fn queue_depth_tracking() {
let (_dir, store) = temp_store();
let rk = vec![4u8; 32];
let ch = vec![];
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0);
store.enqueue(&rk, &ch, vec![1]).unwrap();
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 1);
store.enqueue(&rk, &ch, vec![2]).unwrap();
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 2);
store.fetch(&rk, &ch).unwrap();
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0);
}
#[test]
fn hybrid_key_upload_fetch() {
let (_dir, store) = temp_store();
let ik = vec![5u8; 32];
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), None);
store.upload_hybrid_key(&ik, vec![99; 100]).unwrap();
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(vec![99; 100]));
}
#[test]
fn user_record_crud() {
let (_dir, store) = temp_store();
assert!(!store.has_user_record("alice").unwrap());
store.store_user_record("alice", vec![1, 2, 3]).unwrap();
assert!(store.has_user_record("alice").unwrap());
assert_eq!(store.get_user_record("alice").unwrap(), Some(vec![1, 2, 3]));
}
#[test]
fn user_identity_key_crud() {
let (_dir, store) = temp_store();
assert_eq!(store.get_user_identity_key("bob").unwrap(), None);
store.store_user_identity_key("bob", vec![7u8; 32]).unwrap();
assert_eq!(store.get_user_identity_key("bob").unwrap(), Some(vec![7u8; 32]));
}
#[test]
fn endpoint_publish_resolve() {
let (_dir, store) = temp_store();
let ik = vec![8u8; 32];
assert_eq!(store.resolve_endpoint(&ik).unwrap(), None);
store.publish_endpoint(&ik, vec![10, 20]).unwrap();
assert_eq!(store.resolve_endpoint(&ik).unwrap(), Some(vec![10, 20]));
}
#[test]
fn create_channel_and_members() {
let (_dir, store) = temp_store();
let a = vec![1u8; 32];
let b = vec![2u8; 32];
assert_eq!(store.get_channel_members(&[0u8; 16]).unwrap(), None);
let id1 = store.create_channel(&a, &b).unwrap();
assert_eq!(id1.len(), 16);
let members = store.get_channel_members(&id1).unwrap().unwrap();
assert_eq!(members.0, a);
assert_eq!(members.1, b);
let id2 = store.create_channel(&b, &a).unwrap();
assert_eq!(id1, id2);
}
}

View File

@@ -61,6 +61,12 @@ fn generate_self_signed_cert(cert_path: &PathBuf, key_path: &PathBuf) -> anyhow:
std::fs::write(cert_path, issued.cert.der()).context("write cert")?;
std::fs::write(key_path, &key_der).context("write key")?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(0o600);
std::fs::set_permissions(key_path, perms).context("set key permissions")?;
}
tracing::info!(
cert = %cert_path.display(),