diff --git a/src/bot/services/ai/gemini_client.py b/src/bot/services/ai/gemini_client.py index 85647f4..05ef75a 100644 --- a/src/bot/services/ai/gemini_client.py +++ b/src/bot/services/ai/gemini_client.py @@ -23,7 +23,7 @@ transcribe_system_prompt, ) from services.ai.wtf_prompts import WTFPromptStyle, get_wtf_system_prompt, wtf_explain_user_text -from services.repositories.rate_limit import RateLimitRepository +from services.repositories.rate_limit import RateLimitRepository, RateLimitUnavailableError from urllib3.exceptions import HTTPError logger = LoggerAdapter(get_logger(__name__), {}) @@ -85,7 +85,10 @@ def explain_term(self, term: str, lang: str = "kk", style: WTFPromptStyle = "ang GeminiRPDExhaustedError: daily RPD limit reached after increment. GeminiUnavailableError: 429, 5xx, timeout, or bad response body. """ - count, within_limit = self._rate_repo.increment_and_check() + try: + count, within_limit = self._rate_repo.increment_and_check() + except RateLimitUnavailableError as exc: + raise GeminiUnavailableError("Gemini RPD counter unavailable") from exc if not within_limit: logger.warning( @@ -214,7 +217,10 @@ def explain_media( Raises: GeminiRPDExhaustedError, GeminiUnavailableError: same as ``explain_term``. """ - count, within_limit = self._rate_repo.increment_and_check() + try: + count, within_limit = self._rate_repo.increment_and_check() + except RateLimitUnavailableError as exc: + raise GeminiUnavailableError("Gemini RPD counter unavailable") from exc if not within_limit: logger.warning( "Gemini RPD limit reached (multimodal)", diff --git a/src/bot/services/repositories/rate_limit.py b/src/bot/services/repositories/rate_limit.py index a46edde..30a5c5f 100644 --- a/src/bot/services/repositories/rate_limit.py +++ b/src/bot/services/repositories/rate_limit.py @@ -23,6 +23,10 @@ _TTL_DELTA = timedelta(hours=48) +class RateLimitUnavailableError(RuntimeError): + """Raised when the RPD counter cannot be updated safely.""" + + class RateLimitRepository: """Atomic RPD counter in the shared stats DynamoDB table. @@ -71,9 +75,9 @@ def increment_and_check(self) -> tuple[int, bool]: }, ReturnValues="UPDATED_NEW", ) - except ClientError: + except ClientError as exc: logger.exception("Failed to increment Gemini RPD counter") - return 0, True + raise RateLimitUnavailableError("Gemini RPD counter unavailable") from exc count = int(resp["Attributes"]["request_count"]) return count, count <= self.rpd_limit diff --git a/src/quiz/services/llm_provider.py b/src/quiz/services/llm_provider.py index f2a2e8f..a9fcf3c 100644 --- a/src/quiz/services/llm_provider.py +++ b/src/quiz/services/llm_provider.py @@ -12,7 +12,7 @@ from google import genai from google.genai import errors as genai_errors from google.genai import types -from services.rate_limit_repository import QuizRateLimitRepository +from services.rate_limit_repository import QuizRateLimitRepository, QuizRateLimitUnavailableError from zerde_common.ai_errors import ( ProviderRateLimitError, ProviderResponseError, @@ -81,7 +81,10 @@ def get_rpd_status(self) -> tuple[int, int]: _RETRY_DELAYS = (5, 15, 30) # seconds; quiz runs on schedule, has time def generate_json(self, prompt: str, temperature: float = 0.3) -> dict: - count, within_limit = self._rate_repo.increment_and_check() + try: + count, within_limit = self._rate_repo.increment_and_check() + except QuizRateLimitUnavailableError as exc: + raise ProviderTransportError("Quiz Gemini RPD counter unavailable") from exc if not within_limit: logger.warning( "Quiz Gemini RPD limit reached", diff --git a/src/quiz/services/rate_limit_repository.py b/src/quiz/services/rate_limit_repository.py index 6b7780d..650a32f 100644 --- a/src/quiz/services/rate_limit_repository.py +++ b/src/quiz/services/rate_limit_repository.py @@ -15,6 +15,10 @@ _TTL_DELTA = timedelta(hours=48) +class QuizRateLimitUnavailableError(RuntimeError): + """Raised when the quiz RPD counter cannot be updated safely.""" + + class QuizRateLimitRepository: """DynamoDB-backed daily counter for quiz Gemini requests.""" @@ -47,9 +51,9 @@ def increment_and_check(self) -> tuple[int, bool]: ExpressionAttributeValues={":inc": 1, ":zero": 0, ":ttl": ttl_epoch}, ReturnValues="UPDATED_NEW", ) - except ClientError: + except ClientError as exc: logger.exception("Failed to increment quiz Gemini RPD counter") - return 0, True + raise QuizRateLimitUnavailableError("Quiz Gemini RPD counter unavailable") from exc count = int(resp["Attributes"]["request_count"]) return count, count <= self.rpd_limit diff --git a/tests/test_rate_limit_fail_closed.py b/tests/test_rate_limit_fail_closed.py new file mode 100644 index 0000000..3b1ccb9 --- /dev/null +++ b/tests/test_rate_limit_fail_closed.py @@ -0,0 +1,105 @@ +"""RPD counter failures must fail closed before calling Gemini.""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from botocore.exceptions import ClientError + +from services.ai.gemini_client import GeminiClient, GeminiUnavailableError +from services.repositories.rate_limit import RateLimitRepository, RateLimitUnavailableError + + +def _client_error() -> ClientError: + return ClientError( + {"Error": {"Code": "InternalServerError", "Message": "dynamodb unavailable"}}, + "UpdateItem", + ) + + +def test_bot_rate_limit_increment_failure_fails_closed() -> None: + table = MagicMock() + table.update_item.side_effect = _client_error() + + repo = RateLimitRepository() + with patch("services.repositories.rate_limit.get_dynamodb") as mock_get_dynamodb: + mock_get_dynamodb.return_value.Table.return_value = table + + with pytest.raises(RateLimitUnavailableError): + repo.increment_and_check() + + +def test_gemini_text_explain_does_not_call_api_when_counter_unavailable() -> None: + client = GeminiClient.__new__(GeminiClient) + client._api_key = "test-key" + client._model = "test-model" + client._rate_repo = MagicMock() + client._rate_repo.increment_and_check.side_effect = RateLimitUnavailableError("counter down") + client._rate_repo.rpd_limit = 1000 + + with patch("services.ai.gemini_client._http.request") as mock_request: + with pytest.raises(GeminiUnavailableError): + client.explain_term("kubernetes", "en") + + mock_request.assert_not_called() + + +def test_gemini_media_explain_does_not_call_api_when_counter_unavailable() -> None: + client = GeminiClient.__new__(GeminiClient) + client._api_key = "test-key" + client._model = "test-model" + client._rate_repo = MagicMock() + client._rate_repo.increment_and_check.side_effect = RateLimitUnavailableError("counter down") + client._rate_repo.rpd_limit = 1000 + + with patch("services.ai.gemini_client._http_multimodal.request") as mock_request: + with pytest.raises(GeminiUnavailableError): + client.explain_media( + media_kind="photo", + file_bytes=b"image", + mime_type="image/jpeg", + lang="en", + ) + + mock_request.assert_not_called() + + +def test_quiz_rate_limit_increment_failure_fails_closed() -> None: + os.environ.setdefault("BOT_TOKEN", "test-bot-token") + os.environ.setdefault("TABLE_NAME", "test-quiz-table") + os.environ.setdefault("QUIZ_LLM_RPD", "1000") + + quiz_dir = Path(__file__).resolve().parents[1] / "src" / "quiz" + saved_modules = { + name: module + for name, module in sys.modules.items() + if name in {"core", "services"} or name.startswith(("core.", "services.")) + } + + try: + for name in list(saved_modules): + sys.modules.pop(name, None) + sys.path.insert(0, str(quiz_dir)) + from services.rate_limit_repository import ( # noqa: PLC0415 + QuizRateLimitRepository, + QuizRateLimitUnavailableError, + ) + + repo = QuizRateLimitRepository.__new__(QuizRateLimitRepository) + repo._table = MagicMock() + repo._table.update_item.side_effect = _client_error() + repo.rpd_limit = 1000 + + with pytest.raises(QuizRateLimitUnavailableError): + repo.increment_and_check() + finally: + if str(quiz_dir) in sys.path: + sys.path.remove(str(quiz_dir)) + for name in list(sys.modules): + if name in {"core", "services"} or name.startswith(("core.", "services.")): + sys.modules.pop(name, None) + sys.modules.update(saved_modules)