From faf1b772a4a48479d6595558b552b4216714fb0c Mon Sep 17 00:00:00 2001 From: d3vyce Date: Fri, 15 May 2026 15:46:37 -0400 Subject: [PATCH] refactor: cleanup middleware helpers, remove type hacks, and cache Lua script registration --- src/taskiq_deduplication/middleware.py | 28 ++++++++++++++------------ src/taskiq_deduplication/utils.py | 8 +++----- tests/test_middleware.py | 6 +++++- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/taskiq_deduplication/middleware.py b/src/taskiq_deduplication/middleware.py index e29a4dc..27f5ef7 100644 --- a/src/taskiq_deduplication/middleware.py +++ b/src/taskiq_deduplication/middleware.py @@ -8,7 +8,7 @@ from taskiq import TaskiqMessage, TaskiqResult from taskiq.abc.middleware import TaskiqMiddleware -from .utils import check_and_delete +from .utils import RELEASE_LUA_SCRIPT, check_and_delete logger = logging.getLogger(__name__) @@ -55,6 +55,7 @@ def __init__( self.startup_retries = startup_retries self.startup_retry_delay = startup_retry_delay self._redis: Redis | None = None + self._release_script: Any = None async def startup(self) -> None: last_error: BaseException | None = None @@ -62,6 +63,7 @@ async def startup(self) -> None: try: self._redis = Redis.from_url(self.redis_url) await cast(Awaitable[bool], self._redis.ping()) + self._release_script = self._redis.register_script(RELEASE_LUA_SCRIPT) return except Exception as exc: last_error = exc @@ -133,7 +135,9 @@ async def _release_if_owned(self, key: str, task_id: str) -> None: raise RuntimeError( "RedisDeduplicationMiddleware.startup() was never called." ) - released = await check_and_delete(self._redis, key, task_id) + if self._release_script is None: + self._release_script = self._redis.register_script(RELEASE_LUA_SCRIPT) + released = await check_and_delete(self._release_script, key, task_id) if released: logger.debug("Released lock %s", key) else: @@ -181,11 +185,7 @@ async def pre_send(self, message: TaskiqMessage) -> TaskiqMessage: logger.debug("Lock %s acquired for task %s", key, message.task_name) return message - async def post_execute( - self, - message: TaskiqMessage, - result: TaskiqResult, - ) -> None: + async def _release_lock(self, message: TaskiqMessage) -> None: if not self._is_enabled(message.labels): return key = self._get_cached_key(message) @@ -193,15 +193,17 @@ async def post_execute( return await self._release_if_owned(key, message.task_id) + async def post_execute( + self, + message: TaskiqMessage, + result: TaskiqResult, + ) -> None: + await self._release_lock(message) + async def on_error( self, message: TaskiqMessage, result: TaskiqResult, exception: BaseException, ) -> None: - if not self._is_enabled(message.labels): - return - key = self._get_cached_key(message) - if key is None: - return - await self._release_if_owned(key, message.task_id) + await self._release_lock(message) diff --git a/src/taskiq_deduplication/utils.py b/src/taskiq_deduplication/utils.py index 2539699..dfe7b0a 100644 --- a/src/taskiq_deduplication/utils.py +++ b/src/taskiq_deduplication/utils.py @@ -1,5 +1,4 @@ -from redis.asyncio import Redis -from redis.commands.core import AsyncScript +from typing import Any RELEASE_LUA_SCRIPT = """ if redis.call('get', KEYS[1]) == ARGV[1] then @@ -10,17 +9,16 @@ """ -async def check_and_delete(redis: Redis, key: str, owner: str) -> bool: +async def check_and_delete(script: Any, key: str, owner: str) -> bool: """Delete *key* only if its value equals *owner*. Args: - redis: Async Redis client. + script: Pre-registered Lua script object (from ``Redis.register_script``). key: Lock key to delete. owner: Expected value of the key (task_id). Returns: True if the key was deleted, False otherwise. """ - script: AsyncScript = redis.register_script(RELEASE_LUA_SCRIPT) released: int = await script(keys=[key], args=[owner]) return bool(released) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 22f402f..9b63bbe 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -309,6 +309,7 @@ async def test_startup_creates_redis_client(self): assert mw._redis is None with patch("redis.asyncio.Redis.from_url") as mock_from_url: mock_client = AsyncMock() + mock_client.register_script = MagicMock() mock_from_url.return_value = mock_client await mw.startup() mock_from_url.assert_called_once_with("redis://localhost") @@ -349,6 +350,7 @@ async def test_startup_succeeds_after_retries(self): ConnectionError("fail"), None, ] + mock_client.register_script = MagicMock() mock_from_url.return_value = mock_client await mw.startup() assert mw._redis is mock_client @@ -378,6 +380,7 @@ async def test_startup_no_retry_on_first_success(self): ) with patch("redis.asyncio.Redis.from_url") as mock_from_url: mock_client = AsyncMock() + mock_client.register_script = MagicMock() mock_from_url.return_value = mock_client await mw.startup() mock_client.ping.assert_called_once() @@ -399,6 +402,7 @@ async def test_startup_retry_delay_exponential(self): ConnectionError("fail"), None, ] + mock_client.register_script = MagicMock() mock_from_url.return_value = mock_client await mw.startup() assert mock_sleep.call_count == 2