diff --git a/sigflow/cache/lru.py b/sigflow/cache/lru.py index 72832fe..65e7219 100644 --- a/sigflow/cache/lru.py +++ b/sigflow/cache/lru.py @@ -1,24 +1,39 @@ from collections import OrderedDict +from threading import Lock class LRUCache: + """Thread-safe Least-Recently-Used cache implementation. + + Uses a Lock to ensure all operations (get, set, len) are atomic and safe + for concurrent access from multiple threads. The lock serializes access + to the underlying OrderedDict to prevent data corruption, eviction races, + and concurrent modification errors. + + Args: + capacity: Maximum number of items to cache (must be > 0) + """ def __init__(self, capacity: int = 1024): if capacity <= 0: raise ValueError("capacity must be positive") self.capacity = capacity self._data = OrderedDict() + self._lock = Lock() def get(self, key, default=None): - if key not in self._data: - return default - self._data.move_to_end(key) - return self._data[key] + with self._lock: + if key not in self._data: + return default + self._data.move_to_end(key) + return self._data[key] def set(self, key, value) -> None: - self._data[key] = value - self._data.move_to_end(key) - while len(self._data) > self.capacity: - self._data.popitem(last=False) + with self._lock: + self._data[key] = value + self._data.move_to_end(key) + while len(self._data) > self.capacity: + self._data.popitem(last=False) def __len__(self): - return len(self._data) + with self._lock: + return len(self._data) diff --git a/tests/test_cache_concurrency.py b/tests/test_cache_concurrency.py new file mode 100644 index 0000000..1fd00ea --- /dev/null +++ b/tests/test_cache_concurrency.py @@ -0,0 +1,116 @@ +"""Thread-safety tests for LRUCache concurrent access patterns.""" + +import threading +from sigflow.cache.lru import LRUCache + + +def test_lru_thread_safety(): + """Verify LRUCache is safe for concurrent access from multiple threads. + + Tests that: + - No data corruption occurs during concurrent set/get operations + - Values retrieved match values stored + - Eviction doesn't cause race conditions or KeyError exceptions + - Cache remains consistent under concurrent load + """ + cache = LRUCache(50) + errors = [] + + def worker(thread_id): + """Worker thread that performs concurrent cache operations.""" + try: + for i in range(200): + key = f"t{thread_id}_k{i}" + value = f"v{thread_id}_{i}" + + # Concurrent set operations + cache.set(key, value) + + # Concurrent get operations + retrieved = cache.get(key) + if retrieved != value: + errors.append( + f"Thread {thread_id}: Data mismatch for {key}. " + f"Expected {value}, got {retrieved}" + ) + + # Also test getting other keys (cache hits/misses) + other_key = f"t{(thread_id + 1) % 5}_k{i}" + cache.get(other_key) + + except Exception as exc: + errors.append(f"Thread {thread_id}: {type(exc).__name__}: {exc}") + + # Launch 10 concurrent threads + threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Verify no errors occurred + assert not errors, "Thread safety violations:\n" + "\n".join(errors) + + +def test_lru_concurrent_len(): + """Verify __len__() is thread-safe during concurrent modifications.""" + cache = LRUCache(100) + lengths = [] + + def reader(): + """Continuously read cache length.""" + for _ in range(100): + lengths.append(cache.__len__()) + + def writer(): + """Continuously modify cache.""" + for i in range(100): + cache.set(f"key_{i}", f"value_{i}") + + threads = [ + *[threading.Thread(target=reader) for _ in range(3)], + *[threading.Thread(target=writer) for _ in range(3)], + ] + + for t in threads: + t.start() + for t in threads: + t.join() + + # All length reads should have succeeded without error + assert len(lengths) == 300 + + +def test_lru_eviction_race(): + """Test that eviction under concurrent load doesn't cause corruption.""" + cache = LRUCache(10) # Small capacity to force evictions + errors = [] + + def aggressive_writer(thread_id): + """Aggressively write to trigger evictions.""" + try: + for i in range(500): + cache.set(f"t{thread_id}_k{i}", i) + except Exception as exc: + errors.append(f"Thread {thread_id} write error: {exc}") + + def aggressive_reader(thread_id): + """Aggressively read during eviction.""" + try: + for i in range(500): + cache.get(f"t{(thread_id - 1) % 5}_k{i}") + except Exception as exc: + errors.append(f"Thread {thread_id} read error: {exc}") + + threads = [ + *[threading.Thread(target=aggressive_writer, args=(i,)) for i in range(5)], + *[threading.Thread(target=aggressive_reader, args=(i,)) for i in range(5)], + ] + + for t in threads: + t.start() + for t in threads: + t.join() + + # Should complete without crashes or data corruption + assert not errors, "Eviction race condition:\n" + "\n".join(errors)