feat: v2 Phase 1 — foundation, proto schemas, RPC framework, SDK skeleton

New workspace structure with 9 crates. Adds:

- proto/qpq/v1/*.proto: 11 protobuf schemas covering all 33 RPC methods
- quicproquo-proto: dual codegen (capnp legacy + prost v2)
- quicproquo-rpc: QUIC RPC framework (framing, server, client, middleware)
- quicproquo-sdk: client SDK (QpqClient, events, conversation store)
- quicproquo-server/domain/: protocol-agnostic domain types and services
- justfile: build commands

Wire format: [method_id:u16][req_id:u32][len:u32][protobuf] per QUIC stream.
All 151 existing tests pass. Backward compatible with v1 capnp code.
This commit is contained in:
2026-03-04 12:02:07 +01:00
parent 394199b19b
commit a5864127d1
37 changed files with 3115 additions and 2778 deletions

View File

@@ -0,0 +1,25 @@
[package]
name = "quicproquo-rpc"
version = "0.1.0"
edition = "2021"
description = "QUIC RPC framework for quicproquo v2 — framing, dispatch, tower middleware"
[dependencies]
quicproquo-proto = { path = "../quicproquo-proto" }
prost = { workspace = true }
bytes = { workspace = true }
quinn = { workspace = true }
rustls = { workspace = true }
rcgen = { workspace = true }
tokio = { workspace = true }
futures = { workspace = true }
tower = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
dashmap = { workspace = true }
[dev-dependencies]
tokio = { workspace = true, features = ["test-util"] }
[lints]
workspace = true

View File

@@ -0,0 +1,175 @@
//! QUIC RPC client — connect to server, send requests, receive push events.
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use bytes::{Bytes, BytesMut};
use quinn::{Connection, Endpoint};
use tokio::sync::mpsc;
use tracing::{debug, warn};
use crate::error::{RpcError, RpcStatus};
use crate::framing::{PushFrame, RequestFrame, ResponseFrame};
/// Configuration for the RPC client.
pub struct RpcClientConfig {
/// Server address to connect to.
pub server_addr: std::net::SocketAddr,
/// Server name for TLS verification.
pub server_name: String,
/// TLS client config (rustls).
pub tls_config: Arc<rustls::ClientConfig>,
/// ALPN protocol.
pub alpn: Vec<u8>,
}
/// A QUIC RPC client connection.
pub struct RpcClient {
connection: Connection,
next_request_id: AtomicU32,
}
impl RpcClient {
/// Connect to the RPC server.
pub async fn connect(config: RpcClientConfig) -> Result<Self, RpcError> {
let mut tls = (*config.tls_config).clone();
tls.alpn_protocols = vec![config.alpn];
let quic_tls = quinn::crypto::rustls::QuicClientConfig::try_from(tls)
.map_err(|e| RpcError::Connection(format!("TLS config: {e}")))?;
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().expect("valid addr"))
.map_err(|e| RpcError::Connection(e.to_string()))?;
endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new(quic_tls)));
let connection = endpoint
.connect(config.server_addr, &config.server_name)
.map_err(|e| RpcError::Connection(e.to_string()))?
.await
.map_err(|e| RpcError::Connection(e.to_string()))?;
debug!(remote = %connection.remote_address(), "connected to RPC server");
Ok(Self {
connection,
next_request_id: AtomicU32::new(1),
})
}
/// Send an RPC request and wait for the response.
pub async fn call(
&self,
method_id: u16,
payload: Bytes,
) -> Result<Bytes, RpcError> {
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
let (mut send, mut recv) = self
.connection
.open_bi()
.await
.map_err(|e| RpcError::Connection(e.to_string()))?;
// Send request.
let frame = RequestFrame {
method_id,
request_id,
payload,
};
let encoded = frame.encode();
send.write_all(&encoded)
.await
.map_err(|e| RpcError::Connection(e.to_string()))?;
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
// Read response.
let mut buf = BytesMut::new();
while let Some(chunk) = recv
.read_chunk(65536, true)
.await
.map_err(|e| RpcError::Connection(e.to_string()))?
{
buf.extend_from_slice(&chunk.bytes);
if buf.len() > crate::framing::MAX_PAYLOAD_SIZE + crate::framing::RESPONSE_HEADER_SIZE {
return Err(RpcError::PayloadTooLarge {
size: buf.len(),
max: crate::framing::MAX_PAYLOAD_SIZE,
});
}
}
let response = ResponseFrame::decode(&mut buf)?
.ok_or_else(|| RpcError::Decode("incomplete response frame".into()))?;
if response.request_id != request_id {
return Err(RpcError::Decode(format!(
"request_id mismatch: sent {request_id}, got {}",
response.request_id
)));
}
match RpcStatus::from_u8(response.status) {
Some(RpcStatus::Ok) => Ok(response.payload),
Some(status) => Err(RpcError::Server {
status,
message: String::from_utf8_lossy(&response.payload).into_owned(),
}),
None => Err(RpcError::Decode(format!(
"unknown status byte: {}",
response.status
))),
}
}
/// Subscribe to server-push events. Returns a receiver channel.
/// Spawns a background task that reads uni-streams.
pub fn subscribe_push(&self) -> mpsc::UnboundedReceiver<PushFrame> {
let (tx, rx) = mpsc::unbounded_channel();
let conn = self.connection.clone();
tokio::spawn(async move {
loop {
match conn.accept_uni().await {
Ok(mut recv) => {
let mut buf = BytesMut::new();
loop {
match recv.read_chunk(65536, true).await {
Ok(Some(chunk)) => buf.extend_from_slice(&chunk.bytes),
Ok(None) => break,
Err(e) => {
debug!("push stream read error: {e}");
break;
}
}
}
match PushFrame::decode(&mut buf) {
Ok(Some(frame)) => {
if tx.send(frame).is_err() {
return; // receiver dropped
}
}
Ok(None) => debug!("incomplete push frame"),
Err(e) => debug!("push decode error: {e}"),
}
}
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
Err(e) => {
warn!("accept_uni error: {e}");
break;
}
}
}
});
rx
}
/// Close the connection gracefully.
pub fn close(&self) {
self.connection.close(0u32.into(), b"bye");
}
/// Get the underlying QUIC connection (for advanced use).
pub fn connection(&self) -> &Connection {
&self.connection
}
}

View File

@@ -0,0 +1,68 @@
//! RPC error types.
/// Status codes for RPC responses.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum RpcStatus {
/// Request succeeded.
Ok = 0,
/// Client sent a malformed request.
BadRequest = 1,
/// Authentication required or token invalid.
Unauthorized = 2,
/// Caller lacks permission for this operation.
Forbidden = 3,
/// Requested resource not found.
NotFound = 4,
/// Rate limit exceeded.
RateLimited = 5,
/// Internal server error.
Internal = 10,
/// Method not recognized.
UnknownMethod = 11,
}
impl RpcStatus {
/// Decode a status byte. Returns `None` for unknown values.
pub fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(Self::Ok),
1 => Some(Self::BadRequest),
2 => Some(Self::Unauthorized),
3 => Some(Self::Forbidden),
4 => Some(Self::NotFound),
5 => Some(Self::RateLimited),
10 => Some(Self::Internal),
11 => Some(Self::UnknownMethod),
_ => None,
}
}
}
/// Errors that can occur in the RPC layer.
#[derive(Debug, thiserror::Error)]
pub enum RpcError {
#[error("connection error: {0}")]
Connection(String),
#[error("encoding error: {0}")]
Encode(String),
#[error("decoding error: {0}")]
Decode(String),
#[error("server returned error status {status:?}: {message}")]
Server {
status: RpcStatus,
message: String,
},
#[error("request timed out")]
Timeout,
#[error("stream closed unexpectedly")]
StreamClosed,
#[error("payload too large: {size} bytes (max {max})")]
PayloadTooLarge { size: usize, max: usize },
}

View File

@@ -0,0 +1,280 @@
//! Wire format encoding and decoding for the quicproquo v2 RPC protocol.
//!
//! ## Request frame
//! ```text
//! [method_id: u16 BE][request_id: u32 BE][payload_len: u32 BE][protobuf bytes]
//! ```
//!
//! ## Response frame
//! ```text
//! [status: u8][request_id: u32 BE][payload_len: u32 BE][protobuf bytes]
//! ```
//!
//! ## Push frame (server → client, uni-stream)
//! ```text
//! [event_type: u16 BE][payload_len: u32 BE][protobuf bytes]
//! ```
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crate::error::{RpcError, RpcStatus};
/// Maximum payload size: 4 MiB.
pub const MAX_PAYLOAD_SIZE: usize = 4 * 1024 * 1024;
/// Request header size: 2 (method) + 4 (req_id) + 4 (len) = 10 bytes.
pub const REQUEST_HEADER_SIZE: usize = 10;
/// Response header size: 1 (status) + 4 (req_id) + 4 (len) = 9 bytes.
pub const RESPONSE_HEADER_SIZE: usize = 9;
/// Push header size: 2 (event_type) + 4 (len) = 6 bytes.
pub const PUSH_HEADER_SIZE: usize = 6;
// ── Request ──────────────────────────────────────────────────────────────────
/// A decoded RPC request frame.
#[derive(Debug, Clone)]
pub struct RequestFrame {
pub method_id: u16,
pub request_id: u32,
pub payload: Bytes,
}
impl RequestFrame {
/// Encode this request into a byte buffer.
pub fn encode(&self) -> Bytes {
let mut buf = BytesMut::with_capacity(REQUEST_HEADER_SIZE + self.payload.len());
buf.put_u16(self.method_id);
buf.put_u32(self.request_id);
buf.put_u32(self.payload.len() as u32);
buf.put(self.payload.clone());
buf.freeze()
}
/// Decode a request frame from a byte buffer.
/// Returns `None` if the buffer does not contain a complete frame.
pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, RpcError> {
if buf.len() < REQUEST_HEADER_SIZE {
return Ok(None);
}
// Peek at payload_len without consuming.
let payload_len =
u32::from_be_bytes([buf[6], buf[7], buf[8], buf[9]]) as usize;
if payload_len > MAX_PAYLOAD_SIZE {
return Err(RpcError::PayloadTooLarge {
size: payload_len,
max: MAX_PAYLOAD_SIZE,
});
}
let total = REQUEST_HEADER_SIZE + payload_len;
if buf.len() < total {
return Ok(None);
}
let method_id = buf.get_u16();
let request_id = buf.get_u32();
let _len = buf.get_u32();
let payload = buf.split_to(payload_len).freeze();
Ok(Some(Self {
method_id,
request_id,
payload,
}))
}
}
// ── Response ─────────────────────────────────────────────────────────────────
/// A decoded RPC response frame.
#[derive(Debug, Clone)]
pub struct ResponseFrame {
pub status: u8,
pub request_id: u32,
pub payload: Bytes,
}
impl ResponseFrame {
/// Encode this response into a byte buffer.
pub fn encode(&self) -> Bytes {
let mut buf = BytesMut::with_capacity(RESPONSE_HEADER_SIZE + self.payload.len());
buf.put_u8(self.status);
buf.put_u32(self.request_id);
buf.put_u32(self.payload.len() as u32);
buf.put(self.payload.clone());
buf.freeze()
}
/// Decode a response frame from a byte buffer.
pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, RpcError> {
if buf.len() < RESPONSE_HEADER_SIZE {
return Ok(None);
}
let payload_len =
u32::from_be_bytes([buf[5], buf[6], buf[7], buf[8]]) as usize;
if payload_len > MAX_PAYLOAD_SIZE {
return Err(RpcError::PayloadTooLarge {
size: payload_len,
max: MAX_PAYLOAD_SIZE,
});
}
let total = RESPONSE_HEADER_SIZE + payload_len;
if buf.len() < total {
return Ok(None);
}
let status = buf.get_u8();
let request_id = buf.get_u32();
let _len = buf.get_u32();
let payload = buf.split_to(payload_len).freeze();
Ok(Some(Self {
status,
request_id,
payload,
}))
}
/// Convert the status byte to an `RpcStatus`.
pub fn rpc_status(&self) -> Option<RpcStatus> {
RpcStatus::from_u8(self.status)
}
}
// ── Push ─────────────────────────────────────────────────────────────────────
/// A decoded server-push event frame (sent on QUIC uni-streams).
#[derive(Debug, Clone)]
pub struct PushFrame {
pub event_type: u16,
pub payload: Bytes,
}
impl PushFrame {
/// Encode this push frame into a byte buffer.
pub fn encode(&self) -> Bytes {
let mut buf = BytesMut::with_capacity(PUSH_HEADER_SIZE + self.payload.len());
buf.put_u16(self.event_type);
buf.put_u32(self.payload.len() as u32);
buf.put(self.payload.clone());
buf.freeze()
}
/// Decode a push frame from a byte buffer.
pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, RpcError> {
if buf.len() < PUSH_HEADER_SIZE {
return Ok(None);
}
let payload_len =
u32::from_be_bytes([buf[2], buf[3], buf[4], buf[5]]) as usize;
if payload_len > MAX_PAYLOAD_SIZE {
return Err(RpcError::PayloadTooLarge {
size: payload_len,
max: MAX_PAYLOAD_SIZE,
});
}
let total = PUSH_HEADER_SIZE + payload_len;
if buf.len() < total {
return Ok(None);
}
let event_type = buf.get_u16();
let _len = buf.get_u32();
let payload = buf.split_to(payload_len).freeze();
Ok(Some(Self {
event_type,
payload,
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_roundtrip() {
let frame = RequestFrame {
method_id: 42,
request_id: 1001,
payload: Bytes::from_static(b"hello"),
};
let encoded = frame.encode();
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = RequestFrame::decode(&mut buf).expect("decode").expect("complete");
assert_eq!(decoded.method_id, 42);
assert_eq!(decoded.request_id, 1001);
assert_eq!(decoded.payload, Bytes::from_static(b"hello"));
assert!(buf.is_empty());
}
#[test]
fn response_roundtrip() {
let frame = ResponseFrame {
status: RpcStatus::Ok as u8,
request_id: 2002,
payload: Bytes::from_static(b"world"),
};
let encoded = frame.encode();
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = ResponseFrame::decode(&mut buf).expect("decode").expect("complete");
assert_eq!(decoded.status, 0);
assert_eq!(decoded.request_id, 2002);
assert_eq!(decoded.payload, Bytes::from_static(b"world"));
}
#[test]
fn push_roundtrip() {
let frame = PushFrame {
event_type: 7,
payload: Bytes::from_static(b"event-data"),
};
let encoded = frame.encode();
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = PushFrame::decode(&mut buf).expect("decode").expect("complete");
assert_eq!(decoded.event_type, 7);
assert_eq!(decoded.payload, Bytes::from_static(b"event-data"));
}
#[test]
fn incomplete_request_returns_none() {
let mut buf = BytesMut::from(&[0u8; 5][..]);
assert!(RequestFrame::decode(&mut buf).expect("no error").is_none());
}
#[test]
fn payload_too_large_rejected() {
// Craft a request header with payload_len = MAX + 1.
let mut buf = BytesMut::new();
buf.put_u16(1);
buf.put_u32(1);
buf.put_u32((MAX_PAYLOAD_SIZE + 1) as u32);
let result = RequestFrame::decode(&mut buf);
assert!(matches!(result, Err(RpcError::PayloadTooLarge { .. })));
}
#[test]
fn empty_payload_request() {
let frame = RequestFrame {
method_id: 0,
request_id: 0,
payload: Bytes::new(),
};
let encoded = frame.encode();
assert_eq!(encoded.len(), REQUEST_HEADER_SIZE);
let mut buf = BytesMut::from(encoded.as_ref());
let decoded = RequestFrame::decode(&mut buf).expect("decode").expect("complete");
assert!(decoded.payload.is_empty());
}
}

View File

@@ -0,0 +1,13 @@
//! QUIC RPC framework for quicproquo v2.
//!
//! Wire format per QUIC stream:
//! - Request: `[method_id: u16][request_id: u32][payload_len: u32][protobuf bytes]`
//! - Response: `[status: u8][request_id: u32][payload_len: u32][protobuf bytes]`
//! - Push: `[event_type: u16][payload_len: u32][protobuf bytes]` (uni-stream)
pub mod framing;
pub mod method;
pub mod server;
pub mod client;
pub mod middleware;
pub mod error;

View File

@@ -0,0 +1,102 @@
//! Method registry — maps method IDs to handler functions.
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
use crate::error::RpcStatus;
/// The result of handling an RPC request.
pub struct HandlerResult {
pub status: RpcStatus,
pub payload: Bytes,
}
impl HandlerResult {
/// Shorthand for a successful response.
pub fn ok(payload: Bytes) -> Self {
Self {
status: RpcStatus::Ok,
payload,
}
}
/// Shorthand for an error response.
pub fn err(status: RpcStatus, message: &str) -> Self {
Self {
status,
payload: Bytes::copy_from_slice(message.as_bytes()),
}
}
}
/// Context passed to every RPC handler.
pub struct RequestContext {
/// The authenticated identity key of the caller, if any.
pub identity_key: Option<Vec<u8>>,
/// The session token, if provided.
pub session_token: Option<Vec<u8>>,
/// The raw request payload (protobuf-encoded).
pub payload: Bytes,
}
/// Type-erased async handler function.
pub type HandlerFn<S> = Arc<
dyn Fn(Arc<S>, RequestContext) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>>
+ Send
+ Sync,
>;
/// Registry mapping method IDs to handler functions.
pub struct MethodRegistry<S> {
handlers: HashMap<u16, (HandlerFn<S>, &'static str)>,
}
impl<S: Send + Sync + 'static> MethodRegistry<S> {
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
}
}
/// Register a handler for a method ID.
pub fn register<F, Fut>(&mut self, method_id: u16, name: &'static str, handler: F)
where
F: Fn(Arc<S>, RequestContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = HandlerResult> + Send + 'static,
{
let handler = Arc::new(move |state: Arc<S>, ctx: RequestContext| {
Box::pin(handler(state, ctx)) as Pin<Box<dyn Future<Output = HandlerResult> + Send>>
});
self.handlers.insert(method_id, (handler, name));
}
/// Look up a handler by method ID.
pub fn get(&self, method_id: u16) -> Option<&(HandlerFn<S>, &'static str)> {
self.handlers.get(&method_id)
}
/// Return the number of registered methods.
pub fn len(&self) -> usize {
self.handlers.len()
}
/// Whether the registry is empty.
pub fn is_empty(&self) -> bool {
self.handlers.is_empty()
}
/// Iterate over all registered (method_id, name) pairs.
pub fn methods(&self) -> impl Iterator<Item = (u16, &'static str)> + '_ {
self.handlers.iter().map(|(&id, (_, name))| (id, *name))
}
}
impl<S: Send + Sync + 'static> Default for MethodRegistry<S> {
fn default() -> Self {
Self::new()
}
}

View File

@@ -0,0 +1,96 @@
//! Tower-based middleware layers for the RPC server.
//!
//! - `AuthLayer`: validates session tokens and attaches identity to context.
//! - `RateLimitLayer`: per-IP request rate limiting.
use std::time::{Duration, Instant};
use dashmap::DashMap;
// ── Auth middleware ──────────────────────────────────────────────────────────
/// Validates bearer tokens and resolves identity keys.
pub trait SessionValidator: Send + Sync + 'static {
/// Validate a session token, returning the identity key if valid.
fn validate(&self, token: &[u8]) -> Option<Vec<u8>>;
}
/// Auth context extracted from a validated session.
#[derive(Debug, Clone)]
pub struct AuthContext {
/// The Ed25519 identity key of the authenticated caller.
pub identity_key: Vec<u8>,
}
// ── Rate limiter ─────────────────────────────────────────────────────────────
/// Simple per-key sliding-window rate limiter.
pub struct RateLimiter {
/// Max requests per window.
max_requests: u32,
/// Window duration.
window: Duration,
/// Map from key → (count, window_start).
state: DashMap<Vec<u8>, (u32, Instant)>,
}
impl RateLimiter {
/// Create a new rate limiter.
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
max_requests,
window,
state: DashMap::new(),
}
}
/// Check if a request from `key` is allowed. Returns `true` if allowed.
pub fn check(&self, key: &[u8]) -> bool {
let now = Instant::now();
let mut entry = self.state.entry(key.to_vec()).or_insert((0, now));
let (count, window_start) = entry.value_mut();
if now.duration_since(*window_start) >= self.window {
// Reset window.
*count = 1;
*window_start = now;
true
} else if *count < self.max_requests {
*count += 1;
true
} else {
false
}
}
/// Remove expired entries (call periodically for memory hygiene).
pub fn gc(&self) {
let now = Instant::now();
self.state.retain(|_, (_, start)| now.duration_since(*start) < self.window * 2);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rate_limiter_allows_within_limit() {
let rl = RateLimiter::new(3, Duration::from_secs(60));
let key = b"test-key";
assert!(rl.check(key));
assert!(rl.check(key));
assert!(rl.check(key));
assert!(!rl.check(key)); // 4th request denied
}
#[test]
fn rate_limiter_resets_after_window() {
let rl = RateLimiter::new(1, Duration::from_millis(1));
let key = b"test-key";
assert!(rl.check(key));
assert!(!rl.check(key));
std::thread::sleep(Duration::from_millis(5));
assert!(rl.check(key)); // window expired
}
}

View File

@@ -0,0 +1,198 @@
//! QUIC RPC server — accepts connections, dispatches requests to handlers.
use std::sync::Arc;
use bytes::BytesMut;
use quinn::{Endpoint, Incoming, RecvStream, SendStream};
use tracing::{debug, info, warn};
use crate::error::{RpcError, RpcStatus};
use crate::framing::{RequestFrame, ResponseFrame, PushFrame};
use crate::method::{HandlerResult, MethodRegistry, RequestContext};
/// Configuration for the RPC server.
pub struct RpcServerConfig {
/// QUIC listen address.
pub listen_addr: std::net::SocketAddr,
/// TLS server config (rustls).
pub tls_config: Arc<rustls::ServerConfig>,
/// ALPN protocol for the RPC service.
pub alpn: Vec<u8>,
}
/// The QUIC RPC server.
pub struct RpcServer<S: Send + Sync + 'static> {
endpoint: Endpoint,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
}
impl<S: Send + Sync + 'static> RpcServer<S> {
/// Create and bind the QUIC endpoint. Does not start accepting yet.
pub fn bind(
config: RpcServerConfig,
state: Arc<S>,
registry: MethodRegistry<S>,
) -> Result<Self, RpcError> {
let mut tls = (*config.tls_config).clone();
tls.alpn_protocols = vec![config.alpn];
let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(tls)
.map_err(|e| RpcError::Connection(format!("TLS config: {e}")))?;
let server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_tls));
let endpoint = Endpoint::server(server_config, config.listen_addr)
.map_err(|e| RpcError::Connection(format!("bind {}: {e}", config.listen_addr)))?;
info!(addr = %config.listen_addr, "RPC server bound");
Ok(Self {
endpoint,
state,
registry: Arc::new(registry),
})
}
/// Accept connections in a loop. Spawns a task per connection.
pub async fn serve(self) -> Result<(), RpcError> {
info!("RPC server accepting connections");
while let Some(incoming) = self.endpoint.accept().await {
let state = Arc::clone(&self.state);
let registry = Arc::clone(&self.registry);
tokio::spawn(async move {
if let Err(e) = handle_connection(incoming, state, registry).await {
warn!("connection error: {e}");
}
});
}
Ok(())
}
/// Get the local address the server is listening on.
pub fn local_addr(&self) -> Result<std::net::SocketAddr, RpcError> {
self.endpoint
.local_addr()
.map_err(|e| RpcError::Connection(e.to_string()))
}
}
/// Handle a single QUIC connection: accept bi-directional streams for RPCs.
async fn handle_connection<S: Send + Sync + 'static>(
incoming: Incoming,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
) -> Result<(), RpcError> {
let connection = incoming
.await
.map_err(|e| RpcError::Connection(e.to_string()))?;
let remote = connection.remote_address();
debug!(remote = %remote, "new connection");
loop {
let stream = connection.accept_bi().await;
match stream {
Ok((send, recv)) => {
let state = Arc::clone(&state);
let registry = Arc::clone(&registry);
tokio::spawn(async move {
if let Err(e) = handle_stream(send, recv, state, registry).await {
debug!("stream error: {e}");
}
});
}
Err(quinn::ConnectionError::ApplicationClosed(_)) => {
debug!(remote = %remote, "connection closed by peer");
break;
}
Err(e) => {
debug!(remote = %remote, "accept_bi error: {e}");
break;
}
}
}
Ok(())
}
/// Handle a single bi-directional stream: read request, dispatch, write response.
async fn handle_stream<S: Send + Sync + 'static>(
mut send: SendStream,
mut recv: RecvStream,
state: Arc<S>,
registry: Arc<MethodRegistry<S>>,
) -> Result<(), RpcError> {
// Read the complete request from the stream.
let mut buf = BytesMut::new();
while let Some(chunk) = recv
.read_chunk(65536, true)
.await
.map_err(|e| RpcError::Connection(e.to_string()))?
{
buf.extend_from_slice(&chunk.bytes);
if buf.len() > crate::framing::MAX_PAYLOAD_SIZE + crate::framing::REQUEST_HEADER_SIZE {
return Err(RpcError::PayloadTooLarge {
size: buf.len(),
max: crate::framing::MAX_PAYLOAD_SIZE,
});
}
}
let frame = match RequestFrame::decode(&mut buf)? {
Some(f) => f,
None => return Err(RpcError::Decode("incomplete request frame".into())),
};
let result = match registry.get(frame.method_id) {
Some((handler, name)) => {
debug!(method_id = frame.method_id, method = name, req_id = frame.request_id, "dispatching");
let ctx = RequestContext {
identity_key: None, // populated by auth middleware
session_token: None,
payload: frame.payload,
};
handler(Arc::clone(&state), ctx).await
}
None => {
warn!(method_id = frame.method_id, "unknown method");
HandlerResult::err(RpcStatus::UnknownMethod, "unknown method")
}
};
let response = ResponseFrame {
status: result.status as u8,
request_id: frame.request_id,
payload: result.payload,
};
let encoded = response.encode();
send.write_all(&encoded)
.await
.map_err(|e| RpcError::Connection(e.to_string()))?;
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
Ok(())
}
/// Send a push event to a client via a QUIC uni-stream.
pub async fn send_push(
connection: &quinn::Connection,
event_type: u16,
payload: bytes::Bytes,
) -> Result<(), RpcError> {
let mut send = connection
.open_uni()
.await
.map_err(|e| RpcError::Connection(e.to_string()))?;
let frame = PushFrame {
event_type,
payload,
};
let encoded = frame.encode();
send.write_all(&encoded)
.await
.map_err(|e| RpcError::Connection(e.to_string()))?;
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
Ok(())
}