Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@


# LLM provider
# LLM provider: openai | gemini
LLM_PROVIDER=openai
OPENAI_API_KEY=
OPENAI_MODEL=gpt-4.1-mini

# When using Gemini, optional (defaults in app/config.py)
# GEMINI_API_KEY=
# GEMINI_MODEL=gemini-2.5-flash
# Throttle to stay under Google free-tier RPM; set 0 to disable
# GEMINI_MAX_REQUESTS_PER_MINUTE=5
# GEMINI_RATE_LIMIT_WINDOW_SECONDS=60

10 changes: 9 additions & 1 deletion app/agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import time
from functools import lru_cache
from typing import Any

Expand Down Expand Up @@ -41,7 +42,14 @@ def run_analysis(query: str, source_ids: list[str] | None = None) -> dict[str, A
"""Compatibility entrypoint used by the API service layer."""

workflow = build_graph()
return workflow.invoke(create_initial_state(query, source_ids=source_ids))
started = time.perf_counter()
result = workflow.invoke(create_initial_state(query, source_ids=source_ids))
if not isinstance(result, dict):
result = dict(result)
else:
result = dict(result)
result["runtime_ms"] = max(0, int((time.perf_counter() - started) * 1000))
return result


def load_schema_context_node(state: dict[str, Any]) -> dict[str, Any]:
Expand Down
2 changes: 2 additions & 0 deletions app/api/chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _assistant_metadata(result: AnalyzeResponse) -> dict:
"executed_steps": serialized["executed_steps"],
"errors": serialized["errors"],
"inspection_id": serialized["inspection_id"],
"runtime_ms": serialized.get("runtime_ms"),
}


Expand Down Expand Up @@ -222,4 +223,5 @@ def chat_turn(
executed_steps=analysis_result.executed_steps,
errors=analysis_result.errors,
inspection_id=analysis_result.inspection_id,
runtime_ms=analysis_result.runtime_ms,
)
2 changes: 1 addition & 1 deletion app/api/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _build_inspection(inspection_id: str, prompt: str, response: AnalyzeResponse
and len(response.errors) == 0
and any(step.status == "success" for step in executed_steps)
)
runtime_ms = None
runtime_ms = response.runtime_ms

