use std::net::SocketAddr; use std::path::Path; use std::sync::Arc; use anyhow::Context; use quinn::{ClientConfig, Endpoint}; use quinn_proto::crypto::rustls::QuicClientConfig; use rustls::pki_types::CertificateDer; use rustls::{ClientConfig as RustlsClientConfig, RootCertStore}; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use capnp_rpc::{rpc_twoparty_capnp::Side, twoparty, RpcSystem}; use quicnprotochat_core::HybridPublicKey; use quicnprotochat_proto::node_capnp::{auth, node_service}; use crate::AUTH_CONTEXT; use super::retry::{ anyhow_is_retriable, base_delay_ms_from_env, max_retries_from_env, retry_async, }; /// Establish a QUIC/TLS connection and return a `NodeService` client. /// /// Must be called from within a `LocalSet` because capnp-rpc is `!Send`. pub async fn connect_node( server: &str, ca_cert: &Path, server_name: &str, ) -> anyhow::Result { let addr: SocketAddr = server .parse() .with_context(|| format!("server must be host:port, got {server}"))?; let cert_bytes = std::fs::read(ca_cert).with_context(|| format!("read ca_cert {ca_cert:?}"))?; let mut roots = RootCertStore::empty(); roots .add(CertificateDer::from(cert_bytes)) .context("add root cert")?; let mut tls = RustlsClientConfig::builder() .with_root_certificates(roots) .with_no_client_auth(); tls.alpn_protocols = vec![b"capnp".to_vec()]; let crypto = QuicClientConfig::try_from(tls) .map_err(|e| anyhow::anyhow!("invalid client TLS config: {e}"))?; let bind_addr: SocketAddr = "0.0.0.0:0".parse().context("parse client bind address")?; let mut endpoint = Endpoint::client(bind_addr)?; endpoint.set_default_client_config(ClientConfig::new(Arc::new(crypto))); let connection = endpoint .connect(addr, server_name) .context("quic connect init")? .await .context("quic connect failed")?; let (send, recv) = connection.open_bi().await.context("open bi stream")?; let network = twoparty::VatNetwork::new( recv.compat(), send.compat_write(), Side::Client, Default::default(), ); let mut rpc_system = RpcSystem::new(Box::new(network), None); let client: node_service::Client = rpc_system.bootstrap(Side::Server); tokio::task::spawn_local(rpc_system); Ok(client) } pub fn set_auth(auth: &mut auth::Builder<'_>) -> anyhow::Result<()> { let ctx = AUTH_CONTEXT.get().ok_or_else(|| { anyhow::anyhow!("init_auth must be called with a non-empty token before RPCs") })?; auth.set_version(ctx.version); auth.set_access_token(&ctx.access_token); auth.set_device_id(&ctx.device_id); Ok(()) } /// Upload a KeyPackage and verify the fingerprint echoed by the AS. pub async fn upload_key_package( client: &node_service::Client, identity_key: &[u8], package: &[u8], ) -> anyhow::Result<()> { let mut req = client.upload_key_package_request(); { let mut p = req.get(); p.set_identity_key(identity_key); p.set_package(package); let mut auth = p.reborrow().init_auth(); set_auth(&mut auth)?; } let resp = req .send() .promise .await .context("upload_key_package RPC failed")?; let server_fp = resp .get() .context("upload_key_package: bad response")? .get_fingerprint() .context("upload_key_package: missing fingerprint")? .to_vec(); let local_fp = super::state::sha256(package); anyhow::ensure!(server_fp == local_fp, "fingerprint mismatch"); Ok(()) } /// Fetch a KeyPackage for `identity_key` from the AS. pub async fn fetch_key_package( client: &node_service::Client, identity_key: &[u8], ) -> anyhow::Result> { let mut req = client.fetch_key_package_request(); { let mut p = req.get(); p.set_identity_key(identity_key); let mut auth = p.reborrow().init_auth(); set_auth(&mut auth)?; } let resp = req .send() .promise .await .context("fetch_key_package RPC failed")?; let pkg = resp .get() .context("fetch_key_package: bad response")? .get_package() .context("fetch_key_package: missing package field")? .to_vec(); Ok(pkg) } /// Enqueue an opaque payload to the DS for `recipient_key`. /// Returns the per-inbox sequence number assigned by the server. /// Retries on transient failures with exponential backoff. pub async fn enqueue( client: &node_service::Client, recipient_key: &[u8], payload: &[u8], ) -> anyhow::Result { let client = client.clone(); let recipient_key = recipient_key.to_vec(); let payload = payload.to_vec(); retry_async( || { let client = client.clone(); let recipient_key = recipient_key.clone(); let payload = payload.clone(); async move { let mut req = client.enqueue_request(); { let mut p = req.get(); p.set_recipient_key(&recipient_key); p.set_payload(&payload); p.set_channel_id(&[]); p.set_version(1); let mut auth = p.reborrow().init_auth(); set_auth(&mut auth)?; } let resp = req.send().promise.await.context("enqueue RPC failed")?; let seq = resp.get().context("enqueue: bad response")?.get_seq(); Ok(seq) } }, max_retries_from_env(), base_delay_ms_from_env(), anyhow_is_retriable, ) .await } /// Fetch and drain all payloads for `recipient_key`. /// Returns `(seq, payload)` pairs — sort by `seq` before MLS processing. /// Retries on transient failures with exponential backoff. pub async fn fetch_all( client: &node_service::Client, recipient_key: &[u8], ) -> anyhow::Result)>> { let client = client.clone(); let recipient_key = recipient_key.to_vec(); retry_async( || { let client = client.clone(); let recipient_key = recipient_key.clone(); async move { let mut req = client.fetch_request(); { let mut p = req.get(); p.set_recipient_key(&recipient_key); p.set_channel_id(&[]); p.set_version(1); p.set_limit(0); // fetch all let mut auth = p.reborrow().init_auth(); set_auth(&mut auth)?; } let resp = req.send().promise.await.context("fetch RPC failed")?; let list = resp .get() .context("fetch: bad response")? .get_payloads() .context("fetch: missing payloads")?; let mut payloads = Vec::with_capacity(list.len() as usize); for i in 0..list.len() { let entry = list.get(i); let seq = entry.get_seq(); let data = entry .get_data() .context("fetch: envelope data read failed")? .to_vec(); payloads.push((seq, data)); } Ok(payloads) } }, max_retries_from_env(), base_delay_ms_from_env(), anyhow_is_retriable, ) .await } /// Long-poll for payloads with optional timeout (ms). /// Returns `(seq, payload)` pairs — sort by `seq` before MLS processing. /// Retries on transient failures with exponential backoff. pub async fn fetch_wait( client: &node_service::Client, recipient_key: &[u8], timeout_ms: u64, ) -> anyhow::Result)>> { let client = client.clone(); let recipient_key = recipient_key.to_vec(); retry_async( || { let client = client.clone(); let recipient_key = recipient_key.clone(); let timeout_ms = timeout_ms; async move { let mut req = client.fetch_wait_request(); { let mut p = req.get(); p.set_recipient_key(&recipient_key); p.set_timeout_ms(timeout_ms); p.set_channel_id(&[]); p.set_version(1); p.set_limit(0); // fetch all let mut auth = p.reborrow().init_auth(); set_auth(&mut auth)?; } let resp = req.send().promise.await.context("fetch_wait RPC failed")?; let list = resp .get() .context("fetch_wait: bad response")? .get_payloads() .context("fetch_wait: missing payloads")?; let mut payloads = Vec::with_capacity(list.len() as usize); for i in 0..list.len() { let entry = list.get(i); let seq = entry.get_seq(); let data = entry .get_data() .context("fetch_wait: envelope data read failed")? .to_vec(); payloads.push((seq, data)); } Ok(payloads) } }, max_retries_from_env(), base_delay_ms_from_env(), anyhow_is_retriable, ) .await } /// Upload a hybrid (X25519 + ML-KEM-768) public key for an identity. pub async fn upload_hybrid_key( client: &node_service::Client, identity_key: &[u8], hybrid_pk: &HybridPublicKey, ) -> anyhow::Result<()> { let mut req = client.upload_hybrid_key_request(); { let mut p = req.get(); p.set_identity_key(identity_key); p.set_hybrid_public_key(&hybrid_pk.to_bytes()); let mut auth = p.reborrow().init_auth(); set_auth(&mut auth)?; } req.send() .promise .await .context("upload_hybrid_key RPC failed")?; Ok(()) } /// Fetch a peer's hybrid public key from the server. /// /// Returns `None` if the peer has not uploaded a hybrid key. pub async fn fetch_hybrid_key( client: &node_service::Client, identity_key: &[u8], ) -> anyhow::Result> { let mut req = client.fetch_hybrid_key_request(); { let mut p = req.get(); p.set_identity_key(identity_key); let mut auth = p.reborrow().init_auth(); set_auth(&mut auth)?; } let resp = req .send() .promise .await .context("fetch_hybrid_key RPC failed")?; let pk_bytes = resp .get() .context("fetch_hybrid_key: bad response")? .get_hybrid_public_key() .context("fetch_hybrid_key: missing field")? .to_vec(); if pk_bytes.is_empty() { return Ok(None); } let pk = HybridPublicKey::from_bytes(&pk_bytes).context("invalid hybrid public key")?; Ok(Some(pk)) } /// Decrypt a hybrid envelope. Requires a hybrid key; no fallback to plaintext MLS. pub fn try_hybrid_decrypt( hybrid_kp: Option<&quicnprotochat_core::HybridKeypair>, payload: &[u8], ) -> anyhow::Result> { let kp = hybrid_kp.ok_or_else(|| anyhow::anyhow!("hybrid key required for decryption"))?; quicnprotochat_core::hybrid_decrypt(kp, payload).map_err(|e| anyhow::anyhow!("{e}")) } /// Return the current Unix timestamp in milliseconds. pub fn current_timestamp_ms() -> u64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_millis() as u64 }