From 7800c41c85dd732eccb2a5bf8716cf7f375ff6ac Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Thu, 30 Apr 2026 16:07:06 -0400 Subject: [PATCH 01/22] updated env template to include synthscholar env variables --- env.template | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/env.template b/env.template index 28be2e2..e0a4e3b 100644 --- a/env.template +++ b/env.template @@ -194,6 +194,55 @@ OLLAMA_API_ENDPOINT=http://host.docker.internal:11434 OLLAMA_MODEL=nomic-embed-text OLLAMA_PORT=11434 +# ---------------------------------------------------------------------------- +# ML Service - SynthScholar (PRISMA literature review) +# ---------------------------------------------------------------------------- +# Mounts under /api/synth-scholar/* in ml_service. End-to-end PRISMA-guided +# pipeline (search strategy → article fetch → title/abstract + full-text +# screening → data charting → critical appraisal → narrative synthesis → +# GRADE), driven by the `synthscholar` Python library and orchestrated by +# the local module at ml_service/core/synth_scholar/. SSE progress, plan- +# confirmation gating, exports (markdown / JSON / BibTeX / Turtle / JSON-LD), +# per-user ownership, and shared/cached corpus search are all wired up. +# +# Database: SynthScholar reuses the unified `brainkb` Postgres database +# (the same one JWT auth and usermanagement_service write to). Its tables +# (`reviews`, plus synthscholar's bundled `article_store` / `review_cache` / +# `pipeline_checkpoints`) coexist with the existing `Web_*` tables — names +# are snake_case here vs PascalCase there, so there are no collisions. No +# separate DSN is needed; the JWT_POSTGRES_DATABASE_* values above are +# reused at startup. Optional: install pgvector in the postgres image to +# enable semantic article/review search; without it, only keyword/title +# search works (synthscholar's bundled migration 004 logs and skips). +# +# Auth: every endpoint sits behind ml_service's existing JWT (clients call +# /api/token with NEXT_PUBLIC_JWT_USER / _PASSWORD on the UI side). The +# UI also forwards a per-user OpenRouter key on every create-review call +# (resolved from the user's personal sessionStorage key, or the admin- +# shared key set in /admin/settings/openrouter-key). The OPENROUTER_API_KEY +# below is only the operator fallback used when neither is configured. + +# OpenRouter API key — operator fallback only. The UI normally forwards a +# per-user or admin-shared key on every request, so this env value is used +# only when neither is configured. Safe to leave blank in production. +OPENROUTER_API_KEY= + +# Optional. NCBI E-utilities API key — recommended in production to raise +# PubMed rate limits from 3 to 10 requests/second. +# https://www.ncbi.nlm.nih.gov/account/settings/ +NCBI_API_KEY= + +# Optional. Semantic Scholar API key — adds Semantic Scholar to the search +# fan-out. Silently skipped at runtime if absent. +SEMANTIC_SCHOLAR_API_KEY= + +# Optional. CORE API key — adds the CORE provider. Silently skipped if absent. +CORE_API_KEY= + +# Optional. Contact email passed to OpenAlex / Crossref / Europe PMC for +# polite-pool rate limits. Silently skipped if absent. +SYNTHSCHOLAR_EMAIL= + From ff7a7b3489c3625925c2dc3627456e1b8de2cd29 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Thu, 30 Apr 2026 16:07:31 -0400 Subject: [PATCH 02/22] synthscholar added --- ml_service/requirements.txt | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ml_service/requirements.txt b/ml_service/requirements.txt index b1e3ea8..8b4d68e 100644 --- a/ml_service/requirements.txt +++ b/ml_service/requirements.txt @@ -45,4 +45,11 @@ grpcio-tools==1.60.2 grpcio-health-checking==1.60.2 PyMuPDF==1.26.5 -ollama==0.6.0 \ No newline at end of file +ollama==0.6.0 + +# SynthScholar (PRISMA literature review). The `synthscholar` package owns +# the AI pipeline; the local module under core/synth_scholar/ is just the +# orchestration layer (sessions, SSE, exports). SQLAlchemy 2.0 async backs +# the review tables, separate from ml_service's raw-asyncpg pool. +synthscholar[fulltext,semantic]==0.0.6 +sqlalchemy[asyncio]>=2.0.30 \ No newline at end of file From 4b35828249996a57d86aed808a1f0844ec7b3fbb Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Thu, 30 Apr 2026 16:28:21 -0400 Subject: [PATCH 03/22] synthscholar integrated --- ml_service/core/main.py | 43 +- ml_service/core/synth_scholar/__init__.py | 10 + ml_service/core/synth_scholar/database.py | 122 ++ ml_service/core/synth_scholar/db_models.py | 65 + .../core/synth_scholar/progress_events.py | 240 +++ ml_service/core/synth_scholar/routes.py | 1601 +++++++++++++++++ ml_service/core/synth_scholar/schemas.py | 368 ++++ ml_service/core/synth_scholar/store.py | 715 ++++++++ 8 files changed, 3163 insertions(+), 1 deletion(-) create mode 100644 ml_service/core/synth_scholar/__init__.py create mode 100644 ml_service/core/synth_scholar/database.py create mode 100644 ml_service/core/synth_scholar/db_models.py create mode 100644 ml_service/core/synth_scholar/progress_events.py create mode 100644 ml_service/core/synth_scholar/routes.py create mode 100644 ml_service/core/synth_scholar/schemas.py create mode 100644 ml_service/core/synth_scholar/store.py diff --git a/ml_service/core/main.py b/ml_service/core/main.py index 4984652..c7d7ec6 100644 --- a/ml_service/core/main.py +++ b/ml_service/core/main.py @@ -13,12 +13,24 @@ from core.configure_logging import configure_logging from core.routers.index import router as index_router -from core.routers.jwt_auth import router as jwt_router +from core.routers.jwt_auth import router as jwt_router from core.routers.structsense import router as structsense_router from core.database import init_db_pool, get_db_pool, debug_pool_status from core.configuration import load_environment from motor.motor_asyncio import AsyncIOMotorClient +# SynthScholar (PRISMA literature review). Imports are lazy-guarded so a +# missing `synthscholar` library doesn't crash the rest of ml_service — +# the router simply won't mount and the /api/synth-scholar/* surface returns 404. +try: + from core.synth_scholar.routes import router as synth_scholar_router + from core.synth_scholar.database import init_db as init_synth_scholar_db, close_db as close_synth_scholar_db + from core.synth_scholar.store import fix_stuck_reviews as fix_synth_scholar_stuck_reviews + _SYNTH_SCHOLAR_AVAILABLE = True +except Exception as _exc: + _SYNTH_SCHOLAR_AVAILABLE = False + _SYNTH_SCHOLAR_IMPORT_ERROR = _exc + from fastapi.middleware.cors import CORSMiddleware # Initialize logger - will be configured in lifespan @@ -44,6 +56,25 @@ async def lifespan(app: FastAPI): logger.error(f"Failed to initialize database pool: {e}") raise + # Initialise SynthScholar (PRISMA review) tables and recover any sessions + # that were running when the server last shut down. Failures here are + # logged but non-fatal so the rest of ml_service still boots. + if _SYNTH_SCHOLAR_AVAILABLE: + try: + await init_synth_scholar_db() + stuck = await fix_synth_scholar_stuck_reviews() + if stuck: + logger.warning("SynthScholar: marked %d stuck review(s) as FAILED on startup", stuck) + logger.info("SynthScholar: database ready") + except Exception as e: + logger.error("SynthScholar startup failed (continuing without it): %s", e) + else: + logger.warning( + "SynthScholar disabled — import failed: %s. Install `synthscholar` " + "and `sqlalchemy[asyncio]` to enable.", + _SYNTH_SCHOLAR_IMPORT_ERROR, + ) + # Initialize MongoDB client (reused across all requests) try: env = load_environment() @@ -113,6 +144,14 @@ async def lifespan(app: FastAPI): except Exception as e: logger.error(f"Error closing MongoDB client: {e}") + # Dispose the SynthScholar SQLAlchemy engine. + if _SYNTH_SCHOLAR_AVAILABLE: + try: + await close_synth_scholar_db() + logger.info("SynthScholar: SQLAlchemy engine disposed") + except Exception as e: + logger.error("SynthScholar shutdown error: %s", e) + logger.info("FastAPI shutdown complete") app = FastAPI( @@ -150,6 +189,8 @@ async def lifespan(app: FastAPI): app.include_router(index_router, prefix="/api") app.include_router(jwt_router, prefix="/api", tags=["Security"]) app.include_router(structsense_router, prefix="/api", tags=["Multi-agent Systems"]) +if _SYNTH_SCHOLAR_AVAILABLE: + app.include_router(synth_scholar_router, prefix="/api") # Exception handlers @app.exception_handler(HTTPException) diff --git a/ml_service/core/synth_scholar/__init__.py b/ml_service/core/synth_scholar/__init__.py new file mode 100644 index 0000000..92d3483 --- /dev/null +++ b/ml_service/core/synth_scholar/__init__.py @@ -0,0 +1,10 @@ +"""SynthScholar — PRISMA-guided literature review pipeline. + +Ported from aep-knowledge-synthesis/backend. Uses an independent SQLAlchemy +2.0 async engine for its own tables (reviews, plus synthscholar's bundled +article_store / review_cache / pipeline_checkpoints), separate from the +existing ml_service raw-asyncpg pool used by jwt_auth + structsense. + +Entry point: `core.synth_scholar.routes.router` (mounted at /api/synth-scholar +by core/main.py). All endpoints require a valid JWT (Depends(get_current_user)). +""" diff --git a/ml_service/core/synth_scholar/database.py b/ml_service/core/synth_scholar/database.py new file mode 100644 index 0000000..e7abacd --- /dev/null +++ b/ml_service/core/synth_scholar/database.py @@ -0,0 +1,122 @@ +"""Async SQLAlchemy engine and session factory for the SynthScholar tables. + +Standalone from ml_service's raw-asyncpg pool: SynthScholar uses SQLAlchemy +2.0 declarative models with JSONB columns, and tying it to the existing pool +would force a dual-driver setup. They share the same Postgres database — the +DSN is built from the same JWT_POSTGRES_DATABASE_* vars ml_service already +uses — but the engines are independent so each side manages its own pool. + +SynthScholar's tables (`reviews` + synthscholar's bundled `article_store` / +`review_cache` / `pipeline_checkpoints`) live alongside the existing JWT auth +tables in the same `brainkb` database. No name collisions: synthscholar uses +lower_snake_case while the JWT side uses `Web_*` PascalCase. +""" + +from __future__ import annotations + +import logging +import os +from importlib.resources import files + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase + +logger = logging.getLogger(__name__) + + +def _resolve_dsn() -> str: + """Construct the SQLAlchemy async DSN from the unified Postgres env vars. + + Required: JWT_POSTGRES_DATABASE_HOST_URL / _USER / _PASSWORD / _NAME. + Port defaults to 5432. Falls back to a local-dev DSN only when none of + the env vars are present (e.g. running `python -m` outside docker). + """ + host = os.environ.get("JWT_POSTGRES_DATABASE_HOST_URL") + user = os.environ.get("JWT_POSTGRES_DATABASE_USER") + password = os.environ.get("JWT_POSTGRES_DATABASE_PASSWORD") + db = os.environ.get("JWT_POSTGRES_DATABASE_NAME") + port = os.environ.get("JWT_POSTGRES_DATABASE_PORT", "5432") + if host and user and password and db: + return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{db}" + return "postgresql+asyncpg://postgres:postgres@localhost:5432/brainkb" + + +DATABASE_URL = _resolve_dsn() + +engine = create_async_engine(DATABASE_URL, echo=False, pool_pre_ping=True) +async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + +class Base(DeclarativeBase): + """Declarative base for SynthScholar ORM models.""" + pass + + +async def init_db() -> None: + """Create review tables and apply additive ALTERs idempotently. + + Safe to call on every startup. The CREATE TABLE is checkfirst-aware; the + ALTERs use ADD COLUMN IF NOT EXISTS so re-running is a no-op. + """ + from . import db_models # noqa: F401 — register models with metadata + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + async with engine.begin() as conn: + for stmt in ( + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS run_request_json JSONB", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS share_to_cache BOOLEAN NOT NULL DEFAULT FALSE", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS checkpoint_json JSONB", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS last_completed_step INTEGER NOT NULL DEFAULT 0", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS stage VARCHAR(200)", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS stage_idx INTEGER", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS stage_total INTEGER", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS stage_done_count INTEGER NOT NULL DEFAULT 0", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS stage_remaining INTEGER", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS articles_included INTEGER", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS latest_message TEXT", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS owner_email VARCHAR(320)", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS pending_plan_json JSONB", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS pending_plan_iteration INTEGER", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS plan_response_json JSONB", + "ALTER TABLE reviews ADD COLUMN IF NOT EXISTS cancel_requested BOOLEAN NOT NULL DEFAULT FALSE", + ): + await conn.execute(text(stmt)) + + await apply_synthscholar_migrations() + + +async def apply_synthscholar_migrations() -> None: + """Apply synthscholar's bundled cache migrations idempotently. + + These create the article_store / review_cache / pipeline_checkpoints + tables plus optional pgvector embeddings used by the search endpoints. + Migration 004 requires the pgvector extension; if Postgres doesn't ship + it, the failure is logged and semantic search degrades to keyword/by_title. + """ + try: + from synthscholar.cache import migrations as _migrations_pkg # type: ignore + except Exception as exc: + logger.warning("synthscholar migrations not available: %s", exc) + return + + migrations_dir = ( + files(_migrations_pkg) + if hasattr(_migrations_pkg, "__path__") + else files("synthscholar.cache").joinpath("migrations") + ) + sql_files = sorted(p for p in migrations_dir.iterdir() if p.name.endswith(".sql")) + for path in sql_files: + sql = path.read_text(encoding="utf-8") + try: + async with engine.begin() as conn: + await conn.exec_driver_sql(sql) + logger.info("Applied synthscholar migration: %s", path.name) + except Exception as exc: + logger.warning("Migration %s skipped: %s", path.name, exc) + + +async def close_db() -> None: + """Dispose the connection pool on shutdown.""" + await engine.dispose() diff --git a/ml_service/core/synth_scholar/db_models.py b/ml_service/core/synth_scholar/db_models.py new file mode 100644 index 0000000..53d5c9d --- /dev/null +++ b/ml_service/core/synth_scholar/db_models.py @@ -0,0 +1,65 @@ +"""SQLAlchemy ORM models for SynthScholar reviews. + +Complex nested objects (protocol, result) are stored as JSONB. owner_email +is the only addition over the upstream model — it lets us scope listings to +the calling user without joining against the JWT user table. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from sqlalchemy import Boolean, DateTime, Integer, String, Text +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from .database import Base + + +class ReviewRow(Base): + """Persistent review session.""" + + __tablename__ = "reviews" + + review_id: Mapped[str] = mapped_column(String(128), primary_key=True) + status: Mapped[str] = mapped_column(String(20), nullable=False, default="pending") + title: Mapped[str] = mapped_column(String(500), nullable=False, default="") + + protocol_json: Mapped[dict | None] = mapped_column(JSONB, nullable=True) + result_json: Mapped[dict | None] = mapped_column(JSONB, nullable=True) + + pipeline_log: Mapped[list] = mapped_column(JSONB, nullable=False, default=list) + progress_step: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) + ) + completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + error: Mapped[str | None] = mapped_column(Text, nullable=True) + is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + share_to_cache: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + run_request_json: Mapped[dict | None] = mapped_column(JSONB, nullable=True) + checkpoint_json: Mapped[dict | None] = mapped_column(JSONB, nullable=True) + last_completed_step: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + # Typed progress state (written on each batch flush + terminal events). + stage: Mapped[str | None] = mapped_column(String(200), nullable=True) + stage_idx: Mapped[int | None] = mapped_column(Integer, nullable=True) + stage_total: Mapped[int | None] = mapped_column(Integer, nullable=True) + stage_done_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + stage_remaining: Mapped[int | None] = mapped_column(Integer, nullable=True) + articles_included: Mapped[int | None] = mapped_column(Integer, nullable=True) + latest_message: Mapped[str | None] = mapped_column(Text, nullable=True) + + # JWT email of the user who created the review. Listings are filtered by + # this in the routes layer; admins see everything. + owner_email: Mapped[str | None] = mapped_column(String(320), nullable=True) + + # Cross-worker plan-confirmation gate. The pipeline lives in one gunicorn + # worker but SSE/POST requests round-robin across all of them, so the + # plan and the user's response have to flow through Postgres for any + # other worker to see them. + pending_plan_json: Mapped[dict | None] = mapped_column(JSONB, nullable=True) + pending_plan_iteration: Mapped[int | None] = mapped_column(Integer, nullable=True) + plan_response_json: Mapped[dict | None] = mapped_column(JSONB, nullable=True) + cancel_requested: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) diff --git a/ml_service/core/synth_scholar/progress_events.py b/ml_service/core/synth_scholar/progress_events.py new file mode 100644 index 0000000..f26373b --- /dev/null +++ b/ml_service/core/synth_scholar/progress_events.py @@ -0,0 +1,240 @@ +"""Classify raw pipeline progress strings into typed events. + +synthscholar's `progress_callback` still receives plain strings; the richer +status display is produced by parsing those strings server-side into +categorised events with stage / article counters. This module is the only +place that knows the shape of those strings — everything downstream (session +state, SSE stream, frontend) consumes the structured dicts returned here. +""" + +from __future__ import annotations + +import re +from typing import Any, Optional, TypedDict + + +# ── Event kinds ──────────────────────────────────────────────────────── +# +# "log" — informational line (default) +# "stage_start" — a named pipeline phase begins (may carry stage_total) +# "stage_done" — a named pipeline phase ends (may carry articles_included) +# "article_start" — work on a single study begins within a stage +# "article_done" — work on a single study completes (carries done/total/remaining) +# "plan_ready" — search strategy is awaiting user confirmation +# "done" — entire pipeline finished successfully + +EVENT_KINDS = ( + "log", + "stage_start", + "stage_done", + "article_start", + "article_done", + "plan_ready", + "done", +) + + +class ClassifiedEvent(TypedDict, total=False): + kind: str + stage: Optional[str] + stage_index: Optional[int] + stage_total: Optional[int] + stage_done: Optional[int] + stage_remaining: Optional[int] + articles_included: Optional[int] + source: Optional[str] + + +_KNOWN_SOURCES = ( + "PubMed", "bioRxiv", "medRxiv", "Europe PMC", + "Semantic Scholar", "CrossRef", "OpenAlex", +) + + +# Order matters: first match wins. More specific patterns first. +_STAGE_START_PATTERNS: tuple[tuple[re.Pattern[str], str, Optional[int]], ...] = ( + (re.compile(r"^Generating search strategy"), "Search Strategy", 1), + (re.compile(r"^PubMed search \d+/\d+:"), "Database Search", 3), + (re.compile(r"^bioRxiv search:"), "Database Search", 3), + (re.compile(r"^Finding related articles"), "Related Articles", 4), + (re.compile(r"^Citation hop"), "Citation Hops", 5), + (re.compile(r"^Deduplicating"), "Deduplication", 6), + (re.compile(r"^Screening \d+ articles \(title/abstract"), "Title/Abstract Screening", 7), + (re.compile(r"^Fetching full text for"), "Full-text Retrieval", 8), + (re.compile(r"^Full-text eligibility screening"), "Full-text Screening", 9), + (re.compile(r"^Extracting evidence spans"), "Evidence Extraction", 10), + (re.compile(r"^Extracting data from \d+ studies"), "Data Extraction", 11), + (re.compile(r"^Assessing risk of bias"), "Risk of Bias", 12), + (re.compile(r"^(?:Charting|Data charting for)\s"), "Data Charting", 13), + (re.compile(r"^(?:Critical appraisal|Appraising)\s"), "Critical Appraisal", 14), + (re.compile(r"^(?:Narrative rows?|Building narrative)"),"Narrative Synthesis", 15), + (re.compile(r"^Synthesizing \d+ articles"), "Synthesis", 16), + (re.compile(r"^Validating grounding"), "Grounding Validation", 17), + (re.compile(r"^Assessing overall bias and GRADE"), "GRADE Assessment", 18), +) + +_STAGE_TOTAL_RE = re.compile( + r"(?:Screening|screening|Extracting data from|Fetching full text for)\s+\((?P\d+)\s+articles|" + r"(?:Screening|Extracting evidence spans\s+—)\s+(?P\d+)\s+articles|" + r"Extracting data from\s+(?P\d+)\s+studies|" + r"Fetching full text for\s+(?P\d+)\s+PMC" +) + +_ARTICLE_DONE_RE = re.compile( + r"^\s*✓\s+(?:Charted\s+|Appraised\s+|Narrative\s+)?\S+\s+\[(\d+)/(\d+)\s+done,\s+(\d+)\s+remaining\]" +) + +_ARTICLE_START_RE = re.compile( + r"^\s*(?:\[(\d+)/(\d+)\]|(?:Charting|Appraising|Narrative)\s+\[(\d+)/(\d+),\s+(\d+)\s+remaining\])" +) + +_TA_SCREENING_DONE_RE = re.compile(r"^Screening:\s+(\d+)\s+included,\s+(\d+)\s+excluded") +_FT_INCLUDED_RE = re.compile(r"^Final included:\s+(\d+)\s+articles") +_DEDUP_DONE_RE = re.compile(r"^After dedup:\s+(\d+)\s+\(removed\s+(\d+)\)") +_EVIDENCE_DONE_RE = re.compile(r"^Extracted\s+\d+\s+evidence spans from\s+(\d+)\s+articles") +_TOTAL_IDENT_RE = re.compile(r"^Total identified:\s+(\d+)") +_PLAN_READY_RE = re.compile(r"^Awaiting plan confirmation") +_DONE_RE = re.compile(r"^Review complete!?\s*$") + + +def _extract_source(message: str) -> Optional[str]: + stripped = message.lstrip() + for src in _KNOWN_SOURCES: + if stripped.startswith(src): + return src + return None + + +def classify(message: str) -> ClassifiedEvent: + """Classify one raw progress string into a typed event. + + Returns a dict with `kind` plus any fields the message explicitly + establishes (stage, counters, source). Callers maintain their own + cumulative session state by merging successive events. + """ + msg = message or "" + + if _PLAN_READY_RE.match(msg): + return {"kind": "plan_ready"} + + if _DONE_RE.match(msg): + return {"kind": "done"} + + m = _ARTICLE_DONE_RE.match(msg) + if m: + return { + "kind": "article_done", + "stage_done": int(m.group(1)), + "stage_total": int(m.group(2)), + "stage_remaining": int(m.group(3)), + } + + m = _ARTICLE_START_RE.match(msg) + if m: + idx = m.group(1) or m.group(3) + total = m.group(2) or m.group(4) + remaining = m.group(5) + ev: ClassifiedEvent = {"kind": "article_start"} + if idx and total: + ev["stage_done"] = int(idx) - 1 + ev["stage_total"] = int(total) + if remaining: + ev["stage_remaining"] = int(remaining) + return ev + + m = _TA_SCREENING_DONE_RE.match(msg) + if m: + return { + "kind": "stage_done", + "stage": "Title/Abstract Screening", + "stage_index": 7, + "articles_included": int(m.group(1)), + } + + m = _FT_INCLUDED_RE.match(msg) + if m: + return { + "kind": "stage_done", + "stage": "Full-text Screening", + "stage_index": 9, + "articles_included": int(m.group(1)), + } + + m = _DEDUP_DONE_RE.match(msg) + if m: + return { + "kind": "stage_done", + "stage": "Deduplication", + "stage_index": 6, + "stage_total": int(m.group(1)), + } + + m = _EVIDENCE_DONE_RE.match(msg) + if m: + return { + "kind": "stage_done", + "stage": "Evidence Extraction", + "stage_index": 10, + "stage_total": int(m.group(1)), + } + + m = _TOTAL_IDENT_RE.match(msg) + if m: + return { + "kind": "stage_done", + "stage": "Database Search", + "stage_index": 5, + "stage_total": int(m.group(1)), + } + + for pat, stage_name, stage_index in _STAGE_START_PATTERNS: + if pat.match(msg): + ev = {"kind": "stage_start", "stage": stage_name, "stage_index": stage_index} + tm = _STAGE_TOTAL_RE.search(msg) + if tm: + total_str = tm.group("a") or tm.group("b") or tm.group("c") or tm.group("d") + if total_str: + ev["stage_total"] = int(total_str) + src = _extract_source(msg) + if src: + ev["source"] = src + return ev + + ev_log: ClassifiedEvent = {"kind": "log"} + src = _extract_source(msg) + if src: + ev_log["source"] = src + return ev_log + + +def merge_into_state(state: dict[str, Any], event: ClassifiedEvent) -> dict[str, Any]: + """Merge a classified event into cumulative session progress state.""" + kind = event.get("kind", "log") + + if kind == "stage_start": + state["stage"] = event.get("stage") or state.get("stage") + state["stage_index"] = event.get("stage_index") or state.get("stage_index") + state["stage_total"] = event.get("stage_total") + state["stage_done"] = 0 + state["stage_remaining"] = event.get("stage_total") + return state + + if kind == "stage_done": + if event.get("stage"): + state["stage"] = event["stage"] + state["stage_index"] = event.get("stage_index") or state.get("stage_index") + if event.get("stage_total") is not None: + state["stage_total"] = event["stage_total"] + state["stage_done"] = event["stage_total"] + state["stage_remaining"] = 0 + if event.get("articles_included") is not None: + state["articles_included"] = event["articles_included"] + return state + + if kind in ("article_start", "article_done"): + for key in ("stage_total", "stage_done", "stage_remaining"): + if event.get(key) is not None: + state[key] = event[key] + return state + + return state diff --git a/ml_service/core/synth_scholar/routes.py b/ml_service/core/synth_scholar/routes.py new file mode 100644 index 0000000..564a8ea --- /dev/null +++ b/ml_service/core/synth_scholar/routes.py @@ -0,0 +1,1601 @@ +"""FastAPI routes for SynthScholar (PRISMA literature review). + +Ported from aep-knowledge-synthesis/backend/routes.py with three adaptations +for ml_service: + +1. Every endpoint requires a valid JWT via `Depends(get_current_user)`. The + bearer's email is recorded as `owner_email` on each new review and is + used to scope listings + reads. A claim of `roles=["Admin"]` bypasses + ownership checks (admin oversight). +2. Routes are mounted under /api/synth-scholar (not /api/v1) to align with + ml_service's existing /api prefix and avoid colliding with structsense's + surface (/ws/*, /save/*, /ner, /structured-resource). +3. All references to the BioSynthAI feature (separate vertical in the source) + are dropped — only the PRISMA review pipeline is ported. +""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import inspect +import logging +import os +import re +import time +from datetime import datetime, timezone +from typing import Annotated, AsyncGenerator, Optional + +from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request +from fastapi.responses import StreamingResponse + +from synthscholar import __version__ as agent_version # type: ignore[import-not-found] +from synthscholar.models import ( # type: ignore[import-not-found] + ReviewProtocol, ReviewPlan, RoBTool, CompareReviewResult, +) +from synthscholar.pipeline import PRISMAReviewPipeline # type: ignore[import-not-found] +from synthscholar.export import ( # type: ignore[import-not-found] + to_markdown, to_json, to_bibtex, to_turtle, to_jsonld, + to_rubric_markdown, to_rubric_json, + to_charting_markdown, to_charting_json, + to_appraisal_markdown, to_appraisal_json, + to_narrative_summary_markdown, to_narrative_summary_json, + to_compare_markdown, to_compare_json, + to_compare_charting_markdown, to_compare_charting_json, +) +from synthscholar.agents import ROB_DOMAINS # type: ignore[import-not-found] + +from core.security import decode_jwt, get_user, get_current_user + +from .schemas import ( + RunReviewRequest, CompareRunRequest, CompareReviewSummaryResponse, + RetryRequest, PlanResponseRequest, ReviewStatus, ReviewSummaryResponse, + ReviewDetailResponse, ArticleSummary, EvidenceSpanResponse, + ScreeningLogResponse, GRADEResponse, FlowResponse, ProgressEvent, + HealthResponse, VisibilityRequest, CacheSharingRequest, + LiteratureSearchRequest, LiteratureSearchResponse, + ReviewSearchRequest, ReviewSearchResponse, + SearchArticleResult, SearchReviewResult, + SearchSynthesisResponse, SearchSynthesisGroup, +) +from .store import review_store, ReviewSession +from .progress_events import classify + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# ── SSE auth dependency ──────────────────────────────────────────────── +# +# EventSource (the browser API for SSE) cannot set custom headers, so the +# usual `Authorization: Bearer ...` dance doesn't work. Instead the client +# passes the JWT as a `?token=` query parameter. This dependency mirrors +# `core.security.get_current_user` but reads the token from either the +# Authorization header OR the query string. Used only by the SSE route. + +async def _get_user_for_sse( + request: Request, + token: Optional[str] = Query(default=None, description="JWT bearer token (query-param fallback for SSE clients)"), +) -> dict: + raw = token + if not raw: + auth = request.headers.get("authorization") or request.headers.get("Authorization") + if auth and auth.lower().startswith("bearer "): + raw = auth.split(None, 1)[1].strip() + if not raw: + raise HTTPException( + status_code=401, + detail="Missing JWT (set Authorization header or ?token= query param).", + ) + try: + payload = decode_jwt(raw) + except Exception: + raise HTTPException(status_code=401, detail="Invalid or expired token.") + email = payload.get("sub") + if not email: + raise HTTPException(status_code=401, detail="Invalid token (missing sub claim).") + user = await get_user(email=email) + if user is None: + raise HTTPException(status_code=401, detail="Token user not found.") + return user + + +# OpenRouter slugs use hyphens (Anthropic API ID format), not dots. +# `anthropic/claude-opus-4.7` is NOT a valid slug — OpenRouter returns 401 +# "User not found" for unknown models, which masquerades as an auth failure. +# Keep this list aligned with brainkb-ui/src/app/user/synth-scholar/page.tsx +# `MODEL_OPTIONS`. Add new entries only after confirming on +# https://openrouter.ai/models that OpenRouter routes them. +AVAILABLE_MODELS = [ + # Anthropic — Claude 4.x family. + "anthropic/claude-opus-4-7", + "anthropic/claude-opus-4-6", + "anthropic/claude-sonnet-4-6", + "anthropic/claude-opus-4", + "anthropic/claude-sonnet-4", + "anthropic/claude-haiku-4-5", + "anthropic/claude-haiku-4", + # Google. + "google/gemini-2.5-pro", + "google/gemini-2.5-flash", + # OpenAI — only the slugs OpenRouter actually exposes. + "openai/gpt-4.1", + "openai/gpt-4o", + "openai/gpt-4o-mini", + # xAI / DeepSeek / Meta / Mistral. + "x-ai/grok-2-1212", + "deepseek/deepseek-chat", + "deepseek/deepseek-r1", + "meta-llama/llama-3.3-70b-instruct", + "mistralai/mistral-large-2411", +] + + +# ────────────────────── Helpers ──────────────────────────────────────── + +def _filter_run_kwargs(run_method, kwargs: dict) -> dict: + """Drop kwargs the installed synthscholar's pipeline.run doesn't accept. + + Defensive against version skew — older synthscholar wheels were missing + checkpoint/on_checkpoint/assemble_timeout. Keeps the backend usable + against pinned-broken versions while logging which features were dropped. + """ + try: + params = inspect.signature(run_method).parameters + except (TypeError, ValueError): + return kwargs + accepted = {k: v for k, v in kwargs.items() if k in params} + dropped = [k for k in kwargs if k not in params] + if dropped: + logger.warning( + "synthscholar pipeline.run does not accept %s — feature(s) silently disabled.", + dropped, + ) + return accepted + + +def _resolve_api_key(request_key: Optional[str] = None) -> str: + """Resolve the OpenRouter API key for a pipeline run. + + Precedence: + 1. Caller-supplied key (signed-in user's personal key, or the admin + shared `shared.openrouter_api_key` setting forwarded by the frontend). + 2. OPENROUTER_API_KEY env var (operator fallback / dev). + """ + if request_key and request_key.strip(): + return request_key.strip() + key = os.environ.get("OPENROUTER_API_KEY", "") + if not key: + raise HTTPException( + status_code=400, + detail=( + "No OpenRouter API key available. Configure one in the dashboard " + "API key tab, ask an admin to set the shared key in /admin " + "settings, or set OPENROUTER_API_KEY on the backend." + ), + ) + return key + + +def _is_admin(user: dict) -> bool: + """JWT users carry a `scopes` list (existing ml_service convention) and + optionally a `roles` claim (added when /api/token issues admin tokens). + Treat either as admin for the purposes of cross-user oversight.""" + if not user: + return False + roles = user.get("roles") or [] + if isinstance(roles, list) and "Admin" in roles: + return True + scopes = user.get("scopes") or [] + if isinstance(scopes, list) and "admin" in scopes: + return True + return False + + +def _user_email(user: dict) -> str: + """Pull a stable email/identifier from the JWT user dict. + + `get_current_user` returns the user record from the DB; the calling JWT's + `sub` claim is its email. Fall back to `email` attribute on the model. + """ + if isinstance(user, dict): + return user.get("email") or user.get("sub") or "" + return getattr(user, "email", "") or "" + + +async def _session_or_404(review_id: str, user: dict) -> ReviewSession: + session = await review_store.get(review_id) + if not session: + raise HTTPException(status_code=404, detail=f"Review '{review_id}' not found") + if not _is_admin(user): + owner = session.owner_email or "" + if owner and owner != _user_email(user): + # Don't leak existence: return 404 to non-owners. + raise HTTPException(status_code=404, detail=f"Review '{review_id}' not found") + return session + + +def _to_flow_response(flow) -> FlowResponse: + return FlowResponse(**flow.model_dump()) + + +def _to_summary_response(s: ReviewSession) -> ReviewSummaryResponse: + r = s.result + return ReviewSummaryResponse( + review_id=s.review_id, + status=s.status, + title=s.protocol.title if s.protocol else "", + created_at=s.created_at, + completed_at=s.completed_at, + flow=_to_flow_response(r.flow) if r and not isinstance(r, CompareReviewResult) else None, + included_count=( + sum(len(mr.result.included_articles) for mr in r.model_results if mr.result) + if r and isinstance(r, CompareReviewResult) + else (len(r.included_articles) if r else 0) + ), + is_public=s.is_public, + share_to_cache=s.share_to_cache, + error=s.error, + stage=s.stage, + stage_index=s.stage_index, + stage_total=s.stage_total, + stage_done=s.stage_done, + stage_remaining=s.stage_remaining, + articles_included=s.articles_included, + ) + + +def _to_article_summary(article) -> ArticleSummary: + return ArticleSummary( + pmid=article.pmid, + title=article.title, + authors=article.authors, + year=article.year, + journal=article.journal, + doi=article.doi, + source=article.source, + inclusion_status=article.inclusion_status.value if hasattr(article.inclusion_status, "value") else str(article.inclusion_status), + rob_overall=article.risk_of_bias.overall.value if article.risk_of_bias else "", + study_design=article.extracted_data.study_design if article.extracted_data else "", + quality_score=article.quality_score, + ) + + +def _to_detail_response(session: ReviewSession) -> ReviewDetailResponse: + r = session.result + resp = ReviewDetailResponse( + review_id=session.review_id, + status=session.status, + title=session.protocol.title if session.protocol else "", + created_at=session.created_at, + completed_at=session.completed_at, + is_public=session.is_public, + share_to_cache=session.share_to_cache, + enable_cache=session.run_request.get("enable_cache") if session.run_request else None, + last_completed_step=session.last_completed_step, + run_request=session.run_request, + error=session.error, + ) + if r and isinstance(r, CompareReviewResult): + resp.compare_result = r.model_dump() + return resp + if r: + resp.research_question = r.research_question + resp.flow = _to_flow_response(r.flow) + resp.included_articles = [_to_article_summary(a) for a in r.included_articles] + resp.screening_log = [ + ScreeningLogResponse( + pmid=s.pmid, title=s.title, + decision=s.decision.value if hasattr(s.decision, "value") else str(s.decision), + reason=s.reason, + stage=s.stage.value if hasattr(s.stage, "value") else str(s.stage), + ) + for s in r.screening_log + ] + resp.evidence_spans = [ + EvidenceSpanResponse( + text=e.text, paper_pmid=e.paper_pmid, paper_title=e.paper_title, + section=e.section, relevance_score=e.relevance_score, + claim=e.claim, doi=e.doi, + ) + for e in r.evidence_spans + ] + resp.synthesis_text = r.synthesis_text + resp.bias_assessment = r.bias_assessment + resp.limitations = r.limitations + resp.grade_assessments = [ + GRADEResponse( + outcome=outcome, + overall_certainty=g.overall_certainty.value, + summary=g.summary, + domains={k: {"rating": v.rating, "explanation": v.explanation} for k, v in g.domains.items()}, + ) + for outcome, g in r.grade_assessments.items() + ] + resp.search_queries = r.search_queries + resp.data_charting_rubrics = r.data_charting_rubrics + resp.narrative_rows = r.narrative_rows + resp.critical_appraisals = r.critical_appraisals + resp.grounding_validation = r.grounding_validation + resp.structured_abstract = r.structured_abstract + resp.introduction_text = r.introduction_text + resp.conclusions_text = r.conclusions_text + resp.quality_checklist = r.quality_checklist + pga = getattr(r, "per_group_analysis", None) + if pga is not None: + resp.per_group_analysis = pga.model_dump() if hasattr(pga, "model_dump") else pga + return resp + + +_LOG_ENTRY_RE = re.compile(r"^\[([^\]]+)\] (.+)$", re.DOTALL) + + +def _parse_log_entry(entry: str) -> tuple[str, str]: + m = _LOG_ENTRY_RE.match(entry) + if not m: + return (datetime.now(timezone.utc).isoformat(), entry) + ts_str, msg = m.group(1), m.group(2) + if len(ts_str) > 8 and ("T" in ts_str or "-" in ts_str): + return (ts_str, msg) + today = datetime.now(timezone.utc).strftime("%Y-%m-%d") + return (f"{today}T{ts_str}+00:00", msg) + + +# ────────────────────── Background Task ──────────────────────────────── + +def _build_protocol(request: RunReviewRequest, session: ReviewSession) -> ReviewProtocol: + from .database import DATABASE_URL + db_url = DATABASE_URL.replace("+asyncpg", "") if request.enable_cache else "" + return ReviewProtocol( + title=request.protocol.title, + objective=request.protocol.objective or request.protocol.title, + pico_population=request.protocol.pico_population, + pico_intervention=request.protocol.pico_intervention, + pico_comparison=request.protocol.pico_comparison, + pico_outcome=request.protocol.pico_outcome, + inclusion_criteria=request.protocol.inclusion_criteria, + exclusion_criteria=request.protocol.exclusion_criteria, + databases=request.protocol.databases, + date_range_start=request.protocol.date_range_start, + date_range_end=request.protocol.date_range_end, + max_hops=request.protocol.max_hops, + registration_number=request.protocol.registration_number, + protocol_url=request.protocol.protocol_url, + funding_sources=request.protocol.funding_sources, + competing_interests=request.protocol.competing_interests, + rob_tool=request.protocol.rob_tool, + charting_questions=request.protocol.charting_questions, + appraisal_domains=request.protocol.appraisal_domains, + grouping_dimension=request.protocol.grouping_dimension, + default_group_questions=request.protocol.default_group_questions, + per_group_questions=request.protocol.per_group_questions, + pg_dsn=db_url, + review_id=session.review_id, + share_to_cache=session.is_public or session.share_to_cache, + max_articles=request.max_articles, + article_concurrency=request.concurrency, + ) + + +async def _run_pipeline(session: ReviewSession, request: RunReviewRequest): + try: + await session.mark_running() + api_key = _resolve_api_key(request.openrouter_api_key) + main_loop = asyncio.get_running_loop() + session._live.main_loop = main_loop + + protocol = _build_protocol(request, session) + session.protocol = protocol + + pipeline = PRISMAReviewPipeline( + api_key=api_key, + model_name=request.model, + ncbi_api_key=os.environ.get("NCBI_API_KEY", ""), + protocol=protocol, + enable_cache=request.enable_cache, + max_per_query=request.max_results_per_query, + related_depth=request.related_depth, + biorxiv_days=request.biorxiv_days, + ) + + data_items = request.data_items if request.extract_data else None + + async def on_checkpoint(state: dict) -> None: + await session.save_checkpoint(state) + + if not request.auto_confirm: + result = await _run_pipeline_with_gate( + session, pipeline, request, data_items, main_loop, on_checkpoint + ) + else: + run_kwargs = _filter_run_kwargs(pipeline.run, dict( + progress_callback=session.update_progress, + data_items=data_items, + auto_confirm=True, + output_synthesis_style=request.output_synthesis_style, + checkpoint=session.checkpoint_json, + on_checkpoint=on_checkpoint, + )) + result = await pipeline.run(**run_kwargs) + + await session.mark_completed(result) + + except asyncio.CancelledError: + await session.mark_cancelled() + raise + except RuntimeError as e: + # Cross-worker cancel: cancel_review on a non-owning worker writes + # cancel_requested=True; the gate raises this RuntimeError. Treat + # it as a cancellation, not a failure. + if "cancelled by user" in str(e).lower(): + await session.mark_cancelled() + else: + logger.exception("[synth_scholar] pipeline failed for %s", session.review_id) + await session.mark_failed(str(e)) + except Exception as e: + logger.exception("[synth_scholar] pipeline failed for %s", session.review_id) + await session.mark_failed(str(e)) + finally: + review_store.evict(session.review_id) + + +async def _run_compare_pipeline(session: ReviewSession, request: CompareRunRequest): + try: + await session.mark_running() + api_key = _resolve_api_key(request.openrouter_api_key) + main_loop = asyncio.get_running_loop() + session._live.main_loop = main_loop + + protocol = _build_protocol( + RunReviewRequest( + protocol=request.protocol, + model=request.compare_models[0], + max_results_per_query=request.max_results_per_query, + related_depth=request.related_depth, + biorxiv_days=request.biorxiv_days, + enable_cache=request.enable_cache, + extract_data=request.extract_data, + data_items=request.data_items, + max_plan_iterations=request.max_plan_iterations, + output_synthesis_style=request.output_synthesis_style, + ), + session, + ) + session.protocol = protocol + + pipeline = PRISMAReviewPipeline( + api_key=api_key, + model_name=request.compare_models[0], + ncbi_api_key=os.environ.get("NCBI_API_KEY", ""), + protocol=protocol, + enable_cache=request.enable_cache, + max_per_query=request.max_results_per_query, + related_depth=request.related_depth, + biorxiv_days=request.biorxiv_days, + ) + + data_items = request.data_items if request.extract_data else None + + if not request.auto_confirm: + compare_result = await _run_compare_pipeline_with_gate( + session, pipeline, request, data_items, main_loop + ) + else: + compare_kwargs = _filter_run_kwargs(pipeline.run_compare, dict( + models=request.compare_models, + progress_callback=session.update_progress, + data_items=data_items, + auto_confirm=True, + consensus_model=request.consensus_model or request.compare_models[0], + max_plan_iterations=request.max_plan_iterations, + output_synthesis_style=request.output_synthesis_style, + assemble_timeout=3600.0, + )) + compare_result = await pipeline.run_compare(**compare_kwargs) + + await session.mark_completed(compare_result) + + except asyncio.CancelledError: + await session.mark_cancelled() + raise + except RuntimeError as e: + if "cancelled by user" in str(e).lower(): + await session.mark_cancelled() + else: + logger.exception("[synth_scholar] compare pipeline failed for %s", session.review_id) + await session.mark_failed(str(e)) + except Exception as e: + logger.exception("[synth_scholar] compare pipeline failed for %s", session.review_id) + await session.mark_failed(str(e)) + finally: + review_store.evict(session.review_id) + + +async def _run_pipeline_with_gate( + session: ReviewSession, + pipeline: PRISMAReviewPipeline, + request: RunReviewRequest, + data_items, + main_loop: asyncio.AbstractEventLoop, + on_checkpoint_async=None, +): + """Run pipeline in thread pool so the sync confirm_callback can block.""" + + def confirm_callback(plan: ReviewPlan) -> "bool | str": + asyncio.run_coroutine_threadsafe( + _notify_plan(session, plan), main_loop + ).result(timeout=5) + return _wait_plan_gate(session, main_loop) + + def on_checkpoint_thread_safe(state: dict) -> None: + if on_checkpoint_async: + asyncio.run_coroutine_threadsafe(on_checkpoint_async(state), main_loop) + + def thread_fn(): + import asyncio as _asyncio + thread_loop = _asyncio.new_event_loop() + _asyncio.set_event_loop(thread_loop) + try: + run_kwargs = _filter_run_kwargs(pipeline.run, dict( + progress_callback=_make_thread_safe_callback(session, main_loop), + data_items=data_items, + auto_confirm=False, + confirm_callback=confirm_callback, + max_plan_iterations=request.max_plan_iterations, + output_synthesis_style=request.output_synthesis_style, + checkpoint=session.checkpoint_json, + on_checkpoint=on_checkpoint_thread_safe, + )) + return thread_loop.run_until_complete(pipeline.run(**run_kwargs)) + finally: + thread_loop.close() + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="prisma-gate") + try: + result = await main_loop.run_in_executor(executor, thread_fn) + finally: + executor.shutdown(wait=False) + return result + + +async def _run_compare_pipeline_with_gate( + session: ReviewSession, + pipeline: PRISMAReviewPipeline, + request: CompareRunRequest, + data_items, + main_loop: asyncio.AbstractEventLoop, +): + def confirm_callback(plan: ReviewPlan) -> "bool | str": + asyncio.run_coroutine_threadsafe( + _notify_plan(session, plan), main_loop + ).result(timeout=5) + return _wait_plan_gate(session, main_loop) + + def thread_fn(): + import asyncio as _asyncio + thread_loop = _asyncio.new_event_loop() + _asyncio.set_event_loop(thread_loop) + try: + compare_kwargs = _filter_run_kwargs(pipeline.run_compare, dict( + models=request.compare_models, + progress_callback=_make_thread_safe_callback(session, main_loop), + data_items=data_items, + auto_confirm=False, + confirm_callback=confirm_callback, + consensus_model=request.consensus_model or request.compare_models[0], + max_plan_iterations=request.max_plan_iterations, + output_synthesis_style=request.output_synthesis_style, + assemble_timeout=3600.0, + )) + return thread_loop.run_until_complete(pipeline.run_compare(**compare_kwargs)) + finally: + thread_loop.close() + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="prisma-compare-gate") + try: + result = await main_loop.run_in_executor(executor, thread_fn) + finally: + executor.shutdown(wait=False) + return result + + +def _wait_plan_gate( + session: ReviewSession, + main_loop: asyncio.AbstractEventLoop, + timeout_seconds: float = 600.0, +) -> "bool | str": + """Block the pipeline thread until a plan response arrives, from either + side of the cross-worker boundary: + + * Fast path — the local threading.Event set by submit_plan_response when + it happens to land on this worker. + * Slow path — a response written to plan_response_json in Postgres by + submit_plan_response on a different worker, drained here via a 1s + poll (claim_plan_response atomically reads + clears the column). + + Also honors a cross-worker cancel: cancel_review writes + cancel_requested=True from any worker; this poll picks it up. + + Race-condition note: when the response arrives via the slow path, the + worker that handled `submit_plan_response` already wrote + `status="running"` to the DB via `mark_running()`. We MUST mirror that + onto this worker's `_live.status` before returning — otherwise the very + next `_persist_progress` flush from this worker reads + `effective_status = self._live.status` (still "plan_pending") and + overwrites the DB back to plan_pending, which the UI sees as the status + oscillating and the plan dialog re-appearing. + """ + def _mark_resumed() -> None: + # Mirror the running-state transition that submit_plan_response made + # in the DB. Clearing pending_plan stops set_plan_pending's stale + # value from leaking back into the next SSE replay or + # _persist_progress flush. + session._live.status = ReviewStatus.RUNNING.value + session.status = ReviewStatus.RUNNING + session._live.pending_plan = None + + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + if session._live.plan_gate.wait(timeout=1.0): + session._live.plan_gate.clear() + if session._live.cancel_flag.is_set(): + raise RuntimeError("Review cancelled by user") + _mark_resumed() + return session._live.plan_response[0] + + try: + response = asyncio.run_coroutine_threadsafe( + session.claim_plan_response(), main_loop + ).result(timeout=2.0) + except Exception: + response = None + if response is not None: + _mark_resumed() + return response + + try: + cancelled = asyncio.run_coroutine_threadsafe( + session.is_cancel_requested(), main_loop + ).result(timeout=2.0) + except Exception: + cancelled = False + if cancelled: + raise RuntimeError("Review cancelled by user") + + raise TimeoutError("Plan confirmation timed out (10 min)") + + +async def _notify_plan(session: ReviewSession, plan: ReviewPlan) -> None: + session.set_plan_pending(plan, plan.iteration) + session.signal_plan_notify() + + +def _make_thread_safe_callback(session: ReviewSession, main_loop: asyncio.AbstractEventLoop): + """Return a progress_callback that signals SSE from a non-main thread.""" + def callback(message: str): + session.progress_step += 1 + session._live.latest_message = message + session._live.progress_step = session.progress_step + session.pipeline_log.append( + f"[{datetime.now(timezone.utc).isoformat()}] {message}" + ) + session._apply_classified(message) + session._live._write_counter += 1 + if session._live._write_counter >= 10: + session._live._write_counter = 0 + asyncio.run_coroutine_threadsafe(session._persist_progress(), main_loop) + asyncio.run_coroutine_threadsafe(_set_and_clear_progress(session), main_loop) + return callback + + +async def _set_and_clear_progress(session: ReviewSession) -> None: + session._live.progress_event.set() + session._live.progress_event.clear() + + +# ────────────────────── Routes ───────────────────────────────────────── + +@router.get("/synth-scholar/health", response_model=HealthResponse, tags=["SynthScholar — system"]) +async def health(_user: Annotated[dict, Depends(get_current_user)]): + """Health check with available models and RoB tools.""" + return HealthResponse( + status="ok", + version=agent_version, + models=AVAILABLE_MODELS, + rob_tools=[t.value for t in RoBTool], + ) + + +@router.get("/synth-scholar/rob-tools", tags=["SynthScholar — system"]) +async def list_rob_tools(_user: Annotated[dict, Depends(get_current_user)]): + """List all available Risk of Bias tools with their domains.""" + return [ + { + "id": tool.value, + "name": tool.value, + "domains": ROB_DOMAINS.get(tool.value, []), + "domain_count": len(ROB_DOMAINS.get(tool.value, [])), + } + for tool in RoBTool + ] + + +@router.post( + "/synth-scholar/reviews", + response_model=ReviewSummaryResponse, + status_code=202, + tags=["SynthScholar — reviews"], +) +async def create_review( + request: RunReviewRequest, + user: Annotated[dict, Depends(get_current_user)], +): + """Start a new PRISMA review. Returns immediately with a review_id; the + pipeline runs in the background. Use GET /reviews/{id}/stream for live + progress, or poll GET /reviews/{id}/status.""" + # Validate up-front that we have a usable key; fail fast before creating + # the DB row + spawning the background task. + _resolve_api_key(request.openrouter_api_key) + + protocol = ReviewProtocol( + title=request.protocol.title, + objective=request.protocol.objective or request.protocol.title, + ) + # Persist the run request without the API key — it never lands in the DB. + session = await review_store.create( + protocol, + run_request=request.model_dump(exclude={"openrouter_api_key"}), + owner_email=_user_email(user), + ) + + task = asyncio.create_task(_run_pipeline(session, request)) + session._live.cancel_task = task + + return ReviewSummaryResponse( + review_id=session.review_id, + status=session.status, + title=request.protocol.title, + created_at=session.created_at, + ) + + +@router.post( + "/synth-scholar/reviews/compare", + response_model=CompareReviewSummaryResponse, + status_code=202, + tags=["SynthScholar — reviews"], +) +async def create_compare_review( + request: CompareRunRequest, + user: Annotated[dict, Depends(get_current_user)], +): + """Start a compare-mode PRISMA review across 2–5 models.""" + _resolve_api_key(request.openrouter_api_key) + + protocol = ReviewProtocol( + title=request.protocol.title, + objective=request.protocol.objective or request.protocol.title, + ) + run_req_dict = {**request.model_dump(exclude={"openrouter_api_key"}), "compare_mode": True} + session = await review_store.create( + protocol, + run_request=run_req_dict, + owner_email=_user_email(user), + ) + + task = asyncio.create_task(_run_compare_pipeline(session, request)) + session._live.cancel_task = task + + return CompareReviewSummaryResponse( + review_id=session.review_id, + status=session.status, + compare_models=request.compare_models, + created_at=session.created_at, + ) + + +@router.get( + "/synth-scholar/reviews", + response_model=list[ReviewSummaryResponse], + tags=["SynthScholar — reviews"], +) +async def list_reviews(user: Annotated[dict, Depends(get_current_user)]): + """List the caller's review sessions. Admins see every review.""" + owner = None if _is_admin(user) else _user_email(user) + sessions = await review_store.list_for_owner(owner) + return [_to_summary_response(s) for s in sessions] + + +@router.get( + "/synth-scholar/reviews/{review_id}", + response_model=ReviewDetailResponse, + tags=["SynthScholar — reviews"], +) +async def get_review( + review_id: str, + user: Annotated[dict, Depends(get_current_user)], +): + """Get full review detail including synthesis, evidence, screening log.""" + session = await _session_or_404(review_id, user) + return _to_detail_response(session) + + +@router.get( + "/synth-scholar/reviews/{review_id}/status", + response_model=ReviewSummaryResponse, + tags=["SynthScholar — reviews"], +) +async def get_review_status( + review_id: str, + user: Annotated[dict, Depends(get_current_user)], +): + """Lightweight polling endpoint with status + flow counts.""" + session = await _session_or_404(review_id, user) + return _to_summary_response(session) + + +async def _stream_from_db( + review_id: str, + session: ReviewSession, + last_step: int, +) -> AsyncGenerator[str, None]: + """SSE driver for the cross-worker case: the pipeline is running in some + other gunicorn worker, so we don't have an asyncio.Event to wait on. + Poll Postgres on a fixed tick instead. Everything emitted here is a + replay of state that the owning worker already persisted, so it stays + consistent with what the pipeline truly knows.""" + POLL_INTERVAL = 1.5 + last_plan_iteration = ( + session.pending_plan_iteration_db + if session.pending_plan_iteration_db is not None + else 0 + ) + if ( + session.status == ReviewStatus.PLAN_PENDING + and session.pending_plan_db is not None + and last_plan_iteration > 0 + ): + yield ( + f"data: {ProgressEvent(review_id=review_id, step=last_step, message=f'Awaiting plan confirmation (iteration {last_plan_iteration})...', timestamp=datetime.now().isoformat(), event_type='plan_review', plan=session.pending_plan_db).model_dump_json()}\n\n" + ) + + while True: + await asyncio.sleep(POLL_INTERVAL) + # Owning worker may have come back online — let the live path take + # over on the next reconnect. + if review_id in review_store._runtime: + yield f"data: {ProgressEvent(review_id=review_id, step=last_step, message='Owning worker reattached — reconnect for live stream.', timestamp=datetime.now().isoformat(), event_type='cancelled').model_dump_json()}\n\n" + return + + refreshed = await review_store.get(review_id) + if refreshed is None: + return + + new_log = list(refreshed.pipeline_log) + if len(new_log) > last_step: + for i in range(last_step, len(new_log)): + ts, msg = _parse_log_entry(new_log[i]) + ev = classify(msg) + yield ( + f"data: {ProgressEvent(review_id=review_id, step=i + 1, message=msg, timestamp=ts, event_type='progress', source=ev.get('source'), kind=ev.get('kind', 'log'), stage=refreshed.stage, stage_index=refreshed.stage_index, stage_total=refreshed.stage_total, stage_done=refreshed.stage_done, stage_remaining=refreshed.stage_remaining, articles_included=refreshed.articles_included).model_dump_json()}\n\n" + ) + last_step = len(new_log) + + if ( + refreshed.status == ReviewStatus.PLAN_PENDING + and refreshed.pending_plan_db is not None + and (refreshed.pending_plan_iteration_db or 0) > last_plan_iteration + ): + last_plan_iteration = refreshed.pending_plan_iteration_db or 0 + yield ( + f"data: {ProgressEvent(review_id=review_id, step=last_step, message=f'Awaiting plan confirmation (iteration {last_plan_iteration})...', timestamp=datetime.now().isoformat(), event_type='plan_review', plan=refreshed.pending_plan_db).model_dump_json()}\n\n" + ) + + if refreshed.status in (ReviewStatus.COMPLETED, ReviewStatus.FAILED, ReviewStatus.CANCELLED): + etype = ( + "completed" if refreshed.status == ReviewStatus.COMPLETED + else "cancelled" if refreshed.status == ReviewStatus.CANCELLED + else "failed" + ) + yield ( + f"data: {ProgressEvent(review_id=review_id, step=last_step, message=f'Review {refreshed.status.value}' + (f': {refreshed.error}' if refreshed.error else ''), timestamp=datetime.now().isoformat(), event_type=etype).model_dump_json()}\n\n" + ) + return + + +@router.get("/synth-scholar/reviews/{review_id}/stream", tags=["SynthScholar — reviews"]) +async def stream_progress( + review_id: str, + user: Annotated[dict, Depends(_get_user_for_sse)], +): + """Server-Sent Events (SSE) stream of pipeline progress. + + Two paths, picked per-request based on whether this gunicorn worker + owns the pipeline's runtime (`review_store._runtime[review_id]`): + + * Owning worker → live path. asyncio.Event wakeups from the pipeline + thread give sub-second latency. + * Non-owning worker → `_stream_from_db` polls Postgres on a 1.5s tick. + The pipeline persists pipeline_log + pending_plan_json to the DB on + every relevant transition, and `submit_plan_response` writes the + response back via `plan_response_json` — so the round-trip works + from any worker. The owning worker's `_wait_plan_gate` drains that + column to unblock the gate. + + Never call `mark_failed` from this endpoint — the request landing on + the wrong worker is normal, not a failure. + """ + session = await _session_or_404(review_id, user) + + replay_state: dict = { + "stage": None, "stage_index": None, "stage_total": None, + "stage_done": 0, "stage_remaining": None, "articles_included": None, + } + + async def event_generator() -> AsyncGenerator[str, None]: + from .progress_events import merge_into_state + history = list(session.pipeline_log) + for i, entry in enumerate(history): + ts, msg = _parse_log_entry(entry) + ev = classify(msg) + merge_into_state(replay_state, ev) + yield f"data: {ProgressEvent(review_id=review_id, step=i + 1, message=msg, timestamp=ts, event_type='history', source=ev.get('source'), kind=ev.get('kind', 'log'), stage=replay_state['stage'], stage_index=replay_state['stage_index'], stage_total=replay_state['stage_total'], stage_done=replay_state['stage_done'], stage_remaining=replay_state['stage_remaining'], articles_included=replay_state['articles_included']).model_dump_json()}\n\n" + + last_step = len(history) + + # Cross-worker fallback. See stream_progress docstring for the full + # picture; the short version is: _LiveState is per-process, so if + # this request landed on a worker that doesn't own the pipeline, + # we replay from Postgres instead of fabricating a failure. + if review_id not in review_store._runtime and session.status in ( + ReviewStatus.RUNNING, + ReviewStatus.PLAN_PENDING, + ReviewStatus.PENDING, + ): + async for chunk in _stream_from_db(review_id, session, last_step): + yield chunk + return + + if session.status in (ReviewStatus.COMPLETED, ReviewStatus.FAILED, ReviewStatus.CANCELLED): + if session.status == ReviewStatus.CANCELLED: + _etype = "cancelled" + elif session.status == ReviewStatus.COMPLETED: + _etype = "completed" + else: + _etype = "failed" + yield f"data: {ProgressEvent(review_id=review_id, step=last_step, message=f'Review {session.status.value}' + (f': {session.error}' if session.error else ''), timestamp=datetime.now().isoformat(), event_type=_etype).model_dump_json()}\n\n" + return + + bound_token = session._live.run_token + last_plan_iteration = 0 + while True: + current_live = review_store._runtime.get(review_id) + if current_live is not None and current_live.run_token != bound_token: + yield f"data: {ProgressEvent(review_id=review_id, step=last_step, message='Session reset — reconnect for new run.', timestamp=datetime.now().isoformat(), event_type='cancelled').model_dump_json()}\n\n" + return + + try: + await asyncio.wait_for(session._progress_event.wait(), timeout=30.0) + except asyncio.TimeoutError: + if ( + session._live.status == ReviewStatus.PLAN_PENDING.value + and session._live.pending_plan is not None + and session._live.pending_plan.iteration > last_plan_iteration + ): + last_plan_iteration = session._live.pending_plan.iteration + yield f"data: {ProgressEvent(review_id=review_id, step=last_step, message=f'Awaiting plan confirmation (iteration {last_plan_iteration})...', timestamp=datetime.now().isoformat(), event_type='plan_review', plan=session._live.pending_plan.model_dump()).model_dump_json()}\n\n" + else: + yield f"data: {ProgressEvent(review_id=review_id, step=last_step, message='keepalive', timestamp=datetime.now().isoformat(), event_type='keepalive').model_dump_json()}\n\n" + continue + + if ( + session.status == ReviewStatus.PLAN_PENDING + and session._live.pending_plan is not None + and session._live.pending_plan.iteration > last_plan_iteration + ): + last_plan_iteration = session._live.pending_plan.iteration + yield f"data: {ProgressEvent(review_id=review_id, step=last_step, message=f'Awaiting plan confirmation (iteration {last_plan_iteration})...', timestamp=datetime.now().isoformat(), event_type='plan_review', plan=session._live.pending_plan.model_dump()).model_dump_json()}\n\n" + + live_step = session._live.progress_step + if live_step > last_step: + last_step = live_step + msg = session._live.latest_message + latest = session._live.latest_event or {} + event = ProgressEvent( + review_id=review_id, + step=last_step, + message=msg, + timestamp=datetime.now().isoformat(), + event_type="progress", + source=latest.get("source"), + kind=latest.get("kind", "log"), + stage=session._live.stage, + stage_index=session._live.stage_index, + stage_total=session._live.stage_total, + stage_done=session._live.stage_done, + stage_remaining=session._live.stage_remaining, + articles_included=session._live.articles_included, + ) + yield f"data: {event.model_dump_json()}\n\n" + + live_status = session._live.status + if live_status in ("completed", "failed", "cancelled"): + final = ProgressEvent( + review_id=review_id, + step=last_step, + message=f"Review {live_status}" + (f": {session.error}" if session.error else ""), + timestamp=datetime.now().isoformat(), + event_type=live_status, + ) + yield f"data: {final.model_dump_json()}\n\n" + break + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +@router.post("/synth-scholar/reviews/{review_id}/plan-response", tags=["SynthScholar — reviews"]) +async def submit_plan_response( + review_id: str, + body: PlanResponseRequest, + user: Annotated[dict, Depends(get_current_user)], +): + """Respond to a plan confirmation gate. + + The pipeline lives on whichever gunicorn worker handled POST /reviews, + but this request can land on any worker. Source of truth is the DB: + + 1. Write the response to plan_response_json so the owning worker's + confirm_callback drains it via its DB-poll fallback. + 2. If we happen to be the owning worker, also signal the local + threading.Event so the gate unblocks immediately (skips the 1s + poll latency). + 3. Flip status to RUNNING optimistically. The pipeline thread continues + within ~1s of the DB write and starts emitting progress events. + """ + session = await _session_or_404(review_id, user) + if session.status != ReviewStatus.PLAN_PENDING: + raise HTTPException( + status_code=409, + detail=f"Review is not awaiting plan confirmation (status: {session.status.value})", + ) + response = True if body.approved else body.feedback + iteration = ( + session.pending_plan_iteration_db + if session.pending_plan_iteration_db is not None + else 0 + ) + await session.write_plan_response(response) + if review_id in review_store._runtime: + session._live = review_store._runtime[review_id] + session.resolve_plan(response) + await session.mark_running() + return { + "review_id": review_id, + "status": ReviewStatus.RUNNING.value, + "iteration": iteration, + } + + +@router.post( + "/synth-scholar/reviews/{review_id}/retry", + response_model=ReviewSummaryResponse, + status_code=202, + tags=["SynthScholar — reviews"], +) +async def retry_review( + review_id: str, + user: Annotated[dict, Depends(get_current_user)], + body: RetryRequest = Body(default=RetryRequest()), +): + """Retry a failed or cancelled review; optionally override enable_cache.""" + session = await _session_or_404(review_id, user) + if session.status not in (ReviewStatus.FAILED, ReviewStatus.CANCELLED): + raise HTTPException( + status_code=409, + detail=f"Review cannot be retried (status: {session.status.value})", + ) + if not session.run_request: + raise HTTPException( + status_code=500, + detail="Cannot retry — original configuration not saved. Please create a new review.", + ) + await session.reset_for_retry(clear_checkpoint=not body.resume) + run_req = dict(session.run_request) + if body.enable_cache is not None: + run_req["enable_cache"] = body.enable_cache + session.run_request = run_req + review_store._runtime[review_id] = session._live + if run_req.get("compare_mode"): + compare_request = CompareRunRequest(**{k: v for k, v in run_req.items() if k != "compare_mode"}) + task = asyncio.create_task(_run_compare_pipeline(session, compare_request)) + else: + request = RunReviewRequest(**run_req) + task = asyncio.create_task(_run_pipeline(session, request)) + session._live.cancel_task = task + return ReviewSummaryResponse( + review_id=session.review_id, + status=session.status, + title=session.protocol.title if session.protocol else "", + created_at=session.created_at, + completed_at=session.completed_at, + is_public=session.is_public, + share_to_cache=session.share_to_cache, + error=session.error, + ) + + +@router.post( + "/synth-scholar/reviews/{review_id}/cancel", + response_model=ReviewSummaryResponse, + tags=["SynthScholar — reviews"], +) +async def cancel_review( + review_id: str, + user: Annotated[dict, Depends(get_current_user)], +): + """Cancel a running, pending, or plan_pending review. + + Cross-worker behavior: if this request lands on the worker that owns the + pipeline, we cancel the asyncio.Task directly and the existing + `except CancelledError → mark_cancelled` path handles state transitions. + Otherwise we only set cancel_requested in Postgres — the owning worker's + confirm_callback poll picks it up at the next plan gate, and the + pipeline raises, which trips its own mark_failed/mark_cancelled. We do + NOT call mark_cancelled() here from a non-owning worker, because that + would race with a still-running pipeline and corrupt the final status. + """ + session = await _session_or_404(review_id, user) + effective_status = session.status.value + if effective_status not in ("pending", "running", "plan_pending"): + raise HTTPException( + status_code=409, + detail=f"Review cannot be cancelled (status: {effective_status})", + ) + + if review_id in review_store._runtime: + live = review_store._runtime[review_id] + session._live = live + live.cancel_flag.set() + if effective_status == "plan_pending": + session.resolve_plan(False) + if live.cancel_task: + live.cancel_task.cancel() + await session.mark_cancelled() + else: + await session.request_cancel() + if effective_status == "plan_pending": + await session.write_plan_response(False) + # Status flip is deferred — the owning worker's pipeline will hit + # the cancel signal at its next plan gate or progress tick and + # transition CANCELLED itself. The UI sees CANCELLED via the SSE + # DB-poll fallback within ~2s. + + return ReviewSummaryResponse( + review_id=session.review_id, + status=session.status, + title=session.protocol.title if session.protocol else "", + created_at=session.created_at, + completed_at=session.completed_at, + is_public=session.is_public, + share_to_cache=session.share_to_cache, + error=session.error, + ) + + +@router.get("/synth-scholar/reviews/{review_id}/export", tags=["SynthScholar — export"]) +async def export_review( + review_id: str, + user: Annotated[dict, Depends(get_current_user)], + format: str = Query( + default="markdown", + pattern=r"^(markdown|json|bibtex|ttl|jsonld|rubric_markdown|rubric_json|charting_markdown|charting_json|appraisal_markdown|appraisal_json|narrative_summary_markdown|narrative_summary_json)$", + ), + model: Optional[str] = Query(default=None, description="Compare-mode only: export a single model's result by model_name"), +): + """Export a completed review in the requested format.""" + session = await _session_or_404(review_id, user) + if session.status != ReviewStatus.COMPLETED or not session.result: + raise HTTPException( + status_code=400, + detail=f"Review is not completed (status: {session.status.value})", + ) + + result = session.result + ts = datetime.now().strftime("%Y%m%d") + is_compare = isinstance(result, CompareReviewResult) + + single_result = None + single_model_slug = "" + if is_compare: + if model: + match = next( + (r for r in result.model_results if r.model_name == model and r.result is not None), + None, + ) + if match is None: + raise HTTPException( + status_code=404, + detail=f"Model '{model}' has no successful result in this compare review", + ) + single_result = match.result + single_model_slug = "_" + re.sub(r"[^a-zA-Z0-9_-]+", "-", model).strip("-") + else: + first_ok = next((r for r in result.model_results if r.result is not None), None) + if first_ok is not None: + single_result = first_ok.result + single_model_slug = "_" + re.sub(r"[^a-zA-Z0-9_-]+", "-", first_ok.model_name).strip("-") + + def _stem(base: str) -> str: + return f"{base}{single_model_slug}_{ts}" if (is_compare and single_result and model) else f"{base}_{ts}" + + if format == "markdown": + if is_compare and not model: + content = to_compare_markdown(result) + filename = f"prisma_compare_{ts}.md" + else: + target = single_result or result + content = to_markdown(target) + pr = getattr(target, "prisma_review", None) + extraction_with_fields = ( + pr and pr.methods + and any(getattr(r, "field_answers", None) for r in (pr.methods.data_extraction or [])) + ) + if extraction_with_fields: + content += "\n\n---\n\n" + to_charting_markdown(target) + filename = f"{_stem('prisma_review')}.md" + media_type = "text/markdown" + elif format == "json": + if is_compare and not model: + content = to_compare_json(result) + filename = f"prisma_compare_{ts}.json" + else: + content = to_json(single_result or result) + filename = f"{_stem('prisma_review')}.json" + media_type = "application/json" + elif format == "bibtex": + content = to_bibtex(single_result or result) + filename = f"{_stem('prisma_references')}.bib" + media_type = "text/plain" + elif format == "ttl": + content = to_turtle(single_result or result) + filename = f"{_stem('prisma_review')}.ttl" + media_type = "text/turtle" + elif format == "jsonld": + content = to_jsonld(single_result or result) + filename = f"{_stem('prisma_review')}.jsonld" + media_type = "application/ld+json" + elif format == "rubric_markdown": + if is_compare and not model: + content = to_compare_charting_markdown(result) + filename = f"prisma_compare_rubrics_{ts}.md" + else: + content = to_rubric_markdown(single_result or result) + filename = f"{_stem('prisma_rubrics')}.md" + media_type = "text/markdown" + elif format == "rubric_json": + if is_compare and not model: + content = to_compare_charting_json(result) + filename = f"prisma_compare_rubrics_{ts}.json" + else: + content = to_rubric_json(single_result or result) + filename = f"{_stem('prisma_rubrics')}.json" + media_type = "application/json" + elif format in ("charting_markdown", "charting_json"): + target = single_result or result + pr = getattr(target, "prisma_review", None) + has_extraction = bool(pr and pr.methods and getattr(pr.methods, "data_extraction", None)) + has_field_answers = has_extraction and any( + getattr(r, "field_answers", None) for r in pr.methods.data_extraction + ) + if not has_field_answers: + raise HTTPException( + status_code=422, + detail=( + "Charting export not available — this review has no field-level " + "extraction. Re-run with extract_data=true and a charting template." + ), + ) + if format == "charting_markdown": + content = to_charting_markdown(target) + filename = f"{_stem('prisma_charting')}.md" + media_type = "text/markdown" + else: + content = to_charting_json(target) + filename = f"{_stem('prisma_charting')}.json" + media_type = "application/json" + elif format in ("appraisal_markdown", "appraisal_json"): + target = single_result or result + pr = getattr(target, "prisma_review", None) + appraisal_results = ( + (pr.methods.critical_appraisal_results if pr and pr.methods else None) + or getattr(target, "structured_appraisal_results", None) + or [] + ) + if not appraisal_results: + raise HTTPException( + status_code=422, + detail=( + "Critical appraisal export not available — this review has no " + "appraisal results. Re-run with a critical_appraisal config." + ), + ) + if format == "appraisal_markdown": + content = to_appraisal_markdown(target) + filename = f"{_stem('prisma_appraisal')}.md" + media_type = "text/markdown" + else: + content = to_appraisal_json(target) + filename = f"{_stem('prisma_appraisal')}.json" + media_type = "application/json" + elif format in ("narrative_summary_markdown", "narrative_summary_json"): + target = single_result or result + if not getattr(target, "narrative_rows", None): + raise HTTPException( + status_code=422, + detail=( + "Narrative summary export not available — this review has no narrative rows." + ), + ) + if format == "narrative_summary_markdown": + content = to_narrative_summary_markdown(target) + filename = f"{_stem('prisma_narrative_summary')}.md" + media_type = "text/markdown" + else: + content = to_narrative_summary_json(target) + filename = f"{_stem('prisma_narrative_summary')}.json" + media_type = "application/json" + else: + raise HTTPException(status_code=400, detail=f"Unknown format: {format}") + + return StreamingResponse( + iter([content]), + media_type=media_type, + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, + ) + + +@router.get("/synth-scholar/reviews/{review_id}/log", tags=["SynthScholar — reviews"]) +async def get_pipeline_log( + review_id: str, + user: Annotated[dict, Depends(get_current_user)], +): + """Get the full pipeline execution log.""" + session = await _session_or_404(review_id, user) + log_entries = list(session.pipeline_log) + log_events = [ + {"step": i + 1, "message": msg, "timestamp": ts} + for i, (ts, msg) in enumerate(_parse_log_entry(e) for e in log_entries) + ] + return { + "review_id": review_id, + "status": session.status.value, + "step_count": session.progress_step, + "log": log_entries, + "log_events": log_events, + } + + +@router.patch( + "/synth-scholar/reviews/{review_id}/visibility", + response_model=ReviewSummaryResponse, + tags=["SynthScholar — reviews"], +) +async def set_review_visibility( + review_id: str, + body: VisibilityRequest, + user: Annotated[dict, Depends(get_current_user)], +): + """Toggle public/private visibility. Mirrors share_to_cache.""" + session = await _session_or_404(review_id, user) + await review_store.set_visibility(review_id, body.is_public) + session.is_public = body.is_public + session.share_to_cache = body.is_public + await _apply_cache_sharing(review_id, body.is_public) + _r = session.result + return ReviewSummaryResponse( + review_id=session.review_id, + status=session.status, + title=session.protocol.title if session.protocol else "", + created_at=session.created_at, + completed_at=session.completed_at, + flow=_to_flow_response(_r.flow) if _r and not isinstance(_r, CompareReviewResult) else None, + included_count=len(_r.included_articles) if _r and not isinstance(_r, CompareReviewResult) else 0, + is_public=session.is_public, + share_to_cache=session.share_to_cache, + error=session.error, + ) + + +@router.patch( + "/synth-scholar/reviews/{review_id}/cache-sharing", + response_model=ReviewSummaryResponse, + tags=["SynthScholar — reviews"], +) +async def set_cache_sharing( + review_id: str, + body: CacheSharingRequest, + user: Annotated[dict, Depends(get_current_user)], +): + """Toggle whether this review's cache is available to other users.""" + session = await _session_or_404(review_id, user) + await review_store.set_cache_sharing(review_id, body.share_to_cache) + session.share_to_cache = body.share_to_cache + await _apply_cache_sharing(review_id, body.share_to_cache) + _r2 = session.result + return ReviewSummaryResponse( + review_id=session.review_id, + status=session.status, + title=session.protocol.title if session.protocol else "", + created_at=session.created_at, + completed_at=session.completed_at, + flow=_to_flow_response(_r2.flow) if _r2 and not isinstance(_r2, CompareReviewResult) else None, + included_count=len(_r2.included_articles) if _r2 and not isinstance(_r2, CompareReviewResult) else 0, + is_public=session.is_public, + share_to_cache=session.share_to_cache, + error=session.error, + ) + + +async def _apply_cache_sharing(review_id: str, is_shared: bool) -> None: + """Update is_shared on any existing review_cache entries for this review.""" + from .database import DATABASE_URL + try: + from synthscholar.cache.store import CacheStore # type: ignore[import-not-found] + pg_dsn = DATABASE_URL.replace("+asyncpg", "") + store = CacheStore(dsn=pg_dsn) + await store.connect() + await store.set_sharing(review_id, is_shared) + await store.close() + except Exception: + pass # Cache may not be initialised; safe to ignore + + +@router.delete("/synth-scholar/reviews/{review_id}", tags=["SynthScholar — reviews"]) +async def delete_review( + review_id: str, + user: Annotated[dict, Depends(get_current_user)], +): + """Delete a review session.""" + session = await _session_or_404(review_id, user) + if session.status == ReviewStatus.RUNNING: + raise HTTPException( + status_code=409, + detail="Cannot delete a running review. Wait for completion.", + ) + await review_store.delete(review_id) + return {"deleted": review_id} + + +# ────────────────────── Search ──────────────────────────────────────── + +def _psycopg_dsn() -> str: + from .database import DATABASE_URL + return DATABASE_URL.replace("+asyncpg", "") + + +@router.post( + "/synth-scholar/search/literature", + response_model=LiteratureSearchResponse, + tags=["SynthScholar — search"], +) +async def search_literature( + req: LiteratureSearchRequest, + _user: Annotated[dict, Depends(get_current_user)], +): + """Search the cached article corpus (keyword / by_title / semantic).""" + from synthscholar.cache.article_store import ArticleStore # type: ignore[import-not-found] + + store = ArticleStore(dsn=_psycopg_dsn()) + try: + await store.connect() + try: + if req.mode == "semantic": + articles = await store.search_semantic(req.query, limit=req.top) + elif req.mode == "by_title": + articles = await store.search_by_title(req.query, limit=req.top) + else: + articles = await store.search_by_keyword(req.query, limit=req.top) + except RuntimeError as exc: + raise HTTPException(status_code=503, detail=str(exc)) + + results = [ + SearchArticleResult( + pmid=a.pmid, title=a.title, abstract=a.abstract, authors=a.authors, + journal=a.journal, year=a.year, doi=a.doi, pmc_id=a.pmc_id, + source=getattr(a, "source", "") or "", + similarity=getattr(a, "similarity", None), + ) + for a in articles + ] + + synthesis = None + if req.summarize and articles: + api_key = _get_api_key() + from synthscholar.agents import AgentDeps, run_search_synthesis # type: ignore[import-not-found] + + deps = AgentDeps( + protocol=ReviewProtocol(title=req.query, objective=req.query), + api_key=api_key, + model_name=req.summary_model, + ) + try: + synth = await run_search_synthesis(req.query, articles, deps, top_k=req.summary_top) + synthesis = SearchSynthesisResponse( + query=synth.query, + n_articles_synthesized=synth.n_articles_synthesized, + overview=synth.overview, + overall_caveats=getattr(synth, "overall_caveats", "") or "", + groups=[ + SearchSynthesisGroup( + label=g.label, + n_studies=g.n_studies, + aggregate_finding=g.aggregate_finding, + representative_pmids=list(getattr(g, "representative_pmids", []) or []), + caveats=getattr(g, "caveats", "") or "", + ) + for g in synth.groups + ], + ) + except Exception as exc: + raise HTTPException(status_code=502, detail=f"Search synthesis failed: {exc}") + + return LiteratureSearchResponse( + query=req.query, mode=req.mode, results=results, synthesis=synthesis, + ) + finally: + await store.close() + + +@router.post( + "/synth-scholar/search/reviews", + response_model=ReviewSearchResponse, + tags=["SynthScholar — search"], +) +async def search_reviews( + req: ReviewSearchRequest, + _user: Annotated[dict, Depends(get_current_user)], +): + """Search past completed reviews stored in the cache (shared or owned).""" + from synthscholar.cache.store import CacheStore # type: ignore[import-not-found] + + store = CacheStore(dsn=_psycopg_dsn()) + try: + await store.connect() + try: + if req.mode == "semantic": + entries = await store.search_reviews_semantic( + req.query, limit=req.top, include_expired=req.include_expired, + ) + else: + entries = await store.search_reviews_keyword( + req.query, limit=req.top, include_expired=req.include_expired, + ) + except RuntimeError as exc: + raise HTTPException(status_code=503, detail=str(exc)) + + results = [] + for e in entries: + crit = e.criteria_json or {} + results.append( + SearchReviewResult( + review_id=getattr(e, "review_id", "") or "", + criteria_fingerprint=e.criteria_fingerprint, + title=str(crit.get("title", "")), + research_question=str(crit.get("question", "") or crit.get("research_question", "")), + model_name=e.model_name, + created_at=e.created_at, + similarity=getattr(e, "similarity", None), + ) + ) + return ReviewSearchResponse(query=req.query, mode=req.mode, results=results) + finally: + await store.close() diff --git a/ml_service/core/synth_scholar/schemas.py b/ml_service/core/synth_scholar/schemas.py new file mode 100644 index 0000000..9800995 --- /dev/null +++ b/ml_service/core/synth_scholar/schemas.py @@ -0,0 +1,368 @@ +"""API request / response schemas. + +Thin wrappers around synthscholar models for API boundaries, plus +review-session management models. Imports from `synthscholar.models` are +late-bound below so this module loads even when the library hasn't been +installed yet (e.g. lint / typecheck environments without the heavy +dependency tree). +""" + +from __future__ import annotations + +from enum import Enum +from typing import Literal, Optional, Any +from datetime import datetime + +from pydantic import BaseModel, Field, model_validator + +from synthscholar.models import ( # type: ignore[import-not-found] + RoBTool, + DataChartingRubric, + PRISMANarrativeRow, + CriticalAppraisalRubric, + GroundingValidationResult, +) + + +# ────────────────────── Request Schemas ──────────────────────────────── + +class ProtocolRequest(BaseModel): + """Create or update a review protocol.""" + title: str = Field(..., min_length=1, max_length=500) + objective: str = "" + pico_population: str = "" + pico_intervention: str = "" + pico_comparison: str = "" + pico_outcome: str = "" + inclusion_criteria: str = "" + exclusion_criteria: str = "" + databases: list[str] = Field( + default_factory=lambda: [ + "pubmed", "biorxiv", "medrxiv", + "europe_pmc", "openalex", "crossref", "doaj", + "semantic_scholar", "arxiv", "core", + ] + ) + date_range_start: str = "" + date_range_end: str = "" + max_hops: int = Field(default=1, ge=0, le=10) + registration_number: str = "" + protocol_url: str = "" + funding_sources: str = "" + competing_interests: str = "" + rob_tool: RoBTool = RoBTool.ROB_2 + charting_questions: list[str] = Field(default_factory=list) + appraisal_domains: list[str] = Field(default_factory=list) + + grouping_dimension: str = Field( + default="disorder_cohort", + description=( + "DataChartingRubric attribute used for bucketing during per-group analysis." + ), + ) + default_group_questions: list[str] = Field(default_factory=list, max_length=10) + per_group_questions: dict[str, list[str]] = Field(default_factory=dict) + + +class RunReviewRequest(BaseModel): + """Start a review pipeline run.""" + protocol: ProtocolRequest + model: str = "anthropic/claude-sonnet-4" + max_results_per_query: int = Field(default=20, ge=5, le=1000) + related_depth: int = Field(default=1, ge=0, le=10) + biorxiv_days: int = Field(default=180, ge=30, le=730) + enable_cache: bool = True + extract_data: bool = True + data_items: list[str] = Field(default_factory=list) + auto_confirm: bool = True + max_plan_iterations: int = Field(default=3, ge=1, le=10) + output_synthesis_style: Literal["paragraph", "question_answer", "bullet_list", "table"] = "paragraph" + max_articles: Optional[int] = Field(default=None, ge=10, le=10000) + concurrency: int = Field(default=5, ge=1, le=50) + # Caller-supplied OpenRouter API key. The frontend resolves this from the + # signed-in user's personal key (sessionStorage) or the admin-shared key + # (admin_setting `shared.openrouter_api_key`). The backend uses it as the + # primary key for the run; if absent, falls back to OPENROUTER_API_KEY env. + # Not persisted to the review's run_request_json — see routes.py. + openrouter_api_key: Optional[str] = Field(default=None, exclude=True, repr=False) + + +class CompareRunRequest(BaseModel): + """Start a compare-mode review with 2–5 models.""" + protocol: ProtocolRequest + compare_models: list[str] = Field(..., min_length=2, max_length=5) + consensus_model: Optional[str] = None + max_results_per_query: int = Field(default=20, ge=5, le=1000) + related_depth: int = Field(default=1, ge=0, le=10) + biorxiv_days: int = Field(default=180, ge=30, le=730) + enable_cache: bool = True + extract_data: bool = True + data_items: list[str] = Field(default_factory=list) + auto_confirm: bool = True + max_plan_iterations: int = Field(default=3, ge=1, le=10) + output_synthesis_style: Literal["paragraph", "question_answer", "bullet_list", "table"] = "paragraph" + max_articles: Optional[int] = Field(default=None, ge=10, le=10000) + concurrency: int = Field(default=5, ge=1, le=50) + # See RunReviewRequest.openrouter_api_key — same semantics. + openrouter_api_key: Optional[str] = Field(default=None, exclude=True, repr=False) + + +class RetryRequest(BaseModel): + enable_cache: Optional[bool] = None + resume: bool = True + + +class PlanResponseRequest(BaseModel): + approved: bool + feedback: str = "" + + @model_validator(mode="after") + def feedback_required_when_rejected(self) -> "PlanResponseRequest": + if not self.approved and not self.feedback.strip(): + raise ValueError("feedback is required when approved=false") + return self + + +# ────────────────────── Response Schemas ─────────────────────────────── + +class ReviewStatus(str, Enum): + PENDING = "pending" + PLAN_PENDING = "plan_pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class ArticleSummary(BaseModel): + pmid: str + title: str + authors: str + year: str + journal: str + doi: str + source: str + inclusion_status: str + rob_overall: str = "" + study_design: str = "" + quality_score: float = 0.0 + + +class EvidenceSpanResponse(BaseModel): + text: str + paper_pmid: str + paper_title: str + section: str + relevance_score: float + claim: str + doi: str + + +class ScreeningLogResponse(BaseModel): + pmid: str + title: str + decision: str + reason: str + stage: str + + +class GRADEResponse(BaseModel): + outcome: str + overall_certainty: str + summary: str + domains: dict = Field(default_factory=dict) + + +class FlowResponse(BaseModel): + """PRISMA flow counts.""" + db_pubmed: int = 0 + db_biorxiv: int = 0 + db_medrxiv: int = 0 + db_related: int = 0 + db_hops: int = 0 + db_other_sources: dict[str, int] = Field(default_factory=dict) + total_identified: int = 0 + duplicates_removed: int = 0 + after_dedup: int = 0 + screened_title_abstract: int = 0 + excluded_title_abstract: int = 0 + sought_fulltext: int = 0 + not_retrieved: int = 0 + assessed_eligibility: int = 0 + excluded_eligibility: int = 0 + excluded_reasons: dict[str, int] = Field(default_factory=dict) + included_synthesis: int = 0 + + +class ReviewSummaryResponse(BaseModel): + review_id: str + status: ReviewStatus + title: str + created_at: str + completed_at: Optional[str] = None + flow: Optional[FlowResponse] = None + included_count: int = 0 + is_public: bool = False + share_to_cache: bool = False + error: Optional[str] = None + stage: Optional[str] = None + stage_index: Optional[int] = None + stage_total: Optional[int] = None + stage_done: Optional[int] = None + stage_remaining: Optional[int] = None + articles_included: Optional[int] = None + + +class CompareReviewSummaryResponse(BaseModel): + review_id: str + status: ReviewStatus + compare_models: list[str] + created_at: str + + +class ReviewDetailResponse(BaseModel): + review_id: str + status: ReviewStatus + title: str + created_at: str + completed_at: Optional[str] = None + is_public: bool = False + share_to_cache: bool = False + enable_cache: Optional[bool] = None + last_completed_step: int = 0 + run_request: Optional[dict] = None + research_question: str = "" + flow: Optional[FlowResponse] = None + included_articles: list[ArticleSummary] = Field(default_factory=list) + screening_log: list[ScreeningLogResponse] = Field(default_factory=list) + evidence_spans: list[EvidenceSpanResponse] = Field(default_factory=list) + synthesis_text: str = "" + bias_assessment: str = "" + limitations: str = "" + grade_assessments: list[GRADEResponse] = Field(default_factory=list) + search_queries: list[str] = Field(default_factory=list) + data_charting_rubrics: list[DataChartingRubric] = Field(default_factory=list) + narrative_rows: list[PRISMANarrativeRow] = Field(default_factory=list) + critical_appraisals: list[CriticalAppraisalRubric] = Field(default_factory=list) + grounding_validation: Optional[GroundingValidationResult] = None + structured_abstract: str = "" + introduction_text: str = "" + conclusions_text: str = "" + quality_checklist: dict[str, bool] = Field(default_factory=dict) + error: Optional[str] = None + compare_result: Optional[dict] = None + per_group_analysis: Optional[dict] = None + + +class ProgressEvent(BaseModel): + """SSE progress event. + + `event_type` is the transport category (progress | history | plan_review | + completed | failed | cancelled | keepalive). `kind` is the pipeline-level + classification (log | stage_start | stage_done | article_start | + article_done | plan_ready | done) produced by the server-side parser. + """ + review_id: str + step: int + message: str + timestamp: str = "" + event_type: str = "progress" + source: Optional[str] = None + plan: Optional[dict] = None + + kind: str = "log" + stage: Optional[str] = None + stage_index: Optional[int] = None + stage_total: Optional[int] = None + stage_done: Optional[int] = None + stage_remaining: Optional[int] = None + articles_included: Optional[int] = None + + +class HealthResponse(BaseModel): + status: str = "ok" + version: str = "" + models: list[str] = Field(default_factory=list) + rob_tools: list[str] = Field(default_factory=list) + + +class VisibilityRequest(BaseModel): + is_public: bool + + +class CacheSharingRequest(BaseModel): + share_to_cache: bool + + +# ────────────────────── Search Schemas ──────────────────────────────── + +LiteratureSearchMode = Literal["keyword", "by_title", "semantic"] +ReviewSearchMode = Literal["keyword", "semantic"] + + +class LiteratureSearchRequest(BaseModel): + query: str = Field(..., min_length=1) + mode: LiteratureSearchMode = "keyword" + top: int = Field(default=20, ge=1, le=200) + summarize: bool = False + summary_top: int = Field(default=15, ge=1, le=100) + summary_model: str = "anthropic/claude-sonnet-4" + + +class ReviewSearchRequest(BaseModel): + query: str = Field(..., min_length=1) + mode: ReviewSearchMode = "keyword" + top: int = Field(default=20, ge=1, le=200) + include_expired: bool = False + + +class SearchArticleResult(BaseModel): + pmid: str = "" + title: str = "" + abstract: str = "" + authors: str = "" + journal: str = "" + year: str = "" + doi: str = "" + pmc_id: str = "" + source: str = "" + similarity: Optional[float] = None + + +class SearchReviewResult(BaseModel): + review_id: str = "" + criteria_fingerprint: str + title: str = "" + research_question: str = "" + model_name: str + created_at: datetime + similarity: Optional[float] = None + + +class SearchSynthesisGroup(BaseModel): + label: str + n_studies: int + aggregate_finding: str + representative_pmids: list[str] = Field(default_factory=list) + caveats: str = "" + + +class SearchSynthesisResponse(BaseModel): + query: str + n_articles_synthesized: int + overview: str + overall_caveats: str = "" + groups: list[SearchSynthesisGroup] = Field(default_factory=list) + + +class LiteratureSearchResponse(BaseModel): + query: str + mode: LiteratureSearchMode + results: list[SearchArticleResult] + synthesis: Optional[SearchSynthesisResponse] = None + + +class ReviewSearchResponse(BaseModel): + query: str + mode: ReviewSearchMode + results: list[SearchReviewResult] diff --git a/ml_service/core/synth_scholar/store.py b/ml_service/core/synth_scholar/store.py new file mode 100644 index 0000000..8e188c4 --- /dev/null +++ b/ml_service/core/synth_scholar/store.py @@ -0,0 +1,715 @@ +"""PostgreSQL-backed review session store — DB-only architecture. + +All persistent state (status, log, stage counters, result) lives exclusively +in Postgres. The only things kept in memory are the runtime primitives that +*cannot* be serialised: asyncio.Event, threading.Event, asyncio.Task, and a +tiny mirror of the last-known hot-path values so the SSE loop can emit events +without an extra DB round-trip on every message. + +Ported from aep-knowledge-synthesis with one extension: the `owner_email` +column links reviews to the JWT subject that created them, so the routes +layer can scope listings per user (admins see all). +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Optional + +from sqlalchemy import select, delete as sa_delete, update as sa_update + +from synthscholar.models import ( # type: ignore[import-not-found] + PRISMAReviewResult, + CompareReviewResult, + ReviewPlan, + ReviewProtocol, +) + +from .database import async_session +from .db_models import ReviewRow +from .progress_events import classify, merge_into_state +from .schemas import ReviewStatus + +logger = logging.getLogger(__name__) + +# How many progress messages to buffer before flushing to DB. Reduces write +# amplification while keeping intermediate state reasonably fresh. +_WRITE_BATCH = 10 + + +# ── Runtime-only state ────────────────────────────────────────────────── + +@dataclass +class _LiveState: + """Per-session in-memory state for SSE subscribers and pipeline control. + + Only runtime OS-level primitives and a tiny mirror of hot-path values + live here. Everything else is in Postgres. + """ + run_token: str = field(default_factory=lambda: uuid.uuid4().hex) + + progress_event: asyncio.Event = field(default_factory=asyncio.Event) + plan_gate: threading.Event = field(default_factory=threading.Event) + plan_response: list = field(default_factory=lambda: [True]) + pending_plan: Optional[Any] = None + plan_notify: asyncio.Event = field(default_factory=asyncio.Event) + main_loop: Optional[Any] = None + cancel_task: Optional[Any] = None + cancel_flag: threading.Event = field(default_factory=threading.Event) + + latest_message: str = "" + latest_event: Optional[dict] = None + status: str = "pending" + progress_step: int = 0 + stage: Optional[str] = None + stage_index: Optional[int] = None + stage_total: Optional[int] = None + stage_done: int = 0 + stage_remaining: Optional[int] = None + articles_included: Optional[int] = None + + _write_counter: int = 0 + + +# ── ReviewSession ──────────────────────────────────────────────────────── + +@dataclass +class ReviewSession: + """Thin DTO that mirrors the DB row plus the attached _LiveState.""" + review_id: str + status: ReviewStatus = ReviewStatus.PENDING + protocol: Optional[ReviewProtocol] = None + result: Optional[Any] = None + pipeline_log: list[str] = field(default_factory=list) + progress_step: int = 0 + created_at: str = "" + completed_at: Optional[str] = None + error: Optional[str] = None + is_public: bool = False + share_to_cache: bool = False + run_request: Optional[dict] = None + checkpoint_json: Optional[dict] = None + last_completed_step: int = 0 + owner_email: Optional[str] = None + + stage: Optional[str] = None + stage_index: Optional[int] = None + stage_total: Optional[int] = None + stage_done: int = 0 + stage_remaining: Optional[int] = None + articles_included: Optional[int] = None + latest_event: Optional[dict] = None + + # DB-mirrored plan fields. Populated by _row_to_session so the SSE + # generator can replay plan-pending events on any worker, not just the + # one running the pipeline. + pending_plan_db: Optional[dict] = None + pending_plan_iteration_db: Optional[int] = None + + _live: _LiveState = field(default_factory=_LiveState) + + @property + def _progress_event(self) -> asyncio.Event: + return self._live.progress_event + + @property + def _latest_message(self) -> str: + return self._live.latest_message + + @_latest_message.setter + def _latest_message(self, value: str): + self._live.latest_message = value + + def __post_init__(self): + if not self.created_at: + self.created_at = datetime.now(timezone.utc).isoformat() + + # ── Progress mutation ────────────────────────────────────────────── + + def update_progress(self, message: str): + self.progress_step += 1 + self._latest_message = message + self.pipeline_log.append( + f"[{datetime.now(timezone.utc).isoformat()}] {message}" + ) + if len(self.pipeline_log) > 5000: + self.pipeline_log = self.pipeline_log[-5000:] + self._apply_classified(message) + self._live.progress_event.set() + self._live.progress_event.clear() + self._live._write_counter += 1 + if self._live._write_counter >= _WRITE_BATCH: + self._live._write_counter = 0 + try: + loop = asyncio.get_running_loop() + loop.create_task(self._persist_progress()) + except RuntimeError: + pass + + def _apply_classified(self, message: str) -> dict: + event = dict(classify(message)) + state: dict = { + "stage": self.stage, + "stage_index": self.stage_index, + "stage_total": self.stage_total, + "stage_done": self.stage_done, + "stage_remaining": self.stage_remaining, + "articles_included": self.articles_included, + } + merge_into_state(state, event) + self.stage = state["stage"] + self.stage_index = state["stage_index"] + self.stage_total = state["stage_total"] + self.stage_done = state["stage_done"] or 0 + self.stage_remaining = state["stage_remaining"] + self.articles_included = state["articles_included"] + self.latest_event = event + self._live.stage = self.stage + self._live.stage_index = self.stage_index + self._live.stage_total = self.stage_total + self._live.stage_done = self.stage_done + self._live.stage_remaining = self.stage_remaining + self._live.articles_included = self.articles_included + self._live.latest_event = event + self._live.progress_step = self.progress_step + return event + + def rebuild_progress_state(self) -> None: + """Recompute cumulative progress state from the full pipeline_log.""" + self.stage = None + self.stage_index = None + self.stage_total = None + self.stage_done = 0 + self.stage_remaining = None + self.articles_included = None + for entry in self.pipeline_log: + msg = entry.split("] ", 1)[1] if entry.startswith("[") and "] " in entry else entry + self._apply_classified(msg) + + # ── DB persistence ───────────────────────────────────────────────── + + async def _persist_progress(self): + try: + effective_status = self._live.status or self.status.value + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values( + progress_step=self.progress_step, + pipeline_log=list(self.pipeline_log[-2000:]), + status=effective_status, + stage=self.stage, + stage_idx=self.stage_index, + stage_total=self.stage_total, + stage_done_count=self.stage_done or 0, + stage_remaining=self.stage_remaining, + articles_included=self.articles_included, + latest_message=self._live.latest_message or None, + ) + ) + await db.commit() + except Exception as exc: + logger.warning("[_persist_progress] %s: %s", self.review_id, exc) + + async def mark_completed(self, result: PRISMAReviewResult) -> None: + self.status = ReviewStatus.COMPLETED + self._live.status = ReviewStatus.COMPLETED.value + self.result = result + self.completed_at = datetime.now(timezone.utc).isoformat() + self._live.progress_event.set() + try: + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values( + status=self.status.value, + result_json=self.result.model_dump() if self.result else None, + completed_at=datetime.fromisoformat(self.completed_at), + progress_step=self.progress_step, + pipeline_log=list(self.pipeline_log), + stage=self.stage, + stage_idx=self.stage_index, + stage_total=self.stage_total, + stage_done_count=self.stage_done or 0, + stage_remaining=self.stage_remaining, + articles_included=self.articles_included, + latest_message=self._live.latest_message or None, + pending_plan_json=None, + pending_plan_iteration=None, + plan_response_json=None, + ) + ) + await db.commit() + except Exception as exc: + logger.error("[mark_completed] persist error %s: %s", self.review_id, exc) + self.status = ReviewStatus.FAILED + self._live.status = ReviewStatus.FAILED.value + self.error = f"Persist failed after completion: {exc}" + await self._persist_failed_minimal() + + async def mark_failed(self, error: str) -> None: + self.status = ReviewStatus.FAILED + self._live.status = ReviewStatus.FAILED.value + self.error = error + self.completed_at = datetime.now(timezone.utc).isoformat() + self._live.progress_event.set() + await self._persist_failed_minimal() + + async def _persist_failed_minimal(self) -> None: + try: + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values( + status=self.status.value, + error=self.error, + completed_at=datetime.fromisoformat(self.completed_at) if self.completed_at else None, + progress_step=self.progress_step, + pipeline_log=list(self.pipeline_log[-2000:]), + pending_plan_json=None, + pending_plan_iteration=None, + plan_response_json=None, + ) + ) + await db.commit() + except Exception as exc: + logger.critical("[_persist_failed_minimal] %s: %s", self.review_id, exc) + + async def mark_cancelled(self) -> None: + self.status = ReviewStatus.CANCELLED + self._live.status = ReviewStatus.CANCELLED.value + self.completed_at = datetime.now(timezone.utc).isoformat() + self._live.progress_event.set() + try: + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values( + status=self.status.value, + completed_at=datetime.fromisoformat(self.completed_at), + progress_step=self.progress_step, + pipeline_log=list(self.pipeline_log[-2000:]), + pending_plan_json=None, + pending_plan_iteration=None, + plan_response_json=None, + cancel_requested=False, + ) + ) + await db.commit() + except Exception as exc: + logger.error("[mark_cancelled] %s: %s", self.review_id, exc) + + async def mark_running(self) -> None: + self.status = ReviewStatus.RUNNING + self._live.status = ReviewStatus.RUNNING.value + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values( + status=ReviewStatus.RUNNING.value, + pending_plan_json=None, + pending_plan_iteration=None, + ) + ) + await db.commit() + + async def reset_for_retry(self, clear_checkpoint: bool = True): + self.status = ReviewStatus.PENDING + self._live.status = ReviewStatus.PENDING.value + self.progress_step = 0 + self.pipeline_log = [] + self.error = None + self.completed_at = None + self.result = None + self.stage = None + self.stage_index = None + self.stage_total = None + self.stage_done = 0 + self.stage_remaining = None + self.articles_included = None + self.latest_event = None + self._live = _LiveState() + self._live.status = ReviewStatus.PENDING.value + values: dict = dict( + status=ReviewStatus.PENDING.value, + progress_step=0, + pipeline_log=[], + error=None, + completed_at=None, + result_json=None, + stage=None, + stage_idx=None, + stage_total=None, + stage_done_count=0, + stage_remaining=None, + articles_included=None, + latest_message=None, + ) + if clear_checkpoint: + self.checkpoint_json = None + self.last_completed_step = 0 + values["checkpoint_json"] = None + values["last_completed_step"] = 0 + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values(**values) + ) + await db.commit() + + async def save_checkpoint(self, state: dict) -> None: + step = state.get("last_completed_step", 0) + self.checkpoint_json = state + self.last_completed_step = step + try: + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values(checkpoint_json=state, last_completed_step=step) + ) + await db.commit() + except Exception as exc: + logger.warning("[checkpoint] save failed for %s: %s", self.review_id, exc) + + async def clear_checkpoint(self) -> None: + self.checkpoint_json = None + self.last_completed_step = 0 + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values(checkpoint_json=None, last_completed_step=0) + ) + await db.commit() + + # ── Plan confirmation gate ────────────────────────────────────────── + + def set_plan_pending(self, plan: ReviewPlan, iteration: int) -> None: + self.status = ReviewStatus.PLAN_PENDING + self._live.status = ReviewStatus.PLAN_PENDING.value + self._live.pending_plan = plan + self.pipeline_log.append( + f"[{datetime.now().strftime('%H:%M:%S')}] Plan ready for review (iteration {iteration})" + ) + try: + loop = asyncio.get_running_loop() + loop.create_task(self._persist_plan_pending(plan, iteration)) + except RuntimeError: + pass + + async def _persist_plan_pending(self, plan: ReviewPlan, iteration: int) -> None: + """Stash pending_plan in Postgres so SSE on any worker can replay it + and so submit_plan_response can write a response that the owning + worker's confirm_callback will pick up via DB poll.""" + try: + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values( + status=ReviewStatus.PLAN_PENDING.value, + pipeline_log=list(self.pipeline_log[-2000:]), + progress_step=self.progress_step, + pending_plan_json=plan.model_dump(), + pending_plan_iteration=iteration, + plan_response_json=None, + ) + ) + await db.commit() + except Exception as exc: + logger.warning("[_persist_plan_pending] %s: %s", self.review_id, exc) + + async def _clear_plan_state(self) -> None: + """Clear pending_plan/response columns on terminal transitions or + once the pipeline consumes the response.""" + try: + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values( + pending_plan_json=None, + pending_plan_iteration=None, + plan_response_json=None, + ) + ) + await db.commit() + except Exception as exc: + logger.warning("[_clear_plan_state] %s: %s", self.review_id, exc) + + async def write_plan_response(self, response: "bool | str") -> None: + """Stash the user's plan response in Postgres so the owning worker's + confirm_callback picks it up via its DB-poll fallback.""" + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values(plan_response_json={"response": response}) + ) + await db.commit() + + async def claim_plan_response(self) -> "bool | str | None": + """Atomically read+clear the plan_response_json column. Returns the + response if one was waiting, or None. Used by confirm_callback's + DB-poll loop to drain a response written by another worker.""" + async with async_session() as db: + row = await db.get(ReviewRow, self.review_id) + if not row or not row.plan_response_json: + return None + response = row.plan_response_json.get("response") + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values(plan_response_json=None) + ) + await db.commit() + return response + + async def request_cancel(self) -> None: + """Set the cancel_requested flag in Postgres so the owning worker's + progress and confirm callbacks pick it up. The status flip to + CANCELLED is performed by the pipeline coroutine itself, not here — + marking it cancelled from a non-owning worker would race with a + still-running pipeline.""" + async with async_session() as db: + await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == self.review_id) + .values(cancel_requested=True) + ) + await db.commit() + + async def is_cancel_requested(self) -> bool: + async with async_session() as db: + row = await db.get(ReviewRow, self.review_id) + return bool(row and row.cancel_requested) + + def signal_plan_notify(self) -> None: + self._live.plan_notify.set() + self._live.plan_notify.clear() + + def resolve_plan(self, response: "bool | str") -> None: + self._live.plan_response.clear() + self._live.plan_response.append(response) + self._live.plan_gate.set() + + +# ── Hydration helpers ──────────────────────────────────────────────────── + +def _row_to_session(row: ReviewRow) -> ReviewSession: + """Convert a DB row to a ReviewSession (without live SSE state).""" + protocol = None + if row.protocol_json: + protocol = ReviewProtocol.model_validate(row.protocol_json) + + result = None + if row.result_json: + is_compare = row.run_request_json and row.run_request_json.get("compare_mode") + if is_compare: + result = CompareReviewResult.model_validate(row.result_json) + else: + result = PRISMAReviewResult.model_validate(row.result_json) + + session = ReviewSession( + review_id=row.review_id, + status=ReviewStatus(row.status), + protocol=protocol, + result=result, + pipeline_log=list(row.pipeline_log) if row.pipeline_log else [], + progress_step=row.progress_step, + created_at=row.created_at.isoformat() if row.created_at else "", + completed_at=row.completed_at.isoformat() if row.completed_at else None, + error=row.error, + is_public=row.is_public, + share_to_cache=row.share_to_cache, + run_request=row.run_request_json, + checkpoint_json=row.checkpoint_json, + last_completed_step=row.last_completed_step or 0, + owner_email=row.owner_email, + ) + if row.stage is not None or row.stage_idx is not None: + session.stage = row.stage + session.stage_index = row.stage_idx + session.stage_total = row.stage_total + session.stage_done = row.stage_done_count or 0 + session.stage_remaining = row.stage_remaining + session.articles_included = row.articles_included + else: + session.rebuild_progress_state() + # Surface the persisted plan-pending state so SSE generators on a worker + # that doesn't own the runtime can still emit a plan_review event with + # the plan body. Live runtime state, when present, takes precedence. + session.pending_plan_db = row.pending_plan_json + session.pending_plan_iteration_db = row.pending_plan_iteration + return session + + +# ── ReviewStore ────────────────────────────────────────────────────────── + +class ReviewStore: + """PostgreSQL-only store. Runtime primitives live in _runtime; everything + else is in Postgres. There is no in-memory data cache — every get() reads + from the DB so status is always authoritative.""" + + def __init__(self): + self._runtime: dict[str, _LiveState] = {} + + async def create( + self, + protocol: ReviewProtocol, + run_request: Optional[dict] = None, + owner_email: Optional[str] = None, + ) -> ReviewSession: + now = datetime.now(timezone.utc) + async with async_session() as db: + count_result = await db.execute(select(ReviewRow.review_id)) + counter = len(count_result.all()) + 1 + + review_id = f"review_{counter:04d}_{now.strftime('%Y%m%d%H%M%S')}" + + row = ReviewRow( + review_id=review_id, + status=ReviewStatus.PENDING.value, + title=protocol.title, + protocol_json=protocol.model_dump(), + pipeline_log=[], + progress_step=0, + created_at=now, + is_public=False, + share_to_cache=False, + run_request_json=run_request, + owner_email=owner_email, + ) + async with async_session() as db: + db.add(row) + await db.commit() + + session = ReviewSession( + review_id=review_id, + status=ReviewStatus.PENDING, + protocol=protocol, + created_at=now.isoformat(), + share_to_cache=False, + run_request=run_request, + owner_email=owner_email, + ) + live = _LiveState(status=ReviewStatus.PENDING.value) + session._live = live + self._runtime[review_id] = live + return session + + async def get(self, review_id: str) -> Optional[ReviewSession]: + async with async_session() as db: + row = await db.get(ReviewRow, review_id) + if not row: + return None + session = _row_to_session(row) + if review_id in self._runtime: + live = self._runtime[review_id] + session._live = live + if live.status: + try: + session.status = ReviewStatus(live.status) + except ValueError: + pass + session.stage = live.stage if live.stage is not None else session.stage + session.stage_index = live.stage_index if live.stage_index is not None else session.stage_index + session.stage_total = live.stage_total if live.stage_total is not None else session.stage_total + session.stage_done = live.stage_done or session.stage_done + session.stage_remaining = live.stage_remaining if live.stage_remaining is not None else session.stage_remaining + session.articles_included = live.articles_included if live.articles_included is not None else session.articles_included + session.progress_step = max(session.progress_step, live.progress_step) + return session + + async def list_for_owner(self, owner_email: Optional[str]) -> list[ReviewSession]: + """List reviews. If owner_email is None → list ALL (admin view).""" + async with async_session() as db: + stmt = select(ReviewRow).order_by(ReviewRow.created_at.desc()) + if owner_email is not None: + stmt = stmt.where(ReviewRow.owner_email == owner_email) + result = await db.execute(stmt) + rows = result.scalars().all() + + sessions = [] + for row in rows: + s = _row_to_session(row) + if row.review_id in self._runtime: + live = self._runtime[row.review_id] + s._live = live + if live.status: + try: + s.status = ReviewStatus(live.status) + except ValueError: + pass + s.stage = live.stage if live.stage is not None else s.stage + s.stage_index = live.stage_index if live.stage_index is not None else s.stage_index + s.stage_total = live.stage_total if live.stage_total is not None else s.stage_total + s.stage_done = live.stage_done or s.stage_done + s.stage_remaining = live.stage_remaining if live.stage_remaining is not None else s.stage_remaining + s.articles_included = live.articles_included if live.articles_included is not None else s.articles_included + s.progress_step = max(s.progress_step, live.progress_step) + sessions.append(s) + return sessions + + async def delete(self, review_id: str) -> bool: + self._runtime.pop(review_id, None) + async with async_session() as db: + result = await db.execute( + sa_delete(ReviewRow).where(ReviewRow.review_id == review_id) + ) + await db.commit() + return result.rowcount > 0 + + async def set_visibility(self, review_id: str, is_public: bool) -> bool: + async with async_session() as db: + result = await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == review_id) + .values(is_public=is_public, share_to_cache=is_public) + ) + await db.commit() + return result.rowcount > 0 + + async def set_cache_sharing(self, review_id: str, share_to_cache: bool) -> bool: + async with async_session() as db: + result = await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.review_id == review_id) + .values(share_to_cache=share_to_cache) + ) + await db.commit() + return result.rowcount > 0 + + def evict(self, review_id: str): + self._runtime.pop(review_id, None) + + +async def fix_stuck_reviews() -> int: + """Mark in-progress reviews as FAILED at server startup.""" + async with async_session() as db: + result = await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.status.in_(["running", "pending", "plan_pending"])) + .values( + status=ReviewStatus.FAILED.value, + error="Server restarted while review was in progress — please retry.", + ) + ) + await db.commit() + return result.rowcount + + +# Singleton +review_store = ReviewStore() From 405a64f191c01ecb05971000c3b3f25a3cb456ec Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Thu, 30 Apr 2026 16:29:03 -0400 Subject: [PATCH 04/22] superadmin access everything --- usermanagement_service/core/database.py | 29 ++++++++++++++- .../core/routers/user_management.py | 35 +++++++------------ 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/usermanagement_service/core/database.py b/usermanagement_service/core/database.py index 660d7ae..d64cd19 100644 --- a/usermanagement_service/core/database.py +++ b/usermanagement_service/core/database.py @@ -1534,6 +1534,33 @@ async def check_access( profile_id: Optional[int], role_names: List[str], ) -> tuple[bool, str]: + """Decide whether a user may access `page_key`. + + Roles are re-read from the DB when `profile_id` is supplied — the JWT's + baked-in `roles` claim becomes stale the moment an admin grants or + revokes a role mid-session, and that staleness was the source of the + "I have access but the tool still says disabled" glitch. The fresh DB + read is one indexed SELECT per call; it makes role grants take effect + on the next request rather than on next sign-in. + + SuperAdmin always passes — that role is the platform's protected + oversight tier and explicitly bypasses per-page rules. Regular Admin + is *not* bypassed: admins can grant themselves access through + /admin/page-access if they want it, the same as any other role. + """ + # Fresh roles from the DB — supersede the JWT claim when both are + # available. Falls back to the supplied `role_names` for code paths + # that don't carry a profile_id (e.g. unauthenticated public pages). + effective_roles = list(role_names) + if profile_id is not None: + db_roles = await user_role_repo.get_user_role_names(session, profile_id) + if db_roles: + effective_roles = db_roles + + # SuperAdmin bypass only — see docstring. + if "SuperAdmin" in effective_roles: + return (True, "role") + page = await self.get_by_key(session, page_key) if page is None: return (False, "not_found") @@ -1544,7 +1571,7 @@ async def check_access( if profile_id in user_ids: return (True, "user_override") allowed_roles = await self.get_allowed_roles(session, page.id) - if any(r in allowed_roles for r in role_names): + if any(r in allowed_roles for r in effective_roles): return (True, "role") return (False, "denied") diff --git a/usermanagement_service/core/routers/user_management.py b/usermanagement_service/core/routers/user_management.py index d69e8c8..3e5fc77 100644 --- a/usermanagement_service/core/routers/user_management.py +++ b/usermanagement_service/core/routers/user_management.py @@ -301,7 +301,10 @@ async def create_profile(jwt_user: Annotated[dict, Depends(get_current_user)], organizations_data = profile.organizations or [] education_data = profile.education or [] expertise_data = profile.expertise_areas or [] - roles_data = profile.roles or [] + # `roles` from the request body is ignored on the public endpoint — + # role assignment is admin-only via /api/admin/users/{id}/roles. + # Defaults are seeded by the OAuth callback (Curator) and the + # bootstrap superadmin allowlist (SuperAdmin + Admin). profile_data = profile.dict(exclude={'countries', 'organizations', 'education', 'expertise_areas', 'roles'}) profile_instance = await user_profile_repo.create_or_update_profile( @@ -353,15 +356,8 @@ async def create_profile(jwt_user: Annotated[dict, Depends(get_current_user)], years_experience=expertise_data.years_experience ) - # Add roles if provided - for role_data in roles_data: - await user_role_repo.assign_role( - session=session, - profile_id=profile_instance.id, - role=role_data.role, - is_active=role_data.is_active, - expires_at=role_data.expires_at - ) + # NOTE: roles are intentionally NOT applied from the request body + # on this public endpoint. Use /api/admin/users/{id}/roles instead. # Log activity with enhanced information await user_activity_repo.log_activity( @@ -451,7 +447,8 @@ async def update_profile( organizations_data = update_data.pop('organizations', None) education_data = update_data.pop('education', None) expertise_data = update_data.pop('expertise_areas', None) - roles_data = update_data.pop('roles', None) + # `roles` is dropped from the public update — see note below. + update_data.pop('roles', None) # Update profile with new data for key, value in update_data.items(): @@ -517,17 +514,11 @@ async def update_profile( years_experience=exp_data['years_experience'] ) - # Update roles if provided - if roles_data is not None: - logger.info(f"Updating roles: {roles_data}") - await user_role_repo.remove_all_roles(session, profile.id) - for role_data in roles_data: - await user_role_repo.assign_role( - session=session, - profile_id=profile.id, - role=role_data['role'], - is_active=role_data['is_active'] - ) + # NOTE: roles are intentionally NOT updated from this endpoint. + # Self-service role assignment was a privilege-escalation hole — + # any user could grant themselves Admin. Roles are now managed + # exclusively via /api/admin/users/{id}/roles (require_admin). + # The default role (Curator) is seeded by the OAuth callback. # Log activity with enhanced information await user_activity_repo.log_activity( From a9fd57bd8943c8eb76dee7901992ac3e5f4506a8 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Thu, 30 Apr 2026 16:29:18 -0400 Subject: [PATCH 05/22] updated network --- docker-compose.unified.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docker-compose.unified.yml b/docker-compose.unified.yml index e271967..feb90ea 100644 --- a/docker-compose.unified.yml +++ b/docker-compose.unified.yml @@ -163,4 +163,9 @@ volumes: networks: brainkb-network: external: true + # Pin the network name so it stays the same regardless of the compose + # project (folder) name. The brainkb-ui stack joins the same network by + # this fixed name. Create it once before bringing up either stack: + # docker network create brainkb-network + name: brainkb-network From 634d032483dfb7895b815d78bf94763838ca8e3b Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Thu, 30 Apr 2026 16:29:58 -0400 Subject: [PATCH 06/22] updated docker to include python 3.11 --- Dockerfile.unified | 5 +++-- ml_service/Dockerfile | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Dockerfile.unified b/Dockerfile.unified index 68e363c..e2f9fd8 100644 --- a/Dockerfile.unified +++ b/Dockerfile.unified @@ -5,7 +5,8 @@ # Oxigraph will be run as a separate service in docker-compose or use the fallback script # Main build stage -FROM python:3.10-slim +# Python 3.11+ required: synthscholar (ml_service) depends on it. +FROM python:3.11-slim # Set metadata LABEL project="BrainyPedia" \ @@ -46,7 +47,7 @@ COPY ml_service/ /app/ml_service/ WORKDIR /app/ml_service RUN pip install --use-deprecated=legacy-resolver "structsense==0.0.4" || \ pip install --use-deprecated=legacy-resolver --no-deps "structsense==0.0.4" && \ - pip install -r requirements.txt + pip install --use-deprecated=legacy-resolver -r requirements.txt # Copy usermanagement_service COPY usermanagement_service/ /app/usermanagement_service/ diff --git a/ml_service/Dockerfile b/ml_service/Dockerfile index d4527d2..8771c1b 100644 --- a/ml_service/Dockerfile +++ b/ml_service/Dockerfile @@ -1,4 +1,4 @@ -FROM tiangolo/uvicorn-gunicorn-fastapi:python3.10 +FROM tiangolo/uvicorn-gunicorn-fastapi:python3.11 # Set metadata for the image LABEL project="BrainyPedia" \ From 48845ef942507a8cb6a0ccae2c2b085c8a77f9e2 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Thu, 30 Apr 2026 16:31:19 -0400 Subject: [PATCH 07/22] updated readme to include synthscholar + deployment information --- README.unified-docker.md | 8 +++++++- readme.md | 9 +++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/README.unified-docker.md b/README.unified-docker.md index fe9cf69..423b21d 100644 --- a/README.unified-docker.md +++ b/README.unified-docker.md @@ -4,11 +4,16 @@ This Dockerfile deploys the following BrainKB backend services: - **APItokenmanager** (Django) - Port 8000 - **query_service** (FastAPI) - Port 8010 - **ml_service** (FastAPI) - Port 8007 + - Hosts the **SynthScholar** PRISMA literature-review pipeline at + `/api/synth-scholar/*` alongside the existing structsense routes. + SynthScholar reuses the unified `brainkb` Postgres database (no extra + DSN config required); see the SynthScholar block in `env.template` for + the optional API keys it consumes (OpenRouter / NCBI / Semantic Scholar + / CORE). - **usermanagement_service** (FastAPI) - Port 8004 - **oxigraph** (SPARQL Database) - Port 7878 **Services NOT included in this unified deployment (deploy separately):** -- **brainkb-ui** - The UI is not included. See [SETUP_UI.md](SETUP_UI.md) for UI deployment instructions. - **chat_service** - Deploy separately using `chat_service/docker-compose-prod.yml` or `chat_service/docker-compose-dev.yml`. ## Quick Start with Docker Compose (Recommended) @@ -549,6 +554,7 @@ All environment variables are loaded from the `.env` file in the project root. T - **User Management OAuth**: `USERMANAGEMENT_SERVICE_JWT_SECRET_KEY`, `USERMANAGEMENT_PUBLIC_BASE_URL`, `USERMANAGEMENT_FRONTEND_CALLBACK_URL`, `USERMANAGEMENT_OAUTH_TOKEN_ENC_KEY`, `USERMANAGEMENT_BOOTSTRAP_SUPERADMIN_EMAILS`, `GITHUB_CLIENT_ID/SECRET`, `ORCID_CLIENT_ID/SECRET`, `GLOBUS_CLIENT_ID/SECRET` - **Ollama**: `OLLAMA_MODEL`, `OLLAMA_PORT`, `OLLAMA_API_ENDPOINT` - **ML Service**: `MONGO_DB_URL`, `WEAVIATE_*`, etc. +- **SynthScholar** (PRISMA reviews, lives in ml_service): `OPENROUTER_API_KEY` (operator fallback — UI normally forwards a per-user or admin-shared key), `NCBI_API_KEY` (optional, raises PubMed rate limits), `SEMANTIC_SCHOLAR_API_KEY` / `CORE_API_KEY` / `SYNTHSCHOLAR_EMAIL` (all optional). Database is shared — no separate DSN. - **Query Service**: `GRAPHDATABASE_*`, `RAPID_RELEASE_FILE` See `env.template` for the complete list of all available environment variables with descriptions. diff --git a/readme.md b/readme.md index 2a673a1..1ac77a9 100644 --- a/readme.md +++ b/readme.md @@ -12,7 +12,6 @@ BrainKB serves as a knowledge base platform that provides scientists worldwide w - [GraphDB](graphdb) The docker compose configuration of GraphDB. - [JWT User & Scope Manager](APItokenmanager) A toolkit to manage JWT users and their permissions for API endpoint access. - [Query Service](query_service) Provides the functionalities for querying (and updating) the knowledge graphs from the graph database. -- [RabbitMQ](rabbit-mq) The docker compose configuration of RabbitMQ. - [SPARQL Queries](sparql_queries) List of SPARQL queries tested or used in BrainKB. ## Running @@ -50,9 +49,15 @@ Once started, services are accessible at: - **Query Service (FastAPI)**: `http://localhost:8010/` - Now supports ingestion than just querying. - **ML Service (FastAPI)**: `http://localhost:8007/` - - Integrates StructSense + - Integrates StructSense (multi-agent NER + structured-resource extraction). + - Hosts **SynthScholar** at `/api/synth-scholar/*` — PRISMA-guided literature + review pipeline (search → screening → critical appraisal → synthesis, + with SSE progress streaming and markdown / JSON / RDF exports). Reuses + the unified `brainkb` Postgres database. See `env.template` for the + optional API keys (OpenRouter, NCBI, Semantic Scholar, CORE). - **Oxigraph SPARQL**: `http://localhost:7878/` (password protected) graph database - **pgAdmin**: `http://localhost:5051/` +- **User management service**: http://localhost:8004 ## Documentation From 80c34bfc949b86a003c55fee7db4079a323197b6 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Thu, 30 Apr 2026 17:01:08 -0400 Subject: [PATCH 08/22] missing env variables added. --- env.template | 2 ++ 1 file changed, 2 insertions(+) diff --git a/env.template b/env.template index e0a4e3b..1d2807d 100644 --- a/env.template +++ b/env.template @@ -123,6 +123,8 @@ PGADMIN_PORT=5051 # ---------------------------------------------------------------------------- OXIGRAPH_USER=admin OXIGRAPH_PASSWORD=your_oxigraph_password_change_this +OXIGRAPH_DATA_PATH=/ +OXIGRAPH_TMP_PATH=/ # Oxigraph Data Storage Configuration # For local development: Leave commented to use Docker named volumes (recommended) From f5df73ccc25664631b47f6deda36890d69617585 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Thu, 30 Apr 2026 18:31:38 -0400 Subject: [PATCH 09/22] added missing files --- usermanagement_service/core/oauth/__init__.py | 4 + usermanagement_service/core/oauth/base.py | 59 ++++ .../core/oauth/providers.py | 262 ++++++++++++++++++ 3 files changed, 325 insertions(+) create mode 100644 usermanagement_service/core/oauth/__init__.py create mode 100644 usermanagement_service/core/oauth/base.py create mode 100644 usermanagement_service/core/oauth/providers.py diff --git a/usermanagement_service/core/oauth/__init__.py b/usermanagement_service/core/oauth/__init__.py new file mode 100644 index 0000000..cf1d784 --- /dev/null +++ b/usermanagement_service/core/oauth/__init__.py @@ -0,0 +1,4 @@ +from core.oauth.base import OAuthProvider, OAuthUserInfo, TokenResponse +from core.oauth.providers import get_provider, REGISTRY + +__all__ = ["OAuthProvider", "OAuthUserInfo", "TokenResponse", "get_provider", "REGISTRY"] diff --git a/usermanagement_service/core/oauth/base.py b/usermanagement_service/core/oauth/base.py new file mode 100644 index 0000000..14ce861 --- /dev/null +++ b/usermanagement_service/core/oauth/base.py @@ -0,0 +1,59 @@ +"""OAuth provider abstraction. Each provider knows: + - how to build its authorize URL (with PKCE when supported) + - how to exchange an authorization code for tokens + - how to fetch user info and normalize it into OAuthUserInfo + +The router (core/routers/oauth.py) is provider-agnostic.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Dict, Any + + +@dataclass +class TokenResponse: + access_token: str + refresh_token: Optional[str] = None + expires_in: Optional[int] = None # seconds + id_token: Optional[str] = None + scope: Optional[str] = None + raw: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class OAuthUserInfo: + """Normalized user profile from an OAuth provider.""" + provider: str + provider_user_id: str + email: Optional[str] + name: Optional[str] + # Provider-specific identifiers we want to persist to UserProfile. + orcid_id: Optional[str] = None + github_username: Optional[str] = None + # Any other raw fields; stored in Web_oauth_identity.raw_profile. + raw: Dict[str, Any] = field(default_factory=dict) + + +class OAuthProvider: + """Base class. Subclasses override the four hooks below.""" + + name: str = "" + supports_pkce: bool = False + default_scopes: str = "" + + def __init__(self, client_id: str, client_secret: str): + self.client_id = client_id + self.client_secret = client_secret + + def is_configured(self) -> bool: + return bool(self.client_id and self.client_secret) + + def authorize_url(self, *, redirect_uri: str, state: str, code_challenge: Optional[str] = None) -> str: + raise NotImplementedError + + async def exchange_code(self, *, code: str, redirect_uri: str, code_verifier: Optional[str] = None) -> TokenResponse: + raise NotImplementedError + + async def fetch_userinfo(self, *, access_token: str, token_response: TokenResponse) -> OAuthUserInfo: + raise NotImplementedError diff --git a/usermanagement_service/core/oauth/providers.py b/usermanagement_service/core/oauth/providers.py new file mode 100644 index 0000000..2f5ceec --- /dev/null +++ b/usermanagement_service/core/oauth/providers.py @@ -0,0 +1,262 @@ +"""Concrete OAuth providers: GitHub, ORCID, Globus. + +All three follow the OAuth 2.0 authorization-code flow. ORCID and Globus also +support OIDC userinfo endpoints and PKCE; GitHub does not support PKCE and +uses its own /user endpoint for profile data.""" + +from __future__ import annotations + +import logging +import urllib.parse +from typing import Optional, Dict + +import httpx + +from core.configuration import config +from core.oauth.base import OAuthProvider, TokenResponse, OAuthUserInfo + +logger = logging.getLogger(__name__) + +_TIMEOUT = httpx.Timeout(10.0, connect=5.0) + + +class GitHubProvider(OAuthProvider): + name = "github" + supports_pkce = False + default_scopes = "read:user user:email" + + AUTHORIZE = "https://github.com/login/oauth/authorize" + TOKEN = "https://github.com/login/oauth/access_token" + USER = "https://api.github.com/user" + EMAILS = "https://api.github.com/user/emails" + + def authorize_url(self, *, redirect_uri: str, state: str, code_challenge: Optional[str] = None) -> str: + params = { + "client_id": self.client_id, + "redirect_uri": redirect_uri, + "scope": self.default_scopes, + "state": state, + "allow_signup": "true", + } + return f"{self.AUTHORIZE}?{urllib.parse.urlencode(params)}" + + async def exchange_code(self, *, code: str, redirect_uri: str, code_verifier: Optional[str] = None) -> TokenResponse: + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "redirect_uri": redirect_uri, + } + headers = {"Accept": "application/json"} + async with httpx.AsyncClient(timeout=_TIMEOUT) as client: + resp = await client.post(self.TOKEN, data=data, headers=headers) + resp.raise_for_status() + payload = resp.json() + if "access_token" not in payload: + raise ValueError(f"GitHub token exchange failed: {payload}") + return TokenResponse( + access_token=payload["access_token"], + refresh_token=payload.get("refresh_token"), + expires_in=payload.get("expires_in"), + scope=payload.get("scope"), + raw=payload, + ) + + async def fetch_userinfo(self, *, access_token: str, token_response: TokenResponse) -> OAuthUserInfo: + headers = {"Authorization": f"Bearer {access_token}", "Accept": "application/vnd.github+json"} + async with httpx.AsyncClient(timeout=_TIMEOUT) as client: + u = await client.get(self.USER, headers=headers) + u.raise_for_status() + user = u.json() + email = user.get("email") + if not email: + # user.email can be null if hidden; fetch verified primary from /user/emails + e = await client.get(self.EMAILS, headers=headers) + if e.status_code == 200: + for row in e.json(): + if row.get("primary") and row.get("verified"): + email = row.get("email") + break + return OAuthUserInfo( + provider=self.name, + provider_user_id=str(user["id"]), + email=email, + name=user.get("name") or user.get("login"), + github_username=user.get("login"), + raw=user, + ) + + +class ORCIDProvider(OAuthProvider): + """ORCID OIDC. Uses openid+email+profile scopes to get the userinfo endpoint.""" + + name = "orcid" + supports_pkce = True + default_scopes = "openid email profile" + + def __init__(self, client_id: str, client_secret: str, base_url: str = "https://orcid.org"): + super().__init__(client_id, client_secret) + self.base_url = base_url.rstrip("/") + + @property + def authorize_endpoint(self) -> str: + return f"{self.base_url}/oauth/authorize" + + @property + def token_endpoint(self) -> str: + return f"{self.base_url}/oauth/token" + + @property + def userinfo_endpoint(self) -> str: + return f"{self.base_url}/oauth/userinfo" + + def authorize_url(self, *, redirect_uri: str, state: str, code_challenge: Optional[str] = None) -> str: + params = { + "client_id": self.client_id, + "response_type": "code", + "scope": self.default_scopes, + "redirect_uri": redirect_uri, + "state": state, + } + if code_challenge: + params["code_challenge"] = code_challenge + params["code_challenge_method"] = "S256" + return f"{self.authorize_endpoint}?{urllib.parse.urlencode(params)}" + + async def exchange_code(self, *, code: str, redirect_uri: str, code_verifier: Optional[str] = None) -> TokenResponse: + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + } + if code_verifier: + data["code_verifier"] = code_verifier + headers = {"Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded"} + async with httpx.AsyncClient(timeout=_TIMEOUT) as client: + resp = await client.post(self.token_endpoint, data=data, headers=headers) + resp.raise_for_status() + payload = resp.json() + if "access_token" not in payload: + raise ValueError(f"ORCID token exchange failed: {payload}") + return TokenResponse( + access_token=payload["access_token"], + refresh_token=payload.get("refresh_token"), + expires_in=payload.get("expires_in"), + id_token=payload.get("id_token"), + scope=payload.get("scope"), + raw=payload, + ) + + async def fetch_userinfo(self, *, access_token: str, token_response: TokenResponse) -> OAuthUserInfo: + headers = {"Authorization": f"Bearer {access_token}", "Accept": "application/json"} + async with httpx.AsyncClient(timeout=_TIMEOUT) as client: + resp = await client.get(self.userinfo_endpoint, headers=headers) + resp.raise_for_status() + user = resp.json() + # ORCID userinfo returns `sub` = the ORCID iD + orcid = user.get("sub") + given = user.get("given_name") or "" + family = user.get("family_name") or "" + full_name = user.get("name") or (f"{given} {family}".strip() or None) + return OAuthUserInfo( + provider=self.name, + provider_user_id=orcid, + email=user.get("email"), + name=full_name, + orcid_id=orcid, + raw=user, + ) + + +class GlobusProvider(OAuthProvider): + """Globus Auth. Uses OIDC with PKCE.""" + + name = "globus" + supports_pkce = True + default_scopes = "openid email profile" + + AUTHORIZE = "https://auth.globus.org/v2/oauth2/authorize" + TOKEN = "https://auth.globus.org/v2/oauth2/token" + USERINFO = "https://auth.globus.org/v2/oauth2/userinfo" + + def authorize_url(self, *, redirect_uri: str, state: str, code_challenge: Optional[str] = None) -> str: + params = { + "client_id": self.client_id, + "response_type": "code", + "scope": self.default_scopes, + "redirect_uri": redirect_uri, + "state": state, + "access_type": "online", + } + if code_challenge: + params["code_challenge"] = code_challenge + params["code_challenge_method"] = "S256" + return f"{self.AUTHORIZE}?{urllib.parse.urlencode(params)}" + + async def exchange_code(self, *, code: str, redirect_uri: str, code_verifier: Optional[str] = None) -> TokenResponse: + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + } + if code_verifier: + data["code_verifier"] = code_verifier + auth = (self.client_id, self.client_secret) + async with httpx.AsyncClient(timeout=_TIMEOUT) as client: + resp = await client.post(self.TOKEN, data=data, auth=auth) + resp.raise_for_status() + payload = resp.json() + if "access_token" not in payload: + raise ValueError(f"Globus token exchange failed: {payload}") + return TokenResponse( + access_token=payload["access_token"], + refresh_token=payload.get("refresh_token"), + expires_in=payload.get("expires_in"), + id_token=payload.get("id_token"), + scope=payload.get("scope"), + raw=payload, + ) + + async def fetch_userinfo(self, *, access_token: str, token_response: TokenResponse) -> OAuthUserInfo: + headers = {"Authorization": f"Bearer {access_token}"} + async with httpx.AsyncClient(timeout=_TIMEOUT) as client: + resp = await client.get(self.USERINFO, headers=headers) + resp.raise_for_status() + user = resp.json() + return OAuthUserInfo( + provider=self.name, + provider_user_id=user.get("sub"), + email=user.get("email") or user.get("preferred_username"), + name=user.get("name"), + raw=user, + ) + + +def _build_registry() -> Dict[str, OAuthProvider]: + return { + "github": GitHubProvider( + client_id=config.github_client_id or "", + client_secret=config.github_client_secret or "", + ), + "orcid": ORCIDProvider( + client_id=config.orcid_client_id or "", + client_secret=config.orcid_client_secret or "", + base_url=config.orcid_base_url, + ), + "globus": GlobusProvider( + client_id=config.globus_client_id or "", + client_secret=config.globus_client_secret or "", + ), + } + + +REGISTRY: Dict[str, OAuthProvider] = _build_registry() + + +def get_provider(name: str) -> OAuthProvider: + p = REGISTRY.get(name.lower()) + if p is None: + raise KeyError(f"Unknown OAuth provider: {name}") + return p From 0057d8632769aacede68cad0975499243fd82f83 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Thu, 30 Apr 2026 18:36:57 -0400 Subject: [PATCH 10/22] missing codes --- usermanagement_service/core/bootstrap.py | 207 +++++++ usermanagement_service/core/routers/access.py | 102 +++ usermanagement_service/core/routers/admin.py | 586 ++++++++++++++++++ usermanagement_service/core/routers/oauth.py | 322 ++++++++++ 4 files changed, 1217 insertions(+) create mode 100644 usermanagement_service/core/bootstrap.py create mode 100644 usermanagement_service/core/routers/access.py create mode 100644 usermanagement_service/core/routers/admin.py create mode 100644 usermanagement_service/core/routers/oauth.py diff --git a/usermanagement_service/core/bootstrap.py b/usermanagement_service/core/bootstrap.py new file mode 100644 index 0000000..6f1ef89 --- /dev/null +++ b/usermanagement_service/core/bootstrap.py @@ -0,0 +1,207 @@ +"""Startup bootstrap: seed baseline roles, permissions, and page access rows, and +promote configured superadmin emails to the SuperAdmin + Admin roles. + +Idempotent — safe to run on every startup.""" + +from __future__ import annotations + +import logging +from typing import List + +from sqlalchemy import select + +from core.configuration import config +from core.database import ( + user_db_manager, user_profile_repo, user_role_repo, + available_role_repo, permission_repo, page_access_repo, role_permission_repo, +) +from core.models.database_models import AvailableRole, Permission +from core.models.user import UserRoleEnum + +logger = logging.getLogger(__name__) + + +# Baseline roles. Mirrors UserRoleEnum but adds the category so the admin UI can group them. +_BASELINE_ROLES: list[tuple[str, str, str]] = [ + ("SuperAdmin", "Admin", "Super administrator — protected. Cannot be banned, deleted, or have this role stripped via the UI/API."), + ("Admin", "Admin", "Platform administrator — manages roles, permissions, users, and page access."), + ("Submitter", "Content", "Submits new content for curation."), + ("Annotator", "Content", "Annotates submitted content."), + ("Mapper", "Content", "Maps entities across ontologies."), + ("Curator", "Content", "Default role for scientific contributors."), + ("Reviewer", "Quality", "Reviews submitted content."), + ("Validator", "Quality", "Validates curated content."), + ("Conflict Resolver", "Quality", "Resolves conflicting curations."), + ("Knowledge Contributor", "Knowledge", "Contributes knowledge artifacts."), + ("Evidence Tracer", "Knowledge", "Links evidence to claims."), + ("Provenance Tracker", "Knowledge", "Tracks provenance metadata."), + ("Moderator", "Community", "Moderates community discussions."), + ("Ambassador", "Community", "Community ambassador."), +] + +# Baseline permissions — resource:action. Fine-grained permissions used by admin endpoints. +_BASELINE_PERMISSIONS: list[tuple[str, str, str, str]] = [ + ("user.read", "user", "read", "View user profiles"), + ("user.update", "user", "update", "Update user profiles"), + ("user.delete", "user", "delete", "Delete user profiles"), + ("user.list", "user", "list", "List all users"), + ("role.read", "role", "read", "View roles"), + ("role.manage", "role", "manage", "Create/update/delete roles"), + ("role.assign", "role", "assign", "Assign roles to users"), + ("permission.read", "permission", "read", "View permissions"), + ("permission.manage", "permission", "manage", "Create/update/delete permissions"), + ("page_access.read", "page_access", "read", "View page access rules"), + ("page_access.manage", "page_access", "manage", "Create/update/delete page access rules"), + ("oauth.identity.read", "oauth_identity", "read", "View OAuth identities"), +] + + +async def seed_roles() -> None: + async with user_db_manager.get_async_session() as session: + existing = await available_role_repo.get_active_roles(session) + existing_names = {r.name for r in existing} + for name, category, description in _BASELINE_ROLES: + if name in existing_names: + continue + session.add(AvailableRole(name=name, category=category, description=description, is_active=True)) + await session.commit() + + +async def seed_permissions() -> None: + async with user_db_manager.get_async_session() as session: + existing = await permission_repo.list_all(session) + existing_names = {p.name for p in existing} + for name, resource, action, description in _BASELINE_PERMISSIONS: + if name in existing_names: + continue + session.add(Permission(name=name, resource=resource, action=action, description=description)) + await session.commit() + + +async def grant_admin_all_permissions() -> None: + """Ensure the Admin and SuperAdmin roles each own every permission + currently in the registry. SuperAdmin shadows Admin for permissions — + its only extra power is being un-bannable / un-deletable.""" + async with user_db_manager.get_async_session() as session: + perms = await permission_repo.list_all(session) + perm_ids = [p.id for p in perms] + for role_name in ("Admin", "SuperAdmin"): + result = await session.execute(select(AvailableRole).where(AvailableRole.name == role_name)) + role = result.scalar_one_or_none() + if not role: + logger.warning(f"{role_name} role missing during permission grant bootstrap") + continue + await role_permission_repo.set_role_permissions(session, role.id, perm_ids) + await session.commit() + + +async def seed_default_page_access() -> None: + """Seed a couple of sensible defaults so the UI has something to reference. + The admin can edit these via the admin UI afterward.""" + defaults = [ + {"page_key": "admin.dashboard", "description": "Admin dashboard", "is_public": False, "allowed_roles": ["Admin"], "allowed_emails": []}, + {"page_key": "admin.users", "description": "User management", "is_public": False, "allowed_roles": ["Admin"], "allowed_emails": []}, + {"page_key": "admin.roles", "description": "Role & permission management", "is_public": False, "allowed_roles": ["Admin"], "allowed_emails": []}, + {"page_key": "admin.page_access", "description": "Page access management", "is_public": False, "allowed_roles": ["Admin"], "allowed_emails": []}, + {"page_key": "curate.submit", "description": "Submit content for curation", "is_public": False, "allowed_roles": ["Admin", "Submitter", "Curator"], "allowed_emails": []}, + {"page_key": "curate.review", "description": "Review submitted content", "is_public": False, "allowed_roles": ["Admin", "Reviewer", "Validator"], "allowed_emails": []}, + # SynthScholar (PRISMA literature review). Seeded for Curator (the + # default role assigned at first OAuth login) so any signed-in user + # can run a review without an extra admin-grant step. Tighten this + # in /admin/page-access if access should be more restrictive. + {"page_key": "tools.synth-scholar", "description": "SynthScholar — PRISMA literature review", "is_public": False, "allowed_roles": ["Admin", "Curator"], "allowed_emails": []}, + {"page_key": "home", "description": "Public landing page", "is_public": True, "allowed_roles": [], "allowed_emails": []}, + ] + async with user_db_manager.get_async_session() as session: + existing = await page_access_repo.list_all(session) + existing_keys = {p.page_key for p in existing} + for d in defaults: + if d["page_key"] in existing_keys: + continue + profile_ids: List[int] = [] + for email in d["allowed_emails"]: + p = await user_profile_repo.get_by_email(session, email) + if p: + profile_ids.append(p.id) + await page_access_repo.upsert_with_members( + session, + page_key=d["page_key"], + description=d["description"], + is_public=d["is_public"], + allowed_role_names=d["allowed_roles"], + allowed_profile_ids=profile_ids, + ) + await session.commit() + + +async def promote_bootstrap_superadmins() -> None: + """Assign the SuperAdmin and Admin roles to every email in + USERMANAGEMENT_BOOTSTRAP_SUPERADMIN_EMAILS that already has a UserProfile. + Emails without a profile are ignored — the require_admin dependency honors + the bootstrap list too, so they can log in and create their profile first. + + SuperAdmin is the immutable marker; Admin grants the actual permissions. + Seeding both keeps the page-access RBAC (which checks for "Admin") working + without special-casing SuperAdmin everywhere.""" + emails = config.bootstrap_superadmin_emails + if not emails: + return + async with user_db_manager.get_async_session() as session: + for email in emails: + profile = await user_profile_repo.get_by_email(session, email) + if not profile: + logger.info(f"Bootstrap superadmin {email} has no profile yet — require_admin will still accept them via env allowlist.") + continue + for role in ("Admin", "SuperAdmin"): + await user_role_repo.assign_role(session, profile_id=profile.id, role=role, is_active=True) + await session.commit() + + +async def apply_inline_schema_migrations() -> None: + """Idempotent ALTER TABLE migrations for columns that were added after + the initial create_all(). create_all(checkfirst=True) only creates + *missing tables*, never adds columns to existing ones — so any new + column on an existing model needs an explicit ALTER here. + + Each migration uses ADD COLUMN IF NOT EXISTS (Postgres ≥ 9.6) so + re-running on an already-migrated DB is a no-op. + + When you bump the schema, append a one-liner here. Don't drop or + rename existing columns from this function — those are destructive and + belong in a real migration tool with versioning.""" + from sqlalchemy import text as _text + statements = [ + # User ban support (per-user only; IP bans deferred until a WAF + # decision is made — see brainkb-ui/README.md). + 'ALTER TABLE "Web_user_profile" ADD COLUMN IF NOT EXISTS is_banned BOOLEAN NOT NULL DEFAULT FALSE', + 'ALTER TABLE "Web_user_profile" ADD COLUMN IF NOT EXISTS banned_at TIMESTAMP', + 'ALTER TABLE "Web_user_profile" ADD COLUMN IF NOT EXISTS banned_by INTEGER REFERENCES "Web_user_profile"(id) ON DELETE SET NULL', + 'ALTER TABLE "Web_user_profile" ADD COLUMN IF NOT EXISTS ban_reason TEXT', + ] + async with user_db_manager.get_async_session() as session: + for stmt in statements: + try: + await session.execute(_text(stmt)) + except Exception as e: + # Don't abort other migrations because one failed — log and + # continue. A failure here typically means the column type + # changed and the running schema disagrees with the model; + # surface that in the logs without bringing the boot down. + logger.warning(f"Inline migration failed (continuing): {stmt!r} → {e}") + await session.commit() + + +async def run_bootstrap() -> None: + """Run all bootstrap steps. Order matters: schema migrations before + seeding (later code may reference the new columns), roles before + role_permissions, admins last.""" + try: + await apply_inline_schema_migrations() + await seed_roles() + await seed_permissions() + await grant_admin_all_permissions() + await seed_default_page_access() + await promote_bootstrap_superadmins() + logger.info("Bootstrap complete") + except Exception as e: + logger.error(f"Bootstrap failed: {e}") diff --git a/usermanagement_service/core/routers/access.py b/usermanagement_service/core/routers/access.py new file mode 100644 index 0000000..40a1349 --- /dev/null +++ b/usermanagement_service/core/routers/access.py @@ -0,0 +1,102 @@ +"""Page access check — used by the UI to decide if the current user can view a page. + +The UI calls GET /api/access/page/{page_key}: + - if the page is public, returns {allowed: true, reason: "public"} even for anon + - otherwise requires Bearer auth and checks the user's roles + per-user overrides + +The UI can also call POST /api/access/pages with a list of page_keys to batch-check +(e.g. to decide which nav items to render).""" + +from __future__ import annotations + +from typing import Annotated, List, Optional + +from fastapi import APIRouter, Depends + +from core.database import user_db_manager, page_access_repo +from core.models.user import PageAccessCheck +from core.security import get_current_user_optional + +router = APIRouter() + + +@router.get("/access/page/{page_key}", response_model=PageAccessCheck) +async def check_page_access( + page_key: str, + current_user: Annotated[Optional[dict], Depends(get_current_user_optional)] = None, +): + profile_id = current_user.get("profile_id") if current_user else None + roles = current_user.get("roles", []) if current_user else [] + async with user_db_manager.get_async_session() as session: + allowed, reason = await page_access_repo.check_access( + session, page_key=page_key, profile_id=profile_id, role_names=roles + ) + return PageAccessCheck(page_key=page_key, allowed=allowed, reason=reason) + + +@router.post("/access/pages", response_model=List[PageAccessCheck]) +async def check_page_access_batch( + page_keys: List[str], + current_user: Annotated[Optional[dict], Depends(get_current_user_optional)] = None, +): + profile_id = current_user.get("profile_id") if current_user else None + roles = current_user.get("roles", []) if current_user else [] + results: List[PageAccessCheck] = [] + async with user_db_manager.get_async_session() as session: + for key in page_keys: + allowed, reason = await page_access_repo.check_access( + session, page_key=key, profile_id=profile_id, role_names=roles + ) + results.append(PageAccessCheck(page_key=key, allowed=allowed, reason=reason)) + return results + + +# ============================================================================= +# Effective shared-key fetch for the UI. +# ============================================================================= +# The admin sets a shared OpenRouter API key via /api/admin/settings/openrouter-key +# (encrypted at rest). Any signed-in user whose role is in the setting's +# allowed_role_names list (or any signed-in user, if the list is empty/null) can +# fetch the *plaintext* via this endpoint so their browser can use the key for +# OpenRouter calls. We deliberately do NOT echo the plaintext to the user UI — +# the dashboard input shows it masked. Plaintext is only retrievable by: +# - admins, via the admin endpoint (with `reveal=true`) +# - users with an allowed role, via this endpoint (used to make calls, +# never displayed) + +from core.database import admin_setting_repo +from core.security import decrypt_token, get_current_user + +_OPENROUTER_KEY_SETTING = "shared.openrouter_api_key" + + +@router.get("/settings/openrouter-key/effective") +async def get_effective_openrouter_key( + current_user: Annotated[dict, Depends(get_current_user)], +): + """Return the effective shared OpenRouter key for the calling user. + Response shape: + { "source": "shared" | "none", + "api_key": str | null, + "last_4": str | null } + `source=none` means there is no shared key the caller may use — they + should fall back to their own personal key. `api_key` is included so + the browser can use it for tool calls; the UI is expected to never + render it as plaintext to non-admins.""" + async with user_db_manager.get_async_session() as session: + row = await admin_setting_repo.get(session, _OPENROUTER_KEY_SETTING) + if not row or not row.value_enc: + return {"source": "none", "api_key": None, "last_4": None} + allowed: List[str] = row.allowed_role_names or [] + user_roles: List[str] = current_user.get("roles", []) or [] + # Admins always pass; otherwise the user must have at least one role + # in the allowed list. Empty allowed list = open to any signed-in user. + is_admin = "Admin" in user_roles + if not is_admin and allowed and not (set(user_roles) & set(allowed)): + return {"source": "none", "api_key": None, "last_4": None} + plaintext = decrypt_token(row.value_enc) + return { + "source": "shared", + "api_key": plaintext, + "last_4": plaintext[-4:] if plaintext and len(plaintext) >= 4 else None, + } diff --git a/usermanagement_service/core/routers/admin.py b/usermanagement_service/core/routers/admin.py new file mode 100644 index 0000000..9f81ddb --- /dev/null +++ b/usermanagement_service/core/routers/admin.py @@ -0,0 +1,586 @@ +"""Admin endpoints — mounted at /api/admin/*. Every route requires the Admin role. + +Surfaces: + - /roles CRUD over Web_available_role + - /permissions CRUD over Web_permission + - /roles/{id}/permissions role ↔ permission assignment + - /page-access CRUD over Web_page_access (+ role/user overrides) + - /users list / get / delete / assign-role for UserProfiles +""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import List, Optional, Annotated + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy import select, func + +from core.database import ( + user_db_manager, user_profile_repo, user_role_repo, user_activity_repo, + available_role_repo, permission_repo, role_permission_repo, page_access_repo, + oauth_identity_repo, +) +from core.models.database_models import ( + AvailableRole as AvailableRoleModel, + Permission as PermissionModel, + UserProfile as UserProfileModel, + UserRole as UserRoleModel, + OAuthIdentity as OAuthIdentityModel, +) +from core.models.user import ( + AvailableRoleInput, AvailableRole, + PermissionInput, Permission, RolePermissionAssignment, + PageAccessInput, PageAccess, + AdminUserListItem, UserRoleInput, ActivityType, +) +from core.security import require_admin + +logger = logging.getLogger(__name__) +router = APIRouter() + + +# ============================================================================= +# ROLES +# ============================================================================= + +@router.get("/roles", response_model=List[AvailableRole]) +async def list_roles(_admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + result = await session.execute(select(AvailableRoleModel).order_by(AvailableRoleModel.category, AvailableRoleModel.name)) + return [AvailableRole.model_validate(r, from_attributes=True) for r in result.scalars().all()] + + +@router.post("/roles", response_model=AvailableRole, status_code=201) +async def create_role(role: AvailableRoleInput, _admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + existing = await session.execute(select(AvailableRoleModel).where(AvailableRoleModel.name == role.name)) + if existing.scalar_one_or_none(): + raise HTTPException(status_code=409, detail=f"Role '{role.name}' already exists") + new_role = AvailableRoleModel(**role.dict()) + session.add(new_role) + await session.commit() + await session.refresh(new_role) + return AvailableRole.model_validate(new_role, from_attributes=True) + + +@router.put("/roles/{role_id}", response_model=AvailableRole) +async def update_role(role_id: int, role: AvailableRoleInput, _admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + existing = await session.get(AvailableRoleModel, role_id) + if not existing: + raise HTTPException(status_code=404, detail="Role not found") + for k, v in role.dict().items(): + setattr(existing, k, v) + existing.updated_at = datetime.utcnow() + await session.commit() + await session.refresh(existing) + return AvailableRole.model_validate(existing, from_attributes=True) + + +@router.delete("/roles/{role_id}", status_code=204) +async def delete_role(role_id: int, _admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + existing = await session.get(AvailableRoleModel, role_id) + if not existing: + raise HTTPException(status_code=404, detail="Role not found") + if existing.name == "Admin": + raise HTTPException(status_code=400, detail="The Admin role cannot be deleted") + await session.delete(existing) + await session.commit() + return None + + +@router.get("/roles/{role_id}/permissions", response_model=List[Permission]) +async def get_role_permissions(role_id: int, _admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + perms = await role_permission_repo.get_permissions_for_role(session, role_id) + return [Permission.model_validate(p, from_attributes=True) for p in perms] + + +@router.put("/roles/{role_id}/permissions", response_model=List[Permission]) +async def set_role_permissions( + role_id: int, + body: RolePermissionAssignment, + _admin: Annotated[dict, Depends(require_admin)], +): + async with user_db_manager.get_async_session() as session: + existing = await session.get(AvailableRoleModel, role_id) + if not existing: + raise HTTPException(status_code=404, detail="Role not found") + await role_permission_repo.set_role_permissions(session, role_id, body.permission_ids) + perms = await role_permission_repo.get_permissions_for_role(session, role_id) + await session.commit() + return [Permission.model_validate(p, from_attributes=True) for p in perms] + + +# ============================================================================= +# PERMISSIONS +# ============================================================================= + +@router.get("/permissions", response_model=List[Permission]) +async def list_permissions(_admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + perms = await permission_repo.list_all(session) + return [Permission.model_validate(p, from_attributes=True) for p in perms] + + +@router.post("/permissions", response_model=Permission, status_code=201) +async def create_permission(body: PermissionInput, _admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + existing = await permission_repo.get_by_name(session, body.name) + if existing: + raise HTTPException(status_code=409, detail=f"Permission '{body.name}' already exists") + new_perm = PermissionModel(**body.dict()) + session.add(new_perm) + await session.commit() + await session.refresh(new_perm) + return Permission.model_validate(new_perm, from_attributes=True) + + +@router.put("/permissions/{permission_id}", response_model=Permission) +async def update_permission( + permission_id: int, + body: PermissionInput, + _admin: Annotated[dict, Depends(require_admin)], +): + async with user_db_manager.get_async_session() as session: + existing = await session.get(PermissionModel, permission_id) + if not existing: + raise HTTPException(status_code=404, detail="Permission not found") + for k, v in body.dict().items(): + setattr(existing, k, v) + existing.updated_at = datetime.utcnow() + await session.commit() + await session.refresh(existing) + return Permission.model_validate(existing, from_attributes=True) + + +@router.delete("/permissions/{permission_id}", status_code=204) +async def delete_permission(permission_id: int, _admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + existing = await session.get(PermissionModel, permission_id) + if not existing: + raise HTTPException(status_code=404, detail="Permission not found") + await session.delete(existing) + await session.commit() + return None + + +# ============================================================================= +# PAGE ACCESS +# ============================================================================= + +async def _page_access_to_response(session, page) -> PageAccess: + roles = await page_access_repo.get_allowed_roles(session, page.id) + user_ids = await page_access_repo.get_allowed_user_profile_ids(session, page.id) + emails: List[str] = [] + if user_ids: + result = await session.execute( + select(UserProfileModel.email).where(UserProfileModel.id.in_(user_ids)) + ) + emails = [row[0] for row in result.all()] + return PageAccess( + id=page.id, + page_key=page.page_key, + description=page.description, + is_public=page.is_public, + allowed_roles=roles, + allowed_user_emails=emails, + created_at=page.created_at, + updated_at=page.updated_at, + ) + + +@router.get("/page-access", response_model=List[PageAccess]) +async def list_page_access(_admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + pages = await page_access_repo.list_all(session) + return [await _page_access_to_response(session, p) for p in pages] + + +@router.get("/page-access/{page_key}", response_model=PageAccess) +async def get_page_access(page_key: str, _admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + page = await page_access_repo.get_by_key(session, page_key) + if not page: + raise HTTPException(status_code=404, detail="Page not found") + return await _page_access_to_response(session, page) + + +@router.put("/page-access/{page_key}", response_model=PageAccess) +async def upsert_page_access( + page_key: str, + body: PageAccessInput, + _admin: Annotated[dict, Depends(require_admin)], +): + if body.page_key != page_key: + raise HTTPException(status_code=400, detail="page_key in URL and body must match") + async with user_db_manager.get_async_session() as session: + profile_ids: List[int] = [] + for email in body.allowed_user_emails: + p = await user_profile_repo.get_by_email(session, email) + if not p: + raise HTTPException(status_code=400, detail=f"No profile found for email {email}") + profile_ids.append(p.id) + page = await page_access_repo.upsert_with_members( + session, + page_key=body.page_key, + description=body.description, + is_public=body.is_public, + allowed_role_names=body.allowed_roles, + allowed_profile_ids=profile_ids, + ) + response = await _page_access_to_response(session, page) + await session.commit() + return response + + +@router.delete("/page-access/{page_key}", status_code=204) +async def delete_page_access(page_key: str, _admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + ok = await page_access_repo.delete_by_key(session, page_key) + if not ok: + raise HTTPException(status_code=404, detail="Page not found") + await session.commit() + return None + + +# ============================================================================= +# USERS +# ============================================================================= + +@router.get("/users", response_model=List[AdminUserListItem]) +async def list_users( + _admin: Annotated[dict, Depends(require_admin)], + q: Optional[str] = Query(None, description="Search in name/email/orcid"), + role: Optional[str] = Query(None, description="Filter by assigned role name"), + limit: int = Query(50, ge=1, le=200), + offset: int = Query(0, ge=0), +): + async with user_db_manager.get_async_session() as session: + stmt = select(UserProfileModel) + if q: + like = f"%{q}%" + stmt = stmt.where( + (UserProfileModel.name.ilike(like)) + | (UserProfileModel.email.ilike(like)) + | (UserProfileModel.orcid_id.ilike(like)) + ) + if role: + stmt = stmt.join(UserRoleModel, UserRoleModel.profile_id == UserProfileModel.id).where( + UserRoleModel.role == role, UserRoleModel.is_active == True # noqa: E712 + ) + stmt = stmt.order_by(UserProfileModel.created_at.desc()).limit(limit).offset(offset) + result = await session.execute(stmt) + profiles = list(result.scalars().unique().all()) + + items: List[AdminUserListItem] = [] + for p in profiles: + role_names = await user_role_repo.get_user_role_names(session, p.id) + identities = await oauth_identity_repo.list_for_profile(session, p.id) + items.append(AdminUserListItem( + profile_id=p.id, + name=p.name, + email=p.email, + orcid_id=p.orcid_id, + roles=role_names, + providers=sorted({i.provider for i in identities}), + created_at=p.created_at, + is_banned=bool(getattr(p, "is_banned", False)), + banned_at=getattr(p, "banned_at", None), + banned_by=getattr(p, "banned_by", None), + ban_reason=getattr(p, "ban_reason", None), + )) + return items + + +@router.get("/users/count") +async def count_users(_admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + result = await session.execute(select(func.count()).select_from(UserProfileModel)) + return {"count": result.scalar_one()} + + +@router.delete("/users/{profile_id}", status_code=204) +async def delete_user(profile_id: int, _admin: Annotated[dict, Depends(require_admin)]): + async with user_db_manager.get_async_session() as session: + profile = await session.get(UserProfileModel, profile_id) + if not profile: + raise HTTPException(status_code=404, detail="User not found") + # SuperAdmin accounts are protected — they cannot be deleted via the + # admin endpoints. Drop them from the bootstrap allowlist + restart + # if this is genuinely needed. + target_roles = await user_role_repo.get_user_role_names(session, profile_id) + if "SuperAdmin" in (target_roles or []): + raise HTTPException( + status_code=403, + detail="SuperAdmin accounts cannot be deleted via the admin UI.", + ) + # Cascades delete activities, roles, contributions, countries, orgs, education, expertise, oauth_identity. + await session.delete(profile) + await session.commit() + return None + + +@router.post("/users/{profile_id}/roles", response_model=List[str]) +async def assign_role_to_user( + profile_id: int, + body: UserRoleInput, + admin: Annotated[dict, Depends(require_admin)], +): + async with user_db_manager.get_async_session() as session: + profile = await session.get(UserProfileModel, profile_id) + if not profile: + raise HTTPException(status_code=404, detail="User not found") + await user_role_repo.assign_role( + session=session, + profile_id=profile_id, + role=body.role, + assigned_by=admin.get("profile_id"), + is_active=body.is_active, + expires_at=body.expires_at, + ) + await user_activity_repo.log_activity( + session=session, + profile_id=profile_id, + activity_type=ActivityType.CONTENT_CURATION, + description=f"Role '{body.role}' assigned by admin", + ) + roles = await user_role_repo.get_user_role_names(session, profile_id) + await session.commit() + return roles + + +@router.delete("/users/{profile_id}/roles/{role_name}", response_model=List[str]) +async def remove_role_from_user( + profile_id: int, + role_name: str, + _admin: Annotated[dict, Depends(require_admin)], +): + async with user_db_manager.get_async_session() as session: + profile = await session.get(UserProfileModel, profile_id) + if not profile: + raise HTTPException(status_code=404, detail="User not found") + # SuperAdmin role itself can never be stripped through the admin UI. + if role_name == "SuperAdmin": + raise HTTPException( + status_code=403, + detail="The SuperAdmin role cannot be removed via the admin UI.", + ) + await user_role_repo.remove_role(session, profile_id, role_name) + roles = await user_role_repo.get_user_role_names(session, profile_id) + await session.commit() + return roles + + +# ============================================================================= +# ADMIN-MANAGED SETTINGS (shared keys, etc.) +# ============================================================================= + +# Stable key for the shared OpenRouter API key. New shared settings should +# follow the dotted-namespace convention. +_OPENROUTER_KEY_SETTING = "shared.openrouter_api_key" + + +def _redact_key(plain: Optional[str]) -> dict: + """Return a non-sensitive summary of an API key — last 4 chars only. + Lets the admin UI confirm the value is set without redisplaying it after + the page is refreshed (the admin re-paste flow stays explicit).""" + if not plain: + return {"has_key": False, "last_4": None, "length": 0} + plain = plain.strip() + return { + "has_key": bool(plain), + "last_4": plain[-4:] if len(plain) >= 4 else None, + "length": len(plain), + } + + +@router.get("/settings/openrouter-key") +async def admin_get_openrouter_key( + _admin: Annotated[dict, Depends(require_admin)], + reveal: bool = Query(False, description="Return the plaintext key. Default false (returns last-4 only)."), +): + """Inspect the shared OpenRouter key. By default returns metadata only + (has_key, last_4, allowed_role_names, updated_at, updated_by). Pass + `reveal=true` to retrieve the plaintext — only admins can do this.""" + from core.database import admin_setting_repo + from core.security import decrypt_token + + async with user_db_manager.get_async_session() as session: + row = await admin_setting_repo.get(session, _OPENROUTER_KEY_SETTING) + if not row or not row.value_enc: + return { + "has_key": False, + "last_4": None, + "length": 0, + "allowed_role_names": [], + "updated_at": None, + "updated_by": None, + "plaintext": None, + } + plaintext = decrypt_token(row.value_enc) + body = _redact_key(plaintext) + body["allowed_role_names"] = row.allowed_role_names or [] + body["updated_at"] = row.updated_at.isoformat() if row.updated_at else None + body["updated_by"] = row.updated_by + body["plaintext"] = plaintext if reveal else None + return body + + +@router.put("/settings/openrouter-key") +async def admin_set_openrouter_key( + payload: dict, + admin: Annotated[dict, Depends(require_admin)], +): + """Set or replace the shared OpenRouter API key. + Body: { "api_key": str, "allowed_role_names": [str, ...] | null }. + `allowed_role_names` controls which roles can fetch the effective key + via the user-facing endpoint. Empty list / null means "any signed-in + user with a profile". The plaintext is encrypted at rest with the same + Fernet key as OAuth tokens.""" + from core.database import admin_setting_repo + from core.security import encrypt_token + + api_key = (payload.get("api_key") or "").strip() + if not api_key: + raise HTTPException(status_code=400, detail="api_key is required") + if len(api_key) > 4000: + raise HTTPException(status_code=400, detail="api_key looks too long; check the value") + allowed_role_names = payload.get("allowed_role_names") + if allowed_role_names is not None: + if not isinstance(allowed_role_names, list) or not all(isinstance(x, str) for x in allowed_role_names): + raise HTTPException(status_code=400, detail="allowed_role_names must be a list of strings or null") + # Normalise: trim, drop empties, dedupe, preserve order. + seen = set() + normalised = [] + for r in (s.strip() for s in allowed_role_names): + if r and r not in seen: + seen.add(r) + normalised.append(r) + allowed_role_names = normalised + + async with user_db_manager.get_async_session() as session: + # `admin` from require_admin includes a profile_id claim used as updated_by. + updated_by = admin.get("profile_id") if isinstance(admin, dict) else None + await admin_setting_repo.upsert( + session, + key=_OPENROUTER_KEY_SETTING, + value_enc=encrypt_token(api_key), + allowed_role_names=allowed_role_names, + updated_by=updated_by, + ) + await session.commit() + return {**_redact_key(api_key), "allowed_role_names": allowed_role_names or []} + + +@router.delete("/settings/openrouter-key", status_code=204) +async def admin_delete_openrouter_key(_admin: Annotated[dict, Depends(require_admin)]): + """Clear the shared OpenRouter API key. Users without their own key will + fall back to the "no shared key" path.""" + from core.database import admin_setting_repo + async with user_db_manager.get_async_session() as session: + await admin_setting_repo.delete(session, _OPENROUTER_KEY_SETTING) + await session.commit() + return None + + +# ============================================================================= +# USER BAN +# ============================================================================= +# Banned users keep their UserProfile (so history is preserved) but every +# authenticated request returns 403 — see core.security.get_current_user +# which re-reads `is_banned` per request. To delete a user entirely use +# DELETE /api/admin/users/{profile_id}. + +@router.post("/users/{profile_id}/ban") +async def ban_user( + profile_id: int, + payload: dict, + admin: Annotated[dict, Depends(require_admin)], +): + """Suspend a user. Body: { "reason": str }. Idempotent — re-banning an + already-banned user updates the reason and timestamp. + + Refuses to ban yourself or to ban a SuperAdmin. Regular Admins are + bannable directly — multiple admins can coexist, and one admin moderating + another is part of the model. + """ + reason = (payload.get("reason") or "").strip() if isinstance(payload, dict) else "" + if not reason: + raise HTTPException(status_code=400, detail="reason is required") + if len(reason) > 1000: + raise HTTPException(status_code=400, detail="reason is too long (max 1000 chars)") + + actor_id = admin.get("profile_id") if isinstance(admin, dict) else None + if actor_id is not None and actor_id == profile_id: + raise HTTPException(status_code=400, detail="You cannot ban yourself.") + + async with user_db_manager.get_async_session() as session: + profile = await session.get(UserProfileModel, profile_id) + if not profile: + raise HTTPException(status_code=404, detail="User not found") + + target_roles = await user_role_repo.get_user_role_names(session, profile_id) + if "SuperAdmin" in (target_roles or []): + raise HTTPException( + status_code=403, + detail="SuperAdmin accounts cannot be banned via the admin UI.", + ) + + profile.is_banned = True + profile.banned_at = datetime.utcnow() + profile.banned_by = actor_id + profile.ban_reason = reason + await session.flush() + + await user_activity_repo.log_activity( + session=session, + profile_id=actor_id if actor_id is not None else profile_id, + activity_type=ActivityType.USER_BAN, + description=f"Banned profile_id={profile_id}: {reason}", + ip_address=None, + user_agent=None, + ) + await session.commit() + return { + "profile_id": profile_id, + "is_banned": True, + "banned_at": profile.banned_at.isoformat() if profile.banned_at else None, + "banned_by": profile.banned_by, + "ban_reason": profile.ban_reason, + } + + +@router.delete("/users/{profile_id}/ban") +async def unban_user( + profile_id: int, + admin: Annotated[dict, Depends(require_admin)], +): + """Lift a suspension. Idempotent — unbanning an already-active user is a no-op + that still records the activity for audit.""" + actor_id = admin.get("profile_id") if isinstance(admin, dict) else None + async with user_db_manager.get_async_session() as session: + profile = await session.get(UserProfileModel, profile_id) + if not profile: + raise HTTPException(status_code=404, detail="User not found") + was_banned = bool(getattr(profile, "is_banned", False)) + profile.is_banned = False + profile.banned_at = None + profile.banned_by = None + profile.ban_reason = None + await session.flush() + + if was_banned: + await user_activity_repo.log_activity( + session=session, + profile_id=actor_id if actor_id is not None else profile_id, + activity_type=ActivityType.USER_UNBAN, + description=f"Unbanned profile_id={profile_id}", + ip_address=None, + user_agent=None, + ) + await session.commit() + return {"profile_id": profile_id, "is_banned": False} diff --git a/usermanagement_service/core/routers/oauth.py b/usermanagement_service/core/routers/oauth.py new file mode 100644 index 0000000..01da808 --- /dev/null +++ b/usermanagement_service/core/routers/oauth.py @@ -0,0 +1,322 @@ +"""Unified OAuth routes: /api/auth/{provider}/login and /api/auth/{provider}/callback. + +Flow: + 1. UI hits GET /api/auth/{provider}/login?redirect_after_login=/dashboard + → we mint a state+PKCE pair, store it in Web_oauth_state, return { authorize_url }. + UI redirects the browser to authorize_url. + 2. Provider redirects back to GET /api/auth/{provider}/callback?code=...&state=... + → we validate state, exchange the code, fetch userinfo, upsert + UserProfile + Web_oauth_identity + JWTUser shell, assign default role, + issue a BrainKB JWT, then redirect to USERMANAGEMENT_FRONTEND_CALLBACK_URL + with ?token=... in the query string. +""" + +from __future__ import annotations + +import base64 +import hashlib +import logging +import secrets +from datetime import datetime, timedelta +from typing import Optional +from urllib.parse import urlencode + +from fastapi import APIRouter, HTTPException, Query, Request +from fastapi.responses import RedirectResponse + +from core.configuration import config +from core.database import ( + user_db_manager, user_profile_repo, user_role_repo, jwt_user_repo, + oauth_identity_repo, oauth_state_repo, user_activity_repo, +) +from core.models.user import ActivityType, OAuthLoginStart, UserRoleEnum +from core.models.database_models import UserProfile as UserProfileModel, JWTUser as JWTUserModel +from core.oauth import get_provider +from core.security import ( + create_access_token_v2, encrypt_token, get_password_hash, +) + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# ---- helpers ------------------------------------------------------------ + +def _pkce_pair() -> tuple[str, str]: + verifier = secrets.token_urlsafe(64)[:128] + digest = hashlib.sha256(verifier.encode()).digest() + challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=") + return verifier, challenge + + +def _redirect_uri_for(provider_name: str) -> str: + return f"{config.public_base_url.rstrip('/')}/api/auth/{provider_name}/callback" + + +async def _upsert_profile_for_oauth(session, userinfo) -> UserProfileModel: + """Find or create a UserProfile for an OAuth identity. + Matching order: (1) existing OAuth identity → its linked profile, + (2) UserProfile.email, (3) UserProfile.orcid_id (for ORCID logins), + (4) create a new profile.""" + existing_identity = await oauth_identity_repo.get_by_provider_user( + session, userinfo.provider, userinfo.provider_user_id + ) + if existing_identity: + return await session.get(UserProfileModel, existing_identity.profile_id) + + if userinfo.email: + by_email = await user_profile_repo.get_by_email(session, userinfo.email) + if by_email: + return by_email + + if userinfo.orcid_id: + by_orcid = await user_profile_repo.get_by_orcid_id(session, userinfo.orcid_id) + if by_orcid: + return by_orcid + + if not userinfo.email: + # Some GitHub accounts hide email and have no verified one. We can't + # create a profile without an email — surface a clear error. + raise HTTPException( + status_code=400, + detail=f"{userinfo.provider} did not return an email and no existing profile could be matched. Please make your email public on {userinfo.provider} or log in with ORCID/Globus first.", + ) + + new_profile = UserProfileModel( + name=userinfo.name or userinfo.email.split("@")[0], + email=userinfo.email, + orcid_id=userinfo.orcid_id, + github=userinfo.github_username, + ) + session.add(new_profile) + await session.flush() + await session.refresh(new_profile) + return new_profile + + +async def _ensure_jwt_user_shell(session, email: str, full_name: str) -> JWTUserModel: + """Make sure a Web_jwtuser row exists for this email. OAuth users don't have + a usable password — we store a random high-entropy hash (can't be logged in + with, just exists so the JWT user_id claim is stable). The shell is created + with `is_active=False`, so the lookup must not filter by activation; using + `get_by_email` (active-only) here would re-INSERT on every sign-in and + collide with the unique-email constraint.""" + existing = await jwt_user_repo.get_by_email_any_status(session, email) + if existing: + return existing + random_password = secrets.token_urlsafe(48) + return await jwt_user_repo.create_user( + session=session, + full_name=full_name, + email=email, + password=get_password_hash(random_password), + ) + + +# ---- routes ------------------------------------------------------------- + +@router.get("/auth/providers") +async def list_providers(): + """List OAuth providers and whether each is currently configured. + UI can use this to show/hide login buttons.""" + from core.oauth import REGISTRY + return { + "providers": [ + {"name": p.name, "configured": p.is_configured(), "supports_pkce": p.supports_pkce} + for p in REGISTRY.values() + ] + } + + +@router.get("/auth/{provider_name}/login", response_model=OAuthLoginStart) +async def oauth_login( + provider_name: str, + redirect_after_login: Optional[str] = Query(None, description="Relative path to send the user to after login completes"), +): + """Start an OAuth flow. Returns the authorize URL; the UI redirects the browser there.""" + try: + provider = get_provider(provider_name) + except KeyError: + raise HTTPException(status_code=404, detail=f"Unknown provider: {provider_name}") + + if not provider.is_configured(): + raise HTTPException(status_code=503, detail=f"{provider_name} OAuth is not configured on the server") + + state = secrets.token_urlsafe(32) + code_verifier = None + code_challenge = None + if provider.supports_pkce: + code_verifier, code_challenge = _pkce_pair() + + redirect_uri = _redirect_uri_for(provider.name) + authorize_url = provider.authorize_url( + redirect_uri=redirect_uri, + state=state, + code_challenge=code_challenge, + ) + + # Persist state so the callback (on a different request) can validate it. + async with user_db_manager.get_async_session() as session: + await oauth_state_repo.create( + session, + state=state, + provider=provider.name, + code_verifier=code_verifier, + redirect_after_login=redirect_after_login, + expires_at=datetime.utcnow() + timedelta(minutes=10), + ) + await session.commit() + + return OAuthLoginStart(authorize_url=authorize_url, state=state) + + +@router.get("/auth/{provider_name}/callback") +async def oauth_callback( + provider_name: str, + request: Request, + code: Optional[str] = Query(None), + state: Optional[str] = Query(None), + error: Optional[str] = Query(None), + error_description: Optional[str] = Query(None), +): + """Handle the OAuth provider redirect. On success, redirects the browser to + USERMANAGEMENT_FRONTEND_CALLBACK_URL with ?token=&redirect=.""" + if error: + logger.warning(f"OAuth error from {provider_name}: {error} {error_description}") + return RedirectResponse( + _frontend_error_redirect(f"{error}: {error_description or ''}"), + status_code=302, + ) + + if not code or not state: + raise HTTPException(status_code=400, detail="Missing code or state") + + try: + provider = get_provider(provider_name) + except KeyError: + raise HTTPException(status_code=404, detail=f"Unknown provider: {provider_name}") + + async with user_db_manager.get_async_session() as session: + await oauth_state_repo.purge_expired(session) + state_row = await oauth_state_repo.consume(session, state) + if state_row is None or state_row.provider != provider_name: + await session.commit() + raise HTTPException(status_code=400, detail="Invalid or expired state") + if state_row.expires_at < datetime.utcnow(): + await session.commit() + raise HTTPException(status_code=400, detail="OAuth state expired") + code_verifier = state_row.code_verifier + redirect_after_login = state_row.redirect_after_login + await session.commit() + + redirect_uri = _redirect_uri_for(provider.name) + try: + token_resp = await provider.exchange_code(code=code, redirect_uri=redirect_uri, code_verifier=code_verifier) + userinfo = await provider.fetch_userinfo(access_token=token_resp.access_token, token_response=token_resp) + except Exception as e: + logger.exception(f"OAuth callback failed for {provider_name}") + return RedirectResponse(_frontend_error_redirect(str(e)), status_code=302) + + if not userinfo.provider_user_id: + return RedirectResponse(_frontend_error_redirect("provider returned no user id"), status_code=302) + + async with user_db_manager.get_async_session() as session: + try: + profile = await _upsert_profile_for_oauth(session, userinfo) + + # Top up profile fields the provider may have just given us. + dirty = False + if userinfo.orcid_id and not profile.orcid_id: + profile.orcid_id = userinfo.orcid_id + dirty = True + if userinfo.github_username and not profile.github: + profile.github = userinfo.github_username + dirty = True + if dirty: + profile.updated_at = datetime.utcnow() + await session.flush() + + jwt_user = await _ensure_jwt_user_shell( + session, + email=profile.email, + full_name=profile.name or userinfo.name or profile.email, + ) + + # Default role on first login = Curator. + existing_roles = await user_role_repo.get_user_role_names(session, profile.id) + if not existing_roles: + await user_role_repo.assign_role( + session, + profile_id=profile.id, + role=UserRoleEnum.CURATOR.value, + is_active=True, + ) + existing_roles = [UserRoleEnum.CURATOR.value] + + # Upsert the oauth identity row (encrypt tokens at rest). + token_expires_at = None + if token_resp.expires_in: + token_expires_at = datetime.utcnow() + timedelta(seconds=int(token_resp.expires_in)) + await oauth_identity_repo.upsert( + session, + provider=provider.name, + provider_user_id=userinfo.provider_user_id, + profile_id=profile.id, + email=userinfo.email, + access_token_enc=encrypt_token(token_resp.access_token), + refresh_token_enc=encrypt_token(token_resp.refresh_token), + token_expires_at=token_expires_at, + raw_profile=userinfo.raw, + ) + + # Bootstrap-superadmin allowlist: if configured, elevate on first sight. + # Seed both Admin (for permissions / page-access checks) and + # SuperAdmin (the immutable marker that protects the account). + if (profile.email or "").lower() in config.bootstrap_superadmin_emails: + for role_name in (UserRoleEnum.ADMIN.value, UserRoleEnum.SUPERADMIN.value): + if role_name not in existing_roles: + await user_role_repo.assign_role( + session, + profile_id=profile.id, + role=role_name, + is_active=True, + ) + existing_roles.append(role_name) + + # Log activity. + await user_activity_repo.log_activity( + session=session, + profile_id=profile.id, + activity_type=ActivityType.LOGIN, + description=f"Login via {provider.name}", + ip_address=request.client.host if request.client else None, + user_agent=request.headers.get("user-agent"), + ) + + scopes = await jwt_user_repo.get_user_scopes(session, jwt_user.id) or ["read"] + token = create_access_token_v2( + email=profile.email, + jwt_user_id=jwt_user.id, + profile_id=profile.id, + roles=existing_roles, + scopes=scopes, + auth_source=provider.name, + ) + await session.commit() + except HTTPException: + await session.rollback() + raise + except Exception as e: + await session.rollback() + logger.exception("Error finalizing OAuth login") + return RedirectResponse(_frontend_error_redirect(f"finalize_failed: {e}"), status_code=302) + + qs = {"token": token} + if redirect_after_login: + qs["redirect"] = redirect_after_login + return RedirectResponse(f"{config.frontend_callback_url}?{urlencode(qs)}", status_code=302) + + +def _frontend_error_redirect(message: str) -> str: + return f"{config.frontend_callback_url}?{urlencode({'error': message})}" From a51293260c0ee6bf72aa672e02187b27423c49cc Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri <52251022+tekrajchhetri@users.noreply.github.com> Date: Mon, 4 May 2026 09:56:51 -0400 Subject: [PATCH 11/22] Update ml_service/core/synth_scholar/routes.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- ml_service/core/synth_scholar/routes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml_service/core/synth_scholar/routes.py b/ml_service/core/synth_scholar/routes.py index 564a8ea..d442793 100644 --- a/ml_service/core/synth_scholar/routes.py +++ b/ml_service/core/synth_scholar/routes.py @@ -1519,7 +1519,7 @@ async def search_literature( synthesis = None if req.summarize and articles: - api_key = _get_api_key() + api_key = _resolve_api_key() from synthscholar.agents import AgentDeps, run_search_synthesis # type: ignore[import-not-found] deps = AgentDeps( From 1410194846f1dbe88afd95206db44c7528445d6b Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri <52251022+tekrajchhetri@users.noreply.github.com> Date: Mon, 4 May 2026 09:58:36 -0400 Subject: [PATCH 12/22] Update ml_service/core/main.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- ml_service/core/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml_service/core/main.py b/ml_service/core/main.py index c7d7ec6..360f70c 100644 --- a/ml_service/core/main.py +++ b/ml_service/core/main.py @@ -27,7 +27,7 @@ from core.synth_scholar.database import init_db as init_synth_scholar_db, close_db as close_synth_scholar_db from core.synth_scholar.store import fix_stuck_reviews as fix_synth_scholar_stuck_reviews _SYNTH_SCHOLAR_AVAILABLE = True -except Exception as _exc: +except ImportError as _exc: _SYNTH_SCHOLAR_AVAILABLE = False _SYNTH_SCHOLAR_IMPORT_ERROR = _exc From 5d26acd0decb4fe0b8cc25d83fdfe4b6811eefa6 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Mon, 4 May 2026 15:54:22 -0400 Subject: [PATCH 13/22] heartbeat emitting to prevent timeout --- ml_service/core/synth_scholar/routes.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/ml_service/core/synth_scholar/routes.py b/ml_service/core/synth_scholar/routes.py index 564a8ea..10cac5b 100644 --- a/ml_service/core/synth_scholar/routes.py +++ b/ml_service/core/synth_scholar/routes.py @@ -846,11 +846,13 @@ async def _stream_from_db( replay of state that the owning worker already persisted, so it stays consistent with what the pipeline truly knows.""" POLL_INTERVAL = 1.5 + HEARTBEAT_SECONDS = 15.0 # Below ALB's 60s default idle timeout last_plan_iteration = ( session.pending_plan_iteration_db if session.pending_plan_iteration_db is not None else 0 ) + last_yield = time.monotonic() if ( session.status == ReviewStatus.PLAN_PENDING and session.pending_plan_db is not None @@ -859,6 +861,7 @@ async def _stream_from_db( yield ( f"data: {ProgressEvent(review_id=review_id, step=last_step, message=f'Awaiting plan confirmation (iteration {last_plan_iteration})...', timestamp=datetime.now().isoformat(), event_type='plan_review', plan=session.pending_plan_db).model_dump_json()}\n\n" ) + last_yield = time.monotonic() while True: await asyncio.sleep(POLL_INTERVAL) @@ -881,6 +884,7 @@ async def _stream_from_db( f"data: {ProgressEvent(review_id=review_id, step=i + 1, message=msg, timestamp=ts, event_type='progress', source=ev.get('source'), kind=ev.get('kind', 'log'), stage=refreshed.stage, stage_index=refreshed.stage_index, stage_total=refreshed.stage_total, stage_done=refreshed.stage_done, stage_remaining=refreshed.stage_remaining, articles_included=refreshed.articles_included).model_dump_json()}\n\n" ) last_step = len(new_log) + last_yield = time.monotonic() if ( refreshed.status == ReviewStatus.PLAN_PENDING @@ -891,6 +895,7 @@ async def _stream_from_db( yield ( f"data: {ProgressEvent(review_id=review_id, step=last_step, message=f'Awaiting plan confirmation (iteration {last_plan_iteration})...', timestamp=datetime.now().isoformat(), event_type='plan_review', plan=refreshed.pending_plan_db).model_dump_json()}\n\n" ) + last_yield = time.monotonic() if refreshed.status in (ReviewStatus.COMPLETED, ReviewStatus.FAILED, ReviewStatus.CANCELLED): etype = ( @@ -903,6 +908,10 @@ async def _stream_from_db( ) return + if time.monotonic() - last_yield >= HEARTBEAT_SECONDS: + yield ": keepalive\n\n" + last_yield = time.monotonic() + @router.get("/synth-scholar/reviews/{review_id}/stream", tags=["SynthScholar — reviews"]) async def stream_progress( @@ -976,7 +985,7 @@ async def event_generator() -> AsyncGenerator[str, None]: return try: - await asyncio.wait_for(session._progress_event.wait(), timeout=30.0) + await asyncio.wait_for(session._progress_event.wait(), timeout=15.0) except asyncio.TimeoutError: if ( session._live.status == ReviewStatus.PLAN_PENDING.value @@ -1519,7 +1528,7 @@ async def search_literature( synthesis = None if req.summarize and articles: - api_key = _get_api_key() + api_key = _resolve_api_key() from synthscholar.agents import AgentDeps, run_search_synthesis # type: ignore[import-not-found] deps = AgentDeps( From 4c9f77db72d893654cf3a08ff218d0ab9bb15da2 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Tue, 5 May 2026 17:06:09 -0400 Subject: [PATCH 14/22] push to external graph database --- .../core/synth_scholar/oxigraph_push.py | 154 ++++++++++++++++++ ml_service/core/synth_scholar/store.py | 13 ++ 2 files changed, 167 insertions(+) create mode 100644 ml_service/core/synth_scholar/oxigraph_push.py diff --git a/ml_service/core/synth_scholar/oxigraph_push.py b/ml_service/core/synth_scholar/oxigraph_push.py new file mode 100644 index 0000000..0e54e19 --- /dev/null +++ b/ml_service/core/synth_scholar/oxigraph_push.py @@ -0,0 +1,154 @@ +"""Push completed PRISMA review results to BrainKB's Oxigraph triplestore. + +Wraps :class:`synthscholar.ontology.rdf_push.GraphDBConfig` with BrainKB's +existing ``GRAPHDATABASE_*`` env-var conventions (the same vars the +query_service already uses) so a unified configuration drives both services. + +Environment variables consumed +------------------------------ + +============================================ =========================================== +``GRAPHDATABASE_USERNAME`` Basic-auth username (default ``admin``). +``GRAPHDATABASE_PASSWORD`` Basic-auth password. +``GRAPHDATABASE_HOSTNAME`` Bare hostname (``oxigraph``) or full URL + (``https://db.brainkb.org``). Default + ``oxigraph`` (internal docker hostname). +``GRAPHDATABASE_PORT`` Default ``7878``. +``GRAPHDATABASE_TYPE`` Informational; only ``OXIGRAPH`` triggers + the GSP path. Default ``OXIGRAPH``. +``SYNTH_SCHOLAR_PUSH_TO_GRAPHDB`` Feature flag (``true``/``false``). + Default ``true`` — set to ``false`` to + disable the push without unsetting creds. +``SYNTH_SCHOLAR_GRAPHDB_PATH`` Endpoint path (default ``/store`` for GSP, + use ``/update`` for SPARQL Update). +``SYNTH_SCHOLAR_GRAPHDB_NAMED_GRAPH_PREFIX`` IRI prefix for review-specific named graphs. + Default ``https://brainkb.org/reviews/``. + Each review lands at ``{prefix}{review_id}``. +``SYNTH_SCHOLAR_GRAPHDB_PROTOCOL`` ``gsp`` (default) or ``update``. +``SYNTH_SCHOLAR_GRAPHDB_REPLACE`` If ``true``, replace the named graph each + run (HTTP PUT). Default ``true`` — review + IDs are stable, so a re-run should overwrite. +============================================ =========================================== +""" + +from __future__ import annotations + +import logging +import os +from typing import Optional + +logger = logging.getLogger(__name__) + + +def _truthy(val: Optional[str], default: bool = False) -> bool: + if val is None or val == "": + return default + return val.strip().lower() in ("1", "true", "yes", "on") + + +def _build_endpoint() -> Optional[str]: + """Compose the Oxigraph endpoint URL from BrainKB's GRAPHDATABASE_* vars.""" + host = (os.getenv("GRAPHDATABASE_HOSTNAME") or "oxigraph").strip() + port = (os.getenv("GRAPHDATABASE_PORT") or "7878").strip() + path = (os.getenv("SYNTH_SCHOLAR_GRAPHDB_PATH") or "/store").strip() + if not path.startswith("/"): + path = "/" + path + + if not host: + return None + + # Accept full URL ("https://db.brainkb.org[:port]") or bare hostname ("oxigraph"). + if host.startswith(("http://", "https://")): + # Don't append port if it's already in the URL. + from urllib.parse import urlparse + parsed = urlparse(host) + if parsed.port is None and port: + base = f"{host.rstrip('/')}:{port}" + else: + base = host.rstrip("/") + else: + base = f"http://{host.rstrip('/')}:{port}" + return base + path + + +def _make_config(): + """Build a ``GraphDBConfig`` from BrainKB env vars, or None if disabled. + + Returns ``None`` (without raising) when the push is disabled, when the + optional ``synthscholar`` import fails, or when no usable endpoint can + be composed. + """ + if not _truthy(os.getenv("SYNTH_SCHOLAR_PUSH_TO_GRAPHDB"), default=True): + return None + + try: + from synthscholar.ontology.rdf_push import GraphDBConfig # type: ignore[import-not-found] + except Exception as exc: + logger.warning("[oxigraph_push] synthscholar.rdf_push unavailable: %s", exc) + return None + + endpoint = _build_endpoint() + if not endpoint: + return None + + return GraphDBConfig( + endpoint=endpoint, + user=os.getenv("GRAPHDATABASE_USERNAME") or None, + password=os.getenv("GRAPHDATABASE_PASSWORD") or None, + protocol=(os.getenv("SYNTH_SCHOLAR_GRAPHDB_PROTOCOL") or "gsp").lower(), + replace=_truthy(os.getenv("SYNTH_SCHOLAR_GRAPHDB_REPLACE"), default=True), + ) + + +def _named_graph_for(review_id: str) -> str: + """IRI for the named graph holding *review_id*'s triples.""" + prefix = os.getenv( + "SYNTH_SCHOLAR_GRAPHDB_NAMED_GRAPH_PREFIX", + "https://brainkb.org/reviews/", + ) + if not prefix.endswith("/"): + prefix += "/" + return prefix + review_id + + +def push_review_to_oxigraph(review_id: str, result) -> Optional[int]: + """Best-effort push of a completed review's RDF to BrainKB's Oxigraph. + + Returns the HTTP status code on success, or ``None`` if the push was + skipped or failed. **Never raises** — failures are logged at WARNING and + must not affect the review's recorded status in Postgres. + + Parameters + ---------- + review_id: + The review's stable ID, used to compose the named-graph IRI. + result: + A ``PRISMAReviewResult`` (or anything ``synthscholar.export.to_oxigraph_store`` + accepts). Passed by reference; not mutated. + """ + if result is None: + logger.debug("[oxigraph_push] no result to push for %s", review_id) + return None + + cfg = _make_config() + if cfg is None: + logger.debug("[oxigraph_push] push disabled or not configured for %s", review_id) + return None + + cfg.named_graph = _named_graph_for(review_id) + + try: + from synthscholar.export import to_oxigraph_store # type: ignore[import-not-found] + store = to_oxigraph_store(result) + status = store.push_with_config(cfg) + logger.info( + "[oxigraph_push] pushed review %s to %s (graph <%s>) — HTTP %s", + review_id, cfg.resolved_endpoint, cfg.named_graph, status, + ) + return status + except Exception as exc: + logger.warning( + "[oxigraph_push] failed for review %s (endpoint=%s graph=%s): %s", + review_id, cfg.resolved_endpoint, cfg.named_graph, exc, + ) + return None diff --git a/ml_service/core/synth_scholar/store.py b/ml_service/core/synth_scholar/store.py index 8e188c4..e599ad9 100644 --- a/ml_service/core/synth_scholar/store.py +++ b/ml_service/core/synth_scholar/store.py @@ -253,6 +253,19 @@ async def mark_completed(self, result: PRISMAReviewResult) -> None: self._live.status = ReviewStatus.FAILED.value self.error = f"Persist failed after completion: {exc}" await self._persist_failed_minimal() + return + + # Best-effort push to BrainKB's Oxigraph triplestore. Runs in a + # threadpool so the synchronous httpx client doesn't block the event + # loop, and never raises — Postgres remains the source of truth. + try: + from .oxigraph_push import push_review_to_oxigraph + await asyncio.to_thread(push_review_to_oxigraph, self.review_id, result) + except Exception as exc: + logger.warning( + "[mark_completed] oxigraph push raised unexpectedly for %s: %s", + self.review_id, exc, + ) async def mark_failed(self, error: str) -> None: self.status = ReviewStatus.FAILED From dc622ccb6dadc54524752fb6ab47783ef9f01570 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Tue, 5 May 2026 17:11:28 -0400 Subject: [PATCH 15/22] package version corrected --- ml_service/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml_service/requirements.txt b/ml_service/requirements.txt index 8b4d68e..77ec2bd 100644 --- a/ml_service/requirements.txt +++ b/ml_service/requirements.txt @@ -51,5 +51,5 @@ ollama==0.6.0 # the AI pipeline; the local module under core/synth_scholar/ is just the # orchestration layer (sessions, SSE, exports). SQLAlchemy 2.0 async backs # the review tables, separate from ml_service's raw-asyncpg pool. -synthscholar[fulltext,semantic]==0.0.6 +synthscholar[fulltext,semantic]==0.0.7 sqlalchemy[asyncio]>=2.0.30 \ No newline at end of file From e67d20603403dd2fec182402052cfbc924150f85 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Tue, 5 May 2026 17:28:08 -0400 Subject: [PATCH 16/22] ALB keep alive for long task --- .../core/synth_scholar/oxigraph_push.py | 96 +++++++++++++++++-- 1 file changed, 88 insertions(+), 8 deletions(-) diff --git a/ml_service/core/synth_scholar/oxigraph_push.py b/ml_service/core/synth_scholar/oxigraph_push.py index 0e54e19..27d36e8 100644 --- a/ml_service/core/synth_scholar/oxigraph_push.py +++ b/ml_service/core/synth_scholar/oxigraph_push.py @@ -4,6 +4,11 @@ existing ``GRAPHDATABASE_*`` env-var conventions (the same vars the query_service already uses) so a unified configuration drives both services. +For long-running reviews (6-8 hours) the push happens once at the very end, +when the asyncio task calls ``mark_completed``. To survive large RDF +serialisations and transient network blips, the push uses a configurable +timeout and retries with exponential backoff. + Environment variables consumed ------------------------------ @@ -28,6 +33,11 @@ ``SYNTH_SCHOLAR_GRAPHDB_REPLACE`` If ``true``, replace the named graph each run (HTTP PUT). Default ``true`` — review IDs are stable, so a re-run should overwrite. +``SYNTH_SCHOLAR_GRAPHDB_TIMEOUT`` HTTP timeout in seconds for one push attempt. + Default ``600`` (10 min) — generous because + large reviews can produce megabytes of TTL. +``SYNTH_SCHOLAR_GRAPHDB_MAX_RETRIES`` Total attempts including the first. + Default ``3``. Backoff is 5 / 30 / 120 s. ============================================ =========================================== """ @@ -35,10 +45,15 @@ import logging import os +import time from typing import Optional logger = logging.getLogger(__name__) +# Backoff schedule (seconds) for retries — index 0 is unused (first attempt +# never sleeps). Anything past len(_BACKOFF) clamps to the last value. +_BACKOFF = [0, 5, 30, 120, 300] + def _truthy(val: Optional[str], default: bool = False) -> bool: if val is None or val == "": @@ -91,12 +106,23 @@ def _make_config(): if not endpoint: return None + timeout_env = os.getenv("SYNTH_SCHOLAR_GRAPHDB_TIMEOUT") + try: + timeout = float(timeout_env) if timeout_env else 600.0 + except ValueError: + logger.warning( + "[oxigraph_push] invalid SYNTH_SCHOLAR_GRAPHDB_TIMEOUT=%r — using 600s", + timeout_env, + ) + timeout = 600.0 + return GraphDBConfig( endpoint=endpoint, user=os.getenv("GRAPHDATABASE_USERNAME") or None, password=os.getenv("GRAPHDATABASE_PASSWORD") or None, protocol=(os.getenv("SYNTH_SCHOLAR_GRAPHDB_PROTOCOL") or "gsp").lower(), replace=_truthy(os.getenv("SYNTH_SCHOLAR_GRAPHDB_REPLACE"), default=True), + timeout=timeout, ) @@ -111,6 +137,20 @@ def _named_graph_for(review_id: str) -> str: return prefix + review_id +def _is_retryable(exc: BaseException) -> bool: + """True for transient failures worth retrying (network, 5xx, 429).""" + try: + import httpx + except Exception: + return False + if isinstance(exc, (httpx.TimeoutException, httpx.NetworkError, httpx.RemoteProtocolError)): + return True + if isinstance(exc, httpx.HTTPStatusError): + code = exc.response.status_code + return code == 429 or 500 <= code < 600 # rate-limited or server error + return False + + def push_review_to_oxigraph(review_id: str, result) -> Optional[int]: """Best-effort push of a completed review's RDF to BrainKB's Oxigraph. @@ -118,6 +158,10 @@ def push_review_to_oxigraph(review_id: str, result) -> Optional[int]: skipped or failed. **Never raises** — failures are logged at WARNING and must not affect the review's recorded status in Postgres. + For long-running reviews the RDF is built once (in-memory) and then the + HTTP push is retried with exponential backoff on transient failures + (timeouts, 5xx, 429). Auth and schema errors fail fast on the first try. + Parameters ---------- review_id: @@ -137,18 +181,54 @@ def push_review_to_oxigraph(review_id: str, result) -> Optional[int]: cfg.named_graph = _named_graph_for(review_id) + try: + max_retries = max(1, int(os.getenv("SYNTH_SCHOLAR_GRAPHDB_MAX_RETRIES") or 3)) + except ValueError: + max_retries = 3 + + # Serialise once — retries reuse the same store, so we don't pay the + # rdflib → pyoxigraph conversion cost on every attempt. try: from synthscholar.export import to_oxigraph_store # type: ignore[import-not-found] store = to_oxigraph_store(result) - status = store.push_with_config(cfg) - logger.info( - "[oxigraph_push] pushed review %s to %s (graph <%s>) — HTTP %s", - review_id, cfg.resolved_endpoint, cfg.named_graph, status, - ) - return status except Exception as exc: logger.warning( - "[oxigraph_push] failed for review %s (endpoint=%s graph=%s): %s", - review_id, cfg.resolved_endpoint, cfg.named_graph, exc, + "[oxigraph_push] failed to serialise review %s: %s", review_id, exc, ) return None + + last_exc: Optional[BaseException] = None + for attempt in range(1, max_retries + 1): + try: + status = store.push_with_config(cfg) + if attempt > 1: + logger.info( + "[oxigraph_push] pushed review %s on attempt %d/%d to %s " + "(graph <%s>) — HTTP %s", + review_id, attempt, max_retries, + cfg.resolved_endpoint, cfg.named_graph, status, + ) + else: + logger.info( + "[oxigraph_push] pushed review %s to %s (graph <%s>) — HTTP %s", + review_id, cfg.resolved_endpoint, cfg.named_graph, status, + ) + return status + except Exception as exc: + last_exc = exc + if attempt >= max_retries or not _is_retryable(exc): + break + sleep_s = _BACKOFF[min(attempt, len(_BACKOFF) - 1)] + logger.warning( + "[oxigraph_push] attempt %d/%d failed for review %s " + "(retrying in %ds): %s", + attempt, max_retries, review_id, sleep_s, exc, + ) + time.sleep(sleep_s) + + logger.warning( + "[oxigraph_push] giving up on review %s after %d attempt(s) " + "(endpoint=%s graph=%s): %s", + review_id, max_retries, cfg.resolved_endpoint, cfg.named_graph, last_exc, + ) + return None From 6bf520728a2cd5d9340b4fb1ef0267703204d336 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Tue, 5 May 2026 17:47:33 -0400 Subject: [PATCH 17/22] Update requirements.txt --- ml_service/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml_service/requirements.txt b/ml_service/requirements.txt index 77ec2bd..9707ded 100644 --- a/ml_service/requirements.txt +++ b/ml_service/requirements.txt @@ -51,5 +51,5 @@ ollama==0.6.0 # the AI pipeline; the local module under core/synth_scholar/ is just the # orchestration layer (sessions, SSE, exports). SQLAlchemy 2.0 async backs # the review tables, separate from ml_service's raw-asyncpg pool. -synthscholar[fulltext,semantic]==0.0.7 +synthscholar[fulltext,semantic]==0.0.8 sqlalchemy[asyncio]>=2.0.30 \ No newline at end of file From deabfb0a59f7410c751a9eb9e1b4aabb0b9bf40c Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Tue, 5 May 2026 18:21:49 -0400 Subject: [PATCH 18/22] missing resume feature added. --- ml_service/core/synth_scholar/routes.py | 20 +++ ml_service/core/synth_scholar/store.py | 37 ++++- .../tests/test_resume_preserves_checkpoint.py | 138 ++++++++++++++++++ 3 files changed, 191 insertions(+), 4 deletions(-) create mode 100644 ml_service/core/tests/test_resume_preserves_checkpoint.py diff --git a/ml_service/core/synth_scholar/routes.py b/ml_service/core/synth_scholar/routes.py index 10cac5b..fba17e3 100644 --- a/ml_service/core/synth_scholar/routes.py +++ b/ml_service/core/synth_scholar/routes.py @@ -1117,6 +1117,26 @@ async def retry_review( status_code=500, detail="Cannot retry — original configuration not saved. Please create a new review.", ) + # Log resume intent before reset so we can audit "Resume from step N" + # behaviour end-to-end. If body.resume=True and a checkpoint exists, + # reset_for_retry preserves checkpoint_json + last_completed_step and + # the next pipeline.run picks up from that step (skipping completed work). + _step = session.last_completed_step or 0 + if body.resume and _step > 0: + logger.info( + "[retry] review %s resume=True, preserving checkpoint at step %d", + review_id, _step, + ) + elif body.resume: + logger.info( + "[retry] review %s resume=True but no checkpoint exists — full restart", + review_id, + ) + else: + logger.info( + "[retry] review %s resume=False — clearing any checkpoint, restart from step 0", + review_id, + ) await session.reset_for_retry(clear_checkpoint=not body.resume) run_req = dict(session.run_request) if body.enable_cache is not None: diff --git a/ml_service/core/synth_scholar/store.py b/ml_service/core/synth_scholar/store.py index e599ad9..a560ff5 100644 --- a/ml_service/core/synth_scholar/store.py +++ b/ml_service/core/synth_scholar/store.py @@ -710,18 +710,47 @@ def evict(self, review_id: str): async def fix_stuck_reviews() -> int: - """Mark in-progress reviews as FAILED at server startup.""" + """Mark in-progress reviews as FAILED at server startup. + + The pipeline ran as an in-process asyncio task; if the container restarted + (deploy, OOM, scale-in) while a review was active, that task is gone and + the review is no longer making progress. Two distinct cases: + + 1. **Resumable** — ``checkpoint_json`` is non-NULL. The pipeline had + saved at least one checkpoint, so the user can pick up where it + stopped via the "Resume from step N" action in the UI. We preserve + ``checkpoint_json`` and ``last_completed_step`` (UPDATE doesn't touch + them) and set an error message that hints at this option. + 2. **Not resumable** — no checkpoint yet. Plain retry is the only path. + + Returns the total number of rows updated across both cases. + """ + in_progress = ["running", "pending", "plan_pending"] async with async_session() as db: - result = await db.execute( + # Resumable: checkpoint exists → tell the user they can resume. + resumable = await db.execute( + sa_update(ReviewRow) + .where(ReviewRow.status.in_(in_progress)) + .where(ReviewRow.checkpoint_json.isnot(None)) + .values( + status=ReviewStatus.FAILED.value, + error="Server restarted while review was in progress — " + "use 'Resume from step N' to continue, or 'Full restart' " + "to start over.", + ) + ) + # Not resumable: no checkpoint to fall back on. + not_resumable = await db.execute( sa_update(ReviewRow) - .where(ReviewRow.status.in_(["running", "pending", "plan_pending"])) + .where(ReviewRow.status.in_(in_progress)) + .where(ReviewRow.checkpoint_json.is_(None)) .values( status=ReviewStatus.FAILED.value, error="Server restarted while review was in progress — please retry.", ) ) await db.commit() - return result.rowcount + return (resumable.rowcount or 0) + (not_resumable.rowcount or 0) # Singleton diff --git a/ml_service/core/tests/test_resume_preserves_checkpoint.py b/ml_service/core/tests/test_resume_preserves_checkpoint.py new file mode 100644 index 0000000..4859dd5 --- /dev/null +++ b/ml_service/core/tests/test_resume_preserves_checkpoint.py @@ -0,0 +1,138 @@ +"""Regression test — Resume from step N must preserve the checkpoint. + +Pins the invariant: when a user clicks "Resume from step N" in the UI, the +backend must keep ``checkpoint_json`` and ``last_completed_step`` intact so +the next ``pipeline.run`` call picks up from step N rather than restarting +at step 0. + +Path under test:: + + UI clicks "Resume from step N" + → POST /reviews/{id}/retry body={"resume": true} + → reset_for_retry(clear_checkpoint=False) + → checkpoint_json + last_completed_step UNCHANGED + → asyncio.create_task(_run_pipeline(...)) + → pipeline.run(checkpoint=session.checkpoint_json, ...) + → _ckpt_step = checkpoint["last_completed_step"] (= N) + → all stage guards `if _ckpt_step >= N` skip already-completed work + +Test scope: the in-process portion of ``reset_for_retry`` — we mock +``async_session`` so this runs offline without a real Postgres. +""" + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from unittest.mock import MagicMock, patch + +import pytest + +from core.synth_scholar.schemas import ReviewStatus +from core.synth_scholar.store import ReviewSession + + +def _stub_session(): + """Build a session with a populated checkpoint, as if a 6-hour run failed at step 15.""" + s = ReviewSession(review_id="rv-resume-test-001") + s.status = ReviewStatus.FAILED + s.error = "Server restarted while review was in progress" + s.progress_step = 15 + s.pipeline_log = ["step 1 done", "step 7 done", "step 15 done"] + # Pretend the pipeline checkpointed at step 15 with rich payload data. + s.checkpoint_json = { + "last_completed_step": 15, + "ta_included": [{"pmid": "1"}, {"pmid": "2"}], + "ft_included": [{"pmid": "1"}], + "evidence": [{"text": "snippet", "paper_pmid": "1"}], + "data_charting_rubrics": [{"foo": "bar"}], + "narrative_rows": [], + } + s.last_completed_step = 15 + return s + + +@asynccontextmanager +async def _fake_db_session(*args, **kwargs): + """Async context manager that yields a no-op DB session.""" + db = MagicMock() + db.execute = MagicMock(return_value=asyncio.sleep(0)) + db.commit = MagicMock(return_value=asyncio.sleep(0)) + yield db + + +@pytest.mark.asyncio +async def test_resume_true_preserves_checkpoint_in_memory(): + """``reset_for_retry(clear_checkpoint=False)`` keeps the in-memory checkpoint intact.""" + s = _stub_session() + expected_ckpt = dict(s.checkpoint_json) + expected_step = s.last_completed_step + + with patch("core.synth_scholar.store.async_session", _fake_db_session): + await s.reset_for_retry(clear_checkpoint=False) + + # Status reset for the new run... + assert s.status == ReviewStatus.PENDING + assert s.progress_step == 0 + assert s.pipeline_log == [] + assert s.error is None + # ...but the checkpoint must survive verbatim. + assert s.checkpoint_json == expected_ckpt, ( + "Resume from step N path lost the checkpoint payload — " + "the next pipeline.run will restart from step 0" + ) + assert s.last_completed_step == expected_step + + +@pytest.mark.asyncio +async def test_resume_false_clears_checkpoint(): + """``reset_for_retry(clear_checkpoint=True)`` (Full restart) must wipe checkpoint.""" + s = _stub_session() + + with patch("core.synth_scholar.store.async_session", _fake_db_session): + await s.reset_for_retry(clear_checkpoint=True) + + assert s.checkpoint_json is None + assert s.last_completed_step == 0 + # Status / progress also reset (covered above; assert here too for completeness). + assert s.status == ReviewStatus.PENDING + + +@pytest.mark.asyncio +async def test_resume_db_update_omits_checkpoint_columns(): + """When clear_checkpoint=False, the DB UPDATE must NOT touch checkpoint columns. + + Otherwise a buggy implementation could clear the row's checkpoint in the + database while leaving the in-memory copy intact — the next worker that + loads from the DB would see no checkpoint and silently restart from step 0. + """ + s = _stub_session() + captured_values: list[dict] = [] + + @asynccontextmanager + async def _capturing_session(*args, **kwargs): + db = MagicMock() + async def _exec(stmt): + # SQLAlchemy update().values(**kwargs) keeps the kwargs on the + # compiled statement — we read them back via the internal + # _values mapping. + try: + vals = dict(stmt.compile().params) + except Exception: + vals = dict(getattr(stmt, "_values", {}) or {}) + captured_values.append(vals) + return MagicMock(rowcount=1) + db.execute = _exec + async def _commit(): return None + db.commit = _commit + yield db + + with patch("core.synth_scholar.store.async_session", _capturing_session): + await s.reset_for_retry(clear_checkpoint=False) + + assert captured_values, "Expected at least one DB UPDATE to be issued" + update_vals = captured_values[0] + # The whole point: these two columns must NOT be in the UPDATE values + # when clear_checkpoint=False. + assert "checkpoint_json" not in update_vals + assert "last_completed_step" not in update_vals From c2a9ddf5c42ed7e8e5e23f181a7ba0800a2354e4 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Tue, 5 May 2026 18:51:11 -0400 Subject: [PATCH 19/22] pg bad character issue fixed --- ml_service/core/synth_scholar/store.py | 63 ++++++++++-- ml_service/core/tests/test_pg_scrub.py | 136 +++++++++++++++++++++++++ ml_service/requirements.txt | 2 +- 3 files changed, 192 insertions(+), 9 deletions(-) create mode 100644 ml_service/core/tests/test_pg_scrub.py diff --git a/ml_service/core/synth_scholar/store.py b/ml_service/core/synth_scholar/store.py index a560ff5..2bb6769 100644 --- a/ml_service/core/synth_scholar/store.py +++ b/ml_service/core/synth_scholar/store.py @@ -42,6 +42,45 @@ _WRITE_BATCH = 10 +# ── Postgres-safe text sanitiser ───────────────────────────────────────── +# PostgreSQL's TEXT and JSONB types reject the NUL byte (\x00) — asyncpg +# raises: UntranslatableCharacterError: unsupported Unicode escape sequence +# DETAIL: \x00 cannot be converted to text. +# NUL bytes occasionally leak in via PDF full-text extraction (corrupt or +# OCR-generated PDFs), upstream API responses, or LLM output. The agent +# normally strips them at source, but we run this defensive scrub at every +# write boundary so a single dirty byte can never abort a 6-hour review. +# +# We also strip the rest of the C0 set EXCEPT \t \n \r — those are valid +# in Postgres TEXT and frequently appear in legitimate prose (newlines in +# multi-line summaries, tab-separated tables, etc.). + +import re as _re_sanitise + +_PG_BAD_BYTES = _re_sanitise.compile(r"[\x00-\x08\x0B\x0C\x0E-\x1F]") + + +def _scrub_pg_unsafe(obj): + """Recursively replace Postgres-illegal control bytes in any JSON-able value. + + Strings: forbidden control bytes are removed (not escaped — these were + never meant to be in user-visible text in the first place). + Lists/tuples: walked element-wise. + Dicts: walked value-wise (keys are left alone — they're field names from + Pydantic models, can't legally contain control bytes). + Other types (int, float, bool, None, datetime): returned unchanged. + """ + if isinstance(obj, str): + return _PG_BAD_BYTES.sub("", obj) if _PG_BAD_BYTES.search(obj) else obj + if isinstance(obj, list): + return [_scrub_pg_unsafe(v) for v in obj] + if isinstance(obj, tuple): + return tuple(_scrub_pg_unsafe(v) for v in obj) + if isinstance(obj, dict): + return {k: _scrub_pg_unsafe(v) for k, v in obj.items()} + return obj + + # ── Runtime-only state ────────────────────────────────────────────────── @dataclass @@ -202,7 +241,7 @@ async def _persist_progress(self): .where(ReviewRow.review_id == self.review_id) .values( progress_step=self.progress_step, - pipeline_log=list(self.pipeline_log[-2000:]), + pipeline_log=_scrub_pg_unsafe(list(self.pipeline_log[-2000:])), status=effective_status, stage=self.stage, stage_idx=self.stage_index, @@ -210,7 +249,7 @@ async def _persist_progress(self): stage_done_count=self.stage_done or 0, stage_remaining=self.stage_remaining, articles_included=self.articles_included, - latest_message=self._live.latest_message or None, + latest_message=_scrub_pg_unsafe(self._live.latest_message) or None, ) ) await db.commit() @@ -230,17 +269,21 @@ async def mark_completed(self, result: PRISMAReviewResult) -> None: .where(ReviewRow.review_id == self.review_id) .values( status=self.status.value, - result_json=self.result.model_dump() if self.result else None, + # Scrub Postgres-illegal control bytes (NUL etc.) from + # the entire result tree — they sometimes leak in via + # PDF full-text extraction and would otherwise abort + # the UPDATE with UntranslatableCharacterError. + result_json=_scrub_pg_unsafe(self.result.model_dump()) if self.result else None, completed_at=datetime.fromisoformat(self.completed_at), progress_step=self.progress_step, - pipeline_log=list(self.pipeline_log), + pipeline_log=_scrub_pg_unsafe(list(self.pipeline_log)), stage=self.stage, stage_idx=self.stage_index, stage_total=self.stage_total, stage_done_count=self.stage_done or 0, stage_remaining=self.stage_remaining, articles_included=self.articles_included, - latest_message=self._live.latest_message or None, + latest_message=_scrub_pg_unsafe(self._live.latest_message) or None, pending_plan_json=None, pending_plan_iteration=None, plan_response_json=None, @@ -283,10 +326,10 @@ async def _persist_failed_minimal(self) -> None: .where(ReviewRow.review_id == self.review_id) .values( status=self.status.value, - error=self.error, + error=_scrub_pg_unsafe(self.error), completed_at=datetime.fromisoformat(self.completed_at) if self.completed_at else None, progress_step=self.progress_step, - pipeline_log=list(self.pipeline_log[-2000:]), + pipeline_log=_scrub_pg_unsafe(list(self.pipeline_log[-2000:])), pending_plan_json=None, pending_plan_iteration=None, plan_response_json=None, @@ -385,12 +428,16 @@ async def save_checkpoint(self, state: dict) -> None: step = state.get("last_completed_step", 0) self.checkpoint_json = state self.last_completed_step = step + # Checkpoints contain intermediate stage payloads (full-text-derived + # evidence, charting rubrics, etc.) so they need the same NUL-byte + # scrub the final result_json gets. + scrubbed = _scrub_pg_unsafe(state) try: async with async_session() as db: await db.execute( sa_update(ReviewRow) .where(ReviewRow.review_id == self.review_id) - .values(checkpoint_json=state, last_completed_step=step) + .values(checkpoint_json=scrubbed, last_completed_step=step) ) await db.commit() except Exception as exc: diff --git a/ml_service/core/tests/test_pg_scrub.py b/ml_service/core/tests/test_pg_scrub.py new file mode 100644 index 0000000..8b06e53 --- /dev/null +++ b/ml_service/core/tests/test_pg_scrub.py @@ -0,0 +1,136 @@ +"""Regression test — Postgres-illegal control bytes must be scrubbed before write. + +Reproduces the failure mode reported in the user's incident:: + + asyncpg.exceptions.UntranslatableCharacterError: + unsupported Unicode escape sequence + DETAIL: \\u0000 cannot be converted to text. + +A NUL byte (or other forbidden C0 control byte) was leaking into a 4 MB +``result_json`` payload from PDF full-text extraction and aborting the +``mark_completed`` UPDATE. ``_scrub_pg_unsafe`` is the defensive scrubber at +every BrainKB write boundary; this test pins its contract. +""" + +from __future__ import annotations + +import json + +from core.synth_scholar.store import _scrub_pg_unsafe + + +# ── Strings ─────────────────────────────────────────────────────────────── + + +def test_strips_nul_byte(): + assert _scrub_pg_unsafe("hello\x00world") == "helloworld" + + +def test_strips_other_forbidden_c0_bytes(): + """All C0 except \\t \\n \\r are illegal in Postgres TEXT.""" + raw = "BEL \x07 BS \x08 VT \x0b FF \x0c SO \x0e SI \x0f" + out = _scrub_pg_unsafe(raw) + assert out == "BEL BS VT FF SO SI " + + +def test_preserves_legitimate_whitespace(): + """\\t \\n \\r are legitimate prose whitespace and must NOT be touched.""" + raw = "first\nsecond\tthird\rfourth" + assert _scrub_pg_unsafe(raw) == raw + + +def test_clean_string_passthrough(): + """No regex overhead for clean strings (fast path).""" + s = "ordinary prose with no control bytes" + assert _scrub_pg_unsafe(s) is s or _scrub_pg_unsafe(s) == s + + +# ── Container types ─────────────────────────────────────────────────────── + + +def test_walks_lists_recursively(): + log = [ + "[2026-05-05] entry one", + "[2026-05-05] entry with NUL \x00 buried", + "[2026-05-05] entry with BEL \x07", + ] + out = _scrub_pg_unsafe(log) + assert out == [ + "[2026-05-05] entry one", + "[2026-05-05] entry with NUL buried", + "[2026-05-05] entry with BEL ", + ] + + +def test_walks_dicts_recursively(): + """Reproduces the actual result_json shape from the incident.""" + dirty = { + "research_question": "ADHD biomarkers", + "evidence_spans": [ + {"text": "PDF text with \x00 from extractor", "paper_pmid": "1"}, + {"text": "good prose\nwith newlines", "paper_pmid": "2"}, + ], + "synthesis_text": "long body with \x00 at char 1234", + "flow": {"total_identified": 247}, + } + out = _scrub_pg_unsafe(dirty) + assert out["evidence_spans"][0]["text"] == "PDF text with from extractor" + assert "\x00" not in out["synthesis_text"] + # Whitespace preserved on the clean spans: + assert "\n" in out["evidence_spans"][1]["text"] + + +def test_serialised_output_is_postgres_safe(): + """The final smoke test — feed the scrubbed dict through json.dumps and + confirm the serialisation never contains a forbidden byte. This is the + exact data shape asyncpg sends to Postgres for JSONB columns.""" + dirty = { + "a": "x\x00y", + "b": ["\x07", "\x0c", "fine"], + "c": {"d": "with \x01 SOH and \x1f US"}, + } + serialised = json.dumps(_scrub_pg_unsafe(dirty)) + forbidden = {chr(c) for c in range(0x20)} - {"\t", "\n", "\r"} + for ch in forbidden: + # json.dumps escapes \x00 as "\\u0000" (six chars), so we check for + # the raw byte AND its escaped form. + assert ch not in serialised, f"raw byte {ord(ch):#04x} survived" + assert f"\\u{ord(ch):04x}" not in serialised, ( + f"escaped \\u{ord(ch):04x} survived in {serialised!r}" + ) + + +# ── Non-string types passthrough ───────────────────────────────────────── + + +def test_int_unchanged(): + assert _scrub_pg_unsafe(247) == 247 + + +def test_none_unchanged(): + assert _scrub_pg_unsafe(None) is None + + +def test_bool_unchanged(): + assert _scrub_pg_unsafe(True) is True + assert _scrub_pg_unsafe(False) is False + + +def test_dict_keys_left_alone(): + """Keys are Pydantic field names — they can't legally contain control bytes, + so we don't pay the cost of scrubbing every key on every write.""" + out = _scrub_pg_unsafe({"clean_key": "value with \x00 NUL"}) + assert "clean_key" in out + assert out["clean_key"] == "value with NUL" + + +def test_tuple_walked(): + """Tuples are rare in JSON-able payloads but we support them defensively.""" + out = _scrub_pg_unsafe(("a\x00", "b")) + assert out == ("a", "b") + + +def test_empty_containers(): + assert _scrub_pg_unsafe([]) == [] + assert _scrub_pg_unsafe({}) == {} + assert _scrub_pg_unsafe("") == "" diff --git a/ml_service/requirements.txt b/ml_service/requirements.txt index 9707ded..e5add44 100644 --- a/ml_service/requirements.txt +++ b/ml_service/requirements.txt @@ -51,5 +51,5 @@ ollama==0.6.0 # the AI pipeline; the local module under core/synth_scholar/ is just the # orchestration layer (sessions, SSE, exports). SQLAlchemy 2.0 async backs # the review tables, separate from ml_service's raw-asyncpg pool. -synthscholar[fulltext,semantic]==0.0.8 +synthscholar[fulltext,semantic]==0.0.9 sqlalchemy[asyncio]>=2.0.30 \ No newline at end of file From 243d4d8ac98513d32b5bdb16b9fd96dc34ba32bd Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Tue, 5 May 2026 21:15:38 -0400 Subject: [PATCH 20/22] fix openrouter key issue + cancel + resume --- ml_service/core/synth_scholar/routes.py | 27 ++++++++++++++++++++++++ ml_service/core/synth_scholar/schemas.py | 7 ++++++ 2 files changed, 34 insertions(+) diff --git a/ml_service/core/synth_scholar/routes.py b/ml_service/core/synth_scholar/routes.py index fba17e3..73c935f 100644 --- a/ml_service/core/synth_scholar/routes.py +++ b/ml_service/core/synth_scholar/routes.py @@ -1117,6 +1117,16 @@ async def retry_review( status_code=500, detail="Cannot retry — original configuration not saved. Please create a new review.", ) + # Diagnostic — proves THIS version of the handler is running. If you + # see the legacy "No OpenRouter API key available. Configure one in the + # dashboard…" message but DON'T see this log line for the same review_id, + # the backend container is still running pre-fix code and needs a restart. + logger.info( + "[retry] handler v2 entered: review=%s resume=%s enable_cache=%s key_present=%s", + review_id, body.resume, body.enable_cache, + bool(body.openrouter_api_key), + ) + # Log resume intent before reset so we can audit "Resume from step N" # behaviour end-to-end. If body.resume=True and a checkpoint exists, # reset_for_retry preserves checkpoint_json + last_completed_step and @@ -1142,6 +1152,23 @@ async def retry_review( if body.enable_cache is not None: run_req["enable_cache"] = body.enable_cache session.run_request = run_req + # Re-attach the OpenRouter key for this run only. The key is intentionally + # never stored with `run_request` (security), so it MUST come from the + # retry body. We validate it before re-spawning to give a 400 here + # rather than letting the pipeline fail asynchronously with the same + # "no key available" error and silently re-mark the review FAILED. + if not body.openrouter_api_key: + raise HTTPException( + status_code=400, + detail=( + "No OpenRouter API key provided for retry. The key is not " + "persisted with reviews — your client must include it in the " + "retry request body. Configure one in the dashboard's API " + "key tab, then click Resume / Retry again." + ), + ) + _resolve_api_key(body.openrouter_api_key) # validates env vs. shared vs. provided + run_req["openrouter_api_key"] = body.openrouter_api_key review_store._runtime[review_id] = session._live if run_req.get("compare_mode"): compare_request = CompareRunRequest(**{k: v for k, v in run_req.items() if k != "compare_mode"}) diff --git a/ml_service/core/synth_scholar/schemas.py b/ml_service/core/synth_scholar/schemas.py index 9800995..1539d63 100644 --- a/ml_service/core/synth_scholar/schemas.py +++ b/ml_service/core/synth_scholar/schemas.py @@ -110,6 +110,13 @@ class CompareRunRequest(BaseModel): class RetryRequest(BaseModel): enable_cache: Optional[bool] = None resume: bool = True + # OpenRouter API key for the re-run. The key is intentionally NOT + # persisted with the original review (`run_request` is stored with + # ``exclude={"openrouter_api_key"}`` for security), so every retry has + # to provide it again — same shape as ``RunReviewRequest`` / + # ``CompareRunRequest``. Excluded from logs / model dumps via + # ``exclude=True, repr=False``. + openrouter_api_key: Optional[str] = Field(default=None, exclude=True, repr=False) class PlanResponseRequest(BaseModel): From 9472cfa1b045fda441a78409b6f32c14a8a1a0c4 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Tue, 5 May 2026 21:31:12 -0400 Subject: [PATCH 21/22] iri updated which was causing rdf data, i.e., review issue --- .../synth_scholar/backfill_oxigraph_push.py | 200 ++++++++++++++++++ .../core/synth_scholar/oxigraph_push.py | 4 +- ml_service/requirements.txt | 2 +- 3 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 ml_service/core/synth_scholar/backfill_oxigraph_push.py diff --git a/ml_service/core/synth_scholar/backfill_oxigraph_push.py b/ml_service/core/synth_scholar/backfill_oxigraph_push.py new file mode 100644 index 0000000..8ef062f --- /dev/null +++ b/ml_service/core/synth_scholar/backfill_oxigraph_push.py @@ -0,0 +1,200 @@ +"""One-shot backfill — push every completed review's RDF to Oxigraph. + +Use cases: + +* You ran reviews **before** the auto-push (mark_completed → oxigraph_push) + was wired in. Their result_json is sitting in Postgres but no triples + ever reached Oxigraph. +* The auto-push silently failed for a batch of reviews (e.g. the bare-IRI + serialisation bug fixed in synthscholar 0.0.10) and you want to re-push + them after upgrading. +* You're standing up a fresh Oxigraph and want to populate it from the + durable Postgres source of truth. + +Each push uses ``GraphDBConfig.replace=True`` (the default for BrainKB), +so running this script twice is **idempotent** — the second run overwrites +the same named graph rather than duplicating triples. + +Usage (run inside the brainkb-unified container, where the Postgres DSN +and GRAPHDATABASE_* env vars are already set):: + + # Push every completed review: + docker exec brainkb-unified python -m core.synth_scholar.backfill_oxigraph_push + + # Just one review: + docker exec brainkb-unified python -m core.synth_scholar.backfill_oxigraph_push \\ + --review-id review_0001_20260505234027 + + # Inspect first; don't push: + docker exec brainkb-unified python -m core.synth_scholar.backfill_oxigraph_push \\ + --dry-run --limit 5 +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import sys +from typing import Any, Optional + +from sqlalchemy import select + +from .database import async_session +from .db_models import ReviewRow +from .oxigraph_push import push_review_to_oxigraph + +logger = logging.getLogger("backfill_oxigraph_push") + + +def _build_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + prog="backfill_oxigraph_push", + description=__doc__.split("\n\n", 1)[0], + ) + p.add_argument( + "--review-id", + action="append", + default=[], + metavar="REVIEW_ID", + help="Push only this review_id. Repeatable. If omitted, every " + "completed review with a result_json is pushed.", + ) + p.add_argument( + "--limit", + type=int, + default=None, + metavar="N", + help="Cap how many reviews to push (after filtering). Useful for " + "smoke-testing before a full backfill.", + ) + p.add_argument( + "--dry-run", + action="store_true", + help="Don't actually push — just print which reviews would be processed.", + ) + p.add_argument( + "-v", "--verbose", + action="store_true", + help="Verbose logging (DEBUG level).", + ) + return p + + +async def _select_reviews( + only_ids: list[str] | None = None, +) -> list[tuple[str, dict[str, Any]]]: + """Return [(review_id, result_json), ...] for completed reviews with payload. + + Filters out rows where ``result_json`` is NULL — those reviews never + reached the synthesis stage and have nothing to serialise. + """ + async with async_session() as db: + stmt = ( + select(ReviewRow.review_id, ReviewRow.result_json) + .where(ReviewRow.status == "completed") + .where(ReviewRow.result_json.isnot(None)) + .order_by(ReviewRow.completed_at.asc()) + ) + if only_ids: + stmt = stmt.where(ReviewRow.review_id.in_(only_ids)) + result = await db.execute(stmt) + rows: list[tuple[str, dict[str, Any]]] = [ + (rid, payload) for rid, payload in result.all() + ] + return rows + + +def _reconstruct_result(payload: dict[str, Any]) -> Optional[Any]: + """Materialise a PRISMAReviewResult from the persisted JSONB payload. + + Returns None if synthscholar can't be imported or validation fails + (callers should treat this as a skip, not a hard error — backfill + should keep going for other reviews). + """ + try: + from synthscholar.models import PRISMAReviewResult # type: ignore[import-not-found] + except Exception as exc: + logger.error( + "Cannot import synthscholar.models — install the agent package " + "in this container before running the backfill: %s", exc, + ) + return None + try: + return PRISMAReviewResult.model_validate(payload) + except Exception as exc: + logger.warning( + "PRISMAReviewResult.model_validate failed (skipping): %s", + str(exc)[:300], + ) + return None + + +async def main_async(argv: list[str] | None = None) -> int: + args = _build_arg_parser().parse_args(argv) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s %(levelname)-7s %(name)s: %(message)s", + ) + + rows = await _select_reviews(args.review_id or None) + if args.limit is not None: + rows = rows[: args.limit] + + if not rows: + logger.warning( + "No completed reviews with a result_json matched. " + "Did you pass --review-id for a review that hasn't completed?" + ) + return 0 + + logger.info( + "Backfill plan: %d review%s%s", + len(rows), + "" if len(rows) == 1 else "s", + " (dry-run — no pushes will fire)" if args.dry_run else "", + ) + + pushed = 0 + skipped = 0 + failed = 0 + + for review_id, payload in rows: + if args.dry_run: + logger.info("[dry-run] would push %s", review_id) + continue + + result = _reconstruct_result(payload) + if result is None: + skipped += 1 + continue + + # push_review_to_oxigraph never raises — it returns the HTTP status + # on success, or None on skip/failure (and logs WARNING). Counts + # below mirror that contract. + status = await asyncio.to_thread(push_review_to_oxigraph, review_id, result) + if status is None: + failed += 1 + logger.warning("✗ push returned None for %s — see warning above", review_id) + else: + pushed += 1 + logger.info("✓ pushed %s — HTTP %s", review_id, status) + + logger.info( + "Backfill done: pushed=%d skipped=%d failed=%d (out of %d total)", + pushed, skipped, failed, len(rows), + ) + # Non-zero exit if anything failed AND we were not in dry-run mode, so + # CI / cron drivers can surface the problem. + return 1 if (failed > 0 and not args.dry_run) else 0 + + +def main() -> int: + try: + return asyncio.run(main_async()) + except KeyboardInterrupt: + return 130 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/ml_service/core/synth_scholar/oxigraph_push.py b/ml_service/core/synth_scholar/oxigraph_push.py index 27d36e8..e8acf2c 100644 --- a/ml_service/core/synth_scholar/oxigraph_push.py +++ b/ml_service/core/synth_scholar/oxigraph_push.py @@ -27,7 +27,7 @@ ``SYNTH_SCHOLAR_GRAPHDB_PATH`` Endpoint path (default ``/store`` for GSP, use ``/update`` for SPARQL Update). ``SYNTH_SCHOLAR_GRAPHDB_NAMED_GRAPH_PREFIX`` IRI prefix for review-specific named graphs. - Default ``https://brainkb.org/reviews/``. + Default ``https://brainkb.org/synthscholar/reviews/``. Each review lands at ``{prefix}{review_id}``. ``SYNTH_SCHOLAR_GRAPHDB_PROTOCOL`` ``gsp`` (default) or ``update``. ``SYNTH_SCHOLAR_GRAPHDB_REPLACE`` If ``true``, replace the named graph each @@ -130,7 +130,7 @@ def _named_graph_for(review_id: str) -> str: """IRI for the named graph holding *review_id*'s triples.""" prefix = os.getenv( "SYNTH_SCHOLAR_GRAPHDB_NAMED_GRAPH_PREFIX", - "https://brainkb.org/reviews/", + "https://brainkb.org/synthscholar/reviews/", ) if not prefix.endswith("/"): prefix += "/" diff --git a/ml_service/requirements.txt b/ml_service/requirements.txt index e5add44..2fac77d 100644 --- a/ml_service/requirements.txt +++ b/ml_service/requirements.txt @@ -51,5 +51,5 @@ ollama==0.6.0 # the AI pipeline; the local module under core/synth_scholar/ is just the # orchestration layer (sessions, SSE, exports). SQLAlchemy 2.0 async backs # the review tables, separate from ml_service's raw-asyncpg pool. -synthscholar[fulltext,semantic]==0.0.9 +synthscholar[fulltext,semantic]==0.0.10 sqlalchemy[asyncio]>=2.0.30 \ No newline at end of file From c6d46f77e61797424d590ae7a98c0626621fae22 Mon Sep 17 00:00:00 2001 From: Tek Raj Chhetri Date: Tue, 5 May 2026 22:04:40 -0400 Subject: [PATCH 22/22] Update requirements.txt --- ml_service/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml_service/requirements.txt b/ml_service/requirements.txt index 2fac77d..dc518ce 100644 --- a/ml_service/requirements.txt +++ b/ml_service/requirements.txt @@ -51,5 +51,5 @@ ollama==0.6.0 # the AI pipeline; the local module under core/synth_scholar/ is just the # orchestration layer (sessions, SSE, exports). SQLAlchemy 2.0 async backs # the review tables, separate from ml_service's raw-asyncpg pool. -synthscholar[fulltext,semantic]==0.0.10 +synthscholar[fulltext,semantic]==0.0.11 sqlalchemy[asyncio]>=2.0.30 \ No newline at end of file