Add device_id parameter to fetch, fetch_wait, ack, receive_messages, and receive_messages_wait SDK functions. QpqClient gains device_id field with register_device/list_devices/revoke_device convenience methods. Client REPL passes empty device_id for backwards compat.
577 lines
19 KiB
Rust
577 lines
19 KiB
Rust
//! Messaging pipeline: send and receive messages through the MLS + sealed sender
|
|
//! + hybrid KEM stack.
|
|
//!
|
|
//! This module wraps the full encryption pipeline:
|
|
//! 1. **Send**: serialize → MLS encrypt → sealed sender → hybrid wrap → enqueue
|
|
//! 2. **Receive**: fetch → hybrid unwrap → MLS decrypt → unseal → parse
|
|
|
|
use bytes::Bytes;
|
|
use prost::Message;
|
|
use tracing::debug;
|
|
|
|
use quicproquo_core::{
|
|
AppMessage, GroupMember, HybridKeypair, HybridPublicKey, IdentityKeypair, ReceivedMessage,
|
|
};
|
|
use quicproquo_proto::method_ids;
|
|
use quicproquo_proto::qpq::v1::{
|
|
AckRequest, AckResponse, BatchEnqueueRequest, BatchEnqueueResponse, EnqueueRequest,
|
|
EnqueueResponse, FetchRequest, FetchResponse, FetchWaitRequest, FetchWaitResponse,
|
|
};
|
|
use quicproquo_rpc::client::RpcClient;
|
|
|
|
use crate::error::SdkError;
|
|
|
|
// ── Types ─────────────────────────────────────────────────────────────────────
|
|
|
|
/// A successfully decrypted application message with sender info.
|
|
#[derive(Debug)]
|
|
pub struct ReceivedPlaintext {
|
|
/// Sender's Ed25519 identity key (from sealed sender envelope).
|
|
pub sender_key: [u8; 32],
|
|
/// The parsed application message (Chat, Reply, Reaction, etc.).
|
|
pub message: AppMessage,
|
|
/// Server-assigned sequence number.
|
|
pub seq: u64,
|
|
}
|
|
|
|
/// Default TTL for enqueued messages (24 hours).
|
|
const DEFAULT_TTL_SECS: u32 = 86400;
|
|
|
|
// ── Send Pipeline ─────────────────────────────────────────────────────────────
|
|
|
|
/// Encrypt and send a message to a conversation.
|
|
///
|
|
/// Pipeline: generate_message_id → serialize → MLS encrypt → seal → per-recipient
|
|
/// hybrid wrap → batch enqueue.
|
|
///
|
|
/// Returns the server-assigned sequence numbers (one per recipient).
|
|
pub async fn send_message(
|
|
rpc: &RpcClient,
|
|
member: &mut GroupMember,
|
|
identity: &IdentityKeypair,
|
|
body: &str,
|
|
recipient_keys: &[Vec<u8>],
|
|
hybrid_keys: &[Option<HybridPublicKey>],
|
|
channel_id: &[u8],
|
|
) -> Result<Vec<u64>, SdkError> {
|
|
// 1. Generate message ID.
|
|
let message_id = quicproquo_core::generate_message_id();
|
|
|
|
// 2. Serialize application payload.
|
|
let serialized = quicproquo_core::serialize_chat(body.as_bytes(), Some(message_id))
|
|
.map_err(|e| SdkError::Crypto(format!("serialize_chat: {e}")))?;
|
|
|
|
// 3. MLS encrypt.
|
|
let mls_ciphertext = member
|
|
.send_message(&serialized)
|
|
.map_err(|e| SdkError::Crypto(format!("MLS encrypt: {e}")))?;
|
|
|
|
// 4. Sealed sender wrap.
|
|
let sealed = quicproquo_core::sealed_sender::seal(identity, &mls_ciphertext);
|
|
|
|
// 5. Per-recipient hybrid wrap + enqueue.
|
|
// If all recipients can share the same payload (no hybrid keys), use batch enqueue.
|
|
// Otherwise, enqueue individually with per-recipient hybrid wrapping.
|
|
let all_no_hybrid = hybrid_keys.iter().all(|k| k.is_none());
|
|
|
|
if all_no_hybrid {
|
|
// Batch enqueue — same payload for all recipients.
|
|
let seqs = batch_enqueue(rpc, recipient_keys, channel_id, &sealed, DEFAULT_TTL_SECS).await?;
|
|
debug!(count = seqs.len(), "batch enqueue complete");
|
|
Ok(seqs)
|
|
} else {
|
|
// Per-recipient enqueue with optional hybrid wrapping.
|
|
let mut seqs = Vec::with_capacity(recipient_keys.len());
|
|
for (i, recipient_key) in recipient_keys.iter().enumerate() {
|
|
let payload = if let Some(Some(ref pk)) = hybrid_keys.get(i) {
|
|
quicproquo_core::hybrid_encrypt(pk, &sealed, b"", b"")
|
|
.map_err(|e| SdkError::Crypto(format!("hybrid encrypt: {e}")))?
|
|
} else {
|
|
sealed.clone()
|
|
};
|
|
let seq = enqueue(rpc, recipient_key, channel_id, &payload, DEFAULT_TTL_SECS).await?;
|
|
seqs.push(seq);
|
|
}
|
|
debug!(count = seqs.len(), "per-recipient enqueue complete");
|
|
Ok(seqs)
|
|
}
|
|
}
|
|
|
|
// ── Receive Pipeline ──────────────────────────────────────────────────────────
|
|
|
|
/// Receive and decrypt pending messages from the server.
|
|
///
|
|
/// Pipeline: fetch → sort by seq → for each: hybrid unwrap → MLS decrypt →
|
|
/// unseal → parse. Includes retry loop for multi-epoch batches where commits
|
|
/// must apply before application messages can be decrypted.
|
|
pub async fn receive_messages(
|
|
rpc: &RpcClient,
|
|
member: &mut GroupMember,
|
|
my_identity_key: &[u8],
|
|
hybrid_kp: Option<&HybridKeypair>,
|
|
channel_id: &[u8],
|
|
device_id: &[u8],
|
|
) -> Result<Vec<ReceivedPlaintext>, SdkError> {
|
|
let payloads = fetch(rpc, my_identity_key, channel_id, 0, device_id).await?;
|
|
process_payloads(member, hybrid_kp, payloads)
|
|
}
|
|
|
|
/// Long-poll for new messages with timeout.
|
|
///
|
|
/// Same pipeline as [`receive_messages`] but uses the FETCH_WAIT RPC which
|
|
/// blocks server-side until messages arrive or the timeout expires.
|
|
pub async fn receive_messages_wait(
|
|
rpc: &RpcClient,
|
|
member: &mut GroupMember,
|
|
my_identity_key: &[u8],
|
|
hybrid_kp: Option<&HybridKeypair>,
|
|
channel_id: &[u8],
|
|
timeout_ms: u64,
|
|
device_id: &[u8],
|
|
) -> Result<Vec<ReceivedPlaintext>, SdkError> {
|
|
let payloads = fetch_wait(rpc, my_identity_key, channel_id, timeout_ms, device_id).await?;
|
|
process_payloads(member, hybrid_kp, payloads)
|
|
}
|
|
|
|
/// Shared processing logic for received payloads.
|
|
///
|
|
/// Sorts by sequence number, then processes each payload through the decryption
|
|
/// pipeline. Uses a retry loop to handle multi-epoch batches where MLS commits
|
|
/// must be applied before subsequent application messages can be decrypted.
|
|
fn process_payloads(
|
|
member: &mut GroupMember,
|
|
hybrid_kp: Option<&HybridKeypair>,
|
|
mut payloads: Vec<(u64, Vec<u8>)>,
|
|
) -> Result<Vec<ReceivedPlaintext>, SdkError> {
|
|
if payloads.is_empty() {
|
|
return Ok(Vec::new());
|
|
}
|
|
|
|
// Sort by server-assigned sequence number — commits must arrive before
|
|
// application messages that depend on the resulting epoch.
|
|
payloads.sort_by_key(|(seq, _)| *seq);
|
|
|
|
let mut results = Vec::new();
|
|
let mut pending: Vec<(u64, Vec<u8>)> = Vec::new();
|
|
|
|
for (seq, raw_payload) in &payloads {
|
|
// (a) Try hybrid decrypt; fall back to raw bytes if not hybrid-wrapped.
|
|
let mls_bytes = try_hybrid_unwrap(hybrid_kp, raw_payload);
|
|
|
|
// (b) MLS decrypt.
|
|
match member.receive_message(&mls_bytes) {
|
|
Ok(ReceivedMessage::Application(plaintext)) => {
|
|
if let Some(rp) = try_unseal_and_parse(*seq, &plaintext) {
|
|
results.push(rp);
|
|
}
|
|
}
|
|
Ok(ReceivedMessage::StateChanged | ReceivedMessage::SelfRemoved) => {
|
|
debug!(seq, "commit/state-change applied");
|
|
}
|
|
Err(_) => {
|
|
// MLS decryption failed — likely an epoch mismatch.
|
|
// Stash for retry after commits are applied.
|
|
pending.push((*seq, mls_bytes));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Retry loop: keep retrying pending messages until no more progress.
|
|
// This handles multi-epoch batches where commits must apply first.
|
|
loop {
|
|
let before = pending.len();
|
|
pending.retain_mut(|(seq, mls_bytes)| {
|
|
match member.receive_message(mls_bytes) {
|
|
Ok(ReceivedMessage::Application(plaintext)) => {
|
|
if let Some(rp) = try_unseal_and_parse(*seq, &plaintext) {
|
|
results.push(rp);
|
|
}
|
|
false // processed
|
|
}
|
|
Ok(ReceivedMessage::StateChanged | ReceivedMessage::SelfRemoved) => {
|
|
debug!(seq, "commit applied (retry)");
|
|
false // processed
|
|
}
|
|
Err(_) => true, // still pending
|
|
}
|
|
});
|
|
if pending.len() == before {
|
|
break; // no progress — remaining messages are unprocessable
|
|
}
|
|
}
|
|
|
|
if !pending.is_empty() {
|
|
debug!(
|
|
remaining = pending.len(),
|
|
"unprocessable messages after all retries"
|
|
);
|
|
}
|
|
|
|
Ok(results)
|
|
}
|
|
|
|
/// Try to hybrid-decrypt a payload. If the caller has a hybrid keypair, attempt
|
|
/// decryption. If it fails (payload might not be hybrid-wrapped), return the
|
|
/// raw bytes as-is.
|
|
fn try_hybrid_unwrap(hybrid_kp: Option<&HybridKeypair>, payload: &[u8]) -> Vec<u8> {
|
|
if let Some(kp) = hybrid_kp {
|
|
match quicproquo_core::hybrid_decrypt(kp, payload, b"", b"") {
|
|
Ok(inner) => inner,
|
|
Err(_) => payload.to_vec(), // not hybrid-wrapped, use raw
|
|
}
|
|
} else {
|
|
payload.to_vec()
|
|
}
|
|
}
|
|
|
|
/// Unseal (verify sender identity + Ed25519 signature) then parse the inner
|
|
/// application message. Returns None on failure (logged as debug).
|
|
fn try_unseal_and_parse(seq: u64, plaintext: &[u8]) -> Option<ReceivedPlaintext> {
|
|
let (sender_key, inner) = match quicproquo_core::sealed_sender::unseal(plaintext) {
|
|
Ok(pair) => pair,
|
|
Err(e) => {
|
|
debug!(seq, error = %e, "unseal failed");
|
|
return None;
|
|
}
|
|
};
|
|
|
|
let (_msg_type, message) = match quicproquo_core::parse(&inner) {
|
|
Ok(pair) => pair,
|
|
Err(e) => {
|
|
debug!(seq, error = %e, "app_message parse failed");
|
|
return None;
|
|
}
|
|
};
|
|
|
|
Some(ReceivedPlaintext {
|
|
sender_key,
|
|
message,
|
|
seq,
|
|
})
|
|
}
|
|
|
|
// ── Gap Detection ────────────────────────────────────────────────────────────
|
|
|
|
/// A gap detected in server-side sequence numbers.
|
|
#[derive(Debug, Clone)]
|
|
pub struct SeqGap {
|
|
/// The expected next sequence number.
|
|
pub expected_seq: u64,
|
|
/// The sequence number that was actually received.
|
|
pub received_seq: u64,
|
|
}
|
|
|
|
/// Detect gaps in a sorted list of `(seq, payload)` pairs relative to the
|
|
/// last known sequence number. Returns a list of gaps and the new highest seq.
|
|
///
|
|
/// Callers should update their stored `last_seen_seq` to the returned value
|
|
/// and emit `ClientEvent::MessageGap` for each gap.
|
|
pub fn detect_gaps(last_seen_seq: u64, payloads: &[(u64, Vec<u8>)]) -> (Vec<SeqGap>, u64) {
|
|
if payloads.is_empty() {
|
|
return (Vec::new(), last_seen_seq);
|
|
}
|
|
|
|
let mut gaps = Vec::new();
|
|
let mut expected = last_seen_seq + 1;
|
|
|
|
for &(seq, _) in payloads {
|
|
if seq > expected {
|
|
gaps.push(SeqGap {
|
|
expected_seq: expected,
|
|
received_seq: seq,
|
|
});
|
|
}
|
|
if seq >= expected {
|
|
expected = seq + 1;
|
|
}
|
|
}
|
|
|
|
// The new last_seen_seq is the highest seq we received.
|
|
let new_last_seen = payloads.iter().map(|(s, _)| *s).max().unwrap_or(last_seen_seq);
|
|
(gaps, new_last_seen)
|
|
}
|
|
|
|
// ── RPC Helpers ───────────────────────────────────────────────────────────────
|
|
|
|
/// Enqueue a single payload to one recipient via RPC.
|
|
///
|
|
/// Returns the server-assigned sequence number.
|
|
pub async fn enqueue(
|
|
rpc: &RpcClient,
|
|
recipient_key: &[u8],
|
|
channel_id: &[u8],
|
|
payload: &[u8],
|
|
ttl_secs: u32,
|
|
) -> Result<u64, SdkError> {
|
|
let req = EnqueueRequest {
|
|
recipient_key: recipient_key.to_vec(),
|
|
payload: payload.to_vec(),
|
|
channel_id: channel_id.to_vec(),
|
|
ttl_secs,
|
|
message_id: Vec::new(),
|
|
};
|
|
|
|
let resp_bytes = rpc
|
|
.call(method_ids::ENQUEUE, Bytes::from(req.encode_to_vec()))
|
|
.await?;
|
|
|
|
let resp = EnqueueResponse::decode(resp_bytes)
|
|
.map_err(|e| SdkError::Crypto(format!("decode EnqueueResponse: {e}")))?;
|
|
|
|
Ok(resp.seq)
|
|
}
|
|
|
|
/// Batch enqueue the same payload to multiple recipients via RPC.
|
|
///
|
|
/// Returns per-recipient sequence numbers.
|
|
pub async fn batch_enqueue(
|
|
rpc: &RpcClient,
|
|
recipient_keys: &[Vec<u8>],
|
|
channel_id: &[u8],
|
|
payload: &[u8],
|
|
ttl_secs: u32,
|
|
) -> Result<Vec<u64>, SdkError> {
|
|
let req = BatchEnqueueRequest {
|
|
recipient_keys: recipient_keys.to_vec(),
|
|
payload: payload.to_vec(),
|
|
channel_id: channel_id.to_vec(),
|
|
ttl_secs,
|
|
message_id: Vec::new(),
|
|
};
|
|
|
|
let resp_bytes = rpc
|
|
.call(
|
|
method_ids::BATCH_ENQUEUE,
|
|
Bytes::from(req.encode_to_vec()),
|
|
)
|
|
.await?;
|
|
|
|
let resp = BatchEnqueueResponse::decode(resp_bytes)
|
|
.map_err(|e| SdkError::Crypto(format!("decode BatchEnqueueResponse: {e}")))?;
|
|
|
|
Ok(resp.seqs)
|
|
}
|
|
|
|
/// Fetch messages from server (destructive — removes from queue).
|
|
///
|
|
/// When `device_id` is non-empty, the server scopes the fetch to the
|
|
/// device-specific queue (identity_key + device_id).
|
|
///
|
|
/// Returns `(seq, payload)` pairs sorted by sequence number.
|
|
pub async fn fetch(
|
|
rpc: &RpcClient,
|
|
my_identity_key: &[u8],
|
|
channel_id: &[u8],
|
|
limit: u32,
|
|
device_id: &[u8],
|
|
) -> Result<Vec<(u64, Vec<u8>)>, SdkError> {
|
|
let req = FetchRequest {
|
|
recipient_key: my_identity_key.to_vec(),
|
|
channel_id: channel_id.to_vec(),
|
|
limit,
|
|
device_id: device_id.to_vec(),
|
|
};
|
|
|
|
let resp_bytes = rpc
|
|
.call(method_ids::FETCH, Bytes::from(req.encode_to_vec()))
|
|
.await?;
|
|
|
|
let resp = FetchResponse::decode(resp_bytes)
|
|
.map_err(|e| SdkError::Crypto(format!("decode FetchResponse: {e}")))?;
|
|
|
|
let mut payloads: Vec<(u64, Vec<u8>)> = resp
|
|
.payloads
|
|
.into_iter()
|
|
.map(|env| (env.seq, env.data))
|
|
.collect();
|
|
|
|
payloads.sort_by_key(|(seq, _)| *seq);
|
|
Ok(payloads)
|
|
}
|
|
|
|
/// Long-poll fetch: blocks server-side until messages arrive or timeout expires.
|
|
///
|
|
/// When `device_id` is non-empty, the server scopes the fetch to the
|
|
/// device-specific queue (identity_key + device_id).
|
|
///
|
|
/// Returns `(seq, payload)` pairs sorted by sequence number.
|
|
async fn fetch_wait(
|
|
rpc: &RpcClient,
|
|
my_identity_key: &[u8],
|
|
channel_id: &[u8],
|
|
timeout_ms: u64,
|
|
device_id: &[u8],
|
|
) -> Result<Vec<(u64, Vec<u8>)>, SdkError> {
|
|
let req = FetchWaitRequest {
|
|
recipient_key: my_identity_key.to_vec(),
|
|
channel_id: channel_id.to_vec(),
|
|
timeout_ms,
|
|
limit: 0, // fetch all
|
|
device_id: device_id.to_vec(),
|
|
};
|
|
|
|
let resp_bytes = rpc
|
|
.call(method_ids::FETCH_WAIT, Bytes::from(req.encode_to_vec()))
|
|
.await?;
|
|
|
|
let resp = FetchWaitResponse::decode(resp_bytes)
|
|
.map_err(|e| SdkError::Crypto(format!("decode FetchWaitResponse: {e}")))?;
|
|
|
|
let mut payloads: Vec<(u64, Vec<u8>)> = resp
|
|
.payloads
|
|
.into_iter()
|
|
.map(|env| (env.seq, env.data))
|
|
.collect();
|
|
|
|
payloads.sort_by_key(|(seq, _)| *seq);
|
|
Ok(payloads)
|
|
}
|
|
|
|
// ── Device-aware fetch ──────────────────────────────────────────────────────
|
|
|
|
/// Fetch messages for a specific device.
|
|
///
|
|
/// When `device_id` is non-empty, the server uses the composite queue key
|
|
/// `identity_key + device_id`. When empty, falls back to the bare identity key.
|
|
pub async fn fetch_for_device(
|
|
rpc: &RpcClient,
|
|
my_identity_key: &[u8],
|
|
device_id: &[u8],
|
|
channel_id: &[u8],
|
|
limit: u32,
|
|
) -> Result<Vec<(u64, Vec<u8>)>, SdkError> {
|
|
let req = FetchRequest {
|
|
recipient_key: my_identity_key.to_vec(),
|
|
channel_id: channel_id.to_vec(),
|
|
limit,
|
|
device_id: device_id.to_vec(),
|
|
};
|
|
|
|
let resp_bytes = rpc
|
|
.call(method_ids::FETCH, Bytes::from(req.encode_to_vec()))
|
|
.await?;
|
|
|
|
let resp = FetchResponse::decode(resp_bytes)
|
|
.map_err(|e| SdkError::Crypto(format!("decode FetchResponse: {e}")))?;
|
|
|
|
let mut payloads: Vec<(u64, Vec<u8>)> = resp
|
|
.payloads
|
|
.into_iter()
|
|
.map(|env| (env.seq, env.data))
|
|
.collect();
|
|
|
|
payloads.sort_by_key(|(seq, _)| *seq);
|
|
Ok(payloads)
|
|
}
|
|
|
|
// ── Acknowledge ─────────────────────────────────────────────────────────────
|
|
|
|
/// Acknowledge messages up to a sequence number.
|
|
///
|
|
/// When `device_id` is non-empty, the server acks on the device-scoped queue.
|
|
pub async fn ack(
|
|
rpc: &RpcClient,
|
|
my_identity_key: &[u8],
|
|
device_id: &[u8],
|
|
channel_id: &[u8],
|
|
seq_up_to: u64,
|
|
) -> Result<(), SdkError> {
|
|
let req = AckRequest {
|
|
recipient_key: my_identity_key.to_vec(),
|
|
channel_id: channel_id.to_vec(),
|
|
seq_up_to,
|
|
device_id: device_id.to_vec(),
|
|
};
|
|
|
|
let resp_bytes = rpc
|
|
.call(method_ids::ACK, Bytes::from(req.encode_to_vec()))
|
|
.await?;
|
|
|
|
let _resp = AckResponse::decode(resp_bytes)
|
|
.map_err(|e| SdkError::Crypto(format!("decode AckResponse: {e}")))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[allow(clippy::unwrap_used)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn detect_gaps_empty() {
|
|
let (gaps, last) = detect_gaps(0, &[]);
|
|
assert!(gaps.is_empty());
|
|
assert_eq!(last, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn detect_gaps_contiguous_from_zero() {
|
|
let payloads = vec![
|
|
(1, vec![]),
|
|
(2, vec![]),
|
|
(3, vec![]),
|
|
];
|
|
let (gaps, last) = detect_gaps(0, &payloads);
|
|
assert!(gaps.is_empty());
|
|
assert_eq!(last, 3);
|
|
}
|
|
|
|
#[test]
|
|
fn detect_gaps_contiguous_from_nonzero() {
|
|
let payloads = vec![
|
|
(6, vec![]),
|
|
(7, vec![]),
|
|
(8, vec![]),
|
|
];
|
|
let (gaps, last) = detect_gaps(5, &payloads);
|
|
assert!(gaps.is_empty());
|
|
assert_eq!(last, 8);
|
|
}
|
|
|
|
#[test]
|
|
fn detect_gaps_single_gap() {
|
|
let payloads = vec![
|
|
(1, vec![]),
|
|
(2, vec![]),
|
|
(5, vec![]), // gap: expected 3, got 5
|
|
(6, vec![]),
|
|
];
|
|
let (gaps, last) = detect_gaps(0, &payloads);
|
|
assert_eq!(gaps.len(), 1);
|
|
assert_eq!(gaps[0].expected_seq, 3);
|
|
assert_eq!(gaps[0].received_seq, 5);
|
|
assert_eq!(last, 6);
|
|
}
|
|
|
|
#[test]
|
|
fn detect_gaps_multiple_gaps() {
|
|
let payloads = vec![
|
|
(3, vec![]), // gap from 1 to 3
|
|
(7, vec![]), // gap from 4 to 7
|
|
(8, vec![]),
|
|
];
|
|
let (gaps, last) = detect_gaps(0, &payloads);
|
|
assert_eq!(gaps.len(), 2);
|
|
assert_eq!(gaps[0].expected_seq, 1);
|
|
assert_eq!(gaps[0].received_seq, 3);
|
|
assert_eq!(gaps[1].expected_seq, 4);
|
|
assert_eq!(gaps[1].received_seq, 7);
|
|
assert_eq!(last, 8);
|
|
}
|
|
|
|
#[test]
|
|
fn detect_gaps_initial_gap() {
|
|
// last_seen_seq = 5, but first received is 10
|
|
let payloads = vec![
|
|
(10, vec![]),
|
|
(11, vec![]),
|
|
];
|
|
let (gaps, last) = detect_gaps(5, &payloads);
|
|
assert_eq!(gaps.len(), 1);
|
|
assert_eq!(gaps[0].expected_seq, 6);
|
|
assert_eq!(gaps[0].received_seq, 10);
|
|
assert_eq!(last, 11);
|
|
}
|
|
}
|