diff --git a/app/agent/executor.py b/app/agent/executor.py
index b9915e5..d8e58af 100644
--- a/app/agent/executor.py
+++ b/app/agent/executor.py
@@ -61,15 +61,18 @@ def _register_artifacts(conn: duckdb.DuckDBPyConnection, state: AnalysisState) -
def _execute_sql(state: AnalysisState, step: dict[str, Any]) -> ArtifactSummary:
- conn = new_duckdb_connection()
- _register_artifacts(conn, state)
- frame = conn.execute(step["code"]).fetchdf()
- state["artifacts"][step["output_alias"]] = frame
- return _summarize_artifact(step["output_alias"], frame)
+ conn = new_duckdb_connection(state.get("dataset_context"))
+ try:
+ _register_artifacts(conn, state)
+ frame = conn.execute(step["code"]).fetchdf()
+ state["artifacts"][step["output_alias"]] = frame
+ return _summarize_artifact(step["output_alias"], frame)
+ finally:
+ conn.close()
def _execute_pandas(state: AnalysisState, step: dict[str, Any]) -> ArtifactSummary:
- context = get_semantic_context()
+ context = get_semantic_context(state.get("source_ids"))
local_env: dict[str, Any] = {
**context.raw_views,
**context.semantic_views,
@@ -112,26 +115,29 @@ def preflight_compiled_plan(state: AnalysisState, compiled_plan: dict[str, Any])
Steps are checked in order so later queries can reference earlier output aliases.
"""
- conn = new_duckdb_connection()
- _register_artifacts(conn, state)
- rows = list(compiled_plan.get("plan") or [])
- rows.sort(key=lambda r: r["id"] if isinstance(r, dict) else r.id)
-
- for row in rows:
- internal = compiled_plan_row_to_internal(row)
- sql = internal["code"].strip().rstrip(";")
- try:
- preview = conn.execute(f"SELECT * FROM ({sql}) AS __planera_preflight LIMIT 0").fetchdf()
- conn.register(internal["output_alias"], preview)
- except Exception as exc:
- return {
- "status": "failed",
- "failed_step_id": internal["id"],
- "error": str(exc),
- "query": internal["code"],
- }
-
- return {"status": "success"}
+ conn = new_duckdb_connection(state.get("dataset_context"))
+ try:
+ _register_artifacts(conn, state)
+ rows = list(compiled_plan.get("plan") or [])
+ rows.sort(key=lambda r: r["id"] if isinstance(r, dict) else r.id)
+
+ for row in rows:
+ internal = compiled_plan_row_to_internal(row)
+ sql = internal["code"].strip().rstrip(";")
+ try:
+ preview = conn.execute(f"SELECT * FROM ({sql}) AS __planera_preflight LIMIT 0").fetchdf()
+ conn.register(internal["output_alias"], preview)
+ except Exception as exc:
+ return {
+ "status": "failed",
+ "failed_step_id": internal["id"],
+ "error": str(exc),
+ "query": internal["code"],
+ }
+
+ return {"status": "success"}
+ finally:
+ conn.close()
def _try_sql_step(
diff --git a/app/agent/graph.py b/app/agent/graph.py
index 0bc9ef3..076944b 100644
--- a/app/agent/graph.py
+++ b/app/agent/graph.py
@@ -28,7 +28,7 @@ def _append_error(state: AnalysisState, step: str, message: str, recoverable: bo
def load_schema_context_node(state: AnalysisState) -> AnalysisState:
step_name = "load_schema_context_node"
_append_trace(state, step_name, "started", {})
- context = get_semantic_context()
+ context = get_semantic_context(state.get("source_ids"))
state["dataset_context"] = context.schema_manifest
_append_trace(
state,
@@ -162,8 +162,8 @@ def build_graph():
return graph.compile()
-def run_analysis(query: str) -> AnalysisState:
+def run_analysis(query: str, source_ids: list[str] | None = None) -> AnalysisState:
"""Execute the full workflow for a single user query."""
workflow = build_graph()
- return workflow.invoke(create_initial_state(query))
+ return workflow.invoke(create_initial_state(query, source_ids=source_ids))
diff --git a/app/agent/planner.py b/app/agent/planner.py
index 0c9892a..5fc8bdb 100644
--- a/app/agent/planner.py
+++ b/app/agent/planner.py
@@ -38,7 +38,11 @@ def _field_terms(*values: str) -> set[str]:
def _column_relevance_score(column: dict[str, Any], question_terms: set[str]) -> int:
- column_terms = _field_terms(column.get("name", ""))
+ column_terms = _field_terms(
+ column.get("name", ""),
+ column.get("original_name", ""),
+ column.get("source_path", ""),
+ )
for hint in column.get("semantic_hints") or []:
column_terms.update(_field_terms(hint))
@@ -50,10 +54,20 @@ def _column_relevance_score(column: dict[str, Any], question_terms: set[str]) ->
def _relation_relevance_score(relation: dict[str, Any], question_terms: set[str]) -> int:
- score = len(_field_terms(relation.get("name", ""), relation.get("grain", "")) & question_terms) * 4
+ score = len(
+ _field_terms(
+ relation.get("name", ""),
+ relation.get("grain", ""),
+ relation.get("source_name", ""),
+ json.dumps(relation.get("lineage", {}), default=str),
+ )
+ & question_terms
+ ) * 4
score += sum(_column_relevance_score(column, question_terms) for column in relation.get("columns") or [])
for mapping in relation.get("semantic_mappings") or []:
score += len(_field_terms(mapping.get("concept", "")) & question_terms) * 5
+ if relation.get("is_primary"):
+ score += 3
return score
diff --git a/app/agent/state.py b/app/agent/state.py
index 18a1a7d..c940ed7 100644
--- a/app/agent/state.py
+++ b/app/agent/state.py
@@ -11,6 +11,7 @@ class AnalysisState(TypedDict):
"""Explicit state carried through the analytics workflow."""
query: str
+ source_ids: list[str]
dataset_context: dict[str, Any]
intent: str
metric: str
@@ -27,11 +28,12 @@ class AnalysisState(TypedDict):
errors: list[dict[str, Any]]
-def create_initial_state(query: str) -> AnalysisState:
+def create_initial_state(query: str, source_ids: list[str] | None = None) -> AnalysisState:
"""Return the initial workflow state for a new request."""
return AnalysisState(
query=query,
+ source_ids=list(source_ids or []),
dataset_context={},
intent="",
metric="",
diff --git a/app/api/chat_routes.py b/app/api/chat_routes.py
index 0c5f72c..5dfa109 100644
--- a/app/api/chat_routes.py
+++ b/app/api/chat_routes.py
@@ -30,6 +30,7 @@
)
from app.services.analysis_run import run_stored_analysis
from app.services.inspection_persistence import save_inspection_for_assistant_message
+from app.uploads.service import get_authorized_source_ids
from app.utils.logging import get_logger
@@ -146,8 +147,19 @@ def chat_turn(
db.add(user_msg)
db.flush()
+ requested_source_ids = list(dict.fromkeys(body.source_ids or []))
+ if requested_source_ids:
+ valid_source_ids = get_authorized_source_ids(db, current_user, requested_source_ids)
+ if len(valid_source_ids) != len(requested_source_ids):
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail={"message": "Attach a valid uploaded data source before running analysis."},
+ )
+ else:
+ valid_source_ids = None
+
try:
- analysis_run = run_stored_analysis(body.query)
+ analysis_run = run_stored_analysis(body.query, source_ids=valid_source_ids)
analysis_result = analysis_run.response
inspection_payload = analysis_run.inspection
except Exception as exc: # pragma: no cover - defensive API fallback
diff --git a/app/api/routes.py b/app/api/routes.py
index da1cd92..a3a450c 100644
--- a/app/api/routes.py
+++ b/app/api/routes.py
@@ -2,18 +2,19 @@
from __future__ import annotations
-from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
+from fastapi import APIRouter, Depends, File, HTTPException, Response, UploadFile, status
from sqlalchemy.orm import Session
-from app.api.workspace import get_inspection, profile_upload
-from app.auth.deps import get_current_user_optional
+from app.api.workspace import get_inspection
+from app.auth.deps import get_current_user, get_current_user_optional
from app.config import get_settings
from app.db.session import get_db
from app.models.conversation import Conversation
from app.models.inspection_snapshot import InspectionSnapshot
from app.models.user import User
-from app.schemas import AnalyzeRequest, AnalyzeResponse, HealthResponse, InspectionData, InspectionResponse, SampleQuestionsResponse, UploadResponse
+from app.schemas import AnalyzeRequest, AnalyzeResponse, HealthResponse, InspectionData, InspectionResponse, SampleQuestionsResponse, UploadedAsset, UploadResponse
from app.services.analysis_run import run_stored_analysis
+from app.uploads.service import create_user_upload, delete_user_upload, get_authorized_source_ids, list_user_uploads
from app.utils.constants import SAMPLE_QUESTIONS
from app.utils.logging import get_logger
@@ -37,18 +38,52 @@ def sample_questions() -> SampleQuestionsResponse:
return SampleQuestionsResponse(questions=SAMPLE_QUESTIONS)
+@router.get("/uploads", response_model=list[UploadedAsset])
+def list_uploads(
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+) -> list[UploadedAsset]:
+ """Return uploads owned by the signed-in user."""
+
+ return list_user_uploads(db, current_user)
+
+
@router.post("/uploads", response_model=UploadResponse)
-async def upload_dataset(file: UploadFile = File(...)) -> UploadResponse:
+async def upload_dataset(
+ file: UploadFile = File(...),
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+) -> UploadResponse:
"""Accept a workspace upload and return a profiled asset summary."""
contents = await file.read()
try:
- asset = profile_upload(file.filename or "upload.csv", contents)
+ asset = create_user_upload(
+ db,
+ current_user,
+ filename=file.filename or "upload.csv",
+ content_type=file.content_type,
+ content=contents,
+ )
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail={"message": str(exc)}) from exc
return UploadResponse(asset=asset, fallback=False)
+@router.delete("/uploads/{source_id}", status_code=status.HTTP_204_NO_CONTENT)
+def delete_upload(
+ source_id: str,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+) -> Response:
+ """Delete one upload owned by the signed-in user."""
+
+ deleted = delete_user_upload(db, current_user, source_id)
+ if not deleted:
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail={"message": "Upload not found."})
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
+
+
@router.get("/inspections/{inspection_id}", response_model=InspectionResponse)
def inspection_details(
inspection_id: str,
@@ -94,17 +129,35 @@ def inspection_details(
description=(
"**Not the primary product API.** For normal use, authenticated clients should call "
"`POST /chat`, which persists conversations, messages, and inspection snapshots.\n\n"
- "This endpoint runs the same analytics pipeline without auth, without database persistence, "
+ "This endpoint runs the same analytics pipeline with auth but without conversation/database persistence, "
"and keeps the inspection payload only in process memory (lost on restart). Use it for "
"local debugging, Swagger/Postman checks, and quick stateless demos.\n\n"
"Response shape aligns with the analysis/trace/steps portion of `POST /chat`."
),
)
-def analyze(request: AnalyzeRequest) -> AnalyzeResponse:
- """Run analytics without persistence (see OpenAPI ``description`` — prefer ``POST /chat`` for product flows)."""
+def analyze(
+ request: AnalyzeRequest,
+ current_user: User = Depends(get_current_user),
+ db: Session = Depends(get_db),
+) -> AnalyzeResponse:
+ """Run analytics with auth but without persistence (see OpenAPI ``description`` — prefer ``POST /chat``)."""
+
+ requested_source_ids = list(dict.fromkeys(request.source_ids or []))
+ if not requested_source_ids:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail={"message": "Upload and attach at least one CSV or JSON data source before running analysis."},
+ )
+
+ valid_source_ids = get_authorized_source_ids(db, current_user, requested_source_ids)
+ if len(valid_source_ids) != len(requested_source_ids):
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail={"message": "Attach a valid uploaded data source before running analysis."},
+ )
try:
- return run_stored_analysis(request.query).response
+ return run_stored_analysis(request.query, source_ids=valid_source_ids).response
except Exception as exc: # pragma: no cover - defensive API fallback
logger.exception("Analyze request failed", extra={"query": request.query})
settings = get_settings()
diff --git a/app/api/workspace.py b/app/api/workspace.py
index e9b47a2..c67b238 100644
--- a/app/api/workspace.py
+++ b/app/api/workspace.py
@@ -4,17 +4,13 @@
import json
import re
-import tempfile
-from dataclasses import dataclass
from datetime import datetime, timezone
-from io import BytesIO
-from pathlib import Path
from threading import Lock
from typing import Any
from uuid import uuid4
-import pandas as pd
-
+from app.data.registry import clear_source_registry, ingest_source
+from app.data.semantic_model import clear_semantic_context_cache
from app.schemas import AnalyzeResponse, ArtifactSummary, InspectionData, MetadataItem, ResultTableData, TraceEntry, UploadedAsset, ValidationCheck
@@ -29,58 +25,22 @@
_STORE_LOCK = Lock()
_INSPECTIONS: dict[str, InspectionData] = {}
-_UPLOAD_DIR = Path(tempfile.gettempdir()) / "planera_uploads"
-
-
-@dataclass
-class StoredUpload:
- """Backend-only metadata for uploaded files."""
-
- asset: UploadedAsset
- file_path: Path
-
-
-_UPLOADS: dict[str, StoredUpload] = {}
def clear_workspace_state() -> None:
"""Reset in-memory upload/inspection storage for tests."""
with _STORE_LOCK:
- for stored in _UPLOADS.values():
- if stored.file_path.exists():
- stored.file_path.unlink(missing_ok=True)
- _UPLOADS.clear()
_INSPECTIONS.clear()
+ clear_source_registry()
+ clear_semantic_context_cache()
def profile_upload(filename: str, content: bytes) -> UploadedAsset:
- """Profile an uploaded CSV/TSV file and persist it to a temp location."""
-
- safe_name = Path(filename or "upload.csv").name
- frame = _read_uploaded_frame(safe_name, content)
- row_count = int(len(frame))
- column_count = int(len(frame.columns))
- asset = UploadedAsset(
- id=_short_id("upload"),
- name=safe_name,
- type=_derive_file_type(safe_name),
- source="Workspace upload",
- sizeLabel=_bytes_to_size(len(content)),
- uploadedAt=_now_iso(),
- status="verified",
- rows=row_count,
- columns=column_count,
- summary=_build_upload_summary(frame, row_count, column_count),
- )
-
- _UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
- file_path = _UPLOAD_DIR / f"{asset.id}_{safe_name}"
- file_path.write_bytes(content)
-
- with _STORE_LOCK:
- _UPLOADS[asset.id] = StoredUpload(asset=asset, file_path=file_path)
+ """Persist an uploaded structured dataset to the source registry."""
+ asset = ingest_source(filename, content)
+ clear_semantic_context_cache()
return asset
@@ -100,35 +60,6 @@ def get_inspection(inspection_id: str) -> InspectionData | None:
return _INSPECTIONS.get(inspection_id)
-def _read_uploaded_frame(filename: str, content: bytes) -> pd.DataFrame:
- if not content:
- raise ValueError("Uploaded file is empty.")
-
- suffix = Path(filename).suffix.lower()
- buffer = BytesIO(content)
-
- try:
- if suffix in {".csv"}:
- return pd.read_csv(buffer)
- if suffix in {".tsv", ".tab"}:
- return pd.read_csv(buffer, sep="\t")
- if suffix in {".txt"}:
- return pd.read_csv(buffer, sep=None, engine="python")
- except Exception as exc: # pragma: no cover - pandas error details vary
- raise ValueError(f"Could not parse {filename} as a structured text dataset.") from exc
-
- raise ValueError("Only CSV, TSV, TAB, and TXT uploads are currently supported.")
-
-
-def _build_upload_summary(frame: pd.DataFrame, row_count: int, column_count: int) -> str:
- preview_columns = [str(column) for column in list(frame.columns[:4])]
- column_label = ", ".join(preview_columns)
- suffix = "..." if column_count > 4 else ""
- if column_label:
- return f"Profiled {row_count} rows across {column_count} columns. Leading columns: {column_label}{suffix}."
- return f"Profiled {row_count} rows across {column_count} columns."
-
-
def _build_inspection(inspection_id: str, prompt: str, response: AnalyzeResponse) -> InspectionData:
executed_steps = response.executed_steps or []
primary_artifact = _pick_primary_artifact(response)
@@ -463,20 +394,5 @@ def _short_id(prefix: str) -> str:
return f"{prefix}_{uuid4().hex[:8]}"
-def _derive_file_type(filename: str) -> str:
- suffix = Path(filename).suffix.lstrip(".").upper()
- return suffix or "FILE"
-
-
-def _bytes_to_size(num_bytes: int) -> str:
- if num_bytes < 1024:
- return f"{num_bytes} B"
- if num_bytes < 1024 * 1024:
- return f"{num_bytes / 1024:.1f} KB"
- if num_bytes < 1024 * 1024 * 1024:
- return f"{num_bytes / (1024 * 1024):.1f} MB"
- return f"{num_bytes / (1024 * 1024 * 1024):.1f} GB"
-
-
def _now_iso() -> str:
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
diff --git a/app/config.py b/app/config.py
index 7c9d9d8..52a683f 100644
--- a/app/config.py
+++ b/app/config.py
@@ -20,6 +20,8 @@ class Settings(BaseSettings):
api_host: str = "0.0.0.0"
api_port: int = 8000
data_dir: Path = Field(default=BASE_DIR / "data")
+ registry_path: Path = Field(default=BASE_DIR / "data" / "source_registry.duckdb")
+ upload_storage_dir: Path = Field(default=BASE_DIR / "data" / "uploads", alias="UPLOAD_STORAGE_DIR")
crm_path: Path = Field(default=BASE_DIR / "data" / "crm.csv")
subscriptions_path: Path = Field(default=BASE_DIR / "data" / "subscriptions.csv")
crm_dataset_dir: Path = Field(default=BASE_DIR / "data" / "CRM+Sales+Opportunities")
diff --git a/app/data/registry.py b/app/data/registry.py
new file mode 100644
index 0000000..71d0e19
--- /dev/null
+++ b/app/data/registry.py
@@ -0,0 +1,1172 @@
+"""Persistent DuckDB-backed source registry and source-package ingestion."""
+
+from __future__ import annotations
+
+import json
+import re
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from hashlib import sha256
+from io import BytesIO
+from pathlib import Path
+from typing import Any
+from uuid import uuid4
+
+import duckdb
+import pandas as pd
+
+from app.config import get_settings
+from app.schemas import SchemaColumn, SchemaConceptMapping, SchemaJoinKey, SchemaManifest, SchemaRelation, UploadedAsset
+
+
+_SOURCES_TABLE = "__planera_registry_data_sources"
+_RELATIONS_TABLE = "__planera_registry_source_relations"
+_COLUMNS_TABLE = "__planera_registry_source_columns"
+_LINKS_TABLE = "__planera_registry_source_links"
+_SYSTEM_COLUMNS = {"record_id", "parent_record_id", "ordinal"}
+_SEMANTIC_ALIAS_LEXICON: dict[str, list[str]] = {
+ "owner": ["agent", "rep", "representative", "sales rep", "assignee"],
+ "manager": ["manager", "lead", "supervisor", "team lead"],
+ "regional_office": ["region", "regional office", "office", "territory"],
+ "account_id": ["account", "customer", "client", "account identifier"],
+ "deal_id": ["deal", "opportunity", "opportunity identifier"],
+ "deal_value": ["revenue", "deal size", "amount", "value"],
+ "stage": ["status stage", "pipeline stage"],
+ "segment": ["customer segment", "market segment"],
+ "pipeline_velocity_days": ["pipeline velocity", "cycle time", "sales cycle length"],
+}
+
+
+@dataclass(frozen=True)
+class SourceLinkPackage:
+ """Explicit relation link stored in the registry."""
+
+ left_source_id: str
+ left_relation_name: str
+ right_source_id: str
+ right_relation_name: str
+ join_keys: list[dict[str, str]]
+ link_type: str = "explicit"
+ is_explicit: bool = True
+ metadata: dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass(frozen=True)
+class RelationPackage:
+ """One materialized relation plus its normalized schema metadata."""
+
+ relation: SchemaRelation
+ frame: pd.DataFrame
+
+
+@dataclass(frozen=True)
+class SourcePackage:
+ """Unified internal representation for uploaded or built-in sources."""
+
+ source_id: str
+ source_name: str
+ source_slug: str
+ source_kind: str
+ source_format: str
+ origin: str
+ file_name: str
+ file_type: str
+ size_bytes: int
+ raw_payload: bytes | None
+ relations: list[RelationPackage]
+ links: list[SourceLinkPackage] = field(default_factory=list)
+ metadata: dict[str, Any] = field(default_factory=dict)
+
+ @property
+ def primary_relation_name(self) -> str:
+ primary = next((item.relation.name for item in self.relations if item.relation.is_primary), self.relations[0].relation.name)
+ return primary
+
+
+def get_registry_path() -> Path:
+ """Return the on-disk DuckDB registry path."""
+
+ return get_settings().registry_path
+
+
+def _connect_registry(read_only: bool = False) -> duckdb.DuckDBPyConnection:
+ path = get_registry_path()
+ path.parent.mkdir(parents=True, exist_ok=True)
+ if read_only and not path.exists():
+ conn = duckdb.connect(database=str(path), read_only=False)
+ _ensure_registry_tables(conn)
+ return conn
+ conn = duckdb.connect(database=str(path), read_only=read_only)
+ if not read_only:
+ _ensure_registry_tables(conn)
+ return conn
+
+
+def _ensure_registry_tables(conn: duckdb.DuckDBPyConnection) -> None:
+ conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS {_SOURCES_TABLE} (
+ source_id TEXT PRIMARY KEY,
+ source_name TEXT,
+ source_slug TEXT,
+ source_kind TEXT,
+ source_format TEXT,
+ origin TEXT,
+ file_name TEXT,
+ file_type TEXT,
+ size_bytes BIGINT,
+ created_at TEXT,
+ content_hash TEXT,
+ primary_relation_name TEXT,
+ relation_count INTEGER,
+ row_count INTEGER,
+ status TEXT,
+ raw_payload BLOB,
+ metadata_json TEXT
+ )
+ """
+ )
+ conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS {_RELATIONS_TABLE} (
+ relation_name TEXT PRIMARY KEY,
+ source_id TEXT,
+ kind TEXT,
+ is_primary BOOLEAN,
+ parent_relation TEXT,
+ row_count BIGINT,
+ grain TEXT,
+ identifier_columns_json TEXT,
+ time_columns_json TEXT,
+ measure_columns_json TEXT,
+ dimension_columns_json TEXT,
+ lineage_json TEXT
+ )
+ """
+ )
+ conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS {_COLUMNS_TABLE} (
+ source_id TEXT,
+ relation_name TEXT,
+ ordinal INTEGER,
+ column_name TEXT,
+ dtype TEXT,
+ type_family TEXT,
+ original_name TEXT,
+ source_path TEXT,
+ nullable BOOLEAN,
+ semantic_hints_json TEXT
+ )
+ """
+ )
+ conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS {_LINKS_TABLE} (
+ link_id TEXT PRIMARY KEY,
+ left_source_id TEXT,
+ left_relation_name TEXT,
+ right_source_id TEXT,
+ right_relation_name TEXT,
+ link_type TEXT,
+ is_explicit BOOLEAN,
+ join_keys_json TEXT,
+ metadata_json TEXT
+ )
+ """
+ )
+
+
+def clear_source_registry() -> None:
+ """Remove the persisted registry database."""
+
+ path = get_registry_path()
+ if path.exists():
+ path.unlink()
+ from app.data.semantic_model import clear_semantic_context_cache
+
+ clear_semantic_context_cache()
+
+
+def _now_iso() -> str:
+ return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
+
+
+def _short_id(prefix: str) -> str:
+ return f"{prefix}_{uuid4().hex[:8]}"
+
+
+def _slugify(value: str) -> str:
+ cleaned = re.sub(r"[^a-zA-Z0-9]+", "_", value.strip().lower()).strip("_")
+ cleaned = re.sub(r"_+", "_", cleaned)
+ return cleaned[:48] or "source"
+
+
+def _safe_identifier(value: str) -> str:
+ normalized = _slugify(value)
+ if normalized[0].isdigit():
+ return f"c_{normalized}"
+ return normalized
+
+
+def _derive_file_type(filename: str, source_format: str) -> str:
+ suffix = Path(filename).suffix.lstrip(".").upper()
+ if suffix:
+ return suffix
+ return source_format.upper() or "FILE"
+
+
+def _split_identifier(value: str) -> list[str]:
+ cleaned = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", value)
+ return [token.lower() for token in re.split(r"[^a-zA-Z0-9]+", cleaned) if token]
+
+
+def _type_family(dtype: str) -> str:
+ lower = dtype.lower()
+ if any(token in lower for token in ("int", "float", "double", "decimal")):
+ return "number"
+ if "bool" in lower:
+ return "boolean"
+ if any(token in lower for token in ("datetime", "timestamp", "date")):
+ return "datetime"
+ if any(token in lower for token in ("object", "string", "category")):
+ return "string"
+ return "unknown"
+
+
+def _semantic_hints(column_name: str, original_name: str = "", source_path: str = "") -> list[str]:
+ tokens = _split_identifier(column_name)
+ hints = {column_name, column_name.replace("_", " ")}
+ if original_name:
+ hints.add(original_name)
+ hints.add(original_name.replace("_", " "))
+ if source_path:
+ hints.add(source_path)
+ hints.add(source_path.replace(".", " "))
+ hints.update(tokens)
+ if column_name.endswith("_id"):
+ base = column_name[: -len("_id")].replace("_", " ").strip()
+ if base:
+ hints.add(f"{base} id")
+ hints.add(f"{base} identifier")
+
+ for key, aliases in _SEMANTIC_ALIAS_LEXICON.items():
+ key_tokens = set(_split_identifier(key))
+ if column_name == key or key_tokens.issubset(set(tokens)):
+ hints.update(aliases)
+
+ return sorted(hint for hint in hints if hint)
+
+
+def _build_semantic_mappings(columns: list[SchemaColumn]) -> list[SchemaConceptMapping]:
+ concept_to_columns: dict[str, set[str]] = {}
+ for column in columns:
+ for hint in column.semantic_hints:
+ normalized_hint = hint.strip().lower()
+ if not normalized_hint or normalized_hint == column.name.lower():
+ continue
+ concept_to_columns.setdefault(normalized_hint, set()).add(column.name)
+
+ mappings: list[SchemaConceptMapping] = []
+ for concept, mapped_columns in sorted(concept_to_columns.items()):
+ if len(concept) < 4:
+ continue
+ mappings.append(SchemaConceptMapping(concept=concept, columns=sorted(mapped_columns)))
+ return mappings[:20]
+
+
+def _is_identifier_column(column_name: str, series: pd.Series) -> bool:
+ lowered = column_name.lower()
+ if lowered == "id" or lowered.endswith("_id"):
+ return True
+ non_null = series.dropna()
+ return bool(len(non_null) == len(series) and len(non_null) > 0 and non_null.nunique(dropna=False) == len(series))
+
+
+def _infer_grain(name: str, frame: pd.DataFrame, identifier_columns: list[str]) -> str:
+ if identifier_columns:
+ primary = identifier_columns[0]
+ if primary.lower().endswith("_id"):
+ entity = primary[: -len("_id")].replace("_", " ").strip()
+ if entity:
+ return f"Approximately one row per {entity}"
+ return f"Rows can be keyed by {primary}"
+ return f"Rows represent records in {name}"
+
+
+def _normalize_scalar(value: Any) -> Any:
+ if isinstance(value, (str, int, float, bool)) or value is None:
+ return value
+ if isinstance(value, (datetime, pd.Timestamp)):
+ return pd.Timestamp(value)
+ return json.dumps(value, default=str)
+
+
+def _parse_text_bool(series: pd.Series) -> pd.Series | None:
+ non_null = series.dropna().astype(str).str.strip().str.lower()
+ if non_null.empty:
+ return None
+ if set(non_null.unique()).issubset({"true", "false"}):
+ mapped = series.map(
+ lambda item: None
+ if pd.isna(item)
+ else str(item).strip().lower() == "true"
+ )
+ return mapped.astype("boolean")
+ return None
+
+
+def _parse_text_numeric(series: pd.Series) -> pd.Series | None:
+ non_null = series.dropna()
+ if non_null.empty:
+ return None
+ parsed = pd.to_numeric(non_null, errors="coerce")
+ if parsed.notna().sum() != len(non_null):
+ return None
+ return pd.to_numeric(series, errors="coerce")
+
+
+def _parse_text_datetime(column_name: str, series: pd.Series) -> pd.Series | None:
+ lowered_name = column_name.lower()
+ if not (
+ any(token in lowered_name for token in ("date", "time", "timestamp"))
+ or lowered_name.endswith("_at")
+ or lowered_name == "at"
+ ):
+ return None
+ non_null = series.dropna()
+ if non_null.empty:
+ return None
+ parsed = pd.to_datetime(non_null, errors="coerce", utc=False)
+ if parsed.notna().sum() / len(non_null) < 0.8:
+ return None
+ return pd.to_datetime(series, errors="coerce", utc=False)
+
+
+def _coerce_frame_types(frame: pd.DataFrame) -> pd.DataFrame:
+ coerced = frame.copy()
+ for column_name in coerced.columns:
+ series = coerced[column_name]
+ if pd.api.types.is_bool_dtype(series) or pd.api.types.is_numeric_dtype(series) or pd.api.types.is_datetime64_any_dtype(series):
+ continue
+ if not (pd.api.types.is_object_dtype(series) or pd.api.types.is_string_dtype(series)):
+ continue
+
+ parsed = _parse_text_bool(series)
+ if parsed is None:
+ parsed = _parse_text_numeric(series)
+ if parsed is None:
+ parsed = _parse_text_datetime(str(column_name), series)
+ if parsed is not None:
+ coerced[column_name] = parsed
+ return coerced.convert_dtypes()
+
+
+def _rename_system_columns(frame: pd.DataFrame) -> pd.DataFrame:
+ renamed = frame.copy()
+ next_names: dict[str, str] = {}
+ for column in renamed.columns:
+ safe_name = _safe_identifier(str(column))
+ candidate = safe_name
+ if candidate in _SYSTEM_COLUMNS:
+ candidate = f"{candidate}_value"
+ counter = 2
+ while candidate in next_names.values():
+ candidate = f"{safe_name}_{counter}"
+ counter += 1
+ next_names[column] = candidate
+ return renamed.rename(columns=next_names)
+
+
+def _ensure_record_id(frame: pd.DataFrame) -> pd.DataFrame:
+ framed = frame.copy()
+ if "record_id" in framed.columns:
+ framed = framed.rename(columns={"record_id": "record_id_value"})
+ framed.insert(0, "record_id", [f"record_{index + 1}" for index in range(len(framed))])
+ return framed
+
+
+def _schema_relation_for_frame(
+ *,
+ relation_name: str,
+ source_id: str,
+ source_name: str,
+ frame: pd.DataFrame,
+ kind: str,
+ is_primary: bool,
+ parent_relation: str | None,
+ join_keys: list[SchemaJoinKey],
+ lineage: dict[str, Any],
+ column_paths: dict[str, dict[str, str]] | None = None,
+) -> SchemaRelation:
+ column_paths = column_paths or {}
+ columns: list[SchemaColumn] = []
+ identifier_columns: list[str] = []
+ time_columns: list[str] = []
+ measure_columns: list[str] = []
+ dimension_columns: list[str] = []
+
+ for column_name in frame.columns:
+ dtype = str(frame[column_name].dtype)
+ family = _type_family(dtype)
+ path_meta = column_paths.get(str(column_name), {})
+ column = SchemaColumn(
+ name=str(column_name),
+ dtype=dtype,
+ type_family=family,
+ original_name=path_meta.get("original_name", str(column_name)),
+ source_path=path_meta.get("source_path", str(column_name)),
+ nullable=bool(frame[column_name].isna().any()),
+ semantic_hints=_semantic_hints(
+ str(column_name),
+ original_name=path_meta.get("original_name", str(column_name)),
+ source_path=path_meta.get("source_path", str(column_name)),
+ ),
+ )
+ columns.append(column)
+
+ if _is_identifier_column(str(column_name), frame[column_name]):
+ identifier_columns.append(str(column_name))
+ if family == "datetime":
+ time_columns.append(str(column_name))
+ elif family == "number":
+ measure_columns.append(str(column_name))
+ else:
+ dimension_columns.append(str(column_name))
+
+ return SchemaRelation(
+ name=relation_name,
+ kind=kind,
+ source_id=source_id,
+ source_name=source_name,
+ is_primary=is_primary,
+ parent_relation=parent_relation,
+ row_count=int(len(frame)),
+ grain=_infer_grain(relation_name, frame, identifier_columns),
+ identifier_columns=identifier_columns,
+ time_columns=time_columns,
+ measure_columns=measure_columns,
+ dimension_columns=dimension_columns,
+ join_keys=join_keys,
+ lineage=lineage,
+ columns=columns,
+ semantic_mappings=_build_semantic_mappings(columns),
+ )
+
+
+def _source_filter_sql(source_ids: list[str] | None) -> tuple[str, list[str]]:
+ if not source_ids:
+ return "", []
+ placeholders = ", ".join(["?"] * len(source_ids))
+ return f" WHERE source_id IN ({placeholders})", list(source_ids)
+
+
+def _link_filter_sql(source_ids: list[str] | None) -> tuple[str, list[str]]:
+ if not source_ids:
+ return "", []
+ placeholders = ", ".join(["?"] * len(source_ids))
+ return (
+ f" WHERE left_source_id IN ({placeholders}) AND right_source_id IN ({placeholders})",
+ list(source_ids) + list(source_ids),
+ )
+
+
+def _read_uploaded_frame(filename: str, content: bytes) -> pd.DataFrame:
+ if not content:
+ raise ValueError("Uploaded file is empty.")
+
+ suffix = Path(filename).suffix.lower()
+ buffer = BytesIO(content)
+ try:
+ if suffix == ".csv":
+ return pd.read_csv(buffer)
+ except Exception as exc: # pragma: no cover - pandas errors vary
+ raise ValueError(f"Could not parse {filename} as a structured text dataset.") from exc
+
+ raise ValueError("Only CSV and JSON uploads are currently supported.")
+
+
+def _build_uploaded_asset(package: SourcePackage) -> UploadedAsset:
+ primary = next(item for item in package.relations if item.relation.is_primary)
+ visible_columns = [column.name for column in primary.relation.columns if column.name != "record_id"]
+ relation_count = len(package.relations)
+ summary = (
+ f"Persisted {primary.relation.row_count} rows across {len(visible_columns)} columns"
+ f" into {relation_count} relation{'s' if relation_count != 1 else ''}. "
+ f"Primary relation: {package.primary_relation_name}."
+ )
+ return UploadedAsset(
+ id=package.source_id,
+ name=package.file_name,
+ type=package.file_type,
+ source=package.origin,
+ sizeLabel=_bytes_to_size(package.size_bytes),
+ uploadedAt=_now_iso(),
+ status="verified",
+ rows=primary.relation.row_count,
+ columns=len(visible_columns),
+ relationCount=relation_count,
+ primaryRelationName=package.primary_relation_name,
+ summary=summary,
+ )
+
+
+def _bytes_to_size(num_bytes: int) -> str:
+ if num_bytes < 1024:
+ return f"{num_bytes} B"
+ if num_bytes < 1024 * 1024:
+ return f"{num_bytes / 1024:.1f} KB"
+ if num_bytes < 1024 * 1024 * 1024:
+ return f"{num_bytes / (1024 * 1024):.1f} MB"
+ return f"{num_bytes / (1024 * 1024 * 1024):.1f} GB"
+
+
+def _persist_source_package(conn: duckdb.DuckDBPyConnection, package: SourcePackage) -> None:
+ created_at = _now_iso()
+ conn.execute("BEGIN TRANSACTION")
+ try:
+ for relation in package.relations:
+ conn.execute(f'DELETE FROM {_RELATIONS_TABLE} WHERE relation_name = ?', [relation.relation.name])
+ conn.execute(f'DELETE FROM {_COLUMNS_TABLE} WHERE relation_name = ?', [relation.relation.name])
+ temp_name = f"tmp_{uuid4().hex[:8]}"
+ conn.register(temp_name, relation.frame)
+ conn.execute(f'DROP TABLE IF EXISTS "{relation.relation.name}"')
+ conn.execute(f'CREATE TABLE "{relation.relation.name}" AS SELECT * FROM "{temp_name}"')
+ conn.unregister(temp_name)
+ conn.execute(
+ f"""
+ INSERT INTO {_RELATIONS_TABLE} (
+ relation_name, source_id, kind, is_primary, parent_relation, row_count, grain,
+ identifier_columns_json, time_columns_json, measure_columns_json, dimension_columns_json, lineage_json
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ [
+ relation.relation.name,
+ package.source_id,
+ relation.relation.kind,
+ relation.relation.is_primary,
+ relation.relation.parent_relation,
+ relation.relation.row_count,
+ relation.relation.grain,
+ json.dumps(relation.relation.identifier_columns),
+ json.dumps(relation.relation.time_columns),
+ json.dumps(relation.relation.measure_columns),
+ json.dumps(relation.relation.dimension_columns),
+ json.dumps(relation.relation.lineage),
+ ],
+ )
+ column_rows = [
+ (
+ package.source_id,
+ relation.relation.name,
+ ordinal,
+ column.name,
+ column.dtype,
+ column.type_family,
+ column.original_name,
+ column.source_path,
+ column.nullable,
+ json.dumps(column.semantic_hints),
+ )
+ for ordinal, column in enumerate(relation.relation.columns, start=1)
+ ]
+ if column_rows:
+ conn.executemany(
+ f"""
+ INSERT INTO {_COLUMNS_TABLE} (
+ source_id, relation_name, ordinal, column_name, dtype, type_family,
+ original_name, source_path, nullable, semantic_hints_json
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ column_rows,
+ )
+
+ conn.execute(f'DELETE FROM {_LINKS_TABLE} WHERE left_source_id = ? OR right_source_id = ?', [package.source_id, package.source_id])
+ for link in package.links:
+ conn.execute(
+ f"""
+ INSERT INTO {_LINKS_TABLE} (
+ link_id, left_source_id, left_relation_name, right_source_id, right_relation_name,
+ link_type, is_explicit, join_keys_json, metadata_json
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ [
+ _short_id("link"),
+ link.left_source_id,
+ link.left_relation_name,
+ link.right_source_id,
+ link.right_relation_name,
+ link.link_type,
+ link.is_explicit,
+ json.dumps(link.join_keys),
+ json.dumps(link.metadata),
+ ],
+ )
+
+ primary = next(item for item in package.relations if item.relation.is_primary)
+ row_count = primary.relation.row_count
+ content_hash = sha256(package.raw_payload or b"").hexdigest()
+ raw_payload = None if package.source_kind == "upload" else package.raw_payload
+ conn.execute(f'DELETE FROM {_SOURCES_TABLE} WHERE source_id = ?', [package.source_id])
+ conn.execute(
+ f"""
+ INSERT INTO {_SOURCES_TABLE} (
+ source_id, source_name, source_slug, source_kind, source_format, origin, file_name, file_type,
+ size_bytes, created_at, content_hash, primary_relation_name, relation_count, row_count, status, raw_payload, metadata_json
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ [
+ package.source_id,
+ package.source_name,
+ package.source_slug,
+ package.source_kind,
+ package.source_format,
+ package.origin,
+ package.file_name,
+ package.file_type,
+ package.size_bytes,
+ created_at,
+ content_hash,
+ package.primary_relation_name,
+ len(package.relations),
+ row_count,
+ "verified",
+ raw_payload,
+ json.dumps(package.metadata),
+ ],
+ )
+ conn.execute("COMMIT")
+ except Exception:
+ conn.execute("ROLLBACK")
+ raise
+
+
+def _build_primary_relation_name(source_slug: str, source_id: str) -> str:
+ return f"{source_slug}_{source_id.split('_')[-1]}"
+
+
+def _build_child_relation_name(primary_relation_name: str, path_segments: tuple[str, ...]) -> str:
+ suffix = "__".join(_safe_identifier(segment) for segment in path_segments)
+ return f"{primary_relation_name}__{suffix}"
+
+
+def _ingest_csv_source(filename: str, content: bytes, source_id: str | None = None) -> SourcePackage:
+ safe_name = Path(filename or "upload.csv").name
+ frame = _ensure_record_id(_rename_system_columns(_read_uploaded_frame(safe_name, content)))
+ frame = _coerce_frame_types(frame)
+ source_id = source_id or _short_id("source")
+ source_slug = _slugify(Path(safe_name).stem)
+ relation_name = _build_primary_relation_name(source_slug, source_id)
+ column_paths = {
+ str(column): {
+ "original_name": str(column),
+ "source_path": str(column),
+ }
+ for column in frame.columns
+ }
+ relation = _schema_relation_for_frame(
+ relation_name=relation_name,
+ source_id=source_id,
+ source_name=safe_name,
+ frame=frame,
+ kind="table",
+ is_primary=True,
+ parent_relation=None,
+ join_keys=[],
+ lineage={"format": "csv", "json_path": "$"},
+ column_paths=column_paths,
+ )
+ return SourcePackage(
+ source_id=source_id,
+ source_name=safe_name,
+ source_slug=source_slug,
+ source_kind="upload",
+ source_format="csv",
+ origin="Workspace upload",
+ file_name=safe_name,
+ file_type=_derive_file_type(safe_name, "csv"),
+ size_bytes=len(content),
+ raw_payload=content,
+ relations=[RelationPackage(relation=relation, frame=frame)],
+ )
+
+
+def _ingest_json_source(filename: str, content: bytes, source_id: str | None = None) -> SourcePackage:
+ safe_name = Path(filename or "upload.json").name
+ try:
+ payload = json.loads(content.decode("utf-8"))
+ except Exception as exc: # pragma: no cover - json errors vary
+ raise ValueError(f"Could not parse {safe_name} as JSON.") from exc
+
+ if isinstance(payload, dict):
+ records = [payload]
+ elif isinstance(payload, list) and all(isinstance(item, dict) for item in payload):
+ records = payload
+ elif isinstance(payload, list) and not payload:
+ records = []
+ else:
+ raise ValueError("JSON uploads must contain a top-level object or an array of objects.")
+
+ source_id = source_id or _short_id("source")
+ source_slug = _slugify(Path(safe_name).stem)
+ primary_relation_name = _build_primary_relation_name(source_slug, source_id)
+
+ relation_rows: dict[tuple[str, ...], list[dict[str, Any]]] = {(): []}
+ relation_column_paths: dict[tuple[str, ...], dict[str, dict[str, str]]] = {(): {"record_id": {"original_name": "record_id", "source_path": "$.record_id"}}}
+ relation_parents: dict[tuple[str, ...], str | None] = {(): None}
+ relation_links: list[SourceLinkPackage] = []
+
+ def remember_column(path_key: tuple[str, ...], column_name: str, original_name: str, source_path: str) -> None:
+ relation_column_paths.setdefault(path_key, {})
+ relation_column_paths[path_key][column_name] = {"original_name": original_name, "source_path": source_path}
+
+ def walk_object(
+ obj: dict[str, Any],
+ *,
+ path_key: tuple[str, ...],
+ row: dict[str, Any],
+ record_id: str,
+ full_path: tuple[str, ...],
+ ) -> None:
+ for raw_key, value in obj.items():
+ safe_key = _safe_identifier(str(raw_key))
+ full_column_path = full_path + (str(raw_key),)
+ if safe_key in _SYSTEM_COLUMNS:
+ safe_key = f"{safe_key}_value"
+ if isinstance(value, dict):
+ walk_object(
+ value,
+ path_key=path_key,
+ row=row,
+ record_id=record_id,
+ full_path=full_column_path,
+ )
+ continue
+ if isinstance(value, list):
+ handle_array(
+ value,
+ parent_relation_path=path_key,
+ parent_relation_name=primary_relation_name if not path_key else _build_child_relation_name(primary_relation_name, path_key),
+ parent_record_id=record_id,
+ array_path=path_key + (safe_key,),
+ full_path=full_column_path,
+ )
+ continue
+
+ column_name = "__".join(_safe_identifier(part) for part in full_column_path[len(path_key) :])
+ if column_name in _SYSTEM_COLUMNS:
+ column_name = f"{column_name}_value"
+ row[column_name] = _normalize_scalar(value)
+ remember_column(path_key, column_name, str(raw_key), ".".join(full_column_path))
+
+ def handle_array(
+ values: list[Any],
+ *,
+ parent_relation_path: tuple[str, ...],
+ parent_relation_name: str,
+ parent_record_id: str,
+ array_path: tuple[str, ...],
+ full_path: tuple[str, ...],
+ ) -> None:
+ relation_rows.setdefault(array_path, [])
+ relation_column_paths.setdefault(
+ array_path,
+ {
+ "record_id": {"original_name": "record_id", "source_path": ".".join(full_path) + ".record_id"},
+ "parent_record_id": {"original_name": "parent_record_id", "source_path": ".".join(full_path) + ".parent_record_id"},
+ "ordinal": {"original_name": "ordinal", "source_path": ".".join(full_path) + ".ordinal"},
+ },
+ )
+ relation_parents[array_path] = parent_relation_name
+ relation_name = _build_child_relation_name(primary_relation_name, array_path)
+ if not any(link.right_relation_name == relation_name for link in relation_links):
+ relation_links.append(
+ SourceLinkPackage(
+ left_source_id=source_id,
+ left_relation_name=parent_relation_name,
+ right_source_id=source_id,
+ right_relation_name=relation_name,
+ join_keys=[{"left_column": "record_id", "right_column": "parent_record_id"}],
+ link_type="parent_child",
+ is_explicit=False,
+ metadata={"json_path": ".".join(full_path), "parent_relation_path": ".".join(parent_relation_path)},
+ )
+ )
+ for index, item in enumerate(values):
+ child_record_id = f"{parent_record_id}__{_safe_identifier(array_path[-1])}_{index + 1}"
+ child_row = {
+ "record_id": child_record_id,
+ "parent_record_id": parent_record_id,
+ "ordinal": index,
+ }
+ if isinstance(item, dict):
+ walk_object(item, path_key=array_path, row=child_row, record_id=child_record_id, full_path=full_path)
+ elif isinstance(item, list):
+ handle_array(
+ item,
+ parent_relation_path=array_path,
+ parent_relation_name=relation_name,
+ parent_record_id=child_record_id,
+ array_path=array_path + ("value",),
+ full_path=full_path + ("value",),
+ )
+ else:
+ child_row["value"] = _normalize_scalar(item)
+ remember_column(array_path, "value", "value", ".".join(full_path))
+ relation_rows[array_path].append(child_row)
+
+ for index, item in enumerate(records, start=1):
+ if not isinstance(item, dict):
+ raise ValueError("JSON uploads must contain objects at the top level.")
+ row = {"record_id": f"record_{index}"}
+ walk_object(item, path_key=(), row=row, record_id=row["record_id"], full_path=())
+ relation_rows[()].append(row)
+
+ primary_frame = pd.DataFrame(relation_rows[()])
+ if "record_id" not in primary_frame.columns:
+ primary_frame["record_id"] = pd.Series(dtype="string")
+ primary_frame = _coerce_frame_types(primary_frame)
+ relations: list[RelationPackage] = [
+ RelationPackage(
+ relation=_schema_relation_for_frame(
+ relation_name=primary_relation_name,
+ source_id=source_id,
+ source_name=safe_name,
+ frame=primary_frame,
+ kind="table",
+ is_primary=True,
+ parent_relation=None,
+ join_keys=[],
+ lineage={"format": "json", "json_path": "$"},
+ column_paths=relation_column_paths.get((), {}),
+ ),
+ frame=primary_frame,
+ )
+ ]
+
+ for relation_path in sorted(key for key in relation_rows if key):
+ relation_name = _build_child_relation_name(primary_relation_name, relation_path)
+ frame = _coerce_frame_types(pd.DataFrame(relation_rows[relation_path]))
+ join_keys = [
+ SchemaJoinKey(
+ target_relation=relation_parents[relation_path] or primary_relation_name,
+ source_column="parent_record_id",
+ target_column="record_id",
+ link_type="parent_child",
+ )
+ ]
+ relations.append(
+ RelationPackage(
+ relation=_schema_relation_for_frame(
+ relation_name=relation_name,
+ source_id=source_id,
+ source_name=safe_name,
+ frame=frame,
+ kind="table",
+ is_primary=False,
+ parent_relation=relation_parents[relation_path],
+ join_keys=join_keys,
+ lineage={"format": "json", "json_path": ".".join(relation_path)},
+ column_paths=relation_column_paths.get(relation_path, {}),
+ ),
+ frame=frame,
+ )
+ )
+
+ return SourcePackage(
+ source_id=source_id,
+ source_name=safe_name,
+ source_slug=source_slug,
+ source_kind="upload",
+ source_format="json",
+ origin="Workspace upload",
+ file_name=safe_name,
+ file_type=_derive_file_type(safe_name, "json"),
+ size_bytes=len(content),
+ raw_payload=content,
+ relations=relations,
+ links=relation_links,
+ )
+
+
+def ensure_builtin_sources() -> None:
+ """Backward-compatible no-op now that uploads are the only runtime data sources."""
+
+
+def ingest_source(filename: str, content: bytes, *, source_id: str | None = None) -> UploadedAsset:
+ """Persist a CSV/JSON upload into the registry and return its UI summary."""
+
+ safe_name = Path(filename or "upload.csv").name
+ suffix = Path(safe_name).suffix.lower()
+ if suffix == ".json":
+ package = _ingest_json_source(safe_name, content, source_id=source_id)
+ elif suffix == ".csv":
+ package = _ingest_csv_source(safe_name, content, source_id=source_id)
+ else:
+ raise ValueError("Only CSV and JSON uploads are currently supported.")
+ conn = _connect_registry(read_only=False)
+ try:
+ _persist_source_package(conn, package)
+ finally:
+ conn.close()
+ from app.data.semantic_model import clear_semantic_context_cache
+
+ clear_semantic_context_cache()
+ return _build_uploaded_asset(package)
+
+
+def delete_source(source_id: str) -> None:
+ """Remove one uploaded source and all of its relations from the registry."""
+
+ conn = _connect_registry(read_only=False)
+ try:
+ relation_names = [
+ row[0]
+ for row in conn.execute(
+ f"SELECT relation_name FROM {_RELATIONS_TABLE} WHERE source_id = ?",
+ [source_id],
+ ).fetchall()
+ ]
+ conn.execute("BEGIN TRANSACTION")
+ try:
+ conn.execute(f"DELETE FROM {_COLUMNS_TABLE} WHERE source_id = ?", [source_id])
+ conn.execute(f"DELETE FROM {_LINKS_TABLE} WHERE left_source_id = ? OR right_source_id = ?", [source_id, source_id])
+ conn.execute(f"DELETE FROM {_RELATIONS_TABLE} WHERE source_id = ?", [source_id])
+ conn.execute(f"DELETE FROM {_SOURCES_TABLE} WHERE source_id = ?", [source_id])
+ for relation_name in relation_names:
+ conn.execute(f'DROP TABLE IF EXISTS "{relation_name}"')
+ conn.execute("COMMIT")
+ except Exception:
+ conn.execute("ROLLBACK")
+ raise
+ finally:
+ conn.close()
+ from app.data.semantic_model import clear_semantic_context_cache
+
+ clear_semantic_context_cache()
+
+
+def create_source_link(
+ left_relation_name: str,
+ left_column: str,
+ right_relation_name: str,
+ right_column: str,
+) -> None:
+ """Create an explicit registry join path between two existing relations."""
+
+ conn = _connect_registry(read_only=False)
+ try:
+ rows = conn.execute(
+ f"""
+ SELECT relation_name, source_id
+ FROM {_RELATIONS_TABLE}
+ WHERE relation_name IN (?, ?)
+ """,
+ [left_relation_name, right_relation_name],
+ ).fetchall()
+ relation_to_source = {row[0]: row[1] for row in rows}
+ if left_relation_name not in relation_to_source or right_relation_name not in relation_to_source:
+ raise ValueError("Both relations must exist before creating an explicit source link.")
+ conn.execute(
+ f"""
+ INSERT INTO {_LINKS_TABLE} (
+ link_id, left_source_id, left_relation_name, right_source_id, right_relation_name,
+ link_type, is_explicit, join_keys_json, metadata_json
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ [
+ _short_id("link"),
+ relation_to_source[left_relation_name],
+ left_relation_name,
+ relation_to_source[right_relation_name],
+ right_relation_name,
+ "explicit",
+ True,
+ json.dumps([{"left_column": left_column, "right_column": right_column}]),
+ json.dumps({}),
+ ],
+ )
+ finally:
+ conn.close()
+ from app.data.semantic_model import clear_semantic_context_cache
+
+ clear_semantic_context_cache()
+
+
+def _load_link_map(conn: duckdb.DuckDBPyConnection, source_ids: list[str] | None = None) -> dict[str, list[SchemaJoinKey]]:
+ filter_sql, params = _link_filter_sql(source_ids)
+ rows = conn.execute(
+ f"""
+ SELECT left_source_id, left_relation_name, right_source_id, right_relation_name, link_type, join_keys_json
+ FROM {_LINKS_TABLE}
+ {filter_sql}
+ """,
+ params,
+ ).fetchall()
+ link_map: dict[str, list[SchemaJoinKey]] = {}
+ for row in rows:
+ _, left_relation_name, _, right_relation_name, link_type, join_keys_json = row
+ join_rows = json.loads(join_keys_json or "[]")
+ for join_row in join_rows:
+ left_key = SchemaJoinKey(
+ target_relation=right_relation_name,
+ source_column=join_row["left_column"],
+ target_column=join_row["right_column"],
+ link_type="parent_child" if link_type == "parent_child" else "explicit",
+ )
+ right_key = SchemaJoinKey(
+ target_relation=left_relation_name,
+ source_column=join_row["right_column"],
+ target_column=join_row["left_column"],
+ link_type="parent_child" if link_type == "parent_child" else "explicit",
+ )
+ link_map.setdefault(left_relation_name, []).append(left_key)
+ link_map.setdefault(right_relation_name, []).append(right_key)
+ return link_map
+
+
+def get_schema_manifest(source_ids: list[str] | None = None) -> dict[str, Any]:
+ """Return the normalized schema manifest for the requested scope."""
+
+ conn = _connect_registry(read_only=True)
+ try:
+ filter_sql, params = _source_filter_sql(source_ids)
+ relation_rows = conn.execute(
+ f"""
+ SELECT relation_name, source_id, kind, is_primary, parent_relation, row_count, grain,
+ identifier_columns_json, time_columns_json, measure_columns_json, dimension_columns_json, lineage_json
+ FROM {_RELATIONS_TABLE}
+ {filter_sql}
+ ORDER BY source_id, is_primary DESC, relation_name
+ """,
+ params,
+ ).fetchall()
+ source_names = {
+ row[0]: row[1]
+ for row in conn.execute(
+ f"SELECT source_id, source_name FROM {_SOURCES_TABLE}{filter_sql}",
+ params,
+ ).fetchall()
+ }
+ link_map = _load_link_map(conn, source_ids)
+ relations: list[SchemaRelation] = []
+ for row in relation_rows:
+ relation_name, source_id, kind, is_primary, parent_relation, row_count, grain, identifier_columns_json, time_columns_json, measure_columns_json, dimension_columns_json, lineage_json = row
+ column_rows = conn.execute(
+ f"""
+ SELECT column_name, dtype, type_family, original_name, source_path, nullable, semantic_hints_json
+ FROM {_COLUMNS_TABLE}
+ WHERE relation_name = ?
+ ORDER BY ordinal
+ """,
+ [relation_name],
+ ).fetchall()
+ columns = [
+ SchemaColumn(
+ name=column_name,
+ dtype=dtype,
+ type_family=type_family,
+ original_name=original_name or column_name,
+ source_path=source_path or column_name,
+ nullable=bool(nullable),
+ semantic_hints=json.loads(semantic_hints_json or "[]"),
+ )
+ for column_name, dtype, type_family, original_name, source_path, nullable, semantic_hints_json in column_rows
+ ]
+ relation = SchemaRelation(
+ name=relation_name,
+ kind=kind,
+ source_id=source_id,
+ source_name=source_names.get(source_id, source_id),
+ is_primary=bool(is_primary),
+ parent_relation=parent_relation,
+ row_count=int(row_count),
+ grain=grain or "",
+ identifier_columns=json.loads(identifier_columns_json or "[]"),
+ time_columns=json.loads(time_columns_json or "[]"),
+ measure_columns=json.loads(measure_columns_json or "[]"),
+ dimension_columns=json.loads(dimension_columns_json or "[]"),
+ join_keys=link_map.get(relation_name, []),
+ lineage=json.loads(lineage_json or "{}"),
+ columns=columns,
+ semantic_mappings=_build_semantic_mappings(columns),
+ )
+ relations.append(relation)
+
+ manifest = SchemaManifest(
+ reference_date="",
+ source="source_registry",
+ dialect="duckdb",
+ relations=relations,
+ views=[
+ {
+ "name": relation.name,
+ "source_id": relation.source_id,
+ "source_name": relation.source_name,
+ "row_count": relation.row_count,
+ "is_primary": relation.is_primary,
+ "columns": [{"name": column.name, "dtype": column.dtype} for column in relation.columns],
+ }
+ for relation in relations
+ ],
+ )
+ return manifest.model_dump()
+ finally:
+ conn.close()
+
+
+def load_relation_frames(source_ids: list[str] | None = None) -> dict[str, pd.DataFrame]:
+ """Load materialized relation tables for the requested scope."""
+
+ conn = _connect_registry(read_only=True)
+ try:
+ filter_sql, params = _source_filter_sql(source_ids)
+ names = [
+ row[0]
+ for row in conn.execute(
+ f"SELECT relation_name FROM {_RELATIONS_TABLE}{filter_sql} ORDER BY relation_name",
+ params,
+ ).fetchall()
+ ]
+ return {name: conn.execute(f'SELECT * FROM "{name}"').fetchdf() for name in names}
+ finally:
+ conn.close()
+
+
+def get_registered_relation_names(source_ids: list[str] | None = None) -> list[str]:
+ """Return visible relation names for the requested scope."""
+
+ conn = _connect_registry(read_only=True)
+ try:
+ filter_sql, params = _source_filter_sql(source_ids)
+ return [
+ row[0]
+ for row in conn.execute(
+ f"SELECT relation_name FROM {_RELATIONS_TABLE}{filter_sql} ORDER BY relation_name",
+ params,
+ ).fetchall()
+ ]
+ finally:
+ conn.close()
+
+
+def get_upload_source_ids(source_ids: list[str] | None = None) -> list[str]:
+ """Return matching uploaded source ids for the provided ids."""
+
+ if not source_ids:
+ return []
+ conn = _connect_registry(read_only=False)
+ try:
+ placeholders = ", ".join(["?"] * len(source_ids))
+ rows = conn.execute(
+ f"""
+ SELECT source_id
+ FROM {_SOURCES_TABLE}
+ WHERE source_kind = 'upload' AND source_id IN ({placeholders})
+ ORDER BY source_id
+ """,
+ list(dict.fromkeys(source_ids)),
+ ).fetchall()
+ return [row[0] for row in rows]
+ finally:
+ conn.close()
diff --git a/app/data/semantic_model.py b/app/data/semantic_model.py
index 3cdcef6..53d851c 100644
--- a/app/data/semantic_model.py
+++ b/app/data/semantic_model.py
@@ -1,8 +1,7 @@
-"""Dataset views and schema manifest for planner and executor."""
+"""Registry-backed dataset views and schema manifest for planner and executor."""
from __future__ import annotations
-import re
from dataclasses import dataclass
from functools import lru_cache
from typing import Any
@@ -10,8 +9,7 @@
import duckdb
import pandas as pd
-from app.data.loader import load_data
-from app.schemas import SchemaColumn, SchemaConceptMapping, SchemaManifest, SchemaRelation
+from app.data.registry import get_registered_relation_names, get_registry_path, get_schema_manifest, load_relation_frames
@dataclass(frozen=True)
@@ -25,175 +23,63 @@ class SemanticContext:
schema_manifest: dict[str, Any]
-_SEMANTIC_ALIAS_LEXICON: dict[str, list[str]] = {
- "owner": ["agent", "rep", "representative", "sales rep", "assignee"],
- "manager": ["manager", "lead", "supervisor", "team lead"],
- "regional_office": ["region", "regional office", "office", "territory"],
- "account_id": ["account", "customer", "client", "account identifier"],
- "deal_id": ["deal", "opportunity", "opportunity identifier"],
- "deal_value": ["revenue", "deal size", "amount", "value"],
- "stage": ["status stage", "pipeline stage"],
- "segment": ["customer segment", "market segment"],
- "pipeline_velocity_days": ["pipeline velocity", "cycle time", "sales cycle length"],
-}
-
-
-def _split_identifier(value: str) -> list[str]:
- cleaned = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", value)
- tokens = [token.lower() for token in re.split(r"[^a-zA-Z0-9]+", cleaned) if token]
- return tokens
-
-
-def _type_family(dtype: str) -> str:
- lower = dtype.lower()
- if any(token in lower for token in ("int", "float", "double", "decimal")):
- return "number"
- if "bool" in lower:
- return "boolean"
- if any(token in lower for token in ("datetime", "timestamp", "date")):
- return "datetime"
- if any(token in lower for token in ("object", "string", "category")):
- return "string"
- return "unknown"
-
-
-def _semantic_hints(column_name: str) -> list[str]:
- tokens = _split_identifier(column_name)
- hints = {column_name, column_name.replace("_", " ")}
- hints.update(tokens)
- if column_name.endswith("_id"):
- base = column_name[: -len("_id")].replace("_", " ").strip()
- if base:
- hints.add(f"{base} id")
- hints.add(f"{base} identifier")
-
- for key, aliases in _SEMANTIC_ALIAS_LEXICON.items():
- key_tokens = set(_split_identifier(key))
- if column_name == key or key_tokens.issubset(set(tokens)):
- hints.update(aliases)
-
- return sorted(hint for hint in hints if hint)
-
-
-def _is_identifier_column(column_name: str, series: pd.Series) -> bool:
- lowered = column_name.lower()
- if lowered == "id" or lowered.endswith("_id"):
- return True
- non_null = series.dropna()
- return bool(len(non_null) == len(series) and len(non_null) > 0 and non_null.nunique(dropna=False) == len(series))
-
-
-def _infer_grain(name: str, frame: pd.DataFrame, identifier_columns: list[str]) -> str:
- if identifier_columns:
- primary = identifier_columns[0]
- if primary.lower().endswith("_id"):
- entity = primary[: -len("_id")].replace("_", " ").strip()
- if entity:
- return f"Approximately one row per {entity}"
- return f"Rows can be keyed by {primary}"
- return f"Rows represent records in {name}"
-
-
-def _build_semantic_mappings(columns: list[SchemaColumn]) -> list[SchemaConceptMapping]:
- concept_to_columns: dict[str, set[str]] = {}
- for column in columns:
- for hint in column.semantic_hints:
- normalized_hint = hint.strip().lower()
- if not normalized_hint or normalized_hint == column.name.lower():
- continue
- concept_to_columns.setdefault(normalized_hint, set()).add(column.name)
-
- mappings: list[SchemaConceptMapping] = []
- for concept, mapped_columns in sorted(concept_to_columns.items()):
- if len(concept) < 4:
- continue
- mappings.append(
- SchemaConceptMapping(
- concept=concept,
- columns=sorted(mapped_columns),
- )
- )
- return mappings[:20]
-
-
-def _relation_for_frame(name: str, frame: pd.DataFrame, kind: str = "view") -> SchemaRelation:
- columns: list[SchemaColumn] = []
- identifier_columns: list[str] = []
- time_columns: list[str] = []
- measure_columns: list[str] = []
- dimension_columns: list[str] = []
-
- for column_name in frame.columns:
- dtype = str(frame[column_name].dtype)
- family = _type_family(dtype)
- column = SchemaColumn(
- name=column_name,
- dtype=dtype,
- type_family=family,
- semantic_hints=_semantic_hints(column_name),
- )
- columns.append(column)
-
- if _is_identifier_column(column_name, frame[column_name]):
- identifier_columns.append(column_name)
- if family == "datetime":
- time_columns.append(column_name)
- elif family == "number":
- measure_columns.append(column_name)
- else:
- dimension_columns.append(column_name)
-
- return SchemaRelation(
- name=name,
- kind=kind,
- row_count=int(len(frame)),
- grain=_infer_grain(name, frame, identifier_columns),
- identifier_columns=identifier_columns,
- time_columns=time_columns,
- measure_columns=measure_columns,
- dimension_columns=dimension_columns,
- columns=columns,
- semantic_mappings=_build_semantic_mappings(columns),
- )
-
+def _source_ids_key(source_ids: list[str] | tuple[str, ...] | None) -> tuple[str, ...]:
+ if not source_ids:
+ return ()
+ return tuple(sorted(source_ids))
-@lru_cache(maxsize=1)
-def get_semantic_context() -> SemanticContext:
- """Build and cache views plus a schema-only manifest for planning."""
- bundle = load_data()
- raw_views = {name: frame.copy() for name, frame in bundle.raw_views.items()}
- semantic_views = {"opportunities_enriched": bundle.crm.copy()}
- all_frames = {**raw_views, **semantic_views}
- relations = [_relation_for_frame(name, frame, kind="view") for name, frame in all_frames.items()]
- schema_manifest = SchemaManifest(
- reference_date=bundle.reference_date,
- source=bundle.source,
- dialect="duckdb",
- relations=relations,
- views=[
- {
- "name": relation.name,
- "row_count": relation.row_count,
- "columns": [{"name": column.name, "dtype": column.dtype} for column in relation.columns],
- }
- for relation in relations
- ],
- ).model_dump()
+@lru_cache(maxsize=32)
+def _get_semantic_context_cached(source_ids_key: tuple[str, ...]) -> SemanticContext:
+ source_ids = list(source_ids_key) if source_ids_key else None
+ manifest = get_schema_manifest(source_ids)
+ frames = load_relation_frames(source_ids)
+ primary_names = {relation["name"] for relation in manifest.get("relations", []) if relation.get("is_primary")}
+ raw_views = {name: frame for name, frame in frames.items() if name not in primary_names}
+ semantic_views = {name: frame for name, frame in frames.items() if name in primary_names}
return SemanticContext(
- reference_date=bundle.reference_date,
- source=bundle.source,
+ reference_date=manifest.get("reference_date", ""),
+ source=manifest.get("source", ""),
raw_views=raw_views,
semantic_views=semantic_views,
- schema_manifest=schema_manifest,
+ schema_manifest=manifest,
)
-def new_duckdb_connection() -> duckdb.DuckDBPyConnection:
- """Create an in-memory DuckDB connection with curated views registered."""
+def get_semantic_context(source_ids: list[str] | tuple[str, ...] | None = None) -> SemanticContext:
+ """Build and cache views plus a schema-only manifest for planning."""
+
+ return _get_semantic_context_cached(_source_ids_key(source_ids))
+
+
+def clear_semantic_context_cache() -> None:
+ """Clear cached registry-backed semantic contexts."""
+
+ _get_semantic_context_cached.cache_clear()
+
+
+def new_duckdb_connection(dataset_context: dict[str, Any] | None = None) -> duckdb.DuckDBPyConnection:
+ """Create an in-memory DuckDB connection exposing only the requested relations."""
+
+ relation_names = [
+ relation["name"]
+ for relation in (dataset_context or {}).get("relations", [])
+ if relation.get("name")
+ ]
+ if not relation_names:
+ relation_names = [
+ view["name"]
+ for view in (dataset_context or {}).get("views", [])
+ if view.get("name")
+ ]
+ if not relation_names:
+ relation_names = get_registered_relation_names()
- context = get_semantic_context()
conn = duckdb.connect(database=":memory:")
- for name, frame in {**context.raw_views, **context.semantic_views}.items():
- conn.register(name, frame)
+ registry_path = str(get_registry_path()).replace("'", "''")
+ conn.execute(f"ATTACH '{registry_path}' AS registry_db")
+
+ for relation_name in relation_names:
+ escaped = relation_name.replace('"', '""')
+ conn.execute(f'CREATE VIEW "{escaped}" AS SELECT * FROM registry_db."{escaped}"')
return conn
diff --git a/app/db/schema_compat.py b/app/db/schema_compat.py
new file mode 100644
index 0000000..3a26c8f
--- /dev/null
+++ b/app/db/schema_compat.py
@@ -0,0 +1,144 @@
+"""Small SQLite compatibility migrations for local development databases."""
+
+from __future__ import annotations
+
+from sqlalchemy import inspect, text
+from sqlalchemy.engine import Engine
+
+_UPLOADS_TABLE = "uploads"
+_LEGACY_UPLOADS_TABLE = "uploads__legacy_schema_compat"
+_EXPECTED_UPLOAD_COLUMNS = {
+ "source_id",
+ "user_id",
+ "original_filename",
+ "storage_path",
+ "content_type",
+ "size_bytes",
+ "content_hash",
+ "status",
+ "rows",
+ "columns",
+ "relation_count",
+ "primary_relation_name",
+ "created_at",
+ "updated_at",
+}
+
+
+def ensure_sqlite_schema_compatibility(engine: Engine) -> None:
+ """Upgrade legacy local SQLite tables that predate the current ORM shape."""
+
+ _migrate_uploads_table_if_needed(engine)
+
+
+def _migrate_uploads_table_if_needed(engine: Engine) -> None:
+ inspector = inspect(engine)
+ if _UPLOADS_TABLE not in inspector.get_table_names():
+ return
+
+ columns = inspector.get_columns(_UPLOADS_TABLE)
+ column_names = {column["name"] for column in columns}
+ pk_columns = [column["name"] for column in columns if column.get("primary_key")]
+ if column_names == _EXPECTED_UPLOAD_COLUMNS and pk_columns == ["source_id"]:
+ return
+
+ if _LEGACY_UPLOADS_TABLE in inspector.get_table_names():
+ raise RuntimeError(
+ "Cannot migrate the uploads table because a previous temporary migration table still exists: "
+ f"{_LEGACY_UPLOADS_TABLE}."
+ )
+
+ select_sql = _build_upload_copy_select(column_names)
+ with engine.begin() as connection:
+ connection.execute(text(f"ALTER TABLE {_UPLOADS_TABLE} RENAME TO {_LEGACY_UPLOADS_TABLE}"))
+ connection.execute(
+ text(
+ f"""
+ CREATE TABLE {_UPLOADS_TABLE} (
+ source_id VARCHAR(64) NOT NULL PRIMARY KEY,
+ user_id INTEGER NOT NULL,
+ original_filename VARCHAR(512) NOT NULL,
+ storage_path VARCHAR(1024) NOT NULL,
+ content_type VARCHAR(255) NOT NULL DEFAULT 'application/octet-stream',
+ size_bytes INTEGER NOT NULL,
+ content_hash VARCHAR(128) NOT NULL,
+ status VARCHAR(32) NOT NULL DEFAULT 'verified',
+ rows INTEGER,
+ columns INTEGER,
+ relation_count INTEGER,
+ primary_relation_name VARCHAR(255),
+ created_at DATETIME NOT NULL,
+ updated_at DATETIME NOT NULL,
+ FOREIGN KEY(user_id) REFERENCES users (id) ON DELETE CASCADE
+ )
+ """
+ )
+ )
+ connection.execute(
+ text(
+ f"""
+ INSERT INTO {_UPLOADS_TABLE} (
+ source_id,
+ user_id,
+ original_filename,
+ storage_path,
+ content_type,
+ size_bytes,
+ content_hash,
+ status,
+ rows,
+ columns,
+ relation_count,
+ primary_relation_name,
+ created_at,
+ updated_at
+ )
+ {select_sql}
+ """
+ )
+ )
+ connection.execute(text(f"DROP TABLE {_LEGACY_UPLOADS_TABLE}"))
+ connection.execute(text("CREATE INDEX IF NOT EXISTS ix_uploads_user_id ON uploads (user_id)"))
+
+
+def _build_upload_copy_select(column_names: set[str]) -> str:
+ def has_column(name: str) -> bool:
+ return name in column_names
+
+ def quoted(name: str) -> str:
+ return f'"{name}"'
+
+ def first_present(*names: str, default: str) -> str:
+ for name in names:
+ if has_column(name):
+ return quoted(name)
+ return default
+
+ source_id_expr = first_present(
+ "source_id",
+ "upload_id",
+ default="('source_' || lower(hex(randomblob(4))))",
+ )
+ content_type_expr = first_present("content_type", default="'application/octet-stream'")
+ status_expr = first_present("status", default="'verified'")
+ created_at_expr = first_present("created_at", default="CURRENT_TIMESTAMP")
+ updated_at_expr = first_present("updated_at", "created_at", default="CURRENT_TIMESTAMP")
+
+ return f"""
+ SELECT
+ COALESCE(NULLIF({source_id_expr}, ''), 'source_' || lower(hex(randomblob(4)))) AS source_id,
+ {first_present("user_id", default='0')} AS user_id,
+ {first_present("original_filename", default="'upload.csv'")} AS original_filename,
+ {first_present("storage_path", default="''")} AS storage_path,
+ COALESCE({content_type_expr}, 'application/octet-stream') AS content_type,
+ {first_present("size_bytes", default='0')} AS size_bytes,
+ {first_present("content_hash", default="''")} AS content_hash,
+ COALESCE({status_expr}, 'verified') AS status,
+ {first_present("rows", default='NULL')} AS rows,
+ {first_present("columns", default='NULL')} AS columns,
+ {first_present("relation_count", default='NULL')} AS relation_count,
+ {first_present("primary_relation_name", default='NULL')} AS primary_relation_name,
+ COALESCE({created_at_expr}, CURRENT_TIMESTAMP) AS created_at,
+ COALESCE({updated_at_expr}, COALESCE({created_at_expr}, CURRENT_TIMESTAMP)) AS updated_at
+ FROM {_LEGACY_UPLOADS_TABLE}
+ """
diff --git a/app/main.py b/app/main.py
index 4085152..5f495b1 100644
--- a/app/main.py
+++ b/app/main.py
@@ -23,10 +23,13 @@ async def lifespan(_app: FastAPI):
"""Create database tables on startup (SQLite file demo; no Alembic in this phase)."""
from app.db.base import Base
+ from app.db.schema_compat import ensure_sqlite_schema_compatibility
from app.db.session import get_engine
- from app.models import Conversation, InspectionSnapshot, Message, User # noqa: F401
+ from app.models import Conversation, InspectionSnapshot, Message, UploadRecord, User # noqa: F401
- Base.metadata.create_all(bind=get_engine())
+ engine = get_engine()
+ Base.metadata.create_all(bind=engine)
+ ensure_sqlite_schema_compatibility(engine)
yield
diff --git a/app/models/__init__.py b/app/models/__init__.py
index 4649a41..580bf2c 100644
--- a/app/models/__init__.py
+++ b/app/models/__init__.py
@@ -2,6 +2,7 @@
from app.models.conversation import Conversation, Message
from app.models.inspection_snapshot import InspectionSnapshot
+from app.models.upload import UploadRecord
from app.models.user import User
-__all__ = ["Conversation", "InspectionSnapshot", "Message", "User"]
+__all__ = ["Conversation", "InspectionSnapshot", "Message", "UploadRecord", "User"]
diff --git a/app/models/upload.py b/app/models/upload.py
new file mode 100644
index 0000000..3f16b6b
--- /dev/null
+++ b/app/models/upload.py
@@ -0,0 +1,44 @@
+"""User-owned uploaded source metadata."""
+
+from __future__ import annotations
+
+from datetime import datetime, timezone
+from typing import TYPE_CHECKING
+
+from sqlalchemy import DateTime, ForeignKey, Integer, String
+from sqlalchemy.orm import Mapped, mapped_column, relationship
+
+from app.db.base import Base
+
+if TYPE_CHECKING:
+ from app.models.user import User
+
+
+def _utc_now() -> datetime:
+ return datetime.now(timezone.utc)
+
+
+class UploadRecord(Base):
+ __tablename__ = "uploads"
+
+ source_id: Mapped[str] = mapped_column(String(64), primary_key=True)
+ user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True, nullable=False)
+ original_filename: Mapped[str] = mapped_column(String(512), nullable=False)
+ storage_path: Mapped[str] = mapped_column(String(1024), nullable=False)
+ content_type: Mapped[str] = mapped_column(String(255), nullable=False, default="application/octet-stream")
+ size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
+ content_hash: Mapped[str] = mapped_column(String(128), nullable=False)
+ status: Mapped[str] = mapped_column(String(32), nullable=False, default="verified")
+ rows: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ columns: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ relation_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ primary_relation_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utc_now, nullable=False)
+ updated_at: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True),
+ default=_utc_now,
+ onupdate=_utc_now,
+ nullable=False,
+ )
+
+ user: Mapped[User] = relationship("User", back_populates="uploads")
diff --git a/app/models/user.py b/app/models/user.py
index dcdd9df..48cea3f 100644
--- a/app/models/user.py
+++ b/app/models/user.py
@@ -12,6 +12,7 @@
if TYPE_CHECKING:
from app.models.conversation import Conversation
+ from app.models.upload import UploadRecord
def _utc_now() -> datetime:
@@ -32,3 +33,8 @@ class User(Base):
back_populates="user",
cascade="all, delete-orphan",
)
+ uploads: Mapped[list[UploadRecord]] = relationship(
+ "UploadRecord",
+ back_populates="user",
+ cascade="all, delete-orphan",
+ )
diff --git a/app/prompts/planner_compiled.j2 b/app/prompts/planner_compiled.j2
index d57d691..74c9587 100644
--- a/app/prompts/planner_compiled.j2
+++ b/app/prompts/planner_compiled.j2
@@ -10,6 +10,9 @@ Rules:
- Use only these registered relation names: {{ relation_names_json }}
- Use exact column names only. Never invent, normalize, paraphrase, or rename a column.
- Use semantic mappings only to translate user language into exact schema fields.
+- Prefer relations where "is_primary" is true before reaching for child relations.
+- Only join relations when the schema subset exposes an explicit path in "join_keys".
+- For nested JSON child relations, use the exact join columns from "join_keys" instead of guessing parent-child keys.
- Do not assume the question premise is true. First verify the main metric or comparison before planning causal or grouped breakdowns.
- Prefer SQL over multiple trivial splits; combine logic when one query suffices.
- No imports, file I/O, network calls, or plotting.
diff --git a/app/prompts/planner_repair.j2 b/app/prompts/planner_repair.j2
index 8e697d2..ada2780 100644
--- a/app/prompts/planner_repair.j2
+++ b/app/prompts/planner_repair.j2
@@ -8,6 +8,8 @@ Rules:
- updated_step must use "type": "sql", the same id as the failed step ({{ failed_step_id }}), and a fixed "query".
- Use only exact relation and column names from the schema manifest.
- Use semantic mappings only to translate user language into exact schema fields.
+- Prefer relations where "is_primary" is true before using child relations.
+- Only join relations when the schema manifest exposes the join path in "join_keys".
- Do not invent fields or rename columns.
Original plan:
diff --git a/app/schemas.py b/app/schemas.py
index 61c27e5..ba30286 100644
--- a/app/schemas.py
+++ b/app/schemas.py
@@ -12,6 +12,7 @@ class AnalyzeRequest(BaseModel):
"""Incoming analysis request."""
query: str = Field(..., min_length=3, examples=["Why did pipeline velocity drop this week?"])
+ source_ids: list[str] = Field(default_factory=list)
class TraceEvent(BaseModel):
@@ -99,9 +100,23 @@ class SchemaColumn(BaseModel):
name: str = Field(..., min_length=1)
dtype: str = Field(..., min_length=1)
type_family: Literal["string", "number", "boolean", "datetime", "unknown"] = "unknown"
+ original_name: str = ""
+ source_path: str = ""
+ nullable: bool = True
semantic_hints: list[str] = Field(default_factory=list)
+class SchemaJoinKey(BaseModel):
+ """One explicit relation-level join path available to the planner."""
+
+ model_config = ConfigDict(extra="forbid")
+
+ target_relation: str = Field(..., min_length=1)
+ source_column: str = Field(..., min_length=1)
+ target_column: str = Field(..., min_length=1)
+ link_type: Literal["parent_child", "explicit"] = "explicit"
+
+
class SchemaRelation(BaseModel):
"""One normalized table or view available to the planner."""
@@ -109,12 +124,18 @@ class SchemaRelation(BaseModel):
name: str = Field(..., min_length=1)
kind: Literal["table", "view"] = "view"
+ source_id: str = ""
+ source_name: str = ""
+ is_primary: bool = False
+ parent_relation: str | None = None
row_count: int = 0
grain: str = ""
identifier_columns: list[str] = Field(default_factory=list)
time_columns: list[str] = Field(default_factory=list)
measure_columns: list[str] = Field(default_factory=list)
dimension_columns: list[str] = Field(default_factory=list)
+ join_keys: list[SchemaJoinKey] = Field(default_factory=list)
+ lineage: dict[str, Any] = Field(default_factory=dict)
columns: list[SchemaColumn] = Field(default_factory=list)
semantic_mappings: list[SchemaConceptMapping] = Field(default_factory=list)
@@ -266,6 +287,7 @@ class ChatSubmitRequest(BaseModel):
conversation_id: int | str | None = None
query: str = Field(..., min_length=3)
+ source_ids: list[str] = Field(default_factory=list)
@field_validator("conversation_id", mode="before")
@classmethod
@@ -334,6 +356,8 @@ class UploadedAsset(BaseModel):
status: Literal["uploaded", "profiling", "verified", "error"]
rows: int | None = None
columns: int | None = None
+ relationCount: int | None = None
+ primaryRelationName: str | None = None
summary: str | None = None
diff --git a/app/services/analysis_run.py b/app/services/analysis_run.py
index e779925..8a55000 100644
--- a/app/services/analysis_run.py
+++ b/app/services/analysis_run.py
@@ -25,10 +25,10 @@ class StoredAnalysisRun:
inspection: InspectionData
-def run_stored_analysis(query: str) -> StoredAnalysisRun:
+def run_stored_analysis(query: str, source_ids: list[str] | None = None) -> StoredAnalysisRun:
"""Run `run_analysis`, persist inspection to process memory, return API + inspection objects."""
- state = run_analysis(query)
+ state = run_analysis(query, source_ids=source_ids)
base = AnalyzeResponse(
analysis=state["analysis"],
trace=state.get("trace", []),
diff --git a/app/uploads/__init__.py b/app/uploads/__init__.py
new file mode 100644
index 0000000..dd275ba
--- /dev/null
+++ b/app/uploads/__init__.py
@@ -0,0 +1 @@
+"""Upload persistence services."""
diff --git a/app/uploads/service.py b/app/uploads/service.py
new file mode 100644
index 0000000..7f6c91a
--- /dev/null
+++ b/app/uploads/service.py
@@ -0,0 +1,155 @@
+"""User-owned upload persistence and authorization helpers."""
+
+from __future__ import annotations
+
+from datetime import timezone
+from hashlib import sha256
+from pathlib import Path
+from uuid import uuid4
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from app.data.registry import delete_source, ingest_source
+from app.models.upload import UploadRecord
+from app.models.user import User
+from app.schemas import UploadedAsset
+from app.uploads.storage import LocalUploadBlobStore
+
+
+def _short_source_id() -> str:
+ return f"source_{uuid4().hex[:8]}"
+
+
+def _to_uploaded_at(value) -> str:
+ if value.tzinfo is None:
+ value = value.replace(tzinfo=timezone.utc)
+ return value.astimezone(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
+
+
+def _asset_from_record(record: UploadRecord) -> UploadedAsset:
+ relation_count = record.relation_count or 0
+ summary = None
+ if record.rows is not None and record.columns is not None:
+ summary = (
+ f"Persisted {record.rows} rows across {record.columns} columns"
+ f" into {relation_count} relation{'s' if relation_count != 1 else ''}."
+ )
+ if record.primary_relation_name:
+ summary = f"{summary} Primary relation: {record.primary_relation_name}."
+
+ return UploadedAsset(
+ id=record.source_id,
+ name=record.original_filename,
+ type=Path(record.original_filename).suffix.lstrip(".").upper() or "FILE",
+ source="Workspace upload",
+ sizeLabel=_bytes_to_size(record.size_bytes),
+ uploadedAt=_to_uploaded_at(record.created_at),
+ status=record.status, # type: ignore[arg-type]
+ rows=record.rows,
+ columns=record.columns,
+ relationCount=record.relation_count,
+ primaryRelationName=record.primary_relation_name,
+ summary=summary,
+ )
+
+
+def _bytes_to_size(num_bytes: int) -> str:
+ if num_bytes < 1024:
+ return f"{num_bytes} B"
+ if num_bytes < 1024 * 1024:
+ return f"{num_bytes / 1024:.1f} KB"
+ if num_bytes < 1024 * 1024 * 1024:
+ return f"{num_bytes / (1024 * 1024):.1f} MB"
+ return f"{num_bytes / (1024 * 1024 * 1024):.1f} GB"
+
+
+def list_user_uploads(db: Session, user: User) -> list[UploadedAsset]:
+ rows = db.execute(
+ select(UploadRecord)
+ .where(UploadRecord.user_id == user.id)
+ .order_by(UploadRecord.created_at.desc(), UploadRecord.source_id.desc())
+ ).scalars()
+ return [_asset_from_record(row) for row in rows]
+
+
+def get_authorized_source_ids(db: Session, user: User, source_ids: list[str] | None = None) -> list[str]:
+ unique_source_ids = list(dict.fromkeys(source_ids or []))
+ if not unique_source_ids:
+ return []
+
+ owned_rows = db.execute(
+ select(UploadRecord.source_id)
+ .where(UploadRecord.user_id == user.id, UploadRecord.source_id.in_(unique_source_ids))
+ ).scalars()
+ owned_ids = set(owned_rows)
+ return [source_id for source_id in unique_source_ids if source_id in owned_ids]
+
+
+def create_user_upload(
+ db: Session,
+ user: User,
+ *,
+ filename: str,
+ content_type: str | None,
+ content: bytes,
+ blob_store: LocalUploadBlobStore | None = None,
+) -> UploadedAsset:
+ store = blob_store or LocalUploadBlobStore()
+ source_id = _short_source_id()
+ storage_path = store.save(user_id=user.id, upload_id=source_id, filename=filename, content=content)
+
+ try:
+ asset = ingest_source(filename, content, source_id=source_id)
+ except Exception:
+ store.delete(storage_path)
+ raise
+
+ record = UploadRecord(
+ source_id=source_id,
+ user_id=user.id,
+ original_filename=filename,
+ storage_path=str(storage_path),
+ content_type=content_type or "application/octet-stream",
+ size_bytes=len(content),
+ content_hash=sha256(content).hexdigest(),
+ status=asset.status,
+ rows=asset.rows,
+ columns=asset.columns,
+ relation_count=asset.relationCount,
+ primary_relation_name=asset.primaryRelationName,
+ )
+ db.add(record)
+
+ try:
+ db.commit()
+ except Exception:
+ db.rollback()
+ delete_source(source_id)
+ store.delete(storage_path)
+ raise
+
+ db.refresh(record)
+ return _asset_from_record(record)
+
+
+def delete_user_upload(
+ db: Session,
+ user: User,
+ source_id: str,
+ *,
+ blob_store: LocalUploadBlobStore | None = None,
+) -> bool:
+ record = db.get(UploadRecord, source_id)
+ if record is None or record.user_id != user.id:
+ return False
+
+ store = blob_store or LocalUploadBlobStore()
+ storage_path = record.storage_path
+ db.delete(record)
+ db.commit()
+ try:
+ delete_source(source_id)
+ finally:
+ store.delete(storage_path)
+ return True
diff --git a/app/uploads/storage.py b/app/uploads/storage.py
new file mode 100644
index 0000000..53adab3
--- /dev/null
+++ b/app/uploads/storage.py
@@ -0,0 +1,34 @@
+"""Blob storage for raw uploaded files."""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+from app.config import get_settings
+
+
+class LocalUploadBlobStore:
+ """Persist raw uploads under a stable per-user filesystem path."""
+
+ def __init__(self, root: Path | None = None) -> None:
+ self.root = root or get_settings().upload_storage_dir
+
+ def save(self, *, user_id: int, upload_id: str, filename: str, content: bytes) -> Path:
+ suffix = Path(filename).suffix.lower()
+ upload_dir = self.root / str(user_id) / upload_id
+ upload_dir.mkdir(parents=True, exist_ok=True)
+ path = upload_dir / f"original{suffix}"
+ path.write_bytes(content)
+ return path
+
+ def delete(self, storage_path: str | Path) -> None:
+ path = Path(storage_path)
+ if path.exists():
+ path.unlink()
+ current = path.parent
+ root = self.root.resolve()
+ while current.exists() and current != root:
+ if any(current.iterdir()):
+ break
+ current.rmdir()
+ current = current.parent
diff --git a/data/source_registry.duckdb b/data/source_registry.duckdb
new file mode 100644
index 0000000..f9fc572
Binary files /dev/null and b/data/source_registry.duckdb differ
diff --git a/tests/test_api.py b/tests/test_api.py
index c99c639..fbb7da8 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -1,6 +1,7 @@
"""API response structure tests."""
from io import BytesIO
+import sqlite3
import pytest
from fastapi.testclient import TestClient
@@ -15,6 +16,8 @@
@pytest.fixture
def client(tmp_path, monkeypatch):
monkeypatch.setenv("DATABASE_PATH", str(tmp_path / "api_test.sqlite"))
+ monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "api_source_registry.duckdb"))
+ monkeypatch.setenv("UPLOAD_STORAGE_DIR", str(tmp_path / "uploads"))
get_settings.cache_clear()
reset_engine_and_session()
with TestClient(app) as test_client:
@@ -61,8 +64,14 @@ def test_sample_questions_endpoint(client: TestClient) -> None:
assert len(payload["questions"]) >= 3
+def _signup(client: TestClient, email: str, password: str = "password123") -> str:
+ response = client.post("/auth/signup", json={"email": email, "password": password})
+ assert response.status_code == 200, response.text
+ return response.json()["access_token"]
+
+
def test_analyze_endpoint_structure(client: TestClient) -> None:
- def fake_run_analysis(query: str) -> dict: # noqa: ARG001
+ def fake_run_analysis(query: str, source_ids=None) -> dict: # noqa: ARG001
return {
"analysis": "## Summary\nPipeline velocity improved.\n",
"trace": [{"step": "planner_compiled_node", "status": "completed", "details": {"objective": "x"}}],
@@ -95,7 +104,19 @@ def fake_run_analysis(query: str) -> dict: # noqa: ARG001
original = analysis_run.run_analysis
analysis_run.run_analysis = fake_run_analysis
try:
- response = client.post("/analyze", json={"query": "Why did pipeline velocity drop this week?"})
+ token = _signup(client, "analyze-structure@example.com")
+ headers = {"Authorization": f"Bearer {token}"}
+ upload_response = client.post(
+ "/uploads",
+ files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")},
+ headers=headers,
+ )
+ source_id = upload_response.json()["asset"]["id"]
+ response = client.post(
+ "/analyze",
+ json={"query": "Why did pipeline velocity drop this week?", "source_ids": [source_id]},
+ headers=headers,
+ )
finally:
analysis_run.run_analysis = original
assert response.status_code == 200
@@ -108,7 +129,7 @@ def fake_run_analysis(query: str) -> dict: # noqa: ARG001
def test_analyze_endpoint_returns_http_500_on_failure(client: TestClient) -> None:
- def fake_run_analysis(query: str) -> dict: # noqa: ARG001
+ def fake_run_analysis(query: str, source_ids=None) -> dict: # noqa: ARG001
raise RuntimeError("planner exploded")
import app.services.analysis_run as analysis_run
@@ -116,7 +137,19 @@ def fake_run_analysis(query: str) -> dict: # noqa: ARG001
original = analysis_run.run_analysis
analysis_run.run_analysis = fake_run_analysis
try:
- response = client.post("/analyze", json={"query": "Why did pipeline velocity drop this week?"})
+ token = _signup(client, "analyze-failure@example.com")
+ headers = {"Authorization": f"Bearer {token}"}
+ upload_response = client.post(
+ "/uploads",
+ files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")},
+ headers=headers,
+ )
+ source_id = upload_response.json()["asset"]["id"]
+ response = client.post(
+ "/analyze",
+ json={"query": "Why did pipeline velocity drop this week?", "source_ids": [source_id]},
+ headers=headers,
+ )
finally:
analysis_run.run_analysis = original
@@ -126,10 +159,21 @@ def fake_run_analysis(query: str) -> dict: # noqa: ARG001
assert payload["detail"]["error"] == "planner exploded"
-def test_upload_endpoint_profiles_csv(client: TestClient) -> None:
+def test_upload_endpoint_requires_auth(client: TestClient) -> None:
+ response = client.post(
+ "/uploads",
+ files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\n"), "text/csv")},
+ )
+
+ assert response.status_code == 401
+
+
+def test_upload_endpoint_profiles_csv(client: TestClient, tmp_path) -> None:
+ token = _signup(client, "upload-csv@example.com")
response = client.post(
"/uploads",
files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")},
+ headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
@@ -140,10 +184,134 @@ def test_upload_endpoint_profiles_csv(client: TestClient) -> None:
assert payload["asset"]["rows"] == 2
assert payload["asset"]["columns"] == 2
assert payload["asset"]["status"] == "verified"
+ assert payload["asset"]["relationCount"] == 1
+ assert payload["asset"]["primaryRelationName"]
+ assert any((tmp_path / "uploads").rglob("original.csv"))
+
+ uploads_response = client.get("/uploads", headers={"Authorization": f"Bearer {token}"})
+ assert uploads_response.status_code == 200
+ assert [asset["id"] for asset in uploads_response.json()] == [payload["asset"]["id"]]
+
+
+def test_upload_endpoint_profiles_json(client: TestClient) -> None:
+ token = _signup(client, "upload-json@example.com")
+ response = client.post(
+ "/uploads",
+ files={
+ "file": (
+ "orders.json",
+ BytesIO(
+ b'[{"order_id":"o1","customer":{"name":"Ada"},"items":[{"sku":"A1","qty":2}]}]'
+ ),
+ "application/json",
+ )
+ },
+ headers={"Authorization": f"Bearer {token}"},
+ )
+
+ assert response.status_code == 200
+ payload = response.json()
+ assert payload["asset"]["type"] == "JSON"
+ assert payload["asset"]["rows"] == 1
+ assert payload["asset"]["relationCount"] == 2
+ assert payload["asset"]["primaryRelationName"]
+
+
+def test_upload_endpoint_rejects_unsupported_file_types(client: TestClient) -> None:
+ token = _signup(client, "upload-unsupported@example.com")
+ response = client.post(
+ "/uploads",
+ files={"file": ("pipeline.tsv", BytesIO(b"stage\tamount\nopen\t10\n"), "text/tab-separated-values")},
+ headers={"Authorization": f"Bearer {token}"},
+ )
+
+ assert response.status_code == 400
+ assert response.json()["detail"]["message"] == "Only CSV and JSON uploads are currently supported."
+
+
+def test_upload_endpoint_migrates_legacy_uploads_table(tmp_path, monkeypatch) -> None:
+ db_path = tmp_path / "legacy_uploads.sqlite"
+ registry_path = tmp_path / "legacy_registry.duckdb"
+ upload_dir = tmp_path / "uploads"
+
+ conn = sqlite3.connect(db_path)
+ conn.execute(
+ """
+ CREATE TABLE uploads (
+ id INTEGER NOT NULL PRIMARY KEY,
+ upload_id VARCHAR(64) NOT NULL,
+ user_id INTEGER NOT NULL,
+ source_id VARCHAR(64) NOT NULL,
+ original_filename VARCHAR(255) NOT NULL,
+ file_type VARCHAR(32) NOT NULL,
+ storage_path VARCHAR(1024) NOT NULL,
+ content_type VARCHAR(255),
+ size_bytes BIGINT NOT NULL,
+ content_hash VARCHAR(64) NOT NULL,
+ status VARCHAR(32) NOT NULL,
+ rows INTEGER,
+ columns INTEGER,
+ relation_count INTEGER,
+ primary_relation_name VARCHAR(255),
+ summary TEXT,
+ created_at DATETIME NOT NULL,
+ updated_at DATETIME NOT NULL
+ )
+ """
+ )
+ conn.execute(
+ """
+ INSERT INTO uploads (
+ id, upload_id, user_id, source_id, original_filename, file_type, storage_path, content_type,
+ size_bytes, content_hash, status, rows, columns, relation_count, primary_relation_name, summary,
+ created_at, updated_at
+ ) VALUES (
+ 1, 'legacy_upload_1', 999, 'legacy_source_1', 'legacy.csv', 'CSV', '/tmp/legacy.csv', 'text/csv',
+ 12, 'abc123', 'verified', 1, 2, 1, 'legacy_relation', 'legacy summary',
+ '2026-04-01 12:00:00', '2026-04-01 12:00:00'
+ )
+ """
+ )
+ conn.commit()
+ conn.close()
+
+ monkeypatch.setenv("DATABASE_PATH", str(db_path))
+ monkeypatch.setenv("REGISTRY_PATH", str(registry_path))
+ monkeypatch.setenv("UPLOAD_STORAGE_DIR", str(upload_dir))
+ get_settings.cache_clear()
+ reset_engine_and_session()
+
+ try:
+ with TestClient(app) as legacy_client:
+ token = _signup(legacy_client, "legacy-migration@example.com")
+ response = legacy_client.post(
+ "/uploads",
+ files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\n"), "text/csv")},
+ headers={"Authorization": f"Bearer {token}"},
+ )
+
+ assert response.status_code == 200, response.text
+
+ uploads_response = legacy_client.get("/uploads", headers={"Authorization": f"Bearer {token}"})
+ assert uploads_response.status_code == 200
+ assert len(uploads_response.json()) == 1
+ finally:
+ get_settings.cache_clear()
+ reset_engine_and_session()
+
+ migrated_conn = sqlite3.connect(db_path)
+ columns = [row[1] for row in migrated_conn.execute("PRAGMA table_info(uploads)").fetchall()]
+ source_ids = [row[0] for row in migrated_conn.execute("SELECT source_id FROM uploads ORDER BY source_id").fetchall()]
+ migrated_conn.close()
+
+ assert "upload_id" not in columns
+ assert "file_type" not in columns
+ assert "source_id" in columns
+ assert source_ids == ["legacy_source_1", uploads_response.json()[0]["id"]]
def test_inspection_endpoint_returns_stored_inspection(client: TestClient) -> None:
- def fake_run_analysis(query: str) -> dict: # noqa: ARG001
+ def fake_run_analysis(query: str, source_ids=None) -> dict: # noqa: ARG001
return {
"analysis": "## Summary\nPipeline velocity improved.\n",
"trace": [{"step": "planner_compiled_node", "status": "completed", "details": {"objective": "x"}}],
@@ -175,7 +343,19 @@ def fake_run_analysis(query: str) -> dict: # noqa: ARG001
original = analysis_run.run_analysis
analysis_run.run_analysis = fake_run_analysis
try:
- analyze_response = client.post("/analyze", json={"query": "Why did pipeline velocity drop this week?"})
+ token = _signup(client, "inspection@example.com")
+ headers = {"Authorization": f"Bearer {token}"}
+ upload_response = client.post(
+ "/uploads",
+ files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")},
+ headers=headers,
+ )
+ source_id = upload_response.json()["asset"]["id"]
+ analyze_response = client.post(
+ "/analyze",
+ json={"query": "Why did pipeline velocity drop this week?", "source_ids": [source_id]},
+ headers=headers,
+ )
finally:
analysis_run.run_analysis = original
@@ -194,3 +374,131 @@ def test_inspection_endpoint_returns_404_for_unknown_id(client: TestClient) -> N
response = client.get("/inspections/inspect_missing")
assert response.status_code == 404
assert response.json()["detail"]["message"] == "Inspection not found."
+
+
+def test_analyze_endpoint_forwards_source_ids(client: TestClient) -> None:
+ captured: dict[str, object] = {}
+
+ def fake_run_analysis(query: str, source_ids=None) -> dict:
+ captured["query"] = query
+ captured["source_ids"] = source_ids
+ return {
+ "analysis": "## Summary\nScoped analysis.\n",
+ "trace": [],
+ "executed_steps": [],
+ "errors": [],
+ }
+
+ import app.services.analysis_run as analysis_run
+
+ original = analysis_run.run_analysis
+ analysis_run.run_analysis = fake_run_analysis
+ try:
+ token = _signup(client, "analyze-forward@example.com")
+ headers = {"Authorization": f"Bearer {token}"}
+ upload_response = client.post(
+ "/uploads",
+ files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")},
+ headers=headers,
+ )
+ source_id = upload_response.json()["asset"]["id"]
+ response = client.post("/analyze", json={"query": "Use this upload", "source_ids": [source_id]}, headers=headers)
+ finally:
+ analysis_run.run_analysis = original
+
+ assert response.status_code == 200
+ assert captured["query"] == "Use this upload"
+ assert captured["source_ids"] == [source_id]
+
+
+def test_analyze_endpoint_requires_uploaded_source_before_run_analysis(client: TestClient) -> None:
+ import app.services.analysis_run as analysis_run
+
+ called = False
+ original = analysis_run.run_analysis
+
+ def fake_run_analysis(query: str, source_ids=None) -> dict:
+ nonlocal called
+ called = True
+ return {"analysis": "", "trace": [], "executed_steps": [], "errors": []}
+
+ analysis_run.run_analysis = fake_run_analysis
+ try:
+ token = _signup(client, "analyze-requires-upload@example.com")
+ response = client.post(
+ "/analyze",
+ json={"query": "Why did pipeline velocity drop this week?"},
+ headers={"Authorization": f"Bearer {token}"},
+ )
+ finally:
+ analysis_run.run_analysis = original
+
+ assert response.status_code == 400
+ assert response.json()["detail"]["message"] == "Upload and attach at least one CSV or JSON data source before running analysis."
+ assert called is False
+
+
+def test_uploads_and_analysis_are_scoped_to_the_signed_in_user(client: TestClient) -> None:
+ owner = _signup(client, "owner-uploads@example.com")
+ intruder = _signup(client, "intruder-uploads@example.com")
+ owner_headers = {"Authorization": f"Bearer {owner}"}
+ intruder_headers = {"Authorization": f"Bearer {intruder}"}
+
+ upload_response = client.post(
+ "/uploads",
+ files={"file": ("pipeline.csv", BytesIO(b"stage,amount\nopen,10\nwon,25\n"), "text/csv")},
+ headers=owner_headers,
+ )
+ source_id = upload_response.json()["asset"]["id"]
+
+ owner_list = client.get("/uploads", headers=owner_headers)
+ intruder_list = client.get("/uploads", headers=intruder_headers)
+
+ assert [asset["id"] for asset in owner_list.json()] == [source_id]
+ assert intruder_list.json() == []
+
+ import app.services.analysis_run as analysis_run
+
+ called = False
+ original = analysis_run.run_analysis
+
+ def fake_run_analysis(query: str, source_ids=None) -> dict:
+ nonlocal called
+ called = True
+ return {"analysis": "", "trace": [], "executed_steps": [], "errors": []}
+
+ analysis_run.run_analysis = fake_run_analysis
+ try:
+ forbidden = client.post(
+ "/analyze",
+ json={"query": "Use someone else's upload", "source_ids": [source_id]},
+ headers=intruder_headers,
+ )
+ finally:
+ analysis_run.run_analysis = original
+
+ assert forbidden.status_code == 400
+ assert forbidden.json()["detail"]["message"] == "Attach a valid uploaded data source before running analysis."
+ assert called is False
+
+
+def test_delete_upload_removes_file_and_listing(client: TestClient, tmp_path) -> None:
+ owner = _signup(client, "delete-owner@example.com")
+ intruder = _signup(client, "delete-intruder@example.com")
+ owner_headers = {"Authorization": f"Bearer {owner}"}
+ intruder_headers = {"Authorization": f"Bearer {intruder}"}
+
+ upload_response = client.post(
+ "/uploads",
+ files={"file": ("orders.json", BytesIO(b'[{"order_id":"o1"}]'), "application/json")},
+ headers=owner_headers,
+ )
+ source_id = upload_response.json()["asset"]["id"]
+
+ forbidden = client.delete(f"/uploads/{source_id}", headers=intruder_headers)
+ assert forbidden.status_code == 404
+
+ delete_response = client.delete(f"/uploads/{source_id}", headers=owner_headers)
+ assert delete_response.status_code == 204
+ assert client.get("/uploads", headers=owner_headers).json() == []
+ assert not any((tmp_path / "uploads").rglob("original.json"))
diff --git a/tests/test_conversations.py b/tests/test_conversations.py
index d731f9d..e1dfa7b 100644
--- a/tests/test_conversations.py
+++ b/tests/test_conversations.py
@@ -13,6 +13,8 @@
@pytest.fixture
def chat_client(tmp_path, monkeypatch):
monkeypatch.setenv("DATABASE_PATH", str(tmp_path / "chat_test.sqlite"))
+ monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "chat_source_registry.duckdb"))
+ monkeypatch.setenv("UPLOAD_STORAGE_DIR", str(tmp_path / "uploads"))
get_settings.cache_clear()
reset_engine_and_session()
with TestClient(app) as client:
@@ -27,7 +29,7 @@ def _signup(client: TestClient, email: str, password: str = "password123") -> st
return r.json()["access_token"]
-def _fake_analysis_state(query: str) -> dict: # noqa: ARG001
+def _fake_analysis_state(query: str, source_ids=None) -> dict: # noqa: ARG001
return {
"analysis": "## Demo\nHello from fake analysis.\n",
"trace": [{"step": "planner_compiled_node", "status": "completed", "details": {}}],
@@ -228,3 +230,41 @@ def test_persisted_inspection_forbidden_for_other_user(chat_client: TestClient)
forbidden = chat_client.get(f"/inspections/{inspection_id}", headers={"Authorization": f"Bearer {intruder}"})
assert forbidden.status_code == 403
+
+
+def test_chat_rejects_uploads_owned_by_another_user(chat_client: TestClient) -> None:
+ import app.services.analysis_run as analysis_run
+
+ owner = _signup(chat_client, "owner-upload-chat@example.com")
+ intruder = _signup(chat_client, "intruder-upload-chat@example.com")
+ owner_headers = {"Authorization": f"Bearer {owner}"}
+ intruder_headers = {"Authorization": f"Bearer {intruder}"}
+
+ upload_response = chat_client.post(
+ "/uploads",
+ headers=owner_headers,
+ files={"file": ("pipeline.csv", b"stage,amount\nopen,10\nwon,25\n", "text/csv")},
+ )
+ source_id = upload_response.json()["asset"]["id"]
+
+ called = False
+ original = analysis_run.run_analysis
+
+ def fake_run_analysis(query: str, source_ids=None) -> dict:
+ nonlocal called
+ called = True
+ return {"analysis": "", "trace": [], "executed_steps": [], "errors": []}
+
+ analysis_run.run_analysis = fake_run_analysis
+ try:
+ response = chat_client.post(
+ "/chat",
+ headers=intruder_headers,
+ json={"query": "Use another user's upload", "source_ids": [source_id]},
+ )
+ finally:
+ analysis_run.run_analysis = original
+
+ assert response.status_code == 400
+ assert response.json()["detail"]["message"] == "Attach a valid uploaded data source before running analysis."
+ assert called is False
diff --git a/tests/test_planner_schema.py b/tests/test_planner_schema.py
index a82500e..3659289 100644
--- a/tests/test_planner_schema.py
+++ b/tests/test_planner_schema.py
@@ -1,11 +1,27 @@
"""Compiled plan schema and planner retry behavior."""
+import pytest
+
from app.agent.planner import plan_compiled_query
from app.agent.state import create_initial_state
-from app.data.semantic_model import get_semantic_context
+from app.config import get_settings
+from app.data.registry import clear_source_registry, ingest_source
+from app.data.semantic_model import clear_semantic_context_cache, get_semantic_context
from app.schemas import CompiledPlan
+@pytest.fixture(autouse=True)
+def isolated_registry(tmp_path, monkeypatch):
+ monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "planner_source_registry.duckdb"))
+ get_settings.cache_clear()
+ clear_source_registry()
+ clear_semantic_context_cache()
+ yield
+ clear_source_registry()
+ clear_semantic_context_cache()
+ get_settings.cache_clear()
+
+
def test_compiled_plan_normalizes_max_steps() -> None:
plan = CompiledPlan.model_validate(
{
@@ -64,7 +80,7 @@ def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002
monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: stub)
state = create_initial_state("Compare segments")
- state["dataset_context"] = {"reference_date": "2017-12-31", "views": [{"name": "opportunities_enriched"}]}
+ state["dataset_context"] = {"reference_date": "2017-12-31", "relations": [], "views": []}
state = plan_compiled_query(state)
assert stub.calls == 2
@@ -73,17 +89,20 @@ def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002
def test_semantic_context_exposes_normalized_manifest() -> None:
+ asset = ingest_source("pipeline.csv", b"owner,pipeline_velocity_days\nAda,10\nBen,14\n")
manifest = get_semantic_context().schema_manifest
assert manifest["dialect"] == "duckdb"
assert manifest["relations"]
- relation = next(relation for relation in manifest["relations"] if relation["name"] == "opportunities_enriched")
+ relation = next(relation for relation in manifest["relations"] if relation["name"] == asset.primaryRelationName)
assert relation["columns"]
assert relation["identifier_columns"]
assert relation["grain"]
def test_planner_retries_after_sql_preflight_failure(monkeypatch) -> None:
+ asset = ingest_source("pipeline.csv", b"owner,pipeline_velocity_days\nAda,10\nBen,14\n")
+ relation_name = asset.primaryRelationName
bad = {
"objective": "Analyze by sales agent",
"plan": [
@@ -91,7 +110,7 @@ def test_planner_retries_after_sql_preflight_failure(monkeypatch) -> None:
"id": 1,
"purpose": "Break out velocity by agent",
"type": "sql",
- "query": "SELECT sales_agent, AVG(pipeline_velocity_days) AS avg_velocity_days FROM opportunities_enriched GROUP BY sales_agent",
+ "query": f"SELECT sales_agent, AVG(pipeline_velocity_days) AS avg_velocity_days FROM {relation_name} GROUP BY sales_agent",
"output_alias": "velocity_by_agent",
}
],
@@ -106,7 +125,7 @@ def test_planner_retries_after_sql_preflight_failure(monkeypatch) -> None:
"id": 1,
"purpose": "Break out velocity by owner",
"type": "sql",
- "query": "SELECT owner, AVG(pipeline_velocity_days) AS avg_velocity_days FROM opportunities_enriched GROUP BY owner",
+ "query": f"SELECT owner, AVG(pipeline_velocity_days) AS avg_velocity_days FROM {relation_name} GROUP BY owner",
"output_alias": "velocity_by_owner",
}
],
@@ -131,7 +150,7 @@ def generate_json(self, prompt: str, schema=None): # noqa: ANN001, ARG002
monkeypatch.setattr("app.agent.planner.get_llm_client", lambda: stub)
state = create_initial_state("Why did pipeline velocity drop this week by sales agent?")
- state["dataset_context"] = get_semantic_context().schema_manifest
+ state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest
state = plan_compiled_query(state)
assert stub.calls == 2
diff --git a/tests/test_source_registry.py b/tests/test_source_registry.py
new file mode 100644
index 0000000..12de3f6
--- /dev/null
+++ b/tests/test_source_registry.py
@@ -0,0 +1,100 @@
+"""Tests for the persistent source registry and registry-backed manifests."""
+
+from __future__ import annotations
+
+import pytest
+
+from app.agent.planner import _schema_subset_for_question
+from app.config import get_settings
+from app.data.registry import clear_source_registry, create_source_link, get_schema_manifest, ingest_source
+from app.data.semantic_model import clear_semantic_context_cache, get_semantic_context
+
+
+@pytest.fixture(autouse=True)
+def isolated_registry(tmp_path, monkeypatch):
+ monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "source_registry.duckdb"))
+ get_settings.cache_clear()
+ clear_source_registry()
+ clear_semantic_context_cache()
+ yield
+ clear_source_registry()
+ clear_semantic_context_cache()
+ get_settings.cache_clear()
+
+
+def _relation_by_name(manifest: dict, relation_name: str) -> dict:
+ return next(relation for relation in manifest["relations"] if relation["name"] == relation_name)
+
+
+def test_registry_manifest_starts_empty_without_builtin_sources() -> None:
+ manifest = get_schema_manifest()
+
+ assert manifest["dialect"] == "duckdb"
+ assert manifest["relations"] == []
+
+
+def test_nested_json_upload_creates_primary_and_child_relations() -> None:
+ asset = ingest_source(
+ "orders.json",
+ b'[{"order_id":"o1","customer":{"name":"Ada"},"items":[{"sku":"A1","qty":2},{"sku":"B2","qty":1}]}]',
+ )
+
+ manifest = get_schema_manifest([asset.id])
+ root = _relation_by_name(manifest, asset.primaryRelationName)
+ child = next(relation for relation in manifest["relations"] if relation["name"].startswith(f"{asset.primaryRelationName}__"))
+
+ assert asset.relationCount == 2
+ assert root["is_primary"] is True
+ assert any(column["name"] == "customer__name" and column["source_path"] == "customer.name" for column in root["columns"])
+ assert child["parent_relation"] == asset.primaryRelationName
+ assert any(join["target_relation"] == asset.primaryRelationName for join in child["join_keys"])
+ assert {column["name"] for column in child["columns"]} >= {"record_id", "parent_record_id", "ordinal", "sku", "qty"}
+
+
+def test_unrelated_sources_do_not_auto_link_even_with_shared_column_names() -> None:
+ customers = ingest_source("customers.csv", b"customer_id,name\nc1,Ada\nc2,Ben\n")
+ orders = ingest_source("orders.csv", b"customer_id,amount\nc1,100\nc2,50\n")
+
+ manifest = get_schema_manifest([customers.id, orders.id])
+ customers_relation = _relation_by_name(manifest, customers.primaryRelationName)
+ orders_relation = _relation_by_name(manifest, orders.primaryRelationName)
+
+ assert customers_relation["join_keys"] == []
+ assert orders_relation["join_keys"] == []
+
+
+def test_explicit_source_links_appear_in_manifest() -> None:
+ customers = ingest_source("customers.csv", b"customer_id,name\nc1,Ada\nc2,Ben\n")
+ orders = ingest_source("orders.csv", b"customer_id,amount\nc1,100\nc2,50\n")
+
+ create_source_link(customers.primaryRelationName, "customer_id", orders.primaryRelationName, "customer_id")
+ manifest = get_schema_manifest([customers.id, orders.id])
+ customers_relation = _relation_by_name(manifest, customers.primaryRelationName)
+ orders_relation = _relation_by_name(manifest, orders.primaryRelationName)
+
+ assert any(join["target_relation"] == orders.primaryRelationName and join["source_column"] == "customer_id" for join in customers_relation["join_keys"])
+ assert any(join["target_relation"] == customers.primaryRelationName and join["source_column"] == "customer_id" for join in orders_relation["join_keys"])
+
+
+def test_semantic_context_scopes_to_selected_sources() -> None:
+ inventory = ingest_source("inventory.csv", b"sku,stock\nA1,10\nB2,3\n")
+ invoices = ingest_source("invoices.csv", b"invoice_id,amount\ni1,100\ni2,250\n")
+
+ full_context = get_semantic_context().schema_manifest
+ scoped_context = get_semantic_context([inventory.id]).schema_manifest
+
+ assert any(relation["name"] == inventory.primaryRelationName for relation in full_context["relations"])
+ assert any(relation["name"] == invoices.primaryRelationName for relation in full_context["relations"])
+ assert not any(relation["name"] == invoices.primaryRelationName for relation in scoped_context["relations"])
+ assert any(relation["name"] == inventory.primaryRelationName for relation in scoped_context["relations"])
+
+
+def test_schema_subset_prefers_relevant_uploaded_primary_relation() -> None:
+ ingest_source("inventory.csv", b"sku,stock\nA1,10\nB2,3\n")
+ invoices = ingest_source("invoices.csv", b"invoice_id,invoice_amount,status\ni1,100,paid\ni2,250,due\n")
+ ingest_source("campaigns.csv", b"campaign,spend\nspring,1200\nsummer,950\n")
+
+ manifest = get_schema_manifest()
+ subset = _schema_subset_for_question(manifest, "Which invoice amount is highest?")
+
+ assert any(relation["name"] == invoices.primaryRelationName for relation in subset["relations"])
diff --git a/tests/test_tools.py b/tests/test_tools.py
index d2fa90f..9811fec 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -1,13 +1,30 @@
"""Executor tests for compiled SQL plans."""
+import pytest
+
from app.agent.executor import execute_plan, execute_single_plan_step
from app.agent.state import create_initial_state
-from app.data.semantic_model import get_semantic_context
+from app.config import get_settings
+from app.data.registry import clear_source_registry, ingest_source
+from app.data.semantic_model import clear_semantic_context_cache, get_semantic_context
+
+
+@pytest.fixture(autouse=True)
+def isolated_registry(tmp_path, monkeypatch):
+ monkeypatch.setenv("REGISTRY_PATH", str(tmp_path / "tools_source_registry.duckdb"))
+ get_settings.cache_clear()
+ clear_source_registry()
+ clear_semantic_context_cache()
+ yield
+ clear_source_registry()
+ clear_semantic_context_cache()
+ get_settings.cache_clear()
def test_execute_sql_step_returns_artifact_summary() -> None:
+ asset = ingest_source("pipeline.csv", b"segment\nSMB\nEnterprise\nSMB\n")
state = create_initial_state("Compare SMB vs Enterprise performance")
- state["dataset_context"] = get_semantic_context().schema_manifest
+ state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest
compiled_plan = {
"objective": "Segment counts",
"max_steps": 3,
@@ -16,7 +33,7 @@ def test_execute_sql_step_returns_artifact_summary() -> None:
"id": 1,
"purpose": "Get one sample aggregation.",
"type": "sql",
- "query": "SELECT segment, COUNT(*) AS deals FROM opportunities_enriched GROUP BY segment ORDER BY deals DESC",
+ "query": f"SELECT segment, COUNT(*) AS deals FROM {asset.primaryRelationName} GROUP BY segment ORDER BY deals DESC",
"output_alias": "segment_counts",
}
],
@@ -30,8 +47,9 @@ def test_execute_sql_step_returns_artifact_summary() -> None:
def test_execute_plan_marks_empty_table_as_failed() -> None:
+ asset = ingest_source("pipeline.csv", b"segment\nSMB\nEnterprise\n")
state = create_initial_state("Why did pipeline velocity drop this week?")
- state["dataset_context"] = get_semantic_context().schema_manifest
+ state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest
compiled_plan = {
"objective": "Empty query",
"max_steps": 3,
@@ -51,8 +69,9 @@ def test_execute_plan_marks_empty_table_as_failed() -> None:
def test_execute_single_plan_step_retry() -> None:
+ asset = ingest_source("pipeline.csv", b"segment\nSMB\nEnterprise\n")
state = create_initial_state("Test retry")
- state["dataset_context"] = get_semantic_context().schema_manifest
+ state["dataset_context"] = get_semantic_context([asset.id]).schema_manifest
step = {
"id": 1,
"purpose": "Get one row.",
diff --git a/ui/src/api/chat.ts b/ui/src/api/chat.ts
index 15daccf..bb8b7eb 100644
--- a/ui/src/api/chat.ts
+++ b/ui/src/api/chat.ts
@@ -9,6 +9,7 @@ import { requestWithAuth, ApiError } from "@/api/client";
import { cacheInspection } from "@/api/inspections";
import { conversationTitleFromPrompt, mapAnalyzeResponseToUi } from "@/api/mappers";
import type {
+ AnalyzeApiRequest,
AnalyzeApiResponse,
ApiChatTurnResponse,
ApiConversationDetailResponse,
@@ -143,10 +144,13 @@ export async function submitChatPrompt(payload: ChatRequest, accessToken: string
}
try {
- const body: Record
{asset.summary}
: null} {(asset.rows || asset.columns) && (Primary relation: {asset.primaryRelationName}
+ ) : null} ); } diff --git a/ui/src/components/app/UploadsPanel.test.tsx b/ui/src/components/app/UploadsPanel.test.tsx new file mode 100644 index 0000000..c699a95 --- /dev/null +++ b/ui/src/components/app/UploadsPanel.test.tsx @@ -0,0 +1,77 @@ +import { cleanup, fireEvent, render, screen } from "@testing-library/react"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { UploadsPanel } from "@/components/app/UploadsPanel"; + +afterEach(() => { + cleanup(); +}); + +describe("UploadsPanel", () => { + it("renders a file picker restricted to csv and json", () => { + render( +Upload datasets
+Add a CSV or JSON file to profile it and make it available for analysis.
+