Cursor: Apply local changes for cloud agent

This commit is contained in:
2026-02-22 22:29:52 +01:00
parent 6b8b61c6ae
commit 41c57a1181
21 changed files with 616 additions and 142 deletions

1
.cursor/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
plans/

103
Cargo.lock generated
View File

@@ -804,6 +804,15 @@ dependencies = [
"strsim", "strsim",
] ]
[[package]]
name = "clap_complete"
version = "4.5.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c757a3b7e39161a4e56f9365141ada2a6c915a8622c408ab6bb4b5d047371031"
dependencies = [
"clap",
]
[[package]] [[package]]
name = "clap_derive" name = "clap_derive"
version = "4.5.55" version = "4.5.55"
@@ -856,6 +865,19 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "console"
version = "0.15.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8"
dependencies = [
"encode_unicode",
"libc",
"once_cell",
"unicode-width",
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "const-oid" name = "const-oid"
version = "0.9.6" version = "0.9.6"
@@ -1691,6 +1713,12 @@ version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d"
[[package]]
name = "encode_unicode"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
[[package]] [[package]]
name = "enum-as-inner" name = "enum-as-inner"
version = "0.6.1" version = "0.6.1"
@@ -1995,6 +2023,12 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.32" version = "0.3.32"
@@ -2334,6 +2368,26 @@ dependencies = [
"system-deps", "system-deps",
] ]
[[package]]
name = "governor"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68a7f542ee6b35af73b06abc0dad1c1bae89964e4e253bc4b587b91c9637867b"
dependencies = [
"cfg-if",
"dashmap",
"futures",
"futures-timer",
"no-std-compat",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"quanta",
"rand 0.8.5",
"smallvec",
"spinning_top",
]
[[package]] [[package]]
name = "group" name = "group"
version = "0.13.0" version = "0.13.0"
@@ -2962,6 +3016,19 @@ dependencies = [
"serde_core", "serde_core",
] ]
[[package]]
name = "indicatif"
version = "0.17.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235"
dependencies = [
"console",
"number_prefix",
"portable-atomic",
"unicode-width",
"web-time",
]
[[package]] [[package]]
name = "infer" name = "infer"
version = "0.19.0" version = "0.19.0"
@@ -3903,12 +3970,24 @@ version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]] [[package]]
name = "nodrop" name = "nodrop"
version = "0.1.14" version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb"
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]] [[package]]
name = "ntimestamp" name = "ntimestamp"
version = "1.0.0" version = "1.0.0"
@@ -3989,6 +4068,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "number_prefix"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]] [[package]]
name = "objc2" name = "objc2"
version = "0.6.3" version = "0.6.3"
@@ -5038,9 +5123,11 @@ dependencies = [
"capnp-rpc", "capnp-rpc",
"chacha20poly1305 0.10.1", "chacha20poly1305 0.10.1",
"clap", "clap",
"clap_complete",
"dashmap", "dashmap",
"futures", "futures",
"hex", "hex",
"indicatif",
"opaque-ke", "opaque-ke",
"openmls_rust_crypto", "openmls_rust_crypto",
"portpicker", "portpicker",
@@ -5131,6 +5218,7 @@ dependencies = [
"clap", "clap",
"dashmap", "dashmap",
"futures", "futures",
"governor",
"metrics 0.22.4", "metrics 0.22.4",
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
"opaque-ke", "opaque-ke",
@@ -6310,6 +6398,15 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591"
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]] [[package]]
name = "spki" name = "spki"
version = "0.7.3" version = "0.7.3"
@@ -7327,6 +7424,12 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]] [[package]]
name = "unicode-xid" name = "unicode-xid"
version = "0.2.6" version = "0.2.6"

View File

@@ -54,6 +54,7 @@ rusqlite = { version = "0.31", features = ["bundled-sqlcipher"] }
# ── Server utilities ────────────────────────────────────────────────────────── # ── Server utilities ──────────────────────────────────────────────────────────
dashmap = { version = "5" } dashmap = { version = "5" }
governor = { version = "0.6" }
tracing = { version = "0.1" } tracing = { version = "0.1" }
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
@@ -63,6 +64,8 @@ thiserror = { version = "1" }
# ── CLI ─────────────────────────────────────────────────────────────────────── # ── CLI ───────────────────────────────────────────────────────────────────────
clap = { version = "4", features = ["derive", "env"] } clap = { version = "4", features = ["derive", "env"] }
clap_complete = { version = "4" }
indicatif = { version = "0.17" }
# ── Build-time ──────────────────────────────────────────────────────────────── # ── Build-time ────────────────────────────────────────────────────────────────
capnpc = { version = "0.19" } capnpc = { version = "0.19" }

View File

@@ -48,6 +48,8 @@ tracing-subscriber = { workspace = true }
# CLI # CLI
clap = { workspace = true } clap = { workspace = true }
clap_complete = { workspace = true }
indicatif = { workspace = true }
[dev-dependencies] [dev-dependencies]
dashmap = { workspace = true } dashmap = { workspace = true }

View File

