Rename project to quicnprotochat
This commit is contained in:
42
crates/quicnprotochat-server/Cargo.toml
Normal file
42
crates/quicnprotochat-server/Cargo.toml
Normal file
@@ -0,0 +1,42 @@
|
||||
[package]
|
||||
name = "quicnprotochat-server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "Delivery Service and Authentication Service for quicnprotochat."
|
||||
license = "MIT"
|
||||
|
||||
[[bin]]
|
||||
name = "quicnprotochat-server"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
quicnprotochat-core = { path = "../quicnprotochat-core" }
|
||||
quicnprotochat-proto = { path = "../quicnprotochat-proto" }
|
||||
|
||||
# Serialisation + RPC
|
||||
capnp = { workspace = true }
|
||||
capnp-rpc = { workspace = true }
|
||||
|
||||
# Async
|
||||
tokio = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
# Server utilities
|
||||
dashmap = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
quinn = { workspace = true }
|
||||
quinn-proto = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
rcgen = { workspace = true }
|
||||
|
||||
# Error handling
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
bincode = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
# CLI
|
||||
clap = { workspace = true }
|
||||
508
crates/quicnprotochat-server/src/main.rs
Normal file
508
crates/quicnprotochat-server/src/main.rs
Normal file
@@ -0,0 +1,508 @@
|
||||
//! quicnprotochat-server — unified Authentication + Delivery service.
|
||||
//!
|
||||
//! # M3 scope
|
||||
//!
|
||||
//! The server exposes a single QUIC + TLS 1.3 Cap'n Proto RPC endpoint
|
||||
//! (`NodeService`) combining Authentication and Delivery operations.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! QUIC endpoint (7000)
|
||||
//! └─ TLS 1.3 handshake (self-signed by default)
|
||||
//! └─ capnp-rpc VatNetwork (LocalSet, !Send)
|
||||
//! └─ NodeServiceImpl (KeyPackage + Delivery queues)
|
||||
//! ```
|
||||
//!
|
||||
//! Because `capnp-rpc` uses `Rc<RefCell<>>` internally it is `!Send`.
|
||||
//! The entire RPC stack lives on a `tokio::task::LocalSet` spawned per
|
||||
//! connection.
|
||||
//!
|
||||
//! # Configuration
|
||||
//!
|
||||
//! | Env var | CLI flag | Default |
|
||||
//! |---------------------|----------------|-----------------|
|
||||
//! | `QUICNPROTOCHAT_LISTEN` | `--listen` | `0.0.0.0:7000` |
|
||||
//! | `RUST_LOG` | — | `info` |
|
||||
|
||||
use std::{fs, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use capnp::capability::Promise;
|
||||
use capnp_rpc::{rpc_twoparty_capnp::Side, twoparty, RpcSystem};
|
||||
use clap::Parser;
|
||||
use dashmap::DashMap;
|
||||
use quicnprotochat_proto::node_capnp::node_service;
|
||||
use quinn::{Endpoint, ServerConfig};
|
||||
use quinn_proto::crypto::rustls::QuicServerConfig;
|
||||
use rcgen::generate_simple_self_signed;
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
|
||||
|
||||
mod storage;
|
||||
use storage::{FileBackedStore, StorageError};
|
||||
|
||||
// ── CLI ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
#[command(
|
||||
name = "quicnprotochat-server",
|
||||
about = "quicnprotochat Delivery Service + Authentication Service",
|
||||
version
|
||||
)]
|
||||
struct Args {
|
||||
/// QUIC listen address (host:port).
|
||||
#[arg(long, default_value = "0.0.0.0:7000", env = "QUICNPROTOCHAT_LISTEN")]
|
||||
listen: String,
|
||||
|
||||
/// Directory for persisted server data (KeyPackages + delivery queues).
|
||||
#[arg(long, default_value = "data", env = "QUICNPROTOCHAT_DATA_DIR")]
|
||||
data_dir: String,
|
||||
|
||||
/// TLS certificate path (generated automatically if missing).
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "data/server-cert.der",
|
||||
env = "QUICNPROTOCHAT_TLS_CERT"
|
||||
)]
|
||||
tls_cert: PathBuf,
|
||||
|
||||
/// TLS private key path (generated automatically if missing).
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "data/server-key.der",
|
||||
env = "QUICNPROTOCHAT_TLS_KEY"
|
||||
)]
|
||||
tls_key: PathBuf,
|
||||
}
|
||||
|
||||
// ── Node service implementation ─────────────────────────────────────────────
|
||||
|
||||
/// Cap'n Proto RPC server implementation for `NodeService` (Auth + Delivery).
|
||||
struct NodeServiceImpl {
|
||||
store: Arc<FileBackedStore>,
|
||||
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
fn waiter(&self, recipient_key: &[u8]) -> Arc<Notify> {
|
||||
self.waiters
|
||||
.entry(recipient_key.to_vec())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl node_service::Server for NodeServiceImpl {
|
||||
/// Upload a single-use KeyPackage and return its SHA-256 fingerprint.
|
||||
fn upload_key_package(
|
||||
&mut self,
|
||||
params: node_service::UploadKeyPackageParams,
|
||||
mut results: node_service::UploadKeyPackageResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let params = params
|
||||
.get()
|
||||
.map_err(|e| capnp::Error::failed(format!("upload_key_package: bad params: {e}")));
|
||||
|
||||
let (identity_key, package) = match params {
|
||||
Ok(p) => {
|
||||
let ik = match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
|
||||
};
|
||||
let pkg = match p.get_package() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
|
||||
};
|
||||
(ik, pkg)
|
||||
}
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if identity_key.len() != 32 {
|
||||
return Promise::err(capnp::Error::failed(format!(
|
||||
"identityKey must be exactly 32 bytes, got {}",
|
||||
identity_key.len()
|
||||
)));
|
||||
}
|
||||
if package.is_empty() {
|
||||
return Promise::err(capnp::Error::failed(
|
||||
"package must not be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let fingerprint: Vec<u8> = Sha256::digest(&package).to_vec();
|
||||
if let Err(e) = self
|
||||
.store
|
||||
.upload_key_package(&identity_key, package)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
results.get().set_fingerprint(&fingerprint);
|
||||
|
||||
tracing::debug!(
|
||||
fingerprint = %fmt_hex(&fingerprint[..4]),
|
||||
"KeyPackage uploaded"
|
||||
);
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
/// Atomically remove and return one KeyPackage for the given identity key.
|
||||
fn fetch_key_package(
|
||||
&mut self,
|
||||
params: node_service::FetchKeyPackageParams,
|
||||
mut results: node_service::FetchKeyPackageResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let identity_key = match params.get() {
|
||||
Ok(p) => match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
|
||||
},
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
|
||||
};
|
||||
|
||||
if identity_key.len() != 32 {
|
||||
return Promise::err(capnp::Error::failed(format!(
|
||||
"identityKey must be exactly 32 bytes, got {}",
|
||||
identity_key.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let package = match self
|
||||
.store
|
||||
.fetch_key_package(&identity_key)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
match package {
|
||||
Some(pkg) => {
|
||||
tracing::debug!(
|
||||
identity = %fmt_hex(&identity_key[..4]),
|
||||
"KeyPackage fetched"
|
||||
);
|
||||
results.get().set_package(&pkg);
|
||||
}
|
||||
None => {
|
||||
tracing::debug!(
|
||||
identity = %fmt_hex(&identity_key[..4]),
|
||||
"no KeyPackage available for identity"
|
||||
);
|
||||
results.get().set_package(&[]);
|
||||
}
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
/// Append `payload` to the queue for `recipient_key`.
|
||||
fn enqueue(
|
||||
&mut self,
|
||||
params: node_service::EnqueueParams,
|
||||
_results: node_service::EnqueueResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
|
||||
};
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
|
||||
};
|
||||
let payload = match p.get_payload() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(capnp::Error::failed(format!(
|
||||
"recipientKey must be exactly 32 bytes, got {}",
|
||||
recipient_key.len()
|
||||
)));
|
||||
}
|
||||
if payload.is_empty() {
|
||||
return Promise::err(capnp::Error::failed(
|
||||
"payload must not be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = self
|
||||
.store
|
||||
.enqueue(&recipient_key, payload)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
self.waiter(&recipient_key).notify_waiters();
|
||||
|
||||
tracing::debug!(
|
||||
recipient = %fmt_hex(&recipient_key[..4]),
|
||||
"message enqueued"
|
||||
);
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
/// Atomically drain and return all queued payloads for `recipient_key`.
|
||||
fn fetch(
|
||||
&mut self,
|
||||
params: node_service::FetchParams,
|
||||
mut results: node_service::FetchResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let recipient_key = match params.get() {
|
||||
Ok(p) => match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
|
||||
},
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(capnp::Error::failed(format!(
|
||||
"recipientKey must be exactly 32 bytes, got {}",
|
||||
recipient_key.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let messages = match self.store.fetch(&recipient_key).map_err(storage_err) {
|
||||
Ok(m) => m,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
recipient = %fmt_hex(&recipient_key[..4]),
|
||||
count = messages.len(),
|
||||
"messages fetched"
|
||||
);
|
||||
|
||||
let mut list = results.get().init_payloads(messages.len() as u32);
|
||||
for (i, msg) in messages.iter().enumerate() {
|
||||
list.set(i as u32, msg);
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
/// Long-polling fetch with timeout (ms).
|
||||
fn fetch_wait(
|
||||
&mut self,
|
||||
params: node_service::FetchWaitParams,
|
||||
mut results: node_service::FetchWaitResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
|
||||
};
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("{e}"))),
|
||||
};
|
||||
let timeout_ms = p.get_timeout_ms();
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(capnp::Error::failed(format!(
|
||||
"recipientKey must be exactly 32 bytes, got {}",
|
||||
recipient_key.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let store = Arc::clone(&self.store);
|
||||
let waiters = self.waiters.clone();
|
||||
|
||||
Promise::from_future(async move {
|
||||
let messages = store.fetch(&recipient_key).map_err(storage_err)?;
|
||||
|
||||
if messages.is_empty() && timeout_ms > 0 {
|
||||
let waiter = waiters
|
||||
.entry(recipient_key.clone())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone();
|
||||
let _ = timeout(Duration::from_millis(timeout_ms), waiter.notified()).await;
|
||||
let msgs = store.fetch(&recipient_key).map_err(storage_err)?;
|
||||
fill_payloads_wait(&mut results, msgs);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fill_payloads_wait(&mut results, messages);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn health(
|
||||
&mut self,
|
||||
_params: node_service::HealthParams,
|
||||
mut results: node_service::HealthResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
results.get().set_status("ok");
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn fill_payloads_wait(results: &mut node_service::FetchWaitResults, messages: Vec<Vec<u8>>) {
|
||||
let mut list = results.get().init_payloads(messages.len() as u32);
|
||||
for (i, msg) in messages.iter().enumerate() {
|
||||
list.set(i as u32, msg);
|
||||
}
|
||||
}
|
||||
|
||||
fn storage_err(err: StorageError) -> capnp::Error {
|
||||
capnp::Error::failed(format!("{err}"))
|
||||
}
|
||||
|
||||
// ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let listen: SocketAddr = args.listen.parse().context("--listen must be host:port")?;
|
||||
|
||||
let server_config = build_server_config(&args.tls_cert, &args.tls_key)
|
||||
.context("failed to build TLS/QUIC server config")?;
|
||||
|
||||
// Shared storage — persisted to disk for restart safety.
|
||||
let store = Arc::new(FileBackedStore::open(&args.data_dir)?);
|
||||
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = Arc::new(DashMap::new());
|
||||
|
||||
let endpoint = Endpoint::server(server_config, listen)?;
|
||||
|
||||
tracing::info!(
|
||||
addr = %args.listen,
|
||||
"accepting QUIC connections"
|
||||
);
|
||||
|
||||
// capnp-rpc is !Send (Rc internals), so all RPC tasks must stay on a
|
||||
// LocalSet. Both accept loops share one LocalSet.
|
||||
let local = tokio::task::LocalSet::new();
|
||||
local
|
||||
.run_until(async move {
|
||||
loop {
|
||||
let incoming = match endpoint.accept().await {
|
||||
Some(i) => i,
|
||||
None => break,
|
||||
};
|
||||
|
||||
let connecting = match incoming.accept() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to accept incoming connection");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let store = Arc::clone(&store);
|
||||
let waiters = Arc::clone(&waiters);
|
||||
tokio::task::spawn_local(async move {
|
||||
if let Err(e) = handle_node_connection(connecting, store, waiters).await {
|
||||
tracing::warn!(error = %e, "connection error");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok::<(), anyhow::Error>(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
// ── Per-connection handlers ───────────────────────────────────────────────────
|
||||
|
||||
/// Handle one NodeService connection.
|
||||
async fn handle_node_connection(
|
||||
connecting: quinn::Connecting,
|
||||
store: Arc<FileBackedStore>,
|
||||
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let connection = connecting.await?;
|
||||
|
||||
tracing::info!(peer = %connection.remote_address(), "QUIC connected");
|
||||
|
||||
let (send, recv) = connection
|
||||
.accept_bi()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("failed to accept bi stream: {e}"))?;
|
||||
let (reader, writer) = (recv.compat(), send.compat_write());
|
||||
|
||||
let network = twoparty::VatNetwork::new(reader, writer, Side::Server, Default::default());
|
||||
|
||||
let service: node_service::Client = capnp_rpc::new_client(NodeServiceImpl { store, waiters });
|
||||
|
||||
RpcSystem::new(Box::new(network), Some(service.client))
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("NodeService RPC error: {e}"))
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Format the first `n` bytes of a slice as lowercase hex with a trailing `…`.
|
||||
fn fmt_hex(bytes: &[u8]) -> String {
|
||||
let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
|
||||
format!("{hex}…")
|
||||
}
|
||||
|
||||
/// Ensure a self-signed certificate exists on disk and return a QUIC server config.
|
||||
fn build_server_config(cert_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result<ServerConfig> {
|
||||
if !cert_path.exists() || !key_path.exists() {
|
||||
generate_self_signed(cert_path, key_path)?;
|
||||
}
|
||||
|
||||
let cert_bytes = fs::read(cert_path).context("read cert")?;
|
||||
let key_bytes = fs::read(key_path).context("read key")?;
|
||||
|
||||
let cert_chain = vec![CertificateDer::from(cert_bytes)];
|
||||
let key = PrivateKeyDer::try_from(key_bytes).map_err(|_| anyhow::anyhow!("invalid key"))?;
|
||||
|
||||
let mut tls = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, key)?;
|
||||
tls.alpn_protocols = vec![b"capnp".to_vec()];
|
||||
|
||||
let crypto = QuicServerConfig::try_from(tls)
|
||||
.map_err(|e| anyhow::anyhow!("invalid server TLS config: {e}"))?;
|
||||
|
||||
Ok(ServerConfig::with_crypto(Arc::new(crypto)))
|
||||
}
|
||||
|
||||
fn generate_self_signed(cert_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result<()> {
|
||||
if let Some(parent) = cert_path.parent() {
|
||||
fs::create_dir_all(parent).context("create cert dir")?;
|
||||
}
|
||||
if let Some(parent) = key_path.parent() {
|
||||
fs::create_dir_all(parent).context("create key dir")?;
|
||||
}
|
||||
|
||||
let subject_alt_names = vec![
|
||||
"localhost".to_string(),
|
||||
"127.0.0.1".to_string(),
|
||||
"::1".to_string(),
|
||||
];
|
||||
|
||||
let issued = generate_simple_self_signed(subject_alt_names)?;
|
||||
let key_der = issued.key_pair.serialize_der();
|
||||
|
||||
fs::write(cert_path, issued.cert.der()).context("write cert")?;
|
||||
fs::write(key_path, &key_der).context("write key")?;
|
||||
|
||||
tracing::info!(
|
||||
cert = %cert_path.display(),
|
||||
key = %key_path.display(),
|
||||
"generated self-signed TLS certificate"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
114
crates/quicnprotochat-server/src/storage.rs
Normal file
114
crates/quicnprotochat-server/src/storage.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum StorageError {
|
||||
#[error("io error: {0}")]
|
||||
Io(String),
|
||||
#[error("serialization error")]
|
||||
Serde,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
struct QueueMap {
|
||||
map: HashMap<Vec<u8>, VecDeque<Vec<u8>>>,
|
||||
}
|
||||
|
||||
/// File-backed storage for KeyPackages and delivery queues.
|
||||
///
|
||||
/// Each mutation flushes the entire map to disk. Suitable for MVP-scale loads.
|
||||
pub struct FileBackedStore {
|
||||
kp_path: PathBuf,
|
||||
ds_path: PathBuf,
|
||||
key_packages: Mutex<HashMap<Vec<u8>, VecDeque<Vec<u8>>>>,
|
||||
deliveries: Mutex<HashMap<Vec<u8>, VecDeque<Vec<u8>>>>,
|
||||
}
|
||||
|
||||
impl FileBackedStore {
|
||||
pub fn open(dir: impl AsRef<Path>) -> Result<Self, StorageError> {
|
||||
let dir = dir.as_ref();
|
||||
if !dir.exists() {
|
||||
fs::create_dir_all(dir).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
let kp_path = dir.join("keypackages.bin");
|
||||
let ds_path = dir.join("deliveries.bin");
|
||||
|
||||
let key_packages = Mutex::new(Self::load_map(&kp_path)?);
|
||||
let deliveries = Mutex::new(Self::load_map(&ds_path)?);
|
||||
|
||||
Ok(Self {
|
||||
kp_path,
|
||||
ds_path,
|
||||
key_packages,
|
||||
deliveries,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn upload_key_package(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
package: Vec<u8>,
|
||||
) -> Result<(), StorageError> {
|
||||
let mut map = self.key_packages.lock().unwrap();
|
||||
map.entry(identity_key.to_vec())
|
||||
.or_default()
|
||||
.push_back(package);
|
||||
self.flush_map(&self.kp_path, &*map)
|
||||
}
|
||||
|
||||
pub fn fetch_key_package(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
let mut map = self.key_packages.lock().unwrap();
|
||||
let package = map.get_mut(identity_key).and_then(|q| q.pop_front());
|
||||
self.flush_map(&self.kp_path, &*map)?;
|
||||
Ok(package)
|
||||
}
|
||||
|
||||
pub fn enqueue(&self, recipient_key: &[u8], payload: Vec<u8>) -> Result<(), StorageError> {
|
||||
let mut map = self.deliveries.lock().unwrap();
|
||||
map.entry(recipient_key.to_vec())
|
||||
.or_default()
|
||||
.push_back(payload);
|
||||
self.flush_map(&self.ds_path, &*map)
|
||||
}
|
||||
|
||||
pub fn fetch(&self, recipient_key: &[u8]) -> Result<Vec<Vec<u8>>, StorageError> {
|
||||
let mut map = self.deliveries.lock().unwrap();
|
||||
let messages = map
|
||||
.get_mut(recipient_key)
|
||||
.map(|q| q.drain(..).collect())
|
||||
.unwrap_or_default();
|
||||
self.flush_map(&self.ds_path, &*map)?;
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
fn load_map(path: &Path) -> Result<HashMap<Vec<u8>, VecDeque<Vec<u8>>>, StorageError> {
|
||||
if !path.exists() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
if bytes.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
let map: QueueMap = bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)?;
|
||||
Ok(map.map)
|
||||
}
|
||||
|
||||
fn flush_map(
|
||||
&self,
|
||||
path: &Path,
|
||||
map: &HashMap<Vec<u8>, VecDeque<Vec<u8>>>,
|
||||
) -> Result<(), StorageError> {
|
||||
let payload = QueueMap { map: map.clone() };
|
||||
let bytes = bincode::serialize(&payload).map_err(|_| StorageError::Serde)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user