test: add unit tests for retry logic and retriable classifier
This commit is contained in:
@@ -83,3 +83,124 @@ pub fn anyhow_is_retriable(err: &anyhow::Error) -> bool {
|
|||||||
// Retry: network, timeout, connection, server error, or anything else
|
// Retry: network, timeout, connection, server error, or anything else
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retry_success_first_attempt() {
|
||||||
|
let result = retry_async(|| async { Ok::<_, String>(42) }, 3, 10, |_| true).await;
|
||||||
|
assert_eq!(result.unwrap(), 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retry_succeeds_after_one_failure() {
|
||||||
|
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
|
||||||
|
let c = counter.clone();
|
||||||
|
let result = retry_async(
|
||||||
|
|| {
|
||||||
|
let c = c.clone();
|
||||||
|
async move {
|
||||||
|
let n = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||||
|
if n == 0 {
|
||||||
|
Err("transient failure".to_string())
|
||||||
|
} else {
|
||||||
|
Ok(99)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
3,
|
||||||
|
1, // minimal delay for test speed
|
||||||
|
|_| true,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
assert_eq!(result.unwrap(), 99);
|
||||||
|
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retry_non_retriable_fails_immediately() {
|
||||||
|
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
|
||||||
|
let c = counter.clone();
|
||||||
|
let result = retry_async(
|
||||||
|
|| {
|
||||||
|
let c = c.clone();
|
||||||
|
async move {
|
||||||
|
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||||
|
Err::<(), _>("permanent error")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
5,
|
||||||
|
1,
|
||||||
|
|_: &&str| false, // nothing is retriable
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retry_exhausts_all_attempts() {
|
||||||
|
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
|
||||||
|
let c = counter.clone();
|
||||||
|
let result = retry_async(
|
||||||
|
|| {
|
||||||
|
let c = c.clone();
|
||||||
|
async move {
|
||||||
|
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||||
|
Err::<(), _>("still failing")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
3,
|
||||||
|
1,
|
||||||
|
|_| true,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn anyhow_is_retriable_classifications() {
|
||||||
|
// Auth errors should NOT be retriable
|
||||||
|
let auth_errors = [
|
||||||
|
"unauthorized access",
|
||||||
|
"HTTP 401 Unauthorized",
|
||||||
|
"forbidden resource",
|
||||||
|
"HTTP 403 Forbidden",
|
||||||
|
"auth failed for user",
|
||||||
|
"access denied",
|
||||||
|
"invalid token",
|
||||||
|
];
|
||||||
|
for msg in &auth_errors {
|
||||||
|
let err = anyhow::anyhow!("{msg}");
|
||||||
|
assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bad-request errors should NOT be retriable
|
||||||
|
let bad_req_errors = [
|
||||||
|
"bad request: missing field",
|
||||||
|
"HTTP 400 Bad Request",
|
||||||
|
"invalid param: username",
|
||||||
|
"fingerprint mismatch",
|
||||||
|
];
|
||||||
|
for msg in &bad_req_errors {
|
||||||
|
let err = anyhow::anyhow!("{msg}");
|
||||||
|
assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transient errors SHOULD be retriable
|
||||||
|
let transient_errors = [
|
||||||
|
"connection refused",
|
||||||
|
"network timeout",
|
||||||
|
"server error 500",
|
||||||
|
"stream reset",
|
||||||
|
"something unknown happened",
|
||||||
|
];
|
||||||
|
for msg in &transient_errors {
|
||||||
|
let err = anyhow::anyhow!("{msg}");
|
||||||
|
assert!(anyhow_is_retriable(&err), "expected retriable: {msg}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user