@@ -7,7 +7,7 @@ use opaque_ke::{
}; };
use quicnprotochat_core::{ use quicnprotochat_core::{
generate_key_package, hybrid_decrypt, hybrid_encrypt, opaque_auth::OpaqueSuite, generate_key_package, hybrid_decrypt, hybrid_encrypt, opaque_auth::OpaqueSuite,
GroupMember, HybridKeypair, IdentityKeypair, HybridKeypair, IdentityKeypair,
}; };
use super::{ use super::{
@@ -16,7 +16,10 @@ use super::{
connect_node, current_timestamp_ms, enqueue, fetch_all, fetch_hybrid_key, connect_node, current_timestamp_ms, enqueue, fetch_all, fetch_hybrid_key,
fetch_key_package, fetch_wait, try_hybrid_decrypt, upload_hybrid_key, upload_key_package, fetch_key_package, fetch_wait, try_hybrid_decrypt, upload_hybrid_key, upload_key_package,
}, },
state::{decode_identity_key, load_existing_state, load_or_init_state, save_state, sha256}, state::{
decode_identity_key, load_existing_state, load_or_init_state, save_state, sha256,
MemberBackend,
},
}; };
/// Print local identity information from the state file (no server connection). /// Print local identity information from the state file (no server connection).
@@ -45,6 +48,14 @@ pub fn cmd_whoami(state_path: &Path, password: Option<&str>) -> anyhow::Result<(
"none" "none"
} }
); );
println!(
"pq_backend : {}",
if state.use_pq_backend {
"yes (MLS HPKE: X25519 + ML-KEM-768)"
} else {
"no (classical)"
}
);
println!("state_file : {}", state_path.display()); println!("state_file : {}", state_path.display());
Ok(()) Ok(())
@@ -365,7 +376,7 @@ async fn do_upload_keypackage(
ca_cert: &Path, ca_cert: &Path,
server_name: &str, server_name: &str,
password: Option<&str>, password: Option<&str>,
member: &mut GroupMember, member: &mut MemberBackend,
hybrid_kp: Option<&HybridKeypair>, hybrid_kp: Option<&HybridKeypair>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let tls_bytes = member let tls_bytes = member
@@ -428,8 +439,9 @@ pub async fn cmd_register_state(
ca_cert: &Path, ca_cert: &Path,
server_name: &str, server_name: &str,
password: Option<&str>, password: Option<&str>,
use_pq_backend: bool,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let state = load_or_init_state(state_path, password)?; let state = load_or_init_state(state_path, password, use_pq_backend)?;
let (mut member, hybrid_kp) = state.into_parts(state_path)?; let (mut member, hybrid_kp) = state.into_parts(state_path)?;
do_upload_keypackage( do_upload_keypackage(
state_path, state_path,
@@ -522,15 +534,37 @@ pub async fn cmd_fetch_key(
} }
/// Run a two-party MLS demo against the unified server. /// Run a two-party MLS demo against the unified server.
pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) -> anyhow::Result<()> { pub async fn cmd_demo_group(
server: &str,
ca_cert: &Path,
server_name: &str,
use_pq_backend: bool,
) -> anyhow::Result<()> {
use indicatif::{ProgressBar, ProgressStyle};
let creator_state_path = PathBuf::from("quicnprotochat-demo-creator.bin"); let creator_state_path = PathBuf::from("quicnprotochat-demo-creator.bin");
let joiner_state_path = PathBuf::from("quicnprotochat-demo-joiner.bin"); let joiner_state_path = PathBuf::from("quicnprotochat-demo-joiner.bin");
let (mut creator, creator_hybrid_opt) = let pb = ProgressBar::new(5);
load_or_init_state(&creator_state_path, None)?.into_parts(&creator_state_path)?; pb.set_style(
let (mut joiner, joiner_hybrid_opt) = ProgressStyle::with_template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} {msg}")
load_or_init_state(&joiner_state_path, None)?.into_parts(&joiner_state_path)?; .expect("demo progress template is valid")
.tick_chars("\u{2801}\u{2802}\u{2804}\u{2840}\u{2820}\u{2810}\u{2808} ")
.progress_chars("=>-"),
);
pb.enable_steady_tick(std::time::Duration::from_millis(80));
pb.set_message("Generating Alice keys\u{2026}");
let (mut creator, creator_hybrid_opt) =
load_or_init_state(&creator_state_path, None, use_pq_backend)?.into_parts(&creator_state_path)?;
pb.inc(1);
pb.set_message("Generating Bob keys\u{2026}");
let (mut joiner, joiner_hybrid_opt) =
load_or_init_state(&joiner_state_path, None, use_pq_backend)?.into_parts(&joiner_state_path)?;
pb.inc(1);
pb.set_message("Creating group\u{2026}");
let creator_hybrid = creator_hybrid_opt.unwrap_or_else(HybridKeypair::generate); let creator_hybrid = creator_hybrid_opt.unwrap_or_else(HybridKeypair::generate);
let joiner_hybrid = joiner_hybrid_opt.unwrap_or_else(HybridKeypair::generate); let joiner_hybrid = joiner_hybrid_opt.unwrap_or_else(HybridKeypair::generate);
@@ -552,8 +586,6 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
upload_hybrid_key(&creator_node, &creator_identity, &creator_hybrid.public_key()).await?; upload_hybrid_key(&creator_node, &creator_identity, &creator_hybrid.public_key()).await?;
upload_hybrid_key(&joiner_node, &joiner_identity, &joiner_hybrid.public_key()).await?; upload_hybrid_key(&joiner_node, &joiner_identity, &joiner_hybrid.public_key()).await?;
println!("hybrid public keys uploaded for creator and joiner");
let fetched_joiner_kp = fetch_key_package(&creator_node, &joiner_identity).await?; let fetched_joiner_kp = fetch_key_package(&creator_node, &joiner_identity).await?;
anyhow::ensure!( anyhow::ensure!(
!fetched_joiner_kp.is_empty(), !fetched_joiner_kp.is_empty(),
@@ -566,7 +598,9 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
let (_commit, welcome) = creator let (_commit, welcome) = creator
.add_member(&fetched_joiner_kp) .add_member(&fetched_joiner_kp)
.context("add_member failed")?; .context("add_member failed")?;
pb.inc(1);
pb.set_message("Encrypting\u{2026}");
let creator_ds = creator_node.clone(); let creator_ds = creator_node.clone();
let joiner_ds = joiner_node.clone(); let joiner_ds = joiner_node.clone();
@@ -576,7 +610,9 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
let wrapped_welcome = let wrapped_welcome =
hybrid_encrypt(&joiner_hybrid_pk, &welcome).context("hybrid encrypt welcome")?; hybrid_encrypt(&joiner_hybrid_pk, &welcome).context("hybrid encrypt welcome")?;
enqueue(&creator_ds, &joiner_identity, &wrapped_welcome).await?; enqueue(&creator_ds, &joiner_identity, &wrapped_welcome).await?;
pb.inc(1);
pb.set_message("Delivering\u{2026}");
let welcome_payloads = fetch_all(&joiner_ds, &joiner_identity).await?; let welcome_payloads = fetch_all(&joiner_ds, &joiner_identity).await?;
let raw_welcome = welcome_payloads let raw_welcome = welcome_payloads
.first() .first()
@@ -605,10 +641,6 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
let plaintext_creator_joiner = joiner let plaintext_creator_joiner = joiner
.receive_message(&inner_creator_joiner)? .receive_message(&inner_creator_joiner)?
.context("expected application message")?; .context("expected application message")?;
println!(
"creator -> joiner plaintext: {}",
String::from_utf8_lossy(&plaintext_creator_joiner)
);
let creator_hybrid_pk = fetch_hybrid_key(&joiner_node, &creator_identity) let creator_hybrid_pk = fetch_hybrid_key(&joiner_node, &creator_identity)
.await? .await?
@@ -629,11 +661,17 @@ pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) ->
let plaintext_joiner_creator = creator let plaintext_joiner_creator = creator
.receive_message(&inner_joiner_creator)? .receive_message(&inner_joiner_creator)?
.context("expected application message")?; .context("expected application message")?;
pb.inc(1);
pb.finish_and_clear();
println!( println!(
"joiner -> creator plaintext: {}", "creator -> joiner: {}",
String::from_utf8_lossy(&plaintext_creator_joiner)
);
println!(
"joiner -> creator: {}",
String::from_utf8_lossy(&plaintext_joiner_creator) String::from_utf8_lossy(&plaintext_joiner_creator)
); );
println!("demo-group complete (hybrid PQ envelope active)"); println!("demo-group complete (hybrid PQ envelope active)");
Ok(()) Ok(())
@@ -645,8 +683,9 @@ pub async fn cmd_create_group(
_server: &str, _server: &str,
group_id: &str, group_id: &str,
password: Option<&str>, password: Option<&str>,
use_pq_backend: bool,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let state = load_or_init_state(state_path, password)?; let state = load_or_init_state(state_path, password, use_pq_backend)?;
let (mut member, hybrid_kp) = state.into_parts(state_path)?; let (mut member, hybrid_kp) = state.into_parts(state_path)?;
anyhow::ensure!( anyhow::ensure!(
@@ -850,11 +889,29 @@ pub async fn cmd_recv(
stream: bool, stream: bool,
password: Option<&str>, password: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
use indicatif::{ProgressBar, ProgressStyle};
let state = load_existing_state(state_path, password)?; let state = load_existing_state(state_path, password)?;
let (mut member, hybrid_kp) = state.into_parts(state_path)?; let (mut member, hybrid_kp) = state.into_parts(state_path)?;
let client = connect_node(server, ca_cert, server_name).await?; let client = connect_node(server, ca_cert, server_name).await?;
let stream_pb: Option<ProgressBar> = if stream {
let pb = ProgressBar::new_spinner();
pb.set_style(
ProgressStyle::with_template("{spinner:.green} {msg}")
.expect("recv progress template is valid")
.tick_chars("\u{2801}\u{2802}\u{2804}\u{2840}\u{2820}\u{2810}\u{2808} "),
);
pb.set_message("Listening for messages (0 received)\u{2026}");
pb.enable_steady_tick(std::time::Duration::from_millis(100));
Some(pb)
} else {
None
};
let mut total_received: usize = 0;
loop { loop {
let mut payloads = let mut payloads =
fetch_wait(&client, &member.identity().public_key_bytes(), wait_ms).await?; fetch_wait(&client, &member.identity().public_key_bytes(), wait_ms).await?;
@@ -876,13 +933,29 @@ pub async fn cmd_recv(
let mls_payload = match try_hybrid_decrypt(hybrid_kp.as_ref(), payload) { let mls_payload = match try_hybrid_decrypt(hybrid_kp.as_ref(), payload) {
Ok(b) => b, Ok(b) => b,
Err(e) => { Err(e) => {
println!("[{idx}] decrypt error: {e}"); match &stream_pb {
Some(pb) => pb.println(format!("[{idx}] decrypt error: {e}")),
None => println!("[{idx}] decrypt error: {e}"),
}
continue; continue;
} }
}; };
match member.receive_message(&mls_payload) { match member.receive_message(&mls_payload) {
Ok(Some(pt)) => println!("[{idx}] plaintext: {}", String::from_utf8_lossy(&pt)), Ok(Some(pt)) => {
Ok(None) => println!("[{idx}] commit applied"), total_received += 1;
let line = format!("[{idx}] plaintext: {}", String::from_utf8_lossy(&pt));
match &stream_pb {
Some(pb) => pb.println(line),
None => println!("{line}"),
}
}
Ok(None) => {
let line = format!("[{idx}] commit applied");
match &stream_pb {
Some(pb) => pb.println(line),
None => println!("{line}"),
}
}
Err(_) => retry_mls.push(mls_payload), Err(_) => retry_mls.push(mls_payload),
} }
} }
@@ -890,14 +963,33 @@ pub async fn cmd_recv(
// epoch was not yet advanced until a commit earlier in the batch was applied). // epoch was not yet advanced until a commit earlier in the batch was applied).
for mls_payload in &retry_mls { for mls_payload in &retry_mls {
match member.receive_message(mls_payload) { match member.receive_message(mls_payload) {
Ok(Some(pt)) => println!("[retry] plaintext: {}", String::from_utf8_lossy(&pt)), Ok(Some(pt)) => {
total_received += 1;
let line = format!("[retry] plaintext: {}", String::from_utf8_lossy(&pt));
match &stream_pb {
Some(pb) => pb.println(line),
None => println!("{line}"),
}
}
Ok(None) => {} Ok(None) => {}
Err(e) => println!("[retry] error: {e}"), Err(e) => {
let line = format!("[retry] error: {e}");
match &stream_pb {
Some(pb) => pb.println(line),
None => println!("{line}"),
}
}
} }
} }
save_state(state_path, &member, hybrid_kp.as_ref(), password)?; save_state(state_path, &member, hybrid_kp.as_ref(), password)?;
if let Some(ref pb) = stream_pb {
pb.set_message(format!(
"Listening for messages ({total_received} received)\u{2026}"
));
}
if !stream { if !stream {
return Ok(()); return Ok(());
} }

View File

@@ -1,4 +1,8 @@
//! Retry with exponential backoff for transient RPC failures. //! Retry with exponential backoff for transient RPC failures.
//!
//! Used for `enqueue`, `fetch_all`, and `fetch_wait`. Auth and invalid-param
//! errors are not retried. Configure via `QUICNPROTOCHAT_MAX_RETRIES` and
//! `QUICNPROTOCHAT_BASE_DELAY_MS` (optional).
use std::future::Future; use std::future::Future;
use std::time::Duration; use std::time::Duration;
@@ -11,6 +15,22 @@ pub const DEFAULT_MAX_RETRIES: u32 = 3;
/// Default base delay in milliseconds for exponential backoff. /// Default base delay in milliseconds for exponential backoff.
pub const DEFAULT_BASE_DELAY_MS: u64 = 500; pub const DEFAULT_BASE_DELAY_MS: u64 = 500;
/// Read max retries from env or use default.
pub fn max_retries_from_env() -> u32 {
std::env::var("QUICNPROTOCHAT_MAX_RETRIES")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_MAX_RETRIES)
}
/// Read base delay (ms) from env or use default.
pub fn base_delay_ms_from_env() -> u64 {
std::env::var("QUICNPROTOCHAT_BASE_DELAY_MS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_BASE_DELAY_MS)
}
/// Runs an async operation with retries. On `Ok(t)` returns immediately. /// Runs an async operation with retries. On `Ok(t)` returns immediately.
/// On `Err(e)`: if `is_retriable(&e)` and `attempt < max_retries`, sleeps with /// On `Err(e)`: if `is_retriable(&e)` and `attempt < max_retries`, sleeps with
/// exponential backoff (plus jitter) then retries; otherwise returns the last error. /// exponential backoff (plus jitter) then retries; otherwise returns the last error.
@@ -31,7 +51,7 @@ where
Ok(t) => return Ok(t), Ok(t) => return Ok(t),
Err(e) => { Err(e) => {
last_err = Some(e); last_err = Some(e);
let err = last_err.as_ref().unwrap(); let err = last_err.as_ref().expect("last_err just set in Err branch");
if !is_retriable(err) || attempt + 1 >= max_retries { if !is_retriable(err) || attempt + 1 >= max_retries {
break; break;
} }
@@ -48,7 +68,8 @@ where
} }
} }
} }
Err(last_err.expect("retry_async: last_err set when we break after Err")) // Loop runs at least once (max_retries >= 1) and we only break after storing an Err, so this is always Some.
Err(last_err.expect("retry_async: last_err is Some when breaking after Err"))
} }
/// Classifies `anyhow::Error` for retry: returns `false` for auth or invalid-param /// Classifies `anyhow::Error` for retry: returns `false` for auth or invalid-param

View File

@@ -15,7 +15,9 @@ use quicnprotochat_proto::node_capnp::{auth, node_service};
use crate::AUTH_CONTEXT; use crate::AUTH_CONTEXT;
use super::retry::{anyhow_is_retriable, retry_async, DEFAULT_BASE_DELAY_MS, DEFAULT_MAX_RETRIES}; use super::retry::{
anyhow_is_retriable, base_delay_ms_from_env, max_retries_from_env, retry_async,
};
/// Establish a QUIC/TLS connection and return a `NodeService` client. /// Establish a QUIC/TLS connection and return a `NodeService` client.
/// ///
@@ -174,8 +176,8 @@ pub async fn enqueue(
Ok(seq) Ok(seq)
} }
}, },
DEFAULT_MAX_RETRIES, max_retries_from_env(),
DEFAULT_BASE_DELAY_MS, base_delay_ms_from_env(),
anyhow_is_retriable, anyhow_is_retriable,
) )
.await .await
@@ -228,8 +230,8 @@ pub async fn fetch_all(
Ok(payloads) Ok(payloads)
} }
}, },
DEFAULT_MAX_RETRIES, max_retries_from_env(),
DEFAULT_BASE_DELAY_MS, base_delay_ms_from_env(),
anyhow_is_retriable, anyhow_is_retriable,
) )
.await .await
@@ -285,8 +287,8 @@ pub async fn fetch_wait(
Ok(payloads) Ok(payloads)
} }
}, },
DEFAULT_MAX_RETRIES, max_retries_from_env(),
DEFAULT_BASE_DELAY_MS, base_delay_ms_from_env(),
anyhow_is_retriable, anyhow_is_retriable,
) )
.await .await

