diff --git a/app/core/gemini.py b/app/core/gemini.py index 16e5e81..5c9e3db 100644 --- a/app/core/gemini.py +++ b/app/core/gemini.py @@ -15,6 +15,39 @@ DEFAULT_RETRY_BACKOFFS = (1.0, 2.0, 4.0) +def _usage_value(usage: Any, name: str) -> Any: + return getattr(usage, name, None) + + +def log_gemini_usage( + response: Any, + *, + operation_name: str, + model: str, + attempt: int | None = None, + log_context: dict[str, Any] | None = None, +) -> None: + """Gemini 응답 usage_metadata를 원문 없이 구조화 로그로 남긴다.""" + usage = getattr(response, "usage_metadata", None) + if usage is None: + return + + logger.info( + "Gemini usage: operation=%s model=%s attempt=%s " + "prompt_tokens=%s candidate_tokens=%s total_tokens=%s " + "thoughts_tokens=%s cached_tokens=%s context=%s", + operation_name, + model, + attempt, + _usage_value(usage, "prompt_token_count"), + _usage_value(usage, "candidates_token_count"), + _usage_value(usage, "total_token_count"), + _usage_value(usage, "thoughts_token_count"), + _usage_value(usage, "cached_content_token_count"), + log_context or {}, + ) + + async def generate_content_with_retry( client: genai.Client, *, @@ -22,6 +55,7 @@ async def generate_content_with_retry( config: types.GenerateContentConfig, timeout: float, operation_name: str, + log_context: dict[str, Any] | None = None, backoffs: Sequence[float] = DEFAULT_RETRY_BACKOFFS, ) -> Any: """Gemini generate_content 호출에 timeout/retry 정책을 공통 적용한다.""" @@ -29,7 +63,7 @@ async def generate_content_with_retry( for attempt in range(len(backoffs) + 1): try: - return await asyncio.wait_for( + response = await asyncio.wait_for( client.aio.models.generate_content( model=settings.gemini_llm_model, contents=contents, @@ -37,6 +71,14 @@ async def generate_content_with_retry( ), timeout=timeout, ) + log_gemini_usage( + response, + operation_name=operation_name, + model=settings.gemini_llm_model, + attempt=attempt + 1, + log_context=log_context, + ) + return response except TimeoutError as e: last_error = e if attempt >= len(backoffs): diff --git a/app/core/logging_middleware.py b/app/core/logging_middleware.py index c872de3..eabe1c6 100644 --- a/app/core/logging_middleware.py +++ b/app/core/logging_middleware.py @@ -1,5 +1,7 @@ +import json import logging import time +from typing import Any from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request @@ -8,25 +10,132 @@ logger = logging.getLogger("api.access") _MAX_BODY_SIZE = 10_000 # 10KB 초과 시 truncate +_MAX_STRING_LOG_CHARS = 500 +_MAX_LIST_LOG_ITEMS = 5 +_SENSITIVE_KEY_PARTS = { + "api_key", + "apikey", + "authorization", + "bearer", + "client_secret", + "cookie", + "credential", + "password", + "refresh_token", + "secret", + "token", +} +_VERBOSE_KEY_PARTS = { + "audio", + "changed_code", + "content", + "live_messages", + "text", + "utterance", +} +_PII_KEYS = { + "from_name", + "fromname", + "speaker", + "speaker_name", + "speakername", +} + + +def _key_contains(key: str, candidates: set[str]) -> bool: + normalized = key.lower() + return any(candidate in normalized for candidate in candidates) + + +def _summarize_string(value: str) -> str: + return f"<{len(value)} chars>" + + +def _sanitize_json(value: Any, parent_key: str | None = None) -> Any: + if isinstance(value, dict): + sanitized: dict[str, Any] = {} + for key, child in value.items(): + normalized_key = key.lower() + if _key_contains(normalized_key, _SENSITIVE_KEY_PARTS): + sanitized[key] = "" + elif normalized_key in _PII_KEYS: + sanitized[key] = "" + elif _key_contains(normalized_key, _VERBOSE_KEY_PARTS): + sanitized[key] = ( + _summarize_string(child) + if isinstance(child, str) + else _sanitize_json(child, key) + ) + else: + sanitized[key] = _sanitize_json(child, key) + return sanitized + + if isinstance(value, list): + items = [ + _sanitize_json(item, parent_key) for item in value[:_MAX_LIST_LOG_ITEMS] + ] + if len(value) > _MAX_LIST_LOG_ITEMS: + items.append(f"<{len(value) - _MAX_LIST_LOG_ITEMS} more items>") + return items + + if isinstance(value, str) and len(value) > _MAX_STRING_LOG_CHARS: + return _summarize_string(value) + + return value + + +def _truncate(value: str) -> str: + if len(value) <= _MAX_BODY_SIZE: + return value + return value[:_MAX_BODY_SIZE] + f"... (truncated, total {len(value)} chars)" + + +def _media_type(content_type: str) -> str: + return content_type.split(";", 1)[0].strip().lower() + + +def _is_json_media_type(content_type: str) -> bool: + media_type = _media_type(content_type) + return media_type == "application/json" or media_type.endswith("+json") + + +async def _format_body_for_log(request: Request) -> str: + if request.method in {"GET", "HEAD", "OPTIONS"}: + return "(empty)" + + content_type = request.headers.get("content-type", "") + content_length = request.headers.get("content-length", "unknown") + if content_type.startswith("multipart/form-data"): + return f"(multipart/form-data omitted, content_length={content_length})" + + if not _is_json_media_type(content_type): + media_type = _media_type(content_type) or "non-json" + return f"({media_type} body omitted, content_length={content_length})" + + body_bytes = await request.body() + if not body_bytes: + return "(empty)" + + try: + parsed = json.loads(body_bytes) + sanitized = _sanitize_json(parsed) + return _truncate( + json.dumps(sanitized, ensure_ascii=False, separators=(",", ":")) + ) + except json.JSONDecodeError: + return f"(malformed json, content_length={len(body_bytes)})" class RequestLoggingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next) -> Response: start = time.perf_counter() - - body_bytes = await request.body() - body_text = body_bytes.decode("utf-8", errors="replace") - if len(body_text) > _MAX_BODY_SIZE: - body_text = ( - body_text[:_MAX_BODY_SIZE] - + f"... (truncated, total {len(body_text)} chars)" - ) + body_text = await _format_body_for_log(request) logger.info( ">>> %s %s\nBody: %s", request.method, request.url.path, - body_text or "(empty)", + body_text, ) response = await call_next(request) diff --git a/app/domains/transcribe/services/transcript_correction.py b/app/domains/transcribe/services/transcript_correction.py index 4fd6051..a8a8f3c 100644 --- a/app/domains/transcribe/services/transcript_correction.py +++ b/app/domains/transcribe/services/transcript_correction.py @@ -108,6 +108,10 @@ def _validate_raw_segments(raw_segments: list[dict]) -> list[TranscribeSegment]: ), timeout=90.0, operation_name="Gemini 전사 후처리", + log_context={ + "segment_count": len(segments), + "num_speakers": num_speakers, + }, ) parsed = json.loads(response.text) return _validate_raw_segments(parsed) diff --git a/tests/test_gemini_retry.py b/tests/test_gemini_retry.py index 41f53af..addd10e 100644 --- a/tests/test_gemini_retry.py +++ b/tests/test_gemini_retry.py @@ -4,7 +4,7 @@ import pytest from google.genai import types -from app.core.gemini import generate_content_with_retry +from app.core.gemini import generate_content_with_retry, log_gemini_usage class _SlowThenSuccessModels: @@ -35,3 +35,32 @@ async def test_generate_content_with_retry_retries_timeout() -> None: assert response.text == "ok" assert models.calls == 2 + + +def test_log_gemini_usage_records_metadata(caplog) -> None: + response = SimpleNamespace( + usage_metadata=SimpleNamespace( + prompt_token_count=10, + candidates_token_count=3, + total_token_count=13, + thoughts_token_count=2, + cached_content_token_count=1, + ) + ) + + with caplog.at_level("INFO", logger="app.core.gemini"): + log_gemini_usage( + response, + operation_name="Gemini 테스트", + model="gemini-test", + attempt=2, + log_context={"meeting_id": 1}, + ) + + assert "operation=Gemini 테스트" in caplog.text + assert "model=gemini-test" in caplog.text + assert "attempt=2" in caplog.text + assert "prompt_tokens=10" in caplog.text + assert "candidate_tokens=3" in caplog.text + assert "total_tokens=13" in caplog.text + assert "context={'meeting_id': 1}" in caplog.text diff --git a/tests/test_logging_middleware.py b/tests/test_logging_middleware.py new file mode 100644 index 0000000..70ede0a --- /dev/null +++ b/tests/test_logging_middleware.py @@ -0,0 +1,155 @@ +import pytest +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from app.core.logging_middleware import RequestLoggingMiddleware, _sanitize_json + + +def test_sanitize_json_masks_sensitive_and_verbose_fields() -> None: + payload = { + "apiKey": "secret-key", + "nested": { + "accessToken": "token-value", + "fromName": "유상완", + "changed_code": "+GEMINI_API_KEY=secret", + }, + "live_messages": [ + {"text": "회의 원문입니다", "from_name": "유진"}, + {"text": "두 번째 발화입니다", "from_name": "김준용"}, + {"text": "세 번째 발화입니다", "from_name": "조윤지"}, + {"text": "네 번째 발화입니다", "from_name": "meme"}, + {"text": "다섯 번째 발화입니다", "from_name": "상완"}, + {"text": "여섯 번째 발화입니다", "from_name": "extra"}, + ], + } + + sanitized = _sanitize_json(payload) + + assert sanitized["apiKey"] == "" + assert sanitized["nested"]["accessToken"] == "" + assert sanitized["nested"]["fromName"] == "" + assert sanitized["nested"]["changed_code"] == "<22 chars>" + assert len(sanitized["live_messages"]) == 6 + assert sanitized["live_messages"][-1] == "<1 more items>" + assert sanitized["live_messages"][0]["text"] == "<8 chars>" + assert sanitized["live_messages"][0]["from_name"] == "" + + +def test_logging_middleware_sanitizes_body_without_consuming_request( + caplog, +) -> None: + app = FastAPI() + app.add_middleware(RequestLoggingMiddleware) + + @app.post("/echo") + async def echo(request: Request) -> dict: + body = await request.json() + return {"received": body["message"]} + + client = TestClient(app) + + with caplog.at_level("INFO", logger="api.access"): + response = client.post( + "/echo", + json={ + "message": "ok", + "password": "should-not-log", + "changed_code": "+secret=abc", + }, + ) + + assert response.status_code == 200 + assert response.json() == {"received": "ok"} + log_text = caplog.text + assert "should-not-log" not in log_text + assert "+secret=abc" not in log_text + assert '"password":""' in log_text + assert '"changed_code":"<11 chars>"' in log_text + + +def test_logging_middleware_omits_multipart_body(caplog) -> None: + app = FastAPI() + app.add_middleware(RequestLoggingMiddleware) + + @app.post("/upload") + async def upload() -> dict: + return {"ok": True} + + client = TestClient(app) + + with caplog.at_level("INFO", logger="api.access"): + response = client.post( + "/upload", + files={"audio": ("meeting.ogg", b"raw-audio-bytes", "audio/ogg")}, + ) + + assert response.status_code == 200 + assert "raw-audio-bytes" not in caplog.text + assert "multipart/form-data omitted" in caplog.text + + +def test_logging_middleware_sanitizes_json_suffix_media_type(caplog) -> None: + app = FastAPI() + app.add_middleware(RequestLoggingMiddleware) + + @app.post("/json-api") + async def json_api(request: Request) -> dict: + body = await request.json() + return {"received": body["message"]} + + client = TestClient(app) + + with caplog.at_level("INFO", logger="api.access"): + response = client.post( + "/json-api", + content='{"message":"ok","accessToken":"secret-token"}', + headers={"content-type": "application/vnd.api+json; charset=utf-8"}, + ) + + assert response.status_code == 200 + assert response.json() == {"received": "ok"} + assert "secret-token" not in caplog.text + assert '"accessToken":""' in caplog.text + + +@pytest.mark.parametrize( + ("content_type", "body", "expected_log"), + [ + ( + "application/x-www-form-urlencoded", + "password=secret-password&token=secret-token", + "application/x-www-form-urlencoded body omitted", + ), + ( + "text/plain", + "password=secret-password token=secret-token", + "text/plain body omitted", + ), + ], +) +def test_logging_middleware_omits_non_json_body( + caplog, + content_type: str, + body: str, + expected_log: str, +) -> None: + app = FastAPI() + app.add_middleware(RequestLoggingMiddleware) + + @app.post("/plain") + async def plain() -> dict: + return {"ok": True} + + client = TestClient(app) + + with caplog.at_level("INFO", logger="api.access"): + response = client.post( + "/plain", + content=body, + headers={"content-type": content_type}, + ) + + assert response.status_code == 200 + assert "secret-password" not in caplog.text + assert "secret-token" not in caplog.text + assert expected_log in caplog.text