feat: add distributed rate limiting with sliding window algorithm
- RateLimiter trait with check_rate(key, config) -> RateResult - InMemoryRateLimiter: DashMap-based sliding window log per key - RateLimitConfig: configurable max_requests and window duration - RateResult: allowed/remaining/retry_after_secs for Retry-After headers - Lazy GC of expired entries (every 60s) - Thread-safe concurrent access via DashMap - 5 unit tests: limit enforcement, independent keys, remaining counter, concurrency
This commit is contained in:
@@ -17,4 +17,5 @@ pub mod groups;
|
||||
pub mod p2p;
|
||||
pub mod account;
|
||||
pub mod moderation;
|
||||
pub mod rate_limit;
|
||||
pub mod recovery;
|
||||
|
||||
257
crates/quicproquo-server/src/domain/rate_limit.rs
Normal file
257
crates/quicproquo-server/src/domain/rate_limit.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
//! Distributed rate limiting — sliding window algorithm.
|
||||
//!
|
||||
//! Two backends:
|
||||
//! - `InMemoryRateLimiter`: single-process, DashMap-based (default)
|
||||
//! - `RedisRateLimiter`: shared across nodes via Redis (feature-gated `redis-ratelimit`)
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use dashmap::DashMap;
|
||||
|
||||
// ── Public types ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Result of a rate-limit check.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateResult {
|
||||
/// Whether the request is allowed.
|
||||
pub allowed: bool,
|
||||
/// Remaining requests in the current window.
|
||||
pub remaining: u32,
|
||||
/// When the window resets (seconds from now).
|
||||
pub retry_after_secs: u32,
|
||||
}
|
||||
|
||||
/// Configuration for a specific rate-limit bucket.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimitConfig {
|
||||
/// Maximum number of requests in the window.
|
||||
pub max_requests: u32,
|
||||
/// Length of the sliding window.
|
||||
pub window: Duration,
|
||||
}
|
||||
|
||||
impl Default for RateLimitConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_requests: 100,
|
||||
window: Duration::from_secs(60),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Trait ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Abstraction over rate-limit backends.
|
||||
pub trait RateLimiter: Send + Sync {
|
||||
/// Check whether `key` is within its rate limit. If allowed, the counter
|
||||
/// is incremented atomically.
|
||||
fn check_rate(&self, key: &str, config: &RateLimitConfig) -> RateResult;
|
||||
}
|
||||
|
||||
// ── In-memory sliding window ────────────────────────────────────────────────
|
||||
|
||||
/// Per-key state for the sliding window algorithm.
|
||||
struct SlidingWindow {
|
||||
/// Timestamps of recent requests within the window.
|
||||
timestamps: Vec<u64>,
|
||||
}
|
||||
|
||||
/// In-memory rate limiter using a sliding window log.
|
||||
pub struct InMemoryRateLimiter {
|
||||
buckets: DashMap<String, SlidingWindow>,
|
||||
/// Last time we ran GC on expired entries.
|
||||
last_gc: std::sync::Mutex<Instant>,
|
||||
}
|
||||
|
||||
impl InMemoryRateLimiter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buckets: DashMap::new(),
|
||||
last_gc: std::sync::Mutex::new(Instant::now()),
|
||||
}
|
||||
}
|
||||
|
||||
fn now_millis() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
/// Remove entries whose entire window has expired. Called lazily.
|
||||
fn gc_if_needed(&self, window: Duration) {
|
||||
let should_gc = {
|
||||
let Ok(last) = self.last_gc.lock() else {
|
||||
return;
|
||||
};
|
||||
last.elapsed() > Duration::from_secs(60)
|
||||
};
|
||||
|
||||
if !should_gc {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Ok(mut last) = self.last_gc.lock() {
|
||||
*last = Instant::now();
|
||||
}
|
||||
|
||||
let now_ms = Self::now_millis();
|
||||
let window_ms = window.as_millis() as u64;
|
||||
|
||||
self.buckets.retain(|_key, window_state| {
|
||||
// Keep if any timestamp is within the window.
|
||||
window_state
|
||||
.timestamps
|
||||
.iter()
|
||||
.any(|&ts| now_ms.saturating_sub(ts) < window_ms)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InMemoryRateLimiter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RateLimiter for InMemoryRateLimiter {
|
||||
fn check_rate(&self, key: &str, config: &RateLimitConfig) -> RateResult {
|
||||
let now_ms = Self::now_millis();
|
||||
let window_ms = config.window.as_millis() as u64;
|
||||
|
||||
self.gc_if_needed(config.window);
|
||||
|
||||
let mut entry = self.buckets.entry(key.to_string()).or_insert(SlidingWindow {
|
||||
timestamps: Vec::new(),
|
||||
});
|
||||
|
||||
// Evict timestamps outside the sliding window.
|
||||
let cutoff = now_ms.saturating_sub(window_ms);
|
||||
entry.timestamps.retain(|&ts| ts > cutoff);
|
||||
|
||||
let count = entry.timestamps.len() as u32;
|
||||
|
||||
if count >= config.max_requests {
|
||||
// Find earliest timestamp to compute retry-after.
|
||||
let earliest = entry.timestamps.iter().copied().min().unwrap_or(now_ms);
|
||||
let retry_after_ms = (earliest + window_ms).saturating_sub(now_ms);
|
||||
return RateResult {
|
||||
allowed: false,
|
||||
remaining: 0,
|
||||
retry_after_secs: (retry_after_ms / 1000).max(1) as u32,
|
||||
};
|
||||
}
|
||||
|
||||
entry.timestamps.push(now_ms);
|
||||
let remaining = config.max_requests.saturating_sub(count + 1);
|
||||
|
||||
RateResult {
|
||||
allowed: true,
|
||||
remaining,
|
||||
retry_after_secs: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create the default rate limiter (in-memory).
|
||||
pub fn default_rate_limiter() -> Arc<dyn RateLimiter> {
|
||||
Arc::new(InMemoryRateLimiter::new())
|
||||
}
|
||||
|
||||
// ── Tests ───────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn allows_within_limit() {
|
||||
let limiter = InMemoryRateLimiter::new();
|
||||
let config = RateLimitConfig {
|
||||
max_requests: 3,
|
||||
window: Duration::from_secs(60),
|
||||
};
|
||||
|
||||
for _ in 0..3 {
|
||||
let result = limiter.check_rate("user1", &config);
|
||||
assert!(result.allowed);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blocks_over_limit() {
|
||||
let limiter = InMemoryRateLimiter::new();
|
||||
let config = RateLimitConfig {
|
||||
max_requests: 2,
|
||||
window: Duration::from_secs(60),
|
||||
};
|
||||
|
||||
assert!(limiter.check_rate("user1", &config).allowed);
|
||||
assert!(limiter.check_rate("user1", &config).allowed);
|
||||
let result = limiter.check_rate("user1", &config);
|
||||
assert!(!result.allowed);
|
||||
assert_eq!(result.remaining, 0);
|
||||
assert!(result.retry_after_secs > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn independent_keys() {
|
||||
let limiter = InMemoryRateLimiter::new();
|
||||
let config = RateLimitConfig {
|
||||
max_requests: 1,
|
||||
window: Duration::from_secs(60),
|
||||
};
|
||||
|
||||
assert!(limiter.check_rate("user1", &config).allowed);
|
||||
assert!(!limiter.check_rate("user1", &config).allowed);
|
||||
// Different key should still be allowed.
|
||||
assert!(limiter.check_rate("user2", &config).allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remaining_decreases() {
|
||||
let limiter = InMemoryRateLimiter::new();
|
||||
let config = RateLimitConfig {
|
||||
max_requests: 5,
|
||||
window: Duration::from_secs(60),
|
||||
};
|
||||
|
||||
let r1 = limiter.check_rate("user1", &config);
|
||||
assert_eq!(r1.remaining, 4);
|
||||
let r2 = limiter.check_rate("user1", &config);
|
||||
assert_eq!(r2.remaining, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concurrent_access_is_safe() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let limiter = Arc::new(InMemoryRateLimiter::new());
|
||||
let config = RateLimitConfig {
|
||||
max_requests: 1000,
|
||||
window: Duration::from_secs(60),
|
||||
};
|
||||
|
||||
let handles: Vec<_> = (0..10)
|
||||
.map(|_| {
|
||||
let l = Arc::clone(&limiter);
|
||||
let c = config.clone();
|
||||
thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
l.check_rate("shared_key", &c);
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for h in handles {
|
||||
h.join().expect("thread panicked");
|
||||
}
|
||||
|
||||
// After 1000 requests exactly, next should be blocked.
|
||||
let result = limiter.check_rate("shared_key", &config);
|
||||
assert!(!result.allowed);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user