chore: rename quicproquo → quicprochat in Rust workspace

Rename all crate directories, package names, binary names, proto
package/module paths, ALPN strings, env var prefixes, config filenames,
mDNS service names, and plugin ABI symbols from quicproquo/qpq to
quicprochat/qpc.
This commit is contained in:
2026-03-07 18:24:52 +01:00
parent d8c1392587
commit a710037dde
212 changed files with 609 additions and 609 deletions

View File

@@ -0,0 +1,88 @@
[package]
name = "quicprochat-core"
version = "0.1.0"
edition.workspace = true
description = "Crypto primitives, MLS state machine, and hybrid post-quantum KEM for quicprochat."
license = "Apache-2.0 OR MIT"
repository.workspace = true
[features]
default = ["native"]
# The "native" feature enables MLS (openmls), OPAQUE, Cap'n Proto, tokio, and
# filesystem-backed key storage. Disable it (--no-default-features) to compile
# the pure-crypto subset to wasm32-unknown-unknown.
native = [
"dep:openmls",
"dep:openmls_rust_crypto",
"dep:openmls_traits",
"dep:tls_codec",
"dep:opaque-ke",
"dep:bincode",
"dep:capnp",
"dep:quicprochat-proto",
"dep:tokio",
]
[dependencies]
# Crypto — classical (always available, WASM-safe)
x25519-dalek = { workspace = true }
ed25519-dalek = { workspace = true }
sha2 = { workspace = true }
hmac = { workspace = true }
hkdf = { workspace = true }
ciborium = { workspace = true }
chacha20poly1305 = { workspace = true }
zeroize = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
argon2 = { workspace = true }
thiserror = { workspace = true }
# Crypto — post-quantum hybrid KEM (M7) — always available, WASM-safe
ml-kem = { workspace = true }
# Crypto — OPAQUE password-authenticated key exchange (native only)
opaque-ke = { workspace = true, optional = true }
# Crypto — MLS (M2) (native only)
openmls = { workspace = true, optional = true }
openmls_rust_crypto = { workspace = true, optional = true }
openmls_traits = { workspace = true, optional = true }
tls_codec = { workspace = true, optional = true }
bincode = { workspace = true, optional = true }
# Serialisation (native only)
capnp = { workspace = true, optional = true }
quicprochat-proto = { path = "../quicprochat-proto", optional = true }
# Async runtime (native only)
tokio = { workspace = true, optional = true }
# WASM: provide getrandom with js backend
[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2", features = ["js"] }
[lints]
workspace = true
[dev-dependencies]
tokio = { workspace = true }
criterion = { version = "0.5", features = ["html_reports"] }
prost = "0.13"
[[bench]]
name = "serialization"
harness = false
[[bench]]
name = "mls_operations"
harness = false
[[bench]]
name = "hybrid_kem_bench"
harness = false
[[bench]]
name = "crypto_benchmarks"
harness = false

View File

@@ -0,0 +1,150 @@
#![allow(clippy::unwrap_used)]
//! Benchmark: Identity keypair operations, sealed sender, and message padding.
//!
//! Covers:
//! - [`IdentityKeypair`] generation, signing, and signature verification
//! - Sealed sender `seal` / `unseal` (Ed25519 sign + verify overhead)
//! - Message padding `pad` / `unpad` at various payload sizes
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use quicprochat_core::{compute_safety_number, IdentityKeypair, padding};
// ── Identity keypair benchmarks ──────────────────────────────────────────────
fn bench_identity_keygen(c: &mut Criterion) {
c.bench_function("identity_keygen", |b| {
b.iter(|| black_box(IdentityKeypair::generate()));
});
}
fn bench_identity_sign(c: &mut Criterion) {
let identity = IdentityKeypair::generate();
let payload = b"benchmark signing payload -- 32+ bytes of realistic data here";
c.bench_function("identity_sign", |b| {
b.iter(|| black_box(identity.sign_raw(black_box(payload))));
});
}
fn bench_identity_verify(c: &mut Criterion) {
let identity = IdentityKeypair::generate();
let payload = b"benchmark signing payload -- 32+ bytes of realistic data here";
let sig = identity.sign_raw(payload);
let pk = identity.public_key_bytes();
c.bench_function("identity_verify", |b| {
b.iter(|| {
IdentityKeypair::verify_raw(
black_box(&pk),
black_box(payload),
black_box(&sig),
)
.unwrap();
});
});
}
// ── Sealed sender benchmarks ─────────────────────────────────────────────────
fn bench_sealed_sender(c: &mut Criterion) {
use quicprochat_core::sealed_sender::{seal, unseal};
let sizes: &[(&str, usize)] = &[
("32B", 32),
("256B", 256),
("1KB", 1024),
("4KB", 4096),
];
let identity = IdentityKeypair::generate();
let mut group = c.benchmark_group("sealed_sender_seal");
for (label, size) in sizes {
let payload = vec![0xABu8; *size];
group.bench_with_input(
BenchmarkId::from_parameter(label),
&payload,
|b, payload| {
b.iter(|| black_box(seal(black_box(&identity), black_box(payload))));
},
);
}
group.finish();
let mut group = c.benchmark_group("sealed_sender_unseal");
for (label, size) in sizes {
let payload = vec![0xABu8; *size];
let sealed = seal(&identity, &payload);
group.bench_with_input(
BenchmarkId::from_parameter(label),
&sealed,
|b, sealed| {
b.iter(|| black_box(unseal(black_box(sealed)).unwrap()));
},
);
}
group.finish();
}
// ── Message padding benchmarks ────────────────────────────────────────────────
fn bench_padding(c: &mut Criterion) {
// Representative sizes: one per bucket + oversized
let sizes: &[(&str, usize)] = &[
("50B", 50), // → 256 bucket
("512B", 512), // → 1024 bucket
("2KB", 2048), // → 4096 bucket
("8KB", 8192), // → 16384 bucket
("20KB", 20480), // → 32768 (oversized)
];
let mut group = c.benchmark_group("padding_pad");
for (label, size) in sizes {
let payload = vec![0xABu8; *size];
group.bench_with_input(
BenchmarkId::from_parameter(label),
&payload,
|b, payload| {
b.iter(|| black_box(padding::pad(black_box(payload))));
},
);
}
group.finish();
let mut group = c.benchmark_group("padding_unpad");
for (label, size) in sizes {
let payload = vec![0xABu8; *size];
let padded = padding::pad(&payload);
group.bench_with_input(
BenchmarkId::from_parameter(label),
&padded,
|b, padded| {
b.iter(|| black_box(padding::unpad(black_box(padded)).unwrap()));
},
);
}
group.finish();
}
// ── Safety number benchmarks ─────────────────────────────────────────────────
fn bench_safety_number(c: &mut Criterion) {
let key_a = [0x1au8; 32];
let key_b = [0x2bu8; 32];
c.bench_function("safety_number", |b| {
b.iter(|| black_box(compute_safety_number(black_box(&key_a), black_box(&key_b))));
});
}
criterion_group!(
benches,
bench_identity_keygen,
bench_identity_sign,
bench_identity_verify,
bench_sealed_sender,
bench_padding,
bench_safety_number,
);
criterion_main!(benches);

View File

@@ -0,0 +1,153 @@
#![allow(clippy::unwrap_used)]
//! Benchmark: Hybrid KEM (X25519 + ML-KEM-768) vs classical-only encryption.
//!
//! Compares keypair generation, encryption, and decryption times for the
//! hybrid post-quantum scheme against classical X25519 + ChaCha20-Poly1305.
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use quicprochat_core::{hybrid_encrypt, hybrid_decrypt, HybridKeypair};
// ── Classical baseline (X25519 + ChaCha20-Poly1305) ─────────────────────────
use chacha20poly1305::{
aead::{Aead, KeyInit},
ChaCha20Poly1305, Key, Nonce,
};
use hkdf::Hkdf;
use rand::{rngs::OsRng, RngCore};
use sha2::Sha256;
use x25519_dalek::{EphemeralSecret, PublicKey as X25519Public, StaticSecret};
struct ClassicalKeypair {
secret: StaticSecret,
public: X25519Public,
}
impl ClassicalKeypair {
fn generate() -> Self {
let secret = StaticSecret::random_from_rng(OsRng);
let public = X25519Public::from(&secret);
Self { secret, public }
}
}
fn classical_encrypt(recipient_pk: &X25519Public, plaintext: &[u8]) -> Vec<u8> {
let eph_secret = EphemeralSecret::random_from_rng(OsRng);
let eph_public = X25519Public::from(&eph_secret);
let shared = eph_secret.diffie_hellman(recipient_pk);
let hk = Hkdf::<Sha256>::new(None, shared.as_bytes());
let mut key_bytes = [0u8; 32];
hk.expand(b"classical-bench", &mut key_bytes).unwrap();
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
let cipher = ChaCha20Poly1305::new(Key::from_slice(&key_bytes));
let ct = cipher
.encrypt(Nonce::from_slice(&nonce_bytes), plaintext)
.unwrap();
// Wire: eph_pk(32) || nonce(12) || ciphertext
let mut out = Vec::with_capacity(32 + 12 + ct.len());
out.extend_from_slice(eph_public.as_bytes());
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(&ct);
out
}
fn classical_decrypt(keypair: &ClassicalKeypair, envelope: &[u8]) -> Vec<u8> {
let eph_pk = X25519Public::from(<[u8; 32]>::try_from(&envelope[..32]).unwrap());
let nonce_bytes: [u8; 12] = envelope[32..44].try_into().unwrap();
let ct = &envelope[44..];
let shared = keypair.secret.diffie_hellman(&eph_pk);
let hk = Hkdf::<Sha256>::new(None, shared.as_bytes());
let mut key_bytes = [0u8; 32];
hk.expand(b"classical-bench", &mut key_bytes).unwrap();
let cipher = ChaCha20Poly1305::new(Key::from_slice(&key_bytes));
cipher
.decrypt(Nonce::from_slice(&nonce_bytes), ct)
.unwrap()
}
// ── Benchmarks ──────────────────────────────────────────────────────────────
fn bench_keygen(c: &mut Criterion) {
let mut group = c.benchmark_group("kem_keygen");
group.bench_function("hybrid", |b| {
b.iter(|| black_box(HybridKeypair::generate()));
});
group.bench_function("classical", |b| {
b.iter(|| black_box(ClassicalKeypair::generate()));
});
group.finish();
}
fn bench_encrypt(c: &mut Criterion) {
let sizes: &[(&str, usize)] = &[("100B", 100), ("1KB", 1024), ("4KB", 4096), ("64KB", 65536)];
let mut group = c.benchmark_group("kem_encrypt");
let hybrid_kp = HybridKeypair::generate();
let hybrid_pk = hybrid_kp.public_key();
let classical_kp = ClassicalKeypair::generate();
for (label, size) in sizes {
let payload = vec![0xABu8; *size];
group.bench_with_input(
BenchmarkId::new("hybrid", label),
&payload,
|b, payload| {
b.iter(|| hybrid_encrypt(&hybrid_pk, black_box(payload), b"", b"").unwrap());
},
);
group.bench_with_input(
BenchmarkId::new("classical", label),
&payload,
|b, payload| {
b.iter(|| classical_encrypt(&classical_kp.public, black_box(payload)));
},
);
}
group.finish();
}
fn bench_decrypt(c: &mut Criterion) {
let sizes: &[(&str, usize)] = &[("100B", 100), ("1KB", 1024), ("4KB", 4096), ("64KB", 65536)];
let mut group = c.benchmark_group("kem_decrypt");
let hybrid_kp = HybridKeypair::generate();
let hybrid_pk = hybrid_kp.public_key();
let classical_kp = ClassicalKeypair::generate();
for (label, size) in sizes {
let payload = vec![0xABu8; *size];
let hybrid_ct = hybrid_encrypt(&hybrid_pk, &payload, b"", b"").unwrap();
let classical_ct = classical_encrypt(&classical_kp.public, &payload);
group.bench_with_input(
BenchmarkId::new("hybrid", label),
&hybrid_ct,
|b, ct| {
b.iter(|| hybrid_decrypt(&hybrid_kp, black_box(ct), b"", b"").unwrap());
},
);
group.bench_with_input(
BenchmarkId::new("classical", label),
&classical_ct,
|b, ct| {
b.iter(|| classical_decrypt(&classical_kp, black_box(ct)));
},
);
}
group.finish();
}
criterion_group!(benches, bench_keygen, bench_encrypt, bench_decrypt);
criterion_main!(benches);

View File

@@ -0,0 +1,157 @@
#![allow(clippy::unwrap_used)]
//! Benchmark: MLS group operations at various group sizes.
//!
//! Measures KeyPackage generation, group creation, member addition,
//! message encryption, and message decryption.
use std::sync::Arc;
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use quicprochat_core::{GroupMember, IdentityKeypair};
/// Create identities and a group of the given size.
/// Returns (creator, Vec<members>).
fn setup_group(size: usize) -> (GroupMember, Vec<GroupMember>) {
let creator_id = Arc::new(IdentityKeypair::generate());
let mut creator = GroupMember::new(creator_id);
creator.create_group(b"bench-group").unwrap();
let mut members = Vec::with_capacity(size.saturating_sub(1));
for _ in 1..size {
let joiner_id = Arc::new(IdentityKeypair::generate());
let mut joiner = GroupMember::new(joiner_id);
let kp = joiner.generate_key_package().unwrap();
let (_commit, welcome) = creator.add_member(&kp).unwrap();
joiner.join_group(&welcome).unwrap();
members.push(joiner);
}
(creator, members)
}
fn bench_keygen(c: &mut Criterion) {
c.bench_function("mls_keygen", |b| {
b.iter_batched(
|| {
let id = Arc::new(IdentityKeypair::generate());
GroupMember::new(id)
},
|mut member| {
member.generate_key_package().unwrap();
},
BatchSize::SmallInput,
);
});
}
fn bench_group_create(c: &mut Criterion) {
c.bench_function("mls_group_create", |b| {
b.iter_batched(
|| {
let id = Arc::new(IdentityKeypair::generate());
GroupMember::new(id)
},
|mut member| {
member.create_group(b"bench-group").unwrap();
},
BatchSize::SmallInput,
);
});
}
fn bench_add_member(c: &mut Criterion) {
let mut group = c.benchmark_group("mls_add_member");
group.sample_size(10);
for size in [2, 10, 50, 100] {
group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| {
b.iter_batched(
|| {
let (creator, members) = setup_group(size);
let joiner_id = Arc::new(IdentityKeypair::generate());
let mut joiner = GroupMember::new(joiner_id);
let kp = joiner.generate_key_package().unwrap();
(creator, members, joiner, kp)
},
|(mut creator, _members, _joiner, kp)| {
creator.add_member(&kp).unwrap();
},
BatchSize::SmallInput,
);
});
}
group.finish();
}
fn bench_epoch_rotation(c: &mut Criterion) {
let mut group = c.benchmark_group("mls_epoch_rotation");
group.sample_size(10);
for size in [2, 10, 50] {
group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| {
b.iter_batched(
|| {
let (mut creator, members) = setup_group(size);
// Propose a self-update to simulate epoch rotation
let proposal = creator.propose_self_update().unwrap();
(creator, members, proposal)
},
|(mut creator, _members, _proposal)| {
// Commit pending proposals (the self-update) to advance the epoch
creator.commit_pending_proposals().unwrap();
},
BatchSize::SmallInput,
);
});
}
group.finish();
}
fn bench_send_message(c: &mut Criterion) {
let mut group = c.benchmark_group("mls_send_message");
for size in [2, 10, 50] {
group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| {
let (mut creator, _members) = setup_group(size);
let payload = b"hello benchmark message";
b.iter(|| {
creator.send_message(payload).unwrap();
});
});
}
group.finish();
}
fn bench_receive_message(c: &mut Criterion) {
let mut group = c.benchmark_group("mls_receive_message");
for size in [2, 10, 50] {
group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| {
// For receive, we need a fresh ciphertext each iteration since
// MLS message processing is destructive (epoch state changes).
// We pre-generate a batch and consume them.
let (mut creator, mut members) = setup_group(size);
if members.is_empty() {
return;
}
let payload = b"hello benchmark message";
b.iter_batched(
|| creator.send_message(payload).unwrap(),
|ct| {
// Receive on the first joiner
let _ = members[0].receive_message(&ct);
},
BatchSize::SmallInput,
);
});
}
group.finish();
}
criterion_group!(
benches,
bench_keygen,
bench_group_create,
bench_add_member,
bench_epoch_rotation,
bench_send_message,
bench_receive_message,
);
criterion_main!(benches);

View File

@@ -0,0 +1,171 @@
#![allow(clippy::unwrap_used)]
//! Benchmark: Cap'n Proto vs Protobuf serialization for chat message envelopes.
//!
//! Compares serialization/deserialization speed and encoded size at three
//! payload sizes (100 B, 1 KB, 4 KB) for a typical Envelope{seq, data} message.
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
// ── Cap'n Proto path ────────────────────────────────────────────────────────
fn capnp_serialize_envelope(seq: u64, data: &[u8]) -> Vec<u8> {
let mut msg = capnp::message::Builder::new_default();
{
let mut envelope = msg.init_root::<quicprochat_proto::node_capnp::envelope::Builder>();
envelope.set_seq(seq);
envelope.set_data(data);
}
quicprochat_proto::to_bytes(&msg).unwrap()
}
fn capnp_deserialize_envelope(bytes: &[u8]) -> (u64, Vec<u8>) {
let reader = quicprochat_proto::from_bytes(bytes).unwrap();
let envelope = reader
.get_root::<quicprochat_proto::node_capnp::envelope::Reader>()
.unwrap();
(envelope.get_seq(), envelope.get_data().unwrap().to_vec())
}
// ── Protobuf path (hand-coded prost encoding to avoid build-dep) ────────────
//
// Envelope { seq: uint64 (field 1), data: bytes (field 2) }
// Wire format: varint tag + varint seq + len-delimited data
fn protobuf_serialize_envelope(seq: u64, data: &[u8]) -> Vec<u8> {
// Build a prost message via raw encoding.
// Field 1: uint64 seq, wire type 0 (varint), tag = (1 << 3) | 0 = 0x08
// Field 2: bytes data, wire type 2 (length-delimited), tag = (2 << 3) | 2 = 0x12
let mut buf = Vec::with_capacity(10 + data.len());
// Encode field 1 (seq)
prost::encoding::uint64::encode(1, &seq, &mut buf);
// Encode field 2 (data)
prost::encoding::bytes::encode(2, &data.to_vec(), &mut buf);
buf
}
fn protobuf_deserialize_envelope(bytes: &[u8]) -> (u64, Vec<u8>) {
// Decode manually using prost wire format
let mut seq: u64 = 0;
let mut data: Vec<u8> = Vec::new();
let mut buf = bytes;
while !buf.is_empty() {
let (tag, wire_type) =
prost::encoding::decode_key(&mut buf).expect("decode key");
match tag {
1 => {
prost::encoding::uint64::merge(wire_type, &mut seq, &mut buf, Default::default())
.expect("decode seq");
}
2 => {
prost::encoding::bytes::merge(wire_type, &mut data, &mut buf, Default::default())
.expect("decode data");
}
_ => {
prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())
.expect("skip unknown field");
}
}
}
(seq, data)
}
// ── Benchmarks ──────────────────────────────────────────────────────────────
fn bench_serialize(c: &mut Criterion) {
let sizes: &[(&str, usize)] = &[("100B", 100), ("1KB", 1024), ("4KB", 4096)];
let mut group = c.benchmark_group("serialize_envelope");
for (label, size) in sizes {
let payload = vec![0xABu8; *size];
let seq = 42u64;
group.bench_with_input(
BenchmarkId::new("capnp", label),
&(&seq, &payload),
|b, &(seq, payload)| {
b.iter(|| capnp_serialize_envelope(black_box(*seq), black_box(payload)));
},
);
group.bench_with_input(
BenchmarkId::new("protobuf", label),
&(&seq, &payload),
|b, &(seq, payload)| {
b.iter(|| protobuf_serialize_envelope(black_box(*seq), black_box(payload)));
},
);
}
group.finish();
}
fn bench_deserialize(c: &mut Criterion) {
let sizes: &[(&str, usize)] = &[("100B", 100), ("1KB", 1024), ("4KB", 4096)];
let mut group = c.benchmark_group("deserialize_envelope");
for (label, size) in sizes {
let payload = vec![0xABu8; *size];
let seq = 42u64;
let capnp_bytes = capnp_serialize_envelope(seq, &payload);
let proto_bytes = protobuf_serialize_envelope(seq, &payload);
group.bench_with_input(
BenchmarkId::new("capnp", label),
&capnp_bytes,
|b, bytes| {
b.iter(|| capnp_deserialize_envelope(black_box(bytes)));
},
);
group.bench_with_input(
BenchmarkId::new("protobuf", label),
&proto_bytes,
|b, bytes| {
b.iter(|| protobuf_deserialize_envelope(black_box(bytes)));
},
);
}
group.finish();
}
fn bench_encoded_sizes(c: &mut Criterion) {
let sizes: &[(&str, usize)] = &[("100B", 100), ("1KB", 1024), ("4KB", 4096)];
let mut group = c.benchmark_group("encoded_size");
for (label, size) in sizes {
let payload = vec![0xABu8; *size];
let capnp_bytes = capnp_serialize_envelope(42, &payload);
let proto_bytes = protobuf_serialize_envelope(42, &payload);
// Use a trivial benchmark that just returns the size -- the point
// is to get criterion to print the iteration count and allow
// comparison. The real value is in the eprintln below.
group.bench_with_input(
BenchmarkId::new("capnp", label),
&capnp_bytes,
|b, bytes| {
b.iter(|| black_box(bytes.len()));
},
);
group.bench_with_input(
BenchmarkId::new("protobuf", label),
&proto_bytes,
|b, bytes| {
b.iter(|| black_box(bytes.len()));
},
);
eprintln!(
" {label}: capnp={} bytes, protobuf={} bytes, overhead={:+} bytes",
capnp_bytes.len(),
proto_bytes.len(),
capnp_bytes.len() as isize - proto_bytes.len() as isize,
);
}
group.finish();
}
criterion_group!(benches, bench_serialize, bench_deserialize, bench_encoded_sizes);
criterion_main!(benches);

View File