View File

@@ -10,13 +10,21 @@ use chacha20poly1305::{
use rand::RngCore; use rand::RngCore;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use quicnprotochat_core::{DiskKeyStore, GroupMember, HybridKeypair, HybridKeypairBytes, IdentityKeypair}; use quicnprotochat_core::{
CoreError, DiskKeyStore, GroupMember, HybridCryptoProvider, HybridKeypair, HybridKeypairBytes,
IdentityKeypair, MlsGroup, StoreCrypto,
};
/// Magic bytes for encrypted client state files. /// Magic bytes for encrypted client state files.
const STATE_MAGIC: &[u8; 4] = b"QPCE"; const STATE_MAGIC: &[u8; 4] = b"QPCE";
const STATE_SALT_LEN: usize = 16; const STATE_SALT_LEN: usize = 16;
const STATE_NONCE_LEN: usize = 12; const STATE_NONCE_LEN: usize = 12;
/// Persisted client state (identity, MLS group, optional PQ key).
///
/// **Production note:** When loading state, use the same `use_pq_backend` value that was used when
/// the state was created. Loading PQ state with classical backend (or vice versa) will fail or
/// produce incorrect behavior.
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct StoredState { pub struct StoredState {
pub identity_seed: [u8; 32], pub identity_seed: [u8; 32],
@@ -27,17 +35,115 @@ pub struct StoredState {
/// Cached member public keys for group participants. /// Cached member public keys for group participants.
#[serde(default)] #[serde(default)]
pub member_keys: Vec<Vec<u8>>, pub member_keys: Vec<Vec<u8>>,
/// If true, MLS uses post-quantum hybrid KEM (HybridCryptoProvider) for HPKE. M7.
#[serde(default)]
pub use_pq_backend: bool,
}
/// MLS member backend: classical (StoreCrypto) or post-quantum hybrid (HybridCryptoProvider).
pub enum MemberBackend {
Classical(GroupMember<StoreCrypto>),
Hybrid(GroupMember<HybridCryptoProvider>),
}
impl MemberBackend {
pub fn generate_key_package(&mut self) -> Result<Vec<u8>, CoreError> {
match self {
MemberBackend::Classical(m) => m.generate_key_package(),
MemberBackend::Hybrid(m) => m.generate_key_package(),
}
}
pub fn create_group(&mut self, group_id: &[u8]) -> Result<(), CoreError> {
match self {
MemberBackend::Classical(m) => m.create_group(group_id),
MemberBackend::Hybrid(m) => m.create_group(group_id),
}
}
pub fn add_member(&mut self, key_package_bytes: &[u8]) -> Result<(Vec<u8>, Vec<u8>), CoreError> {
match self {
MemberBackend::Classical(m) => m.add_member(key_package_bytes),
MemberBackend::Hybrid(m) => m.add_member(key_package_bytes),
}
}
pub fn join_group(&mut self, welcome: &[u8]) -> Result<(), CoreError> {
match self {
MemberBackend::Classical(m) => m.join_group(welcome),
MemberBackend::Hybrid(m) => m.join_group(welcome),
}
}
pub fn send_message(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, CoreError> {
match self {
MemberBackend::Classical(m) => m.send_message(plaintext),
MemberBackend::Hybrid(m) => m.send_message(plaintext),
}
}
pub fn receive_message(&mut self, bytes: &[u8]) -> Result<Option<Vec<u8>>, CoreError> {
match self {
MemberBackend::Classical(m) => m.receive_message(bytes),
MemberBackend::Hybrid(m) => m.receive_message(bytes),
}
}
pub fn receive_message_with_sender(
&mut self,
bytes: &[u8],
) -> Result<Option<(Vec<u8>, Vec<u8>)>, CoreError> {
match self {
MemberBackend::Classical(m) => m.receive_message_with_sender(bytes),
MemberBackend::Hybrid(m) => m.receive_message_with_sender(bytes),
}
}
pub fn group_id(&self) -> Option<Vec<u8>> {
match self {
MemberBackend::Classical(m) => m.group_id(),
MemberBackend::Hybrid(m) => m.group_id(),
}
}
pub fn identity(&self) -> &IdentityKeypair {
match self {
MemberBackend::Classical(m) => m.identity(),
MemberBackend::Hybrid(m) => m.identity(),
}
}
pub fn identity_seed(&self) -> [u8; 32] {
match self {
MemberBackend::Classical(m) => m.identity_seed(),
MemberBackend::Hybrid(m) => m.identity_seed(),
}
}
pub fn group_ref(&self) -> Option<&MlsGroup> {
match self {
MemberBackend::Classical(m) => m.group_ref(),
MemberBackend::Hybrid(m) => m.group_ref(),
}
}
pub fn member_identities(&self) -> Vec<Vec<u8>> {
match self {
MemberBackend::Classical(m) => m.member_identities(),
MemberBackend::Hybrid(m) => m.member_identities(),
}
}
pub fn is_pq(&self) -> bool {
matches!(self, MemberBackend::Hybrid(_))
}
} }
impl StoredState { impl StoredState {
pub fn into_parts(self, state_path: &Path) -> anyhow::Result<(GroupMember, Option<HybridKeypair>)> { /// Rebuild member and hybrid key from stored state. Uses PQ backend if `use_pq_backend` is true.
pub fn into_parts(self, state_path: &Path) -> anyhow::Result<(MemberBackend, Option<HybridKeypair>)> {
let identity = Arc::new(IdentityKeypair::from_seed(self.identity_seed)); let identity = Arc::new(IdentityKeypair::from_seed(self.identity_seed));
let group = self let group = self
.group .group
.map(|bytes| bincode::deserialize(&bytes).context("decode group")) .map(|bytes| bincode::deserialize(&bytes).context("decode group"))
.transpose()?; .transpose()?;
let key_store = DiskKeyStore::persistent(keystore_path(state_path))?; let key_store = DiskKeyStore::persistent(keystore_path(state_path))?;
let member = GroupMember::new_with_state(identity, key_store, group);
let member = if self.use_pq_backend {
MemberBackend::Hybrid(GroupMember::<HybridCryptoProvider>::new_with_state_hybrid(
identity, key_store, group,
))
} else {
MemberBackend::Classical(GroupMember::new_with_state(identity, key_store, group))
};
let hybrid_kp = self let hybrid_kp = self
.hybrid_key .hybrid_key
@@ -47,7 +153,11 @@ impl StoredState {
Ok((member, hybrid_kp)) Ok((member, hybrid_kp))
} }
pub fn from_parts(member: &GroupMember, hybrid_kp: Option<&HybridKeypair>) -> anyhow::Result<Self> { /// Build state from a classical GroupMember (backward compat / tests). Prefer [`from_member_backend`](Self::from_member_backend) in production.
pub fn from_parts(
member: &GroupMember<StoreCrypto>,
hybrid_kp: Option<&HybridKeypair>,
) -> anyhow::Result<Self> {
let group = member let group = member
.group_ref() .group_ref()
.map(|g| bincode::serialize(g).context("serialize group")) .map(|g| bincode::serialize(g).context("serialize group"))
@@ -58,6 +168,26 @@ impl StoredState {
group, group,
hybrid_key: hybrid_kp.map(|kp| kp.to_bytes()), hybrid_key: hybrid_kp.map(|kp| kp.to_bytes()),
member_keys: Vec::new(), member_keys: Vec::new(),
use_pq_backend: false,
})
}
/// Build state from MemberBackend (classical or PQ).
pub fn from_member_backend(
member: &MemberBackend,
hybrid_kp: Option<&HybridKeypair>,
) -> anyhow::Result<Self> {
let group = member
.group_ref()
.map(|g| bincode::serialize(g).context("serialize group"))
.transpose()?;
Ok(Self {
identity_seed: member.identity_seed(),
group,
hybrid_key: hybrid_kp.map(|kp| kp.to_bytes()),
member_keys: Vec::new(),
use_pq_backend: member.is_pq(),
}) })
} }
} }
@@ -124,22 +254,49 @@ pub fn is_encrypted_state(bytes: &[u8]) -> bool {
bytes.len() >= 4 && &bytes[..4] == STATE_MAGIC bytes.len() >= 4 && &bytes[..4] == STATE_MAGIC
} }
pub fn load_or_init_state(path: &Path, password: Option<&str>) -> anyhow::Result<StoredState> { /// Create new state with optional post-quantum MLS backend (M7). When `use_pq_backend` is true,
/// new state uses `HybridCryptoProvider` for MLS HPKE (X25519 + ML-KEM-768).
pub fn load_or_init_state(
path: &Path,
password: Option<&str>,
use_pq_backend: bool,
) -> anyhow::Result<StoredState> {
if path.exists() { if path.exists() {
let mut state = load_existing_state(path, password)?; let mut state = load_existing_state(path, password)?;
// Generate hybrid keypair if missing (upgrade from older state). // Generate hybrid keypair if missing (upgrade from older state).
if state.hybrid_key.is_none() { if state.hybrid_key.is_none() {
let pb = indicatif::ProgressBar::new_spinner();
pb.set_message("Generating post-quantum keypair\u{2026}");
pb.enable_steady_tick(std::time::Duration::from_millis(80));
state.hybrid_key = Some(HybridKeypair::generate().to_bytes()); state.hybrid_key = Some(HybridKeypair::generate().to_bytes());
pb.finish_and_clear();
write_state(path, &state, password)?; write_state(path, &state, password)?;
} }
return Ok(state); return Ok(state);
} }
let pb = indicatif::ProgressBar::new_spinner();
pb.set_message("Generating post-quantum keypair\u{2026}");
pb.enable_steady_tick(std::time::Duration::from_millis(80));
let identity = IdentityKeypair::generate(); let identity = IdentityKeypair::generate();
let hybrid_kp = HybridKeypair::generate(); let hybrid_kp = HybridKeypair::generate();
pb.finish_and_clear();
let key_store = DiskKeyStore::persistent(keystore_path(path))?; let key_store = DiskKeyStore::persistent(keystore_path(path))?;
let member = GroupMember::new_with_state(Arc::new(identity), key_store, None); let member = if use_pq_backend {
let state = StoredState::from_parts(&member, Some(&hybrid_kp))?; MemberBackend::Hybrid(GroupMember::<HybridCryptoProvider>::new_with_state_hybrid(
Arc::new(identity),
key_store,
None,
))
} else {
MemberBackend::Classical(GroupMember::new_with_state(
Arc::new(identity),
key_store,
None,
))
};
let state = StoredState::from_member_backend(&member, Some(&hybrid_kp))?;
write_state(path, &state, password)?; write_state(path, &state, password)?;
Ok(state) Ok(state)
} }
@@ -159,11 +316,11 @@ pub fn load_existing_state(path: &Path, password: Option<&str>) -> anyhow::Resul
pub fn save_state( pub fn save_state(
path: &Path, path: &Path,
member: &GroupMember, member: &MemberBackend,
hybrid_kp: Option<&HybridKeypair>, hybrid_kp: Option<&HybridKeypair>,
password: Option<&str>, password: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let state = StoredState::from_parts(member, hybrid_kp)?; let state = StoredState::from_member_backend(member, hybrid_kp)?;
write_state(path, &state, password) write_state(path, &state, password)
} }

