Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions backend/secuscan/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,110 @@
import json
from typing import Any, Optional, Dict
import time
import logging

from .config import settings

logger = logging.getLogger(__name__)

DEFAULT_MAX_ENTRIES = 10_000
SWEEP_EVICT_FRACTION = 0.25
OPPORTUNISTIC_SWEEP_INTERVAL = 50


class CacheClient:
"""In-memory dictionary based cache client."""
"""In-memory dictionary based cache client with TTL, size limit, and LRU eviction."""

def __init__(self, url: Optional[str] = None):
def __init__(self, url: Optional[str] = None, max_entries: int = DEFAULT_MAX_ENTRIES):
self.url = url
self._data: Dict[str, Any] = {}
self._expires: Dict[str, float] = {}
self._access_order: Dict[str, float] = {}
self.max_entries = max_entries
self._eviction_count = 0
self._sweep_count = 0
self._write_count = 0

async def connect(self):
"""No connection needed for in-memory cache."""
pass

async def disconnect(self):
"""Clear cache on disconnect."""
self._data.clear()
self._expires.clear()
self._access_order.clear()

def _sweep_expired(self):
now = time.time()
keys = [k for k, exp in list(self._expires.items()) if exp <= now]
for k in keys:
self._data.pop(k, None)
self._expires.pop(k, None)
self._access_order.pop(k, None)
if keys:
self._sweep_count += len(keys)

def _evict_lru(self):
"""Evict the least recently used entries when over capacity."""
if len(self._data) < self.max_entries:
return
sorted_keys = sorted(self._access_order, key=lambda k: self._access_order[k])
evict_count = max(1, int(self.max_entries * SWEEP_EVICT_FRACTION))
for k in sorted_keys[:evict_count]:
self._data.pop(k, None)
self._expires.pop(k, None)
self._access_order.pop(k, None)
self._eviction_count += evict_count

async def get_json(self, key: str) -> Optional[Any]:
"""Retrieve and parse JSON from memory, respecting TTL."""
now = time.time()
expiry = self._expires.get(key)

if expiry and now > expiry:
# Clean up expired item
self._data.pop(key, None)
self._expires.pop(key, None)
self._access_order.pop(key, None)
return None


if key in self._data:
self._access_order[key] = now

return self._data.get(key)

async def set_json(self, key: str, value: Any, ttl: Optional[int] = None):
"""Store value in memory with optional TTL."""
if len(self._data) >= self.max_entries and key not in self._data:
self._evict_lru()

self._data[key] = value
actual_ttl = ttl or settings.cache_ttl_seconds
self._expires[key] = time.time() + actual_ttl
self._access_order[key] = time.time()
self._write_count += 1

if self._write_count % OPPORTUNISTIC_SWEEP_INTERVAL == 0:
self._sweep_expired()

async def delete_prefix(self, prefix: str):
"""Delete all keys starting with prefix."""
to_delete = [k for k in self._data.keys() if k.startswith(prefix)]
for k in to_delete:
self._data.pop(k, None)
self._expires.pop(k, None)
self._access_order.pop(k, None)

@property
def size(self) -> int:
return len(self._data)

@property
def stats(self) -> Dict[str, Any]:
return {
"size": self.size,
"max_entries": self.max_entries,
"eviction_count": self._eviction_count,
"sweep_count": self._sweep_count,
}


# Global cache instance
Expand Down
156 changes: 156 additions & 0 deletions testing/backend/unit/test_cache_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import time
from unittest.mock import AsyncMock, patch

from backend.secuscan.cache import CacheClient
Expand Down Expand Up @@ -95,3 +96,158 @@ async def run():

result = _run(run())
assert result is None


# ---------------------------------------------------------------------------
# LRU eviction order
# ---------------------------------------------------------------------------


def test_lru_eviction_evicts_oldest_when_over_capacity():
cache = CacheClient(max_entries=3)

async def run():
await cache.set_json("key:1", "val1")
await cache.set_json("key:2", "val2")
await cache.set_json("key:3", "val3")
await cache.set_json("key:4", "val4")

