370 lines
12 KiB
Rust
370 lines
12 KiB
Rust
use std::net::SocketAddr;
|
|
use std::path::Path;
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::Context;
|
|
use quinn::{ClientConfig, Endpoint};
|
|
use quinn_proto::crypto::rustls::QuicClientConfig;
|
|
use rustls::pki_types::CertificateDer;
|
|
use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
|
|
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
|
|
use capnp_rpc::{rpc_twoparty_capnp::Side, twoparty, RpcSystem};
|
|
|
|
use quicnprotochat_core::HybridPublicKey;
|
|
use quicnprotochat_proto::node_capnp::{auth, node_service};
|
|
|
|
use crate::AUTH_CONTEXT;
|
|
|
|
use super::retry::{
|
|
anyhow_is_retriable, base_delay_ms_from_env, max_retries_from_env, retry_async,
|
|
};
|
|
|
|
/// Establish a QUIC/TLS connection and return a `NodeService` client.
|
|
///
|
|
/// Must be called from within a `LocalSet` because capnp-rpc is `!Send`.
|
|
pub async fn connect_node(
|
|
server: &str,
|
|
ca_cert: &Path,
|
|
server_name: &str,
|
|
) -> anyhow::Result<node_service::Client> {
|
|
let addr: SocketAddr = server
|
|
.parse()
|
|
.with_context(|| format!("server must be host:port, got {server}"))?;
|
|
|
|
let cert_bytes = std::fs::read(ca_cert).with_context(|| format!("read ca_cert {ca_cert:?}"))?;
|
|
let mut roots = RootCertStore::empty();
|
|
roots
|
|
.add(CertificateDer::from(cert_bytes))
|
|
.context("add root cert")?;
|
|
|
|
let mut tls = RustlsClientConfig::builder()
|
|
.with_root_certificates(roots)
|
|
.with_no_client_auth();
|
|
tls.alpn_protocols = vec![b"capnp".to_vec()];
|
|
|
|
let crypto = QuicClientConfig::try_from(tls)
|
|
.map_err(|e| anyhow::anyhow!("invalid client TLS config: {e}"))?;
|
|
|
|
let bind_addr: SocketAddr = "0.0.0.0:0".parse().context("parse client bind address")?;
|
|
let mut endpoint = Endpoint::client(bind_addr)?;
|
|
endpoint.set_default_client_config(ClientConfig::new(Arc::new(crypto)));
|
|
|
|
let connection = endpoint
|
|
.connect(addr, server_name)
|
|
.context("quic connect init")?
|
|
.await
|
|
.context("quic connect failed")?;
|
|
|
|
let (send, recv) = connection.open_bi().await.context("open bi stream")?;
|
|
|
|
let network = twoparty::VatNetwork::new(
|
|
recv.compat(),
|
|
send.compat_write(),
|
|
Side::Client,
|
|
Default::default(),
|
|
);
|
|
|
|
let mut rpc_system = RpcSystem::new(Box::new(network), None);
|
|
let client: node_service::Client = rpc_system.bootstrap(Side::Server);
|
|
|
|
tokio::task::spawn_local(rpc_system);
|
|
|
|
Ok(client)
|
|
}
|
|
|
|
pub fn set_auth(auth: &mut auth::Builder<'_>) -> anyhow::Result<()> {
|
|
let ctx = AUTH_CONTEXT.get().ok_or_else(|| {
|
|
anyhow::anyhow!("init_auth must be called with a non-empty token before RPCs")
|
|
})?;
|
|
auth.set_version(ctx.version);
|
|
auth.set_access_token(&ctx.access_token);
|
|
auth.set_device_id(&ctx.device_id);
|
|
Ok(())
|
|
}
|
|
|
|
/// Upload a KeyPackage and verify the fingerprint echoed by the AS.
|
|
pub async fn upload_key_package(
|
|
client: &node_service::Client,
|
|
identity_key: &[u8],
|
|
package: &[u8],
|
|
) -> anyhow::Result<()> {
|
|
let mut req = client.upload_key_package_request();
|
|
{
|
|
let mut p = req.get();
|
|
p.set_identity_key(identity_key);
|
|
p.set_package(package);
|
|
let mut auth = p.reborrow().init_auth();
|
|
set_auth(&mut auth)?;
|
|
}
|
|
|
|
let resp = req
|
|
.send()
|
|
.promise
|
|
.await
|
|
.context("upload_key_package RPC failed")?;
|
|
|
|
let server_fp = resp
|
|
.get()
|
|
.context("upload_key_package: bad response")?
|
|
.get_fingerprint()
|
|
.context("upload_key_package: missing fingerprint")?
|
|
.to_vec();
|
|
|
|
let local_fp = super::state::sha256(package);
|
|
anyhow::ensure!(server_fp == local_fp, "fingerprint mismatch");
|
|
Ok(())
|
|
}
|
|
|
|
/// Fetch a KeyPackage for `identity_key` from the AS.
|
|
pub async fn fetch_key_package(
|
|
client: &node_service::Client,
|
|
identity_key: &[u8],
|
|
) -> anyhow::Result<Vec<u8>> {
|
|
let mut req = client.fetch_key_package_request();
|
|
{
|
|
let mut p = req.get();
|
|
p.set_identity_key(identity_key);
|
|
let mut auth = p.reborrow().init_auth();
|
|
set_auth(&mut auth)?;
|
|
}
|
|
|
|
let resp = req
|
|
.send()
|
|
.promise
|
|
.await
|
|
.context("fetch_key_package RPC failed")?;
|
|
|
|
let pkg = resp
|
|
.get()
|
|
.context("fetch_key_package: bad response")?
|
|
.get_package()
|
|
.context("fetch_key_package: missing package field")?
|
|
.to_vec();
|
|
|
|
Ok(pkg)
|
|
}
|
|
|
|
/// Enqueue an opaque payload to the DS for `recipient_key`.
|
|
/// Returns the per-inbox sequence number assigned by the server.
|
|
/// Retries on transient failures with exponential backoff.
|
|
pub async fn enqueue(
|
|
client: &node_service::Client,
|
|
recipient_key: &[u8],
|
|
payload: &[u8],
|
|
) -> anyhow::Result<u64> {
|
|
let client = client.clone();
|
|
let recipient_key = recipient_key.to_vec();
|
|
let payload = payload.to_vec();
|
|
retry_async(
|
|
|| {
|
|
let client = client.clone();
|
|
let recipient_key = recipient_key.clone();
|
|
let payload = payload.clone();
|
|
async move {
|
|
let mut req = client.enqueue_request();
|
|
{
|
|
let mut p = req.get();
|
|
p.set_recipient_key(&recipient_key);
|
|
p.set_payload(&payload);
|
|
p.set_channel_id(&[]);
|
|
p.set_version(1);
|
|
let mut auth = p.reborrow().init_auth();
|
|
set_auth(&mut auth)?;
|
|
}
|
|
let resp = req.send().promise.await.context("enqueue RPC failed")?;
|
|
let seq = resp.get().context("enqueue: bad response")?.get_seq();
|
|
Ok(seq)
|
|
}
|
|
},
|
|
max_retries_from_env(),
|
|
base_delay_ms_from_env(),
|
|
anyhow_is_retriable,
|
|
)
|
|
.await
|
|
}
|
|
|
|
/// Fetch and drain all payloads for `recipient_key`.
|
|
/// Returns `(seq, payload)` pairs — sort by `seq` before MLS processing.
|
|
/// Retries on transient failures with exponential backoff.
|
|
pub async fn fetch_all(
|
|
client: &node_service::Client,
|
|
recipient_key: &[u8],
|
|
) -> anyhow::Result<Vec<(u64, Vec<u8>)>> {
|
|
let client = client.clone();
|
|
let recipient_key = recipient_key.to_vec();
|
|
retry_async(
|
|
|| {
|
|
let client = client.clone();
|
|
let recipient_key = recipient_key.clone();
|
|
async move {
|
|
let mut req = client.fetch_request();
|
|
{
|
|
let mut p = req.get();
|
|
p.set_recipient_key(&recipient_key);
|
|
p.set_channel_id(&[]);
|
|
p.set_version(1);
|
|
p.set_limit(0); // fetch all
|
|
let mut auth = p.reborrow().init_auth();
|
|
set_auth(&mut auth)?;
|
|
}
|
|
|
|
let resp = req.send().promise.await.context("fetch RPC failed")?;
|
|
|
|
let list = resp
|
|
.get()
|
|
.context("fetch: bad response")?
|
|
.get_payloads()
|
|
.context("fetch: missing payloads")?;
|
|
|
|
let mut payloads = Vec::with_capacity(list.len() as usize);
|
|
for i in 0..list.len() {
|
|
let entry = list.get(i);
|
|
let seq = entry.get_seq();
|
|
let data = entry
|
|
.get_data()
|
|
.context("fetch: envelope data read failed")?
|
|
.to_vec();
|
|
payloads.push((seq, data));
|
|
}
|
|
|
|
Ok(payloads)
|
|
}
|
|
},
|
|
max_retries_from_env(),
|
|
base_delay_ms_from_env(),
|
|
anyhow_is_retriable,
|
|
)
|
|
.await
|
|
}
|
|
|
|
/// Long-poll for payloads with optional timeout (ms).
|
|
/// Returns `(seq, payload)` pairs — sort by `seq` before MLS processing.
|
|
/// Retries on transient failures with exponential backoff.
|
|
pub async fn fetch_wait(
|
|
client: &node_service::Client,
|
|
recipient_key: &[u8],
|
|
timeout_ms: u64,
|
|
) -> anyhow::Result<Vec<(u64, Vec<u8>)>> {
|
|
let client = client.clone();
|
|
let recipient_key = recipient_key.to_vec();
|
|
retry_async(
|
|
|| {
|
|
let client = client.clone();
|
|
let recipient_key = recipient_key.clone();
|
|
let timeout_ms = timeout_ms;
|
|
async move {
|
|
let mut req = client.fetch_wait_request();
|
|
{
|
|
let mut p = req.get();
|
|
p.set_recipient_key(&recipient_key);
|
|
p.set_timeout_ms(timeout_ms);
|
|
p.set_channel_id(&[]);
|
|
p.set_version(1);
|
|
p.set_limit(0); // fetch all
|
|
let mut auth = p.reborrow().init_auth();
|
|
set_auth(&mut auth)?;
|
|
}
|
|
|
|
let resp = req.send().promise.await.context("fetch_wait RPC failed")?;
|
|
|
|
let list = resp
|
|
.get()
|
|
.context("fetch_wait: bad response")?
|
|
.get_payloads()
|
|
.context("fetch_wait: missing payloads")?;
|
|
|
|
let mut payloads = Vec::with_capacity(list.len() as usize);
|
|
for i in 0..list.len() {
|
|
let entry = list.get(i);
|
|
let seq = entry.get_seq();
|
|
let data = entry
|
|
.get_data()
|
|
.context("fetch_wait: envelope data read failed")?
|
|
.to_vec();
|
|
payloads.push((seq, data));
|
|
}
|
|
|
|
Ok(payloads)
|
|
}
|
|
},
|
|
max_retries_from_env(),
|
|
base_delay_ms_from_env(),
|
|
anyhow_is_retriable,
|
|
)
|
|
.await
|
|
}
|
|
|
|
/// Upload a hybrid (X25519 + ML-KEM-768) public key for an identity.
|
|
pub async fn upload_hybrid_key(
|
|
client: &node_service::Client,
|
|
identity_key: &[u8],
|
|
hybrid_pk: &HybridPublicKey,
|
|
) -> anyhow::Result<()> {
|
|
let mut req = client.upload_hybrid_key_request();
|
|
{
|
|
let mut p = req.get();
|
|
p.set_identity_key(identity_key);
|
|
p.set_hybrid_public_key(&hybrid_pk.to_bytes());
|
|
let mut auth = p.reborrow().init_auth();
|
|
set_auth(&mut auth)?;
|
|
}
|
|
req.send()
|
|
.promise
|
|
.await
|
|
.context("upload_hybrid_key RPC failed")?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Fetch a peer's hybrid public key from the server.
|
|
///
|
|
/// Returns `None` if the peer has not uploaded a hybrid key.
|
|
pub async fn fetch_hybrid_key(
|
|
client: &node_service::Client,
|
|
identity_key: &[u8],
|
|
) -> anyhow::Result<Option<HybridPublicKey>> {
|
|
let mut req = client.fetch_hybrid_key_request();
|
|
{
|
|
let mut p = req.get();
|
|
p.set_identity_key(identity_key);
|
|
let mut auth = p.reborrow().init_auth();
|
|
set_auth(&mut auth)?;
|
|
}
|
|
|
|
let resp = req
|
|
.send()
|
|
.promise
|
|
.await
|
|
.context("fetch_hybrid_key RPC failed")?;
|
|
|
|
let pk_bytes = resp
|
|
.get()
|
|
.context("fetch_hybrid_key: bad response")?
|
|
.get_hybrid_public_key()
|
|
.context("fetch_hybrid_key: missing field")?
|
|
.to_vec();
|
|
|
|
if pk_bytes.is_empty() {
|
|
return Ok(None);
|
|
}
|
|
|
|
let pk = HybridPublicKey::from_bytes(&pk_bytes).context("invalid hybrid public key")?;
|
|
Ok(Some(pk))
|
|
}
|
|
|
|
/// Decrypt a hybrid envelope. Requires a hybrid key; no fallback to plaintext MLS.
|
|
pub fn try_hybrid_decrypt(
|
|
hybrid_kp: Option<&quicnprotochat_core::HybridKeypair>,
|
|
payload: &[u8],
|
|
) -> anyhow::Result<Vec<u8>> {
|
|
let kp = hybrid_kp.ok_or_else(|| anyhow::anyhow!("hybrid key required for decryption"))?;
|
|
quicnprotochat_core::hybrid_decrypt(kp, payload).map_err(|e| anyhow::anyhow!("{e}"))
|
|
}
|
|
|
|
/// Return the current Unix timestamp in milliseconds.
|
|
pub fn current_timestamp_ms() -> u64 {
|
|
std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_millis() as u64
|
|
}
|