Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion app/core/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,70 @@
DEFAULT_RETRY_BACKOFFS = (1.0, 2.0, 4.0)


def _usage_value(usage: Any, name: str) -> Any:
return getattr(usage, name, None)


def log_gemini_usage(
response: Any,
*,
operation_name: str,
model: str,
attempt: int | None = None,
log_context: dict[str, Any] | None = None,
) -> None:
"""Gemini 응답 usage_metadata를 원문 없이 구조화 로그로 남긴다."""
usage = getattr(response, "usage_metadata", None)
if usage is None:
return

logger.info(
"Gemini usage: operation=%s model=%s attempt=%s "
"prompt_tokens=%s candidate_tokens=%s total_tokens=%s "
"thoughts_tokens=%s cached_tokens=%s context=%s",
operation_name,
model,
attempt,
_usage_value(usage, "prompt_token_count"),
_usage_value(usage, "candidates_token_count"),
_usage_value(usage, "total_token_count"),
_usage_value(usage, "thoughts_token_count"),
_usage_value(usage, "cached_content_token_count"),
log_context or {},
)


async def generate_content_with_retry(
client: genai.Client,
*,
contents: Any,
config: types.GenerateContentConfig,
timeout: float,
operation_name: str,
log_context: dict[str, Any] | None = None,
backoffs: Sequence[float] = DEFAULT_RETRY_BACKOFFS,
) -> Any:
"""Gemini generate_content 호출에 timeout/retry 정책을 공통 적용한다."""
last_error: Exception | None = None

for attempt in range(len(backoffs) + 1):
try:
return await asyncio.wait_for(
response = await asyncio.wait_for(
client.aio.models.generate_content(
model=settings.gemini_llm_model,
contents=contents,
config=config,
),
timeout=timeout,
)
log_gemini_usage(
response,
operation_name=operation_name,
model=settings.gemini_llm_model,
attempt=attempt + 1,
log_context=log_context,
)
return response
except TimeoutError as e:
last_error = e
if attempt >= len(backoffs):
Expand Down
127 changes: 118 additions & 9 deletions app/core/logging_middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import logging
import time
from typing import Any

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
Expand All @@ -8,25 +10,132 @@
logger = logging.getLogger("api.access")

_MAX_BODY_SIZE = 10_000 # 10KB 초과 시 truncate
_MAX_STRING_LOG_CHARS = 500
_MAX_LIST_LOG_ITEMS = 5
_SENSITIVE_KEY_PARTS = {
"api_key",
"apikey",
"authorization",
"bearer",
"client_secret",
"cookie",
"credential",
"password",
"refresh_token",
"secret",
"token",
}
_VERBOSE_KEY_PARTS = {
"audio",
"changed_code",
"content",
"live_messages",
"text",
"utterance",
}
_PII_KEYS = {
"from_name",
"fromname",
"speaker",
"speaker_name",
"speakername",
}


def _key_contains(key: str, candidates: set[str]) -> bool:
normalized = key.lower()
return any(candidate in normalized for candidate in candidates)


def _summarize_string(value: str) -> str:
return f"<{len(value)} chars>"


def _sanitize_json(value: Any, parent_key: str | None = None) -> Any:
if isinstance(value, dict):
sanitized: dict[str, Any] = {}
for key, child in value.items():
normalized_key = key.lower()
if _key_contains(normalized_key, _SENSITIVE_KEY_PARTS):
sanitized[key] = "<masked>"
elif normalized_key in _PII_KEYS:
sanitized[key] = "<masked>"
elif _key_contains(normalized_key, _VERBOSE_KEY_PARTS):
sanitized[key] = (
_summarize_string(child)
if isinstance(child, str)
else _sanitize_json(child, key)
)
else:
sanitized[key] = _sanitize_json(child, key)
return sanitized

if isinstance(value, list):
items = [
_sanitize_json(item, parent_key) for item in value[:_MAX_LIST_LOG_ITEMS]
]
if len(value) > _MAX_LIST_LOG_ITEMS:
items.append(f"<{len(value) - _MAX_LIST_LOG_ITEMS} more items>")
return items

if isinstance(value, str) and len(value) > _MAX_STRING_LOG_CHARS:
return _summarize_string(value)

return value


def _truncate(value: str) -> str:
if len(value) <= _MAX_BODY_SIZE:
return value
return value[:_MAX_BODY_SIZE] + f"... (truncated, total {len(value)} chars)"


def _media_type(content_type: str) -> str:
return content_type.split(";", 1)[0].strip().lower()


def _is_json_media_type(content_type: str) -> bool:
media_type = _media_type(content_type)
return media_type == "application/json" or media_type.endswith("+json")


async def _format_body_for_log(request: Request) -> str:
if request.method in {"GET", "HEAD", "OPTIONS"}:
return "(empty)"

content_type = request.headers.get("content-type", "")
content_length = request.headers.get("content-length", "unknown")
if content_type.startswith("multipart/form-data"):
return f"(multipart/form-data omitted, content_length={content_length})"

if not _is_json_media_type(content_type):
media_type = _media_type(content_type) or "non-json"
return f"({media_type} body omitted, content_length={content_length})"

body_bytes = await request.body()
if not body_bytes:
return "(empty)"

try:
parsed = json.loads(body_bytes)
sanitized = _sanitize_json(parsed)
return _truncate(
json.dumps(sanitized, ensure_ascii=False, separators=(",", ":"))
)
except json.JSONDecodeError:
return f"(malformed json, content_length={len(body_bytes)})"


class RequestLoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next) -> Response:
start = time.perf_counter()

