Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/taskiq_deduplication/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from taskiq import TaskiqMessage, TaskiqResult
from taskiq.abc.middleware import TaskiqMiddleware

from .utils import check_and_delete
from .utils import RELEASE_LUA_SCRIPT, check_and_delete

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,13 +55,15 @@ def __init__(
self.startup_retries = startup_retries
self.startup_retry_delay = startup_retry_delay
self._redis: Redis | None = None
self._release_script: Any = None

async def startup(self) -> None:
last_error: BaseException | None = None
for attempt in range(self.startup_retries):
try:
self._redis = Redis.from_url(self.redis_url)
await cast(Awaitable[bool], self._redis.ping())
self._release_script = self._redis.register_script(RELEASE_LUA_SCRIPT)
return
except Exception as exc:
last_error = exc
Expand Down Expand Up @@ -133,7 +135,9 @@ async def _release_if_owned(self, key: str, task_id: str) -> None:
raise RuntimeError(
"RedisDeduplicationMiddleware.startup() was never called."
)
released = await check_and_delete(self._redis, key, task_id)
if self._release_script is None:
self._release_script = self._redis.register_script(RELEASE_LUA_SCRIPT)
released = await check_and_delete(self._release_script, key, task_id)
if released:
logger.debug("Released lock %s", key)
else:
Expand Down Expand Up @@ -181,27 +185,25 @@ async def pre_send(self, message: TaskiqMessage) -> TaskiqMessage:
logger.debug("Lock %s acquired for task %s", key, message.task_name)
return message

async def post_execute(
self,
message: TaskiqMessage,
result: TaskiqResult,
) -> None:
async def _release_lock(self, message: TaskiqMessage) -> None:
if not self._is_enabled(message.labels):
return
key = self._get_cached_key(message)
if key is None:
return
await self._release_if_owned(key, message.task_id)

async def post_execute(
self,
message: TaskiqMessage,
result: TaskiqResult,
) -> None:
await self._release_lock(message)

async def on_error(
self,
message: TaskiqMessage,
result: TaskiqResult,
exception: BaseException,
) -> None:
if not self._is_enabled(message.labels):
return
key = self._get_cached_key(message)
if key is None:
return
await self._release_if_owned(key, message.task_id)
await self._release_lock(message)
8 changes: 3 additions & 5 deletions src/taskiq_deduplication/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from redis.asyncio import Redis
from redis.commands.core import AsyncScript
from typing import Any

RELEASE_LUA_SCRIPT = """
if redis.call('get', KEYS[1]) == ARGV[1] then
Expand All @@ -10,17 +9,16 @@
"""


async def check_and_delete(redis: Redis, key: str, owner: str) -> bool:
async def check_and_delete(script: Any, key: str, owner: str) -> bool:
"""Delete *key* only if its value equals *owner*.

Args:
redis: Async Redis client.
script: Pre-registered Lua script object (from ``Redis.register_script``).
key: Lock key to delete.
owner: Expected value of the key (task_id).

Returns:
True if the key was deleted, False otherwise.
"""
script: AsyncScript = redis.register_script(RELEASE_LUA_SCRIPT)
released: int = await script(keys=[key], args=[owner])
return bool(released)
6 changes: 5 additions & 1 deletion tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

Expand Down Expand Up @@ -309,6 +309,7 @@ async def test_startup_creates_redis_client(self):
assert mw._redis is None
with patch("redis.asyncio.Redis.from_url") as mock_from_url:
mock_client = AsyncMock()
mock_client.register_script = MagicMock()
mock_from_url.return_value = mock_client
await mw.startup()
mock_from_url.assert_called_once_with("redis://localhost")
Expand Down Expand Up @@ -349,6 +350,7 @@ async def test_startup_succeeds_after_retries(self):
ConnectionError("fail"),
None,
]
mock_client.register_script = MagicMock()
mock_from_url.return_value = mock_client
await mw.startup()
assert mw._redis is mock_client
Expand Down Expand Up @@ -378,6 +380,7 @@ async def test_startup_no_retry_on_first_success(self):
)
with patch("redis.asyncio.Redis.from_url") as mock_from_url:
mock_client = AsyncMock()
mock_client.register_script = MagicMock()
mock_from_url.return_value = mock_client
await mw.startup()
mock_client.ping.assert_called_once()
Expand All @@ -399,6 +402,7 @@ async def test_startup_retry_delay_exponential(self):
ConnectionError("fail"),
None,
]
mock_client.register_script = MagicMock()
mock_from_url.return_value = mock_client
await mw.startup()
assert mock_sleep.call_count == 2
Expand Down