feat: add post-quantum hybrid KEM + SQLCipher persistence

Feature 1 — Post-Quantum Hybrid KEM (X25519 + ML-KEM-768):
- Create hybrid_kem.rs with keygen, encrypt, decrypt + 11 unit tests
- Wire format: version(1) | x25519_eph_pk(32) | mlkem_ct(1088) | nonce(12) | ct
- Add uploadHybridKey/fetchHybridKey RPCs to node.capnp schema
- Server: hybrid key storage in FileBackedStore + RPC handlers
- Client: hybrid keypair in StoredState, auto-wrap/unwrap in send/recv/invite/join
- demo-group runs full hybrid PQ envelope round-trip

Feature 2 — SQLCipher Persistence:
- Extract Store trait from FileBackedStore API
- Create SqlStore (rusqlite + bundled-sqlcipher) with encrypted-at-rest SQLite
- Schema: key_packages, deliveries, hybrid_keys tables with indexes
- Server CLI: --store-backend=sql, --db-path, --db-key flags
- 5 unit tests for SqlStore (FIFO, round-trip, upsert, channel isolation)

Also includes: client lib.rs refactor, auth config, TOML config file support,
mdBook documentation, and various cleanups by user.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-22 08:07:48 +01:00
parent d1ddef4cea
commit f334ed3d43
81 changed files with 14502 additions and 2289 deletions

View File

@@ -0,0 +1,971 @@
use std::fs;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::{Arc, OnceLock};
use anyhow::Context;
use capnp_rpc::{rpc_twoparty_capnp::Side, twoparty, RpcSystem};
use serde::{Deserialize, Serialize};
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
use quinn::{ClientConfig, Endpoint};
use quinn_proto::crypto::rustls::QuicClientConfig;
use rustls::pki_types::CertificateDer;
use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
use quicnprotochat_core::{
generate_key_package, hybrid_decrypt, hybrid_encrypt, DiskKeyStore, GroupMember,
HybridKeypair, HybridKeypairBytes, HybridPublicKey, IdentityKeypair,
};
use quicnprotochat_proto::node_capnp::{auth, node_service};
// Global auth context initialized once per process.
static AUTH_CONTEXT: OnceLock<ClientAuth> = OnceLock::new();
#[derive(Clone, Debug)]
pub struct ClientAuth {
version: u16,
access_token: Vec<u8>,
device_id: Vec<u8>,
}
impl ClientAuth {
/// Build a client auth context from optional token and device id.
/// Requires a non-empty token; we run version=1 only (no legacy mode).
pub fn from_parts(access_token: String, device_id: Option<String>) -> Self {
let token = access_token.into_bytes();
let device = device_id.unwrap_or_default().into_bytes();
Self {
version: 1,
access_token: token,
device_id: device,
}
}
}
/// Initialize the global auth context; subsequent calls are ignored.
pub fn init_auth(ctx: ClientAuth) {
let _ = AUTH_CONTEXT.set(ctx);
}
// ── Subcommand implementations ───────────────────────────────────────────────
/// Connect to `server`, call health, and print RTT over QUIC/TLS.
pub async fn cmd_ping(server: &str, ca_cert: &Path, server_name: &str) -> anyhow::Result<()> {
let sent_at = current_timestamp_ms();
let client = connect_node(server, ca_cert, server_name).await?;
let req = client.health_request();
let resp = req.send().promise.await.context("health RPC failed")?;
let status = resp
.get()
.context("health: bad response")?
.get_status()
.context("health: missing status")?
.to_str()
.unwrap_or("invalid");
let rtt_ms = current_timestamp_ms().saturating_sub(sent_at);
println!("health={status} rtt={rtt_ms}ms");
Ok(())
}
/// Generate a KeyPackage for a fresh identity and upload it to the AS.
///
/// Must run on a `LocalSet` because capnp-rpc is `!Send`.
pub async fn cmd_register(server: &str, ca_cert: &Path, server_name: &str) -> anyhow::Result<()> {
let identity = IdentityKeypair::generate();
let (tls_bytes, fingerprint) =
generate_key_package(&identity).context("KeyPackage generation failed")?;
let node_client = connect_node(server, ca_cert, server_name).await?;
let mut req = node_client.upload_key_package_request();
{
let mut p = req.get();
p.set_identity_key(&identity.public_key_bytes());
p.set_package(&tls_bytes);
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth);
}
let response = req
.send()
.promise
.await
.context("upload_key_package RPC failed")?;
let server_fp = response
.get()
.context("upload_key_package: bad response")?
.get_fingerprint()
.context("upload_key_package: missing fingerprint")?
.to_vec();
anyhow::ensure!(
server_fp == fingerprint,
"fingerprint mismatch: local={} server={}",
hex::encode(&fingerprint),
hex::encode(&server_fp),
);
println!(
"identity_key : {}",
hex::encode(identity.public_key_bytes())
);
println!("fingerprint : {}", hex::encode(&fingerprint));
println!("KeyPackage uploaded successfully.");
Ok(())
}
/// Upload the stored identity's KeyPackage to the AS (persists backend state).
pub async fn cmd_register_state(
state_path: &Path,
server: &str,
ca_cert: &Path,
server_name: &str,
) -> anyhow::Result<()> {
let state = load_or_init_state(state_path)?;
let (mut member, hybrid_kp) = state.into_parts(state_path)?;
let tls_bytes = member
.generate_key_package()
.context("KeyPackage generation failed")?;
let fingerprint = sha256(&tls_bytes);
let node_client = connect_node(server, ca_cert, server_name).await?;
let mut req = node_client.upload_key_package_request();
{
let mut p = req.get();
p.set_identity_key(&member.identity().public_key_bytes());
p.set_package(&tls_bytes);
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth);
}
let response = req
.send()
.promise
.await
.context("upload_key_package RPC failed")?;
let server_fp = response
.get()
.context("upload_key_package: bad response")?
.get_fingerprint()
.context("upload_key_package: missing fingerprint")?
.to_vec();
anyhow::ensure!(server_fp == fingerprint, "fingerprint mismatch");
// Upload hybrid public key alongside the KeyPackage.
if let Some(ref hkp) = hybrid_kp {
upload_hybrid_key(
&node_client,
&member.identity().public_key_bytes(),
&hkp.public_key(),
)
.await?;
println!("hybrid_key : uploaded (X25519 + ML-KEM-768)");
}
println!(
"identity_key : {}",
hex::encode(member.identity().public_key_bytes())
);
println!("fingerprint : {}", hex::encode(&fingerprint));
println!("KeyPackage uploaded successfully.");
save_state(state_path, &member, hybrid_kp.as_ref())?;
Ok(())
}
/// Fetch a peer's KeyPackage from the AS by their hex-encoded identity key.
///
/// Must run on a `LocalSet` because capnp-rpc is `!Send`.
pub async fn cmd_fetch_key(
server: &str,
ca_cert: &Path,
server_name: &str,
identity_key_hex: &str,
) -> anyhow::Result<()> {
let identity_key = hex::decode(identity_key_hex)
.map_err(|e| anyhow::anyhow!(e))
.context("identity_key must be 64 hex characters (32 bytes)")?;
anyhow::ensure!(
identity_key.len() == 32,
"identity_key must be exactly 32 bytes, got {}",
identity_key.len()
);
let node_client = connect_node(server, ca_cert, server_name).await?;
let mut req = node_client.fetch_key_package_request();
{
let mut p = req.get();
p.set_identity_key(&identity_key);
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth);
}
let response = req
.send()
.promise
.await
.context("fetch_key_package RPC failed")?;
let package = response
.get()
.context("fetch_key_package: bad response")?
.get_package()
.context("fetch_key_package: missing package field")?
.to_vec();
if package.is_empty() {
println!("No KeyPackage available for this identity.");
return Ok(());
}
use sha2::{Digest, Sha256};
let fingerprint = Sha256::digest(&package);
println!("fingerprint : {}", hex::encode(fingerprint));
println!("package_len : {} bytes", package.len());
println!("KeyPackage fetched successfully.");
Ok(())
}
/// Run a complete Alice↔Bob MLS round-trip using the unified server endpoint.
///
/// All payloads are wrapped in post-quantum hybrid envelopes (X25519 + ML-KEM-768).
pub async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) -> anyhow::Result<()> {
// Identities and MLS state must be tied to the same backend instance.
let alice_id = Arc::new(IdentityKeypair::generate());
let bob_id = Arc::new(IdentityKeypair::generate());
// Generate hybrid keypairs for both participants.
let alice_hybrid = HybridKeypair::generate();
let bob_hybrid = HybridKeypair::generate();
let mut alice = GroupMember::new(Arc::clone(&alice_id));
let mut bob = GroupMember::new(Arc::clone(&bob_id));
let alice_kp = alice
.generate_key_package()
.context("Alice KeyPackage generation failed")?;
let bob_kp = bob
.generate_key_package()
.context("Bob KeyPackage generation failed")?;
// Upload both KeyPackages and hybrid public keys to the server.
let alice_node = connect_node(server, ca_cert, server_name).await?;
let bob_node = connect_node(server, ca_cert, server_name).await?;
upload_key_package(&alice_node, &alice_id.public_key_bytes(), &alice_kp).await?;
upload_key_package(&bob_node, &bob_id.public_key_bytes(), &bob_kp).await?;
upload_hybrid_key(
&alice_node,
&alice_id.public_key_bytes(),
&alice_hybrid.public_key(),
)
.await?;
upload_hybrid_key(
&bob_node,
&bob_id.public_key_bytes(),
&bob_hybrid.public_key(),
)
.await?;
println!("hybrid public keys uploaded for Alice and Bob");
// Alice fetches Bob's KeyPackage and creates the group.
let fetched_bob_kp = fetch_key_package(&alice_node, &bob_id.public_key_bytes()).await?;
anyhow::ensure!(
!fetched_bob_kp.is_empty(),
"AS returned an empty KeyPackage for Bob",
);
alice
.create_group(b"demo-group")
.context("Alice create_group failed")?;
let (_commit, welcome) = alice
.add_member(&fetched_bob_kp)
.context("Alice add_member failed")?;
let alice_ds = alice_node.clone();
let bob_ds = bob_node.clone();
// Fetch Bob's hybrid PK and wrap the welcome.
let bob_hybrid_pk = fetch_hybrid_key(&alice_node, &bob_id.public_key_bytes())
.await?
.context("Bob hybrid key not found")?;
let wrapped_welcome =
hybrid_encrypt(&bob_hybrid_pk, &welcome).context("hybrid encrypt welcome")?;
enqueue(&alice_ds, &bob_id.public_key_bytes(), &wrapped_welcome).await?;
let welcome_payloads = fetch_all(&bob_ds, &bob_id.public_key_bytes()).await?;
let raw_welcome = welcome_payloads
.first()
.cloned()
.context("Welcome was not delivered to Bob via DS")?;
// Bob unwraps the hybrid envelope and joins the group.
let welcome_bytes = hybrid_decrypt(&bob_hybrid, &raw_welcome)
.context("Bob: hybrid decrypt welcome failed")?;
bob.join_group(&welcome_bytes)
.context("Bob join_group failed")?;
// Alice → Bob (hybrid-wrapped)
let ct_ab = alice
.send_message(b"hello bob")
.context("Alice send_message failed")?;
let wrapped_ab =
hybrid_encrypt(&bob_hybrid_pk, &ct_ab).context("hybrid encrypt Alice→Bob")?;
enqueue(&alice_ds, &bob_id.public_key_bytes(), &wrapped_ab).await?;
let bob_msgs = fetch_all(&bob_ds, &bob_id.public_key_bytes()).await?;
let raw_ab = bob_msgs
.first()
.context("Bob: missing Alice ciphertext from DS")?;
let inner_ab = hybrid_decrypt(&bob_hybrid, raw_ab).context("Bob: hybrid decrypt failed")?;
let ab_plaintext = bob
.receive_message(&inner_ab)?
.context("Bob expected application message from Alice")?;
println!(
"Alice → Bob plaintext: {}",
String::from_utf8_lossy(&ab_plaintext)
);
// Bob → Alice (hybrid-wrapped)
let alice_hybrid_pk = fetch_hybrid_key(&bob_node, &alice_id.public_key_bytes())
.await?
.context("Alice hybrid key not found")?;
let ct_ba = bob
.send_message(b"hello alice")
.context("Bob send_message failed")?;
let wrapped_ba =
hybrid_encrypt(&alice_hybrid_pk, &ct_ba).context("hybrid encrypt Bob→Alice")?;
enqueue(&bob_ds, &alice_id.public_key_bytes(), &wrapped_ba).await?;
let alice_msgs = fetch_all(&alice_ds, &alice_id.public_key_bytes()).await?;
let raw_ba = alice_msgs
.first()
.context("Alice: missing Bob ciphertext from DS")?;
let inner_ba =
hybrid_decrypt(&alice_hybrid, raw_ba).context("Alice: hybrid decrypt failed")?;
let ba_plaintext = alice
.receive_message(&inner_ba)?
.context("Alice expected application message from Bob")?;
println!(
"Bob → Alice plaintext: {}",
String::from_utf8_lossy(&ba_plaintext)
);
println!("demo-group complete (hybrid PQ envelope active)");
Ok(())
}
/// Create a new group and persist state.
pub async fn cmd_create_group(
state_path: &Path,
_server: &str,
group_id: &str,
) -> anyhow::Result<()> {
let state = load_or_init_state(state_path)?;
let (mut member, hybrid_kp) = state.into_parts(state_path)?;
anyhow::ensure!(
member.group_ref().is_none(),
"group already exists in state"
);
member
.create_group(group_id.as_bytes())
.context("create_group failed")?;
save_state(state_path, &member, hybrid_kp.as_ref())?;
println!("group created: {group_id}");
Ok(())
}
/// Invite a peer: fetch their KeyPackage, add to group, enqueue Welcome.
///
/// If the peer has a hybrid public key on the server, the Welcome is wrapped
/// in a post-quantum hybrid envelope (X25519 + ML-KEM-768).
pub async fn cmd_invite(
state_path: &Path,
server: &str,
ca_cert: &Path,
server_name: &str,
peer_key_hex: &str,
) -> anyhow::Result<()> {
let state = load_existing_state(state_path)?;
let (mut member, hybrid_kp) = state.into_parts(state_path)?;
let peer_key = decode_identity_key(peer_key_hex)?;
let node_client = connect_node(server, ca_cert, server_name).await?;
let peer_kp = fetch_key_package(&node_client, &peer_key).await?;
anyhow::ensure!(
!peer_kp.is_empty(),
"server returned empty KeyPackage for peer"
);
let _ = member
.group_ref()
.context("no active group; run create-group first")?;
let (_, welcome) = member.add_member(&peer_kp).context("add_member failed")?;
// Wrap welcome in hybrid envelope if peer has a hybrid public key.
let peer_hybrid_pk = fetch_hybrid_key(&node_client, &peer_key).await?;
let payload = if let Some(ref pk) = peer_hybrid_pk {
hybrid_encrypt(pk, &welcome).context("hybrid encrypt welcome failed")?
} else {
welcome
};
enqueue(&node_client, &peer_key, &payload).await?;
save_state(state_path, &member, hybrid_kp.as_ref())?;
println!(
"invited peer (welcome queued{})",
if peer_hybrid_pk.is_some() { ", hybrid-encrypted" } else { "" }
);
Ok(())
}
/// Join a group by consuming a Welcome from the server queue.
///
/// Automatically detects and decrypts hybrid-wrapped Welcomes.
pub async fn cmd_join(
state_path: &Path,
server: &str,
ca_cert: &Path,
server_name: &str,
) -> anyhow::Result<()> {
let state = load_existing_state(state_path)?;
let (mut member, hybrid_kp) = state.into_parts(state_path)?;
anyhow::ensure!(
member.group_ref().is_none(),
"group already active in state"
);
let node_client = connect_node(server, ca_cert, server_name).await?;
let welcomes = fetch_all(&node_client, &member.identity().public_key_bytes()).await?;
let raw_welcome = welcomes
.first()
.cloned()
.context("no Welcome found in DS for this identity")?;
// Try hybrid decryption first, fall back to raw MLS welcome.
let welcome_bytes = try_hybrid_unwrap(hybrid_kp.as_ref(), &raw_welcome);
member
.join_group(&welcome_bytes)
.context("join_group failed")?;
save_state(state_path, &member, hybrid_kp.as_ref())?;
println!("joined group successfully");
Ok(())
}
/// Send an application message via DS.
///
/// If the peer has a hybrid public key, the MLS ciphertext is additionally
/// wrapped in a post-quantum hybrid envelope.
pub async fn cmd_send(
state_path: &Path,
server: &str,
ca_cert: &Path,
server_name: &str,
peer_key_hex: &str,
msg: &str,
) -> anyhow::Result<()> {
let state = load_existing_state(state_path)?;
let (mut member, hybrid_kp) = state.into_parts(state_path)?;
let peer_key = decode_identity_key(peer_key_hex)?;
let node_client = connect_node(server, ca_cert, server_name).await?;
let ct = member
.send_message(msg.as_bytes())
.context("send_message failed")?;
// Wrap in hybrid envelope if peer has a hybrid public key.
let peer_hybrid_pk = fetch_hybrid_key(&node_client, &peer_key).await?;
let payload = if let Some(ref pk) = peer_hybrid_pk {
hybrid_encrypt(pk, &ct).context("hybrid encrypt failed")?
} else {
ct
};
enqueue(&node_client, &peer_key, &payload).await?;
save_state(state_path, &member, hybrid_kp.as_ref())?;
println!(
"message sent{}",
if peer_hybrid_pk.is_some() { " (hybrid-encrypted)" } else { "" }
);
Ok(())
}
/// Receive and decrypt all pending messages from the server.
///
/// Automatically detects and decrypts hybrid-wrapped payloads.
pub async fn cmd_recv(
state_path: &Path,
server: &str,
ca_cert: &Path,
server_name: &str,
wait_ms: u64,
stream: bool,
) -> anyhow::Result<()> {
let state = load_existing_state(state_path)?;
let (mut member, hybrid_kp) = state.into_parts(state_path)?;
let client = connect_node(server, ca_cert, server_name).await?;
loop {
let payloads = fetch_wait(&client, &member.identity().public_key_bytes(), wait_ms).await?;
if payloads.is_empty() {
if !stream {
println!("no messages");
return Ok(());
}
continue;
}
for (idx, payload) in payloads.iter().enumerate() {
// Try hybrid decryption, fall back to raw MLS payload.
let mls_payload = try_hybrid_unwrap(hybrid_kp.as_ref(), payload);
match member.receive_message(&mls_payload) {
Ok(Some(pt)) => println!("[{idx}] plaintext: {}", String::from_utf8_lossy(&pt)),
Ok(None) => println!("[{idx}] commit applied"),
Err(e) => println!("[{idx}] error: {e}"),
}
}
save_state(state_path, &member, hybrid_kp.as_ref())?;
if !stream {
return Ok(());
}
}
}
// ── Shared helpers ───────────────────────────────────────────────────────────
/// Establish a QUIC/TLS connection and return a `NodeService` client.
///
/// Must be called from within a `LocalSet` because capnp-rpc is `!Send`.
pub async fn connect_node(
server: &str,
ca_cert: &Path,
server_name: &str,
) -> anyhow::Result<node_service::Client> {
let addr: SocketAddr = server
.parse()
.with_context(|| format!("server must be host:port, got {server}"))?;
let cert_bytes = fs::read(ca_cert).with_context(|| format!("read ca_cert {ca_cert:?}"))?;
let mut roots = RootCertStore::empty();
roots
.add(CertificateDer::from(cert_bytes))
.context("add root cert")?;
let tls = RustlsClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let crypto = QuicClientConfig::try_from(tls)
.map_err(|e| anyhow::anyhow!("invalid client TLS config: {e}"))?;
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())?;
endpoint.set_default_client_config(ClientConfig::new(Arc::new(crypto)));
let connection = endpoint
.connect(addr, server_name)
.context("quic connect init")?
.await
.context("quic connect failed")?;
let (send, recv) = connection.open_bi().await.context("open bi stream")?;
let network = twoparty::VatNetwork::new(
recv.compat(),
send.compat_write(),
Side::Client,
Default::default(),
);
let mut rpc_system = RpcSystem::new(Box::new(network), None);
let client: node_service::Client = rpc_system.bootstrap(Side::Server);
tokio::task::spawn_local(rpc_system);
Ok(client)
}
/// Upload a KeyPackage and verify the fingerprint echoed by the AS.
pub async fn upload_key_package(
client: &node_service::Client,
identity_key: &[u8],
package: &[u8],
) -> anyhow::Result<()> {
let mut req = client.upload_key_package_request();
{
let mut p = req.get();
p.set_identity_key(identity_key);
p.set_package(package);
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth);
}
let resp = req
.send()
.promise
.await
.context("upload_key_package RPC failed")?;
let server_fp = resp
.get()
.context("upload_key_package: bad response")?
.get_fingerprint()
.context("upload_key_package: missing fingerprint")?
.to_vec();
let local_fp = sha256(package);
anyhow::ensure!(server_fp == local_fp, "fingerprint mismatch");
Ok(())
}
/// Fetch a KeyPackage for `identity_key` from the AS.
pub async fn fetch_key_package(
client: &node_service::Client,
identity_key: &[u8],
) -> anyhow::Result<Vec<u8>> {
let mut req = client.fetch_key_package_request();
{
let mut p = req.get();
p.set_identity_key(identity_key);
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth);
}
let resp = req
.send()
.promise
.await
.context("fetch_key_package RPC failed")?;
let pkg = resp
.get()
.context("fetch_key_package: bad response")?
.get_package()
.context("fetch_key_package: missing package field")?
.to_vec();
Ok(pkg)
}
/// Enqueue an opaque payload to the DS for `recipient_key`.
pub async fn enqueue(
client: &node_service::Client,
recipient_key: &[u8],
payload: &[u8],
) -> anyhow::Result<()> {
let mut req = client.enqueue_request();
{
let mut p = req.get();
p.set_recipient_key(recipient_key);
p.set_payload(payload);
p.set_channel_id(&[]);
p.set_version(1);
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth);
}
req.send().promise.await.context("enqueue RPC failed")?;
Ok(())
}
/// Fetch and drain all payloads for `recipient_key`.
pub async fn fetch_all(
client: &node_service::Client,
recipient_key: &[u8],
) -> anyhow::Result<Vec<Vec<u8>>> {
let mut req = client.fetch_request();
{
let mut p = req.get();
p.set_recipient_key(recipient_key);
p.set_channel_id(&[]);
p.set_version(1);
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth);
}
let resp = req.send().promise.await.context("fetch RPC failed")?;
let list = resp
.get()
.context("fetch: bad response")?
.get_payloads()
.context("fetch: missing payloads")?;
let mut payloads = Vec::with_capacity(list.len() as usize);
for i in 0..list.len() {
payloads.push(list.get(i).context("fetch: payload read failed")?.to_vec());
}
Ok(payloads)
}
/// Long-poll for payloads with optional timeout (ms).
pub async fn fetch_wait(
client: &node_service::Client,
recipient_key: &[u8],
timeout_ms: u64,
) -> anyhow::Result<Vec<Vec<u8>>> {
let mut req = client.fetch_wait_request();
{
let mut p = req.get();
p.set_recipient_key(recipient_key);
p.set_timeout_ms(timeout_ms);
p.set_channel_id(&[]);
p.set_version(1);
let mut auth = p.reborrow().init_auth();
set_auth(&mut auth);
}
let resp = req.send().promise.await.context("fetch_wait RPC failed")?;
let list = resp
.get()
.context("fetch_wait: bad response")?
.get_payloads()
.context("fetch_wait: missing payloads")?;
let mut payloads = Vec::with_capacity(list.len() as usize);
for i in 0..list.len() {
payloads.push(
list.get(i)
.context("fetch_wait: payload read failed")?
.to_vec(),
);
}
Ok(payloads)
}
/// Upload a hybrid (X25519 + ML-KEM-768) public key for an identity.
pub async fn upload_hybrid_key(
client: &node_service::Client,
identity_key: &[u8],
hybrid_pk: &HybridPublicKey,
) -> anyhow::Result<()> {
let mut req = client.upload_hybrid_key_request();
{
let mut p = req.get();
p.set_identity_key(identity_key);
p.set_hybrid_public_key(&hybrid_pk.to_bytes());
}
req.send()
.promise
.await
.context("upload_hybrid_key RPC failed")?;
Ok(())
}
/// Fetch a peer's hybrid public key from the server.
///
/// Returns `None` if the peer has not uploaded a hybrid key.
pub async fn fetch_hybrid_key(
client: &node_service::Client,
identity_key: &[u8],
) -> anyhow::Result<Option<HybridPublicKey>> {
let mut req = client.fetch_hybrid_key_request();
req.get().set_identity_key(identity_key);
let resp = req
.send()
.promise
.await
.context("fetch_hybrid_key RPC failed")?;
let pk_bytes = resp
.get()
.context("fetch_hybrid_key: bad response")?
.get_hybrid_public_key()
.context("fetch_hybrid_key: missing field")?
.to_vec();
if pk_bytes.is_empty() {
return Ok(None);
}
let pk = HybridPublicKey::from_bytes(&pk_bytes).context("invalid hybrid public key")?;
Ok(Some(pk))
}
/// Try to decrypt a hybrid envelope. If the payload is not a hybrid envelope or
/// decryption fails, return the original bytes unchanged (legacy plaintext MLS).
fn try_hybrid_unwrap(hybrid_kp: Option<&HybridKeypair>, payload: &[u8]) -> Vec<u8> {
if let Some(kp) = hybrid_kp {
if let Ok(inner) = hybrid_decrypt(kp, payload) {
return inner;
}
}
payload.to_vec()
}
fn sha256(bytes: &[u8]) -> Vec<u8> {
use sha2::{Digest, Sha256};
Sha256::digest(bytes).to_vec()
}
fn set_auth(auth: &mut auth::Builder<'_>) {
let ctx = AUTH_CONTEXT
.get()
.expect("init_auth must be called with a non-empty token before RPCs");
auth.set_version(ctx.version);
auth.set_access_token(&ctx.access_token);
auth.set_device_id(&ctx.device_id);
}
#[derive(Serialize, Deserialize)]
struct StoredState {
identity_seed: [u8; 32],
group: Option<Vec<u8>>,
/// Post-quantum hybrid keypair (X25519 + ML-KEM-768). `None` for legacy state files.
#[serde(default)]
hybrid_key: Option<HybridKeypairBytes>,
}
impl StoredState {
fn into_parts(self, state_path: &Path) -> anyhow::Result<(GroupMember, Option<HybridKeypair>)> {
let identity = Arc::new(IdentityKeypair::from_seed(self.identity_seed));
let group = self
.group
.map(|bytes| bincode::deserialize(&bytes).context("decode group"))
.transpose()?;
let key_store = DiskKeyStore::persistent(keystore_path(state_path))?;
let member = GroupMember::new_with_state(identity, key_store, group);
let hybrid_kp = self
.hybrid_key
.map(|bytes| HybridKeypair::from_bytes(&bytes).context("decode hybrid key"))
.transpose()?;
Ok((member, hybrid_kp))
}
fn from_parts(
member: &GroupMember,
hybrid_kp: Option<&HybridKeypair>,
) -> anyhow::Result<Self> {
let group = member
.group_ref()
.map(|g| bincode::serialize(g).context("serialize group"))
.transpose()?;
Ok(Self {
identity_seed: member.identity_seed(),
group,
hybrid_key: hybrid_kp.map(|kp| kp.to_bytes()),
})
}
}
fn load_or_init_state(path: &Path) -> anyhow::Result<StoredState> {
if path.exists() {
let mut state = load_existing_state(path)?;
// Upgrade legacy state files: generate hybrid keypair if missing.
if state.hybrid_key.is_none() {
state.hybrid_key = Some(HybridKeypair::generate().to_bytes());
write_state(path, &state)?;
}
return Ok(state);
}
let identity = IdentityKeypair::generate();
let hybrid_kp = HybridKeypair::generate();
let key_store = DiskKeyStore::persistent(keystore_path(path))?;
let member = GroupMember::new_with_state(Arc::new(identity), key_store, None);
let state = StoredState::from_parts(&member, Some(&hybrid_kp))?;
write_state(path, &state)?;
Ok(state)
}
fn load_existing_state(path: &Path) -> anyhow::Result<StoredState> {
let bytes = std::fs::read(path).with_context(|| format!("read state file {path:?}"))?;
bincode::deserialize(&bytes).context("decode state")
}
fn save_state(
path: &Path,
member: &GroupMember,
hybrid_kp: Option<&HybridKeypair>,
) -> anyhow::Result<()> {
let state = StoredState::from_parts(member, hybrid_kp)?;
write_state(path, &state)
}
fn write_state(path: &Path, state: &StoredState) -> anyhow::Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).with_context(|| format!("create dir {parent:?}"))?;
}
let bytes = bincode::serialize(state).context("encode state")?;
std::fs::write(path, bytes).with_context(|| format!("write state {path:?}"))?;
Ok(())
}
fn decode_identity_key(hex_str: &str) -> anyhow::Result<Vec<u8>> {
let bytes = hex::decode(hex_str)
.map_err(|e| anyhow::anyhow!(e))
.context("identity key must be hex")?;
anyhow::ensure!(bytes.len() == 32, "identity key must be 32 bytes");
Ok(bytes)
}
fn keystore_path(state_path: &Path) -> PathBuf {
let mut path = state_path.to_path_buf();
path.set_extension("ks");
path
}
/// Return the current Unix timestamp in milliseconds.
fn current_timestamp_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
// ── Hex encoding helper ─────────────────────────────────────────────────────
//
// We use a tiny inline module rather than adding `hex` as a dependency.
mod hex {
pub fn encode(bytes: impl AsRef<[u8]>) -> String {
bytes.as_ref().iter().map(|b| format!("{b:02x}")).collect()
}
pub fn decode(s: &str) -> Result<Vec<u8>, &'static str> {
if s.len() % 2 != 0 {
return Err("odd-length hex string");
}
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|_| "invalid hex character"))
.collect()
}
}

