Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions app/agent/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,18 @@ def _register_artifacts(conn: duckdb.DuckDBPyConnection, state: AnalysisState) -


def _execute_sql(state: AnalysisState, step: dict[str, Any]) -> ArtifactSummary:
conn = new_duckdb_connection()
_register_artifacts(conn, state)
frame = conn.execute(step["code"]).fetchdf()
state["artifacts"][step["output_alias"]] = frame
return _summarize_artifact(step["output_alias"], frame)
conn = new_duckdb_connection(state.get("dataset_context"))
try:
_register_artifacts(conn, state)
frame = conn.execute(step["code"]).fetchdf()
state["artifacts"][step["output_alias"]] = frame
return _summarize_artifact(step["output_alias"], frame)
finally:
conn.close()


def _execute_pandas(state: AnalysisState, step: dict[str, Any]) -> ArtifactSummary:
context = get_semantic_context()
context = get_semantic_context(state.get("source_ids"))
local_env: dict[str, Any] = {
**context.raw_views,
**context.semantic_views,
Expand Down Expand Up @@ -112,26 +115,29 @@ def preflight_compiled_plan(state: AnalysisState, compiled_plan: dict[str, Any])
Steps are checked in order so later queries can reference earlier output aliases.
"""

conn = new_duckdb_connection()
_register_artifacts(conn, state)
rows = list(compiled_plan.get("plan") or [])
rows.sort(key=lambda r: r["id"] if isinstance(r, dict) else r.id)

for row in rows:
internal = compiled_plan_row_to_internal(row)
sql = internal["code"].strip().rstrip(";")
try:
preview = conn.execute(f"SELECT * FROM ({sql}) AS __planera_preflight LIMIT 0").fetchdf()
conn.register(internal["output_alias"], preview)
except Exception as exc:
return {
"status": "failed",
"failed_step_id": internal["id"],
"error": str(exc),
"query": internal["code"],
}

return {"status": "success"}
conn = new_duckdb_connection(state.get("dataset_context"))
try:
_register_artifacts(conn, state)
rows = list(compiled_plan.get("plan") or [])
rows.sort(key=lambda r: r["id"] if isinstance(r, dict) else r.id)

for row in rows:
internal = compiled_plan_row_to_internal(row)
sql = internal["code"].strip().rstrip(";")
try:
preview = conn.execute(f"SELECT * FROM ({sql}) AS __planera_preflight LIMIT 0").fetchdf()
conn.register(internal["output_alias"], preview)
except Exception as exc:
return {
"status": "failed",
"failed_step_id": internal["id"],
"error": str(exc),
"query": internal["code"],
}

return {"status": "success"}
finally:
conn.close()


def _try_sql_step(
Expand Down
6 changes: 3 additions & 3 deletions app/agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _append_error(state: AnalysisState, step: str, message: str, recoverable: bo
def load_schema_context_node(state: AnalysisState) -> AnalysisState:
step_name = "load_schema_context_node"
_append_trace(state, step_name, "started", {})
context = get_semantic_context()
context = get_semantic_context(state.get("source_ids"))
state["dataset_context"] = context.schema_manifest
_append_trace(
state,
Expand Down Expand Up @@ -162,8 +162,8 @@ def build_graph():
return graph.compile()


def run_analysis(query: str) -> AnalysisState:
def run_analysis(query: str, source_ids: list[str] | None = None) -> AnalysisState:
"""Execute the full workflow for a single user query."""

workflow = build_graph()
return workflow.invoke(create_initial_state(query))
return workflow.invoke(create_initial_state(query, source_ids=source_ids))
18 changes: 16 additions & 2 deletions app/agent/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def _field_terms(*values: str) -> set[str]:


def _column_relevance_score(column: dict[str, Any], question_terms: set[str]) -> int:
column_terms = _field_terms(column.get("name", ""))
column_terms = _field_terms(
column.get("name", ""),
column.get("original_name", ""),
column.get("source_path", ""),
)
for hint in column.get("semantic_hints") or []:
column_terms.update(_field_terms(hint))

Expand All @@ -50,10 +54,20 @@ def _column_relevance_score(column: dict[str, Any], question_terms: set[str]) ->


def _relation_relevance_score(relation: dict[str, Any], question_terms: set[str]) -> int:
score = len(_field_terms(relation.get("name", ""), relation.get("grain", "")) & question_terms) * 4
score = len(
_field_terms(
relation.get("name", ""),
relation.get("grain", ""),
relation.get("source_name", ""),
json.dumps(relation.get("lineage", {}), default=str),
)
& question_terms
) * 4
score += sum(_column_relevance_score(column, question_terms) for column in relation.get("columns") or [])
for mapping in relation.get("semantic_mappings") or []:
score += len(_field_terms(mapping.get("concept", "")) & question_terms) * 5
if relation.get("is_primary"):
score += 3
return score


Expand Down
4 changes: 3 additions & 1 deletion app/agent/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class AnalysisState(TypedDict):
"""Explicit state carried through the analytics workflow."""

query: str
source_ids: list[str]
dataset_context: dict[str, Any]
intent: str
metric: str
Expand All @@ -27,11 +28,12 @@ class AnalysisState(TypedDict):
errors: list[dict[str, Any]]


def create_initial_state(query: str) -> AnalysisState:
def create_initial_state(query: str, source_ids: list[str] | None = None) -> AnalysisState:
"""Return the initial workflow state for a new request."""

return AnalysisState(
query=query,
source_ids=list(source_ids or []),
dataset_context={},
intent="",
metric="",
Expand Down
14 changes: 13 additions & 1 deletion app/api/chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from app.services.analysis_run import run_stored_analysis
from app.services.inspection_persistence import save_inspection_for_assistant_message
from app.uploads.service import get_authorized_source_ids
from app.utils.logging import get_logger


Expand Down Expand Up @@ -146,8 +147,19 @@ def chat_turn(
db.add(user_msg)
db.flush()

requested_source_ids = list(dict.fromkeys(body.source_ids or []))
if requested_source_ids:
valid_source_ids = get_authorized_source_ids(db, current_user, requested_source_ids)
if len(valid_source_ids) != len(requested_source_ids):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"message": "Attach a valid uploaded data source before running analysis."},
)
else:
valid_source_ids = None

try:
analysis_run = run_stored_analysis(body.query)
analysis_run = run_stored_analysis(body.query, source_ids=valid_source_ids)
analysis_result = analysis_run.response
inspection_payload = analysis_run.inspection
except Exception as exc: # pragma: no cover - defensive API fallback
Expand Down
73 changes: 63 additions & 10 deletions app/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@

from __future__ import annotations

from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from fastapi import APIRouter, Depends, File, HTTPException, Response, UploadFile, status
from sqlalchemy.orm import Session

from app.api.workspace import get_inspection, profile_upload
from app.auth.deps import get_current_user_optional
from app.api.workspace import get_inspection
from app.auth.deps import get_current_user, get_current_user_optional
from app.config import get_settings
from app.db.session import get_db
from app.models.conversation import Conversation
from app.models.inspection_snapshot import InspectionSnapshot
from app.models.user import User
from app.schemas import AnalyzeRequest, AnalyzeResponse, HealthResponse, InspectionData, InspectionResponse, SampleQuestionsResponse, UploadResponse
from app.schemas import AnalyzeRequest, AnalyzeResponse, HealthResponse, InspectionData, InspectionResponse, SampleQuestionsResponse, UploadedAsset, UploadResponse
from app.services.analysis_run import run_stored_analysis
from app.uploads.service import create_user_upload, delete_user_upload, get_authorized_source_ids, list_user_uploads
from app.utils.constants import SAMPLE_QUESTIONS
from app.utils.logging import get_logger

Expand All @@ -37,18 +38,52 @@ def sample_questions() -> SampleQuestionsResponse:
return SampleQuestionsResponse(questions=SAMPLE_QUESTIONS)


@router.get("/uploads", response_model=list[UploadedAsset])
def list_uploads(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> list[UploadedAsset]:
"""Return uploads owned by the signed-in user."""

return list_user_uploads(db, current_user)


@router.post("/uploads", response_model=UploadResponse)
async def upload_dataset(file: UploadFile = File(...)) -> UploadResponse:
async def upload_dataset(
file: UploadFile = File(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> UploadResponse:
"""Accept a workspace upload and return a profiled asset summary."""

contents = await file.read()
try:
asset = profile_upload(file.filename or "upload.csv", contents)
asset = create_user_upload(
db,
current_user,
filename=file.filename or "upload.csv",
content_type=file.content_type,
content=contents,
)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail={"message": str(exc)}) from exc
return UploadResponse(asset=asset, fallback=False)


@router.delete("/uploads/{source_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_upload(
source_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Response:
"""Delete one upload owned by the signed-in user."""

deleted = delete_user_upload(db, current_user, source_id)
if not deleted:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail={"message": "Upload not found."})
return Response(status_code=status.HTTP_204_NO_CONTENT)


@router.get("/inspections/{inspection_id}", response_model=InspectionResponse)
def inspection_details(
inspection_id: str,
Expand Down Expand Up @@ -94,17 +129,35 @@ def inspection_details(
description=(
"**Not the primary product API.** For normal use, authenticated clients should call "
"`POST /chat`, which persists conversations, messages, and inspection snapshots.\n\n"
"This endpoint runs the same analytics pipeline without auth, without database persistence, "
"This endpoint runs the same analytics pipeline with auth but without conversation/database persistence, "
"and keeps the inspection payload only in process memory (lost on restart). Use it for "
"local debugging, Swagger/Postman checks, and quick stateless demos.\n\n"
"Response shape aligns with the analysis/trace/steps portion of `POST /chat`."
),
)
def analyze(request: AnalyzeRequest) -> AnalyzeResponse:
"""Run analytics without persistence (see OpenAPI ``description`` — prefer ``POST /chat`` for product flows)."""
def analyze(
request: AnalyzeRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> AnalyzeResponse:
"""Run analytics with auth but without persistence (see OpenAPI ``description`` — prefer ``POST /chat``)."""

requested_source_ids = list(dict.fromkeys(request.source_ids or []))
if not requested_source_ids:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"message": "Upload and attach at least one CSV or JSON data source before running analysis."},
)

valid_source_ids = get_authorized_source_ids(db, current_user, requested_source_ids)
if len(valid_source_ids) != len(requested_source_ids):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"message": "Attach a valid uploaded data source before running analysis."},
)

try:
return run_stored_analysis(request.query).response
return run_stored_analysis(request.query, source_ids=valid_source_ids).response
except Exception as exc: # pragma: no cover - defensive API fallback
logger.exception("Analyze request failed", extra={"query": request.query})
settings = get_settings()
Expand Down
Loading
Loading