_run(run())
assert cache.size == 3
assert _run(cache.get_json("key:1")) is None
assert _run(cache.get_json("key:2")) == "val2"
assert _run(cache.get_json("key:3")) == "val3"
assert _run(cache.get_json("key:4")) == "val4"


def test_lru_eviction_skips_when_under_capacity():
cache = CacheClient(max_entries=5)

async def run():
await cache.set_json("key:1", "val1")
await cache.set_json("key:2", "val2")

_run(run())
assert cache.size == 2
assert _run(cache.get_json("key:1")) == "val1"
assert cache._eviction_count == 0


def test_lru_eviction_preserves_recently_accessed():
cache = CacheClient(max_entries=3)

async def run():
await cache.set_json("key:1", "val1")
await cache.set_json("key:2", "val2")
await cache.set_json("key:3", "val3")
cache._access_order["key:2"] = 1.0 # Oldest
cache._access_order["key:3"] = 2.0 # Middle
await cache.get_json("key:1") # Refreshes to ~now (most recent)
await cache.set_json("key:4", "val4")

_run(run())
assert cache.size == 3
assert _run(cache.get_json("key:1")) == "val1" # Recently accessed, preserved
assert _run(cache.get_json("key:2")) is None # Oldest, evicted
assert _run(cache.get_json("key:4")) == "val4"


# ---------------------------------------------------------------------------
# Expiry cleanup
# ---------------------------------------------------------------------------


def test_expiry_sweep_removes_access_order_entries():
cache = CacheClient()

async def run():
await cache.set_json("key:1", "val1", ttl=10)
await cache.set_json("key:2", "val2", ttl=10)
cache._expires["key:1"] = time.time() - 1
cache._expires["key:2"] = time.time() - 1
cache._sweep_expired()

_run(run())
assert "key:1" not in cache._data
assert "key:1" not in cache._expires
assert "key:1" not in cache._access_order
assert cache._sweep_count == 2


def test_expired_entry_get_returns_none_and_cleans_access_order():
cache = CacheClient()

async def run():
await cache.set_json("key:1", "val1", ttl=10)
cache._expires["key:1"] = time.time() - 1
result = await cache.get_json("key:1")
return result

result = _run(run())
assert result is None
assert "key:1" not in cache._data
assert "key:1" not in cache._expires
assert "key:1" not in cache._access_order


def test_opportunistic_sweep_triggers_on_write_interval():
cache = CacheClient()
cache.max_entries = 1000

async def run():
for i in range(51):
await cache.set_json(f"exp:{i}", f"val{i}", ttl=0)
cache._expires[f"exp:{i}"] = time.time() - 1
assert cache._sweep_count > 0

_run(run())


# ---------------------------------------------------------------------------
# delete_prefix cleanup
# ---------------------------------------------------------------------------


def test_delete_prefix_removes_from_all_internal_dicts():
cache = CacheClient()

async def run():
await cache.set_json("prefix:a", "val_a")
await cache.set_json("prefix:b", "val_b")
await cache.set_json("other:c", "val_c")
await cache.delete_prefix("prefix:")

_run(run())
assert "prefix:a" not in cache._data
assert "prefix:a" not in cache._expires
assert "prefix:a" not in cache._access_order
assert "prefix:b" not in cache._data
assert "other:c" in cache._data
assert cache.size == 1


# ---------------------------------------------------------------------------
# Edge cases: max_entries <= 0
# ---------------------------------------------------------------------------


def test_max_entries_zero_does_not_crash():
cache = CacheClient(max_entries=0)

async def run():
await cache.set_json("key:1", "val1")
await cache.set_json("key:2", "val2")

_run(run())
assert cache.size >= 0


def test_max_entries_negative_does_not_crash():
cache = CacheClient(max_entries=-1)

async def run():
await cache.set_json("key:1", "val1")
await cache.set_json("key:2", "val2")

_run(run())
assert cache.size >= 0
Loading