View File

@@ -1,38 +1,13 @@
//! quicnprotochat CLI client.
//!
//! # Subcommands
//!
//! | Subcommand | Description |
//! |--------------|----------------------------------------------------------|
//! | `ping` | Send a Ping to the server, print RTT |
//! | `register` | Generate a KeyPackage and upload it to the AS |
//! | `fetch-key` | Fetch a peer's KeyPackage from the AS by identity key |
//!
//! # Configuration
//!
//! | Env var | CLI flag | Default |
//! |-----------------|--------------|---------------------|
//! | `QUICNPROTOCHAT_SERVER`| `--server` | `127.0.0.1:4201` |
//! | `RUST_LOG` | — | `warn` |
use std::fs;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::path::PathBuf;
use anyhow::Context;
use capnp_rpc::{rpc_twoparty_capnp::Side, twoparty, RpcSystem};
use clap::{Parser, Subcommand};
use serde::{Deserialize, Serialize};
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
use quinn::{ClientConfig, Endpoint};
use quinn_proto::crypto::rustls::QuicClientConfig;
use rustls::pki_types::CertificateDer;
use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
use quicnprotochat_core::{generate_key_package, DiskKeyStore, GroupMember, IdentityKeypair};
use quicnprotochat_proto::node_capnp::node_service;
use quicnprotochat_client::{
cmd_create_group, cmd_demo_group, cmd_fetch_key, cmd_invite, cmd_join, cmd_ping, cmd_recv,
cmd_register, cmd_register_state, cmd_send, ClientAuth, init_auth,
};
// ── CLI ───────────────────────────────────────────────────────────────────────
@@ -57,6 +32,14 @@ struct Args {
)]
server_name: String,
/// Bearer token for authenticated requests (version 1, required).
#[arg(long, global = true, env = "QUICNPROTOCHAT_ACCESS_TOKEN", required = true)]
access_token: String,
/// Optional device identifier (UUID bytes encoded as hex or raw string).
#[arg(long, global = true, env = "QUICNPROTOCHAT_DEVICE_ID")]
device_id: Option<String>,
#[command(subcommand)]
command: Command,
}
@@ -66,7 +49,7 @@ enum Command {
/// Send a Ping to the server and print the round-trip time.
Ping {
/// Server address (host:port).
#[arg(long, default_value = "127.0.0.1:4201", env = "QUICNPROTOCHAT_SERVER")]
#[arg(long, default_value = "127.0.0.1:7000", env = "QUICNPROTOCHAT_SERVER")]
server: String,
},
@@ -76,7 +59,7 @@ enum Command {
/// Ed25519 identity public key bytes (hex), which peers need to fetch it.
Register {
/// Server address (host:port).
#[arg(long, default_value = "127.0.0.1:4201", env = "QUICNPROTOCHAT_SERVER")]
#[arg(long, default_value = "127.0.0.1:7000", env = "QUICNPROTOCHAT_SERVER")]
server: String,
},
@@ -86,7 +69,7 @@ enum Command {
/// hex characters (32 bytes).
FetchKey {
/// Server address (host:port).
#[arg(long, default_value = "127.0.0.1:4201", env = "QUICNPROTOCHAT_SERVER")]
#[arg(long, default_value = "127.0.0.1:7000", env = "QUICNPROTOCHAT_SERVER")]
server: String,
/// Target peer's Ed25519 identity public key (64 hex chars = 32 bytes).
@@ -96,7 +79,7 @@ enum Command {
/// Run a full Alice↔Bob MLS round-trip against live AS and DS endpoints.
DemoGroup {
/// Server address (host:port).
#[arg(long, default_value = "127.0.0.1:4201", env = "QUICNPROTOCHAT_SERVER")]
#[arg(long, default_value = "127.0.0.1:7000", env = "QUICNPROTOCHAT_SERVER")]
server: String,
},
@@ -111,7 +94,7 @@ enum Command {
state: PathBuf,
/// Authentication Service address (host:port).
#[arg(long, default_value = "127.0.0.1:4201", env = "QUICNPROTOCHAT_SERVER")]
#[arg(long, default_value = "127.0.0.1:7000", env = "QUICNPROTOCHAT_SERVER")]
server: String,
},
@@ -126,7 +109,7 @@ enum Command {
state: PathBuf,
/// Server address (host:port).
#[arg(long, default_value = "127.0.0.1:4201", env = "QUICNPROTOCHAT_SERVER")]
#[arg(long, default_value = "127.0.0.1:7000", env = "QUICNPROTOCHAT_SERVER")]
server: String,
/// Group identifier (arbitrary bytes, typically a human-readable name).
@@ -142,7 +125,7 @@ enum Command {
env = "QUICNPROTOCHAT_STATE"
)]
state: PathBuf,
#[arg(long, default_value = "127.0.0.1:4201", env = "QUICNPROTOCHAT_SERVER")]
#[arg(long, default_value = "127.0.0.1:7000", env = "QUICNPROTOCHAT_SERVER")]
server: String,
/// Peer identity public key (64 hex chars = 32 bytes).
#[arg(long)]
@@ -213,6 +196,10 @@ async fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Initialize auth context once for all RPCs.
let auth_ctx = ClientAuth::from_parts(args.access_token.clone(), args.device_id.clone());
init_auth(auth_ctx);
match args.command {
Command::Ping { server } => cmd_ping(&server, &args.ca_cert, &args.server_name).await,
Command::Register { server } => {
@@ -321,699 +308,4 @@ async fn main() -> anyhow::Result<()> {
.await
}
}
}
// ── Subcommand implementations ────────────────────────────────────────────────
/// Connect to `server`, call health, and print RTT over QUIC/TLS.
async fn cmd_ping(server: &str, ca_cert: &Path, server_name: &str) -> anyhow::Result<()> {
let sent_at = current_timestamp_ms();
let client = connect_node(server, ca_cert, server_name).await?;
let req = client.health_request();
let resp = req.send().promise.await.context("health RPC failed")?;
let status = resp
.get()
.context("health: bad response")?
.get_status()
.context("health: missing status")?
.to_str()
.unwrap_or("invalid");
let rtt_ms = current_timestamp_ms().saturating_sub(sent_at);
println!("health={status} rtt={rtt_ms}ms");
Ok(())
}
/// Generate a KeyPackage for a fresh identity and upload it to the AS.
///
/// Must run on a `LocalSet` because capnp-rpc is `!Send`.
async fn cmd_register(server: &str, ca_cert: &Path, server_name: &str) -> anyhow::Result<()> {
let identity = IdentityKeypair::generate();
let (tls_bytes, fingerprint) =
generate_key_package(&identity).context("KeyPackage generation failed")?;
let node_client = connect_node(server, ca_cert, server_name).await?;
let mut req = node_client.upload_key_package_request();
req.get().set_identity_key(&identity.public_key_bytes());
req.get().set_package(&tls_bytes);
let response = req
.send()
.promise
.await
.context("upload_key_package RPC failed")?;
let server_fp = response
.get()
.context("upload_key_package: bad response")?
.get_fingerprint()
.context("upload_key_package: missing fingerprint")?
.to_vec();
// Verify the server echoed the same fingerprint.
anyhow::ensure!(
server_fp == fingerprint,
"fingerprint mismatch: local={} server={}",
hex::encode(&fingerprint),
hex::encode(&server_fp),
);
println!(
"identity_key : {}",
hex::encode(identity.public_key_bytes())
);
println!("fingerprint : {}", hex::encode(&fingerprint));
println!("KeyPackage uploaded successfully.");
Ok(())
}
/// Upload the stored identity's KeyPackage to the AS (persists backend state).
async fn cmd_register_state(
state_path: &Path,
server: &str,
ca_cert: &Path,
server_name: &str,
) -> anyhow::Result<()> {
let state = load_or_init_state(state_path)?;
let mut member = state.into_member(state_path)?;
let tls_bytes = member
.generate_key_package()
.context("KeyPackage generation failed")?;
let fingerprint = sha256(&tls_bytes);
let node_client = connect_node(server, ca_cert, server_name).await?;
let mut req = node_client.upload_key_package_request();
req.get()
.set_identity_key(&member.identity().public_key_bytes());
req.get().set_package(&tls_bytes);
let response = req
.send()
.promise
.await
.context("upload_key_package RPC failed")?;
let server_fp = response
.get()
.context("upload_key_package: bad response")?
.get_fingerprint()
.context("upload_key_package: missing fingerprint")?
.to_vec();
anyhow::ensure!(server_fp == fingerprint, "fingerprint mismatch");
println!(
"identity_key : {}",
hex::encode(member.identity().public_key_bytes())
);
println!("fingerprint : {}", hex::encode(&fingerprint));
println!("KeyPackage uploaded successfully.");
save_state(state_path, &member)?;
Ok(())
}
/// Fetch a peer's KeyPackage from the AS by their hex-encoded identity key.
///
/// Must run on a `LocalSet` because capnp-rpc is `!Send`.
async fn cmd_fetch_key(
server: &str,
ca_cert: &Path,
server_name: &str,
identity_key_hex: &str,
) -> anyhow::Result<()> {
let identity_key = hex::decode(identity_key_hex)
.map_err(|e| anyhow::anyhow!(e))
.context("identity_key must be 64 hex characters (32 bytes)")?;
anyhow::ensure!(
identity_key.len() == 32,
"identity_key must be exactly 32 bytes, got {}",
identity_key.len()
);
let node_client = connect_node(server, ca_cert, server_name).await?;
let mut req = node_client.fetch_key_package_request();
req.get().set_identity_key(&identity_key);
let response = req
.send()
.promise
.await
.context("fetch_key_package RPC failed")?;
let package = response
.get()
.context("fetch_key_package: bad response")?
.get_package()
.context("fetch_key_package: missing package field")?
.to_vec();
if package.is_empty() {
println!("No KeyPackage available for this identity.");
return Ok(());
}
use sha2::{Digest, Sha256};
let fingerprint = Sha256::digest(&package);
println!("fingerprint : {}", hex::encode(fingerprint));
println!("package_len : {} bytes", package.len());
println!("KeyPackage fetched successfully.");
Ok(())
}
/// Run a complete Alice↔Bob MLS round-trip using the unified server endpoint.
async fn cmd_demo_group(server: &str, ca_cert: &Path, server_name: &str) -> anyhow::Result<()> {
// Identities and MLS state must be tied to the same backend instance.
let alice_id = Arc::new(IdentityKeypair::generate());
let bob_id = Arc::new(IdentityKeypair::generate());
let mut alice = GroupMember::new(Arc::clone(&alice_id));
let mut bob = GroupMember::new(Arc::clone(&bob_id));
let alice_kp = alice
.generate_key_package()
.context("Alice KeyPackage generation failed")?;
let bob_kp = bob
.generate_key_package()
.context("Bob KeyPackage generation failed")?;
// Upload both KeyPackages to the server.
let alice_node = connect_node(server, ca_cert, server_name).await?;
let bob_node = connect_node(server, ca_cert, server_name).await?;
upload_key_package(&alice_node, &alice_id.public_key_bytes(), &alice_kp).await?;
upload_key_package(&bob_node, &bob_id.public_key_bytes(), &bob_kp).await?;
// Alice fetches Bob's KeyPackage and creates the group.
let fetched_bob_kp = fetch_key_package(&alice_node, &bob_id.public_key_bytes()).await?;
anyhow::ensure!(
!fetched_bob_kp.is_empty(),
"AS returned an empty KeyPackage for Bob",
);
alice
.create_group(b"demo-group")
.context("Alice create_group failed")?;
let (_commit, welcome) = alice
.add_member(&fetched_bob_kp)
.context("Alice add_member failed")?;
let alice_ds = alice_node.clone();
let bob_ds = bob_node.clone();
enqueue(&alice_ds, &bob_id.public_key_bytes(), &welcome).await?;
let welcome_payloads = fetch_all(&bob_ds, &bob_id.public_key_bytes()).await?;
let welcome_bytes = welcome_payloads
.first()
.cloned()
.context("Welcome was not delivered to Bob via DS")?;
bob.join_group(&welcome_bytes)
.context("Bob join_group failed")?;
// Alice → Bob
let ct_ab = alice
.send_message(b"hello bob")
.context("Alice send_message failed")?;
enqueue(&alice_ds, &bob_id.public_key_bytes(), &ct_ab).await?;
let bob_msgs = fetch_all(&bob_ds, &bob_id.public_key_bytes()).await?;
let ab_plaintext = bob
.receive_message(
bob_msgs
.first()
.context("Bob: missing Alice ciphertext from DS")?,
)?
.context("Bob expected application message from Alice")?;
println!(
"Alice → Bob plaintext: {}",
String::from_utf8_lossy(&ab_plaintext)
);
// Bob → Alice
let ct_ba = bob
.send_message(b"hello alice")
.context("Bob send_message failed")?;
enqueue(&bob_ds, &alice_id.public_key_bytes(), &ct_ba).await?;
let alice_msgs = fetch_all(&alice_ds, &alice_id.public_key_bytes()).await?;
let ba_plaintext = alice
.receive_message(
alice_msgs
.first()
.context("Alice: missing Bob ciphertext from DS")?,
)?
.context("Alice expected application message from Bob")?;
println!(
"Bob → Alice plaintext: {}",
String::from_utf8_lossy(&ba_plaintext)
);
println!("demo-group complete ✔");
Ok(())
}
/// Create a new group and persist state.
async fn cmd_create_group(state_path: &Path, _server: &str, group_id: &str) -> anyhow::Result<()> {
let state = load_or_init_state(state_path)?;
let mut member = state.into_member(state_path)?;
anyhow::ensure!(
member.group_ref().is_none(),
"group already exists in state"
);
member
.create_group(group_id.as_bytes())
.context("create_group failed")?;
save_state(state_path, &member)?;
println!("group created: {group_id}");
Ok(())
}
/// Invite a peer: fetch their KeyPackage, add to group, enqueue Welcome.
async fn cmd_invite(
state_path: &Path,
server: &str,
ca_cert: &Path,
server_name: &str,
peer_key_hex: &str,
) -> anyhow::Result<()> {
let state = load_existing_state(state_path)?;
let mut member = state.into_member(state_path)?;
let peer_key = decode_identity_key(peer_key_hex)?;
let node_client = connect_node(server, ca_cert, server_name).await?;
let peer_kp = fetch_key_package(&node_client, &peer_key).await?;
anyhow::ensure!(
!peer_kp.is_empty(),
"server returned empty KeyPackage for peer"
);
let _ = member
.group_ref()
.context("no active group; run create-group first")?;
let (_, welcome) = member.add_member(&peer_kp).context("add_member failed")?;
enqueue(&node_client, &peer_key, &welcome).await?;
save_state(state_path, &member)?;
println!("invited peer (welcome queued)");
Ok(())
}
/// Join a group by consuming a Welcome from the server queue.
async fn cmd_join(
state_path: &Path,
server: &str,
ca_cert: &Path,
server_name: &str,
) -> anyhow::Result<()> {
let state = load_existing_state(state_path)?;
let mut member = state.into_member(state_path)?;
anyhow::ensure!(
member.group_ref().is_none(),
"group already active in state"
);
let node_client = connect_node(server, ca_cert, server_name).await?;
let welcomes = fetch_all(&node_client, &member.identity().public_key_bytes()).await?;
let welcome_bytes = welcomes
.first()
.cloned()
.context("no Welcome found in DS for this identity")?;
member
.join_group(&welcome_bytes)
.context("join_group failed")?;
save_state(state_path, &member)?;
println!("joined group successfully");
Ok(())
}
/// Send an application message via DS.
async fn cmd_send(
state_path: &Path,
server: &str,
ca_cert: &Path,
server_name: &str,
peer_key_hex: &str,
msg: &str,
) -> anyhow::Result<()> {
let state = load_existing_state(state_path)?;
let mut member = state.into_member(state_path)?;
let peer_key = decode_identity_key(peer_key_hex)?;
let node_client = connect_node(server, ca_cert, server_name).await?;
let ct = member
.send_message(msg.as_bytes())
.context("send_message failed")?;
enqueue(&node_client, &peer_key, &ct).await?;
save_state(state_path, &member)?;
println!("message sent");
Ok(())
}
/// Receive and decrypt all pending messages from the server.
async fn cmd_recv(
state_path: &Path,
server: &str,
ca_cert: &Path,
server_name: &str,
wait_ms: u64,
stream: bool,
) -> anyhow::Result<()> {
let state = load_existing_state(state_path)?;
let mut member = state.into_member(state_path)?;
let client = connect_node(server, ca_cert, server_name).await?;
loop {
let payloads = fetch_wait(&client, &member.identity().public_key_bytes(), wait_ms).await?;
if payloads.is_empty() {
if !stream {
println!("no messages");
return Ok(());
}
continue;
}
for (idx, payload) in payloads.iter().enumerate() {
match member.receive_message(payload) {
Ok(Some(pt)) => println!("[{idx}] plaintext: {}", String::from_utf8_lossy(&pt)),
Ok(None) => println!("[{idx}] commit applied"),
Err(e) => println!("[{idx}] error: {e}"),
}
}
save_state(state_path, &member)?;
if !stream {
return Ok(());
}
}
}
// ── Shared helpers ────────────────────────────────────────────────────────────
/// Establish a QUIC/TLS connection and return a `NodeService` client.
///
/// Must be called from within a `LocalSet` because capnp-rpc is `!Send`.
async fn connect_node(
server: &str,
ca_cert: &Path,
server_name: &str,
) -> anyhow::Result<node_service::Client> {
let addr: SocketAddr = server
.parse()
.with_context(|| format!("server must be host:port, got {server}"))?;
let cert_bytes = fs::read(ca_cert).with_context(|| format!("read ca_cert {ca_cert:?}"))?;
let mut roots = RootCertStore::empty();
roots
.add(CertificateDer::from(cert_bytes))
.context("add root cert")?;
let tls = RustlsClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let crypto = QuicClientConfig::try_from(tls)
.map_err(|e| anyhow::anyhow!("invalid client TLS config: {e}"))?;
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())?;
endpoint.set_default_client_config(ClientConfig::new(Arc::new(crypto)));
let connection = endpoint
.connect(addr, server_name)
.context("quic connect init")?
.await
.context("quic connect failed")?;
let (send, recv) = connection.open_bi().await.context("open bi stream")?;
let network = twoparty::VatNetwork::new(
recv.compat(),
send.compat_write(),
Side::Client,
Default::default(),
);
let mut rpc_system = RpcSystem::new(Box::new(network), None);
let client: node_service::Client = rpc_system.bootstrap(Side::Server);
tokio::task::spawn_local(rpc_system);
Ok(client)
}
/// Upload a KeyPackage and verify the fingerprint echoed by the AS.
async fn upload_key_package(
client: &node_service::Client,
identity_key: &[u8],
package: &[u8],
) -> anyhow::Result<()> {
let mut req = client.upload_key_package_request();
req.get().set_identity_key(identity_key);
req.get().set_package(package);
let resp = req
.send()
.promise
.await
.context("upload_key_package RPC failed")?;
let server_fp = resp
.get()
.context("upload_key_package: bad response")?
.get_fingerprint()
.context("upload_key_package: missing fingerprint")?
.to_vec();
let local_fp = sha256(package);
anyhow::ensure!(server_fp == local_fp, "fingerprint mismatch");
Ok(())
}
/// Fetch a KeyPackage for `identity_key` from the AS.
async fn fetch_key_package(
client: &node_service::Client,
identity_key: &[u8],
) -> anyhow::Result<Vec<u8>> {
let mut req = client.fetch_key_package_request();
req.get().set_identity_key(identity_key);
let resp = req
.send()
.promise
.await
.context("fetch_key_package RPC failed")?;
let pkg = resp
.get()
.context("fetch_key_package: bad response")?
.get_package()
.context("fetch_key_package: missing package field")?
.to_vec();
Ok(pkg)
}
/// Enqueue an opaque payload to the DS for `recipient_key`.
async fn enqueue(
client: &node_service::Client,
recipient_key: &[u8],
payload: &[u8],
) -> anyhow::Result<()> {
let mut req = client.enqueue_request();
req.get().set_recipient_key(recipient_key);
req.get().set_payload(payload);
req.send().promise.await.context("enqueue RPC failed")?;
Ok(())
}
/// Fetch and drain all payloads for `recipient_key`.
async fn fetch_all(
client: &node_service::Client,
recipient_key: &[u8],
) -> anyhow::Result<Vec<Vec<u8>>> {
let mut req = client.fetch_request();
req.get().set_recipient_key(recipient_key);
let resp = req.send().promise.await.context("fetch RPC failed")?;
let list = resp
.get()
.context("fetch: bad response")?
.get_payloads()
.context("fetch: missing payloads")?;
let mut payloads = Vec::with_capacity(list.len() as usize);
for i in 0..list.len() {
payloads.push(list.get(i).context("fetch: payload read failed")?.to_vec());
}
Ok(payloads)
}
/// Long-poll for payloads with optional timeout (ms).
async fn fetch_wait(
client: &node_service::Client,
recipient_key: &[u8],
timeout_ms: u64,
) -> anyhow::Result<Vec<Vec<u8>>> {
let mut req = client.fetch_wait_request();
req.get().set_recipient_key(recipient_key);
req.get().set_timeout_ms(timeout_ms);
let resp = req.send().promise.await.context("fetch_wait RPC failed")?;
let list = resp
.get()
.context("fetch_wait: bad response")?
.get_payloads()
.context("fetch_wait: missing payloads")?;
let mut payloads = Vec::with_capacity(list.len() as usize);
for i in 0..list.len() {
payloads.push(
list.get(i)
.context("fetch_wait: payload read failed")?
.to_vec(),
);
}
Ok(payloads)
}
fn sha256(bytes: &[u8]) -> Vec<u8> {
use sha2::{Digest, Sha256};
Sha256::digest(bytes).to_vec()
}
#[derive(Serialize, Deserialize)]
struct StoredState {
identity_seed: [u8; 32],
group: Option<Vec<u8>>,
}
impl StoredState {
fn into_member(self, state_path: &Path) -> anyhow::Result<GroupMember> {
let identity = Arc::new(IdentityKeypair::from_seed(self.identity_seed));
let group = self
.group
.map(|bytes| bincode::deserialize(&bytes).context("decode group"))
.transpose()?;
let key_store = DiskKeyStore::persistent(keystore_path(state_path))?;
Ok(GroupMember::new_with_state(identity, key_store, group))
}
fn from_member(member: &GroupMember) -> anyhow::Result<Self> {
let group = member
.group_ref()
.map(|g| bincode::serialize(g).context("serialize group"))
.transpose()?;
Ok(Self {
identity_seed: member.identity_seed(),
group,
})
}
}
fn load_or_init_state(path: &Path) -> anyhow::Result<StoredState> {
if path.exists() {
return load_existing_state(path);
}
let identity = IdentityKeypair::generate();
let key_store = DiskKeyStore::persistent(keystore_path(path))?;
let member = GroupMember::new_with_state(Arc::new(identity), key_store, None);
let state = StoredState::from_member(&member)?;
write_state(path, &state)?;
Ok(state)
}
fn load_existing_state(path: &Path) -> anyhow::Result<StoredState> {
let bytes = std::fs::read(path).with_context(|| format!("read state file {path:?}"))?;
bincode::deserialize(&bytes).context("decode state")
}
fn save_state(path: &Path, member: &GroupMember) -> anyhow::Result<()> {
let state = StoredState::from_member(member)?;
write_state(path, &state)
}
fn write_state(path: &Path, state: &StoredState) -> anyhow::Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).with_context(|| format!("create dir {parent:?}"))?;
}
let bytes = bincode::serialize(state).context("encode state")?;
std::fs::write(path, bytes).with_context(|| format!("write state {path:?}"))?;
Ok(())
}
fn decode_identity_key(hex_str: &str) -> anyhow::Result<Vec<u8>> {
let bytes = hex::decode(hex_str)
.map_err(|e| anyhow::anyhow!(e))
.context("identity key must be hex")?;
anyhow::ensure!(bytes.len() == 32, "identity key must be 32 bytes");
Ok(bytes)
}
fn keystore_path(state_path: &Path) -> PathBuf {
let mut path = state_path.to_path_buf();
path.set_extension("ks");
path
}
/// Format the first `n` bytes as lowercase hex with a trailing `…`.
fn fmt_hex(bytes: &[u8]) -> String {
let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
format!("{hex}")
}
/// Return the current Unix timestamp in milliseconds.
fn current_timestamp_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
// ── Hex encoding helper ───────────────────────────────────────────────────────
//
// We use a tiny inline module rather than adding `hex` as a dependency.
mod hex {
pub fn encode(bytes: impl AsRef<[u8]>) -> String {
bytes.as_ref().iter().map(|b| format!("{b:02x}")).collect()
}
pub fn decode(s: &str) -> Result<Vec<u8>, &'static str> {
if s.len() % 2 != 0 {
return Err("odd-length hex string");
}
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|_| "invalid hex character"))
.collect()
}
}
}

