DM channels (createChannel), channel authz, security/docs, future improvements
- Add createChannel RPC (node.capnp @18): create 1:1 channel, returns 16-byte channelId - Store: create_channel(member_a, member_b), get_channel_members(channel_id) - FileBackedStore: channels.bin; SqlStore: migration 003_channels, schema v4 - channel_ops: handle_create_channel (auth + identity, peerKey 32 bytes) - Delivery authz: when channel_id.len() == 16, require caller and recipient are channel members (E022/E023) - Error codes E022 CHANNEL_ACCESS_DENIED, E023 CHANNEL_NOT_FOUND - SUMMARY: link Certificate lifecycle; security audit, future improvements, multi-agent plan docs - Certificate lifecycle doc, SECURITY-AUDIT, FUTURE-IMPROVEMENTS, MULTI-AGENT-WORK-PLAN - Client/core/tls/auth/server main: assorted fixes and updates from review and audit Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -38,6 +38,7 @@ thiserror = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
argon2 = { workspace = true }
|
||||
chacha20poly1305 = { workspace = true }
|
||||
zeroize = { workspace = true }
|
||||
quinn = { workspace = true }
|
||||
quinn-proto = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
@@ -49,10 +50,12 @@ tracing-subscriber = { workspace = true }
|
||||
# CLI
|
||||
clap = { workspace = true }
|
||||
|
||||
# Hex encoding/decoding
|
||||
hex = "0.4"
|
||||
|
||||
[dev-dependencies]
|
||||
dashmap = { workspace = true }
|
||||
assert_cmd = "2"
|
||||
tempfile = "3"
|
||||
portpicker = "0.1"
|
||||
rand = "0.8"
|
||||
hex = "0.4"
|
||||
|
||||
@@ -574,7 +574,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
|
||||
.await?
|
||||
.context("joiner hybrid key not found")?;
|
||||
let wrapped_welcome =
|
||||
hybrid_encrypt(&joiner_hybrid_pk, &welcome).context("hybrid encrypt welcome")?;
|
||||
hybrid_encrypt(&joiner_hybrid_pk, &welcome, b"", b"").context("hybrid encrypt welcome")?;
|
||||
enqueue(&creator_ds, &joiner_identity, &wrapped_welcome).await?;
|
||||
|
||||
let welcome_payloads = fetch_all(&joiner_ds, &joiner_identity).await?;
|
||||
@@ -584,7 +584,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
|
||||
.context("Welcome was not delivered to joiner via DS")?;
|
||||
|
||||
let welcome_bytes =
|
||||
hybrid_decrypt(&joiner_hybrid, &raw_welcome).context("hybrid decrypt welcome failed")?;
|
||||
hybrid_decrypt(&joiner_hybrid, &raw_welcome, b"", b"").context("hybrid decrypt welcome failed")?;
|
||||
joiner
|
||||
.join_group(&welcome_bytes)
|
||||
.context("join_group failed")?;
|
||||
@@ -593,7 +593,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
|
||||
.send_message(b"hello")
|
||||
.context("send_message failed")?;
|
||||
let wrapped_creator_joiner =
|
||||
hybrid_encrypt(&joiner_hybrid_pk, &ct_creator_to_joiner).context("hybrid encrypt failed")?;
|
||||
hybrid_encrypt(&joiner_hybrid_pk, &ct_creator_to_joiner, b"", b"").context("hybrid encrypt failed")?;
|
||||
enqueue(&creator_ds, &joiner_identity, &wrapped_creator_joiner).await?;
|
||||
|
||||
let joiner_msgs = fetch_all(&joiner_ds, &joiner_identity).await?;
|
||||
@@ -601,7 +601,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
|
||||
.first()
|
||||
.context("joiner: missing ciphertext from DS")?;
|
||||
let inner_creator_joiner =
|
||||
hybrid_decrypt(&joiner_hybrid, raw_creator_joiner).context("hybrid decrypt failed")?;
|
||||
hybrid_decrypt(&joiner_hybrid, raw_creator_joiner, b"", b"").context("hybrid decrypt failed")?;
|
||||
let plaintext_creator_joiner = joiner
|
||||
.receive_message(&inner_creator_joiner)?
|
||||
.context("expected application message")?;
|
||||
@@ -617,7 +617,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
|
||||
.send_message(b"hello back")
|
||||
.context("send_message failed")?;
|
||||
let wrapped_joiner_creator =
|
||||
hybrid_encrypt(&creator_hybrid_pk, &ct_joiner_to_creator).context("hybrid encrypt failed")?;
|
||||
hybrid_encrypt(&creator_hybrid_pk, &ct_joiner_to_creator, b"", b"").context("hybrid encrypt failed")?;
|
||||
enqueue(&joiner_ds, &creator_identity, &wrapped_joiner_creator).await?;
|
||||
|
||||
let creator_msgs = fetch_all(&creator_ds, &creator_identity).await?;
|
||||
@@ -625,7 +625,7 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
|
||||
.first()
|
||||
.context("creator: missing ciphertext from DS")?;
|
||||
let inner_joiner_creator =
|
||||
hybrid_decrypt(&creator_hybrid, raw_joiner_creator).context("hybrid decrypt failed")?;
|
||||
hybrid_decrypt(&creator_hybrid, raw_joiner_creator, b"", b"").context("hybrid decrypt failed")?;
|
||||
let plaintext_joiner_creator = creator
|
||||
.receive_message(&inner_joiner_creator)?
|
||||
.context("expected application message")?;
|
||||
@@ -701,7 +701,7 @@ pub async fn cmd_invite(
|
||||
}
|
||||
let peer_hpk = fetch_hybrid_key(&node_client, mk).await?;
|
||||
let commit_payload = if let Some(ref pk) = peer_hpk {
|
||||
hybrid_encrypt(pk, &commit).context("hybrid encrypt commit")?
|
||||
hybrid_encrypt(pk, &commit, b"", b"").context("hybrid encrypt commit")?
|
||||
} else {
|
||||
commit.clone()
|
||||
};
|
||||
@@ -710,7 +710,7 @@ pub async fn cmd_invite(
|
||||
|
||||
let peer_hybrid_pk = fetch_hybrid_key(&node_client, &peer_key).await?;
|
||||
let payload = if let Some(ref pk) = peer_hybrid_pk {
|
||||
hybrid_encrypt(pk, &welcome).context("hybrid encrypt welcome failed")?
|
||||
hybrid_encrypt(pk, &welcome, b"", b"").context("hybrid encrypt welcome failed")?
|
||||
} else {
|
||||
welcome
|
||||
};
|
||||
@@ -774,6 +774,15 @@ pub async fn cmd_join(
|
||||
let _ = member.receive_message(&mls_payload);
|
||||
}
|
||||
|
||||
// Auto-replenish KeyPackage after join consumed the original one.
|
||||
let tls_bytes = member
|
||||
.generate_key_package()
|
||||
.context("KeyPackage replenishment failed")?;
|
||||
upload_key_package(&node_client, &member.identity().public_key_bytes(), &tls_bytes)
|
||||
.await
|
||||
.context("KeyPackage replenishment upload failed")?;
|
||||
println!("KeyPackage auto-replenished after join");
|
||||
|
||||
save_state(state_path, &member, hybrid_kp.as_ref(), password)?;
|
||||
println!("joined group successfully");
|
||||
Ok(())
|
||||
@@ -820,7 +829,7 @@ pub async fn cmd_send(
|
||||
for recipient in &recipients {
|
||||
let peer_hybrid_pk = fetch_hybrid_key(&node_client, recipient).await?;
|
||||
let payload = if let Some(ref pk) = peer_hybrid_pk {
|
||||
hybrid_encrypt(pk, &ct).context("hybrid encrypt failed")?
|
||||
hybrid_encrypt(pk, &ct, b"", b"").context("hybrid encrypt failed")?
|
||||
} else {
|
||||
ct.clone()
|
||||
};
|
||||
@@ -871,7 +880,7 @@ pub async fn cmd_recv(
|
||||
// application messages that depend on the resulting epoch.
|
||||
payloads.sort_by_key(|(seq, _)| *seq);
|
||||
|
||||
let mut retry_mls: Vec<Vec<u8>> = Vec::new();
|
||||
let mut pending: Vec<(usize, Vec<u8>)> = Vec::new();
|
||||
for (idx, (_, payload)) in payloads.iter().enumerate() {
|
||||
let mls_payload = match try_hybrid_decrypt(hybrid_kp.as_ref(), payload) {
|
||||
Ok(b) => b,
|
||||
@@ -883,18 +892,32 @@ pub async fn cmd_recv(
|
||||
match member.receive_message(&mls_payload) {
|
||||
Ok(Some(pt)) => println!("[{idx}] plaintext: {}", String::from_utf8_lossy(&pt)),
|
||||
Ok(None) => println!("[{idx}] commit applied"),
|
||||
Err(_) => retry_mls.push(mls_payload),
|
||||
Err(_) => pending.push((idx, mls_payload)),
|
||||
}
|
||||
}
|
||||
// Retry messages that failed on the first pass (e.g. app messages whose
|
||||
// epoch was not yet advanced until a commit earlier in the batch was applied).
|
||||
for mls_payload in &retry_mls {
|
||||
match member.receive_message(mls_payload) {
|
||||
Ok(Some(pt)) => println!("[retry] plaintext: {}", String::from_utf8_lossy(&pt)),
|
||||
Ok(None) => {}
|
||||
Err(e) => println!("[retry] error: {e}"),
|
||||
// Retry until no more progress (handles multi-epoch batches).
|
||||
loop {
|
||||
let before = pending.len();
|
||||
pending.retain(|(idx, mls_payload)| {
|
||||
match member.receive_message(mls_payload) {
|
||||
Ok(Some(pt)) => {
|
||||
println!("[{idx}/retry] plaintext: {}", String::from_utf8_lossy(&pt));
|
||||
false
|
||||
}
|
||||
Ok(None) => {
|
||||
println!("[{idx}/retry] commit applied");
|
||||
false
|
||||
}
|
||||
Err(_) => true,
|
||||
}
|
||||
});
|
||||
if pending.len() == before {
|
||||
break; // No progress — remaining messages are unprocessable
|
||||
}
|
||||
}
|
||||
for (idx, _) in &pending {
|
||||
println!("[{idx}] error: unprocessable after all retries");
|
||||
}
|
||||
|
||||
save_state(state_path, &member, hybrid_kp.as_ref(), password)?;
|
||||
|
||||
@@ -906,8 +929,8 @@ pub async fn cmd_recv(
|
||||
|
||||
/// Fetch pending payloads, process in order (merge commits, collect plaintexts), save state.
|
||||
/// Returns only application-message plaintexts. Used by E2E tests and callers that need returned messages.
|
||||
/// Uses two passes so that if the server delivers an application message before a Commit, the second pass
|
||||
/// processes it after commits are merged.
|
||||
/// Retries in a loop until no more progress, handling multi-epoch batches where commits must be
|
||||
/// applied before later application messages can be decrypted.
|
||||
pub async fn receive_pending_plaintexts(
|
||||
state_path: &Path,
|
||||
server: &str,
|
||||
@@ -925,7 +948,7 @@ pub async fn receive_pending_plaintexts(
|
||||
payloads.sort_by_key(|(seq, _)| *seq);
|
||||
|
||||
let mut plaintexts = Vec::new();
|
||||
let mut retry_mls: Vec<Vec<u8>> = Vec::new();
|
||||
let mut pending: Vec<Vec<u8>> = Vec::new();
|
||||
for (_, payload) in &payloads {
|
||||
let mls_payload = match try_hybrid_decrypt(hybrid_kp.as_ref(), payload) {
|
||||
Ok(b) => b,
|
||||
@@ -934,12 +957,24 @@ pub async fn receive_pending_plaintexts(
|
||||
match member.receive_message(&mls_payload) {
|
||||
Ok(Some(pt)) => plaintexts.push(pt),
|
||||
Ok(None) => {}
|
||||
Err(_) => retry_mls.push(mls_payload),
|
||||
Err(_) => pending.push(mls_payload),
|
||||
}
|
||||
}
|
||||
for mls_payload in &retry_mls {
|
||||
if let Ok(Some(pt)) = member.receive_message(mls_payload) {
|
||||
plaintexts.push(pt);
|
||||
// Retry until no more progress (handles multi-epoch batches).
|
||||
loop {
|
||||
let before = pending.len();
|
||||
pending.retain(|mls_payload| {
|
||||
match member.receive_message(mls_payload) {
|
||||
Ok(Some(pt)) => {
|
||||
plaintexts.push(pt);
|
||||
false
|
||||
}
|
||||
Ok(None) => false,
|
||||
Err(_) => true,
|
||||
}
|
||||
});
|
||||
if pending.len() == before {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1069,7 +1104,7 @@ pub async fn cmd_chat(
|
||||
.context("send_message failed")?;
|
||||
let peer_hybrid_pk = fetch_hybrid_key(&client, &peer_key).await?;
|
||||
let payload = if let Some(ref pk) = peer_hybrid_pk {
|
||||
hybrid_encrypt(pk, &ct).context("hybrid encrypt failed")?
|
||||
hybrid_encrypt(pk, &ct, b"", b"").context("hybrid encrypt failed")?
|
||||
} else {
|
||||
ct
|
||||
};
|
||||
@@ -1085,6 +1120,7 @@ pub async fn cmd_chat(
|
||||
_ = poll.tick() => {
|
||||
let mut payloads = fetch_wait(&client, &identity_bytes, 0).await?;
|
||||
payloads.sort_by_key(|(seq, _)| *seq);
|
||||
let mut retry_payloads: Vec<Vec<u8>> = Vec::new();
|
||||
for (_, payload) in &payloads {
|
||||
let mls_payload = match try_hybrid_decrypt(hybrid_kp.as_ref(), payload) {
|
||||
Ok(b) => b,
|
||||
@@ -1097,9 +1133,26 @@ pub async fn cmd_chat(
|
||||
std::io::stdout().flush().context("flush stdout")?;
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(_) => {}
|
||||
Err(_) => retry_payloads.push(mls_payload),
|
||||
}
|
||||
}
|
||||
// Retry failed messages (epoch may have advanced from commits in this batch)
|
||||
loop {
|
||||
let before = retry_payloads.len();
|
||||
retry_payloads.retain(|mls_payload| {
|
||||
match member.receive_message(mls_payload) {
|
||||
Ok(Some(pt)) => {
|
||||
let s = String::from_utf8_lossy(&pt);
|
||||
println!("\r\n[peer] {s}\n> ");
|
||||
let _ = std::io::stdout().flush();
|
||||
false
|
||||
}
|
||||
Ok(None) => false,
|
||||
Err(_) => true,
|
||||
}
|
||||
});
|
||||
if retry_payloads.len() == before { break; }
|
||||
}
|
||||
if !payloads.is_empty() {
|
||||
save_state(state_path, &member, hybrid_kp.as_ref(), password)?;
|
||||
}
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
pub fn encode(bytes: impl AsRef<[u8]>) -> String {
|
||||
bytes.as_ref().iter().map(|b| format!("{b:02x}")).collect()
|
||||
hex::encode(bytes)
|
||||
}
|
||||
|
||||
pub fn decode(s: &str) -> Result<Vec<u8>, &'static str> {
|
||||
if s.len() % 2 != 0 {
|
||||
return Err("odd-length hex string");
|
||||
}
|
||||
(0..s.len())
|
||||
.step_by(2)
|
||||
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|_| "invalid hex character"))
|
||||
.collect()
|
||||
hex::decode(s).map_err(|_| "invalid hex string")
|
||||
}
|
||||
|
||||
@@ -48,7 +48,12 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(last_err.expect("retry_async: last_err set when we break after Err"))
|
||||
match last_err {
|
||||
Some(e) => Err(e),
|
||||
None => unreachable!(
|
||||
"retry_async: last_err is always Some when loop exits after an Err"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Classifies `anyhow::Error` for retry: returns `false` for auth or invalid-param
|
||||
|
||||
@@ -17,6 +17,9 @@ use crate::AUTH_CONTEXT;
|
||||
|
||||
use super::retry::{anyhow_is_retriable, retry_async, DEFAULT_BASE_DELAY_MS, DEFAULT_MAX_RETRIES};
|
||||
|
||||
/// Cap'n Proto traversal limit (words). 4 Mi words = 32 MiB; bounds DoS from deeply nested or large messages.
|
||||
const CAPNP_TRAVERSAL_LIMIT_WORDS: usize = 4 * 1024 * 1024;
|
||||
|
||||
/// Establish a QUIC/TLS connection and return a `NodeService` client.
|
||||
///
|
||||
/// Must be called from within a `LocalSet` because capnp-rpc is `!Send`.
|
||||
@@ -55,11 +58,13 @@ pub async fn connect_node(
|
||||
|
||||
let (send, recv) = connection.open_bi().await.context("open bi stream")?;
|
||||
|
||||
let mut reader_opts = capnp::message::ReaderOptions::new();
|
||||
reader_opts.traversal_limit_in_words(Some(CAPNP_TRAVERSAL_LIMIT_WORDS));
|
||||
let network = twoparty::VatNetwork::new(
|
||||
recv.compat(),
|
||||
send.compat_write(),
|
||||
Side::Client,
|
||||
Default::default(),
|
||||
reader_opts,
|
||||
);
|
||||
|
||||
let mut rpc_system = RpcSystem::new(Box::new(network), None);
|
||||
@@ -72,7 +77,9 @@ pub async fn connect_node(
|
||||
|
||||
pub fn set_auth(auth: &mut auth::Builder<'_>) -> anyhow::Result<()> {
|
||||
let ctx = AUTH_CONTEXT.get().ok_or_else(|| {
|
||||
anyhow::anyhow!("init_auth must be called with a non-empty token before RPCs")
|
||||
anyhow::anyhow!(
|
||||
"init_auth must be called before RPCs (use a bearer or session token for authenticated commands)"
|
||||
)
|
||||
})?;
|
||||
auth.set_version(ctx.version);
|
||||
auth.set_access_token(&ctx.access_token);
|
||||
@@ -355,7 +362,216 @@ pub fn try_hybrid_decrypt(
|
||||
payload: &[u8],
|
||||
) -> anyhow::Result<Vec<u8>> {
|
||||
let kp = hybrid_kp.ok_or_else(|| anyhow::anyhow!("hybrid key required for decryption"))?;
|
||||
quicnprotochat_core::hybrid_decrypt(kp, payload).map_err(|e| anyhow::anyhow!("{e}"))
|
||||
quicnprotochat_core::hybrid_decrypt(kp, payload, b"", b"").map_err(|e| anyhow::anyhow!("{e}"))
|
||||
}
|
||||
|
||||
/// Peek at queued payloads without removing them.
|
||||
/// Returns `(seq, payload)` pairs sorted by seq.
|
||||
/// Retries on transient failures with exponential backoff.
|
||||
pub async fn peek(
|
||||
client: &node_service::Client,
|
||||
recipient_key: &[u8],
|
||||
) -> anyhow::Result<Vec<(u64, Vec<u8>)>> {
|
||||
let client = client.clone();
|
||||
let recipient_key = recipient_key.to_vec();
|
||||
retry_async(
|
||||
|| {
|
||||
let client = client.clone();
|
||||
let recipient_key = recipient_key.clone();
|
||||
async move {
|
||||
let mut req = client.peek_request();
|
||||
{
|
||||
let mut p = req.get();
|
||||
p.set_recipient_key(&recipient_key);
|
||||
p.set_channel_id(&[]);
|
||||
p.set_version(1);
|
||||
p.set_limit(0); // peek all
|
||||
let mut auth = p.reborrow().init_auth();
|
||||
set_auth(&mut auth)?;
|
||||
}
|
||||
|
||||
let resp = req.send().promise.await.context("peek RPC failed")?;
|
||||
|
||||
let list = resp
|
||||
.get()
|
||||
.context("peek: bad response")?
|
||||
.get_payloads()
|
||||
.context("peek: missing payloads")?;
|
||||
|
||||
let mut payloads = Vec::with_capacity(list.len() as usize);
|
||||
for i in 0..list.len() {
|
||||
let entry = list.get(i);
|
||||
let seq = entry.get_seq();
|
||||
let data = entry
|
||||
.get_data()
|
||||
.context("peek: envelope data read failed")?
|
||||
.to_vec();
|
||||
payloads.push((seq, data));
|
||||
}
|
||||
|
||||
Ok(payloads)
|
||||
}
|
||||
},
|
||||
DEFAULT_MAX_RETRIES,
|
||||
DEFAULT_BASE_DELAY_MS,
|
||||
anyhow_is_retriable,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Acknowledge all messages up to and including `seq_up_to`.
|
||||
/// Retries on transient failures with exponential backoff.
|
||||
pub async fn ack(
|
||||
client: &node_service::Client,
|
||||
recipient_key: &[u8],
|
||||
seq_up_to: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
let client = client.clone();
|
||||
let recipient_key = recipient_key.to_vec();
|
||||
retry_async(
|
||||
|| {
|
||||
let client = client.clone();
|
||||
let recipient_key = recipient_key.clone();
|
||||
async move {
|
||||
let mut req = client.ack_request();
|
||||
{
|
||||
let mut p = req.get();
|
||||
p.set_recipient_key(&recipient_key);
|
||||
p.set_channel_id(&[]);
|
||||
p.set_version(1);
|
||||
p.set_seq_up_to(seq_up_to);
|
||||
let mut auth = p.reborrow().init_auth();
|
||||
set_auth(&mut auth)?;
|
||||
}
|
||||
req.send().promise.await.context("ack RPC failed")?;
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
DEFAULT_MAX_RETRIES,
|
||||
DEFAULT_BASE_DELAY_MS,
|
||||
anyhow_is_retriable,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Fetch multiple peers' hybrid keys in a single round-trip.
|
||||
/// Returns `None` for peers who have not uploaded a hybrid key.
|
||||
/// Retries on transient failures with exponential backoff.
|
||||
pub async fn fetch_hybrid_keys(
|
||||
client: &node_service::Client,
|
||||
identity_keys: &[&[u8]],
|
||||
) -> anyhow::Result<Vec<Option<HybridPublicKey>>> {
|
||||
let client = client.clone();
|
||||
let identity_keys: Vec<Vec<u8>> = identity_keys.iter().map(|k| k.to_vec()).collect();
|
||||
retry_async(
|
||||
|| {
|
||||
let client = client.clone();
|
||||
let identity_keys = identity_keys.clone();
|
||||
async move {
|
||||
let mut req = client.fetch_hybrid_keys_request();
|
||||
{
|
||||
let mut p = req.get();
|
||||
let mut list = p.reborrow().init_identity_keys(identity_keys.len() as u32);
|
||||
for (i, ik) in identity_keys.iter().enumerate() {
|
||||
list.set(i as u32, ik);
|
||||
}
|
||||
let mut auth = p.reborrow().init_auth();
|
||||
set_auth(&mut auth)?;
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.promise
|
||||
.await
|
||||
.context("fetch_hybrid_keys RPC failed")?;
|
||||
|
||||
let keys = resp
|
||||
.get()
|
||||
.context("fetch_hybrid_keys: bad response")?
|
||||
.get_keys()
|
||||
.context("fetch_hybrid_keys: missing keys")?;
|
||||
|
||||
let mut result = Vec::with_capacity(keys.len() as usize);
|
||||
for i in 0..keys.len() {
|
||||
let pk_bytes = keys
|
||||
.get(i)
|
||||
.context("fetch_hybrid_keys: key read failed")?
|
||||
.to_vec();
|
||||
if pk_bytes.is_empty() {
|
||||
result.push(None);
|
||||
} else {
|
||||
let pk = HybridPublicKey::from_bytes(&pk_bytes)
|
||||
.context("invalid hybrid public key")?;
|
||||
result.push(Some(pk));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
},
|
||||
DEFAULT_MAX_RETRIES,
|
||||
DEFAULT_BASE_DELAY_MS,
|
||||
anyhow_is_retriable,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Enqueue the same payload to multiple recipients in a single round-trip.
|
||||
/// Returns per-recipient sequence numbers.
|
||||
/// Retries on transient failures with exponential backoff.
|
||||
pub async fn batch_enqueue(
|
||||
client: &node_service::Client,
|
||||
recipient_keys: &[&[u8]],
|
||||
payload: &[u8],
|
||||
) -> anyhow::Result<Vec<u64>> {
|
||||
let client = client.clone();
|
||||
let recipient_keys: Vec<Vec<u8>> = recipient_keys.iter().map(|k| k.to_vec()).collect();
|
||||
let payload = payload.to_vec();
|
||||
retry_async(
|
||||
|| {
|
||||
let client = client.clone();
|
||||
let recipient_keys = recipient_keys.clone();
|
||||
let payload = payload.clone();
|
||||
async move {
|
||||
let mut req = client.batch_enqueue_request();
|
||||
{
|
||||
let mut p = req.get();
|
||||
let mut list = p.reborrow().init_recipient_keys(recipient_keys.len() as u32);
|
||||
for (i, rk) in recipient_keys.iter().enumerate() {
|
||||
list.set(i as u32, rk);
|
||||
}
|
||||
p.set_payload(&payload);
|
||||
p.set_channel_id(&[]);
|
||||
p.set_version(1);
|
||||
let mut auth = p.reborrow().init_auth();
|
||||
set_auth(&mut auth)?;
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.promise
|
||||
.await
|
||||
.context("batch_enqueue RPC failed")?;
|
||||
|
||||
let seqs = resp
|
||||
.get()
|
||||
.context("batch_enqueue: bad response")?
|
||||
.get_seqs()
|
||||
.context("batch_enqueue: missing seqs")?;
|
||||
|
||||
let mut result = Vec::with_capacity(seqs.len() as usize);
|
||||
for i in 0..seqs.len() {
|
||||
result.push(seqs.get(i));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
},
|
||||
DEFAULT_MAX_RETRIES,
|
||||
DEFAULT_BASE_DELAY_MS,
|
||||
anyhow_is_retriable,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Return the current Unix timestamp in milliseconds.
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use argon2::Argon2;
|
||||
use argon2::{Algorithm, Argon2, Params, Version};
|
||||
use chacha20poly1305::{
|
||||
aead::{Aead, KeyInit},
|
||||
ChaCha20Poly1305, Key, Nonce,
|
||||
@@ -62,10 +62,21 @@ impl StoredState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Derive a 32-byte key from a password and salt using Argon2id.
|
||||
/// Argon2id parameters for client state key derivation (auditable; matches argon2 crate defaults).
|
||||
/// - Memory: 19 MiB (m_cost = 19*1024 KiB)
|
||||
/// - Time: 2 iterations
|
||||
/// - Parallelism: 1 lane
|
||||
const ARGON2_STATE_M_COST: u32 = 19 * 1024;
|
||||
const ARGON2_STATE_T_COST: u32 = 2;
|
||||
const ARGON2_STATE_P_COST: u32 = 1;
|
||||
|
||||
/// Derive a 32-byte key from a password and salt using Argon2id with explicit parameters.
|
||||
fn derive_state_key(password: &str, salt: &[u8]) -> anyhow::Result<[u8; 32]> {
|
||||
let params = Params::new(ARGON2_STATE_M_COST, ARGON2_STATE_T_COST, ARGON2_STATE_P_COST, Some(32))
|
||||
.map_err(|e| anyhow::anyhow!("argon2 params: {e}"))?;
|
||||
let argon2 = Argon2::new(Algorithm::Argon2id, Version::default(), params);
|
||||
let mut key = [0u8; 32];
|
||||
Argon2::default()
|
||||
argon2
|
||||
.hash_password_into(password.as_bytes(), salt, &mut key)
|
||||
.map_err(|e| anyhow::anyhow!("argon2 key derivation failed: {e}"))?;
|
||||
Ok(key)
|
||||
@@ -79,8 +90,8 @@ pub fn encrypt_state(password: &str, plaintext: &[u8]) -> anyhow::Result<Vec<u8>
|
||||
let mut nonce_bytes = [0u8; STATE_NONCE_LEN];
|
||||
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
|
||||
|
||||
let key = derive_state_key(password, &salt)?;
|
||||
let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
|
||||
let key = zeroize::Zeroizing::new(derive_state_key(password, &salt)?);
|
||||
let cipher = ChaCha20Poly1305::new(Key::from_slice(&*key));
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
let ciphertext = cipher
|
||||
@@ -108,8 +119,8 @@ pub fn decrypt_state(password: &str, data: &[u8]) -> anyhow::Result<Vec<u8>> {
|
||||
let nonce_bytes = &data[4 + STATE_SALT_LEN..header_len];
|
||||
let ciphertext = &data[header_len..];
|
||||
|
||||
let key = derive_state_key(password, salt)?;
|
||||
let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
|
||||
let key = zeroize::Zeroizing::new(derive_state_key(password, salt)?);
|
||||
let cipher = ChaCha20Poly1305::new(Key::from_slice(&*key));
|
||||
let nonce = Nonce::from_slice(nonce_bytes);
|
||||
|
||||
let plaintext = cipher
|
||||
@@ -179,7 +190,9 @@ pub fn write_state(path: &Path, state: &StoredState, password: Option<&str>) ->
|
||||
plaintext
|
||||
};
|
||||
|
||||
std::fs::write(path, bytes).with_context(|| format!("write state {path:?}"))?;
|
||||
let tmp = path.with_extension("tmp");
|
||||
std::fs::write(&tmp, bytes).with_context(|| format!("write state temp {tmp:?}"))?;
|
||||
std::fs::rename(&tmp, path).with_context(|| format!("rename state {tmp:?} -> {path:?}"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -222,4 +235,57 @@ mod tests {
|
||||
let encrypted = encrypt_state("correct", plaintext).unwrap();
|
||||
assert!(decrypt_state("wrong", &encrypted).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_encrypt_decrypt_round_trip() {
|
||||
let state = StoredState {
|
||||
identity_seed: [42u8; 32],
|
||||
hybrid_key: None,
|
||||
group: None,
|
||||
member_keys: Vec::new(),
|
||||
};
|
||||
let password = "test-password";
|
||||
let plaintext = bincode::serialize(&state).unwrap();
|
||||
let encrypted = encrypt_state(password, &plaintext).unwrap();
|
||||
let decrypted = decrypt_state(password, &encrypted).unwrap();
|
||||
let recovered: StoredState = bincode::deserialize(&decrypted).unwrap();
|
||||
assert_eq!(recovered.identity_seed, state.identity_seed);
|
||||
assert!(recovered.hybrid_key.is_none());
|
||||
assert!(recovered.group.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_encrypt_decrypt_with_hybrid_key() {
|
||||
use zeroize::Zeroizing;
|
||||
let state = StoredState {
|
||||
identity_seed: [7u8; 32],
|
||||
hybrid_key: Some(HybridKeypairBytes {
|
||||
x25519_sk: Zeroizing::new([1u8; 32]),
|
||||
mlkem_dk: Zeroizing::new(vec![3u8; 2400]),
|
||||
mlkem_ek: vec![4u8; 1184],
|
||||
}),
|
||||
group: None,
|
||||
member_keys: Vec::new(),
|
||||
};
|
||||
let password = "another-password";
|
||||
let plaintext = bincode::serialize(&state).unwrap();
|
||||
let encrypted = encrypt_state(password, &plaintext).unwrap();
|
||||
let decrypted = decrypt_state(password, &encrypted).unwrap();
|
||||
let recovered: StoredState = bincode::deserialize(&decrypted).unwrap();
|
||||
assert_eq!(recovered.identity_seed, state.identity_seed);
|
||||
assert!(recovered.hybrid_key.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_wrong_password_fails() {
|
||||
let state = StoredState {
|
||||
identity_seed: [99u8; 32],
|
||||
hybrid_key: None,
|
||||
group: None,
|
||||
member_keys: Vec::new(),
|
||||
};
|
||||
let plaintext = bincode::serialize(&state).unwrap();
|
||||
let encrypted = encrypt_state("correct", &plaintext).unwrap();
|
||||
assert!(decrypt_state("wrong", &encrypted).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,22 +91,28 @@ pub fn serialize(msg_type: MessageType, payload: &[u8]) -> Vec<u8> {
|
||||
}
|
||||
|
||||
/// Serialize a Chat message (generates message_id internally; pass None to generate, or Some(id) when replying with a known id).
|
||||
pub fn serialize_chat(body: &[u8], message_id: Option<[u8; 16]>) -> Vec<u8> {
|
||||
pub fn serialize_chat(body: &[u8], message_id: Option<[u8; 16]>) -> Result<Vec<u8>, CoreError> {
|
||||
if body.len() > u16::MAX as usize {
|
||||
return Err(CoreError::AppMessage("chat body exceeds maximum length (65535 bytes)".into()));
|
||||
}
|
||||
let id = message_id.unwrap_or_else(generate_message_id);
|
||||
let mut payload = Vec::with_capacity(16 + 2 + body.len());
|
||||
payload.extend_from_slice(&id);
|
||||
payload.extend_from_slice(&(body.len() as u16).to_be_bytes());
|
||||
payload.extend_from_slice(body);
|
||||
serialize(MessageType::Chat, &payload)
|
||||
Ok(serialize(MessageType::Chat, &payload))
|
||||
}
|
||||
|
||||
/// Serialize a Reply message.
|
||||
pub fn serialize_reply(ref_msg_id: [u8; 16], body: &[u8]) -> Vec<u8> {
|
||||
pub fn serialize_reply(ref_msg_id: [u8; 16], body: &[u8]) -> Result<Vec<u8>, CoreError> {
|
||||
if body.len() > u16::MAX as usize {
|
||||
return Err(CoreError::AppMessage("reply body exceeds maximum length (65535 bytes)".into()));
|
||||
}
|
||||
let mut payload = Vec::with_capacity(16 + 2 + body.len());
|
||||
payload.extend_from_slice(&ref_msg_id);
|
||||
payload.extend_from_slice(&(body.len() as u16).to_be_bytes());
|
||||
payload.extend_from_slice(body);
|
||||
serialize(MessageType::Reply, &payload)
|
||||
Ok(serialize(MessageType::Reply, &payload))
|
||||
}
|
||||
|
||||
/// Serialize a Reaction message.
|
||||
@@ -220,7 +226,7 @@ mod tests {
|
||||
#[test]
|
||||
fn roundtrip_chat() {
|
||||
let body = b"hello";
|
||||
let encoded = serialize_chat(body, None);
|
||||
let encoded = serialize_chat(body, None).unwrap();
|
||||
let (t, msg) = parse(&encoded).unwrap();
|
||||
assert_eq!(t, MessageType::Chat);
|
||||
match &msg {
|
||||
@@ -233,7 +239,7 @@ mod tests {
|
||||
fn roundtrip_reply() {
|
||||
let ref_id = [1u8; 16];
|
||||
let body = b"reply text";
|
||||
let encoded = serialize_reply(ref_id, body);
|
||||
let encoded = serialize_reply(ref_id, body).unwrap();
|
||||
let (t, msg) = parse(&encoded).unwrap();
|
||||
assert_eq!(t, MessageType::Reply);
|
||||
match &msg {
|
||||
@@ -255,4 +261,67 @@ mod tests {
|
||||
_ => panic!("expected Typing"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_reaction() {
|
||||
let ref_id = [2u8; 16];
|
||||
let emoji = "\u{1f44d}".as_bytes();
|
||||
let encoded = serialize_reaction(ref_id, emoji).unwrap();
|
||||
let (t, msg) = parse(&encoded).unwrap();
|
||||
assert_eq!(t, MessageType::Reaction);
|
||||
match &msg {
|
||||
AppMessage::Reaction { ref_msg_id, emoji: e } => {
|
||||
assert_eq!(ref_msg_id, &ref_id);
|
||||
assert_eq!(e.as_slice(), emoji);
|
||||
}
|
||||
_ => panic!("expected Reaction"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_read_receipt() {
|
||||
let msg_id = [3u8; 16];
|
||||
let encoded = serialize_read_receipt(msg_id);
|
||||
let (t, msg) = parse(&encoded).unwrap();
|
||||
assert_eq!(t, MessageType::ReadReceipt);
|
||||
match &msg {
|
||||
AppMessage::ReadReceipt { msg_id: id } => assert_eq!(id, &msg_id),
|
||||
_ => panic!("expected ReadReceipt"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_empty_fails() {
|
||||
assert!(parse(&[]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_bad_version_fails() {
|
||||
assert!(parse(&[99, 0x01]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_bad_type_fails() {
|
||||
assert!(parse(&[1, 0xFF]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_body_too_long() {
|
||||
let body = vec![0u8; 65536]; // exceeds u16::MAX
|
||||
assert!(serialize_chat(&body, None).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reaction_emoji_too_long() {
|
||||
let emoji = vec![0u8; 256];
|
||||
assert!(serialize_reaction([0; 16], &emoji).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_truncated_chat_payload() {
|
||||
// Version + type + only 10 bytes of payload (needs 18 minimum for chat)
|
||||
let mut data = vec![1, 0x01];
|
||||
data.extend_from_slice(&[0u8; 10]);
|
||||
assert!(parse(&data).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,9 +161,15 @@ impl OpenMlsCrypto for HybridCrypto {
|
||||
if Self::is_hybrid_public_key(pk_r) {
|
||||
let recipient_pk = match HybridPublicKey::from_bytes(pk_r) {
|
||||
Ok(pk) => pk,
|
||||
Err(_) => return self.rust_crypto.hpke_seal(config, pk_r, info, aad, ptxt),
|
||||
// Key parsed as hybrid length but failed to deserialize — this is
|
||||
// a real error, not a reason to silently fall back to classical HPKE.
|
||||
Err(_) => return HpkeCiphertext {
|
||||
kem_output: Vec::new().into(),
|
||||
ciphertext: Vec::new().into(),
|
||||
},
|
||||
};
|
||||
match hybrid_encrypt(&recipient_pk, ptxt) {
|
||||
// Pass HPKE info and aad through for proper context binding (RFC 9180).
|
||||
match hybrid_encrypt(&recipient_pk, ptxt, info, aad) {
|
||||
Ok(envelope) => {
|
||||
let kem_output = envelope[..HYBRID_KEM_OUTPUT_LEN].to_vec();
|
||||
let ciphertext = envelope[HYBRID_KEM_OUTPUT_LEN..].to_vec();
|
||||
@@ -172,7 +178,13 @@ impl OpenMlsCrypto for HybridCrypto {
|
||||
ciphertext: ciphertext.into(),
|
||||
}
|
||||
}
|
||||
Err(_) => self.rust_crypto.hpke_seal(config, pk_r, info, aad, ptxt),
|
||||
// Encryption failed with a hybrid key — return empty ciphertext
|
||||
// rather than silently falling back to classical HPKE with an
|
||||
// incompatible key.
|
||||
Err(_) => HpkeCiphertext {
|
||||
kem_output: Vec::new().into(),
|
||||
ciphertext: Vec::new().into(),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
self.rust_crypto.hpke_seal(config, pk_r, info, aad, ptxt)
|
||||
@@ -188,17 +200,17 @@ impl OpenMlsCrypto for HybridCrypto {
|
||||
aad: &[u8],
|
||||
) -> Result<Vec<u8>, CryptoError> {
|
||||
if Self::is_hybrid_private_key(sk_r) {
|
||||
let keypair = match HybridKeypair::from_private_bytes(sk_r) {
|
||||
Ok(kp) => kp,
|
||||
Err(_) => return self.rust_crypto.hpke_open(config, input, sk_r, info, aad),
|
||||
};
|
||||
let keypair = HybridKeypair::from_private_bytes(sk_r)
|
||||
.map_err(|_| CryptoError::HpkeDecryptionError)?;
|
||||
let envelope: Vec<u8> = input
|
||||
.kem_output.as_slice()
|
||||
.iter()
|
||||
.chain(input.ciphertext.as_slice())
|
||||
.copied()
|
||||
.collect();
|
||||
hybrid_decrypt(&keypair, &envelope).map_err(|_| CryptoError::HpkeDecryptionError)
|
||||
// Pass HPKE info and aad through for proper context binding (RFC 9180).
|
||||
hybrid_decrypt(&keypair, &envelope, info, aad)
|
||||
.map_err(|_| CryptoError::HpkeDecryptionError)
|
||||
} else {
|
||||
self.rust_crypto.hpke_open(config, input, sk_r, info, aad)
|
||||
}
|
||||
|
||||
@@ -43,6 +43,9 @@ const HYBRID_VERSION: u8 = 0x01;
|
||||
/// HKDF info string for domain separation.
|
||||
const HKDF_INFO: &[u8] = b"quicnprotochat-hybrid-v1";
|
||||
|
||||
/// HKDF salt for domain separation (defence-in-depth; IKM already has 64 bytes of entropy).
|
||||
const HKDF_SALT: &[u8] = b"quicnprotochat-hybrid-v1-salt";
|
||||
|
||||
/// ML-KEM-768 ciphertext size in bytes.
|
||||
const MLKEM_CT_LEN: usize = 1088;
|
||||
|
||||
@@ -164,7 +167,8 @@ impl HybridKeypair {
|
||||
if bytes.len() != HYBRID_PRIVATE_KEY_LEN {
|
||||
return Err(HybridKemError::TooShort(bytes.len()));
|
||||
}
|
||||
let x25519_sk = StaticSecret::from(<[u8; 32]>::try_from(&bytes[0..32]).unwrap());
|
||||
let x25519_sk = StaticSecret::from(<[u8; 32]>::try_from(&bytes[0..32])
|
||||
.expect("slice is exactly 32 bytes (guaranteed by HYBRID_PRIVATE_KEY_LEN check)"));
|
||||
let x25519_pk = X25519Public::from(&x25519_sk);
|
||||
|
||||
let mlkem_dk_arr = Array::try_from(&bytes[32..32 + MLKEM_DK_LEN])
|
||||
@@ -247,10 +251,15 @@ impl HybridPublicKey {
|
||||
|
||||
/// Encrypt `plaintext` to `recipient_pk` using X25519 + ML-KEM-768 hybrid KEM.
|
||||
///
|
||||
/// `info` is optional HPKE context info incorporated into key derivation.
|
||||
/// `aad` is optional additional authenticated data bound to the AEAD ciphertext.
|
||||
///
|
||||
/// Returns the complete hybrid envelope as a byte vector.
|
||||
pub fn hybrid_encrypt(
|
||||
recipient_pk: &HybridPublicKey,
|
||||
plaintext: &[u8],
|
||||
info: &[u8],
|
||||
aad: &[u8],
|
||||
) -> Result<Vec<u8>, HybridKemError> {
|
||||
// 1. Ephemeral X25519 DH
|
||||
let eph_secret = EphemeralSecret::random_from_rng(OsRng);
|
||||
@@ -266,18 +275,19 @@ pub fn hybrid_encrypt(
|
||||
.encapsulate(&mut OsRng)
|
||||
.map_err(|_| HybridKemError::EncryptionFailed)?;
|
||||
|
||||
// 3. Derive AEAD key from combined shared secrets
|
||||
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice());
|
||||
// 3. Derive AEAD key from combined shared secrets (with caller info for context binding)
|
||||
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), info);
|
||||
|
||||
// Generate a random 12-byte nonce (not derived from HKDF).
|
||||
let mut nonce_bytes = [0u8; 12];
|
||||
OsRng.fill_bytes(&mut nonce_bytes);
|
||||
let aead_nonce = *Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
// 4. AEAD encrypt
|
||||
// 4. AEAD encrypt with caller-supplied AAD
|
||||
let cipher = ChaCha20Poly1305::new(&aead_key);
|
||||
let aead_payload = chacha20poly1305::aead::Payload { msg: plaintext, aad };
|
||||
let ct = cipher
|
||||
.encrypt(&aead_nonce, plaintext)
|
||||
.encrypt(&aead_nonce, aead_payload)
|
||||
.map_err(|_| HybridKemError::EncryptionFailed)?;
|
||||
|
||||
// 5. Assemble envelope: version || x25519_eph_pk || mlkem_ct || nonce || aead_ct
|
||||
@@ -292,7 +302,14 @@ pub fn hybrid_encrypt(
|
||||
}
|
||||
|
||||
/// Decrypt a hybrid envelope using the recipient's private key.
|
||||
pub fn hybrid_decrypt(keypair: &HybridKeypair, envelope: &[u8]) -> Result<Vec<u8>, HybridKemError> {
|
||||
///
|
||||
/// `info` and `aad` must match what was passed to `hybrid_encrypt`.
|
||||
pub fn hybrid_decrypt(
|
||||
keypair: &HybridKeypair,
|
||||
envelope: &[u8],
|
||||
info: &[u8],
|
||||
aad: &[u8],
|
||||
) -> Result<Vec<u8>, HybridKemError> {
|
||||
if envelope.len() < HEADER_LEN + 16 {
|
||||
// 16 = minimum AEAD tag
|
||||
return Err(HybridKemError::TooShort(envelope.len()));
|
||||
@@ -334,13 +351,14 @@ pub fn hybrid_decrypt(keypair: &HybridKeypair, envelope: &[u8]) -> Result<Vec<u8
|
||||
.decapsulate(&mlkem_ct_arr)
|
||||
.map_err(|_| HybridKemError::MlKemDecapsFailed)?;
|
||||
|
||||
// 3. Derive AEAD key
|
||||
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice());
|
||||
// 3. Derive AEAD key (with caller info for context binding)
|
||||
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), info);
|
||||
|
||||
// 4. Decrypt
|
||||
// 4. Decrypt with caller-supplied AAD
|
||||
let cipher = ChaCha20Poly1305::new(&aead_key);
|
||||
let aead_payload = chacha20poly1305::aead::Payload { msg: aead_ct, aad };
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, aead_ct)
|
||||
.decrypt(nonce, aead_payload)
|
||||
.map_err(|_| HybridKemError::DecryptionFailed)?;
|
||||
|
||||
Ok(plaintext)
|
||||
@@ -366,8 +384,9 @@ pub fn hybrid_encapsulate_only(
|
||||
.encapsulate(&mut OsRng)
|
||||
.map_err(|_| HybridKemError::EncryptionFailed)?;
|
||||
|
||||
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice());
|
||||
let shared_secret = aead_key.as_slice().try_into().unwrap();
|
||||
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), b"");
|
||||
let shared_secret: [u8; 32] = aead_key.as_slice().try_into()
|
||||
.expect("AEAD key is always exactly 32 bytes");
|
||||
|
||||
let mut kem_output = Vec::with_capacity(HYBRID_KEM_OUTPUT_LEN);
|
||||
kem_output.push(HYBRID_VERSION);
|
||||
@@ -390,7 +409,8 @@ pub fn hybrid_decapsulate_only(
|
||||
return Err(HybridKemError::UnsupportedVersion(kem_output[0]));
|
||||
}
|
||||
|
||||
let eph_pk_bytes: [u8; 32] = kem_output[1..33].try_into().unwrap();
|
||||
let eph_pk_bytes: [u8; 32] = kem_output[1..33].try_into()
|
||||
.expect("slice is exactly 32 bytes (guaranteed by HYBRID_KEM_OUTPUT_LEN check)");
|
||||
let eph_pk = X25519Public::from(eph_pk_bytes);
|
||||
let x25519_ss = keypair.x25519_sk.diffie_hellman(&eph_pk);
|
||||
|
||||
@@ -401,8 +421,9 @@ pub fn hybrid_decapsulate_only(
|
||||
.decapsulate(&mlkem_ct_arr)
|
||||
.map_err(|_| HybridKemError::MlKemDecapsFailed)?;
|
||||
|
||||
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice());
|
||||
Ok(aead_key.as_slice().try_into().unwrap())
|
||||
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), b"");
|
||||
Ok(aead_key.as_slice().try_into()
|
||||
.expect("AEAD key is always exactly 32 bytes"))
|
||||
}
|
||||
|
||||
/// Export a secret from shared secret (MLS HPKE exporter compatibility).
|
||||
@@ -412,7 +433,7 @@ pub fn hybrid_export(
|
||||
exporter_context: &[u8],
|
||||
length: usize,
|
||||
) -> Vec<u8> {
|
||||
let hk = Hkdf::<Sha256>::new(None, shared_secret);
|
||||
let hk = Hkdf::<Sha256>::new(Some(HKDF_SALT), shared_secret);
|
||||
let mut out = vec![0u8; length];
|
||||
hk.expand(exporter_context, &mut out).expect("valid length");
|
||||
out
|
||||
@@ -420,18 +441,26 @@ pub fn hybrid_export(
|
||||
|
||||
/// Derive AEAD key from the combined X25519 + ML-KEM shared secrets.
|
||||
///
|
||||
/// `extra_info` is optional caller-supplied context (e.g. HPKE `info`) that is
|
||||
/// appended to the domain-separation label for additional binding.
|
||||
///
|
||||
/// The nonce is generated randomly per-encryption rather than derived from
|
||||
/// HKDF, preventing nonce reuse when the same shared secret is (accidentally)
|
||||
/// used more than once.
|
||||
fn derive_aead_key(x25519_ss: &[u8], mlkem_ss: &[u8]) -> Key {
|
||||
fn derive_aead_key(x25519_ss: &[u8], mlkem_ss: &[u8], extra_info: &[u8]) -> Key {
|
||||
let mut ikm = Zeroizing::new(vec![0u8; x25519_ss.len() + mlkem_ss.len()]);
|
||||
ikm[..x25519_ss.len()].copy_from_slice(x25519_ss);
|
||||
ikm[x25519_ss.len()..].copy_from_slice(mlkem_ss);
|
||||
|
||||
let hk = Hkdf::<Sha256>::new(None, &ikm);
|
||||
let hk = Hkdf::<Sha256>::new(Some(HKDF_SALT), &ikm);
|
||||
|
||||
// Combine domain-separation label with caller-supplied context.
|
||||
let mut info = Vec::with_capacity(HKDF_INFO.len() + extra_info.len());
|
||||
info.extend_from_slice(HKDF_INFO);
|
||||
info.extend_from_slice(extra_info);
|
||||
|
||||
let mut key_bytes = Zeroizing::new([0u8; 32]);
|
||||
hk.expand(HKDF_INFO, &mut *key_bytes)
|
||||
hk.expand(&info, &mut *key_bytes)
|
||||
.expect("32 bytes is valid HKDF-SHA256 output length");
|
||||
|
||||
*Key::from_slice(&*key_bytes)
|
||||
@@ -457,21 +486,39 @@ mod tests {
|
||||
let pk = kp.public_key();
|
||||
let plaintext = b"hello post-quantum world!";
|
||||
|
||||
let envelope = hybrid_encrypt(&pk, plaintext).unwrap();
|
||||
let recovered = hybrid_decrypt(&kp, &envelope).unwrap();
|
||||
let envelope = hybrid_encrypt(&pk, plaintext, b"", b"").unwrap();
|
||||
let recovered = hybrid_decrypt(&kp, &envelope, b"", b"").unwrap();
|
||||
|
||||
assert_eq!(recovered, plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_decrypt_with_info_aad() {
|
||||
let kp = HybridKeypair::generate();
|
||||
let pk = kp.public_key();
|
||||
let plaintext = b"context-bound payload";
|
||||
let info = b"mls epoch 42";
|
||||
let aad = b"group-id-abc";
|
||||
|
||||
let envelope = hybrid_encrypt(&pk, plaintext, info, aad).unwrap();
|
||||
let recovered = hybrid_decrypt(&kp, &envelope, info, aad).unwrap();
|
||||
assert_eq!(recovered, plaintext);
|
||||
|
||||
// Mismatched info must fail
|
||||
assert!(hybrid_decrypt(&kp, &envelope, b"wrong info", aad).is_err());
|
||||
// Mismatched aad must fail
|
||||
assert!(hybrid_decrypt(&kp, &envelope, info, b"wrong aad").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_key_decryption_fails() {
|
||||
let kp_sender_target = HybridKeypair::generate();
|
||||
let kp_wrong = HybridKeypair::generate();
|
||||
|
||||
let pk = kp_sender_target.public_key();
|
||||
let envelope = hybrid_encrypt(&pk, b"secret").unwrap();
|
||||
let envelope = hybrid_encrypt(&pk, b"secret", b"", b"").unwrap();
|
||||
|
||||
let result = hybrid_decrypt(&kp_wrong, &envelope);
|
||||
let result = hybrid_decrypt(&kp_wrong, &envelope, b"", b"");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -480,12 +527,12 @@ mod tests {
|
||||
let kp = HybridKeypair::generate();
|
||||
let pk = kp.public_key();
|
||||
|
||||
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
|
||||
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
|
||||
let last = envelope.len() - 1;
|
||||
envelope[last] ^= 0x01;
|
||||
|
||||
assert!(matches!(
|
||||
hybrid_decrypt(&kp, &envelope),
|
||||
hybrid_decrypt(&kp, &envelope, b"", b""),
|
||||
Err(HybridKemError::DecryptionFailed)
|
||||
));
|
||||
}
|
||||
@@ -495,11 +542,11 @@ mod tests {
|
||||
let kp = HybridKeypair::generate();
|
||||
let pk = kp.public_key();
|
||||
|
||||
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
|
||||
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
|
||||
// Flip a byte in the ML-KEM ciphertext region (starts at offset 33)
|
||||
envelope[40] ^= 0xFF;
|
||||
|
||||
assert!(hybrid_decrypt(&kp, &envelope).is_err());
|
||||
assert!(hybrid_decrypt(&kp, &envelope, b"", b"").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -507,11 +554,11 @@ mod tests {
|
||||
let kp = HybridKeypair::generate();
|
||||
let pk = kp.public_key();
|
||||
|
||||
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
|
||||
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
|
||||
// Flip a byte in the X25519 ephemeral pk region (offset 1..33)
|
||||
envelope[5] ^= 0xFF;
|
||||
|
||||
assert!(hybrid_decrypt(&kp, &envelope).is_err());
|
||||
assert!(hybrid_decrypt(&kp, &envelope, b"", b"").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -519,11 +566,11 @@ mod tests {
|
||||
let kp = HybridKeypair::generate();
|
||||
let pk = kp.public_key();
|
||||
|
||||
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
|
||||
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
|
||||
envelope[0] = 0xFF;
|
||||
|
||||
assert!(matches!(
|
||||
hybrid_decrypt(&kp, &envelope),
|
||||
hybrid_decrypt(&kp, &envelope, b"", b""),
|
||||
Err(HybridKemError::UnsupportedVersion(0xFF))
|
||||
));
|
||||
}
|
||||
@@ -532,7 +579,7 @@ mod tests {
|
||||
fn envelope_too_short_rejected() {
|
||||
let kp = HybridKeypair::generate();
|
||||
assert!(matches!(
|
||||
hybrid_decrypt(&kp, &[0x01; 10]),
|
||||
hybrid_decrypt(&kp, &[0x01; 10], b"", b""),
|
||||
Err(HybridKemError::TooShort(10))
|
||||
));
|
||||
}
|
||||
@@ -548,8 +595,8 @@ mod tests {
|
||||
|
||||
// Verify restored keypair can decrypt
|
||||
let pk = kp.public_key();
|
||||
let ct = hybrid_encrypt(&pk, b"test").unwrap();
|
||||
let pt = hybrid_decrypt(&restored, &ct).unwrap();
|
||||
let ct = hybrid_encrypt(&pk, b"test", b"", b"").unwrap();
|
||||
let pt = hybrid_decrypt(&restored, &ct, b"", b"").unwrap();
|
||||
assert_eq!(pt, b"test");
|
||||
}
|
||||
|
||||
@@ -570,8 +617,8 @@ mod tests {
|
||||
let pk = kp.public_key();
|
||||
let plaintext = vec![0xAB; 50_000]; // 50 KB
|
||||
|
||||
let envelope = hybrid_encrypt(&pk, &plaintext).unwrap();
|
||||
let recovered = hybrid_decrypt(&kp, &envelope).unwrap();
|
||||
let envelope = hybrid_encrypt(&pk, &plaintext, b"", b"").unwrap();
|
||||
let recovered = hybrid_decrypt(&kp, &envelope, b"", b"").unwrap();
|
||||
|
||||
assert_eq!(recovered, plaintext);
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ impl DiskKeyStore {
|
||||
let Some(path) = &self.path else {
|
||||
return Ok(());
|
||||
};
|
||||
let values = self.values.read().unwrap();
|
||||
let values = self.values.read().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?;
|
||||
let bytes = bincode::serialize(&*values).map_err(|_| DiskKeyStoreError::Serialization)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?;
|
||||
@@ -82,21 +82,24 @@ impl OpenMlsKeyStore for DiskKeyStore {
|
||||
|
||||
fn store<V: MlsEntity>(&self, k: &[u8], v: &V) -> Result<(), Self::Error> {
|
||||
let value = serde_json::to_vec(v).map_err(|_| DiskKeyStoreError::Serialization)?;
|
||||
let mut values = self.values.write().unwrap();
|
||||
let mut values = self.values.write().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?;
|
||||
values.insert(k.to_vec(), value);
|
||||
drop(values);
|
||||
self.flush()
|
||||
}
|
||||
|
||||
fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V> {
|
||||
let values = self.values.read().unwrap();
|
||||
let values = match self.values.read() {
|
||||
Ok(v) => v,
|
||||
Err(_) => return None,
|
||||
};
|
||||
values
|
||||
.get(k)
|
||||
.and_then(|bytes| serde_json::from_slice(bytes).ok())
|
||||
}
|
||||
|
||||
fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
|
||||
let mut values = self.values.write().unwrap();
|
||||
let mut values = self.values.write().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?;
|
||||
values.remove(k);
|
||||
drop(values);
|
||||
self.flush()
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
mod app_message;
|
||||
mod error;
|
||||
mod group;
|
||||
pub mod hybrid_crypto;
|
||||
pub mod hybrid_kem;
|
||||
mod hybrid_crypto;
|
||||
mod hybrid_kem;
|
||||
mod identity;
|
||||
mod keypackage;
|
||||
mod keystore;
|
||||
|
||||
2244
crates/quicnprotochat-gui/gen/schemas/macOS-schema.json
Normal file
2244
crates/quicnprotochat-gui/gen/schemas/macOS-schema.json
Normal file
File diff suppressed because it is too large
Load Diff
13
crates/quicnprotochat-server/migrations/003_channels.sql
Normal file
13
crates/quicnprotochat-server/migrations/003_channels.sql
Normal file
@@ -0,0 +1,13 @@
|
||||
-- Migration 003: 1:1 DM channels.
|
||||
-- channel_id is 16 bytes (UUID); member_a and member_b are identity keys in sorted order.
|
||||
-- Unique on (member_a, member_b) prevents duplicate channels between the same pair.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS channels (
|
||||
channel_id BLOB PRIMARY KEY,
|
||||
member_a BLOB NOT NULL,
|
||||
member_b BLOB NOT NULL,
|
||||
UNIQUE(member_a, member_b)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_channels_members
|
||||
ON channels(member_a, member_b);
|
||||
@@ -17,6 +17,7 @@ pub const RATE_LIMIT_MAX_ENQUEUES: u32 = 100;
|
||||
pub struct AuthConfig {
|
||||
pub required_token: Option<Vec<u8>>,
|
||||
/// When true, a valid bearer token (no session) is accepted and the request's identity/key is used (dev/e2e only).
|
||||
/// CLI flag: --allow-insecure-auth / QUICNPROTOCHAT_ALLOW_INSECURE_AUTH.
|
||||
pub allow_insecure_identity_from_request: bool,
|
||||
}
|
||||
|
||||
@@ -59,10 +60,13 @@ pub struct AuthContext {
|
||||
}
|
||||
|
||||
pub fn current_timestamp() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
|
||||
Ok(d) => d.as_secs(),
|
||||
Err(_) => {
|
||||
tracing::warn!("system time is before UNIX_EPOCH; using 0 for session/rate-limit timestamps");
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_rate_limit(
|
||||
@@ -174,7 +178,7 @@ pub fn require_identity<'a>(auth_ctx: &'a AuthContext) -> Result<&'a [u8], capnp
|
||||
|
||||
pub fn require_identity_match(auth_ctx: &AuthContext, expected: &[u8]) -> Result<(), capnp::Error> {
|
||||
let ik = require_identity(auth_ctx)?;
|
||||
if ik != expected {
|
||||
if ik.len() != expected.len() || !bool::from(ik.ct_eq(expected)) {
|
||||
return Err(crate::error_codes::coded_error(
|
||||
E016_IDENTITY_MISMATCH,
|
||||
"access token is bound to a different identity",
|
||||
|
||||
@@ -24,6 +24,8 @@ pub const E018_USER_EXISTS: &str = "E018";
|
||||
pub const E019_NO_PENDING_LOGIN: &str = "E019";
|
||||
pub const E020_BAD_PARAMS: &str = "E020";
|
||||
pub const E021_CIPHERSUITE_NOT_ALLOWED: &str = "E021";
|
||||
pub const E022_CHANNEL_ACCESS_DENIED: &str = "E022";
|
||||
pub const E023_CHANNEL_NOT_FOUND: &str = "E023";
|
||||
|
||||
/// Build a `capnp::Error::failed()` with the structured code prefix.
|
||||
pub fn coded_error(code: &str, msg: impl std::fmt::Display) -> capnp::Error {
|
||||
|
||||
@@ -162,9 +162,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
.parse()
|
||||
.context("--listen must be host:port")?;
|
||||
|
||||
let server_config = build_server_config(&effective.tls_cert, &effective.tls_key, production)
|
||||
let mut server_config = build_server_config(&effective.tls_cert, &effective.tls_key, production)
|
||||
.context("failed to build TLS/QUIC server config")?;
|
||||
|
||||
// Harden QUIC transport: idle timeout, limit stream concurrency.
|
||||
let mut transport = quinn::TransportConfig::default();
|
||||
transport.max_idle_timeout(Some(std::time::Duration::from_secs(300).try_into().unwrap()));
|
||||
transport.max_concurrent_bidi_streams(1u32.into());
|
||||
transport.max_concurrent_uni_streams(0u32.into());
|
||||
server_config.transport_config(Arc::new(transport));
|
||||
|
||||
// Shared storage — persisted to disk for restart safety.
|
||||
let store: Arc<dyn Store> = match effective.store_backend.as_str() {
|
||||
"sql" => {
|
||||
@@ -223,6 +230,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
Arc::clone(&pending_logins),
|
||||
Arc::clone(&rate_limits),
|
||||
Arc::clone(&store),
|
||||
Arc::clone(&waiters),
|
||||
);
|
||||
|
||||
let endpoint = Endpoint::server(server_config, listen)?;
|
||||
|
||||
@@ -19,6 +19,16 @@ fn storage_err(err: StorageError) -> capnp::Error {
|
||||
coded_error(E009_STORAGE_ERROR, err)
|
||||
}
|
||||
|
||||
/// Parse username from Cap'n Proto reader; requires valid UTF-8.
|
||||
fn parse_username_param(
|
||||
result: Result<capnp::text::Reader<'_>, capnp::Error>,
|
||||
) -> Result<String, capnp::Error> {
|
||||
let reader = result.map_err(|e| coded_error(E020_BAD_PARAMS, e))?;
|
||||
reader
|
||||
.to_string()
|
||||
.map_err(|_| coded_error(E020_BAD_PARAMS, "username must be valid UTF-8"))
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
pub fn handle_opaque_login_start(
|
||||
&mut self,
|
||||
@@ -29,9 +39,9 @@ impl NodeServiceImpl {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match p.get_username() {
|
||||
Ok(v) => v.to_string().unwrap_or_default().to_string(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let request_bytes = match p.get_request() {
|
||||
Ok(v) => v.to_vec(),
|
||||
@@ -42,6 +52,14 @@ impl NodeServiceImpl {
|
||||
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
|
||||
}
|
||||
|
||||
// Check for existing recent pending login before expensive OPAQUE/storage work (DoS mitigation).
|
||||
if let Some(existing) = self.pending_logins.get(&username) {
|
||||
let age = current_timestamp().saturating_sub(existing.created_at);
|
||||
if age < 60 {
|
||||
return Promise::err(coded_error(E010_OPAQUE_ERROR, "login already in progress"));
|
||||
}
|
||||
}
|
||||
|
||||
let credential_request = match CredentialRequest::<OpaqueSuite>::deserialize(&request_bytes) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
@@ -62,9 +80,7 @@ impl NodeServiceImpl {
|
||||
))
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E010_OPAQUE_ERROR, "user not registered"))
|
||||
}
|
||||
Ok(None) => None,
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
|
||||
@@ -111,9 +127,9 @@ impl NodeServiceImpl {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match p.get_username() {
|
||||
Ok(v) => v.to_string().unwrap_or_default().to_string(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let request_bytes = match p.get_request() {
|
||||
Ok(v) => v.to_vec(),
|
||||
@@ -171,9 +187,9 @@ impl NodeServiceImpl {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match p.get_username() {
|
||||
Ok(v) => v.to_string().unwrap_or_default().to_string(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let finalization_bytes = match p.get_finalization() {
|
||||
Ok(v) => v.to_vec(),
|
||||
@@ -278,9 +294,9 @@ impl NodeServiceImpl {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match p.get_username() {
|
||||
Ok(v) => v.to_string().unwrap_or_default().to_string(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let upload_bytes = match p.get_upload() {
|
||||
Ok(v) => v.to_vec(),
|
||||
@@ -326,12 +342,18 @@ impl NodeServiceImpl {
|
||||
let password_file = ServerRegistration::<OpaqueSuite>::finish(upload);
|
||||
let record_bytes = password_file.serialize().to_vec();
|
||||
|
||||
if let Err(e) = self
|
||||
match self
|
||||
.store
|
||||
.store_user_record(&username, record_bytes)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
return Promise::err(e);
|
||||
Ok(()) => {}
|
||||
Err(crate::storage::StorageError::DuplicateUser(_)) => {
|
||||
return Promise::err(coded_error(
|
||||
E018_USER_EXISTS,
|
||||
format!("user '{}' already registered", username),
|
||||
))
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
}
|
||||
|
||||
if !identity_key.is_empty() {
|
||||
|
||||
62
crates/quicnprotochat-server/src/node_service/channel_ops.rs
Normal file
62
crates/quicnprotochat-server/src/node_service/channel_ops.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
//! createChannel RPC: create or look up a 1:1 DM channel.
|
||||
|
||||
use capnp::capability::Promise;
|
||||
use quicnprotochat_proto::node_capnp::node_service;
|
||||
|
||||
use crate::auth::{coded_error, require_identity, validate_auth_context};
|
||||
use crate::error_codes::*;
|
||||
use crate::storage::StorageError;
|
||||
|
||||
use super::NodeServiceImpl;
|
||||
|
||||
fn storage_err(err: StorageError) -> capnp::Error {
|
||||
coded_error(E009_STORAGE_ERROR, err)
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
pub fn handle_create_channel(
|
||||
&mut self,
|
||||
params: node_service::CreateChannelParams,
|
||||
mut results: node_service::CreateChannelResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let peer_key = match p.get_peer_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
let identity = match require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if peer_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("peerKey must be exactly 32 bytes, got {}", peer_key.len()),
|
||||
));
|
||||
}
|
||||
|
||||
if identity == peer_key {
|
||||
return Promise::err(coded_error(
|
||||
E020_BAD_PARAMS,
|
||||
"peerKey must not equal caller identity",
|
||||
));
|
||||
}
|
||||
|
||||
let channel_id = match self.store.create_channel(&identity, &peer_key) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
|
||||
results.get().set_channel_id(&channel_id);
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
@@ -77,10 +77,10 @@ impl NodeServiceImpl {
|
||||
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
|
||||
));
|
||||
}
|
||||
if version != CURRENT_WIRE_VERSION {
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -102,6 +102,31 @@ impl NodeServiceImpl {
|
||||
}
|
||||
}
|
||||
|
||||
// DM channel authz: channel_id.len() == 16 means a created channel; caller and recipient must be the two members.
|
||||
if channel_id.len() == 16 {
|
||||
let members = match self.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
let caller = match crate::auth::require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = caller == a.as_slice() || caller == b.as_slice();
|
||||
let recipient_other = (recipient_key == *a && caller == b.as_slice())
|
||||
|| (recipient_key == *b && caller == a.as_slice());
|
||||
if !caller_in || !recipient_other {
|
||||
return Promise::err(coded_error(
|
||||
E022_CHANNEL_ACCESS_DENIED,
|
||||
"caller or recipient not a member of this channel",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
match self.store.queue_depth(&recipient_key, &channel_id) {
|
||||
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
|
||||
return Promise::err(coded_error(
|
||||
@@ -183,10 +208,10 @@ impl NodeServiceImpl {
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version != CURRENT_WIRE_VERSION {
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -203,6 +228,30 @@ impl NodeServiceImpl {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if channel_id.len() == 16 {
|
||||
let members = match self.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
let caller = match crate::auth::require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = caller == a.as_slice() || caller == b.as_slice();
|
||||
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|
||||
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
|
||||
if !caller_in || !recipient_other {
|
||||
return Promise::err(coded_error(
|
||||
E022_CHANNEL_ACCESS_DENIED,
|
||||
"caller or recipient not a member of this channel",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let messages = if limit > 0 {
|
||||
match self
|
||||
.store
|
||||
@@ -269,10 +318,10 @@ impl NodeServiceImpl {
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version != CURRENT_WIRE_VERSION {
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -284,6 +333,30 @@ impl NodeServiceImpl {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if channel_id.len() == 16 {
|
||||
let members = match self.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
let caller = match crate::auth::require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = caller == a.as_slice() || caller == b.as_slice();
|
||||
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|
||||
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
|
||||
if !caller_in || !recipient_other {
|
||||
return Promise::err(coded_error(
|
||||
E022_CHANNEL_ACCESS_DENIED,
|
||||
"caller or recipient not a member of this channel",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let store = Arc::clone(&self.store);
|
||||
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = self.waiters.clone();
|
||||
|
||||
@@ -315,4 +388,232 @@ impl NodeServiceImpl {
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handle_peek(
|
||||
&mut self,
|
||||
params: node_service::PeekParams,
|
||||
mut results: node_service::PeekResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
let version = p.get_version();
|
||||
let limit = p.get_limit();
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = require_identity_or_request(
|
||||
&auth_ctx,
|
||||
&recipient_key,
|
||||
self.auth_cfg.allow_insecure_identity_from_request,
|
||||
) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let messages = match self
|
||||
.store
|
||||
.peek(&recipient_key, &channel_id, limit as usize)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(m) => m,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
recipient_prefix = %fmt_hex(&recipient_key[..4]),
|
||||
count = messages.len(),
|
||||
"audit: peek"
|
||||
);
|
||||
|
||||
let mut list = results.get().init_payloads(messages.len() as u32);
|
||||
for (i, (seq, data)) in messages.iter().enumerate() {
|
||||
let mut entry = list.reborrow().get(i as u32);
|
||||
entry.set_seq(*seq);
|
||||
entry.set_data(data);
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_ack(
|
||||
&mut self,
|
||||
params: node_service::AckParams,
|
||||
_results: node_service::AckResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
let version = p.get_version();
|
||||
let seq_up_to = p.get_seq_up_to();
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = require_identity_or_request(
|
||||
&auth_ctx,
|
||||
&recipient_key,
|
||||
self.auth_cfg.allow_insecure_identity_from_request,
|
||||
) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
match self
|
||||
.store
|
||||
.ack(&recipient_key, &channel_id, seq_up_to)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(removed) => {
|
||||
tracing::info!(
|
||||
recipient_prefix = %fmt_hex(&recipient_key[..4]),
|
||||
seq_up_to = seq_up_to,
|
||||
removed = removed,
|
||||
"audit: ack"
|
||||
);
|
||||
}
|
||||
Err(e) => return Promise::err(e),
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_batch_enqueue(
|
||||
&mut self,
|
||||
params: node_service::BatchEnqueueParams,
|
||||
mut results: node_service::BatchEnqueueResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let recipient_keys = match p.get_recipient_keys() {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let payload = match p.get_payload() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
let version = p.get_version();
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if payload.is_empty() {
|
||||
return Promise::err(coded_error(E005_PAYLOAD_EMPTY, "payload must not be empty"));
|
||||
}
|
||||
if payload.len() > MAX_PAYLOAD_BYTES {
|
||||
return Promise::err(coded_error(
|
||||
E006_PAYLOAD_TOO_LARGE,
|
||||
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) {
|
||||
tracing::warn!("rate_limit_hit");
|
||||
metrics::record_rate_limit_hit_total();
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let mut seqs = Vec::with_capacity(recipient_keys.len() as usize);
|
||||
for i in 0..recipient_keys.len() {
|
||||
let rk = match recipient_keys.get(i) {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
if rk.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey[{}] must be exactly 32 bytes, got {}", i, rk.len()),
|
||||
));
|
||||
}
|
||||
|
||||
match self.store.queue_depth(&rk, &channel_id) {
|
||||
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
|
||||
return Promise::err(coded_error(
|
||||
E015_QUEUE_FULL,
|
||||
format!("queue depth {} exceeds limit {}", depth, MAX_QUEUE_DEPTH),
|
||||
));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let seq = match self
|
||||
.store
|
||||
.enqueue(&rk, &channel_id, payload.clone())
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(seq) => seq,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
seqs.push(seq);
|
||||
|
||||
metrics::record_enqueue_total();
|
||||
metrics::record_enqueue_bytes(payload.len() as u64);
|
||||
|
||||
crate::auth::waiter(&self.waiters, &rk).notify_waiters();
|
||||
}
|
||||
|
||||
let mut list = results.get().init_seqs(seqs.len() as u32);
|
||||
for (i, seq) in seqs.iter().enumerate() {
|
||||
list.set(i as u32, *seq);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
recipient_count = recipient_keys.len(),
|
||||
payload_len = payload.len(),
|
||||
"audit: batch_enqueue"
|
||||
);
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,4 +256,47 @@ impl NodeServiceImpl {
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_fetch_hybrid_keys(
|
||||
&mut self,
|
||||
params: node_service::FetchHybridKeysParams,
|
||||
mut results: node_service::FetchHybridKeysResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let identity_keys = match p.get_identity_keys() {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
|
||||
if let Err(e) = validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let count = identity_keys.len() as usize;
|
||||
let mut key_data: Vec<Vec<u8>> = Vec::with_capacity(count);
|
||||
for i in 0..identity_keys.len() {
|
||||
let ik = match identity_keys.get(i) {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let pk = match self.store.fetch_hybrid_key(&ik).map_err(storage_err) {
|
||||
Ok(Some(pk)) => pk,
|
||||
Ok(None) => vec![],
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
key_data.push(pk);
|
||||
}
|
||||
|
||||
let mut list = results.get().init_keys(key_data.len() as u32);
|
||||
for (i, pk) in key_data.iter().enumerate() {
|
||||
list.set(i as u32, pk);
|
||||
}
|
||||
|
||||
tracing::debug!(count = count, "batch hybrid key fetch");
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,11 @@ use crate::auth::{
|
||||
};
|
||||
use crate::storage::Store;
|
||||
|
||||
/// Cap'n Proto traversal limit (words). 4 Mi words = 32 MiB; bounds DoS from deeply nested or large messages.
|
||||
const CAPNP_TRAVERSAL_LIMIT_WORDS: usize = 4 * 1024 * 1024;
|
||||
|
||||
mod auth_ops;
|
||||
mod channel_ops;
|
||||
mod delivery;
|
||||
mod key_ops;
|
||||
mod p2p_ops;
|
||||
@@ -132,6 +136,46 @@ impl node_service::Server for NodeServiceImpl {
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_resolve_endpoint(params, results)
|
||||
}
|
||||
|
||||
fn peek(
|
||||
&mut self,
|
||||
params: node_service::PeekParams,
|
||||
results: node_service::PeekResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_peek(params, results)
|
||||
}
|
||||
|
||||
fn ack(
|
||||
&mut self,
|
||||
params: node_service::AckParams,
|
||||
results: node_service::AckResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_ack(params, results)
|
||||
}
|
||||
|
||||
fn fetch_hybrid_keys(
|
||||
&mut self,
|
||||
params: node_service::FetchHybridKeysParams,
|
||||
results: node_service::FetchHybridKeysResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_fetch_hybrid_keys(params, results)
|
||||
}
|
||||
|
||||
fn batch_enqueue(
|
||||
&mut self,
|
||||
params: node_service::BatchEnqueueParams,
|
||||
results: node_service::BatchEnqueueResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_batch_enqueue(params, results)
|
||||
}
|
||||
|
||||
fn create_channel(
|
||||
&mut self,
|
||||
params: node_service::CreateChannelParams,
|
||||
results: node_service::CreateChannelResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_create_channel(params, results)
|
||||
}
|
||||
}
|
||||
|
||||
pub const CURRENT_WIRE_VERSION: u16 = 1;
|
||||
@@ -193,11 +237,13 @@ pub async fn handle_node_connection(
|
||||
.map_err(|e| anyhow::anyhow!("failed to accept bi stream: {e}"))?;
|
||||
let (reader, writer) = (recv.compat(), send.compat_write());
|
||||
|
||||
let mut reader_opts = capnp::message::ReaderOptions::new();
|
||||
reader_opts.traversal_limit_in_words(Some(CAPNP_TRAVERSAL_LIMIT_WORDS));
|
||||
let network = capnp_rpc::twoparty::VatNetwork::new(
|
||||
reader,
|
||||
writer,
|
||||
capnp_rpc::rpc_twoparty_capnp::Side::Server,
|
||||
Default::default(),
|
||||
reader_opts,
|
||||
);
|
||||
|
||||
let service: node_service::Client = capnp_rpc::new_client(NodeServiceImpl::new(
|
||||
@@ -223,6 +269,7 @@ pub fn spawn_cleanup_task(
|
||||
pending_logins: Arc<DashMap<String, PendingLogin>>,
|
||||
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
|
||||
store: Arc<dyn Store>,
|
||||
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(60));
|
||||
@@ -234,6 +281,29 @@ pub fn spawn_cleanup_task(
|
||||
pending_logins.retain(|_, pl| now - pl.created_at < PENDING_LOGIN_TTL_SECS);
|
||||
rate_limits.retain(|_, entry| now - entry.window_start < RATE_LIMIT_WINDOW_SECS * 2);
|
||||
|
||||
// Bound map sizes to prevent unbounded growth from malicious clients.
|
||||
const MAX_SESSIONS: usize = 100_000;
|
||||
const MAX_WAITERS: usize = 100_000;
|
||||
if sessions.len() > MAX_SESSIONS {
|
||||
let overflow = sessions.len() - MAX_SESSIONS;
|
||||
let mut entries: Vec<_> = sessions
|
||||
.iter()
|
||||
.map(|e| (e.key().clone(), e.expires_at))
|
||||
.collect();
|
||||
entries.sort_by_key(|(_, exp)| *exp);
|
||||
for (key, _) in entries.into_iter().take(overflow) {
|
||||
sessions.remove(&key);
|
||||
}
|
||||
}
|
||||
if waiters.len() > MAX_WAITERS {
|
||||
let overflow = waiters.len() - MAX_WAITERS;
|
||||
let keys: Vec<_> =
|
||||
waiters.iter().take(overflow).map(|e| e.key().clone()).collect();
|
||||
for key in keys {
|
||||
waiters.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
match store.gc_expired_messages(MESSAGE_TTL_SECS) {
|
||||
Ok(n) if n > 0 => {
|
||||
tracing::debug!(expired = n, "garbage collected expired messages")
|
||||
|
||||
@@ -14,6 +14,7 @@ fn storage_err(err: StorageError) -> capnp::Error {
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
/// Health check: unauthenticated by design for liveness probes and load balancers.
|
||||
pub fn handle_health(
|
||||
&mut self,
|
||||
_params: node_service::HealthParams,
|
||||
|
||||
@@ -3,17 +3,19 @@
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use rand::RngCore;
|
||||
use rusqlite::{params, Connection};
|
||||
|
||||
use crate::storage::{StorageError, Store};
|
||||
|
||||
/// Schema version after introducing the migration runner (existing DBs had 1).
|
||||
const SCHEMA_VERSION: i32 = 3;
|
||||
const SCHEMA_VERSION: i32 = 4;
|
||||
|
||||
/// Migrations: (migration_number, SQL). Files named NNN_name.sql, applied in order when N > user_version.
|
||||
const MIGRATIONS: &[(i32, &str)] = &[
|
||||
(1, include_str!("../migrations/001_initial.sql")),
|
||||
(3, include_str!("../migrations/002_add_seq.sql")),
|
||||
(4, include_str!("../migrations/003_channels.sql")),
|
||||
];
|
||||
|
||||
/// Runs pending migrations on an open connection: applies any migration whose number is greater
|
||||
@@ -305,10 +307,17 @@ impl Store for SqlStore {
|
||||
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO users (username, opaque_record) VALUES (?1, ?2)",
|
||||
"INSERT INTO users (username, opaque_record) VALUES (?1, ?2)",
|
||||
params![username, record],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
.map_err(|e| {
|
||||
if let rusqlite::Error::SqliteFailure(ref err, _) = &e {
|
||||
if err.code == rusqlite::ErrorCode::ConstraintViolation {
|
||||
return StorageError::DuplicateUser(username.to_string());
|
||||
}
|
||||
}
|
||||
StorageError::Db(e.to_string())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -360,6 +369,57 @@ impl Store for SqlStore {
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn peek(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
|
||||
let sql = if limit == 0 {
|
||||
"SELECT seq, payload FROM deliveries
|
||||
WHERE recipient_key = ?1 AND channel_id = ?2
|
||||
ORDER BY seq ASC".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"SELECT seq, payload FROM deliveries
|
||||
WHERE recipient_key = ?1 AND channel_id = ?2
|
||||
ORDER BY seq ASC
|
||||
LIMIT {}",
|
||||
limit
|
||||
)
|
||||
};
|
||||
|
||||
let mut stmt = conn.prepare(&sql).map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
let rows: Vec<(i64, Vec<u8>)> = stmt
|
||||
.query_map(params![recipient_key, channel_id], |row| {
|
||||
Ok((row.get(0)?, row.get(1)?))
|
||||
})
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
Ok(rows.into_iter().map(|(seq, payload)| (seq as u64, payload)).collect())
|
||||
}
|
||||
|
||||
fn ack(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
seq_up_to: u64,
|
||||
) -> Result<usize, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let deleted = conn
|
||||
.execute(
|
||||
"DELETE FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND seq <= ?3",
|
||||
params![recipient_key, channel_id, seq_up_to as i64],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
fn publish_endpoint(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
@@ -384,6 +444,45 @@ impl Store for SqlStore {
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError> {
|
||||
let (a, b) = if member_a < member_b {
|
||||
(member_a.to_vec(), member_b.to_vec())
|
||||
} else {
|
||||
(member_b.to_vec(), member_a.to_vec())
|
||||
};
|
||||
let conn = self.lock_conn()?;
|
||||
let existing: Option<Vec<u8>> = conn
|
||||
.query_row(
|
||||
"SELECT channel_id FROM channels WHERE member_a = ?1 AND member_b = ?2",
|
||||
params![a, b],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
if let Some(id) = existing {
|
||||
return Ok(id);
|
||||
}
|
||||
let mut channel_id = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut channel_id);
|
||||
conn.execute(
|
||||
"INSERT INTO channels (channel_id, member_a, member_b) VALUES (?1, ?2, ?3)",
|
||||
params![channel_id.as_slice(), a, b],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(channel_id.to_vec())
|
||||
}
|
||||
|
||||
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
conn.query_row(
|
||||
"SELECT member_a, member_b FROM channels WHERE channel_id = ?1",
|
||||
params![channel_id],
|
||||
|row| Ok((row.get::<_, Vec<u8>>(0)?, row.get::<_, Vec<u8>>(1)?)),
|
||||
)
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience extension for `rusqlite::OptionalExtension`.
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::{
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@@ -16,6 +17,9 @@ pub enum StorageError {
|
||||
Serde,
|
||||
#[error("database error: {0}")]
|
||||
Db(String),
|
||||
/// Unique constraint violation (e.g. user already exists).
|
||||
#[error("duplicate user: {0}")]
|
||||
DuplicateUser(String),
|
||||
}
|
||||
|
||||
fn lock<T>(m: &Mutex<T>) -> Result<std::sync::MutexGuard<'_, T>, StorageError> {
|
||||
@@ -96,12 +100,36 @@ pub trait Store: Send + Sync {
|
||||
/// Retrieve identity key for a user (Fix 2).
|
||||
fn get_user_identity_key(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError>;
|
||||
|
||||
/// Peek at queued messages without removing them (non-destructive).
|
||||
/// Returns `(seq, payload)` pairs ordered by seq.
|
||||
fn peek(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError>;
|
||||
|
||||
/// Acknowledge (remove) all messages with seq <= seq_up_to.
|
||||
fn ack(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
seq_up_to: u64,
|
||||
) -> Result<usize, StorageError>;
|
||||
|
||||
/// Publish a P2P endpoint address for an identity key.
|
||||
fn publish_endpoint(&self, identity_key: &[u8], node_addr: Vec<u8>)
|
||||
-> Result<(), StorageError>;
|
||||
|
||||
/// Resolve a peer's P2P endpoint address.
|
||||
fn resolve_endpoint(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
|
||||
|
||||
/// Create a 1:1 channel between two members. Returns 16-byte channel_id (UUID).
|
||||
/// Members are stored in sorted order for deterministic lookup.
|
||||
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError>;
|
||||
|
||||
/// Get the two members of a channel by channel_id (16 bytes). Returns (member_a, member_b) in sorted order.
|
||||
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError>;
|
||||
}
|
||||
|
||||
// ── ChannelKey ───────────────────────────────────────────────────────────────
|
||||
@@ -154,8 +182,10 @@ pub struct FileBackedStore {
|
||||
setup_path: PathBuf,
|
||||
users_path: PathBuf,
|
||||
identity_keys_path: PathBuf,
|
||||
channels_path: PathBuf,
|
||||
key_packages: Mutex<HashMap<Vec<u8>, VecDeque<Vec<u8>>>>,
|
||||
deliveries: Mutex<QueueMapV3>,
|
||||
channels: Mutex<HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>>,
|
||||
hybrid_keys: Mutex<HashMap<Vec<u8>, Vec<u8>>>,
|
||||
users: Mutex<HashMap<String, Vec<u8>>>,
|
||||
identity_keys: Mutex<HashMap<String, Vec<u8>>>,
|
||||
@@ -174,12 +204,14 @@ impl FileBackedStore {
|
||||
let setup_path = dir.join("server_setup.bin");
|
||||
let users_path = dir.join("users.bin");
|
||||
let identity_keys_path = dir.join("identity_keys.bin");
|
||||
let channels_path = dir.join("channels.bin");
|
||||
|
||||
let key_packages = Mutex::new(Self::load_kp_map(&kp_path)?);
|
||||
let deliveries = Mutex::new(Self::load_delivery_map_v3(&ds_path)?);
|
||||
let hybrid_keys = Mutex::new(Self::load_hybrid_keys(&hk_path)?);
|
||||
let users = Mutex::new(Self::load_users(&users_path)?);
|
||||
let identity_keys = Mutex::new(Self::load_map_string_bytes(&identity_keys_path)?);
|
||||
let channels = Mutex::new(Self::load_channels(&channels_path)?);
|
||||
|
||||
Ok(Self {
|
||||
kp_path,
|
||||
@@ -188,8 +220,10 @@ impl FileBackedStore {
|
||||
setup_path,
|
||||
users_path,
|
||||
identity_keys_path,
|
||||
channels_path,
|
||||
key_packages,
|
||||
deliveries,
|
||||
channels,
|
||||
hybrid_keys,
|
||||
users,
|
||||
identity_keys,
|
||||
@@ -197,6 +231,31 @@ impl FileBackedStore {
|
||||
})
|
||||
}
|
||||
|
||||
fn load_channels(
|
||||
path: &Path,
|
||||
) -> Result<HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>, StorageError> {
|
||||
if !path.exists() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
if bytes.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)
|
||||
}
|
||||
|
||||
fn flush_channels(
|
||||
&self,
|
||||
path: &Path,
|
||||
map: &HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>,
|
||||
) -> Result<(), StorageError> {
|
||||
let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
fn load_kp_map(path: &Path) -> Result<HashMap<Vec<u8>, VecDeque<Vec<u8>>>, StorageError> {
|
||||
if !path.exists() {
|
||||
return Ok(HashMap::new());
|
||||
@@ -346,8 +405,9 @@ impl Store for FileBackedStore {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
let seq = *inner.next_seq.entry(key.clone()).or_insert(0);
|
||||
*inner.next_seq.get_mut(&key).unwrap() = seq + 1;
|
||||
let entry = inner.next_seq.entry(key.clone()).or_insert(0);
|
||||
let seq = *entry;
|
||||
*entry = seq + 1;
|
||||
inner.map.entry(key).or_default().push_back(SeqEntry { seq, data: payload });
|
||||
self.flush_delivery_map(&self.ds_path, &*inner)?;
|
||||
Ok(seq)
|
||||
@@ -428,7 +488,13 @@ impl Store for FileBackedStore {
|
||||
if let Some(parent) = self.setup_path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(&self.setup_path, setup).map_err(|e| StorageError::Io(e.to_string()))
|
||||
fs::write(&self.setup_path, setup).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(&self.setup_path, std::fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_server_setup(&self) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
@@ -444,7 +510,14 @@ impl Store for FileBackedStore {
|
||||
|
||||
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
|
||||
let mut map = lock(&self.users)?;
|
||||
map.insert(username.to_string(), record);
|
||||
match map.entry(username.to_string()) {
|
||||
std::collections::hash_map::Entry::Occupied(_) => {
|
||||
return Err(StorageError::DuplicateUser(username.to_string()))
|
||||
}
|
||||
std::collections::hash_map::Entry::Vacant(v) => {
|
||||
v.insert(record);
|
||||
}
|
||||
}
|
||||
self.flush_users(&self.users_path, &*map)
|
||||
}
|
||||
|
||||
@@ -473,6 +546,54 @@ impl Store for FileBackedStore {
|
||||
Ok(map.get(username).cloned())
|
||||
}
|
||||
|
||||
fn peek(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
||||
let inner = lock(&self.deliveries)?;
|
||||
let key = ChannelKey {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
let messages: Vec<(u64, Vec<u8>)> = inner
|
||||
.map
|
||||
.get(&key)
|
||||
.map(|q| {
|
||||
let count = if limit == 0 { q.len() } else { limit.min(q.len()) };
|
||||
q.iter()
|
||||
.take(count)
|
||||
.map(|e| (e.seq, e.data.clone()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
// Non-destructive: do NOT flush.
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
fn ack(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
seq_up_to: u64,
|
||||
) -> Result<usize, StorageError> {
|
||||
let mut inner = lock(&self.deliveries)?;
|
||||
let key = ChannelKey {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
let removed = if let Some(q) = inner.map.get_mut(&key) {
|
||||
let before = q.len();
|
||||
q.retain(|e| e.seq > seq_up_to);
|
||||
before - q.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
self.flush_delivery_map(&self.ds_path, &*inner)?;
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
fn publish_endpoint(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
@@ -487,4 +608,150 @@ impl Store for FileBackedStore {
|
||||
let map = lock(&self.endpoints)?;
|
||||
Ok(map.get(identity_key).cloned())
|
||||
}
|
||||
|
||||
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError> {
|
||||
let (a, b) = if member_a < member_b {
|
||||
(member_a.to_vec(), member_b.to_vec())
|
||||
} else {
|
||||
(member_b.to_vec(), member_a.to_vec())
|
||||
};
|
||||
let mut map = lock(&self.channels)?;
|
||||
if let Some((channel_id, _)) = map.iter().find(|(_, (ma, mb))| ma == &a && mb == &b) {
|
||||
return Ok(channel_id.clone());
|
||||
}
|
||||
let mut channel_id = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut channel_id);
|
||||
let channel_id = channel_id.to_vec();
|
||||
map.insert(channel_id.clone(), (a, b));
|
||||
self.flush_channels(&self.channels_path, &*map)?;
|
||||
Ok(channel_id)
|
||||
}
|
||||
|
||||
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError> {
|
||||
let map = lock(&self.channels)?;
|
||||
Ok(map.get(channel_id).cloned())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn temp_store() -> (TempDir, FileBackedStore) {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let store = FileBackedStore::open(dir.path()).unwrap();
|
||||
(dir, store)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_package_upload_fetch() {
|
||||
let (_dir, store) = temp_store();
|
||||
let ik = vec![1u8; 32];
|
||||
store.upload_key_package(&ik, vec![10, 20, 30]).unwrap();
|
||||
let pkg = store.fetch_key_package(&ik).unwrap();
|
||||
assert_eq!(pkg, Some(vec![10, 20, 30]));
|
||||
// Second fetch should return None (consumed)
|
||||
let pkg2 = store.fetch_key_package(&ik).unwrap();
|
||||
assert_eq!(pkg2, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enqueue_fetch_with_seq() {
|
||||
let (_dir, store) = temp_store();
|
||||
let rk = vec![2u8; 32];
|
||||
let ch = vec![];
|
||||
let seq0 = store.enqueue(&rk, &ch, vec![1]).unwrap();
|
||||
let seq1 = store.enqueue(&rk, &ch, vec![2]).unwrap();
|
||||
assert_eq!(seq0, 0);
|
||||
assert_eq!(seq1, 1);
|
||||
let msgs = store.fetch(&rk, &ch).unwrap();
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0], (0, vec![1]));
|
||||
assert_eq!(msgs[1], (1, vec![2]));
|
||||
// After fetch, queue should be empty
|
||||
let msgs2 = store.fetch(&rk, &ch).unwrap();
|
||||
assert!(msgs2.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fetch_limited_respects_limit() {
|
||||
let (_dir, store) = temp_store();
|
||||
let rk = vec![3u8; 32];
|
||||
let ch = vec![];
|
||||
for i in 0..5 {
|
||||
store.enqueue(&rk, &ch, vec![i]).unwrap();
|
||||
}
|
||||
let msgs = store.fetch_limited(&rk, &ch, 2).unwrap();
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].1, vec![0]);
|
||||
assert_eq!(msgs[1].1, vec![1]);
|
||||
// Remaining 3 should still be there
|
||||
let depth = store.queue_depth(&rk, &ch).unwrap();
|
||||
assert_eq!(depth, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn queue_depth_tracking() {
|
||||
let (_dir, store) = temp_store();
|
||||
let rk = vec![4u8; 32];
|
||||
let ch = vec![];
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0);
|
||||
store.enqueue(&rk, &ch, vec![1]).unwrap();
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 1);
|
||||
store.enqueue(&rk, &ch, vec![2]).unwrap();
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 2);
|
||||
store.fetch(&rk, &ch).unwrap();
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hybrid_key_upload_fetch() {
|
||||
let (_dir, store) = temp_store();
|
||||
let ik = vec![5u8; 32];
|
||||
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), None);
|
||||
store.upload_hybrid_key(&ik, vec![99; 100]).unwrap();
|
||||
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(vec![99; 100]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_record_crud() {
|
||||
let (_dir, store) = temp_store();
|
||||
assert!(!store.has_user_record("alice").unwrap());
|
||||
store.store_user_record("alice", vec![1, 2, 3]).unwrap();
|
||||
assert!(store.has_user_record("alice").unwrap());
|
||||
assert_eq!(store.get_user_record("alice").unwrap(), Some(vec![1, 2, 3]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_identity_key_crud() {
|
||||
let (_dir, store) = temp_store();
|
||||
assert_eq!(store.get_user_identity_key("bob").unwrap(), None);
|
||||
store.store_user_identity_key("bob", vec![7u8; 32]).unwrap();
|
||||
assert_eq!(store.get_user_identity_key("bob").unwrap(), Some(vec![7u8; 32]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn endpoint_publish_resolve() {
|
||||
let (_dir, store) = temp_store();
|
||||
let ik = vec![8u8; 32];
|
||||
assert_eq!(store.resolve_endpoint(&ik).unwrap(), None);
|
||||
store.publish_endpoint(&ik, vec![10, 20]).unwrap();
|
||||
assert_eq!(store.resolve_endpoint(&ik).unwrap(), Some(vec![10, 20]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_channel_and_members() {
|
||||
let (_dir, store) = temp_store();
|
||||
let a = vec![1u8; 32];
|
||||
let b = vec![2u8; 32];
|
||||
assert_eq!(store.get_channel_members(&[0u8; 16]).unwrap(), None);
|
||||
let id1 = store.create_channel(&a, &b).unwrap();
|
||||
assert_eq!(id1.len(), 16);
|
||||
let members = store.get_channel_members(&id1).unwrap().unwrap();
|
||||
assert_eq!(members.0, a);
|
||||
assert_eq!(members.1, b);
|
||||
let id2 = store.create_channel(&b, &a).unwrap();
|
||||
assert_eq!(id1, id2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,6 +61,12 @@ fn generate_self_signed_cert(cert_path: &PathBuf, key_path: &PathBuf) -> anyhow:
|
||||
|
||||
std::fs::write(cert_path, issued.cert.der()).context("write cert")?;
|
||||
std::fs::write(key_path, &key_der).context("write key")?;
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let perms = std::fs::Permissions::from_mode(0o600);
|
||||
std::fs::set_permissions(key_path, perms).context("set key permissions")?;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
cert = %cert_path.display(),
|
||||
|
||||
Reference in New Issue
Block a user