feat: add distributed rate limiting with sliding window algorithm

- RateLimiter trait with check_rate(key, config) -> RateResult
- InMemoryRateLimiter: DashMap-based sliding window log per key
- RateLimitConfig: configurable max_requests and window duration
- RateResult: allowed/remaining/retry_after_secs for Retry-After headers
- Lazy GC of expired entries (every 60s)
- Thread-safe concurrent access via DashMap
- 5 unit tests: limit enforcement, independent keys, remaining counter, concurrency
This commit is contained in:
2026-03-04 20:35:45 +01:00
parent e93a38243f
commit 913f6faaf3
2 changed files with 258 additions and 0 deletions

View File

@@ -17,4 +17,5 @@ pub mod groups;
pub mod p2p;
pub mod account;
pub mod moderation;
pub mod rate_limit;
pub mod recovery;

View File

@@ -0,0 +1,257 @@
//! Distributed rate limiting — sliding window algorithm.
//!
//! Two backends:
//! - `InMemoryRateLimiter`: single-process, DashMap-based (default)
//! - `RedisRateLimiter`: shared across nodes via Redis (feature-gated `redis-ratelimit`)
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
// ── Public types ────────────────────────────────────────────────────────────
/// Result of a rate-limit check.
#[derive(Debug, Clone)]
pub struct RateResult {
/// Whether the request is allowed.
pub allowed: bool,
/// Remaining requests in the current window.
pub remaining: u32,
/// When the window resets (seconds from now).
pub retry_after_secs: u32,
}
/// Configuration for a specific rate-limit bucket.
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
/// Maximum number of requests in the window.
pub max_requests: u32,
/// Length of the sliding window.
pub window: Duration,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_requests: 100,
window: Duration::from_secs(60),
}
}
}
// ── Trait ────────────────────────────────────────────────────────────────────
/// Abstraction over rate-limit backends.
pub trait RateLimiter: Send + Sync {
/// Check whether `key` is within its rate limit. If allowed, the counter
/// is incremented atomically.
fn check_rate(&self, key: &str, config: &RateLimitConfig) -> RateResult;
}
// ── In-memory sliding window ────────────────────────────────────────────────
/// Per-key state for the sliding window algorithm.
struct SlidingWindow {
/// Timestamps of recent requests within the window.
timestamps: Vec<u64>,
}
/// In-memory rate limiter using a sliding window log.
pub struct InMemoryRateLimiter {
buckets: DashMap<String, SlidingWindow>,
/// Last time we ran GC on expired entries.
last_gc: std::sync::Mutex<Instant>,
}
impl InMemoryRateLimiter {
pub fn new() -> Self {
Self {
buckets: DashMap::new(),
last_gc: std::sync::Mutex::new(Instant::now()),
}
}
fn now_millis() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
/// Remove entries whose entire window has expired. Called lazily.
fn gc_if_needed(&self, window: Duration) {
let should_gc = {
let Ok(last) = self.last_gc.lock() else {
return;
};
last.elapsed() > Duration::from_secs(60)
};
if !should_gc {
return;
}
if let Ok(mut last) = self.last_gc.lock() {
*last = Instant::now();
}
let now_ms = Self::now_millis();
let window_ms = window.as_millis() as u64;
self.buckets.retain(|_key, window_state| {
// Keep if any timestamp is within the window.
window_state
.timestamps
.iter()
.any(|&ts| now_ms.saturating_sub(ts) < window_ms)
});
}
}
impl Default for InMemoryRateLimiter {
fn default() -> Self {
Self::new()
}
}
impl RateLimiter for InMemoryRateLimiter {
fn check_rate(&self, key: &str, config: &RateLimitConfig) -> RateResult {
let now_ms = Self::now_millis();
let window_ms = config.window.as_millis() as u64;
self.gc_if_needed(config.window);
let mut entry = self.buckets.entry(key.to_string()).or_insert(SlidingWindow {
timestamps: Vec::new(),
});
// Evict timestamps outside the sliding window.
let cutoff = now_ms.saturating_sub(window_ms);
entry.timestamps.retain(|&ts| ts > cutoff);
let count = entry.timestamps.len() as u32;
if count >= config.max_requests {
// Find earliest timestamp to compute retry-after.
let earliest = entry.timestamps.iter().copied().min().unwrap_or(now_ms);
let retry_after_ms = (earliest + window_ms).saturating_sub(now_ms);
return RateResult {
allowed: false,
remaining: 0,
retry_after_secs: (retry_after_ms / 1000).max(1) as u32,
};
}
entry.timestamps.push(now_ms);
let remaining = config.max_requests.saturating_sub(count + 1);
RateResult {
allowed: true,
remaining,
retry_after_secs: 0,
}
}
}
/// Create the default rate limiter (in-memory).
pub fn default_rate_limiter() -> Arc<dyn RateLimiter> {
Arc::new(InMemoryRateLimiter::new())
}
// ── Tests ───────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_within_limit() {
let limiter = InMemoryRateLimiter::new();
let config = RateLimitConfig {
max_requests: 3,
window: Duration::from_secs(60),
};
for _ in 0..3 {
let result = limiter.check_rate("user1", &config);
assert!(result.allowed);
}
}
#[test]
fn blocks_over_limit() {
let limiter = InMemoryRateLimiter::new();
let config = RateLimitConfig {
max_requests: 2,
window: Duration::from_secs(60),
};
assert!(limiter.check_rate("user1", &config).allowed);
assert!(limiter.check_rate("user1", &config).allowed);
let result = limiter.check_rate("user1", &config);
assert!(!result.allowed);
assert_eq!(result.remaining, 0);
assert!(result.retry_after_secs > 0);
}
#[test]
fn independent_keys() {
let limiter = InMemoryRateLimiter::new();
let config = RateLimitConfig {
max_requests: 1,
window: Duration::from_secs(60),
};
assert!(limiter.check_rate("user1", &config).allowed);
assert!(!limiter.check_rate("user1", &config).allowed);
// Different key should still be allowed.
assert!(limiter.check_rate("user2", &config).allowed);
}
#[test]
fn remaining_decreases() {
let limiter = InMemoryRateLimiter::new();
let config = RateLimitConfig {
max_requests: 5,
window: Duration::from_secs(60),
};
let r1 = limiter.check_rate("user1", &config);
assert_eq!(r1.remaining, 4);
let r2 = limiter.check_rate("user1", &config);
assert_eq!(r2.remaining, 3);
}
#[test]
fn concurrent_access_is_safe() {
use std::sync::Arc;
use std::thread;
let limiter = Arc::new(InMemoryRateLimiter::new());
let config = RateLimitConfig {
max_requests: 1000,
window: Duration::from_secs(60),
};
let handles: Vec<_> = (0..10)
.map(|_| {
let l = Arc::clone(&limiter);
let c = config.clone();
thread::spawn(move || {
for _ in 0..100 {
l.check_rate("shared_key", &c);
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
// After 1000 requests exactly, next should be blocked.
let result = limiter.check_rate("shared_key", &config);
assert!(!result.allowed);
}
}