View File

@@ -1,201 +0,0 @@
//! M1 integration test: Noise_XX handshake + Ping/Pong round-trip.
//!
//! Both the server-side and client-side logic run in the same Tokio runtime
//! using `tokio::spawn`. The test verifies:
//!
//! 1. The Noise_XX handshake completes from both sides.
//! 2. A Ping sent by the client arrives as a Ping on the server side.
//! 3. The server's Pong arrives correctly on the client side.
//! 4. Mutual authentication: each peer's observed remote static key matches the
//! other peer's actual public key (the core security property of XX).
use std::sync::Arc;
use tokio::net::TcpListener;
use quicnprotochat_core::{handshake_initiator, handshake_responder, NoiseKeypair};
use quicnprotochat_proto::{MsgType, ParsedEnvelope};
/// Completes a full Noise_XX handshake and Ping/Pong exchange, then verifies
/// mutual authentication by comparing observed vs. actual static public keys.
#[tokio::test]
async fn noise_xx_ping_pong_round_trip() {
let server_keypair = Arc::new(NoiseKeypair::generate());
let client_keypair = NoiseKeypair::generate();
// Bind the listener *before* spawning so the port is ready when the client
// calls connect — no sleep or retry needed.
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("failed to bind test listener");
let server_addr = listener.local_addr().expect("failed to get local addr");
// ── Server task ───────────────────────────────────────────────────────────
//
// Handles exactly one connection: completes the handshake, asserts that it
// receives a Ping, sends a Pong, then returns the client's observed key.
let server_kp = Arc::clone(&server_keypair);
let server_task = tokio::spawn(async move {
let (stream, _peer) = listener.accept().await.expect("server accept failed");
let mut transport = handshake_responder(stream, &server_kp)
.await
.expect("server Noise_XX handshake failed");
let env = transport
.recv_envelope()
.await
.expect("server recv_envelope failed");
match env.msg_type {
MsgType::Ping => {}
_ => panic!("server expected Ping, received a different message type"),
}
transport
.send_envelope(&ParsedEnvelope {
msg_type: MsgType::Pong,
group_id: vec![],
sender_id: vec![],
payload: vec![],
timestamp_ms: 0,
})
.await
.expect("server send_envelope failed");
// Return the client's public key as authenticated by the server.
transport
.remote_static_public_key()
.expect("server: no remote static key after completed XX handshake")
.to_vec()
});
// ── Client side ───────────────────────────────────────────────────────────
let stream = tokio::net::TcpStream::connect(server_addr)
.await
.expect("client connect failed");
let mut transport = handshake_initiator(stream, &client_keypair)
.await
.expect("client Noise_XX handshake failed");
// Capture the server's public key as authenticated by the client.
let server_key_seen_by_client = transport
.remote_static_public_key()
.expect("client: no remote static key after completed XX handshake")
.to_vec();
transport
.send_envelope(&ParsedEnvelope {
msg_type: MsgType::Ping,
group_id: vec![],
sender_id: vec![],
payload: vec![],
timestamp_ms: 1_700_000_000_000,
})
.await
.expect("client send_envelope failed");
let pong = tokio::time::timeout(std::time::Duration::from_secs(5), transport.recv_envelope())
.await
.expect("timed out waiting for Pong — server task likely panicked")
.expect("client recv_envelope failed");
match pong.msg_type {
MsgType::Pong => {}
_ => panic!("client expected Pong, received a different message type"),
}
// ── Mutual authentication assertions ──────────────────────────────────────
let client_key_seen_by_server = server_task
.await
.expect("server task panicked — see output above");
// The server authenticated the client's static public key correctly.
assert_eq!(
client_key_seen_by_server,
client_keypair.public_bytes().to_vec(),
"server's authenticated view of client key does not match client's actual public key"
);
// The client authenticated the server's static public key correctly.
assert_eq!(
server_key_seen_by_client,
server_keypair.public_bytes().to_vec(),
"client's authenticated view of server key does not match server's actual public key"
);
}
/// A second independent connection on the same server must also succeed,
/// confirming that the server keypair reuse across connections is correct.
#[tokio::test]
async fn two_sequential_connections_both_authenticate() {
let server_keypair = Arc::new(NoiseKeypair::generate());
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind failed");
let server_addr = listener.local_addr().expect("local_addr failed");
let server_kp = Arc::clone(&server_keypair);
tokio::spawn(async move {
for _ in 0..2_u8 {
let (stream, _) = listener.accept().await.expect("accept failed");
let kp = Arc::clone(&server_kp);
tokio::spawn(async move {
let mut t = handshake_responder(stream, &kp)
.await
.expect("server handshake failed");
let env = t.recv_envelope().await.expect("recv failed");
match env.msg_type {
MsgType::Ping => {}
_ => panic!("expected Ping"),
}
t.send_envelope(&ParsedEnvelope {
msg_type: MsgType::Pong,
group_id: vec![],
sender_id: vec![],
payload: vec![],
timestamp_ms: 0,
})
.await
.expect("server send failed");
});
}
});
for _ in 0..2_u8 {
let kp = NoiseKeypair::generate();
let stream = tokio::net::TcpStream::connect(server_addr)
.await
.expect("connect failed");
let mut t = handshake_initiator(stream, &kp)
.await
.expect("client handshake failed");
t.send_envelope(&ParsedEnvelope {
msg_type: MsgType::Ping,
group_id: vec![],
sender_id: vec![],
payload: vec![],
timestamp_ms: 0,
})
.await
.expect("client send failed");
let pong = tokio::time::timeout(std::time::Duration::from_secs(5), t.recv_envelope())
.await
.expect("timeout")
.expect("recv failed");
match pong.msg_type {
MsgType::Pong => {}
_ => panic!("expected Pong"),
}
// Each client sees the *same* server public key (key reuse across connections).
let seen = t
.remote_static_public_key()
.expect("no remote key")
.to_vec();
assert_eq!(seen, server_keypair.public_bytes().to_vec());
}
}