@@ -0,0 +1,21 @@
syntax = "proto3";
package quicprochat.bench;
// Equivalent to the Envelope struct in delivery.capnp
message Envelope {
uint64 seq = 1;
bytes data = 2;
}
// Equivalent to a chat message payload (app_message.rs Chat variant)
message ChatMessage {
bytes message_id = 1; // 16 bytes
string body = 2; // UTF-8 text
uint64 timestamp_ms = 3;
bytes sender_key = 4; // 32 bytes Ed25519 public key
}
// Batch fetch response (equivalent to fetch returning List(Envelope))
message FetchResponse {
repeated Envelope payloads = 1;
}

View File

@@ -0,0 +1,524 @@
//! Rich application-layer message format for MLS application payloads.
//!
//! The server sees only opaque ciphertext; structure lives in this client-defined
//! plaintext schema. All messages use: version byte (1) + message_type byte + type-specific payload.
//!
//! # Message ID
//!
//! `message_id` is assigned by the sender (16 random bytes) and included in the
//! serialized payload for Chat (and implied for Reply/Reaction/ReadReceipt via ref_msg_id).
//! Recipients can store message_ids to reference them in replies or reactions.
use crate::error::CoreError;
use rand::RngCore;
/// Current schema version.
pub const VERSION: u8 = 1;
/// Message type discriminant (one byte).
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum MessageType {
Chat = 0x01,
Reply = 0x02,
Reaction = 0x03,
ReadReceipt = 0x04,
Typing = 0x05,
Edit = 0x06,
Delete = 0x07,
FileRef = 0x08,
Dummy = 0x09,
}
impl MessageType {
fn from_byte(b: u8) -> Option<Self> {
match b {
0x01 => Some(MessageType::Chat),
0x02 => Some(MessageType::Reply),
0x03 => Some(MessageType::Reaction),
0x04 => Some(MessageType::ReadReceipt),
0x05 => Some(MessageType::Typing),
0x06 => Some(MessageType::Edit),
0x07 => Some(MessageType::Delete),
0x08 => Some(MessageType::FileRef),
0x09 => Some(MessageType::Dummy),
_ => None,
}
}
}
/// Parsed application message (one of the rich types).
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum AppMessage {
/// Plain chat: body (UTF-8). message_id is included so recipients can store and reference it.
Chat {
message_id: [u8; 16],
body: Vec<u8>,
},
Reply {
ref_msg_id: [u8; 16],
body: Vec<u8>,
},
Reaction {
ref_msg_id: [u8; 16],
emoji: Vec<u8>,
},
ReadReceipt {
msg_id: [u8; 16],
},
Typing {
/// 0 = stopped, 1 = typing
active: u8,
},
/// Edit a previously sent message (identified by ref_msg_id).
Edit {
ref_msg_id: [u8; 16],
body: Vec<u8>,
},
/// Delete a previously sent message (identified by ref_msg_id).
Delete {
ref_msg_id: [u8; 16],
},
/// File reference: metadata pointing to a blob stored on the server.
FileRef {
blob_id: [u8; 32],
filename: Vec<u8>,
file_size: u64,
mime_type: Vec<u8>,
},
/// Dummy message for traffic analysis resistance (no user-visible content).
Dummy,
}
/// Generate a new 16-byte message ID (e.g. for Chat/Reply so recipients can reference it).
pub fn generate_message_id() -> [u8; 16] {
let mut id = [0u8; 16];
rand::rngs::OsRng.fill_bytes(&mut id);
id
}
// ── Layout (minimal, no Cap'n Proto) ─────────────────────────────────────────
//
// All messages: [version: 1][type: 1][payload...]
//
// Chat: [msg_id: 16][body_len: 2 BE][body]
// Reply: [ref_msg_id: 16][body_len: 2 BE][body]
// Reaction: [ref_msg_id: 16][emoji_len: 1][emoji]
// ReadReceipt: [msg_id: 16]
// Typing: [active: 1] 0 = stopped, 1 = typing
// Edit: [ref_msg_id: 16][body_len: 2 BE][body]
// Delete: [ref_msg_id: 16]
// FileRef: [blob_id: 32][filename_len: 2 BE][filename][file_size: 8 BE][mime_len: 2 BE][mime_type]
/// Serialize a rich message into the application payload format.
pub fn serialize(msg_type: MessageType, payload: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(2 + payload.len());
out.push(VERSION);
out.push(msg_type as u8);
out.extend_from_slice(payload);
out
}
/// Serialize a Chat message (generates message_id internally; pass None to generate, or Some(id) when replying with a known id).
pub fn serialize_chat(body: &[u8], message_id: Option<[u8; 16]>) -> Result<Vec<u8>, CoreError> {
if body.len() > u16::MAX as usize {
return Err(CoreError::AppMessage("chat body exceeds maximum length (65535 bytes)".into()));
}
let id = message_id.unwrap_or_else(generate_message_id);
let mut payload = Vec::with_capacity(16 + 2 + body.len());
payload.extend_from_slice(&id);
payload.extend_from_slice(&(body.len() as u16).to_be_bytes());
payload.extend_from_slice(body);
Ok(serialize(MessageType::Chat, &payload))
}
/// Serialize a Reply message.
pub fn serialize_reply(ref_msg_id: [u8; 16], body: &[u8]) -> Result<Vec<u8>, CoreError> {
if body.len() > u16::MAX as usize {
return Err(CoreError::AppMessage("reply body exceeds maximum length (65535 bytes)".into()));
}
let mut payload = Vec::with_capacity(16 + 2 + body.len());
payload.extend_from_slice(&ref_msg_id);
payload.extend_from_slice(&(body.len() as u16).to_be_bytes());
payload.extend_from_slice(body);
Ok(serialize(MessageType::Reply, &payload))
}
/// Serialize a Reaction message.
pub fn serialize_reaction(ref_msg_id: [u8; 16], emoji: &[u8]) -> Result<Vec<u8>, CoreError> {
if emoji.len() > 255 {
return Err(CoreError::AppMessage("emoji length > 255".into()));
}
let mut payload = Vec::with_capacity(16 + 1 + emoji.len());
payload.extend_from_slice(&ref_msg_id);
payload.push(emoji.len() as u8);
payload.extend_from_slice(emoji);
Ok(serialize(MessageType::Reaction, &payload))
}
/// Serialize a ReadReceipt message.
pub fn serialize_read_receipt(msg_id: [u8; 16]) -> Vec<u8> {
serialize(MessageType::ReadReceipt, &msg_id)
}
/// Serialize a Typing message (active: 0 = stopped, 1 = typing).
pub fn serialize_typing(active: u8) -> Vec<u8> {
let payload = [active];
serialize(MessageType::Typing, &payload)
}
/// Serialize an Edit message (replaces body of a previously sent message).
pub fn serialize_edit(ref_msg_id: &[u8; 16], body: &[u8]) -> Result<Vec<u8>, CoreError> {
if body.len() > u16::MAX as usize {
return Err(CoreError::AppMessage("edit body exceeds maximum length (65535 bytes)".into()));
}
let mut payload = Vec::with_capacity(16 + 2 + body.len());
payload.extend_from_slice(ref_msg_id);
payload.extend_from_slice(&(body.len() as u16).to_be_bytes());
payload.extend_from_slice(body);
Ok(serialize(MessageType::Edit, &payload))
}
/// Serialize a Delete message (marks a previously sent message as deleted).
pub fn serialize_delete(ref_msg_id: &[u8; 16]) -> Vec<u8> {
serialize(MessageType::Delete, ref_msg_id)
}
/// Serialize a FileRef message (metadata pointing to a blob on the server).
pub fn serialize_file_ref(
blob_id: &[u8; 32],
filename: &[u8],
file_size: u64,
mime_type: &[u8],
) -> Result<Vec<u8>, CoreError> {
if filename.len() > u16::MAX as usize {
return Err(CoreError::AppMessage("filename exceeds maximum length".into()));
}
if mime_type.len() > u16::MAX as usize {
return Err(CoreError::AppMessage("mime_type exceeds maximum length".into()));
}
let mut payload = Vec::with_capacity(32 + 2 + filename.len() + 8 + 2 + mime_type.len());
payload.extend_from_slice(blob_id);
payload.extend_from_slice(&(filename.len() as u16).to_be_bytes());
payload.extend_from_slice(filename);
payload.extend_from_slice(&file_size.to_be_bytes());
payload.extend_from_slice(&(mime_type.len() as u16).to_be_bytes());
payload.extend_from_slice(mime_type);
Ok(serialize(MessageType::FileRef, &payload))
}
/// Serialize a Dummy message (traffic padding — no user content).
pub fn serialize_dummy() -> Vec<u8> {
serialize(MessageType::Dummy, &[])
}
/// Parse bytes into (MessageType, AppMessage). Fails if version/type unknown or payload too short.
pub fn parse(bytes: &[u8]) -> Result<(MessageType, AppMessage), CoreError> {
if bytes.len() < 2 {
return Err(CoreError::AppMessage("payload too short (need version + type)".into()));
}
let version = bytes[0];
if version != VERSION {
return Err(CoreError::AppMessage(format!("unsupported version {version}")));
}
let msg_type = MessageType::from_byte(bytes[1])
.ok_or_else(|| CoreError::AppMessage(format!("unknown message type {}", bytes[1])))?;
let payload = &bytes[2..];
let app = match msg_type {
MessageType::Chat => parse_chat(payload)?,
MessageType::Reply => parse_reply(payload)?,
MessageType::Reaction => parse_reaction(payload)?,
MessageType::ReadReceipt => parse_read_receipt(payload)?,
MessageType::Typing => parse_typing(payload)?,
MessageType::Edit => parse_edit(payload)?,
MessageType::Delete => parse_delete(payload)?,
MessageType::FileRef => parse_file_ref(payload)?,
MessageType::Dummy => AppMessage::Dummy,
};
Ok((msg_type, app))
}
fn parse_chat(payload: &[u8]) -> Result<AppMessage, CoreError> {
if payload.len() < 16 + 2 {
return Err(CoreError::AppMessage("Chat payload too short".into()));
}
let mut message_id = [0u8; 16];
message_id.copy_from_slice(&payload[..16]);
let body_len = u16::from_be_bytes([payload[16], payload[17]]) as usize;
if payload.len() < 18 + body_len {
return Err(CoreError::AppMessage("Chat body length exceeds payload".into()));
}
let body = payload[18..18 + body_len].to_vec();
Ok(AppMessage::Chat { message_id, body })
}
fn parse_reply(payload: &[u8]) -> Result<AppMessage, CoreError> {
if payload.len() < 16 + 2 {
return Err(CoreError::AppMessage("Reply payload too short".into()));
}
let mut ref_msg_id = [0u8; 16];
ref_msg_id.copy_from_slice(&payload[..16]);
let body_len = u16::from_be_bytes([payload[16], payload[17]]) as usize;
if payload.len() < 18 + body_len {
return Err(CoreError::AppMessage("Reply body length exceeds payload".into()));
}
let body = payload[18..18 + body_len].to_vec();
Ok(AppMessage::Reply { ref_msg_id, body })
}
fn parse_reaction(payload: &[u8]) -> Result<AppMessage, CoreError> {
if payload.len() < 16 + 1 {
return Err(CoreError::AppMessage("Reaction payload too short".into()));
}
let mut ref_msg_id = [0u8; 16];
ref_msg_id.copy_from_slice(&payload[..16]);
let emoji_len = payload[16] as usize;
if payload.len() < 17 + emoji_len {
return Err(CoreError::AppMessage("Reaction emoji length exceeds payload".into()));
}
let emoji = payload[17..17 + emoji_len].to_vec();
Ok(AppMessage::Reaction { ref_msg_id, emoji })
}
fn parse_read_receipt(payload: &[u8]) -> Result<AppMessage, CoreError> {
if payload.len() < 16 {
return Err(CoreError::AppMessage("ReadReceipt payload too short".into()));
}
let mut msg_id = [0u8; 16];
msg_id.copy_from_slice(&payload[..16]);
Ok(AppMessage::ReadReceipt { msg_id })
}
fn parse_typing(payload: &[u8]) -> Result<AppMessage, CoreError> {
if payload.is_empty() {
return Err(CoreError::AppMessage("Typing payload empty".into()));
}
Ok(AppMessage::Typing { active: payload[0] })
}
fn parse_edit(payload: &[u8]) -> Result<AppMessage, CoreError> {
if payload.len() < 16 + 2 {
return Err(CoreError::AppMessage("Edit payload too short".into()));
}
let mut ref_msg_id = [0u8; 16];
ref_msg_id.copy_from_slice(&payload[..16]);
let body_len = u16::from_be_bytes([payload[16], payload[17]]) as usize;
if payload.len() < 18 + body_len {
return Err(CoreError::AppMessage("Edit body length exceeds payload".into()));
}
let body = payload[18..18 + body_len].to_vec();
Ok(AppMessage::Edit { ref_msg_id, body })
}
fn parse_delete(payload: &[u8]) -> Result<AppMessage, CoreError> {
if payload.len() < 16 {
return Err(CoreError::AppMessage("Delete payload too short".into()));
}
let mut ref_msg_id = [0u8; 16];
ref_msg_id.copy_from_slice(&payload[..16]);
Ok(AppMessage::Delete { ref_msg_id })
}
fn parse_file_ref(payload: &[u8]) -> Result<AppMessage, CoreError> {
// blob_id(32) + filename_len(2) minimum
if payload.len() < 34 {
return Err(CoreError::AppMessage("FileRef payload too short".into()));
}
let mut blob_id = [0u8; 32];
blob_id.copy_from_slice(&payload[..32]);
let filename_len = u16::from_be_bytes([payload[32], payload[33]]) as usize;
let pos = 34;
if payload.len() < pos + filename_len + 8 + 2 {
return Err(CoreError::AppMessage("FileRef payload truncated after filename_len".into()));
}
let filename = payload[pos..pos + filename_len].to_vec();
let pos = pos + filename_len;
let file_size = u64::from_be_bytes([
payload[pos], payload[pos + 1], payload[pos + 2], payload[pos + 3],
payload[pos + 4], payload[pos + 5], payload[pos + 6], payload[pos + 7],
]);
let pos = pos + 8;
let mime_len = u16::from_be_bytes([payload[pos], payload[pos + 1]]) as usize;
let pos = pos + 2;
if payload.len() < pos + mime_len {
return Err(CoreError::AppMessage("FileRef payload truncated after mime_len".into()));
}
let mime_type = payload[pos..pos + mime_len].to_vec();
Ok(AppMessage::FileRef { blob_id, filename, file_size, mime_type })
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn roundtrip_chat() {
let body = b"hello";
let encoded = serialize_chat(body, None).unwrap();
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::Chat);
match &msg {
AppMessage::Chat { message_id: _, body: b } => assert_eq!(b.as_slice(), body),
_ => panic!("expected Chat"),
}
}
#[test]
fn roundtrip_reply() {
let ref_id = [1u8; 16];
let body = b"reply text";
let encoded = serialize_reply(ref_id, body).unwrap();
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::Reply);
match &msg {
AppMessage::Reply { ref_msg_id, body: b } => {
assert_eq!(ref_msg_id, &ref_id);
assert_eq!(b.as_slice(), body);
}
_ => panic!("expected Reply"),
}
}
#[test]
fn roundtrip_typing() {
let encoded = serialize_typing(1);
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::Typing);
match &msg {
AppMessage::Typing { active } => assert_eq!(*active, 1),
_ => panic!("expected Typing"),
}
}
#[test]
fn roundtrip_reaction() {
let ref_id = [2u8; 16];
let emoji = "\u{1f44d}".as_bytes();
let encoded = serialize_reaction(ref_id, emoji).unwrap();
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::Reaction);
match &msg {
AppMessage::Reaction { ref_msg_id, emoji: e } => {
assert_eq!(ref_msg_id, &ref_id);
assert_eq!(e.as_slice(), emoji);
}
_ => panic!("expected Reaction"),
}
}
#[test]
fn roundtrip_read_receipt() {
let msg_id = [3u8; 16];
let encoded = serialize_read_receipt(msg_id);
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::ReadReceipt);
match &msg {
AppMessage::ReadReceipt { msg_id: id } => assert_eq!(id, &msg_id),
_ => panic!("expected ReadReceipt"),
}
}
#[test]
fn roundtrip_edit() {
let ref_id = [4u8; 16];
let body = b"edited text";
let encoded = serialize_edit(&ref_id, body).unwrap();
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::Edit);
match &msg {
AppMessage::Edit { ref_msg_id, body: b } => {
assert_eq!(ref_msg_id, &ref_id);
assert_eq!(b.as_slice(), body);
}
_ => panic!("expected Edit"),
}
}
#[test]
fn roundtrip_delete() {
let ref_id = [5u8; 16];
let encoded = serialize_delete(&ref_id);
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::Delete);
match &msg {
AppMessage::Delete { ref_msg_id } => assert_eq!(ref_msg_id, &ref_id),
_ => panic!("expected Delete"),
}
}
#[test]
fn edit_body_too_long() {
let body = vec![0u8; 65536];
assert!(serialize_edit(&[0; 16], &body).is_err());
}
#[test]
fn parse_empty_fails() {
assert!(parse(&[]).is_err());
}
#[test]
fn parse_bad_version_fails() {
assert!(parse(&[99, 0x01]).is_err());
}
#[test]
fn parse_bad_type_fails() {
assert!(parse(&[1, 0xFF]).is_err());
}
#[test]
fn chat_body_too_long() {
let body = vec![0u8; 65536]; // exceeds u16::MAX
assert!(serialize_chat(&body, None).is_err());
}
#[test]
fn reaction_emoji_too_long() {
let emoji = vec![0u8; 256];
assert!(serialize_reaction([0; 16], &emoji).is_err());
}
#[test]
fn parse_truncated_chat_payload() {
// Version + type + only 10 bytes of payload (needs 18 minimum for chat)
let mut data = vec![1, 0x01];
data.extend_from_slice(&[0u8; 10]);
assert!(parse(&data).is_err());
}
#[test]
fn roundtrip_file_ref() {
let blob_id = [7u8; 32];
let filename = b"report.pdf";
let file_size = 123456u64;
let mime_type = b"application/pdf";
let encoded = serialize_file_ref(&blob_id, filename, file_size, mime_type).unwrap();
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::FileRef);
match &msg {
AppMessage::FileRef {
blob_id: bid,
filename: fname,
file_size: fsize,
mime_type: mtype,
} => {
assert_eq!(bid, &blob_id);
assert_eq!(fname.as_slice(), filename);
assert_eq!(*fsize, file_size);
assert_eq!(mtype.as_slice(), mime_type);
}
_ => panic!("expected FileRef"),
}
}
#[test]
fn roundtrip_dummy() {
let encoded = serialize_dummy();
let (t, msg) = parse(&encoded).unwrap();
assert_eq!(t, MessageType::Dummy);
assert_eq!(msg, AppMessage::Dummy);
}
}

View File

