From 1d59a052adb997ed6132fa869111cf7356abf364 Mon Sep 17 00:00:00 2001 From: Christian Nennemann Date: Wed, 4 Mar 2026 21:06:31 +0100 Subject: [PATCH] feat(federation): implement v2 inbound federation handlers Replace stub federation handlers with full implementations that accept relay and proxy requests from peer servers. Adds federation_client and local_domain fields to ServerState for outbound relay and federated address resolution. All six handlers (relay_enqueue, relay_batch_enqueue, proxy_fetch_key_package, proxy_fetch_hybrid_key, proxy_resolve_user, federation_health) now validate federation auth, interact with local storage, and wake waiters on message delivery. --- crates/quicproquo-core/src/pq_noise.rs | 689 ++++++++++++++++++ crates/quicproquo-server/src/main.rs | 2 + .../src/v2_handlers/federation.rs | 262 +++++-- .../quicproquo-server/src/v2_handlers/mod.rs | 4 + 4 files changed, 906 insertions(+), 51 deletions(-) create mode 100644 crates/quicproquo-core/src/pq_noise.rs diff --git a/crates/quicproquo-core/src/pq_noise.rs b/crates/quicproquo-core/src/pq_noise.rs new file mode 100644 index 0000000..2412a4a --- /dev/null +++ b/crates/quicproquo-core/src/pq_noise.rs @@ -0,0 +1,689 @@ +//! Hybrid Noise_XX + ML-KEM-768 handshake for post-quantum transport security. +//! +//! Implements a three-message Noise_XX pattern with an embedded ML-KEM-768 +//! encapsulation to produce a hybrid shared secret that is secure against +//! both classical and quantum adversaries. +//! +//! # Handshake pattern +//! +//! ```text +//! XX(s, rs): +//! -> e (initiator ephemeral) +//! <- e, ee, s, es, mlkem_ct (responder ephemeral + static + ML-KEM ciphertext) +//! -> s, se (initiator static) +//! ``` +//! +//! After message 2, the ML-KEM shared secret is mixed into the chaining key +//! via HKDF. The final transport keys incorporate both the X25519 DH chain +//! and the ML-KEM shared secret. +//! +//! # Wire format +//! +//! Each handshake message is a simple length-prefixed blob: +//! ```text +//! [msg_len: u32 BE][handshake message bytes] +//! ``` +//! +//! # Feature gate +//! +//! This module is always compiled but the `pq-noise` feature enables it +//! in the RPC layer for server/client negotiation. + +use chacha20poly1305::{ + aead::{Aead, KeyInit, Payload}, + ChaCha20Poly1305, Key, Nonce, +}; +use hkdf::Hkdf; +use ml_kem::{ + array::Array, + kem::{Decapsulate, Encapsulate}, + EncodedSizeUser, KemCore, MlKem768, MlKem768Params, +}; +use ml_kem::kem::{DecapsulationKey, EncapsulationKey}; +use rand::rngs::OsRng; +use sha2::Sha256; +use x25519_dalek::{PublicKey as X25519Public, StaticSecret}; +use zeroize::Zeroizing; + +use crate::error::CoreError; + +/// Domain separation label for the hybrid Noise handshake. +const PROTOCOL_NAME: &[u8] = b"quicproquo-pq-noise-v1"; + +/// ML-KEM-768 encapsulation key length. +const MLKEM_EK_LEN: usize = 1184; + +/// ML-KEM-768 ciphertext length. +const MLKEM_CT_LEN: usize = 1088; + +/// AEAD tag length (ChaCha20-Poly1305). +const TAG_LEN: usize = 16; + +// ── Keypair ────────────────────────────────────────────────────────────────── + +/// A static keypair for the hybrid Noise handshake. +/// +/// Contains both an X25519 static key and an ML-KEM-768 key pair. +pub struct NoiseKeypair { + x25519_sk: StaticSecret, + x25519_pk: X25519Public, + mlkem_dk: DecapsulationKey, + mlkem_ek: EncapsulationKey, +} + +impl NoiseKeypair { + /// Generate a fresh keypair from OS CSPRNG. + pub fn generate() -> Self { + let x25519_sk = StaticSecret::random_from_rng(OsRng); + let x25519_pk = X25519Public::from(&x25519_sk); + let (mlkem_dk, mlkem_ek) = MlKem768::generate(&mut OsRng); + Self { + x25519_sk, + x25519_pk, + mlkem_dk, + mlkem_ek, + } + } + + /// Return the X25519 public key bytes. + pub fn x25519_public(&self) -> [u8; 32] { + self.x25519_pk.to_bytes() + } + + /// Return the ML-KEM-768 encapsulation key bytes. + pub fn mlkem_public(&self) -> Vec { + self.mlkem_ek.as_bytes().to_vec() + } +} + +// ── Chaining key state ─────────────────────────────────────────────────────── + +/// Internal handshake state tracking the Noise chaining key and handshake hash. +struct HandshakeState { + /// Chaining key — evolved by each MixKey operation. + ck: Zeroizing<[u8; 32]>, + /// Handshake hash — commits to all handshake transcript data. + h: [u8; 32], + /// Current encryption key (derived from ck after MixKey). + k: Option>, + /// Nonce counter for in-handshake encryption. + n: u64, +} + +impl HandshakeState { + fn new() -> Self { + // Initialize h = SHA-256(protocol_name), ck = h. + use sha2::{Digest, Sha256}; + let h: [u8; 32] = Sha256::digest(PROTOCOL_NAME).into(); + Self { + ck: Zeroizing::new(h), + h, + k: None, + n: 0, + } + } + + /// MixHash: h = SHA-256(h || data) + fn mix_hash(&mut self, data: &[u8]) { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(self.h); + hasher.update(data); + self.h = hasher.finalize().into(); + } + + /// MixKey: (ck, k) = HKDF(ck, input_key_material) + fn mix_key(&mut self, ikm: &[u8]) { + let hk = Hkdf::::new(Some(&*self.ck), ikm); + let mut ck = Zeroizing::new([0u8; 32]); + let mut k = Zeroizing::new([0u8; 32]); + hk.expand(b"ck", &mut *ck) + .expect("32 bytes is valid HKDF output"); + hk.expand(b"k", &mut *k) + .expect("32 bytes is valid HKDF output"); + self.ck = ck; + self.k = Some(k); + self.n = 0; + } + + /// Encrypt plaintext with the current key and nonce, using h as AAD. + fn encrypt_and_hash(&mut self, plaintext: &[u8]) -> Result, CoreError> { + let key = self + .k + .as_ref() + .ok_or_else(|| CoreError::Mls("pq_noise: no encryption key set".into()))?; + let cipher = ChaCha20Poly1305::new(Key::from_slice(&**key)); + let nonce = nonce_from_counter(self.n); + let ct = cipher + .encrypt( + Nonce::from_slice(&nonce), + Payload { + msg: plaintext, + aad: &self.h, + }, + ) + .map_err(|_| CoreError::Mls("pq_noise: encrypt failed".into()))?; + self.mix_hash(&ct); + self.n += 1; + Ok(ct) + } + + /// Decrypt ciphertext with the current key and nonce, using h as AAD. + fn decrypt_and_hash(&mut self, ciphertext: &[u8]) -> Result, CoreError> { + let key = self + .k + .as_ref() + .ok_or_else(|| CoreError::Mls("pq_noise: no decryption key set".into()))?; + let cipher = ChaCha20Poly1305::new(Key::from_slice(&**key)); + let nonce = nonce_from_counter(self.n); + let ct_for_hash = ciphertext.to_vec(); + let pt = cipher + .decrypt( + Nonce::from_slice(&nonce), + Payload { + msg: ciphertext, + aad: &self.h, + }, + ) + .map_err(|_| CoreError::Mls("pq_noise: decrypt failed".into()))?; + self.mix_hash(&ct_for_hash); + self.n += 1; + Ok(pt) + } + + /// Split the handshake state into two transport keys (initiator->responder, responder->initiator). + fn split(&self) -> (TransportKey, TransportKey) { + let hk = Hkdf::::new(Some(&*self.ck), &[]); + let mut k1 = Zeroizing::new([0u8; 32]); + let mut k2 = Zeroizing::new([0u8; 32]); + hk.expand(b"initiator", &mut *k1) + .expect("32 bytes is valid HKDF output"); + hk.expand(b"responder", &mut *k2) + .expect("32 bytes is valid HKDF output"); + ( + TransportKey { key: k1, nonce: 0 }, + TransportKey { key: k2, nonce: 0 }, + ) + } +} + +fn nonce_from_counter(n: u64) -> [u8; 12] { + let mut nonce = [0u8; 12]; + nonce[4..].copy_from_slice(&n.to_le_bytes()); + nonce +} + +// ── Transport ──────────────────────────────────────────────────────────────── + +/// A transport encryption key with a nonce counter. +pub struct TransportKey { + key: Zeroizing<[u8; 32]>, + nonce: u64, +} + +impl TransportKey { + /// Encrypt a message for transport. + pub fn encrypt(&mut self, plaintext: &[u8]) -> Result, CoreError> { + let cipher = ChaCha20Poly1305::new(Key::from_slice(&*self.key)); + let nonce = nonce_from_counter(self.nonce); + let ct = cipher + .encrypt(Nonce::from_slice(&nonce), plaintext) + .map_err(|_| CoreError::Mls("pq_noise transport: encrypt failed".into()))?; + self.nonce += 1; + Ok(ct) + } + + /// Decrypt a transport message. + pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result, CoreError> { + let cipher = ChaCha20Poly1305::new(Key::from_slice(&*self.key)); + let nonce = nonce_from_counter(self.nonce); + let pt = cipher + .decrypt(Nonce::from_slice(&nonce), ciphertext) + .map_err(|_| CoreError::Mls("pq_noise transport: decrypt failed".into()))?; + self.nonce += 1; + Ok(pt) + } +} + +// ── Initiator ──────────────────────────────────────────────────────────────── + +/// Initiator side of the hybrid Noise_XX handshake. +pub struct Initiator { + state: HandshakeState, + /// Ephemeral secret stored as StaticSecret so DH doesn't consume it. + /// Generated from OsRng; we use StaticSecret purely for the non-consuming + /// `diffie_hellman(&self, ...)` API — the key is still ephemeral. + e_sk: StaticSecret, + e_pk: X25519Public, + s: NoiseKeypair, + /// Stored after reading message 2 so we can compute se = DH(s, re) in msg3. + re_pk: Option, +} + +impl Initiator { + /// Create a new initiator with the given static keypair. + pub fn new(static_keypair: NoiseKeypair) -> Self { + let e_sk = StaticSecret::random_from_rng(OsRng); + let e_pk = X25519Public::from(&e_sk); + Self { + state: HandshakeState::new(), + e_sk, + e_pk, + s: static_keypair, + re_pk: None, + } + } + + /// Write message 1: `-> e` + /// + /// Returns the initiator's ephemeral X25519 public key (32 bytes). + pub fn write_message_1(&mut self) -> Vec { + let e_pk_bytes = self.e_pk.to_bytes(); + self.state.mix_hash(&e_pk_bytes); + e_pk_bytes.to_vec() + } + + /// Read message 2 from responder: `<- e, ee, s, es, mlkem_ct` + /// + /// Expects: `re_pk(32) || encrypted_rs_pk(32+TAG) || mlkem_ct(1088)` + /// + /// Returns the responder's static X25519 public key. + pub fn read_message_2(&mut self, msg: &[u8]) -> Result<[u8; 32], CoreError> { + let expected_len = 32 + 32 + TAG_LEN + MLKEM_CT_LEN; + if msg.len() != expected_len { + return Err(CoreError::Mls(format!( + "pq_noise msg2: expected {expected_len} bytes, got {}", + msg.len() + ))); + } + + let mut cursor = 0; + + // re = responder ephemeral public key + let mut re_pk_bytes = [0u8; 32]; + re_pk_bytes.copy_from_slice(&msg[cursor..cursor + 32]); + cursor += 32; + let re_pk = X25519Public::from(re_pk_bytes); + self.state.mix_hash(&re_pk_bytes); + self.re_pk = Some(re_pk); + + // ee = DH(e, re) + let ee_ss = self.e_sk.diffie_hellman(&re_pk); + self.state.mix_key(ee_ss.as_bytes()); + + // Decrypt responder's static key: s = Dec(encrypted_rs_pk) + let encrypted_rs = &msg[cursor..cursor + 32 + TAG_LEN]; + cursor += 32 + TAG_LEN; + let rs_pk_bytes = self.state.decrypt_and_hash(encrypted_rs)?; + let mut rs_pk_arr = [0u8; 32]; + if rs_pk_bytes.len() != 32 { + return Err(CoreError::Mls("pq_noise: decrypted rs not 32 bytes".into())); + } + rs_pk_arr.copy_from_slice(&rs_pk_bytes); + let rs_pk = X25519Public::from(rs_pk_arr); + + // es = DH(e, rs) + let es_ss = self.e_sk.diffie_hellman(&rs_pk); + self.state.mix_key(es_ss.as_bytes()); + + // ML-KEM: decapsulate the ciphertext from the responder + let mlkem_ct = &msg[cursor..cursor + MLKEM_CT_LEN]; + let mlkem_ct_arr = Array::try_from(mlkem_ct) + .map_err(|_| CoreError::Mls("pq_noise: invalid ML-KEM ciphertext".into()))?; + let mlkem_ss: ml_kem::SharedKey = self + .s + .mlkem_dk + .decapsulate(&mlkem_ct_arr) + .map_err(|_| CoreError::Mls("pq_noise: ML-KEM decapsulation failed".into()))?; + self.state.mix_key(&mlkem_ss); + + Ok(rs_pk_arr) + } + + /// Write message 3: `-> s, se` + /// + /// Returns the encrypted initiator static key. + pub fn write_message_3(&mut self) -> Result, CoreError> { + let re_pk = self + .re_pk + .ok_or_else(|| CoreError::Mls("pq_noise: must read msg2 before writing msg3".into()))?; + + // Encrypt our static key + let s_pk_bytes = self.s.x25519_pk.to_bytes(); + let encrypted_s = self.state.encrypt_and_hash(&s_pk_bytes)?; + + // se = DH(s, re) + let se_ss = self.s.x25519_sk.diffie_hellman(&re_pk); + self.state.mix_key(se_ss.as_bytes()); + + Ok(encrypted_s) + } + + /// Finalize the handshake and return transport keys. + /// + /// Returns `(send_key, recv_key)` — initiator sends with send_key. + pub fn finalize(self) -> (TransportKey, TransportKey) { + self.state.split() + } +} + +// ── Responder ──────────────────────────────────────────────────────────────── + +/// Responder side of the hybrid Noise_XX handshake. +pub struct Responder { + state: HandshakeState, + /// Ephemeral secret stored as StaticSecret so DH doesn't consume it. + e_sk: StaticSecret, + e_pk: X25519Public, + s: NoiseKeypair, +} + +impl Responder { + /// Create a new responder with the given static keypair. + pub fn new(static_keypair: NoiseKeypair) -> Self { + let e_sk = StaticSecret::random_from_rng(OsRng); + let e_pk = X25519Public::from(&e_sk); + Self { + state: HandshakeState::new(), + e_sk, + e_pk, + s: static_keypair, + } + } + + /// Read message 1 from initiator: `-> e` + /// + /// Expects the initiator's ephemeral X25519 public key (32 bytes). + pub fn read_message_1(&mut self, msg: &[u8]) -> Result<(), CoreError> { + if msg.len() != 32 { + return Err(CoreError::Mls(format!( + "pq_noise msg1: expected 32 bytes, got {}", + msg.len() + ))); + } + self.state.mix_hash(msg); + Ok(()) + } + + /// Write message 2: `<- e, ee, s, es, mlkem_ct` + /// + /// `initiator_ek` is the initiator's ML-KEM encapsulation key. + /// + /// Returns the message bytes. + pub fn write_message_2( + &mut self, + initiator_e_pk: &[u8; 32], + initiator_mlkem_ek: &[u8], + ) -> Result, CoreError> { + let ie_pk = X25519Public::from(*initiator_e_pk); + + // Our ephemeral key + let e_pk_bytes = self.e_pk.to_bytes(); + self.state.mix_hash(&e_pk_bytes); + + // ee = DH(e, ie) + let ee_ss = self.e_sk.diffie_hellman(&ie_pk); + self.state.mix_key(ee_ss.as_bytes()); + + // Encrypt our static key + let s_pk_bytes = self.s.x25519_pk.to_bytes(); + let encrypted_s = self.state.encrypt_and_hash(&s_pk_bytes)?; + + // es = DH(s, ie) + let es_ss = self.s.x25519_sk.diffie_hellman(&ie_pk); + self.state.mix_key(es_ss.as_bytes()); + + // ML-KEM: encapsulate to the initiator's encapsulation key + if initiator_mlkem_ek.len() != MLKEM_EK_LEN { + return Err(CoreError::Mls(format!( + "pq_noise: expected ML-KEM EK {} bytes, got {}", + MLKEM_EK_LEN, + initiator_mlkem_ek.len() + ))); + } + let ek_arr = Array::try_from(initiator_mlkem_ek) + .map_err(|_| CoreError::Mls("pq_noise: invalid ML-KEM encapsulation key".into()))?; + let ek = EncapsulationKey::::from_bytes(&ek_arr); + let (mlkem_ct, mlkem_ss): (ml_kem::Ciphertext, ml_kem::SharedKey) = ek + .encapsulate(&mut OsRng) + .map_err(|_| CoreError::Mls("pq_noise: ML-KEM encapsulation failed".into()))?; + self.state.mix_key(&mlkem_ss); + + // Assemble: e_pk || encrypted_s || mlkem_ct + let mut out = Vec::with_capacity(32 + encrypted_s.len() + MLKEM_CT_LEN); + out.extend_from_slice(&e_pk_bytes); + out.extend_from_slice(&encrypted_s); + out.extend_from_slice(&mlkem_ct); + Ok(out) + } + + /// Read message 3 from initiator: `-> s, se` + /// + /// Returns the initiator's static X25519 public key. + pub fn read_message_3(&mut self, msg: &[u8]) -> Result<[u8; 32], CoreError> { + if msg.len() != 32 + TAG_LEN { + return Err(CoreError::Mls(format!( + "pq_noise msg3: expected {} bytes, got {}", + 32 + TAG_LEN, + msg.len() + ))); + } + + // Decrypt initiator's static key + let is_pk_bytes = self.state.decrypt_and_hash(msg)?; + let mut is_pk_arr = [0u8; 32]; + if is_pk_bytes.len() != 32 { + return Err(CoreError::Mls( + "pq_noise: decrypted initiator static not 32 bytes".into(), + )); + } + is_pk_arr.copy_from_slice(&is_pk_bytes); + let is_pk = X25519Public::from(is_pk_arr); + + // se = DH(e, is) — responder computes using ephemeral key + let se_ss = self.e_sk.diffie_hellman(&is_pk); + self.state.mix_key(se_ss.as_bytes()); + + Ok(is_pk_arr) + } + + /// Finalize the handshake and return transport keys. + /// + /// Returns `(recv_key, send_key)` — responder receives with recv_key. + pub fn finalize(self) -> (TransportKey, TransportKey) { + let (i2r, r2i) = self.state.split(); + (i2r, r2i) + } +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + #[test] + fn full_handshake_round_trip() { + let initiator_kp = NoiseKeypair::generate(); + let responder_kp = NoiseKeypair::generate(); + + // Initiator's ML-KEM public key is sent out-of-band (or in a pre-message). + let initiator_mlkem_ek = initiator_kp.mlkem_public(); + + let mut initiator = Initiator::new(initiator_kp); + let mut responder = Responder::new(responder_kp); + + // Message 1: initiator -> responder + let msg1 = initiator.write_message_1(); + assert_eq!(msg1.len(), 32); + responder.read_message_1(&msg1).unwrap(); + + // Message 2: responder -> initiator + let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap(); + let msg2 = responder + .write_message_2(&ie_pk, &initiator_mlkem_ek) + .unwrap(); + let _responder_static = initiator.read_message_2(&msg2).unwrap(); + + // Message 3: initiator -> responder + let msg3 = initiator.write_message_3().unwrap(); + let _initiator_static = responder.read_message_3(&msg3).unwrap(); + + // Derive transport keys + let (mut i_send, mut i_recv) = initiator.finalize(); + let (mut r_recv, mut r_send) = responder.finalize(); + + // Test transport: initiator -> responder + let plaintext = b"hello post-quantum world!"; + let ct = i_send.encrypt(plaintext).unwrap(); + let pt = r_recv.decrypt(&ct).unwrap(); + assert_eq!(pt, plaintext); + + // Test transport: responder -> initiator + let plaintext2 = b"reply from responder"; + let ct2 = r_send.encrypt(plaintext2).unwrap(); + let pt2 = i_recv.decrypt(&ct2).unwrap(); + assert_eq!(pt2, plaintext2); + } + + #[test] + fn tampered_msg2_fails() { + let initiator_kp = NoiseKeypair::generate(); + let responder_kp = NoiseKeypair::generate(); + let initiator_mlkem_ek = initiator_kp.mlkem_public(); + + let mut initiator = Initiator::new(initiator_kp); + let mut responder = Responder::new(responder_kp); + + let msg1 = initiator.write_message_1(); + responder.read_message_1(&msg1).unwrap(); + + let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap(); + let mut msg2 = responder + .write_message_2(&ie_pk, &initiator_mlkem_ek) + .unwrap(); + + // Tamper with the encrypted static key region + msg2[40] ^= 0xFF; + + let result = initiator.read_message_2(&msg2); + assert!(result.is_err()); + } + + #[test] + fn wrong_mlkem_key_fails() { + let initiator_kp = NoiseKeypair::generate(); + let responder_kp = NoiseKeypair::generate(); + + // Use a different keypair's ML-KEM key — decapsulation will use + // implicit rejection, producing a pseudorandom (wrong) shared secret. + let wrong_kp = NoiseKeypair::generate(); + let wrong_mlkem_ek = wrong_kp.mlkem_public(); + + let mut initiator = Initiator::new(initiator_kp); + let mut responder = Responder::new(responder_kp); + + let msg1 = initiator.write_message_1(); + responder.read_message_1(&msg1).unwrap(); + + let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap(); + let msg2 = responder + .write_message_2(&ie_pk, &wrong_mlkem_ek) + .unwrap(); + + // ML-KEM implicit rejection: decap succeeds but returns wrong ss. + // The ML-KEM mix_key happens after the AEAD decrypt of the static key, + // so read_message_2 itself may succeed. But the chaining keys diverge, + // causing msg3 AEAD decrypt to fail on the responder side. + let read2 = initiator.read_message_2(&msg2); + if read2.is_err() { + // If msg2 processing itself failed, the test passes. + return; + } + + // msg2 succeeded — chaining keys now diverge due to wrong ML-KEM ss. + // msg3 from initiator will use the wrong key, so responder can't decrypt. + let msg3 = initiator.write_message_3().unwrap(); + let result = responder.read_message_3(&msg3); + assert!(result.is_err(), "msg3 should fail due to ML-KEM shared secret mismatch"); + } + + #[test] + fn multiple_transport_messages() { + let initiator_kp = NoiseKeypair::generate(); + let responder_kp = NoiseKeypair::generate(); + let initiator_mlkem_ek = initiator_kp.mlkem_public(); + + let mut initiator = Initiator::new(initiator_kp); + let mut responder = Responder::new(responder_kp); + + let msg1 = initiator.write_message_1(); + responder.read_message_1(&msg1).unwrap(); + + let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap(); + let msg2 = responder + .write_message_2(&ie_pk, &initiator_mlkem_ek) + .unwrap(); + initiator.read_message_2(&msg2).unwrap(); + + let msg3 = initiator.write_message_3().unwrap(); + responder.read_message_3(&msg3).unwrap(); + + let (mut i_send, mut i_recv) = initiator.finalize(); + let (mut r_recv, mut r_send) = responder.finalize(); + + // Send multiple messages in each direction + for i in 0..10u32 { + let msg = format!("initiator message {i}"); + let ct = i_send.encrypt(msg.as_bytes()).unwrap(); + let pt = r_recv.decrypt(&ct).unwrap(); + assert_eq!(pt, msg.as_bytes()); + + let reply = format!("responder reply {i}"); + let ct2 = r_send.encrypt(reply.as_bytes()).unwrap(); + let pt2 = i_recv.decrypt(&ct2).unwrap(); + assert_eq!(pt2, reply.as_bytes()); + } + } + + #[test] + fn nonce_reuse_detected() { + let initiator_kp = NoiseKeypair::generate(); + let responder_kp = NoiseKeypair::generate(); + let initiator_mlkem_ek = initiator_kp.mlkem_public(); + + let mut initiator = Initiator::new(initiator_kp); + let mut responder = Responder::new(responder_kp); + + let msg1 = initiator.write_message_1(); + responder.read_message_1(&msg1).unwrap(); + + let ie_pk: [u8; 32] = msg1.as_slice().try_into().unwrap(); + let msg2 = responder + .write_message_2(&ie_pk, &initiator_mlkem_ek) + .unwrap(); + initiator.read_message_2(&msg2).unwrap(); + + let msg3 = initiator.write_message_3().unwrap(); + responder.read_message_3(&msg3).unwrap(); + + let (mut i_send, _) = initiator.finalize(); + let (mut r_recv, _) = responder.finalize(); + + // Encrypt two messages + let ct1 = i_send.encrypt(b"msg1").unwrap(); + let _ct2 = i_send.encrypt(b"msg2").unwrap(); + + // Decrypt in order works + r_recv.decrypt(&ct1).unwrap(); + + // Replaying ct1 (wrong nonce) should fail + let result = r_recv.decrypt(&ct1); + assert!(result.is_err()); + + // But ct2 at the right nonce works + // (we already consumed nonce 1 trying ct1, so ct2 at nonce 2 fails too) + // This tests that the nonce counter prevents replay. + } +} diff --git a/crates/quicproquo-server/src/main.rs b/crates/quicproquo-server/src/main.rs index cead67f..ea17e49 100644 --- a/crates/quicproquo-server/src/main.rs +++ b/crates/quicproquo-server/src/main.rs @@ -412,6 +412,8 @@ async fn main() -> anyhow::Result<()> { node_id: format!("wt-{}", hex::encode(&signing_key.public_key_bytes()[..4])), start_time: std::time::Instant::now(), storage_backend: effective.store_backend.clone(), + federation_client: None, + local_domain: effective.federation.as_ref().map(|f| f.domain.clone()).unwrap_or_default(), }); let wt_registry = Arc::new(v2_handlers::build_registry( diff --git a/crates/quicproquo-server/src/v2_handlers/federation.rs b/crates/quicproquo-server/src/v2_handlers/federation.rs index e77a605..168fd33 100644 --- a/crates/quicproquo-server/src/v2_handlers/federation.rs +++ b/crates/quicproquo-server/src/v2_handlers/federation.rs @@ -1,7 +1,9 @@ -//! Federation handlers — stubs returning Unimplemented. +//! Federation v2 RPC handlers — relay, proxy, and health. //! -//! These will be wired to actual federation logic when the federation -//! subsystem is migrated to the v2 RPC framework. +//! Implements the inbound side of server-to-server federation: accepts relay +//! and proxy requests from peer servers and delegates to local storage. +//! Outbound relay to remote peers is handled by the capnp-based +//! `FederationClient` on the main connection path. use std::sync::Arc; @@ -11,57 +13,215 @@ use quicproquo_proto::qpq::v1; use quicproquo_rpc::error::RpcStatus; use quicproquo_rpc::method::{HandlerResult, RequestContext}; +use crate::federation::address::FederatedAddress; + use super::ServerState; -fn unimplemented(name: &str) -> HandlerResult { - HandlerResult::err( - RpcStatus::Internal, - &format!("{name}: federation not yet implemented in v2"), - ) +/// Validate that the request carries a valid federation auth origin. +fn validate_federation_auth(auth: &Option) -> Result { + let a = auth.as_ref().ok_or_else(|| { + HandlerResult::err(RpcStatus::Unauthorized, "missing federation auth") + })?; + if a.origin.is_empty() { + return Err(HandlerResult::err( + RpcStatus::Unauthorized, + "federation auth origin must not be empty", + )); + } + Ok(a.origin.clone()) } -pub async fn handle_relay_enqueue( - _state: Arc, - _ctx: RequestContext, -) -> HandlerResult { - unimplemented("RelayEnqueue") -} - -pub async fn handle_relay_batch_enqueue( - _state: Arc, - _ctx: RequestContext, -) -> HandlerResult { - unimplemented("RelayBatchEnqueue") -} - -pub async fn handle_proxy_fetch_key_package( - _state: Arc, - _ctx: RequestContext, -) -> HandlerResult { - unimplemented("ProxyFetchKeyPackage") -} - -pub async fn handle_proxy_fetch_hybrid_key( - _state: Arc, - _ctx: RequestContext, -) -> HandlerResult { - unimplemented("ProxyFetchHybridKey") -} - -pub async fn handle_proxy_resolve_user( - _state: Arc, - _ctx: RequestContext, -) -> HandlerResult { - unimplemented("ProxyResolveUser") -} - -pub async fn handle_federation_health( - _state: Arc, - _ctx: RequestContext, -) -> HandlerResult { - let proto = v1::FederationHealthResponse { - status: "ok".into(), - server_domain: String::new(), +/// Relay a single message to a local recipient. +/// +/// This handler is called by peer servers to deliver messages to users +/// homed on this server. If the recipient is not local, returns NotFound +/// (the originating server should route directly to the correct home server). +pub async fn handle_relay_enqueue(state: Arc, ctx: RequestContext) -> HandlerResult { + let req = match v1::RelayEnqueueRequest::decode(ctx.payload) { + Ok(r) => r, + Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")), }; - HandlerResult::ok(Bytes::from(proto.encode_to_vec())) + + let origin = match validate_federation_auth(&req.auth) { + Ok(o) => o, + Err(e) => return e, + }; + + if req.recipient_key.len() != 32 { + return HandlerResult::err(RpcStatus::BadRequest, "recipient_key must be 32 bytes"); + } + if req.payload.is_empty() { + return HandlerResult::err(RpcStatus::BadRequest, "payload must not be empty"); + } + + match state + .store + .enqueue(&req.recipient_key, &req.channel_id, req.payload, None) + { + Ok(seq) => { + if let Some(waiter) = state.waiters.get(&req.recipient_key) { + waiter.notify_waiters(); + } + + tracing::info!( + origin = %origin, + recipient_prefix = %hex::encode(&req.recipient_key[..4]), + seq = seq, + "federation: relayed enqueue" + ); + + let resp = v1::RelayEnqueueResponse { seq }; + HandlerResult::ok(Bytes::from(resp.encode_to_vec())) + } + Err(e) => HandlerResult::err(RpcStatus::Internal, &format!("store error: {e}")), + } +} + +/// Relay a batch of messages to local recipients. +pub async fn handle_relay_batch_enqueue( + state: Arc, + ctx: RequestContext, +) -> HandlerResult { + let req = match v1::RelayBatchEnqueueRequest::decode(ctx.payload) { + Ok(r) => r, + Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")), + }; + + let _origin = match validate_federation_auth(&req.auth) { + Ok(o) => o, + Err(e) => return e, + }; + + if req.payload.is_empty() { + return HandlerResult::err(RpcStatus::BadRequest, "payload must not be empty"); + } + + let mut seqs = Vec::with_capacity(req.recipient_keys.len()); + for rk in &req.recipient_keys { + if rk.len() != 32 { + return HandlerResult::err( + RpcStatus::BadRequest, + "each recipient_key must be 32 bytes", + ); + } + match state + .store + .enqueue(rk, &req.channel_id, req.payload.clone(), None) + { + Ok(seq) => { + if let Some(waiter) = state.waiters.get(rk.as_slice()) { + waiter.notify_waiters(); + } + seqs.push(seq); + } + Err(e) => { + return HandlerResult::err(RpcStatus::Internal, &format!("store error: {e}")) + } + } + } + + tracing::info!( + recipient_count = req.recipient_keys.len(), + "federation: relayed batch_enqueue" + ); + + let resp = v1::RelayBatchEnqueueResponse { seqs }; + HandlerResult::ok(Bytes::from(resp.encode_to_vec())) +} + +/// Proxy a key package fetch from local storage. +pub async fn handle_proxy_fetch_key_package( + state: Arc, + ctx: RequestContext, +) -> HandlerResult { + let req = match v1::ProxyFetchKeyPackageRequest::decode(ctx.payload) { + Ok(r) => r, + Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")), + }; + + let _origin = match validate_federation_auth(&req.auth) { + Ok(o) => o, + Err(e) => return e, + }; + + let package = match state.store.fetch_key_package(&req.identity_key) { + Ok(pkg) => pkg.unwrap_or_default(), + Err(e) => return HandlerResult::err(RpcStatus::Internal, &format!("store error: {e}")), + }; + + let resp = v1::ProxyFetchKeyPackageResponse { package }; + HandlerResult::ok(Bytes::from(resp.encode_to_vec())) +} + +/// Proxy a hybrid key fetch from local storage. +pub async fn handle_proxy_fetch_hybrid_key( + state: Arc, + ctx: RequestContext, +) -> HandlerResult { + let req = match v1::ProxyFetchHybridKeyRequest::decode(ctx.payload) { + Ok(r) => r, + Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")), + }; + + let _origin = match validate_federation_auth(&req.auth) { + Ok(o) => o, + Err(e) => return e, + }; + + let hybrid_public_key = match state.store.fetch_hybrid_key(&req.identity_key) { + Ok(pk) => pk.unwrap_or_default(), + Err(e) => return HandlerResult::err(RpcStatus::Internal, &format!("store error: {e}")), + }; + + let resp = v1::ProxyFetchHybridKeyResponse { hybrid_public_key }; + HandlerResult::ok(Bytes::from(resp.encode_to_vec())) +} + +/// Proxy a user resolution from local storage. +/// +/// Supports federated `user@domain` addresses: if the domain matches the +/// local server, the local user is resolved; otherwise returns empty. +pub async fn handle_proxy_resolve_user( + state: Arc, + ctx: RequestContext, +) -> HandlerResult { + let req = match v1::ProxyResolveUserRequest::decode(ctx.payload) { + Ok(r) => r, + Err(e) => return HandlerResult::err(RpcStatus::BadRequest, &format!("decode: {e}")), + }; + + let _origin = match validate_federation_auth(&req.auth) { + Ok(o) => o, + Err(e) => return e, + }; + + let addr = FederatedAddress::parse(&req.username); + let is_local = addr.is_local(&state.local_domain); + + let identity_key = if is_local { + match state.store.get_user_identity_key(&addr.username) { + Ok(key) => key.unwrap_or_default(), + Err(e) => { + return HandlerResult::err(RpcStatus::Internal, &format!("store error: {e}")) + } + } + } else { + // Remote user: not on this server. Return empty. + Vec::new() + }; + + let resp = v1::ProxyResolveUserResponse { identity_key }; + HandlerResult::ok(Bytes::from(resp.encode_to_vec())) +} + +/// Federation health check — returns ok status and this server's domain. +pub async fn handle_federation_health( + state: Arc, + _ctx: RequestContext, +) -> HandlerResult { + let resp = v1::FederationHealthResponse { + status: "ok".into(), + server_domain: state.local_domain.clone(), + }; + HandlerResult::ok(Bytes::from(resp.encode_to_vec())) } diff --git a/crates/quicproquo-server/src/v2_handlers/mod.rs b/crates/quicproquo-server/src/v2_handlers/mod.rs index d1e41a0..525ba2f 100644 --- a/crates/quicproquo-server/src/v2_handlers/mod.rs +++ b/crates/quicproquo-server/src/v2_handlers/mod.rs @@ -64,6 +64,10 @@ pub struct ServerState { pub start_time: std::time::Instant, /// Storage backend name (e.g. "sql", "file"). pub storage_backend: String, + /// Federation client for outbound server-to-server relay. None when federation is disabled. + pub federation_client: Option>, + /// This server's domain for federation addressing. Empty when federation is disabled. + pub local_domain: String, } /// A ban record for a user.