View File

@@ -2,20 +2,23 @@
name = "quicnprotochat-core"
version = "0.1.0"
edition = "2021"
description = "Crypto primitives, TLS/QUIC transport, MLS state machine, and Cap'n Proto frame codec for quicnprotochat."
description = "Crypto primitives, MLS state machine, and hybrid post-quantum KEM for quicnprotochat."
license = "MIT"
[dependencies]
# Crypto — classical
x25519-dalek = { workspace = true }
ed25519-dalek = { workspace = true }
snow = { workspace = true }
sha2 = { workspace = true }
hkdf = { workspace = true }
chacha20poly1305 = { workspace = true }
zeroize = { workspace = true }
rand = { workspace = true }
# Crypto — MLS (M2); ml-kem added in M5
# Crypto — post-quantum hybrid KEM (M7)
ml-kem = { workspace = true }
# Crypto — MLS (M2)
openmls = { workspace = true }
openmls_rust_crypto = { workspace = true }
openmls_traits = { workspace = true }
@@ -28,11 +31,8 @@ serde_json = { workspace = true }
capnp = { workspace = true }
quicnprotochat-proto = { path = "../quicnprotochat-proto" }
# Async runtime + codec
# Async runtime
tokio = { workspace = true }
tokio-util = { workspace = true }
futures = { workspace = true }
bytes = { version = "1" }
# Error handling
thiserror = { workspace = true }

View File

@@ -1,203 +0,0 @@
//! Length-prefixed byte frame codec for Tokio's `Framed` adapter.
//!
//! # Wire format
//!
//! ```text
//! ┌──────────────────────────┬──────────────────────────────────────┐
//! │ length (4 bytes, LE u32)│ payload (length bytes) │
//! └──────────────────────────┴──────────────────────────────────────┘
//! ```
//!
//! Little-endian was chosen over big-endian for consistency with Cap'n Proto's
//! own segment table encoding. Both sides of the connection use the same codec.
//!
//! # Usage
//!
//! This codec is transport-agnostic: during the Noise handshake it frames raw
//! Noise handshake messages; after the handshake it frames Noise-encrypted
//! application data. In both cases the payload is opaque bytes from the
//! codec's perspective.
//!
//! # Frame size limit
//!
//! The Noise protocol specifies a maximum message size of 65 535 bytes.
//! Frames larger than [`NOISE_MAX_MSG`] are rejected as protocol violations.
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
use crate::error::CodecError;
/// Maximum Noise protocol message size in bytes (per RFC / Noise spec §3).
pub const NOISE_MAX_MSG: usize = 65_535;
/// A stateless codec that prepends / reads a 4-byte little-endian length field.
///
/// Implements both [`Encoder<Bytes>`] and [`Decoder`] so it can be used with
/// `tokio_util::codec::Framed`.
#[derive(Debug, Clone, Copy, Default)]
pub struct LengthPrefixedCodec;
impl LengthPrefixedCodec {
pub fn new() -> Self {
Self
}
}
impl Encoder<Bytes> for LengthPrefixedCodec {
type Error = CodecError;
/// Prepend a 4-byte LE length field and append the payload to `dst`.
///
/// # Errors
///
/// Returns [`CodecError::FrameTooLarge`] if `item.len() > NOISE_MAX_MSG`.
/// Returns [`CodecError::Io`] if the underlying write fails (propagated
/// by `tokio-util` from the TCP stream).
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
let len = item.len();
if len > NOISE_MAX_MSG {
return Err(CodecError::FrameTooLarge {
len,
max: NOISE_MAX_MSG,
});
}
// Reserve exactly the space needed: 4 bytes header + payload.
dst.reserve(4 + len);
dst.put_u32_le(len as u32);
dst.extend_from_slice(&item);
Ok(())
}
}
impl Decoder for LengthPrefixedCodec {
type Item = BytesMut;
type Error = CodecError;
/// Read a length-prefixed frame from `src`.
///
/// Returns `Ok(None)` when more bytes are needed (standard Decoder contract).
/// Returns `Ok(Some(frame))` when a complete frame is available.
///
/// # Errors
///
/// Returns [`CodecError::FrameTooLarge`] if the length field exceeds
/// [`NOISE_MAX_MSG`]. This is treated as an unrecoverable protocol
/// violation — callers should close the connection.
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// Need at least the 4-byte length header.
if src.len() < 4 {
src.reserve(4_usize.saturating_sub(src.len()));
return Ok(None);
}
// Peek at the length without advancing — avoid mutating state on None.
let frame_len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
if frame_len > NOISE_MAX_MSG {
return Err(CodecError::FrameTooLarge {
len: frame_len,
max: NOISE_MAX_MSG,
});
}
let total = 4 + frame_len;
if src.len() < total {
// Tell Tokio how many additional bytes we need to avoid O(n) polling.
src.reserve(total - src.len());
return Ok(None);
}
// Consume the 4-byte length header, then split the payload.
src.advance(4);
Ok(Some(src.split_to(frame_len)))
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
fn encode_then_decode(payload: &[u8]) -> BytesMut {
let mut codec = LengthPrefixedCodec::new();
let mut buf = BytesMut::new();
codec
.encode(Bytes::copy_from_slice(payload), &mut buf)
.expect("encode failed");
let decoded = codec.decode(&mut buf).expect("decode error");
decoded.expect("expected a complete frame")
}
#[test]
fn round_trip_empty_payload() {
let result = encode_then_decode(&[]);
assert!(result.is_empty());
}
#[test]
fn round_trip_small_payload() {
let payload = b"hello quicnprotochat";
let result = encode_then_decode(payload);
assert_eq!(&result[..], payload);
}
#[test]
fn round_trip_max_size_payload() {
let payload = vec![0xAB_u8; NOISE_MAX_MSG];
let result = encode_then_decode(&payload);
assert_eq!(&result[..], &payload[..]);
}
#[test]
fn oversized_encode_returns_error() {
let mut codec = LengthPrefixedCodec::new();
let mut buf = BytesMut::new();
let oversized = Bytes::from(vec![0u8; NOISE_MAX_MSG + 1]);
let err = codec.encode(oversized, &mut buf).unwrap_err();
assert!(matches!(err, CodecError::FrameTooLarge { .. }));
}
#[test]
fn oversized_length_field_decode_returns_error() {
let mut codec = LengthPrefixedCodec::new();
let mut buf = BytesMut::new();
// Encode a fake length field that exceeds NOISE_MAX_MSG.
buf.put_u32_le((NOISE_MAX_MSG + 1) as u32);
let err = codec.decode(&mut buf).unwrap_err();
assert!(matches!(err, CodecError::FrameTooLarge { .. }));
}
#[test]
fn partial_payload_returns_none() {
let mut codec = LengthPrefixedCodec::new();
let mut buf = BytesMut::new();
// Length header says 10 bytes but we only provide 5.
buf.put_u32_le(10);
buf.extend_from_slice(&[0u8; 5]);
let result = codec.decode(&mut buf).expect("decode error");
assert!(result.is_none());
}
#[test]
fn partial_header_returns_none() {
let mut codec = LengthPrefixedCodec::new();
// Only 2 bytes of the 4-byte header are available.
let mut buf = BytesMut::from(&[0x00_u8, 0x01][..]);
let result = codec.decode(&mut buf).expect("decode error");
assert!(result.is_none());
}
#[test]
fn length_field_is_little_endian() {
let payload = b"le-check";
let mut codec = LengthPrefixedCodec::new();
let mut buf = BytesMut::new();
codec
.encode(Bytes::from_static(payload), &mut buf)
.expect("encode failed");
// First 4 bytes are the LE length: 8 in LE is [0x08, 0x00, 0x00, 0x00].
assert_eq!(&buf[..4], &[8, 0, 0, 0]);
}
}

View File