View File

@@ -26,6 +26,7 @@ pub use client::commands::{
}; };
pub use client::rpc::{connect_node, enqueue, fetch_wait}; pub use client::rpc::{connect_node, enqueue, fetch_wait};
pub use client::state::{load_existing_state, StoredState};
// Global auth context initialized once per process. // Global auth context initialized once per process.
pub(crate) static AUTH_CONTEXT: OnceLock<ClientAuth> = OnceLock::new(); pub(crate) static AUTH_CONTEXT: OnceLock<ClientAuth> = OnceLock::new();

View File

@@ -52,6 +52,10 @@ struct Args {
#[arg(long, global = true, env = "QUICNPROTOCHAT_STATE_PASSWORD")] #[arg(long, global = true, env = "QUICNPROTOCHAT_STATE_PASSWORD")]
state_password: Option<String>, state_password: Option<String>,
/// Use post-quantum MLS backend (X25519 + ML-KEM-768) for new state. M7.
#[arg(long, global = true, env = "QUICNPROTOCHAT_PQ")]
pq: bool,
#[command(subcommand)] #[command(subcommand)]
command: Command, command: Command,
} }
@@ -284,6 +288,13 @@ enum Command {
#[arg(long, default_value_t = 500)] #[arg(long, default_value_t = 500)]
poll_interval_ms: u64, poll_interval_ms: u64,
}, },
/// Generate shell completions for the given shell and print to stdout.
#[command(hide = true)]
Completions {
shell: clap_complete::Shell,
},
} }
// ── Entry point ─────────────────────────────────────────────────────────────── // ── Entry point ───────────────────────────────────────────────────────────────
@@ -390,7 +401,7 @@ async fn main() -> anyhow::Result<()> {
Command::DemoGroup { server } => { Command::DemoGroup { server } => {
let local = tokio::task::LocalSet::new(); let local = tokio::task::LocalSet::new();
local local
.run_until(cmd_demo_group(&server, &args.ca_cert, &args.server_name)) .run_until(cmd_demo_group(&server, &args.ca_cert, &args.server_name, args.pq))
.await .await
} }
Command::RegisterState { state, server } => { Command::RegisterState { state, server } => {
@@ -402,6 +413,7 @@ async fn main() -> anyhow::Result<()> {
&args.ca_cert, &args.ca_cert,
&args.server_name, &args.server_name,
state_pw, state_pw,
args.pq,
)) ))
.await .await
} }
@@ -424,7 +436,7 @@ async fn main() -> anyhow::Result<()> {
} => { } => {
let local = tokio::task::LocalSet::new(); let local = tokio::task::LocalSet::new();
local local
.run_until(cmd_create_group(&state, &server, &group_id, state_pw)) .run_until(cmd_create_group(&state, &server, &group_id, state_pw, args.pq))
.await .await
} }
Command::Invite { Command::Invite {
@@ -515,5 +527,15 @@ async fn main() -> anyhow::Result<()> {
)) ))
.await .await
} }
Command::Completions { shell } => {
use clap::CommandFactory;
clap_complete::generate(
shell,
&mut Args::command(),
"quicnprotochat",
&mut std::io::stdout(),
);
Ok(())
}
} }
} }

