fix: security hardening — 40 findings from full codebase review
Full codebase review by 4 independent agents (security, architecture,
code quality, correctness) identified ~80 findings. This commit fixes 40
of them across all workspace crates.
Critical fixes:
- Federation service: validate origin against mTLS cert CN/SAN (C1)
- WS bridge: add DM channel auth, size limits, rate limiting (C2)
- hpke_seal: panic on error instead of silent empty ciphertext (C3)
- hpke_setup_sender_and_export: error on parse fail, no PQ downgrade (C7)
Security fixes:
- Zeroize: seed_bytes() returns Zeroizing<[u8;32]>, private_to_bytes()
returns Zeroizing<Vec<u8>>, ClientAuth.access_token, SessionState.password,
conversation hex_key all wrapped in Zeroizing
- Keystore: 0o600 file permissions on Unix
- MeshIdentity: 0o600 file permissions on Unix
- Timing floors: resolveIdentity + WS bridge resolve_user get 5ms floor
- Mobile: TLS verification gated behind insecure-dev feature flag
- Proto: from_bytes default limit tightened from 64 MiB to 8 MiB
Correctness fixes:
- fetch_wait: register waiter before fetch to close TOCTOU window
- MeshEnvelope: exclude hop_count from signature (forwarding no longer
invalidates sender signature)
- BroadcastChannel: encrypt returns Result instead of panicking
- transcript: rename verify_transcript_chain → validate_transcript_structure
- group.rs: extract shared process_incoming() for receive_message variants
- auth_ops: remove spurious RegistrationRequest deserialization
- MeshStore.seen: bounded to 100K with FIFO eviction
Quality fixes:
- FFI error classification: typed downcast instead of string matching
- Plugin HookVTable: SAFETY documentation for unsafe Send+Sync
- clippy::unwrap_used: warn → deny workspace-wide
- Various .unwrap_or("") → proper error returns
Review report: docs/REVIEW-2026-03-04.md
152 tests passing (72 core + 35 server + 14 E2E + 1 doctest + 30 P2P)
This commit is contained in:
@@ -71,6 +71,9 @@ quicproquo-p2p = { path = "../quicproquo-p2p", optional = true }
|
||||
ratatui = { version = "0.29", optional = true, default-features = false, features = ["crossterm"] }
|
||||
crossterm = { version = "0.28", optional = true }
|
||||
|
||||
# YAML playbook parsing (only compiled with --features playbook).
|
||||
serde_yaml = { version = "0.9", optional = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -80,6 +83,9 @@ workspace = true
|
||||
mesh = ["dep:mdns-sd", "dep:quicproquo-p2p"]
|
||||
# Enable full-screen Ratatui TUI: cargo build -p quicproquo-client --features tui
|
||||
tui = ["dep:ratatui", "dep:crossterm"]
|
||||
# Enable playbook (scripted command execution): YAML parser + serde derives.
|
||||
# Build: cargo build -p quicproquo-client --features playbook
|
||||
playbook = ["dep:serde_yaml"]
|
||||
|
||||
[dev-dependencies]
|
||||
dashmap = { workspace = true }
|
||||
|
||||
508
crates/quicproquo-client/src/client/command_engine.rs
Normal file
508
crates/quicproquo-client/src/client/command_engine.rs
Normal file
@@ -0,0 +1,508 @@
|
||||
//! Command engine: typed command enum, registry, and execution bridge.
|
||||
//!
|
||||
//! Maps every REPL slash command and lifecycle operation into a single `Command`
|
||||
//! enum with typed parameters. `CommandRegistry` parses raw input and delegates
|
||||
//! execution to the existing `cmd_*` handlers in `repl.rs`.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use quicproquo_proto::node_capnp::node_service;
|
||||
|
||||
use super::repl::{Input, SlashCommand, parse_input};
|
||||
use super::session::SessionState;
|
||||
|
||||
// ── Comparison operator for assert conditions ────────────────────────────────
|
||||
|
||||
/// Comparison operator used in playbook assertions.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "playbook", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum CmpOp {
|
||||
Eq,
|
||||
Ne,
|
||||
Gt,
|
||||
Lt,
|
||||
Gte,
|
||||
Lte,
|
||||
}
|
||||
|
||||
impl CmpOp {
|
||||
/// Evaluate this comparison: `lhs <op> rhs`.
|
||||
pub fn eval(&self, lhs: usize, rhs: usize) -> bool {
|
||||
match self {
|
||||
CmpOp::Eq => lhs == rhs,
|
||||
CmpOp::Ne => lhs != rhs,
|
||||
CmpOp::Gt => lhs > rhs,
|
||||
CmpOp::Lt => lhs < rhs,
|
||||
CmpOp::Gte => lhs >= rhs,
|
||||
CmpOp::Lte => lhs <= rhs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Assert conditions for playbook testing ───────────────────────────────────
|
||||
|
||||
/// Conditions that can be asserted in a playbook step.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "playbook", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum AssertCondition {
|
||||
Connected,
|
||||
LoggedIn,
|
||||
InConversation { name: String },
|
||||
MessageCount { op: CmpOp, count: usize },
|
||||
MemberCount { op: CmpOp, count: usize },
|
||||
Custom { expression: String },
|
||||
}
|
||||
|
||||
// ── Command enum ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Every operation the client can perform, with typed parameters.
|
||||
///
|
||||
/// This is a superset of `SlashCommand` — it adds lifecycle operations
|
||||
/// (`Connect`, `Login`, `Register`, `SendMessage`, `Wait`, `Assert`, `SetVar`)
|
||||
/// that are needed for non-interactive / playbook execution.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "playbook", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum Command {
|
||||
// ── Lifecycle (not in SlashCommand) ──────────────────────────────────
|
||||
Connect {
|
||||
server: String,
|
||||
ca_cert: Option<String>,
|
||||
insecure: bool,
|
||||
},
|
||||
Login {
|
||||
username: String,
|
||||
password: String,
|
||||
},
|
||||
Register {
|
||||
username: String,
|
||||
password: String,
|
||||
},
|
||||
SendMessage {
|
||||
text: String,
|
||||
},
|
||||
Wait {
|
||||
duration_ms: u64,
|
||||
},
|
||||
Assert {
|
||||
condition: AssertCondition,
|
||||
},
|
||||
SetVar {
|
||||
name: String,
|
||||
value: String,
|
||||
},
|
||||
|
||||
// ── SlashCommand mirror ─────────────────────────────────────────────
|
||||
Help,
|
||||
Quit,
|
||||
Whoami,
|
||||
List,
|
||||
Switch { target: String },
|
||||
Dm { username: String },
|
||||
CreateGroup { name: String },
|
||||
Invite { target: String },
|
||||
Remove { target: String },
|
||||
Leave,
|
||||
Join,
|
||||
Members,
|
||||
GroupInfo,
|
||||
Rename { name: String },
|
||||
History { count: usize },
|
||||
|
||||
// Mesh
|
||||
MeshPeers,
|
||||
MeshServer { addr: String },
|
||||
MeshSend { peer_id: String, message: String },
|
||||
MeshBroadcast { topic: String, message: String },
|
||||
MeshSubscribe { topic: String },
|
||||
MeshRoute,
|
||||
MeshIdentity,
|
||||
MeshStore,
|
||||
|
||||
// Security / crypto
|
||||
Verify { username: String },
|
||||
UpdateKey,
|
||||
Typing,
|
||||
TypingNotify { enabled: bool },
|
||||
React { emoji: String, index: Option<usize> },
|
||||
Edit { index: usize, new_text: String },
|
||||
Delete { index: usize },
|
||||
SendFile { path: String },
|
||||
Download { index: usize },
|
||||
DeleteAccount,
|
||||
Disappear { arg: Option<String> },
|
||||
Privacy { arg: Option<String> },
|
||||
VerifyFs,
|
||||
RotateAllKeys,
|
||||
Devices,
|
||||
RegisterDevice { name: String },
|
||||
RevokeDevice { id_prefix: String },
|
||||
}
|
||||
|
||||
impl Command {
|
||||
/// Convert a `Command` to a `SlashCommand` when possible.
|
||||
///
|
||||
/// Returns `None` for lifecycle commands that have no `SlashCommand`
|
||||
/// equivalent (`Connect`, `Login`, `Register`, `SendMessage`, `Wait`,
|
||||
/// `Assert`, `SetVar`).
|
||||
pub(crate) fn to_slash(&self) -> Option<SlashCommand> {
|
||||
match self.clone() {
|
||||
// Lifecycle — no SlashCommand equivalent
|
||||
Command::Connect { .. }
|
||||
| Command::Login { .. }
|
||||
| Command::Register { .. }
|
||||
| Command::SendMessage { .. }
|
||||
| Command::Wait { .. }
|
||||
| Command::Assert { .. }
|
||||
| Command::SetVar { .. } => None,
|
||||
|
||||
// 1:1 mirror
|
||||
Command::Help => Some(SlashCommand::Help),
|
||||
Command::Quit => Some(SlashCommand::Quit),
|
||||
Command::Whoami => Some(SlashCommand::Whoami),
|
||||
Command::List => Some(SlashCommand::List),
|
||||
Command::Switch { target } => Some(SlashCommand::Switch { target }),
|
||||
Command::Dm { username } => Some(SlashCommand::Dm { username }),
|
||||
Command::CreateGroup { name } => Some(SlashCommand::CreateGroup { name }),
|
||||
Command::Invite { target } => Some(SlashCommand::Invite { target }),
|
||||
Command::Remove { target } => Some(SlashCommand::Remove { target }),
|
||||
Command::Leave => Some(SlashCommand::Leave),
|
||||
Command::Join => Some(SlashCommand::Join),
|
||||
Command::Members => Some(SlashCommand::Members),
|
||||
Command::GroupInfo => Some(SlashCommand::GroupInfo),
|
||||
Command::Rename { name } => Some(SlashCommand::Rename { name }),
|
||||
Command::History { count } => Some(SlashCommand::History { count }),
|
||||
Command::MeshPeers => Some(SlashCommand::MeshPeers),
|
||||
Command::MeshServer { addr } => Some(SlashCommand::MeshServer { addr }),
|
||||
Command::MeshSend { peer_id, message } => {
|
||||
Some(SlashCommand::MeshSend { peer_id, message })
|
||||
}
|
||||
Command::MeshBroadcast { topic, message } => {
|
||||
Some(SlashCommand::MeshBroadcast { topic, message })
|
||||
}
|
||||
Command::MeshSubscribe { topic } => Some(SlashCommand::MeshSubscribe { topic }),
|
||||
Command::MeshRoute => Some(SlashCommand::MeshRoute),
|
||||
Command::MeshIdentity => Some(SlashCommand::MeshIdentity),
|
||||
Command::MeshStore => Some(SlashCommand::MeshStore),
|
||||
Command::Verify { username } => Some(SlashCommand::Verify { username }),
|
||||
Command::UpdateKey => Some(SlashCommand::UpdateKey),
|
||||
Command::Typing => Some(SlashCommand::Typing),
|
||||
Command::TypingNotify { enabled } => Some(SlashCommand::TypingNotify { enabled }),
|
||||
Command::React { emoji, index } => Some(SlashCommand::React { emoji, index }),
|
||||
Command::Edit { index, new_text } => Some(SlashCommand::Edit { index, new_text }),
|
||||
Command::Delete { index } => Some(SlashCommand::Delete { index }),
|
||||
Command::SendFile { path } => Some(SlashCommand::SendFile { path }),
|
||||
Command::Download { index } => Some(SlashCommand::Download { index }),
|
||||
Command::DeleteAccount => Some(SlashCommand::DeleteAccount),
|
||||
Command::Disappear { arg } => Some(SlashCommand::Disappear { arg }),
|
||||
Command::Privacy { arg } => Some(SlashCommand::Privacy { arg }),
|
||||
Command::VerifyFs => Some(SlashCommand::VerifyFs),
|
||||
Command::RotateAllKeys => Some(SlashCommand::RotateAllKeys),
|
||||
Command::Devices => Some(SlashCommand::Devices),
|
||||
Command::RegisterDevice { name } => Some(SlashCommand::RegisterDevice { name }),
|
||||
Command::RevokeDevice { id_prefix } => {
|
||||
Some(SlashCommand::RevokeDevice { id_prefix })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── CommandResult ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Outcome of executing a single `Command`.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "playbook", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct CommandResult {
|
||||
pub success: bool,
|
||||
pub output: Option<String>,
|
||||
pub error: Option<String>,
|
||||
/// Structured key-value outputs for variable capture in playbooks.
|
||||
pub data: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl CommandResult {
|
||||
fn ok() -> Self {
|
||||
Self {
|
||||
success: true,
|
||||
output: None,
|
||||
error: None,
|
||||
data: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn err(msg: String) -> Self {
|
||||
Self {
|
||||
success: false,
|
||||
output: None,
|
||||
error: Some(msg),
|
||||
data: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── CommandRegistry ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Parses raw input into `Command` and delegates execution to the existing
|
||||
/// REPL handlers.
|
||||
pub struct CommandRegistry;
|
||||
|
||||
impl CommandRegistry {
|
||||
/// Parse a raw input line into a `Command`.
|
||||
///
|
||||
/// Returns `None` for empty input. Returns `Some(Command::SendMessage)`
|
||||
/// for plain chat text. Slash commands are parsed via the existing
|
||||
/// `parse_input` function.
|
||||
pub fn parse(line: &str) -> Option<Command> {
|
||||
match parse_input(line) {
|
||||
Input::Empty => None,
|
||||
Input::ChatMessage(text) => Some(Command::SendMessage { text }),
|
||||
Input::Slash(sc) => Some(slash_to_command(sc)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a `Command`, delegating slash commands to the existing
|
||||
/// `handle_slash` dispatch and handling lifecycle commands directly.
|
||||
///
|
||||
/// Currently, output from `cmd_*` handlers goes to stdout (unchanged).
|
||||
/// `CommandResult` captures success/failure status; stdout capture can
|
||||
/// be added later.
|
||||
pub async fn execute(
|
||||
cmd: &Command,
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
) -> CommandResult {
|
||||
match cmd {
|
||||
Command::Wait { duration_ms } => {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(*duration_ms)).await;
|
||||
CommandResult::ok()
|
||||
}
|
||||
Command::SetVar { name, value } => {
|
||||
let mut result = CommandResult::ok();
|
||||
result.data.insert(name.clone(), value.clone());
|
||||
result
|
||||
}
|
||||
Command::Assert { condition } => execute_assert(condition, session),
|
||||
Command::Connect { .. } | Command::Login { .. } | Command::Register { .. } => {
|
||||
// These lifecycle commands require external context (endpoint,
|
||||
// OPAQUE state) that lives outside SessionState. The playbook
|
||||
// executor will handle them directly; calling execute() for
|
||||
// them is an error.
|
||||
CommandResult::err(
|
||||
"lifecycle commands (connect/login/register) must be handled by the playbook executor".into(),
|
||||
)
|
||||
}
|
||||
Command::SendMessage { text } => {
|
||||
match super::repl::do_send(session, client, text).await {
|
||||
Ok(()) => CommandResult::ok(),
|
||||
Err(e) => CommandResult::err(format!("{e:#}")),
|
||||
}
|
||||
}
|
||||
Command::Quit => CommandResult::ok(),
|
||||
other => {
|
||||
// All remaining variants have a SlashCommand equivalent.
|
||||
if let Some(sc) = other.to_slash() {
|
||||
match execute_slash(session, client, sc).await {
|
||||
Ok(()) => CommandResult::ok(),
|
||||
Err(e) => CommandResult::err(format!("{e:#}")),
|
||||
}
|
||||
} else {
|
||||
CommandResult::err("command has no slash equivalent".into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Conversion helpers ──────────────────────────────────────────────────────
|
||||
|
||||
/// Convert a `SlashCommand` into the corresponding `Command`.
|
||||
fn slash_to_command(sc: SlashCommand) -> Command {
|
||||
match sc {
|
||||
SlashCommand::Help => Command::Help,
|
||||
SlashCommand::Quit => Command::Quit,
|
||||
SlashCommand::Whoami => Command::Whoami,
|
||||
SlashCommand::List => Command::List,
|
||||
SlashCommand::Switch { target } => Command::Switch { target },
|
||||
SlashCommand::Dm { username } => Command::Dm { username },
|
||||
SlashCommand::CreateGroup { name } => Command::CreateGroup { name },
|
||||
SlashCommand::Invite { target } => Command::Invite { target },
|
||||
SlashCommand::Remove { target } => Command::Remove { target },
|
||||
SlashCommand::Leave => Command::Leave,
|
||||
SlashCommand::Join => Command::Join,
|
||||
SlashCommand::Members => Command::Members,
|
||||
SlashCommand::GroupInfo => Command::GroupInfo,
|
||||
SlashCommand::Rename { name } => Command::Rename { name },
|
||||
SlashCommand::History { count } => Command::History { count },
|
||||
SlashCommand::MeshPeers => Command::MeshPeers,
|
||||
SlashCommand::MeshServer { addr } => Command::MeshServer { addr },
|
||||
SlashCommand::MeshSend { peer_id, message } => Command::MeshSend { peer_id, message },
|
||||
SlashCommand::MeshBroadcast { topic, message } => {
|
||||
Command::MeshBroadcast { topic, message }
|
||||
}
|
||||
SlashCommand::MeshSubscribe { topic } => Command::MeshSubscribe { topic },
|
||||
SlashCommand::MeshRoute => Command::MeshRoute,
|
||||
SlashCommand::MeshIdentity => Command::MeshIdentity,
|
||||
SlashCommand::MeshStore => Command::MeshStore,
|
||||
SlashCommand::Verify { username } => Command::Verify { username },
|
||||
SlashCommand::UpdateKey => Command::UpdateKey,
|
||||
SlashCommand::Typing => Command::Typing,
|
||||
SlashCommand::TypingNotify { enabled } => Command::TypingNotify { enabled },
|
||||
SlashCommand::React { emoji, index } => Command::React { emoji, index },
|
||||
SlashCommand::Edit { index, new_text } => Command::Edit { index, new_text },
|
||||
SlashCommand::Delete { index } => Command::Delete { index },
|
||||
SlashCommand::SendFile { path } => Command::SendFile { path },
|
||||
SlashCommand::Download { index } => Command::Download { index },
|
||||
SlashCommand::DeleteAccount => Command::DeleteAccount,
|
||||
SlashCommand::Disappear { arg } => Command::Disappear { arg },
|
||||
SlashCommand::Privacy { arg } => Command::Privacy { arg },
|
||||
SlashCommand::VerifyFs => Command::VerifyFs,
|
||||
SlashCommand::RotateAllKeys => Command::RotateAllKeys,
|
||||
SlashCommand::Devices => Command::Devices,
|
||||
SlashCommand::RegisterDevice { name } => Command::RegisterDevice { name },
|
||||
SlashCommand::RevokeDevice { id_prefix } => Command::RevokeDevice { id_prefix },
|
||||
}
|
||||
}
|
||||
|
||||
// ── Execution helpers ───────────────────────────────────────────────────────
|
||||
|
||||
/// Execute a `SlashCommand` using the existing `cmd_*` handlers from `repl.rs`.
|
||||
///
|
||||
/// This duplicates the dispatch table from `handle_slash` but returns
|
||||
/// `anyhow::Result<()>` instead of printing errors inline — the caller
|
||||
/// decides how to surface errors.
|
||||
async fn execute_slash(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
cmd: SlashCommand,
|
||||
) -> anyhow::Result<()> {
|
||||
use super::repl::*;
|
||||
match cmd {
|
||||
SlashCommand::Help => {
|
||||
print_help();
|
||||
Ok(())
|
||||
}
|
||||
SlashCommand::Quit => Ok(()),
|
||||
SlashCommand::Whoami => cmd_whoami(session),
|
||||
SlashCommand::List => cmd_list(session),
|
||||
SlashCommand::Switch { target } => cmd_switch(session, &target),
|
||||
SlashCommand::Dm { username } => cmd_dm(session, client, &username).await,
|
||||
SlashCommand::CreateGroup { name } => cmd_create_group(session, &name),
|
||||
SlashCommand::Invite { target } => cmd_invite(session, client, &target).await,
|
||||
SlashCommand::Remove { target } => cmd_remove(session, client, &target).await,
|
||||
SlashCommand::Leave => cmd_leave(session, client).await,
|
||||
SlashCommand::Join => cmd_join(session, client).await,
|
||||
SlashCommand::Members => cmd_members(session, client).await,
|
||||
SlashCommand::GroupInfo => cmd_group_info(session, client).await,
|
||||
SlashCommand::Rename { name } => cmd_rename(session, &name),
|
||||
SlashCommand::History { count } => cmd_history(session, count),
|
||||
SlashCommand::MeshPeers => cmd_mesh_peers(),
|
||||
SlashCommand::MeshServer { addr } => {
|
||||
super::display::print_status(&format!(
|
||||
"mesh server hint: reconnect with --server {addr} to use this node"
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
SlashCommand::MeshSend { peer_id, message } => cmd_mesh_send(&peer_id, &message),
|
||||
SlashCommand::MeshBroadcast { topic, message } => cmd_mesh_broadcast(&topic, &message),
|
||||
SlashCommand::MeshSubscribe { topic } => cmd_mesh_subscribe(&topic),
|
||||
SlashCommand::MeshRoute => cmd_mesh_route(session),
|
||||
SlashCommand::MeshIdentity => cmd_mesh_identity(session),
|
||||
SlashCommand::MeshStore => cmd_mesh_store(session),
|
||||
SlashCommand::Verify { username } => cmd_verify(session, client, &username).await,
|
||||
SlashCommand::UpdateKey => cmd_update_key(session, client).await,
|
||||
SlashCommand::Typing => cmd_typing(session, client).await,
|
||||
SlashCommand::TypingNotify { enabled } => {
|
||||
session.typing_notify_enabled = enabled;
|
||||
super::display::print_status(&format!(
|
||||
"typing notifications {}",
|
||||
if enabled { "enabled" } else { "disabled" }
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
SlashCommand::React { emoji, index } => cmd_react(session, client, &emoji, index).await,
|
||||
SlashCommand::Edit { index, new_text } => {
|
||||
cmd_edit(session, client, index, &new_text).await
|
||||
}
|
||||
SlashCommand::Delete { index } => cmd_delete(session, client, index).await,
|
||||
SlashCommand::SendFile { path } => cmd_send_file(session, client, &path).await,
|
||||
SlashCommand::Download { index } => cmd_download(session, client, index).await,
|
||||
SlashCommand::DeleteAccount => cmd_delete_account(session, client).await,
|
||||
SlashCommand::Disappear { arg } => cmd_disappear(session, arg.as_deref()),
|
||||
SlashCommand::Privacy { arg } => cmd_privacy(session, arg.as_deref()),
|
||||
SlashCommand::VerifyFs => cmd_verify_fs(session),
|
||||
SlashCommand::RotateAllKeys => cmd_rotate_all_keys(session, client).await,
|
||||
SlashCommand::Devices => cmd_devices(client).await,
|
||||
SlashCommand::RegisterDevice { name } => cmd_register_device(client, &name).await,
|
||||
SlashCommand::RevokeDevice { id_prefix } => cmd_revoke_device(client, &id_prefix).await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Assert a condition against the current session state.
|
||||
fn execute_assert(condition: &AssertCondition, session: &SessionState) -> CommandResult {
|
||||
match condition {
|
||||
AssertCondition::Connected => {
|
||||
// We have a session => we got past connect. Always true when
|
||||
// execute() is called with a valid client reference.
|
||||
CommandResult::ok()
|
||||
}
|
||||
AssertCondition::LoggedIn => {
|
||||
let guard = crate::AUTH_CONTEXT
|
||||
.read()
|
||||
.expect("AUTH_CONTEXT poisoned");
|
||||
if guard.is_some() {
|
||||
CommandResult::ok()
|
||||
} else {
|
||||
CommandResult::err("not logged in".into())
|
||||
}
|
||||
}
|
||||
AssertCondition::InConversation { name } => {
|
||||
if let Some(display) = session.active_display_name() {
|
||||
if display.contains(name.as_str()) {
|
||||
CommandResult::ok()
|
||||
} else {
|
||||
CommandResult::err(format!(
|
||||
"active conversation is '{display}', expected '{name}'"
|
||||
))
|
||||
}
|
||||
} else {
|
||||
CommandResult::err("no active conversation".into())
|
||||
}
|
||||
}
|
||||
AssertCondition::MessageCount { op, count } => {
|
||||
let actual = session
|
||||
.active_conversation
|
||||
.as_ref()
|
||||
.and_then(|id| session.conv_store.load_all_messages(id).ok())
|
||||
.map(|msgs| msgs.len())
|
||||
.unwrap_or(0);
|
||||
if op.eval(actual, *count) {
|
||||
CommandResult::ok()
|
||||
} else {
|
||||
CommandResult::err(format!(
|
||||
"message count assertion failed: {actual} {op:?} {count}"
|
||||
))
|
||||
}
|
||||
}
|
||||
AssertCondition::MemberCount { op, count } => {
|
||||
let actual = session
|
||||
.active_conversation
|
||||
.as_ref()
|
||||
.and_then(|id| session.members.get(id))
|
||||
.map(|m| m.member_identities().len())
|
||||
.unwrap_or(0);
|
||||
if op.eval(actual, *count) {
|
||||
CommandResult::ok()
|
||||
} else {
|
||||
CommandResult::err(format!(
|
||||
"member count assertion failed: {actual} {op:?} {count}"
|
||||
))
|
||||
}
|
||||
}
|
||||
AssertCondition::Custom { expression } => {
|
||||
// Custom expressions are not evaluated yet; always pass.
|
||||
let mut result = CommandResult::ok();
|
||||
result.data.insert("expression".into(), expression.clone());
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,10 +169,10 @@ impl ConversationStore {
|
||||
|
||||
let salt = get_or_create_salt(&salt_path)?;
|
||||
let key = derive_convdb_key(password, &salt)?;
|
||||
let hex_key = hex::encode(*key);
|
||||
let hex_key = Zeroizing::new(hex::encode(&*key));
|
||||
|
||||
let conn = Connection::open(db_path).context("open conversation db")?;
|
||||
conn.pragma_update(None, "key", format!("x'{hex_key}'"))
|
||||
conn.pragma_update(None, "key", format!("x'{}'", &*hex_key))
|
||||
.context("set SQLCipher key")?;
|
||||
conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;")
|
||||
.context("set pragmas")?;
|
||||
@@ -188,7 +188,7 @@ impl ConversationStore {
|
||||
) -> anyhow::Result<()> {
|
||||
let salt = get_or_create_salt(salt_path)?;
|
||||
let key = derive_convdb_key(password, &salt)?;
|
||||
let hex_key = hex::encode(*key);
|
||||
let hex_key = Zeroizing::new(hex::encode(&*key));
|
||||
|
||||
let enc_path = db_path.with_extension("convdb-enc");
|
||||
|
||||
@@ -197,10 +197,16 @@ impl ConversationStore {
|
||||
plain.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;").ok();
|
||||
|
||||
// Attach a new encrypted database and export into it.
|
||||
// Sanitize the path to prevent SQL injection (ATTACH does not support parameterized paths).
|
||||
let enc_path_str = enc_path.display().to_string();
|
||||
anyhow::ensure!(
|
||||
!enc_path_str.contains('\''),
|
||||
"database path must not contain single quotes: {enc_path_str}"
|
||||
);
|
||||
plain
|
||||
.execute_batch(&format!(
|
||||
"ATTACH DATABASE '{}' AS encrypted KEY \"x'{hex_key}'\";",
|
||||
enc_path.display()
|
||||
"ATTACH DATABASE '{enc_path_str}' AS encrypted KEY \"x'{}'\";",
|
||||
&*hex_key
|
||||
))
|
||||
.context("attach encrypted db for migration")?;
|
||||
plain
|
||||
@@ -361,7 +367,13 @@ impl ConversationStore {
|
||||
};
|
||||
|
||||
let member_keys: Vec<Vec<u8>> = member_keys_blob
|
||||
.and_then(|b| bincode::deserialize(&b).ok())
|
||||
.and_then(|b| match bincode::deserialize(&b) {
|
||||
Ok(v) => Some(v),
|
||||
Err(e) => {
|
||||
tracing::warn!(conv = %hex::encode(id.0), "bincode deserialize member_keys failed: {e}");
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(Conversation {
|
||||
@@ -418,7 +430,13 @@ impl ConversationStore {
|
||||
}
|
||||
};
|
||||
let member_keys: Vec<Vec<u8>> = member_keys_blob
|
||||
.and_then(|b| bincode::deserialize(&b).ok())
|
||||
.and_then(|b| match bincode::deserialize(&b) {
|
||||
Ok(v) => Some(v),
|
||||
Err(e) => {
|
||||
tracing::warn!(conv = %hex::encode(&id_blob), "bincode deserialize member_keys failed: {e}");
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(Conversation {
|
||||
@@ -545,7 +563,7 @@ impl ConversationStore {
|
||||
ORDER BY timestamp_ms DESC
|
||||
LIMIT ?2",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![conv_id.0.as_slice(), limit as u32], |row| {
|
||||
let rows = stmt.query_map(params![conv_id.0.as_slice(), limit.min(u32::MAX as usize) as u32], |row| {
|
||||
let message_id: Option<Vec<u8>> = row.get(0)?;
|
||||
let sender_key: Vec<u8> = row.get(1)?;
|
||||
let sender_name: Option<String> = row.get(2)?;
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
pub mod command_engine;
|
||||
pub mod commands;
|
||||
pub mod conversation;
|
||||
pub mod display;
|
||||
pub mod hex;
|
||||
pub mod mesh_discovery;
|
||||
#[cfg(feature = "playbook")]
|
||||
pub mod playbook;
|
||||
pub mod repl;
|
||||
pub mod retry;
|
||||
pub mod rpc;
|
||||
|
||||
868
crates/quicproquo-client/src/client/playbook.rs
Normal file
868
crates/quicproquo-client/src/client/playbook.rs
Normal file
@@ -0,0 +1,868 @@
|
||||
//! YAML playbook parser and executor.
|
||||
//!
|
||||
//! Playbooks describe a sequence of client commands in YAML format.
|
||||
//! They support variable substitution, assertions, loops, and per-step
|
||||
//! error handling policies.
|
||||
//!
|
||||
//! ```yaml
|
||||
//! name: "smoke test"
|
||||
//! steps:
|
||||
//! - command: dm
|
||||
//! args: { username: "bob" }
|
||||
//! - command: send
|
||||
//! args: { text: "Hello from playbook" }
|
||||
//! - command: assert
|
||||
//! condition: message_count
|
||||
//! op: gte
|
||||
//! value: 1
|
||||
//! ```
|
||||
//!
|
||||
//! Requires the `playbook` cargo feature.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{Context, bail};
|
||||
use quicproquo_proto::node_capnp::node_service;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::command_engine::{AssertCondition, CmpOp, Command, CommandRegistry};
|
||||
use super::session::SessionState;
|
||||
|
||||
// ── Playbook structs ────────────────────────────────────────────────────────
|
||||
|
||||
/// A parsed YAML playbook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Playbook {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
#[serde(default)]
|
||||
pub variables: HashMap<String, String>,
|
||||
pub steps: Vec<PlaybookStep>,
|
||||
}
|
||||
|
||||
/// A single step in a playbook.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlaybookStep {
|
||||
pub command: String,
|
||||
#[serde(default)]
|
||||
pub args: HashMap<String, serde_yaml::Value>,
|
||||
/// For assert steps: the condition name.
|
||||
#[serde(default)]
|
||||
pub condition: Option<String>,
|
||||
/// For assert steps: comparison operator.
|
||||
#[serde(default)]
|
||||
pub op: Option<String>,
|
||||
/// For assert steps: expected value.
|
||||
#[serde(default)]
|
||||
pub value: Option<serde_yaml::Value>,
|
||||
/// Capture the command output into this variable name.
|
||||
#[serde(default)]
|
||||
pub capture: Option<String>,
|
||||
/// Error handling policy for this step.
|
||||
#[serde(default)]
|
||||
pub on_error: OnError,
|
||||
/// Optional loop specification.
|
||||
#[serde(rename = "loop", default)]
|
||||
pub loop_spec: Option<LoopSpec>,
|
||||
}
|
||||
|
||||
/// What to do when a step fails.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OnError {
|
||||
#[default]
|
||||
Fail,
|
||||
Skip,
|
||||
Continue,
|
||||
}
|
||||
|
||||
/// Loop specification for repeating a step.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LoopSpec {
|
||||
pub var: String,
|
||||
pub from: usize,
|
||||
pub to: usize,
|
||||
}
|
||||
|
||||
// ── Report structs ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Summary of a playbook execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlaybookReport {
|
||||
pub name: String,
|
||||
pub total_steps: usize,
|
||||
pub passed: usize,
|
||||
pub failed: usize,
|
||||
pub skipped: usize,
|
||||
pub duration: Duration,
|
||||
pub step_results: Vec<StepResult>,
|
||||
}
|
||||
|
||||
impl PlaybookReport {
|
||||
/// True if all steps passed (no failures).
|
||||
pub fn all_passed(&self) -> bool {
|
||||
self.failed == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PlaybookReport {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
writeln!(f, "Playbook: {}", self.name)?;
|
||||
writeln!(
|
||||
f,
|
||||
"Result: {} passed, {} failed, {} skipped ({} total)",
|
||||
self.passed, self.failed, self.skipped, self.total_steps,
|
||||
)?;
|
||||
writeln!(f, "Duration: {:.2}s", self.duration.as_secs_f64())?;
|
||||
for sr in &self.step_results {
|
||||
let status = if sr.success { "OK" } else { "FAIL" };
|
||||
write!(
|
||||
f,
|
||||
" [{}/{}] {} ... {} ({:.1}ms)",
|
||||
sr.step_index + 1,
|
||||
self.total_steps,
|
||||
sr.command,
|
||||
status,
|
||||
sr.duration.as_secs_f64() * 1000.0,
|
||||
)?;
|
||||
if let Some(ref e) = sr.error {
|
||||
write!(f, " — {e}")?;
|
||||
}
|
||||
writeln!(f)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of a single step execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StepResult {
|
||||
pub step_index: usize,
|
||||
pub command: String,
|
||||
pub success: bool,
|
||||
pub duration: Duration,
|
||||
pub output: Option<String>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
// ── PlaybookRunner ──────────────────────────────────────────────────────────
|
||||
|
||||
/// Executes a parsed `Playbook` step-by-step.
|
||||
pub struct PlaybookRunner {
|
||||
playbook: Playbook,
|
||||
vars: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl PlaybookRunner {
|
||||
/// Load a playbook from a YAML file.
|
||||
pub fn from_file(path: &Path) -> anyhow::Result<Self> {
|
||||
let content =
|
||||
std::fs::read_to_string(path).with_context(|| format!("read {}", path.display()))?;
|
||||
Self::from_str(&content)
|
||||
}
|
||||
|
||||
/// Parse a playbook from a YAML string.
|
||||
pub fn from_str(yaml: &str) -> anyhow::Result<Self> {
|
||||
let playbook: Playbook =
|
||||
serde_yaml::from_str(yaml).context("parse playbook YAML")?;
|
||||
let vars = playbook.variables.clone();
|
||||
Ok(Self { playbook, vars })
|
||||
}
|
||||
|
||||
/// Override or add variables before execution.
|
||||
pub fn set_var(&mut self, name: impl Into<String>, value: impl Into<String>) {
|
||||
self.vars.insert(name.into(), value.into());
|
||||
}
|
||||
|
||||
/// Execute all steps, returning a report.
|
||||
pub async fn run(
|
||||
&mut self,
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
) -> PlaybookReport {
|
||||
let start = Instant::now();
|
||||
let total = self.expanded_step_count();
|
||||
let mut results = Vec::new();
|
||||
let mut passed = 0usize;
|
||||
let mut failed = 0usize;
|
||||
let mut skipped = 0usize;
|
||||
let mut step_idx = 0usize;
|
||||
let mut abort = false;
|
||||
|
||||
for step in &self.playbook.steps.clone() {
|
||||
if abort {
|
||||
skipped += 1;
|
||||
results.push(StepResult {
|
||||
step_index: step_idx,
|
||||
command: step.command.clone(),
|
||||
success: false,
|
||||
duration: Duration::ZERO,
|
||||
output: None,
|
||||
error: Some("skipped (prior failure)".into()),
|
||||
});
|
||||
step_idx += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(ref ls) = step.loop_spec {
|
||||
for i in ls.from..=ls.to {
|
||||
self.vars.insert(ls.var.clone(), i.to_string());
|
||||
let sr = self.execute_step(step, step_idx, total, session, client).await;
|
||||
if sr.success {
|
||||
passed += 1;
|
||||
} else {
|
||||
failed += 1;
|
||||
if step.on_error == OnError::Fail {
|
||||
abort = true;
|
||||
}
|
||||
}
|
||||
results.push(sr);
|
||||
step_idx += 1;
|
||||
if abort {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let sr = self.execute_step(step, step_idx, total, session, client).await;
|
||||
if sr.success {
|
||||
passed += 1;
|
||||
} else {
|
||||
match step.on_error {
|
||||
OnError::Fail => {
|
||||
failed += 1;
|
||||
abort = true;
|
||||
}
|
||||
OnError::Skip => skipped += 1,
|
||||
OnError::Continue => failed += 1,
|
||||
}
|
||||
}
|
||||
results.push(sr);
|
||||
step_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
PlaybookReport {
|
||||
name: self.playbook.name.clone(),
|
||||
total_steps: step_idx,
|
||||
passed,
|
||||
failed,
|
||||
skipped,
|
||||
duration: start.elapsed(),
|
||||
step_results: results,
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a single step.
|
||||
async fn execute_step(
|
||||
&mut self,
|
||||
step: &PlaybookStep,
|
||||
index: usize,
|
||||
total: usize,
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
) -> StepResult {
|
||||
let t = Instant::now();
|
||||
let cmd = match self.step_to_command(step) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
return StepResult {
|
||||
step_index: index,
|
||||
command: step.command.clone(),
|
||||
success: false,
|
||||
duration: t.elapsed(),
|
||||
output: None,
|
||||
error: Some(format!("{e:#}")),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
eprintln!(
|
||||
"[{}/{}] {} ...",
|
||||
index + 1,
|
||||
total,
|
||||
step.command,
|
||||
);
|
||||
|
||||
let cr = CommandRegistry::execute(&cmd, session, client).await;
|
||||
|
||||
// Capture output into variable if requested.
|
||||
if let Some(ref var_name) = step.capture {
|
||||
if let Some(ref out) = cr.output {
|
||||
self.vars.insert(var_name.clone(), out.clone());
|
||||
}
|
||||
for (k, v) in &cr.data {
|
||||
self.vars.insert(format!("{var_name}.{k}"), v.clone());
|
||||
}
|
||||
}
|
||||
|
||||
StepResult {
|
||||
step_index: index,
|
||||
command: step.command.clone(),
|
||||
success: cr.success,
|
||||
duration: t.elapsed(),
|
||||
output: cr.output,
|
||||
error: cr.error,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a YAML step into a typed `Command`.
|
||||
fn step_to_command(&self, step: &PlaybookStep) -> anyhow::Result<Command> {
|
||||
let cmd_name = step.command.as_str();
|
||||
match cmd_name {
|
||||
// ── Lifecycle commands ────────────────────────────────────────
|
||||
"connect" => Ok(Command::Connect {
|
||||
server: self.resolve_str(&step.args, "server")?,
|
||||
ca_cert: self.opt_str(&step.args, "ca_cert"),
|
||||
insecure: self.opt_bool(&step.args, "insecure"),
|
||||
}),
|
||||
"login" => Ok(Command::Login {
|
||||
username: self.resolve_str(&step.args, "username")?,
|
||||
password: self.resolve_str(&step.args, "password")?,
|
||||
}),
|
||||
"register" => Ok(Command::Register {
|
||||
username: self.resolve_str(&step.args, "username")?,
|
||||
password: self.resolve_str(&step.args, "password")?,
|
||||
}),
|
||||
"send" | "send-message" => Ok(Command::SendMessage {
|
||||
text: self.resolve_str(&step.args, "text")?,
|
||||
}),
|
||||
"wait" => Ok(Command::Wait {
|
||||
duration_ms: self.resolve_u64(&step.args, "duration_ms")?,
|
||||
}),
|
||||
"set-var" | "setvar" => Ok(Command::SetVar {
|
||||
name: self.resolve_str(&step.args, "name")?,
|
||||
value: self.resolve_str(&step.args, "value")?,
|
||||
}),
|
||||
"assert" => {
|
||||
let condition = self.build_assert_condition(step)?;
|
||||
Ok(Command::Assert { condition })
|
||||
}
|
||||
|
||||
// ── Session / identity ───────────────────────────────────────
|
||||
"help" => Ok(Command::Help),
|
||||
"quit" | "exit" => Ok(Command::Quit),
|
||||
"whoami" => Ok(Command::Whoami),
|
||||
"list" | "ls" => Ok(Command::List),
|
||||
"switch" | "sw" => Ok(Command::Switch {
|
||||
target: self.resolve_str(&step.args, "target")?,
|
||||
}),
|
||||
"dm" => Ok(Command::Dm {
|
||||
username: self.resolve_str(&step.args, "username")?,
|
||||
}),
|
||||
"create-group" | "cg" => Ok(Command::CreateGroup {
|
||||
name: self.resolve_str(&step.args, "name")?,
|
||||
}),
|
||||
"invite" => Ok(Command::Invite {
|
||||
target: self.resolve_str(&step.args, "target")?,
|
||||
}),
|
||||
"remove" | "kick" => Ok(Command::Remove {
|
||||
target: self.resolve_str(&step.args, "target")?,
|
||||
}),
|
||||
"leave" => Ok(Command::Leave),
|
||||
"join" => Ok(Command::Join),
|
||||
"members" => Ok(Command::Members),
|
||||
"group-info" | "gi" => Ok(Command::GroupInfo),
|
||||
"rename" => Ok(Command::Rename {
|
||||
name: self.resolve_str(&step.args, "name")?,
|
||||
}),
|
||||
"history" | "hist" => Ok(Command::History {
|
||||
count: self.opt_usize(&step.args, "count").unwrap_or(20),
|
||||
}),
|
||||
|
||||
// ── Security / crypto ────────────────────────────────────────
|
||||
"verify" => Ok(Command::Verify {
|
||||
username: self.resolve_str(&step.args, "username")?,
|
||||
}),
|
||||
"update-key" | "rotate-key" => Ok(Command::UpdateKey),
|
||||
"typing" => Ok(Command::Typing),
|
||||
"typing-notify" => Ok(Command::TypingNotify {
|
||||
enabled: self.opt_bool(&step.args, "enabled"),
|
||||
}),
|
||||
"react" => Ok(Command::React {
|
||||
emoji: self.resolve_str(&step.args, "emoji")?,
|
||||
index: self.opt_usize(&step.args, "index"),
|
||||
}),
|
||||
"edit" => Ok(Command::Edit {
|
||||
index: self.resolve_usize(&step.args, "index")?,
|
||||
new_text: self.resolve_str(&step.args, "new_text")?,
|
||||
}),
|
||||
"delete" | "del" => Ok(Command::Delete {
|
||||
index: self.resolve_usize(&step.args, "index")?,
|
||||
}),
|
||||
"send-file" | "sf" => Ok(Command::SendFile {
|
||||
path: self.resolve_str(&step.args, "path")?,
|
||||
}),
|
||||
"download" | "dl" => Ok(Command::Download {
|
||||
index: self.resolve_usize(&step.args, "index")?,
|
||||
}),
|
||||
"delete-account" => Ok(Command::DeleteAccount),
|
||||
"disappear" => Ok(Command::Disappear {
|
||||
arg: self.opt_str(&step.args, "duration"),
|
||||
}),
|
||||
"privacy" => Ok(Command::Privacy {
|
||||
arg: self.opt_str(&step.args, "setting"),
|
||||
}),
|
||||
"verify-fs" => Ok(Command::VerifyFs),
|
||||
"rotate-all-keys" => Ok(Command::RotateAllKeys),
|
||||
"devices" => Ok(Command::Devices),
|
||||
"register-device" => Ok(Command::RegisterDevice {
|
||||
name: self.resolve_str(&step.args, "name")?,
|
||||
}),
|
||||
"revoke-device" => Ok(Command::RevokeDevice {
|
||||
id_prefix: self.resolve_str(&step.args, "id_prefix")?,
|
||||
}),
|
||||
|
||||
// ── Mesh ─────────────────────────────────────────────────────
|
||||
"mesh-peers" => Ok(Command::MeshPeers),
|
||||
"mesh-server" => Ok(Command::MeshServer {
|
||||
addr: self.resolve_str(&step.args, "addr")?,
|
||||
}),
|
||||
"mesh-send" => Ok(Command::MeshSend {
|
||||
peer_id: self.resolve_str(&step.args, "peer_id")?,
|
||||
message: self.resolve_str(&step.args, "message")?,
|
||||
}),
|
||||
"mesh-broadcast" => Ok(Command::MeshBroadcast {
|
||||
topic: self.resolve_str(&step.args, "topic")?,
|
||||
message: self.resolve_str(&step.args, "message")?,
|
||||
}),
|
||||
"mesh-subscribe" => Ok(Command::MeshSubscribe {
|
||||
topic: self.resolve_str(&step.args, "topic")?,
|
||||
}),
|
||||
"mesh-route" => Ok(Command::MeshRoute),
|
||||
"mesh-identity" | "mesh-id" => Ok(Command::MeshIdentity),
|
||||
"mesh-store" => Ok(Command::MeshStore),
|
||||
|
||||
other => bail!("unknown command: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an `AssertCondition` from a playbook step.
|
||||
fn build_assert_condition(&self, step: &PlaybookStep) -> anyhow::Result<AssertCondition> {
|
||||
let cond = step
|
||||
.condition
|
||||
.as_deref()
|
||||
.context("assert step requires 'condition' field")?;
|
||||
match cond {
|
||||
"connected" => Ok(AssertCondition::Connected),
|
||||
"logged_in" => Ok(AssertCondition::LoggedIn),
|
||||
"in_conversation" => {
|
||||
let name = self.resolve_str(&step.args, "name")
|
||||
.or_else(|_| step.value.as_ref()
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| self.substitute(s))
|
||||
.context("assert in_conversation requires 'name' arg or 'value'"))?;
|
||||
Ok(AssertCondition::InConversation { name })
|
||||
}
|
||||
"message_count" => {
|
||||
let op = self.parse_cmp_op(step.op.as_deref().unwrap_or("gte"))?;
|
||||
let count = step
|
||||
.value
|
||||
.as_ref()
|
||||
.and_then(|v| v.as_u64())
|
||||
.context("message_count assert requires numeric 'value'")?
|
||||
as usize;
|
||||
Ok(AssertCondition::MessageCount { op, count })
|
||||
}
|
||||
"member_count" => {
|
||||
let op = self.parse_cmp_op(step.op.as_deref().unwrap_or("gte"))?;
|
||||
let count = step
|
||||
.value
|
||||
.as_ref()
|
||||
.and_then(|v| v.as_u64())
|
||||
.context("member_count assert requires numeric 'value'")?
|
||||
as usize;
|
||||
Ok(AssertCondition::MemberCount { op, count })
|
||||
}
|
||||
other => Ok(AssertCondition::Custom {
|
||||
expression: other.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_cmp_op(&self, s: &str) -> anyhow::Result<CmpOp> {
|
||||
match s {
|
||||
"eq" | "==" => Ok(CmpOp::Eq),
|
||||
"ne" | "!=" => Ok(CmpOp::Ne),
|
||||
"gt" | ">" => Ok(CmpOp::Gt),
|
||||
"lt" | "<" => Ok(CmpOp::Lt),
|
||||
"gte" | ">=" => Ok(CmpOp::Gte),
|
||||
"lte" | "<=" => Ok(CmpOp::Lte),
|
||||
other => bail!("unknown comparison operator: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Variable substitution helpers ────────────────────────────────────
|
||||
|
||||
/// Substitute `$varname` and `${VAR:-default}` in a string.
|
||||
fn substitute(&self, s: &str) -> String {
|
||||
let mut result = String::with_capacity(s.len());
|
||||
let mut chars = s.chars().peekable();
|
||||
while let Some(c) = chars.next() {
|
||||
if c == '$' {
|
||||
if chars.peek() == Some(&'{') {
|
||||
chars.next(); // consume '{'
|
||||
let mut key = String::new();
|
||||
let mut default = None;
|
||||
while let Some(&ch) = chars.peek() {
|
||||
if ch == '}' {
|
||||
chars.next();
|
||||
break;
|
||||
}
|
||||
if ch == ':' && chars.clone().nth(1) == Some('-') {
|
||||
chars.next(); // consume ':'
|
||||
chars.next(); // consume '-'
|
||||
let mut def = String::new();
|
||||
while let Some(&dch) = chars.peek() {
|
||||
if dch == '}' {
|
||||
chars.next();
|
||||
break;
|
||||
}
|
||||
def.push(dch);
|
||||
chars.next();
|
||||
}
|
||||
default = Some(def);
|
||||
break;
|
||||
}
|
||||
key.push(ch);
|
||||
chars.next();
|
||||
}
|
||||
if let Some(val) = self.vars.get(&key) {
|
||||
result.push_str(val);
|
||||
} else if let Ok(val) = std::env::var(&key) {
|
||||
result.push_str(&val);
|
||||
} else if let Some(def) = default {
|
||||
result.push_str(&def);
|
||||
}
|
||||
} else {
|
||||
let mut key = String::new();
|
||||
while let Some(&ch) = chars.peek() {
|
||||
if ch.is_alphanumeric() || ch == '_' {
|
||||
key.push(ch);
|
||||
chars.next();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if let Some(val) = self.vars.get(&key) {
|
||||
result.push_str(val);
|
||||
} else {
|
||||
result.push('$');
|
||||
result.push_str(&key);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.push(c);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Resolve a required string argument with variable substitution.
|
||||
fn resolve_str(
|
||||
&self,
|
||||
args: &HashMap<String, serde_yaml::Value>,
|
||||
key: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let val = args
|
||||
.get(key)
|
||||
.with_context(|| format!("missing required argument: {key}"))?;
|
||||
match val {
|
||||
serde_yaml::Value::String(s) => Ok(self.substitute(s)),
|
||||
serde_yaml::Value::Number(n) => Ok(n.to_string()),
|
||||
serde_yaml::Value::Bool(b) => Ok(b.to_string()),
|
||||
other => Ok(format!("{other:?}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve an optional string argument.
|
||||
fn opt_str(
|
||||
&self,
|
||||
args: &HashMap<String, serde_yaml::Value>,
|
||||
key: &str,
|
||||
) -> Option<String> {
|
||||
args.get(key).map(|v| match v {
|
||||
serde_yaml::Value::String(s) => self.substitute(s),
|
||||
serde_yaml::Value::Number(n) => n.to_string(),
|
||||
serde_yaml::Value::Bool(b) => b.to_string(),
|
||||
other => format!("{other:?}"),
|
||||
})
|
||||
}
|
||||
|
||||
/// Resolve an optional bool argument (defaults to false).
|
||||
fn opt_bool(
|
||||
&self,
|
||||
args: &HashMap<String, serde_yaml::Value>,
|
||||
key: &str,
|
||||
) -> bool {
|
||||
args.get(key)
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Resolve a required usize argument.
|
||||
fn resolve_usize(
|
||||
&self,
|
||||
args: &HashMap<String, serde_yaml::Value>,
|
||||
key: &str,
|
||||
) -> anyhow::Result<usize> {
|
||||
let val = args
|
||||
.get(key)
|
||||
.with_context(|| format!("missing required argument: {key}"))?;
|
||||
val.as_u64()
|
||||
.map(|n| n as usize)
|
||||
.with_context(|| format!("argument '{key}' must be a positive integer"))
|
||||
}
|
||||
|
||||
/// Resolve a required u64 argument.
|
||||
fn resolve_u64(
|
||||
&self,
|
||||
args: &HashMap<String, serde_yaml::Value>,
|
||||
key: &str,
|
||||
) -> anyhow::Result<u64> {
|
||||
let val = args
|
||||
.get(key)
|
||||
.with_context(|| format!("missing required argument: {key}"))?;
|
||||
val.as_u64()
|
||||
.with_context(|| format!("argument '{key}' must be a positive integer"))
|
||||
}
|
||||
|
||||
/// Resolve an optional usize argument.
|
||||
fn opt_usize(
|
||||
&self,
|
||||
args: &HashMap<String, serde_yaml::Value>,
|
||||
key: &str,
|
||||
) -> Option<usize> {
|
||||
args.get(key).and_then(|v| v.as_u64()).map(|n| n as usize)
|
||||
}
|
||||
|
||||
/// Count total expanded steps (including loop iterations).
|
||||
fn expanded_step_count(&self) -> usize {
|
||||
self.playbook
|
||||
.steps
|
||||
.iter()
|
||||
.map(|s| {
|
||||
if let Some(ref ls) = s.loop_spec {
|
||||
if ls.to >= ls.from {
|
||||
ls.to - ls.from + 1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
} else {
|
||||
1
|
||||
}
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ───────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_minimal_playbook() {
|
||||
let yaml = r#"
|
||||
name: "test"
|
||||
steps:
|
||||
- command: whoami
|
||||
- command: list
|
||||
"#;
|
||||
let runner = PlaybookRunner::from_str(yaml).unwrap();
|
||||
assert_eq!(runner.playbook.name, "test");
|
||||
assert_eq!(runner.playbook.steps.len(), 2);
|
||||
assert_eq!(runner.playbook.steps[0].command, "whoami");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_playbook_with_variables() {
|
||||
let yaml = r#"
|
||||
name: "var test"
|
||||
variables:
|
||||
user: alice
|
||||
server: "127.0.0.1:5001"
|
||||
steps:
|
||||
- command: dm
|
||||
args:
|
||||
username: "$user"
|
||||
"#;
|
||||
let runner = PlaybookRunner::from_str(yaml).unwrap();
|
||||
assert_eq!(runner.vars["user"], "alice");
|
||||
assert_eq!(runner.vars["server"], "127.0.0.1:5001");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn variable_substitution() {
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("name".to_string(), "alice".to_string());
|
||||
vars.insert("port".to_string(), "5001".to_string());
|
||||
let runner = PlaybookRunner {
|
||||
playbook: Playbook {
|
||||
name: "test".into(),
|
||||
description: None,
|
||||
variables: HashMap::new(),
|
||||
steps: vec![],
|
||||
},
|
||||
vars,
|
||||
};
|
||||
assert_eq!(runner.substitute("hello $name"), "hello alice");
|
||||
assert_eq!(runner.substitute("port=$port!"), "port=5001!");
|
||||
assert_eq!(runner.substitute("${name}@server"), "alice@server");
|
||||
assert_eq!(
|
||||
runner.substitute("${missing:-default}"),
|
||||
"default"
|
||||
);
|
||||
assert_eq!(runner.substitute("no vars here"), "no vars here");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn step_to_command_mapping() {
|
||||
let yaml = r#"
|
||||
name: "mapping test"
|
||||
variables:
|
||||
user: bob
|
||||
steps:
|
||||
- command: dm
|
||||
args:
|
||||
username: "$user"
|
||||
- command: send
|
||||
args:
|
||||
text: "hello"
|
||||
- command: history
|
||||
args:
|
||||
count: 10
|
||||
- command: wait
|
||||
args:
|
||||
duration_ms: 500
|
||||
"#;
|
||||
let runner = PlaybookRunner::from_str(yaml).unwrap();
|
||||
let cmd0 = runner.step_to_command(&runner.playbook.steps[0]).unwrap();
|
||||
assert!(matches!(cmd0, Command::Dm { username } if username == "bob"));
|
||||
|
||||
let cmd1 = runner.step_to_command(&runner.playbook.steps[1]).unwrap();
|
||||
assert!(matches!(cmd1, Command::SendMessage { text } if text == "hello"));
|
||||
|
||||
let cmd2 = runner.step_to_command(&runner.playbook.steps[2]).unwrap();
|
||||
assert!(matches!(cmd2, Command::History { count: 10 }));
|
||||
|
||||
let cmd3 = runner.step_to_command(&runner.playbook.steps[3]).unwrap();
|
||||
assert!(matches!(cmd3, Command::Wait { duration_ms: 500 }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_assert_step() {
|
||||
let yaml = r#"
|
||||
name: "assert test"
|
||||
steps:
|
||||
- command: assert
|
||||
condition: message_count
|
||||
op: gte
|
||||
value: 5
|
||||
"#;
|
||||
let runner = PlaybookRunner::from_str(yaml).unwrap();
|
||||
let cmd = runner.step_to_command(&runner.playbook.steps[0]).unwrap();
|
||||
match cmd {
|
||||
Command::Assert {
|
||||
condition: AssertCondition::MessageCount { op, count },
|
||||
} => {
|
||||
assert_eq!(op, CmpOp::Gte);
|
||||
assert_eq!(count, 5);
|
||||
}
|
||||
other => panic!("expected Assert MessageCount, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_loop_spec() {
|
||||
let yaml = r#"
|
||||
name: "loop test"
|
||||
steps:
|
||||
- command: send
|
||||
args:
|
||||
text: "msg $i"
|
||||
loop:
|
||||
var: i
|
||||
from: 1
|
||||
to: 5
|
||||
"#;
|
||||
let runner = PlaybookRunner::from_str(yaml).unwrap();
|
||||
assert_eq!(runner.expanded_step_count(), 5);
|
||||
let ls = runner.playbook.steps[0].loop_spec.as_ref().unwrap();
|
||||
assert_eq!(ls.var, "i");
|
||||
assert_eq!(ls.from, 1);
|
||||
assert_eq!(ls.to, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn on_error_defaults_to_fail() {
|
||||
let yaml = r#"
|
||||
name: "error test"
|
||||
steps:
|
||||
- command: whoami
|
||||
- command: list
|
||||
on_error: continue
|
||||
- command: quit
|
||||
on_error: skip
|
||||
"#;
|
||||
let runner = PlaybookRunner::from_str(yaml).unwrap();
|
||||
assert_eq!(runner.playbook.steps[0].on_error, OnError::Fail);
|
||||
assert_eq!(runner.playbook.steps[1].on_error, OnError::Continue);
|
||||
assert_eq!(runner.playbook.steps[2].on_error, OnError::Skip);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cmp_op_parsing() {
|
||||
let runner = PlaybookRunner::from_str("name: t\nsteps: []").unwrap();
|
||||
assert!(matches!(runner.parse_cmp_op("eq"), Ok(CmpOp::Eq)));
|
||||
assert!(matches!(runner.parse_cmp_op("=="), Ok(CmpOp::Eq)));
|
||||
assert!(matches!(runner.parse_cmp_op("gte"), Ok(CmpOp::Gte)));
|
||||
assert!(matches!(runner.parse_cmp_op(">="), Ok(CmpOp::Gte)));
|
||||
assert!(matches!(runner.parse_cmp_op("<"), Ok(CmpOp::Lt)));
|
||||
assert!(runner.parse_cmp_op("invalid").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_display() {
|
||||
let report = PlaybookReport {
|
||||
name: "test".into(),
|
||||
total_steps: 3,
|
||||
passed: 2,
|
||||
failed: 1,
|
||||
skipped: 0,
|
||||
duration: Duration::from_millis(150),
|
||||
step_results: vec![
|
||||
StepResult {
|
||||
step_index: 0,
|
||||
command: "whoami".into(),
|
||||
success: true,
|
||||
duration: Duration::from_millis(10),
|
||||
output: None,
|
||||
error: None,
|
||||
},
|
||||
StepResult {
|
||||
step_index: 1,
|
||||
command: "dm".into(),
|
||||
success: true,
|
||||
duration: Duration::from_millis(50),
|
||||
output: None,
|
||||
error: None,
|
||||
},
|
||||
StepResult {
|
||||
step_index: 2,
|
||||
command: "assert".into(),
|
||||
success: false,
|
||||
duration: Duration::from_millis(1),
|
||||
output: None,
|
||||
error: Some("message count 0 < 1".into()),
|
||||
},
|
||||
],
|
||||
};
|
||||
let s = format!("{report}");
|
||||
assert!(s.contains("2 passed, 1 failed"));
|
||||
assert!(s.contains("[3/3] assert ... FAIL"));
|
||||
}
|
||||
}
|
||||
@@ -37,13 +37,13 @@ use super::token_cache::{clear_cached_session, load_cached_session, save_cached_
|
||||
|
||||
// ── Input parsing ────────────────────────────────────────────────────────────
|
||||
|
||||
enum Input {
|
||||
pub(crate) enum Input {
|
||||
Slash(SlashCommand),
|
||||
ChatMessage(String),
|
||||
Empty,
|
||||
}
|
||||
|
||||
enum SlashCommand {
|
||||
pub(crate) enum SlashCommand {
|
||||
Help,
|
||||
Quit,
|
||||
Whoami,
|
||||
@@ -104,7 +104,7 @@ enum SlashCommand {
|
||||
RevokeDevice { id_prefix: String },
|
||||
}
|
||||
|
||||
fn parse_input(line: &str) -> Input {
|
||||
pub(crate) fn parse_input(line: &str) -> Input {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Input::Empty;
|
||||
@@ -246,7 +246,7 @@ fn parse_input(line: &str) -> Input {
|
||||
"/react" => match arg {
|
||||
Some(rest) => {
|
||||
let mut parts = rest.splitn(2, ' ');
|
||||
let emoji = parts.next().unwrap().to_string();
|
||||
let emoji = parts.next().unwrap_or_default().to_string();
|
||||
let index = parts.next().and_then(|s| s.trim().parse::<usize>().ok());
|
||||
Input::Slash(SlashCommand::React { emoji, index })
|
||||
}
|
||||
@@ -258,7 +258,7 @@ fn parse_input(line: &str) -> Input {
|
||||
"/edit" => match arg {
|
||||
Some(rest) => {
|
||||
let mut parts = rest.splitn(2, ' ');
|
||||
let idx_str = parts.next().unwrap();
|
||||
let idx_str = parts.next().unwrap_or_default();
|
||||
match (idx_str.parse::<usize>(), parts.next()) {
|
||||
(Ok(index), Some(new_text)) if !new_text.trim().is_empty() => {
|
||||
Input::Slash(SlashCommand::Edit { index, new_text: new_text.trim().to_string() })
|
||||
@@ -847,7 +847,7 @@ async fn handle_slash(
|
||||
}
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
pub(crate) fn print_help() {
|
||||
display::print_status("Commands:");
|
||||
display::print_status(" /dm <user[@domain]> - Start or switch to a DM (federation supported)");
|
||||
display::print_status(" /create-group <name> - Create a new group");
|
||||
@@ -925,7 +925,7 @@ fn format_ttl(secs: u32) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn cmd_disappear(
|
||||
pub(crate) fn cmd_disappear(
|
||||
session: &mut SessionState,
|
||||
arg: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -966,7 +966,7 @@ fn cmd_disappear(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_privacy(
|
||||
pub(crate) fn cmd_privacy(
|
||||
session: &mut SessionState,
|
||||
arg: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -1047,7 +1047,7 @@ fn cmd_privacy(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_verify_fs(session: &SessionState) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_verify_fs(session: &SessionState) -> anyhow::Result<()> {
|
||||
let conv_id = session
|
||||
.active_conversation
|
||||
.as_ref()
|
||||
@@ -1091,7 +1091,7 @@ fn cmd_verify_fs(session: &SessionState) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_rotate_all_keys(
|
||||
pub(crate) async fn cmd_rotate_all_keys(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -1109,7 +1109,7 @@ async fn cmd_rotate_all_keys(
|
||||
}
|
||||
|
||||
/// Discover nearby qpq servers via mDNS (requires `--features mesh` build).
|
||||
fn cmd_mesh_peers() -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_mesh_peers() -> anyhow::Result<()> {
|
||||
use super::mesh_discovery::MeshDiscovery;
|
||||
|
||||
match MeshDiscovery::start() {
|
||||
@@ -1138,7 +1138,7 @@ fn cmd_mesh_peers() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
/// Send a direct P2P mesh message (stub — P2pNode not yet wired into session).
|
||||
fn cmd_mesh_send(peer_id: &str, message: &str) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_mesh_send(peer_id: &str, message: &str) -> anyhow::Result<()> {
|
||||
#[cfg(feature = "mesh")]
|
||||
{
|
||||
display::print_status(&format!("mesh send: would send to {peer_id}: {message}"));
|
||||
@@ -1153,7 +1153,7 @@ fn cmd_mesh_send(peer_id: &str, message: &str) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
/// Broadcast an encrypted message on a topic (stub — P2pNode not yet wired into session).
|
||||
fn cmd_mesh_broadcast(topic: &str, message: &str) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_mesh_broadcast(topic: &str, message: &str) -> anyhow::Result<()> {
|
||||
#[cfg(feature = "mesh")]
|
||||
{
|
||||
display::print_status(&format!("mesh broadcast to {topic}: {message}"));
|
||||
@@ -1168,7 +1168,7 @@ fn cmd_mesh_broadcast(topic: &str, message: &str) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
/// Subscribe to a broadcast topic (stub — P2pNode not yet wired into session).
|
||||
fn cmd_mesh_subscribe(topic: &str) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_mesh_subscribe(topic: &str) -> anyhow::Result<()> {
|
||||
#[cfg(feature = "mesh")]
|
||||
{
|
||||
display::print_status(&format!("subscribed to topic: {topic}"));
|
||||
@@ -1183,7 +1183,7 @@ fn cmd_mesh_subscribe(topic: &str) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
/// Display known mesh peers and routes from the mesh identity file.
|
||||
fn cmd_mesh_route(session: &SessionState) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_mesh_route(session: &SessionState) -> anyhow::Result<()> {
|
||||
#[cfg(feature = "mesh")]
|
||||
{
|
||||
let mesh_state_path = session.state_path.with_extension("mesh.json");
|
||||
@@ -1217,7 +1217,7 @@ fn cmd_mesh_route(session: &SessionState) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
/// Display mesh node identity information.
|
||||
fn cmd_mesh_identity(session: &SessionState) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_mesh_identity(session: &SessionState) -> anyhow::Result<()> {
|
||||
#[cfg(feature = "mesh")]
|
||||
{
|
||||
let mesh_state_path = session.state_path.with_extension("mesh.json");
|
||||
@@ -1239,7 +1239,7 @@ fn cmd_mesh_identity(session: &SessionState) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
/// Display mesh store-and-forward statistics.
|
||||
fn cmd_mesh_store(session: &SessionState) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_mesh_store(session: &SessionState) -> anyhow::Result<()> {
|
||||
#[cfg(feature = "mesh")]
|
||||
{
|
||||
// Without a live P2pNode in the session, we can only report that the store
|
||||
@@ -1256,7 +1256,7 @@ fn cmd_mesh_store(session: &SessionState) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_whoami(session: &SessionState) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_whoami(session: &SessionState) -> anyhow::Result<()> {
|
||||
display::print_status(&format!(
|
||||
"identity: {}",
|
||||
hex::encode(session.identity.public_key_bytes())
|
||||
@@ -1272,7 +1272,7 @@ fn cmd_whoami(session: &SessionState) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_list(session: &SessionState) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_list(session: &SessionState) -> anyhow::Result<()> {
|
||||
let convs = session.conv_store.list_conversations()?;
|
||||
if convs.is_empty() {
|
||||
display::print_status("no conversations yet. Try /dm <username> or /create-group <name>");
|
||||
@@ -1303,7 +1303,7 @@ fn cmd_list(session: &SessionState) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_switch(session: &mut SessionState, target: &str) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_switch(session: &mut SessionState, target: &str) -> anyhow::Result<()> {
|
||||
let target = target.trim();
|
||||
|
||||
let conv = if let Some(username) = target.strip_prefix('@') {
|
||||
@@ -1330,7 +1330,7 @@ fn cmd_switch(session: &mut SessionState, target: &str) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_dm(
|
||||
pub(crate) async fn cmd_dm(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
username: &str,
|
||||
@@ -1469,7 +1469,7 @@ async fn cmd_dm(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_create_group(session: &mut SessionState, name: &str) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_create_group(session: &mut SessionState, name: &str) -> anyhow::Result<()> {
|
||||
let conv_id = ConversationId::from_group_name(name);
|
||||
|
||||
if session.conv_store.find_group_by_name(name)?.is_some() {
|
||||
@@ -1513,7 +1513,7 @@ fn cmd_create_group(session: &mut SessionState, name: &str) -> anyhow::Result<()
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_invite(
|
||||
pub(crate) async fn cmd_invite(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
target: &str,
|
||||
@@ -1584,7 +1584,7 @@ async fn cmd_invite(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_remove(
|
||||
pub(crate) async fn cmd_remove(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
target: &str,
|
||||
@@ -1628,7 +1628,7 @@ async fn cmd_remove(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_leave(
|
||||
pub(crate) async fn cmd_leave(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -1665,7 +1665,7 @@ async fn cmd_leave(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_update_key(
|
||||
pub(crate) async fn cmd_update_key(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -1710,7 +1710,7 @@ async fn cmd_update_key(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_join(
|
||||
pub(crate) async fn cmd_join(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -1818,7 +1818,7 @@ async fn resolve_or_hex(
|
||||
}
|
||||
}
|
||||
|
||||
async fn cmd_members(
|
||||
pub(crate) async fn cmd_members(
|
||||
session: &SessionState,
|
||||
client: &node_service::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -1855,7 +1855,7 @@ async fn cmd_members(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_group_info(
|
||||
pub(crate) async fn cmd_group_info(
|
||||
session: &SessionState,
|
||||
client: &node_service::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -1908,7 +1908,7 @@ async fn cmd_group_info(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_rename(session: &mut SessionState, new_name: &str) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_rename(session: &mut SessionState, new_name: &str) -> anyhow::Result<()> {
|
||||
let conv_id = session
|
||||
.active_conversation
|
||||
.as_ref()
|
||||
@@ -1926,7 +1926,7 @@ fn cmd_rename(session: &mut SessionState, new_name: &str) -> anyhow::Result<()>
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cmd_history(session: &SessionState, count: usize) -> anyhow::Result<()> {
|
||||
pub(crate) fn cmd_history(session: &SessionState, count: usize) -> anyhow::Result<()> {
|
||||
let conv_id = session
|
||||
.active_conversation
|
||||
.as_ref()
|
||||
@@ -1943,7 +1943,7 @@ fn cmd_history(session: &SessionState, count: usize) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_verify(
|
||||
pub(crate) async fn cmd_verify(
|
||||
session: &SessionState,
|
||||
client: &node_service::Client,
|
||||
username: &str,
|
||||
@@ -1982,7 +1982,7 @@ async fn cmd_verify(
|
||||
|
||||
// ── Typing indicator ─────────────────────────────────────────────────────────
|
||||
|
||||
async fn cmd_typing(
|
||||
pub(crate) async fn cmd_typing(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -2033,7 +2033,7 @@ async fn cmd_typing(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_react(
|
||||
pub(crate) async fn cmd_react(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
emoji: &str,
|
||||
@@ -2127,7 +2127,7 @@ async fn cmd_react(
|
||||
|
||||
// ── Edit / Delete ────────────────────────────────────────────────────────────
|
||||
|
||||
async fn cmd_edit(
|
||||
pub(crate) async fn cmd_edit(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
index: usize,
|
||||
@@ -2200,7 +2200,7 @@ async fn cmd_edit(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_delete(
|
||||
pub(crate) async fn cmd_delete(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
index: usize,
|
||||
@@ -2313,7 +2313,7 @@ fn format_size(bytes: u64) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
async fn cmd_send_file(
|
||||
pub(crate) async fn cmd_send_file(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
path_str: &str,
|
||||
@@ -2447,7 +2447,7 @@ async fn cmd_send_file(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_download(
|
||||
pub(crate) async fn cmd_download(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
index: usize,
|
||||
@@ -2582,7 +2582,7 @@ fn extract_filename_from_body(body: &str) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn cmd_delete_account(
|
||||
pub(crate) async fn cmd_delete_account(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -2631,7 +2631,7 @@ async fn handle_send(
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_send(
|
||||
pub(crate) async fn do_send(
|
||||
session: &mut SessionState,
|
||||
client: &node_service::Client,
|
||||
text: &str,
|
||||
@@ -3240,7 +3240,7 @@ async fn replenish_pending_key(
|
||||
|
||||
// ── Device management commands ──────────────────────────────────────────────
|
||||
|
||||
async fn cmd_devices(client: &node_service::Client) -> anyhow::Result<()> {
|
||||
pub(crate) async fn cmd_devices(client: &node_service::Client) -> anyhow::Result<()> {
|
||||
let devices = list_devices(client).await?;
|
||||
if devices.is_empty() {
|
||||
display::print_status("No devices registered.");
|
||||
@@ -3260,7 +3260,7 @@ async fn cmd_devices(client: &node_service::Client) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_register_device(
|
||||
pub(crate) async fn cmd_register_device(
|
||||
client: &node_service::Client,
|
||||
name: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -3279,7 +3279,7 @@ async fn cmd_register_device(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cmd_revoke_device(
|
||||
pub(crate) async fn cmd_revoke_device(
|
||||
client: &node_service::Client,
|
||||
id_prefix: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
|
||||
@@ -152,7 +152,7 @@ pub fn set_auth(auth: &mut auth::Builder<'_>) -> anyhow::Result<()> {
|
||||
)
|
||||
})?;
|
||||
auth.set_version(ctx.version);
|
||||
auth.set_access_token(&ctx.access_token);
|
||||
auth.set_access_token(&*ctx.access_token);
|
||||
auth.set_device_id(&ctx.device_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Context;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use quicproquo_core::{DiskKeyStore, GroupMember, HybridKeypair, IdentityKeypair};
|
||||
|
||||
@@ -25,8 +26,8 @@ pub struct SessionState {
|
||||
pub hybrid_kp: Option<HybridKeypair>,
|
||||
/// Path to the legacy state file (for backward compat with one-shot commands).
|
||||
pub state_path: PathBuf,
|
||||
/// Optional password for the legacy state file.
|
||||
pub password: Option<String>,
|
||||
/// Optional password for the legacy state file. Zeroized on drop. (M9)
|
||||
pub password: Option<Zeroizing<String>>,
|
||||
/// SQLite-backed conversation + message store.
|
||||
pub conv_store: ConversationStore,
|
||||
/// Currently active conversation.
|
||||
@@ -80,7 +81,7 @@ impl SessionState {
|
||||
identity,
|
||||
hybrid_kp,
|
||||
state_path: state_path.to_path_buf(),
|
||||
password: password.map(String::from),
|
||||
password: password.map(|p| Zeroizing::new(String::from(p))),
|
||||
conv_store,
|
||||
active_conversation: None,
|
||||
members: HashMap::new(),
|
||||
@@ -183,7 +184,10 @@ impl SessionState {
|
||||
fn create_member_from_conv(&self, conv: &Conversation) -> anyhow::Result<GroupMember> {
|
||||
let ks_path = self.keystore_path_for(&conv.id);
|
||||
let ks = DiskKeyStore::persistent(&ks_path)
|
||||
.unwrap_or_else(|_| DiskKeyStore::ephemeral());
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::warn!(path = %ks_path.display(), error = %e, "DiskKeyStore open failed, falling back to ephemeral");
|
||||
DiskKeyStore::ephemeral()
|
||||
});
|
||||
|
||||
let group = conv
|
||||
.mls_group_blob
|
||||
|
||||
@@ -55,7 +55,7 @@ impl StoredState {
|
||||
.transpose()?;
|
||||
|
||||
Ok(Self {
|
||||
identity_seed: member.identity_seed(),
|
||||
identity_seed: *member.identity_seed(),
|
||||
group,
|
||||
hybrid_key: hybrid_kp.map(|kp| kp.to_bytes()),
|
||||
member_keys: Vec::new(),
|
||||
|
||||
@@ -64,7 +64,14 @@ pub fn save_cached_session(
|
||||
|
||||
let bytes = match password {
|
||||
Some(pw) => encrypt_state(pw, contents.as_bytes())?,
|
||||
None => contents.into_bytes(),
|
||||
None => {
|
||||
#[cfg(not(unix))]
|
||||
tracing::warn!(
|
||||
"storing session token as plaintext (no password set); \
|
||||
file permissions cannot be restricted on this platform"
|
||||
);
|
||||
contents.into_bytes()
|
||||
}
|
||||
};
|
||||
|
||||
std::fs::write(&path, bytes).with_context(|| format!("write session cache {path:?}"))?;
|
||||
|
||||
@@ -644,7 +644,7 @@ async fn tui_loop(
|
||||
// Clone session state for the poll task (it needs its own SessionState).
|
||||
let poll_session = SessionState::load(
|
||||
&session.state_path.clone(),
|
||||
session.password.as_deref(),
|
||||
session.password.as_ref().map(|p| p.as_str()),
|
||||
)?;
|
||||
let poll_tx = event_tx.clone();
|
||||
tokio::task::spawn_local(poll_task(poll_session, client.clone(), poll_tx));
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
pub mod client;
|
||||
|
||||
pub use client::commands::{
|
||||
@@ -26,14 +28,85 @@ pub use client::commands::{
|
||||
cmd_send, cmd_whoami, opaque_login, receive_pending_plaintexts, whoami_json,
|
||||
};
|
||||
|
||||
pub use client::command_engine::{Command, CommandRegistry, CommandResult};
|
||||
#[cfg(feature = "playbook")]
|
||||
pub use client::playbook::{Playbook, PlaybookReport, PlaybookRunner};
|
||||
pub use client::repl::run_repl;
|
||||
pub use client::rpc::{connect_node, connect_node_opt, create_channel, enqueue, fetch_wait, resolve_user};
|
||||
|
||||
// Global auth context — RwLock so the REPL can set it after OPAQUE login.
|
||||
// ── ClientContext: structured holder for session-scoped auth + TLS config ────
|
||||
|
||||
/// Holds the authentication credentials and TLS policy for a client session.
|
||||
///
|
||||
/// Prefer constructing a `ClientContext` and passing it explicitly where
|
||||
/// possible. The global `AUTH_CONTEXT` / `INSECURE_SKIP_VERIFY` statics
|
||||
/// delegate to a `ClientContext` under the hood and exist only for backward
|
||||
/// compatibility with call-sites that have not yet been migrated.
|
||||
pub struct ClientContext {
|
||||
auth: RwLock<Option<ClientAuth>>,
|
||||
insecure_skip_verify: AtomicBool,
|
||||
}
|
||||
|
||||
impl ClientContext {
|
||||
/// Create a new context with no auth and TLS verification enabled.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
auth: RwLock::new(None),
|
||||
insecure_skip_verify: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a context pre-loaded with auth credentials.
|
||||
pub fn with_auth(auth: ClientAuth) -> Self {
|
||||
Self {
|
||||
auth: RwLock::new(Some(auth)),
|
||||
insecure_skip_verify: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set (or replace) the auth credentials.
|
||||
pub fn set_auth(&self, ctx: ClientAuth) {
|
||||
let mut guard = self.auth.write().expect("ClientContext auth lock poisoned");
|
||||
*guard = Some(ctx);
|
||||
}
|
||||
|
||||
/// Read the current auth snapshot (cloned).
|
||||
pub fn get_auth(&self) -> Option<ClientAuth> {
|
||||
let guard = self.auth.read().expect("ClientContext auth lock poisoned");
|
||||
guard.clone()
|
||||
}
|
||||
|
||||
/// Returns true if auth credentials have been set.
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
let guard = self.auth.read().expect("ClientContext auth lock poisoned");
|
||||
guard.is_some()
|
||||
}
|
||||
|
||||
/// Enable or disable insecure TLS mode.
|
||||
pub fn set_insecure_skip_verify(&self, enabled: bool) {
|
||||
self.insecure_skip_verify.store(enabled, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Read the current insecure-skip-verify flag.
|
||||
pub fn insecure_skip_verify(&self) -> bool {
|
||||
self.insecure_skip_verify.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientContext {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Global statics (thin wrappers, kept for backward compat) ─────────────────
|
||||
|
||||
/// Global auth context — delegates to a process-wide `ClientContext`.
|
||||
/// Prefer passing `&ClientContext` explicitly in new code.
|
||||
pub(crate) static AUTH_CONTEXT: RwLock<Option<ClientAuth>> = RwLock::new(None);
|
||||
|
||||
/// When `true`, [`connect_node`] skips TLS certificate verification.
|
||||
/// Set via [`set_insecure_skip_verify`]; read by the RPC layer.
|
||||
/// Prefer `ClientContext::set_insecure_skip_verify` in new code.
|
||||
pub(crate) static INSECURE_SKIP_VERIFY: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
/// Enable or disable insecure (no-verify) TLS mode globally.
|
||||
@@ -47,7 +120,8 @@ pub fn set_insecure_skip_verify(enabled: bool) {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClientAuth {
|
||||
pub(crate) version: u16,
|
||||
pub(crate) access_token: Vec<u8>,
|
||||
/// Bearer or OPAQUE session token. Zeroized on drop. (M8)
|
||||
pub(crate) access_token: Zeroizing<Vec<u8>>,
|
||||
pub(crate) device_id: Vec<u8>,
|
||||
}
|
||||
|
||||
@@ -58,7 +132,7 @@ impl ClientAuth {
|
||||
let device = device_id.unwrap_or_default().into_bytes();
|
||||
Self {
|
||||
version: 1,
|
||||
access_token: token,
|
||||
access_token: Zeroizing::new(token),
|
||||
device_id: device,
|
||||
}
|
||||
}
|
||||
@@ -68,7 +142,7 @@ impl ClientAuth {
|
||||
let device = device_id.unwrap_or_default().into_bytes();
|
||||
Self {
|
||||
version: 1,
|
||||
access_token: raw_token,
|
||||
access_token: Zeroizing::new(raw_token),
|
||||
device_id: device,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,6 +393,34 @@ enum Command {
|
||||
#[arg(long)]
|
||||
input: PathBuf,
|
||||
},
|
||||
|
||||
/// Execute a YAML playbook (scripted command sequence) and exit.
|
||||
/// Requires `--features playbook`.
|
||||
#[cfg(feature = "playbook")]
|
||||
Run {
|
||||
/// Path to the YAML playbook file.
|
||||
playbook: PathBuf,
|
||||
|
||||
/// State file path (identity + MLS state).
|
||||
#[arg(long, default_value = "qpq-state.bin", env = "QPQ_STATE")]
|
||||
state: PathBuf,
|
||||
|
||||
/// Server address (host:port).
|
||||
#[arg(long, default_value = "127.0.0.1:7000", env = "QPQ_SERVER")]
|
||||
server: String,
|
||||
|
||||
/// OPAQUE username for automatic login.
|
||||
#[arg(long, env = "QPQ_USERNAME")]
|
||||
username: Option<String>,
|
||||
|
||||
/// OPAQUE password.
|
||||
#[arg(long, env = "QPQ_PASSWORD")]
|
||||
password: Option<String>,
|
||||
|
||||
/// Override playbook variables: KEY=VALUE (repeatable).
|
||||
#[arg(long = "var", short = 'V')]
|
||||
vars: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
@@ -410,6 +438,77 @@ fn derive_state_path(state: PathBuf, username: Option<&str>) -> PathBuf {
|
||||
state
|
||||
}
|
||||
|
||||
// ── Playbook execution ───────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(feature = "playbook")]
|
||||
async fn run_playbook(
|
||||
playbook_path: &Path,
|
||||
state: &Path,
|
||||
server: &str,
|
||||
ca_cert: &Path,
|
||||
server_name: &str,
|
||||
state_pw: Option<&str>,
|
||||
username: Option<&str>,
|
||||
password: Option<&str>,
|
||||
access_token: &str,
|
||||
device_id: Option<&str>,
|
||||
extra_vars: &[String],
|
||||
) -> anyhow::Result<()> {
|
||||
use quicproquo_client::PlaybookRunner;
|
||||
|
||||
let insecure = std::env::var("QPQ_DANGER_ACCEPT_INVALID_CERTS").is_ok();
|
||||
|
||||
// Connect to server.
|
||||
let client =
|
||||
quicproquo_client::connect_node_opt(server, ca_cert, server_name, insecure)
|
||||
.await
|
||||
.context("connect to server")?;
|
||||
|
||||
// Build session state.
|
||||
let mut session = quicproquo_client::client::session::SessionState::load(state, state_pw)
|
||||
.context("load session state")?;
|
||||
|
||||
// If username/password provided, do OPAQUE login.
|
||||
if let (Some(uname), Some(pw)) = (username, password) {
|
||||
if let Err(e) =
|
||||
quicproquo_client::opaque_login(&client, uname, pw, &session.identity.public_key_bytes()).await
|
||||
{
|
||||
eprintln!("OPAQUE login failed: {e:#}");
|
||||
}
|
||||
} else if !access_token.is_empty() {
|
||||
let auth = ClientAuth::from_parts(access_token.to_string(), device_id.map(String::from));
|
||||
init_auth(auth);
|
||||
}
|
||||
|
||||
// Load playbook.
|
||||
let mut runner = PlaybookRunner::from_file(playbook_path)
|
||||
.with_context(|| format!("load playbook: {}", playbook_path.display()))?;
|
||||
|
||||
// Inject extra variables from --var KEY=VALUE flags.
|
||||
for kv in extra_vars {
|
||||
if let Some((k, v)) = kv.split_once('=') {
|
||||
runner.set_var(k, v);
|
||||
} else {
|
||||
eprintln!("warning: ignoring malformed --var '{kv}' (expected KEY=VALUE)");
|
||||
}
|
||||
}
|
||||
|
||||
// Inject connection info as variables.
|
||||
runner.set_var("_server", server);
|
||||
if let Some(u) = username {
|
||||
runner.set_var("_username", u);
|
||||
}
|
||||
|
||||
let report = runner.run(&mut session, &client).await;
|
||||
print!("{report}");
|
||||
|
||||
if report.all_passed() {
|
||||
Ok(())
|
||||
} else {
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::main]
|
||||
@@ -736,5 +835,32 @@ async fn main() -> anyhow::Result<()> {
|
||||
)
|
||||
}
|
||||
Command::ExportVerify { input } => cmd_export_verify(&input),
|
||||
#[cfg(feature = "playbook")]
|
||||
Command::Run {
|
||||
playbook,
|
||||
state,
|
||||
server,
|
||||
username,
|
||||
password,
|
||||
vars,
|
||||
} => {
|
||||
let state = derive_state_path(state, username.as_deref());
|
||||
let local = tokio::task::LocalSet::new();
|
||||
local
|
||||
.run_until(run_playbook(
|
||||
&playbook,
|
||||
&state,
|
||||
&server,
|
||||
&args.ca_cert,
|
||||
&args.server_name,
|
||||
state_pw,
|
||||
username.as_deref(),
|
||||
password.as_deref(),
|
||||
&args.access_token,
|
||||
args.device_id.as_deref(),
|
||||
&vars,
|
||||
))
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,12 +10,20 @@ pub enum CoreError {
|
||||
#[error("Cap'n Proto error: {0}")]
|
||||
Capnp(#[from] capnp::Error),
|
||||
|
||||
/// An MLS operation failed.
|
||||
/// An MLS operation failed (string description).
|
||||
///
|
||||
/// The inner string is the debug representation of the openmls error.
|
||||
/// Preserved for backward compatibility. Prefer [`CoreError::MlsError`]
|
||||
/// for new code that wraps typed openmls errors.
|
||||
#[error("MLS error: {0}")]
|
||||
Mls(String),
|
||||
|
||||
/// An MLS operation failed (typed, boxed error).
|
||||
///
|
||||
/// Wraps the underlying openmls error so callers can downcast to specific
|
||||
/// error types when needed.
|
||||
#[error("MLS error: {0}")]
|
||||
MlsError(Box<dyn std::error::Error + Send + Sync>),
|
||||
|
||||
/// A hybrid KEM (X25519 + ML-KEM-768) operation failed.
|
||||
#[error("hybrid KEM error: {0}")]
|
||||
HybridKem(#[from] crate::hybrid_kem::HybridKemError),
|
||||
|
||||
@@ -34,6 +34,8 @@
|
||||
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use openmls::prelude::{
|
||||
Ciphersuite, Credential, CredentialType, CredentialWithKey, CryptoConfig, GroupId, KeyPackage,
|
||||
KeyPackageIn, MlsGroup, MlsGroupConfig, MlsMessageInBody, MlsMessageOut,
|
||||
@@ -468,7 +470,36 @@ impl GroupMember {
|
||||
///
|
||||
/// Returns [`CoreError::Mls`] if the message is malformed, fails
|
||||
/// authentication, or the group state is inconsistent.
|
||||
pub fn receive_message(&mut self, mut bytes: &[u8]) -> Result<ReceivedMessage, CoreError> {
|
||||
pub fn receive_message(&mut self, bytes: &[u8]) -> Result<ReceivedMessage, CoreError> {
|
||||
let (sender, content) = self.process_incoming(bytes)?;
|
||||
let _ = sender; // not needed for this variant
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
/// Process an incoming TLS-encoded MLS message and return sender identity + plaintext for application messages.
|
||||
///
|
||||
/// Same as [`receive_message`], but for Application messages returns
|
||||
/// `(sender_identity_bytes, plaintext)` so the client can display who sent the message.
|
||||
pub fn receive_message_with_sender(
|
||||
&mut self,
|
||||
bytes: &[u8],
|
||||
) -> Result<ReceivedMessageWithSender, CoreError> {
|
||||
let (sender_identity, content) = self.process_incoming(bytes)?;
|
||||
Ok(match content {
|
||||
ReceivedMessage::Application(plaintext) => {
|
||||
ReceivedMessageWithSender::Application(sender_identity, plaintext)
|
||||
}
|
||||
ReceivedMessage::StateChanged => ReceivedMessageWithSender::StateChanged,
|
||||
ReceivedMessage::SelfRemoved => ReceivedMessageWithSender::SelfRemoved,
|
||||
})
|
||||
}
|
||||
|
||||
/// Shared MLS message processing: deserialize, authenticate, and apply
|
||||
/// the incoming message. Returns `(sender_identity_bytes, result)`.
|
||||
fn process_incoming(
|
||||
&mut self,
|
||||
mut bytes: &[u8],
|
||||
) -> Result<(Vec<u8>, ReceivedMessage), CoreError> {
|
||||
let group = self
|
||||
.group
|
||||
.as_mut()
|
||||
@@ -488,9 +519,11 @@ impl GroupMember {
|
||||
.process_message(&self.backend, protocol_message)
|
||||
.map_err(|e| CoreError::Mls(format!("process_message: {e:?}")))?;
|
||||
|
||||
let sender_identity = processed.credential().identity().to_vec();
|
||||
|
||||
match processed.into_content() {
|
||||
ProcessedMessageContent::ApplicationMessage(app) => {
|
||||
Ok(ReceivedMessage::Application(app.into_bytes()))
|
||||
Ok((sender_identity, ReceivedMessage::Application(app.into_bytes())))
|
||||
}
|
||||
ProcessedMessageContent::StagedCommitMessage(staged) => {
|
||||
// Check if this commit removes us.
|
||||
@@ -505,79 +538,19 @@ impl GroupMember {
|
||||
|
||||
if self_removed {
|
||||
self.group = None;
|
||||
Ok(ReceivedMessage::SelfRemoved)
|
||||
Ok((sender_identity, ReceivedMessage::SelfRemoved))
|
||||
} else {
|
||||
Ok(ReceivedMessage::StateChanged)
|
||||
Ok((sender_identity, ReceivedMessage::StateChanged))
|
||||
}
|
||||
}
|
||||
// Proposals are stored for a later Commit; nothing to return yet.
|
||||
ProcessedMessageContent::ProposalMessage(proposal) => {
|
||||
group.store_pending_proposal(*proposal);
|
||||
Ok(ReceivedMessage::StateChanged)
|
||||
Ok((sender_identity, ReceivedMessage::StateChanged))
|
||||
}
|
||||
ProcessedMessageContent::ExternalJoinProposalMessage(proposal) => {
|
||||
group.store_pending_proposal(*proposal);
|
||||
Ok(ReceivedMessage::StateChanged)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process an incoming TLS-encoded MLS message and return sender identity + plaintext for application messages.
|
||||
///
|
||||
/// Same as [`receive_message`], but for Application messages returns
|
||||
/// `(sender_identity_bytes, plaintext)` so the client can display who sent the message.
|
||||
pub fn receive_message_with_sender(
|
||||
&mut self,
|
||||
mut bytes: &[u8],
|
||||
) -> Result<ReceivedMessageWithSender, CoreError> {
|
||||
let group = self
|
||||
.group
|
||||
.as_mut()
|
||||
.ok_or_else(|| CoreError::Mls("no active group".into()))?;
|
||||
|
||||
let msg_in = openmls::prelude::MlsMessageIn::tls_deserialize(&mut bytes)
|
||||
.map_err(|e| CoreError::Mls(format!("message deserialise: {e:?}")))?;
|
||||
|
||||
let protocol_message = match msg_in.extract() {
|
||||
MlsMessageInBody::PrivateMessage(m) => ProtocolMessage::PrivateMessage(m),
|
||||
MlsMessageInBody::PublicMessage(m) => ProtocolMessage::PublicMessage(m),
|
||||
_ => return Err(CoreError::Mls("not a protocol message".into())),
|
||||
};
|
||||
|
||||
let processed = group
|
||||
.process_message(&self.backend, protocol_message)
|
||||
.map_err(|e| CoreError::Mls(format!("process_message: {e:?}")))?;
|
||||
|
||||
let sender_identity = processed.credential().identity().to_vec();
|
||||
|
||||
match processed.into_content() {
|
||||
ProcessedMessageContent::ApplicationMessage(app) => {
|
||||
Ok(ReceivedMessageWithSender::Application(sender_identity, app.into_bytes()))
|
||||
}
|
||||
ProcessedMessageContent::StagedCommitMessage(staged) => {
|
||||
let own_index = group.own_leaf_index();
|
||||
let self_removed = staged.remove_proposals().any(|queued| {
|
||||
queued.remove_proposal().removed() == own_index
|
||||
});
|
||||
|
||||
group
|
||||
.merge_staged_commit(&self.backend, *staged)
|
||||
.map_err(|e| CoreError::Mls(format!("merge_staged_commit: {e:?}")))?;
|
||||
|
||||
if self_removed {
|
||||
self.group = None;
|
||||
Ok(ReceivedMessageWithSender::SelfRemoved)
|
||||
} else {
|
||||
Ok(ReceivedMessageWithSender::StateChanged)
|
||||
}
|
||||
}
|
||||
ProcessedMessageContent::ProposalMessage(proposal) => {
|
||||
group.store_pending_proposal(*proposal);
|
||||
Ok(ReceivedMessageWithSender::StateChanged)
|
||||
}
|
||||
ProcessedMessageContent::ExternalJoinProposalMessage(proposal) => {
|
||||
group.store_pending_proposal(*proposal);
|
||||
Ok(ReceivedMessageWithSender::StateChanged)
|
||||
Ok((sender_identity, ReceivedMessage::StateChanged))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -597,7 +570,10 @@ impl GroupMember {
|
||||
}
|
||||
|
||||
/// Return the private seed of the identity (for persistence).
|
||||
pub fn identity_seed(&self) -> [u8; 32] {
|
||||
///
|
||||
/// The returned value is wrapped in `Zeroizing` so it is securely erased
|
||||
/// when dropped.
|
||||
pub fn identity_seed(&self) -> Zeroizing<[u8; 32]> {
|
||||
self.identity.seed_bytes()
|
||||
}
|
||||
|
||||
|
||||
@@ -191,32 +191,22 @@ impl OpenMlsCrypto for HybridCrypto {
|
||||
ptxt: &[u8],
|
||||
) -> HpkeCiphertext {
|
||||
if Self::is_hybrid_public_key(pk_r) {
|
||||
let recipient_pk = match HybridPublicKey::from_bytes(pk_r) {
|
||||
Ok(pk) => pk,
|
||||
// Key parsed as hybrid length but failed to deserialize — this is
|
||||
// a real error, not a reason to silently fall back to classical HPKE.
|
||||
Err(_) => return HpkeCiphertext {
|
||||
kem_output: Vec::new().into(),
|
||||
ciphertext: Vec::new().into(),
|
||||
},
|
||||
};
|
||||
// The trait `OpenMlsCrypto::hpke_seal` returns `HpkeCiphertext` (not
|
||||
// `Result`), so we cannot propagate errors through the return type.
|
||||
// Returning an empty ciphertext would silently cause data loss.
|
||||
// Instead, panic on failure — a hybrid key that passes the length
|
||||
// check but fails deserialization or encryption indicates a critical
|
||||
// bug (corrupted key material), not a recoverable condition.
|
||||
let recipient_pk = HybridPublicKey::from_bytes(pk_r)
|
||||
.expect("hybrid public key deserialization failed — key material is corrupted");
|
||||
// Pass HPKE info and aad through for proper context binding (RFC 9180).
|
||||
match hybrid_encrypt(&recipient_pk, ptxt, info, aad) {
|
||||
Ok(envelope) => {
|
||||
let kem_output = envelope[..HYBRID_KEM_OUTPUT_LEN].to_vec();
|
||||
let ciphertext = envelope[HYBRID_KEM_OUTPUT_LEN..].to_vec();
|
||||
HpkeCiphertext {
|
||||
kem_output: kem_output.into(),
|
||||
ciphertext: ciphertext.into(),
|
||||
}
|
||||
}
|
||||
// Encryption failed with a hybrid key — return empty ciphertext
|
||||
// rather than silently falling back to classical HPKE with an
|
||||
// incompatible key.
|
||||
Err(_) => HpkeCiphertext {
|
||||
kem_output: Vec::new().into(),
|
||||
ciphertext: Vec::new().into(),
|
||||
},
|
||||
let envelope = hybrid_encrypt(&recipient_pk, ptxt, info, aad)
|
||||
.expect("hybrid HPKE encryption failed — critical crypto error");
|
||||
let kem_output = envelope[..HYBRID_KEM_OUTPUT_LEN].to_vec();
|
||||
let ciphertext = envelope[HYBRID_KEM_OUTPUT_LEN..].to_vec();
|
||||
HpkeCiphertext {
|
||||
kem_output: kem_output.into(),
|
||||
ciphertext: ciphertext.into(),
|
||||
}
|
||||
} else {
|
||||
self.rust_crypto.hpke_seal(config, pk_r, info, aad, ptxt)
|
||||
@@ -257,14 +247,11 @@ impl OpenMlsCrypto for HybridCrypto {
|
||||
exporter_length: usize,
|
||||
) -> Result<(Vec<u8>, ExporterSecret), CryptoError> {
|
||||
if Self::is_hybrid_public_key(pk_r) {
|
||||
let recipient_pk = match HybridPublicKey::from_bytes(pk_r) {
|
||||
Ok(pk) => pk,
|
||||
Err(_) => {
|
||||
return self.rust_crypto.hpke_setup_sender_and_export(
|
||||
config, pk_r, info, exporter_context, exporter_length,
|
||||
)
|
||||
}
|
||||
};
|
||||
// A key that passes the hybrid length check but fails deserialization
|
||||
// is corrupted — return an error instead of silently downgrading to
|
||||
// classical crypto (which would defeat PQ protection).
|
||||
let recipient_pk = HybridPublicKey::from_bytes(pk_r)
|
||||
.map_err(|_| CryptoError::SenderSetupError)?;
|
||||
let (kem_output, shared_secret) =
|
||||
hybrid_encapsulate_only(&recipient_pk).map_err(|_| CryptoError::SenderSetupError)?;
|
||||
let exported = hybrid_export(&shared_secret, exporter_context, exporter_length);
|
||||
@@ -302,8 +289,9 @@ impl OpenMlsCrypto for HybridCrypto {
|
||||
fn derive_hpke_keypair(&self, config: HpkeConfig, ikm: &[u8]) -> HpkeKeyPair {
|
||||
if self.hybrid_enabled && config.0 == HpkeKemType::DhKem25519 {
|
||||
let kp = HybridKeypair::derive_from_ikm(ikm);
|
||||
let private_bytes = kp.private_to_bytes();
|
||||
HpkeKeyPair {
|
||||
private: kp.private_to_bytes().into(),
|
||||
private: private_bytes.as_slice().into(),
|
||||
public: kp.public_key().to_bytes(),
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -159,11 +159,14 @@ impl HybridKeypair {
|
||||
}
|
||||
|
||||
/// Serialise private key for MLS key store: x25519_sk(32) || mlkem_dk(2400).
|
||||
pub fn private_to_bytes(&self) -> Vec<u8> {
|
||||
///
|
||||
/// The returned value is wrapped in [`Zeroizing`] so secret key material
|
||||
/// is securely erased when dropped.
|
||||
pub fn private_to_bytes(&self) -> Zeroizing<Vec<u8>> {
|
||||
let mut out = Vec::with_capacity(HYBRID_PRIVATE_KEY_LEN);
|
||||
out.extend_from_slice(self.x25519_sk.as_bytes());
|
||||
out.extend_from_slice(self.mlkem_dk.as_bytes().as_slice());
|
||||
out
|
||||
Zeroizing::new(out)
|
||||
}
|
||||
|
||||
/// Reconstruct a hybrid keypair from private key bytes (from MLS key store).
|
||||
|
||||
@@ -47,8 +47,11 @@ impl IdentityKeypair {
|
||||
}
|
||||
|
||||
/// Return the raw 32-byte private seed (for persistence).
|
||||
pub fn seed_bytes(&self) -> [u8; 32] {
|
||||
*self.seed
|
||||
///
|
||||
/// The returned value is wrapped in [`Zeroizing`] so it is securely
|
||||
/// erased when dropped, preventing the seed from lingering in memory.
|
||||
pub fn seed_bytes(&self) -> Zeroizing<[u8; 32]> {
|
||||
Zeroizing::new(*self.seed)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,18 @@ use openmls_traits::key_store::{MlsEntity, OpenMlsKeyStore};
|
||||
///
|
||||
/// In-memory when `path` is `None`; otherwise flushes the entire map to disk on
|
||||
/// every store/delete so HPKE init keys survive process restarts.
|
||||
///
|
||||
/// # Serialization
|
||||
///
|
||||
/// Uses bincode for both individual MLS entity values and the outer HashMap
|
||||
/// container. This is required because OpenMLS types use bincode-compatible
|
||||
/// serialization, and `HashMap<Vec<u8>, Vec<u8>>` requires a binary format
|
||||
/// (JSON mandates string keys).
|
||||
///
|
||||
/// # Persistence security
|
||||
///
|
||||
/// When `path` is set, file permissions are restricted to owner-only (0o600)
|
||||
/// on Unix platforms, since the store may contain HPKE private keys.
|
||||
#[derive(Debug)]
|
||||
pub struct DiskKeyStore {
|
||||
path: Option<PathBuf>,
|
||||
@@ -42,16 +54,22 @@ impl DiskKeyStore {
|
||||
if bytes.is_empty() {
|
||||
HashMap::new()
|
||||
} else {
|
||||
bincode::deserialize(&bytes).map_err(|_| DiskKeyStoreError::Serialization)?
|
||||
bincode::deserialize(&bytes)
|
||||
.map_err(|_| DiskKeyStoreError::Serialization)?
|
||||
}
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
let store = Self {
|
||||
path: Some(path),
|
||||
values: RwLock::new(values),
|
||||
})
|
||||
};
|
||||
|
||||
// Set restrictive file permissions on the keystore file.
|
||||
store.set_file_permissions()?;
|
||||
|
||||
Ok(store)
|
||||
}
|
||||
|
||||
fn flush(&self) -> Result<(), DiskKeyStoreError> {
|
||||
@@ -63,7 +81,28 @@ impl DiskKeyStore {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?;
|
||||
}
|
||||
fs::write(path, bytes).map_err(|e| DiskKeyStoreError::Io(e.to_string()))
|
||||
fs::write(path, &bytes).map_err(|e| DiskKeyStoreError::Io(e.to_string()))?;
|
||||
self.set_file_permissions()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Restrict file permissions to owner-only (0o600) on Unix.
|
||||
#[cfg(unix)]
|
||||
fn set_file_permissions(&self) -> Result<(), DiskKeyStoreError> {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
if let Some(path) = &self.path {
|
||||
if path.exists() {
|
||||
let perms = std::fs::Permissions::from_mode(0o600);
|
||||
fs::set_permissions(path, perms)
|
||||
.map_err(|e| DiskKeyStoreError::Io(format!("set permissions: {e}")))?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn set_file_permissions(&self) -> Result<(), DiskKeyStoreError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +116,7 @@ impl OpenMlsKeyStore for DiskKeyStore {
|
||||
type Error = DiskKeyStoreError;
|
||||
|
||||
fn store<V: MlsEntity>(&self, k: &[u8], v: &V) -> Result<(), Self::Error> {
|
||||
let value = serde_json::to_vec(v).map_err(|_| DiskKeyStoreError::Serialization)?;
|
||||
let value = bincode::serialize(v).map_err(|_| DiskKeyStoreError::Serialization)?;
|
||||
let mut values = self.values.write().map_err(|_| DiskKeyStoreError::Io("lock poisoned".into()))?;
|
||||
values.insert(k.to_vec(), value);
|
||||
drop(values);
|
||||
@@ -91,7 +130,7 @@ impl OpenMlsKeyStore for DiskKeyStore {
|
||||
};
|
||||
values
|
||||
.get(k)
|
||||
.and_then(|bytes| serde_json::from_slice(bytes).ok())
|
||||
.and_then(|bytes| bincode::deserialize(bytes).ok())
|
||||
}
|
||||
|
||||
fn delete<V: MlsEntity>(&self, k: &[u8]) -> Result<(), Self::Error> {
|
||||
@@ -101,4 +140,3 @@ impl OpenMlsKeyStore for DiskKeyStore {
|
||||
self.flush()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -72,9 +72,12 @@ pub use hybrid_kem::{
|
||||
pub use identity::{verify_delivery_proof, IdentityKeypair};
|
||||
pub use safety_numbers::compute_safety_number;
|
||||
pub use transcript::{
|
||||
read_transcript, verify_transcript_chain, ChainVerdict, DecodedRecord, TranscriptRecord,
|
||||
read_transcript, validate_transcript_structure, ChainVerdict, DecodedRecord, TranscriptRecord,
|
||||
TranscriptWriter,
|
||||
};
|
||||
// Deprecated re-export for backward compatibility.
|
||||
#[allow(deprecated)]
|
||||
pub use transcript::verify_transcript_chain;
|
||||
|
||||
// ── Public API (native only) ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -208,11 +208,17 @@ pub fn read_transcript(
|
||||
Ok((records, verdict))
|
||||
}
|
||||
|
||||
/// Verify the hash chain without decrypting record contents.
|
||||
/// Validate the structural integrity of a transcript file without decrypting.
|
||||
///
|
||||
/// Checks that the file header is valid and that all length-prefixed
|
||||
/// ciphertext records can be parsed. Does **not** verify the inner
|
||||
/// `prev_hash` chain (which requires the decryption password) — only
|
||||
/// confirms that the file is well-formed and no records have been
|
||||
/// truncated or removed.
|
||||
///
|
||||
/// Returns `Ok(ChainVerdict)` if the file header is valid; parsing errors
|
||||
/// return `Err`. The chain verdict indicates whether all hashes matched.
|
||||
pub fn verify_transcript_chain(data: &[u8]) -> Result<ChainVerdict, CoreError> {
|
||||
/// return `Err`.
|
||||
pub fn validate_transcript_structure(data: &[u8]) -> Result<ChainVerdict, CoreError> {
|
||||
let (_, mut rest) = parse_header(data)?;
|
||||
|
||||
let mut expected_prev: [u8; 32] = [0u8; 32];
|
||||
@@ -250,6 +256,12 @@ pub fn verify_transcript_chain(data: &[u8]) -> Result<ChainVerdict, CoreError> {
|
||||
Ok(ChainVerdict::Ok { records: count })
|
||||
}
|
||||
|
||||
/// Deprecated alias for [`validate_transcript_structure`].
|
||||
#[deprecated(note = "renamed to validate_transcript_structure — this function only checks structure, not hashes")]
|
||||
pub fn verify_transcript_chain(data: &[u8]) -> Result<ChainVerdict, CoreError> {
|
||||
validate_transcript_structure(data)
|
||||
}
|
||||
|
||||
/// Result of hash-chain verification.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ChainVerdict {
|
||||
@@ -515,7 +527,7 @@ mod tests {
|
||||
.expect("write");
|
||||
}
|
||||
|
||||
let verdict = verify_transcript_chain(&buf).expect("verify");
|
||||
let verdict = validate_transcript_structure(&buf).expect("verify");
|
||||
assert_eq!(verdict, ChainVerdict::Ok { records: 5 });
|
||||
}
|
||||
|
||||
@@ -537,7 +549,7 @@ mod tests {
|
||||
|
||||
// Truncate the last few bytes — should fail parsing.
|
||||
let truncated = &buf[..buf.len() - 5];
|
||||
let result = verify_transcript_chain(truncated);
|
||||
let result = validate_transcript_structure(truncated);
|
||||
assert!(result.is_err(), "truncated file must be detected");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ crate-type = ["cdylib", "staticlib"]
|
||||
quicproquo-client = { path = "../quicproquo-client" }
|
||||
tokio = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
capnp = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
|
||||
|
||||
@@ -40,6 +40,42 @@ impl QpqHandle {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error classification
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Classify an `anyhow::Error` from `cmd_login` into an FFI status code.
|
||||
///
|
||||
/// Checks the error chain for typed downcasting before falling back to
|
||||
/// message-based heuristics.
|
||||
fn classify_login_error(err: &anyhow::Error) -> i32 {
|
||||
// Check error chain for OPAQUE-specific typed errors.
|
||||
for cause in err.chain() {
|
||||
// capnp::Error indicates transport/RPC failure.
|
||||
if cause.downcast_ref::<capnp::Error>().is_some() {
|
||||
return QPQ_ERROR;
|
||||
}
|
||||
}
|
||||
// Fall back to message inspection for OPAQUE authentication failures,
|
||||
// since opaque-ke errors are converted to anyhow strings upstream.
|
||||
let msg = format!("{err:#}");
|
||||
if msg.contains("OPAQUE") || msg.contains("bad password") || msg.contains("credential") {
|
||||
QPQ_AUTH_FAILED
|
||||
} else {
|
||||
QPQ_ERROR
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify an `anyhow::Error` from receive operations into an FFI status code.
|
||||
fn classify_receive_error(err: &anyhow::Error) -> i32 {
|
||||
let msg = format!("{err:#}");
|
||||
if msg.contains("timeout") || msg.contains("Timeout") || msg.contains("timed out") {
|
||||
QPQ_TIMEOUT
|
||||
} else {
|
||||
QPQ_ERROR
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -180,13 +216,9 @@ pub unsafe extern "C" fn qpq_login(
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = format!("{e:#}");
|
||||
if msg.contains("auth") || msg.contains("OPAQUE") || msg.contains("credential") {
|
||||
h.set_error(&msg);
|
||||
QPQ_AUTH_FAILED
|
||||
} else {
|
||||
h.set_error(&msg);
|
||||
QPQ_ERROR
|
||||
}
|
||||
let code = classify_login_error(&e);
|
||||
h.set_error(&msg);
|
||||
code
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -345,13 +377,9 @@ pub unsafe extern "C" fn qpq_receive(
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = format!("{e:#}");
|
||||
if msg.contains("timeout") || msg.contains("Timeout") {
|
||||
h.set_error(&msg);
|
||||
QPQ_TIMEOUT
|
||||
} else {
|
||||
h.set_error(&msg);
|
||||
QPQ_ERROR
|
||||
}
|
||||
let code = classify_receive_error(&e);
|
||||
h.set_error(&msg);
|
||||
code
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,12 @@ edition = "2021"
|
||||
description = "C FFI layer for quicproquo, proving QUIC connection migration."
|
||||
license = "MIT"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
# Enable SkipServerVerification for development/testing only.
|
||||
# NEVER enable in production — this disables TLS certificate validation.
|
||||
insecure-dev = []
|
||||
|
||||
[lib]
|
||||
crate-type = ["staticlib", "cdylib", "rlib"]
|
||||
|
||||
@@ -16,6 +22,9 @@ tokio = { workspace = true }
|
||||
quinn = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
|
||||
# TLS root certificates (used when insecure-dev is NOT enabled)
|
||||
webpki-roots = "0.26"
|
||||
|
||||
# Error handling
|
||||
anyhow = { workspace = true }
|
||||
|
||||
|
||||
@@ -89,11 +89,7 @@ async fn connect_inner(
|
||||
) -> anyhow::Result<(Endpoint, quinn::Connection)> {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
// Build a permissive client config (skip server cert verification for dev/testing).
|
||||
let crypto = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
|
||||
.with_no_client_auth();
|
||||
let crypto = build_client_tls_config()?;
|
||||
|
||||
let mut client_config = quinn::ClientConfig::new(Arc::new(
|
||||
quinn::crypto::rustls::QuicClientConfig::try_from(crypto)
|
||||
@@ -159,11 +155,36 @@ pub unsafe extern "C" fn qnpc_disconnect(handle: *mut MobileHandle) {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Internal: skip server cert verification for testing ─────────────────────
|
||||
// ── TLS configuration ───────────────────────────────────────────────────────
|
||||
|
||||
/// Build the rustls `ClientConfig` for the QUIC transport.
|
||||
///
|
||||
/// Without the `insecure-dev` feature, this uses the platform's native root
|
||||
/// certificates for server verification. With `insecure-dev` enabled, all
|
||||
/// certificate verification is skipped (MITM-vulnerable — dev/testing only).
|
||||
fn build_client_tls_config() -> anyhow::Result<rustls::ClientConfig> {
|
||||
#[cfg(feature = "insecure-dev")]
|
||||
{
|
||||
Ok(rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
|
||||
.with_no_client_auth())
|
||||
}
|
||||
#[cfg(not(feature = "insecure-dev"))]
|
||||
{
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
Ok(rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "insecure-dev")]
|
||||
#[derive(Debug)]
|
||||
struct SkipServerVerification;
|
||||
|
||||
#[cfg(feature = "insecure-dev")]
|
||||
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
@@ -216,6 +237,87 @@ mod tests {
|
||||
use super::*;
|
||||
use std::net::UdpSocket;
|
||||
|
||||
/// Test-only insecure verifier (always available in test builds).
|
||||
#[derive(Debug)]
|
||||
struct TestSkipServerVerification;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for TestSkipServerVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a test server using the insecure cert verifier.
|
||||
async fn test_connect_inner(
|
||||
addr: SocketAddr,
|
||||
server_name: &str,
|
||||
) -> anyhow::Result<(Endpoint, quinn::Connection)> {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
let crypto = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(TestSkipServerVerification))
|
||||
.with_no_client_auth();
|
||||
|
||||
let mut client_config = quinn::ClientConfig::new(Arc::new(
|
||||
quinn::crypto::rustls::QuicClientConfig::try_from(crypto)
|
||||
.map_err(|e| anyhow::anyhow!("QUIC client config: {e}"))?,
|
||||
));
|
||||
|
||||
let mut transport = quinn::TransportConfig::default();
|
||||
transport.max_idle_timeout(Some(
|
||||
std::time::Duration::from_secs(120)
|
||||
.try_into()
|
||||
.expect("120s valid"),
|
||||
));
|
||||
client_config.transport_config(Arc::new(transport));
|
||||
|
||||
let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())?;
|
||||
endpoint.set_default_client_config(client_config);
|
||||
|
||||
let connection = endpoint.connect(addr, server_name)?.await?;
|
||||
Ok((endpoint, connection))
|
||||
}
|
||||
|
||||
/// Prove QUIC connection migration: connect, send messages, rebind the
|
||||
/// UDP socket (simulating wifi→cellular), send more messages, verify
|
||||
/// all messages arrive.
|
||||
@@ -228,8 +330,8 @@ mod tests {
|
||||
// Start an in-process echo server.
|
||||
let server_addr = start_echo_server().await;
|
||||
|
||||
// Connect client.
|
||||
let (endpoint, connection) = connect_inner(server_addr, "localhost")
|
||||
// Connect client using test-only insecure verifier.
|
||||
let (endpoint, connection) = test_connect_inner(server_addr, "localhost")
|
||||
.await
|
||||
.expect("connect");
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ hex = { workspace = true }
|
||||
# Broadcast channels (ChaCha20-Poly1305 symmetric encryption)
|
||||
chacha20poly1305 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
zeroize = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
|
||||
@@ -12,6 +12,7 @@ use std::collections::HashMap;
|
||||
use chacha20poly1305::aead::{Aead, AeadCore, KeyInit};
|
||||
use chacha20poly1305::ChaCha20Poly1305;
|
||||
use rand::rngs::OsRng;
|
||||
use zeroize::{Zeroize, ZeroizeOnDrop};
|
||||
|
||||
/// A single broadcast channel identified by topic, secured with a symmetric key.
|
||||
pub struct BroadcastChannel {
|
||||
@@ -19,6 +20,14 @@ pub struct BroadcastChannel {
|
||||
key: [u8; 32],
|
||||
}
|
||||
|
||||
impl Drop for BroadcastChannel {
|
||||
fn drop(&mut self) {
|
||||
self.key.zeroize();
|
||||
}
|
||||
}
|
||||
|
||||
impl ZeroizeOnDrop for BroadcastChannel {}
|
||||
|
||||
impl BroadcastChannel {
|
||||
/// Create a new channel with a random ChaCha20-Poly1305 key.
|
||||
pub fn new(topic: &str) -> Self {
|
||||
@@ -39,16 +48,16 @@ impl BroadcastChannel {
|
||||
}
|
||||
|
||||
/// Encrypt `plaintext`, returning `nonce || ciphertext`.
|
||||
pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
|
||||
pub fn encrypt(&self, plaintext: &[u8]) -> anyhow::Result<Vec<u8>> {
|
||||
let cipher = ChaCha20Poly1305::new((&self.key).into());
|
||||
let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng);
|
||||
let ciphertext = cipher
|
||||
.encrypt(&nonce, plaintext)
|
||||
.expect("ChaCha20Poly1305 encryption should not fail for valid inputs");
|
||||
.map_err(|_| anyhow::anyhow!("ChaCha20Poly1305 encryption failed"))?;
|
||||
let mut out = Vec::with_capacity(nonce.len() + ciphertext.len());
|
||||
out.extend_from_slice(&nonce);
|
||||
out.extend_from_slice(&ciphertext);
|
||||
out
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Decrypt data produced by [`encrypt`](Self::encrypt).
|
||||
@@ -121,7 +130,7 @@ impl BroadcastManager {
|
||||
}
|
||||
|
||||
/// Encrypt a message on the given topic. Returns `None` if not subscribed.
|
||||
pub fn encrypt(&self, topic: &str, plaintext: &[u8]) -> Option<Vec<u8>> {
|
||||
pub fn encrypt(&self, topic: &str, plaintext: &[u8]) -> Option<anyhow::Result<Vec<u8>>> {
|
||||
self.channels.get(topic).map(|ch| ch.encrypt(plaintext))
|
||||
}
|
||||
|
||||
@@ -147,7 +156,7 @@ mod tests {
|
||||
fn encrypt_decrypt_roundtrip() {
|
||||
let ch = BroadcastChannel::new("test-topic");
|
||||
let plaintext = b"hello broadcast";
|
||||
let encrypted = ch.encrypt(plaintext);
|
||||
let encrypted = ch.encrypt(plaintext).expect("encrypt");
|
||||
let decrypted = ch.decrypt(&encrypted).expect("decrypt");
|
||||
assert_eq!(decrypted, plaintext);
|
||||
}
|
||||
@@ -156,7 +165,7 @@ mod tests {
|
||||
fn wrong_key_fails_decrypt() {
|
||||
let ch1 = BroadcastChannel::new("topic");
|
||||
let ch2 = BroadcastChannel::new("topic"); // different random key
|
||||
let encrypted = ch1.encrypt(b"secret");
|
||||
let encrypted = ch1.encrypt(b"secret").expect("encrypt");
|
||||
let result = ch2.decrypt(&encrypted);
|
||||
assert!(result.is_err(), "wrong key should fail decryption");
|
||||
}
|
||||
@@ -165,7 +174,7 @@ mod tests {
|
||||
fn with_key_roundtrip() {
|
||||
let key = [42u8; 32];
|
||||
let ch = BroadcastChannel::with_key("shared", key);
|
||||
let ct = ch.encrypt(b"data");
|
||||
let ct = ch.encrypt(b"data").expect("encrypt");
|
||||
let ch2 = BroadcastChannel::with_key("shared", key);
|
||||
let pt = ch2.decrypt(&ct).expect("same key should decrypt");
|
||||
assert_eq!(pt, b"data");
|
||||
@@ -194,7 +203,7 @@ mod tests {
|
||||
assert_eq!(ch.topic(), "news");
|
||||
|
||||
// Encrypt via manager, decrypt manually with the same key.
|
||||
let ct = mgr.encrypt("news", b"headline").expect("encrypt");
|
||||
let ct = mgr.encrypt("news", b"headline").expect("subscribed").expect("encrypt");
|
||||
let ch2 = BroadcastChannel::with_key("news", key);
|
||||
let pt = ch2.decrypt(&ct).expect("decrypt");
|
||||
assert_eq!(pt, b"headline");
|
||||
@@ -205,7 +214,7 @@ mod tests {
|
||||
let mut mgr = BroadcastManager::new();
|
||||
mgr.subscribe("ch1", [7u8; 32]);
|
||||
|
||||
let ct = mgr.encrypt("ch1", b"round-trip").expect("encrypt");
|
||||
let ct = mgr.encrypt("ch1", b"round-trip").expect("subscribed").expect("encrypt");
|
||||
let pt = mgr.decrypt("ch1", &ct).expect("decrypt");
|
||||
assert_eq!(pt, b"round-trip");
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ impl MeshEnvelope {
|
||||
};
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("system clock before UNIX epoch")
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let id = Self::compute_id(
|
||||
@@ -67,7 +67,7 @@ impl MeshEnvelope {
|
||||
timestamp,
|
||||
);
|
||||
|
||||
let signable = Self::signable_bytes(&id, &sender_key, &recipient_key, &payload, ttl_secs, hop_count, max_hops, timestamp);
|
||||
let signable = Self::signable_bytes(&id, &sender_key, &recipient_key, &payload, ttl_secs, max_hops, timestamp);
|
||||
let signature = identity.sign(&signable).to_vec();
|
||||
|
||||
Self {
|
||||
@@ -103,23 +103,25 @@ impl MeshEnvelope {
|
||||
}
|
||||
|
||||
/// Assemble the byte string that is signed / verified.
|
||||
///
|
||||
/// `hop_count` is intentionally excluded: forwarding nodes increment it
|
||||
/// without re-signing, so including it would invalidate the sender's
|
||||
/// original signature on every hop.
|
||||
fn signable_bytes(
|
||||
id: &[u8; 32],
|
||||
sender_key: &[u8],
|
||||
recipient_key: &[u8],
|
||||
payload: &[u8],
|
||||
ttl_secs: u32,
|
||||
hop_count: u8,
|
||||
max_hops: u8,
|
||||
timestamp: u64,
|
||||
) -> Vec<u8> {
|
||||
let mut buf = Vec::with_capacity(32 + sender_key.len() + recipient_key.len() + payload.len() + 14);
|
||||
let mut buf = Vec::with_capacity(32 + sender_key.len() + recipient_key.len() + payload.len() + 13);
|
||||
buf.extend_from_slice(id);
|
||||
buf.extend_from_slice(sender_key);
|
||||
buf.extend_from_slice(recipient_key);
|
||||
buf.extend_from_slice(payload);
|
||||
buf.extend_from_slice(&ttl_secs.to_le_bytes());
|
||||
buf.push(hop_count);
|
||||
buf.push(max_hops);
|
||||
buf.extend_from_slice(×tamp.to_le_bytes());
|
||||
buf
|
||||
@@ -144,7 +146,6 @@ impl MeshEnvelope {
|
||||
&self.recipient_key,
|
||||
&self.payload,
|
||||
self.ttl_secs,
|
||||
self.hop_count,
|
||||
self.max_hops,
|
||||
self.timestamp,
|
||||
);
|
||||
@@ -155,7 +156,7 @@ impl MeshEnvelope {
|
||||
pub fn is_expired(&self) -> bool {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("system clock before UNIX epoch")
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
now.saturating_sub(self.timestamp) > self.ttl_secs as u64
|
||||
}
|
||||
@@ -243,6 +244,30 @@ mod tests {
|
||||
assert!(!fwd2.can_forward()); // hop_count == max_hops
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarded_envelope_still_verifies() {
|
||||
let id = test_identity();
|
||||
let env = MeshEnvelope::new(&id, &[0xAA; 32], b"fwd-verify".to_vec(), 3600, 5);
|
||||
assert!(env.verify(), "original must verify");
|
||||
|
||||
let fwd = env.forwarded();
|
||||
assert_eq!(fwd.hop_count, 1);
|
||||
assert!(fwd.verify(), "forwarded envelope must still verify (hop_count excluded from signature)");
|
||||
|
||||
let fwd2 = fwd.forwarded();
|
||||
assert!(fwd2.verify(), "double-forwarded must still verify");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_with_wrong_key_fails() {
|
||||
let id = test_identity();
|
||||
let mut env = MeshEnvelope::new(&id, &[0xBB; 32], b"wrong-key".to_vec(), 3600, 5);
|
||||
// Replace sender_key with a different key
|
||||
let other = test_identity();
|
||||
env.sender_key = other.public_key().to_vec();
|
||||
assert!(!env.verify(), "wrong sender key must fail verification");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialization_roundtrip() {
|
||||
let id = test_identity();
|
||||
|
||||
@@ -10,6 +10,9 @@ use std::path::Path;
|
||||
use quicproquo_core::IdentityKeypair;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
/// Information about a known peer in the mesh network.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PeerInfo {
|
||||
@@ -68,14 +71,25 @@ impl MeshIdentity {
|
||||
})
|
||||
}
|
||||
|
||||
/// Save this mesh identity to a JSON file.
|
||||
/// Save this mesh identity to a JSON file with restrictive permissions.
|
||||
///
|
||||
/// On Unix, the file is set to `0o600` (owner read/write only) since it
|
||||
/// contains the Ed25519 seed in the clear.
|
||||
pub fn save(&self, path: &Path) -> anyhow::Result<()> {
|
||||
let file = IdentityFile {
|
||||
seed: hex::encode(self.keypair.seed_bytes()),
|
||||
seed: hex::encode(&*self.keypair.seed_bytes()),
|
||||
peers: self.known_peers.clone(),
|
||||
};
|
||||
let json = serde_json::to_string_pretty(&file)?;
|
||||
std::fs::write(path, json)?;
|
||||
|
||||
// Restrict permissions to owner-only on Unix.
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let perms = std::fs::Permissions::from_mode(0o600);
|
||||
std::fs::set_permissions(path, perms)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -91,7 +105,7 @@ impl MeshIdentity {
|
||||
|
||||
/// Return the underlying seed (for deriving iroh `SecretKey`, etc.).
|
||||
pub fn seed_bytes(&self) -> [u8; 32] {
|
||||
self.keypair.seed_bytes()
|
||||
*self.keypair.seed_bytes()
|
||||
}
|
||||
|
||||
/// Register or update a known peer.
|
||||
|
||||
@@ -310,7 +310,7 @@ impl P2pNode {
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("broadcast manager lock poisoned: {e}"))?;
|
||||
mgr.encrypt(topic, payload)
|
||||
.ok_or_else(|| anyhow::anyhow!("not subscribed to topic: {topic}"))?
|
||||
.ok_or_else(|| anyhow::anyhow!("not subscribed to topic: {topic}"))??
|
||||
};
|
||||
|
||||
// Create a broadcast envelope (empty recipient_key signals broadcast).
|
||||
|
||||
@@ -3,19 +3,25 @@
|
||||
//! [`MeshStore`] buffers [`MeshEnvelope`]s for offline recipients and
|
||||
//! provides deduplication and automatic garbage collection of expired messages.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
|
||||
use crate::envelope::MeshEnvelope;
|
||||
|
||||
/// Default maximum messages stored per recipient.
|
||||
const DEFAULT_MAX_STORED: usize = 1000;
|
||||
|
||||
/// Maximum number of envelope IDs retained in the seen set for deduplication.
|
||||
/// Once exceeded, the oldest IDs are evicted to bound memory growth.
|
||||
const MAX_SEEN_IDS: usize = 100_000;
|
||||
|
||||
/// In-memory store-and-forward queue keyed by recipient public key.
|
||||
pub struct MeshStore {
|
||||
/// Recipient public key -> queued envelopes.
|
||||
inbox: HashMap<Vec<u8>, Vec<MeshEnvelope>>,
|
||||
/// Set of envelope IDs already processed (deduplication).
|
||||
seen: HashSet<[u8; 32]>,
|
||||
/// Insertion-ordered queue of seen IDs for bounded eviction.
|
||||
seen_order: VecDeque<[u8; 32]>,
|
||||
/// Maximum envelopes held per recipient.
|
||||
max_stored: usize,
|
||||
}
|
||||
@@ -28,6 +34,7 @@ impl MeshStore {
|
||||
Self {
|
||||
inbox: HashMap::new(),
|
||||
seen: HashSet::new(),
|
||||
seen_order: VecDeque::new(),
|
||||
max_stored: if max_stored == 0 {
|
||||
DEFAULT_MAX_STORED
|
||||
} else {
|
||||
@@ -50,6 +57,15 @@ impl MeshStore {
|
||||
return false;
|
||||
}
|
||||
self.seen.insert(envelope.id);
|
||||
self.seen_order.push_back(envelope.id);
|
||||
|
||||
// Evict oldest seen IDs if the set exceeds the bound.
|
||||
while self.seen_order.len() > MAX_SEEN_IDS {
|
||||
if let Some(old_id) = self.seen_order.pop_front() {
|
||||
self.seen.remove(&old_id);
|
||||
}
|
||||
}
|
||||
|
||||
queue.push(envelope);
|
||||
true
|
||||
}
|
||||
|
||||
@@ -182,10 +182,21 @@ pub struct HookVTable {
|
||||
pub destroy: Option<unsafe extern "C" fn(user_data: *mut core::ffi::c_void)>,
|
||||
}
|
||||
|
||||
// Safety: user_data is an opaque pointer managed by the plugin. The plugin is
|
||||
// responsible for its own thread safety. The server only calls hook functions
|
||||
// one at a time per plugin (wrapped in a single Arc). Plugins that mutate
|
||||
// user_data through callbacks must use interior mutability.
|
||||
// SAFETY: `HookVTable` contains raw pointers (`user_data`, function pointers)
|
||||
// which are not inherently `Send`/`Sync`. These impls are sound because:
|
||||
//
|
||||
// 1. `user_data` is an opaque pointer managed entirely by the plugin. The plugin
|
||||
// contract (documented in the module-level doc comment) requires that plugins
|
||||
// use interior mutability (Mutex/RwLock) if `user_data` is mutated through
|
||||
// callbacks. The server wraps each loaded plugin in an `Arc<HookVTable>` and
|
||||
// may invoke hooks from any Tokio worker thread.
|
||||
//
|
||||
// 2. All function pointers are `unsafe extern "C" fn` — they are plain addresses
|
||||
// with no captured state. The code they point to must be thread-safe per the
|
||||
// plugin contract.
|
||||
//
|
||||
// 3. The server guarantees that `destroy` is called exactly once during shutdown,
|
||||
// after which no further hook calls are made on the vtable.
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Send for HookVTable {}
|
||||
#[allow(unsafe_code)]
|
||||
|
||||
@@ -61,12 +61,27 @@ pub fn to_bytes<A: capnp::message::Allocator>(
|
||||
|
||||
/// Deserialise unpacked wire bytes into a message with owned segments.
|
||||
///
|
||||
/// Uses `ReaderOptions::new()` (default limits: 64 MiB, 512 nesting levels).
|
||||
/// Callers that receive data from untrusted peers should consider tightening
|
||||
/// the traversal limit via `ReaderOptions::traversal_limit_in_words`.
|
||||
/// Uses a stricter default traversal limit of 1 Mi words (~8 MiB) instead
|
||||
/// of the Cap'n Proto default of 64 MiB, reducing DoS amplification from
|
||||
/// untrusted input. Use [`from_bytes_with_options`] if you need a custom limit.
|
||||
pub fn from_bytes(
|
||||
bytes: &[u8],
|
||||
) -> Result<capnp::message::Reader<capnp::serialize::OwnedSegments>, capnp::Error> {
|
||||
let mut options = capnp::message::ReaderOptions::new();
|
||||
options.traversal_limit_in_words(Some(1_048_576)); // 1 Mi words = ~8 MiB
|
||||
let mut cursor = std::io::Cursor::new(bytes);
|
||||
capnp::serialize::read_message(&mut cursor, capnp::message::ReaderOptions::new())
|
||||
capnp::serialize::read_message(&mut cursor, options)
|
||||
}
|
||||
|
||||
/// Deserialise unpacked wire bytes with caller-specified [`ReaderOptions`].
|
||||
///
|
||||
/// Prefer [`from_bytes`] for typical use. Use this variant when you need to
|
||||
/// raise the traversal limit for large messages (e.g. blob transfers) or
|
||||
/// lower it further for tighter validation.
|
||||
pub fn from_bytes_with_options(
|
||||
bytes: &[u8],
|
||||
options: capnp::message::ReaderOptions,
|
||||
) -> Result<capnp::message::Reader<capnp::serialize::OwnedSegments>, capnp::Error> {
|
||||
let mut cursor = std::io::Cursor::new(bytes);
|
||||
capnp::serialize::read_message(&mut cursor, options)
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ quinn = { workspace = true }
|
||||
quinn-proto = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
rcgen = { workspace = true }
|
||||
x509-parser = { workspace = true }
|
||||
|
||||
# Crypto — OPAQUE PAKE
|
||||
opaque-ke = { workspace = true }
|
||||
@@ -58,6 +59,10 @@ serde_json = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
toml = { version = "0.8" }
|
||||
|
||||
# WebSocket JSON-RPC bridge for browser clients
|
||||
tokio-tungstenite = "0.26"
|
||||
base64 = "0.22"
|
||||
|
||||
# Metrics (Prometheus)
|
||||
metrics = "0.22"
|
||||
metrics-exporter-prometheus = "0.15"
|
||||
|
||||
@@ -178,6 +178,49 @@ pub fn validate_auth_context(
|
||||
Err(crate::error_codes::coded_error(E003_INVALID_TOKEN, "invalid accessToken"))
|
||||
}
|
||||
|
||||
/// Validate a raw bearer token (no Cap'n Proto dependency).
|
||||
/// Used by the WebSocket JSON-RPC bridge.
|
||||
pub fn validate_token_raw(
|
||||
cfg: &AuthConfig,
|
||||
sessions: &DashMap<Vec<u8>, SessionInfo>,
|
||||
token: &[u8],
|
||||
) -> Result<AuthContext, String> {
|
||||
if token.is_empty() {
|
||||
return Err("empty access token".to_string());
|
||||
}
|
||||
|
||||
// Check static bearer token.
|
||||
if let Some(expected) = &cfg.required_token {
|
||||
if expected.len() == token.len() && bool::from(expected.as_slice().ct_eq(token)) {
|
||||
return Ok(AuthContext {
|
||||
token: token.to_vec(),
|
||||
identity_key: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check session tokens.
|
||||
if let Some(session) = sessions.get(token) {
|
||||
let now = current_timestamp();
|
||||
if session.expires_at > now {
|
||||
let identity = if session.identity_key.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(session.identity_key.clone())
|
||||
};
|
||||
return Ok(AuthContext {
|
||||
token: token.to_vec(),
|
||||
identity_key: identity,
|
||||
});
|
||||
}
|
||||
drop(session);
|
||||
sessions.remove(token);
|
||||
return Err("session token has expired".to_string());
|
||||
}
|
||||
|
||||
Err("invalid access token".to_string())
|
||||
}
|
||||
|
||||
pub fn require_identity(auth_ctx: &AuthContext) -> Result<&[u8], capnp::Error> {
|
||||
match auth_ctx.identity_key.as_deref() {
|
||||
Some(ik) => Ok(ik),
|
||||
|
||||
@@ -36,6 +36,8 @@ pub struct FileConfig {
|
||||
/// When true, audit logs hash identity key prefixes and omit payload sizes.
|
||||
#[serde(default)]
|
||||
pub redact_logs: Option<bool>,
|
||||
/// WebSocket JSON-RPC bridge listen address (e.g. "0.0.0.0:9000").
|
||||
pub ws_listen: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -60,6 +62,8 @@ pub struct EffectiveConfig {
|
||||
pub plugin_dir: Option<PathBuf>,
|
||||
/// When true, audit logs hash identity key prefixes and omit payload sizes.
|
||||
pub redact_logs: bool,
|
||||
/// WebSocket JSON-RPC bridge listen address. If set, the bridge is started.
|
||||
pub ws_listen: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
@@ -225,6 +229,10 @@ pub fn merge_config(args: &crate::Args, file: &FileConfig) -> EffectiveConfig {
|
||||
|
||||
let plugin_dir = args.plugin_dir.clone().or_else(|| file.plugin_dir.clone());
|
||||
let redact_logs = args.redact_logs || file.redact_logs.unwrap_or(false);
|
||||
let ws_listen = args
|
||||
.ws_listen
|
||||
.clone()
|
||||
.or_else(|| file.ws_listen.clone());
|
||||
|
||||
EffectiveConfig {
|
||||
listen,
|
||||
@@ -242,6 +250,7 @@ pub fn merge_config(args: &crate::Args, file: &FileConfig) -> EffectiveConfig {
|
||||
federation,
|
||||
plugin_dir,
|
||||
redact_logs,
|
||||
ws_listen,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,21 +2,99 @@
|
||||
//!
|
||||
//! Delegates all operations to the local [`Store`], acting as a trusted relay
|
||||
//! from authenticated peer servers.
|
||||
//!
|
||||
//! **Security:** Each handler validates the request's `origin` field against
|
||||
//! the `verified_peer_domain` extracted from the mTLS client certificate at
|
||||
//! connection time. Per-peer rate limits prevent abuse.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use capnp::capability::Promise;
|
||||
use dashmap::DashMap;
|
||||
use quicproquo_proto::federation_capnp::federation_service;
|
||||
use tokio::sync::Notify;
|
||||
use dashmap::DashMap;
|
||||
|
||||
use crate::auth::RateEntry;
|
||||
use crate::storage::Store;
|
||||
|
||||
/// Per-peer federation rate limit: max requests within a 60-second window.
|
||||
const FED_RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||||
const FED_RATE_LIMIT_MAX: u32 = 200;
|
||||
|
||||
/// Inbound federation RPC handler.
|
||||
pub struct FederationServiceImpl {
|
||||
pub store: Arc<dyn Store>,
|
||||
pub waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
pub local_domain: String,
|
||||
/// The peer domain extracted from the mTLS client certificate's CN/SAN
|
||||
/// at connection time. All requests must declare an `origin` matching this.
|
||||
pub verified_peer_domain: Option<String>,
|
||||
/// Per-peer rate limiter (keyed by peer domain).
|
||||
pub rate_limits: Arc<DashMap<String, RateEntry>>,
|
||||
}
|
||||
|
||||
/// Validate that the request's `origin` matches the mTLS-verified peer domain.
|
||||
fn validate_origin(
|
||||
verified: &Option<String>,
|
||||
declared: &str,
|
||||
) -> Result<(), capnp::Error> {
|
||||
match verified {
|
||||
Some(ref expected) if expected == declared => Ok(()),
|
||||
Some(ref expected) => Err(capnp::Error::failed(format!(
|
||||
"federation auth: origin '{}' does not match mTLS cert '{}'",
|
||||
declared, expected
|
||||
))),
|
||||
None => Err(capnp::Error::failed(
|
||||
"federation auth: no verified peer domain (mTLS required)".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract and validate the origin string from the request's auth field.
|
||||
fn extract_and_validate_origin(
|
||||
service: &FederationServiceImpl,
|
||||
get_auth: Result<quicproquo_proto::federation_capnp::federation_auth::Reader<'_>, capnp::Error>,
|
||||
) -> Result<String, capnp::Error> {
|
||||
let auth = get_auth
|
||||
.map_err(|_| capnp::Error::failed("federation auth: missing auth field".into()))?;
|
||||
let origin_reader = auth.get_origin()
|
||||
.map_err(|_| capnp::Error::failed("federation auth: missing origin".into()))?;
|
||||
let origin = origin_reader.to_str()
|
||||
.map_err(|_| capnp::Error::failed("federation auth: origin is not valid UTF-8".into()))?;
|
||||
|
||||
if origin.is_empty() {
|
||||
return Err(capnp::Error::failed("federation auth: origin must not be empty".into()));
|
||||
}
|
||||
|
||||
validate_origin(&service.verified_peer_domain, origin)?;
|
||||
check_federation_rate_limit(&service.rate_limits, origin)?;
|
||||
|
||||
Ok(origin.to_string())
|
||||
}
|
||||
|
||||
/// Per-peer federation rate limiter.
|
||||
fn check_federation_rate_limit(
|
||||
rate_limits: &DashMap<String, RateEntry>,
|
||||
peer_domain: &str,
|
||||
) -> Result<(), capnp::Error> {
|
||||
let now = crate::auth::current_timestamp();
|
||||
let mut entry = rate_limits.entry(peer_domain.to_string()).or_insert(RateEntry {
|
||||
count: 0,
|
||||
window_start: now,
|
||||
});
|
||||
|
||||
if now - entry.window_start >= FED_RATE_LIMIT_WINDOW_SECS {
|
||||
entry.count = 1;
|
||||
entry.window_start = now;
|
||||
} else {
|
||||
entry.count += 1;
|
||||
if entry.count > FED_RATE_LIMIT_MAX {
|
||||
return Err(capnp::Error::failed(format!(
|
||||
"federation rate limit exceeded for peer '{peer_domain}'"
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl federation_service::Server for FederationServiceImpl {
|
||||
@@ -30,6 +108,12 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
// Validate origin against mTLS cert and apply rate limit.
|
||||
let origin = match extract_and_validate_origin(self, p.get_auth()) {
|
||||
Ok(o) => o,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
let recipient_key = match p.get_recipient_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad recipient_key: {e}"))),
|
||||
@@ -40,13 +124,6 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
};
|
||||
let channel_id = p.get_channel_id().unwrap_or_default().to_vec();
|
||||
|
||||
if let Ok(a) = p.get_auth() {
|
||||
if let Ok(origin) = a.get_origin() {
|
||||
let origin = origin.to_str().unwrap_or("?");
|
||||
tracing::debug!(origin = origin, "federation relay_enqueue");
|
||||
}
|
||||
}
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return Promise::err(capnp::Error::failed("recipient_key must be 32 bytes".into()));
|
||||
}
|
||||
@@ -67,6 +144,7 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
origin = %origin,
|
||||
recipient_prefix = %hex::encode(&recipient_key[..4]),
|
||||
seq = seq,
|
||||
"federation: relayed enqueue"
|
||||
@@ -85,6 +163,12 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
// Validate origin against mTLS cert and apply rate limit.
|
||||
let _origin = match extract_and_validate_origin(self, p.get_auth()) {
|
||||
Ok(o) => o,
|
||||
Err(e) => return Promise::err(e),
|
||||
};
|
||||
|
||||
let recipient_keys = match p.get_recipient_keys() {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad recipient_keys: {e}"))),
|
||||
@@ -134,11 +218,21 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
params: federation_service::ProxyFetchKeyPackageParams,
|
||||
mut results: federation_service::ProxyFetchKeyPackageResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let identity_key = match params.get().and_then(|p| p.get_identity_key()) {
|
||||
Ok(v) => v.to_vec(),
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
// Validate origin against mTLS cert and apply rate limit.
|
||||
if let Err(e) = extract_and_validate_origin(self, p.get_auth()) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let identity_key = match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad identity_key: {e}"))),
|
||||
};
|
||||
|
||||
match self.store.fetch_key_package(&identity_key) {
|
||||
Ok(Some(pkg)) => results.get().set_package(&pkg),
|
||||
Ok(None) => results.get().set_package(&[]),
|
||||
@@ -153,11 +247,21 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
params: federation_service::ProxyFetchHybridKeyParams,
|
||||
mut results: federation_service::ProxyFetchHybridKeyResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let identity_key = match params.get().and_then(|p| p.get_identity_key()) {
|
||||
Ok(v) => v.to_vec(),
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
// Validate origin against mTLS cert and apply rate limit.
|
||||
if let Err(e) = extract_and_validate_origin(self, p.get_auth()) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let identity_key = match p.get_identity_key() {
|
||||
Ok(v) => v.to_vec(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad identity_key: {e}"))),
|
||||
};
|
||||
|
||||
match self.store.fetch_hybrid_key(&identity_key) {
|
||||
Ok(Some(pk)) => results.get().set_hybrid_public_key(&pk),
|
||||
Ok(None) => results.get().set_hybrid_public_key(&[]),
|
||||
@@ -172,7 +276,17 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
params: federation_service::ProxyResolveUserParams,
|
||||
mut results: federation_service::ProxyResolveUserResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
let username = match params.get().and_then(|p| p.get_username()) {
|
||||
let p = match params.get() {
|
||||
Ok(p) => p,
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad params: {e}"))),
|
||||
};
|
||||
|
||||
// Validate origin against mTLS cert and apply rate limit.
|
||||
if let Err(e) = extract_and_validate_origin(self, p.get_auth()) {
|
||||
return Promise::err(e);
|
||||
}
|
||||
|
||||
let username = match p.get_username() {
|
||||
Ok(u) => match u.to_str() {
|
||||
Ok(s) => s.to_string(),
|
||||
Err(e) => return Promise::err(capnp::Error::failed(format!("bad utf-8: {e}"))),
|
||||
@@ -194,8 +308,42 @@ impl federation_service::Server for FederationServiceImpl {
|
||||
_params: federation_service::FederationHealthParams,
|
||||
mut results: federation_service::FederationHealthResults,
|
||||
) -> Promise<(), capnp::Error> {
|
||||
// Health check does not require origin validation (diagnostic endpoint).
|
||||
results.get().set_status("ok");
|
||||
results.get().set_server_domain(&self.local_domain);
|
||||
Promise::ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the peer domain from the mTLS client certificate's first SAN (DNS name)
|
||||
/// or CN, given the QUIC connection's peer identity (a certificate chain).
|
||||
pub fn extract_peer_domain(conn: &quinn::Connection) -> Option<String> {
|
||||
let identity = conn.peer_identity()?;
|
||||
let certs = identity.downcast::<Vec<rustls::pki_types::CertificateDer<'static>>>().ok()?;
|
||||
let first_cert = certs.first()?;
|
||||
|
||||
// Parse the DER certificate to extract SAN DNS names or CN.
|
||||
let (_, parsed) = x509_parser::parse_x509_certificate(first_cert.as_ref()).ok()?;
|
||||
|
||||
// Prefer SAN DNS names.
|
||||
if let Ok(Some(san)) = parsed.subject_alternative_name() {
|
||||
for name in &san.value.general_names {
|
||||
if let x509_parser::extensions::GeneralName::DNSName(dns) = name {
|
||||
return Some(dns.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to CN.
|
||||
for rdn in parsed.subject().iter() {
|
||||
for attr in rdn.iter() {
|
||||
if attr.attr_type() == &x509_parser::oid_registry::OID_X509_COMMON_NAME {
|
||||
if let Ok(cn) = attr.as_str() {
|
||||
return Some(cn.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ mod plugin_loader;
|
||||
mod sql_store;
|
||||
mod tls;
|
||||
mod storage;
|
||||
mod ws_bridge;
|
||||
|
||||
use auth::{AuthConfig, PendingLogin, RateEntry, SessionInfo};
|
||||
use config::{
|
||||
@@ -119,6 +120,10 @@ struct Args {
|
||||
/// Redact identity key prefixes and payload sizes in audit logs for metadata minimization.
|
||||
#[arg(long, env = "QPQ_REDACT_LOGS", default_value_t = false)]
|
||||
redact_logs: bool,
|
||||
|
||||
/// WebSocket JSON-RPC bridge listen address (e.g. 0.0.0.0:9000). Enables browser connectivity.
|
||||
#[arg(long, env = "QPQ_WS_LISTEN")]
|
||||
ws_listen: Option<String>,
|
||||
}
|
||||
|
||||
// ── Entry point ───────────────────────────────────────────────────────────────
|
||||
@@ -329,6 +334,23 @@ async fn main() -> anyhow::Result<()> {
|
||||
Arc::clone(&waiters),
|
||||
);
|
||||
|
||||
// ── WebSocket JSON-RPC bridge ──────────────────────────────────────────
|
||||
if let Some(ws_addr_str) = &effective.ws_listen {
|
||||
let ws_addr: SocketAddr = ws_addr_str
|
||||
.parse()
|
||||
.context("--ws-listen must be host:port (e.g. 0.0.0.0:9000)")?;
|
||||
let ws_state = Arc::new(ws_bridge::WsBridgeState {
|
||||
store: Arc::clone(&store),
|
||||
waiters: Arc::clone(&waiters),
|
||||
auth_cfg: Arc::clone(&auth_cfg),
|
||||
sessions: Arc::clone(&sessions),
|
||||
rate_limits: Arc::clone(&rate_limits),
|
||||
sealed_sender: effective.sealed_sender,
|
||||
allow_insecure_auth: effective.allow_insecure_auth,
|
||||
});
|
||||
ws_bridge::spawn_ws_bridge(ws_addr, ws_state);
|
||||
}
|
||||
|
||||
let endpoint = Endpoint::server(server_config, listen)?;
|
||||
|
||||
tracing::info!(
|
||||
@@ -539,10 +561,20 @@ async fn main() -> anyhow::Result<()> {
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let verified_peer_domain =
|
||||
federation::service::extract_peer_domain(&conn);
|
||||
if let Some(ref peer) = verified_peer_domain {
|
||||
tracing::info!(peer_domain = %peer, "federation: mTLS peer authenticated");
|
||||
} else {
|
||||
tracing::warn!(peer = %conn.remote_address(), "federation: could not extract peer domain from mTLS cert");
|
||||
}
|
||||
|
||||
let service_impl = federation::service::FederationServiceImpl {
|
||||
store,
|
||||
waiters,
|
||||
local_domain: domain,
|
||||
verified_peer_domain,
|
||||
rate_limits: Arc::new(dashmap::DashMap::new()),
|
||||
};
|
||||
let client: quicproquo_proto::federation_capnp::federation_service::Client =
|
||||
capnp_rpc::new_client(service_impl);
|
||||
|
||||
@@ -332,16 +332,6 @@ impl NodeServiceImpl {
|
||||
return Promise::err(coded_error(E011_USERNAME_EMPTY, "username must not be empty"));
|
||||
}
|
||||
|
||||
let _request = match RegistrationRequest::<OpaqueSuite>::deserialize(&upload_bytes) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E010_OPAQUE_ERROR,
|
||||
format!("invalid registration upload: {e}"),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match self.store.has_user_record(&username) {
|
||||
Ok(true) => {
|
||||
return Promise::err(coded_error(
|
||||
|
||||
@@ -113,94 +113,106 @@ impl NodeServiceImpl {
|
||||
let final_path = dir.join(&blob_hex);
|
||||
let meta_path = dir.join(format!("{blob_hex}.meta"));
|
||||
|
||||
// If the blob already exists (fully uploaded), return immediately.
|
||||
if final_path.exists() {
|
||||
results.get().set_blob_id(&blob_hash);
|
||||
return Promise::ok(());
|
||||
}
|
||||
// All file I/O is delegated to spawn_blocking to avoid stalling the Tokio event loop.
|
||||
let uploader_prefix = auth_ctx
|
||||
.identity_key
|
||||
.as_deref()
|
||||
.filter(|k| k.len() >= 4)
|
||||
.map(|k| hex::encode(&k[..4]))
|
||||
.unwrap_or_default();
|
||||
|
||||
// Write chunk at the given offset.
|
||||
let write_result = (|| -> Result<(), String> {
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(false)
|
||||
.open(&part_path)
|
||||
.map_err(|e| format!("open .part file: {e}"))?;
|
||||
file.seek(SeekFrom::Start(offset))
|
||||
.map_err(|e| format!("seek: {e}"))?;
|
||||
file.write_all(&chunk)
|
||||
.map_err(|e| format!("write chunk: {e}"))?;
|
||||
file.sync_all()
|
||||
.map_err(|e| format!("sync: {e}"))?;
|
||||
Ok(())
|
||||
})();
|
||||
Promise::from_future(async move {
|
||||
let blob_hash_clone = blob_hash.clone();
|
||||
let part_path_clone = part_path.clone();
|
||||
let final_path_clone = final_path.clone();
|
||||
let meta_path_clone = meta_path.clone();
|
||||
let chunk_clone = chunk;
|
||||
let mime_clone = mime_type.clone();
|
||||
let prefix_clone = uploader_prefix.clone();
|
||||
|
||||
if let Err(e) = write_result {
|
||||
return Promise::err(coded_error(E009_STORAGE_ERROR, e));
|
||||
}
|
||||
|
||||
// Check if the blob is complete.
|
||||
let end = offset + chunk.len() as u64;
|
||||
if end == total_size {
|
||||
// Verify SHA-256 of the complete file.
|
||||
let verify_result = (|| -> Result<bool, String> {
|
||||
let mut file = std::fs::File::open(&part_path)
|
||||
.map_err(|e| format!("open for verify: {e}"))?;
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buf = [0u8; 64 * 1024];
|
||||
loop {
|
||||
let n = file.read(&mut buf).map_err(|e| format!("read: {e}"))?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
let io_result = tokio::task::spawn_blocking(move || -> Result<Option<bool>, String> {
|
||||
// If the blob already exists (fully uploaded), return immediately.
|
||||
if final_path_clone.exists() {
|
||||
return Ok(None); // signals "already done"
|
||||
}
|
||||
let computed: [u8; 32] = hasher.finalize().into();
|
||||
Ok(computed == blob_hash.as_slice())
|
||||
})();
|
||||
|
||||
match verify_result {
|
||||
Ok(true) => {
|
||||
// Hash matches — finalize the blob.
|
||||
if let Err(e) = std::fs::rename(&part_path, &final_path) {
|
||||
return Promise::err(coded_error(
|
||||
E009_STORAGE_ERROR,
|
||||
format!("rename .part to final: {e}"),
|
||||
));
|
||||
// Write chunk at the given offset.
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(false)
|
||||
.open(&part_path_clone)
|
||||
.map_err(|e| format!("open .part file: {e}"))?;
|
||||
file.seek(SeekFrom::Start(offset))
|
||||
.map_err(|e| format!("seek: {e}"))?;
|
||||
file.write_all(&chunk_clone)
|
||||
.map_err(|e| format!("write chunk: {e}"))?;
|
||||
file.sync_all()
|
||||
.map_err(|e| format!("sync: {e}"))?;
|
||||
|
||||
// Check if the blob is complete.
|
||||
let end = offset + chunk_clone.len() as u64;
|
||||
if end == total_size {
|
||||
// Verify SHA-256 of the complete file.
|
||||
let mut vfile = std::fs::File::open(&part_path_clone)
|
||||
.map_err(|e| format!("open for verify: {e}"))?;
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buf = [0u8; 64 * 1024];
|
||||
loop {
|
||||
let n = vfile.read(&mut buf).map_err(|e| format!("read: {e}"))?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
}
|
||||
let computed: [u8; 32] = hasher.finalize().into();
|
||||
if computed != blob_hash_clone.as_slice() {
|
||||
let _ = std::fs::remove_file(&part_path_clone);
|
||||
return Ok(Some(false)); // hash mismatch
|
||||
}
|
||||
|
||||
// Hash matches — finalize the blob.
|
||||
std::fs::rename(&part_path_clone, &final_path_clone)
|
||||
.map_err(|e| format!("rename .part to final: {e}"))?;
|
||||
|
||||
// Write metadata file.
|
||||
let uploader_prefix = auth_ctx
|
||||
.identity_key
|
||||
.as_deref()
|
||||
.filter(|k| k.len() >= 4)
|
||||
.map(|k| hex::encode(&k[..4]))
|
||||
.unwrap_or_default();
|
||||
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let meta = BlobMeta {
|
||||
mime_type: mime_type.clone(),
|
||||
mime_type: mime_clone,
|
||||
total_size,
|
||||
uploaded_at: now,
|
||||
uploader_key_prefix: uploader_prefix.clone(),
|
||||
uploader_key_prefix: prefix_clone,
|
||||
};
|
||||
|
||||
if let Err(e) = (|| -> Result<(), String> {
|
||||
let json = serde_json::to_string_pretty(&meta)
|
||||
.map_err(|e| format!("serialize meta: {e}"))?;
|
||||
std::fs::write(&meta_path, json.as_bytes())
|
||||
std::fs::write(&meta_path_clone, json.as_bytes())
|
||||
.map_err(|e| format!("write meta: {e}"))?;
|
||||
Ok(())
|
||||
})() {
|
||||
// Non-fatal: the blob is already stored; log and continue.
|
||||
tracing::warn!(error = %e, "failed to write blob metadata");
|
||||
}
|
||||
|
||||
return Ok(Some(true)); // complete + verified
|
||||
}
|
||||
|
||||
Ok(None) // chunk written, not yet complete
|
||||
})
|
||||
.await
|
||||
.map_err(|e| capnp::Error::failed(format!("spawn_blocking join: {e}")))?;
|
||||
|
||||
match io_result {
|
||||
Ok(None) => {
|
||||
// Already existed or chunk written (not yet complete).
|
||||
results.get().set_blob_id(&blob_hash);
|
||||
}
|
||||
Ok(Some(true)) => {
|
||||
// Complete and verified.
|
||||
tracing::info!(
|
||||
blob_hash_prefix = %fmt_hex(&blob_hash[..4]),
|
||||
total_size = total_size,
|
||||
@@ -208,24 +220,21 @@ impl NodeServiceImpl {
|
||||
uploader_prefix = %uploader_prefix,
|
||||
"audit: blob_upload_complete"
|
||||
);
|
||||
results.get().set_blob_id(&blob_hash);
|
||||
}
|
||||
Ok(false) => {
|
||||
// Hash mismatch — delete the .part file.
|
||||
let _ = std::fs::remove_file(&part_path);
|
||||
return Promise::err(coded_error(
|
||||
Ok(Some(false)) => {
|
||||
return Err(coded_error(
|
||||
E026_BLOB_HASH_MISMATCH,
|
||||
"SHA-256 of uploaded data does not match blobHash",
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = std::fs::remove_file(&part_path);
|
||||
return Promise::err(coded_error(E009_STORAGE_ERROR, e));
|
||||
return Err(coded_error(E009_STORAGE_ERROR, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results.get().set_blob_id(&blob_hash);
|
||||
Promise::ok(())
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handle_download_blob(
|
||||
@@ -261,65 +270,57 @@ impl NodeServiceImpl {
|
||||
let blob_path = dir.join(&blob_hex);
|
||||
let meta_path = dir.join(format!("{blob_hex}.meta"));
|
||||
|
||||
// Check that the blob exists.
|
||||
if !blob_path.exists() {
|
||||
return Promise::err(coded_error(E027_BLOB_NOT_FOUND, "blob not found"));
|
||||
}
|
||||
|
||||
// Read metadata.
|
||||
let meta: BlobMeta = match std::fs::read_to_string(&meta_path) {
|
||||
Ok(json) => match serde_json::from_str(&json) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E009_STORAGE_ERROR,
|
||||
format!("corrupt blob metadata: {e}"),
|
||||
));
|
||||
// Delegate all file I/O to spawn_blocking to avoid stalling the event loop.
|
||||
Promise::from_future(async move {
|
||||
let io_result = tokio::task::spawn_blocking(move || -> Result<(Vec<u8>, BlobMeta), capnp::Error> {
|
||||
// Check that the blob exists.
|
||||
if !blob_path.exists() {
|
||||
return Err(coded_error(E027_BLOB_NOT_FOUND, "blob not found"));
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(
|
||||
E009_STORAGE_ERROR,
|
||||
format!("read blob metadata: {e}"),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Read the requested chunk.
|
||||
let read_result = (|| -> Result<Vec<u8>, String> {
|
||||
let mut file = std::fs::File::open(&blob_path)
|
||||
.map_err(|e| format!("open blob: {e}"))?;
|
||||
let file_len = file
|
||||
.metadata()
|
||||
.map_err(|e| format!("file metadata: {e}"))?
|
||||
.len();
|
||||
// Read metadata.
|
||||
let meta: BlobMeta = match std::fs::read_to_string(&meta_path) {
|
||||
Ok(json) => serde_json::from_str(&json).map_err(|e| {
|
||||
coded_error(E009_STORAGE_ERROR, format!("corrupt blob metadata: {e}"))
|
||||
})?,
|
||||
Err(e) => {
|
||||
return Err(coded_error(
|
||||
E009_STORAGE_ERROR,
|
||||
format!("read blob metadata: {e}"),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if offset >= file_len {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
// Read the requested chunk.
|
||||
let mut file = std::fs::File::open(&blob_path)
|
||||
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("open blob: {e}")))?;
|
||||
let file_len = file
|
||||
.metadata()
|
||||
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("file metadata: {e}")))?
|
||||
.len();
|
||||
|
||||
file.seek(SeekFrom::Start(offset))
|
||||
.map_err(|e| format!("seek: {e}"))?;
|
||||
let remaining = (file_len - offset) as usize;
|
||||
let to_read = remaining.min(length as usize);
|
||||
let mut buf = vec![0u8; to_read];
|
||||
file.read_exact(&mut buf)
|
||||
.map_err(|e| format!("read chunk: {e}"))?;
|
||||
Ok(buf)
|
||||
})();
|
||||
if offset >= file_len {
|
||||
return Ok((vec![], meta));
|
||||
}
|
||||
|
||||
match read_result {
|
||||
Ok(chunk) => {
|
||||
let mut r = results.get();
|
||||
r.set_chunk(&chunk);
|
||||
r.set_total_size(meta.total_size);
|
||||
r.set_mime_type(&meta.mime_type);
|
||||
}
|
||||
Err(e) => {
|
||||
return Promise::err(coded_error(E009_STORAGE_ERROR, e));
|
||||
}
|
||||
}
|
||||
file.seek(SeekFrom::Start(offset))
|
||||
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("seek: {e}")))?;
|
||||
let remaining = (file_len - offset) as usize;
|
||||
let to_read = remaining.min(length as usize);
|
||||
let mut buf = vec![0u8; to_read];
|
||||
file.read_exact(&mut buf)
|
||||
.map_err(|e| coded_error(E009_STORAGE_ERROR, format!("read chunk: {e}")))?;
|
||||
Ok((buf, meta))
|
||||
})
|
||||
.await
|
||||
.map_err(|e| capnp::Error::failed(format!("spawn_blocking join: {e}")))?;
|
||||
|
||||
Promise::ok(())
|
||||
let (chunk, meta) = io_result?;
|
||||
let mut r = results.get();
|
||||
r.set_chunk(&chunk);
|
||||
r.set_total_size(meta.total_size);
|
||||
r.set_mime_type(&meta.mime_type);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -502,18 +502,30 @@ impl NodeServiceImpl {
|
||||
}
|
||||
};
|
||||
|
||||
// Register waiter BEFORE the initial fetch to close the TOCTOU window:
|
||||
// an enqueue between fetch and registration would fire notify before
|
||||
// the waiter exists, causing a missed wakeup.
|
||||
let waiter = if timeout_ms > 0 {
|
||||
Some(
|
||||
waiters
|
||||
.entry(recipient_key.clone())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let messages = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
|
||||
|
||||
if messages.is_empty() && timeout_ms > 0 {
|
||||
let waiter = waiters
|
||||
.entry(recipient_key.clone())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone();
|
||||
let _ = timeout(Duration::from_millis(timeout_ms), waiter.notified()).await;
|
||||
let msgs = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
|
||||
fill_payloads_wait(&mut results, msgs);
|
||||
metrics::record_fetch_wait_total();
|
||||
return Ok(());
|
||||
if messages.is_empty() {
|
||||
if let Some(waiter) = waiter {
|
||||
let _ = timeout(Duration::from_millis(timeout_ms), waiter.notified()).await;
|
||||
let msgs = fetch_fn(&store, &recipient_key, &channel_id, limit)?;
|
||||
fill_payloads_wait(&mut results, msgs);
|
||||
metrics::record_fetch_wait_total();
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
fill_payloads_wait(&mut results, messages);
|
||||
|
||||
@@ -46,7 +46,10 @@ impl NodeServiceImpl {
|
||||
}
|
||||
|
||||
let device_name = match p.get_device_name() {
|
||||
Ok(n) => n.to_str().unwrap_or("").to_string(),
|
||||
Ok(n) => match n.to_str() {
|
||||
Ok(s) => s.to_string(),
|
||||
Err(_) => return Promise::err(coded_error(E020_BAD_PARAMS, "deviceName must be valid UTF-8")),
|
||||
},
|
||||
Err(e) => return Promise::err(coded_error(E020_BAD_PARAMS, e)),
|
||||
};
|
||||
|
||||
|
||||
@@ -164,6 +164,9 @@ impl NodeServiceImpl {
|
||||
));
|
||||
}
|
||||
|
||||
// Timing floor: mask DB-lookup timing differences (same as resolveUser).
|
||||
let deadline = Instant::now() + RESOLVE_TIMING_FLOOR;
|
||||
|
||||
match self.store.resolve_identity_key(identity_key) {
|
||||
Ok(Some(username)) => {
|
||||
results.get().set_username(&username);
|
||||
@@ -174,6 +177,10 @@ impl NodeServiceImpl {
|
||||
Err(e) => return Promise::err(storage_err(e)),
|
||||
}
|
||||
|
||||
Promise::ok(())
|
||||
// Pad to timing floor before responding.
|
||||
Promise::from_future(async move {
|
||||
tokio::time::sleep_until(deadline).await;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -542,7 +542,7 @@ impl Store for SqlStore {
|
||||
return Ok((id, false));
|
||||
}
|
||||
let mut channel_id = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut channel_id);
|
||||
rand::rngs::OsRng.fill_bytes(&mut channel_id);
|
||||
conn.execute(
|
||||
"INSERT INTO channels (channel_id, member_a, member_b) VALUES (?1, ?2, ?3)",
|
||||
params![channel_id.as_slice(), a, b],
|
||||
|
||||
@@ -757,7 +757,7 @@ impl Store for FileBackedStore {
|
||||
return Ok((channel_id.clone(), false));
|
||||
}
|
||||
let mut channel_id = [0u8; 16];
|
||||
rand::thread_rng().fill_bytes(&mut channel_id);
|
||||
rand::rngs::OsRng.fill_bytes(&mut channel_id);
|
||||
let channel_id = channel_id.to_vec();
|
||||
map.insert(channel_id.clone(), (a, b));
|
||||
self.flush_channels(&self.channels_path, &map)?;
|
||||
|
||||
561
crates/quicproquo-server/src/ws_bridge.rs
Normal file
561
crates/quicproquo-server/src/ws_bridge.rs
Normal file
@@ -0,0 +1,561 @@
|
||||
//! WebSocket JSON-RPC bridge for browser clients.
|
||||
//!
|
||||
//! Provides a lightweight JSON-RPC interface over WebSocket so that browsers
|
||||
//! can interact with the server without a Cap'n Proto / QUIC stack.
|
||||
//!
|
||||
//! Security parity with the Cap'n Proto path:
|
||||
//! - Rate limiting via `check_rate_limit()` on all mutating handlers
|
||||
//! - DM channel membership verification on `send`
|
||||
//! - Payload size limits (5 MB)
|
||||
//! - Timing floor on `resolveUser` to mask lookup timing
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use base64::Engine;
|
||||
use dashmap::DashMap;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::SinkExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::Instant;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
use crate::auth::{check_rate_limit, validate_token_raw, AuthConfig, AuthContext, RateEntry, SessionInfo};
|
||||
use crate::storage::Store;
|
||||
|
||||
const B64: base64::engine::general_purpose::GeneralPurpose =
|
||||
base64::engine::general_purpose::STANDARD;
|
||||
|
||||
/// Maximum payload size for WS bridge (same as Cap'n Proto path).
|
||||
const MAX_PAYLOAD_BYTES: usize = 5 * 1024 * 1024;
|
||||
|
||||
/// Minimum response time for resolveUser to mask DB lookup timing differences.
|
||||
const RESOLVE_TIMING_FLOOR: Duration = Duration::from_millis(5);
|
||||
|
||||
// ── Shared state ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Subset of server state needed by the WS bridge (all `Send + Sync`).
|
||||
#[allow(dead_code)] // sealed_sender plumbed for future use
|
||||
pub struct WsBridgeState {
|
||||
pub store: Arc<dyn Store>,
|
||||
pub waiters: Arc<DashMap<Vec<u8>, Arc<Notify>>>,
|
||||
pub auth_cfg: Arc<AuthConfig>,
|
||||
pub sessions: Arc<DashMap<Vec<u8>, SessionInfo>>,
|
||||
pub rate_limits: Arc<DashMap<Vec<u8>, RateEntry>>,
|
||||
pub sealed_sender: bool,
|
||||
pub allow_insecure_auth: bool,
|
||||
}
|
||||
|
||||
// ── JSON-RPC types ──────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RpcRequest {
|
||||
id: serde_json::Value,
|
||||
method: String,
|
||||
#[serde(default)]
|
||||
params: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RpcResponse {
|
||||
id: serde_json::Value,
|
||||
ok: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
impl RpcResponse {
|
||||
fn success(id: serde_json::Value, result: serde_json::Value) -> Self {
|
||||
Self {
|
||||
id,
|
||||
ok: true,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn error(id: serde_json::Value, msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
ok: false,
|
||||
result: None,
|
||||
error: Some(msg.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Auth helper ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// Extract and validate the "token" field from params. In insecure-auth mode
|
||||
/// with no token configured, an empty token is accepted as the bearer.
|
||||
fn extract_auth(
|
||||
state: &WsBridgeState,
|
||||
params: &serde_json::Value,
|
||||
) -> Result<AuthContext, String> {
|
||||
let token_str = params
|
||||
.get("token")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let token_bytes = token_str.as_bytes().to_vec();
|
||||
|
||||
// In insecure-auth mode with no configured token, accept any (including empty)
|
||||
// token as the bearer token. This mirrors the Cap'n Proto path behaviour.
|
||||
if state.allow_insecure_auth && state.auth_cfg.required_token.is_none() {
|
||||
// Treat the request identity from params as the identity.
|
||||
return Ok(AuthContext {
|
||||
token: token_bytes,
|
||||
identity_key: None,
|
||||
});
|
||||
}
|
||||
|
||||
validate_token_raw(&state.auth_cfg, &state.sessions, &token_bytes)
|
||||
}
|
||||
|
||||
/// Resolve identity key: either from auth context (session-bound) or from
|
||||
/// request params (insecure-auth mode). Returns the 32-byte identity key.
|
||||
fn resolve_identity(
|
||||
state: &WsBridgeState,
|
||||
auth_ctx: &AuthContext,
|
||||
params: &serde_json::Value,
|
||||
) -> Result<Vec<u8>, String> {
|
||||
// If auth context has an identity-bound session, use that.
|
||||
if let Some(ref ik) = auth_ctx.identity_key {
|
||||
return Ok(ik.clone());
|
||||
}
|
||||
|
||||
// In insecure-auth mode, accept identity from params.
|
||||
if state.allow_insecure_auth {
|
||||
// Try base64-encoded identityKey first.
|
||||
if let Some(b64) = params.get("identityKey").and_then(|v| v.as_str()) {
|
||||
return B64
|
||||
.decode(b64)
|
||||
.map_err(|e| format!("bad base64 identityKey: {e}"));
|
||||
}
|
||||
// Try username lookup.
|
||||
if let Some(username) = params.get("username").and_then(|v| v.as_str()) {
|
||||
if let Ok(Some(ik)) = state.store.get_user_identity_key(username) {
|
||||
return Ok(ik);
|
||||
}
|
||||
return Err(format!("user not found: {username}"));
|
||||
}
|
||||
}
|
||||
|
||||
Err("no identity: login required or pass identityKey/username in insecure mode".to_string())
|
||||
}
|
||||
|
||||
/// Apply rate limiting using the auth token. Returns an error string on limit exceeded.
|
||||
fn ws_check_rate_limit(state: &WsBridgeState, auth_ctx: &AuthContext) -> Result<(), String> {
|
||||
check_rate_limit(&state.rate_limits, &auth_ctx.token)
|
||||
.map_err(|e| format!("rate limit exceeded: {e}"))
|
||||
}
|
||||
|
||||
// ── Dispatch ────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn dispatch(state: &WsBridgeState, req: RpcRequest) -> RpcResponse {
|
||||
match req.method.as_str() {
|
||||
"health" => handle_health(req.id),
|
||||
"resolveUser" => handle_resolve_user(state, req.id, &req.params).await,
|
||||
"createChannel" => handle_create_channel(state, req.id, &req.params),
|
||||
"send" => handle_send(state, req.id, &req.params),
|
||||
"receive" => handle_receive(state, req.id, &req.params),
|
||||
"deleteAccount" => handle_delete_account(state, req.id, &req.params),
|
||||
_ => RpcResponse::error(req.id, format!("unknown method: {}", req.method)),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Handlers ────────────────────────────────────────────────────────────────
|
||||
|
||||
fn handle_health(id: serde_json::Value) -> RpcResponse {
|
||||
RpcResponse::success(id, serde_json::json!("ok"))
|
||||
}
|
||||
|
||||
async fn handle_resolve_user(
|
||||
state: &WsBridgeState,
|
||||
id: serde_json::Value,
|
||||
params: &serde_json::Value,
|
||||
) -> RpcResponse {
|
||||
let auth_ctx = match extract_auth(state, params) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Rate limit resolve requests to prevent bulk enumeration.
|
||||
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
|
||||
return RpcResponse::error(id, e);
|
||||
}
|
||||
|
||||
let username = match params.get("username").and_then(|v| v.as_str()) {
|
||||
Some(u) if !u.is_empty() => u,
|
||||
_ => return RpcResponse::error(id, "missing or empty 'username' param"),
|
||||
};
|
||||
|
||||
// Timing floor: mask DB-lookup timing differences between existing and
|
||||
// non-existing usernames (same as Cap'n Proto resolveUser handler).
|
||||
let deadline = Instant::now() + RESOLVE_TIMING_FLOOR;
|
||||
|
||||
let response = match state.store.get_user_identity_key(username) {
|
||||
Ok(Some(key)) => {
|
||||
RpcResponse::success(id, serde_json::json!({ "identityKey": B64.encode(&key) }))
|
||||
}
|
||||
Ok(None) => RpcResponse::success(id, serde_json::json!({ "identityKey": null })),
|
||||
Err(e) => RpcResponse::error(id, format!("storage error: {e}")),
|
||||
};
|
||||
|
||||
// Pad to timing floor before responding.
|
||||
tokio::time::sleep_until(deadline).await;
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
fn handle_create_channel(
|
||||
state: &WsBridgeState,
|
||||
id: serde_json::Value,
|
||||
params: &serde_json::Value,
|
||||
) -> RpcResponse {
|
||||
let auth_ctx = match extract_auth(state, params) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Rate limit.
|
||||
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
|
||||
return RpcResponse::error(id, e);
|
||||
}
|
||||
|
||||
let my_key = match resolve_identity(state, &auth_ctx, params) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Accept peer key as base64 or resolve from username.
|
||||
let peer_key = if let Some(b64) = params.get("peerKey").and_then(|v| v.as_str()) {
|
||||
match B64.decode(b64) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, format!("bad base64 peerKey: {e}")),
|
||||
}
|
||||
} else if let Some(username) = params.get("peerUsername").and_then(|v| v.as_str()) {
|
||||
match state.store.get_user_identity_key(username) {
|
||||
Ok(Some(k)) => k,
|
||||
Ok(None) => return RpcResponse::error(id, format!("peer user not found: {username}")),
|
||||
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
|
||||
}
|
||||
} else {
|
||||
return RpcResponse::error(id, "missing 'peerKey' (base64) or 'peerUsername'");
|
||||
};
|
||||
|
||||
if peer_key.len() != 32 {
|
||||
return RpcResponse::error(id, "peerKey must be 32 bytes");
|
||||
}
|
||||
|
||||
match state.store.create_channel(&my_key, &peer_key) {
|
||||
Ok((channel_id, was_new)) => RpcResponse::success(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"channelId": B64.encode(&channel_id),
|
||||
"wasNew": was_new,
|
||||
}),
|
||||
),
|
||||
Err(e) => RpcResponse::error(id, format!("storage error: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_send(
|
||||
state: &WsBridgeState,
|
||||
id: serde_json::Value,
|
||||
params: &serde_json::Value,
|
||||
) -> RpcResponse {
|
||||
let auth_ctx = match extract_auth(state, params) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Rate limit (parity with Cap'n Proto enqueue path).
|
||||
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
|
||||
return RpcResponse::error(id, e);
|
||||
}
|
||||
|
||||
let sender_key = match resolve_identity(state, &auth_ctx, params) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Resolve recipient: base64 key or username.
|
||||
let recipient_key =
|
||||
if let Some(b64) = params.get("recipientKey").and_then(|v| v.as_str()) {
|
||||
match B64.decode(b64) {
|
||||
Ok(k) => k,
|
||||
Err(e) => {
|
||||
return RpcResponse::error(id, format!("bad base64 recipientKey: {e}"))
|
||||
}
|
||||
}
|
||||
} else if let Some(username) = params.get("recipient").and_then(|v| v.as_str()) {
|
||||
match state.store.get_user_identity_key(username) {
|
||||
Ok(Some(k)) => k,
|
||||
Ok(None) => {
|
||||
return RpcResponse::error(id, format!("recipient not found: {username}"))
|
||||
}
|
||||
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
|
||||
}
|
||||
} else {
|
||||
return RpcResponse::error(id, "missing 'recipientKey' (base64) or 'recipient' (username)");
|
||||
};
|
||||
|
||||
if recipient_key.len() != 32 {
|
||||
return RpcResponse::error(id, "recipientKey must be 32 bytes");
|
||||
}
|
||||
|
||||
// Payload: base64-encoded binary or plain text message.
|
||||
let payload = if let Some(b64) = params.get("payload").and_then(|v| v.as_str()) {
|
||||
match B64.decode(b64) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return RpcResponse::error(id, format!("bad base64 payload: {e}")),
|
||||
}
|
||||
} else if let Some(msg) = params.get("message").and_then(|v| v.as_str()) {
|
||||
msg.as_bytes().to_vec()
|
||||
} else {
|
||||
return RpcResponse::error(id, "missing 'payload' (base64) or 'message' (text)");
|
||||
};
|
||||
|
||||
if payload.is_empty() {
|
||||
return RpcResponse::error(id, "payload must not be empty");
|
||||
}
|
||||
|
||||
// Payload size limit (same as Cap'n Proto path: 5 MB).
|
||||
if payload.len() > MAX_PAYLOAD_BYTES {
|
||||
return RpcResponse::error(
|
||||
id,
|
||||
format!("payload exceeds max size ({MAX_PAYLOAD_BYTES} bytes)"),
|
||||
);
|
||||
}
|
||||
|
||||
// Create or look up the DM channel between sender and recipient.
|
||||
let channel_id = match state.store.create_channel(&sender_key, &recipient_key) {
|
||||
Ok((ch, _)) => ch,
|
||||
Err(e) => return RpcResponse::error(id, format!("channel error: {e}")),
|
||||
};
|
||||
|
||||
// DM channel membership verification (parity with Cap'n Proto enqueue path).
|
||||
if channel_id.len() == 16 {
|
||||
let members = match state.store.get_channel_members(&channel_id) {
|
||||
Ok(Some(m)) => m,
|
||||
Ok(None) => return RpcResponse::error(id, "channel not found"),
|
||||
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
|
||||
};
|
||||
let (a, b) = &members;
|
||||
let caller_in = sender_key == *a || sender_key == *b;
|
||||
let recipient_other = (recipient_key == *a && sender_key == *b)
|
||||
|| (recipient_key == *b && sender_key == *a);
|
||||
if !caller_in || !recipient_other {
|
||||
return RpcResponse::error(id, "caller or recipient not a member of this channel");
|
||||
}
|
||||
}
|
||||
|
||||
match state
|
||||
.store
|
||||
.enqueue(&recipient_key, &channel_id, payload, None)
|
||||
{
|
||||
Ok(seq) => {
|
||||
// Notify any waiting long-poll fetchers.
|
||||
if let Some(notify) = state.waiters.get(&recipient_key) {
|
||||
notify.notify_waiters();
|
||||
}
|
||||
|
||||
// Audit logging (no secrets: no payload, no full keys).
|
||||
tracing::info!(
|
||||
recipient_prefix = %hex::encode(&recipient_key[..std::cmp::min(4, recipient_key.len())]),
|
||||
seq = seq,
|
||||
"audit: ws_bridge enqueue"
|
||||
);
|
||||
|
||||
RpcResponse::success(id, serde_json::json!({ "seq": seq }))
|
||||
}
|
||||
Err(e) => RpcResponse::error(id, format!("enqueue error: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_receive(
|
||||
state: &WsBridgeState,
|
||||
id: serde_json::Value,
|
||||
params: &serde_json::Value,
|
||||
) -> RpcResponse {
|
||||
let auth_ctx = match extract_auth(state, params) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Rate limit.
|
||||
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
|
||||
return RpcResponse::error(id, e);
|
||||
}
|
||||
|
||||
let my_key = match resolve_identity(state, &auth_ctx, params) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Resolve sender/peer: base64 key or username (needed to find the channel).
|
||||
let peer_key = if let Some(b64) = params.get("recipientKey").and_then(|v| v.as_str()) {
|
||||
match B64.decode(b64) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, format!("bad base64 recipientKey: {e}")),
|
||||
}
|
||||
} else if let Some(username) = params.get("recipient").and_then(|v| v.as_str()) {
|
||||
match state.store.get_user_identity_key(username) {
|
||||
Ok(Some(k)) => k,
|
||||
Ok(None) => return RpcResponse::error(id, format!("user not found: {username}")),
|
||||
Err(e) => return RpcResponse::error(id, format!("storage error: {e}")),
|
||||
}
|
||||
} else {
|
||||
return RpcResponse::error(id, "missing 'recipientKey' (base64) or 'recipient' (username)");
|
||||
};
|
||||
|
||||
// Find the channel between me and the peer.
|
||||
let channel_id = match state.store.create_channel(&my_key, &peer_key) {
|
||||
Ok((ch, _)) => ch,
|
||||
Err(e) => return RpcResponse::error(id, format!("channel error: {e}")),
|
||||
};
|
||||
|
||||
// Fetch (drain) all messages for me in this channel.
|
||||
match state.store.fetch(&my_key, &channel_id) {
|
||||
Ok(messages) => {
|
||||
let items: Vec<serde_json::Value> = messages
|
||||
.into_iter()
|
||||
.map(|(seq, data)| {
|
||||
// Try to decode as UTF-8 text; fall back to base64.
|
||||
let text = String::from_utf8(data.clone()).ok();
|
||||
serde_json::json!({
|
||||
"seq": seq,
|
||||
"data": B64.encode(&data),
|
||||
"text": text,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
RpcResponse::success(id, serde_json::json!(items))
|
||||
}
|
||||
Err(e) => RpcResponse::error(id, format!("fetch error: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_delete_account(
|
||||
state: &WsBridgeState,
|
||||
id: serde_json::Value,
|
||||
params: &serde_json::Value,
|
||||
) -> RpcResponse {
|
||||
let auth_ctx = match extract_auth(state, params) {
|
||||
Ok(ctx) => ctx,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
// Rate limit.
|
||||
if let Err(e) = ws_check_rate_limit(state, &auth_ctx) {
|
||||
return RpcResponse::error(id, e);
|
||||
}
|
||||
|
||||
let identity_key = match resolve_identity(state, &auth_ctx, params) {
|
||||
Ok(k) => k,
|
||||
Err(e) => return RpcResponse::error(id, e),
|
||||
};
|
||||
|
||||
match state.store.delete_account(&identity_key) {
|
||||
Ok(()) => {
|
||||
// Invalidate sessions for this identity.
|
||||
let tokens_to_remove: Vec<Vec<u8>> = state
|
||||
.sessions
|
||||
.iter()
|
||||
.filter(|entry| entry.value().identity_key == identity_key)
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect();
|
||||
for token in &tokens_to_remove {
|
||||
state.sessions.remove(token);
|
||||
}
|
||||
RpcResponse::success(id, serde_json::json!({ "deleted": true }))
|
||||
}
|
||||
Err(e) => RpcResponse::error(id, format!("delete failed: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
// ── WebSocket listener ──────────────────────────────────────────────────────
|
||||
|
||||
/// Spawn the WebSocket JSON-RPC bridge as a background tokio task.
|
||||
pub fn spawn_ws_bridge(addr: SocketAddr, state: Arc<WsBridgeState>) {
|
||||
tokio::spawn(async move {
|
||||
let listener = match TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::error!(addr = %addr, error = %e, "ws_bridge: failed to bind");
|
||||
return;
|
||||
}
|
||||
};
|
||||
tracing::info!(addr = %addr, "ws_bridge: accepting WebSocket connections");
|
||||
|
||||
loop {
|
||||
let (stream, peer) = match listener.accept().await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "ws_bridge: accept error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let state = Arc::clone(&state);
|
||||
tokio::spawn(async move {
|
||||
let ws = match tokio_tungstenite::accept_async(stream).await {
|
||||
Ok(ws) => ws,
|
||||
Err(e) => {
|
||||
tracing::debug!(peer = %peer, error = %e, "ws_bridge: handshake failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
tracing::debug!(peer = %peer, "ws_bridge: client connected");
|
||||
|
||||
let (mut sink, mut stream) = ws.split();
|
||||
|
||||
while let Some(msg) = stream.next().await {
|
||||
let msg = match msg {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
tracing::debug!(peer = %peer, error = %e, "ws_bridge: read error");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match msg {
|
||||
Message::Text(t) => t,
|
||||
Message::Close(_) => break,
|
||||
Message::Ping(_) | Message::Pong(_) => continue,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let req: RpcRequest = match serde_json::from_str(&text) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let resp = RpcResponse::error(
|
||||
serde_json::Value::Null,
|
||||
format!("invalid JSON: {e}"),
|
||||
);
|
||||
let json = serde_json::to_string(&resp).unwrap_or_default();
|
||||
if sink.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let resp = dispatch(&state, req).await;
|
||||
let json = serde_json::to_string(&resp).unwrap_or_default();
|
||||
if sink.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::debug!(peer = %peer, "ws_bridge: client disconnected");
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
Reference in New Issue
Block a user