@@ -1,77 +1,21 @@
//! Error types for `quicnprotochat-core`.
//!
//! Two separate error types are used to preserve type-level separation of concerns:
//!
//! - [`CodecError`] — errors from the length-prefixed frame codec (I/O and framing only).
//! `tokio-util` requires the codec error implement `From<io::Error>`.
//!
//! - [`CoreError`] — errors from the Noise handshake and transport layer.
use thiserror::Error;
/// Maximum plaintext bytes per Noise transport frame.
///
/// Noise limits each message to 65 535 bytes. ChaCha20-Poly1305 consumes
/// 16 bytes for the authentication tag, leaving 65 519 bytes for plaintext.
pub const MAX_PLAINTEXT_LEN: usize = 65_519;
// ── Codec errors ──────────────────────────────────────────────────────────────
/// Errors produced by [`LengthPrefixedCodec`](crate::LengthPrefixedCodec).
#[derive(Debug, Error)]
pub enum CodecError {
/// The underlying TCP stream returned an I/O error.
///
/// This variant satisfies the `tokio-util` requirement that codec error
/// types implement `From<std::io::Error>`.
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
/// A frame length field exceeded the Noise protocol maximum (65 535 bytes).
///
/// This is treated as a protocol violation and the connection should be
/// closed rather than retried.
#[error("frame length {len} exceeds maximum {max} bytes")]
FrameTooLarge { len: usize, max: usize },
}
// ── Core errors ───────────────────────────────────────────────────────────────
/// Errors produced by the Noise handshake and [`NoiseTransport`](crate::NoiseTransport).
/// Errors produced by core cryptographic and MLS operations.
#[derive(Debug, Error)]
pub enum CoreError {
/// The `snow` Noise protocol engine returned an error.
///
/// This covers DH failures, decryption failures, state machine violations,
/// and pattern parse errors.
#[error("Noise protocol error: {0}")]
Noise(#[from] snow::Error),
/// The frame codec reported an I/O or framing error.
#[error("frame codec error: {0}")]
Codec(#[from] CodecError),
/// Cap'n Proto serialisation or deserialisation failed.
#[error("Cap'n Proto error: {0}")]
Capnp(#[from] capnp::Error),
/// The remote peer closed the connection before the handshake completed.
#[error("peer closed connection during Noise handshake")]
HandshakeIncomplete,
/// The remote peer closed the connection during normal operation.
#[error("peer closed connection")]
ConnectionClosed,
/// The caller attempted to send a plaintext larger than the Noise maximum.
///
/// The limit is [`MAX_PLAINTEXT_LEN`] bytes per frame.
#[error("plaintext {size} B exceeds Noise frame limit of {MAX_PLAINTEXT_LEN} B")]
MessageTooLarge { size: usize },
/// An MLS operation failed.
///
/// The inner string is the debug representation of the openmls error.
#[error("MLS error: {0}")]
Mls(String),
/// A hybrid KEM (X25519 + ML-KEM-768) operation failed.
#[error("hybrid KEM error: {0}")]
HybridKem(#[from] crate::hybrid_kem::HybridKemError),
}

View File

@@ -0,0 +1,452 @@
//! Post-quantum hybrid KEM: X25519 + ML-KEM-768.
//!
//! Wraps MLS payloads in an outer encryption layer using a hybrid key
//! encapsulation mechanism. The X25519 component provides classical
//! ECDH security; the ML-KEM-768 component (FIPS 203) provides
//! post-quantum security.
//!
//! # Wire format
//!
//! ```text
//! version(1) | x25519_eph_pk(32) | mlkem_ct(1088) | aead_nonce(12) | aead_ct(var)
//! ```
//!
//! # Key derivation
//!
//! ```text
//! ikm = X25519_shared(32) || ML-KEM_shared(32)
//! key = HKDF-SHA256(salt=[], ikm, info="quicnprotochat-hybrid-v1", L=32)
//! ```
use chacha20poly1305::{
aead::{Aead, KeyInit},
ChaCha20Poly1305, Key, Nonce,
};
use hkdf::Hkdf;
use ml_kem::{
array::Array,
kem::{Decapsulate, Encapsulate},
EncodedSizeUser, KemCore, MlKem768, MlKem768Params,
};
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use x25519_dalek::{EphemeralSecret, PublicKey as X25519Public, StaticSecret};
use zeroize::Zeroizing;
// Re-import the concrete key types from the kem sub-module.
use ml_kem::kem::{DecapsulationKey, EncapsulationKey};
/// Current hybrid envelope version byte.
const HYBRID_VERSION: u8 = 0x01;
/// HKDF info string for domain separation.
const HKDF_INFO: &[u8] = b"quicnprotochat-hybrid-v1";
/// ML-KEM-768 ciphertext size in bytes.
const MLKEM_CT_LEN: usize = 1088;
/// ML-KEM-768 encapsulation key size in bytes.
pub const MLKEM_EK_LEN: usize = 1184;
/// ML-KEM-768 decapsulation key size in bytes.
pub const MLKEM_DK_LEN: usize = 2400;
/// Envelope header: version(1) + x25519 eph pk(32) + mlkem ct(1088) + nonce(12).
const HEADER_LEN: usize = 1 + 32 + MLKEM_CT_LEN + 12;
// ── Error type ──────────────────────────────────────────────────────────────
#[derive(Debug, thiserror::Error)]
pub enum HybridKemError {
#[error("AEAD encryption failed")]
EncryptionFailed,
#[error("AEAD decryption failed (wrong recipient or tampered)")]
DecryptionFailed,
#[error("unsupported hybrid envelope version: {0}")]
UnsupportedVersion(u8),
#[error("envelope too short ({0} bytes, minimum {HEADER_LEN})")]
TooShort(usize),
#[error("invalid ML-KEM encapsulation key")]
InvalidMlKemKey,
#[error("ML-KEM decapsulation failed")]
MlKemDecapsFailed,
}
// ── Keypair types ───────────────────────────────────────────────────────────
/// A hybrid keypair combining X25519 (classical) + ML-KEM-768 (post-quantum).
///
/// Each peer holds one of these. The public portion is distributed so
/// senders can encrypt payloads with post-quantum protection.
pub struct HybridKeypair {
x25519_sk: StaticSecret,
x25519_pk: X25519Public,
mlkem_dk: DecapsulationKey<MlKem768Params>,
mlkem_ek: EncapsulationKey<MlKem768Params>,
}
/// Serialisable form of a [`HybridKeypair`] for persistence.
#[derive(Serialize, Deserialize)]
pub struct HybridKeypairBytes {
pub x25519_sk: [u8; 32],
pub mlkem_dk: Vec<u8>,
pub mlkem_ek: Vec<u8>,
}
/// The public portion of a hybrid keypair, sent to peers.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct HybridPublicKey {
pub x25519_pk: [u8; 32],
pub mlkem_ek: Vec<u8>,
}
impl HybridKeypair {
/// Generate a fresh hybrid keypair from OS CSPRNG.
pub fn generate() -> Self {
let x25519_sk = StaticSecret::random_from_rng(OsRng);
let x25519_pk = X25519Public::from(&x25519_sk);
let (mlkem_dk, mlkem_ek) = MlKem768::generate(&mut OsRng);
Self {
x25519_sk,
x25519_pk,
mlkem_dk,
mlkem_ek,
}
}
/// Reconstruct from serialised bytes.
pub fn from_bytes(bytes: &HybridKeypairBytes) -> Result<Self, HybridKemError> {
let x25519_sk = StaticSecret::from(bytes.x25519_sk);
let x25519_pk = X25519Public::from(&x25519_sk);
let mlkem_dk_arr = Array::try_from(bytes.mlkem_dk.as_slice())
.map_err(|_| HybridKemError::InvalidMlKemKey)?;
let mlkem_dk = DecapsulationKey::<MlKem768Params>::from_bytes(&mlkem_dk_arr);
let mlkem_ek_arr = Array::try_from(bytes.mlkem_ek.as_slice())
.map_err(|_| HybridKemError::InvalidMlKemKey)?;
let mlkem_ek = EncapsulationKey::<MlKem768Params>::from_bytes(&mlkem_ek_arr);
Ok(Self {
x25519_sk,
x25519_pk,
mlkem_dk,
mlkem_ek,
})
}
/// Serialise the keypair for persistence.
pub fn to_bytes(&self) -> HybridKeypairBytes {
HybridKeypairBytes {
x25519_sk: self.x25519_sk.to_bytes(),
mlkem_dk: self.mlkem_dk.as_bytes().to_vec(),
mlkem_ek: self.mlkem_ek.as_bytes().to_vec(),
}
}
/// Extract the public portion for distribution to peers.
pub fn public_key(&self) -> HybridPublicKey {
HybridPublicKey {
x25519_pk: self.x25519_pk.to_bytes(),
mlkem_ek: self.mlkem_ek.as_bytes().to_vec(),
}
}
}
impl HybridPublicKey {
/// Serialise to a single byte blob: x25519_pk(32) || mlkem_ek(1184).
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(32 + self.mlkem_ek.len());
out.extend_from_slice(&self.x25519_pk);
out.extend_from_slice(&self.mlkem_ek);
out
}
/// Deserialise from a single byte blob.
pub fn from_bytes(bytes: &[u8]) -> Result<Self, HybridKemError> {
if bytes.len() < 32 + MLKEM_EK_LEN {
return Err(HybridKemError::TooShort(bytes.len()));
}
let mut x25519_pk = [0u8; 32];
x25519_pk.copy_from_slice(&bytes[..32]);
let mlkem_ek = bytes[32..32 + MLKEM_EK_LEN].to_vec();
Ok(Self {
x25519_pk,
mlkem_ek,
})
}
}
// ── Encrypt / Decrypt ───────────────────────────────────────────────────────
/// Encrypt `plaintext` to `recipient_pk` using X25519 + ML-KEM-768 hybrid KEM.
///
/// Returns the complete hybrid envelope as a byte vector.
pub fn hybrid_encrypt(
recipient_pk: &HybridPublicKey,
plaintext: &[u8],
) -> Result<Vec<u8>, HybridKemError> {
// 1. Ephemeral X25519 DH
let eph_secret = EphemeralSecret::random_from_rng(OsRng);
let eph_public = X25519Public::from(&eph_secret);
let x25519_recipient = X25519Public::from(recipient_pk.x25519_pk);
let x25519_ss = eph_secret.diffie_hellman(&x25519_recipient);
// 2. ML-KEM-768 encapsulation
let mlkem_ek_arr = Array::try_from(recipient_pk.mlkem_ek.as_slice())
.map_err(|_| HybridKemError::InvalidMlKemKey)?;
let mlkem_ek = EncapsulationKey::<MlKem768Params>::from_bytes(&mlkem_ek_arr);
let (mlkem_ct, mlkem_ss) = mlkem_ek
.encapsulate(&mut OsRng)
.map_err(|_| HybridKemError::EncryptionFailed)?;
// 3. Combine shared secrets via HKDF
let (aead_key, aead_nonce) =
derive_aead_material(x25519_ss.as_bytes(), mlkem_ss.as_slice());
// 4. AEAD encrypt
let cipher = ChaCha20Poly1305::new(&aead_key);
let ct = cipher
.encrypt(&aead_nonce, plaintext)
.map_err(|_| HybridKemError::EncryptionFailed)?;
// 5. Assemble envelope: version || x25519_eph_pk || mlkem_ct || nonce || aead_ct
let mut out = Vec::with_capacity(HEADER_LEN + ct.len());
out.push(HYBRID_VERSION);
out.extend_from_slice(&eph_public.to_bytes());
out.extend_from_slice(mlkem_ct.as_slice());
out.extend_from_slice(aead_nonce.as_slice());
out.extend_from_slice(&ct);
Ok(out)
}
/// Decrypt a hybrid envelope using the recipient's private key.
pub fn hybrid_decrypt(
keypair: &HybridKeypair,
envelope: &[u8],
) -> Result<Vec<u8>, HybridKemError> {
if envelope.len() < HEADER_LEN + 16 {
// 16 = minimum AEAD tag
return Err(HybridKemError::TooShort(envelope.len()));
}
let version = envelope[0];
if version != HYBRID_VERSION {
return Err(HybridKemError::UnsupportedVersion(version));
}
let mut cursor = 1;
// X25519 ephemeral public key
let mut eph_pk_bytes = [0u8; 32];
eph_pk_bytes.copy_from_slice(&envelope[cursor..cursor + 32]);
cursor += 32;
// ML-KEM ciphertext
let mlkem_ct_bytes = &envelope[cursor..cursor + MLKEM_CT_LEN];
cursor += MLKEM_CT_LEN;
// AEAD nonce
let nonce = Nonce::from_slice(&envelope[cursor..cursor + 12]);
cursor += 12;
// AEAD ciphertext
let aead_ct = &envelope[cursor..];
// 1. X25519 DH with ephemeral public key
let eph_pk = X25519Public::from(eph_pk_bytes);
let x25519_ss = keypair.x25519_sk.diffie_hellman(&eph_pk);
// 2. ML-KEM decapsulation — convert bytes to the ciphertext array type
// that `DecapsulationKey::decapsulate` expects.
let mlkem_ct_arr = Array::try_from(mlkem_ct_bytes)
.map_err(|_| HybridKemError::MlKemDecapsFailed)?;
let mlkem_ss = keypair
.mlkem_dk
.decapsulate(&mlkem_ct_arr)
.map_err(|_| HybridKemError::MlKemDecapsFailed)?;
// 3. Derive AEAD key
let (aead_key, _) = derive_aead_material(x25519_ss.as_bytes(), mlkem_ss.as_slice());
// 4. Decrypt
let cipher = ChaCha20Poly1305::new(&aead_key);
let plaintext = cipher
.decrypt(nonce, aead_ct)
.map_err(|_| HybridKemError::DecryptionFailed)?;
Ok(plaintext)
}
/// Derive AEAD key + nonce from the combined X25519 + ML-KEM shared secrets.
fn derive_aead_material(
x25519_ss: &[u8],
mlkem_ss: &[u8],
) -> (Key, Nonce) {
let mut ikm = Zeroizing::new(vec![0u8; x25519_ss.len() + mlkem_ss.len()]);
ikm[..x25519_ss.len()].copy_from_slice(x25519_ss);
ikm[x25519_ss.len()..].copy_from_slice(mlkem_ss);
let hk = Hkdf::<Sha256>::new(None, &ikm);
let mut key_bytes = Zeroizing::new([0u8; 32]);
hk.expand(HKDF_INFO, &mut *key_bytes)
.expect("32 bytes is valid HKDF-SHA256 output length");
let mut nonce_bytes = [0u8; 12];
hk.expand(b"quicnprotochat-hybrid-nonce-v1", &mut nonce_bytes)
.expect("12 bytes is valid HKDF-SHA256 output length");
(*Key::from_slice(&*key_bytes), *Nonce::from_slice(&nonce_bytes))
}
// ── Tests ───────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn keygen_produces_valid_public_key() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
assert_eq!(pk.x25519_pk.len(), 32);
assert_eq!(pk.mlkem_ek.len(), MLKEM_EK_LEN);
}
#[test]
fn encrypt_decrypt_round_trip() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let plaintext = b"hello post-quantum world!";
let envelope = hybrid_encrypt(&pk, plaintext).unwrap();
let recovered = hybrid_decrypt(&kp, &envelope).unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn wrong_key_decryption_fails() {
let kp_sender_target = HybridKeypair::generate();
let kp_wrong = HybridKeypair::generate();
let pk = kp_sender_target.public_key();
let envelope = hybrid_encrypt(&pk, b"secret").unwrap();
let result = hybrid_decrypt(&kp_wrong, &envelope);
assert!(result.is_err());
}
#[test]
fn tampered_aead_ciphertext_fails() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
let last = envelope.len() - 1;
envelope[last] ^= 0x01;
assert!(matches!(
hybrid_decrypt(&kp, &envelope),
Err(HybridKemError::DecryptionFailed)
));
}
#[test]
fn tampered_mlkem_ct_fails() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
// Flip a byte in the ML-KEM ciphertext region (starts at offset 33)
envelope[40] ^= 0xFF;
assert!(hybrid_decrypt(&kp, &envelope).is_err());
}
#[test]
fn tampered_x25519_eph_pk_fails() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
// Flip a byte in the X25519 ephemeral pk region (offset 1..33)
envelope[5] ^= 0xFF;
assert!(hybrid_decrypt(&kp, &envelope).is_err());
}
#[test]
fn unsupported_version_rejected() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload").unwrap();
envelope[0] = 0xFF;
assert!(matches!(
hybrid_decrypt(&kp, &envelope),
Err(HybridKemError::UnsupportedVersion(0xFF))
));
}
#[test]
fn envelope_too_short_rejected() {
let kp = HybridKeypair::generate();
assert!(matches!(
hybrid_decrypt(&kp, &[0x01; 10]),
Err(HybridKemError::TooShort(10))
));
}
#[test]
fn keypair_serialisation_round_trip() {
let kp = HybridKeypair::generate();
let bytes = kp.to_bytes();
let restored = HybridKeypair::from_bytes(&bytes).unwrap();
assert_eq!(kp.x25519_pk.to_bytes(), restored.x25519_pk.to_bytes());
assert_eq!(
kp.public_key().mlkem_ek,
restored.public_key().mlkem_ek
);
// Verify restored keypair can decrypt
let pk = kp.public_key();
let ct = hybrid_encrypt(&pk, b"test").unwrap();
let pt = hybrid_decrypt(&restored, &ct).unwrap();
assert_eq!(pt, b"test");
}
#[test]
fn public_key_serialisation_round_trip() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let bytes = pk.to_bytes();
let restored = HybridPublicKey::from_bytes(&bytes).unwrap();
assert_eq!(pk.x25519_pk, restored.x25519_pk);
assert_eq!(pk.mlkem_ek, restored.mlkem_ek);
}
#[test]
fn large_payload_round_trip() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let plaintext = vec![0xAB; 50_000]; // 50 KB
let envelope = hybrid_encrypt(&pk, &plaintext).unwrap();
let recovered = hybrid_decrypt(&kp, &envelope).unwrap();
assert_eq!(recovered, plaintext);
}
}

View File

@@ -1,11 +1,8 @@
//! Ed25519 identity keypair for MLS credentials and AS registration.
//!
//! # Relationship to the Noise keypair
//!
//! The X25519 [`NoiseKeypair`](crate::NoiseKeypair) is the transport-layer
//! static key used in the Noise_XX handshake. The Ed25519 [`IdentityKeypair`]
//! is the long-term identity key embedded in MLS `BasicCredential`s. The two
//! keys serve different roles and must not be confused.
//! The [`IdentityKeypair`] is the long-term identity key embedded in MLS
//! `BasicCredential`s. It is used for signing MLS messages and as the
//! indexing key for the Authentication Service.
//!
//! # Zeroize
//!

View File

