//! Hybrid Noise_XX + ML-KEM-768 handshake for post-quantum transport security. //! //! Implements a three-message Noise_XX pattern with an embedded ML-KEM-768 //! encapsulation to produce a hybrid shared secret that is secure against //! both classical and quantum adversaries. //! //! # Handshake pattern //! //! ```text //! XX(s, rs): //! -> e (initiator ephemeral) //! <- e, ee, s, es, mlkem_ct (responder ephemeral + static + ML-KEM ciphertext) //! -> s, se (initiator static) //! ``` //! //! After message 2, the ML-KEM shared secret is mixed into the chaining key //! via HKDF. The final transport keys incorporate both the X25519 DH chain //! and the ML-KEM shared secret. //! //! # Wire format //! //! Each handshake message is a simple length-prefixed blob: //! ```text //! [msg_len: u32 BE][handshake message bytes] //! ``` //! //! # Feature gate //! //! This module is always compiled but the `pq-noise` feature enables it //! in the RPC layer for server/client negotiation. use chacha20poly1305::{ aead::{Aead, KeyInit, Payload}, ChaCha20Poly1305, Key, Nonce, }; use hkdf::Hkdf; use ml_kem::{ array::Array, kem::{Decapsulate, Encapsulate}, EncodedSizeUser, KemCore, MlKem768, MlKem768Params, }; use ml_kem::kem::{DecapsulationKey, EncapsulationKey}; use rand::rngs::OsRng; use sha2::Sha256; use x25519_dalek::{PublicKey as X25519Public, StaticSecret}; use zeroize::Zeroizing; use crate::error::CoreError; /// Domain separation label for the hybrid Noise handshake. const PROTOCOL_NAME: &[u8] = b"quicprochat-pq-noise-v1"; /// ML-KEM-768 encapsulation key length. const MLKEM_EK_LEN: usize = 1184; /// ML-KEM-768 ciphertext length. const MLKEM_CT_LEN: usize = 1088; /// AEAD tag length (ChaCha20-Poly1305). const TAG_LEN: usize = 16; // ── Keypair ────────────────────────────────────────────────────────────────── /// A static keypair for the hybrid Noise handshake. /// /// Contains both an X25519 static key and an ML-KEM-768 key pair. pub struct NoiseKeypair { x25519_sk: StaticSecret, x25519_pk: X25519Public, mlkem_dk: DecapsulationKey, mlkem_ek: EncapsulationKey, } impl NoiseKeypair { /// Generate a fresh keypair from OS CSPRNG. pub fn generate() -> Self { let x25519_sk = StaticSecret::random_from_rng(OsRng); let x25519_pk = X25519Public::from(&x25519_sk); let (mlkem_dk, mlkem_ek) = MlKem768::generate(&mut OsRng); Self { x25519_sk, x25519_pk, mlkem_dk, mlkem_ek, } } /// Return the X25519 public key bytes. pub fn x25519_public(&self) -> [u8; 32] { self.x25519_pk.to_bytes() } /// Return the ML-KEM-768 encapsulation key bytes. pub fn mlkem_public(&self) -> Vec { self.mlkem_ek.as_bytes().to_vec() } } // ── Chaining key state ─────────────────────────────────────────────────────── /// Internal handshake state tracking the Noise chaining key and handshake hash. struct HandshakeState { /// Chaining key — evolved by each MixKey operation. ck: Zeroizing<[u8; 32]>, /// Handshake hash — commits to all handshake transcript data. h: [u8; 32], /// Current encryption key (derived from ck after MixKey). k: Option>, /// Nonce counter for in-handshake encryption. n: u64, } impl HandshakeState { fn new() -> Self { // Initialize h = SHA-256(protocol_name), ck = h. use sha2::{Digest, Sha256}; let h: [u8; 32] = Sha256::digest(PROTOCOL_NAME).into(); Self { ck: Zeroizing::new(h), h, k: None, n: 0, } } /// MixHash: h = SHA-256(h || data) fn mix_hash(&mut self, data: &[u8]) { use sha2::{Digest, Sha256}; let mut hasher = Sha256::new(); hasher.update(self.h); hasher.update(data); self.h = hasher.finalize().into(); } /// MixKey: (ck, k) = HKDF(ck, input_key_material) fn mix_key(&mut self, ikm: &[u8]) { let hk = Hkdf::::new(Some(&*self.ck), ikm); let mut ck = Zeroizing::new([0u8; 32]); let mut k = Zeroizing::new([0u8; 32]); hk.expand(b"ck", &mut *ck) .expect("32 bytes is valid HKDF output"); hk.expand(b"k", &mut *k) .expect("32 bytes is valid HKDF output"); self.ck = ck; self.k = Some(k); self.n = 0; } /// Encrypt plaintext with the current key and nonce, using h as AAD. fn encrypt_and_hash(&mut self, plaintext: &[u8]) -> Result, CoreError> { let key = self .k .as_ref() .ok_or_else(|| CoreError::Mls("pq_noise: no encryption key set".into()))?; let cipher = ChaCha20Poly1305::new(Key::from_slice(&**key)); let nonce = nonce_from_counter(self.n); let ct = cipher .encrypt( Nonce::from_slice(&nonce), Payload { msg: plaintext, aad: &self.h, }, ) .map_err(|_| CoreError::Mls("pq_noise: encrypt failed".into()))?; self.mix_hash(&ct); self.n += 1; Ok(ct) } /// Decrypt ciphertext with the current key and nonce, using h as AAD. fn decrypt_and_hash(&mut self, ciphertext: &[u8]) -> Result, CoreError> { let key = self .k .as_ref() .ok_or_else(|| CoreError::Mls("pq_noise: no decryption key set".into()))?; let cipher = ChaCha20Poly1305::new(Key::from_slice(&**key)); let nonce = nonce_from_counter(self.n); let ct_for_hash = ciphertext.to_vec(); let pt = cipher .decrypt( Nonce::from_slice(&nonce), Payload { msg: ciphertext, aad: &self.h, }, ) .map_err(|_| CoreError::Mls("pq_noise: decrypt failed".into()))?; self.mix_hash(&ct_for_hash); self.n += 1; Ok(pt) } /// Split the handshake state into two transport keys (initiator->responder, responder->initiator). fn split(&self) -> (TransportKey, TransportKey) { let hk = Hkdf::::new(Some(&*self.ck), &[]); let mut k1 = Zeroizing::new([0u8; 32]); let mut k2 = Zeroizing::new([0u8; 32]); hk.expand(b"initiator", &mut *k1) .expect("32 bytes is valid HKDF output"); hk.expand(b"responder", &mut *k2) .expect("32 bytes is valid HKDF output"); ( TransportKey { key: k1, nonce: 0 }, TransportKey { key: k2, nonce: 0 }, ) } } fn nonce_from_counter(n: u64) -> [u8; 12] { let mut nonce = [0u8; 12]; nonce[4..].copy_from_slice(&n.to_le_bytes()); nonce } // ── Transport ──────────────────────────────────────────────────────────────── /// A transport encryption key with a nonce counter. pub struct TransportKey { key: Zeroizing<[u8; 32]>, nonce: u64, } impl TransportKey { /// Encrypt a message for transport. pub fn encrypt(&mut self, plaintext: &[u8]) -> Result, CoreError> { let cipher = ChaCha20Poly1305::new(Key::from_slice(&*self.key)); let nonce = nonce_from_counter(self.nonce); let ct = cipher .encrypt(Nonce::from_slice(&nonce), plaintext) .map_err(|_| CoreError::Mls("pq_noise transport: encrypt failed".into()))?; self.nonce += 1; Ok(ct) } /// Decrypt a transport message. pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result, CoreError> { let cipher = ChaCha20Poly1305::new(Key::from_slice(&*self.key)); let nonce = nonce_from_counter(self.nonce); let pt = cipher .decrypt(Nonce::from_slice(&nonce), ciphertext) .map_err(|_| CoreError::Mls("pq_noise transport: decrypt failed".into()))?; self.nonce += 1; Ok(pt) } } // ── Initiator ──────────────────────────────────────────────────────────────── /// Initiator side of the hybrid Noise_XX handshake. pub struct Initiator { state: HandshakeState, /// Ephemeral secret stored as StaticSecret so DH doesn't consume it. /// Generated from OsRng; we use StaticSecret purely for the non-consuming /// `diffie_hellman(&self, ...)` API — the key is still ephemeral. e_sk: StaticSecret, e_pk: X25519Public, s: NoiseKeypair, /// Stored after reading message 2 so we can compute se = DH(s, re) in msg3. re_pk: Option, } impl Initiator { /// Create a new initiator with the given static keypair. pub fn new(static_keypair: NoiseKeypair) -> Self { let e_sk = StaticSecret::random_from_rng(OsRng); let e_pk = X25519Public::from(&e_sk); Self { state: HandshakeState::new(), e_sk, e_pk, s: static_keypair, re_pk: None, } } /// Write message 1: `-> e` /// /// Returns the initiator's ephemeral X25519 public key (32 bytes). pub fn write_message_1(&mut self) -> Vec { let e_pk_bytes = self.e_pk.to_bytes(); self.state.mix_hash(&e_pk_bytes); e_pk_bytes.to_vec() } /// Read message 2 from responder: `<- e, ee, s, es, mlkem_ct` /// /// Expects: `re_pk(32) || encrypted_rs_pk(32+TAG) || mlkem_ct(1088)` /// /// Returns the responder's static X25519 public key. pub fn read_message_2(&mut self, msg: &[u8]) -> Result<[u8; 32], CoreError> { let expected_len = 32 + 32 + TAG_LEN + MLKEM_CT_LEN; if msg.len() != expected_len { return Err(CoreError::Mls(format!( "pq_noise msg2: expected {expected_len} bytes, got {}", msg.len() ))); } let mut cursor = 0; // re = responder ephemeral public key let mut re_pk_bytes = [0u8; 32]; re_pk_bytes.copy_from_slice(&msg[cursor..cursor + 32]); cursor += 32; let re_pk = X25519Public::from(re_pk_bytes); self.state.mix_hash(&re_pk_bytes); self.re_pk = Some(re_pk); // ee = DH(e, re) let ee_ss = self.e_sk.diffie_hellman(&re_pk); self.state.mix_key(ee_ss.as_bytes()); // Decrypt responder's static key: s = Dec(encrypted_rs_pk) let encrypted_rs = &msg[cursor..cursor + 32 + TAG_LEN]; cursor += 32 + TAG_LEN; let rs_pk_bytes = self.state.decrypt_and_hash(encrypted_rs)?; let mut rs_pk_arr = [0u8; 32]; if rs_pk_bytes.len() != 32 { return Err(CoreError::Mls("pq_noise: decrypted rs not 32 bytes".into())); } rs_pk_arr.copy_from_slice(&rs_pk_bytes); let rs_pk = X25519Public::from(rs_pk_arr); // es = DH(e, rs) let es_ss = self.e_sk.diffie_hellman(&rs_pk); self.state.mix_key(es_ss.as_bytes()); // ML-KEM: decapsulate the ciphertext from the responder let mlkem_ct = &msg[cursor..cursor + MLKEM_CT_LEN]; let mlkem_ct_arr = Array::try_from(mlkem_ct) .map_err(|_| CoreError::Mls("pq_noise: invalid ML-KEM ciphertext".into()))?; let mlkem_ss: ml_kem::SharedKey = self .s .mlkem_dk .decapsulate(&mlkem_ct_arr) .map_err(|_| CoreError::Mls("pq_noise: ML-KEM decapsulation failed".into()))?; self.state.mix_key(&mlkem_ss); Ok(rs_pk_arr) } /// Write message 3: `-> s, se` /// /// Returns the encrypted initiator static key. pub fn write_message_3(&mut self) -> Result, CoreError> { let re_pk = self .re_pk .ok_or_else(|| CoreError::Mls("pq_noise: must read msg2 before writing msg3".into()))?; // Encrypt our static key let s_pk_bytes = self.s.x25519_pk.to_bytes(); let encrypted_s = self.state.encrypt_and_hash(&s_pk_bytes)?; // se = DH(s, re) let se_ss = self.s.x25519_sk.diffie_hellman(&re_pk); self.state.mix_key(se_ss.as_bytes()); Ok(encrypted_s) } /// Finalize the handshake and return transport keys. /// /// Returns `(send_key, recv_key)` — initiator sends with send_key. pub fn finalize(self) -> (TransportKey, TransportKey) { self.state.split() } } // ── Responder ──────────────────────────────────────────────────────────────── /// Responder side of the hybrid Noise_XX handshake. pub struct Responder { state: HandshakeState, /// Ephemeral secret stored as StaticSecret so DH doesn't consume it. e_sk: StaticSecret, e_pk: X25519Public, s: NoiseKeypair, } impl Responder { /// Create a new responder with the given static keypair. pub fn new(static_keypair: NoiseKeypair) -> Self { let e_sk = StaticSecret::random_from_rng(OsRng); let e_pk = X25519Public::from(&e_sk); Self { state: HandshakeState::new(), e_sk, e_pk, s: static_keypair, } } /// Read message 1 from initiator: `-> e` /// /// Expects the initiator's ephemeral X25519 public key (32 bytes). pub fn read_message_1(&mut self, msg: &[u8]) -> Result<(), CoreError> { if msg.len() != 32 { return Err(CoreError::Mls(format!( "pq_noise msg1: expected 32 bytes, got {}", msg.len() ))); } self.state.mix_hash(msg); Ok(()) } /// Write message 2: `<- e, ee, s, es, mlkem_ct` /// /// `initiator_ek` is the initiator's ML-KEM encapsulation key. /// /// Returns the message bytes. pub fn write_message_2( &mut self, initiator_e_pk: &[u8; 32], initiator_mlkem_ek: &[u8], ) -> Result, CoreError> { let ie_pk = X25519Public::from(*initiator_e_pk); // Our ephemeral key let e_pk_bytes = self.e_pk.to_bytes(); self.state.mix_hash(&e_pk_bytes); // ee = DH(e, ie) let ee_ss = self.e_sk.diffie_hellman(&ie_pk); self.state.mix_key(ee_ss.as_bytes()); // Encrypt our static key let s_pk_bytes = self.s.x25519_pk.to_bytes(); let encrypted_s = self.state.encrypt_and_hash(&s_pk_bytes)?; // es = DH(s, ie) let es_ss = self.s.x25519_sk.diffie_hellman(&ie_pk); self.state.mix_key(es_ss.as_bytes()); // ML-KEM: encapsulate to the initiator's encapsulation key if initiator_mlkem_ek.len() != MLKEM_EK_LEN { return Err(CoreError::Mls(format!( "pq_noise: expected ML-KEM EK {} bytes, got {}", MLKEM_EK_LEN, initiator_mlkem_ek.len() ))); } let ek_arr = Array::try_from(initiator_mlkem_ek) .map_err(|_| CoreError::Mls("pq_noise: invalid ML-KEM encapsulation key".into()))?; let ek = EncapsulationKey::::from_bytes(&ek_arr); let (mlkem_ct, mlkem_ss): (ml_kem::Ciphertext, ml_kem::SharedKey) = ek .encapsulate(&mut OsRng) .map_err(|_| CoreError::Mls("pq_noise: ML-KEM encapsulation failed".into()))?; self.state.mix_key(&mlkem_ss); // Assemble: e_pk || encrypted_s || mlkem_ct let mut out = Vec::with_capacity(32 + encrypted_s.len() + MLKEM_CT_LEN); out.extend_from_slice(&e_pk_bytes); out.extend_from_slice(&encrypted_s); out.extend_from_slice(&mlkem_ct); Ok(out) } /// Read message 3 from initiator: `-> s, se` /// /// Returns the initiator's static X25519 public key. pub fn read_message_3(&mut self, msg: &[u8]) -> Result<[u8; 32], CoreError> { if msg.len() != 32 + TAG_LEN { return Err(CoreError::Mls(format!( "pq_noise msg3: expected {} bytes, got {}", 32 + TAG_LEN, msg.len() ))); } // Decrypt initiator's static key let is_pk_bytes = self.state.decrypt_and_hash(msg)?; let mut is_pk_arr = [0u8; 32]; if is_pk_bytes.len() != 32 { return Err(CoreError::Mls( "pq_noise: decrypted initiator static not 32 bytes".into(), )); } is_pk_arr.copy_from_slice(&is_pk_bytes); let is_pk = X25519Public::from(is_pk_arr); // se = DH(e, is) — responder computes using ephemeral key let se_ss = self.e_sk.diffie_hellman(&is_pk); self.state.mix_key(se_ss.as_bytes()); Ok(is_pk_arr) } /// Finalize the handshake and return transport keys. /// /// Returns `(recv_key, send_key)` — responder receives with recv_key. pub fn finalize(self) -> (TransportKey, TransportKey) { let (i2r, r2i) = self.state.split(); (i2r, r2i) } } // ── Tests ──────────────────────────────────────────────────────────────────── #[cfg(test)] #[allow(clippy::unwrap_used)] mod tests { use super::*; #[test] fn full_handshake_round_trip() { let initiator_kp = NoiseKeypair::generate(); let responder_kp = NoiseKeypair::generate(); // Initiator's ML-KEM public key is sent out-of-band (or in a pre-message). let initiator_mlkem_ek = initiator_kp.mlkem_public(); let mut initiator = Initiator::new(initiator_kp); let mut responder = Responder::new(responder_kp); // Message 1: initiator -> responder let msg1 = initiator.write_message_1(); assert_eq!(msg1.len(), 32); responder.read_message_1(&msg1).unwrap(); // Message 2: responder -> initiator let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap(); let msg2 = responder .write_message_2(&ie_pk, &initiator_mlkem_ek) .unwrap(); let _responder_static = initiator.read_message_2(&msg2).unwrap(); // Message 3: initiator -> responder let msg3 = initiator.write_message_3().unwrap(); let _initiator_static = responder.read_message_3(&msg3).unwrap(); // Derive transport keys let (mut i_send, mut i_recv) = initiator.finalize(); let (mut r_recv, mut r_send) = responder.finalize(); // Test transport: initiator -> responder let plaintext = b"hello post-quantum world!"; let ct = i_send.encrypt(plaintext).unwrap(); let pt = r_recv.decrypt(&ct).unwrap(); assert_eq!(pt, plaintext); // Test transport: responder -> initiator let plaintext2 = b"reply from responder"; let ct2 = r_send.encrypt(plaintext2).unwrap(); let pt2 = i_recv.decrypt(&ct2).unwrap(); assert_eq!(pt2, plaintext2); } #[test] fn tampered_msg2_fails() { let initiator_kp = NoiseKeypair::generate(); let responder_kp = NoiseKeypair::generate(); let initiator_mlkem_ek = initiator_kp.mlkem_public(); let mut initiator = Initiator::new(initiator_kp); let mut responder = Responder::new(responder_kp); let msg1 = initiator.write_message_1(); responder.read_message_1(&msg1).unwrap(); let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap(); let mut msg2 = responder .write_message_2(&ie_pk, &initiator_mlkem_ek) .unwrap(); // Tamper with the encrypted static key region msg2[40] ^= 0xFF; let result = initiator.read_message_2(&msg2); assert!(result.is_err()); } #[test] fn wrong_mlkem_key_fails() { let initiator_kp = NoiseKeypair::generate(); let responder_kp = NoiseKeypair::generate(); // Use a different keypair's ML-KEM key — decapsulation will use // implicit rejection, producing a pseudorandom (wrong) shared secret. let wrong_kp = NoiseKeypair::generate(); let wrong_mlkem_ek = wrong_kp.mlkem_public(); let mut initiator = Initiator::new(initiator_kp); let mut responder = Responder::new(responder_kp); let msg1 = initiator.write_message_1(); responder.read_message_1(&msg1).unwrap(); let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap(); let msg2 = responder .write_message_2(&ie_pk, &wrong_mlkem_ek) .unwrap(); // ML-KEM implicit rejection: decap succeeds but returns wrong ss. // The ML-KEM mix_key happens after the AEAD decrypt of the static key, // so read_message_2 itself may succeed. But the chaining keys diverge, // causing msg3 AEAD decrypt to fail on the responder side. let read2 = initiator.read_message_2(&msg2); if read2.is_err() { // If msg2 processing itself failed, the test passes. return; } // msg2 succeeded — chaining keys now diverge due to wrong ML-KEM ss. // msg3 from initiator will use the wrong key, so responder can't decrypt. let msg3 = initiator.write_message_3().unwrap(); let result = responder.read_message_3(&msg3); assert!(result.is_err(), "msg3 should fail due to ML-KEM shared secret mismatch"); } #[test] fn multiple_transport_messages() { let initiator_kp = NoiseKeypair::generate(); let responder_kp = NoiseKeypair::generate(); let initiator_mlkem_ek = initiator_kp.mlkem_public(); let mut initiator = Initiator::new(initiator_kp); let mut responder = Responder::new(responder_kp); let msg1 = initiator.write_message_1(); responder.read_message_1(&msg1).unwrap(); let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap(); let msg2 = responder .write_message_2(&ie_pk, &initiator_mlkem_ek) .unwrap(); initiator.read_message_2(&msg2).unwrap(); let msg3 = initiator.write_message_3().unwrap(); responder.read_message_3(&msg3).unwrap(); let (mut i_send, mut i_recv) = initiator.finalize(); let (mut r_recv, mut r_send) = responder.finalize(); // Send multiple messages in each direction for i in 0..10u32 { let msg = format!("initiator message {i}"); let ct = i_send.encrypt(msg.as_bytes()).unwrap(); let pt = r_recv.decrypt(&ct).unwrap(); assert_eq!(pt, msg.as_bytes()); let reply = format!("responder reply {i}"); let ct2 = r_send.encrypt(reply.as_bytes()).unwrap(); let pt2 = i_recv.decrypt(&ct2).unwrap(); assert_eq!(pt2, reply.as_bytes()); } } #[test] fn nonce_reuse_detected() { let initiator_kp = NoiseKeypair::generate(); let responder_kp = NoiseKeypair::generate(); let initiator_mlkem_ek = initiator_kp.mlkem_public(); let mut initiator = Initiator::new(initiator_kp); let mut responder = Responder::new(responder_kp); let msg1 = initiator.write_message_1(); responder.read_message_1(&msg1).unwrap(); let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap(); let msg2 = responder .write_message_2(&ie_pk, &initiator_mlkem_ek) .unwrap(); initiator.read_message_2(&msg2).unwrap(); let msg3 = initiator.write_message_3().unwrap(); responder.read_message_3(&msg3).unwrap(); let (mut i_send, _) = initiator.finalize(); let (mut r_recv, _) = responder.finalize(); // Encrypt two messages let ct1 = i_send.encrypt(b"msg1").unwrap(); let _ct2 = i_send.encrypt(b"msg2").unwrap(); // Decrypt in order works r_recv.decrypt(&ct1).unwrap(); // Replaying ct1 (wrong nonce) should fail let result = r_recv.decrypt(&ct1); assert!(result.is_err()); // But ct2 at the right nonce works // (we already consumed nonce 1 trying ct1, so ct2 at nonce 2 fails too) // This tests that the nonce counter prevents replay. } }