diff --git a/app/core/lock.py b/app/core/lock.py index 4146702..460c28d 100644 --- a/app/core/lock.py +++ b/app/core/lock.py @@ -1,10 +1,28 @@ """Store-level coarse lock helpers.""" -from _thread import LockType import threading +from types import TracebackType +from typing import Optional, Protocol -def create_store_lock() -> LockType: +class StoreLock(Protocol): + """Minimal lock contract used by the store.""" + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: ... + + def release(self) -> None: ... + + def __enter__(self) -> bool: ... + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: ... + + +def create_store_lock() -> StoreLock: """Create the single coarse lock shared by store operations.""" return threading.Lock() diff --git a/app/core/store.py b/app/core/store.py index 34d64ea..14b6ed0 100644 --- a/app/core/store.py +++ b/app/core/store.py @@ -1,7 +1,6 @@ """Store API and key/value storage implementation.""" import time -from _thread import LockType from typing import Callable, Optional, Tuple from app.core.expiration import ( @@ -10,7 +9,7 @@ is_expired, ttl_seconds, ) -from app.core.lock import create_store_lock +from app.core.lock import StoreLock, create_store_lock class Store: @@ -19,7 +18,7 @@ class Store: def __init__(self, clock: Optional[Callable[[], float]] = None) -> None: self.data_map: dict[str, str] = {} self.expire_map: dict[str, float] = {} - self.lock: LockType = create_store_lock() + self.lock: StoreLock = create_store_lock() self._clock = clock if clock is not None else time.time def get(self, key: str) -> Tuple[bool, Optional[str]]: diff --git a/app/persistence/replay.py b/app/persistence/replay.py index 677dc88..69e42da 100644 --- a/app/persistence/replay.py +++ b/app/persistence/replay.py @@ -11,9 +11,7 @@ from .aof import AofEntry, AofParseError -def apply_aof_entry_to_store( - store: StoreProtocol, entry: AofEntry, now: float -) -> None: +def apply_aof_entry_to_store(store: StoreProtocol, entry: AofEntry, now: float) -> None: """Apply a single AOF entry to the store. Store.expireat(..., past) deletes the key.""" if entry.command == "SET": diff --git a/app/protocol/resp_parser.py b/app/protocol/resp_parser.py index 3fcc39f..eeb2144 100644 --- a/app/protocol/resp_parser.py +++ b/app/protocol/resp_parser.py @@ -2,14 +2,22 @@ from __future__ import annotations -from typing import BinaryIO +from typing import Protocol class RespProtocolError(Exception): """Raised when a RESP request does not match the supported subset.""" -def parse_command_frame(stream: BinaryIO) -> list[str] | None: +class RespReadableStream(Protocol): + """Binary stream shape required by the RESP parser.""" + + def read(self, size: int = -1, /) -> bytes | None: ... + + def readline(self, size: int = -1, /) -> bytes: ... + + +def parse_command_frame(stream: RespReadableStream) -> list[str] | None: """Parse one RESP command frame from a binary stream. Returns ``None`` when the peer closes the stream cleanly before sending @@ -53,7 +61,7 @@ def parse_command_frame(stream: BinaryIO) -> list[str] | None: return parts -def _parse_length(stream: BinaryIO, error_message: str) -> int: +def _parse_length(stream: RespReadableStream, error_message: str) -> int: line = _readline(stream, error_message) try: return int(line) @@ -61,20 +69,20 @@ def _parse_length(stream: BinaryIO, error_message: str) -> int: raise RespProtocolError(error_message) from error -def _readline(stream: BinaryIO, error_message: str) -> bytes: +def _readline(stream: RespReadableStream, error_message: str) -> bytes: line = stream.readline() if line == b"" or not line.endswith(b"\r\n"): raise RespProtocolError(error_message) return line[:-2] -def _read_exact(stream: BinaryIO, size: int) -> bytes: +def _read_exact(stream: RespReadableStream, size: int) -> bytes: payload = stream.read(size) if payload is None or len(payload) != size: raise RespProtocolError("protocol error: incomplete bulk string") return payload -def _expect_crlf(stream: BinaryIO, error_message: str) -> None: +def _expect_crlf(stream: RespReadableStream, error_message: str) -> None: if stream.read(2) != b"\r\n": raise RespProtocolError(error_message) diff --git a/tests/integration/test_protocol_http.py b/tests/integration/test_protocol_http.py index e3be94f..81a7a4e 100644 --- a/tests/integration/test_protocol_http.py +++ b/tests/integration/test_protocol_http.py @@ -193,7 +193,9 @@ async def test_malformed_json_maps_to_invalid_request() -> None: assert fake_executor.calls == [] -def test_default_app_starts_with_missing_aof_file(tmp_path: Any, monkeypatch: pytest.MonkeyPatch) -> None: +def test_default_app_starts_with_missing_aof_file( + tmp_path: Any, monkeypatch: pytest.MonkeyPatch +) -> None: monkeypatch.chdir(tmp_path) with TestClient(create_app()) as client: @@ -251,7 +253,9 @@ def test_default_app_fails_startup_for_malformed_aof( ) -> None: monkeypatch.chdir(tmp_path) aof_path = tmp_path / "appendonly.aof" - aof_path.write_text('{"command":"SET","args":["a","1"]}\nnot-json\n', encoding="utf-8") + aof_path.write_text( + '{"command":"SET","args":["a","1"]}\nnot-json\n', encoding="utf-8" + ) with pytest.raises(AofParseError, match="line 2"): with TestClient(create_app()): diff --git a/tests/unit/test_persistence_aof.py b/tests/unit/test_persistence_aof.py index ca3736f..a81be0e 100644 --- a/tests/unit/test_persistence_aof.py +++ b/tests/unit/test_persistence_aof.py @@ -198,13 +198,9 @@ def sweep_expired(self) -> int: store = RecordingStore() apply_aof_entry_to_store(store, AofEntry("SET", ("k", "v")), 100.0) - apply_aof_entry_to_store( - store, AofEntry("EXPIREAT", ("k", 50.0)), 100.0 - ) + apply_aof_entry_to_store(store, AofEntry("EXPIREAT", ("k", 50.0)), 100.0) assert store.calls == [("set", ("k", "v")), ("expireat", ("k", 50.0))] store.calls.clear() - apply_aof_entry_to_store( - store, AofEntry("EXPIREAT", ("k", 150.0)), 100.0 - ) + apply_aof_entry_to_store(store, AofEntry("EXPIREAT", ("k", 150.0)), 100.0) assert store.calls == [("expireat", ("k", 150.0))]