@@ -1,121 +0,0 @@
//! Static X25519 keypair for the Noise_XX handshake.
//!
//! # Security properties
//!
//! - The private key is stored as [`x25519_dalek::StaticSecret`], which
//! implements [`ZeroizeOnDrop`](zeroize::ZeroizeOnDrop) — the key material
//! is overwritten with zeros when the `StaticSecret` is dropped.
//!
//! - [`NoiseKeypair::private_bytes`] returns a [`Zeroizing`](zeroize::Zeroizing)
//! wrapper so the caller's copy of the raw bytes is also cleared on drop.
//! Pass it directly to `snow::Builder::local_private_key` and let it fall
//! out of scope immediately after.
//!
//! - The public key is not secret and may be freely cloned or logged.
//!
//! # Persistence
//!
//! `NoiseKeypair` does not implement `Serialize` intentionally. Key persistence
//! to disk is handled at the application layer (M6) with appropriate file
//! permission checks and, optionally, passphrase-based encryption.
use rand::rngs::OsRng;
use x25519_dalek::{PublicKey, StaticSecret};
use zeroize::Zeroizing;
/// A static X25519 keypair used for Noise_XX mutual authentication.
///
/// Generate once per node identity and reuse across connections.
/// The private scalar is zeroized when this value is dropped.
pub struct NoiseKeypair {
/// Private scalar — zeroized on drop via `x25519_dalek`'s `ZeroizeOnDrop` impl.
private: StaticSecret,
/// Corresponding public key — derived from `private` at construction time.
public: PublicKey,
}
impl NoiseKeypair {
/// Generate a fresh keypair from the OS CSPRNG.
///
/// This calls `getrandom` on Linux (via `OsRng`) and is suitable for
/// generating long-lived static identity keys.
pub fn generate() -> Self {
let private = StaticSecret::random_from_rng(OsRng);
let public = PublicKey::from(&private);
Self { private, public }
}
/// Return the raw private key bytes in a [`Zeroizing`] wrapper.
///
/// The returned wrapper clears the 32-byte copy when dropped.
/// Use it immediately to initialise a `snow::Builder` and let it drop:
///
/// ```rust,ignore
/// let private = keypair.private_bytes();
/// let session = snow::Builder::new(params)
/// .local_private_key(&private[..])
/// .build_initiator()?;
/// // `private` is zeroized here.
/// ```
pub fn private_bytes(&self) -> Zeroizing<[u8; 32]> {
Zeroizing::new(self.private.to_bytes())
}
/// Return the public key bytes.
///
/// Safe to log or transmit — this is not secret material.
pub fn public_bytes(&self) -> [u8; 32] {
self.public.to_bytes()
}
}
// Prevent accidental `{:?}` printing of the private key.
impl std::fmt::Debug for NoiseKeypair {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Show only the first 4 bytes of the public key as a sanity identifier.
// No external crate needed; the private key is never printed.
let pub_bytes = self.public_bytes();
write!(
f,
"NoiseKeypair {{ public: {:02x}{:02x}{:02x}{:02x}…, private: [redacted] }}",
pub_bytes[0], pub_bytes[1], pub_bytes[2], pub_bytes[3],
)
}
}
// ── Tests ─────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generated_public_key_matches_private() {
let kp = NoiseKeypair::generate();
// Re-derive the public key from the private bytes and confirm they match.
let private_bytes = kp.private_bytes();
let secret = StaticSecret::from(*private_bytes);
let rederived = PublicKey::from(&secret);
assert_eq!(rederived.to_bytes(), kp.public_bytes());
}
#[test]
fn two_keypairs_differ() {
let a = NoiseKeypair::generate();
let b = NoiseKeypair::generate();
assert_ne!(a.public_bytes(), b.public_bytes());
}
#[test]
fn private_bytes_is_zeroizing() {
// Verify that Zeroizing<[u8;32]> does not expose the key via Debug.
let kp = NoiseKeypair::generate();
let private = kp.private_bytes();
// We cannot observe zeroization after drop in a test without unsafe,
// but we can confirm the wrapper type is returned and is non-zero.
assert!(
private.iter().any(|&b| b != 0),
"freshly generated private key should not be all zeros"
);
}
}

View File

@@ -1,34 +1,32 @@
//! Core cryptographic primitives, Noise_XX transport, MLS group state machine,
//! and frame codec for quicnprotochat.
//! Core cryptographic primitives, MLS group state machine, and hybrid
//! post-quantum KEM for quicnprotochat.
//!
//! # Module layout
//!
//! | Module | Responsibility |
//! |--------------|------------------------------------------------------------------|
//! | `error` | [`CoreError`] and [`CodecError`] types |
//! | `keypair` | [`NoiseKeypair`] — static X25519 key, zeroize-on-drop |
//! | `codec` | [`LengthPrefixedCodec`] — Tokio Encoder + Decoder |
//! | `noise` | [`handshake_initiator`], [`handshake_responder`], [`NoiseTransport`] |
//! | `error` | [`CoreError`] type |
//! | `identity` | [`IdentityKeypair`] — Ed25519 identity key for MLS credentials |
//! | `keypackage` | [`generate_key_package`] — standalone KeyPackage generation |
//! | `group` | [`GroupMember`] — MLS group lifecycle (create/join/send/recv) |
//! | `hybrid_kem` | Hybrid X25519 + ML-KEM-768 key encapsulation |
//! | `keystore` | [`DiskKeyStore`] — OpenMLS key store with optional persistence |
mod codec;
mod error;
mod group;
pub mod hybrid_kem;
mod identity;
mod keypackage;
mod keypair;
mod keystore;
mod noise;
// ── Public API ────────────────────────────────────────────────────────────────
pub use codec::{LengthPrefixedCodec, NOISE_MAX_MSG};
pub use error::{CodecError, CoreError, MAX_PLAINTEXT_LEN};
pub use error::CoreError;
pub use group::GroupMember;
pub use hybrid_kem::{
hybrid_decrypt, hybrid_encrypt, HybridKeypair, HybridKeypairBytes, HybridKemError,
HybridPublicKey,
};
pub use identity::IdentityKeypair;
pub use keypackage::generate_key_package;
pub use keypair::NoiseKeypair;
pub use keystore::DiskKeyStore;
pub use noise::{handshake_initiator, handshake_responder, NoiseTransport};

View File

@@ -1,400 +0,0 @@
//! Noise_XX handshake and encrypted transport.
//!
//! # Protocol
//!
//! Pattern: `Noise_XX_25519_ChaChaPoly_BLAKE2s`
//!
//! ```text
//! XX handshake (3 messages):
//! -> e (initiator sends ephemeral public key)
//! <- e, ee, s, es (responder replies; mutual DH + responder static)
//! -> s, se (initiator sends static key; final DH)
//! ```
//!
//! After the handshake both peers have authenticated each other's static X25519
//! keys and negotiated a symmetric session with ChaCha20-Poly1305.
//!
//! # Framing
//!
//! All messages — handshake and application — are carried in length-prefixed
//! frames (see [`LengthPrefixedCodec`](crate::LengthPrefixedCodec)).
//!
//! In the handshake phase the frame payload is the raw Noise handshake bytes
//! produced by `snow`. In the transport phase the frame payload is a
//! Noise-encrypted Cap'n Proto message.
//!
//! # Post-quantum gap (ADR-006)
//!
//! The Noise transport uses classical X25519. PQ-Noise is not yet standardised
//! in `snow`. MLS application data is PQ-protected from M5 onward. The residual
//! risk (metadata exposure via handshake harvest) is accepted for M1M5.
use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use tokio::{
io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf},
net::TcpStream,
};
use tokio_util::codec::Framed;
use crate::{
codec::{LengthPrefixedCodec, NOISE_MAX_MSG},
error::{CoreError, MAX_PLAINTEXT_LEN},
keypair::NoiseKeypair,
};
use quicnprotochat_proto::{build_envelope, parse_envelope, ParsedEnvelope};
/// Noise parameters used throughout quicnprotochat.
///
/// `Noise_XX_25519_ChaChaPoly_BLAKE2s` — both parties authenticate each
/// other's static X25519 keys; ChaCha20-Poly1305 for AEAD; BLAKE2s as PRF.
const NOISE_PARAMS: &str = "Noise_XX_25519_ChaChaPoly_BLAKE2s";
/// ChaCha20-Poly1305 authentication tag overhead per Noise message.
const NOISE_TAG_LEN: usize = 16;
// ── Public type ───────────────────────────────────────────────────────────────
/// An authenticated, encrypted Noise transport session.
///
/// Obtained by completing a [`handshake_initiator`] or [`handshake_responder`]
/// call. All subsequent I/O is through [`send_frame`](Self::send_frame) and
/// [`recv_frame`](Self::recv_frame), or the higher-level envelope helpers.
///
/// # Thread safety
///
/// `NoiseTransport` is `Send` but not `Clone` or `Sync`. Use one instance per
/// Tokio task; use message passing to share data across tasks.
pub struct NoiseTransport {
/// The TCP stream wrapped in the length-prefix codec.
framed: Framed<TcpStream, LengthPrefixedCodec>,
/// The Noise session in transport mode — encrypts and decrypts frames.
session: snow::TransportState,
/// Remote peer's static X25519 public key, captured from the HandshakeState
/// before `into_transport_mode()` consumes it.
///
/// Stored here explicitly rather than via `TransportState::get_remote_static()`
/// because snow does not guarantee the method survives the mode transition.
remote_static: Option<Vec<u8>>,
}
impl NoiseTransport {
// ── Transport-layer I/O ───────────────────────────────────────────────────
/// Encrypt `plaintext` and send it as a single length-prefixed frame.
///
/// # Errors
///
/// - [`CoreError::MessageTooLarge`] if `plaintext` exceeds
/// [`MAX_PLAINTEXT_LEN`] bytes.
/// - [`CoreError::Noise`] if the Noise session fails to encrypt.
/// - [`CoreError::Codec`] if the underlying TCP write fails.
pub async fn send_frame(&mut self, plaintext: &[u8]) -> Result<(), CoreError> {
if plaintext.len() > MAX_PLAINTEXT_LEN {
return Err(CoreError::MessageTooLarge {
size: plaintext.len(),
});
}
// Allocate exactly the right amount: plaintext + AEAD tag.
let mut ciphertext = vec![0u8; plaintext.len() + NOISE_TAG_LEN];
let len = self
.session
.write_message(plaintext, &mut ciphertext)
.map_err(CoreError::Noise)?;
self.framed
.send(Bytes::copy_from_slice(&ciphertext[..len]))
.await
.map_err(CoreError::Codec)?;
Ok(())
}
/// Receive the next length-prefixed frame and decrypt it.
///
/// Awaits until a complete frame arrives on the TCP stream.
///
/// # Errors
///
/// - [`CoreError::ConnectionClosed`] if the peer closed the connection.
/// - [`CoreError::Noise`] if decryption or authentication fails.
/// - [`CoreError::Codec`] if the underlying TCP read or framing fails.
pub async fn recv_frame(&mut self) -> Result<Vec<u8>, CoreError> {
let ciphertext = self
.framed
.next()
.await
.ok_or(CoreError::ConnectionClosed)?
.map_err(CoreError::Codec)?;
// Plaintext is always shorter than ciphertext (AEAD tag is stripped).
let mut plaintext = vec![0u8; ciphertext.len()];
let len = self
.session
.read_message(&ciphertext, &mut plaintext)
.map_err(CoreError::Noise)?;
plaintext.truncate(len);
Ok(plaintext)
}
// ── Envelope-level I/O ────────────────────────────────────────────────────
/// Serialise and encrypt a [`ParsedEnvelope`], then send it.
///
/// This is the primary application-level send method. The Cap'n Proto
/// encoding is done by [`quicnprotochat_proto::build_envelope`] before encryption.
pub async fn send_envelope(&mut self, env: &ParsedEnvelope) -> Result<(), CoreError> {
let bytes = build_envelope(env).map_err(CoreError::Capnp)?;
self.send_frame(&bytes).await
}
/// Receive a frame, decrypt it, and deserialise it as a [`ParsedEnvelope`].
///
/// This is the primary application-level receive method.
pub async fn recv_envelope(&mut self) -> Result<ParsedEnvelope, CoreError> {
let bytes = self.recv_frame().await?;
parse_envelope(&bytes).map_err(CoreError::Capnp)
}
// ── capnp-rpc bridge ─────────────────────────────────────────────────────
/// Consume the transport and return a byte-stream pair suitable for
/// `capnp-rpc`'s `twoparty::VatNetwork`.
///
/// # Why this exists
///
/// `capnp-rpc` expects `AsyncRead + AsyncWrite` byte streams, but
/// `NoiseTransport` is message-based (each call to `send_frame` /
/// `recv_frame` encrypts/decrypts one Noise message). This method bridges
/// the two models by:
///
/// 1. Creating a `tokio::io::duplex` pipe (an in-process byte channel).
/// 2. Spawning a background task that shuttles bytes between the pipe and
/// the Noise framed transport using `tokio::select!`.
///
/// The returned `(ReadHalf, WriteHalf)` are the **application** ends of the
/// pipe; `capnp-rpc` reads from `ReadHalf` and writes to `WriteHalf`. The
/// bridge task owns the **transport** end and the `NoiseTransport`.
///
/// # Framing
///
/// Each Noise frame carries at most [`MAX_PLAINTEXT_LEN`] bytes of
/// plaintext. The bridge uses that as the read buffer size so that one
/// frame is never split across multiple pipe writes.
///
/// # Lifetime
///
/// The bridge task runs until either side of the pipe closes. When the
/// capnp-rpc system drops the pipe halves, the bridge exits cleanly.
pub fn into_capnp_io(mut self) -> (ReadHalf<DuplexStream>, WriteHalf<DuplexStream>) {
// Choose a pipe capacity large enough for one max-size Noise frame.
let (app_stream, mut transport_stream) = duplex(MAX_PLAINTEXT_LEN);
tokio::spawn(async move {
let mut buf = vec![0u8; MAX_PLAINTEXT_LEN];
loop {
tokio::select! {
// Noise → app: receive an encrypted frame and write decrypted
// plaintext into the pipe.
noise_result = self.recv_frame() => {
match noise_result {
Ok(plaintext) => {
if transport_stream.write_all(&plaintext).await.is_err() {
break; // app side closed
}
}
Err(_) => break, // peer closed or Noise error
}
}
// app → Noise: read bytes from the pipe and send as an
// encrypted Noise frame.
read_result = transport_stream.read(&mut buf) => {
match read_result {
Ok(0) | Err(_) => break, // app side closed
Ok(n) => {
if self.send_frame(&buf[..n]).await.is_err() {
break; // peer closed or Noise error
}
}
}
}
}
}
});
tokio::io::split(app_stream)
}
// ── Session metadata ──────────────────────────────────────────────────────
/// Return the remote peer's static X25519 public key (32 bytes), as
/// authenticated during the Noise_XX handshake.
///
/// Returns `None` only in the impossible case where the XX handshake
/// completed without exchanging static keys (a snow implementation bug).
/// In practice this is always `Some` after a successful handshake.
pub fn remote_static_public_key(&self) -> Option<&[u8]> {
self.remote_static.as_deref()
}
}
impl std::fmt::Debug for NoiseTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let remote = self
.remote_static
.as_deref()
.map(|k| format!("{:02x}{:02x}{:02x}{:02x}", k[0], k[1], k[2], k[3]));
f.debug_struct("NoiseTransport")
.field("remote_static", &remote)
.finish_non_exhaustive()
}
}
// ── Handshake functions ───────────────────────────────────────────────────────
/// Complete a Noise_XX handshake as the **initiator** over `stream`.
///
/// The initiator sends the first handshake message. After the three-message
/// exchange completes, the function returns an authenticated [`NoiseTransport`]
/// ready for application data.
///
/// # Errors
///
/// - [`CoreError::HandshakeIncomplete`] if the peer closes the connection mid-handshake.
/// - [`CoreError::Noise`] if any Noise operation fails (pattern mismatch, bad DH, etc.).
/// - [`CoreError::Codec`] if any TCP I/O fails during the handshake.
pub async fn handshake_initiator(
stream: TcpStream,
keypair: &NoiseKeypair,
) -> Result<NoiseTransport, CoreError> {
let params: snow::params::NoiseParams = NOISE_PARAMS
.parse()
.expect("NOISE_PARAMS is a compile-time constant and must parse successfully");
// The private key bytes are held in a Zeroizing wrapper and cleared after
// snow clones them internally during build_initiator().
let private = keypair.private_bytes();
let mut session = snow::Builder::new(params)
.local_private_key(&private[..])
.build_initiator()
.map_err(CoreError::Noise)?;
drop(private); // zeroize our copy; snow holds its own internal copy
let mut framed = Framed::new(stream, LengthPrefixedCodec::new());
let mut buf = vec![0u8; NOISE_MAX_MSG];
// ── Message 1: -> e ──────────────────────────────────────────────────────
let len = session
.write_message(&[], &mut buf)
.map_err(CoreError::Noise)?;
framed
.send(Bytes::copy_from_slice(&buf[..len]))
.await
.map_err(CoreError::Codec)?;
// ── Message 2: <- e, ee, s, es ───────────────────────────────────────────
let msg2 = recv_handshake_frame(&mut framed).await?;
session
.read_message(&msg2, &mut buf)
.map_err(CoreError::Noise)?;
// ── Message 3: -> s, se ──────────────────────────────────────────────────
let len = session
.write_message(&[], &mut buf)
.map_err(CoreError::Noise)?;
framed
.send(Bytes::copy_from_slice(&buf[..len]))
.await
.map_err(CoreError::Codec)?;
// Zeroize the scratch buffer — it contained plaintext key material during
// the handshake (ephemeral key bytes in message 2 payload).
zeroize::Zeroize::zeroize(&mut buf);
// Capture the remote static key from HandshakeState before consuming it.
let remote_static = session.get_remote_static().map(|k| k.to_vec());
let transport_session = session.into_transport_mode().map_err(CoreError::Noise)?;
Ok(NoiseTransport {
framed,
session: transport_session,
remote_static,
})
}
/// Complete a Noise_XX handshake as the **responder** over `stream`.
///
/// The responder waits for the initiator's first message. After the
/// three-message exchange completes, the function returns an authenticated
/// [`NoiseTransport`] ready for application data.
///
/// # Errors
///
/// Same as [`handshake_initiator`].
pub async fn handshake_responder(
stream: TcpStream,
keypair: &NoiseKeypair,
) -> Result<NoiseTransport, CoreError> {
let params: snow::params::NoiseParams = NOISE_PARAMS
.parse()
.expect("NOISE_PARAMS is a compile-time constant and must parse successfully");
let private = keypair.private_bytes();
let mut session = snow::Builder::new(params)
.local_private_key(&private[..])
.build_responder()
.map_err(CoreError::Noise)?;
drop(private);
let mut framed = Framed::new(stream, LengthPrefixedCodec::new());
let mut buf = vec![0u8; NOISE_MAX_MSG];
// ── Message 1: <- e ──────────────────────────────────────────────────────
let msg1 = recv_handshake_frame(&mut framed).await?;
session
.read_message(&msg1, &mut buf)
.map_err(CoreError::Noise)?;
// ── Message 2: -> e, ee, s, es ───────────────────────────────────────────
let len = session
.write_message(&[], &mut buf)
.map_err(CoreError::Noise)?;
framed
.send(Bytes::copy_from_slice(&buf[..len]))
.await
.map_err(CoreError::Codec)?;
// ── Message 3: <- s, se ──────────────────────────────────────────────────
let msg3 = recv_handshake_frame(&mut framed).await?;
session
.read_message(&msg3, &mut buf)
.map_err(CoreError::Noise)?;
zeroize::Zeroize::zeroize(&mut buf);
// Capture the remote static key from HandshakeState before consuming it.
let remote_static = session.get_remote_static().map(|k| k.to_vec());
let transport_session = session.into_transport_mode().map_err(CoreError::Noise)?;
Ok(NoiseTransport {
framed,
session: transport_session,
remote_static,
})
}
// ── Private helpers ───────────────────────────────────────────────────────────
/// Read one handshake frame from `framed`, mapping stream closure to
/// [`CoreError::HandshakeIncomplete`].
async fn recv_handshake_frame(
framed: &mut Framed<TcpStream, LengthPrefixedCodec>,
) -> Result<bytes::BytesMut, CoreError> {
framed
.next()
.await
.ok_or(CoreError::HandshakeIncomplete)?
.map_err(CoreError::Codec)
}