View File

@@ -18,7 +18,7 @@ fn ensure_rustls_provider() {
use quicnprotochat_client::{ use quicnprotochat_client::{
cmd_create_group, cmd_invite, cmd_join, cmd_login, cmd_ping, cmd_register_state, cmd_create_group, cmd_invite, cmd_join, cmd_login, cmd_ping, cmd_register_state,
cmd_register_user, cmd_send, connect_node, enqueue, fetch_wait, init_auth, cmd_register_user, cmd_send, connect_node, enqueue, fetch_wait, init_auth,
receive_pending_plaintexts, ClientAuth, load_existing_state, receive_pending_plaintexts, ClientAuth,
}; };
use quicnprotochat_core::IdentityKeypair; use quicnprotochat_core::IdentityKeypair;
@@ -26,12 +26,6 @@ fn hex_encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02x}")).collect() bytes.iter().map(|b| format!("{b:02x}")).collect()
} }
#[derive(serde::Deserialize)]
struct StoredStateCompat {
identity_seed: [u8; 32],
#[allow(dead_code)]
group: Option<Vec<u8>>,
}
async fn wait_for_health(server: &str, ca_cert: &PathBuf, server_name: &str) -> anyhow::Result<()> { async fn wait_for_health(server: &str, ca_cert: &PathBuf, server_name: &str) -> anyhow::Result<()> {
let local = tokio::task::LocalSet::new(); let local = tokio::task::LocalSet::new();
@@ -109,6 +103,7 @@ async fn e2e_happy_path_register_invite_join_send_recv() -> anyhow::Result<()> {
&ca_cert, &ca_cert,
"localhost", "localhost",
None, None,
false,
)) ))
.await?; .await?;
@@ -119,16 +114,16 @@ async fn e2e_happy_path_register_invite_join_send_recv() -> anyhow::Result<()> {
&ca_cert, &ca_cert,
"localhost", "localhost",
None, None,
false,
)) ))
.await?; .await?;
local local
.run_until(cmd_create_group(&creator_state, &server, "test-group", None)) .run_until(cmd_create_group(&creator_state, &server, "test-group", None, false))
.await?; .await?;
let joiner_bytes = std::fs::read(&joiner_state)?; let joiner_state_loaded = load_existing_state(&joiner_state, None)?;
let joiner_state_compat: StoredStateCompat = bincode::deserialize(&joiner_bytes)?; let joiner_identity = IdentityKeypair::from_seed(joiner_state_loaded.identity_seed);
let joiner_identity = IdentityKeypair::from_seed(joiner_state_compat.identity_seed);
let joiner_pk_hex = hex_encode(&joiner_identity.public_key_bytes()); let joiner_pk_hex = hex_encode(&joiner_identity.public_key_bytes());
local local
@@ -227,6 +222,7 @@ async fn e2e_three_party_group_invite_join_send_recv() -> anyhow::Result<()> {
&ca_cert, &ca_cert,
"localhost", "localhost",
None, None,
false,
)) ))
.await?; .await?;
local local
@@ -236,6 +232,7 @@ async fn e2e_three_party_group_invite_join_send_recv() -> anyhow::Result<()> {
&ca_cert, &ca_cert,
"localhost", "localhost",
None, None,
false,
)) ))
.await?; .await?;
local local
@@ -245,19 +242,18 @@ async fn e2e_three_party_group_invite_join_send_recv() -> anyhow::Result<()> {
&ca_cert, &ca_cert,
"localhost", "localhost",
None, None,
false,
)) ))
.await?; .await?;
let b_bytes = std::fs::read(&b_state)?; let b_loaded = load_existing_state(&b_state, None)?;
let b_compat: StoredStateCompat = bincode::deserialize(&b_bytes)?; let b_pk_hex = hex_encode(&IdentityKeypair::from_seed(b_loaded.identity_seed).public_key_bytes());
let b_pk_hex = hex_encode(&IdentityKeypair::from_seed(b_compat.identity_seed).public_key_bytes());
let c_bytes = std::fs::read(&c_state)?; let c_loaded = load_existing_state(&c_state, None)?;
let c_compat: StoredStateCompat = bincode::deserialize(&c_bytes)?; let c_pk_hex = hex_encode(&IdentityKeypair::from_seed(c_loaded.identity_seed).public_key_bytes());
let c_pk_hex = hex_encode(&IdentityKeypair::from_seed(c_compat.identity_seed).public_key_bytes());
local local
.run_until(cmd_create_group(&creator_state, &server, "test-group", None)) .run_until(cmd_create_group(&creator_state, &server, "test-group", None, false))
.await?; .await?;
local local
@@ -440,12 +436,12 @@ async fn e2e_login_rejects_mismatched_identity() -> anyhow::Result<()> {
&ca_cert, &ca_cert,
"localhost", "localhost",
None, None,
false,
)) ))
.await?; .await?;
// Register the user with the bound identity so login can enforce mismatches. // Register the user with the bound identity so login can enforce mismatches.
let state_bytes = std::fs::read(&state_path)?; let stored_state = load_existing_state(&state_path, None)?;
let stored_state: StoredStateCompat = bincode::deserialize(&state_bytes)?;
let identity_hex = hex::encode( let identity_hex = hex::encode(
IdentityKeypair::from_seed(stored_state.identity_seed).public_key_bytes(), IdentityKeypair::from_seed(stored_state.identity_seed).public_key_bytes(),
); );
@@ -547,11 +543,11 @@ async fn e2e_sealed_sender_enqueue_then_fetch() -> anyhow::Result<()> {
&ca_cert, &ca_cert,
"localhost", "localhost",
None, None,
false,
)) ))
.await?; .await?;
let state_bytes = std::fs::read(&state_path)?; let stored = load_existing_state(&state_path, None)?;
let stored: StoredStateCompat = bincode::deserialize(&state_bytes)?;
let recipient_key = IdentityKeypair::from_seed(stored.identity_seed).public_key_bytes(); let recipient_key = IdentityKeypair::from_seed(stored.identity_seed).public_key_bytes();
let identity_hex = hex_encode(&recipient_key); let identity_hex = hex_encode(&recipient_key);

