diff --git a/backend/secuscan/cache.py b/backend/secuscan/cache.py index a517b867..a42c3edb 100644 --- a/backend/secuscan/cache.py +++ b/backend/secuscan/cache.py @@ -5,45 +5,89 @@ import json from typing import Any, Optional, Dict import time +import logging from .config import settings +logger = logging.getLogger(__name__) + +DEFAULT_MAX_ENTRIES = 10_000 +SWEEP_EVICT_FRACTION = 0.25 +OPPORTUNISTIC_SWEEP_INTERVAL = 50 + class CacheClient: - """In-memory dictionary based cache client.""" + """In-memory dictionary based cache client with TTL, size limit, and LRU eviction.""" - def __init__(self, url: Optional[str] = None): + def __init__(self, url: Optional[str] = None, max_entries: int = DEFAULT_MAX_ENTRIES): self.url = url self._data: Dict[str, Any] = {} self._expires: Dict[str, float] = {} + self._access_order: Dict[str, float] = {} + self.max_entries = max_entries + self._eviction_count = 0 + self._sweep_count = 0 + self._write_count = 0 async def connect(self): - """No connection needed for in-memory cache.""" pass async def disconnect(self): - """Clear cache on disconnect.""" self._data.clear() self._expires.clear() + self._access_order.clear() + + def _sweep_expired(self): + now = time.time() + keys = [k for k, exp in list(self._expires.items()) if exp <= now] + for k in keys: + self._data.pop(k, None) + self._expires.pop(k, None) + self._access_order.pop(k, None) + if keys: + self._sweep_count += len(keys) + + def _evict_lru(self): + """Evict the least recently used entries when over capacity.""" + if len(self._data) < self.max_entries: + return + sorted_keys = sorted(self._access_order, key=lambda k: self._access_order[k]) + evict_count = max(1, int(self.max_entries * SWEEP_EVICT_FRACTION)) + for k in sorted_keys[:evict_count]: + self._data.pop(k, None) + self._expires.pop(k, None) + self._access_order.pop(k, None) + self._eviction_count += evict_count async def get_json(self, key: str) -> Optional[Any]: """Retrieve and parse JSON from memory, respecting TTL.""" now = time.time() expiry = self._expires.get(key) - + if expiry and now > expiry: - # Clean up expired item self._data.pop(key, None) self._expires.pop(key, None) + self._access_order.pop(key, None) return None - + + if key in self._data: + self._access_order[key] = now + return self._data.get(key) async def set_json(self, key: str, value: Any, ttl: Optional[int] = None): """Store value in memory with optional TTL.""" + if len(self._data) >= self.max_entries and key not in self._data: + self._evict_lru() + self._data[key] = value actual_ttl = ttl or settings.cache_ttl_seconds self._expires[key] = time.time() + actual_ttl + self._access_order[key] = time.time() + self._write_count += 1 + + if self._write_count % OPPORTUNISTIC_SWEEP_INTERVAL == 0: + self._sweep_expired() async def delete_prefix(self, prefix: str): """Delete all keys starting with prefix.""" @@ -51,6 +95,20 @@ async def delete_prefix(self, prefix: str): for k in to_delete: self._data.pop(k, None) self._expires.pop(k, None) + self._access_order.pop(k, None) + + @property + def size(self) -> int: + return len(self._data) + + @property + def stats(self) -> Dict[str, Any]: + return { + "size": self.size, + "max_entries": self.max_entries, + "eviction_count": self._eviction_count, + "sweep_count": self._sweep_count, + } # Global cache instance diff --git a/testing/backend/unit/test_cache_helpers.py b/testing/backend/unit/test_cache_helpers.py index eff67e73..27e0cd48 100644 --- a/testing/backend/unit/test_cache_helpers.py +++ b/testing/backend/unit/test_cache_helpers.py @@ -1,4 +1,5 @@ import asyncio +import time from unittest.mock import AsyncMock, patch from backend.secuscan.cache import CacheClient @@ -95,3 +96,158 @@ async def run(): result = _run(run()) assert result is None + + +# --------------------------------------------------------------------------- +# LRU eviction order +# --------------------------------------------------------------------------- + + +def test_lru_eviction_evicts_oldest_when_over_capacity(): + cache = CacheClient(max_entries=3) + + async def run(): + await cache.set_json("key:1", "val1") + await cache.set_json("key:2", "val2") + await cache.set_json("key:3", "val3") + await cache.set_json("key:4", "val4") + + _run(run()) + assert cache.size == 3 + assert _run(cache.get_json("key:1")) is None + assert _run(cache.get_json("key:2")) == "val2" + assert _run(cache.get_json("key:3")) == "val3" + assert _run(cache.get_json("key:4")) == "val4" + + +def test_lru_eviction_skips_when_under_capacity(): + cache = CacheClient(max_entries=5) + + async def run(): + await cache.set_json("key:1", "val1") + await cache.set_json("key:2", "val2") + + _run(run()) + assert cache.size == 2 + assert _run(cache.get_json("key:1")) == "val1" + assert cache._eviction_count == 0 + + +def test_lru_eviction_preserves_recently_accessed(): + cache = CacheClient(max_entries=3) + + async def run(): + await cache.set_json("key:1", "val1") + await cache.set_json("key:2", "val2") + await cache.set_json("key:3", "val3") + cache._access_order["key:2"] = 1.0 # Oldest + cache._access_order["key:3"] = 2.0 # Middle + await cache.get_json("key:1") # Refreshes to ~now (most recent) + await cache.set_json("key:4", "val4") + + _run(run()) + assert cache.size == 3 + assert _run(cache.get_json("key:1")) == "val1" # Recently accessed, preserved + assert _run(cache.get_json("key:2")) is None # Oldest, evicted + assert _run(cache.get_json("key:4")) == "val4" + + +# --------------------------------------------------------------------------- +# Expiry cleanup +# --------------------------------------------------------------------------- + + +def test_expiry_sweep_removes_access_order_entries(): + cache = CacheClient() + + async def run(): + await cache.set_json("key:1", "val1", ttl=10) + await cache.set_json("key:2", "val2", ttl=10) + cache._expires["key:1"] = time.time() - 1 + cache._expires["key:2"] = time.time() - 1 + cache._sweep_expired() + + _run(run()) + assert "key:1" not in cache._data + assert "key:1" not in cache._expires + assert "key:1" not in cache._access_order + assert cache._sweep_count == 2 + + +def test_expired_entry_get_returns_none_and_cleans_access_order(): + cache = CacheClient() + + async def run(): + await cache.set_json("key:1", "val1", ttl=10) + cache._expires["key:1"] = time.time() - 1 + result = await cache.get_json("key:1") + return result + + result = _run(run()) + assert result is None + assert "key:1" not in cache._data + assert "key:1" not in cache._expires + assert "key:1" not in cache._access_order + + +def test_opportunistic_sweep_triggers_on_write_interval(): + cache = CacheClient() + cache.max_entries = 1000 + + async def run(): + for i in range(51): + await cache.set_json(f"exp:{i}", f"val{i}", ttl=0) + cache._expires[f"exp:{i}"] = time.time() - 1 + assert cache._sweep_count > 0 + + _run(run()) + + +# --------------------------------------------------------------------------- +# delete_prefix cleanup +# --------------------------------------------------------------------------- + + +def test_delete_prefix_removes_from_all_internal_dicts(): + cache = CacheClient() + + async def run(): + await cache.set_json("prefix:a", "val_a") + await cache.set_json("prefix:b", "val_b") + await cache.set_json("other:c", "val_c") + await cache.delete_prefix("prefix:") + + _run(run()) + assert "prefix:a" not in cache._data + assert "prefix:a" not in cache._expires + assert "prefix:a" not in cache._access_order + assert "prefix:b" not in cache._data + assert "other:c" in cache._data + assert cache.size == 1 + + +# --------------------------------------------------------------------------- +# Edge cases: max_entries <= 0 +# --------------------------------------------------------------------------- + + +def test_max_entries_zero_does_not_crash(): + cache = CacheClient(max_entries=0) + + async def run(): + await cache.set_json("key:1", "val1") + await cache.set_json("key:2", "val2") + + _run(run()) + assert cache.size >= 0 + + +def test_max_entries_negative_does_not_crash(): + cache = CacheClient(max_entries=-1) + + async def run(): + await cache.set_json("key:1", "val1") + await cache.set_json("key:2", "val2") + + _run(run()) + assert cache.size >= 0