diff --git a/app/agent/analysis.py b/app/agent/analysis.py index f13e160..87ad11f 100644 --- a/app/agent/analysis.py +++ b/app/agent/analysis.py @@ -8,9 +8,11 @@ from app.agent.state import AnalysisState from app.llm import get_llm_client from app.prompts import render_prompt -from app.schemas import AnalysisRenderResponse +from app.schemas import AnalysisRenderResponse, ApprovedClaim +from app.utils.logging import get_logger _ANALYSIS_RENDER_ATTEMPTS = 2 +logger = get_logger(__name__) def _build_analysis_render_prompt( @@ -26,18 +28,62 @@ def _build_analysis_render_prompt( ) +def _render_fallback_analysis(claims: list[ApprovedClaim], answer_status: str) -> str: + substantive = [claim for claim in claims if claim.kind != "caveat"] + caveats = [claim for claim in claims if claim.kind == "caveat"] + + if answer_status == "contradicted_premise" and substantive: + lines = [substantive[0].statement] + lines.extend(f"- {claim.statement}" for claim in substantive[1:3]) + if caveats: + lines.append("") + lines.append("Some requested breakdowns remain unresolved:") + lines.extend(f"- {claim.statement}" for claim in caveats[:2]) + return "\n".join(lines) + + if answer_status == "partial_answer" and substantive: + lines = ["The available evidence establishes part of the answer, but not the full requested breakdown."] + lines.extend(f"- {claim.statement}" for claim in substantive[:3]) + if caveats: + lines.append("") + lines.append("Unresolved parts:") + lines.extend(f"- {claim.statement}" for claim in caveats[:2]) + return "\n".join(lines) + + if answer_status == "conflicting_evidence": + return "The available evidence is internally inconsistent, so Planera cannot validate a reliable conclusion." + + if substantive: + lines = [substantive[0].statement] + lines.extend(f"- {claim.statement}" for claim in substantive[1:3]) + return "\n".join(lines) + + if caveats: + lines = [caveats[0].statement] + if len(caveats) > 1: + lines.extend(f"- {claim.statement}" for claim in caveats[1:3]) + return "\n".join(lines) + + return "The workflow could not validate a reliable comparison from the available results." + + def run_analysis_narrative(state: AnalysisState) -> AnalysisState: """Produce markdown-friendly analysis from query, objective, and step outputs.""" - workflow = state.get("workflow_status", "") - if workflow in ("planner_failed", "execution_failed"): - state["analysis"] = "The available evidence is insufficient because the workflow did not complete successfully." - return state - evidence = build_analysis_evidence(state) - approved_claims, expected_status = build_approved_claims(evidence) + caveat_steps = [ + step + for step in state.get("executed_steps") or [] + if step.get("status") in {"failed", "invalid"} + or (step.get("status") == "success" and step.get("validation_status") == "partial") + ] + approved_claims, expected_status = build_approved_claims(evidence, unresolved_steps=caveat_steps) + state["answer_status"] = expected_status if not approved_claims: - state["analysis"] = "The approved claims are insufficient to answer the question with the available evidence." + state["analysis"] = _render_fallback_analysis([], expected_status) + return state + if all(claim.kind == "caveat" for claim in approved_claims): + state["analysis"] = _render_fallback_analysis(approved_claims, expected_status) return state approved_claims_json = json.dumps([claim.model_dump() for claim in approved_claims], indent=2) @@ -56,11 +102,10 @@ def run_analysis_narrative(state: AnalysisState) -> AnalysisState: raise continue + state["answer_status"] = parsed.answer_status state["analysis"] = parsed.analysis_markdown.strip() or "No analysis text was returned." return state except Exception as exc: # pragma: no cover - defensive - state["analysis"] = ( - f"The analysis step could not complete ({exc!s}). " - "Review the executed steps and trace for raw outputs." - ) + logger.warning("Analysis rendering fell back to deterministic summary: %s", exc, exc_info=True) + state["analysis"] = _render_fallback_analysis(approved_claims, expected_status) return state diff --git a/app/agent/analysis_grounding.py b/app/agent/analysis_grounding.py index 6911a9e..c5731bb 100644 --- a/app/agent/analysis_grounding.py +++ b/app/agent/analysis_grounding.py @@ -9,7 +9,7 @@ from typing import Any from app.agent.state import AnalysisState -from app.schemas import AnalysisEvidence, AnalysisRenderResponse, ApprovedClaim, EvidenceItem, EvidenceValue +from app.schemas import AnalysisEvidence, AnalysisRenderResponse, ApprovedClaim, EvidenceItem, EvidenceValue, StepExpectation _NEGATIVE_PREMISE_TERMS = ( "drop", @@ -33,6 +33,10 @@ ) _CURRENT_TERMS = ("current", "latest", "this") _PREVIOUS_TERMS = ("previous", "prior", "last") +_DEFAULT_METRIC_DIRECTIONS = { + "pipeline_velocity_days": "lower_is_better", + "avg_pipeline_velocity_days": "lower_is_better", +} _BLOCKED_TERMS = ( "stable", "strong", @@ -58,6 +62,14 @@ "Evidence", "Question", } +_INTERNAL_LEAK_TERMS = ( + "answer_status must be", + "review executed steps and trace for raw outputs", + "validation feedback", + "validator", + "traceback", + "exception", +) def _is_number(value: Any) -> bool: @@ -83,24 +95,144 @@ def _infer_premise_hint(question: str) -> str: return "" -def _row_label(row: dict[str, Any], non_numeric_columns: list[str], fallback: str) -> str: - values = [str(row[column]) for column in non_numeric_columns if row.get(column) not in (None, "")] - if not values: - return fallback - if len(values) == 1: - return values[0] - return " | ".join(values) +def _normalize_metric_key(metric: str) -> str: + return re.sub(r"[^a-z0-9]+", "_", metric.lower()).strip("_") -def _extract_entities(row: dict[str, Any], non_numeric_columns: list[str]) -> list[str]: - seen: list[str] = [] +def _as_expectation(value: dict[str, Any] | StepExpectation | None) -> StepExpectation: + if isinstance(value, StepExpectation): + return value + return StepExpectation.model_validate(value or {}) + + +def _candidate_primary_metrics(state: AnalysisState) -> list[str]: + candidates: list[str] = [] + for candidate in ( + state.get("metric"), + (state.get("compiled_plan") or {}).get("metric"), + ): + if candidate and candidate not in candidates: + candidates.append(str(candidate)) + + plan = state.get("compiled_plan") or {} + for step in plan.get("plan") or []: + expectation = _as_expectation(step.get("expectation")) + if expectation.step_category != "premise_check": + continue + for metric in expectation.expected_metric_columns: + if metric and metric not in candidates: + candidates.append(metric) + return candidates + + +def _resolve_primary_metric(state: AnalysisState, available_metrics: set[str]) -> str: + candidates = _candidate_primary_metrics(state) + normalized_available = {_normalize_metric_key(metric): metric for metric in available_metrics} + for candidate in candidates: + if candidate in available_metrics: + return candidate + normalized = _normalize_metric_key(candidate) + if normalized in normalized_available: + return normalized_available[normalized] + for available_key, available_metric in normalized_available.items(): + if normalized and (normalized in available_key or available_key in normalized): + return available_metric + return candidates[0] if candidates else "" + + +def _resolve_metric_direction(state: AnalysisState, primary_metric: str) -> str: + plan = state.get("compiled_plan") or {} + metric_direction = str(plan.get("metric_direction") or "").strip() + if metric_direction: + return metric_direction + + normalized = _normalize_metric_key(primary_metric) + return _DEFAULT_METRIC_DIRECTIONS.get(normalized, "") + + +def _normalize_period_label(column: str, value: Any, expectation: StepExpectation) -> str: + if value in (None, ""): + return "" + + text = str(value).strip() + lowered_column = column.lower() + if isinstance(value, bool): + if "current_period" in lowered_column: + return "current_week" if value else "previous_week" + if "previous_period" in lowered_column: + return "previous_week" if value else "current_week" + if expectation.expected_period_column and expectation.expected_period_column == column and "period" in lowered_column: + return "current_period" if value else "previous_period" + return text.lower() + + lowered_text = text.lower().replace(" ", "_") + if any(term in lowered_text for term in _CURRENT_TERMS): + return "current_week" + if any(term in lowered_text for term in _PREVIOUS_TERMS): + return "previous_week" + return lowered_text + + +def _extract_period_label( + row: dict[str, Any], + non_numeric_columns: list[str], + expectation: StepExpectation, +) -> tuple[str, set[str]]: + period_columns: set[str] = set() + if expectation.expected_period_column and expectation.expected_period_column in row: + period_columns.add(expectation.expected_period_column) + for column in non_numeric_columns: + lowered = column.lower() + if lowered in {"current_period", "previous_period"} or "period" in lowered: + period_columns.add(column) + for column in non_numeric_columns: + if column not in period_columns: + continue + label = _normalize_period_label(column, row.get(column), expectation) + if label and label not in {"false", "true"}: + return label, period_columns + return "", period_columns + + +def _normalized_dimensions( + row: dict[str, Any], + non_numeric_columns: list[str], + period_columns: set[str], +) -> dict[str, str]: + dimensions: dict[str, str] = {} + for column in non_numeric_columns: + if column in period_columns: + continue value = row.get(column) if value in (None, ""): continue - as_text = str(value) - if as_text not in seen: - seen.append(as_text) + dimensions[column] = str(value) + return dimensions + + +def _row_label( + dimensions: dict[str, str], + period_label: str, + fallback: str, +) -> str: + values = list(dimensions.values()) + if values and period_label: + return f"{' | '.join(values)} | {period_label}" + if values: + return values[0] if len(values) == 1 else " | ".join(values) + if period_label: + return period_label + return fallback + + +def _extract_entities(dimensions: dict[str, str], period_label: str, row_label: str) -> list[str]: + seen: list[str] = [] + for value in [*dimensions.values(), period_label, row_label]: + if not value: + continue + if value not in seen: + seen.append(value) return seen @@ -109,10 +241,14 @@ def build_analysis_evidence(state: AnalysisState) -> AnalysisEvidence: items: list[EvidenceItem] = [] allowed_entities: list[str] = [] + available_metrics: set[str] = set() for step in state.get("executed_steps") or []: if step.get("status") != "success": continue + if step.get("validation_status") not in (None, "valid", "partial"): + continue + expectation = _as_expectation(step.get("expectation")) artifact = step.get("artifact") or {} preview_rows = artifact.get("preview_rows") or [] columns = artifact.get("columns") or [] @@ -125,9 +261,13 @@ def build_analysis_evidence(state: AnalysisState) -> AnalysisEvidence: if any(_is_number(row.get(column)) for row in preview_rows) ] non_numeric_columns = [column for column in columns if column not in numeric_columns] + available_metrics.update(numeric_columns) for index, row in enumerate(preview_rows, start=1): - entities = _extract_entities(row, non_numeric_columns) + period_label, period_columns = _extract_period_label(row, non_numeric_columns, expectation) + dimensions = _normalized_dimensions(row, non_numeric_columns, period_columns) + row_label = _row_label(dimensions, period_label, fallback=f"row_{index}") + entities = _extract_entities(dimensions, period_label, row_label) for entity in entities: if entity not in allowed_entities: allowed_entities.append(entity) @@ -137,17 +277,22 @@ def build_analysis_evidence(state: AnalysisState) -> AnalysisEvidence: id=f"{artifact.get('alias', step['output_alias'])}_row_{index}", source_alias=artifact.get("alias", step["output_alias"]), source_purpose=step["purpose"], - row_label=_row_label(row, non_numeric_columns, fallback=f"row_{index}"), + row_label=row_label, + step_category=expectation.step_category, + comparison_type=expectation.comparison_type, + period_label=period_label, + dimensions=dimensions, entities=entities, metrics=numeric_columns, values=values, ) ) + primary_metric = _resolve_primary_metric(state, available_metrics) return AnalysisEvidence( question=state["query"], - primary_metric=state.get("metric", ""), - metric_direction=(state.get("compiled_plan") or {}).get("metric_direction", ""), + primary_metric=primary_metric, + metric_direction=_resolve_metric_direction(state, primary_metric), premise_hint=_infer_premise_hint(state["query"]), items=items, allowed_entities=allowed_entities, @@ -167,7 +312,7 @@ def _group_items_by_source(evidence: AnalysisEvidence) -> dict[str, list[Evidenc def _sort_period_pair(items: list[EvidenceItem]) -> list[EvidenceItem]: def score(item: EvidenceItem) -> int: - lowered = item.row_label.lower() + lowered = (item.period_label or item.row_label).lower() if any(term in lowered for term in _PREVIOUS_TERMS): return 0 if any(term in lowered for term in _CURRENT_TERMS): @@ -182,13 +327,24 @@ def _build_premise_claim(evidence: AnalysisEvidence) -> tuple[ApprovedClaim | No return None, None for source_alias, items in _group_items_by_source(evidence).items(): - matching = [item for item in items if evidence.primary_metric in item.metrics] + matching = [ + item + for item in items + if evidence.primary_metric in item.metrics + and item.step_category == "premise_check" + ] + if len(matching) < 2: + matching = [ + item + for item in items + if evidence.primary_metric in item.metrics and not item.dimensions + ] if len(matching) < 2: continue ordered = _sort_period_pair(matching[:2]) left, right = ordered[0], ordered[1] - left_value = _value_map(left).get(evidence.primary_metric) - right_value = _value_map(right).get(evidence.primary_metric) + left_value = _value_map(left).get(evidence.primary_metric) or _value_map(left).get(f"{left.row_label}.{evidence.primary_metric}") + right_value = _value_map(right).get(evidence.primary_metric) or _value_map(right).get(f"{right.row_label}.{evidence.primary_metric}") if left_value is None or right_value is None: continue @@ -205,12 +361,16 @@ def _build_premise_claim(evidence: AnalysisEvidence) -> tuple[ApprovedClaim | No (evidence.premise_hint == "deterioration" and performance_change == "improved") or (evidence.premise_hint == "improvement" and performance_change == "deteriorated") ) - statement = ( - f"The primary metric {evidence.primary_metric} was {left_value} for {left.row_label} and {right_value} for {right.row_label}. " - f"The metric direction is {evidence.metric_direction}." - ) if contradicted: - statement += f" This does not support a {evidence.premise_hint} premise." + statement = ( + f"The premise is not supported. {evidence.primary_metric} was {left_value} for {left.row_label} and " + f"{right_value} for {right.row_label}, and {evidence.metric_direction.replace('_', ' ')}." + ) + else: + statement = ( + f"The available evidence supports the primary comparison. {evidence.primary_metric} was {left_value} " + f"for {left.row_label} and {right_value} for {right.row_label}, and {evidence.metric_direction.replace('_', ' ')}." + ) return ( ApprovedClaim( @@ -233,7 +393,7 @@ def _build_premise_claim(evidence: AnalysisEvidence) -> tuple[ApprovedClaim | No def _build_comparison_claims(evidence: AnalysisEvidence) -> list[ApprovedClaim]: claims: list[ApprovedClaim] = [] for source_alias, items in _group_items_by_source(evidence).items(): - if len(items) != 2: + if len(items) != 2 or any(item.dimensions for item in items): continue ordered = _sort_period_pair(items) left, right = ordered[0], ordered[1] @@ -261,9 +421,65 @@ def _build_comparison_claims(evidence: AnalysisEvidence) -> list[ApprovedClaim]: return claims -def _build_row_observation_claims(evidence: AnalysisEvidence) -> list[ApprovedClaim]: +def _build_grouped_period_comparison_claims(evidence: AnalysisEvidence) -> tuple[list[ApprovedClaim], set[str]]: + claims: list[ApprovedClaim] = [] + covered_item_ids: set[str] = set() + + for source_alias, items in _group_items_by_source(evidence).items(): + comparable = [ + item + for item in items + if item.period_label and item.dimensions + ] + if len(comparable) < 2: + continue + + grouped: dict[tuple[tuple[str, str], ...], list[EvidenceItem]] = defaultdict(list) + for item in comparable: + grouped[tuple(sorted(item.dimensions.items()))].append(item) + + for group_items in grouped.values(): + period_values = {item.period_label for item in group_items if item.period_label} + if len(period_values) < 2: + continue + ordered = _sort_period_pair(group_items)[:2] + left, right = ordered[0], ordered[1] + common_metrics = [metric for metric in left.metrics if metric in right.metrics] + if not common_metrics: + continue + group_label = " | ".join(left.dimensions.values()) or left.row_label + for metric in common_metrics[:4]: + left_value = _value_map(left).get(metric) + right_value = _value_map(right).get(metric) + if left_value is None or right_value is None: + continue + claims.append( + ApprovedClaim( + id=f"claim_{source_alias}_{group_label}_{metric}".replace(" ", "_").replace("|", "_").lower(), + kind="comparison", + statement=( + f"For {group_label}, {metric} was {left_value} for {left.period_label} and " + f"{right_value} for {right.period_label}." + ), + entities=[entity for entity in [group_label, left.period_label, right.period_label, *left.entities, *right.entities] if entity], + metrics=[metric], + source_aliases=[source_alias], + values=[ + EvidenceValue(label=f"{group_label}.{left.period_label}.{metric}", value=left_value), + EvidenceValue(label=f"{group_label}.{right.period_label}.{metric}", value=right_value), + ], + ) + ) + covered_item_ids.update({left.id, right.id}) + return claims[:8], covered_item_ids + + +def _build_row_observation_claims(evidence: AnalysisEvidence, covered_item_ids: set[str] | None = None) -> list[ApprovedClaim]: claims: list[ApprovedClaim] = [] + covered_item_ids = covered_item_ids or set() for item in evidence.items: + if item.id in covered_item_ids: + continue if not item.metrics: continue value_map = _value_map(item) @@ -279,17 +495,81 @@ def _build_row_observation_claims(evidence: AnalysisEvidence) -> list[ApprovedCl entities=item.entities or [item.row_label], metrics=item.metrics, source_aliases=[item.source_alias], - values=[EvidenceValue(label=metric, value=value_map[metric]) for metric in item.metrics if metric in value_map], + values=[ + EvidenceValue(label=f"{item.row_label}.{metric}", value=value_map[metric]) + for metric in item.metrics + if metric in value_map + ], ) ) return claims[:8] -def build_approved_claims(evidence: AnalysisEvidence) -> tuple[list[ApprovedClaim], str]: +def _build_caveat_claims(unresolved_steps: list[dict[str, Any]]) -> list[ApprovedClaim]: + def _step_expectation(step: dict[str, Any]) -> StepExpectation: + return _as_expectation(step.get("expectation")) + + def _metric_text(expectation: StepExpectation) -> str: + metrics = expectation.expected_metric_columns + if not metrics: + return "the required metric" + return ", ".join(metrics) + + claims: list[ApprovedClaim] = [] + for step in unresolved_steps: + status = step.get("status", "failed") + purpose = str(step.get("purpose") or "a workflow step").strip() + expectation = _step_expectation(step) + validation_reason = str(step.get("validation_reason") or step.get("error") or "").lower() + repaired = int(step.get("attempt") or 1) > 1 + + if ( + expectation.step_category == "premise_check" + and "missing expected columns" in validation_reason + and expectation.expected_metric_columns + ): + subject = "the premise-check query" if not repaired else "the repaired premise-check query" + statement = f"The workflow could not validate a reliable comparison because {subject} did not return the required metric {_metric_text(expectation)}." + elif ( + expectation.step_category == "premise_check" + and ("distinct periods" in validation_reason or "one comparable row per period" in validation_reason) + ): + subject = "the premise-check query" if not repaired else "the repaired premise-check query" + statement = f"The workflow could not validate a reliable comparison because {subject} did not return a valid period-by-period result structure." + elif status == "partial": + statement = f"The workflow only validated a partially comparable result for {purpose}." + elif status == "invalid": + statement = f"The workflow could not validate a reliable result for {purpose}." + else: + statement = f"The workflow could not fully establish {purpose} from the available execution results." + claims.append( + ApprovedClaim( + id=f"claim_caveat_{step.get('id', 'step')}", + kind="caveat", + statement=statement, + source_aliases=[str(step.get("output_alias") or "")] if step.get("output_alias") else [], + ) + ) + return claims + + +def _claims_conflict(claims: list[ApprovedClaim]) -> bool: + values_by_label: dict[str, set[str]] = defaultdict(set) + for claim in claims: + for value in claim.values: + values_by_label[value.label].add(value.value) + return any(len(values) > 1 for values in values_by_label.values()) + + +def build_approved_claims( + evidence: AnalysisEvidence, + unresolved_steps: list[dict[str, Any]] | None = None, +) -> tuple[list[ApprovedClaim], str]: """Build deterministic claims and the expected final answer status.""" if not evidence.items: - return [], "insufficient_evidence" + caveat_claims = _build_caveat_claims(unresolved_steps or []) + return caveat_claims, "insufficient_evidence" claims: list[ApprovedClaim] = [] expected_status = "answered" @@ -297,11 +577,13 @@ def build_approved_claims(evidence: AnalysisEvidence) -> tuple[list[ApprovedClai premise_claim, premise_status = _build_premise_claim(evidence) if premise_claim is not None: claims.append(premise_claim) - if premise_status is not None: - expected_status = premise_status claims.extend(_build_comparison_claims(evidence)) - claims.extend(_build_row_observation_claims(evidence)) + grouped_claims, covered_item_ids = _build_grouped_period_comparison_claims(evidence) + claims.extend(grouped_claims) + claims.extend(_build_row_observation_claims(evidence, covered_item_ids=covered_item_ids)) + caveat_claims = _build_caveat_claims(unresolved_steps or []) + claims.extend(caveat_claims) deduped: list[ApprovedClaim] = [] seen_statements: set[str] = set() @@ -311,6 +593,18 @@ def build_approved_claims(evidence: AnalysisEvidence) -> tuple[list[ApprovedClai seen_statements.add(claim.statement) deduped.append(claim) + substantive_claims = [claim for claim in deduped if claim.kind != "caveat"] + if _claims_conflict(substantive_claims): + expected_status = "conflicting_evidence" + elif premise_status is not None: + expected_status = premise_status + elif substantive_claims and caveat_claims: + expected_status = "partial_answer" + elif substantive_claims: + expected_status = "answered" + else: + expected_status = "insufficient_evidence" + return deduped, expected_status @@ -329,6 +623,10 @@ def _candidate_entity_phrases(text: str) -> set[str]: return candidates +def _first_nonempty_line(text: str) -> str: + return next((line.strip() for line in text.splitlines() if line.strip()), "") + + def validate_rendered_analysis( response: AnalysisRenderResponse, approved_claims: list[ApprovedClaim], @@ -347,6 +645,8 @@ def validate_rendered_analysis( raise ValueError(f"answer_status must be {expected_status}, got {response.answer_status}") selected_claims = [claim_by_id[claim_id] for claim_id in used_ids if claim_id in claim_by_id] + if expected_status == "partial_answer" and not any(claim.kind == "caveat" for claim in selected_claims): + raise ValueError("A partial answer must cite at least one caveat claim.") allowed_numbers = { token for claim in selected_claims @@ -360,6 +660,11 @@ def validate_rendered_analysis( for term in _BLOCKED_TERMS: if term in lowered_analysis: raise ValueError(f"Analysis used unsupported wording: {term}") + for term in _INTERNAL_LEAK_TERMS: + if term in lowered_analysis: + raise ValueError(f"Analysis leaked internal workflow text: {term}") + if re.search(r"\bclaim_[a-zA-Z0-9_]+\b", response.analysis_markdown): + raise ValueError("Analysis leaked internal claim identifiers.") allowed_entities = { entity @@ -380,8 +685,11 @@ def validate_rendered_analysis( if similar >= 0.82: raise ValueError(f"Analysis changed an approved entity name: {candidate}") + first_line = _first_nonempty_line(response.analysis_markdown) + lowered_first_line = first_line.lower() if expected_status == "contradicted_premise": - first_line = next((line.strip() for line in response.analysis_markdown.splitlines() if line.strip()), "") - lowered_first_line = first_line.lower() if "does not support" not in lowered_first_line and "contradict" not in lowered_first_line: raise ValueError("A contradicted premise must be stated clearly in the first sentence.") + if expected_status == "answered" and any(claim.kind == "premise_check" for claim in selected_claims): + if "supports" not in lowered_first_line and "answers" not in lowered_first_line and "shows" not in lowered_first_line: + raise ValueError("A fully answered verdict should be stated clearly in the first sentence.") diff --git a/app/agent/executor.py b/app/agent/executor.py index b9915e5..45f5bd5 100644 --- a/app/agent/executor.py +++ b/app/agent/executor.py @@ -7,9 +7,10 @@ import duckdb import pandas as pd +from app.agent.metric_aliases import canonical_metric_alias, canonical_metric_aliases, canonicalize_sql_metric_aliases from app.agent.state import AnalysisState from app.data.semantic_model import get_semantic_context, new_duckdb_connection -from app.schemas import ArtifactSummary, CompiledPlanStep, ExecutedStep +from app.schemas import ArtifactSummary, CompiledPlanStep, ExecutedStep, StepExpectation SAFE_BUILTINS: dict[str, Any] = { @@ -60,12 +61,10 @@ def _register_artifacts(conn: duckdb.DuckDBPyConnection, state: AnalysisState) - conn.register(alias, artifact) -def _execute_sql(state: AnalysisState, step: dict[str, Any]) -> ArtifactSummary: +def _execute_sql(state: AnalysisState, step: dict[str, Any]) -> pd.DataFrame: conn = new_duckdb_connection() _register_artifacts(conn, state) - frame = conn.execute(step["code"]).fetchdf() - state["artifacts"][step["output_alias"]] = frame - return _summarize_artifact(step["output_alias"], frame) + return conn.execute(step["code"]).fetchdf() def _execute_pandas(state: AnalysisState, step: dict[str, Any]) -> ArtifactSummary: @@ -89,22 +88,153 @@ def _empty_table_failure(artifact: ArtifactSummary) -> bool: return artifact.artifact_type == "table" and artifact.row_count == 0 -def compiled_plan_row_to_internal(row: dict[str, Any] | CompiledPlanStep) -> dict[str, Any]: +def compiled_plan_row_to_internal( + row: dict[str, Any] | CompiledPlanStep, + dataset_context: dict[str, Any] | None = None, +) -> dict[str, Any]: """Map a compiled plan step to the executor shape used by `_execute_sql`.""" if isinstance(row, CompiledPlanStep): row = row.model_dump() sid = row["id"] alias = row.get("output_alias") or f"step_{sid}" + expectation = _as_expectation(row["expectation"]) + expectation.expected_metric_columns = canonical_metric_aliases(expectation.expected_metric_columns, dataset_context) return { "id": str(sid), "kind": "sql", "purpose": row["purpose"], - "code": row["query"], + "code": canonicalize_sql_metric_aliases(row["query"], dataset_context), + "expectation": expectation, + "dataset_context": dataset_context, "output_alias": alias, } +def _as_expectation(value: dict[str, Any] | StepExpectation) -> StepExpectation: + if isinstance(value, StepExpectation): + return value + return StepExpectation.model_validate(value) + + +def _missing_expected_columns( + columns: list[str], + expectation: StepExpectation, + dataset_context: dict[str, Any] | None = None, +) -> list[str]: + present = {canonical_metric_alias(column, dataset_context) for column in columns} + expected_columns = [ + *expectation.expected_grouping_columns, + *canonical_metric_aliases(expectation.expected_metric_columns, dataset_context), + *( [expectation.expected_period_column] if expectation.expected_period_column else [] ), + ] + seen: set[str] = set() + missing: list[str] = [] + for column in expected_columns: + if not column or column in seen: + continue + seen.add(column) + if column not in present: + missing.append(column) + return missing + + +def _normalized_validation_frame(frame: pd.DataFrame) -> pd.DataFrame: + """Return a validation-safe frame with duplicate column names removed.""" + + if frame.columns.is_unique: + return frame + return frame.loc[:, ~frame.columns.duplicated()].copy() + + +def _effective_grouping_columns(expectation: StepExpectation) -> list[str]: + """Return grouping columns relevant for comparison, excluding the period column itself.""" + + seen: set[str] = set() + result: list[str] = [] + for column in expectation.expected_grouping_columns: + if not column or column == expectation.expected_period_column or column in seen: + continue + seen.add(column) + result.append(column) + return result + + +def _validate_step_expectation(expectation: StepExpectation) -> str | None: + if expectation.step_category == "premise_check" and not expectation.expected_metric_columns: + return "Premise-check steps must declare at least one expected metric column." + if expectation.step_category == "premise_check" and expectation.comparison_type != "period_comparison": + return "Premise-check steps must use period_comparison expectations." + if expectation.comparison_type == "period_comparison" and not expectation.expected_period_column: + return "Period comparison steps must declare an expected period column." + if expectation.comparison_type == "grouped_breakdown" and not expectation.expected_grouping_columns: + return "Grouped breakdown steps must declare expected grouping columns." + if expectation.requires_distinct_periods and not expectation.expected_period_column: + return "Distinct-period validation requires an expected period column." + if expectation.min_expected_rows < 1: + return "Step expectations must require at least one result row." + return None + + +def _validate_preview_shape( + internal: dict[str, Any], + preview_columns: list[str], + dataset_context: dict[str, Any] | None = None, + original_expectation: StepExpectation | None = None, +) -> str | None: + expectation = _as_expectation(internal["expectation"]) + if original_expectation is not None and expectation.model_dump() != original_expectation.model_dump(): + return "The repaired step changed the expected output shape." + expectation_error = _validate_step_expectation(expectation) + if expectation_error: + return expectation_error + + missing = _missing_expected_columns(preview_columns, expectation, dataset_context) + if missing: + return f"The step preview is missing expected columns: {', '.join(missing)}." + return None + + +def _validate_result_shape(frame: pd.DataFrame, internal: dict[str, Any]) -> tuple[str, str | None]: + expectation = _as_expectation(internal["expectation"]) + dataset_context = internal.get("dataset_context") + frame = _normalized_validation_frame(frame) + expectation_error = _validate_step_expectation(expectation) + if expectation_error: + return "invalid", expectation_error + + if frame.empty: + return "invalid", "The step returned no rows." + + missing = _missing_expected_columns(list(frame.columns), expectation, dataset_context) + if missing: + return "invalid", f"The result is missing expected columns: {', '.join(missing)}." + + if expectation.requires_distinct_periods and expectation.expected_period_column: + distinct_periods = frame[expectation.expected_period_column].dropna().astype(str).nunique() + if distinct_periods < 2: + return "invalid", "The comparison result did not return at least two comparable periods." + effective_grouping_columns = _effective_grouping_columns(expectation) + if expectation.step_category == "premise_check" and not effective_grouping_columns: + if distinct_periods != len(frame): + return "invalid", "The premise-check result did not preserve one comparable row per period." + if effective_grouping_columns: + comparable_columns = [*effective_grouping_columns, expectation.expected_period_column] + comparable_frame = frame[comparable_columns].dropna(subset=effective_grouping_columns) + grouped_counts = comparable_frame.groupby(effective_grouping_columns, dropna=False)[expectation.expected_period_column].nunique() + comparable_groups = int((grouped_counts >= 2).sum()) + total_groups = int(len(grouped_counts)) + if comparable_groups == 0: + return "invalid", "The grouped comparison did not return any groups with comparable periods." + if comparable_groups < total_groups: + return "partial", f"Only {comparable_groups} of {total_groups} groups returned comparable periods." + + if len(frame) < expectation.min_expected_rows: + return "invalid", f"The step returned {len(frame)} rows, below the expected minimum of {expectation.min_expected_rows}." + + return "valid", None + + def preflight_compiled_plan(state: AnalysisState, compiled_plan: dict[str, Any]) -> dict[str, Any]: """ Validate compiled SQL steps against the active runtime before execution. @@ -116,12 +246,30 @@ def preflight_compiled_plan(state: AnalysisState, compiled_plan: dict[str, Any]) _register_artifacts(conn, state) rows = list(compiled_plan.get("plan") or []) rows.sort(key=lambda r: r["id"] if isinstance(r, dict) else r.id) + previews: list[dict[str, Any]] = [] for row in rows: - internal = compiled_plan_row_to_internal(row) + internal = compiled_plan_row_to_internal(row, state.get("dataset_context")) sql = internal["code"].strip().rstrip(";") try: preview = conn.execute(f"SELECT * FROM ({sql}) AS __planera_preflight LIMIT 0").fetchdf() + preview_columns = list(preview.columns) + validation_reason = _validate_preview_shape(internal, preview_columns, state.get("dataset_context")) + if validation_reason: + return { + "status": "failed", + "failed_step_id": internal["id"], + "error": validation_reason, + "query": internal["code"], + "step_previews": previews, + } + previews.append( + { + "step_id": internal["id"], + "columns": preview_columns, + "validation_status": "valid", + } + ) conn.register(internal["output_alias"], preview) except Exception as exc: return { @@ -129,24 +277,49 @@ def preflight_compiled_plan(state: AnalysisState, compiled_plan: dict[str, Any]) "failed_step_id": internal["id"], "error": str(exc), "query": internal["code"], + "step_previews": previews, } - return {"status": "success"} + return {"status": "success", "step_previews": previews} def _try_sql_step( state: AnalysisState, internal: dict[str, Any], attempt: int, -) -> tuple[Literal["success", "failed"], ExecutedStep]: - """Run one SQL step with post-execution validation (non-empty table).""" +) -> tuple[Literal["success", "invalid", "failed"], ExecutedStep]: + """Run one SQL step with semantic/result-shape validation.""" state["total_steps"] += 1 try: - artifact = _execute_sql(state, internal) - if _empty_table_failure(artifact): + frame = _execute_sql(state, internal) + artifact = _summarize_artifact(internal["output_alias"], frame) + validation_status, validation_reason = _validate_result_shape(frame, internal) + if validation_status == "invalid": state["artifacts"].pop(internal["output_alias"], None) - raise ValueError("Step returned an empty result set.") + executed = ExecutedStep( + id=internal["id"], + kind="sql", + purpose=internal["purpose"], + code=internal["code"], + output_alias=internal["output_alias"], + attempt=attempt, + status="invalid", + validation_status="invalid", + validation_reason=validation_reason, + expectation=_as_expectation(internal["expectation"]), + error=validation_reason, + ) + state["executed_steps"].append(executed.model_dump()) + state["last_error"] = { + "step_id": internal["id"], + "message": validation_reason, + "code": internal["code"], + "kind": "invalid", + } + return "invalid", executed + + state["artifacts"][internal["output_alias"]] = frame executed = ExecutedStep( id=internal["id"], kind="sql", @@ -155,6 +328,9 @@ def _try_sql_step( output_alias=internal["output_alias"], attempt=attempt, status="success", + validation_status=validation_status, + validation_reason=validation_reason, + expectation=_as_expectation(internal["expectation"]), artifact=artifact, ) state["executed_steps"].append(executed.model_dump()) @@ -169,10 +345,17 @@ def _try_sql_step( output_alias=internal["output_alias"], attempt=attempt, status="failed", + validation_status="invalid", + expectation=_as_expectation(internal["expectation"]), error=str(exc), ) state["executed_steps"].append(executed.model_dump()) - state["last_error"] = {"step_id": internal["id"], "message": str(exc), "code": internal["code"]} + state["last_error"] = { + "step_id": internal["id"], + "message": str(exc), + "code": internal["code"], + "kind": "failed", + } return "failed", executed @@ -186,12 +369,12 @@ def execute_plan(state: AnalysisState, compiled_plan: dict[str, Any]) -> dict[st rows.sort(key=lambda r: r["id"] if isinstance(r, dict) else r.id) for row in rows: - internal = compiled_plan_row_to_internal(row) + internal = compiled_plan_row_to_internal(row, state.get("dataset_context")) status, _ = _try_sql_step(state, internal, attempt=1) - if status == "failed": + if status in {"failed", "invalid"}: sid = internal["id"] return { - "status": "failed", + "status": status, "failed_step_id": sid, "error": state["last_error"]["message"] if state["last_error"] else "Unknown error", } @@ -206,11 +389,11 @@ def execute_single_plan_step( ) -> dict[str, Any]: """Re-run a single compiled step (e.g. after repair).""" - internal = compiled_plan_row_to_internal(compiled_step) + internal = compiled_plan_row_to_internal(compiled_step, state.get("dataset_context")) status, _ = _try_sql_step(state, internal, attempt=attempt) - if status == "failed": + if status in {"failed", "invalid"}: return { - "status": "failed", + "status": status, "failed_step_id": internal["id"], "error": state["last_error"]["message"] if state["last_error"] else "Unknown error", } diff --git a/app/agent/graph.py b/app/agent/graph.py index 0bc9ef3..e8f7068 100644 --- a/app/agent/graph.py +++ b/app/agent/graph.py @@ -25,6 +25,13 @@ def _append_error(state: AnalysisState, step: str, message: str, recoverable: bo state["errors"].append({"step": step, "message": message, "recoverable": recoverable, "details": details or {}}) +def _has_usable_evidence(state: AnalysisState) -> bool: + return any( + step.get("status") == "success" and step.get("validation_status") in (None, "valid", "partial") + for step in state.get("executed_steps") or [] + ) + + def load_schema_context_node(state: AnalysisState) -> AnalysisState: step_name = "load_schema_context_node" _append_trace(state, step_name, "started", {}) @@ -83,6 +90,7 @@ def execute_plan_node(state: AnalysisState) -> AnalysisState: if outcome["status"] == "success": state["workflow_status"] = "ready_to_analyze" + state["unresolved_step_ids"] = [] _append_trace(state, step_name, "completed", {"status": "success"}) return state @@ -93,11 +101,12 @@ def execute_plan_node(state: AnalysisState) -> AnalysisState: state = repair_failed_step(state, failed_id, err) except Exception as exc: _append_error(state, "repair_planner", str(exc), recoverable=False, details={"failed_step_id": failed_id}) - state["workflow_status"] = "execution_failed" + state["unresolved_step_ids"] = [failed_id] + state["workflow_status"] = "partial_execution" if _has_usable_evidence(state) else "execution_failed" _append_trace(state, step_name, "failed", {"phase": "repair", "message": str(exc)}) return state - if state["executed_steps"] and state["executed_steps"][-1]["status"] == "failed": + if state["executed_steps"] and state["executed_steps"][-1]["status"] in {"failed", "invalid"}: state["executed_steps"].pop() plan = state["compiled_plan"] or {} @@ -108,7 +117,7 @@ def execute_plan_node(state: AnalysisState) -> AnalysisState: return state retry = execute_single_plan_step(state, step_row, attempt=2) - if retry["status"] == "failed": + if retry["status"] in {"failed", "invalid"}: _append_error( state, step_name, @@ -116,10 +125,12 @@ def execute_plan_node(state: AnalysisState) -> AnalysisState: recoverable=False, details={"failed_step_id": failed_id}, ) - state["workflow_status"] = "execution_failed" + state["unresolved_step_ids"] = [failed_id] + state["workflow_status"] = "partial_execution" if _has_usable_evidence(state) else "execution_failed" _append_trace(state, step_name, "failed", {"phase": "retry", "failed_step_id": failed_id}) else: state["workflow_status"] = "ready_to_analyze" + state["unresolved_step_ids"] = [] _append_trace(state, step_name, "completed", {"status": "success_after_repair"}) return state @@ -132,9 +143,21 @@ def analysis_node(state: AnalysisState) -> AnalysisState: state = run_analysis_narrative(state) _append_trace(state, step_name, "completed", {"length": len(state["analysis"])}) except Exception as exc: - state["analysis"] = f"The analysis step failed: {exc}" - _append_error(state, step_name, str(exc), recoverable=False) - _append_trace(state, step_name, "failed", {"message": str(exc)}) + logger.warning("%s failed: %s", step_name, exc, exc_info=True) + state["analysis"] = "The available evidence is incomplete for a full answer." + state["answer_status"] = "partial_answer" if _has_usable_evidence(state) else "insufficient_evidence" + _append_error( + state, + step_name, + "The workflow could not render a fully validated narrative for this run.", + recoverable=False, + ) + _append_trace( + state, + step_name, + "failed", + {"message": "The workflow could not render a fully validated narrative for this run."}, + ) return state diff --git a/app/agent/metric_aliases.py b/app/agent/metric_aliases.py new file mode 100644 index 0000000..574e185 --- /dev/null +++ b/app/agent/metric_aliases.py @@ -0,0 +1,101 @@ +"""Schema-aware metric alias helpers shared by planner and executor.""" + +from __future__ import annotations + +import re +from collections import defaultdict +from typing import Any + +_AGGREGATE_PREFIXES = {"avg", "sum", "count", "min", "max", "median"} +_AGGREGATE_ALIAS_PATTERN = re.compile( + r"(?i)\b(?Pavg|sum|count|min|max|median)\s*\(\s*(?P[A-Za-z_][A-Za-z0-9_\.\"`]*)\s*\)\s+AS\s+(?P[A-Za-z_][A-Za-z0-9_]*)" +) + + +def _normalize_metric_key(value: str) -> str: + return re.sub(r"[^a-z0-9]+", "", value.lower()) + + +def _strip_identifier_quotes(value: str) -> str: + return value.strip().strip('"').strip("`") + + +def _iter_measure_columns(dataset_context: dict[str, Any] | None) -> list[tuple[str, list[str]]]: + if not dataset_context: + return [] + + measures: list[tuple[str, list[str]]] = [] + for relation in dataset_context.get("relations") or []: + measure_names = set(relation.get("measure_columns") or []) + for column in relation.get("columns") or []: + column_name = column.get("name", "") + if not column_name or column_name not in measure_names: + continue + hints = list(column.get("semantic_hints") or []) + measures.append((column_name, hints)) + return measures + + +def _schema_metric_lookup(dataset_context: dict[str, Any] | None) -> dict[str, str]: + matches: dict[str, set[str]] = defaultdict(set) + for column_name, hints in _iter_measure_columns(dataset_context): + matches[_normalize_metric_key(column_name)].add(column_name) + for hint in hints: + normalized = _normalize_metric_key(hint) + if normalized: + matches[normalized].add(column_name) + + resolved: dict[str, str] = {} + for normalized, candidates in matches.items(): + if len(candidates) == 1: + resolved[normalized] = next(iter(candidates)) + return resolved + + +def resolve_metric_base_name(name: str, dataset_context: dict[str, Any] | None) -> str: + """Resolve a metric-like identifier to a canonical schema-backed base name when uniquely possible.""" + + cleaned = _strip_identifier_quotes(name) + lookup = _schema_metric_lookup(dataset_context) + return lookup.get(_normalize_metric_key(cleaned), cleaned) + + +def canonical_metric_alias(name: str, dataset_context: dict[str, Any] | None = None) -> str: + """Canonicalize aggregate metric aliases from planner expectations or result columns.""" + + cleaned = _strip_identifier_quotes(name) + match = re.match(rf"^(?P{'|'.join(sorted(_AGGREGATE_PREFIXES))})_(?P.+)$", cleaned, flags=re.IGNORECASE) + if not match: + return cleaned + prefix = match.group("prefix").lower() + base = resolve_metric_base_name(match.group("base"), dataset_context) + return f"{prefix}_{base}" + + +def canonical_metric_aliases(names: list[str], dataset_context: dict[str, Any] | None = None) -> list[str]: + """Canonicalize a list of metric aliases while preserving order.""" + + seen: set[str] = set() + result: list[str] = [] + for name in names: + canonical = canonical_metric_alias(name, dataset_context) + if canonical in seen: + continue + seen.add(canonical) + result.append(canonical) + return result + + +def canonicalize_sql_metric_aliases(query: str, dataset_context: dict[str, Any] | None = None) -> str: + """Rewrite aggregate metric aliases in SQL to canonical schema-aware names.""" + + def repl(match: re.Match[str]) -> str: + func = match.group("func") + expr = match.group("expr") + base_expr = _strip_identifier_quotes(expr.split(".")[-1]) + canonical_base = resolve_metric_base_name(base_expr, dataset_context) + canonical_alias = f"{func.lower()}_{canonical_base}" + return f"{func}({expr}) AS {canonical_alias}" + + return _AGGREGATE_ALIAS_PATTERN.sub(repl, query) + diff --git a/app/agent/planner.py b/app/agent/planner.py index 0c9892a..17858bb 100644 --- a/app/agent/planner.py +++ b/app/agent/planner.py @@ -10,10 +10,11 @@ from pydantic import ValidationError from app.agent.executor import preflight_compiled_plan +from app.agent.metric_aliases import canonical_metric_alias, canonical_metric_aliases, canonicalize_sql_metric_aliases from app.agent.state import AnalysisState from app.llm import get_llm_client from app.prompts import render_prompt -from app.schemas import CompiledPlan, RepairDecision +from app.schemas import CompiledPlan, RepairDecision, StepExpectation from app.utils.logging import get_logger logger = get_logger(__name__) @@ -21,6 +22,23 @@ _COMPILED_PLANNER_ATTEMPTS = 3 _MAX_PROMPT_RELATIONS = 4 _MAX_COLUMNS_PER_RELATION = 18 +_PREMISE_TERMS = ("drop", "decline", "decrease", "increase", "improve", "growth", "faster", "slower") + + +def _canonicalize_step_contract(step: Any, dataset_context: dict[str, Any]) -> Any: + """Normalize metric aliases in a compiled or repaired step.""" + + step.query = canonicalize_sql_metric_aliases(step.query, dataset_context) + step.expectation.expected_metric_columns = canonical_metric_aliases(step.expectation.expected_metric_columns, dataset_context) + return step + + +def _canonicalize_plan_contract(parsed: CompiledPlan, dataset_context: dict[str, Any]) -> CompiledPlan: + """Normalize metric aliases in the planner output before validation/execution.""" + + parsed.metric = canonical_metric_alias(parsed.metric, dataset_context) + parsed.plan = [_canonicalize_step_contract(step, dataset_context) for step in parsed.plan] + return parsed def _query_terms(question: str) -> set[str]: @@ -46,6 +64,8 @@ def _column_relevance_score(column: dict[str, Any], question_terms: set[str]) -> score = overlap * 3 if column.get("name", "").lower() in question_terms: score += 4 + if column.get("name") in {"period", "current_period", "previous_period"}: + score += 2 return score @@ -57,15 +77,50 @@ def _relation_relevance_score(relation: dict[str, Any], question_terms: set[str] return score -def _trim_relation_for_prompt(relation: dict[str, Any], question_terms: set[str]) -> dict[str, Any]: +def _relation_join_keys(dataset_context: dict[str, Any], relation_name: str) -> set[str]: + join_keys: set[str] = set() + for relationship in dataset_context.get("relationships") or []: + if relationship.get("left_relation") == relation_name: + join_keys.update(relationship.get("left_on") or []) + if relationship.get("right_relation") == relation_name: + join_keys.update(relationship.get("right_on") or []) + return join_keys + + +def _mapped_columns_for_question(relation: dict[str, Any], question_terms: set[str]) -> set[str]: + columns: set[str] = set() + for mapping in relation.get("semantic_mappings") or []: + mapping_terms = _field_terms(mapping.get("concept", "")) + if mapping_terms & question_terms: + columns.update(mapping.get("columns") or []) + return columns + + +def _protected_column_names(relation: dict[str, Any], dataset_context: dict[str, Any], question_terms: set[str]) -> set[str]: + protected = set(relation.get("identifier_columns") or []) + protected.update(relation.get("time_columns") or []) + protected.update(relation.get("measure_columns") or []) + protected.update(_relation_join_keys(dataset_context, relation.get("name", ""))) + protected.update(_mapped_columns_for_question(relation, question_terms)) + return protected + + +def _trim_relation_for_prompt( + relation: dict[str, Any], + question_terms: set[str], + dataset_context: dict[str, Any], +) -> dict[str, Any]: trimmed = deepcopy(relation) columns = list(trimmed.get("columns") or []) if len(columns) <= _MAX_COLUMNS_PER_RELATION: return trimmed + protected_names = _protected_column_names(trimmed, dataset_context, question_terms) + ranked = sorted( columns, key=lambda column: ( + column.get("name") in protected_names, _column_relevance_score(column, question_terms), column.get("name") in (trimmed.get("identifier_columns") or []), column.get("name") in (trimmed.get("time_columns") or []), @@ -78,9 +133,11 @@ def _trim_relation_for_prompt(relation: dict[str, Any], question_terms: set[str] name = column.get("name", "") if not name or name in selected_names: continue + if len(selected) >= _MAX_COLUMNS_PER_RELATION and name not in protected_names: + continue selected.append(column) selected_names.add(name) - if len(selected) >= _MAX_COLUMNS_PER_RELATION: + if len(selected) >= _MAX_COLUMNS_PER_RELATION and protected_names.issubset(selected_names): break trimmed["columns"] = selected @@ -109,16 +166,23 @@ def _schema_subset_for_question(dataset_context: dict[str, Any], question: str) "source": dataset_context.get("source", ""), "dialect": dataset_context.get("dialect", ""), "relations": relations, + "relationships": dataset_context.get("relationships") or [], } question_terms = _query_terms(question) ranked_relations = sorted(relations, key=lambda relation: _relation_relevance_score(relation, question_terms), reverse=True) selected = ranked_relations[:_MAX_PROMPT_RELATIONS] + selected_names = {relation["name"] for relation in selected} return { "reference_date": dataset_context.get("reference_date", ""), "source": dataset_context.get("source", ""), "dialect": dataset_context.get("dialect", ""), - "relations": [_trim_relation_for_prompt(relation, question_terms) for relation in selected], + "relations": [_trim_relation_for_prompt(relation, question_terms, dataset_context) for relation in selected], + "relationships": [ + relationship + for relationship in (dataset_context.get("relationships") or []) + if relationship.get("left_relation") in selected_names and relationship.get("right_relation") in selected_names + ], } @@ -144,6 +208,46 @@ def _planner_preflight_feedback(outcome: dict[str, Any], schema_subset: dict[str ) +def _validate_compiled_plan_semantics(parsed: CompiledPlan, query: str) -> str | None: + steps = parsed.plan + if not steps: + return "The plan must contain at least one step." + + first_expectation = steps[0].expectation + if first_expectation.step_category != "premise_check": + return "Step 1 must be a premise_check step." + if first_expectation.comparison_type != "period_comparison": + return "Step 1 must verify the primary comparison with a period_comparison step." + if not first_expectation.expected_metric_columns: + return "Step 1 must declare expected metric columns." + if not first_expectation.expected_period_column or not first_expectation.requires_distinct_periods: + return "Step 1 must declare a period column and require distinct periods." + if not parsed.metric_direction.strip(): + return "Premise-check plans must declare metric_direction." + + step_ids = {step.id for step in steps} + for index, step in enumerate(steps, start=1): + expectation = step.expectation + if index > 1 and expectation.step_category == "premise_check": + return "Only the first step may be a premise_check step." + if expectation.comparison_type == "grouped_breakdown" and not expectation.expected_grouping_columns: + return f"Step {step.id} must declare expected grouping columns for a grouped_breakdown." + if expectation.requires_distinct_periods and not expectation.expected_period_column: + return f"Step {step.id} must declare a period column when distinct periods are required." + if expectation.preserve_population_from_step_id is not None and expectation.preserve_population_from_step_id not in step_ids: + return f"Step {step.id} references an unknown preserve_population_from_step_id." + + if any(term in query.lower() for term in _PREMISE_TERMS): + later_premise_like = [ + step.id + for step in steps[1:] + if step.expectation.comparison_type == "period_comparison" and not step.expectation.expected_grouping_columns + ] + if later_premise_like: + return f"Only Step 1 should handle the top-level premise comparison, found another comparison-only step: {later_premise_like[0]}." + return None + + def _build_compiled_planner_prompt(state: AnalysisState, validation_feedback: str | None = None) -> str: schema_subset = _schema_subset_for_question(state["dataset_context"], state["query"]) relation_names = _relation_names(schema_subset) @@ -159,11 +263,15 @@ def _build_compiled_planner_prompt(state: AnalysisState, validation_feedback: st def _build_repair_prompt(state: AnalysisState, failed_step_id: str, error_message: str) -> str: plan = state.get("compiled_plan") or {} schema_subset = _schema_subset_for_question(state["dataset_context"], state["query"]) + failed_step = next((row for row in (plan.get("plan") or []) if str(row.get("id")) == str(failed_step_id)), {}) return render_prompt( "planner_repair.j2", + query=state["query"], failed_step_id=failed_step_id, error_message=error_message, plan_json=json.dumps(plan, indent=2), + failed_step_json=json.dumps(failed_step, indent=2), + failed_step_expectation_json=json.dumps((failed_step or {}).get("expectation", {}), indent=2), schema_subset_json=json.dumps(schema_subset, indent=2), ) @@ -179,6 +287,7 @@ def plan_compiled_query(state: AnalysisState) -> AnalysisState: try: decision = client.generate_json(prompt, schema=CompiledPlan) parsed = decision if isinstance(decision, CompiledPlan) else CompiledPlan.model_validate(decision) + parsed = _canonicalize_plan_contract(parsed, state["dataset_context"]) except (ValidationError, ValueError) as exc: feedback = exc.json(indent=2) if isinstance(exc, ValidationError) else str(exc) logger.warning( @@ -191,6 +300,19 @@ def plan_compiled_query(state: AnalysisState) -> AnalysisState: raise continue + semantic_error = _validate_compiled_plan_semantics(parsed, state["query"]) + if semantic_error: + feedback = semantic_error + logger.warning( + "Compiled plan semantic validation failed (attempt %s/%s): %s", + attempt, + _COMPILED_PLANNER_ATTEMPTS, + semantic_error, + ) + if attempt >= _COMPILED_PLANNER_ATTEMPTS: + raise ValueError(semantic_error) + continue + preflight = preflight_compiled_plan(state, parsed.model_dump()) if preflight["status"] == "failed": feedback = _planner_preflight_feedback(preflight, _schema_subset_for_question(state["dataset_context"], state["query"])) @@ -220,6 +342,7 @@ def repair_failed_step(state: AnalysisState, failed_step_id: str, error_message: schema=RepairDecision, ) parsed = raw if isinstance(raw, RepairDecision) else RepairDecision.model_validate(raw) + parsed.updated_step = _canonicalize_step_contract(parsed.updated_step, state["dataset_context"]) if str(parsed.updated_step.id) != str(failed_step_id): raise ValueError(f"Repair returned mismatched step id: expected {failed_step_id}, got {parsed.updated_step.id}") @@ -233,6 +356,9 @@ def repair_failed_step(state: AnalysisState, failed_step_id: str, error_message: for i, row in enumerate(steps): sid = row.get("id") if isinstance(row, dict) else row["id"] if str(sid) == str(failed_step_id): + original_expectation = StepExpectation.model_validate(row.get("expectation", {})) + if parsed.updated_step.expectation.model_dump() != original_expectation.model_dump(): + raise ValueError("Repair changed the original step expectation.") steps[i] = parsed.updated_step.model_dump() replaced = True break @@ -240,6 +366,12 @@ def repair_failed_step(state: AnalysisState, failed_step_id: str, error_message: raise ValueError(f"Failed step id {failed_step_id} not found in compiled plan.") plan["plan"] = steps + preflight = preflight_compiled_plan(state, plan) + if preflight["status"] == "failed": + raise ValueError( + "Repair did not preserve the original analytical intent: " + f"{preflight['error']}" + ) state["compiled_plan"] = plan state["repair_attempted"] = True return state diff --git a/app/agent/state.py b/app/agent/state.py index 18a1a7d..bf30564 100644 --- a/app/agent/state.py +++ b/app/agent/state.py @@ -20,9 +20,11 @@ class AnalysisState(TypedDict): artifacts: dict[str, Any] executed_steps: list[dict[str, Any]] analysis: str + answer_status: str total_steps: int last_error: dict[str, Any] | None workflow_status: str + unresolved_step_ids: list[str] trace: list[dict[str, Any]] errors: list[dict[str, Any]] @@ -41,9 +43,11 @@ def create_initial_state(query: str) -> AnalysisState: artifacts={}, executed_steps=[], analysis="", + answer_status="insufficient_evidence", total_steps=0, last_error=None, workflow_status="planning", + unresolved_step_ids=[], trace=[], errors=[], ) diff --git a/app/api/routes.py b/app/api/routes.py index 91a3ec6..d398024 100644 --- a/app/api/routes.py +++ b/app/api/routes.py @@ -1,4 +1,4 @@ -"""API routes for GTM Analytics Copilot.""" +"""API routes for the Planera analytics copilot.""" from __future__ import annotations @@ -60,6 +60,7 @@ def analyze(request: AnalyzeRequest) -> AnalyzeResponse: try: state = run_analysis(request.query) base_response = AnalyzeResponse( + answer_status=state.get("answer_status", "insufficient_evidence"), analysis=state["analysis"], trace=state.get("trace", []), executed_steps=state.get("executed_steps", []), @@ -67,6 +68,7 @@ def analyze(request: AnalyzeRequest) -> AnalyzeResponse: ) inspection_id = store_inspection(request.query, base_response) return AnalyzeResponse( + answer_status=base_response.answer_status, analysis=base_response.analysis, trace=base_response.trace, executed_steps=base_response.executed_steps, diff --git a/app/api/workspace.py b/app/api/workspace.py index 2613ac3..53390ca 100644 --- a/app/api/workspace.py +++ b/app/api/workspace.py @@ -179,6 +179,7 @@ def _build_execution_chips(response: AnalyzeResponse, primary_artifact: Artifact f"{executed_steps} workflow step{'' if executed_steps == 1 else 's'}" if executed_steps else "No executed steps", f"Output: {primary_artifact.alias}" if primary_artifact else "No output alias", f"{retry_count} retry attempt{'' if retry_count == 1 else 's'}" if retry_count else "No retries", + f"Answer: {response.answer_status.replace('_', ' ')}", f"{len(response.errors)} issue{'' if len(response.errors) == 1 else 's'}" if response.errors else "No recorded errors", ] )[:4] @@ -265,10 +266,11 @@ def _build_metadata( "Failed" if any(not item.recoverable for item in response.errors) else "Completed with review notes" - if response.errors or any(step.attempt > 1 for step in response.executed_steps) + if response.errors or any(step.attempt > 1 for step in response.executed_steps) or response.answer_status in {"partial_answer", "insufficient_evidence", "conflicting_evidence"} else "Complete" ), ), + MetadataItem(label="Answer status", value=response.answer_status.replace("_", " ")), MetadataItem(label="Verification", value="Verified" if verified else "Needs analyst review"), MetadataItem( label="Output shape", @@ -356,10 +358,12 @@ def _derive_inspection_status(response: AnalyzeResponse) -> str: steps = response.executed_steps if any(not item.recoverable for item in response.errors) or (steps and not any(step.status == "success" for step in steps)): return "error" + if response.answer_status in {"partial_answer", "insufficient_evidence", "conflicting_evidence"}: + return "warning" if ( response.errors or any(event.status in {"failed", "skipped"} for event in response.trace) - or any(step.status == "failed" or step.attempt > 1 for step in steps) + or any(step.status in {"failed", "invalid"} or step.attempt > 1 for step in steps) ): return "warning" return "valid" @@ -372,7 +376,16 @@ def _derive_confidence(response: AnalyzeResponse, primary_artifact: ArtifactSumm preview_bonus = 0.16 if primary_artifact and primary_artifact.row_count else 0.0 trace_bonus = 0.07 if any(event.status == "completed" for event in response.trace) else 0.0 error_penalty = 0.18 if any(not item.recoverable for item in response.errors) else 0.08 if response.errors else 0.0 - score = 0.46 + success_ratio * 0.24 + preview_bonus + trace_bonus - error_penalty + answer_penalty = ( + 0.15 + if response.answer_status == "conflicting_evidence" + else 0.12 + if response.answer_status == "insufficient_evidence" + else 0.06 + if response.answer_status == "partial_answer" + else 0.0 + ) + score = 0.46 + success_ratio * 0.24 + preview_bonus + trace_bonus - error_penalty - answer_penalty return _clamp(score, 0.35, 0.95) diff --git a/app/data/semantic_model.py b/app/data/semantic_model.py index 3cdcef6..f6d73fa 100644 --- a/app/data/semantic_model.py +++ b/app/data/semantic_model.py @@ -5,13 +5,14 @@ import re from dataclasses import dataclass from functools import lru_cache +from itertools import combinations from typing import Any import duckdb import pandas as pd from app.data.loader import load_data -from app.schemas import SchemaColumn, SchemaConceptMapping, SchemaManifest, SchemaRelation +from app.schemas import SchemaColumn, SchemaConceptMapping, SchemaManifest, SchemaRelation, SchemaRelationship @dataclass(frozen=True) @@ -116,12 +117,29 @@ def _build_semantic_mappings(columns: list[SchemaColumn]) -> list[SchemaConceptM return mappings[:20] -def _relation_for_frame(name: str, frame: pd.DataFrame, kind: str = "view") -> SchemaRelation: +def _derive_sources_for_column( + column_name: str, + field_origin: str, + source_lookup: dict[str, list[str]], +) -> list[str]: + if field_origin != "derived": + return [] + return [f"{relation}.{column_name}" for relation in source_lookup.get(column_name, [])][:4] + + +def _relation_for_frame( + name: str, + frame: pd.DataFrame, + kind: str = "view", + field_origin: str = "source_backed", + source_lookup: dict[str, list[str]] | None = None, +) -> SchemaRelation: columns: list[SchemaColumn] = [] identifier_columns: list[str] = [] time_columns: list[str] = [] measure_columns: list[str] = [] dimension_columns: list[str] = [] + source_lookup = source_lookup or {} for column_name in frame.columns: dtype = str(frame[column_name].dtype) @@ -130,6 +148,8 @@ def _relation_for_frame(name: str, frame: pd.DataFrame, kind: str = "view") -> S name=column_name, dtype=dtype, type_family=family, + field_origin=field_origin, + derived_from=_derive_sources_for_column(column_name, field_origin, source_lookup), semantic_hints=_semantic_hints(column_name), ) columns.append(column) @@ -157,6 +177,42 @@ def _relation_for_frame(name: str, frame: pd.DataFrame, kind: str = "view") -> S ) +def _relationship_type(left: SchemaRelation, right: SchemaRelation, join_keys: list[str]) -> str: + left_unique = all(key in left.identifier_columns for key in join_keys) + right_unique = all(key in right.identifier_columns for key in join_keys) + if left_unique and right_unique: + return "one_to_one" + if right_unique: + return "many_to_one" + if left_unique: + return "one_to_many" + return "many_to_many" + + +def _build_relationships(relations: list[SchemaRelation]) -> list[SchemaRelationship]: + relationships: list[SchemaRelationship] = [] + for left, right in combinations(relations, 2): + left_columns = {column.name for column in left.columns} + right_columns = {column.name for column in right.columns} + join_keys = sorted( + column + for column in (left_columns & right_columns) + if column.endswith("_id") or column in left.identifier_columns or column in right.identifier_columns + ) + if not join_keys: + continue + relationships.append( + SchemaRelationship( + left_relation=left.name, + right_relation=right.name, + left_on=join_keys, + right_on=join_keys, + relationship_type=_relationship_type(left, right, join_keys), + ) + ) + return relationships + + @lru_cache(maxsize=1) def get_semantic_context() -> SemanticContext: """Build and cache views plus a schema-only manifest for planning.""" @@ -164,13 +220,34 @@ def get_semantic_context() -> SemanticContext: bundle = load_data() raw_views = {name: frame.copy() for name, frame in bundle.raw_views.items()} semantic_views = {"opportunities_enriched": bundle.crm.copy()} - all_frames = {**raw_views, **semantic_views} - relations = [_relation_for_frame(name, frame, kind="view") for name, frame in all_frames.items()] + source_lookup: dict[str, list[str]] = {} + for relation_name, frame in raw_views.items(): + for column_name in frame.columns: + source_lookup.setdefault(column_name, []).append(relation_name) + + relations = [ + *[ + _relation_for_frame(name, frame, kind="table", field_origin="source_backed") + for name, frame in raw_views.items() + ], + *[ + _relation_for_frame( + name, + frame, + kind="view", + field_origin="derived", + source_lookup=source_lookup, + ) + for name, frame in semantic_views.items() + ], + ] + relationships = _build_relationships(relations) schema_manifest = SchemaManifest( reference_date=bundle.reference_date, source=bundle.source, dialect="duckdb", relations=relations, + relationships=relationships, views=[ { "name": relation.name, diff --git a/app/prompts/analysis_render.j2 b/app/prompts/analysis_render.j2 index 9121462..b4801ef 100644 --- a/app/prompts/analysis_render.j2 +++ b/app/prompts/analysis_render.j2 @@ -10,18 +10,22 @@ Hard rules: - Use entity strings exactly as provided. Do not abbreviate, normalize, paraphrase, or rename them. - Do not infer causes, drivers, intent, explanations, or business meaning unless explicitly stated in an approved claim. - Do not derive new metrics, percentages, rankings, or comparisons unless they are explicitly stated in an approved claim. +- The first sentence must answer the question directly when the approved claims establish a clear verdict. - If an approved claim contradicts the question premise, state that in the first sentence clearly and directly. +- If the approved claims support the question, state that in the first sentence before listing supporting comparisons. +- If the approved claims support only part of the question, answer the established part and clearly state what remains unresolved. - Prefer direct numeric comparisons that already appear in approved claims. - Do not use unsupported adjectives such as "stable", "strong", "healthy", "significant", "improving", or "worsening" unless the exact wording appears in an approved claim. - If the approved claims are insufficient to answer the question, say so explicitly. - If the approved claims conflict with each other, say the evidence is internally inconsistent. - Use only the minimum set of approved claims needed to answer the question. +- Do not mention claim ids or internal validation/orchestration details in the analysis text. Output rules: - Return only valid JSON. - The JSON must match this schema exactly: { - "answer_status": "answered" | "insufficient_evidence" | "contradicted_premise" | "conflicting_evidence", + "answer_status": "answered" | "partial_answer" | "insufficient_evidence" | "contradicted_premise" | "conflicting_evidence", "analysis_markdown": string, "used_claim_ids": string[] } diff --git a/app/prompts/planner_compiled.j2 b/app/prompts/planner_compiled.j2 index d57d691..1941f9b 100644 --- a/app/prompts/planner_compiled.j2 +++ b/app/prompts/planner_compiled.j2 @@ -8,10 +8,13 @@ Rules: - CRITICAL — the "max_steps" field: set it to the integer 3 always. It is the platform's fixed ceiling, not the count of steps you return. Do not set max_steps to 1 or 2 even if the plan has only one or two queries. - Every step must use "type": "sql" and put the full SQL statement in "query". - Use only these registered relation names: {{ relation_names_json }} -- Use exact column names only. Never invent, normalize, paraphrase, or rename a column. -- Use semantic mappings only to translate user language into exact schema fields. -- Do not assume the question premise is true. First verify the main metric or comparison before planning causal or grouped breakdowns. -- Prefer SQL over multiple trivial splits; combine logic when one query suffices. +{% include "planner_invariants.j2" %} +- Step 1 must be a `premise_check` step that verifies the primary comparison before any breakdown steps. +- For premise checks, preserve comparable populations and period logic. Use stable comparison shapes such as one row per period when the schema allows it. +- Later steps may investigate grounded breakdowns only if they add distinct explanatory value beyond Step 1. +- If the schema subset does not support a requested entity or dimension, do not guess a substitute. Fall back to a simpler grounded query. +- Maintain consistent aggregation grain between comparable rows and periods. +- When you alias an aggregate metric in SQL, use the canonical form `_` such as `avg_pipeline_velocity_days`. - No imports, file I/O, network calls, or plotting. - Optional "output_alias" per step for stable names; if omitted, the executor uses `step_`. @@ -23,7 +26,17 @@ Return JSON in this exact shape: "id": 1, "purpose": "string", "type": "sql", - "query": "SQL query string" + "query": "SQL query string", + "expectation": { + "step_category": "premise_check | breakdown | follow_up", + "comparison_type": "period_comparison | grouped_breakdown | distribution | single_result", + "expected_grouping_columns": ["column_name"], + "expected_metric_columns": ["column_name"], + "expected_period_column": "column_name or empty string", + "min_expected_rows": 1, + "requires_distinct_periods": false, + "preserve_population_from_step_id": null + } } ], "max_steps": 3, diff --git a/app/prompts/planner_invariants.j2 b/app/prompts/planner_invariants.j2 new file mode 100644 index 0000000..3cc496b --- /dev/null +++ b/app/prompts/planner_invariants.j2 @@ -0,0 +1,7 @@ +- Use only exact relation and column names from the schema subset. +- Never invent, rename, abbreviate, normalize, or paraphrase fields. +- Use semantic mappings only to resolve business language into exact schema fields. +- Do not assume the user premise is true. +- Do not replace an invalid identifier with a merely plausible valid one unless justified by semantic mappings. +- Preserve analytical intent, comparable populations, filters, and grain. +- Prefer fewer, higher-value SQL steps. diff --git a/app/prompts/planner_repair.j2 b/app/prompts/planner_repair.j2 index 8e697d2..548988c 100644 --- a/app/prompts/planner_repair.j2 +++ b/app/prompts/planner_repair.j2 @@ -6,13 +6,29 @@ Rules: - Follow the response schema exactly. - repair_action must be "replace_step". - updated_step must use "type": "sql", the same id as the failed step ({{ failed_step_id }}), and a fixed "query". -- Use only exact relation and column names from the schema manifest. -- Use semantic mappings only to translate user language into exact schema fields. -- Do not invent fields or rename columns. +{% include "planner_invariants.j2" %} +- Preserve the analytical intent of the original user query. +- Preserve the original comparison logic, grouping logic, population/filter logic, output shape, period logic, and minimum row requirements. +- Preserve the required metric columns from the original expectation block. +- Preserve any distinct-period requirement from the original expectation block. +- When the repaired SQL aliases an aggregate metric, use the canonical form `_`. +- Do not replace an invalid identifier with a semantically unrelated valid column. +- If the schema mapping does not justify a replacement, prefer a simpler valid query over a misleading one. +- Keep the original step expectation block unchanged. +- Do not weaken a premise-check query into a trivial result such as selecting only the period column, limiting to one row, or dropping the required aggregate metric. + +Original user query: +{{ query }} Original plan: {{ plan_json }} +Original failed step: +{{ failed_step_json }} + +Original step expectation: +{{ failed_step_expectation_json }} + Failed step id: {{ failed_step_id }} Error message: @@ -28,6 +44,7 @@ Return JSON in this shape: "id": , "purpose": "string", "type": "sql", - "query": "corrected SQL" + "query": "corrected SQL", + "expectation": {{ failed_step_expectation_json }} } } diff --git a/app/schemas.py b/app/schemas.py index 279b89c..ef9c1af 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -50,6 +50,7 @@ class CompiledPlanStep(BaseModel): purpose: str = Field(..., min_length=1) type: Literal["sql"] query: str = Field(..., min_length=1) + expectation: StepExpectation output_alias: str | None = None @@ -80,6 +81,21 @@ class RepairDecision(BaseModel): updated_step: CompiledPlanStep +class StepExpectation(BaseModel): + """Deterministic analytical contract for one compiled plan step.""" + + model_config = ConfigDict(extra="forbid") + + step_category: Literal["premise_check", "breakdown", "follow_up"] = "follow_up" + comparison_type: Literal["period_comparison", "grouped_breakdown", "distribution", "single_result"] = "single_result" + expected_grouping_columns: list[str] = Field(default_factory=list) + expected_metric_columns: list[str] = Field(default_factory=list) + expected_period_column: str = "" + min_expected_rows: int = 1 + requires_distinct_periods: bool = False + preserve_population_from_step_id: int | None = None + + class SchemaConceptMapping(BaseModel): """Heuristic business-language alias mapped to exact schema fields.""" @@ -98,9 +114,24 @@ class SchemaColumn(BaseModel): name: str = Field(..., min_length=1) dtype: str = Field(..., min_length=1) type_family: Literal["string", "number", "boolean", "datetime", "unknown"] = "unknown" + field_origin: Literal["source_backed", "derived"] = "source_backed" + derived_from: list[str] = Field(default_factory=list) semantic_hints: list[str] = Field(default_factory=list) +class SchemaRelationship(BaseModel): + """Optional relationship metadata for joining normalized relations.""" + + model_config = ConfigDict(extra="forbid") + + left_relation: str = Field(..., min_length=1) + right_relation: str = Field(..., min_length=1) + left_on: list[str] = Field(default_factory=list) + right_on: list[str] = Field(default_factory=list) + relationship_type: Literal["one_to_one", "one_to_many", "many_to_one", "many_to_many", "unknown"] = "unknown" + confidence: Literal["heuristic", "explicit"] = "heuristic" + + class SchemaRelation(BaseModel): """One normalized table or view available to the planner.""" @@ -127,6 +158,7 @@ class SchemaManifest(BaseModel): source: str = "" dialect: str = "" relations: list[SchemaRelation] = Field(default_factory=list) + relationships: list[SchemaRelationship] = Field(default_factory=list) views: list[dict[str, Any]] = Field(default_factory=list) @@ -148,6 +180,10 @@ class EvidenceItem(BaseModel): source_alias: str = Field(..., min_length=1) source_purpose: str = Field(..., min_length=1) row_label: str = Field(..., min_length=1) + step_category: Literal["premise_check", "breakdown", "follow_up"] = "follow_up" + comparison_type: Literal["period_comparison", "grouped_breakdown", "distribution", "single_result"] = "single_result" + period_label: str = "" + dimensions: dict[str, str] = Field(default_factory=dict) entities: list[str] = Field(default_factory=list) metrics: list[str] = Field(default_factory=list) values: list[EvidenceValue] = Field(default_factory=list) @@ -185,7 +221,7 @@ class AnalysisRenderResponse(BaseModel): model_config = ConfigDict(extra="forbid") - answer_status: Literal["answered", "insufficient_evidence", "contradicted_premise", "conflicting_evidence"] + answer_status: Literal["answered", "partial_answer", "insufficient_evidence", "contradicted_premise", "conflicting_evidence"] analysis_markdown: str = Field(..., min_length=1) used_claim_ids: list[str] = Field(default_factory=list) @@ -199,7 +235,10 @@ class ExecutedStep(BaseModel): code: str output_alias: str attempt: int - status: Literal["success", "failed"] + status: Literal["success", "invalid", "failed"] + validation_status: Literal["valid", "partial", "invalid"] | None = None + validation_reason: str | None = None + expectation: StepExpectation | None = None artifact: ArtifactSummary | None = None error: str | None = None @@ -207,6 +246,7 @@ class ExecutedStep(BaseModel): class AnalyzeResponse(BaseModel): """Final API response for a completed analysis run.""" + answer_status: Literal["answered", "partial_answer", "insufficient_evidence", "contradicted_premise", "conflicting_evidence"] analysis: str trace: list[TraceEvent] executed_steps: list[ExecutedStep] diff --git a/tests/test_analysis_grounding.py b/tests/test_analysis_grounding.py index 7135bea..59378bd 100644 --- a/tests/test_analysis_grounding.py +++ b/tests/test_analysis_grounding.py @@ -11,13 +11,30 @@ def test_build_approved_claims_marks_contradicted_premise() -> None: state = create_initial_state("Why did pipeline velocity drop this week?") - state["metric"] = "avg_pipeline_velocity_days" - state["compiled_plan"] = {"metric_direction": "lower_is_better"} + state["compiled_plan"] = { + "metric": "avg_pipeline_velocity_days", + "metric_direction": "", + "plan": [ + { + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["avg_pipeline_velocity_days"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + } + } + ], + } state["executed_steps"] = [ { "id": "1", "purpose": "Compare weekly metrics", "status": "success", + "validation_status": "valid", "output_alias": "weekly_pipeline_metrics", "artifact": { "alias": "weekly_pipeline_metrics", @@ -35,7 +52,7 @@ def test_build_approved_claims_marks_contradicted_premise() -> None: assert status == "contradicted_premise" assert claims[0].kind == "premise_check" - assert "does not support a deterioration premise" in claims[0].statement + assert claims[0].statement.startswith("The premise is not supported.") def test_validate_rendered_analysis_rejects_changed_entity_name() -> None: @@ -76,3 +93,139 @@ def test_validate_rendered_analysis_rejects_unapproved_numbers() -> None: with pytest.raises(ValueError, match="introduced numbers not present"): validate_rendered_analysis(response, [claim], expected_status="answered") + + +def test_build_approved_claims_marks_partial_answer_when_caveats_remain() -> None: + state = create_initial_state("How did revenue change this week?") + state["metric"] = "revenue" + state["compiled_plan"] = {"metric_direction": "higher_is_better"} + state["executed_steps"] = [ + { + "id": "1", + "purpose": "Compare weekly revenue", + "status": "success", + "validation_status": "valid", + "output_alias": "weekly_revenue", + "artifact": { + "alias": "weekly_revenue", + "columns": ["period", "revenue"], + "preview_rows": [ + {"period": "Previous Week", "revenue": 100}, + {"period": "Current Week", "revenue": 120}, + ], + }, + } + ] + + evidence = build_analysis_evidence(state) + claims, status = build_approved_claims( + evidence, + unresolved_steps=[ + { + "id": "2", + "purpose": "Break revenue out by owner", + "status": "invalid", + "output_alias": "owner_revenue", + } + ], + ) + + assert status == "partial_answer" + assert any(claim.kind == "comparison" for claim in claims) + assert any(claim.kind == "caveat" for claim in claims) + + +def test_build_analysis_evidence_normalizes_boolean_period_labels() -> None: + state = create_initial_state("How did pipeline velocity differ by manager this week versus last week?") + state["compiled_plan"] = { + "metric": "avg_pipeline_velocity_days", + "metric_direction": "lower_is_better", + "plan": [ + { + "expectation": { + "step_category": "breakdown", + "comparison_type": "grouped_breakdown", + "expected_grouping_columns": ["manager"], + "expected_metric_columns": ["avg_pipeline_velocity_days"], + "expected_period_column": "current_period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": 1, + } + } + ], + } + state["executed_steps"] = [ + { + "id": "2", + "purpose": "Compare manager velocity by week", + "status": "success", + "validation_status": "valid", + "output_alias": "manager_velocity", + "expectation": state["compiled_plan"]["plan"][0]["expectation"], + "artifact": { + "alias": "manager_velocity", + "columns": ["manager", "current_period", "avg_pipeline_velocity_days"], + "preview_rows": [ + {"manager": "Cara Losch", "current_period": False, "avg_pipeline_velocity_days": 71.09}, + {"manager": "Cara Losch", "current_period": True, "avg_pipeline_velocity_days": 63.56}, + ], + }, + } + ] + + evidence = build_analysis_evidence(state) + + labels = [item.row_label for item in evidence.items] + assert "Cara Losch | previous_week" in labels + assert "Cara Losch | current_week" in labels + assert all("True" not in label and "False" not in label for label in labels) + + +def test_validate_rendered_analysis_requires_verdict_first_for_contradiction() -> None: + claim = ApprovedClaim( + id="claim_premise_check", + kind="premise_check", + statement=( + "The premise is not supported. avg_pipeline_velocity_days was 69.94 for previous_week and 64.13 " + "for current_week, and lower is better." + ), + entities=["previous_week", "current_week"], + metrics=["avg_pipeline_velocity_days"], + source_aliases=["weekly_pipeline_metrics"], + values=[ + EvidenceValue(label="previous_week.avg_pipeline_velocity_days", value="69.94"), + EvidenceValue(label="current_week.avg_pipeline_velocity_days", value="64.13"), + ], + ) + response = AnalysisRenderResponse( + answer_status="contradicted_premise", + analysis_markdown=( + "avg_pipeline_velocity_days was 69.94 for previous_week and 64.13 for current_week.\n" + "The premise is not supported." + ), + used_claim_ids=["claim_premise_check"], + ) + + with pytest.raises(ValueError, match="first sentence"): + validate_rendered_analysis(response, [claim], expected_status="contradicted_premise") + + +def test_validate_rendered_analysis_rejects_internal_workflow_leakage() -> None: + claim = ApprovedClaim( + id="claim_manager", + kind="row_observation", + statement="For Celia Rouche, current_week_avg_velocity = 64.65.", + entities=["Celia Rouche"], + metrics=["current_week_avg_velocity"], + source_aliases=["regional_manager_velocity"], + values=[EvidenceValue(label="current_week_avg_velocity", value="64.65")], + ) + response = AnalysisRenderResponse( + answer_status="answered", + analysis_markdown="## Summary\nValidator exception review executed steps and trace for raw outputs.", + used_claim_ids=["claim_manager"], + ) + + with pytest.raises(ValueError, match="internal workflow text"): + validate_rendered_analysis(response, [claim], expected_status="answered") diff --git a/tests/test_api.py b/tests/test_api.py index e8a7368..109379d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -23,6 +23,7 @@ def reset_workspace_state() -> None: def test_analyze_response_accepts_skipped_trace() -> None: """execute_plan_node emits skipped when there is no plan; API must still serialize.""" resp = AnalyzeResponse( + answer_status="insufficient_evidence", analysis="Planner failed; no execution.", trace=[ { @@ -54,6 +55,7 @@ def test_sample_questions_endpoint() -> None: def test_analyze_endpoint_structure() -> None: def fake_run_analysis(query: str) -> dict: # noqa: ARG001 return { + "answer_status": "answered", "analysis": "## Summary\nPipeline velocity improved.\n", "trace": [{"step": "planner_compiled_node", "status": "completed", "details": {"objective": "x"}}], "executed_steps": [ @@ -90,10 +92,11 @@ def fake_run_analysis(query: str) -> dict: # noqa: ARG001 routes.run_analysis = original assert response.status_code == 200 payload = response.json() - assert {"analysis", "trace", "executed_steps", "errors", "inspection_id"} <= payload.keys() + assert {"answer_status", "analysis", "trace", "executed_steps", "errors", "inspection_id"} <= payload.keys() assert isinstance(payload["trace"], list) assert isinstance(payload["executed_steps"], list) assert isinstance(payload["analysis"], str) + assert payload["answer_status"] == "answered" assert isinstance(payload["inspection_id"], str) @@ -135,6 +138,7 @@ def test_upload_endpoint_profiles_csv() -> None: def test_inspection_endpoint_returns_stored_inspection() -> None: def fake_run_analysis(query: str) -> dict: # noqa: ARG001 return { + "answer_status": "answered", "analysis": "## Summary\nPipeline velocity improved.\n", "trace": [{"step": "planner_compiled_node", "status": "completed", "details": {"objective": "x"}}], "executed_steps": [ diff --git a/tests/test_graph.py b/tests/test_graph.py new file mode 100644 index 0000000..9d9bbe1 --- /dev/null +++ b/tests/test_graph.py @@ -0,0 +1,116 @@ +"""Graph-node regressions for execution status transitions.""" + +from app.agent.graph import execute_plan_node +from app.agent.state import create_initial_state + + +def test_execute_plan_node_keeps_partial_execution_when_repaired_step_is_still_invalid(monkeypatch) -> None: + state = create_initial_state("Why did revenue drop this week?") + state["workflow_status"] = "ready_to_execute" + state["compiled_plan"] = { + "objective": "Investigate revenue change", + "plan": [ + { + "id": 1, + "purpose": "Compare weekly revenue.", + "type": "sql", + "query": "SELECT 'Previous Week' AS period, 100 AS revenue UNION ALL SELECT 'Current Week' AS period, 120 AS revenue", + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["revenue"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, + "output_alias": "weekly_revenue", + }, + { + "id": 2, + "purpose": "Break revenue out by owner.", + "type": "sql", + "query": "SELECT owner, SUM(revenue) AS revenue FROM opportunities_enriched GROUP BY owner", + "expectation": { + "step_category": "breakdown", + "comparison_type": "grouped_breakdown", + "expected_grouping_columns": ["owner"], + "expected_metric_columns": ["revenue"], + "expected_period_column": "", + "min_expected_rows": 1, + "requires_distinct_periods": False, + "preserve_population_from_step_id": 1, + }, + "output_alias": "owner_revenue", + }, + ], + } + + def fake_execute_plan(local_state, plan): # noqa: ANN001 + local_state["executed_steps"].append( + { + "id": "1", + "purpose": "Compare weekly revenue.", + "status": "success", + "validation_status": "valid", + "output_alias": "weekly_revenue", + "artifact": { + "alias": "weekly_revenue", + "columns": ["period", "revenue"], + "preview_rows": [ + {"period": "Previous Week", "revenue": 100}, + {"period": "Current Week", "revenue": 120}, + ], + }, + } + ) + local_state["executed_steps"].append( + { + "id": "2", + "purpose": "Break revenue out by owner.", + "status": "invalid", + "validation_status": "invalid", + "validation_reason": "The result is missing expected columns: owner.", + "output_alias": "owner_revenue", + "artifact": None, + } + ) + return {"status": "invalid", "failed_step_id": "2", "error": "The result is missing expected columns: owner."} + + def fake_repair(local_state, failed_step_id, error_message): # noqa: ANN001 + assert failed_step_id == "2" + assert "owner" in error_message + return local_state + + def fake_execute_single(local_state, step_row, attempt): # noqa: ANN001 + assert attempt == 2 + assert str(step_row["id"]) == "2" + local_state["executed_steps"].append( + { + "id": "2", + "purpose": "Break revenue out by owner.", + "status": "invalid", + "attempt": 2, + "validation_status": "invalid", + "validation_reason": "The workflow could not validate a reliable grouped result for Break revenue out by owner.", + "output_alias": "owner_revenue", + "artifact": None, + } + ) + return { + "status": "invalid", + "failed_step_id": "2", + "error": "The workflow could not validate a reliable grouped result for Break revenue out by owner.", + } + + monkeypatch.setattr("app.agent.graph.execute_plan", fake_execute_plan) + monkeypatch.setattr("app.agent.graph.repair_failed_step", fake_repair) + monkeypatch.setattr("app.agent.graph.execute_single_plan_step", fake_execute_single) + + state = execute_plan_node(state) + + assert state["workflow_status"] == "partial_execution" + assert state["unresolved_step_ids"] == ["2"] + assert any(step["id"] == "1" and step["status"] == "success" for step in state["executed_steps"]) + assert state["trace"][-1]["status"] == "failed" diff --git a/tests/test_intent.py b/tests/test_intent.py index db126b0..a785993 100644 --- a/tests/test_intent.py +++ b/tests/test_intent.py @@ -17,7 +17,17 @@ def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 "id": 1, "purpose": "Compare current and previous pipeline velocity.", "type": "sql", - "query": "SELECT 1 AS value", + "query": "SELECT 'Previous Week' AS period, 2 AS pipeline_velocity UNION ALL SELECT 'Current Week' AS period, 1 AS pipeline_velocity", + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["pipeline_velocity"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, "output_alias": "comparison_result", } ], @@ -38,15 +48,59 @@ def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 } +class LeakyAnalysisLLM: + """Return deliberately invalid rendered analysis to exercise fallback handling.""" + + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + if '"analysis_markdown": string' in prompt and "approved claims" in prompt.lower(): + return { + "answer_status": "answered", + "analysis_markdown": "Validator exception. Review executed steps and trace for raw outputs.", + "used_claim_ids": ["claim_premise_check"], + } + return { + "answer_status": "insufficient_evidence", + "analysis_markdown": "The approved claims are insufficient to answer the question.", + "used_claim_ids": [], + } + + def test_planner_returns_compiled_plan(monkeypatch) -> None: monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: FakeLLM()) state = create_initial_state("Why did pipeline velocity drop this week?") - state["dataset_context"] = {"reference_date": "2017-12-31", "views": [{"name": "opportunities_enriched"}]} + state["dataset_context"] = { + "reference_date": "2017-12-31", + "relations": [ + { + "name": "opportunities_enriched", + "columns": [ + {"name": "period", "dtype": "object", "type_family": "string", "field_origin": "derived", "derived_from": [], "semantic_hints": ["period"]}, + { + "name": "pipeline_velocity", + "dtype": "int64", + "type_family": "number", + "field_origin": "derived", + "derived_from": [], + "semantic_hints": ["pipeline velocity"], + }, + ], + "identifier_columns": [], + "time_columns": [], + "measure_columns": ["pipeline_velocity"], + "dimension_columns": ["period"], + "semantic_mappings": [], + "grain": "Rows can be keyed by period", + } + ], + "relationships": [], + "views": [{"name": "opportunities_enriched"}], + } state = plan_compiled_query(state) assert state["compiled_plan"] is not None assert state["compiled_plan"]["plan"][0]["type"] == "sql" - assert state["compiled_plan"]["plan"][0]["query"] == "SELECT 1 AS value" + assert "SELECT 'Previous Week'" in state["compiled_plan"]["plan"][0]["query"] assert state["compiled_plan"]["plan"][0]["output_alias"] == "comparison_result" + assert state["compiled_plan"]["plan"][0]["expectation"]["step_category"] == "premise_check" def test_analysis_narrative_uses_llm(monkeypatch) -> None: @@ -59,6 +113,7 @@ def test_analysis_narrative_uses_llm(monkeypatch) -> None: "id": "step_1", "purpose": "Compare", "status": "success", + "validation_status": "valid", "output_alias": "comparison_result", "artifact": { "alias": "comparison_result", @@ -72,3 +127,81 @@ def test_analysis_narrative_uses_llm(monkeypatch) -> None: state = run_analysis_narrative(state) assert "value = 1" in state["analysis"] assert "SMB" in state["analysis"] + + +def test_analysis_narrative_preserves_partial_evidence_without_leaking_internal_text(monkeypatch) -> None: + monkeypatch.setattr("app.agent.analysis.get_llm_client", lambda: LeakyAnalysisLLM()) + state = create_initial_state("Why did pipeline velocity drop this week?") + state["metric"] = "avg_pipeline_velocity_days" + state["compiled_plan"] = {"objective": "Test", "metric": "avg_pipeline_velocity_days", "metric_direction": "lower_is_better"} + state["executed_steps"] = [ + { + "id": "step_1", + "purpose": "Compare weekly metrics", + "status": "success", + "validation_status": "valid", + "output_alias": "weekly_pipeline_metrics", + "artifact": { + "alias": "weekly_pipeline_metrics", + "row_count": 2, + "columns": ["period", "avg_pipeline_velocity_days"], + "preview_rows": [ + {"period": "Previous Week", "avg_pipeline_velocity_days": 69.94}, + {"period": "Current Week", "avg_pipeline_velocity_days": 64.13}, + ], + "summary": {}, + }, + }, + { + "id": "step_2", + "purpose": "Break velocity out by owner", + "status": "failed", + "output_alias": "owner_velocity", + "error": "Binder error", + }, + ] + + state = run_analysis_narrative(state) + + assert state["answer_status"] == "contradicted_premise" + assert state["analysis"].startswith("The premise is not supported.") + assert "validator" not in state["analysis"].lower() + assert "trace for raw outputs" not in state["analysis"].lower() + + +def test_analysis_narrative_uses_clean_specific_fallback_for_failed_repaired_premise_check() -> None: + state = create_initial_state("Why did pipeline velocity drop this week?") + state["compiled_plan"] = { + "objective": "Compare weekly velocity", + "metric": "pipeline_velocity_days", + "metric_direction": "lower_is_better", + } + state["executed_steps"] = [ + { + "id": "step_1", + "purpose": "Compare weekly pipeline velocity.", + "status": "invalid", + "attempt": 2, + "validation_status": "invalid", + "validation_reason": "The result is missing expected columns: avg_pipeline_velocity_days.", + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["avg_pipeline_velocity_days"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, + "output_alias": "velocity_comparison", + "error": "The result is missing expected columns: avg_pipeline_velocity_days.", + } + ] + + state = run_analysis_narrative(state) + + assert state["answer_status"] == "insufficient_evidence" + assert "repaired premise-check query did not return the required metric avg_pipeline_velocity_days" in state["analysis"] + assert "missing expected columns" not in state["analysis"].lower() + assert "validator" not in state["analysis"].lower() diff --git a/tests/test_planner_schema.py b/tests/test_planner_schema.py index a82500e..48278f1 100644 --- a/tests/test_planner_schema.py +++ b/tests/test_planner_schema.py @@ -1,6 +1,6 @@ """Compiled plan schema and planner retry behavior.""" -from app.agent.planner import plan_compiled_query +from app.agent.planner import _build_repair_prompt, _schema_subset_for_question, plan_compiled_query, repair_failed_step from app.agent.state import create_initial_state from app.data.semantic_model import get_semantic_context from app.schemas import CompiledPlan @@ -16,6 +16,16 @@ def test_compiled_plan_normalizes_max_steps() -> None: "purpose": "One step", "type": "sql", "query": "SELECT 1", + "expectation": { + "step_category": "follow_up", + "comparison_type": "single_result", + "expected_grouping_columns": [], + "expected_metric_columns": ["value"], + "expected_period_column": "", + "min_expected_rows": 1, + "requires_distinct_periods": False, + "preserve_population_from_step_id": None, + }, } ], "max_steps": 1, @@ -28,19 +38,29 @@ def test_compiled_plan_normalizes_max_steps() -> None: def test_planner_retries_after_validation_error(monkeypatch) -> None: good = { - "objective": "Segment counts", + "objective": "Revenue comparison", "plan": [ { "id": 1, - "purpose": "Count by segment", + "purpose": "Compare current and previous revenue.", "type": "sql", - "query": "SELECT 1 AS value", - "output_alias": "counts", + "query": "SELECT 'Previous Week' AS period, 10 AS revenue UNION ALL SELECT 'Current Week' AS period, 12 AS revenue", + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["revenue"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, + "output_alias": "weekly_revenue", } ], "max_steps": 3, - "metric": "", - "metric_direction": "", + "metric": "revenue", + "metric_direction": "higher_is_better", } bad = { "objective": "Too many", @@ -63,8 +83,8 @@ def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 stub = FlakyPlannerLLM() monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: stub) - state = create_initial_state("Compare segments") - state["dataset_context"] = {"reference_date": "2017-12-31", "views": [{"name": "opportunities_enriched"}]} + state = create_initial_state("Why did revenue drop this week?") + state["dataset_context"] = get_semantic_context().schema_manifest state = plan_compiled_query(state) assert stub.calls == 2 @@ -77,10 +97,12 @@ def test_semantic_context_exposes_normalized_manifest() -> None: assert manifest["dialect"] == "duckdb" assert manifest["relations"] + assert manifest["relationships"] is not None relation = next(relation for relation in manifest["relations"] if relation["name"] == "opportunities_enriched") assert relation["columns"] assert relation["identifier_columns"] assert relation["grain"] + assert relation["columns"][0]["field_origin"] in {"source_backed", "derived"} def test_planner_retries_after_sql_preflight_failure(monkeypatch) -> None: @@ -91,7 +113,17 @@ def test_planner_retries_after_sql_preflight_failure(monkeypatch) -> None: "id": 1, "purpose": "Break out velocity by agent", "type": "sql", - "query": "SELECT sales_agent, AVG(pipeline_velocity_days) AS avg_velocity_days FROM opportunities_enriched GROUP BY sales_agent", + "query": "SELECT sales_agent, AVG(pipeline_velocity_days) AS avg_pipeline_velocity_days FROM opportunities_enriched GROUP BY sales_agent", + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["avg_pipeline_velocity_days"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, "output_alias": "velocity_by_agent", } ], @@ -106,7 +138,17 @@ def test_planner_retries_after_sql_preflight_failure(monkeypatch) -> None: "id": 1, "purpose": "Break out velocity by owner", "type": "sql", - "query": "SELECT owner, AVG(pipeline_velocity_days) AS avg_velocity_days FROM opportunities_enriched GROUP BY owner", + "query": "SELECT current_period AS period, AVG(pipeline_velocity_days) AS avg_pipeline_velocity_days FROM opportunities_enriched GROUP BY current_period", + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["avg_pipeline_velocity_days"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, "output_alias": "velocity_by_owner", } ], @@ -137,4 +179,166 @@ def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 assert stub.calls == 2 assert "failed SQL preflight validation" in stub.prompts[1] assert state["compiled_plan"] is not None - assert "owner" in state["compiled_plan"]["plan"][0]["query"] + assert "current_period AS period" in state["compiled_plan"]["plan"][0]["query"] + + +def test_schema_subset_preserves_structural_fields() -> None: + manifest = get_semantic_context().schema_manifest + subset = _schema_subset_for_question(manifest, "Why did pipeline velocity drop this week by sales rep and region?") + + relation = next(relation for relation in subset["relations"] if relation["name"] == "opportunities_enriched") + selected_names = {column["name"] for column in relation["columns"]} + + assert set(relation["identifier_columns"]).issubset(selected_names) + assert set(relation["time_columns"]).issubset(selected_names) + assert set(relation["measure_columns"]).issubset(selected_names) + + +def test_repair_prompt_includes_query_and_expectation_context() -> None: + state = create_initial_state("Why did revenue drop this week?") + state["dataset_context"] = get_semantic_context().schema_manifest + state["compiled_plan"] = { + "objective": "Compare revenue", + "plan": [ + { + "id": 1, + "purpose": "Compare revenue by period.", + "type": "sql", + "query": "SELECT sales_agent, SUM(deal_value) AS revenue FROM opportunities_enriched GROUP BY sales_agent", + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["revenue"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, + "output_alias": "revenue_by_period", + } + ], + } + + prompt = _build_repair_prompt(state, "1", 'Binder Error: Referenced column "sales_agent" not found') + + assert "Why did revenue drop this week?" in prompt + assert '"expected_period_column": "period"' in prompt + assert 'Referenced column "sales_agent" not found' in prompt + + +def test_repair_rejects_expectation_drift(monkeypatch) -> None: + class RepairLLM: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + return { + "repair_action": "replace_step", + "updated_step": { + "id": 1, + "purpose": "Compare revenue by period.", + "type": "sql", + "query": "SELECT current_period AS period, SUM(deal_value) AS revenue FROM opportunities_enriched GROUP BY current_period", + "expectation": { + "step_category": "breakdown", + "comparison_type": "grouped_breakdown", + "expected_grouping_columns": ["segment"], + "expected_metric_columns": ["revenue"], + "expected_period_column": "", + "min_expected_rows": 1, + "requires_distinct_periods": False, + "preserve_population_from_step_id": None, + }, + "output_alias": "revenue_by_period", + }, + } + + monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: RepairLLM()) + state = create_initial_state("Why did revenue drop this week?") + state["dataset_context"] = get_semantic_context().schema_manifest + state["compiled_plan"] = { + "objective": "Compare revenue", + "plan": [ + { + "id": 1, + "purpose": "Compare revenue by period.", + "type": "sql", + "query": "SELECT sales_agent, SUM(deal_value) AS revenue FROM opportunities_enriched GROUP BY sales_agent", + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["revenue"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, + "output_alias": "revenue_by_period", + } + ], + } + + try: + repair_failed_step(state, "1", 'Binder Error: Referenced column "sales_agent" not found') + except ValueError as exc: + assert "changed the original step expectation" in str(exc) + else: # pragma: no cover - defensive + raise AssertionError("Repair should have rejected expectation drift.") + + +def test_repair_rejects_premise_check_that_drops_required_metric(monkeypatch) -> None: + class RepairLLM: + def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 + return { + "repair_action": "replace_step", + "updated_step": { + "id": 1, + "purpose": "Compare velocity by period.", + "type": "sql", + "query": "SELECT current_period AS period FROM opportunities_enriched LIMIT 1", + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["avg_pipeline_velocity"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, + "output_alias": "velocity_by_period", + }, + } + + monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: RepairLLM()) + state = create_initial_state("Why did pipeline velocity drop this week?") + state["dataset_context"] = get_semantic_context().schema_manifest + state["compiled_plan"] = { + "objective": "Compare velocity", + "plan": [ + { + "id": 1, + "purpose": "Compare velocity by period.", + "type": "sql", + "query": "SELECT current_period AS period, AVG(pipeline_velocity_days) AS avg_pipeline_velocity_days FROM opportunities_enriched GROUP BY current_period", + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["avg_pipeline_velocity_days"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, + "output_alias": "velocity_by_period", + } + ], + } + + try: + repair_failed_step(state, "1", "Binder error") + except ValueError as exc: + assert "did not preserve the original analytical intent" in str(exc) + assert "missing expected columns" in str(exc) + else: # pragma: no cover - defensive + raise AssertionError("Repair should have been rejected when it dropped the required metric.") diff --git a/tests/test_tools.py b/tests/test_tools.py index d2fa90f..fe4b4fb 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -17,6 +17,16 @@ def test_execute_sql_step_returns_artifact_summary() -> None: "purpose": "Get one sample aggregation.", "type": "sql", "query": "SELECT segment, COUNT(*) AS deals FROM opportunities_enriched GROUP BY segment ORDER BY deals DESC", + "expectation": { + "step_category": "breakdown", + "comparison_type": "grouped_breakdown", + "expected_grouping_columns": ["segment"], + "expected_metric_columns": ["deals"], + "expected_period_column": "", + "min_expected_rows": 1, + "requires_distinct_periods": False, + "preserve_population_from_step_id": None, + }, "output_alias": "segment_counts", } ], @@ -40,14 +50,24 @@ def test_execute_plan_marks_empty_table_as_failed() -> None: "id": 1, "purpose": "Return no rows.", "type": "sql", - "query": "SELECT 1 WHERE 1=0", + "query": "SELECT 1 AS value WHERE 1=0", + "expectation": { + "step_category": "follow_up", + "comparison_type": "single_result", + "expected_grouping_columns": [], + "expected_metric_columns": ["value"], + "expected_period_column": "", + "min_expected_rows": 1, + "requires_distinct_periods": False, + "preserve_population_from_step_id": None, + }, "output_alias": "empty_result", } ], } outcome = execute_plan(state, compiled_plan) - assert outcome["status"] == "failed" - assert state["executed_steps"][-1]["status"] == "failed" + assert outcome["status"] == "invalid" + assert state["executed_steps"][-1]["status"] == "invalid" def test_execute_single_plan_step_retry() -> None: @@ -58,8 +78,184 @@ def test_execute_single_plan_step_retry() -> None: "purpose": "Get one row.", "type": "sql", "query": "SELECT 1 AS value", + "expectation": { + "step_category": "follow_up", + "comparison_type": "single_result", + "expected_grouping_columns": [], + "expected_metric_columns": ["value"], + "expected_period_column": "", + "min_expected_rows": 1, + "requires_distinct_periods": False, + "preserve_population_from_step_id": None, + }, "output_alias": "r1", } out = execute_single_plan_step(state, step, attempt=2) assert out["status"] == "success" assert state["executed_steps"][-1]["attempt"] == 2 + + +def test_execute_plan_rejects_one_period_comparison() -> None: + state = create_initial_state("Did revenue change this week?") + state["dataset_context"] = get_semantic_context().schema_manifest + compiled_plan = { + "objective": "Compare one period", + "max_steps": 3, + "plan": [ + { + "id": 1, + "purpose": "Compare current and previous revenue.", + "type": "sql", + "query": "SELECT 'Current Week' AS period, 120 AS revenue", + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["revenue"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, + "output_alias": "weekly_revenue", + } + ], + } + + outcome = execute_plan(state, compiled_plan) + + assert outcome["status"] == "invalid" + assert state["executed_steps"][-1]["validation_reason"] == "The comparison result did not return at least two comparable periods." + + +def test_execute_single_step_rejects_wrong_group_shape() -> None: + state = create_initial_state("Break revenue out by segment") + state["dataset_context"] = get_semantic_context().schema_manifest + step = { + "id": 2, + "purpose": "Break revenue out by segment.", + "type": "sql", + "query": "SELECT stage, COUNT(*) AS deals FROM opportunities_enriched GROUP BY stage", + "expectation": { + "step_category": "breakdown", + "comparison_type": "grouped_breakdown", + "expected_grouping_columns": ["segment"], + "expected_metric_columns": ["deals"], + "expected_period_column": "", + "min_expected_rows": 1, + "requires_distinct_periods": False, + "preserve_population_from_step_id": 1, + }, + "output_alias": "segment_breakdown", + } + + outcome = execute_single_plan_step(state, step, attempt=2) + + assert outcome["status"] == "invalid" + assert "missing expected columns: segment" in state["executed_steps"][-1]["validation_reason"] + + +def test_execute_single_step_marks_grouped_period_comparison_partial_when_only_some_groups_match() -> None: + state = create_initial_state("How did revenue differ by manager this week versus last week?") + state["dataset_context"] = get_semantic_context().schema_manifest + step = { + "id": 3, + "purpose": "Compare revenue by manager across periods.", + "type": "sql", + "query": ( + "SELECT * FROM (" + "SELECT 'Cara Losch' AS manager, 'previous_week' AS period, 100 AS revenue " + "UNION ALL SELECT 'Cara Losch' AS manager, 'current_week' AS period, 120 AS revenue " + "UNION ALL SELECT 'Rocco Neubert' AS manager, 'current_week' AS period, 80 AS revenue" + ")" + ), + "expectation": { + "step_category": "breakdown", + "comparison_type": "grouped_breakdown", + "expected_grouping_columns": ["manager"], + "expected_metric_columns": ["revenue"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": 1, + }, + "output_alias": "manager_revenue", + } + + outcome = execute_single_plan_step(state, step, attempt=1) + + assert outcome["status"] == "success" + assert state["executed_steps"][-1]["validation_status"] == "partial" + assert "Only 1 of 2 groups returned comparable periods." == state["executed_steps"][-1]["validation_reason"] + + +def test_execute_single_step_accepts_canonical_metric_alias_for_pipeline_velocity_average() -> None: + state = create_initial_state("Why did pipeline velocity drop this week?") + state["dataset_context"] = get_semantic_context().schema_manifest + step = { + "id": 4, + "purpose": "Compare average pipeline velocity by period.", + "type": "sql", + "query": ( + "SELECT " + "CASE WHEN current_period = TRUE THEN 'current_week' " + "WHEN previous_period = TRUE THEN 'previous_week' END AS period, " + "AVG(pipeline_velocity_days) AS avg_pipeline_velocity " + "FROM opportunities_enriched " + "WHERE current_period = TRUE OR previous_period = TRUE " + "GROUP BY 1 ORDER BY 1" + ), + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": [], + "expected_metric_columns": ["avg_pipeline_velocity"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, + "output_alias": "velocity_comparison", + } + + outcome = execute_single_plan_step(state, step, attempt=1) + + assert outcome["status"] == "success" + assert state["executed_steps"][-1]["artifact"]["columns"] == ["period", "avg_pipeline_velocity_days"] + assert state["executed_steps"][-1]["validation_status"] == "valid" + + +def test_execute_single_step_accepts_period_comparison_when_period_is_listed_as_grouping_column() -> None: + state = create_initial_state("Why did pipeline velocity drop this week?") + state["dataset_context"] = get_semantic_context().schema_manifest + step = { + "id": 5, + "purpose": "Premise check: Compare average pipeline_velocity_days for current week vs previous week to confirm pipeline velocity drop", + "type": "sql", + "query": ( + "SELECT " + "CASE WHEN current_period = TRUE THEN 'current_week' " + "WHEN previous_period = TRUE THEN 'previous_week' END AS period, " + "AVG(pipeline_velocity_days) AS avg_pipeline_velocity_days " + "FROM opportunities_enriched " + "WHERE current_period = TRUE OR previous_period = TRUE " + "GROUP BY 1 ORDER BY 1" + ), + "expectation": { + "step_category": "premise_check", + "comparison_type": "period_comparison", + "expected_grouping_columns": ["period"], + "expected_metric_columns": ["avg_pipeline_velocity_days"], + "expected_period_column": "period", + "min_expected_rows": 2, + "requires_distinct_periods": True, + "preserve_population_from_step_id": None, + }, + "output_alias": "pipeline_velocity_comparison", + } + + outcome = execute_single_plan_step(state, step, attempt=2) + + assert outcome["status"] == "success" + assert state["executed_steps"][-1]["validation_status"] == "valid" + assert state["executed_steps"][-1]["artifact"]["columns"] == ["period", "avg_pipeline_velocity_days"] diff --git a/ui/src/api/mappers.ts b/ui/src/api/mappers.ts index b81bf9d..a5c3abf 100644 --- a/ui/src/api/mappers.ts +++ b/ui/src/api/mappers.ts @@ -91,6 +91,7 @@ function buildAssistantPayload(response: AnalyzeApiResponse, inspection: Inspect const metrics = buildMetrics(successCount, totalSteps, inspection.rowsReturned, inspection); const detailPool = [ + `Answer status: ${response.answer_status.replace(/_/g, " ")}.`, ...parsed.details, primaryArtifact ? `Preview ready for ${primaryArtifact.alias} with ${inspection.rowsReturned} row${inspection.rowsReturned === 1 ? "" : "s"} available in the inspection drawer.` @@ -258,6 +259,7 @@ function buildExecutionChips(response: AnalyzeApiResponse, primaryArtifact: Anal executedSteps ? `${executedSteps} workflow step${executedSteps === 1 ? "" : "s"}` : "No executed steps", primaryArtifact ? `Output: ${primaryArtifact.alias}` : "No output alias", retryCount ? `${retryCount} retry attempt${retryCount === 1 ? "" : "s"}` : "No retries", + `Answer: ${response.answer_status.replace(/_/g, " ")}`, response.errors.length ? `${response.errors.length} issue${response.errors.length === 1 ? "" : "s"}` : "No recorded errors", ].filter(Boolean), ).slice(0, 4); @@ -285,6 +287,17 @@ function buildValidation( : "No successful execution steps were returned for this prompt.", status: fatalErrors ? "fail" : successCount > 0 ? "pass" : "warn", }, + { + id: "answer_status", + label: "Answer status", + detail: `The backend classified this run as ${response.answer_status.replace(/_/g, " ")}.`, + status: + response.answer_status === "answered" || response.answer_status === "contradicted_premise" + ? "pass" + : response.answer_status === "partial_answer" + ? "warn" + : "fail", + }, { id: "step_coverage", label: "Step coverage", @@ -335,10 +348,14 @@ function buildMetadata( label: "Execution status", value: response.errors.some((item) => !item.recoverable) ? "Failed" - : response.errors.length || response.executed_steps.some((step) => step.attempt > 1) + : response.errors.length || response.executed_steps.some((step) => step.attempt > 1) || response.answer_status !== "answered" ? "Completed with review notes" : "Complete", }, + { + label: "Answer status", + value: response.answer_status.replace(/_/g, " "), + }, { label: "Verification", value: verified ? "Verified" : "Needs analyst review", @@ -473,10 +490,14 @@ function deriveInspectionStatus(response: AnalyzeApiResponse, steps: AnalyzeExec return "error"; } + if (response.answer_status === "partial_answer" || response.answer_status === "insufficient_evidence" || response.answer_status === "conflicting_evidence") { + return "warning"; + } + if ( response.errors.length > 0 || response.trace.some((event) => event.status === "failed" || event.status === "skipped") || - steps.some((step) => step.status === "failed" || step.attempt > 1) + steps.some((step) => step.status === "failed" || step.status === "invalid" || step.attempt > 1) ) { return "warning"; } @@ -491,7 +512,15 @@ function deriveConfidence(response: AnalyzeApiResponse, primaryArtifact: Analyze const previewBonus = primaryArtifact?.row_count ? 0.16 : 0; const traceBonus = response.trace.some((event) => event.status === "completed") ? 0.07 : 0; const errorPenalty = response.errors.some((item) => !item.recoverable) ? 0.18 : response.errors.length ? 0.08 : 0; - const score = 0.46 + successRatio * 0.24 + previewBonus + traceBonus - errorPenalty; + const answerPenalty = + response.answer_status === "conflicting_evidence" + ? 0.15 + : response.answer_status === "insufficient_evidence" + ? 0.12 + : response.answer_status === "partial_answer" + ? 0.06 + : 0; + const score = 0.46 + successRatio * 0.24 + previewBonus + traceBonus - errorPenalty - answerPenalty; return clamp(score, 0.35, 0.95); } diff --git a/ui/src/api/types.ts b/ui/src/api/types.ts index 8153d72..561ca3d 100644 --- a/ui/src/api/types.ts +++ b/ui/src/api/types.ts @@ -64,12 +64,15 @@ export interface AnalyzeExecutedStep { code: string; output_alias: string; attempt: number; - status: "success" | "failed"; + status: "success" | "invalid" | "failed"; + validation_status?: "valid" | "partial" | "invalid" | null; + validation_reason?: string | null; artifact?: AnalyzeArtifactSummary | null; error?: string | null; } export interface AnalyzeApiResponse { + answer_status: "answered" | "partial_answer" | "insufficient_evidence" | "contradicted_premise" | "conflicting_evidence"; analysis: string; trace: AnalyzeTraceEvent[]; executed_steps: AnalyzeExecutedStep[];