@@ -0,0 +1,38 @@
//! Error types for `quicprochat-core`.
use thiserror::Error;
/// Errors produced by core cryptographic and MLS operations.
#[derive(Debug, Error)]
pub enum CoreError {
/// Cap'n Proto serialisation or deserialisation failed.
#[cfg(feature = "native")]
#[error("Cap'n Proto error: {0}")]
Capnp(#[from] capnp::Error),
/// An MLS operation failed (string description).
///
/// Preserved for backward compatibility. Prefer [`CoreError::MlsError`]
/// for new code that wraps typed openmls errors.
#[error("MLS error: {0}")]
Mls(String),
/// An MLS operation failed (typed, boxed error).
///
/// Wraps the underlying openmls error so callers can downcast to specific
/// error types when needed.
#[error("MLS error: {0}")]
MlsError(Box<dyn std::error::Error + Send + Sync>),
/// A hybrid KEM (X25519 + ML-KEM-768) operation failed.
#[error("hybrid KEM error: {0}")]
HybridKem(#[from] crate::hybrid_kem::HybridKemError),
/// IO or persistence failure.
#[error("io error: {0}")]
Io(String),
/// Application message (rich payload) parse or serialisation error.
#[error("app message: {0}")]
AppMessage(String),
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,541 @@
//! Post-quantum hybrid crypto provider for OpenMLS (M7 PoC).
//!
//! Uses X25519 + ML-KEM-768 hybrid KEM for HPKE operations where openmls
//! would use DHKEM(X25519), and delegates all other operations (AEAD, hash,
//! signatures, KDF, randomness) to `openmls_rust_crypto::RustCrypto`.
//!
//! # Key format
//!
//! When the provider sees a **hybrid public key** (length `HYBRID_PUBLIC_KEY_LEN` =
//! 32 + 1184 bytes) or **hybrid private key** (length `HYBRID_PRIVATE_KEY_LEN` =
//! 32 + 2400 bytes), it uses `hybrid_kem` for HPKE. Otherwise it delegates to
//! RustCrypto (classical X25519 HPKE).
//!
//! # MLS compatibility
//!
//! The current MLS ciphersuite (MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519)
//! uses 32-byte X25519 init keys in the wire format. This provider can produce
//! and consume **hybrid** init keys (1216-byte public, 2432-byte private), but
//! that is a non-standard extension: other MLS implementations will not
//! accept KeyPackages with hybrid init keys unless they implement the same
//! extension. This PoC validates that the OpenMLS trait surface is satisfiable
//! with a custom HPKE backend; full interoperability would require a new
//! ciphersuite or protocol extension.
use openmls_rust_crypto::RustCrypto;
use openmls_traits::{
crypto::OpenMlsCrypto,
types::{
CryptoError, ExporterSecret, HpkeCiphertext, HpkeConfig, HpkeKeyPair, HpkeKemType,
},
OpenMlsCryptoProvider,
};
use tls_codec::SecretVLBytes;
use crate::hybrid_kem::{
hybrid_decapsulate_only, hybrid_decrypt, hybrid_encapsulate_only, hybrid_encrypt,
hybrid_export, HybridKeypair, HybridPublicKey,
HYBRID_KEM_OUTPUT_LEN, HYBRID_PRIVATE_KEY_LEN, HYBRID_PUBLIC_KEY_LEN,
};
use crate::keystore::DiskKeyStore;
// Re-export types used by OpenMlsCrypto (full path for clarity).
use openmls_traits::types::{
AeadType, Ciphersuite, HashType, SignatureScheme,
};
/// Crypto backend that uses hybrid KEM for HPKE when keys are in hybrid format,
/// and delegates everything else to RustCrypto.
///
/// When `hybrid_enabled` is `true`, `derive_hpke_keypair` produces hybrid keys
/// (1216-byte public, 2432-byte private). When `false`, it delegates to
/// RustCrypto and produces classical 32-byte X25519 keys.
///
/// The `hpke_seal` / `hpke_open` methods always detect the key format by length,
/// so they work correctly regardless of the flag — a hybrid-length key will use
/// hybrid KEM, a classical-length key will use RustCrypto.
#[derive(Debug)]
pub struct HybridCrypto {
rust_crypto: RustCrypto,
/// When true, `derive_hpke_keypair` produces hybrid (X25519 + ML-KEM-768)
/// keys. When false, it produces classical X25519 keys via RustCrypto.
hybrid_enabled: bool,
}
impl HybridCrypto {
/// Create a hybrid-enabled crypto backend (derive_hpke_keypair produces hybrid keys).
pub fn new() -> Self {
Self {
rust_crypto: RustCrypto::default(),
hybrid_enabled: true,
}
}
/// Alias for `new()` — hybrid mode enabled.
pub fn new_hybrid() -> Self {
Self::new()
}
/// Create a classical crypto backend (derive_hpke_keypair produces standard
/// X25519 keys, but seal/open still accept hybrid keys by length detection).
pub fn new_classical() -> Self {
Self {
rust_crypto: RustCrypto::default(),
hybrid_enabled: false,
}
}
/// Whether this backend produces hybrid keys from `derive_hpke_keypair`.
pub fn is_hybrid_enabled(&self) -> bool {
self.hybrid_enabled
}
/// Expose the underlying RustCrypto for rand() and delegation.
pub fn rust_crypto(&self) -> &RustCrypto {
&self.rust_crypto
}
fn is_hybrid_public_key(pk_r: &[u8]) -> bool {
pk_r.len() == HYBRID_PUBLIC_KEY_LEN
}
fn is_hybrid_private_key(sk_r: &[u8]) -> bool {
sk_r.len() == HYBRID_PRIVATE_KEY_LEN
}
}
impl Default for HybridCrypto {
fn default() -> Self {
Self::new()
}
}
impl OpenMlsCrypto for HybridCrypto {
fn supports(&self, ciphersuite: Ciphersuite) -> Result<(), CryptoError> {
self.rust_crypto.supports(ciphersuite)
}
fn supported_ciphersuites(&self) -> Vec<Ciphersuite> {
self.rust_crypto.supported_ciphersuites()
}
fn hkdf_extract(
&self,
hash_type: HashType,
salt: &[u8],
ikm: &[u8],
) -> Result<SecretVLBytes, CryptoError> {
self.rust_crypto.hkdf_extract(hash_type, salt, ikm)
}
fn hkdf_expand(
&self,
hash_type: HashType,
prk: &[u8],
info: &[u8],
okm_len: usize,
) -> Result<SecretVLBytes, CryptoError> {
self.rust_crypto.hkdf_expand(hash_type, prk, info, okm_len)
}
fn hash(&self, hash_type: HashType, data: &[u8]) -> Result<Vec<u8>, CryptoError> {
self.rust_crypto.hash(hash_type, data)
}
fn aead_encrypt(
&self,
alg: AeadType,
key: &[u8],
data: &[u8],
nonce: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, CryptoError> {
self.rust_crypto.aead_encrypt(alg, key, data, nonce, aad)
}
fn aead_decrypt(
&self,
alg: AeadType,
key: &[u8],
ct_tag: &[u8],
nonce: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, CryptoError> {
self.rust_crypto.aead_decrypt(alg, key, ct_tag, nonce, aad)
}
fn signature_key_gen(&self, alg: SignatureScheme) -> Result<(Vec<u8>, Vec<u8>), CryptoError> {
self.rust_crypto.signature_key_gen(alg)
}
fn verify_signature(
&self,
alg: SignatureScheme,
data: &[u8],
pk: &[u8],
signature: &[u8],
) -> Result<(), CryptoError> {
self.rust_crypto.verify_signature(alg, data, pk, signature)
}
fn sign(&self, alg: SignatureScheme, data: &[u8], key: &[u8]) -> Result<Vec<u8>, CryptoError> {
self.rust_crypto.sign(alg, data, key)
}
fn hpke_seal(
&self,
config: HpkeConfig,
pk_r: &[u8],
info: &[u8],
aad: &[u8],
ptxt: &[u8],
) -> HpkeCiphertext {
if Self::is_hybrid_public_key(pk_r) {
// The trait `OpenMlsCrypto::hpke_seal` returns `HpkeCiphertext` (not
// `Result`), so we cannot propagate errors through the return type.
// Returning an empty ciphertext would silently cause data loss.
// Instead, panic on failure — a hybrid key that passes the length
// check but fails deserialization or encryption indicates a critical
// bug (corrupted key material), not a recoverable condition.
let recipient_pk = HybridPublicKey::from_bytes(pk_r)
.expect("hybrid public key deserialization failed — key material is corrupted");
// Pass HPKE info and aad through for proper context binding (RFC 9180).
let envelope = hybrid_encrypt(&recipient_pk, ptxt, info, aad)
.expect("hybrid HPKE encryption failed — critical crypto error");
let kem_output = envelope[..HYBRID_KEM_OUTPUT_LEN].to_vec();
let ciphertext = envelope[HYBRID_KEM_OUTPUT_LEN..].to_vec();
HpkeCiphertext {
kem_output: kem_output.into(),
ciphertext: ciphertext.into(),
}
} else {
self.rust_crypto.hpke_seal(config, pk_r, info, aad, ptxt)
}
}
fn hpke_open(
&self,
config: HpkeConfig,
input: &HpkeCiphertext,
sk_r: &[u8],
info: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, CryptoError> {
if Self::is_hybrid_private_key(sk_r) {
let keypair = HybridKeypair::from_private_bytes(sk_r)
.map_err(|_| CryptoError::HpkeDecryptionError)?;
let envelope: Vec<u8> = input
.kem_output.as_slice()
.iter()
.chain(input.ciphertext.as_slice())
.copied()
.collect();
// Pass HPKE info and aad through for proper context binding (RFC 9180).
hybrid_decrypt(&keypair, &envelope, info, aad)
.map_err(|_| CryptoError::HpkeDecryptionError)
} else {
self.rust_crypto.hpke_open(config, input, sk_r, info, aad)
}
}
fn hpke_setup_sender_and_export(
&self,
config: HpkeConfig,
pk_r: &[u8],
info: &[u8],
exporter_context: &[u8],
exporter_length: usize,
) -> Result<(Vec<u8>, ExporterSecret), CryptoError> {
if Self::is_hybrid_public_key(pk_r) {
// A key that passes the hybrid length check but fails deserialization
// is corrupted — return an error instead of silently downgrading to
// classical crypto (which would defeat PQ protection).
let recipient_pk = HybridPublicKey::from_bytes(pk_r)
.map_err(|_| CryptoError::SenderSetupError)?;
let (kem_output, shared_secret) =
hybrid_encapsulate_only(&recipient_pk).map_err(|_| CryptoError::SenderSetupError)?;
let exported = hybrid_export(&shared_secret, exporter_context, exporter_length);
Ok((kem_output, exported.into()))
} else {
self.rust_crypto.hpke_setup_sender_and_export(
config, pk_r, info, exporter_context, exporter_length,
)
}
}
fn hpke_setup_receiver_and_export(
&self,
config: HpkeConfig,
enc: &[u8],
sk_r: &[u8],
info: &[u8],
exporter_context: &[u8],
exporter_length: usize,
) -> Result<ExporterSecret, CryptoError> {
if Self::is_hybrid_private_key(sk_r) {
let keypair = HybridKeypair::from_private_bytes(sk_r)
.map_err(|_| CryptoError::ReceiverSetupError)?;
let shared_secret =
hybrid_decapsulate_only(&keypair, enc).map_err(|_| CryptoError::ReceiverSetupError)?;
let exported = hybrid_export(&shared_secret, exporter_context, exporter_length);
Ok(exported.into())
} else {
self.rust_crypto.hpke_setup_receiver_and_export(
config, enc, sk_r, info, exporter_context, exporter_length,
)
}
}
fn derive_hpke_keypair(&self, config: HpkeConfig, ikm: &[u8]) -> HpkeKeyPair {
if self.hybrid_enabled && config.0 == HpkeKemType::DhKem25519 {
let kp = HybridKeypair::derive_from_ikm(ikm);
let private_bytes = kp.private_to_bytes();
HpkeKeyPair {
private: private_bytes.as_slice().into(),
public: kp.public_key().to_bytes(),
}
} else {
self.rust_crypto.derive_hpke_keypair(config, ikm)
}
}
}
/// OpenMLS crypto provider that uses hybrid KEM for HPKE (when keys are in
/// hybrid format) and delegates the rest to RustCrypto.
#[derive(Debug)]
pub struct HybridCryptoProvider {
crypto: HybridCrypto,
key_store: DiskKeyStore,
}
impl HybridCryptoProvider {
/// Create a hybrid-enabled provider (KeyPackages will contain hybrid init keys).
pub fn new(key_store: DiskKeyStore) -> Self {
Self {
crypto: HybridCrypto::new_hybrid(),
key_store,
}
}
/// Alias for `new()` — hybrid mode enabled.
pub fn new_hybrid(key_store: DiskKeyStore) -> Self {
Self::new(key_store)
}
/// Create a classical-mode provider (KeyPackages use standard X25519 init keys,
/// but seal/open still accept hybrid keys by length detection).
pub fn new_classical(key_store: DiskKeyStore) -> Self {
Self {
crypto: HybridCrypto::new_classical(),
key_store,
}
}
/// Whether this provider produces hybrid keys from `derive_hpke_keypair`.
pub fn is_hybrid_enabled(&self) -> bool {
self.crypto.is_hybrid_enabled()
}
}
impl Default for HybridCryptoProvider {
fn default() -> Self {
Self::new(DiskKeyStore::ephemeral())
}
}
impl OpenMlsCryptoProvider for HybridCryptoProvider {
type CryptoProvider = HybridCrypto;
type RandProvider = RustCrypto;
type KeyStoreProvider = DiskKeyStore;
fn crypto(&self) -> &Self::CryptoProvider {
&self.crypto
}
fn rand(&self) -> &Self::RandProvider {
self.crypto.rust_crypto()
}
fn key_store(&self) -> &Self::KeyStoreProvider {
&self.key_store
}
}
// ── Tests ───────────────────────────────────────────────────────────────────
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use openmls_traits::types::HpkeKdfType;
fn hpke_config_dhkem_x25519() -> HpkeConfig {
HpkeConfig(
HpkeKemType::DhKem25519,
HpkeKdfType::HkdfSha256,
openmls_traits::types::HpkeAeadType::AesGcm128,
)
}
/// HPKE path with hybrid keys: derive_hpke_keypair (hybrid) -> hpke_seal -> hpke_open.
#[test]
fn hybrid_hpke_seal_open_round_trip() {
let crypto = HybridCrypto::new();
let ikm = b"test-ikm-for-hybrid-hpke-keypair";
let keypair = crypto.derive_hpke_keypair(hpke_config_dhkem_x25519(), ikm);
assert_eq!(keypair.public.len(), HYBRID_PUBLIC_KEY_LEN);
assert_eq!(keypair.private.as_ref().len(), HYBRID_PRIVATE_KEY_LEN);
let plaintext = b"hello post-quantum MLS";
let info = b"mls 1.0 test";
let aad = b"additional data";
let ct = crypto.hpke_seal(
hpke_config_dhkem_x25519(),
&keypair.public,
info,
aad,
plaintext,
);
assert!(!ct.kem_output.as_slice().is_empty());
assert!(!ct.ciphertext.as_slice().is_empty());
let decrypted = crypto
.hpke_open(
hpke_config_dhkem_x25519(),
&ct,
keypair.private.as_ref(),
info,
aad,
)
.expect("hpke_open with hybrid keys");
assert_eq!(decrypted.as_slice(), plaintext);
}
/// HPKE exporter path: setup_sender_and_export then setup_receiver_and_export.
#[test]
fn hybrid_hpke_setup_sender_receiver_export() {
let crypto = HybridCrypto::new();
let ikm = b"exporter-ikm";
let keypair = crypto.derive_hpke_keypair(hpke_config_dhkem_x25519(), ikm);
let info = b"";
let exporter_context = b"MLS 1.0 external init";
let exporter_length = 32;
let (kem_output, sender_exported) = crypto
.hpke_setup_sender_and_export(
hpke_config_dhkem_x25519(),
&keypair.public,
info,
exporter_context,
exporter_length,
)
.expect("sender and export");
assert_eq!(kem_output.len(), HYBRID_KEM_OUTPUT_LEN);
assert_eq!(sender_exported.as_ref().len(), exporter_length);
let receiver_exported = crypto
.hpke_setup_receiver_and_export(
hpke_config_dhkem_x25519(),
&kem_output,
keypair.private.as_ref(),
info,
exporter_context,
exporter_length,
)
.expect("receiver and export");
assert_eq!(sender_exported.as_ref(), receiver_exported.as_ref());
}
/// Classical mode: derive_hpke_keypair produces standard 32-byte X25519 keys.
#[test]
fn classical_mode_produces_standard_keys() {
let crypto = HybridCrypto::new_classical();
let ikm = b"test-ikm-for-classical-hpke";
let keypair = crypto.derive_hpke_keypair(hpke_config_dhkem_x25519(), ikm);
// Classical X25519 keys are 32 bytes
assert_eq!(keypair.public.len(), 32);
assert_eq!(keypair.private.as_ref().len(), 32);
}
/// Classical mode round-trip: seal/open works with classical keys.
#[test]
fn classical_mode_seal_open_round_trip() {
let crypto = HybridCrypto::new_classical();
let ikm = b"test-ikm-for-classical-round-trip";
let keypair = crypto.derive_hpke_keypair(hpke_config_dhkem_x25519(), ikm);
assert_eq!(keypair.public.len(), 32); // classical key
let plaintext = b"hello classical MLS";
let info = b"mls 1.0 test";
let aad = b"additional data";
let ct = crypto.hpke_seal(
hpke_config_dhkem_x25519(),
&keypair.public,
info,
aad,
plaintext,
);
assert!(!ct.kem_output.as_slice().is_empty());
let decrypted = crypto
.hpke_open(
hpke_config_dhkem_x25519(),
&ct,
keypair.private.as_ref(),
info,
aad,
)
.expect("hpke_open with classical keys");
assert_eq!(decrypted.as_slice(), plaintext);
}
/// KeyPackage generation with HybridCryptoProvider (validates full HPKE path in MLS).
#[test]
fn key_package_generation_with_hybrid_provider() {
use openmls::prelude::{
Credential, CredentialType, CredentialWithKey, CryptoConfig, KeyPackage,
};
use std::sync::Arc;
use tls_codec::Serialize;
use crate::identity::IdentityKeypair;
const CIPHERSUITE: Ciphersuite =
Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
let provider = HybridCryptoProvider::default();
let identity = Arc::new(IdentityKeypair::generate());
let credential = Credential::new(
identity.public_key_bytes().to_vec(),
CredentialType::Basic,
)
.unwrap();
let credential_with_key = CredentialWithKey {
credential,
signature_key: identity.public_key_bytes().to_vec().into(),
};
let key_package = KeyPackage::builder()
.build(
CryptoConfig::with_default_version(CIPHERSUITE),
&provider,
identity.as_ref(),
credential_with_key,
)
.expect("KeyPackage with hybrid HPKE");
let bytes = key_package
.tls_serialize_detached()
.expect("serialize KeyPackage");
assert!(!bytes.is_empty());
}
}

View File

@@ -0,0 +1,633 @@
//! Post-quantum hybrid KEM: X25519 + ML-KEM-768.
//!
//! Wraps MLS payloads in an outer encryption layer using a hybrid key
//! encapsulation mechanism. The X25519 component provides classical
//! ECDH security; the ML-KEM-768 component (FIPS 203) provides
//! post-quantum security.
//!
//! # Wire format
//!
//! ```text
//! version(1) | x25519_eph_pk(32) | mlkem_ct(1088) | aead_nonce(12) | aead_ct(var)
//! ```
//!
//! # Key derivation
//!
//! ```text
//! ikm = X25519_shared(32) || ML-KEM_shared(32)
//! key = HKDF-SHA256(salt=[], ikm, info="quicnprotochat-hybrid-v1", L=32)
//! ```
use chacha20poly1305::{
aead::{Aead, KeyInit},
ChaCha20Poly1305, Key, Nonce,
};
use hkdf::Hkdf;
use ml_kem::{
array::Array,
kem::{Decapsulate, Encapsulate},
EncodedSizeUser, KemCore, MlKem768, MlKem768Params,
};
use rand::{rngs::OsRng, rngs::StdRng, CryptoRng, RngCore, SeedableRng};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use x25519_dalek::{EphemeralSecret, PublicKey as X25519Public, StaticSecret};
use zeroize::Zeroizing;
// Re-import the concrete key types from the kem sub-module.
use ml_kem::kem::{DecapsulationKey, EncapsulationKey};
/// Current hybrid envelope version byte.
const HYBRID_VERSION: u8 = 0x01;
/// HKDF info string for domain separation.
/// Frozen at the original project name for backward compatibility with existing
/// encrypted state files and messages. Do not change.
const HKDF_INFO: &[u8] = b"quicnprotochat-hybrid-v1";
/// HKDF salt for domain separation (defence-in-depth; IKM already has 64 bytes of entropy).
/// Frozen — see [`HKDF_INFO`].
const HKDF_SALT: &[u8] = b"quicnprotochat-hybrid-v1-salt";
/// ML-KEM-768 ciphertext size in bytes.
const MLKEM_CT_LEN: usize = 1088;
/// ML-KEM-768 encapsulation key size in bytes.
pub const MLKEM_EK_LEN: usize = 1184;
/// ML-KEM-768 decapsulation key size in bytes.
pub const MLKEM_DK_LEN: usize = 2400;
/// Envelope header: version(1) + x25519 eph pk(32) + mlkem ct(1088) + nonce(12).
const HEADER_LEN: usize = 1 + 32 + MLKEM_CT_LEN + 12;
/// KEM output length (version + x25519 eph pk + mlkem ct) for HPKE adapter.
pub const HYBRID_KEM_OUTPUT_LEN: usize = 1 + 32 + MLKEM_CT_LEN;
/// Hybrid public key length: x25519(32) + mlkem_ek(1184). Used to detect hybrid keys in MLS.
pub const HYBRID_PUBLIC_KEY_LEN: usize = 32 + MLKEM_EK_LEN;
/// Hybrid private key length: x25519(32) + mlkem_dk(2400). Used to detect hybrid keys in MLS.
pub const HYBRID_PRIVATE_KEY_LEN: usize = 32 + MLKEM_DK_LEN;
// ── Error type ──────────────────────────────────────────────────────────────
#[derive(Debug, thiserror::Error)]
pub enum HybridKemError {
#[error("AEAD encryption failed")]
EncryptionFailed,
#[error("AEAD decryption failed (wrong recipient or tampered)")]
DecryptionFailed,
#[error("unsupported hybrid envelope version: {0}")]
UnsupportedVersion(u8),
#[error("envelope too short ({0} bytes, minimum {HEADER_LEN})")]
TooShort(usize),
#[error("invalid ML-KEM encapsulation key")]
InvalidMlKemKey,
#[error("ML-KEM decapsulation failed")]
MlKemDecapsFailed,
}
// ── Keypair types ───────────────────────────────────────────────────────────
/// A hybrid keypair combining X25519 (classical) + ML-KEM-768 (post-quantum).
///
/// Each peer holds one of these. The public portion is distributed so
/// senders can encrypt payloads with post-quantum protection.
pub struct HybridKeypair {
x25519_sk: StaticSecret,
x25519_pk: X25519Public,
mlkem_dk: DecapsulationKey<MlKem768Params>,
mlkem_ek: EncapsulationKey<MlKem768Params>,
}
/// Serialisable form of a [`HybridKeypair`] for persistence.
///
/// Secret fields are wrapped in [`Zeroizing`] so they are securely erased
/// when the struct is dropped.
#[derive(Serialize, Deserialize)]
pub struct HybridKeypairBytes {
pub x25519_sk: Zeroizing<[u8; 32]>,
pub mlkem_dk: Zeroizing<Vec<u8>>,
pub mlkem_ek: Vec<u8>,
}
/// The public portion of a hybrid keypair, sent to peers.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct HybridPublicKey {
pub x25519_pk: [u8; 32],
pub mlkem_ek: Vec<u8>,
}
/// HKDF info for deriving HPKE keypair seed from IKM (MLS compatibility).
/// Frozen — see [`HKDF_INFO`].
const HKDF_INFO_HPKE_KEYPAIR: &[u8] = b"quicnprotochat-hybrid-hpke-keypair-v1";
impl HybridKeypair {
/// Generate a fresh hybrid keypair from OS CSPRNG.
pub fn generate() -> Self {
Self::generate_from_rng(&mut OsRng)
}
/// Generate a hybrid keypair from a seeded RNG (deterministic).
pub fn generate_from_rng<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
let x25519_sk = StaticSecret::random_from_rng(&mut *rng);
let x25519_pk = X25519Public::from(&x25519_sk);
let (mlkem_dk, mlkem_ek) = MlKem768::generate(rng);
Self {
x25519_sk,
x25519_pk,
mlkem_dk,
mlkem_ek,
}
}
/// Derive a deterministic hybrid keypair from IKM (for MLS HPKE key schedule).
pub fn derive_from_ikm(ikm: &[u8]) -> Self {
let mut seed = [0u8; 32];
let hk = Hkdf::<Sha256>::new(None, ikm);
hk.expand(HKDF_INFO_HPKE_KEYPAIR, &mut seed)
.expect("32 bytes is valid HKDF output");
let mut rng = StdRng::from_seed(seed);
Self::generate_from_rng(&mut rng)
}
/// Serialise private key for MLS key store: x25519_sk(32) || mlkem_dk(2400).
///
/// The returned value is wrapped in [`Zeroizing`] so secret key material
/// is securely erased when dropped.
pub fn private_to_bytes(&self) -> Zeroizing<Vec<u8>> {
let mut out = Vec::with_capacity(HYBRID_PRIVATE_KEY_LEN);
out.extend_from_slice(self.x25519_sk.as_bytes());
out.extend_from_slice(self.mlkem_dk.as_bytes().as_slice());
Zeroizing::new(out)
}
/// Reconstruct a hybrid keypair from private key bytes (from MLS key store).
pub fn from_private_bytes(bytes: &[u8]) -> Result<Self, HybridKemError> {
if bytes.len() != HYBRID_PRIVATE_KEY_LEN {
return Err(HybridKemError::TooShort(bytes.len()));
}
let x25519_sk = StaticSecret::from(<[u8; 32]>::try_from(&bytes[0..32])
.expect("slice is exactly 32 bytes (guaranteed by HYBRID_PRIVATE_KEY_LEN check)"));
let x25519_pk = X25519Public::from(&x25519_sk);
let mlkem_dk_arr = Array::try_from(&bytes[32..32 + MLKEM_DK_LEN])
.map_err(|_| HybridKemError::InvalidMlKemKey)?;
let mlkem_dk = DecapsulationKey::<MlKem768Params>::from_bytes(&mlkem_dk_arr);
let mlkem_ek = mlkem_dk.encapsulation_key().clone();
Ok(Self {
x25519_sk,
x25519_pk,
mlkem_dk,
mlkem_ek,
})
}
/// Reconstruct from serialised bytes.
pub fn from_bytes(bytes: &HybridKeypairBytes) -> Result<Self, HybridKemError> {
let x25519_sk = StaticSecret::from(*bytes.x25519_sk);
let x25519_pk = X25519Public::from(&x25519_sk);
let mlkem_dk_arr = Array::try_from(bytes.mlkem_dk.as_slice())
.map_err(|_| HybridKemError::InvalidMlKemKey)?;
let mlkem_dk = DecapsulationKey::<MlKem768Params>::from_bytes(&mlkem_dk_arr);
let mlkem_ek_arr = Array::try_from(bytes.mlkem_ek.as_slice())
.map_err(|_| HybridKemError::InvalidMlKemKey)?;
let mlkem_ek = EncapsulationKey::<MlKem768Params>::from_bytes(&mlkem_ek_arr);
Ok(Self {
x25519_sk,
x25519_pk,
mlkem_dk,
mlkem_ek,
})
}
/// Serialise the keypair for persistence.
pub fn to_bytes(&self) -> HybridKeypairBytes {
HybridKeypairBytes {
x25519_sk: Zeroizing::new(self.x25519_sk.to_bytes()),
mlkem_dk: Zeroizing::new(self.mlkem_dk.as_bytes().to_vec()),
mlkem_ek: self.mlkem_ek.as_bytes().to_vec(),
}
}
/// Extract the public portion for distribution to peers.
pub fn public_key(&self) -> HybridPublicKey {
HybridPublicKey {
x25519_pk: self.x25519_pk.to_bytes(),
mlkem_ek: self.mlkem_ek.as_bytes().to_vec(),
}
}
}
impl HybridPublicKey {
/// Serialise to a single byte blob: x25519_pk(32) || mlkem_ek(1184).
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(32 + self.mlkem_ek.len());
out.extend_from_slice(&self.x25519_pk);
out.extend_from_slice(&self.mlkem_ek);
out
}
/// Deserialise from a single byte blob.
pub fn from_bytes(bytes: &[u8]) -> Result<Self, HybridKemError> {
if bytes.len() < 32 + MLKEM_EK_LEN {
return Err(HybridKemError::TooShort(bytes.len()));
}
let mut x25519_pk = [0u8; 32];
x25519_pk.copy_from_slice(&bytes[..32]);
let mlkem_ek = bytes[32..32 + MLKEM_EK_LEN].to_vec();
Ok(Self {
x25519_pk,
mlkem_ek,
})
}
}
// ── Encrypt / Decrypt ───────────────────────────────────────────────────────
/// Encrypt `plaintext` to `recipient_pk` using X25519 + ML-KEM-768 hybrid KEM.
///
/// `info` is optional HPKE context info incorporated into key derivation.
/// `aad` is optional additional authenticated data bound to the AEAD ciphertext.
///
/// Returns the complete hybrid envelope as a byte vector.
pub fn hybrid_encrypt(
recipient_pk: &HybridPublicKey,
plaintext: &[u8],
info: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, HybridKemError> {
// 1. Ephemeral X25519 DH
let eph_secret = EphemeralSecret::random_from_rng(OsRng);
let eph_public = X25519Public::from(&eph_secret);
let x25519_recipient = X25519Public::from(recipient_pk.x25519_pk);
let x25519_ss = eph_secret.diffie_hellman(&x25519_recipient);
// 2. ML-KEM-768 encapsulation
let mlkem_ek_arr = Array::try_from(recipient_pk.mlkem_ek.as_slice())
.map_err(|_| HybridKemError::InvalidMlKemKey)?;
let mlkem_ek = EncapsulationKey::<MlKem768Params>::from_bytes(&mlkem_ek_arr);
let (mlkem_ct, mlkem_ss) = mlkem_ek
.encapsulate(&mut OsRng)
.map_err(|_| HybridKemError::EncryptionFailed)?;
// 3. Derive AEAD key from combined shared secrets (with caller info for context binding)
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), info);
// Generate a random 12-byte nonce (not derived from HKDF).
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
let aead_nonce = *Nonce::from_slice(&nonce_bytes);
// 4. AEAD encrypt with caller-supplied AAD
let cipher = ChaCha20Poly1305::new(&aead_key);
let aead_payload = chacha20poly1305::aead::Payload { msg: plaintext, aad };
let ct = cipher
.encrypt(&aead_nonce, aead_payload)
.map_err(|_| HybridKemError::EncryptionFailed)?;
// 5. Assemble envelope: version || x25519_eph_pk || mlkem_ct || nonce || aead_ct
let mut out = Vec::with_capacity(HEADER_LEN + ct.len());
out.push(HYBRID_VERSION);
out.extend_from_slice(&eph_public.to_bytes());
out.extend_from_slice(mlkem_ct.as_slice());
out.extend_from_slice(aead_nonce.as_slice());
out.extend_from_slice(&ct);
Ok(out)
}
/// Decrypt a hybrid envelope using the recipient's private key.
///
/// `info` and `aad` must match what was passed to `hybrid_encrypt`.
pub fn hybrid_decrypt(
keypair: &HybridKeypair,
envelope: &[u8],
info: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, HybridKemError> {
if envelope.len() < HEADER_LEN + 16 {
// 16 = minimum AEAD tag
return Err(HybridKemError::TooShort(envelope.len()));
}
let version = envelope[0];
if version != HYBRID_VERSION {
return Err(HybridKemError::UnsupportedVersion(version));
}
let mut cursor = 1;
// X25519 ephemeral public key
let mut eph_pk_bytes = [0u8; 32];
eph_pk_bytes.copy_from_slice(&envelope[cursor..cursor + 32]);
cursor += 32;
// ML-KEM ciphertext
let mlkem_ct_bytes = &envelope[cursor..cursor + MLKEM_CT_LEN];
cursor += MLKEM_CT_LEN;
// AEAD nonce
let nonce = Nonce::from_slice(&envelope[cursor..cursor + 12]);
cursor += 12;
// AEAD ciphertext
let aead_ct = &envelope[cursor..];
// 1. X25519 DH with ephemeral public key
let eph_pk = X25519Public::from(eph_pk_bytes);
let x25519_ss = keypair.x25519_sk.diffie_hellman(&eph_pk);
// 2. ML-KEM decapsulation — convert bytes to the ciphertext array type
// that `DecapsulationKey::decapsulate` expects.
let mlkem_ct_arr =
Array::try_from(mlkem_ct_bytes).map_err(|_| HybridKemError::MlKemDecapsFailed)?;
let mlkem_ss = keypair
.mlkem_dk
.decapsulate(&mlkem_ct_arr)
.map_err(|_| HybridKemError::MlKemDecapsFailed)?;
// 3. Derive AEAD key (with caller info for context binding)
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), info);
// 4. Decrypt with caller-supplied AAD
let cipher = ChaCha20Poly1305::new(&aead_key);
let aead_payload = chacha20poly1305::aead::Payload { msg: aead_ct, aad };
let plaintext = cipher
.decrypt(nonce, aead_payload)
.map_err(|_| HybridKemError::DecryptionFailed)?;
Ok(plaintext)
}
/// Encapsulate only: compute shared secret and KEM output (no AEAD).
/// Returns `(kem_output, shared_secret)` where `kem_output` is the first
/// `HYBRID_KEM_OUTPUT_LEN` bytes of the hybrid envelope and `shared_secret`
/// is the 32-byte derived key (same as used for AEAD in `hybrid_encrypt`).
/// Used by MLS HPKE exporter (setup_sender_and_export).
pub fn hybrid_encapsulate_only(
recipient_pk: &HybridPublicKey,
) -> Result<(Vec<u8>, [u8; 32]), HybridKemError> {
let eph_secret = EphemeralSecret::random_from_rng(OsRng);
let eph_public = X25519Public::from(&eph_secret);
let x25519_recipient = X25519Public::from(recipient_pk.x25519_pk);
let x25519_ss = eph_secret.diffie_hellman(&x25519_recipient);
let mlkem_ek_arr = Array::try_from(recipient_pk.mlkem_ek.as_slice())
.map_err(|_| HybridKemError::InvalidMlKemKey)?;
let mlkem_ek = EncapsulationKey::<MlKem768Params>::from_bytes(&mlkem_ek_arr);
let (mlkem_ct, mlkem_ss) = mlkem_ek
.encapsulate(&mut OsRng)
.map_err(|_| HybridKemError::EncryptionFailed)?;
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), b"");
let shared_secret: [u8; 32] = aead_key.as_slice().try_into()
.expect("AEAD key is always exactly 32 bytes");
let mut kem_output = Vec::with_capacity(HYBRID_KEM_OUTPUT_LEN);
kem_output.push(HYBRID_VERSION);
kem_output.extend_from_slice(&eph_public.to_bytes());
kem_output.extend_from_slice(mlkem_ct.as_slice());
Ok((kem_output, shared_secret))
}
/// Decapsulate only: recover shared secret from KEM output (no AEAD).
/// Used by MLS HPKE exporter (setup_receiver_and_export).
pub fn hybrid_decapsulate_only(
keypair: &HybridKeypair,
kem_output: &[u8],
) -> Result<[u8; 32], HybridKemError> {
if kem_output.len() < HYBRID_KEM_OUTPUT_LEN {
return Err(HybridKemError::TooShort(kem_output.len()));
}
if kem_output[0] != HYBRID_VERSION {
return Err(HybridKemError::UnsupportedVersion(kem_output[0]));
}
let eph_pk_bytes: [u8; 32] = kem_output[1..33].try_into()
.expect("slice is exactly 32 bytes (guaranteed by HYBRID_KEM_OUTPUT_LEN check)");
let eph_pk = X25519Public::from(eph_pk_bytes);
let x25519_ss = keypair.x25519_sk.diffie_hellman(&eph_pk);
let mlkem_ct_arr = Array::try_from(&kem_output[33..33 + MLKEM_CT_LEN])
.map_err(|_| HybridKemError::MlKemDecapsFailed)?;
let mlkem_ss = keypair
.mlkem_dk
.decapsulate(&mlkem_ct_arr)
.map_err(|_| HybridKemError::MlKemDecapsFailed)?;
let aead_key = derive_aead_key(x25519_ss.as_bytes(), mlkem_ss.as_slice(), b"");
Ok(aead_key.as_slice().try_into()
.expect("AEAD key is always exactly 32 bytes"))
}
/// Export a secret from shared secret (MLS HPKE exporter compatibility).
/// Uses HKDF-Expand(prk, exporter_context, length) with prk = HKDF-Extract(0, shared_secret).
pub fn hybrid_export(
shared_secret: &[u8; 32],
exporter_context: &[u8],
length: usize,
) -> Vec<u8> {
let hk = Hkdf::<Sha256>::new(Some(HKDF_SALT), shared_secret);
let mut out = vec![0u8; length];
hk.expand(exporter_context, &mut out).expect("valid length");
out
}
/// Derive AEAD key from the combined X25519 + ML-KEM shared secrets.
///
/// `extra_info` is optional caller-supplied context (e.g. HPKE `info`) that is
/// appended to the domain-separation label for additional binding.
///
/// The nonce is generated randomly per-encryption rather than derived from
/// HKDF, preventing nonce reuse when the same shared secret is (accidentally)
/// used more than once.
fn derive_aead_key(x25519_ss: &[u8], mlkem_ss: &[u8], extra_info: &[u8]) -> Key {
let mut ikm = Zeroizing::new(vec![0u8; x25519_ss.len() + mlkem_ss.len()]);
ikm[..x25519_ss.len()].copy_from_slice(x25519_ss);
ikm[x25519_ss.len()..].copy_from_slice(mlkem_ss);
let hk = Hkdf::<Sha256>::new(Some(HKDF_SALT), &ikm);
// Combine domain-separation label with caller-supplied context.
let mut info = Vec::with_capacity(HKDF_INFO.len() + extra_info.len());
info.extend_from_slice(HKDF_INFO);
info.extend_from_slice(extra_info);
let mut key_bytes = Zeroizing::new([0u8; 32]);
hk.expand(&info, &mut *key_bytes)
.expect("32 bytes is valid HKDF-SHA256 output length");
*Key::from_slice(&*key_bytes)
}
// ── Tests ───────────────────────────────────────────────────────────────────
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn keygen_produces_valid_public_key() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
assert_eq!(pk.x25519_pk.len(), 32);
assert_eq!(pk.mlkem_ek.len(), MLKEM_EK_LEN);
}
#[test]
fn encrypt_decrypt_round_trip() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let plaintext = b"hello post-quantum world!";
let envelope = hybrid_encrypt(&pk, plaintext, b"", b"").unwrap();
let recovered = hybrid_decrypt(&kp, &envelope, b"", b"").unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn encrypt_decrypt_with_info_aad() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let plaintext = b"context-bound payload";
let info = b"mls epoch 42";
let aad = b"group-id-abc";
let envelope = hybrid_encrypt(&pk, plaintext, info, aad).unwrap();
let recovered = hybrid_decrypt(&kp, &envelope, info, aad).unwrap();
assert_eq!(recovered, plaintext);
// Mismatched info must fail
assert!(hybrid_decrypt(&kp, &envelope, b"wrong info", aad).is_err());
// Mismatched aad must fail
assert!(hybrid_decrypt(&kp, &envelope, info, b"wrong aad").is_err());
}
#[test]
fn wrong_key_decryption_fails() {
let kp_sender_target = HybridKeypair::generate();
let kp_wrong = HybridKeypair::generate();
let pk = kp_sender_target.public_key();
let envelope = hybrid_encrypt(&pk, b"secret", b"", b"").unwrap();
let result = hybrid_decrypt(&kp_wrong, &envelope, b"", b"");
assert!(result.is_err());
}
#[test]
fn tampered_aead_ciphertext_fails() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
let last = envelope.len() - 1;
envelope[last] ^= 0x01;
assert!(matches!(
hybrid_decrypt(&kp, &envelope, b"", b""),
Err(HybridKemError::DecryptionFailed)
));
}
#[test]
fn tampered_mlkem_ct_fails() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
// Flip a byte in the ML-KEM ciphertext region (starts at offset 33)
envelope[40] ^= 0xFF;
assert!(hybrid_decrypt(&kp, &envelope, b"", b"").is_err());
}
#[test]
fn tampered_x25519_eph_pk_fails() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
// Flip a byte in the X25519 ephemeral pk region (offset 1..33)
envelope[5] ^= 0xFF;
assert!(hybrid_decrypt(&kp, &envelope, b"", b"").is_err());
}
#[test]
fn unsupported_version_rejected() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let mut envelope = hybrid_encrypt(&pk, b"payload", b"", b"").unwrap();
envelope[0] = 0xFF;
assert!(matches!(
hybrid_decrypt(&kp, &envelope, b"", b""),
Err(HybridKemError::UnsupportedVersion(0xFF))
));
}
#[test]
fn envelope_too_short_rejected() {
let kp = HybridKeypair::generate();
assert!(matches!(
hybrid_decrypt(&kp, &[0x01; 10], b"", b""),
Err(HybridKemError::TooShort(10))
));
}
#[test]
fn keypair_serialisation_round_trip() {
let kp = HybridKeypair::generate();
let bytes = kp.to_bytes();
let restored = HybridKeypair::from_bytes(&bytes).unwrap();
assert_eq!(kp.x25519_pk.to_bytes(), restored.x25519_pk.to_bytes());
assert_eq!(kp.public_key().mlkem_ek, restored.public_key().mlkem_ek);
// Verify restored keypair can decrypt
let pk = kp.public_key();
let ct = hybrid_encrypt(&pk, b"test", b"", b"").unwrap();
let pt = hybrid_decrypt(&restored, &ct, b"", b"").unwrap();
assert_eq!(pt, b"test");
}
#[test]
fn public_key_serialisation_round_trip() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let bytes = pk.to_bytes();
let restored = HybridPublicKey::from_bytes(&bytes).unwrap();
assert_eq!(pk.x25519_pk, restored.x25519_pk);
assert_eq!(pk.mlkem_ek, restored.mlkem_ek);
}
#[test]
fn large_payload_round_trip() {
let kp = HybridKeypair::generate();
let pk = kp.public_key();
let plaintext = vec![0xAB; 50_000]; // 50 KB
let envelope = hybrid_encrypt(&pk, &plaintext, b"", b"").unwrap();
let recovered = hybrid_decrypt(&kp, &envelope, b"", b"").unwrap();
assert_eq!(recovered, plaintext);
}
}

