diff --git a/README.md b/README.md index 3c86b3e..710774a 100644 --- a/README.md +++ b/README.md @@ -1,104 +1,79 @@ -# GTM Analytics Copilot +# Planera -GTM Analytics Copilot is an agentic analytics MVP for GTM teams. It takes a business question like "Why did pipeline velocity drop this week?", loads schema context for the dataset, uses an LLM to plan the next SQL or pandas step, executes it over curated views, replans on failure, then runs a single analysis pass to produce a markdown narrative. The API returns that analysis plus trace and executed steps. +Planera is a chat-first analytics workspace for structured data. You sign in, upload CSV or JSON files, ask a business question, and review both the answer and the execution trail behind it. -## Why This Is Not "Chat With CSV" +The product is designed to feel closer to an analytics copilot than a generic "chat with files" tool: -This project is intentionally constrained: +- uploads are scoped to the signed-in user +- analysis runs through a bounded plan/query/execute loop +- answers come back with trace, SQL, result previews, and validation context +- conversation history and inspection snapshots are persisted for the main chat flow -- It uses an LLM planner, but only over curated dataset views. -- It does not let the model access arbitrary files or external systems. -- It executes exact SQL or restricted pandas steps instead of vague chain-of-thought. -- It replans using execution errors instead of silently falling back to rules. -- It does not run a separate deterministic verification layer; the narrative is grounded in executed step outputs. -- It exposes every step, code snippet, and output preview in the UI. +## Current Workflow -That makes it feel much closer to a production analytics copilot than a generic chatbot sitting on top of a CSV file. +1. Sign in from the UI +2. Upload one or more CSV or JSON files +3. Start a chat or continue an existing conversation +4. Ask a question against the attached uploads +5. Review the answer, then open the inspection panel for SQL, results, trace, and validation details -## MVP Scope +## Product Surface -Supported intents: - -- `diagnosis` -- `comparison` -- `recommendation` - -Supported metric: - -- `pipeline_velocity` - -Supported dimensions: - -- `segment` -- `stage` -- `owner` -- `deal_age_bucket` -- `plan_tier` - -Out of scope: +Backend: -- churn analytics for the current dataset -- CRM writes -- forecasting -- causal inference -- broad BI workflows +- FastAPI API +- SQLite for users, conversations, messages, and inspection snapshots +- DuckDB for uploaded data and query execution +- OpenAI or Gemini for the planning and answer-generation steps -## Architecture +Frontend: -Backend: +- React + Vite workspace UI +- authenticated chat experience +- uploads management +- inspection drawer for execution details -- FastAPI for API contracts -- LangGraph for the planner-executor-replanner loop -- DuckDB plus pandas over the provided CRM sales dataset -- OpenAI or Gemini for planning and final analysis text (see `LLM_PROVIDER`) +## API Overview -Workflow: +Primary app flow: -1. Load curated views and a schema-only manifest (tables, columns, dtypes, row counts) -2. Ask the LLM for the next executable step (or finish) -3. Execute SQL (DuckDB) or restricted pandas -4. On failure or empty results, review and replan -5. Loop until the planner finishes or limits are hit -6. Ask the LLM once to turn the executed results into markdown analysis -7. Return `analysis`, `trace`, `executed_steps`, and `errors` +- `POST /auth/signup` +- `POST /auth/login` +- `GET /auth/me` +- `GET /uploads` +- `POST /uploads` +- `DELETE /uploads/{source_id}` +- `POST /chat` +- `GET /conversations` +- `GET /conversations/{id}` +- `GET /inspections/{inspection_id}` -Core modules: +Debug-only helper: -- `app/data/semantic_model.py`: curated dataset views and schema manifest -- `app/llm/`: OpenAI or Gemini client -- `app/agent/planner.py`: compiled multi-step SQL plan and optional repair -- `app/agent/executor.py`: SQL execution engine (pandas helpers retained) -- `app/agent/analysis.py`: single-pass narrative from query + steps -- `app/agent/graph.py`: LangGraph orchestration -- `app/api/routes.py`: Shared API surface (health, uploads, inspections, **stateless** `POST /analyze`) -- `app/api/chat_routes.py`: **Primary product** chat API (`POST /chat`, conversation history) -- `ui/`: React + Vite frontend +- `POST /analyze` -### API: primary chat vs stateless analyze +Notes: -| Path | Role | -|------|------| -| **`POST /chat`** (with JWT) | **Product path:** persists conversations, messages, and inspection snapshots; use for the React app and any integrated client. | -| **`POST /analyze`** (no auth) | **Debug / manual testing:** same analytics engine, but **no persistence** and inspection data only in server memory until restart. Marked **deprecated** in OpenAPI; do not treat it as a peer to `/chat`. | +- `POST /chat` is the main product API and is what the React app uses for real analysis turns. +- `POST /analyze` is a deprecated debug path. It is stateless, still authenticated, and should not be treated as the normal integration path. ## Repo Structure ```text planera/ -├── app/ -├── ui/ -├── data/ +├── app/ # FastAPI backend +├── ui/ # React frontend +├── data/ # sample data, uploads, and DuckDB registry files ├── tests/ -├── .env.example ├── requirements.txt -├── Dockerfile ├── docker-compose.yml +├── Dockerfile └── README.md ``` -## Setup +## Local Setup -### 1. Create the environment +### 1. Backend ```bash cd planera @@ -106,165 +81,73 @@ python3 -m venv .venv source .venv/bin/activate pip install -r requirements.txt cp .env.example .env -``` - -Use this project virtualenv for all Python commands (`uvicorn`, `pytest`, `pip`). In each new shell, activate it first: - -```bash -source .venv/bin/activate -``` - -*(Windows Git Bash: `source .venv/Scripts/activate` — PowerShell: `.venv\Scripts\Activate.ps1`.)* - -### 2. Run the API - -```bash -source .venv/bin/activate uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 ``` -API endpoints: - -- `GET /health` -- `GET /sample-questions` -- `POST /uploads` -- `GET /inspections/{inspection_id}` -- **`POST /chat`** — **primary:** authenticated analysis turn; persists thread + inspection snapshot -- `GET /conversations`, `GET /conversations/{id}` — list/load chat history (authenticated) -- `POST /analyze` — **deprecated / debug only:** stateless run (see table above) -- `POST /auth/signup` — create user (SQLite), returns JWT -- `POST /auth/login` — issue JWT -- `GET /auth/me` — current user (`Authorization: Bearer `) - -**Database:** On API startup the app creates SQLite tables if needed (no separate migration step for this demo). By default the DB file is `planera.db` in the project root (same directory as `requirements.txt`). Override with `DATABASE_PATH` in `.env`. Add a strong `JWT_SECRET_KEY` before any shared deployment; the repo default is for local dev only. - -Example (debug — stateless; prefer `/chat` with a JWT for real usage): - -```bash -curl -X POST http://localhost:8000/analyze \ - -H "Content-Type: application/json" \ - -d '{"query":"Why did pipeline velocity drop this week?"}' -``` - -### 3. Run the React UI +### 2. Frontend In a second terminal: ```bash cd ui -nodeenv -p --prebuilt npm install npm run dev ``` -## Environment Variables - -Backend settings are defined in `.env.example`: - -- `APP_NAME` -- `APP_ENV` -- `API_HOST` -- `API_PORT` -- `GEMINI_API_KEY` -- `GEMINI_MODEL` -- `OPENAI_API_KEY` -- `OPENAI_MODEL` -- `LOG_LEVEL` -- `DATABASE_PATH` (optional; default `planera.db` beside `requirements.txt`) -- `JWT_SECRET_KEY` (optional for local dev; **required** for non-local use) -- `JWT_ALGORITHM` (default `HS256`) -- `ACCESS_TOKEN_EXPIRE_MINUTES` (default `10080`) - -Frontend settings live in `ui/.env.example`. - -Set `LLM_PROVIDER` to `openai` or `gemini` and provide the matching API key (`OPENAI_API_KEY` or `GEMINI_API_KEY`). +Open: -## Data Model +- API: [http://localhost:8000](http://localhost:8000) +- UI: [http://localhost:5173](http://localhost:5173) -Current dataset: +## Environment Variables -- `data/CRM+Sales+Opportunities/sales_pipeline.csv` -- `data/CRM+Sales+Opportunities/accounts.csv` -- `data/CRM+Sales+Opportunities/products.csv` -- `data/CRM+Sales+Opportunities/sales_teams.csv` +Start from `.env.example` for backend setup, then override additional runtime paths or secrets as needed. -The app builds a semantic view called `opportunities_enriched` from these files and uses the dataset's latest close date as the analysis reference point. +Most important: -Key derived fields include: +- `LLM_PROVIDER` +- `OPENAI_API_KEY` or `GEMINI_API_KEY` +- `OPENAI_MODEL` or `GEMINI_MODEL` +- `DATABASE_PATH` +- `JWT_SECRET_KEY` +- `UPLOAD_STORAGE_DIR` +- `REGISTRY_PATH` +- `CORS_ALLOW_ORIGINS` -- `pipeline_velocity_days` -- `deal_age_days` -- `stage_age_days` -- `deal_age_bucket` -- `segment` -- `plan_tier` +Frontend settings live in `ui/.env.example`. -## Sample Questions +Most important: -- Why did pipeline velocity drop this week? -- Compare SMB vs Enterprise performance -- Which segment is underperforming? -- What should we do about this drop? -- Which deals should we prioritize? +- `VITE_API_BASE_URL` +- `VITE_API_FALLBACK_MODE` ## Running Tests -Use the project virtualenv so `pytest` and packages like `passlib` match `requirements.txt` (if you use Conda/base Python, a bare `pytest` may run the wrong interpreter and fail imports): +Backend: ```bash source .venv/bin/activate -pip install -r requirements.txt python -m pytest ``` -The test suite covers: +Frontend: -- mocked LLM planner and analysis contracts -- executor and review behavior -- API response shape +```bash +cd ui +npm run check +``` ## Docker -To launch both services together: +To run both services together: ```bash docker compose up --build ``` -Then open: - -- API: [http://localhost:8000](http://localhost:8000) -- UI: [http://localhost:5173](http://localhost:5173) - -## Demo Script - -1. Open the Planera UI. -2. Select "Why did pipeline velocity drop this week?" -3. Run the analysis and show the planner-executor loop spinner. -4. Open the executed-steps panel and show the generated SQL or pandas code. -5. Show the output preview for the most important step. -6. Expand the trace panel to show replanning-capable agent behavior. -7. Close on the markdown analysis and next best insights. - -## Screenshots - -Add screenshots here for: - -- main dashboard -- analysis panel -- trace panel - -## Known Limitations - -- The current build is scoped to pipeline analytics on the provided CRM dataset. -- Churn analysis is out of scope until a real subscriptions or churn dataset is added. -- The planner only sees registered views; execution is SQL or restricted pandas. -- An API key is required for the configured `LLM_PROVIDER` (OpenAI or Gemini). - -## Future Roadmap +## Notes -- Add subscription and churn datasets for a second metric family -- Richer time window parsing -- Stronger execution-time linting for generated pandas steps -- LangSmith or Phoenix integration for trace export -- Stronger deal-prioritization playbooks +- The current product flow is upload-first: the UI expects attached CSV or JSON files before submitting an analysis turn. +- The repository still contains sample CRM-style data under `data/`, but the active app flow is centered on user uploads rather than a built-in warehouse connection. +- Uploaded sources are scoped correctly, but separate uploads are not automatically joined just because they share similarly named columns. +- A valid API key is required for whichever LLM provider is configured. diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 640e7d2..599ec62 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -1 +1,8 @@ -"""Gemini-driven agent workflow package.""" +"""Agent workflow package for the multi-step analytics runtime.""" + +from __future__ import annotations + +from app.agent.graph import run_analysis +from app.agent.state import AnalysisState, create_initial_state + +__all__ = ["AnalysisState", "create_initial_state", "run_analysis"] diff --git a/app/agent/_compat.py b/app/agent/_compat.py new file mode 100644 index 0000000..0a303cf --- /dev/null +++ b/app/agent/_compat.py @@ -0,0 +1,11 @@ +"""Shared helpers for import-safe runtime stubs during staged implementation.""" + +from __future__ import annotations + +from typing import NoReturn + + +def raise_not_implemented(feature: str) -> NoReturn: + """Raise a consistent error while later workflow branches are still pending.""" + + raise NotImplementedError(f"{feature} is not implemented on this branch yet.") diff --git a/app/agent/analysis.py b/app/agent/analysis.py index f13e160..ebc1258 100644 --- a/app/agent/analysis.py +++ b/app/agent/analysis.py @@ -1,66 +1,160 @@ -"""Single LLM pass: interpret executed results into an analytics narrative.""" +"""Final analyzer runtime for the schema-grounded workflow.""" from __future__ import annotations import json +from typing import Any -from app.agent.analysis_grounding import build_analysis_evidence, build_approved_claims, validate_rendered_analysis -from app.agent.state import AnalysisState from app.llm import get_llm_client from app.prompts import render_prompt -from app.schemas import AnalysisRenderResponse +from app.schemas import AnalyzerDecision -_ANALYSIS_RENDER_ATTEMPTS = 2 +def _fallback_best_effort_answer(state: dict[str, Any]) -> str: + successful_steps = [step for step in state.get("executed_steps") or [] if step.get("status") == "success"] + answered_parts: list[str] = [] + if successful_steps: + last_success = successful_steps[-1] + artifact = last_success.get("artifact") or {} + answered_parts.append( + f"Captured {artifact.get('row_count', 0)} row(s) in {artifact.get('alias', last_success.get('output_alias', 'the final output'))}." + ) + + unanswered = state.get("failure_summary") or "The workflow could not complete every planned step." + lines = [ + "## Best-effort answer", + "", + "Answered parts:", + *(f"- {item}" for item in answered_parts or ["- No completed step produced a reusable final output."]), + "", + "Could not answer completely:", + f"- {unanswered}", + ] + return "\n".join(lines).strip() + + +def _fallback_final_answer(state: dict[str, Any]) -> AnalyzerDecision: + plan = state.get("current_plan") or {} + unsupported = plan.get("unsupported_requirements") or [] + successful_steps = [step for step in state.get("executed_steps") or [] if step.get("status") == "success"] -def _build_analysis_render_prompt( - question: str, - approved_claims_json: str, - validation_feedback: str | None = None, -) -> str: + if state.get("workflow_status") == "best_effort_ready": + return AnalyzerDecision( + decision="final_answer", + summary="Returning the bounded best-effort answer.", + key_findings=[], + important_metrics=[], + caveats=[state.get("failure_summary", "")] if state.get("failure_summary") else [], + final_answer=state.get("final_answer") or _fallback_best_effort_answer(state), + failure_summary="", + ) + + if unsupported and not successful_steps: + caveats = [item.get("description", "") for item in unsupported if item.get("description")] + final_answer = "\n".join( + [ + "## Summary", + "The available uploaded schema/context does not fully support this question.", + *(f"- {item}" for item in caveats), + ] + ).strip() + return AnalyzerDecision( + decision="final_answer", + summary="The available schema/context does not fully support the request.", + key_findings=[], + important_metrics=[], + caveats=caveats, + final_answer=final_answer, + failure_summary="", + ) + + if successful_steps: + last_success = successful_steps[-1] + artifact = last_success.get("artifact") or {} + return AnalyzerDecision( + decision="final_answer", + summary="The workflow produced a final output ready for review.", + key_findings=[f"{artifact.get('row_count', 0)} row(s) returned by the final step."], + important_metrics=[], + caveats=[state.get("failure_summary", "")] if state.get("failure_summary") else [], + final_answer="\n".join( + [ + "## Summary", + f"The final output alias is {artifact.get('alias', last_success.get('output_alias', 'final_output'))}.", + f"{artifact.get('row_count', 0)} row(s) were returned by the final executed step.", + ] + ).strip(), + failure_summary="", + ) + + if int(state.get("replan_count", 0) or 0) < 1 and state.get("failure_summary"): + return AnalyzerDecision( + decision="replan", + summary="The collected outputs are not sufficient yet.", + key_findings=[], + important_metrics=[], + caveats=[], + final_answer="", + failure_summary=state["failure_summary"], + ) + + best_effort = _fallback_best_effort_answer(state) + return AnalyzerDecision( + decision="final_answer", + summary="Returning the best-effort answer after bounded retries.", + key_findings=[], + important_metrics=[], + caveats=[state.get("failure_summary", "")] if state.get("failure_summary") else [], + final_answer=best_effort, + failure_summary="", + ) + + +def _build_analyzer_prompt(state: dict[str, Any]) -> str: + plan = state.get("current_plan") or {} + payload = { + "question": state["query"], + "workflow_status": state.get("workflow_status", ""), + "replan_count": state.get("replan_count", 0), + "failure_summary": state.get("failure_summary", ""), + "schema_context_summary": state.get("schema_context_summary") or {}, + "plan": plan, + "executed_steps": state.get("executed_steps") or [], + "errors": state.get("errors") or [], + "failure_history": state.get("failure_history") or {}, + } return render_prompt( - "analysis_render.j2", - question=question, - approved_claims_json=approved_claims_json, - validation_feedback=validation_feedback, + "analysis_final.j2", + analyzer_input_json=json.dumps(payload, indent=2, default=str), ) -def run_analysis_narrative(state: AnalysisState) -> AnalysisState: - """Produce markdown-friendly analysis from query, objective, and step outputs.""" +def analyze_workflow(state: dict[str, Any]) -> dict[str, Any]: + """Produce the final analyzer decision for the workflow run.""" - workflow = state.get("workflow_status", "") - if workflow in ("planner_failed", "execution_failed"): - state["analysis"] = "The available evidence is insufficient because the workflow did not complete successfully." - return state + try: + result = get_llm_client().generate_json(_build_analyzer_prompt(state), schema=AnalyzerDecision) + decision = result if isinstance(result, AnalyzerDecision) else AnalyzerDecision.model_validate(result) + except Exception: + decision = _fallback_final_answer(state) - evidence = build_analysis_evidence(state) - approved_claims, expected_status = build_approved_claims(evidence) - if not approved_claims: - state["analysis"] = "The approved claims are insufficient to answer the question with the available evidence." + state["analyzer_result"] = decision.model_dump() + if decision.decision == "replan" and int(state.get("replan_count", 0) or 0) < 1: + state["workflow_status"] = "needs_replan" + state["failure_summary"] = decision.failure_summary or state.get("failure_summary", "") return state - approved_claims_json = json.dumps([claim.model_dump() for claim in approved_claims], indent=2) - feedback: str | None = None - - try: - for attempt in range(1, _ANALYSIS_RENDER_ATTEMPTS + 1): - prompt = _build_analysis_render_prompt(state["query"], approved_claims_json, validation_feedback=feedback) - result = get_llm_client().generate_json(prompt, schema=AnalysisRenderResponse) - parsed = result if isinstance(result, AnalysisRenderResponse) else AnalysisRenderResponse.model_validate(result) - try: - validate_rendered_analysis(parsed, approved_claims, expected_status) - except ValueError as exc: - feedback = str(exc) - if attempt >= _ANALYSIS_RENDER_ATTEMPTS: - raise - continue - - state["analysis"] = parsed.analysis_markdown.strip() or "No analysis text was returned." - return state - except Exception as exc: # pragma: no cover - defensive - state["analysis"] = ( - f"The analysis step could not complete ({exc!s}). " - "Review the executed steps and trace for raw outputs." - ) + final_answer = decision.final_answer or _fallback_final_answer(state).final_answer + state["analysis"] = final_answer + state["final_answer"] = final_answer + state["workflow_status"] = "complete" return state + + +def run_analysis_narrative(state: dict[str, Any]) -> dict[str, Any]: + """Compatibility wrapper retained for test and integration callers.""" + + return analyze_workflow(state) + + +__all__ = ["analyze_workflow", "get_llm_client", "run_analysis_narrative"] diff --git a/app/agent/analysis_grounding.py b/app/agent/analysis_grounding.py deleted file mode 100644 index 6911a9e..0000000 --- a/app/agent/analysis_grounding.py +++ /dev/null @@ -1,387 +0,0 @@ -"""Deterministic evidence, claims, and validation for grounded analysis output.""" - -from __future__ import annotations - -import re -from collections import defaultdict -from difflib import SequenceMatcher -from numbers import Real -from typing import Any - -from app.agent.state import AnalysisState -from app.schemas import AnalysisEvidence, AnalysisRenderResponse, ApprovedClaim, EvidenceItem, EvidenceValue - -_NEGATIVE_PREMISE_TERMS = ( - "drop", - "decline", - "decrease", - "down", - "worse", - "slower", - "underperform", - "fall", -) -_POSITIVE_PREMISE_TERMS = ( - "improve", - "increase", - "growth", - "better", - "faster", - "higher", - "gain", - "rise", -) -_CURRENT_TERMS = ("current", "latest", "this") -_PREVIOUS_TERMS = ("previous", "prior", "last") -_BLOCKED_TERMS = ( - "stable", - "strong", - "healthy", - "significant", - "improving", - "worsening", - "improved", - "worsened", - "cause", - "caused", - "driver", - "drivers", - "because", - "due to", - "root cause", -) -_GENERIC_PROPER_NOUN_ALLOWLIST = { - "Summary", - "Conclusion", - "Key Findings", - "Analysis", - "Evidence", - "Question", -} - - -def _is_number(value: Any) -> bool: - return isinstance(value, Real) and not isinstance(value, bool) - - -def _format_value(value: Any) -> str: - if value is None: - return "null" - if _is_number(value): - if float(value).is_integer(): - return str(int(value)) - return f"{float(value):.2f}".rstrip("0").rstrip(".") - return str(value) - - -def _infer_premise_hint(question: str) -> str: - lowered = question.lower() - if any(term in lowered for term in _NEGATIVE_PREMISE_TERMS): - return "deterioration" - if any(term in lowered for term in _POSITIVE_PREMISE_TERMS): - return "improvement" - return "" - - -def _row_label(row: dict[str, Any], non_numeric_columns: list[str], fallback: str) -> str: - values = [str(row[column]) for column in non_numeric_columns if row.get(column) not in (None, "")] - if not values: - return fallback - if len(values) == 1: - return values[0] - return " | ".join(values) - - -def _extract_entities(row: dict[str, Any], non_numeric_columns: list[str]) -> list[str]: - seen: list[str] = [] - for column in non_numeric_columns: - value = row.get(column) - if value in (None, ""): - continue - as_text = str(value) - if as_text not in seen: - seen.append(as_text) - return seen - - -def build_analysis_evidence(state: AnalysisState) -> AnalysisEvidence: - """Build a compact, domain-agnostic evidence packet from executed step previews.""" - - items: list[EvidenceItem] = [] - allowed_entities: list[str] = [] - - for step in state.get("executed_steps") or []: - if step.get("status") != "success": - continue - artifact = step.get("artifact") or {} - preview_rows = artifact.get("preview_rows") or [] - columns = artifact.get("columns") or [] - if not preview_rows or not columns: - continue - - numeric_columns = [ - column - for column in columns - if any(_is_number(row.get(column)) for row in preview_rows) - ] - non_numeric_columns = [column for column in columns if column not in numeric_columns] - - for index, row in enumerate(preview_rows, start=1): - entities = _extract_entities(row, non_numeric_columns) - for entity in entities: - if entity not in allowed_entities: - allowed_entities.append(entity) - values = [EvidenceValue(label=column, value=_format_value(row.get(column))) for column in columns if row.get(column) is not None] - items.append( - EvidenceItem( - id=f"{artifact.get('alias', step['output_alias'])}_row_{index}", - source_alias=artifact.get("alias", step["output_alias"]), - source_purpose=step["purpose"], - row_label=_row_label(row, non_numeric_columns, fallback=f"row_{index}"), - entities=entities, - metrics=numeric_columns, - values=values, - ) - ) - - return AnalysisEvidence( - question=state["query"], - primary_metric=state.get("metric", ""), - metric_direction=(state.get("compiled_plan") or {}).get("metric_direction", ""), - premise_hint=_infer_premise_hint(state["query"]), - items=items, - allowed_entities=allowed_entities, - ) - - -def _value_map(item: EvidenceItem) -> dict[str, str]: - return {value.label: value.value for value in item.values} - - -def _group_items_by_source(evidence: AnalysisEvidence) -> dict[str, list[EvidenceItem]]: - grouped: dict[str, list[EvidenceItem]] = defaultdict(list) - for item in evidence.items: - grouped[item.source_alias].append(item) - return grouped - - -def _sort_period_pair(items: list[EvidenceItem]) -> list[EvidenceItem]: - def score(item: EvidenceItem) -> int: - lowered = item.row_label.lower() - if any(term in lowered for term in _PREVIOUS_TERMS): - return 0 - if any(term in lowered for term in _CURRENT_TERMS): - return 1 - return 2 - - return sorted(items, key=score) - - -def _build_premise_claim(evidence: AnalysisEvidence) -> tuple[ApprovedClaim | None, str | None]: - if not evidence.primary_metric or not evidence.metric_direction or not evidence.premise_hint: - return None, None - - for source_alias, items in _group_items_by_source(evidence).items(): - matching = [item for item in items if evidence.primary_metric in item.metrics] - if len(matching) < 2: - continue - ordered = _sort_period_pair(matching[:2]) - left, right = ordered[0], ordered[1] - left_value = _value_map(left).get(evidence.primary_metric) - right_value = _value_map(right).get(evidence.primary_metric) - if left_value is None or right_value is None: - continue - - left_number = float(left_value) - right_number = float(right_value) - if evidence.metric_direction == "lower_is_better": - performance_change = "improved" if right_number < left_number else "deteriorated" if right_number > left_number else "flat" - elif evidence.metric_direction == "higher_is_better": - performance_change = "improved" if right_number > left_number else "deteriorated" if right_number < left_number else "flat" - else: - performance_change = "flat" - - contradicted = ( - (evidence.premise_hint == "deterioration" and performance_change == "improved") - or (evidence.premise_hint == "improvement" and performance_change == "deteriorated") - ) - statement = ( - f"The primary metric {evidence.primary_metric} was {left_value} for {left.row_label} and {right_value} for {right.row_label}. " - f"The metric direction is {evidence.metric_direction}." - ) - if contradicted: - statement += f" This does not support a {evidence.premise_hint} premise." - - return ( - ApprovedClaim( - id="claim_premise_check", - kind="premise_check", - statement=statement, - entities=[entity for entity in [left.row_label, right.row_label, *left.entities, *right.entities] if entity], - metrics=[evidence.primary_metric], - source_aliases=[source_alias], - values=[ - EvidenceValue(label=f"{left.row_label}.{evidence.primary_metric}", value=left_value), - EvidenceValue(label=f"{right.row_label}.{evidence.primary_metric}", value=right_value), - ], - ), - "contradicted_premise" if contradicted else None, - ) - return None, None - - -def _build_comparison_claims(evidence: AnalysisEvidence) -> list[ApprovedClaim]: - claims: list[ApprovedClaim] = [] - for source_alias, items in _group_items_by_source(evidence).items(): - if len(items) != 2: - continue - ordered = _sort_period_pair(items) - left, right = ordered[0], ordered[1] - common_metrics = [metric for metric in left.metrics if metric in right.metrics] - for metric in common_metrics[:6]: - left_value = _value_map(left).get(metric) - right_value = _value_map(right).get(metric) - if left_value is None or right_value is None: - continue - claim_id = f"claim_{source_alias}_{metric}" - claims.append( - ApprovedClaim( - id=claim_id, - kind="comparison", - statement=f"In {source_alias}, {metric} was {left_value} for {left.row_label} and {right_value} for {right.row_label}.", - entities=[entity for entity in [left.row_label, right.row_label, *left.entities, *right.entities] if entity], - metrics=[metric], - source_aliases=[source_alias], - values=[ - EvidenceValue(label=f"{left.row_label}.{metric}", value=left_value), - EvidenceValue(label=f"{right.row_label}.{metric}", value=right_value), - ], - ) - ) - return claims - - -def _build_row_observation_claims(evidence: AnalysisEvidence) -> list[ApprovedClaim]: - claims: list[ApprovedClaim] = [] - for item in evidence.items: - if not item.metrics: - continue - value_map = _value_map(item) - metric_pairs = [f"{metric} = {value_map[metric]}" for metric in item.metrics if metric in value_map][:4] - if not metric_pairs: - continue - metric_text = ", ".join(metric_pairs[:-1]) + (f", and {metric_pairs[-1]}" if len(metric_pairs) > 1 else metric_pairs[0]) - claims.append( - ApprovedClaim( - id=f"claim_{item.id}", - kind="row_observation", - statement=f"For {item.row_label}, {metric_text}.", - entities=item.entities or [item.row_label], - metrics=item.metrics, - source_aliases=[item.source_alias], - values=[EvidenceValue(label=metric, value=value_map[metric]) for metric in item.metrics if metric in value_map], - ) - ) - return claims[:8] - - -def build_approved_claims(evidence: AnalysisEvidence) -> tuple[list[ApprovedClaim], str]: - """Build deterministic claims and the expected final answer status.""" - - if not evidence.items: - return [], "insufficient_evidence" - - claims: list[ApprovedClaim] = [] - expected_status = "answered" - - premise_claim, premise_status = _build_premise_claim(evidence) - if premise_claim is not None: - claims.append(premise_claim) - if premise_status is not None: - expected_status = premise_status - - claims.extend(_build_comparison_claims(evidence)) - claims.extend(_build_row_observation_claims(evidence)) - - deduped: list[ApprovedClaim] = [] - seen_statements: set[str] = set() - for claim in claims: - if claim.statement in seen_statements: - continue - seen_statements.add(claim.statement) - deduped.append(claim) - - return deduped, expected_status - - -def _extract_numeric_tokens(text: str) -> set[str]: - return set(re.findall(r"\b\d+(?:\.\d+)?\b", text)) - - -def _candidate_entity_phrases(text: str) -> set[str]: - candidates: set[str] = set() - for line in text.splitlines(): - stripped = line.strip() - if not stripped or stripped.startswith("#"): - continue - for match in re.findall(r"\b[A-Za-z][A-Za-z]*(?:\s+[A-Za-z][A-Za-z]*)+\b", stripped): - candidates.add(match.strip()) - return candidates - - -def validate_rendered_analysis( - response: AnalysisRenderResponse, - approved_claims: list[ApprovedClaim], - expected_status: str, -) -> None: - """Reject rendered analysis that steps outside the approved claim set.""" - - claim_by_id = {claim.id: claim for claim in approved_claims} - used_ids = response.used_claim_ids or [] - if expected_status != "insufficient_evidence" and not used_ids: - raise ValueError("The final analysis must cite at least one approved claim id.") - unknown_ids = [claim_id for claim_id in used_ids if claim_id not in claim_by_id] - if unknown_ids: - raise ValueError(f"Unknown claim ids in used_claim_ids: {unknown_ids}") - if response.answer_status != expected_status: - raise ValueError(f"answer_status must be {expected_status}, got {response.answer_status}") - - selected_claims = [claim_by_id[claim_id] for claim_id in used_ids if claim_id in claim_by_id] - allowed_numbers = { - token - for claim in selected_claims - for token in _extract_numeric_tokens(claim.statement) - } - unexpected_numbers = sorted(_extract_numeric_tokens(response.analysis_markdown) - allowed_numbers) - if unexpected_numbers: - raise ValueError(f"Analysis introduced numbers not present in the approved claims: {unexpected_numbers}") - - lowered_analysis = response.analysis_markdown.lower() - for term in _BLOCKED_TERMS: - if term in lowered_analysis: - raise ValueError(f"Analysis used unsupported wording: {term}") - - allowed_entities = { - entity - for claim in selected_claims - for entity in claim.entities - if entity - } - for candidate in _candidate_entity_phrases(response.analysis_markdown): - if candidate in allowed_entities or candidate in _GENERIC_PROPER_NOUN_ALLOWLIST: - continue - similar = max( - ( - SequenceMatcher(None, candidate.lower(), allowed.lower()).ratio() - for allowed in allowed_entities - ), - default=0.0, - ) - if similar >= 0.82: - raise ValueError(f"Analysis changed an approved entity name: {candidate}") - - if expected_status == "contradicted_premise": - first_line = next((line.strip() for line in response.analysis_markdown.splitlines() if line.strip()), "") - lowered_first_line = first_line.lower() - if "does not support" not in lowered_first_line and "contradict" not in lowered_first_line: - raise ValueError("A contradicted premise must be stated clearly in the first sentence.") diff --git a/app/agent/executor.py b/app/agent/executor.py index d8e58af..371026b 100644 --- a/app/agent/executor.py +++ b/app/agent/executor.py @@ -1,223 +1,241 @@ -"""Execution engine for compiled SQL plans and legacy pandas helpers.""" +"""Executor runtime for step-by-step SQL execution.""" from __future__ import annotations -from typing import Any, Literal +from typing import Any -import duckdb import pandas as pd -from app.agent.state import AnalysisState -from app.data.semantic_model import get_semantic_context, new_duckdb_connection -from app.schemas import ArtifactSummary, CompiledPlanStep, ExecutedStep - - -SAFE_BUILTINS: dict[str, Any] = { - "len": len, - "min": min, - "max": max, - "sum": sum, - "round": round, - "sorted": sorted, -} - - -def _summarize_artifact(alias: str, value: Any) -> ArtifactSummary: - if isinstance(value, pd.Series): - value = value.to_frame() - - if isinstance(value, pd.DataFrame): - preview_rows = value.head(5).where(pd.notnull(value), None).to_dict(orient="records") - summary: dict[str, Any] = {} - numeric = value.select_dtypes(include=["number"]) - if not numeric.empty: - summary["numeric_means"] = numeric.mean().round(2).to_dict() - return ArtifactSummary( - alias=alias, - artifact_type="table", - row_count=int(len(value)), - columns=list(value.columns), - preview_rows=preview_rows, - summary=summary, - ) - - if isinstance(value, (int, float, str, bool)): - return ArtifactSummary( - alias=alias, - artifact_type="scalar" if not isinstance(value, str) else "text", - row_count=1, - columns=["value"], - preview_rows=[{"value": value}], - summary={"value": value}, - ) - - return ArtifactSummary(alias=alias, artifact_type="unknown") - - -def _register_artifacts(conn: duckdb.DuckDBPyConnection, state: AnalysisState) -> None: - for alias, artifact in state["artifacts"].items(): - if isinstance(artifact, pd.DataFrame): - conn.register(alias, artifact) - - -def _execute_sql(state: AnalysisState, step: dict[str, Any]) -> ArtifactSummary: - conn = new_duckdb_connection(state.get("dataset_context")) - try: - _register_artifacts(conn, state) - frame = conn.execute(step["code"]).fetchdf() - state["artifacts"][step["output_alias"]] = frame - return _summarize_artifact(step["output_alias"], frame) - finally: - conn.close() - - -def _execute_pandas(state: AnalysisState, step: dict[str, Any]) -> ArtifactSummary: - context = get_semantic_context(state.get("source_ids")) - local_env: dict[str, Any] = { - **context.raw_views, - **context.semantic_views, - **state["artifacts"], - "pd": pd, - "result": None, - } - exec(step["code"], {"__builtins__": SAFE_BUILTINS}, local_env) - result = local_env.get("result") - if result is None: - raise ValueError("Pandas step did not assign a `result` variable.") - state["artifacts"][step["output_alias"]] = result - return _summarize_artifact(step["output_alias"], result) - - -def _empty_table_failure(artifact: ArtifactSummary) -> bool: - return artifact.artifact_type == "table" and artifact.row_count == 0 - +from app.schemas import ArtifactSummary, ExecutedStep, StepFailureRecord +from app.data.semantic_model import new_duckdb_connection + + +def _summarize_artifact(alias: str, frame: pd.DataFrame) -> ArtifactSummary: + preview_rows = frame.head(5).where(pd.notnull(frame), None).to_dict(orient="records") + return ArtifactSummary( + alias=alias, + artifact_type="table", + row_count=int(len(frame)), + columns=list(frame.columns), + preview_rows=preview_rows, + summary={}, + ) + + +def _register_stored_outputs(conn, state: dict[str, Any]) -> None: # noqa: ANN001 + for alias, value in (state.get("stored_outputs") or {}).items(): + if isinstance(value, pd.DataFrame): + conn.register(alias, value) + + +def _current_step(state: dict[str, Any]) -> dict[str, Any]: + plan = state.get("current_plan") or {} + steps = list(plan.get("steps") or []) + step_index = int(state.get("current_step_index", 0) or 0) + if step_index < 0 or step_index >= len(steps): + raise ValueError(f"Current step index {step_index} is out of range for the active plan.") + return steps[step_index] + + +def _record_failure(state: dict[str, Any], *, step: dict[str, Any], attempt: int, error: str, sql: str) -> None: + failure = StepFailureRecord( + step_id=int(step["id"]), + attempt=attempt, + error=error, + query=sql, + details={"output_alias": step["output_alias"]}, + ) + state.setdefault("failure_history", {}).setdefault(str(step["id"]), []).append(failure.model_dump()) + + +def _record_error( + state: dict[str, Any], + *, + step_name: str, + message: str, + recoverable: bool, + details: dict[str, Any], +) -> None: + state.setdefault("errors", []).append( + { + "step": step_name, + "message": message, + "recoverable": recoverable, + "details": details, + } + ) -def compiled_plan_row_to_internal(row: dict[str, Any] | CompiledPlanStep) -> dict[str, Any]: - """Map a compiled plan step to the executor shape used by `_execute_sql`.""" - if isinstance(row, CompiledPlanStep): - row = row.model_dump() - sid = row["id"] - alias = row.get("output_alias") or f"step_{sid}" - return { - "id": str(sid), - "kind": "sql", - "purpose": row["purpose"], - "code": row["query"], - "output_alias": alias, - } +def _failure_summary(step: dict[str, Any], error: str) -> str: + relations = ", ".join(step.get("relations") or []) or "the active relations" + return f"Execution repeatedly failed for step {step['id']} against {relations}: {error}" -def preflight_compiled_plan(state: AnalysisState, compiled_plan: dict[str, Any]) -> dict[str, Any]: - """ - Validate compiled SQL steps against the active runtime before execution. +def execute_current_step(state: dict[str, Any]) -> dict[str, Any]: + """Execute the query for the current workflow step.""" - Steps are checked in order so later queries can reference earlier output aliases. - """ + step = _current_step(state) + generated_query = state.get("generated_query") or {} + sql = str(generated_query.get("sql") or "").strip() + if not sql: + raise ValueError("No generated SQL is available for the current step.") + attempt = int((state.get("retry_counts") or {}).get(str(step["id"]), 0) or 0) + 1 conn = new_duckdb_connection(state.get("dataset_context")) try: - _register_artifacts(conn, state) - rows = list(compiled_plan.get("plan") or []) - rows.sort(key=lambda r: r["id"] if isinstance(r, dict) else r.id) - - for row in rows: - internal = compiled_plan_row_to_internal(row) - sql = internal["code"].strip().rstrip(";") - try: - preview = conn.execute(f"SELECT * FROM ({sql}) AS __planera_preflight LIMIT 0").fetchdf() - conn.register(internal["output_alias"], preview) - except Exception as exc: - return { - "status": "failed", - "failed_step_id": internal["id"], - "error": str(exc), - "query": internal["code"], - } - - return {"status": "success"} - finally: - conn.close() - - -def _try_sql_step( - state: AnalysisState, - internal: dict[str, Any], - attempt: int, -) -> tuple[Literal["success", "failed"], ExecutedStep]: - """Run one SQL step with post-execution validation (non-empty table).""" - - state["total_steps"] += 1 - try: - artifact = _execute_sql(state, internal) - if _empty_table_failure(artifact): - state["artifacts"].pop(internal["output_alias"], None) - raise ValueError("Step returned an empty result set.") + _register_stored_outputs(conn, state) + frame = conn.execute(sql).fetchdf() + except Exception as exc: + message = str(exc) executed = ExecutedStep( - id=internal["id"], + id=str(step["id"]), kind="sql", - purpose=internal["purpose"], - code=internal["code"], - output_alias=internal["output_alias"], + purpose=step["purpose"], + code=sql, + output_alias=step["output_alias"], attempt=attempt, - status="success", - artifact=artifact, + status="failed", + error=message, ) - state["executed_steps"].append(executed.model_dump()) - state["last_error"] = None - return "success", executed - except Exception as exc: + state.setdefault("executed_steps", []).append(executed.model_dump()) + _record_failure(state, step=step, attempt=attempt, error=message, sql=sql) + + prior_retries = int((state.get("retry_counts") or {}).get(str(step["id"]), 0) or 0) + if prior_retries < 1: + state.setdefault("retry_counts", {})[str(step["id"])] = prior_retries + 1 + _record_error( + state, + step_name="executor_node", + message=message, + recoverable=True, + details={"step_id": step["id"], "attempt": attempt}, + ) + state["workflow_status"] = "retry_same_step" + elif int(state.get("replan_count", 0) or 0) < 1: + state["failure_summary"] = _failure_summary(step, message) + _record_error( + state, + step_name="executor_node", + message=message, + recoverable=True, + details={"step_id": step["id"], "attempt": attempt, "action": "replan"}, + ) + state["workflow_status"] = "needs_replan" + else: + state["failure_summary"] = _failure_summary(step, message) + _record_error( + state, + step_name="executor_node", + message=message, + recoverable=False, + details={"step_id": step["id"], "attempt": attempt, "action": "best_effort"}, + ) + state["workflow_status"] = "best_effort" + + state["generated_query"] = None + return state + finally: + conn.close() + + artifact = _summarize_artifact(step["output_alias"], frame) + if artifact.row_count == 0 and not bool(step.get("allow_empty_result", False)): + message = "Step returned an empty result set." executed = ExecutedStep( - id=internal["id"], + id=str(step["id"]), kind="sql", - purpose=internal["purpose"], - code=internal["code"], - output_alias=internal["output_alias"], + purpose=step["purpose"], + code=sql, + output_alias=step["output_alias"], attempt=attempt, status="failed", - error=str(exc), + error=message, ) - state["executed_steps"].append(executed.model_dump()) - state["last_error"] = {"step_id": internal["id"], "message": str(exc), "code": internal["code"]} - return "failed", executed - - -def execute_plan(state: AnalysisState, compiled_plan: dict[str, Any]) -> dict[str, Any]: - """ - Iterate compiled plan steps in order: validate via execute + empty-table check. - No LLM calls. On first failure, stop and return structured status. - """ - - rows = list(compiled_plan.get("plan") or []) - rows.sort(key=lambda r: r["id"] if isinstance(r, dict) else r.id) - - for row in rows: - internal = compiled_plan_row_to_internal(row) - status, _ = _try_sql_step(state, internal, attempt=1) - if status == "failed": - sid = internal["id"] - return { - "status": "failed", - "failed_step_id": sid, - "error": state["last_error"]["message"] if state["last_error"] else "Unknown error", - } - - return {"status": "success"} - - -def execute_single_plan_step( - state: AnalysisState, - compiled_step: dict[str, Any], - attempt: int, -) -> dict[str, Any]: - """Re-run a single compiled step (e.g. after repair).""" - - internal = compiled_plan_row_to_internal(compiled_step) - status, _ = _try_sql_step(state, internal, attempt=attempt) - if status == "failed": - return { - "status": "failed", - "failed_step_id": internal["id"], - "error": state["last_error"]["message"] if state["last_error"] else "Unknown error", - } - return {"status": "success"} + state.setdefault("executed_steps", []).append(executed.model_dump()) + _record_failure(state, step=step, attempt=attempt, error=message, sql=sql) + + prior_retries = int((state.get("retry_counts") or {}).get(str(step["id"]), 0) or 0) + if prior_retries < 1: + state.setdefault("retry_counts", {})[str(step["id"])] = prior_retries + 1 + _record_error( + state, + step_name="executor_node", + message=message, + recoverable=True, + details={"step_id": step["id"], "attempt": attempt}, + ) + state["workflow_status"] = "retry_same_step" + elif int(state.get("replan_count", 0) or 0) < 1: + state["failure_summary"] = _failure_summary(step, message) + _record_error( + state, + step_name="executor_node", + message=message, + recoverable=True, + details={"step_id": step["id"], "attempt": attempt, "action": "replan"}, + ) + state["workflow_status"] = "needs_replan" + else: + state["failure_summary"] = _failure_summary(step, message) + _record_error( + state, + step_name="executor_node", + message=message, + recoverable=False, + details={"step_id": step["id"], "attempt": attempt, "action": "best_effort"}, + ) + state["workflow_status"] = "best_effort" + + state["generated_query"] = None + return state + + state.setdefault("stored_outputs", {})[step["output_alias"]] = frame + executed = ExecutedStep( + id=str(step["id"]), + kind="sql", + purpose=step["purpose"], + code=sql, + output_alias=step["output_alias"], + attempt=attempt, + status="success", + artifact=artifact, + ) + state.setdefault("executed_steps", []).append(executed.model_dump()) + state["generated_query"] = None + + plan = state.get("current_plan") or {} + steps = list(plan.get("steps") or []) + step_index = int(state.get("current_step_index", 0) or 0) + if step_index + 1 < len(steps): + state["current_step_index"] = step_index + 1 + state["workflow_status"] = "plan_ready" + else: + state["workflow_status"] = "ready_for_analysis" + return state + + +def build_best_effort_state(state: dict[str, Any]) -> dict[str, Any]: + """Populate the best-effort answer path after retry limits are exhausted.""" + + successful_steps = [step for step in state.get("executed_steps") or [] if step.get("status") == "success"] + answered_parts: list[str] = [] + if successful_steps: + last_success = successful_steps[-1] + artifact = last_success.get("artifact") or {} + answered_parts.append( + f"Captured {artifact.get('row_count', 0)} row(s) in {artifact.get('alias', last_success.get('output_alias', 'the final output'))}." + ) + + unanswered = state.get("failure_summary") or "The workflow could not complete every planned step." + lines = [ + "## Best-effort answer", + "", + "Answered parts:", + *(f"- {item}" for item in answered_parts or ["- No completed step produced a reusable final output."]), + "", + "Could not answer completely:", + f"- {unanswered}", + ] + state["final_answer"] = "\n".join(lines).strip() + state["analysis"] = state["final_answer"] + state["workflow_status"] = "best_effort_ready" + return state diff --git a/app/agent/graph.py b/app/agent/graph.py index 076944b..a5a2afe 100644 --- a/app/agent/graph.py +++ b/app/agent/graph.py @@ -1,4 +1,4 @@ -"""LangGraph workflow: schema context, compiled plan, deterministic execution, analysis.""" +"""LangGraph orchestration for the schema-grounded analytics workflow.""" from __future__ import annotations @@ -7,25 +7,46 @@ from langgraph.graph import END, START, StateGraph -from app.agent.analysis import run_analysis_narrative -from app.agent.executor import execute_plan, execute_single_plan_step -from app.agent.planner import plan_compiled_query, repair_failed_step +from app.agent.analysis import analyze_workflow +from app.agent.executor import build_best_effort_state, execute_current_step +from app.agent.planner import plan_analysis, replan_analysis +from app.agent.query_writer import write_step_query from app.agent.state import AnalysisState, create_initial_state from app.data.semantic_model import get_semantic_context -from app.utils.logging import get_logger -logger = get_logger(__name__) + +def _append_trace(state: dict[str, Any], step: str, status: str, details: dict[str, Any] | None = None) -> None: + state.setdefault("trace", []).append({"step": step, "status": status, "details": details or {}}) + + +def _append_error( + state: dict[str, Any], + *, + step: str, + message: str, + recoverable: bool, + details: dict[str, Any] | None = None, +) -> None: + state.setdefault("errors", []).append( + { + "step": step, + "message": message, + "recoverable": recoverable, + "details": details or {}, + } + ) -def _append_trace(state: AnalysisState, step: str, status: str, details: dict[str, Any] | None = None) -> None: - state["trace"].append({"step": step, "status": status, "details": details or {}}) +def run_analysis(query: str, source_ids: list[str] | None = None) -> dict[str, Any]: + """Compatibility entrypoint used by the API service layer.""" + workflow = build_graph() + return workflow.invoke(create_initial_state(query, source_ids=source_ids)) -def _append_error(state: AnalysisState, step: str, message: str, recoverable: bool = True, details: dict[str, Any] | None = None) -> None: - state["errors"].append({"step": step, "message": message, "recoverable": recoverable, "details": details or {}}) +def load_schema_context_node(state: dict[str, Any]) -> dict[str, Any]: + """Load schema context into workflow state.""" -def load_schema_context_node(state: AnalysisState) -> AnalysisState: step_name = "load_schema_context_node" _append_trace(state, step_name, "started", {}) context = get_semantic_context(state.get("source_ids")) @@ -34,136 +55,196 @@ def load_schema_context_node(state: AnalysisState) -> AnalysisState: state, step_name, "completed", - {"reference_date": context.reference_date, "views": [v["name"] for v in context.schema_manifest.get("views", [])]}, + { + "reference_date": context.reference_date, + "relations": [relation["name"] for relation in context.schema_manifest.get("relations", [])], + }, ) + state["workflow_status"] = "schema_ready" return state -def planner_compiled_node(state: AnalysisState) -> AnalysisState: - step_name = "planner_compiled_node" - _append_trace(state, step_name, "started", {"total_steps": state["total_steps"]}) +def planner_node(state: dict[str, Any]) -> dict[str, Any]: + """Create the full ordered plan for the request.""" + + step_name = "planner_node" + _append_trace(state, step_name, "started", {"replan_count": state.get("replan_count", 0)}) try: - state = plan_compiled_query(state) - plan = state.get("compiled_plan") or {} - _append_trace( - state, - step_name, - "completed", - { - "objective": plan.get("objective"), - "step_count": len(plan.get("plan") or []), - "metric": plan.get("metric"), - }, - ) + state = plan_analysis(state) except Exception as exc: - logger.warning("%s failed: %s", step_name, exc, exc_info=True) - _append_error(state, step_name, str(exc), recoverable=False) - state["workflow_status"] = "planner_failed" - state["compiled_plan"] = None - _append_trace(state, step_name, "failed", {"message": str(exc)}) + message = str(exc) + _append_error(state, step=step_name, message=message, recoverable=False) + _append_trace(state, step_name, "failed", {"message": message}) + state["failure_summary"] = message + state["workflow_status"] = "best_effort" + return state + + steps = list((state.get("current_plan") or {}).get("steps") or []) + if steps: + _append_trace(state, step_name, "completed", {"step_count": len(steps)}) + state["workflow_status"] = "plan_ready" + else: + _append_trace(state, step_name, "completed", {"step_count": 0}) + state["workflow_status"] = "ready_for_analysis" return state -def execute_plan_node(state: AnalysisState) -> AnalysisState: - step_name = "execute_plan_node" - if state["workflow_status"] == "planner_failed" or not state.get("compiled_plan"): - _append_trace(state, step_name, "skipped", {"reason": "no compiled plan"}) - return state +def query_writer_node(state: dict[str, Any]) -> dict[str, Any]: + """Generate exactly one query for the current step.""" - plan = state["compiled_plan"] - if not plan.get("plan"): - msg = "Planner returned an empty plan." - _append_error(state, step_name, msg, recoverable=False) - state["workflow_status"] = "execution_failed" - _append_trace(state, step_name, "failed", {"message": msg}) + step_name = "query_writer_node" + _append_trace(state, step_name, "started", {"current_step_index": state.get("current_step_index", 0)}) + try: + state = write_step_query(state) + except Exception as exc: + message = str(exc) + _append_error(state, step=step_name, message=message, recoverable=False) + _append_trace(state, step_name, "failed", {"message": message}) + state["failure_summary"] = message + state["workflow_status"] = "best_effort" return state - _append_trace(state, step_name, "started", {"steps": len(plan["plan"])}) - outcome = execute_plan(state, plan) + generated_query = state.get("generated_query") or {} + _append_trace( + state, + step_name, + "completed", + {"step_id": generated_query.get("step_id"), "query_length": len(generated_query.get("sql", ""))}, + ) + return state - if outcome["status"] == "success": - state["workflow_status"] = "ready_to_analyze" - _append_trace(state, step_name, "completed", {"status": "success"}) - return state - failed_id = outcome["failed_step_id"] - err = outcome.get("error", "Unknown execution error") +def executor_node(state: dict[str, Any]) -> dict[str, Any]: + """Execute the current step and store any successful output.""" + step_name = "executor_node" + current_step_index = state.get("current_step_index", 0) + _append_trace(state, step_name, "started", {"current_step_index": current_step_index}) try: - state = repair_failed_step(state, failed_id, err) + state = execute_current_step(state) except Exception as exc: - _append_error(state, "repair_planner", str(exc), recoverable=False, details={"failed_step_id": failed_id}) - state["workflow_status"] = "execution_failed" - _append_trace(state, step_name, "failed", {"phase": "repair", "message": str(exc)}) + message = str(exc) + _append_error(state, step=step_name, message=message, recoverable=False) + _append_trace(state, step_name, "failed", {"message": message}) + state["failure_summary"] = message + state["workflow_status"] = "best_effort" return state - if state["executed_steps"] and state["executed_steps"][-1]["status"] == "failed": - state["executed_steps"].pop() + _append_trace(state, step_name, "completed", {"workflow_status": state.get("workflow_status", "")}) + return state - plan = state["compiled_plan"] or {} - step_row = next((r for r in (plan.get("plan") or []) if str(r.get("id")) == str(failed_id)), None) - if not step_row: - _append_error(state, step_name, "Repaired step missing from plan.", recoverable=False) - state["workflow_status"] = "execution_failed" - return state - retry = execute_single_plan_step(state, step_row, attempt=2) - if retry["status"] == "failed": - _append_error( - state, - step_name, - f"Step {retry.get('failed_step_id', failed_id)} failed after repair: {retry.get('error', err)}", - recoverable=False, - details={"failed_step_id": failed_id}, - ) - state["workflow_status"] = "execution_failed" - _append_trace(state, step_name, "failed", {"phase": "retry", "failed_step_id": failed_id}) - else: - state["workflow_status"] = "ready_to_analyze" - _append_trace(state, step_name, "completed", {"status": "success_after_repair"}) +def analyzer_node(state: dict[str, Any]) -> dict[str, Any]: + """Produce the final answer or request a replan.""" + step_name = "analyzer_node" + _append_trace(state, step_name, "started", {"workflow_status": state.get("workflow_status", "")}) + try: + state = analyze_workflow(state) + except Exception as exc: + message = str(exc) + _append_error(state, step=step_name, message=message, recoverable=False) + _append_trace(state, step_name, "failed", {"message": message}) + state["failure_summary"] = message + state["workflow_status"] = "best_effort" + return state + + decision = (state.get("analyzer_result") or {}).get("decision", "") + _append_trace(state, step_name, "completed", {"decision": decision}) return state -def analysis_node(state: AnalysisState) -> AnalysisState: - step_name = "analysis_node" - _append_trace(state, step_name, "started", {"workflow_status": state["workflow_status"]}) +def replan_node(state: dict[str, Any]) -> dict[str, Any]: + """Request one bounded replan after analyzer or execution failure.""" + + step_name = "replan_node" + _append_trace(state, step_name, "started", {"failure_summary": state.get("failure_summary", "")}) try: - state = run_analysis_narrative(state) - _append_trace(state, step_name, "completed", {"length": len(state["analysis"])}) + state = replan_analysis(state) except Exception as exc: - state["analysis"] = f"The analysis step failed: {exc}" - _append_error(state, step_name, str(exc), recoverable=False) - _append_trace(state, step_name, "failed", {"message": str(exc)}) + message = str(exc) + _append_error(state, step=step_name, message=message, recoverable=False) + _append_trace(state, step_name, "failed", {"message": message}) + state["failure_summary"] = message + state["workflow_status"] = "best_effort" + return state + + state["current_step_index"] = 0 + state["generated_query"] = None + state["stored_outputs"] = {} + state["workflow_status"] = "plan_ready" if (state.get("current_plan") or {}).get("steps") else "ready_for_analysis" + _append_trace(state, step_name, "completed", {"replan_count": state.get("replan_count", 0)}) return state -def route_after_planner(state: AnalysisState) -> str: - if state["workflow_status"] == "planner_failed": - return "analysis_node" - return "execute_plan_node" +def best_effort_node(state: dict[str, Any]) -> dict[str, Any]: + """Return the final best-effort answer when retry limits are exhausted.""" + + step_name = "best_effort_node" + _append_trace(state, step_name, "started", {"failure_summary": state.get("failure_summary", "")}) + state = build_best_effort_state(state) + _append_trace(state, step_name, "completed", {"workflow_status": state.get("workflow_status", "")}) + state = analyze_workflow(state) + return state + + +def _route_after_planner(state: dict[str, Any]) -> str: + status = state.get("workflow_status") + if status == "plan_ready": + return "query_writer_node" + if status == "best_effort": + return "best_effort_node" + return "analyzer_node" + + +def _route_after_executor(state: dict[str, Any]) -> str: + status = state.get("workflow_status") + if status in {"plan_ready", "retry_same_step"}: + return "query_writer_node" + if status == "needs_replan": + return "replan_node" + if status == "best_effort": + return "best_effort_node" + return "analyzer_node" + + +def _route_after_analyzer(state: dict[str, Any]) -> str: + status = state.get("workflow_status") + if status == "needs_replan": + return "replan_node" + if status == "best_effort": + return "best_effort_node" + return END + + +def _route_after_replan(state: dict[str, Any]) -> str: + status = state.get("workflow_status") + if status == "plan_ready": + return "query_writer_node" + if status == "best_effort": + return "best_effort_node" + return "analyzer_node" @lru_cache(maxsize=1) def build_graph(): - """Compile and cache the LangGraph workflow.""" + """Compile and cache the analytics workflow graph.""" graph = StateGraph(AnalysisState) graph.add_node("load_schema_context_node", load_schema_context_node) - graph.add_node("planner_compiled_node", planner_compiled_node) - graph.add_node("execute_plan_node", execute_plan_node) - graph.add_node("analysis_node", analysis_node) + graph.add_node("planner_node", planner_node) + graph.add_node("query_writer_node", query_writer_node) + graph.add_node("executor_node", executor_node) + graph.add_node("analyzer_node", analyzer_node) + graph.add_node("replan_node", replan_node) + graph.add_node("best_effort_node", best_effort_node) graph.add_edge(START, "load_schema_context_node") - graph.add_edge("load_schema_context_node", "planner_compiled_node") - graph.add_conditional_edges("planner_compiled_node", route_after_planner) - graph.add_edge("execute_plan_node", "analysis_node") - graph.add_edge("analysis_node", END) + graph.add_edge("load_schema_context_node", "planner_node") + graph.add_conditional_edges("planner_node", _route_after_planner) + graph.add_edge("query_writer_node", "executor_node") + graph.add_conditional_edges("executor_node", _route_after_executor) + graph.add_conditional_edges("analyzer_node", _route_after_analyzer) + graph.add_conditional_edges("replan_node", _route_after_replan) + graph.add_edge("best_effort_node", END) return graph.compile() - - -def run_analysis(query: str, source_ids: list[str] | None = None) -> AnalysisState: - """Execute the full workflow for a single user query.""" - - workflow = build_graph() - return workflow.invoke(create_initial_state(query, source_ids=source_ids)) diff --git a/app/agent/planner.py b/app/agent/planner.py index 5fc8bdb..b2fbadd 100644 --- a/app/agent/planner.py +++ b/app/agent/planner.py @@ -1,4 +1,4 @@ -"""LLM-driven compiled multi-step planner and repair.""" +"""Planner runtime for the schema-grounded analytics workflow.""" from __future__ import annotations @@ -7,18 +7,11 @@ from copy import deepcopy from typing import Any -from pydantic import ValidationError - -from app.agent.executor import preflight_compiled_plan -from app.agent.state import AnalysisState from app.llm import get_llm_client from app.prompts import render_prompt -from app.schemas import CompiledPlan, RepairDecision -from app.utils.logging import get_logger +from app.schemas import AnalysisPlan, CompactSchemaColumn, CompactSchemaContext, CompactSchemaRelation -logger = get_logger(__name__) -_COMPILED_PLANNER_ATTEMPTS = 3 _MAX_PROMPT_RELATIONS = 4 _MAX_COLUMNS_PER_RELATION = 18 @@ -40,11 +33,10 @@ def _field_terms(*values: str) -> set[str]: def _column_relevance_score(column: dict[str, Any], question_terms: set[str]) -> int: column_terms = _field_terms( column.get("name", ""), - column.get("original_name", ""), - column.get("source_path", ""), + column.get("dtype", ""), ) for hint in column.get("semantic_hints") or []: - column_terms.update(_field_terms(hint)) + column_terms.update(_field_terms(str(hint))) overlap = len(column_terms & question_terms) score = overlap * 3 @@ -59,7 +51,6 @@ def _relation_relevance_score(relation: dict[str, Any], question_terms: set[str] relation.get("name", ""), relation.get("grain", ""), relation.get("source_name", ""), - json.dumps(relation.get("lineage", {}), default=str), ) & question_terms ) * 4 @@ -83,6 +74,7 @@ def _trim_relation_for_prompt(relation: dict[str, Any], question_terms: set[str] _column_relevance_score(column, question_terms), column.get("name") in (trimmed.get("identifier_columns") or []), column.get("name") in (trimmed.get("time_columns") or []), + column.get("name") in (trimmed.get("measure_columns") or []), ), reverse=True, ) @@ -107,153 +99,144 @@ def _trim_relation_for_prompt(relation: dict[str, Any], question_terms: set[str] for mapping in (trimmed.get("semantic_mappings") or []) if any(name in selected_names for name in mapping.get("columns") or []) ][:12] - trimmed["omitted_column_count"] = max(len(columns) - len(selected), 0) return trimmed -def _schema_subset_for_question(dataset_context: dict[str, Any], question: str) -> dict[str, Any]: +def _coerce_relation(relation: dict[str, Any]) -> CompactSchemaRelation: + return CompactSchemaRelation( + name=relation.get("name", ""), + source_id=relation.get("source_id", ""), + source_name=relation.get("source_name", ""), + is_primary=bool(relation.get("is_primary")), + row_count=int(relation.get("row_count", 0) or 0), + grain=relation.get("grain", "") or "", + identifier_columns=list(relation.get("identifier_columns") or []), + time_columns=list(relation.get("time_columns") or []), + measure_columns=list(relation.get("measure_columns") or []), + dimension_columns=list(relation.get("dimension_columns") or []), + join_keys=list(relation.get("join_keys") or []), + semantic_mappings=list(relation.get("semantic_mappings") or []), + columns=[ + CompactSchemaColumn( + name=column.get("name", ""), + dtype=column.get("dtype", ""), + type_family=column.get("type_family", "unknown"), + nullable=bool(column.get("nullable", True)), + semantic_hints=list(column.get("semantic_hints") or []), + ) + for column in (relation.get("columns") or []) + ], + ) + + +def _render_plan_prompt( + *, + query: str, + schema_context_summary: dict[str, Any], + failure_summary: str = "", + current_plan: dict[str, Any] | None = None, +) -> str: + template_name = "planner_replan.j2" if failure_summary else "planner_plan.j2" + return render_prompt( + template_name, + query=query, + schema_context_json=json.dumps(schema_context_summary, indent=2), + failure_summary=failure_summary, + current_plan_json=json.dumps(current_plan or {}, indent=2), + ) + + +def build_compact_schema_context(dataset_context: dict[str, Any], question: str) -> dict[str, Any]: + """Return a compact schema/context summary for planning prompts.""" + relations = list(dataset_context.get("relations") or []) if not relations: - return dataset_context - - total_columns = sum(len(relation.get("columns") or []) for relation in relations) - if len(relations) <= _MAX_PROMPT_RELATIONS and total_columns <= (_MAX_PROMPT_RELATIONS * _MAX_COLUMNS_PER_RELATION): - return { - "reference_date": dataset_context.get("reference_date", ""), - "source": dataset_context.get("source", ""), - "dialect": dataset_context.get("dialect", ""), - "relations": relations, - } + context = CompactSchemaContext( + reference_date=dataset_context.get("reference_date", ""), + source=dataset_context.get("source", ""), + dialect=dataset_context.get("dialect", ""), + relations=[], + ) + return context.model_dump() question_terms = _query_terms(question) - ranked_relations = sorted(relations, key=lambda relation: _relation_relevance_score(relation, question_terms), reverse=True) - selected = ranked_relations[:_MAX_PROMPT_RELATIONS] - return { - "reference_date": dataset_context.get("reference_date", ""), - "source": dataset_context.get("source", ""), - "dialect": dataset_context.get("dialect", ""), - "relations": [_trim_relation_for_prompt(relation, question_terms) for relation in selected], - } - - -def _relation_names(dataset_context: dict[str, Any]) -> list[str]: - relations = dataset_context.get("relations") or [] - if relations: - return [relation["name"] for relation in relations] - return [view["name"] for view in dataset_context.get("views", [])] - - -def _planner_preflight_feedback(outcome: dict[str, Any], schema_subset: dict[str, Any]) -> str: - return ( - "Your previous plan failed SQL preflight validation.\n" - f"Failed step id: {outcome.get('failed_step_id', '')}\n" - f"Error: {outcome.get('error', '')}\n" - f"SQL:\n{outcome.get('query', '').strip()}\n\n" - "Fix guidance:\n" - "- Use only exact relation and column names from the schema subset.\n" - "- Resolve business-language terms through semantic mappings, then use the mapped exact field names in SQL.\n" - "- Do not invent fields or rename columns.\n" - "- If the question premise might be wrong, start with an overall comparison before segment-level breakdowns.\n" - f"- Target SQL dialect: {schema_subset.get('dialect', '') or 'unknown'}.\n" + total_columns = sum(len(relation.get("columns") or []) for relation in relations) + if len(relations) > _MAX_PROMPT_RELATIONS or total_columns > (_MAX_PROMPT_RELATIONS * _MAX_COLUMNS_PER_RELATION): + ranked_relations = sorted( + relations, + key=lambda relation: _relation_relevance_score(relation, question_terms), + reverse=True, + ) + selected_relations = ranked_relations[:_MAX_PROMPT_RELATIONS] + else: + selected_relations = relations + + context = CompactSchemaContext( + reference_date=dataset_context.get("reference_date", ""), + source=dataset_context.get("source", ""), + dialect=dataset_context.get("dialect", ""), + relations=[_coerce_relation(_trim_relation_for_prompt(relation, question_terms)) for relation in selected_relations], ) + return context.model_dump() -def _build_compiled_planner_prompt(state: AnalysisState, validation_feedback: str | None = None) -> str: - schema_subset = _schema_subset_for_question(state["dataset_context"], state["query"]) - relation_names = _relation_names(schema_subset) - return render_prompt( - "planner_compiled.j2", - query=state["query"], - relation_names_json=json.dumps(relation_names), - schema_subset_json=json.dumps(schema_subset, indent=2), - validation_feedback=validation_feedback, +def plan_analysis(state: dict[str, Any]) -> dict[str, Any]: + """Return the planner-authored full workflow plan.""" + + schema_context_summary = state.get("schema_context_summary") or build_compact_schema_context( + state.get("dataset_context", {}), + state["query"], ) + state["schema_context_summary"] = schema_context_summary + prompt = _render_plan_prompt(query=state["query"], schema_context_summary=schema_context_summary) + result = get_llm_client().generate_json(prompt, schema=AnalysisPlan) + parsed = result if isinstance(result, AnalysisPlan) else AnalysisPlan.model_validate(result) + state["current_plan"] = parsed.model_dump() + state["current_step_index"] = 0 + state["workflow_status"] = "planned" + return state -def _build_repair_prompt(state: AnalysisState, failed_step_id: str, error_message: str) -> str: - plan = state.get("compiled_plan") or {} - schema_subset = _schema_subset_for_question(state["dataset_context"], state["query"]) - return render_prompt( - "planner_repair.j2", - failed_step_id=failed_step_id, - error_message=error_message, - plan_json=json.dumps(plan, indent=2), - schema_subset_json=json.dumps(schema_subset, indent=2), +def replan_analysis(state: dict[str, Any]) -> dict[str, Any]: + """Return one revised workflow plan after failure.""" + + schema_context_summary = state.get("schema_context_summary") or build_compact_schema_context( + state.get("dataset_context", {}), + state["query"], + ) + state["schema_context_summary"] = schema_context_summary + prompt = _render_plan_prompt( + query=state["query"], + schema_context_summary=schema_context_summary, + failure_summary=state.get("failure_summary", ""), + current_plan=state.get("current_plan"), ) + result = get_llm_client().generate_json(prompt, schema=AnalysisPlan) + parsed = result if isinstance(result, AnalysisPlan) else AnalysisPlan.model_validate(result) + state["current_plan"] = parsed.model_dump() + state["current_step_index"] = 0 + state["replan_count"] = int(state.get("replan_count", 0) or 0) + 1 + state["workflow_status"] = "replanned" + return state -def plan_compiled_query(state: AnalysisState) -> AnalysisState: - """Call the LLM to produce a full compiled plan (max 3 SQL steps), with retries on schema errors.""" - - client = get_llm_client() - feedback: str | None = None - - for attempt in range(1, _COMPILED_PLANNER_ATTEMPTS + 1): - prompt = _build_compiled_planner_prompt(state, validation_feedback=feedback) - try: - decision = client.generate_json(prompt, schema=CompiledPlan) - parsed = decision if isinstance(decision, CompiledPlan) else CompiledPlan.model_validate(decision) - except (ValidationError, ValueError) as exc: - feedback = exc.json(indent=2) if isinstance(exc, ValidationError) else str(exc) - logger.warning( - "Compiled plan validation failed (attempt %s/%s): %s", - attempt, - _COMPILED_PLANNER_ATTEMPTS, - exc.errors() if isinstance(exc, ValidationError) else str(exc), - ) - if attempt >= _COMPILED_PLANNER_ATTEMPTS: - raise - continue +def _schema_subset_for_question(dataset_context: dict[str, Any], question: str) -> dict[str, Any]: + """Compatibility helper kept for schema-focused tests.""" - preflight = preflight_compiled_plan(state, parsed.model_dump()) - if preflight["status"] == "failed": - feedback = _planner_preflight_feedback(preflight, _schema_subset_for_question(state["dataset_context"], state["query"])) - logger.warning( - "Compiled plan preflight failed (attempt %s/%s): %s", - attempt, - _COMPILED_PLANNER_ATTEMPTS, - preflight["error"], - ) - if attempt >= _COMPILED_PLANNER_ATTEMPTS: - raise ValueError(preflight["error"]) - continue + return build_compact_schema_context(dataset_context, question) - state["compiled_plan"] = parsed.model_dump() - state["planner_reasoning"] = parsed.objective - state["metric"] = parsed.metric - state["intent"] = "diagnosis" - state["workflow_status"] = "ready_to_execute" - return state +def plan_compiled_query(state: dict[str, Any]) -> dict[str, Any]: + """Compatibility wrapper retained during the transition away from compiled SQL planning.""" -def repair_failed_step(state: AnalysisState, failed_step_id: str, error_message: str) -> AnalysisState: - """Call the LLM once to replace a single failed plan step.""" + return plan_analysis(state) - raw = get_llm_client().generate_json( - _build_repair_prompt(state, failed_step_id, error_message), - schema=RepairDecision, - ) - parsed = raw if isinstance(raw, RepairDecision) else RepairDecision.model_validate(raw) - - if str(parsed.updated_step.id) != str(failed_step_id): - raise ValueError(f"Repair returned mismatched step id: expected {failed_step_id}, got {parsed.updated_step.id}") - - plan = state.get("compiled_plan") - if not plan: - raise ValueError("No compiled plan to repair.") - - steps = list(plan.get("plan") or []) - replaced = False - for i, row in enumerate(steps): - sid = row.get("id") if isinstance(row, dict) else row["id"] - if str(sid) == str(failed_step_id): - steps[i] = parsed.updated_step.model_dump() - replaced = True - break - if not replaced: - raise ValueError(f"Failed step id {failed_step_id} not found in compiled plan.") - plan["plan"] = steps - state["compiled_plan"] = plan - state["repair_attempted"] = True - return state +__all__ = [ + "_schema_subset_for_question", + "build_compact_schema_context", + "get_llm_client", + "plan_analysis", + "plan_compiled_query", + "replan_analysis", +] diff --git a/app/agent/query_writer.py b/app/agent/query_writer.py new file mode 100644 index 0000000..0cef37e --- /dev/null +++ b/app/agent/query_writer.py @@ -0,0 +1,79 @@ +"""Query-writer runtime for step-by-step SQL generation.""" + +from __future__ import annotations + +import json +from typing import Any + +from app.llm import get_llm_client +from app.prompts import render_prompt +from app.schemas import GeneratedQuery + + +def _get_current_step(state: dict[str, Any]) -> dict[str, Any]: + plan = state.get("current_plan") or {} + steps = list(plan.get("steps") or []) + if not steps: + raise ValueError("No planner-authored steps are available for query writing.") + + step_index = int(state.get("current_step_index", 0) or 0) + if step_index < 0 or step_index >= len(steps): + raise ValueError(f"Current step index {step_index} is out of range for the active plan.") + return steps[step_index] + + +def _prior_output_summaries(state: dict[str, Any], current_step: dict[str, Any]) -> list[dict[str, Any]]: + needed = {str(step_id) for step_id in current_step.get("depends_on") or []} + if not needed: + return [] + + summaries: list[dict[str, Any]] = [] + for step in state.get("executed_steps") or []: + if step.get("status") != "success": + continue + if str(step.get("id")) not in needed: + continue + artifact = step.get("artifact") or {} + summaries.append( + { + "step_id": step.get("id"), + "output_alias": step.get("output_alias"), + "purpose": step.get("purpose"), + "artifact": { + "alias": artifact.get("alias"), + "row_count": artifact.get("row_count", 0), + "columns": artifact.get("columns") or [], + "preview_rows": artifact.get("preview_rows") or [], + }, + } + ) + return summaries + + +def write_step_query(state: dict[str, Any]) -> dict[str, Any]: + """Generate one SQL query for the current workflow step.""" + + current_step = _get_current_step(state) + prior_outputs = _prior_output_summaries(state, current_step) + error_context = "" + failures = state.get("failure_history", {}).get(str(current_step["id"]), []) + if failures: + error_context = failures[-1].get("error", "") + + prompt = render_prompt( + "query_writer.j2", + query=state["query"], + step_json=json.dumps(current_step, indent=2), + schema_context_json=json.dumps(state.get("schema_context_summary") or {}, indent=2), + prior_outputs_json=json.dumps(prior_outputs, indent=2), + error_context=error_context, + ) + result = get_llm_client().generate_json(prompt, schema=GeneratedQuery) + parsed = result if isinstance(result, GeneratedQuery) else GeneratedQuery.model_validate(result) + state["generated_query"] = parsed.model_dump() + state.setdefault("step_queries", {}).setdefault(str(current_step["id"]), []).append(parsed.sql) + state["workflow_status"] = "query_ready" + return state + + +__all__ = ["get_llm_client", "write_step_query"] diff --git a/app/agent/state.py b/app/agent/state.py index c940ed7..763e9a2 100644 --- a/app/agent/state.py +++ b/app/agent/state.py @@ -1,4 +1,4 @@ -"""Typed workflow state for the compiled-plan analytics workflow.""" +"""Workflow state helpers for the schema-grounded multi-agent analytics flow.""" from __future__ import annotations @@ -7,24 +7,28 @@ from typing_extensions import TypedDict -class AnalysisState(TypedDict): - """Explicit state carried through the analytics workflow.""" +class AnalysisState(TypedDict, total=False): + """Explicit mutable state carried through the analytics workflow.""" query: str source_ids: list[str] dataset_context: dict[str, Any] - intent: str - metric: str - planner_reasoning: str - compiled_plan: dict[str, Any] | None - repair_attempted: bool - artifacts: dict[str, Any] - executed_steps: list[dict[str, Any]] + schema_context_summary: dict[str, Any] + current_plan: dict[str, Any] | None + current_step_index: int + generated_query: dict[str, Any] | None + stored_outputs: dict[str, Any] + step_queries: dict[str, list[str]] + failure_history: dict[str, list[dict[str, Any]]] + retry_counts: dict[str, int] + replan_count: int + analyzer_result: dict[str, Any] | None analysis: str - total_steps: int - last_error: dict[str, Any] | None + final_answer: str + failure_summary: str workflow_status: str trace: list[dict[str, Any]] + executed_steps: list[dict[str, Any]] errors: list[dict[str, Any]] @@ -35,17 +39,21 @@ def create_initial_state(query: str, source_ids: list[str] | None = None) -> Ana query=query, source_ids=list(source_ids or []), dataset_context={}, - intent="", - metric="", - planner_reasoning="", - compiled_plan=None, - repair_attempted=False, - artifacts={}, - executed_steps=[], + schema_context_summary={}, + current_plan=None, + current_step_index=0, + generated_query=None, + stored_outputs={}, + step_queries={}, + failure_history={}, + retry_counts={}, + replan_count=0, + analyzer_result=None, analysis="", - total_steps=0, - last_error=None, - workflow_status="planning", + final_answer="", + failure_summary="", + workflow_status="initializing", trace=[], + executed_steps=[], errors=[], ) diff --git a/app/api/workspace.py b/app/api/workspace.py index c67b238..4e52ca2 100644 --- a/app/api/workspace.py +++ b/app/api/workspace.py @@ -17,10 +17,16 @@ STEP_LABELS: dict[str, str] = { "load_schema_context_node": "Schema Context", "planner_compiled_node": "Query Planning", + "planner_node": "Workflow Planning", + "query_writer_node": "Query Writing", "execute_plan_node": "Execution", + "executor_node": "Step Execution", "analysis_node": "Narrative Synthesis", + "analyzer_node": "Final Analysis", "api_analyze": "API Analyze", "repair_planner": "Repair Planning", + "replan_node": "Replanning", + "best_effort_node": "Best Effort Answer", } _STORE_LOCK = Lock() diff --git a/app/prompts/__init__.py b/app/prompts/__init__.py index baf40c6..4c9bb90 100644 --- a/app/prompts/__init__.py +++ b/app/prompts/__init__.py @@ -10,7 +10,7 @@ @lru_cache(maxsize=1) def _prompt_environment() -> Environment: - """Create a strict Jinja environment for agent prompts.""" + """Create a strict Jinja environment for workflow prompts.""" templates_dir = Path(__file__).resolve().parent return Environment( @@ -23,7 +23,7 @@ def _prompt_environment() -> Environment: def render_prompt(template_name: str, **context: object) -> str: - """Render a prompt template with the provided context.""" + """Render one prompt template with the provided context.""" template = _prompt_environment().get_template(template_name) return template.render(**context).strip() diff --git a/app/prompts/analysis_final.j2 b/app/prompts/analysis_final.j2 new file mode 100644 index 0000000..098ae48 --- /dev/null +++ b/app/prompts/analysis_final.j2 @@ -0,0 +1,36 @@ +You are the final analyzer in a bounded analytics workflow. + +You must use only: +- the original question +- the planner's full plan +- the schema/context summary derived from uploaded data +- executed step outputs +- execution metadata and failure history + +Your responsibilities: +- decide whether the question has been answered sufficiently +- provide the final answer when possible +- request one replan only when the outputs are incomplete and a replan is still allowed +- return a clear best-effort answer when the workflow has already exhausted retry/replan limits + +Rules: +- Return only valid JSON matching the schema. +- Ground every statement in the provided plan, schema/context summary, and executed outputs. +- Do not invent columns, joins, relationships, numbers, or business claims. +- If you choose `replan`, `failure_summary` must be specific and actionable. +- If the workflow is already in best-effort mode or `replan_count` is already 1, do not choose `replan`. +- When returning a best-effort final answer, clearly separate answered parts from unanswered parts. + +Return JSON in this shape: +{ + "decision": "final_answer" | "replan", + "summary": "string", + "key_findings": ["string"], + "important_metrics": [{"label": "string", "value": "string"}], + "caveats": ["string"], + "final_answer": "string", + "failure_summary": "string" +} + +Analyzer input: +{{ analyzer_input_json }} diff --git a/app/prompts/analysis_render.j2 b/app/prompts/analysis_render.j2 deleted file mode 100644 index 9121462..0000000 --- a/app/prompts/analysis_render.j2 +++ /dev/null @@ -1,39 +0,0 @@ -You are a careful, domain-agnostic data analyst. - -Task: -Answer the question using only the approved claims provided. - -Hard rules: -- Use only the approved claims. -- Do not add new facts, names, labels, metrics, dates, or interpretations. -- Do not restate any entity, label, or metric unless it appears exactly in the approved claims. -- Use entity strings exactly as provided. Do not abbreviate, normalize, paraphrase, or rename them. -- Do not infer causes, drivers, intent, explanations, or business meaning unless explicitly stated in an approved claim. -- Do not derive new metrics, percentages, rankings, or comparisons unless they are explicitly stated in an approved claim. -- If an approved claim contradicts the question premise, state that in the first sentence clearly and directly. -- Prefer direct numeric comparisons that already appear in approved claims. -- Do not use unsupported adjectives such as "stable", "strong", "healthy", "significant", "improving", or "worsening" unless the exact wording appears in an approved claim. -- If the approved claims are insufficient to answer the question, say so explicitly. -- If the approved claims conflict with each other, say the evidence is internally inconsistent. -- Use only the minimum set of approved claims needed to answer the question. - -Output rules: -- Return only valid JSON. -- The JSON must match this schema exactly: - { - "answer_status": "answered" | "insufficient_evidence" | "contradicted_premise" | "conflicting_evidence", - "analysis_markdown": string, - "used_claim_ids": string[] - } - -Question: -{{ question }} - -Approved claims: -{{ approved_claims_json }} -{% if validation_feedback %} - -Your previous response was rejected by validation. -Validation feedback: -{{ validation_feedback }} -{% endif %} diff --git a/app/prompts/planner_compiled.j2 b/app/prompts/planner_compiled.j2 deleted file mode 100644 index 74c9587..0000000 --- a/app/prompts/planner_compiled.j2 +++ /dev/null @@ -1,47 +0,0 @@ -You are the planning component of a domain-agnostic data analysis agent. - -Return a single JSON object that describes a full multi-step plan to answer the user's question using only the normalized schema manifest described below. - -Rules: -- Follow the response schema exactly. -- Produce 1 to 3 items in "plan" (at most three SQL steps). Each step must add incremental explanatory value; avoid redundant segmentation. -- CRITICAL — the "max_steps" field: set it to the integer 3 always. It is the platform's fixed ceiling, not the count of steps you return. Do not set max_steps to 1 or 2 even if the plan has only one or two queries. -- Every step must use "type": "sql" and put the full SQL statement in "query". -- Use only these registered relation names: {{ relation_names_json }} -- Use exact column names only. Never invent, normalize, paraphrase, or rename a column. -- Use semantic mappings only to translate user language into exact schema fields. -- Prefer relations where "is_primary" is true before reaching for child relations. -- Only join relations when the schema subset exposes an explicit path in "join_keys". -- For nested JSON child relations, use the exact join columns from "join_keys" instead of guessing parent-child keys. -- Do not assume the question premise is true. First verify the main metric or comparison before planning causal or grouped breakdowns. -- Prefer SQL over multiple trivial splits; combine logic when one query suffices. -- No imports, file I/O, network calls, or plotting. -- Optional "output_alias" per step for stable names; if omitted, the executor uses `step_`. - -Return JSON in this exact shape: -{ - "objective": "string — what the plan will establish end-to-end", - "plan": [ - { - "id": 1, - "purpose": "string", - "type": "sql", - "query": "SQL query string" - } - ], - "max_steps": 3, - "metric": "optional short label for the primary metric, or empty string", - "metric_direction": "optional: e.g. higher_is_better or lower_is_better, or empty string" -} - -User query: -{{ query }} - -Normalized schema subset (relations, exact fields, types, grain, semantic mappings): -{{ schema_subset_json }} -{% if validation_feedback %} - -Your previous attempt was rejected. Fix it and try again. -Feedback: -{{ validation_feedback }} -{% endif %} diff --git a/app/prompts/planner_plan.j2 b/app/prompts/planner_plan.j2 new file mode 100644 index 0000000..3d0aba7 --- /dev/null +++ b/app/prompts/planner_plan.j2 @@ -0,0 +1,57 @@ +You are the planner in a bounded analytics workflow. + +You receive: +- the user's natural-language question +- a compact schema/context summary derived from uploaded structured data + +Your job: +- return the full ordered analysis plan in one shot +- Do not write SQL +- use only relations, columns, and join paths present in the schema/context summary +- if the question cannot be fully answered, say so explicitly in `unsupported_requirements` + +Rules: +- Return only valid JSON matching the schema. +- Use at most 3 steps. +- Each step must be logically distinct and useful. +- Every `output_alias` must be stable and human-readable. +- `depends_on` must reference earlier step ids only when a later step needs a prior output. +- `required_columns` should use exact qualified field names like `relation_name.column_name`. +- Do not invent columns, joins, metrics, filters, relationships, or time ranges. +- If the question is unsupported, you may return zero steps. +- The planner must never generate SQL or pseudo-SQL. + +Return JSON in this shape: +{ + "objective": "string", + "can_answer_fully": true, + "unsupported_requirements": [ + { + "type": "column | relationship | concept | time_range | filter | metric | other", + "description": "string", + "relation": "optional string", + "column": "optional string" + } + ], + "steps": [ + { + "id": 1, + "purpose": "string", + "depends_on": [], + "output_alias": "string", + "relations": ["string"], + "required_columns": ["relation.column"], + "expected_output": "string", + "allow_empty_result": false + } + ], + "max_steps": 3, + "metric": "string", + "metric_direction": "string" +} + +User question: +{{ query }} + +Compact schema/context summary: +{{ schema_context_json }} diff --git a/app/prompts/planner_repair.j2 b/app/prompts/planner_repair.j2 deleted file mode 100644 index ada2780..0000000 --- a/app/prompts/planner_repair.j2 +++ /dev/null @@ -1,35 +0,0 @@ -You are the planning component of a domain-agnostic data analysis agent. A SQL step from an existing plan failed execution. - -Repair the failed step only: return JSON that replaces that step with corrected SQL. Do not add new steps. - -Rules: -- Follow the response schema exactly. -- repair_action must be "replace_step". -- updated_step must use "type": "sql", the same id as the failed step ({{ failed_step_id }}), and a fixed "query". -- Use only exact relation and column names from the schema manifest. -- Use semantic mappings only to translate user language into exact schema fields. -- Prefer relations where "is_primary" is true before using child relations. -- Only join relations when the schema manifest exposes the join path in "join_keys". -- Do not invent fields or rename columns. - -Original plan: -{{ plan_json }} - -Failed step id: {{ failed_step_id }} - -Error message: -{{ error_message }} - -Relevant schema subset: -{{ schema_subset_json }} - -Return JSON in this shape: -{ - "repair_action": "replace_step", - "updated_step": { - "id": , - "purpose": "string", - "type": "sql", - "query": "corrected SQL" - } -} diff --git a/app/prompts/planner_replan.j2 b/app/prompts/planner_replan.j2 new file mode 100644 index 0000000..383ab00 --- /dev/null +++ b/app/prompts/planner_replan.j2 @@ -0,0 +1,25 @@ +You are the planner in a bounded analytics workflow. + +The previous plan or execution path failed and needs one revised full plan. + +Rules: +- Return only valid JSON matching the schema. +- Do not write SQL. +- Use only relations, columns, and join paths present in the schema/context summary. +- Use the failure summary to avoid repeating the same mistake. +- If the question is only partially answerable, set `can_answer_fully` to false and explain the gap in `unsupported_requirements`. +- If the available data can no longer support any safe next step, return zero steps. + +Return JSON in the same shape as the initial planner response. + +User question: +{{ query }} + +Current plan: +{{ current_plan_json }} + +Failure summary: +{{ failure_summary }} + +Compact schema/context summary: +{{ schema_context_json }} diff --git a/app/prompts/query_writer.j2 b/app/prompts/query_writer.j2 new file mode 100644 index 0000000..affb036 --- /dev/null +++ b/app/prompts/query_writer.j2 @@ -0,0 +1,47 @@ +You are the query writer in a bounded analytics workflow. + +You receive: +- the user's original question +- one current plan step +- the compact schema/context summary +- prior successful step outputs only when the current step depends on them +- optional error context if this is a retry + +Your job: +- generate exactly one DuckDB SQL query for the current step +- include a short explanation + +Rules: +- Return only valid JSON matching the schema. +- Write exactly one SQL statement in `sql`. +- Stay aligned to the current step only. +- Do not replan the workflow. +- Do not answer the user's question. +- Use only relations and exact columns present in the schema/context summary. +- If the step depends on prior outputs, use only the provided output aliases from `prior_outputs_json`. +- Do not invent columns, joins, or filters. +- Do not generate multiple query options. + +Return JSON in this shape: +{ + "step_id": 1, + "sql": "SELECT ...", + "explanation": "string" +} + +Original question: +{{ query }} + +Current step: +{{ step_json }} + +Compact schema/context summary: +{{ schema_context_json }} + +Prior successful step outputs: +{{ prior_outputs_json }} +{% if error_context %} + +Execution error context for retry: +{{ error_context }} +{% endif %} diff --git a/app/schemas.py b/app/schemas.py index ba30286..597814f 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -152,6 +152,132 @@ class SchemaManifest(BaseModel): views: list[dict[str, Any]] = Field(default_factory=list) +class CompactSchemaColumn(BaseModel): + """Compact field summary shared with the planner, query writer, and analyzer.""" + + model_config = ConfigDict(extra="forbid") + + name: str = Field(..., min_length=1) + dtype: str = Field(..., min_length=1) + type_family: Literal["string", "number", "boolean", "datetime", "unknown"] = "unknown" + nullable: bool = True + semantic_hints: list[str] = Field(default_factory=list) + + +class CompactSchemaRelation(BaseModel): + """One compact relation summary used as workflow source-of-truth.""" + + model_config = ConfigDict(extra="forbid") + + name: str = Field(..., min_length=1) + source_id: str = "" + source_name: str = "" + is_primary: bool = False + row_count: int = 0 + grain: str = "" + identifier_columns: list[str] = Field(default_factory=list) + time_columns: list[str] = Field(default_factory=list) + measure_columns: list[str] = Field(default_factory=list) + dimension_columns: list[str] = Field(default_factory=list) + join_keys: list[SchemaJoinKey] = Field(default_factory=list) + semantic_mappings: list[SchemaConceptMapping] = Field(default_factory=list) + columns: list[CompactSchemaColumn] = Field(default_factory=list) + + +class CompactSchemaContext(BaseModel): + """Compact schema/context summary passed into the multi-agent workflow.""" + + model_config = ConfigDict(extra="forbid") + + reference_date: str = "" + source: str = "" + dialect: str = "" + relations: list[CompactSchemaRelation] = Field(default_factory=list) + + +class UnsupportedRequirement(BaseModel): + """One reason the question may be only partially answerable with the available schema.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["column", "relationship", "concept", "time_range", "filter", "metric", "other"] = "other" + description: str = Field(..., min_length=1) + relation: str | None = None + column: str | None = None + + +class AnalysisPlanStep(BaseModel): + """One planner-authored step without executable query text.""" + + model_config = ConfigDict(extra="forbid") + + id: int + purpose: str = Field(..., min_length=1) + depends_on: list[int] = Field(default_factory=list) + output_alias: str = Field(..., min_length=1) + relations: list[str] = Field(default_factory=list) + required_columns: list[str] = Field(default_factory=list) + expected_output: str = Field(..., min_length=1) + allow_empty_result: bool = False + + +class AnalysisPlan(BaseModel): + """Full ordered workflow plan created by the planner in one shot.""" + + model_config = ConfigDict(extra="forbid") + + objective: str = Field(..., min_length=1) + can_answer_fully: bool = True + unsupported_requirements: list[UnsupportedRequirement] = Field(default_factory=list) + steps: list[AnalysisPlanStep] = Field(default_factory=list, max_length=3) + max_steps: int + metric: str = "" + metric_direction: str = "" + + @field_validator("max_steps", mode="before") + @classmethod + def normalize_max_steps(cls, v: Any) -> int: + """The workflow keeps a fixed ceiling of three executable steps.""" + + return 3 + + +class GeneratedQuery(BaseModel): + """One SQL query emitted by the query writer for a single workflow step.""" + + model_config = ConfigDict(extra="forbid") + + step_id: int + sql: str = Field(..., min_length=1) + explanation: str = Field(..., min_length=1) + + +class StepFailureRecord(BaseModel): + """Execution failure details kept in workflow state across retries and replans.""" + + model_config = ConfigDict(extra="forbid") + + step_id: int + attempt: int + error: str = Field(..., min_length=1) + query: str = "" + details: dict[str, Any] = Field(default_factory=dict) + + +class AnalyzerDecision(BaseModel): + """Final analyzer output: answer, replan decision, or best-effort summary.""" + + model_config = ConfigDict(extra="forbid") + + decision: Literal["final_answer", "replan"] + summary: str = Field(..., min_length=1) + key_findings: list[str] = Field(default_factory=list) + important_metrics: list[EvidenceValue] = Field(default_factory=list) + caveats: list[str] = Field(default_factory=list) + final_answer: str = "" + failure_summary: str = "" + + class EvidenceValue(BaseModel): """One exact label/value pair carried into the analysis evidence packet.""" diff --git a/data/source_registry.duckdb b/data/source_registry.duckdb deleted file mode 100644 index f9fc572..0000000 Binary files a/data/source_registry.duckdb and /dev/null differ diff --git a/tests/test_agent_contracts.py b/tests/test_agent_contracts.py new file mode 100644 index 0000000..55d1eb5 --- /dev/null +++ b/tests/test_agent_contracts.py @@ -0,0 +1,90 @@ +"""Contract and scaffolding tests for the staged agent runtime restoration.""" + +from __future__ import annotations + +from app.agent.state import create_initial_state +from app.api.workspace import STEP_LABELS +from app.schemas import AnalysisPlan, AnalyzerDecision, GeneratedQuery + + +def test_analysis_plan_normalizes_max_steps() -> None: + plan = AnalysisPlan.model_validate( + { + "objective": "Answer the question with available uploads.", + "can_answer_fully": True, + "unsupported_requirements": [], + "steps": [ + { + "id": 1, + "purpose": "Compute a grouped summary.", + "depends_on": [], + "output_alias": "grouped_summary", + "relations": ["orders_source_1234"], + "required_columns": ["orders_source_1234.amount"], + "expected_output": "A grouped table for downstream use.", + "allow_empty_result": False, + } + ], + "max_steps": 1, + "metric": "amount", + "metric_direction": "higher_is_better", + } + ) + + assert plan.max_steps == 3 + + +def test_generated_query_requires_one_sql_string() -> None: + query = GeneratedQuery.model_validate( + { + "step_id": 2, + "sql": "SELECT 1 AS value", + "explanation": "Returns a placeholder row for the current step.", + } + ) + + assert query.step_id == 2 + assert query.sql == "SELECT 1 AS value" + + +def test_analyzer_decision_supports_replan_shape() -> None: + decision = AnalyzerDecision.model_validate( + { + "decision": "replan", + "summary": "The current outputs only answer part of the question.", + "key_findings": [], + "important_metrics": [], + "caveats": ["A required relationship was not available."], + "final_answer": "", + "failure_summary": "Required relationship not present in schema/context.", + } + ) + + assert decision.decision == "replan" + assert decision.failure_summary != "" + + +def test_create_initial_state_exposes_new_workflow_fields() -> None: + state = create_initial_state("How many orders are overdue?", source_ids=["source_123"]) + + assert state["query"] == "How many orders are overdue?" + assert state["source_ids"] == ["source_123"] + assert state["schema_context_summary"] == {} + assert state["current_plan"] is None + assert state["stored_outputs"] == {} + assert state["step_queries"] == {} + assert state["failure_history"] == {} + assert state["retry_counts"] == {} + assert state["replan_count"] == 0 + assert state["trace"] == [] + assert state["executed_steps"] == [] + assert state["errors"] == [] + + +def test_workspace_step_labels_include_new_workflow_nodes() -> None: + assert STEP_LABELS["planner_node"] == "Workflow Planning" + assert STEP_LABELS["query_writer_node"] == "Query Writing" + assert STEP_LABELS["executor_node"] == "Step Execution" + assert STEP_LABELS["analyzer_node"] == "Final Analysis" + assert STEP_LABELS["replan_node"] == "Replanning" + assert STEP_LABELS["best_effort_node"] == "Best Effort Answer" diff --git a/tests/test_analysis_grounding.py b/tests/test_analysis_grounding.py deleted file mode 100644 index 7135bea..0000000 --- a/tests/test_analysis_grounding.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Tests for deterministic analysis evidence, approved claims, and rendering validation.""" - -from __future__ import annotations - -import pytest - -from app.agent.analysis_grounding import build_analysis_evidence, build_approved_claims, validate_rendered_analysis -from app.agent.state import create_initial_state -from app.schemas import AnalysisRenderResponse, ApprovedClaim, EvidenceValue - - -def test_build_approved_claims_marks_contradicted_premise() -> None: - state = create_initial_state("Why did pipeline velocity drop this week?") - state["metric"] = "avg_pipeline_velocity_days" - state["compiled_plan"] = {"metric_direction": "lower_is_better"} - state["executed_steps"] = [ - { - "id": "1", - "purpose": "Compare weekly metrics", - "status": "success", - "output_alias": "weekly_pipeline_metrics", - "artifact": { - "alias": "weekly_pipeline_metrics", - "columns": ["period", "avg_pipeline_velocity_days"], - "preview_rows": [ - {"period": "Previous Week", "avg_pipeline_velocity_days": 69.9423076923077}, - {"period": "Current Week", "avg_pipeline_velocity_days": 64.13291139240506}, - ], - }, - } - ] - - evidence = build_analysis_evidence(state) - claims, status = build_approved_claims(evidence) - - assert status == "contradicted_premise" - assert claims[0].kind == "premise_check" - assert "does not support a deterioration premise" in claims[0].statement - - -def test_validate_rendered_analysis_rejects_changed_entity_name() -> None: - claim = ApprovedClaim( - id="claim_manager", - kind="row_observation", - statement="For Celia Rouche, current_week_avg_velocity = 64.65.", - entities=["Celia Rouche"], - metrics=["current_week_avg_velocity"], - source_aliases=["regional_manager_velocity"], - values=[EvidenceValue(label="current_week_avg_velocity", value="64.65")], - ) - response = AnalysisRenderResponse( - answer_status="answered", - analysis_markdown="## Summary\nFor eCelia Rouche, current_week_avg_velocity = 64.65.", - used_claim_ids=["claim_manager"], - ) - - with pytest.raises(ValueError, match="changed an approved entity name"): - validate_rendered_analysis(response, [claim], expected_status="answered") - - -def test_validate_rendered_analysis_rejects_unapproved_numbers() -> None: - claim = ApprovedClaim( - id="claim_manager", - kind="row_observation", - statement="For Celia Rouche, current_week_avg_velocity = 64.65.", - entities=["Celia Rouche"], - metrics=["current_week_avg_velocity"], - source_aliases=["regional_manager_velocity"], - values=[EvidenceValue(label="current_week_avg_velocity", value="64.65")], - ) - response = AnalysisRenderResponse( - answer_status="answered", - analysis_markdown="## Summary\nFor Celia Rouche, current_week_avg_velocity = 60.", - used_claim_ids=["claim_manager"], - ) - - with pytest.raises(ValueError, match="introduced numbers not present"): - validate_rendered_analysis(response, [claim], expected_status="answered") diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py new file mode 100644 index 0000000..79e745e --- /dev/null +++ b/tests/test_analyzer.py @@ -0,0 +1,95 @@ +"""Analyzer behavior tests for the final workflow interface.""" + +from __future__ import annotations + +from app.agent.analysis import analyze_workflow +from app.agent.state import create_initial_state + + +def test_analyzer_returns_final_answer_when_llm_succeeds(monkeypatch) -> None: + class FakeAnalyzerLLM: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + return { + "decision": "final_answer", + "summary": "The workflow produced a grounded answer.", + "key_findings": ["The final output contains 2 rows."], + "important_metrics": [{"label": "row_count", "value": "2"}], + "caveats": [], + "final_answer": "## Summary\nThe final output contains 2 rows.", + "failure_summary": "", + } + + monkeypatch.setattr("app.agent.analysis.get_llm_client", lambda: FakeAnalyzerLLM()) + + state = create_initial_state("What happened?") + state["workflow_status"] = "ready_for_analysis" + state["schema_context_summary"] = {"reference_date": "", "relations": []} + state["current_plan"] = {"objective": "Test", "steps": []} + state["executed_steps"] = [ + { + "id": "1", + "status": "success", + "purpose": "Summarize rows", + "output_alias": "final_output", + "artifact": {"alias": "final_output", "row_count": 2, "columns": ["value"], "preview_rows": [{"value": 1}], "summary": {}}, + } + ] + + state = analyze_workflow(state) + + assert state["workflow_status"] == "complete" + assert "2 rows" in state["analysis"] + + +def test_analyzer_requests_replan_when_outputs_are_incomplete(monkeypatch) -> None: + class FakeAnalyzerLLM: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + return { + "decision": "replan", + "summary": "The outputs do not answer the question yet.", + "key_findings": [], + "important_metrics": [], + "caveats": [], + "final_answer": "", + "failure_summary": "Final results answered only part of the user question.", + } + + monkeypatch.setattr("app.agent.analysis.get_llm_client", lambda: FakeAnalyzerLLM()) + + state = create_initial_state("What happened?") + state["workflow_status"] = "ready_for_analysis" + state["schema_context_summary"] = {"reference_date": "", "relations": []} + state["current_plan"] = {"objective": "Test", "steps": []} + + state = analyze_workflow(state) + + assert state["workflow_status"] == "needs_replan" + assert state["failure_summary"] == "Final results answered only part of the user question." + + +def test_analyzer_returns_best_effort_after_limits(monkeypatch) -> None: + class FakeAnalyzerLLM: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + return { + "decision": "final_answer", + "summary": "Returning the best-effort answer.", + "key_findings": [], + "important_metrics": [], + "caveats": ["Requested relationship was unavailable."], + "final_answer": "## Best-effort answer\nAnswered parts:\n- Captured 1 row.\n\nCould not answer completely:\n- Requested relationship was unavailable.", + "failure_summary": "", + } + + monkeypatch.setattr("app.agent.analysis.get_llm_client", lambda: FakeAnalyzerLLM()) + + state = create_initial_state("What happened?") + state["workflow_status"] = "best_effort_ready" + state["schema_context_summary"] = {"reference_date": "", "relations": []} + state["current_plan"] = {"objective": "Test", "steps": []} + state["final_answer"] = "## Best-effort answer\nAnswered parts:\n- Captured 1 row." + state["failure_summary"] = "Requested relationship was unavailable." + + state = analyze_workflow(state) + + assert state["workflow_status"] == "complete" + assert "## Best-effort answer" in state["analysis"] diff --git a/tests/test_executor_graph.py b/tests/test_executor_graph.py new file mode 100644 index 0000000..94e61ae --- /dev/null +++ b/tests/test_executor_graph.py @@ -0,0 +1,209 @@ +"""Execution and orchestration tests for the staged workflow control flow.""" + +from __future__ import annotations + +import pytest + +from app.agent.graph import run_analysis +from app.config import get_settings +from app.data.registry import clear_source_registry, ingest_source +from app.data.semantic_model import clear_semantic_context_cache + + +@pytest.fixture(autouse=True) +def isolated_registry(tmp_path, monkeypatch): + monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "executor_graph_registry.duckdb")) + get_settings.cache_clear() + clear_source_registry() + clear_semantic_context_cache() + yield + clear_source_registry() + clear_semantic_context_cache() + get_settings.cache_clear() + + +def test_run_analysis_retries_one_failed_step_then_succeeds(monkeypatch) -> None: + asset = ingest_source("pipeline.csv", b"owner,pipeline_velocity_days\nAda,10\nBen,14\n") + + class PlannerStub: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + return { + "objective": "Summarize velocity by owner.", + "can_answer_fully": True, + "unsupported_requirements": [], + "steps": [ + { + "id": 1, + "purpose": "Summarize velocity by owner.", + "depends_on": [], + "output_alias": "velocity_by_owner", + "relations": [asset.primaryRelationName], + "required_columns": [ + f"{asset.primaryRelationName}.owner", + f"{asset.primaryRelationName}.pipeline_velocity_days", + ], + "expected_output": "A grouped owner summary.", + "allow_empty_result": False, + } + ], + "max_steps": 3, + "metric": "pipeline_velocity_days", + "metric_direction": "lower_is_better", + } + + class QueryWriterStub: + def __init__(self) -> None: + self.calls = 0 + + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + self.calls += 1 + if self.calls == 1: + return { + "step_id": 1, + "sql": f"SELECT sales_agent, AVG(pipeline_velocity_days) AS avg_velocity_days FROM {asset.primaryRelationName} GROUP BY sales_agent", + "explanation": "First attempt uses the wrong column name.", + } + return { + "step_id": 1, + "sql": f"SELECT owner, AVG(pipeline_velocity_days) AS avg_velocity_days FROM {asset.primaryRelationName} GROUP BY owner", + "explanation": "Retry uses the available owner column.", + } + + query_writer_stub = QueryWriterStub() + monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: PlannerStub()) + monkeypatch.setattr("app.agent.query_writer.get_llm_client", lambda: query_writer_stub) + + state = run_analysis("Which owners are slowest?", source_ids=[asset.id]) + + assert query_writer_stub.calls == 2 + assert state["retry_counts"]["1"] == 1 + assert len(state["failure_history"]["1"]) == 1 + assert state["stored_outputs"]["velocity_by_owner"] is not None + assert state["workflow_status"] == "complete" + + +def test_run_analysis_replans_once_after_second_failure(monkeypatch) -> None: + asset = ingest_source("pipeline.csv", b"owner,pipeline_velocity_days\nAda,10\nBen,14\n") + + class PlannerStub: + def __init__(self) -> None: + self.calls = 0 + + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + self.calls += 1 + if self.calls == 1: + return { + "objective": "Try an owner summary.", + "can_answer_fully": True, + "unsupported_requirements": [], + "steps": [ + { + "id": 1, + "purpose": "Broken owner summary.", + "depends_on": [], + "output_alias": "broken_summary", + "relations": [asset.primaryRelationName], + "required_columns": [f"{asset.primaryRelationName}.owner"], + "expected_output": "A broken output.", + "allow_empty_result": False, + } + ], + "max_steps": 3, + "metric": "", + "metric_direction": "", + } + return { + "objective": "Fallback to a simple row count.", + "can_answer_fully": True, + "unsupported_requirements": [], + "steps": [ + { + "id": 1, + "purpose": "Count available rows.", + "depends_on": [], + "output_alias": "row_count", + "relations": [asset.primaryRelationName], + "required_columns": [f"{asset.primaryRelationName}.record_id"], + "expected_output": "A row count for best-effort answering.", + "allow_empty_result": False, + } + ], + "max_steps": 3, + "metric": "", + "metric_direction": "", + } + + class QueryWriterStub: + def __init__(self) -> None: + self.calls = 0 + + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + self.calls += 1 + if self.calls <= 2: + return { + "step_id": 1, + "sql": f"SELECT sales_agent FROM {asset.primaryRelationName}", + "explanation": "This query will fail twice.", + } + return { + "step_id": 1, + "sql": f"SELECT COUNT(*) AS total_rows FROM {asset.primaryRelationName}", + "explanation": "The replanned query succeeds.", + } + + planner_stub = PlannerStub() + query_writer_stub = QueryWriterStub() + monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: planner_stub) + monkeypatch.setattr("app.agent.query_writer.get_llm_client", lambda: query_writer_stub) + + state = run_analysis("How much data is available?", source_ids=[asset.id]) + + assert planner_stub.calls == 2 + assert state["replan_count"] == 1 + assert state["stored_outputs"]["row_count"] is not None + assert state["workflow_status"] == "complete" + + +def test_run_analysis_returns_best_effort_after_replan_limit(monkeypatch) -> None: + asset = ingest_source("pipeline.csv", b"owner,pipeline_velocity_days\nAda,10\nBen,14\n") + + class PlannerStub: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + return { + "objective": "Keep trying the same unsupported summary.", + "can_answer_fully": True, + "unsupported_requirements": [], + "steps": [ + { + "id": 1, + "purpose": "Broken summary.", + "depends_on": [], + "output_alias": "broken_summary", + "relations": [asset.primaryRelationName], + "required_columns": [f"{asset.primaryRelationName}.owner"], + "expected_output": "A summary that keeps failing.", + "allow_empty_result": False, + } + ], + "max_steps": 3, + "metric": "", + "metric_direction": "", + } + + class QueryWriterStub: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + return { + "step_id": 1, + "sql": f"SELECT sales_agent FROM {asset.primaryRelationName}", + "explanation": "Always fails because sales_agent is unavailable.", + } + + monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: PlannerStub()) + monkeypatch.setattr("app.agent.query_writer.get_llm_client", lambda: QueryWriterStub()) + + state = run_analysis("Which owners are slowest?", source_ids=[asset.id]) + + assert state["replan_count"] == 1 + assert state["workflow_status"] == "complete" + assert "## Best-effort answer" in state["analysis"] + assert state["failure_summary"] != "" diff --git a/tests/test_intent.py b/tests/test_intent.py index db126b0..f078661 100644 --- a/tests/test_intent.py +++ b/tests/test_intent.py @@ -1,59 +1,103 @@ -"""Planner and analysis contract tests with mocked LLM responses.""" +"""Planner and analyzer contract tests with mocked LLM responses.""" + +from __future__ import annotations from app.agent.analysis import run_analysis_narrative -from app.agent.planner import plan_compiled_query +from app.agent.planner import plan_analysis from app.agent.state import create_initial_state class FakeLLM: - """Minimal stub for planner and analysis tests.""" + """Minimal stub for planner and analyzer tests.""" def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 - if '"max_steps": 3' in prompt and "metric_direction" in prompt: + if '"max_steps": 3' in prompt and '"steps"' in prompt: return { - "objective": "Compare current and previous pipeline velocity.", - "plan": [ + "objective": "Compare current and previous pipeline velocity with one bounded step.", + "can_answer_fully": True, + "unsupported_requirements": [], + "steps": [ { "id": 1, - "purpose": "Compare current and previous pipeline velocity.", - "type": "sql", - "query": "SELECT 1 AS value", + "purpose": "Summarize the metric by period.", + "depends_on": [], "output_alias": "comparison_result", + "relations": ["opportunities_enriched"], + "required_columns": ["opportunities_enriched.pipeline_velocity_days"], + "expected_output": "A table with the grouped metric.", + "allow_empty_result": False, } ], "max_steps": 3, - "metric": "pipeline_velocity", + "metric": "pipeline_velocity_days", "metric_direction": "lower_is_better", } - if '"analysis_markdown": string' in prompt and "approved claims" in prompt.lower(): + if '"decision": "final_answer" | "replan"' in prompt: return { - "answer_status": "answered", - "analysis_markdown": "## Summary\nThe available evidence shows value = 1 for SMB.", - "used_claim_ids": ["claim_comparison_result_row_1"], + "decision": "final_answer", + "summary": "The workflow produced a usable final output.", + "key_findings": ["1 row was returned by the final step."], + "important_metrics": [{"label": "row_count", "value": "1"}], + "caveats": [], + "final_answer": "## Summary\nThe final step returned 1 row in comparison_result.", + "failure_summary": "", } return { - "answer_status": "insufficient_evidence", - "analysis_markdown": "The approved claims are insufficient to answer the question.", - "used_claim_ids": [], + "decision": "replan", + "summary": "The current outputs are incomplete.", + "key_findings": [], + "important_metrics": [], + "caveats": [], + "final_answer": "", + "failure_summary": "The collected outputs were incomplete.", } -def test_planner_returns_compiled_plan(monkeypatch) -> None: +def test_planner_returns_full_non_sql_plan(monkeypatch) -> None: monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: FakeLLM()) state = create_initial_state("Why did pipeline velocity drop this week?") - state["dataset_context"] = {"reference_date": "2017-12-31", "views": [{"name": "opportunities_enriched"}]} - state = plan_compiled_query(state) - assert state["compiled_plan"] is not None - assert state["compiled_plan"]["plan"][0]["type"] == "sql" - assert state["compiled_plan"]["plan"][0]["query"] == "SELECT 1 AS value" - assert state["compiled_plan"]["plan"][0]["output_alias"] == "comparison_result" + state["dataset_context"] = { + "reference_date": "2017-12-31", + "source": "source_registry", + "dialect": "duckdb", + "relations": [ + { + "name": "opportunities_enriched", + "source_id": "source_1", + "source_name": "opportunities.csv", + "is_primary": True, + "row_count": 10, + "grain": "One row per opportunity", + "identifier_columns": ["record_id"], + "time_columns": [], + "measure_columns": ["pipeline_velocity_days"], + "dimension_columns": [], + "join_keys": [], + "semantic_mappings": [], + "columns": [ + { + "name": "pipeline_velocity_days", + "dtype": "DOUBLE", + "type_family": "number", + "nullable": True, + "semantic_hints": ["pipeline velocity"], + } + ], + } + ], + } + state = plan_analysis(state) + assert state["current_plan"] is not None + assert state["current_plan"]["steps"][0]["output_alias"] == "comparison_result" + assert "query" not in state["current_plan"]["steps"][0] def test_analysis_narrative_uses_llm(monkeypatch) -> None: monkeypatch.setattr("app.agent.analysis.get_llm_client", lambda: FakeLLM()) state = create_initial_state("Why did pipeline velocity drop this week?") - state["dataset_context"] = {"reference_date": "2017-12-31", "views": []} - state["compiled_plan"] = {"objective": "Test", "metric": "", "metric_direction": ""} + state["schema_context_summary"] = {"reference_date": "2017-12-31", "relations": []} + state["current_plan"] = {"objective": "Test", "metric": "", "metric_direction": "", "steps": []} + state["workflow_status"] = "ready_for_analysis" state["executed_steps"] = [ { "id": "step_1", @@ -70,5 +114,5 @@ def test_analysis_narrative_uses_llm(monkeypatch) -> None: } ] state = run_analysis_narrative(state) - assert "value = 1" in state["analysis"] - assert "SMB" in state["analysis"] + assert "comparison_result" in state["analysis"] + assert state["workflow_status"] == "complete" diff --git a/tests/test_planner_query_writer.py b/tests/test_planner_query_writer.py new file mode 100644 index 0000000..85e4e12 --- /dev/null +++ b/tests/test_planner_query_writer.py @@ -0,0 +1,197 @@ +"""Planner and query-writer behavior tests for the staged workflow implementation.""" + +from __future__ import annotations + +from app.agent.planner import _schema_subset_for_question, plan_analysis, replan_analysis +from app.agent.query_writer import write_step_query +from app.agent.state import create_initial_state +from app.data.registry import clear_source_registry, ingest_source +from app.data.semantic_model import clear_semantic_context_cache, get_semantic_context + + +def test_schema_subset_prefers_relevant_uploaded_relation(monkeypatch, tmp_path) -> None: + monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "planner_subset.duckdb")) + clear_source_registry() + clear_semantic_context_cache() + try: + ingest_source("inventory.csv", b"sku,stock\nA1,10\nB2,3\n") + invoices = ingest_source("invoices.csv", b"invoice_id,invoice_amount,status\ni1,100,paid\ni2,250,due\n") + ingest_source("campaigns.csv", b"campaign,spend\nspring,1200\nsummer,950\n") + + manifest = get_semantic_context().schema_manifest + subset = _schema_subset_for_question(manifest, "Which invoice amount is highest?") + finally: + clear_source_registry() + clear_semantic_context_cache() + + assert any(relation["name"] == invoices.primaryRelationName for relation in subset["relations"]) + + +def test_plan_analysis_stores_full_non_sql_plan(monkeypatch) -> None: + class FakePlannerLLM: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + assert "Do not write SQL" in prompt + return { + "objective": "Answer the question with two bounded steps.", + "can_answer_fully": True, + "unsupported_requirements": [], + "steps": [ + { + "id": 1, + "purpose": "Summarize velocity by owner.", + "depends_on": [], + "output_alias": "velocity_by_owner", + "relations": ["pipeline_source_1234"], + "required_columns": ["pipeline_source_1234.owner", "pipeline_source_1234.pipeline_velocity_days"], + "expected_output": "A grouped table by owner.", + "allow_empty_result": False, + }, + { + "id": 2, + "purpose": "Identify the slowest owners from the summary.", + "depends_on": [1], + "output_alias": "slowest_owners", + "relations": ["velocity_by_owner"], + "required_columns": ["velocity_by_owner.owner", "velocity_by_owner.avg_velocity_days"], + "expected_output": "A ranked subset for the final answer.", + "allow_empty_result": False, + }, + ], + "max_steps": 3, + "metric": "pipeline_velocity_days", + "metric_direction": "lower_is_better", + } + + monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: FakePlannerLLM()) + + state = create_initial_state("Why is velocity worse for some owners?") + state["dataset_context"] = { + "reference_date": "2026-04-10", + "source": "source_registry", + "dialect": "duckdb", + "relations": [ + { + "name": "pipeline_source_1234", + "source_id": "source_1234", + "source_name": "pipeline.csv", + "is_primary": True, + "row_count": 20, + "grain": "One row per opportunity", + "identifier_columns": ["record_id"], + "time_columns": [], + "measure_columns": ["pipeline_velocity_days"], + "dimension_columns": ["owner"], + "join_keys": [], + "semantic_mappings": [], + "columns": [ + {"name": "owner", "dtype": "VARCHAR", "type_family": "string", "nullable": False, "semantic_hints": ["owner"]}, + { + "name": "pipeline_velocity_days", + "dtype": "DOUBLE", + "type_family": "number", + "nullable": True, + "semantic_hints": ["pipeline velocity"], + }, + ], + } + ], + } + + state = plan_analysis(state) + + assert state["workflow_status"] == "planned" + assert state["current_plan"] is not None + assert state["current_plan"]["steps"][0]["output_alias"] == "velocity_by_owner" + assert "sql" not in state["current_plan"]["steps"][0] + + +def test_plan_analysis_preserves_unsupported_requirements(monkeypatch) -> None: + class FakePlannerLLM: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + return { + "objective": "Answer the supported part of the question.", + "can_answer_fully": False, + "unsupported_requirements": [ + { + "type": "relationship", + "description": "A relationship between customers and invoices is not available.", + "relation": "customers_source_1234", + } + ], + "steps": [], + "max_steps": 3, + "metric": "", + "metric_direction": "", + } + + monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: FakePlannerLLM()) + + state = create_initial_state("Which customers have the highest invoice amount?") + state["dataset_context"] = {"reference_date": "", "source": "", "dialect": "duckdb", "relations": []} + state = plan_analysis(state) + + assert state["current_plan"]["can_answer_fully"] is False + assert state["current_plan"]["unsupported_requirements"][0]["type"] == "relationship" + + +def test_query_writer_emits_one_query_for_current_step(monkeypatch) -> None: + class FakeQueryWriterLLM: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + assert "exactly one DuckDB SQL query" in prompt + return { + "step_id": 1, + "sql": "SELECT owner, AVG(pipeline_velocity_days) AS avg_velocity_days FROM pipeline_source_1234 GROUP BY owner", + "explanation": "Aggregates velocity by owner for the first step.", + } + + monkeypatch.setattr("app.agent.query_writer.get_llm_client", lambda: FakeQueryWriterLLM()) + + state = create_initial_state("Why is velocity worse for some owners?") + state["schema_context_summary"] = {"reference_date": "", "source": "", "dialect": "duckdb", "relations": []} + state["current_plan"] = { + "steps": [ + { + "id": 1, + "purpose": "Summarize velocity by owner.", + "depends_on": [], + "output_alias": "velocity_by_owner", + "relations": ["pipeline_source_1234"], + "required_columns": ["pipeline_source_1234.owner", "pipeline_source_1234.pipeline_velocity_days"], + "expected_output": "A grouped table by owner.", + "allow_empty_result": False, + } + ] + } + + state = write_step_query(state) + + assert state["generated_query"]["sql"].startswith("SELECT owner") + assert state["step_queries"]["1"] == [state["generated_query"]["sql"]] + + +def test_replan_prompt_includes_failure_summary(monkeypatch) -> None: + captured: dict[str, str] = {} + + class FakePlannerLLM: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + captured["prompt"] = prompt + return { + "objective": "Try a safer fallback plan.", + "can_answer_fully": False, + "unsupported_requirements": [], + "steps": [], + "max_steps": 3, + "metric": "", + "metric_direction": "", + } + + monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: FakePlannerLLM()) + + state = create_initial_state("Why is velocity worse for some owners?") + state["dataset_context"] = {"reference_date": "", "source": "", "dialect": "duckdb", "relations": []} + state["current_plan"] = {"objective": "Old plan", "steps": []} + state["failure_summary"] = "Required column not present in schema/context." + + replan_analysis(state) + + assert "Required column not present in schema/context." in captured["prompt"] diff --git a/tests/test_planner_schema.py b/tests/test_planner_schema.py index 3659289..18cfff5 100644 --- a/tests/test_planner_schema.py +++ b/tests/test_planner_schema.py @@ -1,13 +1,15 @@ -"""Compiled plan schema and planner retry behavior.""" +"""Planner schema and compact context behavior.""" + +from __future__ import annotations import pytest -from app.agent.planner import plan_compiled_query +from app.agent.planner import build_compact_schema_context, plan_analysis from app.agent.state import create_initial_state from app.config import get_settings from app.data.registry import clear_source_registry, ingest_source from app.data.semantic_model import clear_semantic_context_cache, get_semantic_context -from app.schemas import CompiledPlan +from app.schemas import AnalysisPlan @pytest.fixture(autouse=True) @@ -22,16 +24,22 @@ def isolated_registry(tmp_path, monkeypatch): get_settings.cache_clear() -def test_compiled_plan_normalizes_max_steps() -> None: - plan = CompiledPlan.model_validate( +def test_analysis_plan_normalizes_max_steps() -> None: + plan = AnalysisPlan.model_validate( { - "objective": "Test", - "plan": [ + "objective": "Test the planner contract", + "can_answer_fully": True, + "unsupported_requirements": [], + "steps": [ { "id": 1, "purpose": "One step", - "type": "sql", - "query": "SELECT 1", + "depends_on": [], + "output_alias": "sample_output", + "relations": ["sample_relation"], + "required_columns": ["sample_relation.value"], + "expected_output": "One simple result.", + "allow_empty_result": False, } ], "max_steps": 1, @@ -42,50 +50,38 @@ def test_compiled_plan_normalizes_max_steps() -> None: assert plan.max_steps == 3 -def test_planner_retries_after_validation_error(monkeypatch) -> None: - good = { - "objective": "Segment counts", - "plan": [ - { - "id": 1, - "purpose": "Count by segment", - "type": "sql", - "query": "SELECT 1 AS value", - "output_alias": "counts", - } - ], - "max_steps": 3, - "metric": "", - "metric_direction": "", - } - bad = { - "objective": "Too many", - "plan": good["plan"] * 4, - "max_steps": 3, - "metric": "", - "metric_direction": "", - } - - class FlakyPlannerLLM: - def __init__(self) -> None: - self.calls = 0 - +def test_plan_analysis_stores_current_plan(monkeypatch) -> None: + class PlannerLLM: def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 - self.calls += 1 - if self.calls == 1: - return bad - return good + return { + "objective": "Count by segment", + "can_answer_fully": True, + "unsupported_requirements": [], + "steps": [ + { + "id": 1, + "purpose": "Count by segment", + "depends_on": [], + "output_alias": "counts", + "relations": ["segments_source_1234"], + "required_columns": ["segments_source_1234.segment"], + "expected_output": "A grouped count table.", + "allow_empty_result": False, + } + ], + "max_steps": 3, + "metric": "", + "metric_direction": "", + } - stub = FlakyPlannerLLM() - monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: stub) + monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: PlannerLLM()) state = create_initial_state("Compare segments") - state["dataset_context"] = {"reference_date": "2017-12-31", "relations": [], "views": []} - state = plan_compiled_query(state) + state["dataset_context"] = {"reference_date": "2017-12-31", "source": "", "dialect": "duckdb", "relations": []} + state = plan_analysis(state) - assert stub.calls == 2 - assert state["compiled_plan"] is not None - assert len(state["compiled_plan"]["plan"]) == 1 + assert state["current_plan"] is not None + assert len(state["current_plan"]["steps"]) == 1 def test_semantic_context_exposes_normalized_manifest() -> None: @@ -100,60 +96,13 @@ def test_semantic_context_exposes_normalized_manifest() -> None: assert relation["grain"] -def test_planner_retries_after_sql_preflight_failure(monkeypatch) -> None: +def test_compact_schema_context_exposes_relevant_relation() -> None: asset = ingest_source("pipeline.csv", b"owner,pipeline_velocity_days\nAda,10\nBen,14\n") - relation_name = asset.primaryRelationName - bad = { - "objective": "Analyze by sales agent", - "plan": [ - { - "id": 1, - "purpose": "Break out velocity by agent", - "type": "sql", - "query": f"SELECT sales_agent, AVG(pipeline_velocity_days) AS avg_velocity_days FROM {relation_name} GROUP BY sales_agent", - "output_alias": "velocity_by_agent", - } - ], - "max_steps": 3, - "metric": "pipeline_velocity_days", - "metric_direction": "lower_is_better", - } - good = { - "objective": "Analyze by owner", - "plan": [ - { - "id": 1, - "purpose": "Break out velocity by owner", - "type": "sql", - "query": f"SELECT owner, AVG(pipeline_velocity_days) AS avg_velocity_days FROM {relation_name} GROUP BY owner", - "output_alias": "velocity_by_owner", - } - ], - "max_steps": 3, - "metric": "pipeline_velocity_days", - "metric_direction": "lower_is_better", - } - - class FlakyPlannerLLM: - def __init__(self) -> None: - self.calls = 0 - self.prompts: list[str] = [] + compact = build_compact_schema_context( + get_semantic_context([asset.id]).schema_manifest, + "Why did pipeline velocity change by owner?", + ) - def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 - self.calls += 1 - self.prompts.append(prompt) - if self.calls == 1: - return bad - return good - - stub = FlakyPlannerLLM() - monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: stub) - - state = create_initial_state("Why did pipeline velocity drop this week by sales agent?") - state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest - state = plan_compiled_query(state) - - assert stub.calls == 2 - assert "failed SQL preflight validation" in stub.prompts[1] - assert state["compiled_plan"] is not None - assert "owner" in state["compiled_plan"]["plan"][0]["query"] + assert compact["relations"] + assert compact["relations"][0]["name"] == asset.primaryRelationName + assert any(column["name"] == "owner" for column in compact["relations"][0]["columns"]) diff --git a/tests/test_tools.py b/tests/test_tools.py index 9811fec..c3b3b5b 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,8 +1,10 @@ -"""Executor tests for compiled SQL plans.""" +"""Executor tests for the step-by-step SQL runtime.""" + +from __future__ import annotations import pytest -from app.agent.executor import execute_plan, execute_single_plan_step +from app.agent.executor import execute_current_step from app.agent.state import create_initial_state from app.config import get_settings from app.data.registry import clear_source_registry, ingest_source @@ -25,21 +27,28 @@ def test_execute_sql_step_returns_artifact_summary() -> None: asset = ingest_source("pipeline.csv", b"segment\nSMB\nEnterprise\nSMB\n") state = create_initial_state("Compare SMB vs Enterprise performance") state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest - compiled_plan = { + state["current_plan"] = { "objective": "Segment counts", - "max_steps": 3, - "plan": [ + "steps": [ { "id": 1, "purpose": "Get one sample aggregation.", - "type": "sql", - "query": f"SELECT segment, COUNT(*) AS deals FROM {asset.primaryRelationName} GROUP BY segment ORDER BY deals DESC", "output_alias": "segment_counts", + "depends_on": [], + "relations": [asset.primaryRelationName], + "required_columns": [f"{asset.primaryRelationName}.segment"], + "expected_output": "A grouped segment count table.", + "allow_empty_result": False, } ], } - outcome = execute_plan(state, compiled_plan) - assert outcome["status"] == "success" + state["generated_query"] = { + "step_id": 1, + "sql": f"SELECT segment, COUNT(*) AS deals FROM {asset.primaryRelationName} GROUP BY segment ORDER BY deals DESC", + "explanation": "Counts rows by segment.", + } + outcome = execute_current_step(state) + assert outcome["workflow_status"] == "ready_for_analysis" last = state["executed_steps"][-1] assert last["status"] == "success" assert last["artifact"]["row_count"] > 0 @@ -50,35 +59,55 @@ def test_execute_plan_marks_empty_table_as_failed() -> None: asset = ingest_source("pipeline.csv", b"segment\nSMB\nEnterprise\n") state = create_initial_state("Why did pipeline velocity drop this week?") state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest - compiled_plan = { + state["current_plan"] = { "objective": "Empty query", - "max_steps": 3, - "plan": [ + "steps": [ { "id": 1, "purpose": "Return no rows.", - "type": "sql", - "query": "SELECT 1 WHERE 1=0", "output_alias": "empty_result", + "depends_on": [], + "relations": [asset.primaryRelationName], + "required_columns": [f"{asset.primaryRelationName}.segment"], + "expected_output": "An intentionally empty result.", + "allow_empty_result": False, } ], } - outcome = execute_plan(state, compiled_plan) - assert outcome["status"] == "failed" + state["generated_query"] = { + "step_id": 1, + "sql": "SELECT 1 WHERE 1=0", + "explanation": "Returns no rows.", + } + outcome = execute_current_step(state) + assert outcome["workflow_status"] == "retry_same_step" assert state["executed_steps"][-1]["status"] == "failed" -def test_execute_single_plan_step_retry() -> None: +def test_execute_current_step_retry_attempt_is_recorded() -> None: asset = ingest_source("pipeline.csv", b"segment\nSMB\nEnterprise\n") state = create_initial_state("Test retry") state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest - step = { - "id": 1, - "purpose": "Get one row.", - "type": "sql", - "query": "SELECT 1 AS value", - "output_alias": "r1", + state["current_plan"] = { + "steps": [ + { + "id": 1, + "purpose": "Get one row.", + "output_alias": "r1", + "depends_on": [], + "relations": [asset.primaryRelationName], + "required_columns": [], + "expected_output": "A single-row output.", + "allow_empty_result": False, + } + ] + } + state["retry_counts"]["1"] = 1 + state["generated_query"] = { + "step_id": 1, + "sql": "SELECT 1 AS value", + "explanation": "Returns one row.", } - out = execute_single_plan_step(state, step, attempt=2) - assert out["status"] == "success" + out = execute_current_step(state) + assert out["workflow_status"] == "ready_for_analysis" assert state["executed_steps"][-1]["attempt"] == 2 diff --git a/ui/src/api/client.ts b/ui/src/api/client.ts index e56d155..ff6d94c 100644 --- a/ui/src/api/client.ts +++ b/ui/src/api/client.ts @@ -88,7 +88,21 @@ export async function request(path: string, options: RequestOptions = {}): Pr return (await response.text()) as T; } - return (await response.json()) as T; + if (response.status === 204 || response.status === 205) { + return undefined as T; + } + + const text = await response.text(); + if (!text) { + return undefined as T; + } + + const contentType = response.headers.get("Content-Type") ?? ""; + if (contentType.includes("application/json")) { + return JSON.parse(text) as T; + } + + return text as T; } /** Authenticated request helper — passes `Authorization: Bearer` for you. */ diff --git a/ui/src/api/uploads.test.ts b/ui/src/api/uploads.test.ts new file mode 100644 index 0000000..2ae0d14 --- /dev/null +++ b/ui/src/api/uploads.test.ts @@ -0,0 +1,27 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +describe("uploads api", () => { + afterEach(() => { + vi.unstubAllGlobals(); + vi.restoreAllMocks(); + }); + + it("treats a 204 delete response as success", async () => { + const fetchMock = vi.fn().mockResolvedValue(new Response(null, { status: 204 })); + vi.stubGlobal("fetch", fetchMock); + + const { deleteUpload } = await import("@/api/uploads"); + + await expect(deleteUpload("source_123", "token_123")).resolves.toBeUndefined(); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(fetchMock).toHaveBeenCalledWith( + "http://localhost:8000/uploads/source_123", + expect.objectContaining({ + method: "DELETE", + headers: expect.any(Headers), + }), + ); + expect((fetchMock.mock.calls[0]?.[1]?.headers as Headers).get("Authorization")).toBe("Bearer token_123"); + }); +});