test: add unit tests for retry logic and retriable classifier

This commit is contained in:
2026-03-04 13:31:16 +01:00
parent a3f67aca45
commit 75f11cb76b

View File

@@ -83,3 +83,124 @@ pub fn anyhow_is_retriable(err: &anyhow::Error) -> bool {
// Retry: network, timeout, connection, server error, or anything else
true
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn retry_success_first_attempt() {
let result = retry_async(|| async { Ok::<_, String>(42) }, 3, 10, |_| true).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn retry_succeeds_after_one_failure() {
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let c = counter.clone();
let result = retry_async(
|| {
let c = c.clone();
async move {
let n = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if n == 0 {
Err("transient failure".to_string())
} else {
Ok(99)
}
}
},
3,
1, // minimal delay for test speed
|_| true,
)
.await;
assert_eq!(result.unwrap(), 99);
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 2);
}
#[tokio::test]
async fn retry_non_retriable_fails_immediately() {
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let c = counter.clone();
let result = retry_async(
|| {
let c = c.clone();
async move {
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Err::<(), _>("permanent error")
}
},
5,
1,
|_: &&str| false, // nothing is retriable
)
.await;
assert!(result.is_err());
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[tokio::test]
async fn retry_exhausts_all_attempts() {
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
let c = counter.clone();
let result = retry_async(
|| {
let c = c.clone();
async move {
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Err::<(), _>("still failing")
}
},
3,
1,
|_| true,
)
.await;
assert!(result.is_err());
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3);
}
#[test]
fn anyhow_is_retriable_classifications() {
// Auth errors should NOT be retriable
let auth_errors = [
"unauthorized access",
"HTTP 401 Unauthorized",
"forbidden resource",
"HTTP 403 Forbidden",
"auth failed for user",
"access denied",
"invalid token",
];
for msg in &auth_errors {
let err = anyhow::anyhow!("{msg}");
assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}");
}
// Bad-request errors should NOT be retriable
let bad_req_errors = [
"bad request: missing field",
"HTTP 400 Bad Request",
"invalid param: username",
"fingerprint mismatch",
];
for msg in &bad_req_errors {
let err = anyhow::anyhow!("{msg}");
assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}");
}
// Transient errors SHOULD be retriable
let transient_errors = [
"connection refused",
"network timeout",
"server error 500",
"stream reset",
"something unknown happened",
];
for msg in &transient_errors {
let err = anyhow::anyhow!("{msg}");
assert!(anyhow_is_retriable(&err), "expected retriable: {msg}");
}
}
}