From e4c5868b310d2cc7b64c5b30bb75a3e95bc24743 Mon Sep 17 00:00:00 2001 From: Christian Nennemann Date: Sun, 8 Mar 2026 18:00:47 +0100 Subject: [PATCH] feat: add client auto-reconnect, heartbeat, and connection status UI RPC layer (quicprochat-rpc): - RpcClient now uses tokio::sync::Mutex for safe reconnection - Auto-reconnect with exponential backoff + jitter on retriable errors - QUIC-level keepalive via quinn TransportConfig - subscribe_push() returns Option with None sentinel on break - RpcError::is_retriable() classifies transient vs permanent errors - ConnectionState enum (Connected/Reconnecting/Disconnected) with Display - Configurable max_retries, base_delay, max_backoff, keepalive_secs SDK layer (quicprochat-sdk): - QpqClient wraps RpcClient in Arc for safe heartbeat task sharing - start_heartbeat() spawns background task checking connection every 30s - connection_state() exposes RPC-layer state to UI - Reconnecting event added to ClientEvent enum - disconnect() aborts heartbeat before closing connection Client UI (quicprochat-client): - TUI status bar shows Connected/Reconnecting.../Offline with color - TUI handles Reconnecting event with attempt count display - REPL event listener prints connection state changes - REPL /status shows connection state instead of bool - Both TUI and REPL call start_heartbeat() on startup --- Cargo.lock | 1 + .../quicprochat-client/src/client/v2_repl.rs | 14 +- .../quicprochat-client/src/client/v2_tui.rs | 89 +++-- crates/quicprochat-rpc/Cargo.toml | 1 + crates/quicprochat-rpc/src/client.rs | 313 ++++++++++++++---- crates/quicprochat-rpc/src/error.rs | 89 +++++ crates/quicprochat-sdk/src/client.rs | 115 +++++-- crates/quicprochat-sdk/src/events.rs | 3 + 8 files changed, 526 insertions(+), 99 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4e39c40..e6ae39f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4491,6 +4491,7 @@ dependencies = [ "prost", "quicprochat-proto", "quinn", + "rand 0.8.5", "rcgen", "rustls", "sha2 0.10.9", diff --git a/crates/quicprochat-client/src/client/v2_repl.rs b/crates/quicprochat-client/src/client/v2_repl.rs index 9a1ef0a..feedf71 100644 --- a/crates/quicprochat-client/src/client/v2_repl.rs +++ b/crates/quicprochat-client/src/client/v2_repl.rs @@ -294,6 +294,15 @@ fn show_event(event: &ClientEvent) { }; display::print_incoming(&sender, body); } + ClientEvent::Connected => { + display::print_status("connected to server"); + } + ClientEvent::Disconnected { reason } => { + display::print_error(&format!("disconnected: {reason}")); + } + ClientEvent::Reconnecting { attempt } => { + display::print_status(&format!("reconnecting... (attempt {attempt})")); + } ClientEvent::ConversationCreated { display_name, .. } => { display::print_status(&format!("new conversation: {display_name}")); } @@ -397,7 +406,7 @@ async fn dispatch( fn do_status(client: &QpqClient, st: &ReplState) { println!("{BOLD}Status{RESET}"); - println!(" connected: {}", if client.is_connected() { "yes" } else { "no" }); + println!(" connection: {}", client.connection_state()); println!(" authenticated: {}", if client.is_authenticated() { "yes" } else { "no" }); println!(" username: {}", client.username().unwrap_or("(none)")); println!(" conversation: {}", st.current_display_name.as_deref().unwrap_or("(none)")); @@ -990,6 +999,9 @@ pub async fn run_v2_repl( // Connect to server. client.connect().await.context("connect to server")?; + // Start heartbeat for proactive dead-connection detection. + client.start_heartbeat(); + // Background event listener. let rx = client.subscribe(); spawn_event_listener(rx); diff --git a/crates/quicprochat-client/src/client/v2_tui.rs b/crates/quicprochat-client/src/client/v2_tui.rs index 765ab36..b5b5782 100644 --- a/crates/quicprochat-client/src/client/v2_tui.rs +++ b/crates/quicprochat-client/src/client/v2_tui.rs @@ -41,7 +41,7 @@ use ratatui::{ }; use tokio::sync::broadcast; -use quicprochat_sdk::client::QpqClient; +use quicprochat_sdk::client::{ConnectionState, QpqClient}; use quicprochat_sdk::conversation::ConversationStore; use quicprochat_sdk::events::ClientEvent; @@ -87,8 +87,8 @@ pub struct TuiApp { focus: Focus, /// Notification line (shown briefly, e.g. "Message sent", "Error: ..."). notification: Option, - /// Whether the client is currently connected. - connected: bool, + /// Current connection state. + conn_state: quicprochat_sdk::client::ConnectionState, /// Current MLS epoch for the active conversation (if available). mls_epoch: Option, } @@ -108,7 +108,7 @@ impl TuiApp { server_addr: server_addr.to_string(), focus: Focus::Input, notification: None, - connected: false, + conn_state: ConnectionState::Disconnected, mls_epoch: None, } } @@ -149,7 +149,15 @@ impl TuiApp { } fn update_status(&mut self) { - let conn_indicator = if self.connected { "Online" } else { "Offline" }; + let conn_indicator = match self.conn_state { + ConnectionState::Connected => "Connected", + ConnectionState::Reconnecting { attempt } => { + // We can't use format! in a match arm and return &str, + // so we'll handle this below. + return self.update_status_reconnecting(attempt); + } + ConnectionState::Disconnected => "Offline", + }; let user = self .username .as_deref() @@ -167,6 +175,25 @@ impl TuiApp { if conv_count == 1 { "" } else { "s" } ); } + + fn update_status_reconnecting(&mut self, attempt: u32) { + let user = self + .username + .as_deref() + .unwrap_or("not logged in"); + let conv_count = self.conversations.len(); + let epoch_str = match self.mls_epoch { + Some(e) => format!("epoch {e}"), + None => "epoch --".to_string(), + }; + self.status_line = format!( + "Reconnecting... (attempt {attempt}) | {} | {} | {} conversation{} | MLS {epoch_str}", + self.server_addr, + user, + conv_count, + if conv_count == 1 { "" } else { "s" } + ); + } } // ── Terminal Drop Guard ───────────────────────────────────────────────────── @@ -198,7 +225,7 @@ pub async fn run_v2_tui(client: &mut QpqClient) -> anyhow::Result<()> { "disconnected" }; let mut app = TuiApp::new(server_addr); - app.connected = client.is_connected(); + app.conn_state = client.connection_state(); // Populate initial state from client. if let Some(name) = client.username() { @@ -225,6 +252,9 @@ pub async fn run_v2_tui(client: &mut QpqClient) -> anyhow::Result<()> { app.update_status(); + // Start heartbeat for proactive dead-connection detection. + client.start_heartbeat(); + // Subscribe to SDK events. let mut event_rx = client.subscribe(); @@ -278,15 +308,20 @@ pub async fn run_v2_tui(client: &mut QpqClient) -> anyhow::Result<()> { fn handle_sdk_event(app: &mut TuiApp, event: ClientEvent) { match event { ClientEvent::Connected => { - app.connected = true; + app.conn_state = ConnectionState::Connected; app.notification = Some("Connected to server".to_string()); app.update_status(); } ClientEvent::Disconnected { reason } => { - app.connected = false; + app.conn_state = ConnectionState::Disconnected; app.notification = Some(format!("Disconnected: {reason}")); app.update_status(); } + ClientEvent::Reconnecting { attempt } => { + app.conn_state = ConnectionState::Reconnecting { attempt }; + app.notification = Some(format!("Reconnecting... (attempt {attempt})")); + app.update_status(); + } ClientEvent::Registered { username } => { app.notification = Some(format!("Registered as {username}")); } @@ -839,12 +874,11 @@ fn draw_input(frame: &mut Frame, app: &TuiApp, area: Rect) { } fn draw_status(frame: &mut Frame, app: &TuiApp, area: Rect) { - let conn_color = if app.connected { - Color::Green - } else { - Color::Red + let (conn_color, conn_indicator) = match app.conn_state { + ConnectionState::Connected => (Color::Green, " ON "), + ConnectionState::Reconnecting { .. } => (Color::Yellow, " ... "), + ConnectionState::Disconnected => (Color::Red, " OFF "), }; - let conn_indicator = if app.connected { " ON " } else { " OFF " }; let spans = vec![ Span::styled( @@ -1019,7 +1053,7 @@ mod tests { fn make_app() -> TuiApp { let mut app = TuiApp::new("127.0.0.1:7000"); - app.connected = true; + app.conn_state = ConnectionState::Connected; app.username = Some("alice".to_string()); app.conversations.push(ConversationItem { id: [1u8; 16], @@ -1067,12 +1101,12 @@ mod tests { } #[test] - fn status_bar_shows_online() { + fn status_bar_shows_connected() { let mut app = TuiApp::new("127.0.0.1:7000"); - app.connected = true; + app.conn_state = ConnectionState::Connected; app.username = Some("alice".to_string()); app.update_status(); - assert!(app.status_line.contains("Online")); + assert!(app.status_line.contains("Connected")); assert!(app.status_line.contains("alice")); assert!(app.status_line.contains("MLS epoch --")); } @@ -1080,15 +1114,32 @@ mod tests { #[test] fn status_bar_shows_offline() { let mut app = TuiApp::new("127.0.0.1:7000"); - app.connected = false; + app.conn_state = ConnectionState::Disconnected; app.update_status(); assert!(app.status_line.contains("Offline")); } + #[test] + fn status_bar_shows_reconnecting() { + let mut app = TuiApp::new("127.0.0.1:7000"); + app.conn_state = ConnectionState::Reconnecting { attempt: 2 }; + app.update_status(); + assert!( + app.status_line.contains("Reconnecting"), + "expected Reconnecting in: {}", + app.status_line + ); + assert!( + app.status_line.contains("attempt 2"), + "expected attempt count in: {}", + app.status_line + ); + } + #[test] fn status_bar_shows_epoch() { let mut app = TuiApp::new("127.0.0.1:7000"); - app.connected = true; + app.conn_state = ConnectionState::Connected; app.mls_epoch = Some(42); app.update_status(); assert!(app.status_line.contains("MLS epoch 42")); diff --git a/crates/quicprochat-rpc/Cargo.toml b/crates/quicprochat-rpc/Cargo.toml index 7297978..1359b21 100644 --- a/crates/quicprochat-rpc/Cargo.toml +++ b/crates/quicprochat-rpc/Cargo.toml @@ -20,6 +20,7 @@ tracing = { workspace = true } thiserror = { workspace = true } dashmap = { workspace = true } sha2 = { workspace = true } +rand = { workspace = true } uuid = { version = "1", features = ["v7"] } metrics = "0.22" diff --git a/crates/quicprochat-rpc/src/client.rs b/crates/quicprochat-rpc/src/client.rs index 5ad26fe..c12f79f 100644 --- a/crates/quicprochat-rpc/src/client.rs +++ b/crates/quicprochat-rpc/src/client.rs @@ -1,18 +1,32 @@ //! QUIC RPC client — connect to server, send requests, receive push events. +//! +//! Supports auto-reconnect with exponential backoff, keepalive pings, and +//! push subscription recovery. use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; +use std::time::Duration; use bytes::{Bytes, BytesMut}; use quinn::{Connection, Endpoint}; -use tokio::sync::mpsc; -use tracing::{debug, warn}; +use tokio::sync::{mpsc, Mutex}; +use tracing::{debug, info, warn}; use crate::auth_handshake; use crate::error::{RpcError, RpcStatus}; use crate::framing::{PushFrame, RequestFrame, ResponseFrame}; +/// Default maximum retries for auto-reconnect (including first attempt). +pub const DEFAULT_MAX_RETRIES: u32 = 3; +/// Default base delay for exponential backoff (milliseconds). +pub const DEFAULT_BASE_DELAY_MS: u64 = 500; +/// Default maximum backoff cap (milliseconds). +pub const DEFAULT_MAX_BACKOFF_MS: u64 = 30_000; +/// Default keepalive interval (seconds). +pub const DEFAULT_KEEPALIVE_SECS: u64 = 30; + /// Configuration for the RPC client. +#[derive(Clone)] pub struct RpcClientConfig { /// Server address to connect to. pub server_addr: std::net::SocketAddr, @@ -24,19 +38,84 @@ pub struct RpcClientConfig { pub alpn: Vec, /// Session token to send during auth handshake. pub session_token: Option>, + /// Max retries on connection failure (default 3). + pub max_retries: u32, + /// Base delay for backoff in milliseconds (default 500). + pub base_delay_ms: u64, + /// Maximum backoff cap in milliseconds (default 30000). + pub max_backoff_ms: u64, + /// Keepalive interval in seconds (default 30). Set to 0 to disable. + pub keepalive_secs: u64, } -/// A QUIC RPC client connection. +impl RpcClientConfig { + /// Fill in default values for zero/unset fields. + fn with_defaults(mut self) -> Self { + if self.max_retries == 0 { + self.max_retries = DEFAULT_MAX_RETRIES; + } + if self.base_delay_ms == 0 { + self.base_delay_ms = DEFAULT_BASE_DELAY_MS; + } + if self.max_backoff_ms == 0 { + self.max_backoff_ms = DEFAULT_MAX_BACKOFF_MS; + } + if self.keepalive_secs == 0 { + self.keepalive_secs = DEFAULT_KEEPALIVE_SECS; + } + self + } +} + +/// Connection state for the RPC client. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectionState { + /// Connected and ready to send/receive. + Connected, + /// Connection lost, attempting to reconnect. + Reconnecting { attempt: u32 }, + /// Disconnected (intentional or exhausted retries). + Disconnected, +} + +impl std::fmt::Display for ConnectionState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Connected => write!(f, "Connected"), + Self::Reconnecting { attempt } => write!(f, "Reconnecting (attempt {attempt})"), + Self::Disconnected => write!(f, "Disconnected"), + } + } +} + +/// A QUIC RPC client connection with auto-reconnect support. pub struct RpcClient { - connection: Connection, + connection: Mutex, + endpoint: Endpoint, + config: RpcClientConfig, next_request_id: AtomicU32, + state: std::sync::Mutex, } impl RpcClient { /// Connect to the RPC server. pub async fn connect(config: RpcClientConfig) -> Result { + let config = config.with_defaults(); + let (endpoint, connection) = Self::establish(&config).await?; + + Ok(Self { + connection: Mutex::new(connection), + endpoint, + config, + next_request_id: AtomicU32::new(1), + state: std::sync::Mutex::new(ConnectionState::Connected), + }) + } + + /// Establish a new QUIC connection + optional auth handshake. + async fn establish(config: &RpcClientConfig) -> Result<(Endpoint, Connection), RpcError> { let mut tls = (*config.tls_config).clone(); - tls.alpn_protocols = vec![config.alpn]; + tls.alpn_protocols = vec![config.alpn.clone()]; let quic_tls = quinn::crypto::rustls::QuicClientConfig::try_from(tls) .map_err(|e| RpcError::Connection(format!("TLS config: {e}")))?; @@ -46,7 +125,13 @@ impl RpcClient { ); let mut endpoint = Endpoint::client(bind_addr) .map_err(|e| RpcError::Connection(e.to_string()))?; - endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new(quic_tls))); + + let mut quinn_config = quinn::ClientConfig::new(Arc::new(quic_tls)); + // Enable QUIC-level keepalive. + let mut transport = quinn::TransportConfig::default(); + transport.keep_alive_interval(Some(Duration::from_secs(config.keepalive_secs))); + quinn_config.transport_config(Arc::new(transport)); + endpoint.set_default_client_config(quinn_config); let connection = endpoint .connect(config.server_addr, &config.server_name) @@ -58,34 +143,115 @@ impl RpcClient { // Perform auth handshake if a session token was provided. if let Some(ref token) = config.session_token { - let (mut send, mut recv) = connection - .open_bi() - .await - .map_err(|e| RpcError::Connection(format!("open auth stream: {e}")))?; - - auth_handshake::send_auth_init(&mut send, token).await?; - send.finish() - .map_err(|e| RpcError::Connection(format!("finish auth send: {e}")))?; - auth_handshake::recv_auth_ack(&mut recv).await?; - debug!("auth handshake complete"); + Self::do_auth_handshake(&connection, token).await?; } - Ok(Self { - connection, - next_request_id: AtomicU32::new(1), - }) + Ok((endpoint, connection)) + } + + /// Perform the auth handshake on a connection. + async fn do_auth_handshake(connection: &Connection, token: &[u8]) -> Result<(), RpcError> { + let (mut send, mut recv) = connection + .open_bi() + .await + .map_err(|e| RpcError::Connection(format!("open auth stream: {e}")))?; + + auth_handshake::send_auth_init(&mut send, token).await?; + send.finish() + .map_err(|e| RpcError::Connection(format!("finish auth send: {e}")))?; + auth_handshake::recv_auth_ack(&mut recv).await?; + debug!("auth handshake complete"); + Ok(()) + } + + /// Attempt to reconnect to the server with exponential backoff. + /// Returns `Ok(())` on success, `Err` if all retries exhausted. + async fn reconnect(&self) -> Result<(), RpcError> { + let max = self.config.max_retries; + let base = self.config.base_delay_ms; + let cap = self.config.max_backoff_ms; + + for attempt in 1..=max { + self.set_state(ConnectionState::Reconnecting { attempt }); + info!(attempt, max, "attempting reconnect"); + + // Exponential backoff with jitter, capped. + let delay_ms = (base * 2u64.saturating_pow(attempt.saturating_sub(1))).min(cap); + let jitter_ms = rand::Rng::gen_range(&mut rand::thread_rng(), 0..=delay_ms / 2); + tokio::time::sleep(Duration::from_millis(delay_ms + jitter_ms)).await; + + match self.try_connect_once().await { + Ok(new_conn) => { + // Auth handshake on the new connection. + if let Some(ref token) = self.config.session_token { + if let Err(e) = Self::do_auth_handshake(&new_conn, token).await { + warn!(attempt, "reconnect auth handshake failed: {e}"); + continue; + } + } + + // Swap the connection under the lock. + *self.connection.lock().await = new_conn; + self.set_state(ConnectionState::Connected); + info!("reconnected successfully"); + return Ok(()); + } + Err(e) => { + warn!(attempt, max, "reconnect attempt failed: {e}"); + } + } + } + + self.set_state(ConnectionState::Disconnected); + Err(RpcError::Connection(format!( + "reconnect failed after {max} attempts" + ))) + } + + /// Single connection attempt (no retry). + async fn try_connect_once(&self) -> Result { + let conn = self + .endpoint + .connect(self.config.server_addr, &self.config.server_name) + .map_err(|e| RpcError::Connection(e.to_string()))? + .await + .map_err(|e| RpcError::Connection(e.to_string()))?; + Ok(conn) } /// Send an RPC request and wait for the response. + /// + /// On retriable connection errors, automatically reconnects and retries. pub async fn call( &self, method_id: u16, payload: Bytes, ) -> Result { - let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); + let conn = self.connection.lock().await.clone(); + match Self::call_on(&conn, &self.next_request_id, method_id, payload.clone()).await { + Ok(resp) => Ok(resp), + Err(e) if e.is_retriable() && conn.close_reason().is_some() => { + // Connection is dead — try reconnect then retry once. + warn!("connection lost during RPC call, attempting reconnect"); + drop(conn); + self.reconnect().await?; + let conn = self.connection.lock().await.clone(); + Self::call_on(&conn, &self.next_request_id, method_id, payload).await + } + Err(e) => Err(e), + } + } - let (mut send, mut recv) = self - .connection + /// Inner call implementation on a specific connection. + async fn call_on( + connection: &Connection, + next_request_id: &AtomicU32, + method_id: u16, + payload: Bytes, + ) -> Result { + let request_id = next_request_id.fetch_add(1, Ordering::Relaxed); + + let (mut send, mut recv) = connection .open_bi() .await .map_err(|e| RpcError::Connection(e.to_string()))?; @@ -142,55 +308,86 @@ impl RpcClient { } /// Subscribe to server-push events. Returns a receiver channel. - /// Spawns a background task that reads uni-streams. - pub fn subscribe_push(&self) -> mpsc::UnboundedReceiver { + /// + /// Spawns a background task that reads uni-streams. When the push stream + /// breaks (connection error, EOF), a `None` sentinel is sent so the + /// caller can detect the break and resubscribe after reconnection. + /// + /// This is an async method because it needs to clone the current connection. + pub async fn subscribe_push(&self) -> mpsc::UnboundedReceiver> { let (tx, rx) = mpsc::unbounded_channel(); - let conn = self.connection.clone(); + let conn = self.connection.lock().await.clone(); + tokio::spawn(Self::push_loop(conn, tx)); + rx + } - tokio::spawn(async move { - loop { - match conn.accept_uni().await { - Ok(mut recv) => { - let mut buf = BytesMut::new(); - loop { - match recv.read_chunk(65536, true).await { - Ok(Some(chunk)) => buf.extend_from_slice(&chunk.bytes), - Ok(None) => break, - Err(e) => { - debug!("push stream read error: {e}"); - break; - } + async fn push_loop(conn: Connection, tx: mpsc::UnboundedSender>) { + loop { + match conn.accept_uni().await { + Ok(mut recv) => { + let mut buf = BytesMut::new(); + loop { + match recv.read_chunk(65536, true).await { + Ok(Some(chunk)) => buf.extend_from_slice(&chunk.bytes), + Ok(None) => break, + Err(e) => { + debug!("push stream read error: {e}"); + break; } } - match PushFrame::decode(&mut buf) { - Ok(Some(frame)) => { - if tx.send(frame).is_err() { - return; // receiver dropped - } - } - Ok(None) => debug!("incomplete push frame"), - Err(e) => debug!("push decode error: {e}"), - } } - Err(quinn::ConnectionError::ApplicationClosed(_)) => break, - Err(e) => { - warn!("accept_uni error: {e}"); - break; + match PushFrame::decode(&mut buf) { + Ok(Some(frame)) => { + if tx.send(Some(frame)).is_err() { + return; // receiver dropped + } + } + Ok(None) => debug!("incomplete push frame"), + Err(e) => debug!("push decode error: {e}"), } } + Err(quinn::ConnectionError::ApplicationClosed(_)) => { + let _ = tx.send(None); + break; + } + Err(e) => { + warn!("accept_uni error: {e}"); + let _ = tx.send(None); + break; + } } - }); - - rx + } } /// Close the connection gracefully. pub fn close(&self) { - self.connection.close(0u32.into(), b"bye"); + self.set_state(ConnectionState::Disconnected); + if let Ok(conn) = self.connection.try_lock() { + conn.close(0u32.into(), b"bye"); + } } /// Get the underlying QUIC connection (for advanced use). - pub fn connection(&self) -> &Connection { - &self.connection + pub async fn connection(&self) -> Connection { + self.connection.lock().await.clone() + } + + /// Get the current connection state. + pub fn connection_state(&self) -> ConnectionState { + *self.state.lock().unwrap_or_else(|e| e.into_inner()) + } + + /// Check if the connection appears alive (no close reason set). + pub fn is_alive(&self) -> bool { + match self.connection.try_lock() { + Ok(conn) => conn.close_reason().is_none(), + Err(_) => true, // locked = likely in use = alive + } + } + + fn set_state(&self, new_state: ConnectionState) { + if let Ok(mut s) = self.state.lock() { + *s = new_state; + } } } diff --git a/crates/quicprochat-rpc/src/error.rs b/crates/quicprochat-rpc/src/error.rs index fdc256d..8842c3d 100644 --- a/crates/quicprochat-rpc/src/error.rs +++ b/crates/quicprochat-rpc/src/error.rs @@ -72,3 +72,92 @@ pub enum RpcError { #[error("payload too large: {size} bytes (max {max})")] PayloadTooLarge { size: usize, max: usize }, } + +impl RpcError { + /// Returns `true` if this error is transient and the operation may succeed + /// on retry (e.g. connection reset, timeout, server 5xx). Returns `false` + /// for permanent failures (auth, bad request, payload limits). + pub fn is_retriable(&self) -> bool { + match self { + Self::Connection(_) | Self::Timeout | Self::StreamClosed => true, + Self::Server { status, .. } => matches!( + status, + RpcStatus::Unavailable + | RpcStatus::DeadlineExceeded + | RpcStatus::Internal + | RpcStatus::RateLimited + ), + Self::Encode(_) | Self::Decode(_) | Self::PayloadTooLarge { .. } => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn retriable_errors() { + assert!(RpcError::Connection("reset".into()).is_retriable()); + assert!(RpcError::Timeout.is_retriable()); + assert!(RpcError::StreamClosed.is_retriable()); + assert!(RpcError::Server { + status: RpcStatus::Unavailable, + message: String::new(), + } + .is_retriable()); + assert!(RpcError::Server { + status: RpcStatus::Internal, + message: String::new(), + } + .is_retriable()); + assert!(RpcError::Server { + status: RpcStatus::DeadlineExceeded, + message: String::new(), + } + .is_retriable()); + assert!(RpcError::Server { + status: RpcStatus::RateLimited, + message: String::new(), + } + .is_retriable()); + } + + #[test] + fn non_retriable_errors() { + assert!(!RpcError::Encode("bad".into()).is_retriable()); + assert!(!RpcError::Decode("bad".into()).is_retriable()); + assert!(!RpcError::PayloadTooLarge { size: 100, max: 50 }.is_retriable()); + assert!(!RpcError::Server { + status: RpcStatus::Unauthorized, + message: String::new(), + } + .is_retriable()); + assert!(!RpcError::Server { + status: RpcStatus::BadRequest, + message: String::new(), + } + .is_retriable()); + assert!(!RpcError::Server { + status: RpcStatus::Forbidden, + message: String::new(), + } + .is_retriable()); + assert!(!RpcError::Server { + status: RpcStatus::NotFound, + message: String::new(), + } + .is_retriable()); + } + + #[test] + fn connection_state_display() { + use crate::client::ConnectionState; + assert_eq!(ConnectionState::Connected.to_string(), "Connected"); + assert_eq!(ConnectionState::Disconnected.to_string(), "Disconnected"); + assert_eq!( + ConnectionState::Reconnecting { attempt: 2 }.to_string(), + "Reconnecting (attempt 2)" + ); + } +} diff --git a/crates/quicprochat-sdk/src/client.rs b/crates/quicprochat-sdk/src/client.rs index 363b952..9ecad22 100644 --- a/crates/quicprochat-sdk/src/client.rs +++ b/crates/quicprochat-sdk/src/client.rs @@ -1,19 +1,28 @@ //! `QpqClient` — the main entry point for the quicprochat SDK. +//! +//! Provides connection lifecycle management with auto-reconnect, heartbeat +//! monitoring, push subscription recovery, and a connection state machine. use std::sync::Arc; +use std::time::Duration; use tokio::sync::broadcast; use tracing::info; +pub use quicprochat_rpc::client::ConnectionState; + use crate::config::ClientConfig; use crate::conversation::ConversationStore; use crate::error::SdkError; use crate::events::ClientEvent; +/// Default heartbeat interval for proactive dead-connection detection. +const HEARTBEAT_INTERVAL_SECS: u64 = 30; + /// The main SDK client. All state is contained within this struct — no globals. pub struct QpqClient { config: ClientConfig, - rpc: Option, + rpc: Option>, event_tx: broadcast::Sender, /// The authenticated username, if logged in. username: Option, @@ -24,9 +33,9 @@ pub struct QpqClient { /// Local conversation store (SQLCipher). conv_store: Option, /// Device ID for multi-device support. - /// When set, fetch/peek/ack requests include this device_id so the server - /// scopes them to the correct per-device queue. device_id: Option>, + /// Handle to the heartbeat background task (if running). + heartbeat_handle: Option>, } impl QpqClient { @@ -42,6 +51,7 @@ impl QpqClient { session_token: None, conv_store: None, device_id: None, + heartbeat_handle: None, } } @@ -55,10 +65,14 @@ impl QpqClient { tls_config: Arc::new(tls_config), alpn: self.config.alpn.clone(), session_token: self.session_token.clone(), + max_retries: 0, // use defaults + base_delay_ms: 0, + max_backoff_ms: 0, + keepalive_secs: 0, }; let client = quicprochat_rpc::client::RpcClient::connect(rpc_config).await?; - self.rpc = Some(client); + self.rpc = Some(Arc::new(client)); // Open local conversation store. let store = ConversationStore::open( @@ -109,7 +123,7 @@ impl QpqClient { /// Get a reference to the RPC client (for direct calls). pub fn rpc(&self) -> Result<&quicprochat_rpc::client::RpcClient, SdkError> { - self.rpc.as_ref().ok_or(SdkError::NotConnected) + self.rpc.as_deref().ok_or(SdkError::NotConnected) } /// Get a reference to the conversation store. @@ -119,12 +133,70 @@ impl QpqClient { .ok_or(SdkError::NotConnected) } - /// Register a new user account via OPAQUE. + /// Get the current connection state from the RPC layer. + pub fn connection_state(&self) -> ConnectionState { + match &self.rpc { + Some(rpc) => rpc.connection_state(), + None => ConnectionState::Disconnected, + } + } + + /// Start a background heartbeat task that monitors the connection and + /// emits events on state changes. Checks QUIC connection liveness every + /// 30 seconds. If a dead connection is detected, emits a `Disconnected` + /// event. /// - /// Generates a fresh identity keypair, registers it with the server, and - /// stores the identity key locally. + /// Call this after `connect()` to enable proactive dead-connection detection. + pub fn start_heartbeat(&mut self) { + // Cancel any existing heartbeat. + if let Some(h) = self.heartbeat_handle.take() { + h.abort(); + } + + let rpc = match self.rpc.clone() { + Some(rpc) => rpc, + None => return, + }; + + let event_tx = self.event_tx.clone(); + + self.heartbeat_handle = Some(tokio::spawn(async move { + let mut last_state = ConnectionState::Connected; + loop { + tokio::time::sleep(Duration::from_secs(HEARTBEAT_INTERVAL_SECS)).await; + + let alive = rpc.is_alive(); + let current_state = rpc.connection_state(); + + if current_state != last_state { + match current_state { + ConnectionState::Connected => { + let _ = event_tx.send(ClientEvent::Connected); + } + ConnectionState::Reconnecting { attempt } => { + let _ = event_tx.send(ClientEvent::Reconnecting { attempt }); + } + ConnectionState::Disconnected => { + let _ = event_tx.send(ClientEvent::Disconnected { + reason: "connection lost".into(), + }); + } + } + last_state = current_state; + } else if !alive && last_state == ConnectionState::Connected { + // Connection died but RPC layer hasn't noticed yet. + let _ = event_tx.send(ClientEvent::Disconnected { + reason: "heartbeat: connection dead".into(), + }); + last_state = ConnectionState::Disconnected; + } + } + })); + } + + /// Register a new user account via OPAQUE. pub async fn register(&mut self, username: &str, password: &str) -> Result<(), SdkError> { - let rpc = self.rpc.as_ref().ok_or(SdkError::NotConnected)?; + let rpc = self.rpc.as_deref().ok_or(SdkError::NotConnected)?; let keypair = crate::auth::opaque_register(rpc, username, password, None).await?; self.identity_key = Some(keypair.public_key_bytes().to_vec()); self.emit(ClientEvent::Registered { @@ -135,10 +207,6 @@ impl QpqClient { } /// Log in via OPAQUE and store the session token. - /// - /// Requires an identity key to be set (either from a previous `register()` - /// call or loaded from state). After login, the client is authenticated - /// and subsequent RPC calls include the session token. pub async fn login(&mut self, username: &str, password: &str) -> Result<(), SdkError> { let identity_key = self .identity_key @@ -146,7 +214,7 @@ impl QpqClient { .ok_or_else(|| SdkError::AuthFailed("no identity key — register or load state first".into()))? .clone(); - let rpc = self.rpc.as_ref().ok_or(SdkError::NotConnected)?; + let rpc = self.rpc.as_deref().ok_or(SdkError::NotConnected)?; let session_token = crate::auth::opaque_login(rpc, username, password, &identity_key).await?; self.session_token = Some(session_token); @@ -181,8 +249,7 @@ impl QpqClient { // ── Multi-device ───────────────────────────────────────────────────────── - /// Set the device ID for this client. Subsequent fetch/peek/ack calls - /// will include this ID so the server scopes them to the correct queue. + /// Set the device ID for this client. pub fn set_device_id(&mut self, device_id: Vec) { self.device_id = Some(device_id); } @@ -193,13 +260,12 @@ impl QpqClient { } /// Register this device with the server. - /// Sets the local device_id on success. pub async fn register_device( &mut self, device_id: &[u8], device_name: &str, ) -> Result { - let rpc = self.rpc.as_ref().ok_or(SdkError::NotConnected)?; + let rpc = self.rpc.as_deref().ok_or(SdkError::NotConnected)?; let newly_registered = crate::devices::register_device(rpc, device_id, device_name).await?; self.device_id = Some(device_id.to_vec()); @@ -208,13 +274,13 @@ impl QpqClient { /// List all registered devices for this identity. pub async fn list_devices(&self) -> Result, SdkError> { - let rpc = self.rpc.as_ref().ok_or(SdkError::NotConnected)?; + let rpc = self.rpc.as_deref().ok_or(SdkError::NotConnected)?; crate::devices::list_devices(rpc).await } /// Revoke (remove) a registered device. pub async fn revoke_device(&self, device_id: &[u8]) -> Result { - let rpc = self.rpc.as_ref().ok_or(SdkError::NotConnected)?; + let rpc = self.rpc.as_deref().ok_or(SdkError::NotConnected)?; crate::devices::revoke_device(rpc, device_id).await } @@ -258,8 +324,15 @@ impl QpqClient { .map_err(|e| SdkError::Storage(e.to_string())) } - /// Disconnect from the server. + /// Disconnect from the server gracefully. + /// + /// Stops the heartbeat task and closes the QUIC connection. Emits a + /// `Disconnected` event. pub fn disconnect(&mut self) { + // Stop heartbeat first. + if let Some(h) = self.heartbeat_handle.take() { + h.abort(); + } if let Some(rpc) = self.rpc.take() { rpc.close(); self.emit(ClientEvent::Disconnected { diff --git a/crates/quicprochat-sdk/src/events.rs b/crates/quicprochat-sdk/src/events.rs index f6c267d..533ca88 100644 --- a/crates/quicprochat-sdk/src/events.rs +++ b/crates/quicprochat-sdk/src/events.rs @@ -9,6 +9,9 @@ pub enum ClientEvent { /// Disconnected from the server. Disconnected { reason: String }, + /// Connection lost, attempting to reconnect. + Reconnecting { attempt: u32 }, + /// Registration succeeded. Registered { username: String },