Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/bot/services/ai/gemini_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
transcribe_system_prompt,
)
from services.ai.wtf_prompts import WTFPromptStyle, get_wtf_system_prompt, wtf_explain_user_text
from services.repositories.rate_limit import RateLimitRepository
from services.repositories.rate_limit import RateLimitRepository, RateLimitUnavailableError
from urllib3.exceptions import HTTPError

logger = LoggerAdapter(get_logger(__name__), {})
Expand Down Expand Up @@ -85,7 +85,10 @@ def explain_term(self, term: str, lang: str = "kk", style: WTFPromptStyle = "ang
GeminiRPDExhaustedError: daily RPD limit reached after increment.
GeminiUnavailableError: 429, 5xx, timeout, or bad response body.
"""
count, within_limit = self._rate_repo.increment_and_check()
try:
count, within_limit = self._rate_repo.increment_and_check()
except RateLimitUnavailableError as exc:
raise GeminiUnavailableError("Gemini RPD counter unavailable") from exc

if not within_limit:
logger.warning(
Expand Down Expand Up @@ -214,7 +217,10 @@ def explain_media(
Raises:
GeminiRPDExhaustedError, GeminiUnavailableError: same as ``explain_term``.
"""
count, within_limit = self._rate_repo.increment_and_check()
try:
count, within_limit = self._rate_repo.increment_and_check()
except RateLimitUnavailableError as exc:
raise GeminiUnavailableError("Gemini RPD counter unavailable") from exc
if not within_limit:
logger.warning(
"Gemini RPD limit reached (multimodal)",
Expand Down
8 changes: 6 additions & 2 deletions src/bot/services/repositories/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
_TTL_DELTA = timedelta(hours=48)


class RateLimitUnavailableError(RuntimeError):
"""Raised when the RPD counter cannot be updated safely."""


class RateLimitRepository:
"""Atomic RPD counter in the shared stats DynamoDB table.

Expand Down Expand Up @@ -71,9 +75,9 @@ def increment_and_check(self) -> tuple[int, bool]:
},
ReturnValues="UPDATED_NEW",
)
except ClientError:
except ClientError as exc:
logger.exception("Failed to increment Gemini RPD counter")
return 0, True
raise RateLimitUnavailableError("Gemini RPD counter unavailable") from exc

count = int(resp["Attributes"]["request_count"])
return count, count <= self.rpd_limit
Expand Down
7 changes: 5 additions & 2 deletions src/quiz/services/llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from google import genai
from google.genai import errors as genai_errors
from google.genai import types
from services.rate_limit_repository import QuizRateLimitRepository
from services.rate_limit_repository import QuizRateLimitRepository, QuizRateLimitUnavailableError
from zerde_common.ai_errors import (
ProviderRateLimitError,
ProviderResponseError,
Expand Down Expand Up @@ -81,7 +81,10 @@ def get_rpd_status(self) -> tuple[int, int]:
_RETRY_DELAYS = (5, 15, 30) # seconds; quiz runs on schedule, has time

def generate_json(self, prompt: str, temperature: float = 0.3) -> dict:
count, within_limit = self._rate_repo.increment_and_check()
try:
count, within_limit = self._rate_repo.increment_and_check()
except QuizRateLimitUnavailableError as exc:
raise ProviderTransportError("Quiz Gemini RPD counter unavailable") from exc
if not within_limit:
logger.warning(
"Quiz Gemini RPD limit reached",
Expand Down
8 changes: 6 additions & 2 deletions src/quiz/services/rate_limit_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
_TTL_DELTA = timedelta(hours=48)


class QuizRateLimitUnavailableError(RuntimeError):
"""Raised when the quiz RPD counter cannot be updated safely."""


class QuizRateLimitRepository:
"""DynamoDB-backed daily counter for quiz Gemini requests."""

Expand Down Expand Up @@ -47,9 +51,9 @@ def increment_and_check(self) -> tuple[int, bool]:
ExpressionAttributeValues={":inc": 1, ":zero": 0, ":ttl": ttl_epoch},
ReturnValues="UPDATED_NEW",
)
except ClientError:
except ClientError as exc:
logger.exception("Failed to increment quiz Gemini RPD counter")
return 0, True
raise QuizRateLimitUnavailableError("Quiz Gemini RPD counter unavailable") from exc

count = int(resp["Attributes"]["request_count"])
return count, count <= self.rpd_limit
Expand Down
105 changes: 105 additions & 0 deletions tests/test_rate_limit_fail_closed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""RPD counter failures must fail closed before calling Gemini."""

from __future__ import annotations

import os
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest
from botocore.exceptions import ClientError

from services.ai.gemini_client import GeminiClient, GeminiUnavailableError
from services.repositories.rate_limit import RateLimitRepository, RateLimitUnavailableError


def _client_error() -> ClientError:
return ClientError(
{"Error": {"Code": "InternalServerError", "Message": "dynamodb unavailable"}},
"UpdateItem",
)


def test_bot_rate_limit_increment_failure_fails_closed() -> None:
table = MagicMock()
table.update_item.side_effect = _client_error()

repo = RateLimitRepository()
with patch("services.repositories.rate_limit.get_dynamodb") as mock_get_dynamodb:
mock_get_dynamodb.return_value.Table.return_value = table

with pytest.raises(RateLimitUnavailableError):
repo.increment_and_check()


def test_gemini_text_explain_does_not_call_api_when_counter_unavailable() -> None:
client = GeminiClient.__new__(GeminiClient)
client._api_key = "test-key"
client._model = "test-model"
client._rate_repo = MagicMock()
client._rate_repo.increment_and_check.side_effect = RateLimitUnavailableError("counter down")
client._rate_repo.rpd_limit = 1000

with patch("services.ai.gemini_client._http.request") as mock_request:
with pytest.raises(GeminiUnavailableError):
client.explain_term("kubernetes", "en")

mock_request.assert_not_called()


def test_gemini_media_explain_does_not_call_api_when_counter_unavailable() -> None:
client = GeminiClient.__new__(GeminiClient)
client._api_key = "test-key"
client._model = "test-model"
client._rate_repo = MagicMock()
client._rate_repo.increment_and_check.side_effect = RateLimitUnavailableError("counter down")
client._rate_repo.rpd_limit = 1000

with patch("services.ai.gemini_client._http_multimodal.request") as mock_request:
with pytest.raises(GeminiUnavailableError):
client.explain_media(
media_kind="photo",
file_bytes=b"image",
mime_type="image/jpeg",
lang="en",
)

mock_request.assert_not_called()


def test_quiz_rate_limit_increment_failure_fails_closed() -> None:
os.environ.setdefault("BOT_TOKEN", "test-bot-token")
os.environ.setdefault("TABLE_NAME", "test-quiz-table")
os.environ.setdefault("QUIZ_LLM_RPD", "1000")

quiz_dir = Path(__file__).resolve().parents[1] / "src" / "quiz"
saved_modules = {
name: module
for name, module in sys.modules.items()
if name in {"core", "services"} or name.startswith(("core.", "services."))
}

try:
for name in list(saved_modules):
sys.modules.pop(name, None)
sys.path.insert(0, str(quiz_dir))
from services.rate_limit_repository import ( # noqa: PLC0415
QuizRateLimitRepository,
QuizRateLimitUnavailableError,
)

repo = QuizRateLimitRepository.__new__(QuizRateLimitRepository)
repo._table = MagicMock()
repo._table.update_item.side_effect = _client_error()
repo.rpd_limit = 1000

with pytest.raises(QuizRateLimitUnavailableError):
repo.increment_and_check()
finally:
if str(quiz_dir) in sys.path:
sys.path.remove(str(quiz_dir))
for name in list(sys.modules):
if name in {"core", "services"} or name.startswith(("core.", "services.")):
sys.modules.pop(name, None)
sys.modules.update(saved_modules)
Loading