body_bytes = await request.body()
body_text = body_bytes.decode("utf-8", errors="replace")
if len(body_text) > _MAX_BODY_SIZE:
body_text = (
body_text[:_MAX_BODY_SIZE]
+ f"... (truncated, total {len(body_text)} chars)"
)
body_text = await _format_body_for_log(request)

logger.info(
">>> %s %s\nBody: %s",
request.method,
request.url.path,
body_text or "(empty)",
body_text,
)

response = await call_next(request)
Expand Down
4 changes: 4 additions & 0 deletions app/domains/transcribe/services/transcript_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def _validate_raw_segments(raw_segments: list[dict]) -> list[TranscribeSegment]:
),
timeout=90.0,
operation_name="Gemini 전사 후처리",
log_context={
"segment_count": len(segments),
"num_speakers": num_speakers,
},
)
parsed = json.loads(response.text)
return _validate_raw_segments(parsed)
Expand Down
31 changes: 30 additions & 1 deletion tests/test_gemini_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from google.genai import types

from app.core.gemini import generate_content_with_retry
from app.core.gemini import generate_content_with_retry, log_gemini_usage


class _SlowThenSuccessModels:
Expand Down Expand Up @@ -35,3 +35,32 @@ async def test_generate_content_with_retry_retries_timeout() -> None:

assert response.text == "ok"
assert models.calls == 2


def test_log_gemini_usage_records_metadata(caplog) -> None:
response = SimpleNamespace(
usage_metadata=SimpleNamespace(
prompt_token_count=10,
candidates_token_count=3,
total_token_count=13,
thoughts_token_count=2,
cached_content_token_count=1,
)
)

with caplog.at_level("INFO", logger="app.core.gemini"):
log_gemini_usage(
response,
operation_name="Gemini 테스트",
model="gemini-test",
attempt=2,
log_context={"meeting_id": 1},
)

assert "operation=Gemini 테스트" in caplog.text
assert "model=gemini-test" in caplog.text
assert "attempt=2" in caplog.text
assert "prompt_tokens=10" in caplog.text
assert "candidate_tokens=3" in caplog.text
assert "total_tokens=13" in caplog.text
assert "context={'meeting_id': 1}" in caplog.text
Loading
Loading