"""JTI replay cache for production verification.""" from __future__ import annotations import threading import time from abc import ABC, abstractmethod class JTICache(ABC): @abstractmethod def seen(self, jti: str) -> bool: pass @abstractmethod def add(self, jti: str) -> None: pass class _MemoryJTICache(JTICache): def __init__(self, max_size: int, ttl_sec: int) -> None: self._max_size = max_size self._ttl_sec = ttl_sec self._by_jti: dict[str, float] = {} self._lock = threading.RLock() def seen(self, jti: str) -> bool: with self._lock: exp = self._by_jti.get(jti) if exp is None: return False if time.time() > exp: del self._by_jti[jti] return False return True def add(self, jti: str) -> None: with self._lock: now = time.time() for k, exp in list(self._by_jti.items()): if now > exp: del self._by_jti[k] if self._max_size > 0 and len(self._by_jti) >= self._max_size and jti not in self._by_jti: # evict one for k in self._by_jti: del self._by_jti[k] break self._by_jti[jti] = now + self._ttl_sec def new_jti_cache(max_size: int, ttl_sec: int) -> JTICache: return _MemoryJTICache(max_size, ttl_sec)