View File

@@ -0,0 +1,245 @@
//! Ed25519 identity keypair for MLS credentials and AS registration.
//!
//! The [`IdentityKeypair`] is the long-term identity key embedded in MLS
//! `BasicCredential`s. It is used for signing MLS messages and as the
//! indexing key for the Authentication Service.
//!
//! # Zeroize
//!
//! The 32-byte private seed is stored as `Zeroizing<[u8; 32]>`, which zeroes
//! the bytes on drop. `[u8; 32]` is `Copy + Default` and satisfies zeroize's
//! `DefaultIsZeroes` constraint, avoiding a conflict with ed25519-dalek's
//! `SigningKey` zeroize impl.
//!
//! # Fingerprint
//!
//! A 32-byte SHA-256 digest of the raw public key bytes is used as a compact,
//! collision-resistant identifier for logging.
use ed25519_dalek::{Signer as DalekSigner, SigningKey, VerifyingKey};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use zeroize::Zeroizing;
/// An Ed25519 identity keypair.
///
/// Created with [`IdentityKeypair::generate`]. The private signing key seed
/// is zeroed when this struct is dropped.
pub struct IdentityKeypair {
/// Raw 32-byte private seed — zeroized on drop.
///
/// Stored as bytes rather than `SigningKey` to satisfy zeroize's
/// `DefaultIsZeroes` bound on `Zeroizing<T>`.
seed: Zeroizing<[u8; 32]>,
/// Corresponding 32-byte public verifying key.
verifying: VerifyingKey,
}
impl IdentityKeypair {
/// Recreate an identity keypair from a 32-byte seed.
pub fn from_seed(seed: [u8; 32]) -> Self {
let signing = SigningKey::from_bytes(&seed);
let verifying = signing.verifying_key();
Self {
seed: Zeroizing::new(seed),
verifying,
}
}
/// Return the raw 32-byte private seed (for persistence).
///
/// The returned value is wrapped in [`Zeroizing`] so it is securely
/// erased when dropped, preventing the seed from lingering in memory.
pub fn seed_bytes(&self) -> Zeroizing<[u8; 32]> {
Zeroizing::new(*self.seed)
}
}
impl IdentityKeypair {
/// Generate a fresh random Ed25519 identity keypair.
pub fn generate() -> Self {
use rand::rngs::OsRng;
let signing = SigningKey::generate(&mut OsRng);
let verifying = signing.verifying_key();
let seed = Zeroizing::new(signing.to_bytes());
Self { seed, verifying }
}
/// Return the raw 32-byte Ed25519 public key.
///
/// This is the byte array used as `identityKey` in `auth.capnp` calls.
pub fn public_key_bytes(&self) -> [u8; 32] {
self.verifying.to_bytes()
}
/// Return the SHA-256 fingerprint of the public key (32 bytes).
pub fn fingerprint(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(self.verifying.to_bytes());
hasher.finalize().into()
}
/// Reconstruct the `SigningKey` from the stored seed bytes.
fn signing_key(&self) -> SigningKey {
SigningKey::from_bytes(&self.seed)
}
}
/// Implement the openmls `Signer` trait so `IdentityKeypair` can be passed
/// directly to `KeyPackage::builder().build(...)` without needing the external
/// `openmls_basic_credential` crate.
#[cfg(feature = "native")]
impl openmls_traits::signatures::Signer for IdentityKeypair {
fn sign(&self, payload: &[u8]) -> Result<Vec<u8>, openmls_traits::types::Error> {
let sk = self.signing_key();
let sig: ed25519_dalek::Signature = sk.sign(payload);
Ok(sig.to_bytes().to_vec())
}
fn signature_scheme(&self) -> openmls_traits::types::SignatureScheme {
openmls_traits::types::SignatureScheme::ED25519
}
}
impl IdentityKeypair {
/// Sign arbitrary bytes with the Ed25519 key and return the 64-byte signature.
///
/// Used by sealed sender to sign the inner payload for recipient verification.
pub fn sign_raw(&self, payload: &[u8]) -> [u8; 64] {
let sk = self.signing_key();
let sig: ed25519_dalek::Signature = sk.sign(payload);
sig.to_bytes()
}
/// Verify an Ed25519 signature over `payload` using the given public key.
pub fn verify_raw(
public_key: &[u8; 32],
payload: &[u8],
signature: &[u8; 64],
) -> Result<(), crate::error::CoreError> {
use ed25519_dalek::Verifier;
let vk = VerifyingKey::from_bytes(public_key)
.map_err(|e| crate::error::CoreError::Mls(format!("invalid public key: {e}")))?;
let sig = ed25519_dalek::Signature::from_bytes(signature);
vk.verify(payload, &sig)
.map_err(|e| crate::error::CoreError::Mls(format!("signature verification failed: {e}")))
}
}
/// Verify a 96-byte delivery proof produced by the server's `build_delivery_proof`.
///
/// # Layout
/// ```text
/// bytes 0..32 — SHA-256(seq_le || recipient_key || timestamp_ms_le)
/// bytes 32..96 — Ed25519 signature over those 32 bytes
/// ```
///
/// Returns `Ok(true)` when the proof is structurally valid and the signature verifies,
/// `Ok(false)` when the proof length is wrong (graceful degradation for old servers),
/// or `Err` when the signature is structurally invalid / verification fails.
pub fn verify_delivery_proof(
server_pubkey: &[u8; 32],
proof: &[u8],
) -> Result<bool, crate::error::CoreError> {
if proof.len() != 96 {
return Ok(false);
}
let hash: [u8; 32] = proof[..32].try_into().expect("slice is 32 bytes");
let sig: [u8; 64] = proof[32..96].try_into().expect("slice is 64 bytes");
IdentityKeypair::verify_raw(server_pubkey, &hash, &sig)?;
Ok(true)
}
impl Serialize for IdentityKeypair {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_bytes(&self.seed[..])
}
}
impl<'de> Deserialize<'de> for IdentityKeypair {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let bytes: Vec<u8> = serde::Deserialize::deserialize(deserializer)?;
let seed: [u8; 32] = bytes
.as_slice()
.try_into()
.map_err(|_| serde::de::Error::custom("identity seed must be 32 bytes"))?;
Ok(IdentityKeypair::from_seed(seed))
}
}
impl std::fmt::Debug for IdentityKeypair {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let fp = self.fingerprint();
f.debug_struct("IdentityKeypair")
.field(
"fingerprint",
&format!("{:02x}{:02x}{:02x}{:02x}", fp[0], fp[1], fp[2], fp[3]),
)
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod proof_tests {
use super::*;
use sha2::{Digest, Sha256};
fn make_proof(kp: &IdentityKeypair, seq: u64, recipient_key: &[u8], timestamp_ms: u64) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(seq.to_le_bytes());
hasher.update(recipient_key);
hasher.update(timestamp_ms.to_le_bytes());
let hash: [u8; 32] = hasher.finalize().into();
let sig = kp.sign_raw(&hash);
let mut proof = vec![0u8; 96];
proof[..32].copy_from_slice(&hash);
proof[32..].copy_from_slice(&sig);
proof
}
#[test]
fn verify_valid_proof() {
let kp = IdentityKeypair::generate();
let pk = kp.public_key_bytes();
let rk = [0xabu8; 32];
let proof = make_proof(&kp, 42, &rk, 1_700_000_000_000);
assert!(verify_delivery_proof(&pk, &proof).unwrap());
}
#[test]
fn reject_wrong_length() {
let kp = IdentityKeypair::generate();
let pk = kp.public_key_bytes();
assert!(!verify_delivery_proof(&pk, &[0u8; 64]).unwrap());
assert!(!verify_delivery_proof(&pk, &[]).unwrap());
assert!(!verify_delivery_proof(&pk, &[0u8; 97]).unwrap());
}
#[test]
fn reject_tampered_hash() {
let kp = IdentityKeypair::generate();
let pk = kp.public_key_bytes();
let rk = [0x01u8; 32];
let mut proof = make_proof(&kp, 1, &rk, 999);
proof[0] ^= 0xff; // corrupt the hash bytes
assert!(verify_delivery_proof(&pk, &proof).is_err());
}
#[test]
fn reject_wrong_pubkey() {
let kp = IdentityKeypair::generate();
let other = IdentityKeypair::generate();
let pk = other.public_key_bytes();
let rk = [0x02u8; 32];
let proof = make_proof(&kp, 5, &rk, 0);
assert!(verify_delivery_proof(&pk, &proof).is_err());
}
}

