diff --git a/.env.example b/.env.example index 7cf769b..ad66488 100644 --- a/.env.example +++ b/.env.example @@ -1,8 +1,13 @@ -# LLM provider +# LLM provider: openai | gemini LLM_PROVIDER=openai OPENAI_API_KEY= OPENAI_MODEL=gpt-4.1-mini - +# When using Gemini, optional (defaults in app/config.py) +# GEMINI_API_KEY= +# GEMINI_MODEL=gemini-2.5-flash +# Throttle to stay under Google free-tier RPM; set 0 to disable +# GEMINI_MAX_REQUESTS_PER_MINUTE=5 +# GEMINI_RATE_LIMIT_WINDOW_SECONDS=60 diff --git a/app/agent/graph.py b/app/agent/graph.py index a5a2afe..7d4edd1 100644 --- a/app/agent/graph.py +++ b/app/agent/graph.py @@ -2,6 +2,7 @@ from __future__ import annotations +import time from functools import lru_cache from typing import Any @@ -41,7 +42,14 @@ def run_analysis(query: str, source_ids: list[str] | None = None) -> dict[str, A """Compatibility entrypoint used by the API service layer.""" workflow = build_graph() - return workflow.invoke(create_initial_state(query, source_ids=source_ids)) + started = time.perf_counter() + result = workflow.invoke(create_initial_state(query, source_ids=source_ids)) + if not isinstance(result, dict): + result = dict(result) + else: + result = dict(result) + result["runtime_ms"] = max(0, int((time.perf_counter() - started) * 1000)) + return result def load_schema_context_node(state: dict[str, Any]) -> dict[str, Any]: diff --git a/app/api/chat_routes.py b/app/api/chat_routes.py index 47164e2..9562d55 100644 --- a/app/api/chat_routes.py +++ b/app/api/chat_routes.py @@ -62,6 +62,7 @@ def _assistant_metadata(result: AnalyzeResponse) -> dict: "executed_steps": serialized["executed_steps"], "errors": serialized["errors"], "inspection_id": serialized["inspection_id"], + "runtime_ms": serialized.get("runtime_ms"), } @@ -222,4 +223,5 @@ def chat_turn( executed_steps=analysis_result.executed_steps, errors=analysis_result.errors, inspection_id=analysis_result.inspection_id, + runtime_ms=analysis_result.runtime_ms, ) diff --git a/app/api/workspace.py b/app/api/workspace.py index 4e52ca2..7afa198 100644 --- a/app/api/workspace.py +++ b/app/api/workspace.py @@ -78,7 +78,7 @@ def _build_inspection(inspection_id: str, prompt: str, response: AnalyzeResponse and len(response.errors) == 0 and any(step.status == "success" for step in executed_steps) ) - runtime_ms = None + runtime_ms = response.runtime_ms return InspectionData( id=inspection_id, diff --git a/app/config.py b/app/config.py index 52a683f..2cd6c29 100644 --- a/app/config.py +++ b/app/config.py @@ -29,6 +29,15 @@ class Settings(BaseSettings): llm_provider: str = Field(default="openai", alias="LLM_PROVIDER") gemini_api_key: str | None = Field(default=None, alias="GEMINI_API_KEY") gemini_model: str = "gemini-2.0-flash" + # Free-tier limits are often 5 RPM; set to 0 to disable throttling (e.g. paid / higher limits). + gemini_max_requests_per_minute: int = Field( + default=5, + alias="GEMINI_MAX_REQUESTS_PER_MINUTE", + ) + gemini_rate_limit_window_seconds: float = Field( + default=60.0, + alias="GEMINI_RATE_LIMIT_WINDOW_SECONDS", + ) openai_api_key: str | None = Field(default=None, alias="OPENAI_API_KEY") openai_model: str = "gpt-4.1-mini" log_level: str = "INFO" diff --git a/app/llm/gemini.py b/app/llm/gemini.py index 02881d7..a9a2dc0 100644 --- a/app/llm/gemini.py +++ b/app/llm/gemini.py @@ -2,16 +2,80 @@ from __future__ import annotations -from typing import TypeVar +import threading +import time +from collections import deque +from typing import Any, TypeVar from google import genai from pydantic import BaseModel from app.config import get_settings from app.llm.json_response import validate_structured_output +from app.utils.logging import get_logger SchemaT = TypeVar("SchemaT", bound=BaseModel) +logger = get_logger(__name__) + +# Shared across all GeminiClient instances and requests (planner, query writer, etc.). +_rate_state_lock = threading.Lock() +_request_times: deque[float] = deque() + + +def _acquire_gemini_request_slot() -> None: + """Block until a generate_content call is allowed (sliding window over the free-tier RPM).""" + + settings = get_settings() + max_n = settings.gemini_max_requests_per_minute + if max_n <= 0: + return + window = max(settings.gemini_rate_limit_window_seconds, 0.1) + + while True: + wait_s = 0.0 + with _rate_state_lock: + now = time.monotonic() + while _request_times and now - _request_times[0] >= window: + _request_times.popleft() + if len(_request_times) < max_n: + _request_times.append(now) + return + wait_s = max(0.0, window - (now - _request_times[0]) + 0.05) + if wait_s > 0.5: + logger.info( + "Throttling Gemini calls: waiting %.1fs to respect max %d requests / %.0fs (RPM limit).", + wait_s, + max_n, + window, + ) + time.sleep(wait_s if wait_s > 0 else 0.01) + + +def _json_schema_for_gemini_structured_output(model: type[BaseModel]) -> dict[str, Any]: + """Build a JSON Schema that Gemini accepts. + + Pydantic's default schema for ``extra="forbid"`` includes ``additionalProperties: false`` and + similar keywords. The Gemini API rejects those; we still validate with Pydantic after the call. + """ + return _strip_gemini_incompatible_json_schema_keywords(model.model_json_schema()) + + +def _strip_gemini_incompatible_json_schema_keywords(value: object) -> Any: + if isinstance(value, dict): + return { + k: _strip_gemini_incompatible_json_schema_keywords(v) + for k, v in value.items() + if k + not in ( + "additionalProperties", + "unevaluatedProperties", # OpenAPI / draft 2020-12 style + ) + } + if isinstance(value, list): + return [_strip_gemini_incompatible_json_schema_keywords(v) for v in value] + return value + class GeminiClient: """Thin wrapper around the Gemini API.""" @@ -26,12 +90,13 @@ def __init__(self) -> None: def generate_json(self, prompt: str, schema: type[SchemaT]) -> SchemaT: """Generate schema-constrained JSON and return a validated model.""" + _acquire_gemini_request_slot() response = self.client.models.generate_content( model=self.model, contents=prompt, config={ "response_mime_type": "application/json", - "response_schema": schema, + "response_json_schema": _json_schema_for_gemini_structured_output(schema), }, ) text = response.text or "" @@ -40,5 +105,6 @@ def generate_json(self, prompt: str, schema: type[SchemaT]) -> SchemaT: def generate_text(self, prompt: str) -> str: """Generate free text for final user-facing output.""" + _acquire_gemini_request_slot() response = self.client.models.generate_content(model=self.model, contents=prompt) return (response.text or "").strip() diff --git a/app/schemas.py b/app/schemas.py index 7546bcc..a839b92 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -451,6 +451,7 @@ class AnalyzeResponse(BaseModel): executed_steps: list[ExecutedStep] errors: list[ErrorItem] inspection_id: str | None = None + runtime_ms: int | None = None MessageRoleLiteral = Literal["user", "assistant"] @@ -546,6 +547,7 @@ class ChatTurnResponse(BaseModel): executed_steps: list[ExecutedStep] errors: list[ErrorItem] inspection_id: str | None = None + runtime_ms: int | None = None class HealthResponse(BaseModel): diff --git a/app/services/analysis_run.py b/app/services/analysis_run.py index 8a55000..cae7969 100644 --- a/app/services/analysis_run.py +++ b/app/services/analysis_run.py @@ -29,11 +29,14 @@ def run_stored_analysis(query: str, source_ids: list[str] | None = None) -> Stor """Run `run_analysis`, persist inspection to process memory, return API + inspection objects.""" state = run_analysis(query, source_ids=source_ids) + raw_ms = state.get("runtime_ms") + runtime_ms: int | None = int(raw_ms) if raw_ms is not None else None base = AnalyzeResponse( analysis=state["analysis"], trace=state.get("trace", []), executed_steps=state.get("executed_steps", []), errors=state.get("errors", []), + runtime_ms=runtime_ms, ) inspection_id, inspection = store_inspection(query, base) response = AnalyzeResponse( @@ -42,5 +45,6 @@ def run_stored_analysis(query: str, source_ids: list[str] | None = None) -> Stor executed_steps=base.executed_steps, errors=base.errors, inspection_id=inspection_id, + runtime_ms=base.runtime_ms, ) return StoredAnalysisRun(response=response, inspection=inspection) diff --git a/tests/test_api.py b/tests/test_api.py index fc70620..c4be543 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -96,6 +96,7 @@ def fake_run_analysis(query: str, source_ids=None) -> dict: # noqa: ARG001 } ], "errors": [], + "runtime_ms": 7, } app.dependency_overrides = {} @@ -121,11 +122,12 @@ def fake_run_analysis(query: str, source_ids=None) -> dict: # noqa: ARG001 analysis_run.run_analysis = original assert response.status_code == 200 payload = response.json() - assert {"analysis", "trace", "executed_steps", "errors", "inspection_id"} <= payload.keys() + assert {"analysis", "trace", "executed_steps", "errors", "inspection_id", "runtime_ms"} <= payload.keys() assert isinstance(payload["trace"], list) assert isinstance(payload["executed_steps"], list) assert isinstance(payload["analysis"], str) assert isinstance(payload["inspection_id"], str) + assert payload["runtime_ms"] == 7 def test_analyze_endpoint_returns_http_500_on_failure(client: TestClient) -> None: diff --git a/tests/test_conversations.py b/tests/test_conversations.py index 7136f1c..346f956 100644 --- a/tests/test_conversations.py +++ b/tests/test_conversations.py @@ -36,6 +36,7 @@ def _fake_analysis_state(query: str, source_ids=None) -> dict: # noqa: ARG001 "trace": [{"step": "planner_compiled_node", "status": "completed", "details": {}}], "executed_steps": [], "errors": [], + "runtime_ms": 42, } @@ -62,7 +63,9 @@ def test_chat_creates_conversation_on_first_prompt(chat_client: TestClient) -> N assert body["assistant_message"]["role"] == "assistant" assert body["assistant_message"]["content"] == "## Demo\nHello from fake analysis.\n" assert body["analysis"] == body["assistant_message"]["content"] + assert body["runtime_ms"] == 42 assert body["assistant_message"]["metadata_json"]["inspection_id"] == body["inspection_id"] + assert body["assistant_message"]["metadata_json"]["runtime_ms"] == 42 lst = chat_client.get("/conversations", headers={"Authorization": f"Bearer {token}"}) assert lst.status_code == 200 diff --git a/tests/test_gemini_rate_limit.py b/tests/test_gemini_rate_limit.py new file mode 100644 index 0000000..8b5c38e --- /dev/null +++ b/tests/test_gemini_rate_limit.py @@ -0,0 +1,48 @@ +"""Sliding-window rate limit for Gemini API (free-tier RPM).""" + +from __future__ import annotations + +import time + +import pytest + +from app.config import get_settings +from app.llm import gemini as gemini_mod + + +@pytest.fixture +def clear_gemini_rate_deque() -> None: + gemini_mod._request_times.clear() + yield + gemini_mod._request_times.clear() + get_settings.cache_clear() + + +def test_acquire_allows_bursts_up_to_max_then_waits( + monkeypatch: pytest.MonkeyPatch, + clear_gemini_rate_deque: None, +) -> None: + """With max=2 in a 150ms window, the third slot must wait until the window slides.""" + monkeypatch.setenv("GEMINI_MAX_REQUESTS_PER_MINUTE", "2") + monkeypatch.setenv("GEMINI_RATE_LIMIT_WINDOW_SECONDS", "0.15") + get_settings.cache_clear() + t0 = time.perf_counter() + gemini_mod._acquire_gemini_request_slot() + gemini_mod._acquire_gemini_request_slot() + gemini_mod._acquire_gemini_request_slot() + elapsed = time.perf_counter() - t0 + get_settings.cache_clear() + assert elapsed >= 0.12, "third acquire should block until a slot is released" + + +def test_acquire_is_noop_when_max_zero( + monkeypatch: pytest.MonkeyPatch, + clear_gemini_rate_deque: None, +) -> None: + monkeypatch.setenv("GEMINI_MAX_REQUESTS_PER_MINUTE", "0") + get_settings.cache_clear() + t0 = time.perf_counter() + for _ in range(5): + gemini_mod._acquire_gemini_request_slot() + assert time.perf_counter() - t0 < 0.1 + get_settings.cache_clear() diff --git a/tests/test_llm_json.py b/tests/test_llm_json.py index de520a8..7ec8dfb 100644 --- a/tests/test_llm_json.py +++ b/tests/test_llm_json.py @@ -5,6 +5,7 @@ import pytest from pydantic import BaseModel, ConfigDict, ValidationError +from app.llm.gemini import _json_schema_for_gemini_structured_output from app.llm.json_response import validate_structured_output @@ -31,3 +32,20 @@ def test_validate_structured_output_rejects_extra_keys() -> None: schema=DemoPayload, source="test", ) + + +def _assert_no_additional_properties(obj: object) -> None: + if isinstance(obj, dict): + assert "additionalProperties" not in obj + assert "unevaluatedProperties" not in obj + for v in obj.values(): + _assert_no_additional_properties(v) + elif isinstance(obj, list): + for v in obj: + _assert_no_additional_properties(v) + + +def test_json_schema_for_gemini_strips_additional_properties() -> None: + """Gemini rejects JSON Schema with additionalProperties; our adapter removes them before the API call.""" + s = _json_schema_for_gemini_structured_output(DemoPayload) + _assert_no_additional_properties(s) diff --git a/ui/src/api/chat.ts b/ui/src/api/chat.ts index bb8b7eb..f0cf082 100644 --- a/ui/src/api/chat.ts +++ b/ui/src/api/chat.ts @@ -109,6 +109,7 @@ function mapApiMessagesToChatMessages(messages: ApiMessage[]): ChatMessage[] { } function mapPersistedAssistantToChatMessage(m: ApiMessage, promptForInspection: string): ChatMessage { + const v = m.metadata_json?.runtime_ms; const response: AnalyzeApiResponse = { analysis: m.content, trace: (m.metadata_json?.trace as AnalyzeApiResponse["trace"]) ?? [], @@ -116,6 +117,7 @@ function mapPersistedAssistantToChatMessage(m: ApiMessage, promptForInspection: errors: (m.metadata_json?.errors as AnalyzeApiResponse["errors"]) ?? [], inspection_id: typeof m.metadata_json?.inspection_id === "string" ? m.metadata_json.inspection_id : undefined, + runtime_ms: typeof v === "number" && Number.isFinite(v) ? v : null, }; const mapped = mapAnalyzeResponseToUi(promptForInspection, response); cacheInspection(mapped.inspection); @@ -163,6 +165,7 @@ export async function submitChatPrompt(payload: ChatRequest, accessToken: string executed_steps: raw.executed_steps, errors: raw.errors, inspection_id: raw.inspection_id ?? undefined, + runtime_ms: raw.runtime_ms, }; const mapped = mapAnalyzeResponseToUi(payload.prompt, analyzeLike); cacheInspection(mapped.inspection); diff --git a/ui/src/api/mappers.ts b/ui/src/api/mappers.ts index e2544a8..f2081c2 100644 --- a/ui/src/api/mappers.ts +++ b/ui/src/api/mappers.ts @@ -62,7 +62,7 @@ function buildInspection(id: string, prompt: string, response: AnalyzeApiRespons const status = deriveInspectionStatus(response, executedSteps); const verified = status === "valid" && response.errors.length === 0 && executedSteps.some((step) => step.status === "success"); const query = buildCodeBundle(executedSteps); - const runtimeMs = null; + const runtimeMs = response.runtime_ms ?? null; return { id, diff --git a/ui/src/api/types.ts b/ui/src/api/types.ts index 78a6818..ad590c9 100644 --- a/ui/src/api/types.ts +++ b/ui/src/api/types.ts @@ -81,6 +81,8 @@ export interface AnalyzeApiResponse { executed_steps: AnalyzeExecutedStep[]; errors: AnalyzeErrorItem[]; inspection_id?: string; + /** Wall-clock run time of the graph (ms), from the backend. */ + runtime_ms?: number | null; } /** GET /conversations row (backend snake_case). */ @@ -132,4 +134,6 @@ export interface ApiChatTurnResponse { executed_steps: AnalyzeExecutedStep[]; errors: AnalyzeErrorItem[]; inspection_id: string | null; + /** Omitted on very old clients; use null when absent. */ + runtime_ms?: number | null; }