chore: rename quicproquo → quicprochat in Rust workspace
Rename all crate directories, package names, binary names, proto package/module paths, ALPN strings, env var prefixes, config filenames, mDNS service names, and plugin ABI symbols from quicproquo/qpq to quicprochat/qpc.
This commit is contained in:
135
crates/quicprochat-rpc/src/auth_handshake.rs
Normal file
135
crates/quicprochat-rpc/src/auth_handshake.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
//! Auth handshake protocol for QUIC connections.
|
||||
//!
|
||||
//! The session token is sent as a "connection init" message on the first
|
||||
//! bi-stream before any RPC calls.
|
||||
//!
|
||||
//! ## Protocol
|
||||
//! ```text
|
||||
//! Client → Server: [0x01 magic][token_len: u16 BE][token bytes]
|
||||
//! Server → Client: [0x01 magic][0x00 status OK]
|
||||
//! ```
|
||||
|
||||
use crate::error::RpcError;
|
||||
|
||||
/// Magic byte identifying an auth init frame.
|
||||
pub const AUTH_INIT_MAGIC: u8 = 0x01;
|
||||
|
||||
/// Status byte: auth accepted.
|
||||
const AUTH_STATUS_OK: u8 = 0x00;
|
||||
|
||||
/// Maximum token length (64 KiB).
|
||||
const MAX_TOKEN_LEN: usize = 65535;
|
||||
|
||||
/// Write an auth init frame to a QUIC send stream.
|
||||
pub async fn send_auth_init(
|
||||
send: &mut quinn::SendStream,
|
||||
token: &[u8],
|
||||
) -> Result<(), RpcError> {
|
||||
if token.len() > MAX_TOKEN_LEN {
|
||||
return Err(RpcError::Encode(format!(
|
||||
"auth token too large: {} bytes (max {MAX_TOKEN_LEN})",
|
||||
token.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let token_len = token.len() as u16;
|
||||
let mut buf = Vec::with_capacity(1 + 2 + token.len());
|
||||
buf.push(AUTH_INIT_MAGIC);
|
||||
buf.extend_from_slice(&token_len.to_be_bytes());
|
||||
buf.extend_from_slice(token);
|
||||
|
||||
send.write_all(&buf)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("send auth init: {e}")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read an auth init frame from a QUIC recv stream.
|
||||
pub async fn recv_auth_init(
|
||||
recv: &mut quinn::RecvStream,
|
||||
) -> Result<Vec<u8>, RpcError> {
|
||||
// Read magic byte.
|
||||
let mut header = [0u8; 3];
|
||||
recv.read_exact(&mut header)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("read auth init header: {e}")))?;
|
||||
|
||||
if header[0] != AUTH_INIT_MAGIC {
|
||||
return Err(RpcError::Decode(format!(
|
||||
"bad auth init magic: expected 0x{AUTH_INIT_MAGIC:02x}, got 0x{:02x}",
|
||||
header[0]
|
||||
)));
|
||||
}
|
||||
|
||||
let token_len = u16::from_be_bytes([header[1], header[2]]) as usize;
|
||||
if token_len > MAX_TOKEN_LEN {
|
||||
return Err(RpcError::Decode(format!(
|
||||
"auth token length {token_len} exceeds max {MAX_TOKEN_LEN}"
|
||||
)));
|
||||
}
|
||||
|
||||
let mut token = vec![0u8; token_len];
|
||||
if token_len > 0 {
|
||||
recv.read_exact(&mut token)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("read auth token: {e}")))?;
|
||||
}
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Send auth ack (success response).
|
||||
pub async fn send_auth_ack(
|
||||
send: &mut quinn::SendStream,
|
||||
) -> Result<(), RpcError> {
|
||||
let buf = [AUTH_INIT_MAGIC, AUTH_STATUS_OK];
|
||||
send.write_all(&buf)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("send auth ack: {e}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read auth ack from the server.
|
||||
pub async fn recv_auth_ack(
|
||||
recv: &mut quinn::RecvStream,
|
||||
) -> Result<(), RpcError> {
|
||||
let mut buf = [0u8; 2];
|
||||
recv.read_exact(&mut buf)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("read auth ack: {e}")))?;
|
||||
|
||||
if buf[0] != AUTH_INIT_MAGIC {
|
||||
return Err(RpcError::Decode(format!(
|
||||
"bad auth ack magic: expected 0x{AUTH_INIT_MAGIC:02x}, got 0x{:02x}",
|
||||
buf[0]
|
||||
)));
|
||||
}
|
||||
|
||||
if buf[1] != AUTH_STATUS_OK {
|
||||
return Err(RpcError::Decode(format!(
|
||||
"auth rejected with status 0x{:02x}",
|
||||
buf[1]
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn auth_init_magic_is_0x01() {
|
||||
assert_eq!(AUTH_INIT_MAGIC, 0x01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn token_too_large_returns_encode_error() {
|
||||
// We cannot call send_auth_init without a real stream, but we can
|
||||
// verify the length check logic by constructing the guard condition.
|
||||
let big_token = vec![0u8; MAX_TOKEN_LEN + 1];
|
||||
assert!(big_token.len() > MAX_TOKEN_LEN);
|
||||
}
|
||||
}
|
||||
196
crates/quicprochat-rpc/src/client.rs
Normal file
196
crates/quicprochat-rpc/src/client.rs
Normal file
@@ -0,0 +1,196 @@
|
||||
//! QUIC RPC client — connect to server, send requests, receive push events.
|
||||
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use quinn::{Connection, Endpoint};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::auth_handshake;
|
||||
use crate::error::{RpcError, RpcStatus};
|
||||
use crate::framing::{PushFrame, RequestFrame, ResponseFrame};
|
||||
|
||||
/// Configuration for the RPC client.
|
||||
pub struct RpcClientConfig {
|
||||
/// Server address to connect to.
|
||||
pub server_addr: std::net::SocketAddr,
|
||||
/// Server name for TLS verification.
|
||||
pub server_name: String,
|
||||
/// TLS client config (rustls).
|
||||
pub tls_config: Arc<rustls::ClientConfig>,
|
||||
/// ALPN protocol.
|
||||
pub alpn: Vec<u8>,
|
||||
/// Session token to send during auth handshake.
|
||||
pub session_token: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
/// A QUIC RPC client connection.
|
||||
pub struct RpcClient {
|
||||
connection: Connection,
|
||||
next_request_id: AtomicU32,
|
||||
}
|
||||
|
||||
impl RpcClient {
|
||||
/// Connect to the RPC server.
|
||||
pub async fn connect(config: RpcClientConfig) -> Result<Self, RpcError> {
|
||||
let mut tls = (*config.tls_config).clone();
|
||||
tls.alpn_protocols = vec![config.alpn];
|
||||
let quic_tls = quinn::crypto::rustls::QuicClientConfig::try_from(tls)
|
||||
.map_err(|e| RpcError::Connection(format!("TLS config: {e}")))?;
|
||||
|
||||
let bind_addr = std::net::SocketAddr::new(
|
||||
std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED),
|
||||
0,
|
||||
);
|
||||
let mut endpoint = Endpoint::client(bind_addr)
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new(quic_tls)));
|
||||
|
||||
let connection = endpoint
|
||||
.connect(config.server_addr, &config.server_name)
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
debug!(remote = %connection.remote_address(), "connected to RPC server");
|
||||
|
||||
// Perform auth handshake if a session token was provided.
|
||||
if let Some(ref token) = config.session_token {
|
||||
let (mut send, mut recv) = connection
|
||||
.open_bi()
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("open auth stream: {e}")))?;
|
||||
|
||||
auth_handshake::send_auth_init(&mut send, token).await?;
|
||||
send.finish()
|
||||
.map_err(|e| RpcError::Connection(format!("finish auth send: {e}")))?;
|
||||
auth_handshake::recv_auth_ack(&mut recv).await?;
|
||||
debug!("auth handshake complete");
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
connection,
|
||||
next_request_id: AtomicU32::new(1),
|
||||
})
|
||||
}
|
||||
|
||||
/// Send an RPC request and wait for the response.
|
||||
pub async fn call(
|
||||
&self,
|
||||
method_id: u16,
|
||||
payload: Bytes,
|
||||
) -> Result<Bytes, RpcError> {
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let (mut send, mut recv) = self
|
||||
.connection
|
||||
.open_bi()
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
// Send request.
|
||||
let frame = RequestFrame {
|
||||
method_id,
|
||||
request_id,
|
||||
payload,
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
send.write_all(&encoded)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
// Read response.
|
||||
let mut buf = BytesMut::new();
|
||||
while let Some(chunk) = recv
|
||||
.read_chunk(65536, true)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?
|
||||
{
|
||||
buf.extend_from_slice(&chunk.bytes);
|
||||
if buf.len() > crate::framing::MAX_PAYLOAD_SIZE + crate::framing::RESPONSE_HEADER_SIZE {
|
||||
return Err(RpcError::PayloadTooLarge {
|
||||
size: buf.len(),
|
||||
max: crate::framing::MAX_PAYLOAD_SIZE,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let response = ResponseFrame::decode(&mut buf)?
|
||||
.ok_or_else(|| RpcError::Decode("incomplete response frame".into()))?;
|
||||
|
||||
if response.request_id != request_id {
|
||||
return Err(RpcError::Decode(format!(
|
||||
"request_id mismatch: sent {request_id}, got {}",
|
||||
response.request_id
|
||||
)));
|
||||
}
|
||||
|
||||
match RpcStatus::from_u8(response.status) {
|
||||
Some(RpcStatus::Ok) => Ok(response.payload),
|
||||
Some(status) => Err(RpcError::Server {
|
||||
status,
|
||||
message: String::from_utf8_lossy(&response.payload).into_owned(),
|
||||
}),
|
||||
None => Err(RpcError::Decode(format!(
|
||||
"unknown status byte: {}",
|
||||
response.status
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to server-push events. Returns a receiver channel.
|
||||
/// Spawns a background task that reads uni-streams.
|
||||
pub fn subscribe_push(&self) -> mpsc::UnboundedReceiver<PushFrame> {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
let conn = self.connection.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match conn.accept_uni().await {
|
||||
Ok(mut recv) => {
|
||||
let mut buf = BytesMut::new();
|
||||
loop {
|
||||
match recv.read_chunk(65536, true).await {
|
||||
Ok(Some(chunk)) => buf.extend_from_slice(&chunk.bytes),
|
||||
Ok(None) => break,
|
||||
Err(e) => {
|
||||
debug!("push stream read error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
match PushFrame::decode(&mut buf) {
|
||||
Ok(Some(frame)) => {
|
||||
if tx.send(frame).is_err() {
|
||||
return; // receiver dropped
|
||||
}
|
||||
}
|
||||
Ok(None) => debug!("incomplete push frame"),
|
||||
Err(e) => debug!("push decode error: {e}"),
|
||||
}
|
||||
}
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
|
||||
Err(e) => {
|
||||
warn!("accept_uni error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
/// Close the connection gracefully.
|
||||
pub fn close(&self) {
|
||||
self.connection.close(0u32.into(), b"bye");
|
||||
}
|
||||
|
||||
/// Get the underlying QUIC connection (for advanced use).
|
||||
pub fn connection(&self) -> &Connection {
|
||||
&self.connection
|
||||
}
|
||||
}
|
||||
74
crates/quicprochat-rpc/src/error.rs
Normal file
74
crates/quicprochat-rpc/src/error.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
//! RPC error types.
|
||||
|
||||
/// Status codes for RPC responses.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum RpcStatus {
|
||||
/// Request succeeded.
|
||||
Ok = 0,
|
||||
/// Client sent a malformed request.
|
||||
BadRequest = 1,
|
||||
/// Authentication required or token invalid.
|
||||
Unauthorized = 2,
|
||||
/// Caller lacks permission for this operation.
|
||||
Forbidden = 3,
|
||||
/// Requested resource not found.
|
||||
NotFound = 4,
|
||||
/// Rate limit exceeded.
|
||||
RateLimited = 5,
|
||||
/// Request deadline exceeded (server-side timeout).
|
||||
DeadlineExceeded = 8,
|
||||
/// Server is shutting down (draining).
|
||||
Unavailable = 9,
|
||||
/// Internal server error.
|
||||
Internal = 10,
|
||||
/// Method not recognized.
|
||||
UnknownMethod = 11,
|
||||
}
|
||||
|
||||
impl RpcStatus {
|
||||
/// Decode a status byte. Returns `None` for unknown values.
|
||||
pub fn from_u8(v: u8) -> Option<Self> {
|
||||
match v {
|
||||
0 => Some(Self::Ok),
|
||||
1 => Some(Self::BadRequest),
|
||||
2 => Some(Self::Unauthorized),
|
||||
3 => Some(Self::Forbidden),
|
||||
4 => Some(Self::NotFound),
|
||||
5 => Some(Self::RateLimited),
|
||||
8 => Some(Self::DeadlineExceeded),
|
||||
9 => Some(Self::Unavailable),
|
||||
10 => Some(Self::Internal),
|
||||
11 => Some(Self::UnknownMethod),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors that can occur in the RPC layer.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RpcError {
|
||||
#[error("connection error: {0}")]
|
||||
Connection(String),
|
||||
|
||||
#[error("encoding error: {0}")]
|
||||
Encode(String),
|
||||
|
||||
#[error("decoding error: {0}")]
|
||||
Decode(String),
|
||||
|
||||
#[error("server returned error status {status:?}: {message}")]
|
||||
Server {
|
||||
status: RpcStatus,
|
||||
message: String,
|
||||
},
|
||||
|
||||
#[error("request timed out")]
|
||||
Timeout,
|
||||
|
||||
#[error("stream closed unexpectedly")]
|
||||
StreamClosed,
|
||||
|
||||
#[error("payload too large: {size} bytes (max {max})")]
|
||||
PayloadTooLarge { size: usize, max: usize },
|
||||
}
|
||||
280
crates/quicprochat-rpc/src/framing.rs
Normal file
280
crates/quicprochat-rpc/src/framing.rs
Normal file
@@ -0,0 +1,280 @@
|
||||
//! Wire format encoding and decoding for the quicprochat v2 RPC protocol.
|
||||
//!
|
||||
//! ## Request frame
|
||||
//! ```text
|
||||
//! [method_id: u16 BE][request_id: u32 BE][payload_len: u32 BE][protobuf bytes]
|
||||
//! ```
|
||||
//!
|
||||
//! ## Response frame
|
||||
//! ```text
|
||||
//! [status: u8][request_id: u32 BE][payload_len: u32 BE][protobuf bytes]
|
||||
//! ```
|
||||
//!
|
||||
//! ## Push frame (server → client, uni-stream)
|
||||
//! ```text
|
||||
//! [event_type: u16 BE][payload_len: u32 BE][protobuf bytes]
|
||||
//! ```
|
||||
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::error::{RpcError, RpcStatus};
|
||||
|
||||
/// Maximum payload size: 4 MiB.
|
||||
pub const MAX_PAYLOAD_SIZE: usize = 4 * 1024 * 1024;
|
||||
|
||||
/// Request header size: 2 (method) + 4 (req_id) + 4 (len) = 10 bytes.
|
||||
pub const REQUEST_HEADER_SIZE: usize = 10;
|
||||
|
||||
/// Response header size: 1 (status) + 4 (req_id) + 4 (len) = 9 bytes.
|
||||
pub const RESPONSE_HEADER_SIZE: usize = 9;
|
||||
|
||||
/// Push header size: 2 (event_type) + 4 (len) = 6 bytes.
|
||||
pub const PUSH_HEADER_SIZE: usize = 6;
|
||||
|
||||
// ── Request ──────────────────────────────────────────────────────────────────
|
||||
|
||||
/// A decoded RPC request frame.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RequestFrame {
|
||||
pub method_id: u16,
|
||||
pub request_id: u32,
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
impl RequestFrame {
|
||||
/// Encode this request into a byte buffer.
|
||||
pub fn encode(&self) -> Bytes {
|
||||
let mut buf = BytesMut::with_capacity(REQUEST_HEADER_SIZE + self.payload.len());
|
||||
buf.put_u16(self.method_id);
|
||||
buf.put_u32(self.request_id);
|
||||
buf.put_u32(self.payload.len() as u32);
|
||||
buf.put(self.payload.clone());
|
||||
buf.freeze()
|
||||
}
|
||||
|
||||
/// Decode a request frame from a byte buffer.
|
||||
/// Returns `None` if the buffer does not contain a complete frame.
|
||||
pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, RpcError> {
|
||||
if buf.len() < REQUEST_HEADER_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Peek at payload_len without consuming.
|
||||
let payload_len =
|
||||
u32::from_be_bytes([buf[6], buf[7], buf[8], buf[9]]) as usize;
|
||||
|
||||
if payload_len > MAX_PAYLOAD_SIZE {
|
||||
return Err(RpcError::PayloadTooLarge {
|
||||
size: payload_len,
|
||||
max: MAX_PAYLOAD_SIZE,
|
||||
});
|
||||
}
|
||||
|
||||
let total = REQUEST_HEADER_SIZE + payload_len;
|
||||
if buf.len() < total {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let method_id = buf.get_u16();
|
||||
let request_id = buf.get_u32();
|
||||
let _len = buf.get_u32();
|
||||
let payload = buf.split_to(payload_len).freeze();
|
||||
|
||||
Ok(Some(Self {
|
||||
method_id,
|
||||
request_id,
|
||||
payload,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Response ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// A decoded RPC response frame.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResponseFrame {
|
||||
pub status: u8,
|
||||
pub request_id: u32,
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
impl ResponseFrame {
|
||||
/// Encode this response into a byte buffer.
|
||||
pub fn encode(&self) -> Bytes {
|
||||
let mut buf = BytesMut::with_capacity(RESPONSE_HEADER_SIZE + self.payload.len());
|
||||
buf.put_u8(self.status);
|
||||
buf.put_u32(self.request_id);
|
||||
buf.put_u32(self.payload.len() as u32);
|
||||
buf.put(self.payload.clone());
|
||||
buf.freeze()
|
||||
}
|
||||
|
||||
/// Decode a response frame from a byte buffer.
|
||||
pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, RpcError> {
|
||||
if buf.len() < RESPONSE_HEADER_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let payload_len =
|
||||
u32::from_be_bytes([buf[5], buf[6], buf[7], buf[8]]) as usize;
|
||||
|
||||
if payload_len > MAX_PAYLOAD_SIZE {
|
||||
return Err(RpcError::PayloadTooLarge {
|
||||
size: payload_len,
|
||||
max: MAX_PAYLOAD_SIZE,
|
||||
});
|
||||
}
|
||||
|
||||
let total = RESPONSE_HEADER_SIZE + payload_len;
|
||||
if buf.len() < total {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let status = buf.get_u8();
|
||||
let request_id = buf.get_u32();
|
||||
let _len = buf.get_u32();
|
||||
let payload = buf.split_to(payload_len).freeze();
|
||||
|
||||
Ok(Some(Self {
|
||||
status,
|
||||
request_id,
|
||||
payload,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Convert the status byte to an `RpcStatus`.
|
||||
pub fn rpc_status(&self) -> Option<RpcStatus> {
|
||||
RpcStatus::from_u8(self.status)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Push ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// A decoded server-push event frame (sent on QUIC uni-streams).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PushFrame {
|
||||
pub event_type: u16,
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
impl PushFrame {
|
||||
/// Encode this push frame into a byte buffer.
|
||||
pub fn encode(&self) -> Bytes {
|
||||
let mut buf = BytesMut::with_capacity(PUSH_HEADER_SIZE + self.payload.len());
|
||||
buf.put_u16(self.event_type);
|
||||
buf.put_u32(self.payload.len() as u32);
|
||||
buf.put(self.payload.clone());
|
||||
buf.freeze()
|
||||
}
|
||||
|
||||
/// Decode a push frame from a byte buffer.
|
||||
pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, RpcError> {
|
||||
if buf.len() < PUSH_HEADER_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let payload_len =
|
||||
u32::from_be_bytes([buf[2], buf[3], buf[4], buf[5]]) as usize;
|
||||
|
||||
if payload_len > MAX_PAYLOAD_SIZE {
|
||||
return Err(RpcError::PayloadTooLarge {
|
||||
size: payload_len,
|
||||
max: MAX_PAYLOAD_SIZE,
|
||||
});
|
||||
}
|
||||
|
||||
let total = PUSH_HEADER_SIZE + payload_len;
|
||||
if buf.len() < total {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let event_type = buf.get_u16();
|
||||
let _len = buf.get_u32();
|
||||
let payload = buf.split_to(payload_len).freeze();
|
||||
|
||||
Ok(Some(Self {
|
||||
event_type,
|
||||
payload,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn request_roundtrip() {
|
||||
let frame = RequestFrame {
|
||||
method_id: 42,
|
||||
request_id: 1001,
|
||||
payload: Bytes::from_static(b"hello"),
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
let mut buf = BytesMut::from(encoded.as_ref());
|
||||
let decoded = RequestFrame::decode(&mut buf).expect("decode").expect("complete");
|
||||
assert_eq!(decoded.method_id, 42);
|
||||
assert_eq!(decoded.request_id, 1001);
|
||||
assert_eq!(decoded.payload, Bytes::from_static(b"hello"));
|
||||
assert!(buf.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_roundtrip() {
|
||||
let frame = ResponseFrame {
|
||||
status: RpcStatus::Ok as u8,
|
||||
request_id: 2002,
|
||||
payload: Bytes::from_static(b"world"),
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
let mut buf = BytesMut::from(encoded.as_ref());
|
||||
let decoded = ResponseFrame::decode(&mut buf).expect("decode").expect("complete");
|
||||
assert_eq!(decoded.status, 0);
|
||||
assert_eq!(decoded.request_id, 2002);
|
||||
assert_eq!(decoded.payload, Bytes::from_static(b"world"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_roundtrip() {
|
||||
let frame = PushFrame {
|
||||
event_type: 7,
|
||||
payload: Bytes::from_static(b"event-data"),
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
let mut buf = BytesMut::from(encoded.as_ref());
|
||||
let decoded = PushFrame::decode(&mut buf).expect("decode").expect("complete");
|
||||
assert_eq!(decoded.event_type, 7);
|
||||
assert_eq!(decoded.payload, Bytes::from_static(b"event-data"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn incomplete_request_returns_none() {
|
||||
let mut buf = BytesMut::from(&[0u8; 5][..]);
|
||||
assert!(RequestFrame::decode(&mut buf).expect("no error").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn payload_too_large_rejected() {
|
||||
// Craft a request header with payload_len = MAX + 1.
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_u16(1);
|
||||
buf.put_u32(1);
|
||||
buf.put_u32((MAX_PAYLOAD_SIZE + 1) as u32);
|
||||
let result = RequestFrame::decode(&mut buf);
|
||||
assert!(matches!(result, Err(RpcError::PayloadTooLarge { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_payload_request() {
|
||||
let frame = RequestFrame {
|
||||
method_id: 0,
|
||||
request_id: 0,
|
||||
payload: Bytes::new(),
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
assert_eq!(encoded.len(), REQUEST_HEADER_SIZE);
|
||||
let mut buf = BytesMut::from(encoded.as_ref());
|
||||
let decoded = RequestFrame::decode(&mut buf).expect("decode").expect("complete");
|
||||
assert!(decoded.payload.is_empty());
|
||||
}
|
||||
}
|
||||
15
crates/quicprochat-rpc/src/lib.rs
Normal file
15
crates/quicprochat-rpc/src/lib.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! QUIC RPC framework for quicprochat v2.
|
||||
//!
|
||||
//! Wire format per QUIC stream:
|
||||
//! - Request: `[method_id: u16][request_id: u32][payload_len: u32][protobuf bytes]`
|
||||
//! - Response: `[status: u8][request_id: u32][payload_len: u32][protobuf bytes]`
|
||||
//! - Push: `[event_type: u16][payload_len: u32][protobuf bytes]` (uni-stream)
|
||||
|
||||
pub mod auth_handshake;
|
||||
pub mod framing;
|
||||
pub mod method;
|
||||
pub mod server;
|
||||
pub mod client;
|
||||
pub mod middleware;
|
||||
pub mod push;
|
||||
pub mod error;
|
||||
208
crates/quicprochat-rpc/src/method.rs
Normal file
208
crates/quicprochat-rpc/src/method.rs
Normal file
@@ -0,0 +1,208 @@
|
||||
//! Method registry — maps method IDs to handler functions.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use bytes::Bytes;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::error::RpcStatus;
|
||||
|
||||
/// The result of handling an RPC request.
|
||||
pub struct HandlerResult {
|
||||
pub status: RpcStatus,
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
impl HandlerResult {
|
||||
/// Shorthand for a successful response.
|
||||
pub fn ok(payload: Bytes) -> Self {
|
||||
Self {
|
||||
status: RpcStatus::Ok,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
|
||||
/// Shorthand for an error response.
|
||||
pub fn err(status: RpcStatus, message: &str) -> Self {
|
||||
Self {
|
||||
status,
|
||||
payload: Bytes::copy_from_slice(message.as_bytes()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context passed to every RPC handler.
|
||||
pub struct RequestContext {
|
||||
/// The authenticated identity key of the caller, if any.
|
||||
pub identity_key: Option<Vec<u8>>,
|
||||
/// The session token, if provided.
|
||||
pub session_token: Option<Vec<u8>>,
|
||||
/// The raw request payload (protobuf-encoded).
|
||||
pub payload: Bytes,
|
||||
/// Unique correlation ID for request tracing (UUID v7, monotonic).
|
||||
pub trace_id: String,
|
||||
/// The effective deadline for this request. Handlers can check this to bail
|
||||
/// early on long-running operations. `None` means no deadline.
|
||||
pub deadline: Option<Instant>,
|
||||
}
|
||||
|
||||
/// Type-erased async handler function.
|
||||
pub type HandlerFn<S> = Arc<
|
||||
dyn Fn(Arc<S>, RequestContext) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>;
|
||||
|
||||
/// Per-method registration entry.
|
||||
struct MethodEntry<S> {
|
||||
handler: HandlerFn<S>,
|
||||
name: &'static str,
|
||||
/// Optional per-method timeout override. `None` means use the server default.
|
||||
timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
/// Registry mapping method IDs to handler functions.
|
||||
pub struct MethodRegistry<S> {
|
||||
handlers: HashMap<u16, MethodEntry<S>>,
|
||||
/// Default timeout applied to methods that don't specify their own.
|
||||
default_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl<S: Send + Sync + 'static> MethodRegistry<S> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handlers: HashMap::new(),
|
||||
default_timeout: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the default timeout for all methods that don't have a per-method override.
|
||||
pub fn set_default_timeout(&mut self, timeout: Duration) {
|
||||
self.default_timeout = Some(timeout);
|
||||
}
|
||||
|
||||
/// Register a handler for a method ID.
|
||||
pub fn register<F, Fut>(&mut self, method_id: u16, name: &'static str, handler: F)
|
||||
where
|
||||
F: Fn(Arc<S>, RequestContext) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output = HandlerResult> + Send + 'static,
|
||||
{
|
||||
let handler = Arc::new(move |state: Arc<S>, ctx: RequestContext| {
|
||||
Box::pin(handler(state, ctx)) as Pin<Box<dyn Future<Output = HandlerResult> + Send>>
|
||||
});
|
||||
self.handlers.insert(method_id, MethodEntry { handler, name, timeout: None });
|
||||
}
|
||||
|
||||
/// Register a handler with a per-method timeout override.
|
||||
pub fn register_with_timeout<F, Fut>(
|
||||
&mut self,
|
||||
method_id: u16,
|
||||
name: &'static str,
|
||||
timeout: Duration,
|
||||
handler: F,
|
||||
)
|
||||
where
|
||||
F: Fn(Arc<S>, RequestContext) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output = HandlerResult> + Send + 'static,
|
||||
{
|
||||
let handler = Arc::new(move |state: Arc<S>, ctx: RequestContext| {
|
||||
Box::pin(handler(state, ctx)) as Pin<Box<dyn Future<Output = HandlerResult> + Send>>
|
||||
});
|
||||
self.handlers.insert(method_id, MethodEntry { handler, name, timeout: Some(timeout) });
|
||||
}
|
||||
|
||||
/// Look up a handler, name, and effective timeout by method ID.
|
||||
pub fn get(&self, method_id: u16) -> Option<(&HandlerFn<S>, &'static str, Option<Duration>)> {
|
||||
self.handlers.get(&method_id).map(|e| {
|
||||
(&e.handler, e.name, e.timeout.or(self.default_timeout))
|
||||
})
|
||||
}
|
||||
|
||||
/// Return the number of registered methods.
|
||||
pub fn len(&self) -> usize {
|
||||
self.handlers.len()
|
||||
}
|
||||
|
||||
/// Whether the registry is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.handlers.is_empty()
|
||||
}
|
||||
|
||||
/// Iterate over all registered (method_id, name) pairs.
|
||||
pub fn methods(&self) -> impl Iterator<Item = (u16, &'static str)> + '_ {
|
||||
self.handlers.iter().map(|(&id, entry)| (id, entry.name))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Send + Sync + 'static> Default for MethodRegistry<S> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn registry_default_timeout_applies_to_methods() {
|
||||
let mut reg = MethodRegistry::<()>::new();
|
||||
reg.set_default_timeout(Duration::from_secs(30));
|
||||
reg.register(1, "Test", |_state: Arc<()>, _ctx| async { HandlerResult::ok(Bytes::new()) });
|
||||
|
||||
let (_, name, timeout) = reg.get(1).expect("registered method");
|
||||
assert_eq!(name, "Test");
|
||||
assert_eq!(timeout, Some(Duration::from_secs(30)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_per_method_timeout_overrides_default() {
|
||||
let mut reg = MethodRegistry::<()>::new();
|
||||
reg.set_default_timeout(Duration::from_secs(30));
|
||||
reg.register_with_timeout(
|
||||
1,
|
||||
"Slow",
|
||||
Duration::from_secs(120),
|
||||
|_state: Arc<()>, _ctx| async { HandlerResult::ok(Bytes::new()) },
|
||||
);
|
||||
|
||||
let (_, _, timeout) = reg.get(1).expect("registered method");
|
||||
assert_eq!(timeout, Some(Duration::from_secs(120)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_no_default_timeout_returns_none() {
|
||||
let mut reg = MethodRegistry::<()>::new();
|
||||
reg.register(1, "NoTimeout", |_state: Arc<()>, _ctx| async {
|
||||
HandlerResult::ok(Bytes::new())
|
||||
});
|
||||
|
||||
let (_, _, timeout) = reg.get(1).expect("registered method");
|
||||
assert_eq!(timeout, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_context_deadline_is_accessible() {
|
||||
let ctx = RequestContext {
|
||||
identity_key: None,
|
||||
session_token: None,
|
||||
payload: Bytes::new(),
|
||||
trace_id: String::new(),
|
||||
deadline: Some(Instant::now() + Duration::from_secs(10)),
|
||||
};
|
||||
assert!(ctx.deadline.is_some());
|
||||
|
||||
let ctx_no_deadline = RequestContext {
|
||||
identity_key: None,
|
||||
session_token: None,
|
||||
payload: Bytes::new(),
|
||||
trace_id: String::new(),
|
||||
deadline: None,
|
||||
};
|
||||
assert!(ctx_no_deadline.deadline.is_none());
|
||||
}
|
||||
}
|
||||
154
crates/quicprochat-rpc/src/middleware.rs
Normal file
154
crates/quicprochat-rpc/src/middleware.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
//! Middleware layers for the RPC server.
|
||||
//!
|
||||
//! - `SessionValidator`: validates session tokens and resolves identity keys.
|
||||
//! - `RateLimiter`: per-key sliding-window rate limiting.
|
||||
//! - `log_rpc_call`: structured audit logging for RPC calls.
|
||||
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use sha2::Digest;
|
||||
|
||||
use crate::error::RpcStatus;
|
||||
|
||||
// ── Auth middleware ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Validates bearer tokens and resolves identity keys.
|
||||
pub trait SessionValidator: Send + Sync + 'static {
|
||||
/// Validate a session token, returning the identity key if valid.
|
||||
fn validate(&self, token: &[u8]) -> Option<Vec<u8>>;
|
||||
}
|
||||
|
||||
/// Auth context extracted from a validated session.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthContext {
|
||||
/// The Ed25519 identity key of the authenticated caller.
|
||||
pub identity_key: Vec<u8>,
|
||||
}
|
||||
|
||||
// ── Rate limiter ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Simple per-key sliding-window rate limiter.
|
||||
pub struct RateLimiter {
|
||||
/// Max requests per window.
|
||||
max_requests: u32,
|
||||
/// Window duration.
|
||||
window: Duration,
|
||||
/// Map from key → (count, window_start).
|
||||
state: DashMap<Vec<u8>, (u32, Instant)>,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
/// Create a new rate limiter.
|
||||
pub fn new(max_requests: u32, window: Duration) -> Self {
|
||||
Self {
|
||||
max_requests,
|
||||
window,
|
||||
state: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a request from `key` is allowed. Returns `true` if allowed.
|
||||
pub fn check(&self, key: &[u8]) -> bool {
|
||||
let now = Instant::now();
|
||||
let mut entry = self.state.entry(key.to_vec()).or_insert((0, now));
|
||||
let (count, window_start) = entry.value_mut();
|
||||
|
||||
if now.duration_since(*window_start) >= self.window {
|
||||
// Reset window.
|
||||
*count = 1;
|
||||
*window_start = now;
|
||||
true
|
||||
} else if *count < self.max_requests {
|
||||
*count += 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove expired entries (call periodically for memory hygiene).
|
||||
pub fn gc(&self) {
|
||||
let now = Instant::now();
|
||||
self.state.retain(|_, (_, start)| now.duration_since(*start) < self.window * 2);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Audit logging ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Log an RPC call with timing and caller info.
|
||||
///
|
||||
/// When `redact` is true, the identity key is hashed before logging so that
|
||||
/// raw keys never appear in log output.
|
||||
pub fn log_rpc_call(
|
||||
method_name: &str,
|
||||
identity_key: Option<&[u8]>,
|
||||
latency: Duration,
|
||||
status: RpcStatus,
|
||||
redact: bool,
|
||||
) {
|
||||
let ik_display = match identity_key {
|
||||
Some(ik) if redact => {
|
||||
let hash_input_len = 8.min(ik.len());
|
||||
let digest = sha2::Sha256::digest(&ik[..hash_input_len]);
|
||||
format!("h:{}", hex_prefix(&digest))
|
||||
}
|
||||
Some(ik) => hex_prefix(ik),
|
||||
None => "anonymous".to_string(),
|
||||
};
|
||||
tracing::info!(
|
||||
method = method_name,
|
||||
identity = %ik_display,
|
||||
latency_ms = latency.as_millis() as u64,
|
||||
status = ?status,
|
||||
"rpc"
|
||||
);
|
||||
}
|
||||
|
||||
fn hex_prefix(bytes: &[u8]) -> String {
|
||||
bytes
|
||||
.iter()
|
||||
.take(4)
|
||||
.map(|b| format!("{b:02x}"))
|
||||
.collect::<String>()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn rate_limiter_allows_within_limit() {
|
||||
let rl = RateLimiter::new(3, Duration::from_secs(60));
|
||||
let key = b"test-key";
|
||||
assert!(rl.check(key));
|
||||
assert!(rl.check(key));
|
||||
assert!(rl.check(key));
|
||||
assert!(!rl.check(key)); // 4th request denied
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rate_limiter_resets_after_window() {
|
||||
let rl = RateLimiter::new(1, Duration::from_millis(1));
|
||||
let key = b"test-key";
|
||||
assert!(rl.check(key));
|
||||
assert!(!rl.check(key));
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
assert!(rl.check(key)); // window expired
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hex_prefix_formats_first_4_bytes() {
|
||||
assert_eq!(hex_prefix(&[0xab, 0xcd, 0xef, 0x01, 0x99]), "abcdef01");
|
||||
assert_eq!(hex_prefix(&[0x00, 0xff]), "00ff");
|
||||
assert_eq!(hex_prefix(&[]), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_rpc_call_does_not_panic() {
|
||||
// Verify that audit log function does not panic with various inputs.
|
||||
log_rpc_call("test.method", None, Duration::from_millis(42), RpcStatus::Ok, false);
|
||||
log_rpc_call("test.method", Some(&[1, 2, 3, 4, 5, 6, 7, 8]), Duration::from_millis(1), RpcStatus::Internal, true);
|
||||
log_rpc_call("test.method", Some(&[0xab]), Duration::ZERO, RpcStatus::Unauthorized, true);
|
||||
}
|
||||
}
|
||||
114
crates/quicprochat-rpc/src/push.rs
Normal file
114
crates/quicprochat-rpc/src/push.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
//! Server-push infrastructure — manages push connections per identity key.
|
||||
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use quinn::Connection;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::error::RpcError;
|
||||
|
||||
/// Manages push connections per identity key.
|
||||
pub struct PushBroker {
|
||||
/// Map from identity_key to active QUIC connections.
|
||||
connections: DashMap<Vec<u8>, Vec<Connection>>,
|
||||
}
|
||||
|
||||
impl PushBroker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
connections: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a connection for an identity.
|
||||
pub fn register(&self, identity_key: Vec<u8>, connection: Connection) {
|
||||
self.connections
|
||||
.entry(identity_key)
|
||||
.or_default()
|
||||
.push(connection);
|
||||
}
|
||||
|
||||
/// Remove closed connections from the registry.
|
||||
pub fn gc(&self) {
|
||||
self.connections.alter_all(|_, mut conns| {
|
||||
conns.retain(|c| c.close_reason().is_none());
|
||||
conns
|
||||
});
|
||||
self.connections.retain(|_, conns| !conns.is_empty());
|
||||
}
|
||||
|
||||
/// Send a push event to all connections for an identity.
|
||||
/// Returns the number of successful sends.
|
||||
pub async fn send_to(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
event_type: u16,
|
||||
payload: Bytes,
|
||||
) -> usize {
|
||||
let conns = match self.connections.get(identity_key) {
|
||||
Some(entry) => entry.clone(),
|
||||
None => return 0,
|
||||
};
|
||||
|
||||
let mut sent = 0usize;
|
||||
for conn in &conns {
|
||||
match crate::server::send_push(conn, event_type, payload.clone()).await {
|
||||
Ok(()) => {
|
||||
sent += 1;
|
||||
debug!("push sent to connection {}", conn.remote_address());
|
||||
}
|
||||
Err(RpcError::Connection(e)) => {
|
||||
warn!("push send failed to {}: {e}", conn.remote_address());
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("push send error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
sent
|
||||
}
|
||||
|
||||
/// Send a push event to all members of a channel.
|
||||
/// `member_keys` is the list of identity keys in the channel.
|
||||
/// Returns the total number of successful sends.
|
||||
pub async fn send_to_channel(
|
||||
&self,
|
||||
member_keys: &[Vec<u8>],
|
||||
event_type: u16,
|
||||
payload: Bytes,
|
||||
) -> usize {
|
||||
let mut total = 0usize;
|
||||
for key in member_keys {
|
||||
total += self.send_to(key, event_type, payload.clone()).await;
|
||||
}
|
||||
total
|
||||
}
|
||||
|
||||
/// Number of identities with registered connections.
|
||||
pub fn identity_count(&self) -> usize {
|
||||
self.connections.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PushBroker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_broker_is_empty() {
|
||||
let broker = PushBroker::new();
|
||||
assert_eq!(broker.identity_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_broker_is_empty() {
|
||||
let broker = PushBroker::default();
|
||||
assert_eq!(broker.identity_count(), 0);
|
||||
}
|
||||
}
|
||||
309
crates/quicprochat-rpc/src/server.rs
Normal file
309
crates/quicprochat-rpc/src/server.rs
Normal file
@@ -0,0 +1,309 @@
|
||||
//! QUIC RPC server — accepts connections, dispatches requests to handlers.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::BytesMut;
|
||||
use quinn::{Endpoint, Incoming, RecvStream, SendStream};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth_handshake;
|
||||
use crate::error::{RpcError, RpcStatus};
|
||||
use crate::framing::{RequestFrame, ResponseFrame, PushFrame};
|
||||
use crate::method::{HandlerResult, MethodRegistry, RequestContext};
|
||||
use crate::middleware::SessionValidator;
|
||||
|
||||
/// Per-connection state established during the auth handshake.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ConnectionState {
|
||||
/// The raw session token provided during the auth init handshake.
|
||||
pub session_token: Option<Vec<u8>>,
|
||||
/// The identity key resolved from the session token (populated by validator).
|
||||
pub identity_key: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
/// Configuration for the RPC server.
|
||||
pub struct RpcServerConfig {
|
||||
/// QUIC listen address.
|
||||
pub listen_addr: std::net::SocketAddr,
|
||||
/// TLS server config (rustls).
|
||||
pub tls_config: Arc<rustls::ServerConfig>,
|
||||
/// ALPN protocol for the RPC service.
|
||||
pub alpn: Vec<u8>,
|
||||
}
|
||||
|
||||
/// The QUIC RPC server.
|
||||
pub struct RpcServer<S: Send + Sync + 'static> {
|
||||
endpoint: Endpoint,
|
||||
state: Arc<S>,
|
||||
registry: Arc<MethodRegistry<S>>,
|
||||
/// Optional session validator for the auth handshake.
|
||||
validator: Option<Arc<dyn SessionValidator>>,
|
||||
}
|
||||
|
||||
impl<S: Send + Sync + 'static> RpcServer<S> {
|
||||
/// Create and bind the QUIC endpoint. Does not start accepting yet.
|
||||
pub fn bind(
|
||||
config: RpcServerConfig,
|
||||
state: Arc<S>,
|
||||
registry: MethodRegistry<S>,
|
||||
) -> Result<Self, RpcError> {
|
||||
let mut tls = (*config.tls_config).clone();
|
||||
tls.alpn_protocols = vec![config.alpn];
|
||||
let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(tls)
|
||||
.map_err(|e| RpcError::Connection(format!("TLS config: {e}")))?;
|
||||
let server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_tls));
|
||||
|
||||
let endpoint = Endpoint::server(server_config, config.listen_addr)
|
||||
.map_err(|e| RpcError::Connection(format!("bind {}: {e}", config.listen_addr)))?;
|
||||
|
||||
info!(addr = %config.listen_addr, "RPC server bound");
|
||||
|
||||
Ok(Self {
|
||||
endpoint,
|
||||
state,
|
||||
registry: Arc::new(registry),
|
||||
validator: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set a session validator for the auth handshake.
|
||||
pub fn with_validator(mut self, validator: Arc<dyn SessionValidator>) -> Self {
|
||||
self.validator = Some(validator);
|
||||
self
|
||||
}
|
||||
|
||||
/// Accept connections in a loop. Spawns a task per connection.
|
||||
pub async fn serve(self) -> Result<(), RpcError> {
|
||||
info!("RPC server accepting connections");
|
||||
while let Some(incoming) = self.endpoint.accept().await {
|
||||
let state = Arc::clone(&self.state);
|
||||
let registry = Arc::clone(&self.registry);
|
||||
let validator = self.validator.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
handle_connection(incoming, state, registry, validator).await
|
||||
{
|
||||
warn!("connection error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the local address the server is listening on.
|
||||
pub fn local_addr(&self) -> Result<std::net::SocketAddr, RpcError> {
|
||||
self.endpoint
|
||||
.local_addr()
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a single QUIC connection: perform auth handshake, then accept
|
||||
/// bi-directional streams for RPCs.
|
||||
async fn handle_connection<S: Send + Sync + 'static>(
|
||||
incoming: Incoming,
|
||||
state: Arc<S>,
|
||||
registry: Arc<MethodRegistry<S>>,
|
||||
validator: Option<Arc<dyn SessionValidator>>,
|
||||
) -> Result<(), RpcError> {
|
||||
let connection = incoming
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
let remote = connection.remote_address();
|
||||
debug!(remote = %remote, "new connection");
|
||||
|
||||
metrics::gauge!("rpc_active_connections").increment(1.0);
|
||||
metrics::counter!("rpc_connections_total").increment(1);
|
||||
|
||||
// Perform auth handshake on the first bi-stream.
|
||||
let conn_state = {
|
||||
let (mut send, mut recv) = connection
|
||||
.accept_bi()
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(format!("accept auth stream: {e}")))?;
|
||||
|
||||
let token = auth_handshake::recv_auth_init(&mut recv).await?;
|
||||
debug!(remote = %remote, token_len = token.len(), "received auth init");
|
||||
|
||||
let identity_key = validator.as_ref().and_then(|v| v.validate(&token));
|
||||
|
||||
auth_handshake::send_auth_ack(&mut send).await?;
|
||||
send.finish()
|
||||
.map_err(|e| RpcError::Connection(format!("finish auth stream: {e}")))?;
|
||||
|
||||
Arc::new(ConnectionState {
|
||||
session_token: Some(token),
|
||||
identity_key,
|
||||
})
|
||||
};
|
||||
|
||||
// Accept RPC streams.
|
||||
let result = loop {
|
||||
let stream = connection.accept_bi().await;
|
||||
match stream {
|
||||
Ok((send, recv)) => {
|
||||
let state = Arc::clone(&state);
|
||||
let registry = Arc::clone(®istry);
|
||||
let conn_state = Arc::clone(&conn_state);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
handle_stream(send, recv, state, registry, &conn_state).await
|
||||
{
|
||||
debug!("stream error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => {
|
||||
debug!(remote = %remote, "connection closed by peer");
|
||||
break Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(remote = %remote, "accept_bi error: {e}");
|
||||
break Ok(());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
metrics::gauge!("rpc_active_connections").decrement(1.0);
|
||||
result
|
||||
}
|
||||
|
||||
/// Handle a single bi-directional stream: read request, dispatch, write response.
|
||||
async fn handle_stream<S: Send + Sync + 'static>(
|
||||
mut send: SendStream,
|
||||
mut recv: RecvStream,
|
||||
state: Arc<S>,
|
||||
registry: Arc<MethodRegistry<S>>,
|
||||
conn_state: &ConnectionState,
|
||||
) -> Result<(), RpcError> {
|
||||
// Read the complete request from the stream.
|
||||
let mut buf = BytesMut::new();
|
||||
while let Some(chunk) = recv
|
||||
.read_chunk(65536, true)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?
|
||||
{
|
||||
buf.extend_from_slice(&chunk.bytes);
|
||||
if buf.len() > crate::framing::MAX_PAYLOAD_SIZE + crate::framing::REQUEST_HEADER_SIZE {
|
||||
return Err(RpcError::PayloadTooLarge {
|
||||
size: buf.len(),
|
||||
max: crate::framing::MAX_PAYLOAD_SIZE,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let frame = match RequestFrame::decode(&mut buf)? {
|
||||
Some(f) => f,
|
||||
None => return Err(RpcError::Decode("incomplete request frame".into())),
|
||||
};
|
||||
|
||||
let trace_id = uuid::Uuid::now_v7().to_string();
|
||||
|
||||
let result = match registry.get(frame.method_id) {
|
||||
Some((handler, name, timeout)) => {
|
||||
let span = tracing::info_span!(
|
||||
"rpc",
|
||||
trace_id = %trace_id,
|
||||
method_id = frame.method_id,
|
||||
method = name,
|
||||
req_id = frame.request_id,
|
||||
);
|
||||
let _guard = span.enter();
|
||||
debug!("dispatching");
|
||||
|
||||
let deadline = timeout.map(|d| tokio::time::Instant::now() + d);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let ctx = RequestContext {
|
||||
identity_key: conn_state.identity_key.clone(),
|
||||
session_token: conn_state.session_token.clone(),
|
||||
payload: frame.payload,
|
||||
trace_id: trace_id.clone(),
|
||||
deadline,
|
||||
};
|
||||
|
||||
let result = if let Some(dur) = timeout {
|
||||
match tokio::time::timeout(dur, handler(Arc::clone(&state), ctx)).await {
|
||||
Ok(r) => r,
|
||||
Err(_) => {
|
||||
warn!(method = name, timeout_ms = dur.as_millis() as u64, "request deadline exceeded");
|
||||
HandlerResult::err(RpcStatus::DeadlineExceeded, "request deadline exceeded")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
handler(Arc::clone(&state), ctx).await
|
||||
};
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
// Per-endpoint latency histogram.
|
||||
metrics::histogram!("rpc_request_duration_seconds", "method" => name)
|
||||
.record(elapsed.as_secs_f64());
|
||||
metrics::counter!("rpc_requests_total", "method" => name, "status" => status_label(result.status))
|
||||
.increment(1);
|
||||
|
||||
result
|
||||
}
|
||||
None => {
|
||||
warn!(method_id = frame.method_id, trace_id = %trace_id, "unknown method");
|
||||
metrics::counter!("rpc_requests_total", "method" => "unknown", "status" => "unknown_method")
|
||||
.increment(1);
|
||||
HandlerResult::err(RpcStatus::UnknownMethod, "unknown method")
|
||||
}
|
||||
};
|
||||
|
||||
let response = ResponseFrame {
|
||||
status: result.status as u8,
|
||||
request_id: frame.request_id,
|
||||
payload: result.payload,
|
||||
};
|
||||
|
||||
let encoded = response.encode();
|
||||
send.write_all(&encoded)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert an RpcStatus to a short label for metrics.
|
||||
fn status_label(status: RpcStatus) -> &'static str {
|
||||
match status {
|
||||
RpcStatus::Ok => "ok",
|
||||
RpcStatus::BadRequest => "bad_request",
|
||||
RpcStatus::Unauthorized => "unauthorized",
|
||||
RpcStatus::Forbidden => "forbidden",
|
||||
RpcStatus::NotFound => "not_found",
|
||||
RpcStatus::RateLimited => "rate_limited",
|
||||
RpcStatus::DeadlineExceeded => "deadline_exceeded",
|
||||
RpcStatus::Unavailable => "unavailable",
|
||||
RpcStatus::Internal => "internal",
|
||||
RpcStatus::UnknownMethod => "unknown_method",
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a push event to a client via a QUIC uni-stream.
|
||||
pub async fn send_push(
|
||||
connection: &quinn::Connection,
|
||||
event_type: u16,
|
||||
payload: bytes::Bytes,
|
||||
) -> Result<(), RpcError> {
|
||||
let mut send = connection
|
||||
.open_uni()
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
let frame = PushFrame {
|
||||
event_type,
|
||||
payload,
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
send.write_all(&encoded)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user