View File

@@ -221,11 +221,11 @@ mod tests {
fn roundtrip_chat() { fn roundtrip_chat() {
let body = b"hello"; let body = b"hello";
let encoded = serialize_chat(body, None); let encoded = serialize_chat(body, None);
let (t, msg) = parse(&encoded).unwrap(); let (t, msg) = parse(&encoded).expect("serialize_chat output is valid");
assert_eq!(t, MessageType::Chat); assert_eq!(t, MessageType::Chat);
match &msg { assert!(matches!(&msg, AppMessage::Chat { .. }), "expected Chat, got {:?}", msg);
AppMessage::Chat { message_id: _, body: b } => assert_eq!(b.as_slice(), body), if let AppMessage::Chat { body: b, .. } = &msg {
_ => panic!("expected Chat"), assert_eq!(b.as_slice(), body);
} }
} }
@@ -234,25 +234,23 @@ mod tests {
let ref_id = [1u8; 16]; let ref_id = [1u8; 16];
let body = b"reply text"; let body = b"reply text";
let encoded = serialize_reply(ref_id, body); let encoded = serialize_reply(ref_id, body);
let (t, msg) = parse(&encoded).unwrap(); let (t, msg) = parse(&encoded).expect("serialize_reply output is valid");
assert_eq!(t, MessageType::Reply); assert_eq!(t, MessageType::Reply);
match &msg { assert!(matches!(&msg, AppMessage::Reply { .. }), "expected Reply, got {:?}", msg);
AppMessage::Reply { ref_msg_id, body: b } => { if let AppMessage::Reply { ref_msg_id, body: b } = &msg {
assert_eq!(ref_msg_id, &ref_id); assert_eq!(ref_msg_id, &ref_id);
assert_eq!(b.as_slice(), body); assert_eq!(b.as_slice(), body);
}
_ => panic!("expected Reply"),
} }
} }
#[test] #[test]
fn roundtrip_typing() { fn roundtrip_typing() {
let encoded = serialize_typing(1); let encoded = serialize_typing(1);
let (t, msg) = parse(&encoded).unwrap(); let (t, msg) = parse(&encoded).expect("serialize_typing output is valid");
assert_eq!(t, MessageType::Typing); assert_eq!(t, MessageType::Typing);
match &msg { assert!(matches!(&msg, AppMessage::Typing { .. }), "expected Typing, got {:?}", msg);
AppMessage::Typing { active } => assert_eq!(*active, 1), if let AppMessage::Typing { active } = &msg {
_ => panic!("expected Typing"), assert_eq!(*active, 1);
} }
} }
} }

View File