return InspectionData(
id=inspection_id,
Expand Down
9 changes: 9 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ class Settings(BaseSettings):
llm_provider: str = Field(default="openai", alias="LLM_PROVIDER")
gemini_api_key: str | None = Field(default=None, alias="GEMINI_API_KEY")
gemini_model: str = "gemini-2.0-flash"
# Free-tier limits are often 5 RPM; set to 0 to disable throttling (e.g. paid / higher limits).
gemini_max_requests_per_minute: int = Field(
default=5,
alias="GEMINI_MAX_REQUESTS_PER_MINUTE",
)
gemini_rate_limit_window_seconds: float = Field(
default=60.0,
alias="GEMINI_RATE_LIMIT_WINDOW_SECONDS",
)
openai_api_key: str | None = Field(default=None, alias="OPENAI_API_KEY")
openai_model: str = "gpt-4.1-mini"
log_level: str = "INFO"
Expand Down
70 changes: 68 additions & 2 deletions app/llm/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,80 @@

from __future__ import annotations

from typing import TypeVar
import threading
import time
from collections import deque
from typing import Any, TypeVar

from google import genai
from pydantic import BaseModel

from app.config import get_settings
from app.llm.json_response import validate_structured_output
from app.utils.logging import get_logger

SchemaT = TypeVar("SchemaT", bound=BaseModel)

logger = get_logger(__name__)

# Shared across all GeminiClient instances and requests (planner, query writer, etc.).
_rate_state_lock = threading.Lock()
_request_times: deque[float] = deque()


def _acquire_gemini_request_slot() -> None:
"""Block until a generate_content call is allowed (sliding window over the free-tier RPM)."""

settings = get_settings()
max_n = settings.gemini_max_requests_per_minute
if max_n <= 0:
return
window = max(settings.gemini_rate_limit_window_seconds, 0.1)

while True:
wait_s = 0.0
with _rate_state_lock:
now = time.monotonic()
while _request_times and now - _request_times[0] >= window:
_request_times.popleft()
if len(_request_times) < max_n:
_request_times.append(now)
return
wait_s = max(0.0, window - (now - _request_times[0]) + 0.05)
if wait_s > 0.5:
logger.info(
"Throttling Gemini calls: waiting %.1fs to respect max %d requests / %.0fs (RPM limit).",
wait_s,
max_n,
window,
)
time.sleep(wait_s if wait_s > 0 else 0.01)


def _json_schema_for_gemini_structured_output(model: type[BaseModel]) -> dict[str, Any]:
"""Build a JSON Schema that Gemini accepts.

Pydantic's default schema for ``extra="forbid"`` includes ``additionalProperties: false`` and
similar keywords. The Gemini API rejects those; we still validate with Pydantic after the call.
"""
return _strip_gemini_incompatible_json_schema_keywords(model.model_json_schema())


def _strip_gemini_incompatible_json_schema_keywords(value: object) -> Any:
if isinstance(value, dict):
return {
k: _strip_gemini_incompatible_json_schema_keywords(v)
for k, v in value.items()
if k
not in (
"additionalProperties",
"unevaluatedProperties", # OpenAPI / draft 2020-12 style
)
}
if isinstance(value, list):
return [_strip_gemini_incompatible_json_schema_keywords(v) for v in value]
return value


class GeminiClient:
"""Thin wrapper around the Gemini API."""
Expand All @@ -26,12 +90,13 @@ def __init__(self) -> None:
def generate_json(self, prompt: str, schema: type[SchemaT]) -> SchemaT:
"""Generate schema-constrained JSON and return a validated model."""

_acquire_gemini_request_slot()
response = self.client.models.generate_content(
model=self.model,
contents=prompt,
config={
"response_mime_type": "application/json",
"response_schema": schema,
"response_json_schema": _json_schema_for_gemini_structured_output(schema),
},
)
text = response.text or ""
Expand All @@ -40,5 +105,6 @@ def generate_json(self, prompt: str, schema: type[SchemaT]) -> SchemaT:
def generate_text(self, prompt: str) -> str:
"""Generate free text for final user-facing output."""

_acquire_gemini_request_slot()
response = self.client.models.generate_content(model=self.model, contents=prompt)
return (response.text or "").strip()
2 changes: 2 additions & 0 deletions app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ class AnalyzeResponse(BaseModel):
executed_steps: list[ExecutedStep]
errors: list[ErrorItem]
inspection_id: str | None = None
runtime_ms: int | None = None


MessageRoleLiteral = Literal["user", "assistant"]
Expand Down Expand Up @@ -546,6 +547,7 @@ class ChatTurnResponse(BaseModel):
executed_steps: list[ExecutedStep]
errors: list[ErrorItem]
inspection_id: str | None = None
runtime_ms: int | None = None


class HealthResponse(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions app/services/analysis_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ def run_stored_analysis(query: str, source_ids: list[str] | None = None) -> Stor
"""Run `run_analysis`, persist inspection to process memory, return API + inspection objects."""

state = run_analysis(query, source_ids=source_ids)
raw_ms = state.get("runtime_ms")
runtime_ms: int | None = int(raw_ms) if raw_ms is not None else None
base = AnalyzeResponse(
analysis=state["analysis"],
trace=state.get("trace", []),
executed_steps=state.get("executed_steps", []),
errors=state.get("errors", []),
runtime_ms=runtime_ms,
)
inspection_id, inspection = store_inspection(query, base)
response = AnalyzeResponse(
Expand All @@ -42,5 +45,6 @@ def run_stored_analysis(query: str, source_ids: list[str] | None = None) -> Stor
executed_steps=base.executed_steps,
errors=base.errors,
inspection_id=inspection_id,
runtime_ms=base.runtime_ms,
)
return StoredAnalysisRun(response=response, inspection=inspection)
4 changes: 3 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def fake_run_analysis(query: str, source_ids=None) -> dict: # noqa: ARG001
}
],
"errors": [],
"runtime_ms": 7,
}

app.dependency_overrides = {}
Expand All @@ -121,11 +122,12 @@ def fake_run_analysis(query: str, source_ids=None) -> dict: # noqa: ARG001
analysis_run.run_analysis = original
assert response.status_code == 200
payload = response.json()
assert {"analysis", "trace", "executed_steps", "errors", "inspection_id"} <= payload.keys()
assert {"analysis", "trace", "executed_steps", "errors", "inspection_id", "runtime_ms"} <= payload.keys()
assert isinstance(payload["trace"], list)
assert isinstance(payload["executed_steps"], list)
assert isinstance(payload["analysis"], str)
assert isinstance(payload["inspection_id"], str)
assert payload["runtime_ms"] == 7


def test_analyze_endpoint_returns_http_500_on_failure(client: TestClient) -> None:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _fake_analysis_state(query: str, source_ids=None) -> dict: # noqa: ARG001
"trace": [{"step": "planner_compiled_node", "status": "completed", "details": {}}],
"executed_steps": [],
"errors": [],
"runtime_ms": 42,
}


Expand All @@ -62,7 +63,9 @@ def test_chat_creates_conversation_on_first_prompt(chat_client: TestClient) -> N
assert body["assistant_message"]["role"] == "assistant"
assert body["assistant_message"]["content"] == "## Demo\nHello from fake analysis.\n"
assert body["analysis"] == body["assistant_message"]["content"]
assert body["runtime_ms"] == 42
assert body["assistant_message"]["metadata_json"]["inspection_id"] == body["inspection_id"]
assert body["assistant_message"]["metadata_json"]["runtime_ms"] == 42