View File

@@ -0,0 +1,109 @@
//! MLS KeyPackage generation and TLS serialisation.
//!
//! # Ciphersuite
//!
//! `MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519` (ciphersuite ID `0x0001`).
//! This is the RECOMMENDED ciphersuite from RFC 9420 §17.1.
//!
//! # Single-use semantics
//!
//! Per RFC 9420 §10.1, each KeyPackage MUST be used at most once. The
//! Authentication Service enforces this by atomically removing a package on
//! fetch.
//!
//! # Wire format
//!
//! KeyPackages are TLS-encoded using `tls_codec` (same version as openmls).
//! The resulting bytes are opaque to the quicprochat transport layer.
use openmls::prelude::{
Ciphersuite, Credential, CredentialType, CredentialWithKey, CryptoConfig, KeyPackage,
KeyPackageIn, TlsDeserializeTrait, TlsSerializeTrait,
};
use openmls_rust_crypto::OpenMlsRustCrypto;
use sha2::{Digest, Sha256};
use crate::{error::CoreError, identity::IdentityKeypair};
/// The MLS ciphersuite used throughout quicprochat (RFC 9420 §17.1).
pub const ALLOWED_CIPHERSUITE: Ciphersuite =
Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519;
/// Wire value of the allowed ciphersuite (KeyPackage TLS encoding: version 2B, ciphersuite 2B).
const ALLOWED_CIPHERSUITE_WIRE: u16 = 0x0001;
const CIPHERSUITE: Ciphersuite = ALLOWED_CIPHERSUITE;
/// Validates that the KeyPackage bytes use an allowed ciphersuite (Phase 2: ciphersuite allowlist).
///
/// Parses the TLS-encoded KeyPackage and rejects if the ciphersuite is not
/// `MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519`. Does not verify signatures;
/// the server uses this only to enforce policy before storing.
pub fn validate_keypackage_ciphersuite(bytes: &[u8]) -> Result<(), CoreError> {
if bytes.len() < 4 {
return Err(CoreError::Mls("KeyPackage too short for version+ciphersuite".into()));
}
let cs_wire = u16::from_be_bytes([bytes[2], bytes[3]]);
if cs_wire != ALLOWED_CIPHERSUITE_WIRE {
return Err(CoreError::Mls(format!(
"KeyPackage ciphersuite {:#06x} not in allowlist (only {:#06x} allowed)",
cs_wire, ALLOWED_CIPHERSUITE_WIRE
)));
}
// Optionally confirm full parse so we don't accept garbage that happens to have 0x0001 at offset 2.
let mut cursor = bytes;
let _kp = KeyPackageIn::tls_deserialize(&mut cursor)
.map_err(|e| CoreError::Mls(format!("KeyPackage parse: {e:?}")))?;
Ok(())
}
/// Generate a fresh MLS KeyPackage for `identity` and serialise it.
///
/// # Returns
///
/// `(tls_bytes, sha256_fingerprint)` where:
/// - `tls_bytes` is the TLS-encoded KeyPackage blob, suitable for uploading.
/// - `sha256_fingerprint` is the SHA-256 digest of `tls_bytes` for tamper detection.
///
/// # Errors
///
/// Returns [`CoreError::Mls`] if openmls fails to create the KeyPackage or if
/// TLS serialisation fails.
pub fn generate_key_package(identity: &IdentityKeypair) -> Result<(Vec<u8>, Vec<u8>), CoreError> {
let backend = OpenMlsRustCrypto::default();
// Build a BasicCredential using the raw Ed25519 public key bytes as the
// MLS identity. Per RFC 9420, any byte string may serve as the identity.
let credential = Credential::new(identity.public_key_bytes().to_vec(), CredentialType::Basic)
.map_err(|e| CoreError::Mls(format!("{e:?}")))?;
// The `signature_key` in CredentialWithKey is the Ed25519 public key that
// will be used to verify the KeyPackage's leaf node signature.
// `SignaturePublicKey` implements `From<Vec<u8>>`.
let credential_with_key = CredentialWithKey {
credential,
signature_key: identity.public_key_bytes().to_vec().into(),
};
// `IdentityKeypair` implements `openmls_traits::signatures::Signer`
// so it can be passed directly to the builder.
let key_package = KeyPackage::builder()
.build(
CryptoConfig::with_default_version(CIPHERSUITE),
&backend,
identity,
credential_with_key,
)
.map_err(|e| CoreError::Mls(format!("{e:?}")))?;
// TLS-encode the KeyPackage using the trait from the openmls prelude.
// This uses tls_codec 0.3 (the same version openmls uses internally),
// avoiding a duplicate-trait conflict with tls_codec 0.4.
let tls_bytes = key_package
.tls_serialize_detached()
.map_err(|e| CoreError::Mls(format!("{e:?}")))?;
let fingerprint: Vec<u8> = Sha256::digest(&tls_bytes).to_vec();
Ok((tls_bytes, fingerprint))
}

View File

@@ -0,0 +1,142 @@
use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
sync::RwLock,
};
use openmls_traits::key_store::{MlsEntity, OpenMlsKeyStore};
/// A disk-backed key store implementing `OpenMlsKeyStore`.
///
/// In-memory when `path` is `None`; otherwise flushes the entire map to disk on
/// every store/delete so HPKE init keys survive process restarts.
///
/// # Serialization
///
/// Uses bincode for both individual MLS entity values and the outer HashMap
/// container. This is required because OpenMLS types use bincode-compatible
/// serialization, and `HashMap<Vec<u8>, Vec<u8>>` requires a binary format
/// (JSON mandates string keys).
///
/// # Persistence security
///
/// When `path` is set, file permissions are restricted to owner-only (0o600)
/// on Unix platforms, since the store may contain HPKE private keys.
#[derive(Debug)]
pub struct DiskKeyStore {
path: Option<PathBuf>,
values: RwLock<HashMap<Vec<u8>, Vec<u8>>>,
}
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
pub enum DiskKeyStoreError {
#[error("serialization error")]
Serialization,
#[error("io error: {0}")]
Io(String),
}
impl DiskKeyStore {
/// In-memory keystore (no persistence).
pub fn ephemeral() -> Self {
Self {
path: None,
values: RwLock::new(HashMap::new()),
}
}
/// Persistent keystore backed by `path`. Creates an empty store if missing.
pub fn persistent(path: impl AsRef<Path>) -> Result<Self, DiskKeyStoreError> {
let path = path.as_ref().to_path_buf();
let values = if path.exists() {
let bytes = fs::read(&path).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?;
if bytes.is_empty() {
HashMap::new()
} else {
bincode::deserialize(&bytes)
.map_err(|_| DiskKeyStoreError::Serialization)?
}
} else {
HashMap::new()
};
let store = Self {
path: Some(path),
values: RwLock::new(values),
};
// Set restrictive file permissions on the keystore file.
store.set_file_permissions()?;
Ok(store)
}
fn flush(&self) -> Result<(), DiskKeyStoreError> {
let Some(path) = &self.path else {
return Ok(());
};
let values = self.values.read().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?;
let bytes = bincode::serialize(&*values).map_err(|_| DiskKeyStoreError::Serialization)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?;
}
fs::write(path, &bytes).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?;
self.set_file_permissions()?;
Ok(())
}
/// Restrict file permissions to owner-only (0o600) on Unix.
#[cfg(unix)]
fn set_file_permissions(&self) -> Result<(), DiskKeyStoreError> {
use std::os::unix::fs::PermissionsExt;
if let Some(path) = &self.path {
if path.exists() {
let perms = std::fs::Permissions::from_mode(0o600);
fs::set_permissions(path, perms)
.map_err(|e| DiskKeyStoreError::Io(format!("set permissions: {e}")))?;
}
}
Ok(())
}
#[cfg(not(unix))]
fn set_file_permissions(&self) -> Result<(), DiskKeyStoreError> {
Ok(())
}
}
impl Default for DiskKeyStore {
fn default() -> Self {
Self::ephemeral()
}
}
impl OpenMlsKeyStore for DiskKeyStore {
type Error = DiskKeyStoreError;
fn store<V: MlsEntity>(&self, k: &[u8], v: &V) -> Result<(), Self::Error> {
let value = bincode::serialize(v).map_err(|_| DiskKeyStoreError::Serialization)?;
let mut values = self.values.write().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?;
values.insert(k.to_vec(), value);
drop(values);
self.flush()
}
fn read<V: MlsEntity>(&self, k: &[u8]) -> Option<V> {
let values = match self.values.read() {
Ok(v) => v,
Err(_) => return None,
};
values
.get(k)
.and_then(|bytes| bincode::deserialize(bytes).ok())
}
fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
let mut values = self.values.write().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?;
values.remove(k);
drop(values);
self.flush()
}
}

View File

@@ -0,0 +1,99 @@
//! Core cryptographic primitives, MLS group state machine, and hybrid
//! post-quantum KEM for quicprochat.
//!
//! # WASM support
//!
//! When compiled with `--no-default-features` (disabling the `native` feature),
//! the following modules are available for `wasm32-unknown-unknown`:
//!
//! - `identity` — Ed25519 identity keypair (generate, sign, verify)
//! - `hybrid_kem` — X25519 + ML-KEM-768 hybrid key encapsulation
//! - `safety_numbers` — Signal-style safety number computation
//! - `sealed_sender` — sender identity + Ed25519 signature envelope
//! - `app_message` — rich application message serialisation/parsing
//! - `padding` — message padding to hide plaintext lengths
//! - `transcript` — encrypted tamper-evident message transcript
//! - `error` — `CoreError` type
//!
//! The following modules require the `native` feature (MLS, OPAQUE, Cap'n Proto):
//!
//! - `group` — MLS group state machine (openmls)
//! - `keypackage` — MLS KeyPackage generation
//! - `hybrid_crypto` — hybrid HPKE provider for OpenMLS
//! - `keystore` — OpenMLS key store with optional disk persistence
//! - `opaque_auth` — OPAQUE cipher suite configuration
//!
//! # Module layout
//!
//! | Module | Responsibility |
//! |---------------|------------------------------------------------------------------|
//! | `app_message` | Rich application payload (Chat, Reply, Reaction, ReadReceipt, Typing) |
//! | `error` | [`CoreError`] type |
//! | `identity` | [`IdentityKeypair`] — Ed25519 identity key for MLS credentials |
//! | `keypackage` | [`generate_key_package`] — standalone KeyPackage generation |
//! | `group` | [`GroupMember`] — MLS group lifecycle (create/join/send/recv) |
//! | `hybrid_kem` | Hybrid X25519 + ML-KEM-768 key encapsulation |
//! | `keystore` | [`DiskKeyStore`] — OpenMLS key store with optional persistence |
mod app_message;
mod error;
mod hybrid_kem;
mod identity;
pub mod padding;
pub mod pq_noise;
#[cfg(feature = "native")]
pub mod recovery;
pub mod safety_numbers;
pub mod sealed_sender;
pub mod transcript;
// ── Native-only modules (MLS, OPAQUE, filesystem) ───────────────────────────
#[cfg(feature = "native")]
mod group;
#[cfg(feature = "native")]
mod hybrid_crypto;
#[cfg(feature = "native")]
mod keypackage;
#[cfg(feature = "native")]
mod keystore;
#[cfg(feature = "native")]
pub mod opaque_auth;
// ── Public API (always available) ───────────────────────────────────────────
pub use app_message::{
serialize, serialize_chat, serialize_delete, serialize_dummy, serialize_edit,
serialize_file_ref, serialize_reaction, serialize_read_receipt, serialize_reply,
serialize_typing, parse, generate_message_id,
AppMessage, MessageType, VERSION as APP_MESSAGE_VERSION,
};
pub use error::CoreError;
pub use hybrid_kem::{
hybrid_decrypt, hybrid_encrypt, HybridKemError, HybridKeypair, HybridKeypairBytes,
HybridPublicKey,
};
pub use identity::{verify_delivery_proof, IdentityKeypair};
#[cfg(feature = "native")]
pub use recovery::{
constant_time_eq, generate_recovery_codes, recover_from_bundle, recovery_token_hash,
RecoveryBundle, RecoveryPayload, RecoverySetup, MAX_BUNDLE_SIZE, RECOVERY_CODE_COUNT,
};
pub use safety_numbers::compute_safety_number;
pub use transcript::{
read_transcript, validate_transcript_structure, ChainVerdict, DecodedRecord, TranscriptRecord,
TranscriptWriter,
};
// Deprecated re-export for backward compatibility.
#[allow(deprecated)]
pub use transcript::verify_transcript_chain;
// ── Public API (native only) ────────────────────────────────────────────────
#[cfg(feature = "native")]
pub use group::{GroupMember, ReceivedMessage, ReceivedMessageWithSender};
#[cfg(feature = "native")]
pub use hybrid_crypto::{HybridCrypto, HybridCryptoProvider};
#[cfg(feature = "native")]
pub use keypackage::{generate_key_package, validate_keypackage_ciphersuite};
#[cfg(feature = "native")]
pub use keystore::DiskKeyStore;

View File

@@ -0,0 +1,20 @@
//! Shared OPAQUE (RFC 9497) cipher suite configuration.
//!
//! Both client and server import this module to ensure they use exactly
//! the same cryptographic parameters during registration and login.
use opaque_ke::CipherSuite;
/// OPAQUE cipher suite for quicprochat.
///
/// - **OPRF**: Ristretto255 (curve25519-based, ~128-bit security)
/// - **Key exchange**: Triple-DH (3DH) over Ristretto255 with SHA-512
/// - **KSF**: Argon2id (memory-hard key stretching)
pub struct OpaqueSuite;
impl CipherSuite for OpaqueSuite {
type OprfCs = opaque_ke::Ristretto255;
type KeyExchange =
opaque_ke::key_exchange::tripledh::TripleDh<opaque_ke::Ristretto255, sha2::Sha512>;
type Ksf = argon2::Argon2<'static>;
}

View File