@@ -2,9 +2,10 @@
//! //!
//! # Design //! # Design
//! //!
//! [`GroupMember`] wraps an openmls [`MlsGroup`] plus the per-client //! [`GroupMember`] wraps an openmls [`MlsGroup`] plus a per-client crypto
//! [`StoreCrypto`] backend. The backend is **persistent** — it holds the //! backend ([`StoreCrypto`] or [`HybridCryptoProvider`] for M7). The backend
//! in-memory key store that maps init-key references to HPKE private keys. //! is **persistent** — it holds the key store that maps init-key references
//! to HPKE private keys (classical or hybrid).
//! openmls's `new_from_welcome` reads those private keys from the key store to //! openmls's `new_from_welcome` reads those private keys from the key store to
//! decrypt the Welcome, so the same backend instance must be used from //! decrypt the Welcome, so the same backend instance must be used from
//! `generate_key_package` through `join_group`. //! `generate_key_package` through `join_group`.
@@ -37,6 +38,7 @@ use openmls_traits::OpenMlsCryptoProvider;
use crate::{ use crate::{
error::CoreError, error::CoreError,
hybrid_crypto::HybridCryptoProvider,
identity::IdentityKeypair, identity::IdentityKeypair,
keystore::{DiskKeyStore, StoreCrypto}, keystore::{DiskKeyStore, StoreCrypto},
}; };
@@ -49,6 +51,9 @@ const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA2
/// Per-client MLS state: identity keypair, crypto backend, and optional group. /// Per-client MLS state: identity keypair, crypto backend, and optional group.
/// ///
/// Generic over the crypto provider `P`: [`StoreCrypto`] (default, classical)
/// or [`HybridCryptoProvider`] (M7, post-quantum hybrid KEM).
///
/// # Lifecycle /// # Lifecycle
/// ///
/// ```text /// ```text
@@ -60,10 +65,10 @@ const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA2
/// ├─ send_message(msg) → encrypt application data /// ├─ send_message(msg) → encrypt application data
/// └─ receive_message(b) → decrypt; returns Some(plaintext) or None /// └─ receive_message(b) → decrypt; returns Some(plaintext) or None
/// ``` /// ```
pub struct GroupMember { pub struct GroupMember<P: OpenMlsCryptoProvider = StoreCrypto> {
/// Persistent crypto backend. Holds the in-memory key store with HPKE /// Crypto backend (classical or hybrid). Holds the key store with HPKE
/// private keys created during `generate_key_package`. /// private keys created during `generate_key_package`.
backend: StoreCrypto, backend: P,
/// Long-term Ed25519 identity keypair. Also used as the MLS `Signer`. /// Long-term Ed25519 identity keypair. Also used as the MLS `Signer`.
identity: Arc<IdentityKeypair>, identity: Arc<IdentityKeypair>,
/// Active MLS group, if any. /// Active MLS group, if any.
@@ -72,8 +77,8 @@ pub struct GroupMember {
config: MlsGroupConfig, config: MlsGroupConfig,
} }
impl GroupMember { impl GroupMember<StoreCrypto> {
/// Create a new `GroupMember` with a fresh crypto backend. /// Create a new `GroupMember` with a fresh crypto backend (classical X25519).
pub fn new(identity: Arc<IdentityKeypair>) -> Self { pub fn new(identity: Arc<IdentityKeypair>) -> Self {
Self::new_with_state(identity, DiskKeyStore::ephemeral(), None) Self::new_with_state(identity, DiskKeyStore::ephemeral(), None)
} }
@@ -105,6 +110,41 @@ impl GroupMember {
config, config,
} }
} }
}
impl GroupMember<HybridCryptoProvider> {
/// Create a `GroupMember` that uses post-quantum hybrid KEM (X25519 + ML-KEM-768) for HPKE.
///
/// All members of a group must use the same provider type: if the creator uses
/// `new_with_hybrid`, KeyPackages will have hybrid init keys and joiners must
/// also use `new_with_hybrid` to decrypt the Welcome.
pub fn new_with_hybrid(
identity: Arc<IdentityKeypair>,
key_store: DiskKeyStore,
) -> Self {
Self::new_with_state_hybrid(identity, key_store, None)
}
/// Create a PQ `GroupMember` from persisted state (identity, key store, optional group).
pub fn new_with_state_hybrid(
identity: Arc<IdentityKeypair>,
key_store: DiskKeyStore,
group: Option<MlsGroup>,
) -> Self {
let config = MlsGroupConfig::builder()
.use_ratchet_tree_extension(true)
.build();
Self {
backend: HybridCryptoProvider::new(key_store),
identity,
group,
config,
}
}
}
impl<P: OpenMlsCryptoProvider> GroupMember<P> {
// ── KeyPackage ──────────────────────────────────────────────────────────── // ── KeyPackage ────────────────────────────────────────────────────────────
@@ -414,7 +454,7 @@ impl GroupMember {
} }
/// Return a reference to the underlying crypto backend. /// Return a reference to the underlying crypto backend.
pub fn backend(&self) -> &StoreCrypto { pub fn backend(&self) -> &P {
&self.backend &self.backend
} }
@@ -498,6 +538,48 @@ mod tests {
assert_eq!(pt_creator, b"hello back"); assert_eq!(pt_creator, b"hello back");
} }
/// M7: Full two-party MLS round-trip with post-quantum hybrid KEM (HybridCryptoProvider).
#[test]
fn two_party_mls_round_trip_hybrid() {
let creator_id = Arc::new(IdentityKeypair::generate());
let joiner_id = Arc::new(IdentityKeypair::generate());
let key_store_creator = DiskKeyStore::ephemeral();
let key_store_joiner = DiskKeyStore::ephemeral();
let mut creator =
GroupMember::<HybridCryptoProvider>::new_with_hybrid(Arc::clone(&creator_id), key_store_creator);
let mut joiner =
GroupMember::<HybridCryptoProvider>::new_with_hybrid(Arc::clone(&joiner_id), key_store_joiner);
let joiner_kp = joiner
.generate_key_package()
.expect("joiner KeyPackage (hybrid)");
creator
.create_group(b"test-group-m7-hybrid")
.expect("creator create group");
let (_, welcome) = creator
.add_member(&joiner_kp)
.expect("creator add joiner");
joiner.join_group(&welcome).expect("joiner join group");
let ct_creator = creator.send_message(b"hello pq").expect("creator send");
let pt_joiner = joiner
.receive_message(&ct_creator)
.expect("joiner recv")
.expect("application message");
assert_eq!(pt_joiner, b"hello pq");
let ct_joiner = joiner.send_message(b"hello back pq").expect("joiner send");
let pt_creator = creator
.receive_message(&ct_joiner)
.expect("creator recv")
.expect("application message");
assert_eq!(pt_creator, b"hello back pq");
}
/// `group_id()` returns None before create_group, Some afterwards. /// `group_id()` returns None before create_group, Some afterwards.
#[test] #[test]
fn group_id_lifecycle() { fn group_id_lifecycle() {

View File

@@ -38,4 +38,5 @@ pub use hybrid_kem::{
pub use hybrid_crypto::{HybridCrypto, HybridCryptoProvider}; pub use hybrid_crypto::{HybridCrypto, HybridCryptoProvider};
pub use identity::IdentityKeypair; pub use identity::IdentityKeypair;
pub use keypackage::{generate_key_package, validate_keypackage_ciphersuite}; pub use keypackage::{generate_key_package, validate_keypackage_ciphersuite};
pub use keystore::DiskKeyStore; pub use keystore::{DiskKeyStore, StoreCrypto};
pub use openmls::prelude::MlsGroup;

View File

@@ -24,6 +24,7 @@ futures = { workspace = true }
# Server utilities # Server utilities
dashmap = { workspace = true } dashmap = { workspace = true }
governor = { workspace = true }
sha2 = { workspace = true } sha2 = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }

View File

@@ -10,8 +10,12 @@ use crate::error_codes::*;
pub const SESSION_TTL_SECS: u64 = 24 * 60 * 60; // 24 hours pub const SESSION_TTL_SECS: u64 = 24 * 60 * 60; // 24 hours
pub const PENDING_LOGIN_TTL_SECS: u64 = 300; // 5 minutes pub const PENDING_LOGIN_TTL_SECS: u64 = 300; // 5 minutes
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60; /// Maximum enqueues per second per token before GCRA rate limiting kicks in.
pub const RATE_LIMIT_MAX_ENQUEUES: u32 = 100; pub const RATE_LIMIT_MAX_PER_SEC: std::num::NonZeroU32 =
std::num::NonZeroU32::new(100).expect("RATE_LIMIT_MAX_PER_SEC must be non-zero");
/// Keyed GCRA rate limiter backed by DashMap (one bucket per session token).
pub type RateLimiter = governor::DefaultKeyedRateLimiter<Vec<u8>>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct AuthConfig { pub struct AuthConfig {
@@ -47,11 +51,6 @@ pub struct PendingLogin {
pub created_at: u64, pub created_at: u64,
} }
pub struct RateEntry {
pub count: u32,
pub window_start: u64,
}
#[derive(Clone)] #[derive(Clone)]
pub struct AuthContext { pub struct AuthContext {
pub token: Vec<u8>, pub token: Vec<u8>,
@@ -65,32 +64,14 @@ pub fn current_timestamp() -> u64 {
.as_secs() .as_secs()
} }
pub fn check_rate_limit( /// Check the GCRA rate limit for a token. Returns an error if the token has exceeded the quota.
rate_limits: &DashMap<Vec<u8>, RateEntry>, pub fn check_rate_limit(limiter: &RateLimiter, token: &[u8]) -> Result<(), capnp::Error> {
token: &[u8], limiter.check_key(&token.to_vec()).map_err(|_| {
) -> Result<(), capnp::Error> { crate::error_codes::coded_error(
let now = current_timestamp(); E014_RATE_LIMITED,
let mut entry = rate_limits.entry(token.to_vec()).or_insert(RateEntry { format!("rate limit exceeded: max {} enqueues/s", RATE_LIMIT_MAX_PER_SEC),
count: 0, )
window_start: now, })
});
if now - entry.window_start >= RATE_LIMIT_WINDOW_SECS {
entry.count = 1;
entry.window_start = now;
} else {
entry.count += 1;
if entry.count > RATE_LIMIT_MAX_ENQUEUES {
return Err(crate::error_codes::coded_error(
E014_RATE_LIMITED,
format!(
"rate limit exceeded: {} enqueues in {}s window",
RATE_LIMIT_MAX_ENQUEUES, RATE_LIMIT_WINDOW_SECS
),
));
}
}
Ok(())
} }
pub fn validate_auth( pub fn validate_auth(

View File

@@ -23,7 +23,7 @@ mod sql_store;
mod tls; mod tls;
mod storage; mod storage;
use auth::{AuthConfig, PendingLogin, RateEntry, SessionInfo}; use auth::{AuthConfig, PendingLogin, RateLimiter, SessionInfo, RATE_LIMIT_MAX_PER_SEC};
use config::{ use config::{
load_config, merge_config, validate_production_config, DEFAULT_DATA_DIR, DEFAULT_DB_PATH, load_config, merge_config, validate_production_config, DEFAULT_DATA_DIR, DEFAULT_DB_PATH,
DEFAULT_LISTEN, DEFAULT_STORE_BACKEND, DEFAULT_TLS_CERT, DEFAULT_TLS_KEY, DEFAULT_LISTEN, DEFAULT_STORE_BACKEND, DEFAULT_TLS_CERT, DEFAULT_TLS_KEY,
@@ -215,13 +215,15 @@ async fn main() -> anyhow::Result<()> {
let pending_logins: Arc<DashMap<String, PendingLogin>> = Arc::new(DashMap::new()); let pending_logins: Arc<DashMap<String, PendingLogin>> = Arc::new(DashMap::new());
let sessions: Arc<DashMap<Vec<u8>, SessionInfo>> = Arc::new(DashMap::new()); let sessions: Arc<DashMap<Vec<u8>, SessionInfo>> = Arc::new(DashMap::new());
let rate_limits: Arc<DashMap<Vec<u8>, RateEntry>> = Arc::new(DashMap::new()); let rate_limiter: Arc<RateLimiter> = Arc::new(governor::RateLimiter::keyed(
governor::Quota::per_second(RATE_LIMIT_MAX_PER_SEC),
));
// Background cleanup task (expire sessions, pending logins, rate limits, and stale messages). // Background cleanup task (expire sessions, pending logins, and stale messages).
// Governor's DashMapStateStore handles rate-limit cleanup automatically.
spawn_cleanup_task( spawn_cleanup_task(
Arc::clone(&sessions), Arc::clone(&sessions),
Arc::clone(&pending_logins), Arc::clone(&pending_logins),
Arc::clone(&rate_limits),
Arc::clone(&store), Arc::clone(&store),
); );
@@ -260,7 +262,7 @@ async fn main() -> anyhow::Result<()> {
let opaque_setup = Arc::clone(&opaque_setup); let opaque_setup = Arc::clone(&opaque_setup);
let pending_logins = Arc::clone(&pending_logins); let pending_logins = Arc::clone(&pending_logins);
let sessions = Arc::clone(&sessions); let sessions = Arc::clone(&sessions);
let rate_limits = Arc::clone(&rate_limits); let rate_limiter = Arc::clone(&rate_limiter);
let sealed_sender = effective.sealed_sender; let sealed_sender = effective.sealed_sender;
tokio::task::spawn_local(async move { tokio::task::spawn_local(async move {
@@ -272,7 +274,7 @@ async fn main() -> anyhow::Result<()> {
opaque_setup, opaque_setup,
pending_logins, pending_logins,
sessions, sessions,
rate_limits, rate_limiter,
sealed_sender, sealed_sender,
) )
.await .await

View File

@@ -84,7 +84,7 @@ impl NodeServiceImpl {
)); ));
} }
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) { if let Err(e) = check_rate_limit(&self.rate_limiter, &auth_ctx.token) {
// Audit: rate limit hit — do not log token or identity. // Audit: rate limit hit — do not log token or identity.
tracing::warn!("rate_limit_hit"); tracing::warn!("rate_limit_hit");
metrics::record_rate_limit_hit_total(); metrics::record_rate_limit_hit_total();

View File

@@ -10,8 +10,7 @@ use tokio::sync::Notify;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
use crate::auth::{ use crate::auth::{
current_timestamp, AuthConfig, PendingLogin, RateEntry, SessionInfo, current_timestamp, AuthConfig, PendingLogin, RateLimiter, SessionInfo, PENDING_LOGIN_TTL_SECS,
PENDING_LOGIN_TTL_SECS, RATE_LIMIT_WINDOW_SECS,
}; };
use crate::storage::Store; use crate::storage::Store;
@@ -143,7 +142,7 @@ pub struct NodeServiceImpl {
pub opaque_setup: Arc<ServerSetup<OpaqueSuite>>, pub opaque_setup: Arc<ServerSetup<OpaqueSuite>>,
pub pending_logins: Arc<DashMap<String, PendingLogin>>, pub pending_logins: Arc<DashMap<String, PendingLogin>>,
pub sessions: Arc<DashMap<Vec<u8>, SessionInfo>>, pub sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
pub rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>, pub rate_limiter: Arc<RateLimiter>,
/// When true, enqueue does not require identity-bound session (Sealed Sender). /// When true, enqueue does not require identity-bound session (Sealed Sender).
pub sealed_sender: bool, pub sealed_sender: bool,
} }
@@ -156,7 +155,7 @@ impl NodeServiceImpl {
opaque_setup: Arc<ServerSetup<OpaqueSuite>>, opaque_setup: Arc<ServerSetup<OpaqueSuite>>,
pending_logins: Arc<DashMap<String, PendingLogin>>, pending_logins: Arc<DashMap<String, PendingLogin>>,
sessions: Arc<DashMap<Vec<u8>, SessionInfo>>, sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>, rate_limiter: Arc<RateLimiter>,
sealed_sender: bool, sealed_sender: bool,
) -> Self { ) -> Self {
Self { Self {
@@ -166,7 +165,7 @@ impl NodeServiceImpl {
opaque_setup, opaque_setup,
pending_logins, pending_logins,
sessions, sessions,
rate_limits, rate_limiter,
sealed_sender, sealed_sender,
} }
} }
@@ -180,7 +179,7 @@ pub async fn handle_node_connection(
opaque_setup: Arc<ServerSetup<OpaqueSuite>>, opaque_setup: Arc<ServerSetup<OpaqueSuite>>,
pending_logins: Arc<DashMap<String, PendingLogin>>, pending_logins: Arc<DashMap<String, PendingLogin>>,
sessions: Arc<DashMap<Vec<u8>, SessionInfo>>, sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>, rate_limiter: Arc<RateLimiter>,
sealed_sender: bool, sealed_sender: bool,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let connection = connecting.await?; let connection = connecting.await?;
@@ -207,7 +206,7 @@ pub async fn handle_node_connection(
opaque_setup, opaque_setup,
pending_logins, pending_logins,
sessions, sessions,
rate_limits, rate_limiter,
sealed_sender, sealed_sender,
)); ));
@@ -221,7 +220,6 @@ const MESSAGE_TTL_SECS: u64 = 7 * 24 * 60 * 60; // 7 days
pub fn spawn_cleanup_task( pub fn spawn_cleanup_task(
sessions: Arc<DashMap<Vec<u8>, SessionInfo>>, sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
pending_logins: Arc<DashMap<String, PendingLogin>>, pending_logins: Arc<DashMap<String, PendingLogin>>,
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
store: Arc<dyn Store>, store: Arc<dyn Store>,
) { ) {
tokio::spawn(async move { tokio::spawn(async move {
@@ -232,7 +230,7 @@ pub fn spawn_cleanup_task(
sessions.retain(|_, info| info.expires_at > now); sessions.retain(|_, info| info.expires_at > now);
pending_logins.retain(|_, pl| now - pl.created_at < PENDING_LOGIN_TTL_SECS); pending_logins.retain(|_, pl| now - pl.created_at < PENDING_LOGIN_TTL_SECS);
rate_limits.retain(|_, entry| now - entry.window_start < RATE_LIMIT_WINDOW_SECS * 2); // Rate limit cleanup is handled automatically by governor's DashMapStateStore.
match store.gc_expired_messages(MESSAGE_TTL_SECS) { match store.gc_expired_messages(MESSAGE_TTL_SECS) {
Ok(n) if n > 0 => { Ok(n) if n > 0 => {

View File

@@ -346,8 +346,12 @@ impl Store for FileBackedStore {
channel_id: channel_id.to_vec(), channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(), recipient_key: recipient_key.to_vec(),
}; };
let seq = *inner.next_seq.entry(key.clone()).or_insert(0); let seq = {
*inner.next_seq.get_mut(&key).unwrap() = seq + 1; let entry = inner.next_seq.entry(key.clone()).or_insert(0);
let s = *entry;
*entry = s + 1;
s
};
inner.map.entry(key).or_default().push_back(SeqEntry { seq, data: payload }); inner.map.entry(key).or_default().push_back(SeqEntry { seq, data: payload });
self.flush_delivery_map(&self.ds_path, &*inner)?; self.flush_delivery_map(&self.ds_path, &*inner)?;
Ok(seq) Ok(seq)

View File

@@ -141,9 +141,16 @@ encrypted-at-rest options.
**Goal:** Replace the MLS crypto backend with a hybrid X25519 + ML-KEM-768 KEM, **Goal:** Replace the MLS crypto backend with a hybrid X25519 + ML-KEM-768 KEM,
providing post-quantum confidentiality for all group key material. providing post-quantum confidentiality for all group key material.
**Status:** PoC complete. `HybridCryptoProvider` and `HybridCrypto` implement
`OpenMlsCryptoProvider` / HPKE with hybrid KEM; `GroupMember<HybridCryptoProvider>`
via `new_with_hybrid()` runs full two-party MLS (create, add, join, send, recv).
Unit test `two_party_mls_round_trip_hybrid` passes. Remaining: optional CLI/client
flag to use hybrid provider for new groups; interoperability note (hybrid init
keys are non-standard until MLS adopts PQ).
**Deliverables:** **Deliverables:**
- Custom `OpenMlsCryptoProvider` with hybrid KEM in `quicnprotochat-core` - Custom `OpenMlsCryptoProvider` with hybrid KEM in `quicnprotochat-core` (**done**)
- Hybrid shared secret derivation: - Hybrid shared secret derivation:
``` ```