//! Retry with exponential backoff for transient RPC failures. use std::future::Future; use std::time::Duration; use rand::Rng; use tracing::warn; /// Default maximum number of retry attempts (including the first try). pub const DEFAULT_MAX_RETRIES: u32 = 3; /// Default base delay in milliseconds for exponential backoff. pub const DEFAULT_BASE_DELAY_MS: u64 = 500; /// Runs an async operation with retries. On `Ok(t)` returns immediately. /// On `Err(e)`: if `is_retriable(&e)` and `attempt < max_retries`, sleeps with /// exponential backoff (plus jitter) then retries; otherwise returns the last error. pub async fn retry_async( op: F, max_retries: u32, base_delay_ms: u64, is_retriable: P, ) -> Result where F: Fn() -> Fut, Fut: Future>, P: Fn(&E) -> bool, { let mut last_err: Option = None; for attempt in 0..max_retries { match op().await { Ok(t) => return Ok(t), Err(e) => { if !is_retriable(&e) || attempt + 1 >= max_retries { return Err(e); } let delay_ms = base_delay_ms * 2u64.saturating_pow(attempt); let jitter_ms = rand::thread_rng().gen_range(0..=delay_ms / 2); let total_ms = delay_ms + jitter_ms; warn!( attempt = attempt + 1, max_retries, delay_ms = total_ms, "RPC failed, retrying after backoff" ); last_err = Some(e); tokio::time::sleep(Duration::from_millis(total_ms)).await; } } } match last_err { Some(e) => Err(e), None => unreachable!( "retry_async: last_err is always Some when loop exits after an Err" ), } } /// Classifies `anyhow::Error` for retry: returns `false` for auth or invalid-param /// errors (do not retry), `true` for transient errors (network, timeout, server 5xx). /// When in doubt, returns `true` (retry). pub fn anyhow_is_retriable(err: &anyhow::Error) -> bool { let s = format!("{:#}", err); let s_lower = s.to_lowercase(); // Do not retry: auth / permission if s_lower.contains("unauthorized") || s_lower.contains("auth failed") || s_lower.contains("access denied") || s_lower.contains("401") || s_lower.contains("forbidden") || s_lower.contains("403") || s_lower.contains("token") { return false; } // Do not retry: bad request / invalid params if s_lower.contains("bad request") || s_lower.contains("400") || s_lower.contains("invalid param") || s_lower.contains("fingerprint mismatch") { return false; } // Retry: network, timeout, connection, server error, or anything else true } #[cfg(test)] #[allow(clippy::unwrap_used)] mod tests { use super::*; #[tokio::test] async fn retry_success_first_attempt() { let result = retry_async(|| async { Ok::<_, String>(42) }, 3, 10, |_| true).await; assert_eq!(result.unwrap(), 42); } #[tokio::test] async fn retry_succeeds_after_one_failure() { let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); let c = counter.clone(); let result = retry_async( || { let c = c.clone(); async move { let n = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); if n == 0 { Err("transient failure".to_string()) } else { Ok(99) } } }, 3, 1, // minimal delay for test speed |_| true, ) .await; assert_eq!(result.unwrap(), 99); assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 2); } #[tokio::test] async fn retry_non_retriable_fails_immediately() { let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); let c = counter.clone(); let result = retry_async( || { let c = c.clone(); async move { c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); Err::<(), _>("permanent error") } }, 5, 1, |_: &&str| false, // nothing is retriable ) .await; assert!(result.is_err()); assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1); } #[tokio::test] async fn retry_exhausts_all_attempts() { let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); let c = counter.clone(); let result = retry_async( || { let c = c.clone(); async move { c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); Err::<(), _>("still failing") } }, 3, 1, |_| true, ) .await; assert!(result.is_err()); assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3); } #[test] fn anyhow_is_retriable_classifications() { // Auth errors should NOT be retriable let auth_errors = [ "unauthorized access", "HTTP 401 Unauthorized", "forbidden resource", "HTTP 403 Forbidden", "auth failed for user", "access denied", "invalid token", ]; for msg in &auth_errors { let err = anyhow::anyhow!("{msg}"); assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}"); } // Bad-request errors should NOT be retriable let bad_req_errors = [ "bad request: missing field", "HTTP 400 Bad Request", "invalid param: username", "fingerprint mismatch", ]; for msg in &bad_req_errors { let err = anyhow::anyhow!("{msg}"); assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}"); } // Transient errors SHOULD be retriable let transient_errors = [ "connection refused", "network timeout", "server error 500", "stream reset", "something unknown happened", ]; for msg in &transient_errors { let err = anyhow::anyhow!("{msg}"); assert!(anyhow_is_retriable(&err), "expected retriable: {msg}"); } } }