lst = chat_client.get("/conversations", headers={"Authorization": f"Bearer {token}"})
assert lst.status_code == 200
Expand Down
48 changes: 48 additions & 0 deletions tests/test_gemini_rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Sliding-window rate limit for Gemini API (free-tier RPM)."""

from __future__ import annotations

import time

import pytest

from app.config import get_settings
from app.llm import gemini as gemini_mod


@pytest.fixture
def clear_gemini_rate_deque() -> None:
gemini_mod._request_times.clear()
yield
gemini_mod._request_times.clear()
get_settings.cache_clear()


def test_acquire_allows_bursts_up_to_max_then_waits(
monkeypatch: pytest.MonkeyPatch,
clear_gemini_rate_deque: None,
) -> None:
"""With max=2 in a 150ms window, the third slot must wait until the window slides."""
monkeypatch.setenv("GEMINI_MAX_REQUESTS_PER_MINUTE", "2")
monkeypatch.setenv("GEMINI_RATE_LIMIT_WINDOW_SECONDS", "0.15")
get_settings.cache_clear()
t0 = time.perf_counter()
gemini_mod._acquire_gemini_request_slot()
gemini_mod._acquire_gemini_request_slot()
gemini_mod._acquire_gemini_request_slot()
elapsed = time.perf_counter() - t0
get_settings.cache_clear()
assert elapsed >= 0.12, "third acquire should block until a slot is released"


def test_acquire_is_noop_when_max_zero(
monkeypatch: pytest.MonkeyPatch,
clear_gemini_rate_deque: None,
) -> None:
monkeypatch.setenv("GEMINI_MAX_REQUESTS_PER_MINUTE", "0")
get_settings.cache_clear()
t0 = time.perf_counter()
for _ in range(5):
gemini_mod._acquire_gemini_request_slot()
assert time.perf_counter() - t0 < 0.1
get_settings.cache_clear()
18 changes: 18 additions & 0 deletions tests/test_llm_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from pydantic import BaseModel, ConfigDict, ValidationError

from app.llm.gemini import _json_schema_for_gemini_structured_output
from app.llm.json_response import validate_structured_output


Expand All @@ -31,3 +32,20 @@ def test_validate_structured_output_rejects_extra_keys() -> None:
schema=DemoPayload,
source="test",
)


def _assert_no_additional_properties(obj: object) -> None:
if isinstance(obj, dict):
assert "additionalProperties" not in obj
assert "unevaluatedProperties" not in obj
for v in obj.values():
_assert_no_additional_properties(v)
elif isinstance(obj, list):
for v in obj:
_assert_no_additional_properties(v)


def test_json_schema_for_gemini_strips_additional_properties() -> None:
"""Gemini rejects JSON Schema with additionalProperties; our adapter removes them before the API call."""
s = _json_schema_for_gemini_structured_output(DemoPayload)
_assert_no_additional_properties(s)
3 changes: 3 additions & 0 deletions ui/src/api/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,15 @@ function mapApiMessagesToChatMessages(messages: ApiMessage[]): ChatMessage[] {
}

function mapPersistedAssistantToChatMessage(m: ApiMessage, promptForInspection: string): ChatMessage {
const v = m.metadata_json?.runtime_ms;
const response: AnalyzeApiResponse = {
analysis: m.content,
trace: (m.metadata_json?.trace as AnalyzeApiResponse["trace"]) ?? [],
executed_steps: (m.metadata_json?.executed_steps as AnalyzeApiResponse["executed_steps"]) ?? [],
errors: (m.metadata_json?.errors as AnalyzeApiResponse["errors"]) ?? [],
inspection_id:
typeof m.metadata_json?.inspection_id === "string" ? m.metadata_json.inspection_id : undefined,
runtime_ms: typeof v === "number" && Number.isFinite(v) ? v : null,
};
const mapped = mapAnalyzeResponseToUi(promptForInspection, response);
cacheInspection(mapped.inspection);
Expand Down Expand Up @@ -163,6 +165,7 @@ export async function submitChatPrompt(payload: ChatRequest, accessToken: string
executed_steps: raw.executed_steps,
errors: raw.errors,
inspection_id: raw.inspection_id ?? undefined,
runtime_ms: raw.runtime_ms,
};
const mapped = mapAnalyzeResponseToUi(payload.prompt, analyzeLike);
cacheInspection(mapped.inspection);
Expand Down
2 changes: 1 addition & 1 deletion ui/src/api/mappers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function buildInspection(id: string, prompt: string, response: AnalyzeApiRespons
const status = deriveInspectionStatus(response, executedSteps);
const verified = status === "valid" && response.errors.length === 0 && executedSteps.some((step) => step.status === "success");
const query = buildCodeBundle(executedSteps);
const runtimeMs = null;
const runtimeMs = response.runtime_ms ?? null;

return {
id,
Expand Down
4 changes: 4 additions & 0 deletions ui/src/api/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ export interface AnalyzeApiResponse {
executed_steps: AnalyzeExecutedStep[];
errors: AnalyzeErrorItem[];
inspection_id?: string;
/** Wall-clock run time of the graph (ms), from the backend. */
runtime_ms?: number | null;
}

/** GET /conversations row (backend snake_case). */
Expand Down Expand Up @@ -132,4 +134,6 @@ export interface ApiChatTurnResponse {
executed_steps: AnalyzeExecutedStep[];
errors: AnalyzeErrorItem[];
inspection_id: string | null;
/** Omitted on very old clients; use null when absent. */
runtime_ms?: number | null;
}