//! Message padding to hide plaintext lengths from the server. //! //! Pads payloads to fixed bucket sizes before MLS encryption so that the //! ciphertext does not reveal the actual message length. //! //! # Wire format //! //! ```text //! [real_length: 4 bytes LE (u32)][payload: real_length bytes][random padding] //! ``` //! //! The total padded output is always one of the bucket sizes: 256, 1024, 4096, 16384 bytes. //! For payloads larger than 16380 bytes, rounds up to the nearest 16384-byte multiple. use rand::RngCore; use crate::error::CoreError; /// Bucket sizes in bytes. The smallest (256) accommodates a sealed sender /// envelope (99 bytes overhead) plus a short message. const BUCKETS: &[usize] = &[256, 1024, 4096, 16384]; /// Select the smallest bucket that fits `content_len + 4` (the 4-byte length prefix). fn bucket_for(content_len: usize) -> usize { let total = content_len + 4; for &b in BUCKETS { if total <= b { return b; } } // Larger than biggest bucket: round up to nearest 16384-byte multiple. total.div_ceil(16384) * 16384 } /// Pad a payload to the next bucket boundary with cryptographic random bytes. pub fn pad(payload: &[u8]) -> Vec { let bucket = bucket_for(payload.len()); let mut out = Vec::with_capacity(bucket); out.extend_from_slice(&(payload.len() as u32).to_le_bytes()); out.extend_from_slice(payload); let pad_len = bucket - 4 - payload.len(); if pad_len > 0 { let mut padding = vec![0u8; pad_len]; rand::rngs::OsRng.fill_bytes(&mut padding); out.extend_from_slice(&padding); } out } /// Remove padding and return the original payload. pub fn unpad(padded: &[u8]) -> Result, CoreError> { if padded.len() < 4 { return Err(CoreError::AppMessage("padded message too short".into())); } let real_len = u32::from_le_bytes([padded[0], padded[1], padded[2], padded[3]]) as usize; if 4 + real_len > padded.len() { return Err(CoreError::AppMessage( "padded real_length exceeds buffer".into(), )); } Ok(padded[4..4 + real_len].to_vec()) } #[cfg(test)] mod tests { use super::*; #[test] fn round_trip_small() { let msg = b"hello"; let padded = pad(msg); assert_eq!(padded.len(), 256); // smallest bucket let unpadded = unpad(&padded).unwrap(); assert_eq!(unpadded, msg); } #[test] fn round_trip_medium() { let msg = vec![0xAB; 300]; let padded = pad(&msg); assert_eq!(padded.len(), 1024); // second bucket let unpadded = unpad(&padded).unwrap(); assert_eq!(unpadded, msg); } #[test] fn round_trip_large() { let msg = vec![0xCD; 2000]; let padded = pad(&msg); assert_eq!(padded.len(), 4096); // third bucket let unpadded = unpad(&padded).unwrap(); assert_eq!(unpadded, msg); } #[test] fn round_trip_very_large() { let msg = vec![0xEF; 10000]; let padded = pad(&msg); assert_eq!(padded.len(), 16384); // largest bucket let unpadded = unpad(&padded).unwrap(); assert_eq!(unpadded, msg); } #[test] fn round_trip_oversized() { let msg = vec![0xFF; 20000]; let padded = pad(&msg); assert_eq!(padded.len(), 32768); // 2 * 16384 let unpadded = unpad(&padded).unwrap(); assert_eq!(unpadded, msg); } #[test] fn round_trip_empty() { let msg = b""; let padded = pad(msg); assert_eq!(padded.len(), 256); // smallest bucket let unpadded = unpad(&padded).unwrap(); assert_eq!(unpadded, msg); } #[test] fn exactly_at_bucket_boundary() { // 252 + 4 = 256 → fits in 256 bucket exactly let msg = vec![0x42; 252]; let padded = pad(&msg); assert_eq!(padded.len(), 256); let unpadded = unpad(&padded).unwrap(); assert_eq!(unpadded, msg); } #[test] fn unpad_too_short_fails() { assert!(unpad(&[0, 0]).is_err()); } #[test] fn unpad_invalid_length_fails() { // Claims 1000 bytes but only has 10 let mut bad = (1000u32).to_le_bytes().to_vec(); bad.extend_from_slice(&[0u8; 10]); assert!(unpad(&bad).is_err()); } }