DM channels (createChannel), channel authz, security/docs, future improvements
- Add createChannel RPC (node.capnp @18): create 1:1 channel, returns 16-byte channelId - Store: create_channel(member_a, member_b), get_channel_members(channel_id) - FileBackedStore: channels.bin; SqlStore: migration 003_channels, schema v4 - channel_ops: handle_create_channel (auth + identity, peerKey 32 bytes) - Delivery authz: when channel_id.len() == 16, require caller and recipient are channel members (E022/E023) - Error codes E022 CHANNEL_ACCESS_DENIED, E023 CHANNEL_NOT_FOUND - SUMMARY: link Certificate lifecycle; security audit, future improvements, multi-agent plan docs - Certificate lifecycle doc, SECURITY-AUDIT, FUTURE-IMPROVEMENTS, MULTI-AGENT-WORK-PLAN - Client/core/tls/auth/server main: assorted fixes and updates from review and audit Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
13
crates/quicnprotochat-server/migrations/003_channels.sql
Normal file
13
crates/quicnprotochat-server/migrations/003_channels.sql
Normal file
@@ -0,0 +1,13 @@
|
||||
-- Migration 003: 1:1 DM channels.
|
||||
-- channel_id is 16 bytes (UUID); member_a and member_b are identity keys in sorted order.
|
||||
-- Unique on (member_a, member_b) prevents duplicate channels between the same pair.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS channels (
|
||||
channel_id BLOB PRIMARY KEY,
|
||||
member_a BLOB NOT NULL,
|
||||
member_b BLOB NOT NULL,
|
||||
UNIQUE(member_a, member_b)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_channels_members
|
||||
ON channels(member_a, member_b);
|
||||
@@ -17,6 +17,7 @@ pub const RATE_LIMIT_MAX_ENQUEUES: u32 = 100;
|
||||
pub struct AuthConfig {
|
||||
pub required_token: Option<Vec<u8>>,
|
||||
/// When true, a valid bearer token (no session) is accepted and the request's identity/key is used (dev/e2e only).
|
||||
/// CLI flag: --allow-insecure-auth / QUICNPROTOCHAT_ALLOW_INSECURE_AUTH.
|
||||
pub allow_insecure_identity_from_request: bool,
|
||||
}
|
||||
|
||||
@@ -59,10 +60,13 @@ pub struct AuthContext {
|
||||
}
|
||||
|
||||
pub fn current_timestamp() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
|
||||
Ok(d) => d.as_secs(),
|
||||
Err(_) => {
|
||||
tracing::warn!("system time is before UNIX_EPOCH; using 0 for session/rate-limit timestamps");
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_rate_limit(
|
||||
@@ -174,7 +178,7 @@ pub fn require_identity<'a>(auth_ctx: &'a AuthContext) -> Result<&'a [u8], capnp
|
||||
|
||||
pub fn require_identity_match(auth_ctx: &AuthContext, expected: &[u8]) -> Result<(), capnp::Error> {
|
||||
let ik = require_identity(auth_ctx)?;
|
||||
if ik != expected {
|
||||
if ik.len() != expected.len() || !bool::from(ik.ct_eq(expected)) {
|
||||
return Err(crate::error_codes::coded_error(
|
||||
E016_IDENTITY_MISMATCH,
|
||||
"access token is bound to a different identity",
|
||||
|
||||
@@ -24,6 +24,8 @@ pub const E018_USER_EXISTS: &str = "E018";
|
||||
pub const E019_NO_PENDING_LOGIN: &str = "E019";
|
||||
pub const E020_BAD_PARAMS: &str = "E020";
|
||||
pub const E021_CIPHERSUITE_NOT_ALLOWED: &str = "E021";
|
||||
pub const E022_CHANNEL_ACCESS_DENIED: &str = "E022";
|
||||
pub const E023_CHANNEL_NOT_FOUND: &str = "E023";
|
||||
|
||||
/// Build a `capnp::Error::failed()` with the structured code prefix.
|
||||
pub fn coded_error(code: &str, msg: impl std::fmt::Display) -> capnp::Error {
|
||||
|
||||
@@ -162,9 +162,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
.parse()
|
||||
.context("--listen must be host:port")?;
|
||||
|
||||
let server_config = build_server_config(&effective.tls_cert, &effective.tls_key, production)
|
||||
let mut server_config = build_server_config(&effective.tls_cert, &effective.tls_key, production)
|
||||
.context("failed to build TLS/QUIC server config")?;
|
||||
|
||||
// Harden QUIC transport: idle timeout, limit stream concurrency.
|
||||
let mut transport = quinn::TransportConfig::default();
|
||||
transport.max_idle_timeout(Some(std::time::Duration::from_secs(300).try_into().unwrap()));
|
||||
transport.max_concurrent_bidi_streams(1u32.into());
|
||||
transport.max_concurrent_uni_streams(0u32.into());
|
||||
server_config.transport_config(Arc::new(transport));
|
||||
|
||||
// Shared storage — persisted to disk for restart safety.
|
||||
let store: Arc<dyn Store> = match effective.store_backend.as_str() {
|
||||
"sql" => {
|
||||
@@ -223,6 +230,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
Arc::clone(&pending_logins),
|
||||
Arc::clone(&rate_limits),
|
||||
Arc::clone(&store),
|
||||
Arc::clone(&waiters),
|
||||
);
|
||||
|
||||
let endpoint = Endpoint::server(server_config, listen)?;
|
||||
|
||||
@@ -19,6 +19,16 @@ fn storage_err(err: StorageError) -> capnp::Error {
|
||||
coded_error(E009_STORAGE_ERROR, err)
|
||||
}
|
||||
|
||||
/// Parse username from Cap'n Proto reader; requires valid UTF-8.
|
||||
fn parse_username_param(
|
||||
result: Result<capnp::text::Reader<'_>, capnp::Error>,
|
||||
) -> Result<String, capnp::Error> {
|
||||
let reader = result.map_err(|e| coded_error(E020_BAD_PARAMS, e))?;
|
||||
reader
|
||||
.to_string()
|
||||
.map_err(|_| coded_error(E020_BAD_PARAMS, "username must be valid UTF-8"))
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
pub fn handle_opaque_login_start(
|
||||
&mut self,
|
||||
@@ -29,9 +39,9 @@ impl NodeServiceImpl {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match p.get_username() {
|
||||
Ok(v) => v.to_string().unwrap_or_default().to_string(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let request_bytes = match p.get_request() {
|
||||
Ok(v) => v.to_vec(),
|
||||
@@ -42,6 +52,14 @@ impl NodeServiceImpl {
|
||||
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
|
||||
}
|
||||
|
||||
// Check for existing recent pending login before expensive OPAQUE/storage work (DoS mitigation).
|
||||
if let Some(existing) = self.pending_logins.get(&username) {
|
||||
let age = current_timestamp().saturating_sub(existing.created_at);
|
||||
if age < 60 {
|
||||
return Promise::err(coded_error(E010_OPAQUE_ERROR, "login already in progress"));
|
||||
}
|
||||
}
|
||||
|
||||
let credential_request = match CredentialRequest::<OpaqueSuite>::deserialize(&request_bytes) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
@@ -62,9 +80,7 @@ impl NodeServiceImpl {
|
||||
))
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E010_OPAQUE_ERROR, "user not registered"))
|
||||
}
|
||||
Ok(None) => None,
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
|
||||
@@ -111,9 +127,9 @@ impl NodeServiceImpl {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match p.get_username() {
|
||||
Ok(v) => v.to_string().unwrap_or_default().to_string(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let request_bytes = match p.get_request() {
|
||||
Ok(v) => v.to_vec(),
|
||||
@@ -171,9 +187,9 @@ impl NodeServiceImpl {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match p.get_username() {
|
||||
Ok(v) => v.to_string().unwrap_or_default().to_string(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let finalization_bytes = match p.get_finalization() {
|
||||
Ok(v) => v.to_vec(),
|
||||
@@ -278,9 +294,9 @@ impl NodeServiceImpl {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let username = match p.get_username() {
|
||||
Ok(v) => v.to_string().unwrap_or_default().to_string(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
let username = match parse_username_param(p.get_username()) {
|
||||
Ok(s) => s,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let upload_bytes = match p.get_upload() {
|
||||
Ok(v) => v.to_vec(),
|
||||
@@ -326,12 +342,18 @@ impl NodeServiceImpl {
|
||||
let password_file = ServerRegistration::<OpaqueSuite>::finish(upload);
|
||||
let record_bytes = password_file.serialize().to_vec();
|
||||
|
||||
if let Err(e) = self
|
||||
match self
|
||||
.store
|
||||
.store_user_record(&username, record_bytes)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
return Promise::err(e);
|
||||
Ok(()) => {}
|
||||
Err(crate::storage::StorageError::DuplicateUser(_)) => {
|
||||
return Promise::err(coded_error(
|
||||
E018_USER_EXISTS,
|
||||
format!("user '{}' already registered", username),
|
||||
))
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
}
|
||||
|
||||
if !identity_key.is_empty() {
|
||||
|
||||
62
crates/quicnprotochat-server/src/node_service/channel_ops.rs
Normal file
62
crates/quicnprotochat-server/src/node_service/channel_ops.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
//! createChannel RPC: create or look up a 1:1 DM channel.
|
||||
|
||||
use capnp::capability::Promise;
|
||||
use quicnprotochat_proto::node_capnp::node_service;
|
||||
|
||||
use crate::auth::{coded_error, require_identity, validate_auth_context};
|
||||
use crate::error_codes::*;
|
||||
use crate::storage::StorageError;
|
||||
|
||||
use super::NodeServiceImpl;
|
||||
|
||||
fn storage_err(err: StorageError) -> capnp::Error {
|
||||
coded_error(E009_STORAGE_ERROR, err)
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
pub fn handle_create_channel(
|
||||
&mut self,
|
||||
params: node_service::CreateChannelParams,
|
||||
mut results: node_service::CreateChannelResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let peer_key = match p.get_peer_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
let identity = match require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if peer_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("peerKey must be exactly 32 bytes, got {}", peer_key.len()),
|
||||
));
|
||||
}
|
||||
|
||||
if identity == peer_key {
|
||||
return Promise::err(coded_error(
|
||||
E020_BAD_PARAMS,
|
||||
"peerKey must not equal caller identity",
|
||||
));
|
||||
}
|
||||
|
||||
let channel_id = match self.store.create_channel(&identity, &peer_key) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
|
||||
results.get().set_channel_id(&channel_id);
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
@@ -77,10 +77,10 @@ impl NodeServiceImpl {
|
||||
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
|
||||
));
|
||||
}
|
||||
if version != CURRENT_WIRE_VERSION {
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -102,6 +102,31 @@ impl NodeServiceImpl {
|
||||
}
|
||||
}
|
||||
|
||||
// DM channel authz: channel_id.len() == 16 means a created channel; caller and recipient must be the two members.
|
||||
if channel_id.len() == 16 {
|
||||
let members = match self.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
let caller = match crate::auth::require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = caller == a.as_slice() || caller == b.as_slice();
|
||||
let recipient_other = (recipient_key == *a && caller == b.as_slice())
|
||||
|| (recipient_key == *b && caller == a.as_slice());
|
||||
if !caller_in || !recipient_other {
|
||||
return Promise::err(coded_error(
|
||||
E022_CHANNEL_ACCESS_DENIED,
|
||||
"caller or recipient not a member of this channel",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
match self.store.queue_depth(&recipient_key, &channel_id) {
|
||||
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
|
||||
return Promise::err(coded_error(
|
||||
@@ -183,10 +208,10 @@ impl NodeServiceImpl {
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version != CURRENT_WIRE_VERSION {
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -203,6 +228,30 @@ impl NodeServiceImpl {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if channel_id.len() == 16 {
|
||||
let members = match self.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
let caller = match crate::auth::require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = caller == a.as_slice() || caller == b.as_slice();
|
||||
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|
||||
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
|
||||
if !caller_in || !recipient_other {
|
||||
return Promise::err(coded_error(
|
||||
E022_CHANNEL_ACCESS_DENIED,
|
||||
"caller or recipient not a member of this channel",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let messages = if limit > 0 {
|
||||
match self
|
||||
.store
|
||||
@@ -269,10 +318,10 @@ impl NodeServiceImpl {
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version != CURRENT_WIRE_VERSION {
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("unsupported wire version {} (expected {CURRENT_WIRE_VERSION})", version),
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -284,6 +333,30 @@ impl NodeServiceImpl {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
if channel_id.len() == 16 {
|
||||
let members = match self.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => {
|
||||
return Promise::err(coded_error(E023_CHANNEL_NOT_FOUND, "channel not found"));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
};
|
||||
let caller = match crate::auth::require_identity(&auth_ctx) {
|
||||
Ok(id) => id,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = caller == a.as_slice() || caller == b.as_slice();
|
||||
let recipient_other = (recipient_key.as_slice() == a.as_slice() && caller == b.as_slice())
|
||||
|| (recipient_key.as_slice() == b.as_slice() && caller == a.as_slice());
|
||||
if !caller_in || !recipient_other {
|
||||
return Promise::err(coded_error(
|
||||
E022_CHANNEL_ACCESS_DENIED,
|
||||
"caller or recipient not a member of this channel",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let store = Arc::clone(&self.store);
|
||||
let waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>> = self.waiters.clone();
|
||||
|
||||
@@ -315,4 +388,232 @@ impl NodeServiceImpl {
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handle_peek(
|
||||
&mut self,
|
||||
params: node_service::PeekParams,
|
||||
mut results: node_service::PeekResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
let version = p.get_version();
|
||||
let limit = p.get_limit();
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = require_identity_or_request(
|
||||
&auth_ctx,
|
||||
&recipient_key,
|
||||
self.auth_cfg.allow_insecure_identity_from_request,
|
||||
) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let messages = match self
|
||||
.store
|
||||
.peek(&recipient_key, &channel_id, limit as usize)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(m) => m,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
recipient_prefix = %fmt_hex(&recipient_key[..4]),
|
||||
count = messages.len(),
|
||||
"audit: peek"
|
||||
);
|
||||
|
||||
let mut list = results.get().init_payloads(messages.len() as u32);
|
||||
for (i, (seq, data)) in messages.iter().enumerate() {
|
||||
let mut entry = list.reborrow().get(i as u32);
|
||||
entry.set_seq(*seq);
|
||||
entry.set_data(data);
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_ack(
|
||||
&mut self,
|
||||
params: node_service::AckParams,
|
||||
_results: node_service::AckResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
let version = p.get_version();
|
||||
let seq_up_to = p.get_seq_up_to();
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey must be exactly 32 bytes, got {}", recipient_key.len()),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = require_identity_or_request(
|
||||
&auth_ctx,
|
||||
&recipient_key,
|
||||
self.auth_cfg.allow_insecure_identity_from_request,
|
||||
) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
match self
|
||||
.store
|
||||
.ack(&recipient_key, &channel_id, seq_up_to)
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(removed) => {
|
||||
tracing::info!(
|
||||
recipient_prefix = %fmt_hex(&recipient_key[..4]),
|
||||
seq_up_to = seq_up_to,
|
||||
removed = removed,
|
||||
"audit: ack"
|
||||
);
|
||||
}
|
||||
Err(e) => return Promise::err(e),
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_batch_enqueue(
|
||||
&mut self,
|
||||
params: node_service::BatchEnqueueParams,
|
||||
mut results: node_service::BatchEnqueueResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let recipient_keys = match p.get_recipient_keys() {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let payload = match p.get_payload() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
let version = p.get_version();
|
||||
let auth_ctx = match validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
if payload.is_empty() {
|
||||
return Promise::err(coded_error(E005_PAYLOAD_EMPTY, "payload must not be empty"));
|
||||
}
|
||||
if payload.len() > MAX_PAYLOAD_BYTES {
|
||||
return Promise::err(coded_error(
|
||||
E006_PAYLOAD_TOO_LARGE,
|
||||
format!("payload exceeds max size ({} bytes)", MAX_PAYLOAD_BYTES),
|
||||
));
|
||||
}
|
||||
if version > CURRENT_WIRE_VERSION {
|
||||
return Promise::err(coded_error(
|
||||
E012_WIRE_VERSION,
|
||||
format!("wire version {} not supported (max {CURRENT_WIRE_VERSION})", version),
|
||||
));
|
||||
}
|
||||
|
||||
if let Err(e) = check_rate_limit(&self.rate_limits, &auth_ctx.token) {
|
||||
tracing::warn!("rate_limit_hit");
|
||||
metrics::record_rate_limit_hit_total();
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let mut seqs = Vec::with_capacity(recipient_keys.len() as usize);
|
||||
for i in 0..recipient_keys.len() {
|
||||
let rk = match recipient_keys.get(i) {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
if rk.len() != 32 {
|
||||
return Promise::err(coded_error(
|
||||
E004_IDENTITY_KEY_LENGTH,
|
||||
format!("recipientKey[{}] must be exactly 32 bytes, got {}", i, rk.len()),
|
||||
));
|
||||
}
|
||||
|
||||
match self.store.queue_depth(&rk, &channel_id) {
|
||||
Ok(depth) if depth >= MAX_QUEUE_DEPTH => {
|
||||
return Promise::err(coded_error(
|
||||
E015_QUEUE_FULL,
|
||||
format!("queue depth {} exceeds limit {}", depth, MAX_QUEUE_DEPTH),
|
||||
));
|
||||
}
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let seq = match self
|
||||
.store
|
||||
.enqueue(&rk, &channel_id, payload.clone())
|
||||
.map_err(storage_err)
|
||||
{
|
||||
Ok(seq) => seq,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
seqs.push(seq);
|
||||
|
||||
metrics::record_enqueue_total();
|
||||
metrics::record_enqueue_bytes(payload.len() as u64);
|
||||
|
||||
crate::auth::waiter(&self.waiters, &rk).notify_waiters();
|
||||
}
|
||||
|
||||
let mut list = results.get().init_seqs(seqs.len() as u32);
|
||||
for (i, seq) in seqs.iter().enumerate() {
|
||||
list.set(i as u32, *seq);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
recipient_count = recipient_keys.len(),
|
||||
payload_len = payload.len(),
|
||||
"audit: batch_enqueue"
|
||||
);
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,4 +256,47 @@ impl NodeServiceImpl {
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
|
||||
pub fn handle_fetch_hybrid_keys(
|
||||
&mut self,
|
||||
params: node_service::FetchHybridKeysParams,
|
||||
mut results: node_service::FetchHybridKeysResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let identity_keys = match p.get_identity_keys() {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
|
||||
if let Err(e) = validate_auth_context(&self.auth_cfg, &self.sessions, p.get_auth()) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let count = identity_keys.len() as usize;
|
||||
let mut key_data: Vec<Vec<u8>> = Vec::with_capacity(count);
|
||||
for i in 0..identity_keys.len() {
|
||||
let ik = match identity_keys.get(i) {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
let pk = match self.store.fetch_hybrid_key(&ik).map_err(storage_err) {
|
||||
Ok(Some(pk)) => pk,
|
||||
Ok(None) => vec![],
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
key_data.push(pk);
|
||||
}
|
||||
|
||||
let mut list = results.get().init_keys(key_data.len() as u32);
|
||||
for (i, pk) in key_data.iter().enumerate() {
|
||||
list.set(i as u32, pk);
|
||||
}
|
||||
|
||||
tracing::debug!(count = count, "batch hybrid key fetch");
|
||||
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,11 @@ use crate::auth::{
|
||||
};
|
||||
use crate::storage::Store;
|
||||
|
||||
/// Cap'n Proto traversal limit (words). 4 Mi words = 32 MiB; bounds DoS from deeply nested or large messages.
|
||||
const CAPNP_TRAVERSAL_LIMIT_WORDS: usize = 4 * 1024 * 1024;
|
||||
|
||||
mod auth_ops;
|
||||
mod channel_ops;
|
||||
mod delivery;
|
||||
mod key_ops;
|
||||
mod p2p_ops;
|
||||
@@ -132,6 +136,46 @@ impl node_service::Server for NodeServiceImpl {
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_resolve_endpoint(params, results)
|
||||
}
|
||||
|
||||
fn peek(
|
||||
&mut self,
|
||||
params: node_service::PeekParams,
|
||||
results: node_service::PeekResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_peek(params, results)
|
||||
}
|
||||
|
||||
fn ack(
|
||||
&mut self,
|
||||
params: node_service::AckParams,
|
||||
results: node_service::AckResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_ack(params, results)
|
||||
}
|
||||
|
||||
fn fetch_hybrid_keys(
|
||||
&mut self,
|
||||
params: node_service::FetchHybridKeysParams,
|
||||
results: node_service::FetchHybridKeysResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_fetch_hybrid_keys(params, results)
|
||||
}
|
||||
|
||||
fn batch_enqueue(
|
||||
&mut self,
|
||||
params: node_service::BatchEnqueueParams,
|
||||
results: node_service::BatchEnqueueResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_batch_enqueue(params, results)
|
||||
}
|
||||
|
||||
fn create_channel(
|
||||
&mut self,
|
||||
params: node_service::CreateChannelParams,
|
||||
results: node_service::CreateChannelResults,
|
||||
) -> capnp::capability::Promise<(), capnp::Error> {
|
||||
self.handle_create_channel(params, results)
|
||||
}
|
||||
}
|
||||
|
||||
pub const CURRENT_WIRE_VERSION: u16 = 1;
|
||||
@@ -193,11 +237,13 @@ pub async fn handle_node_connection(
|
||||
.map_err(|e| anyhow::anyhow!("failed to accept bi stream: {e}"))?;
|
||||
let (reader, writer) = (recv.compat(), send.compat_write());
|
||||
|
||||
let mut reader_opts = capnp::message::ReaderOptions::new();
|
||||
reader_opts.traversal_limit_in_words(Some(CAPNP_TRAVERSAL_LIMIT_WORDS));
|
||||
let network = capnp_rpc::twoparty::VatNetwork::new(
|
||||
reader,
|
||||
writer,
|
||||
capnp_rpc::rpc_twoparty_capnp::Side::Server,
|
||||
Default::default(),
|
||||
reader_opts,
|
||||
);
|
||||
|
||||
let service: node_service::Client = capnp_rpc::new_client(NodeServiceImpl::new(
|
||||
@@ -223,6 +269,7 @@ pub fn spawn_cleanup_task(
|
||||
pending_logins: Arc<DashMap<String, PendingLogin>>,
|
||||
rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
|
||||
store: Arc<dyn Store>,
|
||||
waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(60));
|
||||
@@ -234,6 +281,29 @@ pub fn spawn_cleanup_task(
|
||||
pending_logins.retain(|_, pl| now - pl.created_at < PENDING_LOGIN_TTL_SECS);
|
||||
rate_limits.retain(|_, entry| now - entry.window_start < RATE_LIMIT_WINDOW_SECS * 2);
|
||||
|
||||
// Bound map sizes to prevent unbounded growth from malicious clients.
|
||||
const MAX_SESSIONS: usize = 100_000;
|
||||
const MAX_WAITERS: usize = 100_000;
|
||||
if sessions.len() > MAX_SESSIONS {
|
||||
let overflow = sessions.len() - MAX_SESSIONS;
|
||||
let mut entries: Vec<_> = sessions
|
||||
.iter()
|
||||
.map(|e| (e.key().clone(), e.expires_at))
|
||||
.collect();
|
||||
entries.sort_by_key(|(_, exp)| *exp);
|
||||
for (key, _) in entries.into_iter().take(overflow) {
|
||||
sessions.remove(&key);
|
||||
}
|
||||
}
|
||||
if waiters.len() > MAX_WAITERS {
|
||||
let overflow = waiters.len() - MAX_WAITERS;
|
||||
let keys: Vec<_> =
|
||||
waiters.iter().take(overflow).map(|e| e.key().clone()).collect();
|
||||
for key in keys {
|
||||
waiters.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
match store.gc_expired_messages(MESSAGE_TTL_SECS) {
|
||||
Ok(n) if n > 0 => {
|
||||
tracing::debug!(expired = n, "garbage collected expired messages")
|
||||
|
||||
@@ -14,6 +14,7 @@ fn storage_err(err: StorageError) -> capnp::Error {
|
||||
}
|
||||
|
||||
impl NodeServiceImpl {
|
||||
/// Health check: unauthenticated by design for liveness probes and load balancers.
|
||||
pub fn handle_health(
|
||||
&mut self,
|
||||
_params: node_service::HealthParams,
|
||||
|
||||
@@ -3,17 +3,19 @@
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use rand::RngCore;
|
||||
use rusqlite::{params, Connection};
|
||||
|
||||
use crate::storage::{StorageError, Store};
|
||||
|
||||
/// Schema version after introducing the migration runner (existing DBs had 1).
|
||||
const SCHEMA_VERSION: i32 = 3;
|
||||
const SCHEMA_VERSION: i32 = 4;
|
||||
|
||||
/// Migrations: (migration_number, SQL). Files named NNN_name.sql, applied in order when N > user_version.
|
||||
const MIGRATIONS: &[(i32, &str)] = &[
|
||||
(1, include_str!("../migrations/001_initial.sql")),
|
||||
(3, include_str!("../migrations/002_add_seq.sql")),
|
||||
(4, include_str!("../migrations/003_channels.sql")),
|
||||
];
|
||||
|
||||
/// Runs pending migrations on an open connection: applies any migration whose number is greater
|
||||
@@ -305,10 +307,17 @@ impl Store for SqlStore {
|
||||
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO users (username, opaque_record) VALUES (?1, ?2)",
|
||||
"INSERT INTO users (username, opaque_record) VALUES (?1, ?2)",
|
||||
params![username, record],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
.map_err(|e| {
|
||||
if let rusqlite::Error::SqliteFailure(ref err, _) = &e {
|
||||
if err.code == rusqlite::ErrorCode::ConstraintViolation {
|
||||
return StorageError::DuplicateUser(username.to_string());
|
||||
}
|
||||
}
|
||||
StorageError::Db(e.to_string())
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -360,6 +369,57 @@ impl Store for SqlStore {
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn peek(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
|
||||
let sql = if limit == 0 {
|
||||
"SELECT seq, payload FROM deliveries
|
||||
WHERE recipient_key = ?1 AND channel_id = ?2
|
||||
ORDER BY seq ASC".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"SELECT seq, payload FROM deliveries
|
||||
WHERE recipient_key = ?1 AND channel_id = ?2
|
||||
ORDER BY seq ASC
|
||||
LIMIT {}",
|
||||
limit
|
||||
)
|
||||
};
|
||||
|
||||
let mut stmt = conn.prepare(&sql).map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
let rows: Vec<(i64, Vec<u8>)> = stmt
|
||||
.query_map(params![recipient_key, channel_id], |row| {
|
||||
Ok((row.get(0)?, row.get(1)?))
|
||||
})
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
|
||||
Ok(rows.into_iter().map(|(seq, payload)| (seq as u64, payload)).collect())
|
||||
}
|
||||
|
||||
fn ack(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
seq_up_to: u64,
|
||||
) -> Result<usize, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
let deleted = conn
|
||||
.execute(
|
||||
"DELETE FROM deliveries WHERE recipient_key = ?1 AND channel_id = ?2 AND seq <= ?3",
|
||||
params![recipient_key, channel_id, seq_up_to as i64],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
fn publish_endpoint(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
@@ -384,6 +444,45 @@ impl Store for SqlStore {
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
|
||||
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError> {
|
||||
let (a, b) = if member_a < member_b {
|
||||
(member_a.to_vec(), member_b.to_vec())
|
||||
} else {
|
||||
(member_b.to_vec(), member_a.to_vec())
|
||||
};
|
||||
let conn = self.lock_conn()?;
|
||||
let existing: Option<Vec<u8>> = conn
|
||||
.query_row(
|
||||
"SELECT channel_id FROM channels WHERE member_a = ?1 AND member_b = ?2",
|
||||
params![a, b],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
if let Some(id) = existing {
|
||||
return Ok(id);
|
||||
}
|
||||
let mut channel_id = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut channel_id);
|
||||
conn.execute(
|
||||
"INSERT INTO channels (channel_id, member_a, member_b) VALUES (?1, ?2, ?3)",
|
||||
params![channel_id.as_slice(), a, b],
|
||||
)
|
||||
.map_err(|e| StorageError::Db(e.to_string()))?;
|
||||
Ok(channel_id.to_vec())
|
||||
}
|
||||
|
||||
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError> {
|
||||
let conn = self.lock_conn()?;
|
||||
conn.query_row(
|
||||
"SELECT member_a, member_b FROM channels WHERE channel_id = ?1",
|
||||
params![channel_id],
|
||||
|row| Ok((row.get::<_, Vec<u8>>(0)?, row.get::<_, Vec<u8>>(1)?)),
|
||||
)
|
||||
.optional()
|
||||
.map_err(|e| StorageError::Db(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience extension for `rusqlite::OptionalExtension`.
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::{
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@@ -16,6 +17,9 @@ pub enum StorageError {
|
||||
Serde,
|
||||
#[error("database error: {0}")]
|
||||
Db(String),
|
||||
/// Unique constraint violation (e.g. user already exists).
|
||||
#[error("duplicate user: {0}")]
|
||||
DuplicateUser(String),
|
||||
}
|
||||
|
||||
fn lock<T>(m: &Mutex<T>) -> Result<std::sync::MutexGuard<'_, T>, StorageError> {
|
||||
@@ -96,12 +100,36 @@ pub trait Store: Send + Sync {
|
||||
/// Retrieve identity key for a user (Fix 2).
|
||||
fn get_user_identity_key(&self, username: &str) -> Result<Option<Vec<u8>>, StorageError>;
|
||||
|
||||
/// Peek at queued messages without removing them (non-destructive).
|
||||
/// Returns `(seq, payload)` pairs ordered by seq.
|
||||
fn peek(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError>;
|
||||
|
||||
/// Acknowledge (remove) all messages with seq <= seq_up_to.
|
||||
fn ack(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
seq_up_to: u64,
|
||||
) -> Result<usize, StorageError>;
|
||||
|
||||
/// Publish a P2P endpoint address for an identity key.
|
||||
fn publish_endpoint(&self, identity_key: &[u8], node_addr: Vec<u8>)
|
||||
-> Result<(), StorageError>;
|
||||
|
||||
/// Resolve a peer's P2P endpoint address.
|
||||
fn resolve_endpoint(&self, identity_key: &[u8]) -> Result<Option<Vec<u8>>, StorageError>;
|
||||
|
||||
/// Create a 1:1 channel between two members. Returns 16-byte channel_id (UUID).
|
||||
/// Members are stored in sorted order for deterministic lookup.
|
||||
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError>;
|
||||
|
||||
/// Get the two members of a channel by channel_id (16 bytes). Returns (member_a, member_b) in sorted order.
|
||||
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError>;
|
||||
}
|
||||
|
||||
// ── ChannelKey ───────────────────────────────────────────────────────────────
|
||||
@@ -154,8 +182,10 @@ pub struct FileBackedStore {
|
||||
setup_path: PathBuf,
|
||||
users_path: PathBuf,
|
||||
identity_keys_path: PathBuf,
|
||||
channels_path: PathBuf,
|
||||
key_packages: Mutex<HashMap<Vec<u8>, VecDeque<Vec<u8>>>>,
|
||||
deliveries: Mutex<QueueMapV3>,
|
||||
channels: Mutex<HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>>,
|
||||
hybrid_keys: Mutex<HashMap<Vec<u8>, Vec<u8>>>,
|
||||
users: Mutex<HashMap<String, Vec<u8>>>,
|
||||
identity_keys: Mutex<HashMap<String, Vec<u8>>>,
|
||||
@@ -174,12 +204,14 @@ impl FileBackedStore {
|
||||
let setup_path = dir.join("server_setup.bin");
|
||||
let users_path = dir.join("users.bin");
|
||||
let identity_keys_path = dir.join("identity_keys.bin");
|
||||
let channels_path = dir.join("channels.bin");
|
||||
|
||||
let key_packages = Mutex::new(Self::load_kp_map(&kp_path)?);
|
||||
let deliveries = Mutex::new(Self::load_delivery_map_v3(&ds_path)?);
|
||||
let hybrid_keys = Mutex::new(Self::load_hybrid_keys(&hk_path)?);
|
||||
let users = Mutex::new(Self::load_users(&users_path)?);
|
||||
let identity_keys = Mutex::new(Self::load_map_string_bytes(&identity_keys_path)?);
|
||||
let channels = Mutex::new(Self::load_channels(&channels_path)?);
|
||||
|
||||
Ok(Self {
|
||||
kp_path,
|
||||
@@ -188,8 +220,10 @@ impl FileBackedStore {
|
||||
setup_path,
|
||||
users_path,
|
||||
identity_keys_path,
|
||||
channels_path,
|
||||
key_packages,
|
||||
deliveries,
|
||||
channels,
|
||||
hybrid_keys,
|
||||
users,
|
||||
identity_keys,
|
||||
@@ -197,6 +231,31 @@ impl FileBackedStore {
|
||||
})
|
||||
}
|
||||
|
||||
fn load_channels(
|
||||
path: &Path,
|
||||
) -> Result<HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>, StorageError> {
|
||||
if !path.exists() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
let bytes = fs::read(path).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
if bytes.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
bincode::deserialize(&bytes).map_err(|_| StorageError::Serde)
|
||||
}
|
||||
|
||||
fn flush_channels(
|
||||
&self,
|
||||
path: &Path,
|
||||
map: &HashMap<Vec<u8>, (Vec<u8>, Vec<u8>)>,
|
||||
) -> Result<(), StorageError> {
|
||||
let bytes = bincode::serialize(map).map_err(|_| StorageError::Serde)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(path, bytes).map_err(|e| StorageError::Io(e.to_string()))
|
||||
}
|
||||
|
||||
fn load_kp_map(path: &Path) -> Result<HashMap<Vec<u8>, VecDeque<Vec<u8>>>, StorageError> {
|
||||
if !path.exists() {
|
||||
return Ok(HashMap::new());
|
||||
@@ -346,8 +405,9 @@ impl Store for FileBackedStore {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
let seq = *inner.next_seq.entry(key.clone()).or_insert(0);
|
||||
*inner.next_seq.get_mut(&key).unwrap() = seq + 1;
|
||||
let entry = inner.next_seq.entry(key.clone()).or_insert(0);
|
||||
let seq = *entry;
|
||||
*entry = seq + 1;
|
||||
inner.map.entry(key).or_default().push_back(SeqEntry { seq, data: payload });
|
||||
self.flush_delivery_map(&self.ds_path, &*inner)?;
|
||||
Ok(seq)
|
||||
@@ -428,7 +488,13 @@ impl Store for FileBackedStore {
|
||||
if let Some(parent) = self.setup_path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(&self.setup_path, setup).map_err(|e| StorageError::Io(e.to_string()))
|
||||
fs::write(&self.setup_path, setup).map_err(|e| StorageError::Io(e.to_string()))?;
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(&self.setup_path, std::fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_server_setup(&self) -> Result<Option<Vec<u8>>, StorageError> {
|
||||
@@ -444,7 +510,14 @@ impl Store for FileBackedStore {
|
||||
|
||||
fn store_user_record(&self, username: &str, record: Vec<u8>) -> Result<(), StorageError> {
|
||||
let mut map = lock(&self.users)?;
|
||||
map.insert(username.to_string(), record);
|
||||
match map.entry(username.to_string()) {
|
||||
std::collections::hash_map::Entry::Occupied(_) => {
|
||||
return Err(StorageError::DuplicateUser(username.to_string()))
|
||||
}
|
||||
std::collections::hash_map::Entry::Vacant(v) => {
|
||||
v.insert(record);
|
||||
}
|
||||
}
|
||||
self.flush_users(&self.users_path, &*map)
|
||||
}
|
||||
|
||||
@@ -473,6 +546,54 @@ impl Store for FileBackedStore {
|
||||
Ok(map.get(username).cloned())
|
||||
}
|
||||
|
||||
fn peek(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, StorageError> {
|
||||
let inner = lock(&self.deliveries)?;
|
||||
let key = ChannelKey {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
let messages: Vec<(u64, Vec<u8>)> = inner
|
||||
.map
|
||||
.get(&key)
|
||||
.map(|q| {
|
||||
let count = if limit == 0 { q.len() } else { limit.min(q.len()) };
|
||||
q.iter()
|
||||
.take(count)
|
||||
.map(|e| (e.seq, e.data.clone()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
// Non-destructive: do NOT flush.
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
fn ack(
|
||||
&self,
|
||||
recipient_key: &[u8],
|
||||
channel_id: &[u8],
|
||||
seq_up_to: u64,
|
||||
) -> Result<usize, StorageError> {
|
||||
let mut inner = lock(&self.deliveries)?;
|
||||
let key = ChannelKey {
|
||||
channel_id: channel_id.to_vec(),
|
||||
recipient_key: recipient_key.to_vec(),
|
||||
};
|
||||
let removed = if let Some(q) = inner.map.get_mut(&key) {
|
||||
let before = q.len();
|
||||
q.retain(|e| e.seq > seq_up_to);
|
||||
before - q.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
self.flush_delivery_map(&self.ds_path, &*inner)?;
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
fn publish_endpoint(
|
||||
&self,
|
||||
identity_key: &[u8],
|
||||
@@ -487,4 +608,150 @@ impl Store for FileBackedStore {
|
||||
let map = lock(&self.endpoints)?;
|
||||
Ok(map.get(identity_key).cloned())
|
||||
}
|
||||
|
||||
fn create_channel(&self, member_a: &[u8], member_b: &[u8]) -> Result<Vec<u8>, StorageError> {
|
||||
let (a, b) = if member_a < member_b {
|
||||
(member_a.to_vec(), member_b.to_vec())
|
||||
} else {
|
||||
(member_b.to_vec(), member_a.to_vec())
|
||||
};
|
||||
let mut map = lock(&self.channels)?;
|
||||
if let Some((channel_id, _)) = map.iter().find(|(_, (ma, mb))| ma == &a && mb == &b) {
|
||||
return Ok(channel_id.clone());
|
||||
}
|
||||
let mut channel_id = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut channel_id);
|
||||
let channel_id = channel_id.to_vec();
|
||||
map.insert(channel_id.clone(), (a, b));
|
||||
self.flush_channels(&self.channels_path, &*map)?;
|
||||
Ok(channel_id)
|
||||
}
|
||||
|
||||
fn get_channel_members(&self, channel_id: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>, StorageError> {
|
||||
let map = lock(&self.channels)?;
|
||||
Ok(map.get(channel_id).cloned())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn temp_store() -> (TempDir, FileBackedStore) {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let store = FileBackedStore::open(dir.path()).unwrap();
|
||||
(dir, store)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_package_upload_fetch() {
|
||||
let (_dir, store) = temp_store();
|
||||
let ik = vec![1u8; 32];
|
||||
store.upload_key_package(&ik, vec![10, 20, 30]).unwrap();
|
||||
let pkg = store.fetch_key_package(&ik).unwrap();
|
||||
assert_eq!(pkg, Some(vec![10, 20, 30]));
|
||||
// Second fetch should return None (consumed)
|
||||
let pkg2 = store.fetch_key_package(&ik).unwrap();
|
||||
assert_eq!(pkg2, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enqueue_fetch_with_seq() {
|
||||
let (_dir, store) = temp_store();
|
||||
let rk = vec![2u8; 32];
|
||||
let ch = vec![];
|
||||
let seq0 = store.enqueue(&rk, &ch, vec![1]).unwrap();
|
||||
let seq1 = store.enqueue(&rk, &ch, vec![2]).unwrap();
|
||||
assert_eq!(seq0, 0);
|
||||
assert_eq!(seq1, 1);
|
||||
let msgs = store.fetch(&rk, &ch).unwrap();
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0], (0, vec![1]));
|
||||
assert_eq!(msgs[1], (1, vec![2]));
|
||||
// After fetch, queue should be empty
|
||||
let msgs2 = store.fetch(&rk, &ch).unwrap();
|
||||
assert!(msgs2.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fetch_limited_respects_limit() {
|
||||
let (_dir, store) = temp_store();
|
||||
let rk = vec![3u8; 32];
|
||||
let ch = vec![];
|
||||
for i in 0..5 {
|
||||
store.enqueue(&rk, &ch, vec![i]).unwrap();
|
||||
}
|
||||
let msgs = store.fetch_limited(&rk, &ch, 2).unwrap();
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].1, vec![0]);
|
||||
assert_eq!(msgs[1].1, vec![1]);
|
||||
// Remaining 3 should still be there
|
||||
let depth = store.queue_depth(&rk, &ch).unwrap();
|
||||
assert_eq!(depth, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn queue_depth_tracking() {
|
||||
let (_dir, store) = temp_store();
|
||||
let rk = vec![4u8; 32];
|
||||
let ch = vec![];
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0);
|
||||
store.enqueue(&rk, &ch, vec![1]).unwrap();
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 1);
|
||||
store.enqueue(&rk, &ch, vec![2]).unwrap();
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 2);
|
||||
store.fetch(&rk, &ch).unwrap();
|
||||
assert_eq!(store.queue_depth(&rk, &ch).unwrap(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hybrid_key_upload_fetch() {
|
||||
let (_dir, store) = temp_store();
|
||||
let ik = vec![5u8; 32];
|
||||
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), None);
|
||||
store.upload_hybrid_key(&ik, vec![99; 100]).unwrap();
|
||||
assert_eq!(store.fetch_hybrid_key(&ik).unwrap(), Some(vec![99; 100]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_record_crud() {
|
||||
let (_dir, store) = temp_store();
|
||||
assert!(!store.has_user_record("alice").unwrap());
|
||||
store.store_user_record("alice", vec![1, 2, 3]).unwrap();
|
||||
assert!(store.has_user_record("alice").unwrap());
|
||||
assert_eq!(store.get_user_record("alice").unwrap(), Some(vec![1, 2, 3]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_identity_key_crud() {
|
||||
let (_dir, store) = temp_store();
|
||||
assert_eq!(store.get_user_identity_key("bob").unwrap(), None);
|
||||
store.store_user_identity_key("bob", vec![7u8; 32]).unwrap();
|
||||
assert_eq!(store.get_user_identity_key("bob").unwrap(), Some(vec![7u8; 32]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn endpoint_publish_resolve() {
|
||||
let (_dir, store) = temp_store();
|
||||
let ik = vec![8u8; 32];
|
||||
assert_eq!(store.resolve_endpoint(&ik).unwrap(), None);
|
||||
store.publish_endpoint(&ik, vec![10, 20]).unwrap();
|
||||
assert_eq!(store.resolve_endpoint(&ik).unwrap(), Some(vec![10, 20]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_channel_and_members() {
|
||||
let (_dir, store) = temp_store();
|
||||
let a = vec![1u8; 32];
|
||||
let b = vec![2u8; 32];
|
||||
assert_eq!(store.get_channel_members(&[0u8; 16]).unwrap(), None);
|
||||
let id1 = store.create_channel(&a, &b).unwrap();
|
||||
assert_eq!(id1.len(), 16);
|
||||
let members = store.get_channel_members(&id1).unwrap().unwrap();
|
||||
assert_eq!(members.0, a);
|
||||
assert_eq!(members.1, b);
|
||||
let id2 = store.create_channel(&b, &a).unwrap();
|
||||
assert_eq!(id1, id2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,6 +61,12 @@ fn generate_self_signed_cert(cert_path: &PathBuf, key_path: &PathBuf) -> anyhow:
|
||||
|
||||
std::fs::write(cert_path, issued.cert.der()).context("write cert")?;
|
||||
std::fs::write(key_path, &key_der).context("write key")?;
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let perms = std::fs::Permissions::from_mode(0o600);
|
||||
std::fs::set_permissions(key_path, perms).context("set key permissions")?;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
cert = %cert_path.display(),
|
||||
|
||||
Reference in New Issue
Block a user