Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions sigflow/cache/lru.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,39 @@
from collections import OrderedDict
from threading import Lock


class LRUCache:
"""Thread-safe Least-Recently-Used cache implementation.

Uses a Lock to ensure all operations (get, set, len) are atomic and safe
for concurrent access from multiple threads. The lock serializes access
to the underlying OrderedDict to prevent data corruption, eviction races,
and concurrent modification errors.

Args:
capacity: Maximum number of items to cache (must be > 0)
"""
def __init__(self, capacity: int = 1024):
if capacity <= 0:
raise ValueError("capacity must be positive")
self.capacity = capacity
self._data = OrderedDict()
self._lock = Lock()

def get(self, key, default=None):
if key not in self._data:
return default
self._data.move_to_end(key)
return self._data[key]
with self._lock:
if key not in self._data:
return default
self._data.move_to_end(key)
return self._data[key]

def set(self, key, value) -> None:
self._data[key] = value
self._data.move_to_end(key)
while len(self._data) > self.capacity:
self._data.popitem(last=False)
with self._lock:
self._data[key] = value
self._data.move_to_end(key)
while len(self._data) > self.capacity:
self._data.popitem(last=False)

def __len__(self):
return len(self._data)
with self._lock:
return len(self._data)
116 changes: 116 additions & 0 deletions tests/test_cache_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Thread-safety tests for LRUCache concurrent access patterns."""

import threading
from sigflow.cache.lru import LRUCache


def test_lru_thread_safety():
"""Verify LRUCache is safe for concurrent access from multiple threads.

Tests that:
- No data corruption occurs during concurrent set/get operations
- Values retrieved match values stored
- Eviction doesn't cause race conditions or KeyError exceptions
- Cache remains consistent under concurrent load
"""
cache = LRUCache(50)
errors = []

def worker(thread_id):
"""Worker thread that performs concurrent cache operations."""
try:
for i in range(200):
key = f"t{thread_id}_k{i}"
value = f"v{thread_id}_{i}"

# Concurrent set operations
cache.set(key, value)

# Concurrent get operations
retrieved = cache.get(key)
if retrieved != value:
errors.append(
f"Thread {thread_id}: Data mismatch for {key}. "
f"Expected {value}, got {retrieved}"
)

# Also test getting other keys (cache hits/misses)
other_key = f"t{(thread_id + 1) % 5}_k{i}"
cache.get(other_key)

except Exception as exc:
errors.append(f"Thread {thread_id}: {type(exc).__name__}: {exc}")

# Launch 10 concurrent threads
threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()

# Verify no errors occurred
assert not errors, "Thread safety violations:\n" + "\n".join(errors)


def test_lru_concurrent_len():
"""Verify __len__() is thread-safe during concurrent modifications."""
cache = LRUCache(100)
lengths = []

def reader():
"""Continuously read cache length."""
for _ in range(100):
lengths.append(cache.__len__())

def writer():
"""Continuously modify cache."""
for i in range(100):
cache.set(f"key_{i}", f"value_{i}")

threads = [
*[threading.Thread(target=reader) for _ in range(3)],
*[threading.Thread(target=writer) for _ in range(3)],
]

for t in threads:
t.start()
for t in threads:
t.join()

# All length reads should have succeeded without error
assert len(lengths) == 300


def test_lru_eviction_race():
"""Test that eviction under concurrent load doesn't cause corruption."""
cache = LRUCache(10) # Small capacity to force evictions
errors = []

def aggressive_writer(thread_id):
"""Aggressively write to trigger evictions."""
try:
for i in range(500):
cache.set(f"t{thread_id}_k{i}", i)
except Exception as exc:
errors.append(f"Thread {thread_id} write error: {exc}")

def aggressive_reader(thread_id):
"""Aggressively read during eviction."""
try:
for i in range(500):
cache.get(f"t{(thread_id - 1) % 5}_k{i}")
except Exception as exc:
errors.append(f"Thread {thread_id} read error: {exc}")

threads = [
*[threading.Thread(target=aggressive_writer, args=(i,)) for i in range(5)],
*[threading.Thread(target=aggressive_reader, args=(i,)) for i in range(5)],
]

for t in threads:
t.start()
for t in threads:
t.join()

# Should complete without crashes or data corruption
assert not errors, "Eviction race condition:\n" + "\n".join(errors)
Loading