diff --git a/app/agent/executor.py b/app/agent/executor.py index b9915e5..d8e58af 100644 --- a/app/agent/executor.py +++ b/app/agent/executor.py @@ -61,15 +61,18 @@ def _register_artifacts(conn: duckdb.DuckDBPyConnection, state: AnalysisState) - def _execute_sql(state: AnalysisState, step: dict[str, Any]) -> ArtifactSummary: - conn = new_duckdb_connection() - _register_artifacts(conn, state) - frame = conn.execute(step["code"]).fetchdf() - state["artifacts"][step["output_alias"]] = frame - return _summarize_artifact(step["output_alias"], frame) + conn = new_duckdb_connection(state.get("dataset_context")) + try: + _register_artifacts(conn, state) + frame = conn.execute(step["code"]).fetchdf() + state["artifacts"][step["output_alias"]] = frame + return _summarize_artifact(step["output_alias"], frame) + finally: + conn.close() def _execute_pandas(state: AnalysisState, step: dict[str, Any]) -> ArtifactSummary: - context = get_semantic_context() + context = get_semantic_context(state.get("source_ids")) local_env: dict[str, Any] = { **context.raw_views, **context.semantic_views, @@ -112,26 +115,29 @@ def preflight_compiled_plan(state: AnalysisState, compiled_plan: dict[str, Any]) Steps are checked in order so later queries can reference earlier output aliases. """ - conn = new_duckdb_connection() - _register_artifacts(conn, state) - rows = list(compiled_plan.get("plan") or []) - rows.sort(key=lambda r: r["id"] if isinstance(r, dict) else r.id) - - for row in rows: - internal = compiled_plan_row_to_internal(row) - sql = internal["code"].strip().rstrip(";") - try: - preview = conn.execute(f"SELECT * FROM ({sql}) AS __planera_preflight LIMIT 0").fetchdf() - conn.register(internal["output_alias"], preview) - except Exception as exc: - return { - "status": "failed", - "failed_step_id": internal["id"], - "error": str(exc), - "query": internal["code"], - } - - return {"status": "success"} + conn = new_duckdb_connection(state.get("dataset_context")) + try: + _register_artifacts(conn, state) + rows = list(compiled_plan.get("plan") or []) + rows.sort(key=lambda r: r["id"] if isinstance(r, dict) else r.id) + + for row in rows: + internal = compiled_plan_row_to_internal(row) + sql = internal["code"].strip().rstrip(";") + try: + preview = conn.execute(f"SELECT * FROM ({sql}) AS __planera_preflight LIMIT 0").fetchdf() + conn.register(internal["output_alias"], preview) + except Exception as exc: + return { + "status": "failed", + "failed_step_id": internal["id"], + "error": str(exc), + "query": internal["code"], + } + + return {"status": "success"} + finally: + conn.close() def _try_sql_step( diff --git a/app/agent/graph.py b/app/agent/graph.py index 0bc9ef3..076944b 100644 --- a/app/agent/graph.py +++ b/app/agent/graph.py @@ -28,7 +28,7 @@ def _append_error(state: AnalysisState, step: str, message: str, recoverable: bo def load_schema_context_node(state: AnalysisState) -> AnalysisState: step_name = "load_schema_context_node" _append_trace(state, step_name, "started", {}) - context = get_semantic_context() + context = get_semantic_context(state.get("source_ids")) state["dataset_context"] = context.schema_manifest _append_trace( state, @@ -162,8 +162,8 @@ def build_graph(): return graph.compile() -def run_analysis(query: str) -> AnalysisState: +def run_analysis(query: str, source_ids: list[str] | None = None) -> AnalysisState: """Execute the full workflow for a single user query.""" workflow = build_graph() - return workflow.invoke(create_initial_state(query)) + return workflow.invoke(create_initial_state(query, source_ids=source_ids)) diff --git a/app/agent/planner.py b/app/agent/planner.py index 0c9892a..5fc8bdb 100644 --- a/app/agent/planner.py +++ b/app/agent/planner.py @@ -38,7 +38,11 @@ def _field_terms(*values: str) -> set[str]: def _column_relevance_score(column: dict[str, Any], question_terms: set[str]) -> int: - column_terms = _field_terms(column.get("name", "")) + column_terms = _field_terms( + column.get("name", ""), + column.get("original_name", ""), + column.get("source_path", ""), + ) for hint in column.get("semantic_hints") or []: column_terms.update(_field_terms(hint)) @@ -50,10 +54,20 @@ def _column_relevance_score(column: dict[str, Any], question_terms: set[str]) -> def _relation_relevance_score(relation: dict[str, Any], question_terms: set[str]) -> int: - score = len(_field_terms(relation.get("name", ""), relation.get("grain", "")) & question_terms) * 4 + score = len( + _field_terms( + relation.get("name", ""), + relation.get("grain", ""), + relation.get("source_name", ""), + json.dumps(relation.get("lineage", {}), default=str), + ) + & question_terms + ) * 4 score += sum(_column_relevance_score(column, question_terms) for column in relation.get("columns") or []) for mapping in relation.get("semantic_mappings") or []: score += len(_field_terms(mapping.get("concept", "")) & question_terms) * 5 + if relation.get("is_primary"): + score += 3 return score diff --git a/app/agent/state.py b/app/agent/state.py index 18a1a7d..c940ed7 100644 --- a/app/agent/state.py +++ b/app/agent/state.py @@ -11,6 +11,7 @@ class AnalysisState(TypedDict): """Explicit state carried through the analytics workflow.""" query: str + source_ids: list[str] dataset_context: dict[str, Any] intent: str metric: str @@ -27,11 +28,12 @@ class AnalysisState(TypedDict): errors: list[dict[str, Any]] -def create_initial_state(query: str) -> AnalysisState: +def create_initial_state(query: str, source_ids: list[str] | None = None) -> AnalysisState: """Return the initial workflow state for a new request.""" return AnalysisState( query=query, + source_ids=list(source_ids or []), dataset_context={}, intent="", metric="", diff --git a/app/api/chat_routes.py b/app/api/chat_routes.py index 0c5f72c..5dfa109 100644 --- a/app/api/chat_routes.py +++ b/app/api/chat_routes.py @@ -30,6 +30,7 @@ ) from app.services.analysis_run import run_stored_analysis from app.services.inspection_persistence import save_inspection_for_assistant_message +from app.uploads.service import get_authorized_source_ids from app.utils.logging import get_logger @@ -146,8 +147,19 @@ def chat_turn( db.add(user_msg) db.flush() + requested_source_ids = list(dict.fromkeys(body.source_ids or [])) + if requested_source_ids: + valid_source_ids = get_authorized_source_ids(db, current_user, requested_source_ids) + if len(valid_source_ids) != len(requested_source_ids): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"message": "Attach a valid uploaded data source before running analysis."}, + ) + else: + valid_source_ids = None + try: - analysis_run = run_stored_analysis(body.query) + analysis_run = run_stored_analysis(body.query, source_ids=valid_source_ids) analysis_result = analysis_run.response inspection_payload = analysis_run.inspection except Exception as exc: # pragma: no cover - defensive API fallback diff --git a/app/api/routes.py b/app/api/routes.py index da1cd92..a3a450c 100644 --- a/app/api/routes.py +++ b/app/api/routes.py @@ -2,18 +2,19 @@ from __future__ import annotations -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, Response, UploadFile, status from sqlalchemy.orm import Session -from app.api.workspace import get_inspection, profile_upload -from app.auth.deps import get_current_user_optional +from app.api.workspace import get_inspection +from app.auth.deps import get_current_user, get_current_user_optional from app.config import get_settings from app.db.session import get_db from app.models.conversation import Conversation from app.models.inspection_snapshot import InspectionSnapshot from app.models.user import User -from app.schemas import AnalyzeRequest, AnalyzeResponse, HealthResponse, InspectionData, InspectionResponse, SampleQuestionsResponse, UploadResponse +from app.schemas import AnalyzeRequest, AnalyzeResponse, HealthResponse, InspectionData, InspectionResponse, SampleQuestionsResponse, UploadedAsset, UploadResponse from app.services.analysis_run import run_stored_analysis +from app.uploads.service import create_user_upload, delete_user_upload, get_authorized_source_ids, list_user_uploads from app.utils.constants import SAMPLE_QUESTIONS from app.utils.logging import get_logger @@ -37,18 +38,52 @@ def sample_questions() -> SampleQuestionsResponse: return SampleQuestionsResponse(questions=SAMPLE_QUESTIONS) +@router.get("/uploads", response_model=list[UploadedAsset]) +def list_uploads( + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> list[UploadedAsset]: + """Return uploads owned by the signed-in user.""" + + return list_user_uploads(db, current_user) + + @router.post("/uploads", response_model=UploadResponse) -async def upload_dataset(file: UploadFile = File(...)) -> UploadResponse: +async def upload_dataset( + file: UploadFile = File(...), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> UploadResponse: """Accept a workspace upload and return a profiled asset summary.""" contents = await file.read() try: - asset = profile_upload(file.filename or "upload.csv", contents) + asset = create_user_upload( + db, + current_user, + filename=file.filename or "upload.csv", + content_type=file.content_type, + content=contents, + ) except ValueError as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail={"message": str(exc)}) from exc return UploadResponse(asset=asset, fallback=False) +@router.delete("/uploads/{source_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_upload( + source_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> Response: + """Delete one upload owned by the signed-in user.""" + + deleted = delete_user_upload(db, current_user, source_id) + if not deleted: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail={"message": "Upload not found."}) + return Response(status_code=status.HTTP_204_NO_CONTENT) + + @router.get("/inspections/{inspection_id}", response_model=InspectionResponse) def inspection_details( inspection_id: str, @@ -94,17 +129,35 @@ def inspection_details( description=( "**Not the primary product API.** For normal use, authenticated clients should call " "`POST /chat`, which persists conversations, messages, and inspection snapshots.\n\n" - "This endpoint runs the same analytics pipeline without auth, without database persistence, " + "This endpoint runs the same analytics pipeline with auth but without conversation/database persistence, " "and keeps the inspection payload only in process memory (lost on restart). Use it for " "local debugging, Swagger/Postman checks, and quick stateless demos.\n\n" "Response shape aligns with the analysis/trace/steps portion of `POST /chat`." ), ) -def analyze(request: AnalyzeRequest) -> AnalyzeResponse: - """Run analytics without persistence (see OpenAPI ``description`` — prefer ``POST /chat`` for product flows).""" +def analyze( + request: AnalyzeRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> AnalyzeResponse: + """Run analytics with auth but without persistence (see OpenAPI ``description`` — prefer ``POST /chat``).""" + + requested_source_ids = list(dict.fromkeys(request.source_ids or [])) + if not requested_source_ids: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"message": "Upload and attach at least one CSV or JSON data source before running analysis."}, + ) + + valid_source_ids = get_authorized_source_ids(db, current_user, requested_source_ids) + if len(valid_source_ids) != len(requested_source_ids): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"message": "Attach a valid uploaded data source before running analysis."}, + ) try: - return run_stored_analysis(request.query).response + return run_stored_analysis(request.query, source_ids=valid_source_ids).response except Exception as exc: # pragma: no cover - defensive API fallback logger.exception("Analyze request failed", extra={"query": request.query}) settings = get_settings() diff --git a/app/api/workspace.py b/app/api/workspace.py index e9b47a2..c67b238 100644 --- a/app/api/workspace.py +++ b/app/api/workspace.py @@ -4,17 +4,13 @@ import json import re -import tempfile -from dataclasses import dataclass from datetime import datetime, timezone -from io import BytesIO -from pathlib import Path from threading import Lock from typing import Any from uuid import uuid4 -import pandas as pd - +from app.data.registry import clear_source_registry, ingest_source +from app.data.semantic_model import clear_semantic_context_cache from app.schemas import AnalyzeResponse, ArtifactSummary, InspectionData, MetadataItem, ResultTableData, TraceEntry, UploadedAsset, ValidationCheck @@ -29,58 +25,22 @@ _STORE_LOCK = Lock() _INSPECTIONS: dict[str, InspectionData] = {} -_UPLOAD_DIR = Path(tempfile.gettempdir()) / "planera_uploads" - - -@dataclass -class StoredUpload: - """Backend-only metadata for uploaded files.""" - - asset: UploadedAsset - file_path: Path - - -_UPLOADS: dict[str, StoredUpload] = {} def clear_workspace_state() -> None: """Reset in-memory upload/inspection storage for tests.""" with _STORE_LOCK: - for stored in _UPLOADS.values(): - if stored.file_path.exists(): - stored.file_path.unlink(missing_ok=True) - _UPLOADS.clear() _INSPECTIONS.clear() + clear_source_registry() + clear_semantic_context_cache() def profile_upload(filename: str, content: bytes) -> UploadedAsset: - """Profile an uploaded CSV/TSV file and persist it to a temp location.""" - - safe_name = Path(filename or "upload.csv").name - frame = _read_uploaded_frame(safe_name, content) - row_count = int(len(frame)) - column_count = int(len(frame.columns)) - asset = UploadedAsset( - id=_short_id("upload"), - name=safe_name, - type=_derive_file_type(safe_name), - source="Workspace upload", - sizeLabel=_bytes_to_size(len(content)), - uploadedAt=_now_iso(), - status="verified", - rows=row_count, - columns=column_count, - summary=_build_upload_summary(frame, row_count, column_count), - ) - - _UPLOAD_DIR.mkdir(parents=True, exist_ok=True) - file_path = _UPLOAD_DIR / f"{asset.id}_{safe_name}" - file_path.write_bytes(content) - - with _STORE_LOCK: - _UPLOADS[asset.id] = StoredUpload(asset=asset, file_path=file_path) + """Persist an uploaded structured dataset to the source registry.""" + asset = ingest_source(filename, content) + clear_semantic_context_cache() return asset @@ -100,35 +60,6 @@ def get_inspection(inspection_id: str) -> InspectionData | None: return _INSPECTIONS.get(inspection_id) -def _read_uploaded_frame(filename: str, content: bytes) -> pd.DataFrame: - if not content: - raise ValueError("Uploaded file is empty.") - - suffix = Path(filename).suffix.lower() - buffer = BytesIO(content) - - try: - if suffix in {".csv"}: - return pd.read_csv(buffer) - if suffix in {".tsv", ".tab"}: - return pd.read_csv(buffer, sep="\t") - if suffix in {".txt"}: - return pd.read_csv(buffer, sep=None, engine="python") - except Exception as exc: # pragma: no cover - pandas error details vary - raise ValueError(f"Could not parse {filename} as a structured text dataset.") from exc - - raise ValueError("Only CSV, TSV, TAB, and TXT uploads are currently supported.") - - -def _build_upload_summary(frame: pd.DataFrame, row_count: int, column_count: int) -> str: - preview_columns = [str(column) for column in list(frame.columns[:4])] - column_label = ", ".join(preview_columns) - suffix = "..." if column_count > 4 else "" - if column_label: - return f"Profiled {row_count} rows across {column_count} columns. Leading columns: {column_label}{suffix}." - return f"Profiled {row_count} rows across {column_count} columns." - - def _build_inspection(inspection_id: str, prompt: str, response: AnalyzeResponse) -> InspectionData: executed_steps = response.executed_steps or [] primary_artifact = _pick_primary_artifact(response) @@ -463,20 +394,5 @@ def _short_id(prefix: str) -> str: return f"{prefix}_{uuid4().hex[:8]}" -def _derive_file_type(filename: str) -> str: - suffix = Path(filename).suffix.lstrip(".").upper() - return suffix or "FILE" - - -def _bytes_to_size(num_bytes: int) -> str: - if num_bytes < 1024: - return f"{num_bytes} B" - if num_bytes < 1024 * 1024: - return f"{num_bytes / 1024:.1f} KB" - if num_bytes < 1024 * 1024 * 1024: - return f"{num_bytes / (1024 * 1024):.1f} MB" - return f"{num_bytes / (1024 * 1024 * 1024):.1f} GB" - - def _now_iso() -> str: return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") diff --git a/app/config.py b/app/config.py index 7c9d9d8..52a683f 100644 --- a/app/config.py +++ b/app/config.py @@ -20,6 +20,8 @@ class Settings(BaseSettings): api_host: str = "0.0.0.0" api_port: int = 8000 data_dir: Path = Field(default=BASE_DIR / "data") + registry_path: Path = Field(default=BASE_DIR / "data" / "source_registry.duckdb") + upload_storage_dir: Path = Field(default=BASE_DIR / "data" / "uploads", alias="UPLOAD_STORAGE_DIR") crm_path: Path = Field(default=BASE_DIR / "data" / "crm.csv") subscriptions_path: Path = Field(default=BASE_DIR / "data" / "subscriptions.csv") crm_dataset_dir: Path = Field(default=BASE_DIR / "data" / "CRM+Sales+Opportunities") diff --git a/app/data/registry.py b/app/data/registry.py new file mode 100644 index 0000000..71d0e19 --- /dev/null +++ b/app/data/registry.py @@ -0,0 +1,1172 @@ +"""Persistent DuckDB-backed source registry and source-package ingestion.""" + +from __future__ import annotations + +import json +import re +from dataclasses import dataclass, field +from datetime import datetime, timezone +from hashlib import sha256 +from io import BytesIO +from pathlib import Path +from typing import Any +from uuid import uuid4 + +import duckdb +import pandas as pd + +from app.config import get_settings +from app.schemas import SchemaColumn, SchemaConceptMapping, SchemaJoinKey, SchemaManifest, SchemaRelation, UploadedAsset + + +_SOURCES_TABLE = "__planera_registry_data_sources" +_RELATIONS_TABLE = "__planera_registry_source_relations" +_COLUMNS_TABLE = "__planera_registry_source_columns" +_LINKS_TABLE = "__planera_registry_source_links" +_SYSTEM_COLUMNS = {"record_id", "parent_record_id", "ordinal"} +_SEMANTIC_ALIAS_LEXICON: dict[str, list[str]] = { + "owner": ["agent", "rep", "representative", "sales rep", "assignee"], + "manager": ["manager", "lead", "supervisor", "team lead"], + "regional_office": ["region", "regional office", "office", "territory"], + "account_id": ["account", "customer", "client", "account identifier"], + "deal_id": ["deal", "opportunity", "opportunity identifier"], + "deal_value": ["revenue", "deal size", "amount", "value"], + "stage": ["status stage", "pipeline stage"], + "segment": ["customer segment", "market segment"], + "pipeline_velocity_days": ["pipeline velocity", "cycle time", "sales cycle length"], +} + + +@dataclass(frozen=True) +class SourceLinkPackage: + """Explicit relation link stored in the registry.""" + + left_source_id: str + left_relation_name: str + right_source_id: str + right_relation_name: str + join_keys: list[dict[str, str]] + link_type: str = "explicit" + is_explicit: bool = True + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class RelationPackage: + """One materialized relation plus its normalized schema metadata.""" + + relation: SchemaRelation + frame: pd.DataFrame + + +@dataclass(frozen=True) +class SourcePackage: + """Unified internal representation for uploaded or built-in sources.""" + + source_id: str + source_name: str + source_slug: str + source_kind: str + source_format: str + origin: str + file_name: str + file_type: str + size_bytes: int + raw_payload: bytes | None + relations: list[RelationPackage] + links: list[SourceLinkPackage] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def primary_relation_name(self) -> str: + primary = next((item.relation.name for item in self.relations if item.relation.is_primary), self.relations[0].relation.name) + return primary + + +def get_registry_path() -> Path: + """Return the on-disk DuckDB registry path.""" + + return get_settings().registry_path + + +def _connect_registry(read_only: bool = False) -> duckdb.DuckDBPyConnection: + path = get_registry_path() + path.parent.mkdir(parents=True, exist_ok=True) + if read_only and not path.exists(): + conn = duckdb.connect(database=str(path), read_only=False) + _ensure_registry_tables(conn) + return conn + conn = duckdb.connect(database=str(path), read_only=read_only) + if not read_only: + _ensure_registry_tables(conn) + return conn + + +def _ensure_registry_tables(conn: duckdb.DuckDBPyConnection) -> None: + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {_SOURCES_TABLE} ( + source_id TEXT PRIMARY KEY, + source_name TEXT, + source_slug TEXT, + source_kind TEXT, + source_format TEXT, + origin TEXT, + file_name TEXT, + file_type TEXT, + size_bytes BIGINT, + created_at TEXT, + content_hash TEXT, + primary_relation_name TEXT, + relation_count INTEGER, + row_count INTEGER, + status TEXT, + raw_payload BLOB, + metadata_json TEXT + ) + """ + ) + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {_RELATIONS_TABLE} ( + relation_name TEXT PRIMARY KEY, + source_id TEXT, + kind TEXT, + is_primary BOOLEAN, + parent_relation TEXT, + row_count BIGINT, + grain TEXT, + identifier_columns_json TEXT, + time_columns_json TEXT, + measure_columns_json TEXT, + dimension_columns_json TEXT, + lineage_json TEXT + ) + """ + ) + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {_COLUMNS_TABLE} ( + source_id TEXT, + relation_name TEXT, + ordinal INTEGER, + column_name TEXT, + dtype TEXT, + type_family TEXT, + original_name TEXT, + source_path TEXT, + nullable BOOLEAN, + semantic_hints_json TEXT + ) + """ + ) + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {_LINKS_TABLE} ( + link_id TEXT PRIMARY KEY, + left_source_id TEXT, + left_relation_name TEXT, + right_source_id TEXT, + right_relation_name TEXT, + link_type TEXT, + is_explicit BOOLEAN, + join_keys_json TEXT, + metadata_json TEXT + ) + """ + ) + + +def clear_source_registry() -> None: + """Remove the persisted registry database.""" + + path = get_registry_path() + if path.exists(): + path.unlink() + from app.data.semantic_model import clear_semantic_context_cache + + clear_semantic_context_cache() + + +def _now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + +def _short_id(prefix: str) -> str: + return f"{prefix}_{uuid4().hex[:8]}" + + +def _slugify(value: str) -> str: + cleaned = re.sub(r"[^a-zA-Z0-9]+", "_", value.strip().lower()).strip("_") + cleaned = re.sub(r"_+", "_", cleaned) + return cleaned[:48] or "source" + + +def _safe_identifier(value: str) -> str: + normalized = _slugify(value) + if normalized[0].isdigit(): + return f"c_{normalized}" + return normalized + + +def _derive_file_type(filename: str, source_format: str) -> str: + suffix = Path(filename).suffix.lstrip(".").upper() + if suffix: + return suffix + return source_format.upper() or "FILE" + + +def _split_identifier(value: str) -> list[str]: + cleaned = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", value) + return [token.lower() for token in re.split(r"[^a-zA-Z0-9]+", cleaned) if token] + + +def _type_family(dtype: str) -> str: + lower = dtype.lower() + if any(token in lower for token in ("int", "float", "double", "decimal")): + return "number" + if "bool" in lower: + return "boolean" + if any(token in lower for token in ("datetime", "timestamp", "date")): + return "datetime" + if any(token in lower for token in ("object", "string", "category")): + return "string" + return "unknown" + + +def _semantic_hints(column_name: str, original_name: str = "", source_path: str = "") -> list[str]: + tokens = _split_identifier(column_name) + hints = {column_name, column_name.replace("_", " ")} + if original_name: + hints.add(original_name) + hints.add(original_name.replace("_", " ")) + if source_path: + hints.add(source_path) + hints.add(source_path.replace(".", " ")) + hints.update(tokens) + if column_name.endswith("_id"): + base = column_name[: -len("_id")].replace("_", " ").strip() + if base: + hints.add(f"{base} id") + hints.add(f"{base} identifier") + + for key, aliases in _SEMANTIC_ALIAS_LEXICON.items(): + key_tokens = set(_split_identifier(key)) + if column_name == key or key_tokens.issubset(set(tokens)): + hints.update(aliases) + + return sorted(hint for hint in hints if hint) + + +def _build_semantic_mappings(columns: list[SchemaColumn]) -> list[SchemaConceptMapping]: + concept_to_columns: dict[str, set[str]] = {} + for column in columns: + for hint in column.semantic_hints: + normalized_hint = hint.strip().lower() + if not normalized_hint or normalized_hint == column.name.lower(): + continue + concept_to_columns.setdefault(normalized_hint, set()).add(column.name) + + mappings: list[SchemaConceptMapping] = [] + for concept, mapped_columns in sorted(concept_to_columns.items()): + if len(concept) < 4: + continue + mappings.append(SchemaConceptMapping(concept=concept, columns=sorted(mapped_columns))) + return mappings[:20] + + +def _is_identifier_column(column_name: str, series: pd.Series) -> bool: + lowered = column_name.lower() + if lowered == "id" or lowered.endswith("_id"): + return True + non_null = series.dropna() + return bool(len(non_null) == len(series) and len(non_null) > 0 and non_null.nunique(dropna=False) == len(series)) + + +def _infer_grain(name: str, frame: pd.DataFrame, identifier_columns: list[str]) -> str: + if identifier_columns: + primary = identifier_columns[0] + if primary.lower().endswith("_id"): + entity = primary[: -len("_id")].replace("_", " ").strip() + if entity: + return f"Approximately one row per {entity}" + return f"Rows can be keyed by {primary}" + return f"Rows represent records in {name}" + + +def _normalize_scalar(value: Any) -> Any: + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, (datetime, pd.Timestamp)): + return pd.Timestamp(value) + return json.dumps(value, default=str) + + +def _parse_text_bool(series: pd.Series) -> pd.Series | None: + non_null = series.dropna().astype(str).str.strip().str.lower() + if non_null.empty: + return None + if set(non_null.unique()).issubset({"true", "false"}): + mapped = series.map( + lambda item: None + if pd.isna(item) + else str(item).strip().lower() == "true" + ) + return mapped.astype("boolean") + return None + + +def _parse_text_numeric(series: pd.Series) -> pd.Series | None: + non_null = series.dropna() + if non_null.empty: + return None + parsed = pd.to_numeric(non_null, errors="coerce") + if parsed.notna().sum() != len(non_null): + return None + return pd.to_numeric(series, errors="coerce") + + +def _parse_text_datetime(column_name: str, series: pd.Series) -> pd.Series | None: + lowered_name = column_name.lower() + if not ( + any(token in lowered_name for token in ("date", "time", "timestamp")) + or lowered_name.endswith("_at") + or lowered_name == "at" + ): + return None + non_null = series.dropna() + if non_null.empty: + return None + parsed = pd.to_datetime(non_null, errors="coerce", utc=False) + if parsed.notna().sum() / len(non_null) < 0.8: + return None + return pd.to_datetime(series, errors="coerce", utc=False) + + +def _coerce_frame_types(frame: pd.DataFrame) -> pd.DataFrame: + coerced = frame.copy() + for column_name in coerced.columns: + series = coerced[column_name] + if pd.api.types.is_bool_dtype(series) or pd.api.types.is_numeric_dtype(series) or pd.api.types.is_datetime64_any_dtype(series): + continue + if not (pd.api.types.is_object_dtype(series) or pd.api.types.is_string_dtype(series)): + continue + + parsed = _parse_text_bool(series) + if parsed is None: + parsed = _parse_text_numeric(series) + if parsed is None: + parsed = _parse_text_datetime(str(column_name), series) + if parsed is not None: + coerced[column_name] = parsed + return coerced.convert_dtypes() + + +def _rename_system_columns(frame: pd.DataFrame) -> pd.DataFrame: + renamed = frame.copy() + next_names: dict[str, str] = {} + for column in renamed.columns: + safe_name = _safe_identifier(str(column)) + candidate = safe_name + if candidate in _SYSTEM_COLUMNS: + candidate = f"{candidate}_value" + counter = 2 + while candidate in next_names.values(): + candidate = f"{safe_name}_{counter}" + counter += 1 + next_names[column] = candidate + return renamed.rename(columns=next_names) + + +def _ensure_record_id(frame: pd.DataFrame) -> pd.DataFrame: + framed = frame.copy() + if "record_id" in framed.columns: + framed = framed.rename(columns={"record_id": "record_id_value"}) + framed.insert(0, "record_id", [f"record_{index + 1}" for index in range(len(framed))]) + return framed + + +def _schema_relation_for_frame( + *, + relation_name: str, + source_id: str, + source_name: str, + frame: pd.DataFrame, + kind: str, + is_primary: bool, + parent_relation: str | None, + join_keys: list[SchemaJoinKey], + lineage: dict[str, Any], + column_paths: dict[str, dict[str, str]] | None = None, +) -> SchemaRelation: + column_paths = column_paths or {} + columns: list[SchemaColumn] = [] + identifier_columns: list[str] = [] + time_columns: list[str] = [] + measure_columns: list[str] = [] + dimension_columns: list[str] = [] + + for column_name in frame.columns: + dtype = str(frame[column_name].dtype) + family = _type_family(dtype) + path_meta = column_paths.get(str(column_name), {}) + column = SchemaColumn( + name=str(column_name), + dtype=dtype, + type_family=family, + original_name=path_meta.get("original_name", str(column_name)), + source_path=path_meta.get("source_path", str(column_name)), + nullable=bool(frame[column_name].isna().any()), + semantic_hints=_semantic_hints( + str(column_name), + original_name=path_meta.get("original_name", str(column_name)), + source_path=path_meta.get("source_path", str(column_name)), + ), + ) + columns.append(column) + + if _is_identifier_column(str(column_name), frame[column_name]): + identifier_columns.append(str(column_name)) + if family == "datetime": + time_columns.append(str(column_name)) + elif family == "number": + measure_columns.append(str(column_name)) + else: + dimension_columns.append(str(column_name)) + + return SchemaRelation( + name=relation_name, + kind=kind, + source_id=source_id, + source_name=source_name, + is_primary=is_primary, + parent_relation=parent_relation, + row_count=int(len(frame)), + grain=_infer_grain(relation_name, frame, identifier_columns), + identifier_columns=identifier_columns, + time_columns=time_columns, + measure_columns=measure_columns, + dimension_columns=dimension_columns, + join_keys=join_keys, + lineage=lineage, + columns=columns, + semantic_mappings=_build_semantic_mappings(columns), + ) + + +def _source_filter_sql(source_ids: list[str] | None) -> tuple[str, list[str]]: + if not source_ids: + return "", [] + placeholders = ", ".join(["?"] * len(source_ids)) + return f" WHERE source_id IN ({placeholders})", list(source_ids) + + +def _link_filter_sql(source_ids: list[str] | None) -> tuple[str, list[str]]: + if not source_ids: + return "", [] + placeholders = ", ".join(["?"] * len(source_ids)) + return ( + f" WHERE left_source_id IN ({placeholders}) AND right_source_id IN ({placeholders})", + list(source_ids) + list(source_ids), + ) + + +def _read_uploaded_frame(filename: str, content: bytes) -> pd.DataFrame: + if not content: + raise ValueError("Uploaded file is empty.") + + suffix = Path(filename).suffix.lower() + buffer = BytesIO(content) + try: + if suffix == ".csv": + return pd.read_csv(buffer) + except Exception as exc: # pragma: no cover - pandas errors vary + raise ValueError(f"Could not parse {filename} as a structured text dataset.") from exc + + raise ValueError("Only CSV and JSON uploads are currently supported.") + + +def _build_uploaded_asset(package: SourcePackage) -> UploadedAsset: + primary = next(item for item in package.relations if item.relation.is_primary) + visible_columns = [column.name for column in primary.relation.columns if column.name != "record_id"] + relation_count = len(package.relations) + summary = ( + f"Persisted {primary.relation.row_count} rows across {len(visible_columns)} columns" + f" into {relation_count} relation{'s' if relation_count != 1 else ''}. " + f"Primary relation: {package.primary_relation_name}." + ) + return UploadedAsset( + id=package.source_id, + name=package.file_name, + type=package.file_type, + source=package.origin, + sizeLabel=_bytes_to_size(package.size_bytes), + uploadedAt=_now_iso(), + status="verified", + rows=primary.relation.row_count, + columns=len(visible_columns), + relationCount=relation_count, + primaryRelationName=package.primary_relation_name, + summary=summary, + ) + + +def _bytes_to_size(num_bytes: int) -> str: + if num_bytes < 1024: + return f"{num_bytes} B" + if num_bytes < 1024 * 1024: + return f"{num_bytes / 1024:.1f} KB" + if num_bytes < 1024 * 1024 * 1024: + return f"{num_bytes / (1024 * 1024):.1f} MB" + return f"{num_bytes / (1024 * 1024 * 1024):.1f} GB" + + +def _persist_source_package(conn: duckdb.DuckDBPyConnection, package: SourcePackage) -> None: + created_at = _now_iso() + conn.execute("BEGIN TRANSACTION") + try: + for relation in package.relations: + conn.execute(f'DELETE FROM {_RELATIONS_TABLE} WHERE relation_name = ?', [relation.relation.name]) + conn.execute(f'DELETE FROM {_COLUMNS_TABLE} WHERE relation_name = ?', [relation.relation.name]) + temp_name = f"tmp_{uuid4().hex[:8]}" + conn.register(temp_name, relation.frame) + conn.execute(f'DROP TABLE IF EXISTS "{relation.relation.name}"') + conn.execute(f'CREATE TABLE "{relation.relation.name}" AS SELECT * FROM "{temp_name}"') + conn.unregister(temp_name) + conn.execute( + f""" + INSERT INTO {_RELATIONS_TABLE} ( + relation_name, source_id, kind, is_primary, parent_relation, row_count, grain, + identifier_columns_json, time_columns_json, measure_columns_json, dimension_columns_json, lineage_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + relation.relation.name, + package.source_id, + relation.relation.kind, + relation.relation.is_primary, + relation.relation.parent_relation, + relation.relation.row_count, + relation.relation.grain, + json.dumps(relation.relation.identifier_columns), + json.dumps(relation.relation.time_columns), + json.dumps(relation.relation.measure_columns), + json.dumps(relation.relation.dimension_columns), + json.dumps(relation.relation.lineage), + ], + ) + column_rows = [ + ( + package.source_id, + relation.relation.name, + ordinal, + column.name, + column.dtype, + column.type_family, + column.original_name, + column.source_path, + column.nullable, + json.dumps(column.semantic_hints), + ) + for ordinal, column in enumerate(relation.relation.columns, start=1) + ] + if column_rows: + conn.executemany( + f""" + INSERT INTO {_COLUMNS_TABLE} ( + source_id, relation_name, ordinal, column_name, dtype, type_family, + original_name, source_path, nullable, semantic_hints_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + column_rows, + ) + + conn.execute(f'DELETE FROM {_LINKS_TABLE} WHERE left_source_id = ? OR right_source_id = ?', [package.source_id, package.source_id]) + for link in package.links: + conn.execute( + f""" + INSERT INTO {_LINKS_TABLE} ( + link_id, left_source_id, left_relation_name, right_source_id, right_relation_name, + link_type, is_explicit, join_keys_json, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + _short_id("link"), + link.left_source_id, + link.left_relation_name, + link.right_source_id, + link.right_relation_name, + link.link_type, + link.is_explicit, + json.dumps(link.join_keys), + json.dumps(link.metadata), + ], + ) + + primary = next(item for item in package.relations if item.relation.is_primary) + row_count = primary.relation.row_count + content_hash = sha256(package.raw_payload or b"").hexdigest() + raw_payload = None if package.source_kind == "upload" else package.raw_payload + conn.execute(f'DELETE FROM {_SOURCES_TABLE} WHERE source_id = ?', [package.source_id]) + conn.execute( + f""" + INSERT INTO {_SOURCES_TABLE} ( + source_id, source_name, source_slug, source_kind, source_format, origin, file_name, file_type, + size_bytes, created_at, content_hash, primary_relation_name, relation_count, row_count, status, raw_payload, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + package.source_id, + package.source_name, + package.source_slug, + package.source_kind, + package.source_format, + package.origin, + package.file_name, + package.file_type, + package.size_bytes, + created_at, + content_hash, + package.primary_relation_name, + len(package.relations), + row_count, + "verified", + raw_payload, + json.dumps(package.metadata), + ], + ) + conn.execute("COMMIT") + except Exception: + conn.execute("ROLLBACK") + raise + + +def _build_primary_relation_name(source_slug: str, source_id: str) -> str: + return f"{source_slug}_{source_id.split('_')[-1]}" + + +def _build_child_relation_name(primary_relation_name: str, path_segments: tuple[str, ...]) -> str: + suffix = "__".join(_safe_identifier(segment) for segment in path_segments) + return f"{primary_relation_name}__{suffix}" + + +def _ingest_csv_source(filename: str, content: bytes, source_id: str | None = None) -> SourcePackage: + safe_name = Path(filename or "upload.csv").name + frame = _ensure_record_id(_rename_system_columns(_read_uploaded_frame(safe_name, content))) + frame = _coerce_frame_types(frame) + source_id = source_id or _short_id("source") + source_slug = _slugify(Path(safe_name).stem) + relation_name = _build_primary_relation_name(source_slug, source_id) + column_paths = { + str(column): { + "original_name": str(column), + "source_path": str(column), + } + for column in frame.columns + } + relation = _schema_relation_for_frame( + relation_name=relation_name, + source_id=source_id, + source_name=safe_name, + frame=frame, + kind="table", + is_primary=True, + parent_relation=None, + join_keys=[], + lineage={"format": "csv", "json_path": "$"}, + column_paths=column_paths, + ) + return SourcePackage( + source_id=source_id, + source_name=safe_name, + source_slug=source_slug, + source_kind="upload", + source_format="csv", + origin="Workspace upload", + file_name=safe_name, + file_type=_derive_file_type(safe_name, "csv"), + size_bytes=len(content), + raw_payload=content, + relations=[RelationPackage(relation=relation, frame=frame)], + ) + + +def _ingest_json_source(filename: str, content: bytes, source_id: str | None = None) -> SourcePackage: + safe_name = Path(filename or "upload.json").name + try: + payload = json.loads(content.decode("utf-8")) + except Exception as exc: # pragma: no cover - json errors vary + raise ValueError(f"Could not parse {safe_name} as JSON.") from exc + + if isinstance(payload, dict): + records = [payload] + elif isinstance(payload, list) and all(isinstance(item, dict) for item in payload): + records = payload + elif isinstance(payload, list) and not payload: + records = [] + else: + raise ValueError("JSON uploads must contain a top-level object or an array of objects.") + + source_id = source_id or _short_id("source") + source_slug = _slugify(Path(safe_name).stem) + primary_relation_name = _build_primary_relation_name(source_slug, source_id) + + relation_rows: dict[tuple[str, ...], list[dict[str, Any]]] = {(): []} + relation_column_paths: dict[tuple[str, ...], dict[str, dict[str, str]]] = {(): {"record_id": {"original_name": "record_id", "source_path": "$.record_id"}}} + relation_parents: dict[tuple[str, ...], str | None] = {(): None} + relation_links: list[SourceLinkPackage] = [] + + def remember_column(path_key: tuple[str, ...], column_name: str, original_name: str, source_path: str) -> None: + relation_column_paths.setdefault(path_key, {}) + relation_column_paths[path_key][column_name] = {"original_name": original_name, "source_path": source_path} + + def walk_object( + obj: dict[str, Any], + *, + path_key: tuple[str, ...], + row: dict[str, Any], + record_id: str, + full_path: tuple[str, ...], + ) -> None: + for raw_key, value in obj.items(): + safe_key = _safe_identifier(str(raw_key)) + full_column_path = full_path + (str(raw_key),) + if safe_key in _SYSTEM_COLUMNS: + safe_key = f"{safe_key}_value" + if isinstance(value, dict): + walk_object( + value, + path_key=path_key, + row=row, + record_id=record_id, + full_path=full_column_path, + ) + continue + if isinstance(value, list): + handle_array( + value, + parent_relation_path=path_key, + parent_relation_name=primary_relation_name if not path_key else _build_child_relation_name(primary_relation_name, path_key), + parent_record_id=record_id, + array_path=path_key + (safe_key,), + full_path=full_column_path, + ) + continue + + column_name = "__".join(_safe_identifier(part) for part in full_column_path[len(path_key) :]) + if column_name in _SYSTEM_COLUMNS: + column_name = f"{column_name}_value" + row[column_name] = _normalize_scalar(value) + remember_column(path_key, column_name, str(raw_key), ".".join(full_column_path)) + + def handle_array( + values: list[Any], + *, + parent_relation_path: tuple[str, ...], + parent_relation_name: str, + parent_record_id: str, + array_path: tuple[str, ...], + full_path: tuple[str, ...], + ) -> None: + relation_rows.setdefault(array_path, []) + relation_column_paths.setdefault( + array_path, + { + "record_id": {"original_name": "record_id", "source_path": ".".join(full_path) + ".record_id"}, + "parent_record_id": {"original_name": "parent_record_id", "source_path": ".".join(full_path) + ".parent_record_id"}, + "ordinal": {"original_name": "ordinal", "source_path": ".".join(full_path) + ".ordinal"}, + }, + ) + relation_parents[array_path] = parent_relation_name + relation_name = _build_child_relation_name(primary_relation_name, array_path) + if not any(link.right_relation_name == relation_name for link in relation_links): + relation_links.append( + SourceLinkPackage( + left_source_id=source_id, + left_relation_name=parent_relation_name, + right_source_id=source_id, + right_relation_name=relation_name, + join_keys=[{"left_column": "record_id", "right_column": "parent_record_id"}], + link_type="parent_child", + is_explicit=False, + metadata={"json_path": ".".join(full_path), "parent_relation_path": ".".join(parent_relation_path)}, + ) + ) + for index, item in enumerate(values): + child_record_id = f"{parent_record_id}__{_safe_identifier(array_path[-1])}_{index + 1}" + child_row = { + "record_id": child_record_id, + "parent_record_id": parent_record_id, + "ordinal": index, + } + if isinstance(item, dict): + walk_object(item, path_key=array_path, row=child_row, record_id=child_record_id, full_path=full_path) + elif isinstance(item, list): + handle_array( + item, + parent_relation_path=array_path, + parent_relation_name=relation_name, + parent_record_id=child_record_id, + array_path=array_path + ("value",), + full_path=full_path + ("value",), + ) + else: + child_row["value"] = _normalize_scalar(item) + remember_column(array_path, "value", "value", ".".join(full_path)) + relation_rows[array_path].append(child_row) + + for index, item in enumerate(records, start=1): + if not isinstance(item, dict): + raise ValueError("JSON uploads must contain objects at the top level.") + row = {"record_id": f"record_{index}"} + walk_object(item, path_key=(), row=row, record_id=row["record_id"], full_path=()) + relation_rows[()].append(row) + + primary_frame = pd.DataFrame(relation_rows[()]) + if "record_id" not in primary_frame.columns: + primary_frame["record_id"] = pd.Series(dtype="string") + primary_frame = _coerce_frame_types(primary_frame) + relations: list[RelationPackage] = [ + RelationPackage( + relation=_schema_relation_for_frame( + relation_name=primary_relation_name, + source_id=source_id, + source_name=safe_name, + frame=primary_frame, + kind="table", + is_primary=True, + parent_relation=None, + join_keys=[], + lineage={"format": "json", "json_path": "$"}, + column_paths=relation_column_paths.get((), {}), + ), + frame=primary_frame, + ) + ] + + for relation_path in sorted(key for key in relation_rows if key): + relation_name = _build_child_relation_name(primary_relation_name, relation_path) + frame = _coerce_frame_types(pd.DataFrame(relation_rows[relation_path])) + join_keys = [ + SchemaJoinKey( + target_relation=relation_parents[relation_path] or primary_relation_name, + source_column="parent_record_id", + target_column="record_id", + link_type="parent_child", + ) + ] + relations.append( + RelationPackage( + relation=_schema_relation_for_frame( + relation_name=relation_name, + source_id=source_id, + source_name=safe_name, + frame=frame, + kind="table", + is_primary=False, + parent_relation=relation_parents[relation_path], + join_keys=join_keys, + lineage={"format": "json", "json_path": ".".join(relation_path)}, + column_paths=relation_column_paths.get(relation_path, {}), + ), + frame=frame, + ) + ) + + return SourcePackage( + source_id=source_id, + source_name=safe_name, + source_slug=source_slug, + source_kind="upload", + source_format="json", + origin="Workspace upload", + file_name=safe_name, + file_type=_derive_file_type(safe_name, "json"), + size_bytes=len(content), + raw_payload=content, + relations=relations, + links=relation_links, + ) + + +def ensure_builtin_sources() -> None: + """Backward-compatible no-op now that uploads are the only runtime data sources.""" + + +def ingest_source(filename: str, content: bytes, *, source_id: str | None = None) -> UploadedAsset: + """Persist a CSV/JSON upload into the registry and return its UI summary.""" + + safe_name = Path(filename or "upload.csv").name + suffix = Path(safe_name).suffix.lower() + if suffix == ".json": + package = _ingest_json_source(safe_name, content, source_id=source_id) + elif suffix == ".csv": + package = _ingest_csv_source(safe_name, content, source_id=source_id) + else: + raise ValueError("Only CSV and JSON uploads are currently supported.") + conn = _connect_registry(read_only=False) + try: + _persist_source_package(conn, package) + finally: + conn.close() + from app.data.semantic_model import clear_semantic_context_cache + + clear_semantic_context_cache() + return _build_uploaded_asset(package) + + +def delete_source(source_id: str) -> None: + """Remove one uploaded source and all of its relations from the registry.""" + + conn = _connect_registry(read_only=False) + try: + relation_names = [ + row[0] + for row in conn.execute( + f"SELECT relation_name FROM {_RELATIONS_TABLE} WHERE source_id = ?", + [source_id], + ).fetchall() + ] + conn.execute("BEGIN TRANSACTION") + try: + conn.execute(f"DELETE FROM {_COLUMNS_TABLE} WHERE source_id = ?", [source_id]) + conn.execute(f"DELETE FROM {_LINKS_TABLE} WHERE left_source_id = ? OR right_source_id = ?", [source_id, source_id]) + conn.execute(f"DELETE FROM {_RELATIONS_TABLE} WHERE source_id = ?", [source_id]) + conn.execute(f"DELETE FROM {_SOURCES_TABLE} WHERE source_id = ?", [source_id]) + for relation_name in relation_names: + conn.execute(f'DROP TABLE IF EXISTS "{relation_name}"') + conn.execute("COMMIT") + except Exception: + conn.execute("ROLLBACK") + raise + finally: + conn.close() + from app.data.semantic_model import clear_semantic_context_cache + + clear_semantic_context_cache() + + +def create_source_link( + left_relation_name: str, + left_column: str, + right_relation_name: str, + right_column: str, +) -> None: + """Create an explicit registry join path between two existing relations.""" + + conn = _connect_registry(read_only=False) + try: + rows = conn.execute( + f""" + SELECT relation_name, source_id + FROM {_RELATIONS_TABLE} + WHERE relation_name IN (?, ?) + """, + [left_relation_name, right_relation_name], + ).fetchall() + relation_to_source = {row[0]: row[1] for row in rows} + if left_relation_name not in relation_to_source or right_relation_name not in relation_to_source: + raise ValueError("Both relations must exist before creating an explicit source link.") + conn.execute( + f""" + INSERT INTO {_LINKS_TABLE} ( + link_id, left_source_id, left_relation_name, right_source_id, right_relation_name, + link_type, is_explicit, join_keys_json, metadata_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + _short_id("link"), + relation_to_source[left_relation_name], + left_relation_name, + relation_to_source[right_relation_name], + right_relation_name, + "explicit", + True, + json.dumps([{"left_column": left_column, "right_column": right_column}]), + json.dumps({}), + ], + ) + finally: + conn.close() + from app.data.semantic_model import clear_semantic_context_cache + + clear_semantic_context_cache() + + +def _load_link_map(conn: duckdb.DuckDBPyConnection, source_ids: list[str] | None = None) -> dict[str, list[SchemaJoinKey]]: + filter_sql, params = _link_filter_sql(source_ids) + rows = conn.execute( + f""" + SELECT left_source_id, left_relation_name, right_source_id, right_relation_name, link_type, join_keys_json + FROM {_LINKS_TABLE} + {filter_sql} + """, + params, + ).fetchall() + link_map: dict[str, list[SchemaJoinKey]] = {} + for row in rows: + _, left_relation_name, _, right_relation_name, link_type, join_keys_json = row + join_rows = json.loads(join_keys_json or "[]") + for join_row in join_rows: + left_key = SchemaJoinKey( + target_relation=right_relation_name, + source_column=join_row["left_column"], + target_column=join_row["right_column"], + link_type="parent_child" if link_type == "parent_child" else "explicit", + ) + right_key = SchemaJoinKey( + target_relation=left_relation_name, + source_column=join_row["right_column"], + target_column=join_row["left_column"], + link_type="parent_child" if link_type == "parent_child" else "explicit", + ) + link_map.setdefault(left_relation_name, []).append(left_key) + link_map.setdefault(right_relation_name, []).append(right_key) + return link_map + + +def get_schema_manifest(source_ids: list[str] | None = None) -> dict[str, Any]: + """Return the normalized schema manifest for the requested scope.""" + + conn = _connect_registry(read_only=True) + try: + filter_sql, params = _source_filter_sql(source_ids) + relation_rows = conn.execute( + f""" + SELECT relation_name, source_id, kind, is_primary, parent_relation, row_count, grain, + identifier_columns_json, time_columns_json, measure_columns_json, dimension_columns_json, lineage_json + FROM {_RELATIONS_TABLE} + {filter_sql} + ORDER BY source_id, is_primary DESC, relation_name + """, + params, + ).fetchall() + source_names = { + row[0]: row[1] + for row in conn.execute( + f"SELECT source_id, source_name FROM {_SOURCES_TABLE}{filter_sql}", + params, + ).fetchall() + } + link_map = _load_link_map(conn, source_ids) + relations: list[SchemaRelation] = [] + for row in relation_rows: + relation_name, source_id, kind, is_primary, parent_relation, row_count, grain, identifier_columns_json, time_columns_json, measure_columns_json, dimension_columns_json, lineage_json = row + column_rows = conn.execute( + f""" + SELECT column_name, dtype, type_family, original_name, source_path, nullable, semantic_hints_json + FROM {_COLUMNS_TABLE} + WHERE relation_name = ? + ORDER BY ordinal + """, + [relation_name], + ).fetchall() + columns = [ + SchemaColumn( + name=column_name, + dtype=dtype, + type_family=type_family, + original_name=original_name or column_name, + source_path=source_path or column_name, + nullable=bool(nullable), + semantic_hints=json.loads(semantic_hints_json or "[]"), + ) + for column_name, dtype, type_family, original_name, source_path, nullable, semantic_hints_json in column_rows + ] + relation = SchemaRelation( + name=relation_name, + kind=kind, + source_id=source_id, + source_name=source_names.get(source_id, source_id), + is_primary=bool(is_primary), + parent_relation=parent_relation, + row_count=int(row_count), + grain=grain or "", + identifier_columns=json.loads(identifier_columns_json or "[]"), + time_columns=json.loads(time_columns_json or "[]"), + measure_columns=json.loads(measure_columns_json or "[]"), + dimension_columns=json.loads(dimension_columns_json or "[]"), + join_keys=link_map.get(relation_name, []), + lineage=json.loads(lineage_json or "{}"), + columns=columns, + semantic_mappings=_build_semantic_mappings(columns), + ) + relations.append(relation) + + manifest = SchemaManifest( + reference_date="", + source="source_registry", + dialect="duckdb", + relations=relations, + views=[ + { + "name": relation.name, + "source_id": relation.source_id, + "source_name": relation.source_name, + "row_count": relation.row_count, + "is_primary": relation.is_primary, + "columns": [{"name": column.name, "dtype": column.dtype} for column in relation.columns], + } + for relation in relations + ], + ) + return manifest.model_dump() + finally: + conn.close() + + +def load_relation_frames(source_ids: list[str] | None = None) -> dict[str, pd.DataFrame]: + """Load materialized relation tables for the requested scope.""" + + conn = _connect_registry(read_only=True) + try: + filter_sql, params = _source_filter_sql(source_ids) + names = [ + row[0] + for row in conn.execute( + f"SELECT relation_name FROM {_RELATIONS_TABLE}{filter_sql} ORDER BY relation_name", + params, + ).fetchall() + ] + return {name: conn.execute(f'SELECT * FROM "{name}"').fetchdf() for name in names} + finally: + conn.close() + + +def get_registered_relation_names(source_ids: list[str] | None = None) -> list[str]: + """Return visible relation names for the requested scope.""" + + conn = _connect_registry(read_only=True) + try: + filter_sql, params = _source_filter_sql(source_ids) + return [ + row[0] + for row in conn.execute( + f"SELECT relation_name FROM {_RELATIONS_TABLE}{filter_sql} ORDER BY relation_name", + params, + ).fetchall() + ] + finally: + conn.close() + + +def get_upload_source_ids(source_ids: list[str] | None = None) -> list[str]: + """Return matching uploaded source ids for the provided ids.""" + + if not source_ids: + return [] + conn = _connect_registry(read_only=False) + try: + placeholders = ", ".join(["?"] * len(source_ids)) + rows = conn.execute( + f""" + SELECT source_id + FROM {_SOURCES_TABLE} + WHERE source_kind = 'upload' AND source_id IN ({placeholders}) + ORDER BY source_id + """, + list(dict.fromkeys(source_ids)), + ).fetchall() + return [row[0] for row in rows] + finally: + conn.close() diff --git a/app/data/semantic_model.py b/app/data/semantic_model.py index 3cdcef6..53d851c 100644 --- a/app/data/semantic_model.py +++ b/app/data/semantic_model.py @@ -1,8 +1,7 @@ -"""Dataset views and schema manifest for planner and executor.""" +"""Registry-backed dataset views and schema manifest for planner and executor.""" from __future__ import annotations -import re from dataclasses import dataclass from functools import lru_cache from typing import Any @@ -10,8 +9,7 @@ import duckdb import pandas as pd -from app.data.loader import load_data -from app.schemas import SchemaColumn, SchemaConceptMapping, SchemaManifest, SchemaRelation +from app.data.registry import get_registered_relation_names, get_registry_path, get_schema_manifest, load_relation_frames @dataclass(frozen=True) @@ -25,175 +23,63 @@ class SemanticContext: schema_manifest: dict[str, Any] -_SEMANTIC_ALIAS_LEXICON: dict[str, list[str]] = { - "owner": ["agent", "rep", "representative", "sales rep", "assignee"], - "manager": ["manager", "lead", "supervisor", "team lead"], - "regional_office": ["region", "regional office", "office", "territory"], - "account_id": ["account", "customer", "client", "account identifier"], - "deal_id": ["deal", "opportunity", "opportunity identifier"], - "deal_value": ["revenue", "deal size", "amount", "value"], - "stage": ["status stage", "pipeline stage"], - "segment": ["customer segment", "market segment"], - "pipeline_velocity_days": ["pipeline velocity", "cycle time", "sales cycle length"], -} - - -def _split_identifier(value: str) -> list[str]: - cleaned = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", value) - tokens = [token.lower() for token in re.split(r"[^a-zA-Z0-9]+", cleaned) if token] - return tokens - - -def _type_family(dtype: str) -> str: - lower = dtype.lower() - if any(token in lower for token in ("int", "float", "double", "decimal")): - return "number" - if "bool" in lower: - return "boolean" - if any(token in lower for token in ("datetime", "timestamp", "date")): - return "datetime" - if any(token in lower for token in ("object", "string", "category")): - return "string" - return "unknown" - - -def _semantic_hints(column_name: str) -> list[str]: - tokens = _split_identifier(column_name) - hints = {column_name, column_name.replace("_", " ")} - hints.update(tokens) - if column_name.endswith("_id"): - base = column_name[: -len("_id")].replace("_", " ").strip() - if base: - hints.add(f"{base} id") - hints.add(f"{base} identifier") - - for key, aliases in _SEMANTIC_ALIAS_LEXICON.items(): - key_tokens = set(_split_identifier(key)) - if column_name == key or key_tokens.issubset(set(tokens)): - hints.update(aliases) - - return sorted(hint for hint in hints if hint) - - -def _is_identifier_column(column_name: str, series: pd.Series) -> bool: - lowered = column_name.lower() - if lowered == "id" or lowered.endswith("_id"): - return True - non_null = series.dropna() - return bool(len(non_null) == len(series) and len(non_null) > 0 and non_null.nunique(dropna=False) == len(series)) - - -def _infer_grain(name: str, frame: pd.DataFrame, identifier_columns: list[str]) -> str: - if identifier_columns: - primary = identifier_columns[0] - if primary.lower().endswith("_id"): - entity = primary[: -len("_id")].replace("_", " ").strip() - if entity: - return f"Approximately one row per {entity}" - return f"Rows can be keyed by {primary}" - return f"Rows represent records in {name}" - - -def _build_semantic_mappings(columns: list[SchemaColumn]) -> list[SchemaConceptMapping]: - concept_to_columns: dict[str, set[str]] = {} - for column in columns: - for hint in column.semantic_hints: - normalized_hint = hint.strip().lower() - if not normalized_hint or normalized_hint == column.name.lower(): - continue - concept_to_columns.setdefault(normalized_hint, set()).add(column.name) - - mappings: list[SchemaConceptMapping] = [] - for concept, mapped_columns in sorted(concept_to_columns.items()): - if len(concept) < 4: - continue - mappings.append( - SchemaConceptMapping( - concept=concept, - columns=sorted(mapped_columns), - ) - ) - return mappings[:20] - - -def _relation_for_frame(name: str, frame: pd.DataFrame, kind: str = "view") -> SchemaRelation: - columns: list[SchemaColumn] = [] - identifier_columns: list[str] = [] - time_columns: list[str] = [] - measure_columns: list[str] = [] - dimension_columns: list[str] = [] - - for column_name in frame.columns: - dtype = str(frame[column_name].dtype) - family = _type_family(dtype) - column = SchemaColumn( - name=column_name, - dtype=dtype, - type_family=family, - semantic_hints=_semantic_hints(column_name), - ) - columns.append(column) - - if _is_identifier_column(column_name, frame[column_name]): - identifier_columns.append(column_name) - if family == "datetime": - time_columns.append(column_name) - elif family == "number": - measure_columns.append(column_name) - else: - dimension_columns.append(column_name) - - return SchemaRelation( - name=name, - kind=kind, - row_count=int(len(frame)), - grain=_infer_grain(name, frame, identifier_columns), - identifier_columns=identifier_columns, - time_columns=time_columns, - measure_columns=measure_columns, - dimension_columns=dimension_columns, - columns=columns, - semantic_mappings=_build_semantic_mappings(columns), - ) - +def _source_ids_key(source_ids: list[str] | tuple[str, ...] | None) -> tuple[str, ...]: + if not source_ids: + return () + return tuple(sorted(source_ids)) -@lru_cache(maxsize=1) -def get_semantic_context() -> SemanticContext: - """Build and cache views plus a schema-only manifest for planning.""" - bundle = load_data() - raw_views = {name: frame.copy() for name, frame in bundle.raw_views.items()} - semantic_views = {"opportunities_enriched": bundle.crm.copy()} - all_frames = {**raw_views, **semantic_views} - relations = [_relation_for_frame(name, frame, kind="view") for name, frame in all_frames.items()] - schema_manifest = SchemaManifest( - reference_date=bundle.reference_date, - source=bundle.source, - dialect="duckdb", - relations=relations, - views=[ - { - "name": relation.name, - "row_count": relation.row_count, - "columns": [{"name": column.name, "dtype": column.dtype} for column in relation.columns], - } - for relation in relations - ], - ).model_dump() +@lru_cache(maxsize=32) +def _get_semantic_context_cached(source_ids_key: tuple[str, ...]) -> SemanticContext: + source_ids = list(source_ids_key) if source_ids_key else None + manifest = get_schema_manifest(source_ids) + frames = load_relation_frames(source_ids) + primary_names = {relation["name"] for relation in manifest.get("relations", []) if relation.get("is_primary")} + raw_views = {name: frame for name, frame in frames.items() if name not in primary_names} + semantic_views = {name: frame for name, frame in frames.items() if name in primary_names} return SemanticContext( - reference_date=bundle.reference_date, - source=bundle.source, + reference_date=manifest.get("reference_date", ""), + source=manifest.get("source", ""), raw_views=raw_views, semantic_views=semantic_views, - schema_manifest=schema_manifest, + schema_manifest=manifest, ) -def new_duckdb_connection() -> duckdb.DuckDBPyConnection: - """Create an in-memory DuckDB connection with curated views registered.""" +def get_semantic_context(source_ids: list[str] | tuple[str, ...] | None = None) -> SemanticContext: + """Build and cache views plus a schema-only manifest for planning.""" + + return _get_semantic_context_cached(_source_ids_key(source_ids)) + + +def clear_semantic_context_cache() -> None: + """Clear cached registry-backed semantic contexts.""" + + _get_semantic_context_cached.cache_clear() + + +def new_duckdb_connection(dataset_context: dict[str, Any] | None = None) -> duckdb.DuckDBPyConnection: + """Create an in-memory DuckDB connection exposing only the requested relations.""" + + relation_names = [ + relation["name"] + for relation in (dataset_context or {}).get("relations", []) + if relation.get("name") + ] + if not relation_names: + relation_names = [ + view["name"] + for view in (dataset_context or {}).get("views", []) + if view.get("name") + ] + if not relation_names: + relation_names = get_registered_relation_names() - context = get_semantic_context() conn = duckdb.connect(database=":memory:") - for name, frame in {**context.raw_views, **context.semantic_views}.items(): - conn.register(name, frame) + registry_path = str(get_registry_path()).replace("'", "''") + conn.execute(f"ATTACH '{registry_path}' AS registry_db") + + for relation_name in relation_names: + escaped = relation_name.replace('"', '""') + conn.execute(f'CREATE VIEW "{escaped}" AS SELECT * FROM registry_db."{escaped}"') return conn diff --git a/app/db/schema_compat.py b/app/db/schema_compat.py new file mode 100644 index 0000000..3a26c8f --- /dev/null +++ b/app/db/schema_compat.py @@ -0,0 +1,144 @@ +"""Small SQLite compatibility migrations for local development databases.""" + +from __future__ import annotations + +from sqlalchemy import inspect, text +from sqlalchemy.engine import Engine + +_UPLOADS_TABLE = "uploads" +_LEGACY_UPLOADS_TABLE = "uploads__legacy_schema_compat" +_EXPECTED_UPLOAD_COLUMNS = { + "source_id", + "user_id", + "original_filename", + "storage_path", + "content_type", + "size_bytes", + "content_hash", + "status", + "rows", + "columns", + "relation_count", + "primary_relation_name", + "created_at", + "updated_at", +} + + +def ensure_sqlite_schema_compatibility(engine: Engine) -> None: + """Upgrade legacy local SQLite tables that predate the current ORM shape.""" + + _migrate_uploads_table_if_needed(engine) + + +def _migrate_uploads_table_if_needed(engine: Engine) -> None: + inspector = inspect(engine) + if _UPLOADS_TABLE not in inspector.get_table_names(): + return + + columns = inspector.get_columns(_UPLOADS_TABLE) + column_names = {column["name"] for column in columns} + pk_columns = [column["name"] for column in columns if column.get("primary_key")] + if column_names == _EXPECTED_UPLOAD_COLUMNS and pk_columns == ["source_id"]: + return + + if _LEGACY_UPLOADS_TABLE in inspector.get_table_names(): + raise RuntimeError( + "Cannot migrate the uploads table because a previous temporary migration table still exists: " + f"{_LEGACY_UPLOADS_TABLE}." + ) + + select_sql = _build_upload_copy_select(column_names) + with engine.begin() as connection: + connection.execute(text(f"ALTER TABLE {_UPLOADS_TABLE} RENAME TO {_LEGACY_UPLOADS_TABLE}")) + connection.execute( + text( + f""" + CREATE TABLE {_UPLOADS_TABLE} ( + source_id VARCHAR(64) NOT NULL PRIMARY KEY, + user_id INTEGER NOT NULL, + original_filename VARCHAR(512) NOT NULL, + storage_path VARCHAR(1024) NOT NULL, + content_type VARCHAR(255) NOT NULL DEFAULT 'application/octet-stream', + size_bytes INTEGER NOT NULL, + content_hash VARCHAR(128) NOT NULL, + status VARCHAR(32) NOT NULL DEFAULT 'verified', + rows INTEGER, + columns INTEGER, + relation_count INTEGER, + primary_relation_name VARCHAR(255), + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL, + FOREIGN KEY(user_id) REFERENCES users (id) ON DELETE CASCADE + ) + """ + ) + ) + connection.execute( + text( + f""" + INSERT INTO {_UPLOADS_TABLE} ( + source_id, + user_id, + original_filename, + storage_path, + content_type, + size_bytes, + content_hash, + status, + rows, + columns, + relation_count, + primary_relation_name, + created_at, + updated_at + ) + {select_sql} + """ + ) + ) + connection.execute(text(f"DROP TABLE {_LEGACY_UPLOADS_TABLE}")) + connection.execute(text("CREATE INDEX IF NOT EXISTS ix_uploads_user_id ON uploads (user_id)")) + + +def _build_upload_copy_select(column_names: set[str]) -> str: + def has_column(name: str) -> bool: + return name in column_names + + def quoted(name: str) -> str: + return f'"{name}"' + + def first_present(*names: str, default: str) -> str: + for name in names: + if has_column(name): + return quoted(name) + return default + + source_id_expr = first_present( + "source_id", + "upload_id", + default="('source_' || lower(hex(randomblob(4))))", + ) + content_type_expr = first_present("content_type", default="'application/octet-stream'") + status_expr = first_present("status", default="'verified'") + created_at_expr = first_present("created_at", default="CURRENT_TIMESTAMP") + updated_at_expr = first_present("updated_at", "created_at", default="CURRENT_TIMESTAMP") + + return f""" + SELECT + COALESCE(NULLIF({source_id_expr}, ''), 'source_' || lower(hex(randomblob(4)))) AS source_id, + {first_present("user_id", default='0')} AS user_id, + {first_present("original_filename", default="'upload.csv'")} AS original_filename, + {first_present("storage_path", default="''")} AS storage_path, + COALESCE({content_type_expr}, 'application/octet-stream') AS content_type, + {first_present("size_bytes", default='0')} AS size_bytes, + {first_present("content_hash", default="''")} AS content_hash, + COALESCE({status_expr}, 'verified') AS status, + {first_present("rows", default='NULL')} AS rows, + {first_present("columns", default='NULL')} AS columns, + {first_present("relation_count", default='NULL')} AS relation_count, + {first_present("primary_relation_name", default='NULL')} AS primary_relation_name, + COALESCE({created_at_expr}, CURRENT_TIMESTAMP) AS created_at, + COALESCE({updated_at_expr}, COALESCE({created_at_expr}, CURRENT_TIMESTAMP)) AS updated_at + FROM {_LEGACY_UPLOADS_TABLE} + """ diff --git a/app/main.py b/app/main.py index 4085152..5f495b1 100644 --- a/app/main.py +++ b/app/main.py @@ -23,10 +23,13 @@ async def lifespan(_app: FastAPI): """Create database tables on startup (SQLite file demo; no Alembic in this phase).""" from app.db.base import Base + from app.db.schema_compat import ensure_sqlite_schema_compatibility from app.db.session import get_engine - from app.models import Conversation, InspectionSnapshot, Message, User # noqa: F401 + from app.models import Conversation, InspectionSnapshot, Message, UploadRecord, User # noqa: F401 - Base.metadata.create_all(bind=get_engine()) + engine = get_engine() + Base.metadata.create_all(bind=engine) + ensure_sqlite_schema_compatibility(engine) yield diff --git a/app/models/__init__.py b/app/models/__init__.py index 4649a41..580bf2c 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -2,6 +2,7 @@ from app.models.conversation import Conversation, Message from app.models.inspection_snapshot import InspectionSnapshot +from app.models.upload import UploadRecord from app.models.user import User -__all__ = ["Conversation", "InspectionSnapshot", "Message", "User"] +__all__ = ["Conversation", "InspectionSnapshot", "Message", "UploadRecord", "User"] diff --git a/app/models/upload.py b/app/models/upload.py new file mode 100644 index 0000000..3f16b6b --- /dev/null +++ b/app/models/upload.py @@ -0,0 +1,44 @@ +"""User-owned uploaded source metadata.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import TYPE_CHECKING + +from sqlalchemy import DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.db.base import Base + +if TYPE_CHECKING: + from app.models.user import User + + +def _utc_now() -> datetime: + return datetime.now(timezone.utc) + + +class UploadRecord(Base): + __tablename__ = "uploads" + + source_id: Mapped[str] = mapped_column(String(64), primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True, nullable=False) + original_filename: Mapped[str] = mapped_column(String(512), nullable=False) + storage_path: Mapped[str] = mapped_column(String(1024), nullable=False) + content_type: Mapped[str] = mapped_column(String(255), nullable=False, default="application/octet-stream") + size_bytes: Mapped[int] = mapped_column(Integer, nullable=False) + content_hash: Mapped[str] = mapped_column(String(128), nullable=False) + status: Mapped[str] = mapped_column(String(32), nullable=False, default="verified") + rows: Mapped[int | None] = mapped_column(Integer, nullable=True) + columns: Mapped[int | None] = mapped_column(Integer, nullable=True) + relation_count: Mapped[int | None] = mapped_column(Integer, nullable=True) + primary_relation_name: Mapped[str | None] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utc_now, nullable=False) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=_utc_now, + onupdate=_utc_now, + nullable=False, + ) + + user: Mapped[User] = relationship("User", back_populates="uploads") diff --git a/app/models/user.py b/app/models/user.py index dcdd9df..48cea3f 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from app.models.conversation import Conversation + from app.models.upload import UploadRecord def _utc_now() -> datetime: @@ -32,3 +33,8 @@ class User(Base): back_populates="user", cascade="all, delete-orphan", ) + uploads: Mapped[list[UploadRecord]] = relationship( + "UploadRecord", + back_populates="user", + cascade="all, delete-orphan", + ) diff --git a/app/prompts/planner_compiled.j2 b/app/prompts/planner_compiled.j2 index d57d691..74c9587 100644 --- a/app/prompts/planner_compiled.j2 +++ b/app/prompts/planner_compiled.j2 @@ -10,6 +10,9 @@ Rules: - Use only these registered relation names: {{ relation_names_json }} - Use exact column names only. Never invent, normalize, paraphrase, or rename a column. - Use semantic mappings only to translate user language into exact schema fields. +- Prefer relations where "is_primary" is true before reaching for child relations. +- Only join relations when the schema subset exposes an explicit path in "join_keys". +- For nested JSON child relations, use the exact join columns from "join_keys" instead of guessing parent-child keys. - Do not assume the question premise is true. First verify the main metric or comparison before planning causal or grouped breakdowns. - Prefer SQL over multiple trivial splits; combine logic when one query suffices. - No imports, file I/O, network calls, or plotting. diff --git a/app/prompts/planner_repair.j2 b/app/prompts/planner_repair.j2 index 8e697d2..ada2780 100644 --- a/app/prompts/planner_repair.j2 +++ b/app/prompts/planner_repair.j2 @@ -8,6 +8,8 @@ Rules: - updated_step must use "type": "sql", the same id as the failed step ({{ failed_step_id }}), and a fixed "query". - Use only exact relation and column names from the schema manifest. - Use semantic mappings only to translate user language into exact schema fields. +- Prefer relations where "is_primary" is true before using child relations. +- Only join relations when the schema manifest exposes the join path in "join_keys". - Do not invent fields or rename columns. Original plan: diff --git a/app/schemas.py b/app/schemas.py index 61c27e5..ba30286 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -12,6 +12,7 @@ class AnalyzeRequest(BaseModel): """Incoming analysis request.""" query: str = Field(..., min_length=3, examples=["Why did pipeline velocity drop this week?"]) + source_ids: list[str] = Field(default_factory=list) class TraceEvent(BaseModel): @@ -99,9 +100,23 @@ class SchemaColumn(BaseModel): name: str = Field(..., min_length=1) dtype: str = Field(..., min_length=1) type_family: Literal["string", "number", "boolean", "datetime", "unknown"] = "unknown" + original_name: str = "" + source_path: str = "" + nullable: bool = True semantic_hints: list[str] = Field(default_factory=list) +class SchemaJoinKey(BaseModel): + """One explicit relation-level join path available to the planner.""" + + model_config = ConfigDict(extra="forbid") + + target_relation: str = Field(..., min_length=1) + source_column: str = Field(..., min_length=1) + target_column: str = Field(..., min_length=1) + link_type: Literal["parent_child", "explicit"] = "explicit" + + class SchemaRelation(BaseModel): """One normalized table or view available to the planner.""" @@ -109,12 +124,18 @@ class SchemaRelation(BaseModel): name: str = Field(..., min_length=1) kind: Literal["table", "view"] = "view" + source_id: str = "" + source_name: str = "" + is_primary: bool = False + parent_relation: str | None = None row_count: int = 0 grain: str = "" identifier_columns: list[str] = Field(default_factory=list) time_columns: list[str] = Field(default_factory=list) measure_columns: list[str] = Field(default_factory=list) dimension_columns: list[str] = Field(default_factory=list) + join_keys: list[SchemaJoinKey] = Field(default_factory=list) + lineage: dict[str, Any] = Field(default_factory=dict) columns: list[SchemaColumn] = Field(default_factory=list) semantic_mappings: list[SchemaConceptMapping] = Field(default_factory=list) @@ -266,6 +287,7 @@ class ChatSubmitRequest(BaseModel): conversation_id: int | str | None = None query: str = Field(..., min_length=3) + source_ids: list[str] = Field(default_factory=list) @field_validator("conversation_id", mode="before") @classmethod @@ -334,6 +356,8 @@ class UploadedAsset(BaseModel): status: Literal["uploaded", "profiling", "verified", "error"] rows: int | None = None columns: int | None = None + relationCount: int | None = None + primaryRelationName: str | None = None summary: str | None = None diff --git a/app/services/analysis_run.py b/app/services/analysis_run.py index e779925..8a55000 100644 --- a/app/services/analysis_run.py +++ b/app/services/analysis_run.py @@ -25,10 +25,10 @@ class StoredAnalysisRun: inspection: InspectionData -def run_stored_analysis(query: str) -> StoredAnalysisRun: +def run_stored_analysis(query: str, source_ids: list[str] | None = None) -> StoredAnalysisRun: """Run `run_analysis`, persist inspection to process memory, return API + inspection objects.""" - state = run_analysis(query) + state = run_analysis(query, source_ids=source_ids) base = AnalyzeResponse( analysis=state["analysis"], trace=state.get("trace", []), diff --git a/app/uploads/__init__.py b/app/uploads/__init__.py new file mode 100644 index 0000000..dd275ba --- /dev/null +++ b/app/uploads/__init__.py @@ -0,0 +1 @@ +"""Upload persistence services.""" diff --git a/app/uploads/service.py b/app/uploads/service.py new file mode 100644 index 0000000..7f6c91a --- /dev/null +++ b/app/uploads/service.py @@ -0,0 +1,155 @@ +"""User-owned upload persistence and authorization helpers.""" + +from __future__ import annotations + +from datetime import timezone +from hashlib import sha256 +from pathlib import Path +from uuid import uuid4 + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.data.registry import delete_source, ingest_source +from app.models.upload import UploadRecord +from app.models.user import User +from app.schemas import UploadedAsset +from app.uploads.storage import LocalUploadBlobStore + + +def _short_source_id() -> str: + return f"source_{uuid4().hex[:8]}" + + +def _to_uploaded_at(value) -> str: + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + +def _asset_from_record(record: UploadRecord) -> UploadedAsset: + relation_count = record.relation_count or 0 + summary = None + if record.rows is not None and record.columns is not None: + summary = ( + f"Persisted {record.rows} rows across {record.columns} columns" + f" into {relation_count} relation{'s' if relation_count != 1 else ''}." + ) + if record.primary_relation_name: + summary = f"{summary} Primary relation: {record.primary_relation_name}." + + return UploadedAsset( + id=record.source_id, + name=record.original_filename, + type=Path(record.original_filename).suffix.lstrip(".").upper() or "FILE", + source="Workspace upload", + sizeLabel=_bytes_to_size(record.size_bytes), + uploadedAt=_to_uploaded_at(record.created_at), + status=record.status, # type: ignore[arg-type] + rows=record.rows, + columns=record.columns, + relationCount=record.relation_count, + primaryRelationName=record.primary_relation_name, + summary=summary, + ) + + +def _bytes_to_size(num_bytes: int) -> str: + if num_bytes < 1024: + return f"{num_bytes} B" + if num_bytes < 1024 * 1024: + return f"{num_bytes / 1024:.1f} KB" + if num_bytes < 1024 * 1024 * 1024: + return f"{num_bytes / (1024 * 1024):.1f} MB" + return f"{num_bytes / (1024 * 1024 * 1024):.1f} GB" + + +def list_user_uploads(db: Session, user: User) -> list[UploadedAsset]: + rows = db.execute( + select(UploadRecord) + .where(UploadRecord.user_id == user.id) + .order_by(UploadRecord.created_at.desc(), UploadRecord.source_id.desc()) + ).scalars() + return [_asset_from_record(row) for row in rows] + + +def get_authorized_source_ids(db: Session, user: User, source_ids: list[str] | None = None) -> list[str]: + unique_source_ids = list(dict.fromkeys(source_ids or [])) + if not unique_source_ids: + return [] + + owned_rows = db.execute( + select(UploadRecord.source_id) + .where(UploadRecord.user_id == user.id, UploadRecord.source_id.in_(unique_source_ids)) + ).scalars() + owned_ids = set(owned_rows) + return [source_id for source_id in unique_source_ids if source_id in owned_ids] + + +def create_user_upload( + db: Session, + user: User, + *, + filename: str, + content_type: str | None, + content: bytes, + blob_store: LocalUploadBlobStore | None = None, +) -> UploadedAsset: + store = blob_store or LocalUploadBlobStore() + source_id = _short_source_id() + storage_path = store.save(user_id=user.id, upload_id=source_id, filename=filename, content=content) + + try: + asset = ingest_source(filename, content, source_id=source_id) + except Exception: + store.delete(storage_path) + raise + + record = UploadRecord( + source_id=source_id, + user_id=user.id, + original_filename=filename, + storage_path=str(storage_path), + content_type=content_type or "application/octet-stream", + size_bytes=len(content), + content_hash=sha256(content).hexdigest(), + status=asset.status, + rows=asset.rows, + columns=asset.columns, + relation_count=asset.relationCount, + primary_relation_name=asset.primaryRelationName, + ) + db.add(record) + + try: + db.commit() + except Exception: + db.rollback() + delete_source(source_id) + store.delete(storage_path) + raise + + db.refresh(record) + return _asset_from_record(record) + + +def delete_user_upload( + db: Session, + user: User, + source_id: str, + *, + blob_store: LocalUploadBlobStore | None = None, +) -> bool: + record = db.get(UploadRecord, source_id) + if record is None or record.user_id != user.id: + return False + + store = blob_store or LocalUploadBlobStore() + storage_path = record.storage_path + db.delete(record) + db.commit() + try: + delete_source(source_id) + finally: + store.delete(storage_path) + return True diff --git a/app/uploads/storage.py b/app/uploads/storage.py new file mode 100644 index 0000000..53adab3 --- /dev/null +++ b/app/uploads/storage.py @@ -0,0 +1,34 @@ +"""Blob storage for raw uploaded files.""" + +from __future__ import annotations + +from pathlib import Path + +from app.config import get_settings + + +class LocalUploadBlobStore: + """Persist raw uploads under a stable per-user filesystem path.""" + + def __init__(self, root: Path | None = None) -> None: + self.root = root or get_settings().upload_storage_dir + + def save(self, *, user_id: int, upload_id: str, filename: str, content: bytes) -> Path: + suffix = Path(filename).suffix.lower() + upload_dir = self.root / str(user_id) / upload_id + upload_dir.mkdir(parents=True, exist_ok=True) + path = upload_dir / f"original{suffix}" + path.write_bytes(content) + return path + + def delete(self, storage_path: str | Path) -> None: + path = Path(storage_path) + if path.exists(): + path.unlink() + current = path.parent + root = self.root.resolve() + while current.exists() and current != root: + if any(current.iterdir()): + break + current.rmdir() + current = current.parent diff --git a/data/source_registry.duckdb b/data/source_registry.duckdb new file mode 100644 index 0000000..f9fc572 Binary files /dev/null and b/data/source_registry.duckdb differ diff --git a/tests/test_api.py b/tests/test_api.py index c99c639..fbb7da8 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,7 @@ """API response structure tests.""" from io import BytesIO +import sqlite3 import pytest from fastapi.testclient import TestClient @@ -15,6 +16,8 @@ @pytest.fixture def client(tmp_path, monkeypatch): monkeypatch.setenv("DATABASE_PATH", str(tmp_path / "api_test.sqlite")) + monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "api_source_registry.duckdb")) + monkeypatch.setenv("UPLOAD_STORAGE_DIR", str(tmp_path / "uploads")) get_settings.cache_clear() reset_engine_and_session() with TestClient(app) as test_client: @@ -61,8 +64,14 @@ def test_sample_questions_endpoint(client: TestClient) -> None: assert len(payload["questions"]) >= 3 +def _signup(client: TestClient, email: str, password: str = "password123") -> str: + response = client.post("/auth/signup", json={"email": email, "password": password}) + assert response.status_code == 200, response.text + return response.json()["access_token"] + + def test_analyze_endpoint_structure(client: TestClient) -> None: - def fake_run_analysis(query: str) -> dict: # noqa: ARG001 + def fake_run_analysis(query: str, source_ids=None) -> dict: # noqa: ARG001 return { "analysis": "## Summary\nPipeline velocity improved.\n", "trace": [{"step": "planner_compiled_node", "status": "completed", "details": {"objective": "x"}}], @@ -95,7 +104,19 @@ def fake_run_analysis(query: str) -> dict: # noqa: ARG001 original = analysis_run.run_analysis analysis_run.run_analysis = fake_run_analysis try: - response = client.post("/analyze", json={"query": "Why did pipeline velocity drop this week?"}) + token = _signup(client, "analyze-structure@example.com") + headers = {"Authorization": f"Bearer {token}"} + upload_response = client.post( + "/uploads", + files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")}, + headers=headers, + ) + source_id = upload_response.json()["asset"]["id"] + response = client.post( + "/analyze", + json={"query": "Why did pipeline velocity drop this week?", "source_ids": [source_id]}, + headers=headers, + ) finally: analysis_run.run_analysis = original assert response.status_code == 200 @@ -108,7 +129,7 @@ def fake_run_analysis(query: str) -> dict: # noqa: ARG001 def test_analyze_endpoint_returns_http_500_on_failure(client: TestClient) -> None: - def fake_run_analysis(query: str) -> dict: # noqa: ARG001 + def fake_run_analysis(query: str, source_ids=None) -> dict: # noqa: ARG001 raise RuntimeError("planner exploded") import app.services.analysis_run as analysis_run @@ -116,7 +137,19 @@ def fake_run_analysis(query: str) -> dict: # noqa: ARG001 original = analysis_run.run_analysis analysis_run.run_analysis = fake_run_analysis try: - response = client.post("/analyze", json={"query": "Why did pipeline velocity drop this week?"}) + token = _signup(client, "analyze-failure@example.com") + headers = {"Authorization": f"Bearer {token}"} + upload_response = client.post( + "/uploads", + files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")}, + headers=headers, + ) + source_id = upload_response.json()["asset"]["id"] + response = client.post( + "/analyze", + json={"query": "Why did pipeline velocity drop this week?", "source_ids": [source_id]}, + headers=headers, + ) finally: analysis_run.run_analysis = original @@ -126,10 +159,21 @@ def fake_run_analysis(query: str) -> dict: # noqa: ARG001 assert payload["detail"]["error"] == "planner exploded" -def test_upload_endpoint_profiles_csv(client: TestClient) -> None: +def test_upload_endpoint_requires_auth(client: TestClient) -> None: + response = client.post( + "/uploads", + files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\n"), "text/csv")}, + ) + + assert response.status_code == 401 + + +def test_upload_endpoint_profiles_csv(client: TestClient, tmp_path) -> None: + token = _signup(client, "upload-csv@example.com") response = client.post( "/uploads", files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")}, + headers={"Authorization": f"Bearer {token}"}, ) assert response.status_code == 200 @@ -140,10 +184,134 @@ def test_upload_endpoint_profiles_csv(client: TestClient) -> None: assert payload["asset"]["rows"] == 2 assert payload["asset"]["columns"] == 2 assert payload["asset"]["status"] == "verified" + assert payload["asset"]["relationCount"] == 1 + assert payload["asset"]["primaryRelationName"] + assert any((tmp_path / "uploads").rglob("original.csv")) + + uploads_response = client.get("/uploads", headers={"Authorization": f"Bearer {token}"}) + assert uploads_response.status_code == 200 + assert [asset["id"] for asset in uploads_response.json()] == [payload["asset"]["id"]] + + +def test_upload_endpoint_profiles_json(client: TestClient) -> None: + token = _signup(client, "upload-json@example.com") + response = client.post( + "/uploads", + files={ + "file": ( + "orders.json", + BytesIO( + b'[{"order_id":"o1","customer":{"name":"Ada"},"items":[{"sku":"A1","qty":2}]}]' + ), + "application/json", + ) + }, + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["asset"]["type"] == "JSON" + assert payload["asset"]["rows"] == 1 + assert payload["asset"]["relationCount"] == 2 + assert payload["asset"]["primaryRelationName"] + + +def test_upload_endpoint_rejects_unsupported_file_types(client: TestClient) -> None: + token = _signup(client, "upload-unsupported@example.com") + response = client.post( + "/uploads", + files={"file": ("pipeline.tsv", BytesIO(b"stage\tamount\nopen\t10\n"), "text/tab-separated-values")}, + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 400 + assert response.json()["detail"]["message"] == "Only CSV and JSON uploads are currently supported." + + +def test_upload_endpoint_migrates_legacy_uploads_table(tmp_path, monkeypatch) -> None: + db_path = tmp_path / "legacy_uploads.sqlite" + registry_path = tmp_path / "legacy_registry.duckdb" + upload_dir = tmp_path / "uploads" + + conn = sqlite3.connect(db_path) + conn.execute( + """ + CREATE TABLE uploads ( + id INTEGER NOT NULL PRIMARY KEY, + upload_id VARCHAR(64) NOT NULL, + user_id INTEGER NOT NULL, + source_id VARCHAR(64) NOT NULL, + original_filename VARCHAR(255) NOT NULL, + file_type VARCHAR(32) NOT NULL, + storage_path VARCHAR(1024) NOT NULL, + content_type VARCHAR(255), + size_bytes BIGINT NOT NULL, + content_hash VARCHAR(64) NOT NULL, + status VARCHAR(32) NOT NULL, + rows INTEGER, + columns INTEGER, + relation_count INTEGER, + primary_relation_name VARCHAR(255), + summary TEXT, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + ) + """ + ) + conn.execute( + """ + INSERT INTO uploads ( + id, upload_id, user_id, source_id, original_filename, file_type, storage_path, content_type, + size_bytes, content_hash, status, rows, columns, relation_count, primary_relation_name, summary, + created_at, updated_at + ) VALUES ( + 1, 'legacy_upload_1', 999, 'legacy_source_1', 'legacy.csv', 'CSV', '/tmp/legacy.csv', 'text/csv', + 12, 'abc123', 'verified', 1, 2, 1, 'legacy_relation', 'legacy summary', + '2026-04-01 12:00:00', '2026-04-01 12:00:00' + ) + """ + ) + conn.commit() + conn.close() + + monkeypatch.setenv("DATABASE_PATH", str(db_path)) + monkeypatch.setenv("REGISTRY_PATH", str(registry_path)) + monkeypatch.setenv("UPLOAD_STORAGE_DIR", str(upload_dir)) + get_settings.cache_clear() + reset_engine_and_session() + + try: + with TestClient(app) as legacy_client: + token = _signup(legacy_client, "legacy-migration@example.com") + response = legacy_client.post( + "/uploads", + files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\n"), "text/csv")}, + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 200, response.text + + uploads_response = legacy_client.get("/uploads", headers={"Authorization": f"Bearer {token}"}) + assert uploads_response.status_code == 200 + assert len(uploads_response.json()) == 1 + finally: + get_settings.cache_clear() + reset_engine_and_session() + + migrated_conn = sqlite3.connect(db_path) + columns = [row[1] for row in migrated_conn.execute("PRAGMA table_info(uploads)").fetchall()] + source_ids = [row[0] for row in migrated_conn.execute("SELECT source_id FROM uploads ORDER BY source_id").fetchall()] + migrated_conn.close() + + assert "upload_id" not in columns + assert "file_type" not in columns + assert "source_id" in columns + assert source_ids == ["legacy_source_1", uploads_response.json()[0]["id"]] def test_inspection_endpoint_returns_stored_inspection(client: TestClient) -> None: - def fake_run_analysis(query: str) -> dict: # noqa: ARG001 + def fake_run_analysis(query: str, source_ids=None) -> dict: # noqa: ARG001 return { "analysis": "## Summary\nPipeline velocity improved.\n", "trace": [{"step": "planner_compiled_node", "status": "completed", "details": {"objective": "x"}}], @@ -175,7 +343,19 @@ def fake_run_analysis(query: str) -> dict: # noqa: ARG001 original = analysis_run.run_analysis analysis_run.run_analysis = fake_run_analysis try: - analyze_response = client.post("/analyze", json={"query": "Why did pipeline velocity drop this week?"}) + token = _signup(client, "inspection@example.com") + headers = {"Authorization": f"Bearer {token}"} + upload_response = client.post( + "/uploads", + files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")}, + headers=headers, + ) + source_id = upload_response.json()["asset"]["id"] + analyze_response = client.post( + "/analyze", + json={"query": "Why did pipeline velocity drop this week?", "source_ids": [source_id]}, + headers=headers, + ) finally: analysis_run.run_analysis = original @@ -194,3 +374,131 @@ def test_inspection_endpoint_returns_404_for_unknown_id(client: TestClient) -> N response = client.get("/inspections/inspect_missing") assert response.status_code == 404 assert response.json()["detail"]["message"] == "Inspection not found." + + +def test_analyze_endpoint_forwards_source_ids(client: TestClient) -> None: + captured: dict[str, object] = {} + + def fake_run_analysis(query: str, source_ids=None) -> dict: + captured["query"] = query + captured["source_ids"] = source_ids + return { + "analysis": "## Summary\nScoped analysis.\n", + "trace": [], + "executed_steps": [], + "errors": [], + } + + import app.services.analysis_run as analysis_run + + original = analysis_run.run_analysis + analysis_run.run_analysis = fake_run_analysis + try: + token = _signup(client, "analyze-forward@example.com") + headers = {"Authorization": f"Bearer {token}"} + upload_response = client.post( + "/uploads", + files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")}, + headers=headers, + ) + source_id = upload_response.json()["asset"]["id"] + response = client.post("/analyze", json={"query": "Use this upload", "source_ids": [source_id]}, headers=headers) + finally: + analysis_run.run_analysis = original + + assert response.status_code == 200 + assert captured["query"] == "Use this upload" + assert captured["source_ids"] == [source_id] + + +def test_analyze_endpoint_requires_uploaded_source_before_run_analysis(client: TestClient) -> None: + import app.services.analysis_run as analysis_run + + called = False + original = analysis_run.run_analysis + + def fake_run_analysis(query: str, source_ids=None) -> dict: + nonlocal called + called = True + return {"analysis": "", "trace": [], "executed_steps": [], "errors": []} + + analysis_run.run_analysis = fake_run_analysis + try: + token = _signup(client, "analyze-requires-upload@example.com") + response = client.post( + "/analyze", + json={"query": "Why did pipeline velocity drop this week?"}, + headers={"Authorization": f"Bearer {token}"}, + ) + finally: + analysis_run.run_analysis = original + + assert response.status_code == 400 + assert response.json()["detail"]["message"] == "Upload and attach at least one CSV or JSON data source before running analysis." + assert called is False + + +def test_uploads_and_analysis_are_scoped_to_the_signed_in_user(client: TestClient) -> None: + owner = _signup(client, "owner-uploads@example.com") + intruder = _signup(client, "intruder-uploads@example.com") + owner_headers = {"Authorization": f"Bearer {owner}"} + intruder_headers = {"Authorization": f"Bearer {intruder}"} + + upload_response = client.post( + "/uploads", + files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")}, + headers=owner_headers, + ) + source_id = upload_response.json()["asset"]["id"] + + owner_list = client.get("/uploads", headers=owner_headers) + intruder_list = client.get("/uploads", headers=intruder_headers) + + assert [asset["id"] for asset in owner_list.json()] == [source_id] + assert intruder_list.json() == [] + + import app.services.analysis_run as analysis_run + + called = False + original = analysis_run.run_analysis + + def fake_run_analysis(query: str, source_ids=None) -> dict: + nonlocal called + called = True + return {"analysis": "", "trace": [], "executed_steps": [], "errors": []} + + analysis_run.run_analysis = fake_run_analysis + try: + forbidden = client.post( + "/analyze", + json={"query": "Use someone else's upload", "source_ids": [source_id]}, + headers=intruder_headers, + ) + finally: + analysis_run.run_analysis = original + + assert forbidden.status_code == 400 + assert forbidden.json()["detail"]["message"] == "Attach a valid uploaded data source before running analysis." + assert called is False + + +def test_delete_upload_removes_file_and_listing(client: TestClient, tmp_path) -> None: + owner = _signup(client, "delete-owner@example.com") + intruder = _signup(client, "delete-intruder@example.com") + owner_headers = {"Authorization": f"Bearer {owner}"} + intruder_headers = {"Authorization": f"Bearer {intruder}"} + + upload_response = client.post( + "/uploads", + files={"file": ("orders.json", BytesIO(b'[{"order_id":"o1"}]'), "application/json")}, + headers=owner_headers, + ) + source_id = upload_response.json()["asset"]["id"] + + forbidden = client.delete(f"/uploads/{source_id}", headers=intruder_headers) + assert forbidden.status_code == 404 + + delete_response = client.delete(f"/uploads/{source_id}", headers=owner_headers) + assert delete_response.status_code == 204 + assert client.get("/uploads", headers=owner_headers).json() == [] + assert not any((tmp_path / "uploads").rglob("original.json")) diff --git a/tests/test_conversations.py b/tests/test_conversations.py index d731f9d..e1dfa7b 100644 --- a/tests/test_conversations.py +++ b/tests/test_conversations.py @@ -13,6 +13,8 @@ @pytest.fixture def chat_client(tmp_path, monkeypatch): monkeypatch.setenv("DATABASE_PATH", str(tmp_path / "chat_test.sqlite")) + monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "chat_source_registry.duckdb")) + monkeypatch.setenv("UPLOAD_STORAGE_DIR", str(tmp_path / "uploads")) get_settings.cache_clear() reset_engine_and_session() with TestClient(app) as client: @@ -27,7 +29,7 @@ def _signup(client: TestClient, email: str, password: str = "password123") -> st return r.json()["access_token"] -def _fake_analysis_state(query: str) -> dict: # noqa: ARG001 +def _fake_analysis_state(query: str, source_ids=None) -> dict: # noqa: ARG001 return { "analysis": "## Demo\nHello from fake analysis.\n", "trace": [{"step": "planner_compiled_node", "status": "completed", "details": {}}], @@ -228,3 +230,41 @@ def test_persisted_inspection_forbidden_for_other_user(chat_client: TestClient) forbidden = chat_client.get(f"/inspections/{inspection_id}", headers={"Authorization": f"Bearer {intruder}"}) assert forbidden.status_code == 403 + + +def test_chat_rejects_uploads_owned_by_another_user(chat_client: TestClient) -> None: + import app.services.analysis_run as analysis_run + + owner = _signup(chat_client, "owner-upload-chat@example.com") + intruder = _signup(chat_client, "intruder-upload-chat@example.com") + owner_headers = {"Authorization": f"Bearer {owner}"} + intruder_headers = {"Authorization": f"Bearer {intruder}"} + + upload_response = chat_client.post( + "/uploads", + headers=owner_headers, + files={"file": ("pipeline.csv", b"stage,amount\nopen,10\nwon,25\n", "text/csv")}, + ) + source_id = upload_response.json()["asset"]["id"] + + called = False + original = analysis_run.run_analysis + + def fake_run_analysis(query: str, source_ids=None) -> dict: + nonlocal called + called = True + return {"analysis": "", "trace": [], "executed_steps": [], "errors": []} + + analysis_run.run_analysis = fake_run_analysis + try: + response = chat_client.post( + "/chat", + headers=intruder_headers, + json={"query": "Use another user's upload", "source_ids": [source_id]}, + ) + finally: + analysis_run.run_analysis = original + + assert response.status_code == 400 + assert response.json()["detail"]["message"] == "Attach a valid uploaded data source before running analysis." + assert called is False diff --git a/tests/test_planner_schema.py b/tests/test_planner_schema.py index a82500e..3659289 100644 --- a/tests/test_planner_schema.py +++ b/tests/test_planner_schema.py @@ -1,11 +1,27 @@ """Compiled plan schema and planner retry behavior.""" +import pytest + from app.agent.planner import plan_compiled_query from app.agent.state import create_initial_state -from app.data.semantic_model import get_semantic_context +from app.config import get_settings +from app.data.registry import clear_source_registry, ingest_source +from app.data.semantic_model import clear_semantic_context_cache, get_semantic_context from app.schemas import CompiledPlan +@pytest.fixture(autouse=True) +def isolated_registry(tmp_path, monkeypatch): + monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "planner_source_registry.duckdb")) + get_settings.cache_clear() + clear_source_registry() + clear_semantic_context_cache() + yield + clear_source_registry() + clear_semantic_context_cache() + get_settings.cache_clear() + + def test_compiled_plan_normalizes_max_steps() -> None: plan = CompiledPlan.model_validate( { @@ -64,7 +80,7 @@ def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: stub) state = create_initial_state("Compare segments") - state["dataset_context"] = {"reference_date": "2017-12-31", "views": [{"name": "opportunities_enriched"}]} + state["dataset_context"] = {"reference_date": "2017-12-31", "relations": [], "views": []} state = plan_compiled_query(state) assert stub.calls == 2 @@ -73,17 +89,20 @@ def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 def test_semantic_context_exposes_normalized_manifest() -> None: + asset = ingest_source("pipeline.csv", b"owner,pipeline_velocity_days\nAda,10\nBen,14\n") manifest = get_semantic_context().schema_manifest assert manifest["dialect"] == "duckdb" assert manifest["relations"] - relation = next(relation for relation in manifest["relations"] if relation["name"] == "opportunities_enriched") + relation = next(relation for relation in manifest["relations"] if relation["name"] == asset.primaryRelationName) assert relation["columns"] assert relation["identifier_columns"] assert relation["grain"] def test_planner_retries_after_sql_preflight_failure(monkeypatch) -> None: + asset = ingest_source("pipeline.csv", b"owner,pipeline_velocity_days\nAda,10\nBen,14\n") + relation_name = asset.primaryRelationName bad = { "objective": "Analyze by sales agent", "plan": [ @@ -91,7 +110,7 @@ def test_planner_retries_after_sql_preflight_failure(monkeypatch) -> None: "id": 1, "purpose": "Break out velocity by agent", "type": "sql", - "query": "SELECT sales_agent, AVG(pipeline_velocity_days) AS avg_velocity_days FROM opportunities_enriched GROUP BY sales_agent", + "query": f"SELECT sales_agent, AVG(pipeline_velocity_days) AS avg_velocity_days FROM {relation_name} GROUP BY sales_agent", "output_alias": "velocity_by_agent", } ], @@ -106,7 +125,7 @@ def test_planner_retries_after_sql_preflight_failure(monkeypatch) -> None: "id": 1, "purpose": "Break out velocity by owner", "type": "sql", - "query": "SELECT owner, AVG(pipeline_velocity_days) AS avg_velocity_days FROM opportunities_enriched GROUP BY owner", + "query": f"SELECT owner, AVG(pipeline_velocity_days) AS avg_velocity_days FROM {relation_name} GROUP BY owner", "output_alias": "velocity_by_owner", } ], @@ -131,7 +150,7 @@ def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002 monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: stub) state = create_initial_state("Why did pipeline velocity drop this week by sales agent?") - state["dataset_context"] = get_semantic_context().schema_manifest + state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest state = plan_compiled_query(state) assert stub.calls == 2 diff --git a/tests/test_source_registry.py b/tests/test_source_registry.py new file mode 100644 index 0000000..12de3f6 --- /dev/null +++ b/tests/test_source_registry.py @@ -0,0 +1,100 @@ +"""Tests for the persistent source registry and registry-backed manifests.""" + +from __future__ import annotations + +import pytest + +from app.agent.planner import _schema_subset_for_question +from app.config import get_settings +from app.data.registry import clear_source_registry, create_source_link, get_schema_manifest, ingest_source +from app.data.semantic_model import clear_semantic_context_cache, get_semantic_context + + +@pytest.fixture(autouse=True) +def isolated_registry(tmp_path, monkeypatch): + monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "source_registry.duckdb")) + get_settings.cache_clear() + clear_source_registry() + clear_semantic_context_cache() + yield + clear_source_registry() + clear_semantic_context_cache() + get_settings.cache_clear() + + +def _relation_by_name(manifest: dict, relation_name: str) -> dict: + return next(relation for relation in manifest["relations"] if relation["name"] == relation_name) + + +def test_registry_manifest_starts_empty_without_builtin_sources() -> None: + manifest = get_schema_manifest() + + assert manifest["dialect"] == "duckdb" + assert manifest["relations"] == [] + + +def test_nested_json_upload_creates_primary_and_child_relations() -> None: + asset = ingest_source( + "orders.json", + b'[{"order_id":"o1","customer":{"name":"Ada"},"items":[{"sku":"A1","qty":2},{"sku":"B2","qty":1}]}]', + ) + + manifest = get_schema_manifest([asset.id]) + root = _relation_by_name(manifest, asset.primaryRelationName) + child = next(relation for relation in manifest["relations"] if relation["name"].startswith(f"{asset.primaryRelationName}__")) + + assert asset.relationCount == 2 + assert root["is_primary"] is True + assert any(column["name"] == "customer__name" and column["source_path"] == "customer.name" for column in root["columns"]) + assert child["parent_relation"] == asset.primaryRelationName + assert any(join["target_relation"] == asset.primaryRelationName for join in child["join_keys"]) + assert {column["name"] for column in child["columns"]} >= {"record_id", "parent_record_id", "ordinal", "sku", "qty"} + + +def test_unrelated_sources_do_not_auto_link_even_with_shared_column_names() -> None: + customers = ingest_source("customers.csv", b"customer_id,name\nc1,Ada\nc2,Ben\n") + orders = ingest_source("orders.csv", b"customer_id,amount\nc1,100\nc2,50\n") + + manifest = get_schema_manifest([customers.id, orders.id]) + customers_relation = _relation_by_name(manifest, customers.primaryRelationName) + orders_relation = _relation_by_name(manifest, orders.primaryRelationName) + + assert customers_relation["join_keys"] == [] + assert orders_relation["join_keys"] == [] + + +def test_explicit_source_links_appear_in_manifest() -> None: + customers = ingest_source("customers.csv", b"customer_id,name\nc1,Ada\nc2,Ben\n") + orders = ingest_source("orders.csv", b"customer_id,amount\nc1,100\nc2,50\n") + + create_source_link(customers.primaryRelationName, "customer_id", orders.primaryRelationName, "customer_id") + manifest = get_schema_manifest([customers.id, orders.id]) + customers_relation = _relation_by_name(manifest, customers.primaryRelationName) + orders_relation = _relation_by_name(manifest, orders.primaryRelationName) + + assert any(join["target_relation"] == orders.primaryRelationName and join["source_column"] == "customer_id" for join in customers_relation["join_keys"]) + assert any(join["target_relation"] == customers.primaryRelationName and join["source_column"] == "customer_id" for join in orders_relation["join_keys"]) + + +def test_semantic_context_scopes_to_selected_sources() -> None: + inventory = ingest_source("inventory.csv", b"sku,stock\nA1,10\nB2,3\n") + invoices = ingest_source("invoices.csv", b"invoice_id,amount\ni1,100\ni2,250\n") + + full_context = get_semantic_context().schema_manifest + scoped_context = get_semantic_context([inventory.id]).schema_manifest + + assert any(relation["name"] == inventory.primaryRelationName for relation in full_context["relations"]) + assert any(relation["name"] == invoices.primaryRelationName for relation in full_context["relations"]) + assert not any(relation["name"] == invoices.primaryRelationName for relation in scoped_context["relations"]) + assert any(relation["name"] == inventory.primaryRelationName for relation in scoped_context["relations"]) + + +def test_schema_subset_prefers_relevant_uploaded_primary_relation() -> None: + ingest_source("inventory.csv", b"sku,stock\nA1,10\nB2,3\n") + invoices = ingest_source("invoices.csv", b"invoice_id,invoice_amount,status\ni1,100,paid\ni2,250,due\n") + ingest_source("campaigns.csv", b"campaign,spend\nspring,1200\nsummer,950\n") + + manifest = get_schema_manifest() + subset = _schema_subset_for_question(manifest, "Which invoice amount is highest?") + + assert any(relation["name"] == invoices.primaryRelationName for relation in subset["relations"]) diff --git a/tests/test_tools.py b/tests/test_tools.py index d2fa90f..9811fec 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,13 +1,30 @@ """Executor tests for compiled SQL plans.""" +import pytest + from app.agent.executor import execute_plan, execute_single_plan_step from app.agent.state import create_initial_state -from app.data.semantic_model import get_semantic_context +from app.config import get_settings +from app.data.registry import clear_source_registry, ingest_source +from app.data.semantic_model import clear_semantic_context_cache, get_semantic_context + + +@pytest.fixture(autouse=True) +def isolated_registry(tmp_path, monkeypatch): + monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "tools_source_registry.duckdb")) + get_settings.cache_clear() + clear_source_registry() + clear_semantic_context_cache() + yield + clear_source_registry() + clear_semantic_context_cache() + get_settings.cache_clear() def test_execute_sql_step_returns_artifact_summary() -> None: + asset = ingest_source("pipeline.csv", b"segment\nSMB\nEnterprise\nSMB\n") state = create_initial_state("Compare SMB vs Enterprise performance") - state["dataset_context"] = get_semantic_context().schema_manifest + state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest compiled_plan = { "objective": "Segment counts", "max_steps": 3, @@ -16,7 +33,7 @@ def test_execute_sql_step_returns_artifact_summary() -> None: "id": 1, "purpose": "Get one sample aggregation.", "type": "sql", - "query": "SELECT segment, COUNT(*) AS deals FROM opportunities_enriched GROUP BY segment ORDER BY deals DESC", + "query": f"SELECT segment, COUNT(*) AS deals FROM {asset.primaryRelationName} GROUP BY segment ORDER BY deals DESC", "output_alias": "segment_counts", } ], @@ -30,8 +47,9 @@ def test_execute_sql_step_returns_artifact_summary() -> None: def test_execute_plan_marks_empty_table_as_failed() -> None: + asset = ingest_source("pipeline.csv", b"segment\nSMB\nEnterprise\n") state = create_initial_state("Why did pipeline velocity drop this week?") - state["dataset_context"] = get_semantic_context().schema_manifest + state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest compiled_plan = { "objective": "Empty query", "max_steps": 3, @@ -51,8 +69,9 @@ def test_execute_plan_marks_empty_table_as_failed() -> None: def test_execute_single_plan_step_retry() -> None: + asset = ingest_source("pipeline.csv", b"segment\nSMB\nEnterprise\n") state = create_initial_state("Test retry") - state["dataset_context"] = get_semantic_context().schema_manifest + state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest step = { "id": 1, "purpose": "Get one row.", diff --git a/ui/src/api/chat.ts b/ui/src/api/chat.ts index 15daccf..bb8b7eb 100644 --- a/ui/src/api/chat.ts +++ b/ui/src/api/chat.ts @@ -9,6 +9,7 @@ import { requestWithAuth, ApiError } from "@/api/client"; import { cacheInspection } from "@/api/inspections"; import { conversationTitleFromPrompt, mapAnalyzeResponseToUi } from "@/api/mappers"; import type { + AnalyzeApiRequest, AnalyzeApiResponse, ApiChatTurnResponse, ApiConversationDetailResponse, @@ -143,10 +144,13 @@ export async function submitChatPrompt(payload: ChatRequest, accessToken: string } try { - const body: Record = { query: payload.prompt }; + const body: AnalyzeApiRequest & { conversation_id?: number } = { query: payload.prompt }; if (payload.conversationId && isBackendConversationId(payload.conversationId)) { body.conversation_id = Number.parseInt(payload.conversationId, 10); } + if (payload.attachmentIds?.length) { + body.source_ids = payload.attachmentIds; + } const raw = await requestWithAuth("/chat", accessToken, { method: "POST", @@ -177,6 +181,9 @@ export async function submitChatPrompt(payload: ChatRequest, accessToken: string fallback: false, }; } catch (error) { + if (error instanceof ApiError && error.status && error.status >= 400 && error.status < 500) { + throw error; + } if (!shouldFallbackToDemo) throw error; await sleep(850); diff --git a/ui/src/api/types.ts b/ui/src/api/types.ts index 282b910..78a6818 100644 --- a/ui/src/api/types.ts +++ b/ui/src/api/types.ts @@ -34,6 +34,7 @@ export interface InspectionResponse { /** Body for the server's stateless `POST /analyze` (debug only — the UI sends `POST /chat` instead). */ export interface AnalyzeApiRequest { query: string; + source_ids?: string[]; } export interface AnalyzeTraceEvent { diff --git a/ui/src/api/uploads.ts b/ui/src/api/uploads.ts index c443f3c..d226595 100644 --- a/ui/src/api/uploads.ts +++ b/ui/src/api/uploads.ts @@ -1,32 +1,42 @@ -import { request } from "@/api/client"; +import { ApiError, requestWithAuth } from "@/api/client"; import type { UploadResponse } from "@/api/types"; -import { createUploadedAsset } from "@/data/mockUploads"; -import { shouldFallbackToDemo } from "@/config/env"; -import { sleep } from "@/lib/utils"; +import type { UploadedAsset } from "@/types/upload"; +import { isSupportedUploadFile } from "@/lib/uploads"; + +export async function fetchUploads(accessToken: string | null): Promise { + if (!accessToken) { + throw new ApiError("Not authenticated."); + } + + return requestWithAuth("/uploads", accessToken); +} + + +export async function uploadDataset(file: File, accessToken: string | null): Promise { + if (!isSupportedUploadFile(file)) { + throw new ApiError("Only CSV and JSON files are supported."); + } + if (!accessToken) { + throw new ApiError("Not authenticated."); + } -export async function uploadDataset(file: File): Promise { const formData = new FormData(); formData.append("file", file); - try { - const response = await request("/uploads", { - method: "POST", - body: formData, - }); - return { ...response, fallback: false }; - } catch (error) { - if (!shouldFallbackToDemo) throw error; - await sleep(480); - return { - asset: { - ...createUploadedAsset(file), - source: "Demo fallback", - status: "verified", - rows: 12840, - columns: 9, - summary: "Demo fallback profiling completed. The uploaded file is shown with generated preview statistics.", - }, - fallback: true, - }; + const response = await requestWithAuth("/uploads", accessToken, { + method: "POST", + body: formData, + }); + return { ...response, fallback: false }; +} + + +export async function deleteUpload(sourceId: string, accessToken: string | null): Promise { + if (!accessToken) { + throw new ApiError("Not authenticated."); } + + await requestWithAuth(`/uploads/${sourceId}`, accessToken, { + method: "DELETE", + }); } diff --git a/ui/src/components/app/ChatInput.tsx b/ui/src/components/app/ChatInput.tsx index f8b113e..86cb9d2 100644 --- a/ui/src/components/app/ChatInput.tsx +++ b/ui/src/components/app/ChatInput.tsx @@ -3,6 +3,7 @@ import { Button } from "@/components/shared/Button"; import { Textarea } from "@/components/shared/Textarea"; import { PromptChips } from "@/components/app/PromptChips"; import { Spinner } from "@/components/shared/Spinner"; +import { UPLOAD_ACCEPT } from "@/lib/uploads"; import type { UploadedAsset } from "@/types/upload"; interface ChatInputProps { @@ -110,6 +111,7 @@ export function ChatInput({ ref={fileInputRef} type="file" className="hidden" + accept={UPLOAD_ACCEPT} onChange={(event) => { const file = event.target.files?.[0]; if (file) { @@ -122,7 +124,7 @@ export function ChatInput({ {isUploading ? : null} Attach file - CSV, TSV, SQL exports, or connected data sources + CSV or JSON only - -
- {uploads.slice(0, 2).map((asset) => ( - - ))} -
- - ) : null} - {!collapsed && user ?

{user.email}

: null} + ) : null} + {asset.summary ?

{asset.summary}

: null} {(asset.rows || asset.columns) && (
{asset.rows ? {asset.rows.toLocaleString()} rows : null} {asset.columns ? {asset.columns} columns : null} + {asset.relationCount ? {asset.relationCount} relations : null}
)} + {asset.primaryRelationName ? ( +

Primary relation: {asset.primaryRelationName}

+ ) : null} ); } diff --git a/ui/src/components/app/UploadsPanel.test.tsx b/ui/src/components/app/UploadsPanel.test.tsx new file mode 100644 index 0000000..c699a95 --- /dev/null +++ b/ui/src/components/app/UploadsPanel.test.tsx @@ -0,0 +1,77 @@ +import { cleanup, fireEvent, render, screen } from "@testing-library/react"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { UploadsPanel } from "@/components/app/UploadsPanel"; + +afterEach(() => { + cleanup(); +}); + +describe("UploadsPanel", () => { + it("renders a file picker restricted to csv and json", () => { + render( + , + ); + + expect(screen.getByRole("button", { name: "Upload file" })).toBeInTheDocument(); + expect(screen.getByTestId("uploads-panel-file-input")).toHaveAttribute( + "accept", + ".csv,.json,application/json,text/csv", + ); + }); + + it("forwards selected files to the upload callback", () => { + const onUpload = vi.fn(); + render( + , + ); + + const input = screen.getByTestId("uploads-panel-file-input"); + const file = new File(["order_id,total\n1,25"], "orders.csv", { type: "text/csv" }); + + fireEvent.change(input, { target: { files: [file] } }); + + expect(onUpload).toHaveBeenCalledWith(file); + }); + + it("renders a delete action for each uploaded file", () => { + const onDelete = vi.fn(); + render( + , + ); + + fireEvent.click(screen.getByRole("button", { name: "Delete orders.json" })); + + expect(onDelete).toHaveBeenCalledWith("source_1"); + }); +}); diff --git a/ui/src/components/app/UploadsPanel.tsx b/ui/src/components/app/UploadsPanel.tsx new file mode 100644 index 0000000..35b3ab4 --- /dev/null +++ b/ui/src/components/app/UploadsPanel.tsx @@ -0,0 +1,72 @@ +import { useRef } from "react"; +import { UploadCard } from "@/components/app/UploadCard"; +import { Button } from "@/components/shared/Button"; +import { EmptyState } from "@/components/shared/EmptyState"; +import { ErrorState } from "@/components/shared/ErrorState"; +import { PageContainer } from "@/components/shared/PageContainer"; +import { Spinner } from "@/components/shared/Spinner"; +import { UPLOAD_ACCEPT } from "@/lib/uploads"; +import type { UploadedAsset } from "@/types/upload"; + +interface UploadsPanelProps { + uploads: UploadedAsset[]; + error: string | null; + isUploading: boolean; + deletingUploadId: string | null; + onUpload: (file: File) => void; + onDelete: (assetId: string) => void; +} + +export function UploadsPanel({ uploads, error, isUploading, deletingUploadId, onUpload, onDelete }: UploadsPanelProps) { + const fileInputRef = useRef(null); + + return ( + + { + const file = event.target.files?.[0]; + if (file) { + onUpload(file); + event.currentTarget.value = ""; + } + }} + /> + +
+
+

Upload datasets

+

Add a CSV or JSON file to profile it and make it available for analysis.

+
+ +
+ + {error ? : null} + + {uploads.length ? ( +
+ {uploads.map((asset) => ( + + ))} +
+ ) : ( + + )} +
+ ); +} diff --git a/ui/src/components/marketing/HowItWorks.tsx b/ui/src/components/marketing/HowItWorks.tsx index f1b5dc8..90a048e 100644 --- a/ui/src/components/marketing/HowItWorks.tsx +++ b/ui/src/components/marketing/HowItWorks.tsx @@ -10,7 +10,7 @@ const steps = [ { number: "02", title: "Connect or upload data", - description: "Point the workspace at a database or bring in CSV and TSV files for quick analysis.", + description: "Point the workspace at a database or bring in CSV and JSON files for quick analysis.", }, { number: "03", diff --git a/ui/src/data/mockUploads.ts b/ui/src/data/mockUploads.ts deleted file mode 100644 index 421feb4..0000000 --- a/ui/src/data/mockUploads.ts +++ /dev/null @@ -1,42 +0,0 @@ -import type { UploadedAsset } from "@/types/upload"; -import { bytesToSize, shortId } from "@/lib/utils"; - -export const seededUploads: UploadedAsset[] = [ - { - id: "upload_finance_q2", - name: "finance_forecast_q2.csv", - type: "CSV", - source: "Workspace upload", - sizeLabel: "2.4 MB", - uploadedAt: "2026-03-25T09:12:00.000Z", - status: "verified", - rows: 18240, - columns: 12, - summary: "Quarterly revenue, margin, and plan variance by region and product line.", - }, - { - id: "upload_product_health", - name: "product_health.tsv", - type: "TSV", - source: "Workspace upload", - sizeLabel: "864 KB", - uploadedAt: "2026-03-24T18:40:00.000Z", - status: "profiling", - rows: 6420, - columns: 8, - summary: "Engagement, retention, and activation metrics across product surfaces.", - }, -]; - -export function createUploadedAsset(file: File): UploadedAsset { - return { - id: shortId("upload"), - name: file.name, - type: file.name.split(".").pop()?.toUpperCase() ?? "FILE", - source: "Workspace upload", - sizeLabel: bytesToSize(file.size), - uploadedAt: new Date().toISOString(), - status: "uploaded", - summary: "Freshly uploaded dataset ready for profiling and analysis.", - }; -} diff --git a/ui/src/hooks/useActiveUploads.test.tsx b/ui/src/hooks/useActiveUploads.test.tsx new file mode 100644 index 0000000..f6d2ff6 --- /dev/null +++ b/ui/src/hooks/useActiveUploads.test.tsx @@ -0,0 +1,51 @@ +import { renderHook, waitFor } from "@testing-library/react"; +import { act } from "react"; +import { describe, expect, it } from "vitest"; +import { useActiveUploads } from "@/hooks/useActiveUploads"; +import type { UploadedAsset } from "@/types/upload"; + +function createAsset(id: string, name: string): UploadedAsset { + return { + id, + name, + type: "CSV", + source: "Workspace upload", + sizeLabel: "1.0 KB", + uploadedAt: new Date().toISOString(), + status: "verified", + }; +} + +describe("useActiveUploads", () => { + it("defaults all visible uploads to active", () => { + const uploads = [createAsset("source_1", "orders.csv"), createAsset("source_2", "customers.csv")]; + const { result } = renderHook(({ items }) => useActiveUploads(items), { + initialProps: { items: uploads }, + }); + + expect(result.current.activeUploadIds).toEqual(["source_1", "source_2"]); + expect(result.current.activeUploads.map((asset) => asset.id)).toEqual(["source_1", "source_2"]); + }); + + it("keeps removals but auto-adds newly uploaded files", async () => { + const first = createAsset("source_1", "orders.csv"); + const second = createAsset("source_2", "customers.csv"); + const third = createAsset("source_3", "inventory.csv"); + const { result, rerender } = renderHook(({ items }) => useActiveUploads(items), { + initialProps: { items: [first, second] }, + }); + + act(() => { + result.current.removeActiveUpload("source_1"); + }); + + act(() => { + rerender({ items: [first, second, third] }); + }); + + await waitFor(() => { + expect(result.current.activeUploadIds).toEqual(["source_2", "source_3"]); + expect(result.current.activeUploads.map((asset) => asset.id)).toEqual(["source_2", "source_3"]); + }); + }); +}); diff --git a/ui/src/hooks/useActiveUploads.ts b/ui/src/hooks/useActiveUploads.ts new file mode 100644 index 0000000..bcb6306 --- /dev/null +++ b/ui/src/hooks/useActiveUploads.ts @@ -0,0 +1,52 @@ +import { useEffect, useMemo, useRef, useState } from "react"; +import type { UploadedAsset } from "@/types/upload"; + +function mergeActiveUploadIds( + currentIds: string[], + previousVisibleIds: string[], + uploads: UploadedAsset[], +): string[] { + const visibleIds = uploads.map((asset) => asset.id); + const next = currentIds.filter((id) => visibleIds.includes(id)); + const newIds = visibleIds.filter((id) => !previousVisibleIds.includes(id)); + + for (const id of newIds) { + if (!next.includes(id)) { + next.push(id); + } + } + + return next; +} + +export function useActiveUploads(uploads: UploadedAsset[]) { + const [activeUploadIds, setActiveUploadIds] = useState([]); + const previousVisibleIdsRef = useRef([]); + + useEffect(() => { + const previousVisibleIds = previousVisibleIdsRef.current; + setActiveUploadIds((current) => { + const next = mergeActiveUploadIds(current, previousVisibleIds, uploads); + if (next.length === current.length && next.every((id, index) => id === current[index])) { + return current; + } + return next; + }); + previousVisibleIdsRef.current = uploads.map((asset) => asset.id); + }, [uploads]); + + const activeUploads = useMemo( + () => uploads.filter((asset) => activeUploadIds.includes(asset.id)), + [activeUploadIds, uploads], + ); + + const removeActiveUpload = (assetId: string) => { + setActiveUploadIds((current) => current.filter((id) => id !== assetId)); + }; + + return { + activeUploadIds, + activeUploads, + removeActiveUpload, + }; +} diff --git a/ui/src/hooks/useChat.ts b/ui/src/hooks/useChat.ts index bccfa5b..d15bea6 100644 --- a/ui/src/hooks/useChat.ts +++ b/ui/src/hooks/useChat.ts @@ -138,6 +138,10 @@ export function useChat() { const sendPrompt = async (prompt: string, attachments: UploadedAsset[] = []) => { const trimmed = prompt.trim(); if (!trimmed) return false; + if (attachments.length === 0) { + setError("Upload and attach at least one CSV or JSON file before running analysis."); + return false; + } let conversationId = activeConversationId; diff --git a/ui/src/hooks/useUpload.test.tsx b/ui/src/hooks/useUpload.test.tsx index 9334a06..45e459f 100644 --- a/ui/src/hooks/useUpload.test.tsx +++ b/ui/src/hooks/useUpload.test.tsx @@ -1,32 +1,120 @@ -import { renderHook } from "@testing-library/react"; +import { act, renderHook, waitFor } from "@testing-library/react"; import { afterEach, describe, expect, it, vi } from "vitest"; +import type { UploadedAsset } from "@/types/upload"; -async function loadUseUpload(mode: "demo" | "hybrid" | "live") { +function createAsset(id: string, name: string): UploadedAsset { + return { + id, + name, + type: "CSV", + source: "Workspace upload", + sizeLabel: "1.0 KB", + uploadedAt: new Date().toISOString(), + status: "verified", + }; +} + +async function loadUseUpload(options?: { + isReady?: boolean; + isAuthenticated?: boolean; + token?: string | null; + fetchedUploads?: UploadedAsset[]; +}) { vi.resetModules(); - vi.stubEnv("VITE_API_FALLBACK_MODE", mode); + const fetchUploads = vi.fn().mockResolvedValue(options?.fetchedUploads ?? []); + const uploadDataset = vi.fn().mockResolvedValue({ + asset: createAsset("source_uploaded", "orders.csv"), + fallback: false, + }); + const deleteUpload = vi.fn().mockResolvedValue(undefined); + + vi.doMock("@/hooks/useAuth", () => ({ + useAuth: () => ({ + user: + options?.isAuthenticated === false + ? null + : { id: 1, email: "test@example.com", display_name: null, created_at: new Date().toISOString() }, + token: options?.token ?? "token_123", + isReady: options?.isReady ?? true, + isAuthenticated: options?.isAuthenticated ?? true, + login: vi.fn(), + signUp: vi.fn(), + logout: vi.fn(), + }), + })); + vi.doMock("@/api/uploads", () => ({ + fetchUploads, + uploadDataset, + deleteUpload, + })); + const module = await import("@/hooks/useUpload"); - return module.useUpload; + return { useUpload: module.useUpload, fetchUploads, uploadDataset, deleteUpload }; } afterEach(() => { - vi.unstubAllEnvs(); vi.resetModules(); + vi.clearAllMocks(); }); describe("useUpload", () => { - it("seeds demo uploads only in demo-only mode", async () => { - const useUpload = await loadUseUpload("demo"); + it("stays empty when the user is not authenticated", async () => { + const { useUpload, fetchUploads } = await loadUseUpload({ isAuthenticated: false, token: null }); const { result } = renderHook(() => useUpload()); - expect(result.current.uploads.length).toBeGreaterThan(0); - expect(result.current.latestUploadMode).toBe("demo"); + expect(result.current.uploads).toHaveLength(0); + expect(result.current.latestUploadMode).toBeNull(); + expect(fetchUploads).not.toHaveBeenCalled(); }); - it("starts with an empty upload list outside demo-only mode", async () => { - const useUpload = await loadUseUpload("hybrid"); + it("hydrates uploads for the signed-in user", async () => { + const fetchedUploads = [createAsset("source_1", "orders.csv"), createAsset("source_2", "customers.csv")]; + const { useUpload, fetchUploads } = await loadUseUpload({ fetchedUploads }); const { result } = renderHook(() => useUpload()); - expect(result.current.uploads).toHaveLength(0); - expect(result.current.latestUploadMode).toBeNull(); + await waitFor(() => { + expect(result.current.uploads).toHaveLength(2); + }); + + expect(fetchUploads).toHaveBeenCalledWith("token_123"); + expect(result.current.latestUploadMode).toBe("live"); + }); + + it("prepends a newly uploaded file", async () => { + const fetchedUploads = [createAsset("source_existing", "customers.csv")]; + const uploadedAsset = createAsset("source_uploaded", "orders.csv"); + const { useUpload, uploadDataset } = await loadUseUpload({ fetchedUploads }); + uploadDataset.mockResolvedValue({ asset: uploadedAsset, fallback: false }); + + const { result } = renderHook(() => useUpload()); + + await waitFor(() => { + expect(result.current.uploads).toHaveLength(1); + }); + + await act(async () => { + await result.current.uploadFile(new File(["order_id\n1"], "orders.csv", { type: "text/csv" })); + }); + + expect(uploadDataset).toHaveBeenCalled(); + expect(result.current.uploads.map((asset) => asset.id)).toEqual(["source_uploaded", "source_existing"]); + expect(result.current.latestUploadMode).toBe("live"); + }); + + it("removes a deleted file from the upload list", async () => { + const fetchedUploads = [createAsset("source_existing", "customers.csv"), createAsset("source_delete", "orders.csv")]; + const { useUpload, deleteUpload } = await loadUseUpload({ fetchedUploads }); + const { result } = renderHook(() => useUpload()); + + await waitFor(() => { + expect(result.current.uploads).toHaveLength(2); + }); + + await act(async () => { + await result.current.deleteUpload("source_delete"); + }); + + expect(deleteUpload).toHaveBeenCalledWith("source_delete", "token_123"); + expect(result.current.uploads.map((asset) => asset.id)).toEqual(["source_existing"]); }); }); diff --git a/ui/src/hooks/useUpload.ts b/ui/src/hooks/useUpload.ts index dacd1eb..bcdc2f5 100644 --- a/ui/src/hooks/useUpload.ts +++ b/ui/src/hooks/useUpload.ts @@ -1,24 +1,54 @@ -import { useState } from "react"; -import { uploadDataset } from "@/api/uploads"; -import { isDemoOnlyMode } from "@/config/env"; -import { seededUploads } from "@/data/mockUploads"; +import { useEffect, useState } from "react"; +import { deleteUpload as deleteUploadRequest, fetchUploads, uploadDataset } from "@/api/uploads"; +import { useAuth } from "@/hooks/useAuth"; import type { UploadedAsset } from "@/types/upload"; export function useUpload() { - const [uploads, setUploads] = useState(() => (isDemoOnlyMode ? seededUploads : [])); + const { token, isReady, isAuthenticated } = useAuth(); + const [uploads, setUploads] = useState([]); const [isUploading, setIsUploading] = useState(false); + const [deletingUploadId, setDeletingUploadId] = useState(null); const [error, setError] = useState(null); - const [latestUploadMode, setLatestUploadMode] = useState<"live" | "demo" | null>(() => - isDemoOnlyMode && seededUploads.length > 0 ? "demo" : null, - ); + const [latestUploadMode, setLatestUploadMode] = useState<"live" | "demo" | null>(null); + + useEffect(() => { + if (!isReady) return; + + if (!isAuthenticated || !token) { + setUploads([]); + setLatestUploadMode(null); + setError(null); + return; + } + + let cancelled = false; + + const load = async () => { + setError(null); + try { + const nextUploads = await fetchUploads(token); + if (cancelled) return; + setUploads(nextUploads); + setLatestUploadMode(nextUploads.length > 0 ? "live" : null); + } catch (err) { + if (cancelled) return; + setError(err instanceof Error ? err.message : "Unable to load uploads."); + } + }; + + void load(); + return () => { + cancelled = true; + }; + }, [isAuthenticated, isReady, token]); const uploadFile = async (file: File) => { setIsUploading(true); setError(null); try { - const response = await uploadDataset(file); - setUploads((current) => [response.asset, ...current]); + const response = await uploadDataset(file, token); + setUploads((current) => [response.asset, ...current.filter((asset) => asset.id !== response.asset.id)]); setLatestUploadMode(response.fallback ? "demo" : "live"); return response.asset; } catch (err) { @@ -30,11 +60,35 @@ export function useUpload() { } }; + const deleteUpload = async (sourceId: string) => { + setDeletingUploadId(sourceId); + setError(null); + + try { + await deleteUploadRequest(sourceId, token); + let remainingCount = 0; + setUploads((current) => { + const next = current.filter((asset) => asset.id !== sourceId); + remainingCount = next.length; + return next; + }); + setLatestUploadMode(remainingCount > 0 ? "live" : null); + } catch (err) { + const message = err instanceof Error ? err.message : "Unable to delete file."; + setError(message); + throw err; + } finally { + setDeletingUploadId(null); + } + }; + return { uploads, isUploading, + deletingUploadId, error, latestUploadMode, uploadFile, + deleteUpload, }; } diff --git a/ui/src/lib/constants.ts b/ui/src/lib/constants.ts index 50627ea..3796ff6 100644 --- a/ui/src/lib/constants.ts +++ b/ui/src/lib/constants.ts @@ -43,7 +43,7 @@ export const homepageFeatureCards = [ }, { title: "Data upload and profiling", - description: "Bring CSV, TSV, and database data into the workflow and see freshness, shape, and quality at a glance.", + description: "Bring CSV and JSON data into the workflow and see freshness, shape, and quality at a glance.", }, { title: "Execution transparency", diff --git a/ui/src/lib/uploads.ts b/ui/src/lib/uploads.ts new file mode 100644 index 0000000..9e9ccd6 --- /dev/null +++ b/ui/src/lib/uploads.ts @@ -0,0 +1,6 @@ +export const UPLOAD_ACCEPT = ".csv,.json,application/json,text/csv"; + +export function isSupportedUploadFile(file: Pick): boolean { + const normalized = file.name.trim().toLowerCase(); + return normalized.endsWith(".csv") || normalized.endsWith(".json"); +} diff --git a/ui/src/pages/AppPage.tsx b/ui/src/pages/AppPage.tsx index 386fa6a..8887d65 100644 --- a/ui/src/pages/AppPage.tsx +++ b/ui/src/pages/AppPage.tsx @@ -5,7 +5,7 @@ import { ChatThread } from "@/components/app/ChatThread"; import { InspectionPanel } from "@/components/app/InspectionPanel"; import { InsightCard } from "@/components/app/InsightCard"; import { Sidebar } from "@/components/app/Sidebar"; -import { UploadCard } from "@/components/app/UploadCard"; +import { UploadsPanel } from "@/components/app/UploadsPanel"; import { StatusBadge } from "@/components/app/StatusBadge"; import { Button } from "@/components/shared/Button"; import { Card } from "@/components/shared/Card"; @@ -16,13 +16,13 @@ import { Spinner } from "@/components/shared/Spinner"; import { savedAnalyses } from "@/data/mockInsights"; import { env, isDemoOnlyMode } from "@/config/env"; import { useChat } from "@/hooks/useChat"; +import { useActiveUploads } from "@/hooks/useActiveUploads"; import { useInspectionPanel } from "@/hooks/useInspectionPanel"; import { useResponsiveSidebar } from "@/hooks/useResponsiveSidebar"; import { useUpload } from "@/hooks/useUpload"; import { AppLayout } from "@/layouts/AppLayout"; import { formatCompactNumber } from "@/lib/utils"; import { uiStore } from "@/store/uiStore"; -import type { UploadedAsset } from "@/types/upload"; type SidebarSection = "chats" | "uploads" | "saved" | "dashboards"; @@ -35,7 +35,15 @@ const dashboardMetrics = [ export function AppPage() { const { isMobile, collapsed, mobileOpen, closeMobileSidebar, toggleSidebar } = useResponsiveSidebar(); - const { uploads, isUploading, error: uploadError, latestUploadMode, uploadFile } = useUpload(); + const { + uploads, + isUploading, + deletingUploadId, + error: uploadError, + latestUploadMode, + uploadFile, + deleteUpload, + } = useUpload(); const { conversations, activeConversation, @@ -50,8 +58,8 @@ export function AppPage() { } = useChat(); const inspection = useInspectionPanel(); const [draft, setDraft] = useState(""); - const [attachments, setAttachments] = useState([]); const [activeSection, setActiveSection] = useState(() => uiStore.getActiveSection() as SidebarSection); + const { activeUploads, removeActiveUpload } = useActiveUploads(uploads); const currentTitle = useMemo(() => { if (activeSection === "uploads") return "Data uploads"; @@ -100,21 +108,28 @@ export function AppPage() { uiStore.setActiveSection(section); }; - const handleUpload = async (file: File) => { + const handleChatUpload = async (file: File) => { try { - const asset = await uploadFile(file); - setAttachments((current) => [asset, ...current].slice(0, 2)); + await uploadFile(file); handleSectionChange("chats"); } catch { // Upload errors are surfaced through the hook state and UI. } }; + const handleUploadsSectionUpload = async (file: File) => { + try { + await uploadFile(file); + handleSectionChange("uploads"); + } catch { + // Upload errors are surfaced through the hook state and UI. + } + }; + const handleSubmit = async () => { - const success = await sendPrompt(draft, attachments); + const success = await sendPrompt(draft, activeUploads); if (success) { setDraft(""); - setAttachments([]); } handleSectionChange("chats"); }; @@ -150,9 +165,9 @@ export function AppPage() { setDraft(prompt); handleSectionChange("chats"); }} - onUpload={(file) => void handleUpload(file)} - onRemoveAttachment={(assetId) => setAttachments((current) => current.filter((asset) => asset.id !== assetId))} - attachments={attachments} + onUpload={(file) => void handleChatUpload(file)} + onRemoveAttachment={removeActiveUpload} + attachments={activeUploads} isSubmitting={isSubmitting} isUploading={isUploading} /> @@ -160,21 +175,14 @@ export function AppPage() { ); const uploadsView = ( - - {uploadError ? : null} - {uploads.length ? ( -
- {uploads.map((asset) => ( - - ))} -
- ) : ( - - )} -
+ void handleUploadsSectionUpload(file)} + onDelete={(assetId) => void deleteUpload(assetId)} + /> ); const savedView = ( @@ -243,7 +251,6 @@ export function AppPage() { sidebar={