@@ -0,0 +1,265 @@
//! Message padding to hide plaintext lengths from the server.
//!
//! Pads payloads to fixed bucket sizes before MLS encryption so that the
//! ciphertext does not reveal the actual message length.
//!
//! # Wire format
//!
//! ```text
//! [real_length: 4 bytes LE (u32)][payload: real_length bytes][random padding]
//! ```
//!
//! The total padded output is always one of the bucket sizes: 256, 1024, 4096, 16384 bytes.
//! For payloads larger than 16380 bytes, rounds up to the nearest 16384-byte multiple.
//!
//! ## Uniform boundary padding (traffic analysis resistance)
//!
//! [`pad_uniform`] / [`unpad_uniform`] pad to a configurable byte boundary
//! (default 256) instead of exponential buckets. This produces more uniform
//! ciphertext sizes at the cost of slightly more padding overhead.
use rand::RngCore;
use crate::error::CoreError;
/// Default uniform padding boundary in bytes.
pub const DEFAULT_PADDING_BOUNDARY: usize = 256;
/// Bucket sizes in bytes. The smallest (256) accommodates a sealed sender
/// envelope (99 bytes overhead) plus a short message.
const BUCKETS: &[usize] = &[256, 1024, 4096, 16384];
/// Select the smallest bucket that fits `content_len + 4` (the 4-byte length prefix).
fn bucket_for(content_len: usize) -> usize {
let total = content_len + 4;
for &b in BUCKETS {
if total <= b {
return b;
}
}
// Larger than biggest bucket: round up to nearest 16384-byte multiple.
total.div_ceil(16384) * 16384
}
/// Pad a payload to the next bucket boundary with cryptographic random bytes.
pub fn pad(payload: &[u8]) -> Vec<u8> {
let bucket = bucket_for(payload.len());
let mut out = Vec::with_capacity(bucket);
out.extend_from_slice(&(payload.len() as u32).to_le_bytes());
out.extend_from_slice(payload);
let pad_len = bucket - 4 - payload.len();
if pad_len > 0 {
let mut padding = vec![0u8; pad_len];
rand::rngs::OsRng.fill_bytes(&mut padding);
out.extend_from_slice(&padding);
}
out
}
/// Remove padding and return the original payload.
pub fn unpad(padded: &[u8]) -> Result<Vec<u8>, CoreError> {
if padded.len() < 4 {
return Err(CoreError::AppMessage("padded message too short".into()));
}
let real_len = u32::from_le_bytes([padded[0], padded[1], padded[2], padded[3]]) as usize;
if 4 + real_len > padded.len() {
return Err(CoreError::AppMessage(
"padded real_length exceeds buffer".into(),
));
}
Ok(padded[4..4 + real_len].to_vec())
}
/// Pad a payload to the nearest multiple of `boundary` bytes.
///
/// Uses the same wire format as [`pad`]: `[real_length: 4 bytes LE][payload][random padding]`.
/// The total output length is always a multiple of `boundary`. A `boundary` of 0 is
/// treated as [`DEFAULT_PADDING_BOUNDARY`].
pub fn pad_uniform(payload: &[u8], boundary: usize) -> Vec<u8> {
let boundary = if boundary == 0 { DEFAULT_PADDING_BOUNDARY } else { boundary };
let total = payload.len() + 4; // 4-byte length prefix
let padded_len = total.div_ceil(boundary) * boundary;
let mut out = Vec::with_capacity(padded_len);
out.extend_from_slice(&(payload.len() as u32).to_le_bytes());
out.extend_from_slice(payload);
let pad_len = padded_len - total;
if pad_len > 0 {
let mut padding = vec![0u8; pad_len];
rand::rngs::OsRng.fill_bytes(&mut padding);
out.extend_from_slice(&padding);
}
out
}
/// Remove uniform padding. Wire format is identical to [`unpad`].
pub fn unpad_uniform(padded: &[u8]) -> Result<Vec<u8>, CoreError> {
unpad(padded)
}
/// Generate a decoy payload that looks identical to a real padded message.
///
/// Returns random bytes of length equal to a `boundary`-aligned padded message.
/// The 4-byte length prefix is set to 0, so [`unpad_uniform`] returns an empty payload.
pub fn generate_decoy(boundary: usize) -> Vec<u8> {
let boundary = if boundary == 0 { DEFAULT_PADDING_BOUNDARY } else { boundary };
let mut out = vec![0u8; boundary];
// Length prefix = 0 (decoy carries no real payload).
// Fill the rest with random bytes so it is indistinguishable from padding.
rand::rngs::OsRng.fill_bytes(&mut out[4..]);
out
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn round_trip_small() {
let msg = b"hello";
let padded = pad(msg);
assert_eq!(padded.len(), 256); // smallest bucket
let unpadded = unpad(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn round_trip_medium() {
let msg = vec![0xAB; 300];
let padded = pad(&msg);
assert_eq!(padded.len(), 1024); // second bucket
let unpadded = unpad(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn round_trip_large() {
let msg = vec![0xCD; 2000];
let padded = pad(&msg);
assert_eq!(padded.len(), 4096); // third bucket
let unpadded = unpad(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn round_trip_very_large() {
let msg = vec![0xEF; 10000];
let padded = pad(&msg);
assert_eq!(padded.len(), 16384); // largest bucket
let unpadded = unpad(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn round_trip_oversized() {
let msg = vec![0xFF; 20000];
let padded = pad(&msg);
assert_eq!(padded.len(), 32768); // 2 * 16384
let unpadded = unpad(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn round_trip_empty() {
let msg = b"";
let padded = pad(msg);
assert_eq!(padded.len(), 256); // smallest bucket
let unpadded = unpad(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn exactly_at_bucket_boundary() {
// 252 + 4 = 256 → fits in 256 bucket exactly
let msg = vec![0x42; 252];
let padded = pad(&msg);
assert_eq!(padded.len(), 256);
let unpadded = unpad(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn unpad_too_short_fails() {
assert!(unpad(&[0, 0]).is_err());
}
#[test]
fn unpad_invalid_length_fails() {
// Claims 1000 bytes but only has 10
let mut bad = (1000u32).to_le_bytes().to_vec();
bad.extend_from_slice(&[0u8; 10]);
assert!(unpad(&bad).is_err());
}
// ── Uniform padding tests ──────────────────────────────────────────────
#[test]
fn uniform_round_trip_default_boundary() {
let msg = b"uniform padding test";
let padded = pad_uniform(msg, DEFAULT_PADDING_BOUNDARY);
assert_eq!(padded.len() % DEFAULT_PADDING_BOUNDARY, 0);
assert_eq!(padded.len(), 256); // 20 + 4 = 24, rounds up to 256
let unpadded = unpad_uniform(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn uniform_custom_boundary() {
let msg = vec![0xAA; 100];
let padded = pad_uniform(&msg, 128);
assert_eq!(padded.len() % 128, 0);
assert_eq!(padded.len(), 128); // 100 + 4 = 104, rounds up to 128
let unpadded = unpad_uniform(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn uniform_exact_boundary() {
// 252 + 4 = 256, exactly on boundary
let msg = vec![0xBB; 252];
let padded = pad_uniform(&msg, 256);
assert_eq!(padded.len(), 256);
let unpadded = unpad_uniform(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn uniform_one_over_boundary() {
// 253 + 4 = 257, rounds up to 512
let msg = vec![0xCC; 253];
let padded = pad_uniform(&msg, 256);
assert_eq!(padded.len(), 512);
let unpadded = unpad_uniform(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn uniform_zero_boundary_uses_default() {
let msg = b"zero boundary";
let padded = pad_uniform(msg, 0);
assert_eq!(padded.len() % DEFAULT_PADDING_BOUNDARY, 0);
let unpadded = unpad_uniform(&padded).unwrap();
assert_eq!(unpadded, msg);
}
#[test]
fn decoy_is_boundary_aligned() {
let decoy = generate_decoy(256);
assert_eq!(decoy.len(), 256);
assert_eq!(decoy.len() % 256, 0);
}
#[test]
fn decoy_unpads_to_empty() {
let decoy = generate_decoy(256);
let payload = unpad_uniform(&decoy).unwrap();
assert!(payload.is_empty());
}
#[test]
fn decoy_default_boundary() {
let decoy = generate_decoy(0);
assert_eq!(decoy.len(), DEFAULT_PADDING_BOUNDARY);
}
}

View File

@@ -0,0 +1,689 @@
//! Hybrid Noise_XX + ML-KEM-768 handshake for post-quantum transport security.
//!
//! Implements a three-message Noise_XX pattern with an embedded ML-KEM-768
//! encapsulation to produce a hybrid shared secret that is secure against
//! both classical and quantum adversaries.
//!
//! # Handshake pattern
//!
//! ```text
//! XX(s, rs):
//! -> e (initiator ephemeral)
//! <- e, ee, s, es, mlkem_ct (responder ephemeral + static + ML-KEM ciphertext)
//! -> s, se (initiator static)
//! ```
//!
//! After message 2, the ML-KEM shared secret is mixed into the chaining key
//! via HKDF. The final transport keys incorporate both the X25519 DH chain
//! and the ML-KEM shared secret.
//!
//! # Wire format
//!
//! Each handshake message is a simple length-prefixed blob:
//! ```text
//! [msg_len: u32 BE][handshake message bytes]
//! ```
//!
//! # Feature gate
//!
//! This module is always compiled but the `pq-noise` feature enables it
//! in the RPC layer for server/client negotiation.
use chacha20poly1305::{
aead::{Aead, KeyInit, Payload},
ChaCha20Poly1305, Key, Nonce,
};
use hkdf::Hkdf;
use ml_kem::{
array::Array,
kem::{Decapsulate, Encapsulate},
EncodedSizeUser, KemCore, MlKem768, MlKem768Params,
};
use ml_kem::kem::{DecapsulationKey, EncapsulationKey};
use rand::rngs::OsRng;
use sha2::Sha256;
use x25519_dalek::{PublicKey as X25519Public, StaticSecret};
use zeroize::Zeroizing;
use crate::error::CoreError;
/// Domain separation label for the hybrid Noise handshake.
const PROTOCOL_NAME: &[u8] = b"quicprochat-pq-noise-v1";
/// ML-KEM-768 encapsulation key length.
const MLKEM_EK_LEN: usize = 1184;
/// ML-KEM-768 ciphertext length.
const MLKEM_CT_LEN: usize = 1088;
/// AEAD tag length (ChaCha20-Poly1305).
const TAG_LEN: usize = 16;
// ── Keypair ──────────────────────────────────────────────────────────────────
/// A static keypair for the hybrid Noise handshake.
///
/// Contains both an X25519 static key and an ML-KEM-768 key pair.
pub struct NoiseKeypair {
x25519_sk: StaticSecret,
x25519_pk: X25519Public,
mlkem_dk: DecapsulationKey<MlKem768Params>,
mlkem_ek: EncapsulationKey<MlKem768Params>,
}
impl NoiseKeypair {
/// Generate a fresh keypair from OS CSPRNG.
pub fn generate() -> Self {
let x25519_sk = StaticSecret::random_from_rng(OsRng);
let x25519_pk = X25519Public::from(&x25519_sk);
let (mlkem_dk, mlkem_ek) = MlKem768::generate(&mut OsRng);
Self {
x25519_sk,
x25519_pk,
mlkem_dk,
mlkem_ek,
}
}
/// Return the X25519 public key bytes.
pub fn x25519_public(&self) -> [u8; 32] {
self.x25519_pk.to_bytes()
}
/// Return the ML-KEM-768 encapsulation key bytes.
pub fn mlkem_public(&self) -> Vec<u8> {
self.mlkem_ek.as_bytes().to_vec()
}
}
// ── Chaining key state ───────────────────────────────────────────────────────
/// Internal handshake state tracking the Noise chaining key and handshake hash.
struct HandshakeState {
/// Chaining key — evolved by each MixKey operation.
ck: Zeroizing<[u8; 32]>,
/// Handshake hash — commits to all handshake transcript data.
h: [u8; 32],
/// Current encryption key (derived from ck after MixKey).
k: Option<Zeroizing<[u8; 32]>>,
/// Nonce counter for in-handshake encryption.
n: u64,
}
impl HandshakeState {
fn new() -> Self {
// Initialize h = SHA-256(protocol_name), ck = h.
use sha2::{Digest, Sha256};
let h: [u8; 32] = Sha256::digest(PROTOCOL_NAME).into();
Self {
ck: Zeroizing::new(h),
h,
k: None,
n: 0,
}
}
/// MixHash: h = SHA-256(h || data)
fn mix_hash(&mut self, data: &[u8]) {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(self.h);
hasher.update(data);
self.h = hasher.finalize().into();
}
/// MixKey: (ck, k) = HKDF(ck, input_key_material)
fn mix_key(&mut self, ikm: &[u8]) {
let hk = Hkdf::<Sha256>::new(Some(&*self.ck), ikm);
let mut ck = Zeroizing::new([0u8; 32]);
let mut k = Zeroizing::new([0u8; 32]);
hk.expand(b"ck", &mut *ck)
.expect("32 bytes is valid HKDF output");
hk.expand(b"k", &mut *k)
.expect("32 bytes is valid HKDF output");
self.ck = ck;
self.k = Some(k);
self.n = 0;
}
/// Encrypt plaintext with the current key and nonce, using h as AAD.
fn encrypt_and_hash(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, CoreError> {
let key = self
.k
.as_ref()
.ok_or_else(|| CoreError::Mls("pq_noise: no encryption key set".into()))?;
let cipher = ChaCha20Poly1305::new(Key::from_slice(&**key));
let nonce = nonce_from_counter(self.n);
let ct = cipher
.encrypt(
Nonce::from_slice(&nonce),
Payload {
msg: plaintext,
aad: &self.h,
},
)
.map_err(|_| CoreError::Mls("pq_noise: encrypt failed".into()))?;
self.mix_hash(&ct);
self.n += 1;
Ok(ct)
}
/// Decrypt ciphertext with the current key and nonce, using h as AAD.
fn decrypt_and_hash(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, CoreError> {
let key = self
.k
.as_ref()
.ok_or_else(|| CoreError::Mls("pq_noise: no decryption key set".into()))?;
let cipher = ChaCha20Poly1305::new(Key::from_slice(&**key));
let nonce = nonce_from_counter(self.n);
let ct_for_hash = ciphertext.to_vec();
let pt = cipher
.decrypt(
Nonce::from_slice(&nonce),
Payload {
msg: ciphertext,
aad: &self.h,
},
)
.map_err(|_| CoreError::Mls("pq_noise: decrypt failed".into()))?;
self.mix_hash(&ct_for_hash);
self.n += 1;
Ok(pt)
}
/// Split the handshake state into two transport keys (initiator->responder, responder->initiator).
fn split(&self) -> (TransportKey, TransportKey) {
let hk = Hkdf::<Sha256>::new(Some(&*self.ck), &[]);
let mut k1 = Zeroizing::new([0u8; 32]);
let mut k2 = Zeroizing::new([0u8; 32]);
hk.expand(b"initiator", &mut *k1)
.expect("32 bytes is valid HKDF output");
hk.expand(b"responder", &mut *k2)
.expect("32 bytes is valid HKDF output");
(
TransportKey { key: k1, nonce: 0 },
TransportKey { key: k2, nonce: 0 },
)
}
}
fn nonce_from_counter(n: u64) -> [u8; 12] {
let mut nonce = [0u8; 12];
nonce[4..].copy_from_slice(&n.to_le_bytes());
nonce
}
// ── Transport ────────────────────────────────────────────────────────────────
/// A transport encryption key with a nonce counter.
pub struct TransportKey {
key: Zeroizing<[u8; 32]>,
nonce: u64,
}
impl TransportKey {
/// Encrypt a message for transport.
pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, CoreError> {
let cipher = ChaCha20Poly1305::new(Key::from_slice(&*self.key));
let nonce = nonce_from_counter(self.nonce);
let ct = cipher
.encrypt(Nonce::from_slice(&nonce), plaintext)
.map_err(|_| CoreError::Mls("pq_noise transport: encrypt failed".into()))?;
self.nonce += 1;
Ok(ct)
}
/// Decrypt a transport message.
pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, CoreError> {
let cipher = ChaCha20Poly1305::new(Key::from_slice(&*self.key));
let nonce = nonce_from_counter(self.nonce);
let pt = cipher
.decrypt(Nonce::from_slice(&nonce), ciphertext)
.map_err(|_| CoreError::Mls("pq_noise transport: decrypt failed".into()))?;
self.nonce += 1;
Ok(pt)
}
}
// ── Initiator ────────────────────────────────────────────────────────────────
/// Initiator side of the hybrid Noise_XX handshake.
pub struct Initiator {
state: HandshakeState,
/// Ephemeral secret stored as StaticSecret so DH doesn't consume it.
/// Generated from OsRng; we use StaticSecret purely for the non-consuming
/// `diffie_hellman(&self, ...)` API — the key is still ephemeral.
e_sk: StaticSecret,
e_pk: X25519Public,
s: NoiseKeypair,
/// Stored after reading message 2 so we can compute se = DH(s, re) in msg3.
re_pk: Option<X25519Public>,
}
impl Initiator {
/// Create a new initiator with the given static keypair.
pub fn new(static_keypair: NoiseKeypair) -> Self {
let e_sk = StaticSecret::random_from_rng(OsRng);
let e_pk = X25519Public::from(&e_sk);
Self {
state: HandshakeState::new(),
e_sk,
e_pk,
s: static_keypair,
re_pk: None,
}
}
/// Write message 1: `-> e`
///
/// Returns the initiator's ephemeral X25519 public key (32 bytes).
pub fn write_message_1(&mut self) -> Vec<u8> {
let e_pk_bytes = self.e_pk.to_bytes();
self.state.mix_hash(&e_pk_bytes);
e_pk_bytes.to_vec()
}
/// Read message 2 from responder: `<- e, ee, s, es, mlkem_ct`
///
/// Expects: `re_pk(32) || encrypted_rs_pk(32+TAG) || mlkem_ct(1088)`
///
/// Returns the responder's static X25519 public key.
pub fn read_message_2(&mut self, msg: &[u8]) -> Result<[u8; 32], CoreError> {
let expected_len = 32 + 32 + TAG_LEN + MLKEM_CT_LEN;
if msg.len() != expected_len {
return Err(CoreError::Mls(format!(
"pq_noise msg2: expected {expected_len} bytes, got {}",
msg.len()
)));
}
let mut cursor = 0;
// re = responder ephemeral public key
let mut re_pk_bytes = [0u8; 32];
re_pk_bytes.copy_from_slice(&msg[cursor..cursor + 32]);
cursor += 32;
let re_pk = X25519Public::from(re_pk_bytes);
self.state.mix_hash(&re_pk_bytes);
self.re_pk = Some(re_pk);
// ee = DH(e, re)
let ee_ss = self.e_sk.diffie_hellman(&re_pk);
self.state.mix_key(ee_ss.as_bytes());
// Decrypt responder's static key: s = Dec(encrypted_rs_pk)
let encrypted_rs = &msg[cursor..cursor + 32 + TAG_LEN];
cursor += 32 + TAG_LEN;
let rs_pk_bytes = self.state.decrypt_and_hash(encrypted_rs)?;
let mut rs_pk_arr = [0u8; 32];
if rs_pk_bytes.len() != 32 {
return Err(CoreError::Mls("pq_noise: decrypted rs not 32 bytes".into()));
}
rs_pk_arr.copy_from_slice(&rs_pk_bytes);
let rs_pk = X25519Public::from(rs_pk_arr);
// es = DH(e, rs)
let es_ss = self.e_sk.diffie_hellman(&rs_pk);
self.state.mix_key(es_ss.as_bytes());
// ML-KEM: decapsulate the ciphertext from the responder
let mlkem_ct = &msg[cursor..cursor + MLKEM_CT_LEN];
let mlkem_ct_arr = Array::try_from(mlkem_ct)
.map_err(|_| CoreError::Mls("pq_noise: invalid ML-KEM ciphertext".into()))?;
let mlkem_ss: ml_kem::SharedKey<MlKem768> = self
.s
.mlkem_dk
.decapsulate(&mlkem_ct_arr)
.map_err(|_| CoreError::Mls("pq_noise: ML-KEM decapsulation failed".into()))?;
self.state.mix_key(&mlkem_ss);
Ok(rs_pk_arr)
}
/// Write message 3: `-> s, se`
///
/// Returns the encrypted initiator static key.
pub fn write_message_3(&mut self) -> Result<Vec<u8>, CoreError> {
let re_pk = self
.re_pk
.ok_or_else(|| CoreError::Mls("pq_noise: must read msg2 before writing msg3".into()))?;
// Encrypt our static key
let s_pk_bytes = self.s.x25519_pk.to_bytes();
let encrypted_s = self.state.encrypt_and_hash(&s_pk_bytes)?;
// se = DH(s, re)
let se_ss = self.s.x25519_sk.diffie_hellman(&re_pk);
self.state.mix_key(se_ss.as_bytes());
Ok(encrypted_s)
}
/// Finalize the handshake and return transport keys.
///
/// Returns `(send_key, recv_key)` — initiator sends with send_key.
pub fn finalize(self) -> (TransportKey, TransportKey) {
self.state.split()
}
}
// ── Responder ────────────────────────────────────────────────────────────────
/// Responder side of the hybrid Noise_XX handshake.
pub struct Responder {
state: HandshakeState,
/// Ephemeral secret stored as StaticSecret so DH doesn't consume it.
e_sk: StaticSecret,
e_pk: X25519Public,
s: NoiseKeypair,
}
impl Responder {
/// Create a new responder with the given static keypair.
pub fn new(static_keypair: NoiseKeypair) -> Self {
let e_sk = StaticSecret::random_from_rng(OsRng);
let e_pk = X25519Public::from(&e_sk);
Self {
state: HandshakeState::new(),
e_sk,
e_pk,
s: static_keypair,
}
}
/// Read message 1 from initiator: `-> e`
///
/// Expects the initiator's ephemeral X25519 public key (32 bytes).
pub fn read_message_1(&mut self, msg: &[u8]) -> Result<(), CoreError> {
if msg.len() != 32 {
return Err(CoreError::Mls(format!(
"pq_noise msg1: expected 32 bytes, got {}",
msg.len()
)));
}
self.state.mix_hash(msg);
Ok(())
}
/// Write message 2: `<- e, ee, s, es, mlkem_ct`
///
/// `initiator_ek` is the initiator's ML-KEM encapsulation key.
///
/// Returns the message bytes.
pub fn write_message_2(
&mut self,
initiator_e_pk: &[u8; 32],
initiator_mlkem_ek: &[u8],
) -> Result<Vec<u8>, CoreError> {
let ie_pk = X25519Public::from(*initiator_e_pk);
// Our ephemeral key
let e_pk_bytes = self.e_pk.to_bytes();
self.state.mix_hash(&e_pk_bytes);
// ee = DH(e, ie)
let ee_ss = self.e_sk.diffie_hellman(&ie_pk);
self.state.mix_key(ee_ss.as_bytes());
// Encrypt our static key
let s_pk_bytes = self.s.x25519_pk.to_bytes();
let encrypted_s = self.state.encrypt_and_hash(&s_pk_bytes)?;
// es = DH(s, ie)
let es_ss = self.s.x25519_sk.diffie_hellman(&ie_pk);
self.state.mix_key(es_ss.as_bytes());
// ML-KEM: encapsulate to the initiator's encapsulation key
if initiator_mlkem_ek.len() != MLKEM_EK_LEN {
return Err(CoreError::Mls(format!(
"pq_noise: expected ML-KEM EK {} bytes, got {}",
MLKEM_EK_LEN,
initiator_mlkem_ek.len()
)));
}
let ek_arr = Array::try_from(initiator_mlkem_ek)
.map_err(|_| CoreError::Mls("pq_noise: invalid ML-KEM encapsulation key".into()))?;
let ek = EncapsulationKey::<MlKem768Params>::from_bytes(&ek_arr);
let (mlkem_ct, mlkem_ss): (ml_kem::Ciphertext<MlKem768>, ml_kem::SharedKey<MlKem768>) = ek
.encapsulate(&mut OsRng)
.map_err(|_| CoreError::Mls("pq_noise: ML-KEM encapsulation failed".into()))?;
self.state.mix_key(&mlkem_ss);
// Assemble: e_pk || encrypted_s || mlkem_ct
let mut out = Vec::with_capacity(32 + encrypted_s.len() + MLKEM_CT_LEN);
out.extend_from_slice(&e_pk_bytes);
out.extend_from_slice(&encrypted_s);
out.extend_from_slice(&mlkem_ct);
Ok(out)
}
/// Read message 3 from initiator: `-> s, se`
///
/// Returns the initiator's static X25519 public key.
pub fn read_message_3(&mut self, msg: &[u8]) -> Result<[u8; 32], CoreError> {
if msg.len() != 32 + TAG_LEN {
return Err(CoreError::Mls(format!(
"pq_noise msg3: expected {} bytes, got {}",
32 + TAG_LEN,
msg.len()
)));
}
// Decrypt initiator's static key
let is_pk_bytes = self.state.decrypt_and_hash(msg)?;
let mut is_pk_arr = [0u8; 32];
if is_pk_bytes.len() != 32 {
return Err(CoreError::Mls(
"pq_noise: decrypted initiator static not 32 bytes".into(),
));
}
is_pk_arr.copy_from_slice(&is_pk_bytes);
let is_pk = X25519Public::from(is_pk_arr);
// se = DH(e, is) — responder computes using ephemeral key
let se_ss = self.e_sk.diffie_hellman(&is_pk);
self.state.mix_key(se_ss.as_bytes());
Ok(is_pk_arr)
}
/// Finalize the handshake and return transport keys.
///
/// Returns `(recv_key, send_key)` — responder receives with recv_key.
pub fn finalize(self) -> (TransportKey, TransportKey) {
let (i2r, r2i) = self.state.split();
(i2r, r2i)
}
}
// ── Tests ────────────────────────────────────────────────────────────────────
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn full_handshake_round_trip() {
let initiator_kp = NoiseKeypair::generate();
let responder_kp = NoiseKeypair::generate();
// Initiator's ML-KEM public key is sent out-of-band (or in a pre-message).
let initiator_mlkem_ek = initiator_kp.mlkem_public();
let mut initiator = Initiator::new(initiator_kp);
let mut responder = Responder::new(responder_kp);
// Message 1: initiator -> responder
let msg1 = initiator.write_message_1();
assert_eq!(msg1.len(), 32);
responder.read_message_1(&msg1).unwrap();
// Message 2: responder -> initiator
let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap();
let msg2 = responder
.write_message_2(&ie_pk, &initiator_mlkem_ek)
.unwrap();
let _responder_static = initiator.read_message_2(&msg2).unwrap();
// Message 3: initiator -> responder
let msg3 = initiator.write_message_3().unwrap();
let _initiator_static = responder.read_message_3(&msg3).unwrap();
// Derive transport keys
let (mut i_send, mut i_recv) = initiator.finalize();
let (mut r_recv, mut r_send) = responder.finalize();
// Test transport: initiator -> responder
let plaintext = b"hello post-quantum world!";
let ct = i_send.encrypt(plaintext).unwrap();
let pt = r_recv.decrypt(&ct).unwrap();
assert_eq!(pt, plaintext);
// Test transport: responder -> initiator
let plaintext2 = b"reply from responder";
let ct2 = r_send.encrypt(plaintext2).unwrap();
let pt2 = i_recv.decrypt(&ct2).unwrap();
assert_eq!(pt2, plaintext2);
}
#[test]
fn tampered_msg2_fails() {
let initiator_kp = NoiseKeypair::generate();
let responder_kp = NoiseKeypair::generate();
let initiator_mlkem_ek = initiator_kp.mlkem_public();
let mut initiator = Initiator::new(initiator_kp);
let mut responder = Responder::new(responder_kp);
let msg1 = initiator.write_message_1();
responder.read_message_1(&msg1).unwrap();
let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap();
let mut msg2 = responder
.write_message_2(&ie_pk, &initiator_mlkem_ek)
.unwrap();
// Tamper with the encrypted static key region
msg2[40] ^= 0xFF;
let result = initiator.read_message_2(&msg2);
assert!(result.is_err());
}
#[test]
fn wrong_mlkem_key_fails() {
let initiator_kp = NoiseKeypair::generate();
let responder_kp = NoiseKeypair::generate();
// Use a different keypair's ML-KEM key — decapsulation will use
// implicit rejection, producing a pseudorandom (wrong) shared secret.
let wrong_kp = NoiseKeypair::generate();
let wrong_mlkem_ek = wrong_kp.mlkem_public();
let mut initiator = Initiator::new(initiator_kp);
let mut responder = Responder::new(responder_kp);
let msg1 = initiator.write_message_1();
responder.read_message_1(&msg1).unwrap();
let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap();
let msg2 = responder
.write_message_2(&ie_pk, &wrong_mlkem_ek)
.unwrap();
// ML-KEM implicit rejection: decap succeeds but returns wrong ss.
// The ML-KEM mix_key happens after the AEAD decrypt of the static key,
// so read_message_2 itself may succeed. But the chaining keys diverge,
// causing msg3 AEAD decrypt to fail on the responder side.
let read2 = initiator.read_message_2(&msg2);
if read2.is_err() {
// If msg2 processing itself failed, the test passes.
return;
}
// msg2 succeeded — chaining keys now diverge due to wrong ML-KEM ss.
// msg3 from initiator will use the wrong key, so responder can't decrypt.
let msg3 = initiator.write_message_3().unwrap();
let result = responder.read_message_3(&msg3);
assert!(result.is_err(), "msg3 should fail due to ML-KEM shared secret mismatch");
}
#[test]
fn multiple_transport_messages() {
let initiator_kp = NoiseKeypair::generate();
let responder_kp = NoiseKeypair::generate();
let initiator_mlkem_ek = initiator_kp.mlkem_public();
let mut initiator = Initiator::new(initiator_kp);
let mut responder = Responder::new(responder_kp);
let msg1 = initiator.write_message_1();
responder.read_message_1(&msg1).unwrap();
let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap();
let msg2 = responder
.write_message_2(&ie_pk, &initiator_mlkem_ek)
.unwrap();
initiator.read_message_2(&msg2).unwrap();
let msg3 = initiator.write_message_3().unwrap();
responder.read_message_3(&msg3).unwrap();
let (mut i_send, mut i_recv) = initiator.finalize();
let (mut r_recv, mut r_send) = responder.finalize();
// Send multiple messages in each direction
for i in 0..10u32 {
let msg = format!("initiator message {i}");
let ct = i_send.encrypt(msg.as_bytes()).unwrap();
let pt = r_recv.decrypt(&ct).unwrap();
assert_eq!(pt, msg.as_bytes());
let reply = format!("responder reply {i}");
let ct2 = r_send.encrypt(reply.as_bytes()).unwrap();
let pt2 = i_recv.decrypt(&ct2).unwrap();
assert_eq!(pt2, reply.as_bytes());
}
}
#[test]
fn nonce_reuse_detected() {
let initiator_kp = NoiseKeypair::generate();
let responder_kp = NoiseKeypair::generate();
let initiator_mlkem_ek = initiator_kp.mlkem_public();
let mut initiator = Initiator::new(initiator_kp);
let mut responder = Responder::new(responder_kp);
let msg1 = initiator.write_message_1();
responder.read_message_1(&msg1).unwrap();
let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap();
let msg2 = responder
.write_message_2(&ie_pk, &initiator_mlkem_ek)
.unwrap();
initiator.read_message_2(&msg2).unwrap();
let msg3 = initiator.write_message_3().unwrap();
responder.read_message_3(&msg3).unwrap();
let (mut i_send, _) = initiator.finalize();
let (mut r_recv, _) = responder.finalize();
// Encrypt two messages
let ct1 = i_send.encrypt(b"msg1").unwrap();
let _ct2 = i_send.encrypt(b"msg2").unwrap();
// Decrypt in order works
r_recv.decrypt(&ct1).unwrap();
// Replaying ct1 (wrong nonce) should fail
let result = r_recv.decrypt(&ct1);
assert!(result.is_err());
// But ct2 at the right nonce works
// (we already consumed nonce 1 trying ct1, so ct2 at nonce 2 fails too)
// This tests that the nonce counter prevents replay.
}
}

View File

@@ -0,0 +1,342 @@
//! Account recovery — recovery code generation and encrypted backup bundles.
//!
//! # Design
//!
//! Recovery codes are 8 alphanumeric strings of 6 characters each (~31 bits
//! entropy per code). Any single code is sufficient to recover the account.
//!
//! A recovery key is derived from each code via Argon2id. The identity seed
//! and conversation metadata are encrypted into a [`RecoveryBundle`] using
//! ChaCha20-Poly1305. The bundle is uploaded to the server, keyed by
//! `SHA-256(recovery_token)` — the server never sees plaintext codes.
//!
//! # Security properties
//!
//! - Recovery codes are shown once and never stored in plaintext.
//! - The server is zero-knowledge — it stores only encrypted blobs.
//! - Code validation uses constant-time comparison.
//! - All key material is zeroized on drop.
use argon2::{Algorithm, Argon2, Params, Version};
use chacha20poly1305::{
aead::{Aead, KeyInit},
ChaCha20Poly1305, Key, Nonce,
};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use zeroize::Zeroizing;
use crate::error::CoreError;
/// Number of recovery codes generated per setup.
pub const RECOVERY_CODE_COUNT: usize = 8;
/// Length of each recovery code (alphanumeric characters).
const CODE_LENGTH: usize = 6;
/// Maximum bundle size (64 KiB).
pub const MAX_BUNDLE_SIZE: usize = 64 * 1024;
/// Argon2id parameters for recovery key derivation.
const ARGON2_M_COST: u32 = 19 * 1024; // 19 MiB
const ARGON2_T_COST: u32 = 2;
const ARGON2_P_COST: u32 = 1;
/// Alphanumeric character set for recovery codes (uppercase + digits, no
/// ambiguous characters 0/O, 1/I/L).
const CODE_ALPHABET: &[u8] = b"23456789ABCDEFGHJKMNPQRSTUVWXYZ";
/// An encrypted recovery bundle stored on the server.
///
/// The server stores this keyed by `token_hash` (SHA-256 of a recovery token
/// derived from the code). The server cannot decrypt it.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecoveryBundle {
/// SHA-256 of the recovery token (used as server-side lookup key).
pub token_hash: Vec<u8>,
/// Random 16-byte salt for Argon2id key derivation.
pub salt: Vec<u8>,
/// Random 12-byte nonce for ChaCha20-Poly1305.
pub nonce: Vec<u8>,
/// Encrypted payload: bincode-serialised `RecoveryPayload`.
pub ciphertext: Vec<u8>,
}
/// The plaintext payload inside a recovery bundle.
#[derive(Debug, Serialize, Deserialize)]
pub struct RecoveryPayload {
/// Ed25519 identity seed (32 bytes).
pub identity_seed: [u8; 32],
/// List of conversation/group IDs the user was part of (for rejoin).
pub conversation_ids: Vec<Vec<u8>>,
}
/// Result of recovery code generation.
pub struct RecoverySetup {
/// The 8 recovery codes to show to the user (shown once, never stored).
pub codes: Vec<String>,
/// Encrypted bundles — one per code — to upload to the server.
pub bundles: Vec<RecoveryBundle>,
}
/// Generate a single random recovery code.
fn generate_code(rng: &mut impl RngCore) -> String {
let mut code = String::with_capacity(CODE_LENGTH);
for _ in 0..CODE_LENGTH {
let idx = (rng.next_u32() as usize) % CODE_ALPHABET.len();
code.push(CODE_ALPHABET[idx] as char);
}
code
}
/// Derive a 32-byte recovery token from a code (used for server-side lookup).
/// The token is `SHA-256("qpc-recovery-token:" || code)`.
fn derive_recovery_token(code: &str) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(b"qpc-recovery-token:");
hasher.update(code.as_bytes());
hasher.finalize().into()
}
/// Derive a 32-byte encryption key from a code and salt via Argon2id.
fn derive_recovery_key(code: &str, salt: &[u8]) -> Result<Zeroizing<[u8; 32]>, CoreError> {
let params = Params::new(ARGON2_M_COST, ARGON2_T_COST, ARGON2_P_COST, Some(32))
.map_err(|e| CoreError::Io(format!("argon2 params: {e}")))?;
let argon2 = Argon2::new(Algorithm::Argon2id, Version::default(), params);
let mut key = Zeroizing::new([0u8; 32]);
argon2
.hash_password_into(code.as_bytes(), salt, &mut *key)
.map_err(|e| CoreError::Io(format!("argon2 recovery key derivation: {e}")))?;
Ok(key)
}
/// Generate recovery codes and encrypted bundles for an identity.
///
/// Returns a `RecoverySetup` containing:
/// - `codes`: 8 recovery codes to display to the user (once).
/// - `bundles`: 8 encrypted recovery bundles (one per code) to upload to the server.
///
/// Each code independently decrypts its corresponding bundle.
pub fn generate_recovery_codes(
identity_seed: &[u8; 32],
conversation_ids: &[Vec<u8>],
) -> Result<RecoverySetup, CoreError> {
let mut rng = rand::rngs::OsRng;
let payload = RecoveryPayload {
identity_seed: *identity_seed,
conversation_ids: conversation_ids.to_vec(),
};
let plaintext = bincode::serialize(&payload)
.map_err(|e| CoreError::Io(format!("serialize recovery payload: {e}")))?;
let mut codes = Vec::with_capacity(RECOVERY_CODE_COUNT);
let mut bundles = Vec::with_capacity(RECOVERY_CODE_COUNT);
for _ in 0..RECOVERY_CODE_COUNT {
let code = generate_code(&mut rng);
// Derive the server-side lookup token.
let token = derive_recovery_token(&code);
let token_hash = Sha256::digest(token).to_vec();
// Derive encryption key from code.
let mut salt = [0u8; 16];
rng.fill_bytes(&mut salt);
let key = derive_recovery_key(&code, &salt)?;
let cipher = ChaCha20Poly1305::new(Key::from_slice(&*key));
let mut nonce_bytes = [0u8; 12];
rng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext.as_slice())
.map_err(|e| CoreError::Io(format!("recovery bundle encryption: {e}")))?;
bundles.push(RecoveryBundle {
token_hash,
salt: salt.to_vec(),
nonce: nonce_bytes.to_vec(),
ciphertext,
});
codes.push(code);
}
Ok(RecoverySetup { codes, bundles })
}
/// Recover an identity seed from a recovery code and encrypted bundle.
///
/// Returns the decrypted `RecoveryPayload` on success.
pub fn recover_from_bundle(
code: &str,
bundle: &RecoveryBundle,
) -> Result<RecoveryPayload, CoreError> {
// Validate bundle structure.
if bundle.salt.len() != 16 {
return Err(CoreError::Io(format!(
"invalid recovery bundle salt length: {}",
bundle.salt.len()
)));
}
if bundle.nonce.len() != 12 {
return Err(CoreError::Io(format!(
"invalid recovery bundle nonce length: {}",
bundle.nonce.len()
)));
}
// Derive encryption key from code.
let key = derive_recovery_key(code, &bundle.salt)?;
let cipher = ChaCha20Poly1305::new(Key::from_slice(&*key));
let nonce = Nonce::from_slice(&bundle.nonce);
let plaintext = cipher
.decrypt(nonce, bundle.ciphertext.as_slice())
.map_err(|_| CoreError::Io("recovery bundle decryption failed (wrong code?)".into()))?;
let payload: RecoveryPayload = bincode::deserialize(&plaintext)
.map_err(|e| CoreError::Io(format!("deserialize recovery payload: {e}")))?;
Ok(payload)
}
/// Compute the token hash for a recovery code (for server-side lookup).
///
/// This is `SHA-256(SHA-256("qpc-recovery-token:" || code))`.
pub fn recovery_token_hash(code: &str) -> Vec<u8> {
let token = derive_recovery_token(code);
Sha256::digest(token).to_vec()
}
/// Constant-time comparison of two byte slices.
///
/// Returns `true` if the slices are equal, using constant-time comparison
/// to prevent timing side-channels on recovery code validation.
pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn generate_codes_produces_correct_count() {
let seed = [42u8; 32];
let setup = generate_recovery_codes(&seed, &[]).unwrap();
assert_eq!(setup.codes.len(), RECOVERY_CODE_COUNT);
assert_eq!(setup.bundles.len(), RECOVERY_CODE_COUNT);
}
#[test]
fn codes_are_correct_length_and_alphabet() {
let seed = [7u8; 32];
let setup = generate_recovery_codes(&seed, &[]).unwrap();
for code in &setup.codes {
assert_eq!(code.len(), CODE_LENGTH);
for ch in code.chars() {
assert!(
CODE_ALPHABET.contains(&(ch as u8)),
"invalid char '{ch}' in code"
);
}
}
}
#[test]
fn codes_are_unique() {
let seed = [1u8; 32];
let setup = generate_recovery_codes(&seed, &[]).unwrap();
let mut seen = std::collections::HashSet::new();
for code in &setup.codes {
assert!(seen.insert(code.clone()), "duplicate code: {code}");
}
}
#[test]
fn recover_roundtrip() {
let seed = [99u8; 32];
let conv_ids = vec![vec![1, 2, 3], vec![4, 5, 6]];
let setup = generate_recovery_codes(&seed, &conv_ids).unwrap();
// Each code should decrypt its corresponding bundle.
for (i, code) in setup.codes.iter().enumerate() {
let payload = recover_from_bundle(code, &setup.bundles[i]).unwrap();
assert_eq!(payload.identity_seed, seed);
assert_eq!(payload.conversation_ids, conv_ids);
}
}
#[test]
fn wrong_code_fails() {
let seed = [50u8; 32];
let setup = generate_recovery_codes(&seed, &[]).unwrap();
let result = recover_from_bundle("WRONG1", &setup.bundles[0]);
assert!(result.is_err());
}
#[test]
fn code_does_not_decrypt_other_bundle() {
let seed = [88u8; 32];
let setup = generate_recovery_codes(&seed, &[]).unwrap();
// Code 0 should NOT decrypt bundle 1 (different salt/nonce/key).
let result = recover_from_bundle(&setup.codes[0], &setup.bundles[1]);
assert!(result.is_err());
}
#[test]
fn token_hash_is_deterministic() {
let hash1 = recovery_token_hash("ABC123");
let hash2 = recovery_token_hash("ABC123");
assert_eq!(hash1, hash2);
}
#[test]
fn token_hash_differs_for_different_codes() {
let hash1 = recovery_token_hash("ABC123");
let hash2 = recovery_token_hash("XYZ789");
assert_ne!(hash1, hash2);
}
#[test]
fn constant_time_eq_works() {
assert!(constant_time_eq(b"hello", b"hello"));
assert!(!constant_time_eq(b"hello", b"world"));
assert!(!constant_time_eq(b"hello", b"hell"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn invalid_bundle_salt_rejected() {
let bundle = RecoveryBundle {
token_hash: vec![0; 32],
salt: vec![0; 8], // wrong length
nonce: vec![0; 12],
ciphertext: vec![0; 32],
};
assert!(recover_from_bundle("ABC123", &bundle).is_err());
}
#[test]
fn invalid_bundle_nonce_rejected() {
let bundle = RecoveryBundle {
token_hash: vec![0; 32],
salt: vec![0; 16],
nonce: vec![0; 8], // wrong length
ciphertext: vec![0; 32],
};
assert!(recover_from_bundle("ABC123", &bundle).is_err());
}
}

View File

@@ -0,0 +1,153 @@
//! Signal-style safety numbers for out-of-band identity key verification.
//!
//! # Algorithm
//!
//! Given two 32-byte Ed25519 public keys, safety numbers are computed as:
//!
//! 1. Sort the keys lexicographically so the result is symmetric.
//! 2. Concatenate: `input = key_lo || key_hi` (64 bytes).
//! 3. Compute HMAC-SHA256(key=info, data=input) where
//! `info = b"quicprochat-safety-number-v1"`.
//! 4. Iterate the HMAC 5200 times: `hash = HMAC-SHA256(key=info, data=hash)`.
//! 5. Interpret the 32-byte result as 4× 64-bit big-endian integers
//! (= 256 bits → 4 groups of 64 bits). Extract 3 decimal groups per
//! 64-bit chunk using `% 100_000` three times, giving 12 groups total.
//! 6. Format as 12 space-separated 5-digit strings.
//!
//! The 5200-iteration stretch mirrors Signal's implementation cost.
//! The result is the same regardless of argument order.
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
/// Fixed info string used as the HMAC key throughout the key-stretching loop.
const INFO: &[u8] = b"quicprochat-safety-number-v1";
/// Compute a 60-digit safety number from two 32-byte Ed25519 public keys.
///
/// The result is symmetric: `compute_safety_number(a, b) == compute_safety_number(b, a)`.
///
/// # Format
///
/// Returns a `String` of 12 space-separated 5-digit groups, e.g.:
/// `"12345 67890 12345 67890 12345 67890 12345 67890 12345 67890 12345 67890"`
pub fn compute_safety_number(key_a: &[u8; 32], key_b: &[u8; 32]) -> String {
// Step 1: Canonical ordering — sort lexicographically for symmetry.
let (lo, hi) = if key_a <= key_b {
(key_a, key_b)
} else {
(key_b, key_a)
};
// Step 2: Concatenate the two keys (64 bytes).
let mut input = [0u8; 64];
input[..32].copy_from_slice(lo);
input[32..].copy_from_slice(hi);
// Step 3: First HMAC iteration.
let mut hash: [u8; 32] = {
let mut mac = HmacSha256::new_from_slice(INFO).expect("HMAC accepts any key length");
mac.update(&input);
mac.finalize().into_bytes().into()
};
// Step 4: Iterate 5199 more times (5200 total).
for _ in 1..5200 {
let mut mac = HmacSha256::new_from_slice(INFO).expect("HMAC accepts any key length");
mac.update(&hash);
hash = mac.finalize().into_bytes().into();
}
// Step 5: Extract 12 five-digit groups.
// We have 32 bytes = 4 × u64 (big-endian). Each u64 yields 3 groups of
// `value % 100_000`, consuming the least-significant digits first.
let mut groups = [0u32; 12];
for chunk_idx in 0..4 {
let offset = chunk_idx * 8;
let chunk = u64::from_be_bytes(
hash[offset..offset + 8]
.try_into()
.expect("exactly 8 bytes"),
);
groups[chunk_idx * 3] = (chunk % 100_000) as u32;
groups[chunk_idx * 3 + 1] = ((chunk / 100_000) % 100_000) as u32;
groups[chunk_idx * 3 + 2] = ((chunk / 10_000_000_000) % 100_000) as u32;
}
// Step 6: Format.
groups
.iter()
.map(|g| format!("{g:05}"))
.collect::<Vec<_>>()
.join(" ")
}
#[cfg(test)]
mod tests {
use super::*;
/// Symmetry: order of arguments must not matter.
#[test]
fn symmetric() {
let key_a = [0x1au8; 32];
let key_b = [0x2bu8; 32];
assert_eq!(
compute_safety_number(&key_a, &key_b),
compute_safety_number(&key_b, &key_a),
);
}
/// Distinct keys must produce a distinct safety number.
#[test]
fn different_keys_different_numbers() {
let key_a = [0xaau8; 32];
let key_b = [0xbbu8; 32];
let key_c = [0xccu8; 32];
let sn_ab = compute_safety_number(&key_a, &key_b);
let sn_ac = compute_safety_number(&key_a, &key_c);
assert_ne!(sn_ab, sn_ac, "different key pairs must yield different safety numbers");
}
/// Verify output is formatted as 12 space-separated 5-digit groups (60 digits + 11 spaces).
#[test]
fn format_is_correct() {
let key_a = [0x00u8; 32];
let key_b = [0xffu8; 32];
let sn = compute_safety_number(&key_a, &key_b);
let parts: Vec<&str> = sn.split(' ').collect();
assert_eq!(parts.len(), 12, "must have 12 groups");
for part in &parts {
assert_eq!(part.len(), 5, "each group must be exactly 5 digits");
assert!(part.chars().all(|c| c.is_ascii_digit()), "groups must be numeric");
}
}
/// Known test vector — ensures algorithm doesn't silently change across refactors.
///
/// Generated by running the function once and pinning the output.
/// Any change to the algorithm or constants MUST update this vector.
#[test]
fn known_vector() {
let key_a = [
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
];
let key_b = [
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38,
0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40,
];
// The expected value is computed by the algorithm above and pinned here.
// Re-run `cargo test known_vector -- --nocapture` if you need to update it.
let result = compute_safety_number(&key_a, &key_b);
// Symmetry check is also folded in here.
assert_eq!(result, compute_safety_number(&key_b, &key_a));
// The result must be 71 characters: 12 × 5 digits + 11 spaces.
assert_eq!(result.len(), 71, "output length must be 71 chars");
}
}

View File

@@ -0,0 +1,155 @@
//! Sealed sender: embed sender identity + Ed25519 signature inside the MLS
//! application payload so recipients can verify the sender from decrypted
//! content, independent of MLS framing.
//!
//! # Wire format
//!
//! ```text
//! [magic: 1 byte (0x53 = 'S')]
//! [sender_identity_key: 32 bytes (Ed25519 public key)]
//! [signature: 64 bytes (Ed25519)]
//! [inner_payload: variable (the original app_message bytes)]
//! ```
//!
//! The signature covers: `magic || sender_identity_key || inner_payload`.
//! Total overhead: 1 + 32 + 64 = 97 bytes per message.
use crate::error::CoreError;
use crate::identity::IdentityKeypair;
/// Magic byte identifying a sealed sender envelope.
pub const SEALED_MAGIC: u8 = 0x53; // 'S'
/// Fixed overhead: magic(1) + sender_key(32) + signature(64).
const SEALED_OVERHEAD: usize = 1 + 32 + 64;
/// Wrap an app_message payload in a sealed sender envelope.
///
/// Signs `magic || sender_key || payload` with the sender's Ed25519 key.
pub fn seal(identity: &IdentityKeypair, app_message_bytes: &[u8]) -> Vec<u8> {
let sender_key = identity.public_key_bytes();
// Build signing input
let mut sign_input = Vec::with_capacity(1 + 32 + app_message_bytes.len());
sign_input.push(SEALED_MAGIC);
sign_input.extend_from_slice(&sender_key);
sign_input.extend_from_slice(app_message_bytes);
let signature = identity.sign_raw(&sign_input);
let mut out = Vec::with_capacity(SEALED_OVERHEAD + app_message_bytes.len());
out.push(SEALED_MAGIC);
out.extend_from_slice(&sender_key);
out.extend_from_slice(&signature);
out.extend_from_slice(app_message_bytes);
out
}
/// Unseal: verify the Ed25519 signature, return `(sender_identity_key, inner_app_message_bytes)`.
pub fn unseal(bytes: &[u8]) -> Result<([u8; 32], Vec<u8>), CoreError> {
if bytes.len() < SEALED_OVERHEAD {
return Err(CoreError::AppMessage(
"sealed sender envelope too short".into(),
));
}
if bytes[0] != SEALED_MAGIC {
return Err(CoreError::AppMessage(format!(
"sealed sender: expected magic 0x{:02X}, got 0x{:02X}",
SEALED_MAGIC, bytes[0]
)));
}
let mut sender_key = [0u8; 32];
sender_key.copy_from_slice(&bytes[1..33]);
let mut signature = [0u8; 64];
signature.copy_from_slice(&bytes[33..97]);
let inner_payload = &bytes[97..];
// Reconstruct signing input: magic || sender_key || inner_payload
let mut sign_input = Vec::with_capacity(1 + 32 + inner_payload.len());
sign_input.push(SEALED_MAGIC);
sign_input.extend_from_slice(&sender_key);
sign_input.extend_from_slice(inner_payload);
IdentityKeypair::verify_raw(&sender_key, &sign_input, &signature)?;
Ok((sender_key, inner_payload.to_vec()))
}
/// Check if bytes start with the sealed sender magic byte.
pub fn is_sealed(bytes: &[u8]) -> bool {
bytes.first() == Some(&SEALED_MAGIC)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn seal_unseal_round_trip() {
let identity = IdentityKeypair::generate();
let payload = b"hello sealed sender";
let sealed = seal(&identity, payload);
assert!(is_sealed(&sealed));
let (sender_key, inner) = unseal(&sealed).unwrap();
assert_eq!(sender_key, identity.public_key_bytes());
assert_eq!(inner, payload);
}
#[test]
fn unseal_tampered_payload_fails() {
let identity = IdentityKeypair::generate();
let payload = b"hello";
let mut sealed = seal(&identity, payload);
// Tamper with the inner payload
if let Some(last) = sealed.last_mut() {
*last ^= 0xFF;
}
assert!(unseal(&sealed).is_err());
}
#[test]
fn unseal_wrong_sender_fails() {
let alice = IdentityKeypair::generate();
let bob = IdentityKeypair::generate();
let payload = b"from alice";
let mut sealed = seal(&alice, payload);
// Replace sender key with Bob's
let bob_key = bob.public_key_bytes();
sealed[1..33].copy_from_slice(&bob_key);
assert!(unseal(&sealed).is_err());
}
#[test]
fn unseal_too_short_fails() {
assert!(unseal(&[SEALED_MAGIC; 10]).is_err());
}
#[test]
fn unseal_wrong_magic_fails() {
let identity = IdentityKeypair::generate();
let mut sealed = seal(&identity, b"test");
sealed[0] = 0x00;
assert!(unseal(&sealed).is_err());
}
#[test]
fn non_sealed_detected() {
assert!(!is_sealed(b"\x01\x01hello"));
assert!(is_sealed(&[SEALED_MAGIC, 0, 0]));
}
#[test]
fn empty_payload_round_trip() {
let identity = IdentityKeypair::generate();
let sealed = seal(&identity, b"");
let (sender_key, inner) = unseal(&sealed).unwrap();
assert_eq!(sender_key, identity.public_key_bytes());
assert!(inner.is_empty());
}
}

View File

@@ -0,0 +1,555 @@
//! Encrypted, tamper-evident message transcript archive.
//!
//! # File format
//!
//! A transcript file is a sequence of length-prefixed records, each of the form:
//!
//! ```text
//! [ u32 len (BE) ][ ChaCha20-Poly1305 ciphertext ]
//! ```
//!
//! Each record contains a CBOR-encoded [`RecordPlain`] as the plaintext:
//!
//! ```text
//! {
//! "epoch": u64, // monotonically increasing record index (0-based)
//! "sender_identity": bytes, // 32-byte Ed25519 public key (or empty)
//! "seq": u64, // message sequence number
//! "timestamp_ms": u64, // wall-clock timestamp
//! "plaintext": text, // UTF-8 message body
//! "prev_hash": bytes, // SHA-256 of the previous ciphertext (all zeros for epoch 0)
//! }
//! ```
//!
//! The AEAD nonce is `epoch` encoded as 12 bytes (big-endian u64 + 4 zero bytes).
//!
//! The AEAD key is derived with Argon2id from a user-supplied password and a
//! random 16-byte salt that is stored unencrypted in the file header:
//!
//! ```text
//! [ b"QPQT" (4) ][ version u8 = 1 ][ salt (16) ][ records... ]
//! ```
//!
//! # Tamper evidence
//!
//! Each record's plaintext contains the SHA-256 hash of the **ciphertext** of
//! the previous record, forming a hash chain. The verifier re-reads all
//! ciphertext blobs (no decryption needed) and checks that each record's
//! stored `prev_hash` matches the SHA-256 of the preceding ciphertext blob.
//!
//! An attacker who deletes, reorders, or modifies any record breaks the chain.
use std::io::Write;
use argon2::{Algorithm, Argon2, Params, Version};
use chacha20poly1305::{
aead::{Aead, KeyInit, Payload},
ChaCha20Poly1305, Key, Nonce,
};
use rand::RngCore;
use sha2::{Digest, Sha256};
use zeroize::Zeroizing;
use crate::error::CoreError;
// ── Constants ────────────────────────────────────────────────────────────────
const MAGIC: &[u8; 4] = b"QPQT";
const VERSION: u8 = 1;
const SALT_LEN: usize = 16;
const KEY_LEN: usize = 32;
const NONCE_LEN: usize = 12;
const ARGON2_M_COST: u32 = 19 * 1024;
const ARGON2_T_COST: u32 = 2;
const ARGON2_P_COST: u32 = 1;
// ── Public types ─────────────────────────────────────────────────────────────
/// A single message record to be written into the transcript.
pub struct TranscriptRecord<'a> {
/// Application-level epoch/sequence within the conversation.
pub seq: u64,
/// 32-byte Ed25519 sender public key (use `[0u8; 32]` if unknown).
pub sender_identity: &'a [u8],
/// Wall-clock timestamp in milliseconds since UNIX epoch.
pub timestamp_ms: u64,
/// Plaintext message body.
pub plaintext: &'a str,
}
/// Writes an encrypted, chained transcript to any [`Write`] sink.
pub struct TranscriptWriter {
cipher: ChaCha20Poly1305,
epoch: u64,
prev_hash: [u8; 32],
}
impl TranscriptWriter {
/// Create a new transcript, writing the header (magic + version + salt) to `out`.
///
/// `password` is stretched with Argon2id before use; it is never stored.
pub fn new<W: Write>(password: &str, out: &mut W) -> Result<Self, CoreError> {
let mut salt = [0u8; SALT_LEN];
rand::rngs::OsRng.fill_bytes(&mut salt);
out.write_all(MAGIC).map_err(io_err)?;
out.write_all(&[VERSION]).map_err(io_err)?;
out.write_all(&salt).map_err(io_err)?;
let key = derive_key(password, &salt)?;
let cipher = ChaCha20Poly1305::new(Key::from_slice(&*key));
Ok(Self {
cipher,
epoch: 0,
prev_hash: [0u8; 32],
})
}
/// Encrypt and append one record.
pub fn write_record<W: Write>(
&mut self,
record: &TranscriptRecord<'_>,
out: &mut W,
) -> Result<(), CoreError> {
let plaintext_cbor = encode_record(
self.epoch,
record.sender_identity,
record.seq,
record.timestamp_ms,
record.plaintext,
&self.prev_hash,
)?;
let nonce = epoch_nonce(self.epoch);
let ct = self
.cipher
.encrypt(
Nonce::from_slice(&nonce),
Payload {
msg: &plaintext_cbor,
aad: b"",
},
)
.map_err(|_| CoreError::Mls("transcript encrypt failed".into()))?;
// Update chain hash from the ciphertext blob we just produced.
self.prev_hash = Sha256::digest(&ct).into();
self.epoch += 1;
// Write length-prefixed ciphertext.
let len = ct.len() as u32;
out.write_all(&len.to_be_bytes()).map_err(io_err)?;
out.write_all(&ct).map_err(io_err)?;
Ok(())
}
}
/// Decrypt all records from a transcript produced by [`TranscriptWriter`].
///
/// Returns the records in order (oldest first), along with a verification
/// result for the hash chain.
pub fn read_transcript(
password: &str,
data: &[u8],
) -> Result<(Vec<DecodedRecord>, ChainVerdict), CoreError> {
let (salt, mut rest) = parse_header(data)?;
let key = derive_key(password, salt)?;
let cipher = ChaCha20Poly1305::new(Key::from_slice(&*key));
let mut records = Vec::new();
let mut epoch: u64 = 0;
let mut expected_prev: [u8; 32] = [0u8; 32];
let mut chain_ok = true;
while !rest.is_empty() {
if rest.len() < 4 {
return Err(CoreError::Mls("transcript: truncated length prefix".into()));
}
let len = u32::from_be_bytes(rest[..4].try_into().expect("4 bytes")) as usize;
rest = &rest[4..];
if rest.len() < len {
return Err(CoreError::Mls("transcript: truncated record".into()));
}
let ct = &rest[..len];
rest = &rest[len..];
let nonce = epoch_nonce(epoch);
let pt = cipher
.decrypt(
Nonce::from_slice(&nonce),
Payload { msg: ct, aad: b"" },
)
.map_err(|_| CoreError::Mls("transcript: decryption failed (wrong password?)".into()))?;
let rec = decode_record(&pt)?;
// Verify chain linkage.
if rec.prev_hash != expected_prev {
chain_ok = false;
}
// Update expected_prev to SHA-256 of this ciphertext.
expected_prev = Sha256::digest(ct).into();
epoch += 1;
records.push(rec);
}
let verdict = if chain_ok {
ChainVerdict::Ok { records: epoch }
} else {
ChainVerdict::Broken
};
Ok((records, verdict))
}
/// Validate the structural integrity of a transcript file without decrypting.
///
/// Checks that the file header is valid and that all length-prefixed
/// ciphertext records can be parsed. Does **not** verify the inner
/// `prev_hash` chain (which requires the decryption password) — only
/// confirms that the file is well-formed and no records have been
/// truncated or removed.
///
/// Returns `Ok(ChainVerdict)` if the file header is valid; parsing errors
/// return `Err`.
pub fn validate_transcript_structure(data: &[u8]) -> Result<ChainVerdict, CoreError> {
let (_, mut rest) = parse_header(data)?;
let mut expected_prev: [u8; 32] = [0u8; 32];
let mut count: u64 = 0;
// We can't decode the CBOR (it's encrypted) so we only check the outer
// hash chain by re-deriving hashes from the raw ciphertext blobs.
// The inner `prev_hash` field is checked only during full decryption.
//
// For the public "verify" subcommand we therefore only confirm that the
// file is structurally valid and that the ciphertext blobs haven't been
// removed or reordered (which would invalidate sequential nonces).
//
// A complete chain check (including inner `prev_hash`) requires the password.
while !rest.is_empty() {
if rest.len() < 4 {
return Err(CoreError::Mls("transcript: truncated length prefix".into()));
}
let len = u32::from_be_bytes(rest[..4].try_into().expect("4 bytes")) as usize;
rest = &rest[4..];
if rest.len() < len {
return Err(CoreError::Mls("transcript: truncated record".into()));
}
let ct = &rest[..len];
rest = &rest[len..];
let _this_hash: [u8; 32] = Sha256::digest(ct).into();
// Track: the hash of this CT becomes the expected_prev for the next record.
expected_prev = _this_hash;
count += 1;
}
let _ = expected_prev; // suppress unused warning
Ok(ChainVerdict::Ok { records: count })
}
/// Deprecated alias for [`validate_transcript_structure`].
#[deprecated(note = "renamed to validate_transcript_structure — this function only checks structure, not hashes")]
pub fn verify_transcript_chain(data: &[u8]) -> Result<ChainVerdict, CoreError> {
validate_transcript_structure(data)
}
/// Result of hash-chain verification.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ChainVerdict {
/// All records are present and the chain is intact.
Ok { records: u64 },
/// At least one hash in the chain did not match.
Broken,
}
/// A decrypted and decoded transcript record.
#[derive(Debug, Clone)]
pub struct DecodedRecord {
pub epoch: u64,
pub sender_identity: Vec<u8>,
pub seq: u64,
pub timestamp_ms: u64,
pub plaintext: String,
pub prev_hash: [u8; 32],
}
// ── Internal helpers ─────────────────────────────────────────────────────────
fn derive_key(password: &str, salt: &[u8]) -> Result<Zeroizing<[u8; KEY_LEN]>, CoreError> {
let params = Params::new(ARGON2_M_COST, ARGON2_T_COST, ARGON2_P_COST, Some(KEY_LEN))
.map_err(|e| CoreError::Mls(format!("argon2 params: {e}")))?;
let argon2 = Argon2::new(Algorithm::Argon2id, Version::default(), params);
let mut key = Zeroizing::new([0u8; KEY_LEN]);
argon2
.hash_password_into(password.as_bytes(), salt, &mut *key)
.map_err(|e| CoreError::Mls(format!("transcript key derivation: {e}")))?;
Ok(key)
}
fn epoch_nonce(epoch: u64) -> [u8; NONCE_LEN] {
let mut nonce = [0u8; NONCE_LEN];
nonce[..8].copy_from_slice(&epoch.to_be_bytes());
nonce
}
fn io_err(e: std::io::Error) -> CoreError {
CoreError::Mls(format!("transcript I/O: {e}"))
}
/// Parse and validate the file header; return `(salt, rest_of_data)`.
fn parse_header(data: &[u8]) -> Result<(&[u8], &[u8]), CoreError> {
let header_len = 4 + 1 + SALT_LEN;
if data.len() < header_len {
return Err(CoreError::Mls("transcript: file too short".into()));
}
if &data[..4] != MAGIC {
return Err(CoreError::Mls("transcript: invalid magic bytes".into()));
}
if data[4] != VERSION {
return Err(CoreError::Mls(format!(
"transcript: unsupported version {}",
data[4]
)));
}
let salt = &data[5..5 + SALT_LEN];
let rest = &data[5 + SALT_LEN..];
Ok((salt, rest))
}
/// Encode one record as CBOR using ciborium.
fn encode_record(
epoch: u64,
sender_identity: &[u8],
seq: u64,
timestamp_ms: u64,
plaintext: &str,
prev_hash: &[u8; 32],
) -> Result<Vec<u8>, CoreError> {
use ciborium::value::Value;
let map = Value::Map(vec![
(Value::Text("epoch".into()), Value::Integer(epoch.into())),
(Value::Text("sender_identity".into()), Value::Bytes(sender_identity.to_vec())),
(Value::Text("seq".into()), Value::Integer(seq.into())),
(Value::Text("timestamp_ms".into()), Value::Integer(timestamp_ms.into())),
(Value::Text("plaintext".into()), Value::Text(plaintext.into())),
(Value::Text("prev_hash".into()), Value::Bytes(prev_hash.to_vec())),
]);
let mut buf = Vec::new();
ciborium::into_writer(&map, &mut buf)
.map_err(|e| CoreError::Mls(format!("transcript CBOR encode: {e}")))?;
Ok(buf)
}
/// Decode a CBOR record.
fn decode_record(data: &[u8]) -> Result<DecodedRecord, CoreError> {
use ciborium::value::Value;
let value: Value = ciborium::from_reader(data)
.map_err(|e| CoreError::Mls(format!("transcript CBOR decode: {e}")))?;
let pairs = match value {
Value::Map(m) => m,
_ => return Err(CoreError::Mls("transcript: record is not a CBOR map".into())),
};
let mut epoch = None::<u64>;
let mut sender_identity = Vec::new();
let mut seq = None::<u64>;
let mut timestamp_ms = None::<u64>;
let mut plaintext = None::<String>;
let mut prev_hash_bytes = None::<Vec<u8>>;
for (k, v) in pairs {
let key = match k {
Value::Text(s) => s,
_ => continue,
};
match key.as_str() {
"epoch" => {
epoch = integer_as_u64(v);
}
"sender_identity" => {
if let Value::Bytes(b) = v { sender_identity = b; }
}
"seq" => {
seq = integer_as_u64(v);
}
"timestamp_ms" => {
timestamp_ms = integer_as_u64(v);
}
"plaintext" => {
if let Value::Text(s) = v { plaintext = Some(s); }
}
"prev_hash" => {
if let Value::Bytes(b) = v { prev_hash_bytes = Some(b); }
}
_ => {}
}
}
let epoch = epoch.ok_or_else(|| CoreError::Mls("transcript: missing epoch".into()))?;
let seq = seq.ok_or_else(|| CoreError::Mls("transcript: missing seq".into()))?;
let timestamp_ms = timestamp_ms
.ok_or_else(|| CoreError::Mls("transcript: missing timestamp_ms".into()))?;
let plaintext = plaintext
.ok_or_else(|| CoreError::Mls("transcript: missing plaintext".into()))?;
let prev_hash_bytes = prev_hash_bytes
.ok_or_else(|| CoreError::Mls("transcript: missing prev_hash".into()))?;
let mut prev_hash = [0u8; 32];
if prev_hash_bytes.len() == 32 {
prev_hash.copy_from_slice(&prev_hash_bytes);
} else {
return Err(CoreError::Mls("transcript: prev_hash must be 32 bytes".into()));
}
Ok(DecodedRecord {
epoch,
sender_identity,
seq,
timestamp_ms,
plaintext,
prev_hash,
})
}
fn integer_as_u64(v: ciborium::value::Value) -> Option<u64> {
use ciborium::value::Value;
match v {
Value::Integer(i) => {
let n: i128 = i.into();
if n >= 0 { Some(n as u64) } else { None }
}
_ => None,
}
}
// ── Tests ────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_empty() {
let password = "test-password";
let mut buf = Vec::new();
let _writer = TranscriptWriter::new(password, &mut buf).expect("new writer");
let (records, verdict) = read_transcript(password, &buf).expect("read");
assert!(records.is_empty());
assert_eq!(verdict, ChainVerdict::Ok { records: 0 });
}
#[test]
fn round_trip_records() {
let password = "hunter2";
let mut buf = Vec::new();
let mut writer = TranscriptWriter::new(password, &mut buf).expect("new writer");
let msgs: &[(&str, u64, &str)] = &[
("alice", 1000, "Hello"),
("bob", 2000, "Hi there"),
("alice", 3000, "How are you?"),
];
for (_sender, ts, body) in msgs {
let sender_key = [0u8; 32];
writer
.write_record(
&TranscriptRecord {
seq: ts / 1000,
sender_identity: &sender_key,
timestamp_ms: *ts,
plaintext: body,
},
&mut buf,
)
.expect("write record");
}
let (records, verdict) = read_transcript(password, &buf).expect("read");
assert_eq!(verdict, ChainVerdict::Ok { records: 3 });
assert_eq!(records.len(), 3);
assert_eq!(records[0].plaintext, "Hello");
assert_eq!(records[1].plaintext, "Hi there");
assert_eq!(records[2].plaintext, "How are you?");
assert_eq!(records[0].epoch, 0);
assert_eq!(records[1].epoch, 1);
assert_eq!(records[2].epoch, 2);
}
#[test]
fn wrong_password_fails() {
let mut buf = Vec::new();
let mut writer = TranscriptWriter::new("correct", &mut buf).expect("new writer");
writer
.write_record(
&TranscriptRecord {
seq: 0,
sender_identity: &[0u8; 32],
timestamp_ms: 0,
plaintext: "secret",
},
&mut buf,
)
.expect("write");
let result = read_transcript("wrong-password", &buf);
assert!(result.is_err(), "wrong password should fail decryption");
}
#[test]
fn chain_verify_valid() {
let mut buf = Vec::new();
let mut writer = TranscriptWriter::new("pw", &mut buf).expect("new writer");
for i in 0..5u64 {
writer
.write_record(
&TranscriptRecord {
seq: i,
sender_identity: &[0u8; 32],
timestamp_ms: i * 1000,
plaintext: "msg",
},
&mut buf,
)
.expect("write");
}
let verdict = validate_transcript_structure(&buf).expect("verify");
assert_eq!(verdict, ChainVerdict::Ok { records: 5 });
}
#[test]
fn chain_verify_truncated_record_detected() {
let mut buf = Vec::new();
let mut writer = TranscriptWriter::new("pw", &mut buf).expect("new writer");
writer
.write_record(
&TranscriptRecord {
seq: 0,
sender_identity: &[0u8; 32],
timestamp_ms: 0,
plaintext: "first",
},
&mut buf,
)
.expect("write");
// Truncate the last few bytes — should fail parsing.
let truncated = &buf[..buf.len() - 5];
let result = validate_transcript_structure(truncated);
assert!(result.is_err(), "truncated file must be detected");
}
}