View File

@@ -32,6 +32,9 @@ quinn-proto = { workspace = true }
rustls = { workspace = true }
rcgen = { workspace = true }
# Database
rusqlite = { workspace = true }
# Error handling
anyhow = { workspace = true }
thiserror = { workspace = true }
@@ -40,3 +43,4 @@ serde = { workspace = true }
# CLI
clap = { workspace = true }
toml = { version = "0.8" }

View File

@@ -25,14 +25,15 @@
//! | `QUICNPROTOCHAT_LISTEN` | `--listen` | `0.0.0.0:4201` |
//! | `RUST_LOG` | — | `info` |
use std::{fs, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
use std::{fs, net::SocketAddr, path::{Path, PathBuf}, sync::Arc, time::Duration};
use anyhow::Context;
use serde::Deserialize;
use capnp::capability::Promise;
use capnp_rpc::{rpc_twoparty_capnp::Side, twoparty, RpcSystem};
use clap::Parser;
use dashmap::DashMap;
use quicnprotochat_proto::node_capnp::node_service;
use quicnprotochat_proto::node_capnp::{auth, node_service};
use quinn::{Endpoint, ServerConfig};
use quinn_proto::crypto::rustls::QuicServerConfig;
use rcgen::generate_simple_self_signed;
@@ -43,12 +44,139 @@ use tokio::sync::Notify;
use tokio::time::timeout;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
mod sql_store;
mod storage;
use storage::{FileBackedStore, StorageError};
use sql_store::SqlStore;
use storage::{FileBackedStore, Store, StorageError};
const MAX_PAYLOAD_BYTES: usize = 5 * 1024 * 1024; // 5 MB cap per message
const MAX_KEYPACKAGE_BYTES: usize = 1 * 1024 * 1024; // 1 MB cap per KeyPackage
const CURRENT_WIRE_VERSION: u16 = 1; // allow 0 (legacy) and 1 (current)
const CURRENT_WIRE_VERSION: u16 = 1; // legacy disabled; current wire version only
const DEFAULT_LISTEN: &str = "0.0.0.0:7000";
const DEFAULT_DATA_DIR: &str = "data";
const DEFAULT_TLS_CERT: &str = "data/server-cert.der";
const DEFAULT_TLS_KEY: &str = "data/server-key.der";
const DEFAULT_STORE_BACKEND: &str = "file";
const DEFAULT_DB_PATH: &str = "data/quicnprotochat.db";
#[derive(Clone, Debug)]
struct AuthConfig {
required_token: Option<Vec<u8>>,
}
impl AuthConfig {
fn new(required_token: Option<String>) -> Self {
let required_token = required_token.filter(|s| !s.is_empty()).map(|s| s.into_bytes());
Self { required_token }
}
}
#[derive(Debug, Default, Deserialize)]
struct FileConfig {
listen: Option<String>,
data_dir: Option<String>,
tls_cert: Option<PathBuf>,
tls_key: Option<PathBuf>,
auth_token: Option<String>,
store_backend: Option<String>,
db_path: Option<PathBuf>,
db_key: Option<String>,
}
#[derive(Debug)]
struct EffectiveConfig {
listen: String,
data_dir: String,
tls_cert: PathBuf,
tls_key: PathBuf,
auth_token: Option<String>,
store_backend: String,
db_path: PathBuf,
db_key: String,
}
fn load_config(path: Option<&Path>) -> anyhow::Result<FileConfig> {
let path = match path {
Some(p) => PathBuf::from(p),
None => PathBuf::from("quicnprotochat-server.toml"),
};
if !path.exists() {
return Ok(FileConfig::default());
}
let contents = fs::read_to_string(&path)
.with_context(|| format!("read config file {path:?}"))?;
let cfg: FileConfig = toml::from_str(&contents)
.with_context(|| format!("parse config file {path:?}"))?;
Ok(cfg)
}
fn merge_config(args: &Args, file: &FileConfig) -> EffectiveConfig {
let listen = if args.listen == DEFAULT_LISTEN {
file.listen.clone().unwrap_or_else(|| DEFAULT_LISTEN.to_string())
} else {
args.listen.clone()
};
let data_dir = if args.data_dir == DEFAULT_DATA_DIR {
file.data_dir.clone().unwrap_or_else(|| DEFAULT_DATA_DIR.to_string())
} else {
args.data_dir.clone()
};
let tls_cert = if args.tls_cert == PathBuf::from(DEFAULT_TLS_CERT) {
file.tls_cert.clone().unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_CERT))
} else {
args.tls_cert.clone()
};
let tls_key = if args.tls_key == PathBuf::from(DEFAULT_TLS_KEY) {
file.tls_key.clone().unwrap_or_else(|| PathBuf::from(DEFAULT_TLS_KEY))
} else {
args.tls_key.clone()
};
let auth_token = if args.auth_token.is_some() {
args.auth_token.clone()
} else {
file.auth_token.clone()
};
let store_backend = if args.store_backend == DEFAULT_STORE_BACKEND {
file.store_backend
.clone()
.unwrap_or_else(|| DEFAULT_STORE_BACKEND.to_string())
} else {
args.store_backend.clone()
};
let db_path = if args.db_path == PathBuf::from(DEFAULT_DB_PATH) {
file.db_path
.clone()
.unwrap_or_else(|| PathBuf::from(DEFAULT_DB_PATH))
} else {
args.db_path.clone()
};
let db_key = if args.db_key.is_empty() {
file.db_key.clone().unwrap_or_else(|| args.db_key.clone())
} else {
args.db_key.clone()
};
EffectiveConfig {
listen,
data_dir,
tls_cert,
tls_key,
auth_token,
store_backend,
db_path,
db_key,
}
}
// ── CLI ───────────────────────────────────────────────────────────────────────
@@ -59,37 +187,50 @@ const CURRENT_WIRE_VERSION: u16 = 1; // allow 0 (legacy) and 1 (current)
version
)]
struct Args {
/// Optional path to a TOML config file (fields map to CLI flags).
#[arg(long, env = "QUICNPROTOCHAT_CONFIG")]
config: Option<PathBuf>,
/// QUIC listen address (host:port).
#[arg(long, default_value = "0.0.0.0:4201", env = "QUICNPROTOCHAT_LISTEN")]
#[arg(long, default_value = DEFAULT_LISTEN, env = "QUICNPROTOCHAT_LISTEN")]
listen: String,
/// Directory for persisted server data (KeyPackages + delivery queues).
#[arg(long, default_value = "data", env = "QUICNPROTOCHAT_DATA_DIR")]
#[arg(long, default_value = DEFAULT_DATA_DIR, env = "QUICNPROTOCHAT_DATA_DIR")]
data_dir: String,
/// TLS certificate path (generated automatically if missing).
#[arg(
long,
default_value = "data/server-cert.der",
env = "QUICNPROTOCHAT_TLS_CERT"
)]
#[arg(long, default_value = DEFAULT_TLS_CERT, env = "QUICNPROTOCHAT_TLS_CERT")]
tls_cert: PathBuf,
/// TLS private key path (generated automatically if missing).
#[arg(
long,
default_value = "data/server-key.der",
env = "QUICNPROTOCHAT_TLS_KEY"
)]
#[arg(long, default_value = DEFAULT_TLS_KEY, env = "QUICNPROTOCHAT_TLS_KEY")]
tls_key: PathBuf,
/// Required bearer token for auth.version=1 requests. If unset, any non-empty token is accepted.
#[arg(long, env = "QUICNPROTOCHAT_AUTH_TOKEN")]
auth_token: Option<String>,
/// Storage backend: "file" (bincode) or "sql" (SQLCipher-encrypted).
#[arg(long, default_value = DEFAULT_STORE_BACKEND, env = "QUICNPROTOCHAT_STORE_BACKEND")]
store_backend: String,
/// Path to the SQLCipher database file (only used when --store-backend=sql).
#[arg(long, default_value = DEFAULT_DB_PATH, env = "QUICNPROTOCHAT_DB_PATH")]
db_path: PathBuf,
/// SQLCipher encryption key. Empty string disables encryption.
#[arg(long, default_value = "", env = "QUICNPROTOCHAT_DB_KEY")]
db_key: String,
}
// ── Node service implementation ─────────────────────────────────────────────
/// Cap'n Proto RPC server implementation for `NodeService` (Auth + Delivery).
struct NodeServiceImpl {
store: Arc<FileBackedStore>,
store: Arc<dyn Store>,
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
auth_cfg: Arc<AuthConfig>,
}
impl NodeServiceImpl {
@@ -114,6 +255,9 @@ impl node_service::Server for NodeServiceImpl {
let (identity_key, package) = match params {
Ok(p) => {
if let Err(e) = validate_auth(&self.auth_cfg, p.get_auth()) {
return Promise::err(e);
}
let ik = match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
@@ -177,6 +321,14 @@ impl node_service::Server for NodeServiceImpl {
},
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
};
if let Err(e) = params
.get()
.ok()
.map(|p| validate_auth(&self.auth_cfg, p.get_auth()))
.transpose()
{
return Promise::err(e);
}
if identity_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
@@ -234,6 +386,9 @@ impl node_service::Server for NodeServiceImpl {
};
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
if let Err(e) = validate_auth(&self.auth_cfg, p.get_auth()) {
return Promise::err(e);
}
if recipient_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
@@ -252,9 +407,9 @@ impl node_service::Server for NodeServiceImpl {
MAX_PAYLOAD_BYTES
)));
}
if version != 0 && version != CURRENT_WIRE_VERSION {
if version != CURRENT_WIRE_VERSION {
return Promise::err(capnp::Error::failed(format!(
"unsupported wire version {} (expected 0 or {CURRENT_WIRE_VERSION})",
"unsupported wire version {} (expected {CURRENT_WIRE_VERSION})",
version
)));
}
@@ -300,7 +455,15 @@ impl node_service::Server for NodeServiceImpl {
.get()
.ok()
.map(|p| p.get_version())
.unwrap_or(0);
.unwrap_or(CURRENT_WIRE_VERSION);
if let Err(e) = params
.get()
.ok()
.map(|p| validate_auth(&self.auth_cfg, p.get_auth()))
.transpose()
{
return Promise::err(e);
}
if recipient_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
@@ -308,9 +471,9 @@ impl node_service::Server for NodeServiceImpl {
recipient_key.len()
)));
}
if version != 0 && version != CURRENT_WIRE_VERSION {
if version != CURRENT_WIRE_VERSION {
return Promise::err(capnp::Error::failed(format!(
"unsupported wire version {} (expected 0 or {CURRENT_WIRE_VERSION})",
"unsupported wire version {} (expected {CURRENT_WIRE_VERSION})",
version
)));
}
@@ -355,6 +518,9 @@ impl node_service::Server for NodeServiceImpl {
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
let version = p.get_version();
let timeout_ms = p.get_timeout_ms();
if let Err(e) = validate_auth(&self.auth_cfg, p.get_auth()) {
return Promise::err(e);
}
if recipient_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
@@ -362,9 +528,9 @@ impl node_service::Server for NodeServiceImpl {
recipient_key.len()
)));
}
if version != 0 && version != CURRENT_WIRE_VERSION {
if version != CURRENT_WIRE_VERSION {
return Promise::err(capnp::Error::failed(format!(
"unsupported wire version {} (expected 0 or {CURRENT_WIRE_VERSION})",
"unsupported wire version {} (expected {CURRENT_WIRE_VERSION})",
version
)));
}
@@ -403,6 +569,103 @@ impl node_service::Server for NodeServiceImpl {
results.get().set_status("ok");
Promise::ok(())
}
/// Store a hybrid (X25519 + ML-KEM-768) public key for an identity.
fn upload_hybrid_key(
&mut self,
params: node_service::UploadHybridKeyParams,
_results: node_service::UploadHybridKeyResults,
) -> Promise<(), capnp::Error> {
let p = match params.get() {
Ok(p) => p,
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
};
let identity_key = match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
};
let hybrid_pk = match p.get_hybrid_public_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
};
if identity_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
"identityKey must be exactly 32 bytes, got {}",
identity_key.len()
)));
}
if hybrid_pk.is_empty() {
return Promise::err(capnp::Error::failed(
"hybridPublicKey must not be empty".to_string(),
));
}
if let Err(e) = self
.store
.upload_hybrid_key(&identity_key, hybrid_pk)
.map_err(storage_err)
{
return Promise::err(e);
}
tracing::debug!(
identity = %fmt_hex(&identity_key[..4]),
"hybrid public key uploaded"
);
Promise::ok(())
}
/// Fetch a peer's hybrid public key.
fn fetch_hybrid_key(
&mut self,
params: node_service::FetchHybridKeyParams,
mut results: node_service::FetchHybridKeyResults,
) -> Promise<(), capnp::Error> {
let identity_key = match params.get() {
Ok(p) => match p.get_identity_key() {
Ok(v) => v.to_vec(),
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
},
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
};
if identity_key.len() != 32 {
return Promise::err(capnp::Error::failed(format!(
"identityKey must be exactly 32 bytes, got {}",
identity_key.len()
)));
}
let hybrid_pk = match self
.store
.fetch_hybrid_key(&identity_key)
.map_err(storage_err)
{
Ok(p) => p,
Err(e) => return Promise::err(e),
};
match hybrid_pk {
Some(pk) => {
tracing::debug!(
identity = %fmt_hex(&identity_key[..4]),
"hybrid key fetched"
);
results.get().set_hybrid_public_key(&pk);
}
None => {
tracing::debug!(
identity = %fmt_hex(&identity_key[..4]),
"no hybrid key for identity"
);
results.get().set_hybrid_public_key(&[]);
}
}
Promise::ok(())
}
}
fn fill_payloads_wait(results: &mut node_service::FetchWaitResults, messages: Vec<Vec<u8>>) {
@@ -416,6 +679,42 @@ fn storage_err(err: StorageError) -> capnp::Error {
capnp::Error::failed(format!("{err}"))
}
fn validate_auth(
cfg: &AuthConfig,
auth: Result<auth::Reader<'_>, capnp::Error>,
) -> Result<(), capnp::Error> {
let auth = auth?;
let version = auth.get_version();
if version != 1 {
return Err(capnp::Error::failed(format!(
"unsupported auth version {} (expected 1)",
version
)));
}
let token = auth
.get_access_token()
.map_err(|e| capnp::Error::failed(format!("auth.accessToken: {e}")))?
.to_vec();
if token.is_empty() {
return Err(capnp::Error::failed(
"auth.version=1 requires non-empty accessToken".to_string(),
));
}
if let Some(expected) = &cfg.required_token {
if &token != expected {
return Err(capnp::Error::failed("invalid accessToken".to_string()));
}
}
// Early-development stance: no legacy/no-auth path to avoid maintaining divergent behavior.
Ok(())
}
// ── Entry point ───────────────────────────────────────────────────────────────
#[tokio::main]
@@ -428,20 +727,42 @@ async fn main() -> anyhow::Result<()> {
.init();
let args = Args::parse();
let file_cfg = load_config(args.config.as_deref())?;
let effective = merge_config(&args, &file_cfg);
let listen: SocketAddr = args.listen.parse().context("--listen must be host:port")?;
let listen: SocketAddr = effective
.listen
.parse()
.context("--listen must be host:port")?;
let server_config = build_server_config(&args.tls_cert, &args.tls_key)
let server_config = build_server_config(&effective.tls_cert, &effective.tls_key)
.context("failed to build TLS/QUIC server config")?;
// Shared storage — persisted to disk for restart safety.
let store = Arc::new(FileBackedStore::open(&args.data_dir)?);
let store: Arc<dyn Store> = match effective.store_backend.as_str() {
"sql" => {
if let Some(parent) = effective.db_path.parent() {
std::fs::create_dir_all(parent).context("create db dir")?;
}
tracing::info!(
path = %effective.db_path.display(),
encrypted = !effective.db_key.is_empty(),
"opening SQLCipher store"
);
Arc::new(SqlStore::open(&effective.db_path, &effective.db_key)?)
}
"file" | _ => {
tracing::info!(dir = %effective.data_dir, "opening file-backed store");
Arc::new(FileBackedStore::open(&effective.data_dir)?)
}
};
let auth_cfg = Arc::new(AuthConfig::new(effective.auth_token.clone()));
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = Arc::new(DashMap::new());
let endpoint = Endpoint::server(server_config, listen)?;
tracing::info!(
addr = %args.listen,
addr = %effective.listen,
"accepting QUIC connections"
);
@@ -466,8 +787,9 @@ async fn main() -> anyhow::Result<()> {
let store = Arc::clone(&store);
let waiters = Arc::clone(&waiters);
let auth_cfg = Arc::clone(&auth_cfg);
tokio::task::spawn_local(async move {
if let Err(e) = handle_node_connection(connecting, store, waiters).await {
if let Err(e) = handle_node_connection(connecting, store, waiters, auth_cfg).await {
tracing::warn!(error = %e, "connection error");
}
});
@@ -483,8 +805,9 @@ async fn main() -> anyhow::Result<()> {
/// Handle one NodeService connection.
async fn handle_node_connection(
connecting: quinn::Connecting,
store: Arc<FileBackedStore>,
store: Arc<dyn Store>,
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
auth_cfg: Arc<AuthConfig>,
) -> Result<(), anyhow::Error> {
let connection = connecting.await?;
@@ -498,7 +821,11 @@ async fn handle_node_connection(
let network = twoparty::VatNetwork::new(reader, writer, Side::Server, Default::default());
let service: node_service::Client = capnp_rpc::new_client(NodeServiceImpl { store, waiters });
let service: node_service::Client = capnp_rpc::new_client(NodeServiceImpl {
store,
waiters,
auth_cfg,
});
RpcSystem::new(Box::new(network), Some(service.client))
.await

View File

@@ -0,0 +1,315 @@
//! SQLCipher-backed persistent storage.
//!
//! Uses `rusqlite` with `bundled-sqlcipher` for encrypted-at-rest storage.
//! Implements the same [`Store`] trait as [`FileBackedStore`] but with proper
//! ACID transactions and indexed queries.
use std::path::Path;
use std::sync::Mutex;
use rusqlite::{params, Connection};
use crate::storage::{StorageError, Store};
/// SQLCipher-encrypted storage backend.
///
/// All data is stored in a single encrypted SQLite database. The encryption
/// key is set via `PRAGMA key` at open time.
pub struct SqlStore {
conn: Mutex<Connection>,
}
impl SqlStore {
/// Open (or create) an encrypted database at `path`.
///
/// `key` is the passphrase used by SQLCipher. Pass an empty string for an
/// unencrypted database (useful for testing).
pub fn open(path: impl AsRef<Path>, key: &str) -> Result<Self, StorageError> {
let conn = Connection::open(path).map_err(|e| StorageError::Db(e.to_string()))?;
if !key.is_empty() {
conn.pragma_update(None, "key", key)
.map_err(|e| StorageError::Db(format!("PRAGMA key failed: {e}")))?;
}
// Performance pragmas — safe for a single-writer server.
conn.execute_batch(
"PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA foreign_keys = ON;",
)
.map_err(|e| StorageError::Db(e.to_string()))?;
let store = Self {
conn: Mutex::new(conn),
};
store.migrate()?;
Ok(store)
}
/// Create schema tables if they don't exist yet.
fn migrate(&self) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS key_packages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
identity_key BLOB NOT NULL,
package_data BLOB NOT NULL,
created_at INTEGER DEFAULT (strftime('%s','now'))
);
CREATE TABLE IF NOT EXISTS deliveries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
recipient_key BLOB NOT NULL,
channel_id BLOB NOT NULL DEFAULT X'',
payload BLOB NOT NULL,
created_at INTEGER DEFAULT (strftime('%s','now'))
);
CREATE TABLE IF NOT EXISTS hybrid_keys (
identity_key BLOB PRIMARY KEY,
hybrid_public_key BLOB NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_kp_identity
ON key_packages(identity_key);
CREATE INDEX IF NOT EXISTS idx_del_recipient_channel
ON deliveries(recipient_key, channel_id);",
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(())
}
}
impl Store for SqlStore {
fn upload_key_package(
&self,
identity_key: &[u8],
package: Vec<u8>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
conn.execute(
"INSERT INTO key_packages (identity_key, package_data) VALUES (?1, ?2)",
params![identity_key, package],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(())
}
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
// Find the oldest KeyPackage (FIFO) and delete it atomically.
let mut stmt = conn
.prepare(
"SELECT id, package_data FROM key_packages
WHERE identity_key = ?1
ORDER BY id ASC
LIMIT 1",
)
.map_err(|e| StorageError::Db(e.to_string()))?;
let row = stmt
.query_row(params![identity_key], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, Vec<u8>>(1)?))
})
.optional()
.map_err(|e| StorageError::Db(e.to_string()))?;
match row {
Some((id, package)) => {
conn.execute("DELETE FROM key_packages WHERE id = ?1", params![id])
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(Some(package))
}
None => Ok(None),
}
}
fn enqueue(
&self,
recipient_key: &[u8],
channel_id: &[u8],
payload: Vec<u8>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
conn.execute(
"INSERT INTO deliveries (recipient_key, channel_id, payload) VALUES (?1, ?2, ?3)",
params![recipient_key, channel_id, payload],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(())
}
fn fetch(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<Vec<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
let mut stmt = conn
.prepare(
"SELECT id, payload FROM deliveries
WHERE recipient_key = ?1 AND channel_id = ?2
ORDER BY id ASC",
)
.map_err(|e| StorageError::Db(e.to_string()))?;
let rows: Vec<(i64, Vec<u8>)> = stmt
.query_map(params![recipient_key, channel_id], |row| {
Ok((row.get(0)?, row.get(1)?))
})
.map_err(|e| StorageError::Db(e.to_string()))?
.collect::<Result<Vec<_>, _>>()
.map_err(|e| StorageError::Db(e.to_string()))?;
if !rows.is_empty() {
let ids: Vec<i64> = rows.iter().map(|(id, _)| *id).collect();
// Delete fetched rows in a single statement.
let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
let sql = format!("DELETE FROM deliveries WHERE id IN ({placeholders})");
let params: Vec<&dyn rusqlite::types::ToSql> =
ids.iter().map(|id| id as &dyn rusqlite::types::ToSql).collect();
conn.execute(&sql, params.as_slice())
.map_err(|e| StorageError::Db(e.to_string()))?;
}
Ok(rows.into_iter().map(|(_, payload)| payload).collect())
}
fn upload_hybrid_key(
&self,
identity_key: &[u8],
hybrid_pk: Vec<u8>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
conn.execute(
"INSERT OR REPLACE INTO hybrid_keys (identity_key, hybrid_public_key) VALUES (?1, ?2)",
params![identity_key, hybrid_pk],
)
.map_err(|e| StorageError::Db(e.to_string()))?;
Ok(())
}
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let conn = self.conn.lock().unwrap();
let mut stmt = conn
.prepare("SELECT hybrid_public_key FROM hybrid_keys WHERE identity_key = ?1")
.map_err(|e| StorageError::Db(e.to_string()))?;
stmt.query_row(params![identity_key], |row| row.get(0))
.optional()
.map_err(|e| StorageError::Db(e.to_string()))
}
}
/// Convenience extension for `rusqlite::OptionalExtension`.
trait OptionalExt<T> {
fn optional(self) -> Result<Option<T>, rusqlite::Error>;
}
impl<T> OptionalExt<T> for Result<T, rusqlite::Error> {
fn optional(self) -> Result<Option<T>, rusqlite::Error> {
match self {
Ok(v) => Ok(Some(v)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn open_in_memory() -> SqlStore {
SqlStore::open(":memory:", "").unwrap()
}
#[test]
fn key_package_fifo() {
let store = open_in_memory();
let ik = b"alice_identity_key__32bytes_long";
// Pad to 32 bytes to match real usage
let mut identity = [0u8; 32];
identity[..ik.len()].copy_from_slice(ik);
store
.upload_key_package(&identity, b"kp1".to_vec())
.unwrap();
store
.upload_key_package(&identity, b"kp2".to_vec())
.unwrap();
assert_eq!(
store.fetch_key_package(&identity).unwrap(),
Some(b"kp1".to_vec())
);
assert_eq!(
store.fetch_key_package(&identity).unwrap(),
Some(b"kp2".to_vec())
);
assert_eq!(store.fetch_key_package(&identity).unwrap(), None);
}
#[test]
fn delivery_round_trip() {
let store = open_in_memory();
let rk = [1u8; 32];
let ch = b"channel-1";
store.enqueue(&rk, ch, b"msg1".to_vec()).unwrap();
store.enqueue(&rk, ch, b"msg2".to_vec()).unwrap();
let msgs = store.fetch(&rk, ch).unwrap();
assert_eq!(msgs, vec![b"msg1".to_vec(), b"msg2".to_vec()]);
// Queue is drained.
assert!(store.fetch(&rk, ch).unwrap().is_empty());
}
#[test]
fn hybrid_key_round_trip() {
let store = open_in_memory();
let ik = [2u8; 32];
let pk = b"hybrid_public_key_data".to_vec();
store.upload_hybrid_key(&ik, pk.clone()).unwrap();
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(pk));
}
#[test]
fn hybrid_key_upsert() {
let store = open_in_memory();
let ik = [3u8; 32];
store
.upload_hybrid_key(&ik, b"v1".to_vec())
.unwrap();
store
.upload_hybrid_key(&ik, b"v2".to_vec())
.unwrap();
assert_eq!(
store.fetch_hybrid_key(&ik).unwrap(),
Some(b"v2".to_vec())
);
}
#[test]
fn separate_channels_isolated() {
let store = open_in_memory();
let rk = [4u8; 32];
store.enqueue(&rk, b"ch-a", b"a1".to_vec()).unwrap();
store.enqueue(&rk, b"ch-b", b"b1".to_vec()).unwrap();
let a_msgs = store.fetch(&rk, b"ch-a").unwrap();
assert_eq!(a_msgs, vec![b"a1".to_vec()]);
let b_msgs = store.fetch(&rk, b"ch-b").unwrap();
assert_eq!(b_msgs, vec![b"b1".to_vec()]);
}
}

View File

@@ -1,7 +1,7 @@
use std::{
collections::{HashMap, VecDeque},
fs,
hash::{Hash, Hasher},
hash::Hash,
path::{Path, PathBuf},
sync::Mutex,
};
@@ -14,13 +14,46 @@ pub enum StorageError {
Io(String),
#[error("serialization error")]
Serde,
#[error("database error: {0}")]
Db(String),
}
#[derive(Serialize, Deserialize, Default)]
struct QueueMapV1 {
map: HashMap<Vec<u8>, VecDeque<Vec<u8>>>,
// ── Store trait ──────────────────────────────────────────────────────────────
/// Abstraction over storage backends (file-backed, SQLCipher, etc.).
pub trait Store: Send + Sync {
fn upload_key_package(
&self,
identity_key: &[u8],
package: Vec<u8>,
) -> Result<(), StorageError>;
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
fn enqueue(
&self,
recipient_key: &[u8],
channel_id: &[u8],
payload: Vec<u8>,
) -> Result<(), StorageError>;
fn fetch(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<Vec<Vec<u8>>, StorageError>;
fn upload_hybrid_key(
&self,
identity_key: &[u8],
hybrid_pk: Vec<u8>,
) -> Result<(), StorageError>;
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
}
// ── ChannelKey ───────────────────────────────────────────────────────────────
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
pub struct ChannelKey {
pub channel_id: Vec<u8>,
@@ -28,12 +61,19 @@ pub struct ChannelKey {
}
impl Hash for ChannelKey {
fn hash<H: Hasher>(&self, state: &mut H) {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.channel_id.hash(state);
self.recipient_key.hash(state);
}
}
// ── FileBackedStore ──────────────────────────────────────────────────────────
#[derive(Serialize, Deserialize, Default)]
struct QueueMapV1 {
map: HashMap<Vec<u8>, VecDeque<Vec<u8>>>,
}
#[derive(Serialize, Deserialize, Default)]
struct QueueMapV2 {
map: HashMap<ChannelKey, VecDeque<Vec<u8>>>,
@@ -45,8 +85,10 @@ struct QueueMapV2 {
pub struct FileBackedStore {
kp_path: PathBuf,
ds_path: PathBuf,
hk_path: PathBuf,
key_packages: Mutex<HashMap<Vec<u8>, VecDeque<Vec<u8>>>>,
deliveries: Mutex<HashMap<ChannelKey, VecDeque<Vec<u8>>>>,
hybrid_keys: Mutex<HashMap<Vec<u8>, Vec<u8>>>,
}
impl FileBackedStore {
@@ -57,73 +99,23 @@ impl FileBackedStore {
}
let kp_path = dir.join("keypackages.bin");
let ds_path = dir.join("deliveries.bin");
let hk_path = dir.join("hybridkeys.bin");
let key_packages = Mutex::new(Self::load_map(&kp_path)?);
let deliveries = Mutex::new(Self::load_map(&ds_path)?);
let key_packages = Mutex::new(Self::load_kp_map(&kp_path)?);
let deliveries = Mutex::new(Self::load_delivery_map(&ds_path)?);
let hybrid_keys = Mutex::new(Self::load_hybrid_keys(&hk_path)?);
Ok(Self {
kp_path,
ds_path,
hk_path,
key_packages,
deliveries,
hybrid_keys,
})
}
pub fn upload_key_package(
&self,
identity_key: &[u8],
package: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.key_packages.lock().unwrap();
map.entry(identity_key.to_vec())
.or_default()
.push_back(package);
self.flush_map(&self.kp_path, &*map)
}
pub fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let mut map = self.key_packages.lock().unwrap();
let package = map.get_mut(identity_key).and_then(|q| q.pop_front());
self.flush_map(&self.kp_path, &*map)?;
Ok(package)
}
pub fn enqueue(
&self,
recipient_key: &[u8],
channel_id: &[u8],
payload: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.deliveries.lock().unwrap();
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
map.entry(key)
.or_default()
.push_back(payload);
self.flush_map(&self.ds_path, &*map)
}
pub fn fetch(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<Vec<Vec<u8>>, StorageError> {
let mut map = self.deliveries.lock().unwrap();
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
let messages = map
.get_mut(&key)
.map(|q| q.drain(..).collect())
.unwrap_or_default();
self.flush_map(&self.ds_path, &*map)?;
Ok(messages)
}
fn load_map(path: &Path) -> Result<HashMap<ChannelKey, VecDeque<Vec<u8>>>, StorageError> {
fn load_kp_map(path: &Path) -> Result<HashMap<Vec<u8>, VecDeque<Vec<u8>>>, StorageError> {
if !path.exists() {
return Ok(HashMap::new());
}
@@ -131,7 +123,32 @@ impl FileBackedStore {
if bytes.is_empty() {
return Ok(HashMap::new());
}
// Try v2 format (channel-aware). Fallback to legacy v1.
let map: QueueMapV1 = bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)?;
Ok(map.map)
}
fn flush_kp_map(
&self,
path: &Path,
map: &HashMap<Vec<u8>, VecDeque<Vec<u8>>>,
) -> Result<(), StorageError> {
let payload = QueueMapV1 { map: map.clone() };
let bytes = bincode::serialize(&payload).map_err(|_| StorageError::Serde)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
}
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
}
fn load_delivery_map(path: &Path) -> Result<HashMap<ChannelKey, VecDeque<Vec<u8>>>, StorageError> {
if !path.exists() {
return Ok(HashMap::new());
}
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
if bytes.is_empty() {
return Ok(HashMap::new());
}
// Try v2 format (channel-aware). Fallback to legacy v1 for upgrade.
if let Ok(map) = bincode::deserialize::<QueueMapV2>(&bytes) {
return Ok(map.map);
}
@@ -149,7 +166,7 @@ impl FileBackedStore {
Ok(upgraded)
}
fn flush_map(
fn flush_delivery_map(
&self,
path: &Path,
map: &HashMap<ChannelKey, VecDeque<Vec<u8>>>,
@@ -161,4 +178,98 @@ impl FileBackedStore {
}
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
}
fn load_hybrid_keys(path: &Path) -> Result<HashMap<Vec<u8>, Vec<u8>>, StorageError> {
if !path.exists() {
return Ok(HashMap::new());
}
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
if bytes.is_empty() {
return Ok(HashMap::new());
}
bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)
}
fn flush_hybrid_keys(
&self,
path: &Path,
map: &HashMap<Vec<u8>, Vec<u8>>,
) -> Result<(), StorageError> {
let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
}
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
}
}
impl Store for FileBackedStore {
fn upload_key_package(
&self,
identity_key: &[u8],
package: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.key_packages.lock().unwrap();
map.entry(identity_key.to_vec())
.or_default()
.push_back(package);
self.flush_kp_map(&self.kp_path, &*map)
}
fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let mut map = self.key_packages.lock().unwrap();
let package = map.get_mut(identity_key).and_then(|q| q.pop_front());
self.flush_kp_map(&self.kp_path, &*map)?;
Ok(package)
}
fn enqueue(
&self,
recipient_key: &[u8],
channel_id: &[u8],
payload: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.deliveries.lock().unwrap();
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
map.entry(key)
.or_default()
.push_back(payload);
self.flush_delivery_map(&self.ds_path, &*map)
}
fn fetch(
&self,
recipient_key: &[u8],
channel_id: &[u8],
) -> Result<Vec<Vec<u8>>, StorageError> {
let mut map = self.deliveries.lock().unwrap();
let key = ChannelKey {
channel_id: channel_id.to_vec(),
recipient_key: recipient_key.to_vec(),
};
let messages = map
.get_mut(&key)
.map(|q| q.drain(..).collect())
.unwrap_or_default();
self.flush_delivery_map(&self.ds_path, &*map)?;
Ok(messages)
}
fn upload_hybrid_key(
&self,
identity_key: &[u8],
hybrid_pk: Vec<u8>,
) -> Result<(), StorageError> {
let mut map = self.hybrid_keys.lock().unwrap();
map.insert(identity_key.to_vec(), hybrid_pk);
self.flush_hybrid_keys(&self.hk_path, &*map)
}
fn fetch_hybrid_key(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
let map = self.hybrid_keys.lock().unwrap();
Ok(map.get(identity_key).cloned())
}
}