//! Messaging pipeline: send and receive messages through the MLS + sealed sender //! + hybrid KEM stack. //! //! This module wraps the full encryption pipeline: //! 1. **Send**: serialize → MLS encrypt → sealed sender → hybrid wrap → enqueue //! 2. **Receive**: fetch → hybrid unwrap → MLS decrypt → unseal → parse use bytes::Bytes; use prost::Message; use tracing::debug; use quicproquo_core::{ AppMessage, GroupMember, HybridKeypair, HybridPublicKey, IdentityKeypair, ReceivedMessage, }; use quicproquo_proto::method_ids; use quicproquo_proto::qpq::v1::{ AckRequest, AckResponse, BatchEnqueueRequest, BatchEnqueueResponse, EnqueueRequest, EnqueueResponse, FetchRequest, FetchResponse, FetchWaitRequest, FetchWaitResponse, }; use quicproquo_rpc::client::RpcClient; use crate::error::SdkError; // ── Types ───────────────────────────────────────────────────────────────────── /// A successfully decrypted application message with sender info. #[derive(Debug)] pub struct ReceivedPlaintext { /// Sender's Ed25519 identity key (from sealed sender envelope). pub sender_key: [u8; 32], /// The parsed application message (Chat, Reply, Reaction, etc.). pub message: AppMessage, /// Server-assigned sequence number. pub seq: u64, } /// Default TTL for enqueued messages (24 hours). const DEFAULT_TTL_SECS: u32 = 86400; // ── Send Pipeline ───────────────────────────────────────────────────────────── /// Encrypt and send a message to a conversation. /// /// Pipeline: generate_message_id → serialize → MLS encrypt → seal → per-recipient /// hybrid wrap → batch enqueue. /// /// Returns the server-assigned sequence numbers (one per recipient). pub async fn send_message( rpc: &RpcClient, member: &mut GroupMember, identity: &IdentityKeypair, body: &str, recipient_keys: &[Vec], hybrid_keys: &[Option], channel_id: &[u8], ) -> Result, SdkError> { // 1. Generate message ID. let message_id = quicproquo_core::generate_message_id(); // 2. Serialize application payload. let serialized = quicproquo_core::serialize_chat(body.as_bytes(), Some(message_id)) .map_err(|e| SdkError::Crypto(format!("serialize_chat: {e}")))?; // 3. MLS encrypt. let mls_ciphertext = member .send_message(&serialized) .map_err(|e| SdkError::Crypto(format!("MLS encrypt: {e}")))?; // 4. Sealed sender wrap. let sealed = quicproquo_core::sealed_sender::seal(identity, &mls_ciphertext); // 5. Per-recipient hybrid wrap + enqueue. // If all recipients can share the same payload (no hybrid keys), use batch enqueue. // Otherwise, enqueue individually with per-recipient hybrid wrapping. let all_no_hybrid = hybrid_keys.iter().all(|k| k.is_none()); if all_no_hybrid { // Batch enqueue — same payload for all recipients. let seqs = batch_enqueue(rpc, recipient_keys, channel_id, &sealed, DEFAULT_TTL_SECS).await?; debug!(count = seqs.len(), "batch enqueue complete"); Ok(seqs) } else { // Per-recipient enqueue with optional hybrid wrapping. let mut seqs = Vec::with_capacity(recipient_keys.len()); for (i, recipient_key) in recipient_keys.iter().enumerate() { let payload = if let Some(Some(ref pk)) = hybrid_keys.get(i) { quicproquo_core::hybrid_encrypt(pk, &sealed, b"", b"") .map_err(|e| SdkError::Crypto(format!("hybrid encrypt: {e}")))? } else { sealed.clone() }; let seq = enqueue(rpc, recipient_key, channel_id, &payload, DEFAULT_TTL_SECS).await?; seqs.push(seq); } debug!(count = seqs.len(), "per-recipient enqueue complete"); Ok(seqs) } } // ── Receive Pipeline ────────────────────────────────────────────────────────── /// Receive and decrypt pending messages from the server. /// /// Pipeline: fetch → sort by seq → for each: hybrid unwrap → MLS decrypt → /// unseal → parse. Includes retry loop for multi-epoch batches where commits /// must apply before application messages can be decrypted. pub async fn receive_messages( rpc: &RpcClient, member: &mut GroupMember, my_identity_key: &[u8], hybrid_kp: Option<&HybridKeypair>, channel_id: &[u8], device_id: &[u8], ) -> Result, SdkError> { let payloads = fetch(rpc, my_identity_key, channel_id, 0, device_id).await?; process_payloads(member, hybrid_kp, payloads) } /// Long-poll for new messages with timeout. /// /// Same pipeline as [`receive_messages`] but uses the FETCH_WAIT RPC which /// blocks server-side until messages arrive or the timeout expires. pub async fn receive_messages_wait( rpc: &RpcClient, member: &mut GroupMember, my_identity_key: &[u8], hybrid_kp: Option<&HybridKeypair>, channel_id: &[u8], timeout_ms: u64, device_id: &[u8], ) -> Result, SdkError> { let payloads = fetch_wait(rpc, my_identity_key, channel_id, timeout_ms, device_id).await?; process_payloads(member, hybrid_kp, payloads) } /// Shared processing logic for received payloads. /// /// Sorts by sequence number, then processes each payload through the decryption /// pipeline. Uses a retry loop to handle multi-epoch batches where MLS commits /// must be applied before subsequent application messages can be decrypted. fn process_payloads( member: &mut GroupMember, hybrid_kp: Option<&HybridKeypair>, mut payloads: Vec<(u64, Vec)>, ) -> Result, SdkError> { if payloads.is_empty() { return Ok(Vec::new()); } // Sort by server-assigned sequence number — commits must arrive before // application messages that depend on the resulting epoch. payloads.sort_by_key(|(seq, _)| *seq); let mut results = Vec::new(); let mut pending: Vec<(u64, Vec)> = Vec::new(); for (seq, raw_payload) in &payloads { // (a) Try hybrid decrypt; fall back to raw bytes if not hybrid-wrapped. let mls_bytes = try_hybrid_unwrap(hybrid_kp, raw_payload); // (b) MLS decrypt. match member.receive_message(&mls_bytes) { Ok(ReceivedMessage::Application(plaintext)) => { if let Some(rp) = try_unseal_and_parse(*seq, &plaintext) { results.push(rp); } } Ok(ReceivedMessage::StateChanged | ReceivedMessage::SelfRemoved) => { debug!(seq, "commit/state-change applied"); } Err(_) => { // MLS decryption failed — likely an epoch mismatch. // Stash for retry after commits are applied. pending.push((*seq, mls_bytes)); } } } // Retry loop: keep retrying pending messages until no more progress. // This handles multi-epoch batches where commits must apply first. loop { let before = pending.len(); pending.retain_mut(|(seq, mls_bytes)| { match member.receive_message(mls_bytes) { Ok(ReceivedMessage::Application(plaintext)) => { if let Some(rp) = try_unseal_and_parse(*seq, &plaintext) { results.push(rp); } false // processed } Ok(ReceivedMessage::StateChanged | ReceivedMessage::SelfRemoved) => { debug!(seq, "commit applied (retry)"); false // processed } Err(_) => true, // still pending } }); if pending.len() == before { break; // no progress — remaining messages are unprocessable } } if !pending.is_empty() { debug!( remaining = pending.len(), "unprocessable messages after all retries" ); } Ok(results) } /// Try to hybrid-decrypt a payload. If the caller has a hybrid keypair, attempt /// decryption. If it fails (payload might not be hybrid-wrapped), return the /// raw bytes as-is. fn try_hybrid_unwrap(hybrid_kp: Option<&HybridKeypair>, payload: &[u8]) -> Vec { if let Some(kp) = hybrid_kp { match quicproquo_core::hybrid_decrypt(kp, payload, b"", b"") { Ok(inner) => inner, Err(_) => payload.to_vec(), // not hybrid-wrapped, use raw } } else { payload.to_vec() } } /// Unseal (verify sender identity + Ed25519 signature) then parse the inner /// application message. Returns None on failure (logged as debug). fn try_unseal_and_parse(seq: u64, plaintext: &[u8]) -> Option { let (sender_key, inner) = match quicproquo_core::sealed_sender::unseal(plaintext) { Ok(pair) => pair, Err(e) => { debug!(seq, error = %e, "unseal failed"); return None; } }; let (_msg_type, message) = match quicproquo_core::parse(&inner) { Ok(pair) => pair, Err(e) => { debug!(seq, error = %e, "app_message parse failed"); return None; } }; Some(ReceivedPlaintext { sender_key, message, seq, }) } // ── Gap Detection ──────────────────────────────────────────────────────────── /// A gap detected in server-side sequence numbers. #[derive(Debug, Clone)] pub struct SeqGap { /// The expected next sequence number. pub expected_seq: u64, /// The sequence number that was actually received. pub received_seq: u64, } /// Detect gaps in a sorted list of `(seq, payload)` pairs relative to the /// last known sequence number. Returns a list of gaps and the new highest seq. /// /// Callers should update their stored `last_seen_seq` to the returned value /// and emit `ClientEvent::MessageGap` for each gap. pub fn detect_gaps(last_seen_seq: u64, payloads: &[(u64, Vec)]) -> (Vec, u64) { if payloads.is_empty() { return (Vec::new(), last_seen_seq); } let mut gaps = Vec::new(); let mut expected = last_seen_seq + 1; for &(seq, _) in payloads { if seq > expected { gaps.push(SeqGap { expected_seq: expected, received_seq: seq, }); } if seq >= expected { expected = seq + 1; } } // The new last_seen_seq is the highest seq we received. let new_last_seen = payloads.iter().map(|(s, _)| *s).max().unwrap_or(last_seen_seq); (gaps, new_last_seen) } // ── RPC Helpers ─────────────────────────────────────────────────────────────── /// Enqueue a single payload to one recipient via RPC. /// /// Returns the server-assigned sequence number. pub async fn enqueue( rpc: &RpcClient, recipient_key: &[u8], channel_id: &[u8], payload: &[u8], ttl_secs: u32, ) -> Result { let req = EnqueueRequest { recipient_key: recipient_key.to_vec(), payload: payload.to_vec(), channel_id: channel_id.to_vec(), ttl_secs, message_id: Vec::new(), }; let resp_bytes = rpc .call(method_ids::ENQUEUE, Bytes::from(req.encode_to_vec())) .await?; let resp = EnqueueResponse::decode(resp_bytes) .map_err(|e| SdkError::Crypto(format!("decode EnqueueResponse: {e}")))?; Ok(resp.seq) } /// Batch enqueue the same payload to multiple recipients via RPC. /// /// Returns per-recipient sequence numbers. pub async fn batch_enqueue( rpc: &RpcClient, recipient_keys: &[Vec], channel_id: &[u8], payload: &[u8], ttl_secs: u32, ) -> Result, SdkError> { let req = BatchEnqueueRequest { recipient_keys: recipient_keys.to_vec(), payload: payload.to_vec(), channel_id: channel_id.to_vec(), ttl_secs, message_id: Vec::new(), }; let resp_bytes = rpc .call( method_ids::BATCH_ENQUEUE, Bytes::from(req.encode_to_vec()), ) .await?; let resp = BatchEnqueueResponse::decode(resp_bytes) .map_err(|e| SdkError::Crypto(format!("decode BatchEnqueueResponse: {e}")))?; Ok(resp.seqs) } /// Fetch messages from server (destructive — removes from queue). /// /// When `device_id` is non-empty, the server scopes the fetch to the /// device-specific queue (identity_key + device_id). /// /// Returns `(seq, payload)` pairs sorted by sequence number. pub async fn fetch( rpc: &RpcClient, my_identity_key: &[u8], channel_id: &[u8], limit: u32, device_id: &[u8], ) -> Result)>, SdkError> { let req = FetchRequest { recipient_key: my_identity_key.to_vec(), channel_id: channel_id.to_vec(), limit, device_id: device_id.to_vec(), }; let resp_bytes = rpc .call(method_ids::FETCH, Bytes::from(req.encode_to_vec())) .await?; let resp = FetchResponse::decode(resp_bytes) .map_err(|e| SdkError::Crypto(format!("decode FetchResponse: {e}")))?; let mut payloads: Vec<(u64, Vec)> = resp .payloads .into_iter() .map(|env| (env.seq, env.data)) .collect(); payloads.sort_by_key(|(seq, _)| *seq); Ok(payloads) } /// Long-poll fetch: blocks server-side until messages arrive or timeout expires. /// /// When `device_id` is non-empty, the server scopes the fetch to the /// device-specific queue (identity_key + device_id). /// /// Returns `(seq, payload)` pairs sorted by sequence number. async fn fetch_wait( rpc: &RpcClient, my_identity_key: &[u8], channel_id: &[u8], timeout_ms: u64, device_id: &[u8], ) -> Result)>, SdkError> { let req = FetchWaitRequest { recipient_key: my_identity_key.to_vec(), channel_id: channel_id.to_vec(), timeout_ms, limit: 0, // fetch all device_id: device_id.to_vec(), }; let resp_bytes = rpc .call(method_ids::FETCH_WAIT, Bytes::from(req.encode_to_vec())) .await?; let resp = FetchWaitResponse::decode(resp_bytes) .map_err(|e| SdkError::Crypto(format!("decode FetchWaitResponse: {e}")))?; let mut payloads: Vec<(u64, Vec)> = resp .payloads .into_iter() .map(|env| (env.seq, env.data)) .collect(); payloads.sort_by_key(|(seq, _)| *seq); Ok(payloads) } // ── Device-aware fetch ────────────────────────────────────────────────────── /// Fetch messages for a specific device. /// /// When `device_id` is non-empty, the server uses the composite queue key /// `identity_key + device_id`. When empty, falls back to the bare identity key. pub async fn fetch_for_device( rpc: &RpcClient, my_identity_key: &[u8], device_id: &[u8], channel_id: &[u8], limit: u32, ) -> Result)>, SdkError> { let req = FetchRequest { recipient_key: my_identity_key.to_vec(), channel_id: channel_id.to_vec(), limit, device_id: device_id.to_vec(), }; let resp_bytes = rpc .call(method_ids::FETCH, Bytes::from(req.encode_to_vec())) .await?; let resp = FetchResponse::decode(resp_bytes) .map_err(|e| SdkError::Crypto(format!("decode FetchResponse: {e}")))?; let mut payloads: Vec<(u64, Vec)> = resp .payloads .into_iter() .map(|env| (env.seq, env.data)) .collect(); payloads.sort_by_key(|(seq, _)| *seq); Ok(payloads) } // ── Acknowledge ───────────────────────────────────────────────────────────── /// Acknowledge messages up to a sequence number. /// /// When `device_id` is non-empty, the server acks on the device-scoped queue. pub async fn ack( rpc: &RpcClient, my_identity_key: &[u8], device_id: &[u8], channel_id: &[u8], seq_up_to: u64, ) -> Result<(), SdkError> { let req = AckRequest { recipient_key: my_identity_key.to_vec(), channel_id: channel_id.to_vec(), seq_up_to, device_id: device_id.to_vec(), }; let resp_bytes = rpc .call(method_ids::ACK, Bytes::from(req.encode_to_vec())) .await?; let _resp = AckResponse::decode(resp_bytes) .map_err(|e| SdkError::Crypto(format!("decode AckResponse: {e}")))?; Ok(()) } #[cfg(test)] #[allow(clippy::unwrap_used)] mod tests { use super::*; #[test] fn detect_gaps_empty() { let (gaps, last) = detect_gaps(0, &[]); assert!(gaps.is_empty()); assert_eq!(last, 0); } #[test] fn detect_gaps_contiguous_from_zero() { let payloads = vec![ (1, vec![]), (2, vec![]), (3, vec![]), ]; let (gaps, last) = detect_gaps(0, &payloads); assert!(gaps.is_empty()); assert_eq!(last, 3); } #[test] fn detect_gaps_contiguous_from_nonzero() { let payloads = vec![ (6, vec![]), (7, vec![]), (8, vec![]), ]; let (gaps, last) = detect_gaps(5, &payloads); assert!(gaps.is_empty()); assert_eq!(last, 8); } #[test] fn detect_gaps_single_gap() { let payloads = vec![ (1, vec![]), (2, vec![]), (5, vec![]), // gap: expected 3, got 5 (6, vec![]), ]; let (gaps, last) = detect_gaps(0, &payloads); assert_eq!(gaps.len(), 1); assert_eq!(gaps[0].expected_seq, 3); assert_eq!(gaps[0].received_seq, 5); assert_eq!(last, 6); } #[test] fn detect_gaps_multiple_gaps() { let payloads = vec![ (3, vec![]), // gap from 1 to 3 (7, vec![]), // gap from 4 to 7 (8, vec![]), ]; let (gaps, last) = detect_gaps(0, &payloads); assert_eq!(gaps.len(), 2); assert_eq!(gaps[0].expected_seq, 1); assert_eq!(gaps[0].received_seq, 3); assert_eq!(gaps[1].expected_seq, 4); assert_eq!(gaps[1].received_seq, 7); assert_eq!(last, 8); } #[test] fn detect_gaps_initial_gap() { // last_seen_seq = 5, but first received is 10 let payloads = vec![ (10, vec![]), (11, vec![]), ]; let (gaps, last) = detect_gaps(5, &payloads); assert_eq!(gaps.len(), 1); assert_eq!(gaps[0].expected_seq, 6); assert_eq!(gaps[0].received_seq, 10); assert_eq!(last, 11); } }