feat: v2 Phase 1 — foundation, proto schemas, RPC framework, SDK skeleton
New workspace structure with 9 crates. Adds: - proto/qpq/v1/*.proto: 11 protobuf schemas covering all 33 RPC methods - quicproquo-proto: dual codegen (capnp legacy + prost v2) - quicproquo-rpc: QUIC RPC framework (framing, server, client, middleware) - quicproquo-sdk: client SDK (QpqClient, events, conversation store) - quicproquo-server/domain/: protocol-agnostic domain types and services - justfile: build commands Wire format: [method_id:u16][req_id:u32][len:u32][protobuf] per QUIC stream. All 151 existing tests pass. Backward compatible with v1 capnp code.
This commit is contained in:
25
crates/quicproquo-rpc/Cargo.toml
Normal file
25
crates/quicproquo-rpc/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "quicproquo-rpc"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "QUIC RPC framework for quicproquo v2 — framing, dispatch, tower middleware"
|
||||
|
||||
[dependencies]
|
||||
quicproquo-proto = { path = "../quicproquo-proto" }
|
||||
prost = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
quinn = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
rcgen = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["test-util"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
175
crates/quicproquo-rpc/src/client.rs
Normal file
175
crates/quicproquo-rpc/src/client.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
//! QUIC RPC client — connect to server, send requests, receive push events.
|
||||
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use quinn::{Connection, Endpoint};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::error::{RpcError, RpcStatus};
|
||||
use crate::framing::{PushFrame, RequestFrame, ResponseFrame};
|
||||
|
||||
/// Configuration for the RPC client.
|
||||
pub struct RpcClientConfig {
|
||||
/// Server address to connect to.
|
||||
pub server_addr: std::net::SocketAddr,
|
||||
/// Server name for TLS verification.
|
||||
pub server_name: String,
|
||||
/// TLS client config (rustls).
|
||||
pub tls_config: Arc<rustls::ClientConfig>,
|
||||
/// ALPN protocol.
|
||||
pub alpn: Vec<u8>,
|
||||
}
|
||||
|
||||
/// A QUIC RPC client connection.
|
||||
pub struct RpcClient {
|
||||
connection: Connection,
|
||||
next_request_id: AtomicU32,
|
||||
}
|
||||
|
||||
impl RpcClient {
|
||||
/// Connect to the RPC server.
|
||||
pub async fn connect(config: RpcClientConfig) -> Result<Self, RpcError> {
|
||||
let mut tls = (*config.tls_config).clone();
|
||||
tls.alpn_protocols = vec![config.alpn];
|
||||
let quic_tls = quinn::crypto::rustls::QuicClientConfig::try_from(tls)
|
||||
.map_err(|e| RpcError::Connection(format!("TLS config: {e}")))?;
|
||||
|
||||
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().expect("valid addr"))
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new(quic_tls)));
|
||||
|
||||
let connection = endpoint
|
||||
.connect(config.server_addr, &config.server_name)
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
debug!(remote = %connection.remote_address(), "connected to RPC server");
|
||||
|
||||
Ok(Self {
|
||||
connection,
|
||||
next_request_id: AtomicU32::new(1),
|
||||
})
|
||||
}
|
||||
|
||||
/// Send an RPC request and wait for the response.
|
||||
pub async fn call(
|
||||
&self,
|
||||
method_id: u16,
|
||||
payload: Bytes,
|
||||
) -> Result<Bytes, RpcError> {
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let (mut send, mut recv) = self
|
||||
.connection
|
||||
.open_bi()
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
// Send request.
|
||||
let frame = RequestFrame {
|
||||
method_id,
|
||||
request_id,
|
||||
payload,
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
send.write_all(&encoded)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
// Read response.
|
||||
let mut buf = BytesMut::new();
|
||||
while let Some(chunk) = recv
|
||||
.read_chunk(65536, true)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?
|
||||
{
|
||||
buf.extend_from_slice(&chunk.bytes);
|
||||
if buf.len() > crate::framing::MAX_PAYLOAD_SIZE + crate::framing::RESPONSE_HEADER_SIZE {
|
||||
return Err(RpcError::PayloadTooLarge {
|
||||
size: buf.len(),
|
||||
max: crate::framing::MAX_PAYLOAD_SIZE,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let response = ResponseFrame::decode(&mut buf)?
|
||||
.ok_or_else(|| RpcError::Decode("incomplete response frame".into()))?;
|
||||
|
||||
if response.request_id != request_id {
|
||||
return Err(RpcError::Decode(format!(
|
||||
"request_id mismatch: sent {request_id}, got {}",
|
||||
response.request_id
|
||||
)));
|
||||
}
|
||||
|
||||
match RpcStatus::from_u8(response.status) {
|
||||
Some(RpcStatus::Ok) => Ok(response.payload),
|
||||
Some(status) => Err(RpcError::Server {
|
||||
status,
|
||||
message: String::from_utf8_lossy(&response.payload).into_owned(),
|
||||
}),
|
||||
None => Err(RpcError::Decode(format!(
|
||||
"unknown status byte: {}",
|
||||
response.status
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to server-push events. Returns a receiver channel.
|
||||
/// Spawns a background task that reads uni-streams.
|
||||
pub fn subscribe_push(&self) -> mpsc::UnboundedReceiver<PushFrame> {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
let conn = self.connection.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match conn.accept_uni().await {
|
||||
Ok(mut recv) => {
|
||||
let mut buf = BytesMut::new();
|
||||
loop {
|
||||
match recv.read_chunk(65536, true).await {
|
||||
Ok(Some(chunk)) => buf.extend_from_slice(&chunk.bytes),
|
||||
Ok(None) => break,
|
||||
Err(e) => {
|
||||
debug!("push stream read error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
match PushFrame::decode(&mut buf) {
|
||||
Ok(Some(frame)) => {
|
||||
if tx.send(frame).is_err() {
|
||||
return; // receiver dropped
|
||||
}
|
||||
}
|
||||
Ok(None) => debug!("incomplete push frame"),
|
||||
Err(e) => debug!("push decode error: {e}"),
|
||||
}
|
||||
}
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
|
||||
Err(e) => {
|
||||
warn!("accept_uni error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
/// Close the connection gracefully.
|
||||
pub fn close(&self) {
|
||||
self.connection.close(0u32.into(), b"bye");
|
||||
}
|
||||
|
||||
/// Get the underlying QUIC connection (for advanced use).
|
||||
pub fn connection(&self) -> &Connection {
|
||||
&self.connection
|
||||
}
|
||||
}
|
||||
68
crates/quicproquo-rpc/src/error.rs
Normal file
68
crates/quicproquo-rpc/src/error.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
//! RPC error types.
|
||||
|
||||
/// Status codes for RPC responses.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum RpcStatus {
|
||||
/// Request succeeded.
|
||||
Ok = 0,
|
||||
/// Client sent a malformed request.
|
||||
BadRequest = 1,
|
||||
/// Authentication required or token invalid.
|
||||
Unauthorized = 2,
|
||||
/// Caller lacks permission for this operation.
|
||||
Forbidden = 3,
|
||||
/// Requested resource not found.
|
||||
NotFound = 4,
|
||||
/// Rate limit exceeded.
|
||||
RateLimited = 5,
|
||||
/// Internal server error.
|
||||
Internal = 10,
|
||||
/// Method not recognized.
|
||||
UnknownMethod = 11,
|
||||
}
|
||||
|
||||
impl RpcStatus {
|
||||
/// Decode a status byte. Returns `None` for unknown values.
|
||||
pub fn from_u8(v: u8) -> Option<Self> {
|
||||
match v {
|
||||
0 => Some(Self::Ok),
|
||||
1 => Some(Self::BadRequest),
|
||||
2 => Some(Self::Unauthorized),
|
||||
3 => Some(Self::Forbidden),
|
||||
4 => Some(Self::NotFound),
|
||||
5 => Some(Self::RateLimited),
|
||||
10 => Some(Self::Internal),
|
||||
11 => Some(Self::UnknownMethod),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors that can occur in the RPC layer.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RpcError {
|
||||
#[error("connection error: {0}")]
|
||||
Connection(String),
|
||||
|
||||
#[error("encoding error: {0}")]
|
||||
Encode(String),
|
||||
|
||||
#[error("decoding error: {0}")]
|
||||
Decode(String),
|
||||
|
||||
#[error("server returned error status {status:?}: {message}")]
|
||||
Server {
|
||||
status: RpcStatus,
|
||||
message: String,
|
||||
},
|
||||
|
||||
#[error("request timed out")]
|
||||
Timeout,
|
||||
|
||||
#[error("stream closed unexpectedly")]
|
||||
StreamClosed,
|
||||
|
||||
#[error("payload too large: {size} bytes (max {max})")]
|
||||
PayloadTooLarge { size: usize, max: usize },
|
||||
}
|
||||
280
crates/quicproquo-rpc/src/framing.rs
Normal file
280
crates/quicproquo-rpc/src/framing.rs
Normal file
@@ -0,0 +1,280 @@
|
||||
//! Wire format encoding and decoding for the quicproquo v2 RPC protocol.
|
||||
//!
|
||||
//! ## Request frame
|
||||
//! ```text
|
||||
//! [method_id: u16 BE][request_id: u32 BE][payload_len: u32 BE][protobuf bytes]
|
||||
//! ```
|
||||
//!
|
||||
//! ## Response frame
|
||||
//! ```text
|
||||
//! [status: u8][request_id: u32 BE][payload_len: u32 BE][protobuf bytes]
|
||||
//! ```
|
||||
//!
|
||||
//! ## Push frame (server → client, uni-stream)
|
||||
//! ```text
|
||||
//! [event_type: u16 BE][payload_len: u32 BE][protobuf bytes]
|
||||
//! ```
|
||||
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
|
||||
use crate::error::{RpcError, RpcStatus};
|
||||
|
||||
/// Maximum payload size: 4 MiB.
|
||||
pub const MAX_PAYLOAD_SIZE: usize = 4 * 1024 * 1024;
|
||||
|
||||
/// Request header size: 2 (method) + 4 (req_id) + 4 (len) = 10 bytes.
|
||||
pub const REQUEST_HEADER_SIZE: usize = 10;
|
||||
|
||||
/// Response header size: 1 (status) + 4 (req_id) + 4 (len) = 9 bytes.
|
||||
pub const RESPONSE_HEADER_SIZE: usize = 9;
|
||||
|
||||
/// Push header size: 2 (event_type) + 4 (len) = 6 bytes.
|
||||
pub const PUSH_HEADER_SIZE: usize = 6;
|
||||
|
||||
// ── Request ──────────────────────────────────────────────────────────────────
|
||||
|
||||
/// A decoded RPC request frame.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RequestFrame {
|
||||
pub method_id: u16,
|
||||
pub request_id: u32,
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
impl RequestFrame {
|
||||
/// Encode this request into a byte buffer.
|
||||
pub fn encode(&self) -> Bytes {
|
||||
let mut buf = BytesMut::with_capacity(REQUEST_HEADER_SIZE + self.payload.len());
|
||||
buf.put_u16(self.method_id);
|
||||
buf.put_u32(self.request_id);
|
||||
buf.put_u32(self.payload.len() as u32);
|
||||
buf.put(self.payload.clone());
|
||||
buf.freeze()
|
||||
}
|
||||
|
||||
/// Decode a request frame from a byte buffer.
|
||||
/// Returns `None` if the buffer does not contain a complete frame.
|
||||
pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, RpcError> {
|
||||
if buf.len() < REQUEST_HEADER_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Peek at payload_len without consuming.
|
||||
let payload_len =
|
||||
u32::from_be_bytes([buf[6], buf[7], buf[8], buf[9]]) as usize;
|
||||
|
||||
if payload_len > MAX_PAYLOAD_SIZE {
|
||||
return Err(RpcError::PayloadTooLarge {
|
||||
size: payload_len,
|
||||
max: MAX_PAYLOAD_SIZE,
|
||||
});
|
||||
}
|
||||
|
||||
let total = REQUEST_HEADER_SIZE + payload_len;
|
||||
if buf.len() < total {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let method_id = buf.get_u16();
|
||||
let request_id = buf.get_u32();
|
||||
let _len = buf.get_u32();
|
||||
let payload = buf.split_to(payload_len).freeze();
|
||||
|
||||
Ok(Some(Self {
|
||||
method_id,
|
||||
request_id,
|
||||
payload,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Response ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// A decoded RPC response frame.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResponseFrame {
|
||||
pub status: u8,
|
||||
pub request_id: u32,
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
impl ResponseFrame {
|
||||
/// Encode this response into a byte buffer.
|
||||
pub fn encode(&self) -> Bytes {
|
||||
let mut buf = BytesMut::with_capacity(RESPONSE_HEADER_SIZE + self.payload.len());
|
||||
buf.put_u8(self.status);
|
||||
buf.put_u32(self.request_id);
|
||||
buf.put_u32(self.payload.len() as u32);
|
||||
buf.put(self.payload.clone());
|
||||
buf.freeze()
|
||||
}
|
||||
|
||||
/// Decode a response frame from a byte buffer.
|
||||
pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, RpcError> {
|
||||
if buf.len() < RESPONSE_HEADER_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let payload_len =
|
||||
u32::from_be_bytes([buf[5], buf[6], buf[7], buf[8]]) as usize;
|
||||
|
||||
if payload_len > MAX_PAYLOAD_SIZE {
|
||||
return Err(RpcError::PayloadTooLarge {
|
||||
size: payload_len,
|
||||
max: MAX_PAYLOAD_SIZE,
|
||||
});
|
||||
}
|
||||
|
||||
let total = RESPONSE_HEADER_SIZE + payload_len;
|
||||
if buf.len() < total {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let status = buf.get_u8();
|
||||
let request_id = buf.get_u32();
|
||||
let _len = buf.get_u32();
|
||||
let payload = buf.split_to(payload_len).freeze();
|
||||
|
||||
Ok(Some(Self {
|
||||
status,
|
||||
request_id,
|
||||
payload,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Convert the status byte to an `RpcStatus`.
|
||||
pub fn rpc_status(&self) -> Option<RpcStatus> {
|
||||
RpcStatus::from_u8(self.status)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Push ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// A decoded server-push event frame (sent on QUIC uni-streams).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PushFrame {
|
||||
pub event_type: u16,
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
impl PushFrame {
|
||||
/// Encode this push frame into a byte buffer.
|
||||
pub fn encode(&self) -> Bytes {
|
||||
let mut buf = BytesMut::with_capacity(PUSH_HEADER_SIZE + self.payload.len());
|
||||
buf.put_u16(self.event_type);
|
||||
buf.put_u32(self.payload.len() as u32);
|
||||
buf.put(self.payload.clone());
|
||||
buf.freeze()
|
||||
}
|
||||
|
||||
/// Decode a push frame from a byte buffer.
|
||||
pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, RpcError> {
|
||||
if buf.len() < PUSH_HEADER_SIZE {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let payload_len =
|
||||
u32::from_be_bytes([buf[2], buf[3], buf[4], buf[5]]) as usize;
|
||||
|
||||
if payload_len > MAX_PAYLOAD_SIZE {
|
||||
return Err(RpcError::PayloadTooLarge {
|
||||
size: payload_len,
|
||||
max: MAX_PAYLOAD_SIZE,
|
||||
});
|
||||
}
|
||||
|
||||
let total = PUSH_HEADER_SIZE + payload_len;
|
||||
if buf.len() < total {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let event_type = buf.get_u16();
|
||||
let _len = buf.get_u32();
|
||||
let payload = buf.split_to(payload_len).freeze();
|
||||
|
||||
Ok(Some(Self {
|
||||
event_type,
|
||||
payload,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn request_roundtrip() {
|
||||
let frame = RequestFrame {
|
||||
method_id: 42,
|
||||
request_id: 1001,
|
||||
payload: Bytes::from_static(b"hello"),
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
let mut buf = BytesMut::from(encoded.as_ref());
|
||||
let decoded = RequestFrame::decode(&mut buf).expect("decode").expect("complete");
|
||||
assert_eq!(decoded.method_id, 42);
|
||||
assert_eq!(decoded.request_id, 1001);
|
||||
assert_eq!(decoded.payload, Bytes::from_static(b"hello"));
|
||||
assert!(buf.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_roundtrip() {
|
||||
let frame = ResponseFrame {
|
||||
status: RpcStatus::Ok as u8,
|
||||
request_id: 2002,
|
||||
payload: Bytes::from_static(b"world"),
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
let mut buf = BytesMut::from(encoded.as_ref());
|
||||
let decoded = ResponseFrame::decode(&mut buf).expect("decode").expect("complete");
|
||||
assert_eq!(decoded.status, 0);
|
||||
assert_eq!(decoded.request_id, 2002);
|
||||
assert_eq!(decoded.payload, Bytes::from_static(b"world"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_roundtrip() {
|
||||
let frame = PushFrame {
|
||||
event_type: 7,
|
||||
payload: Bytes::from_static(b"event-data"),
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
let mut buf = BytesMut::from(encoded.as_ref());
|
||||
let decoded = PushFrame::decode(&mut buf).expect("decode").expect("complete");
|
||||
assert_eq!(decoded.event_type, 7);
|
||||
assert_eq!(decoded.payload, Bytes::from_static(b"event-data"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn incomplete_request_returns_none() {
|
||||
let mut buf = BytesMut::from(&[0u8; 5][..]);
|
||||
assert!(RequestFrame::decode(&mut buf).expect("no error").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn payload_too_large_rejected() {
|
||||
// Craft a request header with payload_len = MAX + 1.
|
||||
let mut buf = BytesMut::new();
|
||||
buf.put_u16(1);
|
||||
buf.put_u32(1);
|
||||
buf.put_u32((MAX_PAYLOAD_SIZE + 1) as u32);
|
||||
let result = RequestFrame::decode(&mut buf);
|
||||
assert!(matches!(result, Err(RpcError::PayloadTooLarge { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_payload_request() {
|
||||
let frame = RequestFrame {
|
||||
method_id: 0,
|
||||
request_id: 0,
|
||||
payload: Bytes::new(),
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
assert_eq!(encoded.len(), REQUEST_HEADER_SIZE);
|
||||
let mut buf = BytesMut::from(encoded.as_ref());
|
||||
let decoded = RequestFrame::decode(&mut buf).expect("decode").expect("complete");
|
||||
assert!(decoded.payload.is_empty());
|
||||
}
|
||||
}
|
||||
13
crates/quicproquo-rpc/src/lib.rs
Normal file
13
crates/quicproquo-rpc/src/lib.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
//! QUIC RPC framework for quicproquo v2.
|
||||
//!
|
||||
//! Wire format per QUIC stream:
|
||||
//! - Request: `[method_id: u16][request_id: u32][payload_len: u32][protobuf bytes]`
|
||||
//! - Response: `[status: u8][request_id: u32][payload_len: u32][protobuf bytes]`
|
||||
//! - Push: `[event_type: u16][payload_len: u32][protobuf bytes]` (uni-stream)
|
||||
|
||||
pub mod framing;
|
||||
pub mod method;
|
||||
pub mod server;
|
||||
pub mod client;
|
||||
pub mod middleware;
|
||||
pub mod error;
|
||||
102
crates/quicproquo-rpc/src/method.rs
Normal file
102
crates/quicproquo-rpc/src/method.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
//! Method registry — maps method IDs to handler functions.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
use crate::error::RpcStatus;
|
||||
|
||||
/// The result of handling an RPC request.
|
||||
pub struct HandlerResult {
|
||||
pub status: RpcStatus,
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
impl HandlerResult {
|
||||
/// Shorthand for a successful response.
|
||||
pub fn ok(payload: Bytes) -> Self {
|
||||
Self {
|
||||
status: RpcStatus::Ok,
|
||||
payload,
|
||||
}
|
||||
}
|
||||
|
||||
/// Shorthand for an error response.
|
||||
pub fn err(status: RpcStatus, message: &str) -> Self {
|
||||
Self {
|
||||
status,
|
||||
payload: Bytes::copy_from_slice(message.as_bytes()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context passed to every RPC handler.
|
||||
pub struct RequestContext {
|
||||
/// The authenticated identity key of the caller, if any.
|
||||
pub identity_key: Option<Vec<u8>>,
|
||||
/// The session token, if provided.
|
||||
pub session_token: Option<Vec<u8>>,
|
||||
/// The raw request payload (protobuf-encoded).
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
/// Type-erased async handler function.
|
||||
pub type HandlerFn<S> = Arc<
|
||||
dyn Fn(Arc<S>, RequestContext) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>;
|
||||
|
||||
/// Registry mapping method IDs to handler functions.
|
||||
pub struct MethodRegistry<S> {
|
||||
handlers: HashMap<u16, (HandlerFn<S>, &'static str)>,
|
||||
}
|
||||
|
||||
impl<S: Send + Sync + 'static> MethodRegistry<S> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handlers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a handler for a method ID.
|
||||
pub fn register<F, Fut>(&mut self, method_id: u16, name: &'static str, handler: F)
|
||||
where
|
||||
F: Fn(Arc<S>, RequestContext) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output = HandlerResult> + Send + 'static,
|
||||
{
|
||||
let handler = Arc::new(move |state: Arc<S>, ctx: RequestContext| {
|
||||
Box::pin(handler(state, ctx)) as Pin<Box<dyn Future<Output = HandlerResult> + Send>>
|
||||
});
|
||||
self.handlers.insert(method_id, (handler, name));
|
||||
}
|
||||
|
||||
/// Look up a handler by method ID.
|
||||
pub fn get(&self, method_id: u16) -> Option<&(HandlerFn<S>, &'static str)> {
|
||||
self.handlers.get(&method_id)
|
||||
}
|
||||
|
||||
/// Return the number of registered methods.
|
||||
pub fn len(&self) -> usize {
|
||||
self.handlers.len()
|
||||
}
|
||||
|
||||
/// Whether the registry is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.handlers.is_empty()
|
||||
}
|
||||
|
||||
/// Iterate over all registered (method_id, name) pairs.
|
||||
pub fn methods(&self) -> impl Iterator<Item = (u16, &'static str)> + '_ {
|
||||
self.handlers.iter().map(|(&id, (_, name))| (id, *name))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Send + Sync + 'static> Default for MethodRegistry<S> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
96
crates/quicproquo-rpc/src/middleware.rs
Normal file
96
crates/quicproquo-rpc/src/middleware.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
//! Tower-based middleware layers for the RPC server.
|
||||
//!
|
||||
//! - `AuthLayer`: validates session tokens and attaches identity to context.
|
||||
//! - `RateLimitLayer`: per-IP request rate limiting.
|
||||
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use dashmap::DashMap;
|
||||
|
||||
// ── Auth middleware ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Validates bearer tokens and resolves identity keys.
|
||||
pub trait SessionValidator: Send + Sync + 'static {
|
||||
/// Validate a session token, returning the identity key if valid.
|
||||
fn validate(&self, token: &[u8]) -> Option<Vec<u8>>;
|
||||
}
|
||||
|
||||
/// Auth context extracted from a validated session.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthContext {
|
||||
/// The Ed25519 identity key of the authenticated caller.
|
||||
pub identity_key: Vec<u8>,
|
||||
}
|
||||
|
||||
// ── Rate limiter ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Simple per-key sliding-window rate limiter.
|
||||
pub struct RateLimiter {
|
||||
/// Max requests per window.
|
||||
max_requests: u32,
|
||||
/// Window duration.
|
||||
window: Duration,
|
||||
/// Map from key → (count, window_start).
|
||||
state: DashMap<Vec<u8>, (u32, Instant)>,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
/// Create a new rate limiter.
|
||||
pub fn new(max_requests: u32, window: Duration) -> Self {
|
||||
Self {
|
||||
max_requests,
|
||||
window,
|
||||
state: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a request from `key` is allowed. Returns `true` if allowed.
|
||||
pub fn check(&self, key: &[u8]) -> bool {
|
||||
let now = Instant::now();
|
||||
let mut entry = self.state.entry(key.to_vec()).or_insert((0, now));
|
||||
let (count, window_start) = entry.value_mut();
|
||||
|
||||
if now.duration_since(*window_start) >= self.window {
|
||||
// Reset window.
|
||||
*count = 1;
|
||||
*window_start = now;
|
||||
true
|
||||
} else if *count < self.max_requests {
|
||||
*count += 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove expired entries (call periodically for memory hygiene).
|
||||
pub fn gc(&self) {
|
||||
let now = Instant::now();
|
||||
self.state.retain(|_, (_, start)| now.duration_since(*start) < self.window * 2);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn rate_limiter_allows_within_limit() {
|
||||
let rl = RateLimiter::new(3, Duration::from_secs(60));
|
||||
let key = b"test-key";
|
||||
assert!(rl.check(key));
|
||||
assert!(rl.check(key));
|
||||
assert!(rl.check(key));
|
||||
assert!(!rl.check(key)); // 4th request denied
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rate_limiter_resets_after_window() {
|
||||
let rl = RateLimiter::new(1, Duration::from_millis(1));
|
||||
let key = b"test-key";
|
||||
assert!(rl.check(key));
|
||||
assert!(!rl.check(key));
|
||||
std::thread::sleep(Duration::from_millis(5));
|
||||
assert!(rl.check(key)); // window expired
|
||||
}
|
||||
}
|
||||
198
crates/quicproquo-rpc/src/server.rs
Normal file
198
crates/quicproquo-rpc/src/server.rs
Normal file
@@ -0,0 +1,198 @@
|
||||
//! QUIC RPC server — accepts connections, dispatches requests to handlers.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::BytesMut;
|
||||
use quinn::{Endpoint, Incoming, RecvStream, SendStream};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::error::{RpcError, RpcStatus};
|
||||
use crate::framing::{RequestFrame, ResponseFrame, PushFrame};
|
||||
use crate::method::{HandlerResult, MethodRegistry, RequestContext};
|
||||
|
||||
/// Configuration for the RPC server.
|
||||
pub struct RpcServerConfig {
|
||||
/// QUIC listen address.
|
||||
pub listen_addr: std::net::SocketAddr,
|
||||
/// TLS server config (rustls).
|
||||
pub tls_config: Arc<rustls::ServerConfig>,
|
||||
/// ALPN protocol for the RPC service.
|
||||
pub alpn: Vec<u8>,
|
||||
}
|
||||
|
||||
/// The QUIC RPC server.
|
||||
pub struct RpcServer<S: Send + Sync + 'static> {
|
||||
endpoint: Endpoint,
|
||||
state: Arc<S>,
|
||||
registry: Arc<MethodRegistry<S>>,
|
||||
}
|
||||
|
||||
impl<S: Send + Sync + 'static> RpcServer<S> {
|
||||
/// Create and bind the QUIC endpoint. Does not start accepting yet.
|
||||
pub fn bind(
|
||||
config: RpcServerConfig,
|
||||
state: Arc<S>,
|
||||
registry: MethodRegistry<S>,
|
||||
) -> Result<Self, RpcError> {
|
||||
let mut tls = (*config.tls_config).clone();
|
||||
tls.alpn_protocols = vec![config.alpn];
|
||||
let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(tls)
|
||||
.map_err(|e| RpcError::Connection(format!("TLS config: {e}")))?;
|
||||
let server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_tls));
|
||||
|
||||
let endpoint = Endpoint::server(server_config, config.listen_addr)
|
||||
.map_err(|e| RpcError::Connection(format!("bind {}: {e}", config.listen_addr)))?;
|
||||
|
||||
info!(addr = %config.listen_addr, "RPC server bound");
|
||||
|
||||
Ok(Self {
|
||||
endpoint,
|
||||
state,
|
||||
registry: Arc::new(registry),
|
||||
})
|
||||
}
|
||||
|
||||
/// Accept connections in a loop. Spawns a task per connection.
|
||||
pub async fn serve(self) -> Result<(), RpcError> {
|
||||
info!("RPC server accepting connections");
|
||||
while let Some(incoming) = self.endpoint.accept().await {
|
||||
let state = Arc::clone(&self.state);
|
||||
let registry = Arc::clone(&self.registry);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_connection(incoming, state, registry).await {
|
||||
warn!("connection error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the local address the server is listening on.
|
||||
pub fn local_addr(&self) -> Result<std::net::SocketAddr, RpcError> {
|
||||
self.endpoint
|
||||
.local_addr()
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a single QUIC connection: accept bi-directional streams for RPCs.
|
||||
async fn handle_connection<S: Send + Sync + 'static>(
|
||||
incoming: Incoming,
|
||||
state: Arc<S>,
|
||||
registry: Arc<MethodRegistry<S>>,
|
||||
) -> Result<(), RpcError> {
|
||||
let connection = incoming
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
let remote = connection.remote_address();
|
||||
debug!(remote = %remote, "new connection");
|
||||
|
||||
loop {
|
||||
let stream = connection.accept_bi().await;
|
||||
match stream {
|
||||
Ok((send, recv)) => {
|
||||
let state = Arc::clone(&state);
|
||||
let registry = Arc::clone(®istry);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_stream(send, recv, state, registry).await {
|
||||
debug!("stream error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => {
|
||||
debug!(remote = %remote, "connection closed by peer");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(remote = %remote, "accept_bi error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle a single bi-directional stream: read request, dispatch, write response.
|
||||
async fn handle_stream<S: Send + Sync + 'static>(
|
||||
mut send: SendStream,
|
||||
mut recv: RecvStream,
|
||||
state: Arc<S>,
|
||||
registry: Arc<MethodRegistry<S>>,
|
||||
) -> Result<(), RpcError> {
|
||||
// Read the complete request from the stream.
|
||||
let mut buf = BytesMut::new();
|
||||
while let Some(chunk) = recv
|
||||
.read_chunk(65536, true)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?
|
||||
{
|
||||
buf.extend_from_slice(&chunk.bytes);
|
||||
if buf.len() > crate::framing::MAX_PAYLOAD_SIZE + crate::framing::REQUEST_HEADER_SIZE {
|
||||
return Err(RpcError::PayloadTooLarge {
|
||||
size: buf.len(),
|
||||
max: crate::framing::MAX_PAYLOAD_SIZE,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let frame = match RequestFrame::decode(&mut buf)? {
|
||||
Some(f) => f,
|
||||
None => return Err(RpcError::Decode("incomplete request frame".into())),
|
||||
};
|
||||
|
||||
let result = match registry.get(frame.method_id) {
|
||||
Some((handler, name)) => {
|
||||
debug!(method_id = frame.method_id, method = name, req_id = frame.request_id, "dispatching");
|
||||
let ctx = RequestContext {
|
||||
identity_key: None, // populated by auth middleware
|
||||
session_token: None,
|
||||
payload: frame.payload,
|
||||
};
|
||||
handler(Arc::clone(&state), ctx).await
|
||||
}
|
||||
None => {
|
||||
warn!(method_id = frame.method_id, "unknown method");
|
||||
HandlerResult::err(RpcStatus::UnknownMethod, "unknown method")
|
||||
}
|
||||
};
|
||||
|
||||
let response = ResponseFrame {
|
||||
status: result.status as u8,
|
||||
request_id: frame.request_id,
|
||||
payload: result.payload,
|
||||
};
|
||||
|
||||
let encoded = response.encode();
|
||||
send.write_all(&encoded)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a push event to a client via a QUIC uni-stream.
|
||||
pub async fn send_push(
|
||||
connection: &quinn::Connection,
|
||||
event_type: u16,
|
||||
payload: bytes::Bytes,
|
||||
) -> Result<(), RpcError> {
|
||||
let mut send = connection
|
||||
.open_uni()
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
let frame = PushFrame {
|
||||
event_type,
|
||||
payload,
|
||||
};
|
||||
let encoded = frame.encode();
|
||||
send.write_all(&encoded)
|
||||
.await
|
||||
.map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
send.finish().map_err(|e| RpcError::Connection(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user