Rename all crate directories, package names, binary names, proto package/module paths, ALPN strings, env var prefixes, config filenames, mDNS service names, and plugin ABI symbols from quicproquo/qpq to quicprochat/qpc.
208 lines
6.6 KiB
Rust
208 lines
6.6 KiB
Rust
//! Retry with exponential backoff for transient RPC failures.
|
|
|
|
use std::future::Future;
|
|
use std::time::Duration;
|
|
|
|
use rand::Rng;
|
|
use tracing::warn;
|
|
|
|
/// Default maximum number of retry attempts (including the first try).
|
|
pub const DEFAULT_MAX_RETRIES: u32 = 3;
|
|
/// Default base delay in milliseconds for exponential backoff.
|
|
pub const DEFAULT_BASE_DELAY_MS: u64 = 500;
|
|
|
|
/// Runs an async operation with retries. On `Ok(t)` returns immediately.
|
|
/// On `Err(e)`: if `is_retriable(&e)` and `attempt < max_retries`, sleeps with
|
|
/// exponential backoff (plus jitter) then retries; otherwise returns the last error.
|
|
pub async fn retry_async<F, Fut, T, E, P>(
|
|
op: F,
|
|
max_retries: u32,
|
|
base_delay_ms: u64,
|
|
is_retriable: P,
|
|
) -> Result<T, E>
|
|
where
|
|
F: Fn() -> Fut,
|
|
Fut: Future<Output = Result<T, E>>,
|
|
P: Fn(&E) -> bool,
|
|
{
|
|
let mut last_err: Option<E> = None;
|
|
for attempt in 0..max_retries {
|
|
match op().await {
|
|
Ok(t) => return Ok(t),
|
|
Err(e) => {
|
|
if !is_retriable(&e) || attempt + 1 >= max_retries {
|
|
return Err(e);
|
|
}
|
|
let delay_ms = base_delay_ms * 2u64.saturating_pow(attempt);
|
|
let jitter_ms = rand::thread_rng().gen_range(0..=delay_ms / 2);
|
|
let total_ms = delay_ms + jitter_ms;
|
|
warn!(
|
|
attempt = attempt + 1,
|
|
max_retries,
|
|
delay_ms = total_ms,
|
|
"RPC failed, retrying after backoff"
|
|
);
|
|
last_err = Some(e);
|
|
tokio::time::sleep(Duration::from_millis(total_ms)).await;
|
|
}
|
|
}
|
|
}
|
|
match last_err {
|
|
Some(e) => Err(e),
|
|
None => unreachable!(
|
|
"retry_async: last_err is always Some when loop exits after an Err"
|
|
),
|
|
}
|
|
}
|
|
|
|
/// Classifies `anyhow::Error` for retry: returns `false` for auth or invalid-param
|
|
/// errors (do not retry), `true` for transient errors (network, timeout, server 5xx).
|
|
/// When in doubt, returns `true` (retry).
|
|
pub fn anyhow_is_retriable(err: &anyhow::Error) -> bool {
|
|
let s = format!("{:#}", err);
|
|
let s_lower = s.to_lowercase();
|
|
// Do not retry: auth / permission
|
|
if s_lower.contains("unauthorized")
|
|
|| s_lower.contains("auth failed")
|
|
|| s_lower.contains("access denied")
|
|
|| s_lower.contains("401")
|
|
|| s_lower.contains("forbidden")
|
|
|| s_lower.contains("403")
|
|
|| s_lower.contains("token")
|
|
{
|
|
return false;
|
|
}
|
|
// Do not retry: bad request / invalid params
|
|
if s_lower.contains("bad request")
|
|
|| s_lower.contains("400")
|
|
|| s_lower.contains("invalid param")
|
|
|| s_lower.contains("fingerprint mismatch")
|
|
{
|
|
return false;
|
|
}
|
|
// Retry: network, timeout, connection, server error, or anything else
|
|
true
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[allow(clippy::unwrap_used)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn retry_success_first_attempt() {
|
|
let result = retry_async(|| async { Ok::<_, String>(42) }, 3, 10, |_| true).await;
|
|
assert_eq!(result.unwrap(), 42);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn retry_succeeds_after_one_failure() {
|
|
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
|
|
let c = counter.clone();
|
|
let result = retry_async(
|
|
|| {
|
|
let c = c.clone();
|
|
async move {
|
|
let n = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
|
if n == 0 {
|
|
Err("transient failure".to_string())
|
|
} else {
|
|
Ok(99)
|
|
}
|
|
}
|
|
},
|
|
3,
|
|
1, // minimal delay for test speed
|
|
|_| true,
|
|
)
|
|
.await;
|
|
assert_eq!(result.unwrap(), 99);
|
|
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn retry_non_retriable_fails_immediately() {
|
|
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
|
|
let c = counter.clone();
|
|
let result = retry_async(
|
|
|| {
|
|
let c = c.clone();
|
|
async move {
|
|
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
|
Err::<(), _>("permanent error")
|
|
}
|
|
},
|
|
5,
|
|
1,
|
|
|_: &&str| false, // nothing is retriable
|
|
)
|
|
.await;
|
|
assert!(result.is_err());
|
|
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn retry_exhausts_all_attempts() {
|
|
let counter = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
|
|
let c = counter.clone();
|
|
let result = retry_async(
|
|
|| {
|
|
let c = c.clone();
|
|
async move {
|
|
c.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
|
Err::<(), _>("still failing")
|
|
}
|
|
},
|
|
3,
|
|
1,
|
|
|_| true,
|
|
)
|
|
.await;
|
|
assert!(result.is_err());
|
|
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3);
|
|
}
|
|
|
|
#[test]
|
|
fn anyhow_is_retriable_classifications() {
|
|
// Auth errors should NOT be retriable
|
|
let auth_errors = [
|
|
"unauthorized access",
|
|
"HTTP 401 Unauthorized",
|
|
"forbidden resource",
|
|
"HTTP 403 Forbidden",
|
|
"auth failed for user",
|
|
"access denied",
|
|
"invalid token",
|
|
];
|
|
for msg in &auth_errors {
|
|
let err = anyhow::anyhow!("{msg}");
|
|
assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}");
|
|
}
|
|
|
|
// Bad-request errors should NOT be retriable
|
|
let bad_req_errors = [
|
|
"bad request: missing field",
|
|
"HTTP 400 Bad Request",
|
|
"invalid param: username",
|
|
"fingerprint mismatch",
|
|
];
|
|
for msg in &bad_req_errors {
|
|
let err = anyhow::anyhow!("{msg}");
|
|
assert!(!anyhow_is_retriable(&err), "expected non-retriable: {msg}");
|
|
}
|
|
|
|
// Transient errors SHOULD be retriable
|
|
let transient_errors = [
|
|
"connection refused",
|
|
"network timeout",
|
|
"server error 500",
|
|
"stream reset",
|
|
"something unknown happened",
|
|
];
|
|
for msg in &transient_errors {
|
|
let err = anyhow::anyhow!("{msg}");
|
|
assert!(anyhow_is_retriable(&err), "expected retriable: {msg}");
|
|
}
|
|
}
|
|
}
|