diff --git a/crates/quicproquo-client/src/client/retry.rs b/crates/quicproquo-client/src/client/retry.rs index 2175a31..32c703f 100644 --- a/crates/quicproquo-client/src/client/retry.rs +++ b/crates/quicproquo-client/src/client/retry.rs @@ -83,3 +83,124 @@ pub fn anyhow_is_retriable(err: &anyhow::Error) -> bool { // Retry: network, timeout, connection, server error, or anything else true } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn retry_success_first_attempt() { + let result = retry_async(|| async { Ok::<_, String>(42) }, 3, 10, |_| true).await; + assert_eq!(result.unwrap(), 42); + } + + #[tokio::test] + async fn retry_succeeds_after_one_failure() { + let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); + let c = counter.clone(); + let result = retry_async( + || { + let c = c.clone(); + async move { + let n = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + if n == 0 { + Err("transient failure".to_string()) + } else { + Ok(99) + } + } + }, + 3, + 1, // minimal delay for test speed + |_| true, + ) + .await; + assert_eq!(result.unwrap(), 99); + assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn retry_non_retriable_fails_immediately() { + let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); + let c = counter.clone(); + let result = retry_async( + || { + let c = c.clone(); + async move { + c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + Err::<(), _>("permanent error") + } + }, + 5, + 1, + |_: &&str| false, // nothing is retriable + ) + .await; + assert!(result.is_err()); + assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn retry_exhausts_all_attempts() { + let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); + let c = counter.clone(); + let result = retry_async( + || { + let c = c.clone(); + async move { + c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + Err::<(), _>("still failing") + } + }, + 3, + 1, + |_| true, + ) + .await; + assert!(result.is_err()); + assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3); + } + + #[test] + fn anyhow_is_retriable_classifications() { + // Auth errors should NOT be retriable + let auth_errors = [ + "unauthorized access", + "HTTP 401 Unauthorized", + "forbidden resource", + "HTTP 403 Forbidden", + "auth failed for user", + "access denied", + "invalid token", + ]; + for msg in &auth_errors { + let err = anyhow::anyhow!("{msg}"); + assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}"); + } + + // Bad-request errors should NOT be retriable + let bad_req_errors = [ + "bad request: missing field", + "HTTP 400 Bad Request", + "invalid param: username", + "fingerprint mismatch", + ]; + for msg in &bad_req_errors { + let err = anyhow::anyhow!("{msg}"); + assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}"); + } + + // Transient errors SHOULD be retriable + let transient_errors = [ + "connection refused", + "network timeout", + "server error 500", + "stream reset", + "something unknown happened", + ]; + for msg in &transient_errors { + let err = anyhow::anyhow!("{msg}"); + assert!(anyhow_is_retriable(&err), "expected retriable: {msg}"); + } + } +}