test: add unit tests for retry logic and retriable classifier
This commit is contained in:
@@ -83,3 +83,124 @@ pub fn anyhow_is_retriable(err: &anyhow::Error) -> bool {
|
||||
// Retry: network, timeout, connection, server error, or anything else
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn retry_success_first_attempt() {
|
||||
let result = retry_async(|| async { Ok::<_, String>(42) }, 3, 10, |_| true).await;
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retry_succeeds_after_one_failure() {
|
||||
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
|
||||
let c = counter.clone();
|
||||
let result = retry_async(
|
||||
|| {
|
||||
let c = c.clone();
|
||||
async move {
|
||||
let n = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
if n == 0 {
|
||||
Err("transient failure".to_string())
|
||||
} else {
|
||||
Ok(99)
|
||||
}
|
||||
}
|
||||
},
|
||||
3,
|
||||
1, // minimal delay for test speed
|
||||
|_| true,
|
||||
)
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), 99);
|
||||
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retry_non_retriable_fails_immediately() {
|
||||
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
|
||||
let c = counter.clone();
|
||||
let result = retry_async(
|
||||
|| {
|
||||
let c = c.clone();
|
||||
async move {
|
||||
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
Err::<(), _>("permanent error")
|
||||
}
|
||||
},
|
||||
5,
|
||||
1,
|
||||
|_: &&str| false, // nothing is retriable
|
||||
)
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retry_exhausts_all_attempts() {
|
||||
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
|
||||
let c = counter.clone();
|
||||
let result = retry_async(
|
||||
|| {
|
||||
let c = c.clone();
|
||||
async move {
|
||||
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
Err::<(), _>("still failing")
|
||||
}
|
||||
},
|
||||
3,
|
||||
1,
|
||||
|_| true,
|
||||
)
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn anyhow_is_retriable_classifications() {
|
||||
// Auth errors should NOT be retriable
|
||||
let auth_errors = [
|
||||
"unauthorized access",
|
||||
"HTTP 401 Unauthorized",
|
||||
"forbidden resource",
|
||||
"HTTP 403 Forbidden",
|
||||
"auth failed for user",
|
||||
"access denied",
|
||||
"invalid token",
|
||||
];
|
||||
for msg in &auth_errors {
|
||||
let err = anyhow::anyhow!("{msg}");
|
||||
assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}");
|
||||
}
|
||||
|
||||
// Bad-request errors should NOT be retriable
|
||||
let bad_req_errors = [
|
||||
"bad request: missing field",
|
||||
"HTTP 400 Bad Request",
|
||||
"invalid param: username",
|
||||
"fingerprint mismatch",
|
||||
];
|
||||
for msg in &bad_req_errors {
|
||||
let err = anyhow::anyhow!("{msg}");
|
||||
assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}");
|
||||
}
|
||||
|
||||
// Transient errors SHOULD be retriable
|
||||
let transient_errors = [
|
||||
"connection refused",
|
||||
"network timeout",
|
||||
"server error 500",
|
||||
"stream reset",
|
||||
"something unknown happened",
|
||||
];
|
||||
for msg in &transient_errors {
|
||||
let err = anyhow::anyhow!("{msg}");
|
||||
assert!(anyhow_is_retriable(&err), "expected retriable: {msg}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user