diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/.dockerignore b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/.dockerignore new file mode 100644 index 000000000..352a2f365 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/.dockerignore @@ -0,0 +1,53 @@ +# Python artifacts +__pycache__/ +*.pyc +*.pyo +*.pyd +*.egg-info/ +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ + +# Virtual environments +.venv/ +venv/ +env/ + +# Jupyter +.ipynb_checkpoints/ + +# Git / tooling +.git/ +.gitignore +.gitattributes +.claude/ + +# Editor / OS +.DS_Store +.idea/ +.vscode/ +*.swp + +# Secrets and local env +.env +.env.local +*.pem +*.key + +# Generated / regenerable artifacts +data/ +logs/ +*.log +collection_stats.json +storage_status.json +docker_build.version.log +tmp.pytest.log + +# Docker build scripts (not needed inside the image) +docker_*.sh +docker_name.sh +Dockerfile.python_slim +Dockerfile.uv + +# Docs (the runtime doesn't read these) +*.md diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/.gitignore b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/.gitignore new file mode 100644 index 000000000..a42c47cd6 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/.gitignore @@ -0,0 +1,178 @@ +# ============================================================================= +# Python / General +# ============================================================================= + +# Python virtual environments +.venv/ +venv/ +env/ +ENV/ +.conda/ + +# Python cache and bytecode +__pycache__/ +*.py[cod] +*$py.class +*.pyc +*.pyo +.Python + +# Python build artifacts +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +*.egg +*.manifest +*.spec + +# ============================================================================= +# Testing and Coverage +# ============================================================================= +.pytest_cache/ +.coverage +.coverage.* +htmlcov/ +.tox/ +.nox/ +coverage.xml +*.cover +*.py,cover + +# ============================================================================= +# Type Checkers and Linters +# ============================================================================= +.mypy_cache/ +.dmypy.json +dmypy.json +.pyrightcache/ +.ruff_cache/ +.ruff_cache/ +.lint_cache/ + +# ============================================================================= +# IDE and Editor Settings +# ============================================================================= + +# JetBrains (PyCharm, IntelliJ, etc.) +.idea/ +*.iml +*.ipr +*.iws + +# VS Code +.vscode/ +*.code-workspace + +# Vim +*.swp +*.swo +*~ +*.vim + +# Emacs +*~ +\#*\# +.\#* + +# macOS +.DS_Store +.AppleDouble +.LSOverride + +# Windows +Thumbs.db +ehthumbs.db +Desktop.ini + +# ============================================================================= +# Environment Variables and Secrets +# ============================================================================= +.env +.env.local +.env.*.local +.env.development +.env.production +.env.test +*.env + +# Secrets and credentials +*.pem +*.key +secrets.json +credentials.json +service_account.json + +# ============================================================================= +# Claude and AI Tools +# ============================================================================= +.claude/ +.claude-code/ +*.claude.local.* + +# ============================================================================= +# Jupyter Notebooks +# ============================================================================= +.ipynb_checkpoints/ +.jupyter/ +*.ipynb_checkpoints + +# ============================================================================= +# Logs and Temp Files +# ============================================================================= +*.log +logs/ +tmp/ +temp/ +tmp/ +scratch/ +.scratch/ + +# ============================================================================= +# Local pipeline artifacts (regenerated by collector / status scripts) +# ============================================================================= +collection_stats.json +storage_status.json + +# ============================================================================= +# Data and Model Artifacts +# ============================================================================= +data/ +!data/.gitkeep +models/ +!models/.gitkeep +checkpoints/ +*.h5 +*.pth +*.pt +*.onnx + +# ============================================================================= +# Docker +# ============================================================================= +*.dockerfile.local +docker-compose.override.yml + +# ============================================================================= +# Local Configuration (Keep Templates) +# ============================================================================= +*.local.yaml +*.local.json +*.local.yml +config.local.* + +# ============================================================================= +# Documentation Build +# ============================================================================= +docs/_build/ +site/ +__pypackages__/ diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/CLAUDE.md b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/CLAUDE.md new file mode 100644 index 000000000..49f4bcf5d --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/CLAUDE.md @@ -0,0 +1,128 @@ +# CLAUDE.md + +- Guidance for Claude Code when working in this project +- The repo-wide `CLAUDE.md` at `/Users/gprakash/src/umd_classes1/CLAUDE.md` still + applies (coding style, testing, notebook, markdown rules) — this file only + adds project-specific context + +# What This Project Is + +- A txtai-based market research platform for the DATA605 Spring 2026 course + (`UmdTask430`) +- End-to-end it does three things: + - **Collect**: SEC EDGAR filings + news (NewsAPI, Alpha Vantage) into a + four-tier store + - **Search**: an agentic pipeline routes a query to sub-agents (`sec`, + `news`), retrieves top-k chunks, and synthesizes a cited answer + - **Serve**: a FastAPI service with a Streamlit UI +- See `RUN_INSTRUCTIONS.md` for the full quickstart, env vars, and + troubleshooting — do not duplicate that here + +# Repository Layout + +- `app/`: application code (importable as `app.*`) + - `agents/research_agent.py`: core agentic pipeline; `run_research(query)` + streams events, `run_research_sync(query)` returns a single dict + - `agents/{diligence,earnings,regulatory,sentiment,web_research,orchestrator}.py`: + domain agents used by the dashboard/chat UI + - `api/server.py`: FastAPI app — `GET /`, `POST /research`, + `POST /research/stream` (SSE) + - `collectors/`: `base_collector.py`, `sec_collector.py`, + `news_collector.py` + - `pipeline/`: `ingest.py` (chunking/normalization), `embeddings.py` (txtai + index) + - `storage/`: four-tier storage clients + - `hot_storage/`: KeyDB (live prices, semantic cache, sessions) + - `warm_storage/`: PostgreSQL + pgvector (filings, chunks, XBRL facts) + - `cold_storage/`: MinIO (raw filings archive) + - `cache_manager.py`: thin wrapper over KeyDB + - `ui/`: Streamlit pages — `research.py` (agent chat), `dashboard.py`, + `chat.py`; entrypoint is `app/main.py` +- `scripts/`: one-shot CLIs (`run_sec_collector`, `run_sec_bulk`, + `run_all_collectors`, `backfill_txtai_from_chunks`, `eval_research`, + `check_storage_status`) +- `sql/init.sql`: PostgreSQL schema (mounted into the postgres container) +- `data/`: persisted txtai index (`documents`, `embeddings`, `index.db`, + `config.json`) — large binary files, do not commit +- `notebooks/`: jupytext-paired notebooks + - `txtai.API.{ipynb,py}`: txtai library primitives in isolation + - `txtai.example.{ipynb,py}`: full ingest → search → agent demo +- `docker-compose.yml`: brings up KeyDB, PostgreSQL+pgvector, MinIO + +# Common Commands + +- Bring up infra (KeyDB, Postgres+pgvector, MinIO): + ```bash + > docker-compose up -d + > docker-compose ps + ``` +- Install deps (Python 3.11+): + ```bash + > python -m venv .venv && source .venv/bin/activate + > pip install -r requirements.txt + ``` +- Collect data (one-time): + ```bash + > python -m scripts.run_sec_bulk --group all --skip-existing --limit 10 + > python -m scripts.run_all_collectors --tickers AAPL,MSFT,NVDA --skip-sec --no-search + > python -m scripts.backfill_txtai_from_chunks --from-scratch + ``` +- Run API + UI: + ```bash + > uvicorn app.api.server:app --host 127.0.0.1 --port 8000 & + > streamlit run app/ui/research.py --server.port 8501 + ``` +- Evaluate the pipeline: + ```bash + > python -m scripts.eval_research --warmup + > python -m scripts.eval_research --repeats 5 --json logs/eval.json + ``` +- Inspect storage state: + ```bash + > python -m scripts.check_storage_status + ``` + +# Configuration + +- Secrets and connection strings live in `.env` (template: `.env.example`) +- Required keys for a full run: + - `SEC_USER_AGENT`: SEC EDGAR requires a real contact email + - `NEWSAPI_KEY`, `ALPHAVANTAGE_API_KEY`: news collectors + - `OPENAI_API_KEY`: embeddings (txtai default backend) +- Optional for LLM-backed answer synthesis: + - `LLM_BASE_URL`, `LLM_API_KEY`, `LLM_MODEL` (any OpenAI-compatible endpoint, + including local Ollama) — without these the synthesizer falls back to an + extractive template +- Never read or write secrets in code; always pull from environment variables + +# Conventions Specific to This Project + +- Python imports use module-qualified names (`from app.storage import ...`, + `from app.agents.research_agent import ...`) — keep it that way +- Storage clients are accessed via factory helpers (`get_keydb_client`, + `get_postgres_client`, `get_minio_client`, `get_cache_manager`, + `get_embeddings`) — do not instantiate clients directly in new code +- Collectors inherit from `app/collectors/base_collector.py`; new sources + should follow the same `fetch → normalize → store` flow +- Long-running collection scripts must support `--skip-existing` and + `--no-search` style toggles so partial runs are cheap +- Logs go to `logs/` (gitignored); persistent artifacts go to `data/` + +# Things to Avoid + +- Do not commit `.env`, `data/`, `logs/`, `.venv/`, or anything in + `**/__pycache__/` +- Do not edit `data/{documents,embeddings,index.db}` by hand — rebuild via + `scripts.backfill_txtai_from_chunks` +- Do not bypass the storage tier abstraction (e.g. talking directly to psycopg + or boto from agent code) — go through `app.storage` +- Do not add new top-level docs files for one-off notes; extend + `RUN_INSTRUCTIONS.md` or this file + +# When in Doubt + +- Architecture / data flow: `RUN_INSTRUCTIONS.md` and `app/storage/README.md` +- Schema: `sql/init.sql` +- Agent behavior: `app/agents/research_agent.py` (router → retrievers → + synthesizer) +- Eval / benchmarks: `scripts/eval_research.py` diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/Dockerfile b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/Dockerfile new file mode 100644 index 000000000..fecc269a7 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/Dockerfile @@ -0,0 +1,39 @@ +# txtai Market Research Platform - Docker Configuration + +FROM python:3.11-slim + +# Set working directory +WORKDIR /app + +# Set environment variables +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PIP_NO_CACHE_DIR=1 + +# Install system dependencies (curl for healthchecks). +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first for better caching +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Create data directory for txtai index +RUN mkdir -p /app/data + +# Expose API and Streamlit ports. +EXPOSE 8000 8501 + +# Set environment variable for data directory. +ENV TXTAI_DATA_DIR=/app/data + +# Default command runs the FastAPI server; docker-compose overrides this for +# the Streamlit UI service. +CMD ["uvicorn", "app.api.server:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/Dockerfile.python_slim b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/Dockerfile.python_slim new file mode 100644 index 000000000..6f1204a83 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/Dockerfile.python_slim @@ -0,0 +1,28 @@ +# Use Python 3.12 slim (already has Python and pip). +FROM python:3.12-slim + +# Avoid interactive prompts during apt operations. +ENV DEBIAN_FRONTEND=noninteractive + +# Install CA certificates (needed for HTTPS). +RUN apt-get update && apt-get install -y \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Install project specific packages. +RUN mkdir -p /install +COPY requirements.txt /install/requirements.txt +RUN pip install --upgrade pip && \ + pip install --no-cache-dir jupyterlab jupyterlab_vim -r /install/requirements.txt + +# Config. +COPY etc_sudoers /install/ +COPY etc_sudoers /etc/sudoers +COPY bashrc /root/.bashrc + +# Report package versions. +COPY version.sh /install/ +RUN /install/version.sh 2>&1 | tee version.log + +# Jupyter. +EXPOSE 8888 diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/Dockerfile.uv b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/Dockerfile.uv new file mode 100644 index 000000000..ac27d2aeb --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/Dockerfile.uv @@ -0,0 +1,48 @@ +FROM ubuntu:24.04 +ENV DEBIAN_FRONTEND noninteractive + +# Install system utilities and Python in a single layer. +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y --no-install-recommends \ + sudo \ + curl \ + git \ + build-essential \ + python3 \ + python3-pip \ + python3-dev \ + python3-venv \ + libgomp1 \ + g++ \ + && rm -rf /var/lib/apt/lists/* + +# Install Jupyter. +RUN pip3 install jupyterlab jupyterlab_vim + +# Install uv for package management. +RUN curl -LsSf https://astral.sh/uv/install.sh | sh +ENV PATH="/root/.local/bin:$PATH" + +# Install project specific packages using uv. +COPY pyproject.toml uv.lock /app/ +WORKDIR /app +RUN uv sync +ENV PATH="/app/.venv/bin:$PATH" + +# Copy project files. +COPY . /app + +RUN mkdir /install + +# Config. +COPY etc_sudoers /install/ +COPY etc_sudoers /etc/sudoers +COPY bashrc /root/.bashrc + +# Report package versions. +COPY version.sh /install/ +RUN /install/version.sh 2>&1 | tee version.log + +# Jupyter. +EXPOSE 8888 diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/README.md b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/README.md new file mode 100644 index 000000000..9e067cb2e --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/README.md @@ -0,0 +1,375 @@ +# txtai for Multi-Agentic Market Research + +- A `txtai`-based market research platform that ingests SEC EDGAR filings + and financial news, stores them across a four-tier storage architecture, + and answers natural-language questions through an agentic retrieval + pipeline served behind a FastAPI + Streamlit front end +- The whole search and agent layer is built on `txtai.Embeddings` plus an + optional `txtai.LLM` synthesizer, so the project is essentially a worked + example of how far you can take `txtai` as the search core of a real + application +- Two ways to use it: a **Streamlit UI** for interactive Q&A, or **Jupyter + notebooks** for the same pipeline cell-by-cell +- Project: `UmdTask430` for DATA605 Spring 2026 + +# System Architecture + +```mermaid +flowchart TD + user([User]) + ui["Streamlit UI
(port 8501)"] + api["FastAPI
GET /
POST /research
POST /research/stream (SSE)"] + router["research_agent._route
keyword + ticker extraction"] + sec["SEC sub-agent
txtai search WHERE tags='sec'"] + news["News sub-agent
txtai search WHERE tags='news'"] + synth["_synthesize
txtai.LLM or extractive"] + txtai[["txtai.Embeddings
SQLite content store + ANN
(data/)"]] + hot["Hot tier - KeyDB
prices, semantic cache, sessions"] + warm["Warm tier - PostgreSQL + pgvector
filings, chunks(768d), xbrl_facts"] + cold["Cold tier - MinIO
raw filing HTML/XML"] + ext["External APIs
SEC EDGAR / NewsAPI / Alpha Vantage"] + + user --> ui --> api --> router + router --> sec + router --> news + sec --> synth + news --> synth + sec -. similar(:q) .-> txtai + news -. similar(:q) .-> txtai + txtai --- warm + api --- hot + ext --> cold --> warm --> txtai +``` + +- The diagram is the project at a glance: every read path resolves to a + `txtai.Embeddings.search` call; every write path lands a chunk in the + `txtai` index after going through the warm tier first + +# Storage Architecture + +- Four physical tiers, each chosen for a different access pattern. They are + wired together by the factory helpers in `app/storage/__init__.py` + (`get_keydb_client`, `get_postgres_client`, `get_minio_client`, + `get_cache_manager`) — agent and collector code never instantiates a + client directly + +- **Hot tier - KeyDB (Redis-compatible)** at port `6379` + - `prices:{ticker}` with 60s TTL: live snapshot from Alpha Vantage + - `cache:{md5(query)}` with 3600s TTL: semantic-search response cache so + repeated questions skip the embedding model and the SQL query + - `session:{id}` with 1800s TTL: per-session agent memory used by the + Streamlit chat + - Why KeyDB and not Redis? KeyDB is a drop-in fork with multithreaded I/O + and a more permissive license — same protocol, no code changes + - Client: `app/storage/hot_storage/keydb_client.py`, + cache wrapper: `app/storage/cache_manager.py` + +- **Warm tier - PostgreSQL + pgvector** at port `5432` + - `companies(cik, ticker, name, sector, …)` — entity table + - `filings(id, ticker, filing_type, accession_number, filing_date, …)` — + one row per SEC filing + - `chunks(id, filing_id, chunk_index, text, section, embedding vector(768))` + — chunked filing text with a `pgvector` 768-dim embedding column matching + `sentence-transformers/all-mpnet-base-v2` + - `document_metadata`, `xbrl_facts`, `articles` — auxiliary tables + - Indexed with `ivfflat (embedding vector_cosine_ops) WITH (lists = 100)` + so pgvector remains a viable analytics fallback to the txtai index + - Schema is in [`sql/init.sql`](sql/init.sql) and is mounted into the + Postgres container at boot + - Client: `app/storage/warm_storage/pgvector_client.py` + +- **Cold tier - MinIO (S3-compatible)** at ports `9000` (S3 API) / `9001` + (console) + - Buckets `sec/`, `news/`, `web/`, `social/` — each holds raw documents + keyed by `{ticker}/{accession}/{filename}` + - Append-only archive used to re-derive the warm tier and the search + index without re-hitting upstream APIs (SEC EDGAR rate-limits at 10 RPS) + - Client: `app/storage/cold_storage/minio_client.py` + +- **Search tier - `txtai.Embeddings` on SQLite** at `data/` + - The artifact this project is really about; described in its own section + below + +- Why four tiers, not one? Different data has different access patterns. + Live prices need sub-millisecond reads (hot). Structured rows with + metadata filters need SQL plus vectors (warm). Raw 5-MB HTML filings need + cheap durable storage (cold). The semantic search index needs + filter-aware ANN over the chunked text (search). Collapsing them into one + store either bloats the hot path or pushes blob writes through the search + index + +# Search Index — `txtai` Deep-Dive + +- The single source of truth for retrieval is one `txtai.Embeddings` + instance configured in + [`app/pipeline/embeddings.py`](app/pipeline/embeddings.py): + + ```python + Embeddings( + { + "path": "sentence-transformers/all-mpnet-base-v2", + "content": True, + "chunksize": 100, + } + ) + ``` + + - `path`: the encoder. 768-dim, matches the pgvector schema exactly so + the same embeddings can live in both the txtai SQLite store and Postgres + - `content=True`: the original text is stored alongside the vectors in + SQLite, which is what gives us a SQL surface (see filter section below) + - `chunksize=100`: streaming insert batch size; tuned to keep memory flat + during `scripts.backfill_txtai_from_chunks --from-scratch` + +- Persistence: the index lives at + `data/{config.json,documents,embeddings,index.db}` + - Written via `Embeddings.save(path)` after each upsert in `upsert(...)` + - Read at process start via `Embeddings.load(path)` inside + `create_embeddings()` + - Singleton-cached via `get_embeddings()` so the API server, eval harness, + and notebooks all share the same on-disk artifact and don't fight over + SQLite locks + +- Search shape — every query goes through one of two `txtai` calls: + 1. **Plain semantic top-k**: + ```python + embeddings.search(query, limit=k) + ``` + used by simple smoke tests and the notebook + 2. **SQL-style filter** (the production path): + ```python + embeddings.search( + "SELECT id, text, score, tags, data FROM txtai " + "WHERE similar(:q) AND tags = :src LIMIT :k", + parameters={"q": query, "src": "sec", "k": 5}, + ) + ``` + used by both sub-agents to scope a query to one source. The custom + `data` column carries per-chunk metadata (ticker, filing_type, + filing_date) that `txtai` does not surface by default — the wrapper in + `app.pipeline.embeddings.search` lifts those fields back to the top + level + +- Why `txtai` over raw FAISS / LangChain RAG? + - Bundles three things we need together: a SQLite-backed ANN index, a + SQL-like filter language over the same store, and a unified `LLM` + abstraction + - We get `WHERE tags = 'sec'` semantics without standing up a separate + Postgres-vector duplicate or hand-rolling a metadata layer over FAISS + - The pgvector copy in the warm tier is for analytics and is not on the + hot read path + +- The full `txtai` surface used by the project is enumerated in + [`notebooks/txtai.API.ipynb`](notebooks/txtai.API.ipynb), with each + primitive isolated in one cell + +# Agentic Infrastructure + +- The pipeline is implemented in + [`app/agents/research_agent.py`](app/agents/research_agent.py) as a small + state machine. It exposes two entry points: + - `run_research_sync(query) -> dict` — drains the pipeline into a single + JSON-serialisable result; called by `POST /research` + - `run_research(query) -> Iterator[ResearchEvent]` — streams one event + per stage (`route` -> `retrieve` -> `synthesize` -> `done`); consumed + by `POST /research/stream` (SSE) and the Streamlit UI + +- **Stage 1 - Router** (`_route(query)`) + - Deterministic keyword match against `_SEC_KEYWORDS` (10-K, 8-K, proxy, + risk factor, …) and `_NEWS_KEYWORDS` (analyst, upgrade, sentiment, + bearish, …) + - Ticker extraction prefers cashtags (`$AAPL`), then known company names, + then bare uppercase tokens with a small stop-list (`SEC`, `CEO`, …) + - Output: `{"agents": ["sec","news"], "ticker": "AAPL", "reason": "..."}` + - The router is intentionally LLM-free: it costs nothing, never hangs, + and is fully testable. An LLM router can be plugged in without changing + callers because the schema does not change + +- **Stage 2 - Retrieval** — one sub-agent per source, each backed by `txtai` + - `_run_sec_agent(query, ticker)` -> `embeddings.search(... WHERE tags='sec' ...)` + - `_run_news_agent(query, ticker)` -> `embeddings.search(... WHERE tags='news' ...)` + - Results are post-filtered by ticker (`_filter_by_ticker`) when the + router extracted one, then truncated to `_MAX_CHUNKS_PER_AGENT = 5` + - Each chunk carries the `metadata` dict the collector wrote so the + citation step can render `[1] AAPL 10-K 2024-09-28` + +- **Stage 3 - Synthesis** + - If `LLM_BASE_URL` / `LLM_API_KEY` / `LLM_MODEL` are set, the synthesizer + builds a numbered-citation prompt and calls `txtai.LLM(...)` (any + OpenAI-compatible endpoint, including local Ollama) + - Otherwise it falls back to `_synthesize_template`, an extractive + composer that takes the first 1–2 sentences of the top three chunks + and stitches them together with citation markers + - The fallback exists so the demo runs end-to-end on a laptop with no + paid API keys, and so citations always reference real text rather than + a hallucination + +- **Other agents** (`diligence`, `earnings`, `regulatory`, `orchestrator`) + - Live alongside `research_agent` and back the dashboard pages in + `app/ui/dashboard.py` + - They share the same `get_embeddings()` singleton — adding a new + domain agent is a matter of writing one `txtai`-search wrapper and + plugging it into the router, no infra changes required + +# Setup + +- Clone and enter the project + ```bash + > git clone + > cd class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research + ``` + +- Configure secrets + ```bash + > cp .env.example .env + ``` + - Edit `.env` to fill in: + - `SEC_USER_AGENT`: real contact email (required by SEC EDGAR) + - `NEWSAPI_KEY`, `ALPHAVANTAGE_API_KEY`: news collectors + - `OPENAI_API_KEY`: txtai embedding backend + - Optional, for LLM-backed answers: `LLM_BASE_URL`, `LLM_API_KEY`, + `LLM_MODEL` (any OpenAI-compatible endpoint, including local Ollama). + Without these the synthesizer falls back to an extractive template + +- Bring up the storage tiers and the API + ```bash + > docker-compose build + > docker-compose up -d + ``` + - Starts: KeyDB (6379), PostgreSQL+pgvector (5432), MinIO (9000/9001), + FastAPI (8000), Streamlit (8501) + +- Seed the index with one ticker (one-time) + ```bash + > docker-compose exec api python -m scripts.run_sec_collector --ticker AAPL --limit 5 + > docker-compose exec api python -m scripts.run_earnings_collector --ticker AAPL --quarters 4 + > docker-compose exec api python -m scripts.backfill_txtai_from_chunks --from-scratch + ``` + - The earnings step uses Alpha Vantage's `EARNINGS_CALL_TRANSCRIPT` + endpoint on the same `ALPHAVANTAGE_API_KEY` already in `.env` — no + extra key required, free tier caps at 25 requests/day + +- Stop everything when done + ```bash + > docker-compose down + ``` + +# Run the Streamlit UI + +- The Docker stack already serves the UI — just open the browser + - `http://localhost:8501`: Streamlit research chat + - `http://localhost:8000/docs`: FastAPI OpenAPI docs + +- Or run the UI from a local venv (Python 3.11+) + ```bash + > python -m venv .venv && source .venv/bin/activate + > pip install -r requirements.txt + > uvicorn app.api.server:app --host 127.0.0.1 --port 8000 & + > streamlit run app/ui/research.py --server.port 8501 + ``` + +- Try it + - Type a question like _"What are the key risks in Apple's latest 10-K?"_ + - The pipeline routes to the SEC and/or news sub-agents, retrieves the + top-k chunks, and synthesizes a cited answer + +# Run the Jupyter Notebooks + +- Two paired notebooks live in `notebooks/` + - `txtai.API.ipynb`: txtai library primitives in isolation (Embeddings, + SQL filter, save/load, LLM) + - `txtai.example.ipynb`: full end-to-end demo — ingest a ticker, query + the index, run the agentic pipeline, stream events + - Both notebooks open with a table of the `txtai` and project endpoints + they exercise so the reader can navigate cell-by-cell + +- Make sure the storage tiers are up first + ```bash + > docker-compose up -d keydb postgres minio + ``` + +- Launch Jupyter from the project root (so `app.*` imports resolve) + ```bash + > source .venv/bin/activate + > pip install jupyterlab jupytext + > jupyter lab notebooks/ + ``` + +- Open `txtai.example.ipynb` and run the cells top to bottom + - Cell 1: load `.env` and configure logging + - Cell 2: ingest a small batch of AAPL filings + - Cell 3: spot-check the txtai index + - Cell 4: call `run_research_sync(query)` — the same entry point used + by FastAPI + - Cell 5: stream events with `run_research(query)` — the same generator + used by the SSE endpoint and the Streamlit UI + +- Notebooks are paired with `*.py` files via jupytext, so edits to either + format stay in sync + +# API Endpoints + +- `GET /`: health probe and capability summary + +- `POST /research`: synchronous, returns the full result + ```json + { + "query": "Apple revenue trend", + "route": {"agents": ["sec","news"], "ticker": "AAPL"}, + "answer": "...", + "retrievals": [...], + "used_llm": false, + "chunk_count": 8 + } + ``` + +- `POST /research/stream`: Server-Sent Events, one per pipeline stage + (`route` -> `retrieve` -> `synthesize` -> `done`) + +# Project Layout + +- `app/`: application code + - `agents/research_agent.py`: the agentic pipeline (router -> sub-agents + -> synthesizer) + - `agents/{diligence,earnings,regulatory,orchestrator}.py`: domain agents + used by the dashboard + - `api/server.py`: FastAPI server + - `collectors/`: SEC and news collectors + - `pipeline/{ingest,embeddings}.py`: chunk and index documents into + `txtai.Embeddings` + - `storage/`: KeyDB / PostgreSQL+pgvector / MinIO clients + - `ui/`: Streamlit pages — `research.py`, `dashboard.py`, `chat.py` + +- `scripts/`: one-shot CLIs (`run_sec_collector`, `run_sec_bulk`, + `backfill_txtai_from_chunks`, `eval_research`, `check_storage_status`) + +- `notebooks/`: paired Jupyter / jupytext notebooks +- `sql/init.sql`: PostgreSQL schema +- `data/`: persisted txtai index (gitignored, regenerable) +- `docs/architecture.excalidraw`: editable architecture diagram source +- `RUN_INSTRUCTIONS.md`: deep operational guide and troubleshooting + +# Tests and Eval + +- Unit tests + ```bash + > docker-compose exec api python -m pytest app/agents/test app/pipeline/test -v + ``` + +- End-to-end smoke (ingest + one query) + ```bash + > ./scripts/smoke_test.sh + ``` + +- Latency and retrieval metrics + ```bash + > docker-compose exec api python -m scripts.eval_research --warmup + > docker-compose exec api python -m scripts.eval_research --repeats 5 --json logs/eval.json + ``` +- Reports per-stage p50/p95/p99 latency, routing accuracy on a benchmark + set, and retrieval health (chunks/query, chunk-empty rate) + +# Troubleshooting + +- See `RUN_INSTRUCTIONS.md` for SEC rate limits, missing OpenAI key, + pgvector extension issues, and other common problems diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/RUN_INSTRUCTIONS.md b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/RUN_INSTRUCTIONS.md new file mode 100644 index 000000000..063a6655f --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/RUN_INSTRUCTIONS.md @@ -0,0 +1,316 @@ +# Running the txtai Market Research Platform + +End-to-end this project does three things: + +1. **Collect** SEC filings (EDGAR) and news articles (NewsAPI + Alpha Vantage) + into a four-tier store: KeyDB (hot cache), MinIO (cold archive), + PostgreSQL + pgvector (warm structured store), and a txtai embeddings + index (search). +2. **Search** the index through an agentic pipeline — a router picks + sub-agents (`sec`, `news`), each retrieves the top-k chunks, and a + synthesizer writes a cited answer. +3. **Serve** the pipeline as a FastAPI service with a Streamlit UI on top. + +## Quickstart for new users + +```bash +# 1. Clone and enter the repo +git clone +cd class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research + +# 2. Configure secrets +cp .env.example .env +# Edit .env to add NEWSAPI_KEY (https://newsapi.org/register) +# and ALPHAVANTAGE_API_KEY (https://www.alphavantage.co/support/#api-key) +# Replace SEC_USER_AGENT email with yours + +# 3. Bring up storage tiers (KeyDB, MinIO, Postgres + pgvector) +docker-compose up -d + +# 4. Install Python deps (Python 3.11+) +python -m venv .venv && source .venv/bin/activate +pip install -r requirements.txt + +# 5. Collect data (one-time; ~50 min for the full ticker set) +python -m scripts.run_sec_bulk --group all --skip-existing --limit 10 +python -m scripts.run_all_collectors --tickers AAPL,MSFT,NVDA --skip-sec --no-search +python -m scripts.backfill_txtai_from_chunks --from-scratch + +# 6. Start the API + UI +uvicorn app.api.server:app --host 127.0.0.1 --port 8000 & +streamlit run app/ui/research.py --server.port 8501 +# Browse to http://localhost:8501 +``` + +## Agent / API / UI + +- **Agent core**: `app/agents/research_agent.py` — `run_research(query)` + yields streaming events; `run_research_sync(query)` returns a single dict. +- **FastAPI**: `app/api/server.py` + - `GET /` — health probe + - `POST /research` — sync request, returns JSON (answer + sources + timings) + - `POST /research/stream` — Server-Sent Events, one event per pipeline + step (route → retrieve → synthesize → done) +- **Streamlit UI**: `app/ui/research.py` — shows the agent's live trace + while it runs, then collapses it into an expander and renders the clean + answer + sources. + +### Optional: enable LLM-backed answer synthesis + +The synthesizer falls back to an extractive template (first 1-2 sentences +of the top three chunks). Set these env vars on the API server to use any +OpenAI-compatible endpoint: + +```bash +export LLM_BASE_URL=http://localhost:11434/v1 # or https://api.openai.com/v1 +export LLM_API_KEY=sk-... # any value for local Ollama +export LLM_MODEL=qwen2.5:3b # or gpt-4o-mini +uvicorn app.api.server:app --host 127.0.0.1 --port 8000 +``` + +## Eval harness + +```bash +python -m scripts.eval_research --warmup +python -m scripts.eval_research --repeats 5 --json logs/eval.json +``` + +Prints p50/p95/p99 latency per pipeline stage, routing accuracy on a +benchmark set, and retrieval health metrics. + +--- + +# Running the SEC EDGAR Collector + +This guide explains how to run the SEC EDGAR collector to fetch and store filings. + +## Prerequisites + +### 1. Start Infrastructure Services + +First, start all required services (KeyDB, PostgreSQL, MinIO): + +```bash +docker-compose up -d +``` + +Verify services are running: + +```bash +docker-compose ps +``` + +You should see: +- `keydb` - Hot tier cache (port 6379) +- `postgres` - Warm tier database with pgvector (port 5432) +- `minio` - Cold tier object storage (ports 9000, 9001) + +### 2. Configure Environment + +Copy the example environment file and configure: + +```bash +cp .env.example .env +``` + +Edit `.env` with your settings: + +```bash +# KeyDB Configuration (Hot Tier) +KEYDB_HOST=localhost +KEYDB_PORT=6379 +KEYDB_PASSWORD= + +# MinIO Configuration (Cold Tier) +MINIO_ENDPOINT=localhost:9000 +MINIO_ACCESS_KEY=minioadmin +MINIO_SECRET_KEY=minioadmin +MINIO_SECURE=false + +# PostgreSQL Configuration (Warm Tier) +POSTGRES_HOST=localhost +POSTGRES_PORT=5432 +POSTGRES_DB=financial_kb +POSTGRES_USER=fin +POSTGRES_PASSWORD=fin_local + +# OpenAI API Key (required for embeddings) +OPENAI_API_KEY=sk-your-api-key-here + +# SEC EDGAR User Agent (required by SEC API) +SEC_USER_AGENT=txtai-market-research (your@email.com) +``` + +### 3. Install Dependencies + +```bash +pip install -r requirements.txt +``` + +### 4. Initialize Database + +The database is automatically initialized when PostgreSQL starts via the `sql/init.sql` script mounted in `docker-compose.yml`. + +To manually verify initialization: + +```bash +docker-compose exec postgres psql -U fin -d financial_kb -c "\dt" +``` + +You should see tables: `companies`, `filings`, `chunks`, `xbrl_facts`, `articles`, `collection_runs` + +## Running the SEC Collector + +### Basic Usage + +Fetch SEC filings for Apple (AAPL): + +```bash +python -m scripts.run_sec_collector --ticker AAPL +``` + +### Command Line Options + +``` +usage: run_sec_collector.py [-h] [-t TICKER] [-f FILING_TYPES] [-l LIMIT] + [--no-cold] [--no-warm] [--no-search] + [--use-cache] [-v] + +options: + -h, --help show this help message and exit + -t, --ticker TICKER Stock ticker symbol (default: AAPL) + -f, --filing-types Comma-separated filing types (default: 10-K,8-K,DEF 14A) + -l, --limit Maximum filings per type (default: 20) + --no-cold Skip cold storage (MinIO) + --no-warm Skip warm storage (PostgreSQL) + --no-search Skip search index (txtai) + --use-cache Use cached results if available + -v, --verbose Enable debug logging +``` + +### Examples + +**Fetch only 10-K filings for Tesla:** + +```bash +python -m scripts.run_sec_collector -t TSLA -f 10-K -l 5 +``` + +**Fetch multiple filing types for Microsoft:** + +```bash +python -m scripts.run_sec_collector -t MSFT -f "10-K,10-Q,8-K" -l 10 +``` + +**Skip search indexing (faster, just archive):** + +```bash +python -m scripts.run_sec_collector -t GOOGL --no-search +``` + +**Enable verbose logging for debugging:** + +```bash +python -m scripts.run_sec_collector -t AAPL -v +``` + +## Verifying Collection + +### Check MinIO (Cold Storage) + +Access MinIO console at http://localhost:9001 with credentials: +- Username: `minioadmin` +- Password: `minioadmin` + +Browse to the `filings` bucket to see stored SEC filings. + +### Check PostgreSQL (Warm Storage) + +Connect to PostgreSQL and query: + +```bash +docker-compose exec postgres psql -U fin -d financial_kb -c "SELECT ticker, filing_type, filing_date FROM filings ORDER BY filing_date DESC LIMIT 10;" +``` + +### Check Search Index + +Run the example notebook to verify search functionality end-to-end: + +```bash +jupyter lab notebooks/txtai.example.ipynb +``` + +Or do an isolated txtai-API tour without the storage tiers: + +```bash +jupyter lab notebooks/txtai.API.ipynb +``` + +## Troubleshooting + +### Connection Errors + +If you see connection errors: + +1. Verify Docker containers are running: + ```bash + docker-compose ps + ``` + +2. Check service logs: + ```bash + docker-compose logs postgres + docker-compose logs minio + docker-compose logs keydb + ``` + +### SEC API Rate Limiting + +The SEC API may rate-limit requests. If this happens: + +1. Ensure you have a valid `SEC_USER_AGENT` in `.env` +2. Reduce the `--limit` parameter +3. Wait a few minutes between requests + +### Missing OpenAI API Key + +Embeddings require an OpenAI API key. Set in `.env`: + +```bash +OPENAI_API_KEY=sk-... +``` + +### pgvector Extension Not Found + +If pgvector is not enabled: + +```bash +docker-compose down +docker volume rm _pgdata +docker-compose up -d postgres +``` + +This recreates the PostgreSQL volume and reinitializes with pgvector. + +## Architecture Overview + +``` +SEC EDGAR API + │ + ▼ +┌─────────────────┐ +│ SECCollector │ +└────────┬────────┘ + │ + ┌────┴────┬─────────────┬──────────────┐ + ▼ ▼ ▼ ▼ +┌────────┐ ┌──────────┐ ┌─────────┐ ┌──────────┐ +│ MinIO │ │PostgreSQL│ │ txtai │ │ KeyDB │ +│ (Cold) │ │ (Warm) │ │ (Search)│ │ (Hot) │ +└────────┘ └──────────┘ └─────────┘ └──────────┘ +``` + +- **Cold (MinIO)**: Raw HTML/XML filings archived +- **Warm (PostgreSQL)**: Structured metadata, chunks with embeddings +- **Search (txtai)**: Semantic search index +- **Hot (KeyDB)**: API response caching diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/__init__.py new file mode 100644 index 000000000..30cd9de5a --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/__init__.py @@ -0,0 +1 @@ +# txtai Market Research Platform diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/__init__.py new file mode 100644 index 000000000..5413ea9e4 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/__init__.py @@ -0,0 +1 @@ +# txtai Agents for market research diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/diligence.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/diligence.py new file mode 100644 index 000000000..78feb4643 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/diligence.py @@ -0,0 +1,213 @@ +""" +Due Diligence Agent for M&A analysis. + +This agent specializes in: +- Risk identification and flagging +- Synergy opportunity assessment +- Red-flag language pattern detection vs historical deal failures +""" + +import logging +from typing import Any + +from txtai import LLM + +from app.pipeline.embeddings import search + +_LOG = logging.getLogger(__name__) + + +# System prompt for the diligence agent +SYSTEM_PROMPT = """You are an M&A analyst specializing in risk identification and synergy assessment. + +Your role is to: +1. Identify potential risks in acquisition targets or deals +2. Assess synergy opportunities (cost, revenue, technology) +3. Detect red-flag language patterns that resemble historical deal failures +4. Provide actionable due diligence findings + +Focus on: +- Financial risks (debt levels, cash flow issues, accounting irregularities) +- Operational risks (customer concentration, supply chain, key person risk) +- Legal/regulatory risks (pending litigation, antitrust concerns) +- Cultural risks (integration challenges, talent retention) + +Be specific and cite evidence from SEC filings and earnings transcripts. +""" + + +def run(query: str, context: dict | None = None) -> dict[str, Any]: + """ + Run due diligence analysis for a query. + + Args: + query: The user's question (e.g., "What are the risks in acquiring Company X?") + context: Optional context dict with: + - ticker: Target company ticker + - acquirer_ticker: Acquiring company ticker (if applicable) + - deal_type: Type of deal (merger, acquisition, etc.) + + Returns: + Dict with: + - risk_flags: List of identified risks with severity levels + - synergies: List of potential synergy opportunities + - red_flags: Language patterns matching historical deal failures + - financial_health: Assessment of financial stability + - summary: Natural language due diligence summary + """ + if context is None: + context = {} + + ticker = context.get("ticker", "") + acquirer = context.get("acquirer_ticker", "") + + # Search SEC filings and earnings transcripts + sec_results = search(query, source_filter="sec", limit=15) + earnings_results = search( + "earnings call transcript", source_filter="earnings", limit=10 + ) + + all_results = sec_results + earnings_results + + if not all_results: + return { + "risk_flags": [], + "synergies": [], + "red_flags": [], + "financial_health": "No data available", + "summary": f"No due diligence data available for {ticker}" + if ticker + else "No data available.", + } + + # Build context for LLM + context_text = _build_context_string(all_results) + + # Use LLM for analysis + llm = _get_llm() + + prompt = f"""{SYSTEM_PROMPT} + +Query: {query} +Target: {ticker} +Acquirer: {acquirer} + +Relevant documents from SEC filings and earnings calls: +{context_text} + +Provide your due diligence analysis in this exact JSON format: +{{ + "risk_flags": [ + {{"category": "financial|operational|legal|cultural", "severity": "high|medium|low", "description": "...", "evidence": "quote or reference"}} + ], + "synergies": [ + {{"type": "cost|revenue|technology|talent", "description": "...", "estimated_impact": "..."}} + ], + "red_flags": [ + {{"pattern": "...", "similarity_to_historical_failures": "high|medium|low", "explanation": "..."}} + ], + "financial_health": "assessment summary", + "summary": "2-3 sentence due diligence summary" +}} + +Be specific and cite filing dates or earnings call quarters where possible. +""" + + response = llm(prompt) + + # Parse the response + result = _parse_response(response, all_results) + + return result + + +def _build_context_string(results: list[dict]) -> str: + """Build formatted context string from search results.""" + lines = [] + + for i, result in enumerate(results, 1): + metadata = result.get("metadata", {}) + source = metadata.get("source", "unknown") + form_type = metadata.get("form_type", "") + filing_date = metadata.get("filing_date", "") + text = result.get("text", "")[:600] + + lines.append(f"[{i}] [{source}] {form_type} - Filed: {filing_date}") + lines.append(f" Content: {text}") + lines.append("") + + return "\n".join(lines) + + +def _get_llm() -> LLM: + """Get configured LLM using Ollama.""" + import os + + ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434") + ollama_model = os.getenv("OLLAMA_LLM_MODEL", "qwen2.5:7b") + + # Try Ollama first + try: + return LLM(model=f"ollama:{ollama_model}@{ollama_host}") + except Exception: + # Fallback to OpenAI + api_key = os.getenv("OPENAI_API_KEY") + if api_key: + return LLM(model="openai:gpt-4o-mini", api_key=api_key) + else: + return LLM(model="mistralai/Mistral-7B-Instruct-v0.2") + + +def _parse_response(response: str, results: list[dict]) -> dict[str, Any]: + """Parse LLM response into structured due diligence analysis.""" + import json + import re + + json_match = re.search(r"\{.*\}", response, re.DOTALL) + + if json_match: + try: + parsed = json.loads(json_match.group()) + return { + "risk_flags": parsed.get("risk_flags", []), + "synergies": parsed.get("synergies", []), + "red_flags": parsed.get("red_flags", []), + "financial_health": parsed.get("financial_health", ""), + "summary": parsed.get("summary", response), + } + except json.JSONDecodeError: + pass + + # Fallback: return basic structure + return { + "risk_flags": [ + { + "category": "unknown", + "severity": "medium", + "description": "Analysis unavailable", + "evidence": "LLM parsing failed", + } + ], + "synergies": [], + "red_flags": [], + "financial_health": "Unable to assess", + "summary": response[:500], + } + + +if __name__ == "__main__": + # Test the agent. + from dotenv import load_dotenv + + load_dotenv() + logging.basicConfig(level=logging.INFO) + result = run("What are the key risks for AAPL?", context={"ticker": "AAPL"}) + _LOG.info("Risk flags: %d", len(result["risk_flags"])) + for risk in result["risk_flags"][:3]: + _LOG.info( + " - [%s] %s: %s", + risk["severity"], + risk["category"], + risk["description"][:100], + ) + _LOG.info("Summary: %s", result["summary"]) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/earnings.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/earnings.py new file mode 100644 index 000000000..73ed31f44 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/earnings.py @@ -0,0 +1,200 @@ +""" +Earnings Agent for earnings call and KPI analysis. + +This agent specializes in: +- KPI trend analysis +- Guidance language changes quarter-over-quarter +- Management tone shift detection +- Earnings beat/miss analysis +""" + +import logging +from typing import Any + +from txtai import LLM + +from app.pipeline.embeddings import search + +_LOG = logging.getLogger(__name__) + + +# System prompt for the earnings agent +SYSTEM_PROMPT = """You are a financial analyst specializing in earnings call analysis. + +Your role is to: +1. Analyze KPI trends quarter-over-quarter +2. Detect changes in guidance language vs prior quarters +3. Assess management tone shifts (confidence, caution, optimism) +4. Identify earnings beats/misses and analyst reactions + +Focus on: +- Revenue and EPS vs expectations +- Forward guidance changes +- Management commentary on headwinds/tailwinds +- Analyst questions and concerns +- Segment-level performance + +Be specific and cite exact language from earnings transcripts. +""" + + +def run(query: str, context: dict | None = None) -> dict[str, Any]: + """ + Run earnings analysis for a query. + + Args: + query: The user's question (e.g., "How did AAPL's earnings go?") + context: Optional context dict with: + - ticker: Company ticker + - quarter: Specific quarter (e.g., "Q4 2024") + - compare_prior: Whether to compare to prior quarter + + Returns: + Dict with: + - kpi_trends: List of KPI changes quarter-over-quarter + - guidance_changes: Changes in forward guidance language + - management_tone: Assessment of management tone + - earnings_summary: Beat/miss analysis + - summary: Natural language earnings summary + """ + if context is None: + context = {} + + ticker = context.get("ticker", "") + quarter = context.get("quarter", "") + + # Search earnings transcripts + earnings_query = f"{ticker} earnings call" if ticker else query + earnings_results = search(earnings_query, source_filter="earnings", limit=15) + + # Also search news for earnings coverage + news_results = search(f"{ticker} earnings", source_filter="news", limit=5) + + all_results = earnings_results + news_results + + if not all_results: + return { + "kpi_trends": [], + "guidance_changes": [], + "management_tone": "No data available", + "earnings_summary": "No earnings data available", + "summary": f"No earnings data available for {ticker}" + if ticker + else "No data available.", + } + + # Build context for LLM + context_text = _build_context_string(all_results) + + # Use LLM for analysis + llm = _get_llm() + + prompt = f"""{SYSTEM_PROMPT} + +Query: {query} +Company: {ticker} +Quarter: {quarter if quarter else "Most recent"} + +Relevant documents from earnings transcripts and news: +{context_text} + +Provide your earnings analysis in this exact JSON format: +{{ + "kpi_trends": [ + {{"metric": "...", "current_value": "...", "prior_value": "...", "change": "...", "direction": "improving|declining|stable"}} + ], + "guidance_changes": [ + {{"area": "...", "prior_guidance": "...", "current_guidance": "...", "change_type": "raised|lowered|unchanged|narrowed|widened"}} + ], + "management_tone": "description of tone (confident/cautious/optimistic/concerned) with evidence", + "earnings_summary": "beat/miss analysis with specific numbers", + "summary": "2-3 sentence earnings summary" +}} + +Be specific and cite exact language from the transcript where possible. +""" + + response = llm(prompt) + + # Parse the response + result = _parse_response(response, all_results) + + return result + + +def _build_context_string(results: list[dict]) -> str: + """Build formatted context string from search results.""" + lines = [] + + for i, result in enumerate(results, 1): + metadata = result.get("metadata", {}) + source = metadata.get("source", "unknown") + title = metadata.get("title", "") + text = result.get("text", "")[:600] + + lines.append(f"[{i}] [{source}] {title}") + lines.append(f" Content: {text}") + lines.append("") + + return "\n".join(lines) + + +def _get_llm() -> LLM: + """Get configured LLM using Ollama.""" + import os + + ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434") + ollama_model = os.getenv("OLLAMA_LLM_MODEL", "qwen2.5:7b") + + # Try Ollama first + try: + return LLM(model=f"ollama:{ollama_model}@{ollama_host}") + except Exception: + # Fallback to OpenAI + api_key = os.getenv("OPENAI_API_KEY") + if api_key: + return LLM(model="openai:gpt-4o-mini", api_key=api_key) + else: + return LLM(model="mistralai/Mistral-7B-Instruct-v0.2") + + +def _parse_response(response: str, results: list[dict]) -> dict[str, Any]: + """Parse LLM response into structured earnings analysis.""" + import json + import re + + json_match = re.search(r"\{.*\}", response, re.DOTALL) + + if json_match: + try: + parsed = json.loads(json_match.group()) + return { + "kpi_trends": parsed.get("kpi_trends", []), + "guidance_changes": parsed.get("guidance_changes", []), + "management_tone": parsed.get("management_tone", ""), + "earnings_summary": parsed.get("earnings_summary", ""), + "summary": parsed.get("summary", response), + } + except json.JSONDecodeError: + pass + + # Fallback: return basic structure + return { + "kpi_trends": [], + "guidance_changes": [], + "management_tone": "Unable to assess", + "earnings_summary": "Data unavailable", + "summary": response[:500], + } + + +if __name__ == "__main__": + # Test the agent. + from dotenv import load_dotenv + + load_dotenv() + logging.basicConfig(level=logging.INFO) + result = run("How did AAPL's latest earnings go?", context={"ticker": "AAPL"}) + _LOG.info("Earnings summary: %s", result["earnings_summary"]) + _LOG.info("KPI trends: %d", len(result["kpi_trends"])) + _LOG.info("Summary: %s", result["summary"]) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/orchestrator.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/orchestrator.py new file mode 100644 index 000000000..39c117d10 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/orchestrator.py @@ -0,0 +1,212 @@ +""" +Orchestrator Agent for txtai market research platform. + +This is the main routing agent that: +1. Receives user queries +2. Classifies intent (diligence / earnings / regulatory) +3. Calls relevant sub-agents via txtai Agent tool calling +4. Merges and ranks results before returning to UI +""" + +import logging +import os +from typing import Any + +from txtai import Agent, LLM + +from app.agents.diligence import run as run_diligence +from app.agents.earnings import run as run_earnings +from app.agents.regulatory import run as run_regulatory + +_LOG = logging.getLogger(__name__) + + +# System prompt for the orchestrator +SYSTEM_PROMPT = """You are the Orchestrator for a market research platform. + +Your role is to: +1. Analyze the user's query to determine their intent +2. Route to the appropriate specialist agent(s): + - DILIGENCE: M&A due diligence, risk identification, synergy assessment + - EARNINGS: Earnings call analysis, KPI trends, guidance changes, management tone + - REGULATORY: SEC filings, regulatory risk, compliance issues, enforcement signals + +3. Merge and synthesize results from multiple agents when needed +4. Provide clear, actionable insights with source citations + +Always be specific and cite your sources. If multiple agents provide relevant +information, synthesize their findings into a coherent response. +""" + + +def create_agent() -> Agent: + """ + Create and configure the orchestrator Agent. + + Uses txtai's Agent class with: + - Ollama qwen2.5:7b as primary LLM + - OpenAI gpt-4o-mini as fallback + - Tool bindings for each sub-agent + """ + # Configure LLM with fallback + llm = get_llm() + + # Define tools for each sub-agent + # txtai Agent tools are functions that the LLM can call + tools = { + "due_diligence": { + "function": run_diligence, + "description": ( + "Perform M&A due diligence analysis. Use for questions about " + "risks, synergies, deal assessment, or acquisition targets." + ), + }, + "earnings_analysis": { + "function": run_earnings, + "description": ( + "Analyze earnings calls and KPIs. Use for questions about " + "earnings results, guidance changes, management tone, or " + "financial metrics." + ), + }, + "regulatory_analysis": { + "function": run_regulatory, + "description": ( + "Analyze SEC filings and regulatory risk. Use for questions " + "about compliance, enforcement, regulatory issues, or filing " + "disclosures." + ), + }, + } + + # Create the txtai Agent + # The Agent class handles: + # - LLM invocation + # - Tool routing based on LLM decisions + # - Response synthesis + agent = Agent( + llm=llm, + tools=tools, + system_prompt=SYSTEM_PROMPT, + ) + + return agent + + +def get_llm() -> LLM: + """ + Get configured LLM using Ollama. + + Primary: Ollama qwen2.5:7b (or custom model via OLLAMA_LLM_MODEL) + Fallback: OpenAI gpt-4o-mini (if Ollama not configured) + """ + ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434") + ollama_model = os.getenv("OLLAMA_LLM_MODEL", "qwen2.5:7b") + + # Try Ollama first + try: + return LLM( + model=f"ollama:{ollama_model}@{ollama_host}", + ) + except Exception: + # Fallback to OpenAI if Ollama is not available + api_key = os.getenv("OPENAI_API_KEY") + if api_key: + return LLM( + model="openai:gpt-4o-mini", + api_key=api_key, + ) + else: + # Final fallback to HuggingFace + return LLM( + model="mistralai/Mistral-7B-Instruct-v0.2", + ) + + +# Global singleton agent instance +_orchestrator: Agent | None = None + + +def get_orchestrator() -> Agent: + """Get or create the singleton orchestrator agent.""" + global _orchestrator + if _orchestrator is None: + _orchestrator = create_agent() + return _orchestrator + + +def run(query: str, context: dict | None = None) -> dict[str, Any]: + """ + Run the orchestrator with a user query. + + Args: + query: The user's question or request + context: Optional context dict with: + - ticker: Stock ticker symbol (e.g., "AAPL") + - time_range: Time range for analysis (e.g., "last 30 days") + + Returns: + Dict with: + - response: The synthesized response text + - sources: List of source documents cited + - agents_used: List of sub-agents that were called + - confidence: Confidence score (0-1) in the response + """ + if context is None: + context = {} + + # Add context to the query + ticker = context.get("ticker", "") + if ticker: + enhanced_query = f"For {ticker}: {query}" + else: + enhanced_query = query + + # Run the orchestrator agent + # txtai's Agent.__call__ handles the full workflow: + # 1. LLM analyzes query and decides which tools to call + # 2. Tools are executed with LLM-provided arguments + # 3. LLM synthesizes final response from tool results + agent = get_orchestrator() + response = agent(enhanced_query) + + # Parse the response to extract metadata + # txtai returns the raw LLM response, we need to structure it + result = { + "response": response if isinstance(response, str) else str(response), + "sources": [], + "agents_used": [], + "confidence": 0.8, # Default confidence + } + + # Try to extract sources from the response + # Sub-agents should include source markers in their output + if "[Sources:" in result["response"]: + sources_section = result["response"].split("[Sources:")[1].split("]")[0] + result["sources"] = [s.strip() for s in sources_section.split(",") if s.strip()] + + # Detect which agents were used based on response content + agent_indicators = { + "diligence": ["risk", "synergy", "due diligence", "M&A"], + "earnings": ["earnings", "KPI", "guidance", "revenue"], + "regulatory": ["SEC", "filing", "regulatory", "compliance"], + } + + response_lower = result["response"].lower() + for agent_name, indicators in agent_indicators.items(): + if any(ind in response_lower for ind in indicators): + result["agents_used"].append(agent_name) + + return result + + +if __name__ == "__main__": + # Test the orchestrator. + from dotenv import load_dotenv + + load_dotenv() + logging.basicConfig(level=logging.INFO) + result = run("What are the key risks for AAPL?", context={"ticker": "AAPL"}) + _LOG.info("Response: %s", result["response"]) + _LOG.info("Agents used: %s", result["agents_used"]) + _LOG.info("Sources: %s", result["sources"]) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/regulatory.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/regulatory.py new file mode 100644 index 000000000..d72e31815 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/regulatory.py @@ -0,0 +1,205 @@ +""" +Regulatory Agent for SEC and regulatory filing analysis. + +This agent specializes in: +- Flagged filing identification +- Enforcement risk signal detection +- Material disclosure language analysis +- Compliance monitoring +""" + +import logging +from typing import Any + +from txtai import LLM + +from app.pipeline.embeddings import search + +_LOG = logging.getLogger(__name__) + + +# System prompt for the regulatory agent +SYSTEM_PROMPT = """You are a compliance analyst monitoring regulatory and policy risk. + +Your role is to: +1. Identify flagged filings that warrant attention +2. Detect enforcement risk signals +3. Analyze material disclosure language +4. Monitor regulatory and policy changes + +Focus on: +- SEC comment letters and responses +- Restatements and amendments +- Risk factor changes (Item 1A in 10-K) +- Legal proceedings disclosures +- Related party transactions +- Going concern qualifications + +Be specific and cite filing types, dates, and exact disclosure language. +""" + + +def run(query: str, context: dict | None = None) -> dict[str, Any]: + """ + Run regulatory analysis for a query. + + Args: + query: The user's question (e.g., "Any regulatory issues with AAPL?") + context: Optional context dict with: + - ticker: Company ticker + - filing_type: Specific filing type to focus on + - risk_areas: Specific risk areas to investigate + + Returns: + Dict with: + - flagged_filings: List of filings that warrant attention + - enforcement_signals: Signs of potential enforcement action + - disclosure_changes: Material changes in disclosure language + - risk_assessment: Overall regulatory risk level + - summary: Natural language regulatory summary + """ + if context is None: + context = {} + + ticker = context.get("ticker", "") + filing_type = context.get("filing_type", "") + + # Search SEC filings + sec_query = f"{ticker} SEC filing" if ticker else query + sec_results = search(sec_query, source_filter="sec", limit=15) + + if not sec_results: + return { + "flagged_filings": [], + "enforcement_signals": [], + "disclosure_changes": [], + "risk_assessment": "No data available", + "summary": f"No regulatory data available for {ticker}" + if ticker + else "No data available.", + } + + # Build context for LLM + context_text = _build_context_string(sec_results) + + # Use LLM for analysis + llm = _get_llm() + + prompt = f"""{SYSTEM_PROMPT} + +Query: {query} +Company: {ticker} +Filing Type Focus: {filing_type if filing_type else "All types"} + +Relevant SEC filings: +{context_text} + +Provide your regulatory analysis in this exact JSON format: +{{ + "flagged_filings": [ + {{"filing_type": "...", "filing_date": "...", "issue": "...", "severity": "high|medium|low", "reference": "specific section or item"}} + ], + "enforcement_signals": [ + {{"signal": "...", "type": "comment_letter|restatement|investigation|warning", "description": "...", "severity": "high|medium|low"}} + ], + "disclosure_changes": [ + {{"area": "...", "change_type": "added|removed|modified", "prior_language": "...", "current_language": "...", "implication": "..."}} + ], + "risk_assessment": "overall regulatory risk level (low/medium/high) with explanation", + "summary": "2-3 sentence regulatory summary" +}} + +Be specific and cite filing dates, form types, and exact disclosure language. +Flag any unusual patterns or changes from standard disclosure practices. +""" + + response = llm(prompt) + + # Parse the response + result = _parse_response(response, sec_results) + + return result + + +def _build_context_string(results: list[dict]) -> str: + """Build formatted context string from search results.""" + lines = [] + + for i, result in enumerate(results, 1): + metadata = result.get("metadata", {}) + form_type = metadata.get("form_type", "") + company = metadata.get("company_name", "") + filing_date = metadata.get("filing_date", "") + items = metadata.get("items", []) + text = result.get("text", "")[:600] + + lines.append(f"[{i}] {company} - {form_type}") + lines.append(f" Filed: {filing_date}") + if items: + lines.append(f" Items: {', '.join(items)}") + lines.append(f" Content: {text}") + lines.append("") + + return "\n".join(lines) + + +def _get_llm() -> LLM: + """Get configured LLM using Ollama.""" + import os + + ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434") + ollama_model = os.getenv("OLLAMA_LLM_MODEL", "qwen2.5:7b") + + # Try Ollama first + try: + return LLM(model=f"ollama:{ollama_model}@{ollama_host}") + except Exception: + # Fallback to OpenAI + api_key = os.getenv("OPENAI_API_KEY") + if api_key: + return LLM(model="openai:gpt-4o-mini", api_key=api_key) + else: + return LLM(model="mistralai/Mistral-7B-Instruct-v0.2") + + +def _parse_response(response: str, results: list[dict]) -> dict[str, Any]: + """Parse LLM response into structured regulatory analysis.""" + import json + import re + + json_match = re.search(r"\{.*\}", response, re.DOTALL) + + if json_match: + try: + parsed = json.loads(json_match.group()) + return { + "flagged_filings": parsed.get("flagged_filings", []), + "enforcement_signals": parsed.get("enforcement_signals", []), + "disclosure_changes": parsed.get("disclosure_changes", []), + "risk_assessment": parsed.get("risk_assessment", ""), + "summary": parsed.get("summary", response), + } + except json.JSONDecodeError: + pass + + # Fallback: return basic structure + return { + "flagged_filings": [], + "enforcement_signals": [], + "disclosure_changes": [], + "risk_assessment": "Unable to assess", + "summary": response[:500], + } + + +if __name__ == "__main__": + # Test the agent. + from dotenv import load_dotenv + + load_dotenv() + logging.basicConfig(level=logging.INFO) + result = run("Any regulatory issues for AAPL?", context={"ticker": "AAPL"}) + _LOG.info("Flagged filings: %d", len(result["flagged_filings"])) + _LOG.info("Enforcement signals: %d", len(result["enforcement_signals"])) + _LOG.info("Risk assessment: %s", result["risk_assessment"]) + _LOG.info("Summary: %s", result["summary"]) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/research_agent.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/research_agent.py new file mode 100644 index 000000000..e9dfe8cc4 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/research_agent.py @@ -0,0 +1,627 @@ +""" +Agentic search prototype. + +Pipeline: + 1. Router - inspects the user query, decides which sub-agents to fire and + extracts a ticker if one is mentioned. + 2. Sub-agents - SEC and News. Each runs a filtered semantic search against + the txtai index (and optionally narrows by ticker). + 3. Synthesizer - formats the retrieved chunks into a natural-language + answer with inline citations. + +Each step is a generator that yields events so the FastAPI / Streamlit layer +can stream progress to the user. +""" + +import logging +import os +import re +import time +from typing import Any, Iterator, Optional, TypedDict + +from app.pipeline.embeddings import search + +_LOG = logging.getLogger(__name__) + +# ############################################################################# +# Constants +# ############################################################################# + +# Map of common tickers we track so the router can recognise them in plain +# English (e.g. "Apple" -> "AAPL"). +_COMPANY_TO_TICKER = { + "apple": "AAPL", + "microsoft": "MSFT", + "google": "GOOGL", + "alphabet": "GOOGL", + "amazon": "AMZN", + "meta": "META", + "facebook": "META", + "tesla": "TSLA", + "nvidia": "NVDA", + "netflix": "NFLX", + "oracle": "ORCL", + "intel": "INTC", + "ibm": "IBM", + "jpmorgan": "JPM", + "bank of america": "BAC", + "wells fargo": "WFC", + "goldman sachs": "GS", + "morgan stanley": "MS", + "citi": "C", + "citigroup": "C", + "visa": "V", + "mastercard": "MA", + "berkshire": "BRK-B", + "walmart": "WMT", + "costco": "COST", + "home depot": "HD", + "johnson & johnson": "JNJ", + "pfizer": "PFE", + "merck": "MRK", + "eli lilly": "LLY", + "abbvie": "ABBV", + "exxon": "XOM", + "chevron": "CVX", + "amd": "AMD", + "qualcomm": "QCOM", + "broadcom": "AVGO", + "ford": "F", + "general motors": "GM", +} + +# Keyword -> sub-agent. Used when the LLM router is offline. Multiple agents +# can match a single query. +_SEC_KEYWORDS = { + "10-k", + "10k", + "10-q", + "10q", + "8-k", + "8k", + "filing", + "filings", + "proxy", + "def 14a", + "annual report", + "quarterly report", + "risk factor", + "sec", + "edgar", + "form", + "regulator", + "compliance", + "audit", + "auditor", +} + +_NEWS_KEYWORDS = { + "news", + "headline", + "article", + "sentiment", + "bullish", + "bearish", + "analyst", + "upgrade", + "downgrade", + "rumor", + "rumour", + "press", + "announcement", + "outlook", + "forecast", + "target price", +} + +# Maximum chunks per sub-agent fed into the synthesizer. +_MAX_CHUNKS_PER_AGENT = 5 + + +# ############################################################################# +# Event types +# ############################################################################# + + +class ResearchEvent(TypedDict): + """ + Streaming event emitted by `run_research`. + + :param step: which pipeline stage produced this event + ("route" | "retrieve" | "synthesize" | "done" | "error") + :param payload: stage-specific data (see individual stage docs) + """ + + step: str + payload: dict[str, Any] + + +# ############################################################################# +# Router +# ############################################################################# + + +def _extract_ticker(query: str) -> Optional[str]: + """ + Pull a ticker symbol out of a query. + + :param query: raw user query + :return: ticker (e.g. "AAPL") or None if nothing recognisable found + """ + query_lower = query.lower() + # Try a literal $TICKER mention first ("$AAPL"). + cashtag_match = re.search(r"\$([A-Z]{1,5}(?:-[A-Z])?)", query.upper()) + if cashtag_match: + return cashtag_match.group(1) + # Then look for a known company name. + for name, ticker in _COMPANY_TO_TICKER.items(): + if name in query_lower: + return ticker + # Finally, look for a bare 2-5 letter uppercase token. Skip if it's a + # common English word that happens to be all caps in the query. + bare_match = re.search(r"\b([A-Z]{2,5}(?:-[A-Z])?)\b", query) + if bare_match: + candidate = bare_match.group(1) + if candidate not in {"SEC", "CEO", "CFO", "USA", "EU", "AI", "ML"}: + return candidate + return None + + +def _route(query: str) -> dict[str, Any]: + """ + Decide which sub-agents to fire for this query. + + :param query: raw user query + :return: dict with keys + - ``ticker`` : extracted ticker symbol or None + - ``agents`` : ordered list of sub-agent names that should run + - ``reason`` : short human-readable explanation of the routing + decision (shown in the UI) + """ + query_lower = query.lower() + ticker = _extract_ticker(query) + sec_hit = any(kw in query_lower for kw in _SEC_KEYWORDS) + news_hit = any(kw in query_lower for kw in _NEWS_KEYWORDS) + # Default routing: if neither bucket is explicitly mentioned, run both + # because the user probably wants a broad answer. + if sec_hit and not news_hit: + agents = ["sec"] + reason = "Question mentions SEC / filings keywords; routing to SEC agent only." + elif news_hit and not sec_hit: + agents = ["news"] + reason = ( + "Question mentions news / sentiment keywords; routing to News agent only." + ) + else: + agents = ["sec", "news"] + reason = "Generic question; routing to both SEC and News agents for coverage." + return {"ticker": ticker, "agents": agents, "reason": reason} + + +# ############################################################################# +# Sub-agents (retrieval) +# ############################################################################# + + +def _filter_by_ticker(results: list[dict], ticker: Optional[str]) -> list[dict]: + """ + Keep only chunks whose stored ticker metadata matches ``ticker``. + + :param results: raw txtai search results + :param ticker: ticker to match (case-insensitive); pass None to skip + :return: filtered list (same shape as input) + """ + if not ticker: + return results + ticker_upper = ticker.upper() + filtered = [] + for r in results: + # The metadata field can be a dict or a JSON-encoded string depending + # on how txtai stored it; handle both shapes defensively. + md = r.get("metadata") or {} + if isinstance(md, str): + try: + import json + + md = json.loads(md) + except Exception: + md = {} + if md.get("ticker", "").upper() == ticker_upper: + filtered.append(r) + return filtered + + +def _run_sec_agent(query: str, ticker: Optional[str]) -> list[dict]: + """ + Retrieve top SEC chunks for the query. + + :param query: user query, used as the semantic search input + :param ticker: optional ticker to narrow the result set + :return: up to _MAX_CHUNKS_PER_AGENT chunk dicts with ``score``, ``text``, + and ``metadata`` + """ + _LOG.info("SEC agent query=%s ticker=%s", query, ticker) + raw = search(query, source_filter="sec", limit=_MAX_CHUNKS_PER_AGENT * 4) + if ticker: + narrowed = _filter_by_ticker(raw, ticker) + # If the ticker filter wipes out everything, fall back to the + # unfiltered set so the user still sees something. + if narrowed: + raw = narrowed + return raw[:_MAX_CHUNKS_PER_AGENT] + + +def _run_news_agent(query: str, ticker: Optional[str]) -> list[dict]: + """ + Retrieve top news chunks for the query. + + :param query: user query, used as the semantic search input + :param ticker: optional ticker to narrow the result set + :return: up to _MAX_CHUNKS_PER_AGENT chunk dicts + """ + _LOG.info("News agent query=%s ticker=%s", query, ticker) + raw = search(query, source_filter="news", limit=_MAX_CHUNKS_PER_AGENT * 4) + if ticker: + narrowed = _filter_by_ticker(raw, ticker) + if narrowed: + raw = narrowed + return raw[:_MAX_CHUNKS_PER_AGENT] + + +_AGENT_FUNCS = { + "sec": _run_sec_agent, + "news": _run_news_agent, +} + + +# ############################################################################# +# Synthesizer +# ############################################################################# + + +def _format_citation(idx: int, chunk: dict) -> str: + """ + Render a chunk as one bullet line for the citation list. + + :param idx: 1-based citation number used in the body of the answer + :param chunk: raw txtai search result + :return: formatted markdown line + """ + md = chunk.get("metadata") or {} + if isinstance(md, str): + try: + import json + + md = json.loads(md) + except Exception: + md = {} + ticker = md.get("ticker", "?") + source = md.get("source") or chunk.get("tags") or "?" + filing_type = md.get("filing_type", "") + filing_date = md.get("filing_date") or md.get("published_at", "") + score = chunk.get("score", 0.0) + label_parts = [str(ticker), str(source)] + if filing_type: + label_parts.append(str(filing_type)) + if filing_date: + label_parts.append(str(filing_date)[:10]) + label = " | ".join(label_parts) + snippet = (chunk.get("text") or "").strip().replace("\n", " ")[:240] + return f"[{idx}] **{label}** (score={score:.3f})\n > {snippet}…" + + +def _first_sentences(text: str, n: int = 2) -> str: + """ + Pull the first ``n`` sentences out of a chunk for use in a prose answer. + + Uses a simple regex split on sentence-ending punctuation, then filters + out fragments that are too short to be real sentences (table headers, + page numbers, etc.). + + :param text: raw chunk text + :param n: maximum number of sentences to keep + :return: cleaned, joined sentences (no trailing newline) + """ + cleaned = re.sub(r"\s+", " ", text or "").strip() + parts = re.split(r"(?<=[.!?])\s+(?=[A-Z])", cleaned) + real = [p.strip() for p in parts if len(p.strip()) > 30] + if not real: + return cleaned[:300] + return " ".join(real[:n]) + + +def _synthesize_template(query: str, route_info: dict, all_chunks: list[dict]) -> str: + """ + Build a clean natural-language answer from retrieved chunks. + + Without an LLM available we use an extractive strategy: take the first + 1-2 sentences of the top three chunks and stitch them into a short + prose paragraph with inline ``[N]`` citation markers. The full source + list is rendered separately by the API/UI and is NOT included here. + + :param query: original user question (kept for downstream LLM prompt + parity, not used in template output) + :param route_info: output of ``_route``; only the ticker is used + :param all_chunks: every chunk returned by every sub-agent, in the + order they should be cited + :return: short markdown paragraph with inline ``[N]`` citations + """ + _ = query # kept for symmetry with the LLM synthesizer. + if not all_chunks: + ticker = route_info.get("ticker") + target = f"about **{ticker}**" if ticker else "for that question" + agents = ", ".join(route_info.get("agents", [])) + return ( + f"I couldn't find relevant documents {target}. " + f"Searched the {agents} index. Try rephrasing or asking about " + "a different ticker (we have data for AAPL, MSFT, NVDA, JPM, " + "TSLA, and 60+ others)." + ) + # Take the strongest chunks first — they're already score-ordered per + # agent, so we just slice the top three across the merged list. + top = all_chunks[:3] + sentences = [] + for i, chunk in enumerate(top, start=1): + sent = _first_sentences(chunk.get("text") or "", n=2) + if sent: + sentences.append(f"{sent} [{i}]") + if not sentences: + return "I retrieved documents but couldn't extract a coherent summary." + return " ".join(sentences) + + +def _synthesize_with_llm( + query: str, route_info: dict, all_chunks: list[dict] +) -> Optional[str]: + """ + Try to synthesise an answer using an OpenAI-compatible LLM endpoint. + + Returns None on any error so the caller can fall back to the template. + + Reads endpoint configuration from environment: + - ``LLM_BASE_URL`` : OpenAI-compatible base URL (e.g. http://localhost:11434/v1) + - ``LLM_API_KEY`` : API key (use any string for local Ollama) + - ``LLM_MODEL`` : model name to use + + :param query: user query + :param route_info: output of ``_route`` + :param all_chunks: chunks to ground the answer in + :return: synthesized markdown answer, or None if the LLM call fails + """ + base_url = os.getenv("LLM_BASE_URL") + model = os.getenv("LLM_MODEL") + if not base_url or not model: + return None + api_key = os.getenv("LLM_API_KEY", "not-needed") + try: + import httpx + except ImportError: + return None + # Build the context block: each chunk numbered so the LLM can cite it. + context_lines = [] + for i, chunk in enumerate(all_chunks, start=1): + md = chunk.get("metadata") or {} + if isinstance(md, str): + import json + + try: + md = json.loads(md) + except Exception: + md = {} + ticker = md.get("ticker", "?") + src = md.get("source") or chunk.get("tags") or "?" + date = md.get("filing_date") or md.get("published_at", "") + snippet = (chunk.get("text") or "").strip()[:600] + context_lines.append(f"[{i}] ({ticker} | {src} | {date})\n{snippet}") + context_text = "\n\n".join(context_lines) + system = ( + "You are a financial research assistant. Answer the user's question using " + "ONLY the provided document excerpts. Cite each claim with [N] where N is " + "the document number. If the documents do not answer the question, say so." + ) + user = ( + f"Question: {query}\n\n" + f"Documents:\n{context_text}\n\n" + f"Write a concise answer (3-5 sentences) with inline [N] citations." + ) + payload = { + "model": model, + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + "temperature": 0.2, + } + try: + with httpx.Client(timeout=60.0) as client: + resp = client.post( + f"{base_url.rstrip('/')}/chat/completions", + json=payload, + headers={"Authorization": f"Bearer {api_key}"}, + ) + resp.raise_for_status() + data = resp.json() + content = data["choices"][0]["message"]["content"].strip() + except Exception as e: + _LOG.warning("LLM synthesis failed: %s", e) + return None + # Sources are rendered separately by the caller, so we return only the + # natural-language answer (with the inline ``[N]`` citation markers). + return content + + +# ############################################################################# +# Public entry point +# ############################################################################# + + +def _chunk_to_source(idx: int, chunk: dict) -> dict[str, Any]: + """ + Project a raw chunk into a compact source record for the UI / API. + + :param idx: 1-based citation index used in the answer text + :param chunk: raw txtai search result + :return: dict with the fields a UI needs to render a source line + (``id``, ``score``, ``ticker``, ``source``, ``filing_type``, + ``filing_date``, ``url``, ``snippet``) + """ + md = chunk.get("metadata") or {} + if isinstance(md, str): + try: + import json + + md = json.loads(md) + except Exception: + md = {} + snippet = (chunk.get("text") or "").strip().replace("\n", " ")[:240] + return { + "id": idx, + "score": float(chunk.get("score") or 0.0), + "ticker": md.get("ticker"), + "source": md.get("source") or chunk.get("tags"), + "filing_type": md.get("filing_type"), + "filing_date": md.get("filing_date") or md.get("published_at"), + "url": md.get("url"), + "accession_number": md.get("accession_number"), + "snippet": snippet, + } + + +def run_research(query: str) -> Iterator[ResearchEvent]: + """ + Run the full agentic research pipeline as a streaming generator. + + Every emitted event is annotated with ``elapsed_ms`` (milliseconds since + the start of the run) so a UI / eval harness can plot timing without + needing its own clock. + + :param query: raw user question + :yields: a sequence of ``ResearchEvent`` dicts in this order: + - ``route`` : routing decision (ticker, agents, reason) + - ``retrieve`` : one event per sub-agent with the chunks it found + - ``synthesize`` : the final natural-language answer + - ``done`` : terminal marker with the full timing breakdown + """ + t_start = time.perf_counter() + + def _now_ms() -> float: + return (time.perf_counter() - t_start) * 1000.0 + + timings: dict[str, float] = {} + # 1. Route. + t0 = time.perf_counter() + yield { + "step": "route", + "payload": {"query": query, "status": "starting", "elapsed_ms": _now_ms()}, + } + route_info = _route(query) + timings["route_ms"] = (time.perf_counter() - t0) * 1000.0 + yield { + "step": "route", + "payload": { + "query": query, + **route_info, + "elapsed_ms": _now_ms(), + "step_ms": timings["route_ms"], + }, + } + # 2. Retrieve from each sub-agent. + all_chunks: list[dict] = [] + for agent_name in route_info["agents"]: + yield { + "step": "retrieve", + "payload": { + "agent": agent_name, + "status": "running", + "elapsed_ms": _now_ms(), + }, + } + t1 = time.perf_counter() + chunks = _AGENT_FUNCS[agent_name](query, route_info.get("ticker")) + agent_ms = (time.perf_counter() - t1) * 1000.0 + timings[f"retrieve_{agent_name}_ms"] = agent_ms + yield { + "step": "retrieve", + "payload": { + "agent": agent_name, + "status": "complete", + "count": len(chunks), + "chunks": chunks, + "elapsed_ms": _now_ms(), + "step_ms": agent_ms, + }, + } + all_chunks.extend(chunks) + # 3. Synthesize. + yield { + "step": "synthesize", + "payload": {"status": "running", "elapsed_ms": _now_ms()}, + } + t2 = time.perf_counter() + answer = _synthesize_with_llm(query, route_info, all_chunks) + used_llm = answer is not None + if not answer: + answer = _synthesize_template(query, route_info, all_chunks) + timings["synthesize_ms"] = (time.perf_counter() - t2) * 1000.0 + sources = [_chunk_to_source(i, c) for i, c in enumerate(all_chunks, start=1)] + timings["total_ms"] = _now_ms() + yield { + "step": "synthesize", + "payload": { + "status": "complete", + "answer": answer, + "sources": sources, + "used_llm": used_llm, + "chunk_count": len(all_chunks), + "elapsed_ms": _now_ms(), + "step_ms": timings["synthesize_ms"], + }, + } + yield { + "step": "done", + "payload": { + "chunk_count": len(all_chunks), + "timings": timings, + }, + } + + +def run_research_sync(query: str) -> dict[str, Any]: + """ + Convenience wrapper that drains ``run_research`` into a single dict. + + :param query: raw user question + :return: dict with keys ``query``, ``route``, ``retrievals``, ``answer``, + ``sources``, ``used_llm``, ``chunk_count``, ``timings`` + """ + route: dict[str, Any] = {} + retrievals: list[dict[str, Any]] = [] + answer = "" + sources: list[dict[str, Any]] = [] + used_llm = False + chunk_count = 0 + timings: dict[str, float] = {} + for event in run_research(query): + step = event["step"] + payload = event["payload"] + if step == "route" and "agents" in payload: + route = payload + elif step == "retrieve" and payload.get("status") == "complete": + retrievals.append(payload) + elif step == "synthesize" and payload.get("status") == "complete": + answer = payload.get("answer", "") + sources = payload.get("sources", []) + used_llm = bool(payload.get("used_llm")) + chunk_count = int(payload.get("chunk_count", 0)) + elif step == "done": + timings = payload.get("timings", {}) + return { + "query": query, + "route": route, + "retrievals": retrievals, + "answer": answer, + "sources": sources, + "used_llm": used_llm, + "chunk_count": chunk_count, + "timings": timings, + } diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/test/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/test/test_research_agent.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/test/test_research_agent.py new file mode 100644 index 000000000..97f989341 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/agents/test/test_research_agent.py @@ -0,0 +1,269 @@ +""" +Unit tests for the research agent's pure functions. + +Covers: +- ``_extract_ticker``: cashtag, company name, bare uppercase token, ambiguous +- ``_route``: keyword routing for SEC / News / both +- ``_filter_by_ticker``: dict and JSON-string metadata shapes +- ``_first_sentences``: short text, multi-sentence text +- ``_format_citation``: minimal vs. full metadata +""" + +import json +import logging +import unittest + +from app.agents.research_agent import ( + _extract_ticker, + _filter_by_ticker, + _first_sentences, + _format_citation, + _route, +) + +_LOG = logging.getLogger(__name__) + + +# ############################################################################# +# Test__extract_ticker +# ############################################################################# + + +class Test__extract_ticker(unittest.TestCase): + """ + Test ticker extraction from natural-language queries. + """ + + def helper(self, query: str, expected: str | None) -> None: + """ + Run extractor and compare to expected ticker. + + :param query: raw user query + :param expected: expected ticker symbol, or ``None`` + """ + # Run test. + actual = _extract_ticker(query) + # Check output. + self.assertEqual(actual, expected) + + def test1(self) -> None: + """ + Cashtag mention returns the explicit ticker. + """ + self.helper("How is $AAPL doing?", "AAPL") + + def test2(self) -> None: + """ + Company name maps to the canonical ticker. + """ + self.helper("What's the latest on apple?", "AAPL") + + def test3(self) -> None: + """ + Bare uppercase token is treated as a ticker when it is not a common + English word. + """ + self.helper("Tell me about NVDA earnings", "NVDA") + + def test4(self) -> None: + """ + Common acronyms (SEC, CEO, AI, ML) are not mistaken for tickers. + """ + self.helper("Are there any SEC issues?", None) + + def test5(self) -> None: + """ + Query with no ticker-like content returns ``None``. + """ + self.helper("how are markets today", None) + + +# ############################################################################# +# Test__route +# ############################################################################# + + +class Test__route(unittest.TestCase): + """ + Test the keyword router that chooses sub-agents. + """ + + def test1(self) -> None: + """ + SEC-only keyword routes to the SEC agent. + """ + # Run test. + out = _route("Summarize the latest 10-K filing") + # Check outputs. + self.assertEqual(out["agents"], ["sec"]) + + def test2(self) -> None: + """ + News-only keyword routes to the News agent. + """ + # Run test. + out = _route("Any bullish analyst upgrades?") + # Check outputs. + self.assertEqual(out["agents"], ["news"]) + + def test3(self) -> None: + """ + Generic question fans out to both sub-agents. + """ + # Run test. + out = _route("Tell me about Apple") + # Check outputs. + self.assertEqual(out["agents"], ["sec", "news"]) + + def test4(self) -> None: + """ + Router carries the extracted ticker through. + """ + # Run test. + out = _route("Tell me about $TSLA") + # Check outputs. + self.assertEqual(out["ticker"], "TSLA") + + +# ############################################################################# +# Test__filter_by_ticker +# ############################################################################# + + +class Test__filter_by_ticker(unittest.TestCase): + """ + Test that ticker filtering handles both dict and JSON-string metadata. + """ + + def test1(self) -> None: + """ + Dict metadata is matched case-insensitively. + """ + # Prepare inputs. + results = [ + {"text": "a", "metadata": {"ticker": "AAPL"}}, + {"text": "b", "metadata": {"ticker": "MSFT"}}, + ] + # Run test. + out = _filter_by_ticker(results, "aapl") + # Check outputs. + self.assertEqual(len(out), 1) + self.assertEqual(out[0]["text"], "a") + + def test2(self) -> None: + """ + JSON-string metadata is parsed and matched. + """ + # Prepare inputs. + results = [ + {"text": "a", "metadata": json.dumps({"ticker": "AAPL"})}, + {"text": "b", "metadata": json.dumps({"ticker": "NVDA"})}, + ] + # Run test. + out = _filter_by_ticker(results, "NVDA") + # Check outputs. + self.assertEqual(len(out), 1) + self.assertEqual(out[0]["text"], "b") + + def test3(self) -> None: + """ + ``None`` ticker disables filtering and returns the input unchanged. + """ + # Prepare inputs. + results = [{"text": "a"}, {"text": "b"}] + # Run test. + out = _filter_by_ticker(results, None) + # Check outputs. + self.assertEqual(out, results) + + +# ############################################################################# +# Test__first_sentences +# ############################################################################# + + +class Test__first_sentences(unittest.TestCase): + """ + Test the simple sentence picker used by the extractive synthesizer. + """ + + def test1(self) -> None: + """ + Multi-sentence text returns the first ``n`` sentences joined. + """ + # Prepare inputs (each sentence > 30 chars to clear the fragment filter). + text = ( + "Apple reported record fiscal-year revenue across services. " + "Services grew double digits versus the prior year period. " + "Hardware revenue was approximately flat compared to prior year. " + "The CEO commented at length about long-term gross margins." + ) + # Run test. + out = _first_sentences(text, n=2) + # Check outputs. + self.assertIn("Apple reported record", out) + self.assertIn("Services grew", out) + self.assertNotIn("CEO commented", out) + + def test2(self) -> None: + """ + Short single-fragment text falls back to a 300-char window. + """ + # Prepare inputs. + text = "Quarterly highlights" + # Run test. + out = _first_sentences(text) + # Check outputs. + self.assertEqual(out, "Quarterly highlights") + + +# ############################################################################# +# Test__format_citation +# ############################################################################# + + +class Test__format_citation(unittest.TestCase): + """ + Test citation rendering for source list bullets. + """ + + def test1(self) -> None: + """ + Full SEC metadata renders ticker, source, form type, and date. + """ + # Prepare inputs. + chunk = { + "score": 0.87, + "text": "Risk factors include macro uncertainty.", + "metadata": { + "ticker": "AAPL", + "source": "sec", + "filing_type": "10-K", + "filing_date": "2024-09-30", + }, + } + # Run test. + out = _format_citation(1, chunk) + # Check outputs. + self.assertIn("[1]", out) + self.assertIn("AAPL", out) + self.assertIn("sec", out) + self.assertIn("10-K", out) + self.assertIn("2024-09-30", out) + self.assertIn("0.870", out) + + def test2(self) -> None: + """ + Minimal metadata still renders a valid line. + """ + # Prepare inputs. + chunk = {"score": 0.1, "text": "x", "metadata": {}} + # Run test. + out = _format_citation(7, chunk) + # Check outputs. + self.assertIn("[7]", out) + self.assertIn("0.100", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/api/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/api/server.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/api/server.py new file mode 100644 index 000000000..5f094dec2 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/api/server.py @@ -0,0 +1,142 @@ +""" +FastAPI server exposing the agentic research pipeline. + +Endpoints: + - GET / : health probe + capability summary + - POST /research : run the pipeline synchronously, return one + JSON document with the full trace + - POST /research/stream : run the pipeline as Server-Sent Events so a + UI can show each step as it happens + +Run: + uvicorn app.api.server:app --host 0.0.0.0 --port 8000 +""" + +import json +import logging +from typing import Any + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from app.agents.research_agent import run_research, run_research_sync + +_LOG = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +app = FastAPI( + title="txtai Market Research Agent", + description="Agentic search over SEC filings + news for 67 tickers.", + version="0.1.0", +) + +# Allow the Streamlit UI (running on :8501) to call us. +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) + + +# ############################################################################# +# Schemas +# ############################################################################# + + +class ResearchRequest(BaseModel): + """ + Request body for ``/research`` and ``/research/stream``. + """ + + query: str + + +# ############################################################################# +# Routes +# ############################################################################# + + +@app.get("/") +def root() -> dict[str, Any]: + """ + Health probe; lists the available endpoints so a fresh dev knows what's + wired up without opening the OpenAPI docs. + """ + return { + "service": "txtai Market Research Agent", + "endpoints": { + "POST /research": "Run pipeline, return full result as JSON.", + "POST /research/stream": "Run pipeline, stream events as SSE.", + }, + } + + +@app.post("/research") +def research(req: ResearchRequest) -> dict[str, Any]: + """ + Run the research pipeline and return the final answer plus trace. + + :param req: JSON body with a ``query`` field + :return: dict with keys ``query``, ``route``, ``retrievals``, ``answer``, + ``used_llm``, ``chunk_count`` + """ + _LOG.info("Sync research query=%s", req.query) + return run_research_sync(req.query) + + +def _sse_stream(query: str): + """ + Adapt the research generator to the Server-Sent Events wire format. + + Each yielded event is encoded as ``data: \\n\\n`` so any standard + SSE client (browser EventSource, httpx, requests with stream=True) can + parse it. + + :param query: user query passed through to ``run_research`` + :yield: bytes ready to send on the wire + """ + try: + for event in run_research(query): + # Strip the heavyweight "chunks" payload from the wire — full + # text would blow up the SSE buffer for large responses. We + # send a compact preview instead; the final synthesise step + # already references the originals via citations. + payload = dict(event["payload"]) + if "chunks" in payload: + payload["chunks"] = [ + { + "score": c.get("score"), + "text": (c.get("text") or "")[:400], + "metadata": c.get("metadata"), + "tags": c.get("tags"), + } + for c in payload["chunks"] + ] + line = json.dumps({"step": event["step"], "payload": payload}) + yield f"data: {line}\n\n" + except Exception as e: + _LOG.exception("Stream failed") + err = json.dumps({"step": "error", "payload": {"message": str(e)}}) + yield f"data: {err}\n\n" + + +@app.post("/research/stream") +def research_stream(req: ResearchRequest) -> StreamingResponse: + """ + Run the pipeline as Server-Sent Events. + + The response media type is ``text/event-stream`` and each event is a + single ``data: {...}`` line followed by a blank line, exactly as + specified by the EventSource protocol. + + :param req: JSON body with a ``query`` field + :return: streaming response that emits one event per pipeline step + """ + _LOG.info("Stream research query=%s", req.query) + return StreamingResponse( + _sse_stream(req.query), + media_type="text/event-stream", + ) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/__init__.py new file mode 100644 index 000000000..2fc840971 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/__init__.py @@ -0,0 +1,27 @@ +""" +Data collectors for market research platform. + +Each collector fetches data from a source and stores it across all storage tiers: +- Cold: MinIO for raw document archive +- Warm: PostgreSQL for structured data + pgvector +- Hot: KeyDB for caching +- Search: txtai EmbeddingsIndex for semantic search + +Usage: + from app.collectors import SECCollector, NewsCollector, EarningsCollector + + collector = SECCollector() + collector.collect("AAPL", filing_types=["10-K", "8-K"]) +""" + +from app.collectors.base_collector import BaseCollector +from app.collectors.sec_collector import SECCollector +from app.collectors.news_collector import NewsCollector +from app.collectors.earnings_collector import EarningsCollector + +__all__ = [ + "BaseCollector", + "SECCollector", + "NewsCollector", + "EarningsCollector", +] diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/base_collector.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/base_collector.py new file mode 100644 index 000000000..fdf4894fc --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/base_collector.py @@ -0,0 +1,489 @@ +""" +Base collector with common storage logic. + +All collectors inherit from this base class which provides: +- Multi-tier storage (cold, warm, hot, search) +- Document chunking for embeddings +- Deduplication via deterministic IDs +- Progress tracking and logging +""" + +import hashlib +import logging +from abc import ABC, abstractmethod +from typing import Any, Optional + +from app.storage import ( + get_minio_client, + get_postgres_client, + get_cache_manager, + get_embeddings, +) +from app.storage.cold_storage.minio_client import MinIOClient +from app.storage.warm_storage.pgvector_client import PostgresClient +from app.storage.cache_manager import CacheManager + +_LOG = logging.getLogger(__name__) + + +_DEFAULT_SEPARATORS = ["\n\n", "\n", ". ", "? ", "! ", " ", ""] + + +def _split_keep_separator(text: str, separator: str) -> list[str]: + """ + Split ``text`` on ``separator`` while keeping the separator attached to + the preceding piece. Empty separator returns one-character splits. + """ + if separator == "": + return list(text) + parts = text.split(separator) + out: list[str] = [] + for i, part in enumerate(parts): + # Re-attach the separator to every piece except the last so that + # joining the chunks back yields the original text. + if i < len(parts) - 1: + out.append(part + separator) + elif part: + out.append(part) + return out + + +def _recursive_split( + text: str, separators: list[str], chunk_size: int +) -> list[str]: + """ + Recursively split ``text`` so that no piece exceeds ``chunk_size`` chars. + + Walks ``separators`` from coarse to fine. For each piece still over the + budget, recurse with the remaining (finer) separators. The final ``""`` + separator guarantees termination by falling back to a per-character split. + """ + if len(text) <= chunk_size or not separators: + return [text] if text else [] + separator, rest = separators[0], separators[1:] + pieces = _split_keep_separator(text, separator) + out: list[str] = [] + for piece in pieces: + if len(piece) <= chunk_size: + out.append(piece) + else: + out.extend(_recursive_split(piece, rest, chunk_size)) + return out + + +def _merge_with_overlap( + pieces: list[str], chunk_size: int, overlap: int +) -> list[str]: + """ + Greedily concatenate ``pieces`` up to ``chunk_size``, carrying the trailing + ``overlap`` chars from each emitted chunk into the next one. + """ + chunks: list[str] = [] + current = "" + for piece in pieces: + if not piece: + continue + if len(current) + len(piece) <= chunk_size: + current += piece + continue + if current: + chunks.append(current.strip()) + tail = current[-overlap:] if overlap > 0 else "" + current = tail + piece + else: + # Single piece exceeds the budget on its own — keep as-is. + chunks.append(piece.strip()) + current = "" + if current.strip(): + chunks.append(current.strip()) + return chunks + + +class BaseCollector(ABC): + # Soft target: ~512 tokens × ~4 chars/token ≈ 2048 chars per chunk. + MAX_TOKENS_PER_CHUNK = 512 + CHUNK_SIZE_CHARS = MAX_TOKENS_PER_CHUNK * 4 + CHUNK_OVERLAP_CHARS = 200 + + def __init__(self): + self.minio: MinIOClient = get_minio_client() + self.postgres: PostgresClient = get_postgres_client() + self.cache: CacheManager = get_cache_manager() + self.embeddings = get_embeddings() + + def _generate_doc_id(self, source: str, content: str, ticker: str) -> str: + key = f"{source}:{ticker}:{content[:500]}" + return hashlib.sha256(key.encode()).hexdigest()[:32] + + def _generate_filing_id(self, ticker: str, accession: str, form_type: str) -> str: + """Deterministic filing ID from accession number.""" + key = f"{ticker}:{accession}:{form_type}" + return hashlib.sha256(key.encode()).hexdigest()[:32] + + def _chunk_text(self, text: str) -> list[str]: + """ + Split ``text`` into chunks using a recursive character-based strategy. + + Tries separators from coarse to fine (paragraph → line → sentence → + word → char). Any segment that is still larger than ``CHUNK_SIZE_CHARS`` + after splitting on the current separator is recursively split with the + next one. Adjacent segments are then greedily merged back up to the + size budget, with ``CHUNK_OVERLAP_CHARS`` of trailing context carried + into the next chunk to preserve continuity across boundaries. + + :param text: raw document text + :return: list of chunk strings, each ``<= CHUNK_SIZE_CHARS`` chars + (best effort — a single token longer than the budget is kept whole) + """ + if not text: + return [text] + pieces = _recursive_split( + text, _DEFAULT_SEPARATORS, self.CHUNK_SIZE_CHARS + ) + merged = _merge_with_overlap( + pieces, self.CHUNK_SIZE_CHARS, self.CHUNK_OVERLAP_CHARS + ) + return merged if merged else [text] + + def _store_to_cold_tier( + self, + ticker: str, + content: str, + metadata: dict[str, Any], + ) -> Optional[str]: + source = self._get_source_tag() + url = metadata.get("url", "") + + if source == "sec": + return self.minio.store_sec_filing( + ticker=ticker, + filing_type=metadata.get("form_type", "unknown"), + accession_number=metadata.get("accession_number", ""), + content=content, + metadata=metadata, + ) + elif source == "news": + return self.minio.store_news_article( + ticker=ticker, + url=url, + content=content, + metadata=metadata, + ) + elif source == "earnings": + return self.minio.store_earnings_transcript( + ticker=ticker, + quarter_code=metadata.get("quarter", "unknown"), + content=content, + metadata=metadata, + ) + else: + object_name = f"generic/{source}/{ticker}/{self._generate_doc_id(source, content, ticker)}.txt" + return self.minio.put_object("raw_docs", object_name, content) + + def _build_filing_row( + self, + ticker: str, + filing_id: str, + metadata: dict[str, Any], + file_size: int, + ) -> dict[str, Any]: + """ + Map sec_collector metadata keys → filings table columns. + + sec_collector returns: + form_type, filing_date, report_date, accession_number, + cik, url, company_name, description + """ + + def parse_date(val: Any) -> Optional[str]: + """Return ISO date string or None.""" + if not val: + return None + if isinstance(val, str) and len(val) >= 10: + return val[:10] + return None + + return { + "id": filing_id, + "ticker": ticker, + "company_name": metadata.get("company_name", ""), + # sec_collector uses "form_type"; filings table calls it "filing_type" + "filing_type": metadata.get("form_type", "unknown"), + "cik": metadata.get("cik", ""), + "accession_number": metadata.get("accession_number", ""), + "filing_date": parse_date(metadata.get("filing_date")), + # sec_collector uses "report_date"; filings table calls it "period_of_report" + "period_of_report": parse_date(metadata.get("report_date")), + "document_url": metadata.get("url", ""), + "file_size_bytes": file_size, + } + + def _store_to_warm_tier( + self, + ticker: str, + chunks: list[dict[str, Any]], + filing_metadata: Optional[dict[str, Any]] = None, + ) -> int: + """ + Store structured data in PostgreSQL in the correct order: + 1. INSERT INTO filings (one row per document) + 2. INSERT INTO chunks (many rows, FK → filings.id) + 3. INSERT INTO document_metadata (key/value pairs, FK → chunks.id) + + Previously this method only called insert_chunks(), skipping filings + and document_metadata entirely, leaving filing_id = NULL on all chunks. + """ + if not chunks: + return 0 + + # ── 1. Insert filing row (once per unique filing_id) ────────────────── + seen_filing_ids: set[str] = set() + filing_rows: list[dict[str, Any]] = [] + + for chunk in chunks: + fid = chunk.get("filing_id") + if fid and fid not in seen_filing_ids: + seen_filing_ids.add(fid) + filing_rows.append(chunk["filing_row"]) + + if filing_rows: + try: + self.postgres.insert_filings(filing_rows) + _LOG.info("Inserted %d filing rows", len(filing_rows)) + except Exception as e: + _LOG.error("Failed to insert filings: %s", e) + # Don't continue — chunks FK-depend on filings existing + return 0 + + # ── 2. Wipe old chunks for these filings ───────────────────────────── + # Chunk IDs are content-hashed, but the chunks table also has a + # UNIQUE (filing_id, chunk_index) constraint. On re-ingest with even + # slightly different text the new chunk gets a new id while the + # (filing_id, chunk_index) slot is still held by the old row, so a + # plain ``ON CONFLICT (id) DO UPDATE`` upsert can't reach it. Delete + # the old chunks for these filings (cascades to document_metadata) + # so the next INSERT gets a clean slate. + if seen_filing_ids: + try: + deleted = self.postgres.delete_chunks_by_filing_ids( + list(seen_filing_ids) + ) + if deleted: + _LOG.info( + "Deleted %d existing chunks for %d filings before re-ingest", + deleted, + len(seen_filing_ids), + ) + except Exception as e: + _LOG.error("Failed to delete prior chunks: %s", e) + return 0 + + # ── 3. Insert chunks with embeddings ────────────────────────────────── + db_chunks = [] + for chunk in chunks: + try: + embedding_result = self.embeddings.batchtransform([chunk["text"]]) + embedding = ( + embedding_result[0].tolist() + if embedding_result is not None and len(embedding_result) > 0 + else None + ) + except Exception as e: + _LOG.warning("Embedding failed for chunk %s: %s", chunk["id"], e) + embedding = None + + db_chunks.append( + { + "id": chunk["id"], + "filing_id": chunk.get("filing_id"), # FK → filings.id ✓ + "chunk_index": chunk.get("chunk_index", 0), + "text": chunk["text"], + "section": chunk.get("section", ""), + "embedding": embedding, + } + ) + + try: + inserted = self.postgres.insert_chunks(db_chunks) + _LOG.info("Inserted %d chunks", inserted) + except Exception as e: + _LOG.error("Failed to insert chunks: %s", e) + return 0 + + # ── 4. Insert document_metadata (key/value pairs per chunk) ─────────── + metadata_rows = [] + for chunk in chunks: + chunk_meta = chunk.get("metadata", {}) + chunk_id = chunk["id"] + for key, value in chunk_meta.items(): + if value is None: + continue + metadata_rows.append( + { + "id": self._generate_doc_id( + "meta", f"{chunk_id}:{key}", ticker + ), + "chunk_id": chunk_id, + "key": key, + "value": str(value), + } + ) + + if metadata_rows: + try: + self.postgres.insert_document_metadata(metadata_rows) + _LOG.info("Inserted %d document_metadata rows", len(metadata_rows)) + except Exception as e: + # Non-fatal — metadata is supplementary + _LOG.warning("Failed to insert document_metadata: %s", e) + + return inserted + + def _store_to_search_index( + self, + ticker: str, + chunks: list[dict[str, Any]], + ) -> list[str]: + source = self._get_source_tag() + documents = [] + for chunk in chunks: + documents.append( + { + "id": chunk["id"], + "text": chunk["text"], + "tags": source, + "metadata": { + "ticker": ticker, + "source": source, + **chunk.get("metadata", {}), + }, + } + ) + return self.embeddings.upsert(documents) + + def _cache_results(self, ticker, query_key, results, ttl=3600): + cache_key = f"fetch:{self._get_source_tag()}:{ticker}:{query_key}" + return self.cache.set(cache_key, results, ttl=ttl) + + def _get_cached_results(self, ticker, query_key): + cache_key = f"fetch:{self._get_source_tag()}:{ticker}:{query_key}" + return self.cache.get(cache_key) + + @abstractmethod + def _fetch_data(self, ticker: str, **kwargs) -> list[dict[str, Any]]: + pass + + @abstractmethod + def _get_source_tag(self) -> str: + pass + + def collect( + self, + ticker: str, + store_cold: bool = True, + store_warm: bool = True, + store_search: bool = True, + use_cache: bool = False, + **kwargs, + ) -> dict[str, int]: + + _LOG.info( + "Starting collection for %s ticker=%s", self._get_source_tag(), ticker + ) + + results = {"fetched": 0, "stored_cold": 0, "stored_warm": 0, "indexed": 0} + + if use_cache: + cache_key = str(sorted(kwargs.items())) + cached = self._get_cached_results(ticker, cache_key) + if cached: + _LOG.info("Using cached results for %s", self._get_source_tag()) + raw_docs = cached + else: + raw_docs = self._fetch_data(ticker, **kwargs) + self._cache_results(ticker, cache_key, raw_docs) + else: + raw_docs = self._fetch_data(ticker, **kwargs) + + results["fetched"] = len(raw_docs) + _LOG.info("Fetched %d documents", len(raw_docs)) + + all_chunks = [] + + for doc in raw_docs: + text = doc.get("text", "") + if not text or len(text) < 10: + continue + + metadata = doc.get("metadata", {}) + source = self._get_source_tag() + + # ── Generate a stable filing_id for this document ──────────────── + filing_id = self._generate_filing_id( + ticker, + metadata.get("accession_number", text[:100]), + metadata.get("form_type", source), + ) + + # ── Build the filings table row for this document ──────────────── + filing_row = self._build_filing_row( + ticker=ticker, + filing_id=filing_id, + metadata=metadata, + file_size=len(text.encode("utf-8")), + ) + + # ── Cold storage ───────────────────────────────────────────────── + if store_cold: + object_path = self._store_to_cold_tier(ticker, text, metadata) + if object_path: + results["stored_cold"] += 1 + + # ── Chunk the document ─────────────────────────────────────────── + chunks = self._chunk_text(text) + for i, chunk_text in enumerate(chunks): + all_chunks.append( + { + "id": self._generate_doc_id(source, chunk_text, ticker), + "filing_id": filing_id, # ← FK to filings.id now set ✓ + "filing_row": filing_row, # ← carried for warm tier insert + "chunk_index": i, + "total_chunks": len(chunks), + "text": chunk_text, + "section": "", + "metadata": { + **metadata, + "source": source, + "ticker": ticker, + }, + } + ) + + # ── Warm tier (filings → chunks → document_metadata) ───────────────── + if store_warm and all_chunks: + inserted = self._store_to_warm_tier(ticker, all_chunks) + results["stored_warm"] = inserted + + # ── Search index ────────────────────────────────────────────────────── + if store_search and all_chunks: + indexed_ids = self._store_to_search_index(ticker, all_chunks) + results["indexed"] = len(indexed_ids) if indexed_ids else 0 + + if store_search: + from app.pipeline.embeddings import get_data_dir + + # txtai's Embeddings.save() takes a *directory* and writes + # ``config.json``, ``documents`` (SQLite), and ``embeddings`` + # (ANN index) into it. Passing a sub-path like + # ``data/index.db`` made txtai create that as a directory and + # left the live state in ``data/`` un-persisted. + self.embeddings.save(str(get_data_dir())) + + _LOG.info( + "Collection complete: fetched=%d cold=%d warm=%d indexed=%d", + results["fetched"], + results["stored_cold"], + results["stored_warm"], + results["indexed"], + ) + return results diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/earnings_collector.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/earnings_collector.py new file mode 100644 index 000000000..8fc9e0fcd --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/earnings_collector.py @@ -0,0 +1,190 @@ +""" +Earnings call transcript collector. + +Fetches full earnings call transcripts from Alpha Vantage's +``EARNINGS_CALL_TRANSCRIPT`` endpoint and stores them across all four tiers. + +The free Alpha Vantage tier supports this endpoint at 25 calls/day, so we +default to the last 4 quarters per ticker. + +Stores: +- Cold (MinIO): raw transcript JSON under ``earnings/{ticker}/{quarter}.json`` +- Warm (Postgres): filing row (filing_type=``EARNINGS_TRANSCRIPT``) + + chunks + document_metadata +- Search (txtai): chunks indexed with ``tags="earnings"`` +- Hot (KeyDB): fetch cache via the base collector +""" + +import logging +import os +from datetime import datetime +from typing import Any + +import httpx + +from app.collectors.base_collector import BaseCollector + +_LOG = logging.getLogger(__name__) + + +# Calendar quarter -> end-of-quarter month/day used for period_of_report. +_QUARTER_END = { + 1: (3, 31), + 2: (6, 30), + 3: (9, 30), + 4: (12, 31), +} + + +class EarningsCollector(BaseCollector): + """ + Collector for earnings call transcripts. + """ + + def _get_source_tag(self) -> str: + return "earnings" + + def _fetch_data( + self, + ticker: str, + quarters: int = 4, + year: int | None = None, + quarter: int | None = None, + **_, + ) -> list[dict]: + """ + Fetch earnings transcripts for the given ticker. + + :param ticker: stock ticker symbol + :param quarters: number of trailing quarters to fetch (default 4) + :param year: optional explicit year (overrides ``quarters``) + :param quarter: optional explicit quarter 1-4 (overrides ``quarters``) + :return: list of dicts with ``text`` and ``metadata`` + """ + api_key = os.getenv("ALPHAVANTAGE_API_KEY") + if not api_key: + _LOG.error( + "ALPHAVANTAGE_API_KEY not set; cannot fetch earnings transcripts" + ) + return [] + if year is not None and quarter is not None: + quarter_codes = [self._format_quarter(year, quarter)] + else: + quarter_codes = self._recent_quarter_codes(quarters) + results: list[dict] = [] + for q_code in quarter_codes: + doc = self._fetch_alphavantage_transcript(ticker, q_code, api_key) + if doc is not None: + results.append(doc) + _LOG.info("Fetched %d transcript(s) for %s", len(results), ticker) + return results + + @staticmethod + def _format_quarter(year: int, quarter: int) -> str: + return f"{year}Q{quarter}" + + @staticmethod + def _recent_quarter_codes(n: int) -> list[str]: + """ + Return the last ``n`` finished calendar quarter codes (most recent first). + """ + now = datetime.utcnow() + # Most recent finished quarter is the one before the current calendar + # quarter, since reports lag a few weeks. + cur_q = (now.month - 1) // 3 + 1 + year, quarter = now.year, cur_q - 1 + if quarter == 0: + year, quarter = year - 1, 4 + codes: list[str] = [] + for _ in range(n): + codes.append(f"{year}Q{quarter}") + quarter -= 1 + if quarter == 0: + year, quarter = year - 1, 4 + return codes + + def _fetch_alphavantage_transcript( + self, + ticker: str, + quarter_code: str, + api_key: str, + ) -> dict | None: + """ + Hit the Alpha Vantage transcript endpoint for one ticker / quarter. + + Returns None if the response is empty or rate-limited. + """ + url = "https://www.alphavantage.co/query" + params = { + "function": "EARNINGS_CALL_TRANSCRIPT", + "symbol": ticker, + "quarter": quarter_code, + "apikey": api_key, + } + try: + with httpx.Client(timeout=20.0) as client: + response = client.get(url, params=params) + response.raise_for_status() + payload = response.json() + except httpx.HTTPError as e: + _LOG.error("Alpha Vantage transcript HTTP error for %s %s: %s", ticker, quarter_code, e) + return None + # Alpha Vantage signals throttling and empty responses with informational keys. + if "Note" in payload or "Information" in payload: + _LOG.warning( + "Alpha Vantage rate-limit / info response for %s %s: %s", + ticker, + quarter_code, + payload.get("Note") or payload.get("Information"), + ) + return None + turns = payload.get("transcript") or [] + if not turns: + _LOG.info("No transcript available for %s %s", ticker, quarter_code) + return None + # Join all speaker turns into a single text blob; keep speakers in metadata. + text_lines = [] + speakers: list[str] = [] + for turn in turns: + speaker = (turn.get("speaker") or "").strip() + content = (turn.get("content") or "").strip() + if not content: + continue + if speaker: + text_lines.append(f"{speaker}: {content}") + if speaker not in speakers: + speakers.append(speaker) + else: + text_lines.append(content) + text = "\n".join(text_lines).strip() + if not text: + return None + year, q = self._parse_quarter_code(quarter_code) + period_end = self._quarter_end_date(year, q) + metadata: dict[str, Any] = { + "ticker": ticker, + "form_type": "EARNINGS_TRANSCRIPT", + "accession_number": f"{ticker}-EARN-{quarter_code}", + "quarter": quarter_code, + "fiscal_year": year, + "fiscal_quarter": q, + "period_of_report": period_end, + # filing_date is unknown from this endpoint; agents can fall back to period. + "filing_date": period_end, + "url": f"alphavantage:EARNINGS_CALL_TRANSCRIPT/{ticker}/{quarter_code}", + "company_name": payload.get("symbol", ticker), + "speakers": speakers, + "speaker_count": len(speakers), + "turn_count": len(turns), + } + return {"text": text, "metadata": metadata, "raw": payload} + + @staticmethod + def _parse_quarter_code(code: str) -> tuple[int, int]: + year_str, q_str = code.split("Q") + return int(year_str), int(q_str) + + @staticmethod + def _quarter_end_date(year: int, quarter: int) -> str: + month, day = _QUARTER_END[quarter] + return f"{year:04d}-{month:02d}-{day:02d}" diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/news_collector.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/news_collector.py new file mode 100644 index 000000000..0610b1eba --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/news_collector.py @@ -0,0 +1,221 @@ +""" +News collector for market research data. + +Fetches headlines and articles from: +- NewsAPI (https://newsapi.org) +- Alpha Vantage News (https://www.alphavantage.co) + +Stores across all storage tiers. +""" + +import logging +import os +from datetime import datetime, timedelta + +import httpx + +from app.collectors.base_collector import BaseCollector + +_LOG = logging.getLogger(__name__) + + +class NewsCollector(BaseCollector): + """ + Collector for news articles. + + Stores: + - Cold: Article HTML in MinIO + - Warm: Article metadata in PostgreSQL articles table + - Search: Embedded chunks in txtai + - Hot: Fetch cache in KeyDB + """ + + def _get_source_tag(self) -> str: + return "news" + + def _fetch_data( + self, + ticker: str, + days_back: int = 7, + limit: int = 50, + ) -> list[dict]: + """ + Fetch news articles for a given ticker. + + Args: + ticker: Stock ticker symbol + days_back: Number of days to look back + limit: Maximum number of articles to return + + Returns: + List of dicts with text and metadata + """ + articles = [] + + # Try NewsAPI first + newsapi_key = os.getenv("NEWSAPI_KEY") + if newsapi_key: + fetched = self._fetch_newsapi(ticker, days_back, limit) + articles.extend(fetched) + + # Try Alpha Vantage as fallback/supplement + alphavantage_key = os.getenv("ALPHAVANTAGE_API_KEY") + if alphavantage_key: + fetched = self._fetch_alphavantage(ticker, limit) + articles.extend(fetched) + + # Deduplicate by title + seen = set() + unique = [] + for article in articles: + title_key = article["metadata"].get("title", "")[:50] + if title_key not in seen: + seen.add(title_key) + unique.append(article) + + return unique[:limit] + + def _fetch_newsapi( + self, + ticker: str, + days_back: int, + limit: int, + ) -> list[dict]: + """Fetch from NewsAPI.""" + key = os.getenv("NEWSAPI_KEY") + if not key: + return [] + + from_date = (datetime.utcnow() - timedelta(days=days_back)).strftime("%Y-%m-%d") + + url = "https://newsapi.org/v2/everything" + params = { + "q": f'"{ticker}" OR "{ticker.replace("NASDAQ:", "")}"', + "from": from_date, + "sortBy": "relevancy", + "language": "en", + "apiKey": key, + "pageSize": min(limit, 100), + } + + try: + with httpx.Client(timeout=10.0) as client: + response = client.get(url, params=params) + response.raise_for_status() + data = response.json() + + articles = [] + for item in data.get("articles", []): + text_parts = [ + item.get("title", ""), + item.get("description", ""), + item.get("content", ""), + ] + text = " ".join(filter(None, text_parts)) + + if text: + articles.append( + { + "text": text, + "metadata": { + "title": item.get("title", ""), + "source": item.get("source", {}).get("name", ""), + "url": item.get("url", ""), + "published_at": item.get("publishedAt", ""), + "author": item.get("author", ""), + }, + } + ) + + return articles + + except httpx.HTTPError as e: + _LOG.error("NewsAPI error: %s", e) + return [] + + def _fetch_alphavantage( + self, + ticker: str, + limit: int, + ) -> list[dict]: + """Fetch from Alpha Vantage News API.""" + key = os.getenv("ALPHAVANTAGE_API_KEY") + if not key: + return [] + + url = "https://www.alphavantage.co/query" + params = { + "function": "NEWS_SENTIMENT", + "ticker": ticker, + "apikey": key, + "limit": min(limit, 50), + } + + try: + with httpx.Client(timeout=10.0) as client: + response = client.get(url, params=params) + response.raise_for_status() + data = response.json() + + articles = [] + for item in data.get("feed", []): + text_parts = [ + item.get("title", ""), + item.get("summary", ""), + ] + text = " ".join(filter(None, text_parts)) + + if text: + articles.append( + { + "text": text, + "metadata": { + "title": item.get("title", ""), + "source": item.get("source", ""), + "url": item.get("url", ""), + "published_at": item.get("time_published", ""), + "sentiment_score": item.get("overall_sentiment_score"), + }, + } + ) + + return articles + + except httpx.HTTPError as e: + _LOG.error("Alpha Vantage error: %s", e) + return [] + + def collect( + self, + ticker: str, + days_back: int = 7, + limit: int = 50, + store_cold: bool = True, + store_warm: bool = True, + store_search: bool = True, + use_cache: bool = False, + ) -> dict[str, int]: + """ + Run the full news collection pipeline. + + Args: + ticker: Stock ticker symbol + days_back: Days to look back + limit: Maximum articles + store_cold: Store in MinIO + store_warm: Store in PostgreSQL + store_search: Index in txtai + use_cache: Use cached results + + Returns: + Dict with counts + """ + return super().collect( + ticker=ticker, + days_back=days_back, + limit=limit, + store_cold=store_cold, + store_warm=store_warm, + store_search=store_search, + use_cache=use_cache, + ) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/sec_collector.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/sec_collector.py new file mode 100644 index 000000000..660705d2c --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/collectors/sec_collector.py @@ -0,0 +1,614 @@ +""" +SEC EDGAR collector — built on official data.sec.gov APIs only. + +Official API reference: + https://www.sec.gov/search-filings/edgar-application-programming-interfaces + +Endpoints used: + CIK lookup: https://www.sec.gov/files/company_tickers.json + Submissions: https://data.sec.gov/submissions/CIK{cik}.json + Filing docs: https://www.sec.gov/Archives/edgar/data/{cik}/{accession}/ + +Rate limit: 10 req/s (SEC documented max) +Concurrency: 8 parallel fetches (safely under limit with per-request delay) +""" + +import asyncio +import concurrent.futures +import logging +import os +import re +import time +from typing import Any, Coroutine, Optional + +import httpx + +from app.collectors.base_collector import BaseCollector + +_LOG = logging.getLogger(__name__) + +# ── Official SEC constants ──────────────────────────────────────────────────── +# User-Agent format required by SEC: "Company Name email@domain.com" +DEFAULT_USER_AGENT = os.getenv( + "SEC_USER_AGENT", + "txtai-market-research security-research@proton.me", +) + +# SEC allows max 10 req/s; stay safely under with 8 workers + 110ms delay +SEC_RATE_LIMIT_DELAY = 0.11 # 110ms proactive throttle per request +SEC_MAX_RETRIES = 4 +SEC_CONCURRENCY = 8 # semaphore cap + +# Default limits for large-scale collection +DEFAULT_FILING_TYPES = ["10-K", "10-Q", "8-K", "DEF 14A", "S-1", "10-K/A", "10-Q/A"] +DEFAULT_LARGE_SCALE_LIMIT = 5000 # Default for large-scale collection + +# Official API base URLs (data.sec.gov — documented public REST API) +DATA_API_BASE = "https://data.sec.gov" +SEC_BASE = "https://www.sec.gov" +SUBMISSIONS_URL = f"{DATA_API_BASE}/submissions/CIK{{cik}}.json" +COMPANY_TICKERS_URL = f"{SEC_BASE}/files/company_tickers.json" + +# Required headers per SEC developer FAQ +BASE_HEADERS = { + "User-Agent": DEFAULT_USER_AGENT, + "Accept-Encoding": "gzip, deflate", + "Host": "data.sec.gov", # overridden per-request where needed +} + + +def _run_async(coro: Coroutine[Any, Any, Any]) -> Any: + """ + Run an async coroutine from sync code, regardless of caller context. + + ``asyncio.run`` raises if a loop is already running (Jupyter, FastAPI + handlers, async test harnesses). This helper transparently dispatches + such calls to a worker thread that owns its own loop, leaving the + parent loop untouched. + + :param coro: coroutine to execute to completion + :return: whatever the coroutine returns + """ + try: + asyncio.get_running_loop() + except RuntimeError: + # No loop running — typical CLI / script case. + return asyncio.run(coro) + # A loop is already running. Spin up a worker thread that owns its own + # loop so the running loop stays untouched. + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + return executor.submit(asyncio.run, coro).result() + + +class SECCollector(BaseCollector): + """ + Collector for SEC EDGAR filings using the official data.sec.gov REST APIs. + + Flow per ticker: + 1. Resolve ticker → CIK (company_tickers.json) + 2. Fetch filing history (submissions/CIK{cik}.json, paginated) + 3. Filter by form type + 4. Concurrently fetch filing index pages → resolve primary document URL + 5. Fetch and clean primary document HTML + """ + + def _get_source_tag(self) -> str: + return "sec" + + # ── CIK resolution ──────────────────────────────────────────────────────── + + def _get_cik_for_ticker(self, ticker: str) -> Optional[str]: + """ + Resolve ticker → zero-padded 10-digit CIK. + Uses the official SEC company_tickers.json mapping. + """ + try: + with httpx.Client( + timeout=10.0, + headers={**BASE_HEADERS, "Host": "www.sec.gov"}, + ) as client: + r = client.get(COMPANY_TICKERS_URL) + r.raise_for_status() + for item in r.json().values(): + if item.get("ticker", "").upper() == ticker.upper(): + cik = str(item["cik_str"]).zfill(10) + _LOG.info("Resolved %s → CIK %s", ticker, cik) + return cik + _LOG.warning("Ticker %s not found in SEC company_tickers.json", ticker) + return None + except httpx.HTTPError as e: + _LOG.error("CIK lookup failed for %s: %s", ticker, e) + return None + + # ── Official submissions API ─────────────────────────────────────────────── + + def _fetch_submissions(self, cik: str, max_filings: int = 10000) -> dict: + """ + Fetch the complete submission history for a CIK from the official + data.sec.gov/submissions/ REST API. + + The primary JSON contains at least 1 year or 1,000 filings. + If the company has more filings, an additional `files` array lists + supplemental JSON files — we fetch and merge those too. + + Args: + cik: 10-digit CIK number + max_filings: Maximum number of filings to fetch (default: 10000) + + Returns the `filings.recent` dict (columnar arrays) with all historical + filings merged in, capped at max_filings. + """ + url = SUBMISSIONS_URL.format(cik=cik) + try: + with httpx.Client( + timeout=15.0, + headers=BASE_HEADERS, + follow_redirects=True, + ) as client: + r = client.get(url) + r.raise_for_status() + data = r.json() + + recent = data.get("filings", {}).get("recent", {}) + + # Fetch any additional paginated filing files + extra_files = data.get("filings", {}).get("files", []) + _LOG.info( + "CIK %s: found %d extra submission files, fetching all...", + cik, + len(extra_files), + ) + for file_info in extra_files: + fname = file_info.get("name", "") + if not fname: + continue + extra_url = f"{DATA_API_BASE}/submissions/{fname}" + time.sleep(SEC_RATE_LIMIT_DELAY) + try: + er = client.get(extra_url) + er.raise_for_status() + extra = er.json() + for key, values in extra.items(): + if isinstance(values, list) and key in recent: + recent[key].extend(values) + _LOG.debug( + "Merged extra file %s (%d items)", + fname, + len(extra.get(key, [])), + ) + except httpx.HTTPError as e: + _LOG.warning( + "Failed to fetch extra submissions file %s: %s", fname, e + ) + + # Cap results if we have more than requested + total_filings = len(recent.get("form", [])) + if total_filings > max_filings: + _LOG.info( + "CIK %s: capping %d filings to %d (max_filings)", + cik, + total_filings, + max_filings, + ) + for key in recent: + if isinstance(recent[key], list): + recent[key] = recent[key][:max_filings] + + return recent + + except httpx.HTTPError as e: + _LOG.error("Submissions API failed for CIK %s: %s", cik, e) + return {} + + def _filter_submissions_by_type( + self, + recent: dict, + filing_type: str, + limit: int, + ) -> list[dict]: + """ + Extract matching filings from the columnar submissions data structure. + + The submissions JSON uses parallel arrays (not a list of objects): + recent["form"] → ["10-K", "8-K", ...] + recent["accessionNumber"] → ["0001234567-23-000001", ...] + recent["filingDate"] → ["2023-02-03", ...] + etc. + + Returns list of dicts, newest first, up to `limit`. + """ + forms = recent.get("form", []) + accessions = recent.get("accessionNumber", []) + filing_dates = recent.get("filingDate", []) + primary_docs = recent.get("primaryDocument", []) + descriptions = recent.get("primaryDocDescription", []) + report_dates = recent.get("reportDate", []) + + matched = [] + for i, form in enumerate(forms): + # Match exact form type or amendment variants (e.g. 10-K/A) + if not (form == filing_type or form.startswith(f"{filing_type}/")): + continue + + matched.append( + { + "form": form, + "accession": accessions[i] if i < len(accessions) else "", + "filing_date": filing_dates[i] if i < len(filing_dates) else "", + "primary_doc": primary_docs[i] if i < len(primary_docs) else "", + "description": descriptions[i] if i < len(descriptions) else "", + "report_date": report_dates[i] if i < len(report_dates) else "", + } + ) + + if len(matched) >= limit: + break # submissions are already newest-first + + return matched + + # ── Async document fetching ─────────────────────────────────────────────── + + async def _async_get( + self, + client: httpx.AsyncClient, + url: str, + host_header: str = "www.sec.gov", + ) -> Optional[httpx.Response]: + """ + Async GET with proactive rate-limit delay and exponential-backoff + retry on 429/503 (the only codes SEC returns for throttling). + """ + for attempt in range(SEC_MAX_RETRIES): + await asyncio.sleep(SEC_RATE_LIMIT_DELAY) + try: + r = await client.get( + url, + headers={**BASE_HEADERS, "Host": host_header}, + ) + if r.status_code in (429, 503): + wait = (attempt + 1) * 2 # 2s, 4s, 6s, 8s + _LOG.warning( + "HTTP %s on %s — retrying in %ds (attempt %d/%d)", + r.status_code, + url, + wait, + attempt + 1, + SEC_MAX_RETRIES, + ) + await asyncio.sleep(wait) + continue + r.raise_for_status() + return r + except httpx.HTTPError as e: + if attempt == SEC_MAX_RETRIES - 1: + _LOG.error( + "Request failed after %d attempts: %s — %s", + SEC_MAX_RETRIES, + url, + e, + ) + return None + await asyncio.sleep(2**attempt) + return None + + async def _fetch_document( + self, + client: httpx.AsyncClient, + semaphore: asyncio.Semaphore, + cik: str, + filing: dict, + ) -> tuple[dict, str]: + """ + Fetch the primary document for a single filing. + + Strategy (using official URL patterns from SEC docs): + 1. Build primary doc URL directly from submissions data if available. + URL pattern: /Archives/edgar/data/{cik_int}/{accession_nodash}/{primary_doc} + 2. Fall back to parsing the filing index page if primary_doc is missing. + """ + async with semaphore: + accession = filing["accession"] + accession_clean = accession.replace("-", "") + cik_int = int(cik) # SEC URLs use integer CIK (no leading zeros) + base_archive = f"{SEC_BASE}/Archives/edgar/data/{cik_int}/{accession_clean}" + + # Path 1: primary document filename is in the submissions JSON + primary_doc = filing.get("primary_doc", "") + if primary_doc: + doc_url = f"{base_archive}/{primary_doc}" + resp = await self._async_get(client, doc_url) + if resp and len(resp.text) > 500: + return filing, self._strip_html(resp.text) + + # Path 2: fetch the index page and parse the primary document link + index_url = f"{base_archive}/{accession_clean}-index.htm" + index_resp = await self._async_get(client, index_url) + if not index_resp: + return filing, "" + + doc_url = self._find_primary_doc_in_index(index_resp.text, base_archive) + if not doc_url: + # Last resort: return cleaned index text + return filing, self._strip_html(index_resp.text) + + doc_resp = await self._async_get(client, doc_url) + if not doc_resp: + return filing, "" + + return filing, self._strip_html(doc_resp.text) + + def _find_primary_doc_in_index( + self, index_html: str, base_url: str + ) -> Optional[str]: + """ + Parse the EDGAR filing index page to find the primary document. + Prefers .htm/.html files; skips exhibits and the index page itself. + """ + # Match links in the filing index table + pattern = re.compile(r'href="([^"]+\.htm[l]?)"', re.IGNORECASE) + for href in pattern.findall(index_html): + if href.endswith("-index.htm"): + continue + # Relative href → absolute + if href.startswith("/Archives"): + return f"{SEC_BASE}{href}" + if not href.startswith("http"): + return f"{base_url}/{href}" + return href + return None + + # ── HTML cleaning ───────────────────────────────────────────────────────── + + def _strip_html(self, html: str) -> str: + """Clean HTML filing to plain text. Handles inline XBRL.""" + # Remove non-content blocks + for pat in ( + r"]*>.*?", + r"]*>.*?", + r"]*/>", # self-closing inline XBRL tags + r"]*>", # opening/closing XBRL wrappers + r"", # HTML comments + ): + html = re.sub(pat, " ", html, flags=re.DOTALL | re.IGNORECASE) + text = re.sub(r"<[^>]+>", " ", html) # strip remaining tags + text = re.sub(r"[ \t]+", " ", text) # collapse horizontal whitespace + text = re.sub(r"\n{3,}", "\n\n", text) # max 2 consecutive newlines + return text.strip() + + # ── Main async pipeline ─────────────────────────────────────────────────── + + async def _fetch_filing_type_async( + self, + cik: str, + filing_type: str, + filings_meta: list[dict], + ) -> list[dict]: + """Concurrently fetch document content for a list of filing metadata dicts.""" + semaphore = asyncio.Semaphore(SEC_CONCURRENCY) + + async with httpx.AsyncClient( + timeout=30.0, + follow_redirects=True, + ) as client: + tasks = [ + self._fetch_document(client, semaphore, cik, f) for f in filings_meta + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + output = [] + for result in results: + if isinstance(result, Exception): + _LOG.warning("Document fetch raised: %s", result) + continue + filing_meta, text = result + if not text or len(text) < 100: + continue + output.append( + { + "text": text, + "metadata": { + "form_type": filing_meta["form"], + "filing_date": filing_meta["filing_date"], + "report_date": filing_meta["report_date"], + "accession_number": filing_meta["accession"], + "description": filing_meta["description"], + "cik": cik, + "url": ( + f"{SEC_BASE}/Archives/edgar/data/{int(cik)}/" + f"{filing_meta['accession'].replace('-', '')}/" + f"{filing_meta['accession'].replace('-', '')}-index.htm" + ), + }, + } + ) + return output + + # ── Public interface ────────────────────────────────────────────────────── + + def _fetch_data( + self, + ticker: str, + filing_types: Optional[list[str]] = None, + limit: int = 5000, + max_filings: int = 10000, + ) -> list[dict]: + """ + Main entry point. Resolves CIK, fetches submission history via the + official data.sec.gov API, then concurrently downloads filing content. + + Args: + ticker: Stock ticker symbol + filing_types: List of form types to collect (default: 10-K, 10-Q, 8-K, DEF 14A, S-1) + limit: Maximum filings to return (default: 5000) + max_filings: Maximum filings to fetch from SEC API (default: 10000) + + Returns: + List of filing dicts with 'text' and 'metadata' keys + + Note: + For large-scale collection (1000+ filings), expect 5-15 minutes per ticker + depending on network speed and SEC API response times. + """ + if filing_types is None: + filing_types = DEFAULT_FILING_TYPES + + cik = self._get_cik_for_ticker(ticker) + if not cik: + return [] + + t0 = time.perf_counter() + + # ── Step 1: fetch full submission history (single API call) ── + _LOG.info( + "Fetching submission history for %s (CIK %s), max_filings=%d…", + ticker, + cik, + max_filings, + ) + recent = self._fetch_submissions(cik, max_filings=max_filings) + if not recent: + _LOG.error("No submission data returned for CIK %s", cik) + return [] + + total_on_record = len(recent.get("form", [])) + _LOG.info("Submission history: %d filings on record", total_on_record) + + # ── Step 2: filter by type (in-memory, no extra requests) ── + # For large-scale: get proportional share per filing type + per_type = max(100, limit // len(filing_types)) + all_meta: list[dict] = [] + for ft in filing_types: + matched = self._filter_submissions_by_type(recent, ft, per_type) + _LOG.info( + " [%s] matched %d filings (requested %d)", ft, len(matched), per_type + ) + all_meta.extend(matched) + + if not all_meta: + _LOG.warning( + "No matching filings found for %s with types %s", ticker, filing_types + ) + return [] + + _LOG.info( + "Total filings to fetch: %d (types: %s)", + len(all_meta), + ", ".join(filing_types), + ) + + # ── Step 3: concurrently fetch document content ── + # Process in batches to avoid memory issues with large collections + BATCH_SIZE = 100 + all_filings: list[dict] = [] + + for batch_start in range(0, len(all_meta), BATCH_SIZE): + batch_end = min(batch_start + BATCH_SIZE, len(all_meta)) + batch_meta = all_meta[batch_start:batch_end] + + _LOG.info( + "Fetching batch %d-%d of %d filings (concurrency=%d)…", + batch_start + 1, + batch_end, + len(all_meta), + SEC_CONCURRENCY, + ) + + batch_results = _run_async(self._fetch_all_types(cik, batch_meta)) + all_filings.extend(batch_results) + + # Progress update + if (batch_end // BATCH_SIZE) % 10 == 0: + _LOG.info( + "Progress: %d/%d filings processed", len(all_filings), len(all_meta) + ) + + # Sort newest-first and cap at limit + all_filings.sort( + key=lambda x: x["metadata"].get("filing_date", ""), + reverse=True, + ) + all_filings = all_filings[:limit] + + elapsed = time.perf_counter() - t0 + _LOG.info( + "Done: %d filings fetched in %.1fs (%.2fs avg/filing)", + len(all_filings), + elapsed, + elapsed / max(len(all_filings), 1), + ) + return all_filings + + async def _fetch_all_types(self, cik: str, all_meta: list[dict]) -> list[dict]: + """Run all document fetches in a single async context.""" + semaphore = asyncio.Semaphore(SEC_CONCURRENCY) + async with httpx.AsyncClient( + timeout=30.0, + follow_redirects=True, + ) as client: + tasks = [self._fetch_document(client, semaphore, cik, f) for f in all_meta] + results = await asyncio.gather(*tasks, return_exceptions=True) + + output = [] + for result in results: + if isinstance(result, Exception): + _LOG.warning("Document fetch error: %s", result) + continue + filing_meta, text = result + if not text or len(text) < 100: + continue + output.append( + { + "text": text, + "metadata": { + "form_type": filing_meta["form"], + "filing_date": filing_meta["filing_date"], + "report_date": filing_meta["report_date"], + "accession_number": filing_meta["accession"], + "description": filing_meta["description"], + "cik": cik, + "url": ( + f"{SEC_BASE}/Archives/edgar/data/{int(cik)}/" + f"{filing_meta['accession'].replace('-', '')}/" + f"{filing_meta['accession'].replace('-', '')}-index.htm" + ), + }, + } + ) + return output + + def collect( + self, + ticker: str, + filing_types: Optional[list[str]] = None, + limit: int = 5000, + max_filings: int = 10000, + store_cold: bool = True, + store_warm: bool = True, + store_search: bool = True, + use_cache: bool = False, + ) -> dict[str, int]: + """ + Collect SEC filings for a ticker. + + Args: + ticker: Stock ticker symbol + filing_types: List of form types (default: 10-K, 10-Q, 8-K, DEF 14A, S-1) + limit: Maximum filings to process (default: 5000) + max_filings: Max filings to fetch from SEC API (default: 10000) + store_cold: Store raw filings in MinIO + store_warm: Store chunks + embeddings in PostgreSQL + store_search: Index in txtai for semantic search + use_cache: Use cached results if available + """ + if filing_types is None: + filing_types = DEFAULT_FILING_TYPES + return super().collect( + ticker=ticker, + filing_types=filing_types, + limit=limit, + max_filings=max_filings, + store_cold=store_cold, + store_warm=store_warm, + store_search=store_search, + use_cache=use_cache, + ) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/main.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/main.py new file mode 100644 index 000000000..e61315c9b --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/main.py @@ -0,0 +1,102 @@ +""" +txtai Market Research Platform - Streamlit Entry Point + +This is the main entry point for the Streamlit application. +It provides two tabs: +1. Dashboard: Company ticker input with risk flags and earnings KPIs +2. Research Chat: Free-text Q&A routed through the orchestrator agent + +Usage: + streamlit run app/main.py +""" + +import os + +import streamlit as st + +from app.storage import get_cache_manager, get_keydb_client +from app.ui.dashboard import render as render_dashboard +from app.ui.chat import render as render_chat + + +# Page configuration +st.set_page_config( + page_title="txtai Market Research", + page_icon="📈", + layout="wide", + initial_sidebar_state="expanded", +) + + +def main(): + """Main application entry point.""" + # Initialize KeyDB connection on startup + cache_manager = get_cache_manager() + keydb_connected = get_keydb_client().ping() + + # Sidebar configuration + with st.sidebar: + st.title("txtai Market Research") + st.markdown("Multi-agent market research platform powered by txtai") + + # KeyDB Status Indicator + if keydb_connected: + st.success("KeyDB: Connected") + # Show cache stats + stats = cache_manager.get_stats() + with st.expander("Cache Statistics"): + st.metric("Prices Cached", stats["prices"]) + st.metric("Semantic Cache", stats["semantic"]) + st.metric("Active Sessions", stats["sessions"]) + else: + st.error("KeyDB: Disconnected") + + st.divider() + + # API Key configuration + api_key = st.text_input( + "OpenAI API Key", + type="password", + value=os.getenv("OPENAI_API_KEY", ""), + help="Enter your OpenAI API key for LLM inference", + ) + + if api_key: + os.environ["OPENAI_API_KEY"] = api_key + + st.divider() + + # Ticker input (used in Dashboard tab) + st.subheader("Company Lookup") + ticker = st.text_input( + "Ticker Symbol", + placeholder="AAPL", + max_chars=5, + help="Enter a stock ticker symbol", + ).upper() + + # Store ticker in session state for dashboard to use + st.session_state.ticker = ticker + + st.divider() + + # Utility links + st.markdown("### Resources") + st.markdown("- [txtai Documentation](https://neuml.github.io/txtai/)") + st.markdown("- [GitHub Repo](#)") + + # Main content area with tabs + tab1, tab2 = st.tabs(["📊 Dashboard", "💬 Research Chat"]) + + with tab1: + if ticker: + render_dashboard(ticker) + else: + st.info("Enter a ticker symbol in the sidebar to view the dashboard") + + with tab2: + render_chat() + + +if __name__ == "__main__": + main() diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/__init__.py new file mode 100644 index 000000000..2f3ae9de6 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/__init__.py @@ -0,0 +1 @@ +# txtai Pipeline modules diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/embeddings.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/embeddings.py new file mode 100644 index 000000000..d0ed2cf88 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/embeddings.py @@ -0,0 +1,156 @@ +""" +Shared EmbeddingsIndex configuration for txtai. + +This module configures a single txtai.Embeddings instance backed by SQLite. +All agents share this index but filter by metadata tags (source: news|sec). +""" + +import os +from pathlib import Path +from txtai.embeddings import Embeddings + + +def get_data_dir() -> Path: + """Get the data directory path, creating it if necessary.""" + # Support both local and deployed environments + base = Path( + os.getenv("TXTAI_DATA_DIR", Path(__file__).parent.parent.parent / "data") + ) + base.mkdir(parents=True, exist_ok=True) + return base + + +def create_embeddings() -> Embeddings: + """ + Create and return a configured txtai Embeddings instance. + + If a saved index exists in the data directory (config + documents + + embeddings files), load it. Otherwise create a fresh index. + + - Embedding model: sentence-transformers/all-mpnet-base-v2 (768-dim, + matches pgvector schema) + - content=True stores original text in SQLite alongside the ANN index + """ + data_dir = get_data_dir() + config_file = data_dir / "config.json" + if config_file.exists(): + # Load persisted index from disk. + embeddings = Embeddings() + embeddings.load(str(data_dir)) + return embeddings + # No saved index: create a new one. + embeddings = Embeddings( + { + "path": "sentence-transformers/all-mpnet-base-v2", + "content": True, + "chunksize": 100, + } + ) + return embeddings + + +# Global singleton instance - shared across all modules +# This ensures all agents query the same index +_embeddings_instance: Embeddings | None = None + + +def get_embeddings() -> Embeddings: + """ + Get the singleton Embeddings instance. + + Creates a new instance on first call, returns existing instance thereafter. + This prevents multiple SQLite connections and ensures consistent state. + """ + global _embeddings_instance + if _embeddings_instance is None: + _embeddings_instance = create_embeddings() + return _embeddings_instance + + +def search(query: str, source_filter: str | None = None, limit: int = 5) -> list[dict]: + """ + Search the embeddings index with optional source filtering. + + :param query: the search query text + :param source_filter: optional source tag filter (e.g. ``"news"`` or + ``"sec"``); ``None`` searches across everything + :param limit: maximum number of results to return + :return: list of dicts with keys ``id``, ``text``, ``score``, ``tags``, + ``metadata`` (decoded from the on-disk JSON ``data`` column) + """ + import json + + embeddings = get_embeddings() + # Use parameterized SQL so apostrophes / quotes / colons in the user's + # query don't break txtai's SQL parser. Pulling the ``data`` column + # surfaces our custom per-chunk metadata (ticker, filing_type, + # filing_date) which txtai does not expose by default. + params: dict[str, str] = {"q": query} + where_clauses = ["similar(:q)"] + if source_filter: + where_clauses.append("tags = :src") + params["src"] = source_filter + where_sql = " AND ".join(where_clauses) + sql = ( + "SELECT id, text, score, tags, data FROM txtai " + f"WHERE {where_sql} LIMIT {int(limit)}" + ) + raw = embeddings.search(sql, parameters=params, limit=limit) + # Decode the ``data`` blob and lift the inner ``metadata`` dict to top + # level so callers don't have to know about txtai's internal layout. + out = [] + for row in raw: + data_str = row.get("data") or "{}" + try: + data = json.loads(data_str) if isinstance(data_str, str) else data_str + except Exception: + data = {} + out.append( + { + "id": row.get("id"), + "text": row.get("text") or data.get("text", ""), + "score": row.get("score"), + "tags": row.get("tags") or data.get("tags"), + "metadata": data.get("metadata") or {}, + } + ) + return out + + +def upsert(documents: list[dict], *, save: bool = True) -> list[str]: + """ + Index documents into the embeddings database. + + Args: + documents: List of dicts with keys: + - id: Unique document identifier + - text: The text content to embed + - tags: Source tag (news|sec) + save: If True, persist the ANN index files after upsert. + Set False during batched backfills, then call + get_embeddings().save(get_data_dir()) once at the end. + + Returns: + List of document IDs that were indexed. + + Note: txtai's 'upsert' inserts new documents or updates existing ones + with the same ID. The content (text + metadata) is stored in the + SQLite database at index_path; the ANN index is persisted via save(). + """ + embeddings = get_embeddings() + # Transform to txtai's expected dict format. + index_documents = [] + for doc in documents: + index_documents.append( + { + "id": doc["id"], + "text": doc["text"], + "tags": doc.get("tags", "unknown"), + "metadata": doc.get("metadata", {}), + } + ) + embeddings.upsert(index_documents) + if save: + # Persist ANN index files to the data directory. + embeddings.save(str(get_data_dir())) + return [doc["id"] for doc in index_documents] diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/ingest.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/ingest.py new file mode 100644 index 000000000..d9ce1f68f --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/ingest.py @@ -0,0 +1,200 @@ +""" +Data ingestion pipeline for txtai market research platform. + +This module orchestrates the full ingestion pipeline using collectors: +1. Fetch data from all sources (sec, news, earnings) +2. Store raw documents in MinIO cold storage +3. Store structured data in PostgreSQL warm storage +4. Chunk documents to <= 512 tokens +5. Embed with txtai's sentence transformers +6. Upsert into the shared EmbeddingsIndex + +Usage: + python -m app.pipeline.ingest --ticker AAPL + +Or use collectors directly: + from app.collectors import SECCollector, NewsCollector, EarningsCollector + + sec = SECCollector() + sec.collect("AAPL", filing_types=["10-K", "8-K"]) +""" + +import logging + +from app.collectors import SECCollector, NewsCollector, EarningsCollector + +_LOG = logging.getLogger(__name__) + + +def ingest_all( + ticker: str, + store_cold: bool = True, + store_warm: bool = True, + store_search: bool = True, +) -> dict[str, dict[str, int]]: + """ + Run the full ingestion pipeline for all data sources. + + Args: + ticker: Stock ticker symbol (e.g., "AAPL") + store_cold: Store raw documents in MinIO (default: True) + store_warm: Store structured data in PostgreSQL (default: True) + store_search: Index in txtai for semantic search (default: True) + + Returns: + Dict mapping source name to collection results + """ + results = {} + + # Define all data sources and their collectors + collectors = [ + ("news", NewsCollector(), {}), + ("sec", SECCollector(), {"filing_types": ["10-K", "8-K", "DEF 14A"]}), + ("earnings", EarningsCollector(), {"quarters": 4}), + ] + + _LOG.info("Starting ingestion pipeline for %s...", ticker) + + for source_name, collector, kwargs in collectors: + try: + _LOG.info(" Collecting from %s...", source_name) + result = collector.collect( + ticker=ticker, + store_cold=store_cold, + store_warm=store_warm, + store_search=store_search, + **kwargs, + ) + results[source_name] = result + + except Exception as e: + _LOG.error(" Error collecting from %s: %s", source_name, e) + results[source_name] = { + "fetched": 0, + "stored_cold": 0, + "stored_warm": 0, + "indexed": 0, + } + + # Print summary + total_fetched = sum(r.get("fetched", 0) for r in results.values()) + total_cold = sum(r.get("stored_cold", 0) for r in results.values()) + total_warm = sum(r.get("stored_warm", 0) for r in results.values()) + total_indexed = sum(r.get("indexed", 0) for r in results.values()) + + _LOG.info("Ingestion complete:") + _LOG.info(" Total fetched: %d", total_fetched) + _LOG.info(" Stored in cold: %d", total_cold) + _LOG.info(" Stored in warm: %d", total_warm) + _LOG.info(" Indexed: %d", total_indexed) + + return results + + +def ingest_source( + ticker: str, + source: str, + store_cold: bool = True, + store_warm: bool = True, + store_search: bool = True, + **kwargs, +) -> dict[str, int]: + """ + Run ingestion for a specific data source. + + Args: + ticker: Stock ticker symbol + source: Source name (news, sec, earnings) + store_cold: Store in MinIO + store_warm: Store in PostgreSQL + store_search: Index in txtai + **kwargs: Source-specific arguments + + Returns: + Dict with counts: {'fetched', 'stored_cold', 'stored_warm', 'indexed'} + """ + collectors = { + "news": NewsCollector(), + "sec": SECCollector(), + "earnings": EarningsCollector(), + } + + if source not in collectors: + raise ValueError(f"Unknown source: {source}. Valid: {list(collectors.keys())}") + + collector = collectors[source] + return collector.collect( + ticker=ticker, + store_cold=store_cold, + store_warm=store_warm, + store_search=store_search, + **kwargs, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Ingest market research data") + parser.add_argument("--ticker", default="AAPL", help="Stock ticker symbol") + parser.add_argument( + "--source", + choices=["news", "sec", "earnings", "all"], + default="all", + help="Data source to ingest (default: all)", + ) + parser.add_argument( + "--no-cold", action="store_true", help="Skip cold storage (MinIO)" + ) + parser.add_argument( + "--no-warm", action="store_true", help="Skip warm storage (PostgreSQL)" + ) + parser.add_argument( + "--no-search", action="store_true", help="Skip search index (txtai)" + ) + args = parser.parse_args() + + # Load environment variables for API keys + from dotenv import load_dotenv + + load_dotenv() + + # Configure logging + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + if args.source == "all": + results = ingest_all( + args.ticker, + store_cold=not args.no_cold, + store_warm=not args.no_warm, + store_search=not args.no_search, + ) + _LOG.info("Summary by source:") + for source, counts in results.items(): + _LOG.info( + " %s: fetched=%s, cold=%s, warm=%s, indexed=%s", + source, + counts["fetched"], + counts["stored_cold"], + counts["stored_warm"], + counts["indexed"], + ) + else: + result = ingest_source( + args.ticker, + args.source, + store_cold=not args.no_cold, + store_warm=not args.no_warm, + store_search=not args.no_search, + ) + _LOG.info( + "%s: fetched=%s, cold=%s, warm=%s, indexed=%s", + args.source, + result["fetched"], + result["stored_cold"], + result["stored_warm"], + result["indexed"], + ) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/test/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/test/test_embeddings.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/test/test_embeddings.py new file mode 100644 index 000000000..99524864e --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/pipeline/test/test_embeddings.py @@ -0,0 +1,51 @@ +""" +Unit tests for the txtai embeddings helpers. + +Covers ``get_data_dir`` env-var override and on-demand directory creation. +The expensive paths (model load, index build) are not exercised here because +they require sentence-transformers and an actual index on disk. +""" + +import os +import tempfile +import unittest +from pathlib import Path + +from app.pipeline.embeddings import get_data_dir + + +# ############################################################################# +# Test_get_data_dir +# ############################################################################# + + +class Test_get_data_dir(unittest.TestCase): + """ + Test that ``get_data_dir`` honors ``TXTAI_DATA_DIR`` and creates the path. + """ + + def test1(self) -> None: + """ + Env override returns the requested directory and creates it. + """ + # Prepare inputs. + with tempfile.TemporaryDirectory() as tmp: + target = Path(tmp) / "txtai_test_data" + old = os.environ.get("TXTAI_DATA_DIR") + os.environ["TXTAI_DATA_DIR"] = str(target) + try: + # Run test. + out = get_data_dir() + # Check outputs. + self.assertEqual(out, target) + self.assertTrue(out.exists()) + self.assertTrue(out.is_dir()) + finally: + if old is None: + os.environ.pop("TXTAI_DATA_DIR", None) + else: + os.environ["TXTAI_DATA_DIR"] = old + + +if __name__ == "__main__": + unittest.main() diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/README.md b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/README.md new file mode 100644 index 000000000..46a454c8b --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/README.md @@ -0,0 +1,346 @@ +# Storage Layer - Multi-Tier Architecture + +This module implements a **multi-tier storage architecture** for financial market research data. + +## Architecture Overview + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Storage Architecture │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ HOT TIER - KeyDB (Redis-compatible) │ │ +│ │ ┌───────────────┬──────────────────┬────────────────────────────┐ │ │ +│ │ │ prices:{ticker}│ cache:{md5_hash} │ session:{id} │ │ │ +│ │ │ TTL: 60s │ TTL: 3600s │ TTL: 1800s │ │ │ +│ │ │ (live prices) │ (semantic cache) │ (agent memory) │ │ │ +│ │ └───────────────┴──────────────────┴────────────────────────────┘ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ WARM TIER - PostgreSQL + pgvector │ │ +│ │ ┌────────────┬──────────────┬────────────────┬──────────────────┐ │ │ +│ │ │ filings │ chunks │ xbrl_facts │ articles │ │ │ +│ │ │ (metadata) │ (w/embeddings)│ (structured) │ (news metadata) │ │ │ +│ │ └────────────┴──────────────┴────────────────┴──────────────────┘ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ COLD TIER - MinIO (S3-compatible) │ │ +│ │ ┌────────────┬──────────────┬────────────────┬──────────────────┐ │ │ +│ │ │ sec/ │ news/ │ web/ │ social/ │ │ │ +│ │ │ (filings) │ (articles) │ (scrapes) │ (posts) │ │ │ +│ │ └────────────┴──────────────┴────────────────┴──────────────────┘ │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ SEARCH - txtai EmbeddingsIndex (SQLite) │ │ +│ │ - Embedded chunks with semantic search │ │ +│ │ - Filterable by source (news|sec|web|social) │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +## Storage Tiers + +| Tier | Technology | Purpose | TTL | Data Types | +|------|------------|---------|-----|------------| +| **Hot** | KeyDB | Live cache, sessions | 60s-3600s | Prices, semantic cache, agent state | +| **Warm** | PostgreSQL + pgvector | Structured data, embeddings | Persistent | Filings, chunks, XBRL facts, articles | +| **Cold** | MinIO | Raw document archive | Persistent | SEC filings, news HTML, web scrapes | +| **Search** | txtai (SQLite) | Semantic search index | Persistent | Embedded chunks with metadata | + +## Components + +### Hot Tier - KeyDB + +**Files:** `hot_storage/keydb_client.py`, `cache_manager.py` + +```python +from app.storage import get_keydb_client, get_cache_manager + +# Low-level client +client = get_keydb_client() +client.set("key", "value", ttl=300) + +# High-level cache manager +cache = get_cache_manager() +cache.set_price("AAPL", price_data) +cache.set_semantic(query, results) +``` + +### Warm Tier - PostgreSQL + pgvector + +**Files:** `warm_storage/pgvector_client.py`, `warm_storage/filings_manager.py` + +```python +from app.storage import get_postgres_client + +postgres = get_postgres_client() + +# Insert filing metadata +filing_id = postgres.insert_filing(filing_data) + +# Insert chunks with embeddings +postgres.insert_chunks(chunks) + +# Vector similarity search +results = postgres.search_similar(query_embedding, limit=10) +``` + +### Cold Tier - MinIO + +**Files:** `cold_storage/minio_client.py` + +```python +from app.storage import get_minio_client + +minio = get_minio_client() + +# Store SEC filing +minio.store_sec_filing( + ticker="AAPL", + filing_type="10-K", + accession_number="0000320193-24-000006", + content=html_content, +) + +# Store news article +minio.store_news_article( + ticker="AAPL", + url="https://...", + content=html_content, + metadata={"title": "...", "published_at": "..."}, +) +``` + +### Search Tier - txtai + +**Files:** `../pipeline/embeddings.py` + +```python +from app.pipeline.embeddings import get_embeddings, search + +embeddings = get_embeddings() +results = search("Apple revenue", source_filter="sec", limit=5) +``` + +## Collectors - Writing to All Tiers + +The `app/collectors/` module provides unified data collection that writes to all storage tiers: + +```python +from app.collectors import SECCollector, NewsCollector, WebCollector, SocialCollector + +# SEC filings collector +sec = SECCollector() +results = sec.collect( + ticker="AAPL", + filing_types=["10-K", "8-K"], + store_cold=True, # MinIO: raw filings + store_warm=True, # PostgreSQL: metadata + chunks + store_search=True, # txtai: embeddings +) + +# News collector +news = NewsCollector() +news.collect("AAPL", days_back=7) + +# Web collector (press releases) +web = WebCollector() +web.collect("AAPL") + +# Social collector (Reddit, StockTwits) +social = SocialCollector() +social.collect("AAPL", subreddits=["investing", "stocks"]) +``` + +### Collector Flow + +``` +┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ +│ Fetch │────▶│ Cold │────▶│ Warm │────▶│ Search │ +│ (API) │ │ (MinIO) │ │ (PostgreSQL)│ │ (txtai) │ +└──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ + │ │ │ │ + │ │ │ │ + ▼ ▼ ▼ ▼ + Raw documents Raw HTML/JSON Structured data Embeddings + + with metadata archive + chunks semantic index +``` + +## Setup + +### 1. Install Dependencies + +```bash +pip install -r requirements.txt +``` + +### 2. Configure Environment + +Copy `.env.example` to `.env` and configure: + +```bash +# Hot Tier - KeyDB +KEYDB_HOST=localhost +KEYDB_PORT=6379 +KEYDB_PASSWORD= + +# Cold Tier - MinIO +MINIO_ENDPOINT=localhost:9000 +MINIO_ACCESS_KEY=minioadmin +MINIO_SECRET_KEY=minioadmin +MINIO_SECURE=false + +# Warm Tier - PostgreSQL +POSTGRES_HOST=localhost +POSTGRES_PORT=5432 +POSTGRES_DB=financial_kb +POSTGRES_USER=fin +POSTGRES_PASSWORD=fin_local + +# Embeddings - Ollama +OLLAMA_HOST=http://localhost:11434 +OLLAMA_EMBEDDING_MODEL=nomic-embed-text +``` + +### 3. Start Infrastructure + +```bash +docker-compose up -d +``` + +This starts: +- KeyDB on port 6379 +- PostgreSQL + pgvector on port 5432 +- MinIO on ports 9000 (API) and 9001 (console) + +### 4. Verify Connections + +```bash +# KeyDB +redis-cli -h localhost -p 6379 ping # Should return: PONG + +# PostgreSQL +psql -h localhost -U fin -d financial_kb -c "SELECT 1" + +# MinIO +curl http://localhost:9000/minio/health/live # Should return: OK + +# Or use the test scripts +python -m app.storage.hot_storage.tests.test_keydb +python -m app.storage.warm_storage.tests.test_pgvector +python -m app.storage.cold_storage.tests.test_minio +``` + +## Data Flow Example + +Full pipeline for collecting and storing SEC filings: + +```python +from app.collectors import SECCollector + +# Initialize collector +sec = SECCollector() + +# Collect filings - stores to ALL tiers +results = sec.collect( + ticker="AAPL", + filing_types=["10-K", "8-K", "DEF 14A"], + limit=20, + store_cold=True, # Archive raw HTML in MinIO + store_warm=True, # Store metadata in PostgreSQL + store_search=True, # Generate embeddings in txtai +) + +print(f"Fetched: {results['fetched']}") +print(f"Stored in cold (MinIO): {results['stored_cold']}") +print(f"Stored in warm (PostgreSQL): {results['stored_warm']}") +print(f"Indexed for search: {results['indexed']}") +``` + +## Bucket Structure (MinIO) + +``` +filings/ +├── sec/ +│ ├── AAPL/ +│ │ ├── 10-K/ +│ │ │ └── 000032019324000006.html +│ │ └── 8-K/ +│ └── MSFT/ +│ └── ... +articles/ +├── news/ +│ ├── AAPL/ +│ │ ├── 2024-01-15/ +│ │ │ └── abc123def456.html +│ └── ... +web_scrapes/ +├── web/ +│ ├── AAPL/ +│ │ └── abc123.html +social/ +├── reddit/ +│ └── AAPL/ +│ └── post_id.json +``` + +## Key Design Decisions + +### Why Multi-Tier? + +| Tier | Use Case | Query Pattern | +|------|----------|---------------| +| Hot | Real-time data, session state | Key-value lookup, sub-millisecond | +| Warm | Structured queries, vector search | SQL, cosine similarity | +| Cold | Compliance, audit, reprocessing | Object retrieval | +| Search | Semantic search | Natural language queries | + +### TTL Strategy + +| Data Type | TTL | Rationale | +|-----------|-----|-----------| +| Prices | 60s | Live feeds update frequently | +| Semantic cache | 3600s | Expensive embeddings, 1hr balances cost/freshness | +| Sessions | 1800s | User sessions expire after 30min inactivity | +| Cold/Warm | Persistent | Compliance and historical analysis | + +## Troubleshooting + +### Connection Issues + +```bash +# Check Docker containers +docker-compose ps + +# View logs +docker-compose logs keydb +docker-compose logs postgres +docker-compose logs minio +``` + +### MinIO Browser + +Access MinIO console at http://localhost:9001 with credentials: +- Username: `minioadmin` +- Password: `minioadmin` + +### PostgreSQL Vector Search + +```sql +-- Check pgvector extension +SELECT * FROM pg_extension WHERE extname = 'vector'; + +-- Check embedding dimensions +SELECT embedding::text FROM chunks LIMIT 1; +``` + +## Next Steps + +1. **Graph Tier**: Add Kuzu for company/filing relationships +2. **Analytics Tier**: Add DuckDB + Parquet for time-series analysis +3. **Backup**: Configure MinIO bucket replication for disaster recovery diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/__init__.py new file mode 100644 index 000000000..7146d12de --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/__init__.py @@ -0,0 +1,45 @@ +""" +Storage layer for txtai Market Research Platform. + +This package provides multi-tier storage infrastructure: +- Hot tier: KeyDB (Redis-compatible) for live prices, semantic cache, sessions +- Warm tier: PostgreSQL + pgvector for filings, chunks, XBRL facts +- Cold tier: MinIO for raw filings archive +- Graph tier: Kuzu for company/filing/regulation relationships (planned) +""" + +# Hot tier - KeyDB +from app.storage.hot_storage.keydb_client import KeyDBClient, get_keydb_client +from app.storage.cache_manager import CacheManager, get_cache_manager + +# Warm tier - PostgreSQL + pgvector +from app.storage.warm_storage import ( + PostgresClient, + get_postgres_client, + FilingsManager, + get_filings_manager, +) + +# Cold tier - MinIO +from app.storage.cold_storage import MinIOClient, get_minio_client + +# Embeddings +from app.pipeline.embeddings import get_embeddings + +__all__ = [ + # Hot tier + "KeyDBClient", + "get_keydb_client", + "CacheManager", + "get_cache_manager", + # Warm tier + "PostgresClient", + "get_postgres_client", + "FilingsManager", + "get_filings_manager", + # Cold tier + "MinIOClient", + "get_minio_client", + # Embeddings + "get_embeddings", +] diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cache_manager.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cache_manager.py new file mode 100644 index 000000000..eae19d7f3 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cache_manager.py @@ -0,0 +1,467 @@ +""" +Cache manager for KeyDB hot tier. + +Provides high-level cache operations for: +- Live price feeds (TTL 60s) +- Semantic cache (TTL 3600s) +- Session memory (TTL 1800s) + +Key Patterns: +- prices:{ticker} - Live stock prices +- cache:{md5_hash} - Semantic query cache +- session:{session_id} - Agent conversation state +""" + +import hashlib +import logging +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Optional + +from app.storage.hot_storage.keydb_client import KeyDBClient, get_keydb_client + +_LOG = logging.getLogger(__name__) + + +# TTL constants (in seconds) +TTL_PRICES = 60 # 1 minute for live price feeds +TTL_CACHE = 3600 # 1 hour for semantic cache +TTL_SESSION = 1800 # 30 minutes for agent sessions + + +@dataclass +class PriceData: + """Live price data for a ticker.""" + + ticker: str + price: float + change: float + change_percent: float + volume: int + timestamp: datetime + + def to_dict(self) -> dict: + """Convert to dictionary.""" + return { + "ticker": self.ticker, + "price": self.price, + "change": self.change, + "change_percent": self.change_percent, + "volume": self.volume, + "timestamp": self.timestamp.isoformat(), + } + + @classmethod + def from_dict(cls, data: dict) -> "PriceData": + """Create from dictionary.""" + return cls( + ticker=data["ticker"], + price=data["price"], + change=data["change"], + change_percent=data["change_percent"], + volume=data["volume"], + timestamp=datetime.fromisoformat(data["timestamp"]), + ) + + +class CacheManager: + """ + High-level cache manager for hot tier operations. + + Provides typed methods for each cache category with appropriate TTLs. + """ + + def __init__(self, client: Optional[KeyDBClient] = None): + """ + Initialize cache manager. + + Args: + client: KeyDB client (uses singleton if not provided) + """ + self.client = client or get_keydb_client() + + # ------------------------------------------------------------------------- + # Price Cache (TTL: 60s) + # ------------------------------------------------------------------------- + + def set_price(self, ticker: str, price_data: PriceData) -> bool: + """ + Cache live price data for a ticker. + + Args: + ticker: Stock ticker symbol + price_data: Price data object + + Returns: + True if successful + """ + key = f"prices:{ticker.upper()}" + return self.client.set(key, price_data.to_dict(), ttl=TTL_PRICES) + + def get_price(self, ticker: str) -> Optional[PriceData]: + """ + Get cached price data for a ticker. + + Args: + ticker: Stock ticker symbol + + Returns: + PriceData if cached and not expired, None otherwise + """ + key = f"prices:{ticker.upper()}" + data = self.client.get(key) + if data: + return PriceData.from_dict(data) + return None + + def get_prices_batch(self, tickers: list[str]) -> dict[str, PriceData]: + """ + Get cached prices for multiple tickers. + + Args: + tickers: List of ticker symbols + + Returns: + Dict mapping ticker to PriceData (only for cached tickers) + """ + result = {} + for ticker in tickers: + price = self.get_price(ticker) + if price: + result[ticker.upper()] = price + return result + + def clear_price(self, ticker: str) -> bool: + """ + Clear cached price for a ticker. + + Args: + ticker: Stock ticker symbol + + Returns: + True if key was deleted + """ + key = f"prices:{ticker.upper()}" + return self.client.delete(key) > 0 + + # ------------------------------------------------------------------------- + # Semantic Cache (TTL: 3600s) + # ------------------------------------------------------------------------- + + def _compute_cache_key(self, query: str, context: Optional[str] = None) -> str: + """ + Compute MD5 hash for cache key. + + Args: + query: Search query string + context: Optional context string + + Returns: + MD5 hash string + """ + key_material = f"{query}:{context}" if context else query + return hashlib.md5(key_material.encode()).hexdigest() + + def get_semantic(self, query: str, context: Optional[str] = None) -> Optional[Any]: + """ + Get cached semantic search results. + + Args: + query: Search query + context: Optional context for more specific caching + + Returns: + Cached results if found, None otherwise + """ + key = f"cache:{self._compute_cache_key(query, context)}" + return self.client.get(key) + + def set_semantic( + self, + query: str, + results: Any, + context: Optional[str] = None, + ) -> bool: + """ + Cache semantic search results. + + Args: + query: Search query + results: Results to cache (will be JSON serialized) + context: Optional context for more specific caching + + Returns: + True if successful + """ + key = f"cache:{self._compute_cache_key(query, context)}" + return self.client.set(key, results, ttl=TTL_CACHE) + + def clear_semantic(self, query: str, context: Optional[str] = None) -> bool: + """ + Clear cached semantic results. + + Args: + query: Search query + context: Optional context + + Returns: + True if key was deleted + """ + key = f"cache:{self._compute_cache_key(query, context)}" + return self.client.delete(key) > 0 + + def clear_all_semantic(self) -> int: + """ + Clear all semantic cache entries. + + Returns: + Number of keys deleted + """ + count = 0 + for key in self.client.scan_iter("cache:*"): + if self.client.delete(key) > 0: + count += 1 + return count + + # ------------------------------------------------------------------------- + # Session Memory (TTL: 1800s) + # ------------------------------------------------------------------------- + + def create_session( + self, session_id: str, initial_data: Optional[dict] = None + ) -> bool: + """ + Create a new session. + + Args: + session_id: Unique session identifier + initial_data: Optional initial session data + + Returns: + True if successful + """ + key = f"session:{session_id}" + data = { + "created_at": datetime.utcnow().isoformat(), + "last_accessed": datetime.utcnow().isoformat(), + "access_count": 0, + **(initial_data or {}), + } + return self.client.set(key, data, ttl=TTL_SESSION) + + def get_session(self, session_id: str) -> Optional[dict]: + """ + Get session data. + + Args: + session_id: Session identifier + + Returns: + Session data dict if found, None otherwise + """ + key = f"session:{session_id}" + data = self.client.get(key) + if data: + # Update access metadata + data["last_accessed"] = datetime.utcnow().isoformat() + data["access_count"] = data.get("access_count", 0) + 1 + self.client.set(key, data, ttl=TTL_SESSION) # Refresh TTL + return data + + def update_session(self, session_id: str, updates: dict) -> bool: + """ + Update session data. + + Args: + session_id: Session identifier + updates: Dict of fields to update + + Returns: + True if successful + """ + key = f"session:{session_id}" + data = self.client.get(key) + if data: + data.update(updates) + data["last_accessed"] = datetime.utcnow().isoformat() + return self.client.set(key, data, ttl=TTL_SESSION) + return False + + def add_to_session_history( + self, + session_id: str, + role: str, + content: str, + max_history: int = 20, + ) -> bool: + """ + Add a message to session conversation history. + + Args: + session_id: Session identifier + role: Message role (user/assistant/system) + content: Message content + max_history: Maximum messages to keep in history + + Returns: + True if successful + """ + key = f"session:{session_id}" + data = self.client.get(key) + if not data: + return False + + if "history" not in data: + data["history"] = [] + + data["history"].append( + { + "role": role, + "content": content, + "timestamp": datetime.utcnow().isoformat(), + } + ) + + # Trim history if needed + if len(data["history"]) > max_history: + data["history"] = data["history"][-max_history:] + + data["last_accessed"] = datetime.utcnow().isoformat() + return self.client.set(key, data, ttl=TTL_SESSION) + + def delete_session(self, session_id: str) -> bool: + """ + Delete a session. + + Args: + session_id: Session identifier + + Returns: + True if session was deleted + """ + key = f"session:{session_id}" + return self.client.delete(key) > 0 + + def get_active_sessions(self) -> list[str]: + """ + Get list of active session IDs. + + Returns: + List of session IDs + """ + session_ids = [] + for key in self.client.scan_iter("session:*"): + session_ids.append(key.replace("session:", "")) + return session_ids + + # ------------------------------------------------------------------------- + # Generic Cache Operations + # ------------------------------------------------------------------------- + + def get(self, key: str, default: Any = None) -> Any: + """ + Get a value by key. + + Args: + key: Cache key + default: Default value if not found + + Returns: + Cached value or default + """ + return self.client.get(key, default) + + def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: + """ + Set a value with optional TTL. + + Args: + key: Cache key + value: Value to cache + ttl: Time-to-live in seconds + + Returns: + True if successful + """ + return self.client.set(key, value, ttl=ttl) + + def delete(self, *keys: str) -> int: + """ + Delete keys. + + Args: + keys: Keys to delete + + Returns: + Number of keys deleted + """ + return self.client.delete(*keys) + + def exists(self, *keys: str) -> int: + """ + Check if keys exist. + + Args: + keys: Keys to check + + Returns: + Number of keys that exist + """ + return self.client.exists(*keys) + + def clear_all(self) -> int: + """ + Clear all cache entries (prices, semantic, sessions). + + Returns: + Number of keys deleted + """ + count = 0 + for pattern in ["prices:*", "cache:*", "session:*"]: + for key in self.client.scan_iter(pattern): + if self.client.delete(key) > 0: + count += 1 + return count + + def get_stats(self) -> dict: + """ + Get cache statistics. + + Returns: + Dict with counts for each cache type + """ + stats = { + "prices": 0, + "semantic": 0, + "sessions": 0, + } + + for key in self.client.scan_iter("prices:*"): + stats["prices"] += 1 + + for key in self.client.scan_iter("cache:*"): + stats["semantic"] += 1 + + for key in self.client.scan_iter("session:*"): + stats["sessions"] += 1 + + stats["total"] = sum(stats.values()) + return stats + + +# Global singleton instance +_cache_manager: Optional[CacheManager] = None + + +def get_cache_manager() -> CacheManager: + """ + Get the singleton CacheManager instance. + + Returns: + CacheManager instance + """ + global _cache_manager + if _cache_manager is None: + _cache_manager = CacheManager() + _LOG.info("CacheManager initialized") + return _cache_manager diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/__init__.py new file mode 100644 index 000000000..acc7a8e0a --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/__init__.py @@ -0,0 +1,14 @@ +""" +Cold storage for raw document archive. + +MinIO (S3-compatible) object storage for: +- Raw SEC filings (original HTML/XML) +- News article HTML +""" + +from app.storage.cold_storage.minio_client import MinIOClient, get_minio_client + +__all__ = [ + "MinIOClient", + "get_minio_client", +] diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/minio_client.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/minio_client.py new file mode 100644 index 000000000..ced74a6db --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/minio_client.py @@ -0,0 +1,463 @@ +""" +MinIO client for cold tier storage. + +Provides object storage operations for archiving raw documents: +- SEC filings (original HTML/XML) +- News articles (HTML snapshots) + +Environment Variables: + MINIO_ENDPOINT: MinIO server endpoint (default: localhost:9000) + MINIO_ACCESS_KEY: Access key (default: minioadmin) + MINIO_SECRET_KEY: Secret key (default: minioadmin) + MINIO_SECURE: Use HTTPS (default: false) + +Bucket Structure: + sec/{ticker}/{filing_type}/{accession_number}.html + news/{ticker}/{date}/{article_id}.html +""" + +import hashlib +import io +import json +import logging +import os +from datetime import datetime +from typing import Any, Optional + +from minio import Minio +from minio.error import S3Error + +_LOG = logging.getLogger(__name__) + + +class MinIOClient: + """ + MinIO client for cold tier object storage. + + Manages bucket creation and object operations with automatic + retry and error handling. + """ + + def __init__( + self, + endpoint: Optional[str] = None, + access_key: Optional[str] = None, + secret_key: Optional[str] = None, + secure: bool = False, + ): + """ + Initialize MinIO client. + + Args: + endpoint: MinIO server endpoint (default: MINIO_ENDPOINT env var or localhost:9000) + access_key: Access key (default: MINIO_ACCESS_KEY env var or minioadmin) + secret_key: Secret key (default: MINIO_SECRET_KEY env var or minioadmin) + secure: Use HTTPS (default: False) + """ + self.endpoint = endpoint or os.getenv("MINIO_ENDPOINT", "localhost:9000") + self.access_key = access_key or os.getenv("MINIO_ACCESS_KEY", "minioadmin") + self.secret_key = secret_key or os.getenv("MINIO_SECRET_KEY", "minioadmin") + self.secure = secure + + self._client: Optional[Minio] = None + self._buckets_created: set[str] = set() + + def _get_client(self) -> Minio: + """Get or create Minio client.""" + if self._client is None: + self._client = Minio( + self.endpoint, + access_key=self.access_key, + secret_key=self.secret_key, + secure=self.secure, + ) + _LOG.info("MinIO client created for endpoint: %s", self.endpoint) + return self._client + + def ping(self) -> bool: + """ + Test connection to MinIO server. + + Returns: + True if connection successful + """ + try: + client = self._get_client() + # List buckets to test connection + list(client.list_buckets()) + return True + except S3Error as e: + _LOG.error("MinIO connection failed: %s", e) + return False + + def create_bucket(self, bucket_name: str) -> bool: + """ + Create a bucket if it doesn't exist. + + Args: + bucket_name: Name of the bucket to create + + Returns: + True if bucket exists or was created successfully + """ + client = self._get_client() + + if bucket_name in self._buckets_created: + return True + + try: + if not client.bucket_exists(bucket_name): + client.make_bucket(bucket_name) + _LOG.info("Bucket '%s' created", bucket_name) + self._buckets_created.add(bucket_name) + return True + except S3Error as e: + _LOG.error("Failed to create bucket '%s': %s", bucket_name, e) + return False + + def put_object( + self, + bucket: str, + object_name: str, + data: str | bytes, + content_type: str = "application/octet-stream", + metadata: Optional[dict[str, str]] = None, + ) -> Optional[str]: + """ + Upload an object to a bucket. + + Args: + bucket: Bucket name + object_name: Object path within bucket + data: Data to upload (string or bytes) + content_type: MIME type of the content + metadata: Optional metadata headers + + Returns: + ETag of the uploaded object, or None on failure + """ + client = self._get_client() + + # Ensure bucket exists + if not self.create_bucket(bucket): + return None + + # Convert string to bytes + if isinstance(data, str): + data = data.encode("utf-8") + + try: + # Wrap data in BytesIO for compatibility with minio API + data_stream = io.BytesIO(data) + result = client.put_object( + bucket, + object_name, + data_stream, + length=len(data), + content_type=content_type, + metadata=metadata, + ) + _LOG.debug("Uploaded '%s/%s' (%d bytes)", bucket, object_name, len(data)) + return result.etag + except S3Error as e: + _LOG.error("Failed to upload '%s/%s': %s", bucket, object_name, e) + return None + + def get_object(self, bucket: str, object_name: str) -> Optional[bytes]: + """ + Download an object from a bucket. + + Args: + bucket: Bucket name + object_name: Object path within bucket + + Returns: + Object content as bytes, or None on failure + """ + client = self._get_client() + + try: + response = client.get_object(bucket, object_name) + data = response.read() + response.close() + response.release_conn() + return data + except S3Error as e: + _LOG.error("Failed to get '%s/%s': %s", bucket, object_name, e) + return None + + def get_object_as_string(self, bucket: str, object_name: str) -> Optional[str]: + """ + Download an object and return as string. + + Args: + bucket: Bucket name + object_name: Object path within bucket + + Returns: + Object content as string, or None on failure + """ + data = self.get_object(bucket, object_name) + if data: + return data.decode("utf-8") + return None + + def delete_object(self, bucket: str, object_name: str) -> bool: + """ + Delete an object from a bucket. + + Args: + bucket: Bucket name + object_name: Object path within bucket + + Returns: + True if deleted successfully + """ + client = self._get_client() + + try: + client.remove_object(bucket, object_name) + _LOG.debug("Deleted '%s/%s'", bucket, object_name) + return True + except S3Error as e: + _LOG.error("Failed to delete '%s/%s': %s", bucket, object_name, e) + return False + + def list_objects( + self, + bucket: str, + prefix: str = "", + recursive: bool = True, + ) -> list[str]: + """ + List objects in a bucket. + + Args: + bucket: Bucket name + prefix: Optional prefix to filter objects + recursive: Whether to list recursively + + Returns: + List of object names + """ + client = self._get_client() + objects = [] + + try: + for obj in client.list_objects(bucket, prefix=prefix, recursive=recursive): + objects.append(obj.object_name) + except S3Error as e: + _LOG.error("Failed to list objects in '%s': %s", bucket, e) + + return objects + + def object_exists(self, bucket: str, object_name: str) -> bool: + """ + Check if an object exists in a bucket. + + Args: + bucket: Bucket name + object_name: Object path within bucket + + Returns: + True if object exists + """ + client = self._get_client() + + try: + client.stat_object(bucket, object_name) + return True + except S3Error as e: + if e.code == "NoSuchKey": + return False + _LOG.error("Failed to check object '%s/%s': %s", bucket, object_name, e) + return False + + def put_json( + self, + bucket: str, + object_name: str, + data: dict[str, Any], + ) -> Optional[str]: + """ + Upload JSON data to a bucket. + + Args: + bucket: Bucket name + object_name: Object path within bucket + data: Dictionary to serialize and upload + + Returns: + ETag of the uploaded object, or None on failure + """ + json_str = json.dumps(data, indent=2, default=str) + return self.put_object( + bucket, + object_name, + json_str, + content_type="application/json", + ) + + def get_json(self, bucket: str, object_name: str) -> Optional[dict[str, Any]]: + """ + Download and parse JSON from a bucket. + + Args: + bucket: Bucket name + object_name: Object path within bucket + + Returns: + Parsed JSON as dict, or None on failure + """ + json_str = self.get_object_as_string(bucket, object_name) + if json_str: + try: + return json.loads(json_str) + except json.JSONDecodeError as e: + _LOG.error( + "Failed to parse JSON from '%s/%s': %s", bucket, object_name, e + ) + return None + + # ------------------------------------------------------------------------- + # Convenience Methods for Document Storage + # ------------------------------------------------------------------------- + + def store_sec_filing( + self, + ticker: str, + filing_type: str, + accession_number: str, + content: str | bytes, + metadata: Optional[dict[str, Any]] = None, + ) -> Optional[str]: + """ + Store a SEC filing in cold storage. + + Args: + ticker: Stock ticker symbol + filing_type: Form type (e.g., "10-K", "8-K") + accession_number: SEC accession number + content: Filing content (HTML/XML text or bytes) + metadata: Optional metadata to store alongside + + Returns: + Object path if successful + """ + # Normalize accession number (remove dashes and colons — both are invalid in S3 object names) + filename = accession_number.replace("-", "").replace(":", "_") + # Build path, skipping empty filing_type to avoid double slashes + if filing_type: + object_name = f"sec/{ticker}/{filing_type}/{filename}.html" + else: + object_name = f"sec/{ticker}/{filename}.html" + + etag = self.put_object( + "filings", + object_name, + content, + content_type="text/html", + ) + + if etag and metadata: + # Store metadata as separate JSON (skip if filing_type is empty to avoid double slashes) + if filing_type: + metadata_path = f"sec/{ticker}/{filing_type}/{filename}.meta.json" + else: + metadata_path = f"sec/{ticker}/{filename}.meta.json" + self.put_json("filings", metadata_path, metadata) + + return object_name if etag else None + + def store_earnings_transcript( + self, + ticker: str, + quarter_code: str, + content: str | bytes, + metadata: dict[str, Any], + ) -> Optional[str]: + """ + Store a raw earnings call transcript JSON in cold storage. + + :param ticker: stock ticker symbol + :param quarter_code: ``YYYYQN`` quarter identifier (e.g. ``2024Q1``) + :param content: full transcript text (joined speaker turns) + :param metadata: transcript metadata (speakers, fiscal period, etc.) + :return: object path if stored, ``None`` otherwise + """ + object_name = f"earnings/{ticker}/{quarter_code}.txt" + etag = self.put_object( + "transcripts", + object_name, + content, + content_type="text/plain", + ) + if etag: + metadata_path = f"earnings/{ticker}/{quarter_code}.meta.json" + self.put_json("transcripts", metadata_path, metadata) + return object_name if etag else None + + def store_news_article( + self, + ticker: str, + url: str, + content: str, + metadata: dict[str, Any], + ) -> Optional[str]: + """ + Store a news article in cold storage. + + Args: + ticker: Stock ticker symbol + url: Article URL + content: Article HTML or text content + metadata: Article metadata + + Returns: + Object path if successful + """ + # Generate deterministic filename from URL + url_hash = hashlib.md5(url.encode()).hexdigest()[:12] + date = metadata.get("published_at", datetime.utcnow().isoformat())[:10] + object_name = f"news/{ticker}/{date}/{url_hash}.html" + + etag = self.put_object( + "articles", + object_name, + content, + content_type="text/html", + ) + + if etag: + # Store metadata as separate JSON + metadata_path = f"news/{ticker}/{date}/{url_hash}.meta.json" + self.put_json("articles", metadata_path, metadata) + + return object_name if etag else None + + def close(self) -> None: + """Close MinIO client connections.""" + self._client = None + self._buckets_created.clear() + _LOG.info("MinIO client closed") + + +# Global singleton instance +_minio_client: Optional[MinIOClient] = None + + +def get_minio_client() -> MinIOClient: + """ + Get the singleton MinIO client instance. + + Returns: + MinIOClient instance + """ + global _minio_client + if _minio_client is None: + _minio_client = MinIOClient() + if _minio_client.ping(): + _LOG.info("Connected to MinIO at %s", _minio_client.endpoint) + else: + _LOG.warning("MinIO connection test failed") + return _minio_client diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/tests/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/tests/__init__.py new file mode 100644 index 000000000..71a5fff2d --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Tests for MinIO cold storage. +""" diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/tests/test_minio.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/tests/test_minio.py new file mode 100644 index 000000000..b96f1cee6 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/cold_storage/tests/test_minio.py @@ -0,0 +1,214 @@ +# storage/cold_storage/tests/test_minio.py +""" +Test script for MinIO cold tier infrastructure. + +Usage: + python -m app.storage.cold_storage.tests.test_minio +""" + +import logging + +from app.storage.cold_storage.minio_client import get_minio_client + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +_LOG = logging.getLogger(__name__) + + +def test_minio_client(): + """Test basic MinIO client operations.""" + print("\n" + "=" * 60) + print("Testing MinIO Client") + print("=" * 60) + + client = get_minio_client() + + # Test connection + print("\n1. Testing connection...") + if client.ping(): + print(" [OK] Connected to MinIO") + else: + print(" [FAIL] Connection failed") + return False + + # Test bucket creation + print("\n2. Testing bucket creation...") + test_bucket = "test-bucket" + if client.create_bucket(test_bucket): + print(f" [OK] Bucket '{test_bucket}' created/exists") + else: + print(f" [FAIL] Failed to create bucket '{test_bucket}'") + return False + + # Test put_object with string + print("\n3. Testing put_object (string)...") + test_content = "Hello, MinIO! This is a test document." + test_object = "test/documents/hello.txt" + etag = client.put_object(test_bucket, test_object, test_content) + if etag: + print(f" [OK] Object uploaded with ETag: {etag}") + else: + print(" [FAIL] Failed to upload object") + return False + + # Test put_object with bytes + print("\n4. Testing put_object (bytes)...") + test_bytes = b"Binary content test" + test_object_bytes = "test/documents/binary.bin" + etag_bytes = client.put_object(test_bucket, test_object_bytes, test_bytes) + if etag_bytes: + print(f" [OK] Binary object uploaded with ETag: {etag_bytes}") + else: + print(" [FAIL] Failed to upload binary object") + + # Test get_object + print("\n5. Testing get_object...") + retrieved = client.get_object(test_bucket, test_object) + if retrieved and retrieved.decode("utf-8") == test_content: + print(f" [OK] Object retrieved: '{retrieved.decode('utf-8')}'") + else: + print(" [FAIL] Retrieved content mismatch") + + # Test get_object_as_string + print("\n6. Testing get_object_as_string...") + retrieved_str = client.get_object_as_string(test_bucket, test_object) + if retrieved_str == test_content: + print(" [OK] String retrieval works") + else: + print(" [FAIL] String retrieval failed") + + # Test object_exists + print("\n7. Testing object_exists...") + exists = client.object_exists(test_bucket, test_object) + if exists: + print(" [OK] Object exists check works") + else: + print(" [FAIL] Object exists check failed") + + # Test list_objects + print("\n8. Testing list_objects...") + objects = client.list_objects(test_bucket, prefix="test/") + if len(objects) >= 2: + print(f" [OK] Listed {len(objects)} objects: {objects}") + else: + print(f" [FAIL] Expected at least 2 objects, got {len(objects)}") + + # Test put_json and get_json + print("\n9. Testing JSON operations...") + test_json = { + "ticker": "AAPL", + "company": "Apple Inc.", + "filing_type": "10-K", + "filing_date": "2024-01-15", + "metrics": { + "revenue": 383285000000, + "net_income": 96995000000, + }, + } + json_etag = client.put_json(test_bucket, "test/data/filing.json", test_json) + if json_etag: + print(f" [OK] JSON uploaded with ETag: {json_etag}") + retrieved_json = client.get_json(test_bucket, "test/data/filing.json") + if retrieved_json == test_json: + print(f" [OK] JSON retrieval works: {retrieved_json}") + else: + print(" [FAIL] JSON retrieval mismatch") + else: + print(" [FAIL] JSON upload failed") + + # Test convenience method: store_sec_filing + print("\n10. Testing store_sec_filing...") + sec_content = """ + + Apple Inc. 10-K + +

UNITED STATES SECURITIES AND EXCHANGE COMMISSION

+

Form 10-K

+

Apple Inc. reported record revenue for fiscal year 2024.

+ + + """ + sec_metadata = { + "company_name": "Apple Inc.", + "cik": "0000320193", + "filing_date": "2024-01-15", + } + sec_path = client.store_sec_filing( + ticker="AAPL", + filing_type="10-K", + accession_number="0000320193-24-000001", + content=sec_content, + metadata=sec_metadata, + ) + if sec_path: + print(f" [OK] SEC filing stored at: {sec_path}") + # Verify we can retrieve it + retrieved_sec = client.get_object_as_string("filings", sec_path) + if retrieved_sec and "Apple Inc." in retrieved_sec: + print(" [OK] SEC filing retrieval works") + else: + print(" [FAIL] SEC filing retrieval failed") + else: + print(" [FAIL] SEC filing storage failed") + + # Test convenience method: store_news_article + print("\n11. Testing store_news_article...") + news_content = "

Apple Stock Rises on Strong Earnings

" + news_metadata = { + "title": "Apple Stock Rises on Strong Earnings", + "source": "Reuters", + "published_at": "2024-01-16T10:00:00Z", + "url": "https://reuters.com/test-article", + } + news_path = client.store_news_article( + ticker="AAPL", + url="https://reuters.com/test-article", + content=news_content, + metadata=news_metadata, + ) + if news_path: + print(f" [OK] News article stored at: {news_path}") + else: + print(" [FAIL] News article storage failed") + + # Test delete_object + print("\n12. Testing delete_object...") + deleted = client.delete_object(test_bucket, test_object) + if deleted: + print(" [OK] Object deleted") + # Verify deletion + exists_after = client.object_exists(test_bucket, test_object) + if not exists_after: + print(" [OK] Object confirmed deleted") + else: + print(" [FAIL] Object still exists after deletion") + else: + print(" [FAIL] Delete failed") + + # Test close + print("\n13. Testing close...") + client.close() + print(" [OK] Client closed") + + return True + + +def main(): + """Run all tests.""" + print("\n" + "=" * 60) + print("MinIO Cold Tier Infrastructure Tests") + print("=" * 60) + + success = test_minio_client() + + print("\n" + "=" * 60) + if success: + print("All tests passed!") + else: + print("Some tests failed!") + print("=" * 60 + "\n") + + +if __name__ == "__main__": + main() diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/hot_storage/keydb_client.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/hot_storage/keydb_client.py new file mode 100644 index 000000000..49fb2f7f3 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/hot_storage/keydb_client.py @@ -0,0 +1,396 @@ +""" +KeyDB client for hot tier storage. + +KeyDB is a Redis-compatible, high-performance in-memory database. +This module provides connection management with authentication support. + +Hot Tier Keys: +- prices:{ticker} - Live price feeds (TTL 60s) +- cache:{md5_hash} - Semantic cache (TTL 3600s) +- session:{id} - Agent memory/sessions (TTL 1800s) +""" + +import json +import logging +import os +from typing import Any, Optional +from dotenv import load_dotenv + +import redis + +load_dotenv() + +_LOG = logging.getLogger(__name__) + + +class KeyDBClient: + """ + KeyDB client with connection pooling and authentication. + + KeyDB is Redis-compatible, so we use the redis-py client. + Supports both single-node and cluster deployments. + """ + + def __init__( + self, + host: Optional[str] = os.getenv("KEYDB_HOST", "localhost"), + port: Optional[int] = int(os.getenv("KEYDB_PORT", "6379")), + password: Optional[str] = os.getenv("KEYDB_PASSWORD"), + db: int = 0, + socket_timeout: float = 5.0, + socket_connect_timeout: float = 5.0, + max_connections: int = 10, + ): + """ + Initialize KeyDB client with connection pool. + + Args: + host: KeyDB server host (default: localhost or KEYDB_HOST env var) + port: KeyDB server port (default: 6379 or KEYDB_PORT env var) + password: Authentication password (default: KEYDB_PASSWORD env var) + db: Database number (default: 0) + socket_timeout: Socket timeout in seconds + socket_connect_timeout: Connection timeout in seconds + max_connections: Maximum connections in pool + """ + self.host = host + self.port = port + self.password = password + self.db = db + + self._pool: Optional[redis.ConnectionPool] = None + self._client: Optional[redis.Redis] = None + + self._pool_kwargs = { + "host": self.host, + "port": self.port, + "password": self.password, + "db": self.db, + "socket_timeout": socket_timeout, + "socket_connect_timeout": socket_connect_timeout, + "max_connections": max_connections, + "decode_responses": True, + } + + def _get_pool(self) -> redis.ConnectionPool: + """Get or create connection pool.""" + if self._pool is None: + self._pool = redis.ConnectionPool(**self._pool_kwargs) + return self._pool + + def _get_client(self) -> redis.Redis: + """Get or create Redis client.""" + if self._client is None: + self._client = redis.Redis(connection_pool=self._get_pool()) + return self._client + + def ping(self) -> bool: + """ + Test connection to KeyDB server. + + Returns: + True if connection successful + """ + try: + client = self._get_client() + return client.ping() + except redis.ConnectionError as e: + _LOG.error("KeyDB connection failed: %s", e) + return False + + def set( + self, + key: str, + value: Any, + ttl: Optional[int] = None, + ) -> bool: + """ + Set a key-value pair with optional TTL. + + Args: + key: Redis key + value: Value to store (auto-serialized to JSON if not string) + ttl: Time-to-live in seconds (optional) + + Returns: + True if successful + """ + client = self._get_client() + + # Serialize non-string values to JSON + if not isinstance(value, str): + value = json.dumps(value) + + try: + if ttl: + return client.setex(key, ttl, value) + else: + return client.set(key, value) + except redis.RedisError as e: + _LOG.error("KeyDB set error for key='%s': %s", key, e) + return False + + def get(self, key: str, default: Any = None) -> Any: + """ + Get a value by key. + + Args: + key: Redis key + default: Default value if key not found + + Returns: + Value (auto-deserialized from JSON) or default + """ + client = self._get_client() + + try: + value = client.get(key) + if value is None: + return default + + # Try to deserialize JSON + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + return value + + except redis.RedisError as e: + _LOG.error("KeyDB get error for key='%s': %s", key, e) + return default + + def delete(self, *keys: str) -> int: + """ + Delete one or more keys. + + Args: + keys: Keys to delete + + Returns: + Number of keys deleted + """ + client = self._get_client() + + try: + return client.delete(*keys) + except redis.RedisError as e: + _LOG.error("KeyDB delete error: %s", e) + return 0 + + def exists(self, *keys: str) -> int: + """ + Check if keys exist. + + Args: + keys: Keys to check + + Returns: + Number of keys that exist + """ + client = self._get_client() + + try: + return client.exists(*keys) + except redis.RedisError as e: + _LOG.error("KeyDB exists error: %s", e) + return 0 + + def ttl(self, key: str) -> int: + """ + Get TTL for a key. + + Args: + key: Redis key + + Returns: + TTL in seconds, -1 if no TTL, -2 if key doesn't exist + """ + client = self._get_client() + + try: + return client.ttl(key) + except redis.RedisError as e: + _LOG.error("KeyDB ttl error for key='%s': %s", key, e) + return -2 + + def expire(self, key: str, ttl: int) -> bool: + """ + Set TTL on an existing key. + + Args: + key: Redis key + ttl: Time-to-live in seconds + + Returns: + True if TTL was set + """ + client = self._get_client() + + try: + return client.expire(key, ttl) + except redis.RedisError as e: + _LOG.error("KeyDB expire error for key='%s': %s", key, e) + return False + + def incr(self, key: str, amount: int = 1) -> Optional[int]: + """ + Increment a key atomically. + + Args: + key: Redis key + amount: Amount to increment by + + Returns: + New value after increment + """ + client = self._get_client() + + try: + return client.incr(key, amount) + except redis.RedisError as e: + _LOG.error("KeyDB incr error for key='%s': %s", key, e) + return None + + def hset( + self, + name: str, + key: Optional[str] = None, + value: Any = None, + mapping: Optional[dict] = None, + ) -> int: + """ + Set hash field(s). + + Args: + name: Hash name + key: Field key (optional if mapping provided) + value: Field value (optional if mapping provided) + mapping: Dict of field-value pairs (optional) + + Returns: + Number of fields set + """ + client = self._get_client() + + # Serialize value if provided + if key is not None and value is not None and not isinstance(value, str): + value = json.dumps(value) + + try: + return client.hset(name, key, value, mapping=mapping) + except redis.RedisError as e: + _LOG.error("KeyDB hset error for hash='%s': %s", name, e) + return 0 + + def hgetall(self, name: str) -> dict: + """ + Get all fields from a hash. + + Args: + name: Hash name + + Returns: + Dict of field-value pairs + """ + client = self._get_client() + + try: + return client.hgetall(name) + except redis.RedisError as e: + _LOG.error("KeyDB hgetall error for hash='%s': %s", name, e) + return {} + + def hget(self, name: str, key: str, default: Any = None) -> Any: + """ + Get a single field from a hash. + + Args: + name: Hash name + key: Field key + default: Default value if field not found + + Returns: + Field value (auto-deserialized) or default + """ + client = self._get_client() + + try: + value = client.hget(name, key) + if value is None: + return default + + # Try to deserialize JSON + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + return value + + except redis.RedisError as e: + _LOG.error("KeyDB hget error for hash='%s', key='%s': %s", name, key, e) + return default + + def scan_iter(self, match: str, count: int = 100): + """ + Scan keys matching a pattern. + + Args: + match: Pattern to match (e.g., "prices:*") + count: Hint for number of keys per iteration + + Yields: + Matching keys + """ + client = self._get_client() + + try: + for key in client.scan_iter(match=match, count=count): + yield key + except redis.RedisError as e: + _LOG.error("KeyDB scan_iter error for pattern='%s': %s", match, e) + + def flushdb(self) -> bool: + """ + Flush current database (dangerous - use with caution). + + Returns: + True if successful + """ + client = self._get_client() + + try: + return client.flushdb() + except redis.RedisError as e: + _LOG.error("KeyDB flushdb error: %s", e) + return False + + def close(self) -> None: + """Close all connections in the pool.""" + if self._pool: + self._pool.disconnect() + self._pool = None + self._client = None + _LOG.info("KeyDB connections closed") + + +# Global singleton instance +_keydb_client: Optional[KeyDBClient] = None + + +def get_keydb_client() -> KeyDBClient: + """ + Get the singleton KeyDB client instance. + + Creates a new instance on first call, returns existing instance thereafter. + + Returns: + KeyDBClient instance + """ + global _keydb_client + if _keydb_client is None: + _keydb_client = KeyDBClient() + if _keydb_client.ping(): + _LOG.info( + "Connected to KeyDB at %s:%d", _keydb_client.host, _keydb_client.port + ) + else: + _LOG.warning("KeyDB connection test failed") + return _keydb_client diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/hot_storage/tests/test_keydb.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/hot_storage/tests/test_keydb.py new file mode 100644 index 000000000..04336327e --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/hot_storage/tests/test_keydb.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +Test script for KeyDB hot tier infrastructure. + +Usage: + python -m app.storage.test_keydb +""" + +import logging +from datetime import datetime + +from app.storage.keydb_client import get_keydb_client +from app.storage.cache_manager import ( + get_cache_manager, + PriceData, + TTL_PRICES, + TTL_CACHE, + TTL_SESSION, +) + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +_LOG = logging.getLogger(__name__) + + +def test_keydb_client(): + """Test basic KeyDB client operations.""" + print("\n" + "=" * 60) + print("Testing KeyDB Client") + print("=" * 60) + + client = get_keydb_client() + + # Test connection + print("\n1. Testing connection...") + if client.ping(): + print(" [OK] Connected to KeyDB") + else: + print(" [FAIL] Connection failed") + return False + + # Test set/get + print("\n2. Testing set/get...") + test_key = "test:string" + test_value = "Hello, KeyDB!" + client.set(test_key, test_value, ttl=60) + result = client.get(test_key) + if result == test_value: + print(f" [OK] String set/get: '{result}'") + else: + print(f" [FAIL] Expected '{test_value}', got '{result}'") + + # Test JSON serialization + print("\n3. Testing JSON serialization...") + test_dict = {"ticker": "AAPL", "price": 175.50, "volume": 1000000} + client.set("test:json", test_dict, ttl=60) + result = client.get("test:json") + if result == test_dict: + print(f" [OK] JSON set/get: {result}") + else: + print(" [FAIL] JSON mismatch") + + # Test hash operations + print("\n4. Testing hash operations...") + client.hset("test:hash", "field1", "value1") + client.hset("test:hash", "field2", {"nested": "data"}) + all_fields = client.hgetall("test:hash") + print(f" [OK] Hash fields: {all_fields}") + + # Test TTL + print("\n5. Testing TTL...") + client.set("test:ttl", "expires soon", ttl=10) + ttl = client.ttl("test:ttl") + if 0 < ttl <= 10: + print(f" [OK] TTL set correctly: {ttl}s remaining") + else: + print(f" [FAIL] TTL incorrect: {ttl}") + + # Test delete + print("\n6. Testing delete...") + deleted = client.delete("test:string", "test:json", "test:hash", "test:ttl") + print(f" [OK] Deleted {deleted} keys") + + return True + + +def test_cache_manager(): + """Test CacheManager operations.""" + print("\n" + "=" * 60) + print("Testing Cache Manager") + print("=" * 60) + + cache = get_cache_manager() + + # Test price caching + print("\n1. Testing price cache...") + price_data = PriceData( + ticker="AAPL", + price=175.50, + change=2.25, + change_percent=1.30, + volume=52000000, + timestamp=datetime.utcnow(), + ) + cache.set_price("AAPL", price_data) + cached_price = cache.get_price("AAPL") + if cached_price and cached_price.ticker == "AAPL": + print( + f" [OK] Price cached: ${cached_price.price} ({cached_price.change_percent}%)" + ) + else: + print(" [FAIL] Price not cached correctly") + + # Test batch price retrieval + print("\n2. Testing batch price retrieval...") + for ticker in ["GOOGL", "MSFT", "AMZN"]: + cache.set_price( + ticker, + PriceData( + ticker=ticker, + price=100.0, + change=1.0, + change_percent=1.0, + volume=1000000, + timestamp=datetime.utcnow(), + ), + ) + batch = cache.get_prices_batch(["AAPL", "GOOGL", "MSFT", "INVALID"]) + print(f" [OK] Retrieved {len(batch)} prices: {list(batch.keys())}") + + # Test semantic cache + print("\n3. Testing semantic cache...") + query = "What is Apple's revenue?" + mock_results = [ + {"text": "Apple reported revenue of $89.5B", "score": 0.95}, + {"text": "Q4 revenue beat expectations", "score": 0.87}, + ] + cache.set_semantic(query, mock_results) + cached_results = cache.get_semantic(query) + if cached_results and len(cached_results) == 2: + print(f" [OK] Semantic cache: {len(cached_results)} results") + else: + print(" [FAIL] Semantic cache failed") + + # Test semantic cache with context + print("\n4. Testing semantic cache with context...") + cache.set_semantic("revenue query", {"data": "result1"}, context="AAPL") + cache.set_semantic("revenue query", {"data": "result2"}, context="GOOGL") + result_aapl = cache.get_semantic("revenue query", context="AAPL") + result_googl = cache.get_semantic("revenue query", context="GOOGL") + if result_aapl != result_googl: + print(" [OK] Context-specific caching works") + else: + print(" [FAIL] Context not differentiating results") + + # Test session management + print("\n5. Testing session management...") + session_id = "test_session_123" + cache.create_session(session_id, {"user": "test_user", "ticker": "AAPL"}) + session = cache.get_session(session_id) + if session and session.get("user") == "test_user": + print(f" [OK] Session created for user: {session['user']}") + else: + print(" [FAIL] Session not created correctly") + + # Test session history + print("\n6. Testing session conversation history...") + cache.add_to_session_history(session_id, "user", "What's the stock price?") + cache.add_to_session_history(session_id, "assistant", "AAPL is at $175.50") + session = cache.get_session(session_id) + if session.get("history") and len(session["history"]) == 2: + print(f" [OK] History has {len(session['history'])} messages") + else: + print(" [FAIL] Session history not working") + + # Test session update + print("\n7. Testing session update...") + cache.update_session(session_id, {"ticker": "GOOGL", "last_query": "revenue"}) + session = cache.get_session(session_id) + if session.get("ticker") == "GOOGL": + print(" [OK] Session updated successfully") + else: + print(" [FAIL] Session update failed") + + # Test cache stats + print("\n8. Getting cache statistics...") + stats = cache.get_stats() + print(f" [OK] Cache stats: {stats}") + + # Cleanup test sessions + print("\n9. Cleaning up test data...") + cache.delete_session(session_id) + cache.clear_all_semantic() + for ticker in ["AAPL", "GOOGL", "MSFT", "AMZN"]: + cache.clear_price(ticker) + print(" [OK] Cleanup complete") + + return True + + +def main(): + """Run all tests.""" + print("\n" + "=" * 60) + print("KeyDB Hot Tier Infrastructure Tests") + print("=" * 60) + print("\nTTL Configuration:") + print(f" - Prices: {TTL_PRICES}s (1 minute)") + print(f" - Semantic: {TTL_CACHE}s (1 hour)") + print(f" - Sessions: {TTL_SESSION}s (30 minutes)") + + success = True + + if not test_keydb_client(): + print("\n[ABORT] KeyDB client tests failed - skipping cache manager tests") + success = False + else: + if not test_cache_manager(): + success = False + + print("\n" + "=" * 60) + if success: + print("All tests passed!") + else: + print("Some tests failed!") + print("=" * 60 + "\n") + + +if __name__ == "__main__": + main() diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/__init__.py new file mode 100644 index 000000000..20cadeaab --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/__init__.py @@ -0,0 +1,44 @@ +""" +Warm storage layer using PostgreSQL + pgvector. + +This package provides persistent storage for: +- SEC filings metadata +- Document chunks with vector embeddings +- XBRL facts (structured financial data) +- Semantic search via pgvector + +Modules: +- pgvector_client: Low-level PostgreSQL client with connection pooling +- filings_manager: High-level manager for filings, chunks, and XBRL facts +""" + +from app.storage.warm_storage.pgvector_client import PostgresClient, get_postgres_client +from app.storage.warm_storage.filings_manager import ( + FilingsManager, + get_filings_manager, + FilingData, + ChunkData, + XBRLFact, + SearchResults, + generate_filing_id, + generate_chunk_id, + generate_xbrl_fact_id, +) + +__all__ = [ + # Client + "PostgresClient", + "get_postgres_client", + # Manager + "FilingsManager", + "get_filings_manager", + # Data classes + "FilingData", + "ChunkData", + "XBRLFact", + "SearchResults", + # ID generators + "generate_filing_id", + "generate_chunk_id", + "generate_xbrl_fact_id", +] diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/filings_manager.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/filings_manager.py new file mode 100644 index 000000000..9960b8aca --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/filings_manager.py @@ -0,0 +1,600 @@ +""" +High-level manager for SEC filings, chunks, and XBRL facts. + +Provides typed methods for: +- Storing and retrieving SEC filings (10-K, 10-Q, 8-K, DEF 14A) +- Managing document chunks with vector embeddings +- Storing and querying XBRL facts +- Semantic search across filing content + +This is the warm tier storage - data persists beyond the hot KeyDB cache. +""" + +import hashlib +import logging +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Optional + +import psycopg2 +from dotenv import load_dotenv + +from app.storage.warm_storage.pgvector_client import PostgresClient, get_postgres_client + +load_dotenv() + +_LOG = logging.getLogger(__name__) + + +# ============================================================================= +# Data Classes +# ============================================================================= + + +@dataclass +class FilingData: + """SEC filing metadata.""" + + id: str + ticker: str + company_name: str + filing_type: str # 10-K, 10-Q, 8-K, DEF 14A + cik: str + accession_number: str + filing_date: datetime + period_of_report: Optional[datetime] = None + document_url: Optional[str] = None + file_size_bytes: Optional[int] = None + + def to_dict(self) -> dict: + """Convert to dictionary for database insertion.""" + return { + "id": self.id, + "ticker": self.ticker, + "company_name": self.company_name, + "filing_type": self.filing_type, + "cik": self.cik, + "accession_number": self.accession_number, + "filing_date": self.filing_date.date() if self.filing_date else None, + "period_of_report": self.period_of_report.date() + if self.period_of_report + else None, + "document_url": self.document_url, + "file_size_bytes": self.file_size_bytes, + } + + +@dataclass +class ChunkData: + """Document chunk with embedding.""" + + id: str + filing_id: str + chunk_index: int + text: str + section: Optional[str] = None + embedding: Optional[List[float]] = None + token_count: Optional[int] = None + + def to_dict(self) -> dict: + """Convert to dictionary for database insertion.""" + return { + "id": self.id, + "filing_id": self.filing_id, + "chunk_index": self.chunk_index, + "text": self.text, + "section": self.section, + "embedding": self.embedding, + "token_count": self.token_count, + } + + +@dataclass +class XBRLFact: + """XBRL fact from SEC filing.""" + + id: str + filing_id: str + concept_name: str + value: str + value_numeric: Optional[float] = None + unit: Optional[str] = None + period_start: Optional[datetime] = None + period_end: Optional[datetime] = None + instant_date: Optional[datetime] = None + axis: Optional[str] = None + member: Optional[str] = None + + def to_dict(self) -> dict: + """Convert to dictionary for database insertion.""" + return { + "id": self.id, + "filing_id": self.filing_id, + "concept_name": self.concept_name, + "value": self.value, + "value_numeric": self.value_numeric, + "unit": self.unit, + "period_start": self.period_start.date() if self.period_start else None, + "period_end": self.period_end.date() if self.period_end else None, + "instant_date": self.instant_date.date() if self.instant_date else None, + "axis": self.axis, + "member": self.member, + } + + +@dataclass +class SearchResults: + """Semantic search results.""" + + chunk_id: str + filing_id: str + text: str + section: Optional[str] + similarity: float + ticker: str + filing_type: str + filing_date: datetime + company_name: Optional[str] + + @classmethod + def from_db(cls, row: Dict[str, Any]) -> "SearchResults": + """Create from database row.""" + return cls( + chunk_id=row["chunk_id"], + filing_id=row["filing_id"], + text=row["text"], + section=row.get("section"), + similarity=row["similarity"], + ticker=row["ticker"], + filing_type=row["filing_type"], + filing_date=row["filing_date"], + company_name=row.get("company_name"), + ) + + +# ============================================================================= +# FilingsManager +# ============================================================================= + + +class FilingsManager: + """ + High-level manager for warm tier storage operations. + + Provides typed methods for storing and retrieving: + - SEC filings metadata + - Document chunks with embeddings + - XBRL facts + - Semantic search results + """ + + def __init__(self, client: Optional[PostgresClient] = None): + """ + Initialize filings manager. + + Args: + client: PostgreSQL client (uses singleton if not provided) + """ + self.client = client or get_postgres_client() + + # ------------------------------------------------------------------------- + # Filing Operations + # ------------------------------------------------------------------------- + + def store_filing(self, filing: FilingData) -> bool: + """ + Store an SEC filing in the database. + + Args: + filing: Filing data object + + Returns: + True if successful + """ + filing_id = self.client.insert_filing(filing.to_dict()) + if filing_id: + _LOG.info("Stored filing %s for %s", filing_id, filing.ticker) + return True + return False + + def get_filing(self, filing_id: str) -> Optional[FilingData]: + """ + Retrieve a filing by ID. + + Args: + filing_id: Filing identifier + + Returns: + FilingData if found, None otherwise + """ + data = self.client.get_filing(filing_id) + if not data: + return None + + return FilingData( + id=data["id"], + ticker=data["ticker"], + company_name=data.get("company_name"), + filing_type=data["filing_type"], + cik=data.get("cik"), + accession_number=data.get("accession_number"), + filing_date=data["filing_date"], + period_of_report=data.get("period_of_report"), + document_url=data.get("document_url"), + file_size_bytes=data.get("file_size_bytes"), + ) + + def get_filings_for_ticker( + self, + ticker: str, + filing_types: Optional[List[str]] = None, + limit: int = 10, + ) -> List[FilingData]: + """ + Get filings for a specific ticker. + + Args: + ticker: Stock ticker symbol + filing_types: Optional list of filing types to filter + limit: Maximum results to return + + Returns: + List of FilingData objects + """ + filings = self.client.get_filings_by_ticker(ticker, filing_types, limit) + result = [] + + for f in filings: + result.append( + FilingData( + id=f["id"], + ticker=f["ticker"], + company_name=f.get("company_name"), + filing_type=f["filing_type"], + cik=f.get("cik"), + accession_number=f.get("accession_number"), + filing_date=f["filing_date"], + period_of_report=f.get("period_of_report"), + document_url=f.get("document_url"), + file_size_bytes=f.get("file_size_bytes"), + ) + ) + + return result + + def delete_filing(self, filing_id: str) -> bool: + """ + Delete a filing and all associated chunks/facts. + + Args: + filing_id: Filing to delete + + Returns: + True if successful + """ + return self.client.delete_filing(filing_id) + + # ------------------------------------------------------------------------- + # Chunk Operations + # ------------------------------------------------------------------------- + + def store_chunks(self, chunks: List[ChunkData]) -> int: + """ + Store document chunks with embeddings. + + Args: + chunks: List of chunk data objects + + Returns: + Number of chunks stored + """ + chunk_dicts = [c.to_dict() for c in chunks] + count = self.client.insert_chunks(chunk_dicts) + _LOG.info("Stored %d chunks", count) + return count + + def get_chunks_for_filing( + self, + filing_id: str, + include_embedding: bool = False, + ) -> List[ChunkData]: + """ + Get all chunks for a filing. + + Args: + filing_id: Filing identifier + include_embedding: Whether to include embedding vectors + + Returns: + List of ChunkData objects + """ + chunks = self.client.get_chunks_by_filing(filing_id, include_embedding) + result = [] + + for c in chunks: + result.append( + ChunkData( + id=c["id"], + filing_id=c["filing_id"], + chunk_index=c["chunk_index"], + text=c["text"], + section=c.get("section"), + embedding=c.get("embedding"), + token_count=c.get("token_count"), + ) + ) + + return result + + # ------------------------------------------------------------------------- + # XBRL Facts Operations + # ------------------------------------------------------------------------- + + def store_xbrl_facts(self, facts: List[XBRLFact]) -> int: + """ + Store XBRL facts from a filing. + + Args: + facts: List of XBRL fact objects + + Returns: + Number of facts stored + """ + fact_dicts = [f.to_dict() for f in facts] + count = self.client.insert_xbrl_facts(fact_dicts) + _LOG.info("Stored %d XBRL facts", count) + return count + + def get_xbrl_facts( + self, + filing_id: str, + concepts: Optional[List[str]] = None, + ) -> List[XBRLFact]: + """ + Get XBRL facts for a filing. + + Args: + filing_id: Filing identifier + concepts: Optional list of specific concepts to retrieve + + Returns: + List of XBRLFact objects + """ + if concepts: + facts = self.client.get_xbrl_facts_by_concept(filing_id, concepts) + else: + # Get all facts for filing + query = "SELECT * FROM xbrl_facts WHERE filing_id = %s" + with self.client.get_cursor() as cur: + cur.execute(query, [filing_id]) + facts = [dict(row) for row in cur.fetchall()] + + result = [] + for f in facts: + result.append( + XBRLFact( + id=f["id"], + filing_id=f["filing_id"], + concept_name=f["concept_name"], + value=f["value"], + value_numeric=f.get("value_numeric"), + unit=f.get("unit"), + period_start=f.get("period_start"), + period_end=f.get("period_end"), + instant_date=f.get("instant_date"), + axis=f.get("axis"), + member=f.get("member"), + ) + ) + + return result + + # ------------------------------------------------------------------------- + # Semantic Search + # ------------------------------------------------------------------------- + + def search_similar( + self, + query_embedding: List[float], + ticker_filter: Optional[str] = None, + limit: int = 10, + threshold: float = 0.5, + ) -> List[SearchResults]: + """ + Search for similar chunks using vector embeddings. + + Args: + query_embedding: Query vector (768 dimensions for nomic-embed-text) + ticker_filter: Optional ticker to filter results + limit: Maximum results to return + threshold: Minimum similarity threshold (0-1) + + Returns: + List of SearchResults objects + """ + results = self.client.search_similar( + query_embedding=query_embedding, + table="chunks", + limit=limit, + threshold=threshold, + ticker_filter=ticker_filter, + ) + + return [SearchResults.from_db(r) for r in results] + + def search_with_text( + self, + query_text: str, + embeddings_model: Any, + ticker_filter: Optional[str] = None, + limit: int = 10, + threshold: float = 0.5, + ) -> List[SearchResults]: + """ + Search using text query (embeds query then searches). + + Args: + query_text: Text query string + embeddings_model: Model to embed query (e.g., OllamaEmbedding) + ticker_filter: Optional ticker to filter results + limit: Maximum results to return + threshold: Minimum similarity threshold + + Returns: + List of SearchResults objects + """ + # Embed the query + query_embedding = embeddings_model.embed_query(query_text) + + # Search + return self.search_similar( + query_embedding=query_embedding, + ticker_filter=ticker_filter, + limit=limit, + threshold=threshold, + ) + + # ------------------------------------------------------------------------- + # Statistics + # ------------------------------------------------------------------------- + + def get_stats(self) -> Dict[str, Any]: + """ + Get storage statistics. + + Returns: + Dict with counts and metadata + """ + stats = self.client.get_stats() + + # Add human-readable summary + stats["summary"] = ( + f"Filings: {stats.get('filings', 0)}, " + f"Chunks: {stats.get('chunks', 0)}, " + f"XBRL Facts: {stats.get('xbrl_facts', 0)}, " + f"Tickers: {stats.get('unique_tickers', 0)}" + ) + + return stats + + def get_ticker_stats(self, ticker: str) -> Dict[str, Any]: + """ + Get statistics for a specific ticker. + + Args: + ticker: Stock ticker symbol + + Returns: + Dict with ticker-specific stats + """ + query = """ + SELECT + COUNT(DISTINCT f.id) AS filing_count, + COUNT(c.id) AS chunk_count, + COUNT(xf.id) AS xbrl_fact_count, + MIN(f.filing_date) AS earliest_filing, + MAX(f.filing_date) AS latest_filing + FROM filings f + LEFT JOIN chunks c ON f.id = c.filing_id + LEFT JOIN xbrl_facts xf ON f.id = xf.filing_id + WHERE f.ticker = %s + """ + + try: + with self.client.get_cursor() as cur: + cur.execute(query, [ticker]) + result = cur.fetchone() + if result: + return { + "ticker": ticker, + "filing_count": result["filing_count"], + "chunk_count": result["chunk_count"], + "xbrl_fact_count": result["xbrl_fact_count"], + "earliest_filing": result["earliest_filing"], + "latest_filing": result["latest_filing"], + } + except psycopg2.Error as e: + _LOG.error("Failed to get ticker stats: %s", e) + + return {"ticker": ticker, "error": "Failed to retrieve stats"} + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def generate_filing_id( + ticker: str, + filing_type: str, + filing_date: datetime, + accession_number: str, +) -> str: + """ + Generate a deterministic filing ID. + + Args: + ticker: Stock ticker symbol + filing_type: Type of filing (10-K, 8-K, etc.) + filing_date: Filing date + accession_number: SEC accession number + + Returns: + Unique filing ID string + """ + key = f"{ticker}:{filing_type}:{filing_date.isoformat()}:{accession_number}" + return hashlib.sha256(key.encode()).hexdigest()[:32] + + +def generate_chunk_id(filing_id: str, chunk_index: int) -> str: + """ + Generate a deterministic chunk ID. + + Args: + filing_id: Parent filing ID + chunk_index: Index of chunk in document + + Returns: + Unique chunk ID string + """ + key = f"{filing_id}:{chunk_index}" + return hashlib.sha256(key.encode()).hexdigest()[:32] + + +def generate_xbrl_fact_id( + filing_id: str, concept_name: str, period_end: datetime +) -> str: + """ + Generate a deterministic XBRL fact ID. + + Args: + filing_id: Parent filing ID + concept_name: XBRL concept name + period_end: Period end date + + Returns: + Unique fact ID string + """ + key = f"{filing_id}:{concept_name}:{period_end.isoformat()}" + return hashlib.sha256(key.encode()).hexdigest()[:32] + + +# ============================================================================= +# Singleton +# ============================================================================= + +_filings_manager: Optional[FilingsManager] = None + + +def get_filings_manager() -> FilingsManager: + """ + Get the singleton FilingsManager instance. + + Returns: + FilingsManager instance + """ + global _filings_manager + if _filings_manager is None: + _filings_manager = FilingsManager() + _LOG.info("FilingsManager initialized") + return _filings_manager diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/pgvector_client.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/pgvector_client.py new file mode 100644 index 000000000..f57ac9935 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/pgvector_client.py @@ -0,0 +1,733 @@ +""" +PostgreSQL client with pgvector support for warm tier storage. + +Provides connection pooling and vector operations for: +- Storing document chunks with embeddings +- Semantic similarity search +- XBRL facts storage +- Filing metadata management + +Environment Variables: +- POSTGRES_HOST: Database host (default: localhost) +- POSTGRES_PORT: Database port (default: 5432) +- POSTGRES_DB: Database name (default: financial_kb) +- POSTGRES_USER: Database user (default: fin) +- POSTGRES_PASSWORD: Database password (default: fin_local) +""" + +import logging +import os +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple +from dotenv import load_dotenv + +import psycopg2 +from psycopg2 import pool, sql +from psycopg2.extras import RealDictCursor, execute_batch + +load_dotenv() + +_LOG = logging.getLogger(__name__) + +# Default embedding dimensions for sentence-transformers/all-mpnet-base-v2 +DEFAULT_EMBEDDING_DIM = 768 + + +class PostgresClient: + """ + PostgreSQL client with connection pooling and pgvector support. + + Manages connections to the warm tier database with automatic + reconnection and connection pooling. + """ + + def __init__( + self, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + user: Optional[str] = None, + password: Optional[str] = None, + min_connections: int = 1, + max_connections: int = 10, + ): + """ + Initialize PostgreSQL client with connection pool. + + Args: + host: Database host (default: POSTGRES_HOST env var or localhost) + port: Database port (default: POSTGRES_PORT env var or 5432) + database: Database name (default: POSTGRES_DB env var or financial_kb) + user: Database user (default: POSTGRES_USER env var or fin) + password: Database password (default: POSTGRES_PASSWORD env var) + min_connections: Minimum connections in pool + max_connections: Maximum connections in pool + """ + self.host = host or os.getenv("POSTGRES_HOST", "localhost") + self.port = port or int(os.getenv("POSTGRES_PORT", "5432")) + self.database = database or os.getenv("POSTGRES_DB", "financial_kb") + self.user = user or os.getenv("POSTGRES_USER", "fin") + self.password = password or os.getenv("POSTGRES_PASSWORD", "fin_local") + + self._pool: Optional[pool.SimpleConnectionPool] = None + self._min_connections = min_connections + self._max_connections = max_connections + + def _get_pool(self) -> pool.SimpleConnectionPool: + """Get or create connection pool.""" + if self._pool is None: + self._pool = pool.SimpleConnectionPool( + minconn=self._min_connections, + maxconn=self._max_connections, + host=self.host, + port=self.port, + database=self.database, + user=self.user, + password=self.password, + ) + _LOG.info( + "PostgreSQL connection pool created: min=%d, max=%d", + self._min_connections, + self._max_connections, + ) + return self._pool + + @contextmanager + def get_connection(self): + """ + Context manager for database connections. + + Yields: + psycopg2 connection object + + Example: + with client.get_connection() as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1") + """ + conn = None + try: + conn = self._get_pool().getconn() + yield conn + finally: + if conn: + self._get_pool().putconn(conn) + + @contextmanager + def get_cursor(self, cursor_factory=RealDictCursor): + """ + Context manager for database cursors with automatic commit/rollback. + + Args: + cursor_factory: Cursor class to use (default: RealDictCursor) + + Yields: + psycopg2 cursor object + """ + with self.get_connection() as conn: + cursor = conn.cursor(cursor_factory=cursor_factory) + try: + yield cursor + conn.commit() + except Exception: + conn.rollback() + raise + finally: + cursor.close() + + def ping(self) -> bool: + """ + Test connection to PostgreSQL server. + + Returns: + True if connection successful + """ + try: + with self.get_cursor() as cur: + cur.execute("SELECT 1") + result = cur.fetchone() + return result is not None + except psycopg2.OperationalError as e: + _LOG.error("PostgreSQL connection failed: %s", e) + return False + + def execute_many( + self, + query: str, + params_list: List[Tuple], + batch_size: int = 100, + ) -> int: + """ + Execute a query with multiple parameter sets using batch operations. + + Args: + query: SQL query with placeholders + params_list: List of parameter tuples + batch_size: Number of rows per batch + + Returns: + Number of rows inserted/updated + """ + count = 0 + with self.get_connection() as conn: + with conn.cursor() as cur: + for i in range(0, len(params_list), batch_size): + batch = params_list[i : i + batch_size] + execute_batch(cur, query, batch) + count += len(batch) + conn.commit() + return count + + # ------------------------------------------------------------------------- + # Vector Operations + # ------------------------------------------------------------------------- + + def insert_embedding( + self, + table: str, + chunk_id: str, + embedding: List[float], + extra_columns: Optional[Dict[str, Any]] = None, + ) -> bool: + """ + Insert a chunk with its vector embedding. + + Args: + table: Target table name (e.g., "chunks") + chunk_id: Unique chunk identifier + embedding: Vector embedding as list of floats + extra_columns: Additional column values + + Returns: + True if successful + """ + columns = ["id", "embedding"] + values = [chunk_id, psycopg2.extensions.adapt(embedding).getquoted().decode()] + + if extra_columns: + for key, value in extra_columns.items(): + columns.append(key) + values.append(value) + + columns_sql = sql.SQL(", ").join(sql.Identifier(col) for col in columns) + placeholders = sql.SQL(", ").join(sql.Placeholder() for _ in values) + + query = sql.SQL("INSERT INTO {table} ({columns}) VALUES ({placeholders})") + query = query.format( + table=sql.Identifier(table), + columns=columns_sql, + placeholders=placeholders, + ) + + try: + with self.get_cursor() as cur: + cur.execute(query, values) + return True + except psycopg2.Error as e: + _LOG.error("Failed to insert embedding: %s", e) + return False + + def search_similar( + self, + query_embedding: List[float], + table: str = "chunks", + limit: int = 10, + threshold: float = 0.5, + ticker_filter: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """ + Search for similar chunks using vector similarity. + + Uses pgvector's cosine distance (<=>) for similarity search. + + Args: + query_embedding: Query vector as list of floats + table: Table to search (default: chunks) + limit: Maximum results to return + threshold: Minimum similarity threshold (0-1) + ticker_filter: Optional ticker symbol to filter by + + Returns: + List of matching chunks with metadata and similarity scores + """ + embedding_sql = psycopg2.extensions.adapt(query_embedding).getquoted().decode() + + if ticker_filter: + query = f""" + SELECT + c.id, + c.filing_id, + c.text, + c.section, + c.chunk_index, + 1 - (c.embedding <=> %s::vector) AS similarity, + f.ticker, + f.filing_type, + f.filing_date, + f.company_name + FROM {table} c + JOIN filings f ON c.filing_id = f.id + WHERE 1 - (c.embedding <=> %s::vector) > %s + AND f.ticker = %s + ORDER BY c.embedding <=> %s::vector + LIMIT %s + """ + params = [ + embedding_sql, + embedding_sql, + threshold, + ticker_filter, + embedding_sql, + limit, + ] + else: + query = f""" + SELECT + c.id, + c.filing_id, + c.text, + c.section, + c.chunk_index, + 1 - (c.embedding <=> %s::vector) AS similarity, + f.ticker, + f.filing_type, + f.filing_date, + f.company_name + FROM {table} c + JOIN filings f ON c.filing_id = f.id + WHERE 1 - (c.embedding <=> %s::vector) > %s + ORDER BY c.embedding <=> %s::vector + LIMIT %s + """ + params = [embedding_sql, embedding_sql, threshold, embedding_sql, limit] + + try: + with self.get_cursor() as cur: + cur.execute(query, params) + results = cur.fetchall() + return [dict(row) for row in results] + except psycopg2.Error as e: + _LOG.error("Vector search failed: %s", e) + return [] + + # ------------------------------------------------------------------------- + # Filing Operations + # ------------------------------------------------------------------------- + + def insert_filing(self, filing_data: Dict[str, Any]) -> Optional[str]: + """ + Insert a filing record. + + Args: + filing_data: Dict with filing metadata + + Returns: + Filing ID if successful, None otherwise + """ + query = """ + INSERT INTO filings ( + id, ticker, company_name, filing_type, cik, + accession_number, filing_date, period_of_report, + document_url, file_size_bytes + ) VALUES ( + %(id)s, %(ticker)s, %(company_name)s, %(filing_type)s, + %(cik)s, %(accession_number)s, %(filing_date)s, + %(period_of_report)s, %(document_url)s, %(file_size_bytes)s + ) + ON CONFLICT (id) DO UPDATE SET + updated_at = CURRENT_TIMESTAMP, + ticker = EXCLUDED.ticker, + company_name = EXCLUDED.company_name, + filing_type = EXCLUDED.filing_type + RETURNING id + """ + try: + with self.get_cursor() as cur: + cur.execute(query, filing_data) + result = cur.fetchone() + return result["id"] if result else None + except psycopg2.Error as e: + _LOG.error("Failed to insert filing: %s", e) + return None + + def get_filing(self, filing_id: str) -> Optional[Dict[str, Any]]: + """ + Get a filing by ID. + + Args: + filing_id: Filing identifier + + Returns: + Filing data dict or None + """ + query = "SELECT * FROM filings WHERE id = %s" + try: + with self.get_cursor() as cur: + cur.execute(query, [filing_id]) + result = cur.fetchone() + return dict(result) if result else None + except psycopg2.Error as e: + _LOG.error("Failed to get filing: %s", e) + return None + + def get_filings_by_ticker( + self, + ticker: str, + filing_types: Optional[List[str]] = None, + limit: int = 10, + ) -> List[Dict[str, Any]]: + """ + Get filings for a ticker. + + Args: + ticker: Stock ticker symbol + filing_types: Optional list of filing types to filter + limit: Maximum results + + Returns: + List of filing records + """ + if filing_types: + placeholders = ", ".join(["%s"] * len(filing_types)) + query = f""" + SELECT * FROM filings + WHERE ticker = %s AND filing_type IN ({placeholders}) + ORDER BY filing_date DESC + LIMIT %s + """ + params = [ticker] + filing_types + [limit] + else: + query = """ + SELECT * FROM filings + WHERE ticker = %s + ORDER BY filing_date DESC + LIMIT %s + """ + params = [ticker, limit] + + try: + with self.get_cursor() as cur: + cur.execute(query, params) + return [dict(row) for row in cur.fetchall()] + except psycopg2.Error as e: + _LOG.error("Failed to get filings: %s", e) + return [] + + # ------------------------------------------------------------------------- + # Chunk Operations + # ------------------------------------------------------------------------- + + def insert_chunks(self, chunks: List[Dict[str, Any]]) -> int: + """ + Insert multiple chunks with embeddings. + + Args: + chunks: List of chunk dicts with keys: + - id, filing_id, chunk_index, text, section, embedding + + Returns: + Number of chunks inserted + """ + if not chunks: + return 0 + + query = """ + INSERT INTO chunks ( + id, filing_id, chunk_index, text, section, embedding + ) VALUES ( + %(id)s, %(filing_id)s, %(chunk_index)s, %(text)s, + %(section)s, %(embedding)s + ) + ON CONFLICT (id) DO UPDATE SET + text = EXCLUDED.text, + section = EXCLUDED.section, + embedding = EXCLUDED.embedding, + updated_at = CURRENT_TIMESTAMP + """ + + # Convert embeddings to pgvector format (strip ARRAY prefix) + for chunk in chunks: + if "embedding" in chunk and isinstance(chunk["embedding"], list): + embedding_str = ( + psycopg2.extensions.adapt(chunk["embedding"]).getquoted().decode() + ) + # pgvector expects '[...]' not 'ARRAY[...]'. + if embedding_str.startswith("ARRAY["): + embedding_str = embedding_str[5:] # Remove "ARRAY" prefix + chunk["embedding"] = embedding_str + + count = 0 + with self.get_connection() as conn: + with conn.cursor() as cur: + for chunk in chunks: + # Use a per-row SAVEPOINT so that a single failed insert + # does not abort the whole transaction. Without this, + # psycopg2 leaves the transaction in an aborted state and + # every subsequent execute() raises "current transaction + # is aborted, commands ignored until end of transaction + # block". + cur.execute("SAVEPOINT chunk_sp") + try: + cur.execute(query, chunk) + except psycopg2.Error as e: + cur.execute("ROLLBACK TO SAVEPOINT chunk_sp") + _LOG.error( + "Failed to insert chunk %s: %s", chunk.get("id"), e + ) + continue + cur.execute("RELEASE SAVEPOINT chunk_sp") + count += 1 + conn.commit() + return count + + def delete_chunks_by_filing_ids(self, filing_ids: List[str]) -> int: + """ + Delete all chunks for the given ``filing_ids``. + + Cascades to ``document_metadata`` via the FK ``ON DELETE CASCADE``. + + :param filing_ids: filing IDs whose chunks should be removed + :return: number of chunk rows deleted + """ + if not filing_ids: + return 0 + query = "DELETE FROM chunks WHERE filing_id = ANY(%s)" + with self.get_connection() as conn: + with conn.cursor() as cur: + cur.execute(query, [filing_ids]) + count = cur.rowcount + conn.commit() + return count + + def insert_filings(self, rows: list[dict]) -> int: + """INSERT INTO filings ON CONFLICT (id) DO NOTHING.""" + if not rows: + return 0 + with self.get_connection() as conn: + with conn.cursor() as cur: + cur.executemany( + """ + INSERT INTO filings + (id, ticker, company_name, filing_type, cik, + accession_number, filing_date, period_of_report, + document_url, file_size_bytes) + VALUES + (%(id)s, %(ticker)s, %(company_name)s, %(filing_type)s, + %(cik)s, %(accession_number)s, %(filing_date)s, + %(period_of_report)s, %(document_url)s, %(file_size_bytes)s) + ON CONFLICT (id) DO NOTHING + """, + rows, + ) + conn.commit() + return len(rows) + + def insert_document_metadata(self, rows: list[dict]) -> int: + """INSERT INTO document_metadata ON CONFLICT (id) DO NOTHING.""" + if not rows: + return 0 + with self.get_connection() as conn: + with conn.cursor() as cur: + cur.executemany( + """ + INSERT INTO document_metadata (id, chunk_id, key, value) + VALUES (%(id)s, %(chunk_id)s, %(key)s, %(value)s) + ON CONFLICT (id) DO NOTHING + """, + rows, + ) + conn.commit() + return len(rows) + + def get_chunks_by_filing( + self, + filing_id: str, + include_embedding: bool = False, + ) -> List[Dict[str, Any]]: + """ + Get all chunks for a filing. + + Args: + filing_id: Filing identifier + include_embedding: Whether to include embedding vectors + + Returns: + List of chunk records + """ + columns = ["id", "filing_id", "chunk_index", "text", "section"] + if include_embedding: + columns.append("embedding") + + query = f""" + SELECT {", ".join(columns)} + FROM chunks + WHERE filing_id = %s + ORDER BY chunk_index + """ + + try: + with self.get_cursor() as cur: + cur.execute(query, [filing_id]) + return [dict(row) for row in cur.fetchall()] + except psycopg2.Error as e: + _LOG.error("Failed to get chunks: %s", e) + return [] + + # ------------------------------------------------------------------------- + # XBRL Facts Operations + # ------------------------------------------------------------------------- + + def insert_xbrl_facts(self, facts: List[Dict[str, Any]]) -> int: + """ + Insert multiple XBRL facts. + + Args: + facts: List of fact dicts with keys: + - id, filing_id, concept_name, value, value_numeric, + unit, period_start, period_end, instant_date, axis, member + + Returns: + Number of facts inserted + """ + if not facts: + return 0 + + query = """ + INSERT INTO xbrl_facts ( + id, filing_id, concept_name, value, value_numeric, + unit, period_start, period_end, instant_date, axis, member + ) VALUES ( + %(id)s, %(filing_id)s, %(concept_name)s, %(value)s, + %(value_numeric)s, %(unit)s, %(period_start)s, %(period_end)s, + %(instant_date)s, %(axis)s, %(member)s + ) + ON CONFLICT (id) DO NOTHING + """ + + count = 0 + with self.get_connection() as conn: + with conn.cursor() as cur: + for fact in facts: + try: + cur.execute(query, fact) + count += 1 + except psycopg2.Error as e: + _LOG.error( + "Failed to insert XBRL fact %s: %s", fact.get("id"), e + ) + conn.commit() + return count + + def get_xbrl_facts_by_concept( + self, + filing_id: str, + concepts: List[str], + ) -> List[Dict[str, Any]]: + """ + Get XBRL facts for specific concepts in a filing. + + Args: + filing_id: Filing identifier + concepts: List of concept names to retrieve + + Returns: + List of XBRL fact records + """ + placeholders = ", ".join(["%s"] * len(concepts)) + query = f""" + SELECT * FROM xbrl_facts + WHERE filing_id = %s AND concept_name IN ({placeholders}) + """ + params = [filing_id] + concepts + + try: + with self.get_cursor() as cur: + cur.execute(query, params) + return [dict(row) for row in cur.fetchall()] + except psycopg2.Error as e: + _LOG.error("Failed to get XBRL facts: %s", e) + return [] + + # ------------------------------------------------------------------------- + # Statistics and Maintenance + # ------------------------------------------------------------------------- + + def get_stats(self) -> Dict[str, Any]: + """ + Get storage statistics. + + Returns: + Dict with counts for filings, chunks, and XBRL facts + """ + stats = {} + + queries = { + "filings": "SELECT COUNT(*) FROM filings", + "chunks": "SELECT COUNT(*) FROM chunks", + "xbrl_facts": "SELECT COUNT(*) FROM xbrl_facts", + "unique_tickers": "SELECT COUNT(DISTINCT ticker) FROM filings", + } + + try: + with self.get_cursor(cursor_factory=psycopg2.extensions.cursor) as cur: + for key, query in queries.items(): + cur.execute(query) + result = cur.fetchone() + stats[key] = result[0] if result else 0 + except psycopg2.Error as e: + _LOG.error("Failed to get stats: %s", e) + + return stats + + def delete_filing(self, filing_id: str) -> bool: + """ + Delete a filing and all associated data. + + Args: + filing_id: Filing to delete + + Returns: + True if successful + """ + # Cascading deletes handle chunks and XBRL facts + query = "DELETE FROM filings WHERE id = %s" + try: + with self.get_cursor() as cur: + cur.execute(query, [filing_id]) + return True + except psycopg2.Error as e: + _LOG.error("Failed to delete filing: %s", e) + return False + + def close(self) -> None: + """Close all connections in the pool.""" + if self._pool: + self._pool.closeall() + self._pool = None + _LOG.info("PostgreSQL connections closed") + + +# Global singleton instance +_postgres_client: Optional[PostgresClient] = None + + +def get_postgres_client() -> PostgresClient: + """ + Get the singleton PostgreSQL client instance. + + Returns: + PostgresClient instance + """ + global _postgres_client + if _postgres_client is None: + _postgres_client = PostgresClient() + if _postgres_client.ping(): + _LOG.info( + "Connected to PostgreSQL at %s:%d/%s", + _postgres_client.host, + _postgres_client.port, + _postgres_client.database, + ) + else: + _LOG.warning("PostgreSQL connection test failed") + return _postgres_client diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/test_connection.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/test_connection.py new file mode 100644 index 000000000..15bb2297c --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/test_connection.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +""" +Test script for warm storage (PostgreSQL + pgvector). + +Usage: + python -m app.storage.warm_storage.test_connection + +This script: +1. Tests connection to PostgreSQL +2. Verifies pgvector extension is enabled +3. Shows storage statistics +4. Demonstrates basic operations +""" + +import logging +import sys +from datetime import datetime + +from app.storage.warm_storage import ( + get_postgres_client, + get_filings_manager, + FilingData, + ChunkData, + generate_filing_id, + generate_chunk_id, +) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", +) +_LOG = logging.getLogger(__name__) + + +def test_connection() -> bool: + """Test PostgreSQL connection.""" + _LOG.info("Testing PostgreSQL connection...") + client = get_postgres_client() + + if client.ping(): + _LOG.info( + "✓ Connected to PostgreSQL at %s:%d/%s", + client.host, + client.port, + client.database, + ) + return True + else: + _LOG.error("✗ Failed to connect to PostgreSQL") + return False + + +def test_pgvector_extension() -> bool: + """Verify pgvector extension is enabled.""" + _LOG.info("Checking pgvector extension...") + client = get_postgres_client() + + try: + with client.get_cursor() as cur: + cur.execute(""" + SELECT extname FROM pg_extension + WHERE extname = 'vector' + """) + result = cur.fetchone() + if result: + _LOG.info("✓ pgvector extension is enabled") + return True + else: + _LOG.error("✗ pgvector extension not found") + return False + except Exception as e: + _LOG.error("✗ Error checking pgvector: %s", e) + return False + + +def test_storage_stats() -> None: + """Display storage statistics.""" + _LOG.info("Fetching storage statistics...") + manager = get_filings_manager() + stats = manager.get_stats() + + _LOG.info("=" * 50) + _LOG.info("Storage Statistics:") + _LOG.info(" - Filings: %d", stats.get("filings", 0)) + _LOG.info(" - Chunks: %d", stats.get("chunks", 0)) + _LOG.info(" - XBRL Facts: %d", stats.get("xbrl_facts", 0)) + _LOG.info(" - Unique Tickers: %d", stats.get("unique_tickers", 0)) + _LOG.info(" - Summary: %s", stats.get("summary", "N/A")) + _LOG.info("=" * 50) + + +def demo_store_filing() -> None: + """Demonstrate storing a filing with chunks.""" + _LOG.info("Demonstrating filing storage...") + + manager = get_filings_manager() + + # Create a sample filing + filing_id = generate_filing_id( + ticker="TEST", + filing_type="10-K", + filing_date=datetime(2024, 12, 31), + accession_number="0000000000-24-000001", + ) + + filing = FilingData( + id=filing_id, + ticker="TEST", + company_name="Test Company Inc.", + filing_type="10-K", + cik="0000000000", + accession_number="0000000000-24-000001", + filing_date=datetime(2024, 12, 31), + period_of_report=datetime(2024, 12, 31), + document_url="https://example.com/filing", + ) + + # Store filing + if manager.store_filing(filing): + _LOG.info("✓ Stored filing: %s", filing_id) + + # Create sample chunks with dummy embeddings + dummy_embedding = [0.01 * i for i in range(768)] + chunks = [ + ChunkData( + id=generate_chunk_id(filing_id, 0), + filing_id=filing_id, + chunk_index=0, + text="Test Company reported strong revenue growth in Q4 2024.", + section="MD&A", + embedding=dummy_embedding, + ), + ChunkData( + id=generate_chunk_id(filing_id, 1), + filing_id=filing_id, + chunk_index=1, + text="Net income increased by 15% year-over-year to $1.2 billion.", + section="Financial Statements", + embedding=dummy_embedding, + ), + ] + + count = manager.store_chunks(chunks) + _LOG.info("✓ Stored %d chunks", count) + + # Retrieve and verify + retrieved = manager.get_filing(filing_id) + if retrieved: + _LOG.info( + "✓ Retrieved filing: %s - %s", retrieved.ticker, retrieved.filing_type + ) + + # Get chunks + chunk_list = manager.get_chunks_for_filing(filing_id) + _LOG.info("✓ Retrieved %d chunks for filing", len(chunk_list)) + + # Cleanup demo data + manager.delete_filing(filing_id) + _LOG.info("✓ Cleaned up demo filing") + else: + _LOG.error("✗ Failed to store demo filing") + + +def main() -> int: + """Run all tests.""" + _LOG.info("=" * 60) + _LOG.info("Warm Storage Test Suite (PostgreSQL + pgvector)") + _LOG.info("=" * 60) + + # Test connection + if not test_connection(): + _LOG.error("Connection test failed - exiting") + return 1 + + # Test pgvector extension + if not test_pgvector_extension(): + _LOG.error("pgvector test failed - exiting") + return 1 + + # Show stats + test_storage_stats() + + # Demo operations + demo_store_filing() + + _LOG.info("=" * 60) + _LOG.info("All warm storage tests completed successfully!") + _LOG.info("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/tests/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/tests/__init__.py new file mode 100644 index 000000000..403e59e3d --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for warm storage (PostgreSQL + pgvector).""" diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/tests/test_pgvector.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/tests/test_pgvector.py new file mode 100644 index 000000000..069bc6c83 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/storage/warm_storage/tests/test_pgvector.py @@ -0,0 +1,253 @@ +# storage/warm_storage/tests/test_pgvector.py +import logging +import os +import pytest + +_LOG = logging.getLogger(__name__) + +pytestmark = pytest.mark.skipif( + not os.getenv("POSTGRES_HOST"), + reason="POSTGRES_HOST not set - skipping PostgreSQL tests", +) + + +@pytest.fixture(scope="module") +def pg(): + from app.storage.warm_storage.pgvector_client import get_postgres_client + + client = get_postgres_client() + yield client + client.close() + + +# ── connection ──────────────────────────────────────────────────────────────── + + +def test_ping(pg): + assert pg.ping() is True + + +# ── companies ───────────────────────────────────────────────────────────────── + + +def test_upsert_and_get_company(pg): + pg.upsert_company( + cik="0000320193", + ticker="AAPL", + name="Apple Inc.", + sic_code="3571", + sector="Information Technology", + sub_industry="Technology Hardware", + ) + company = pg.get_company("0000320193") + assert company is not None + assert company["ticker"] == "AAPL" + assert company["name"] == "Apple Inc." + + +# ── filings ─────────────────────────────────────────────────────────────────── + + +def test_upsert_and_get_filing(pg): + pg.upsert_company( + cik="0000320193", + ticker="AAPL", + name="Apple Inc.", + ) + filing_id = pg.upsert_filing( + cik="0000320193", + form_type="10-K", + filing_date="2024-01-01", + accession="0000320193-24-000001", + s3_raw_path="s3://raw-edgar/0000320193/primary.htm", + period_of_report="2023-09-30", + ) + assert filing_id is not None + + filings = pg.get_filings("0000320193", form_type="10-K") + assert len(filings) >= 1 + assert filings[0]["form_type"] == "10-K" + _LOG.info("Filing id: %s", filing_id) + + +# ── chunks + embeddings ─────────────────────────────────────────────────────── + + +def test_insert_chunk_single(pg): + filings = pg.get_filings("0000320193", form_type="10-K") + filing_id = str(filings[0]["id"]) + dummy_embedding = [0.1] * 768 + + chunk_id = pg.insert_chunk( + filing_id=filing_id, + text="Apple faces intense competition in all markets.", + embedding=dummy_embedding, + section="Risk Factors", + chunk_index=0, + token_count=8, + metadata={"cik": "0000320193", "form": "10-K"}, + ) + assert chunk_id is not None + _LOG.info("Chunk id: %s", chunk_id) + + +def test_insert_chunks_batch(pg): + filings = pg.get_filings("0000320193", form_type="10-K") + filing_id = str(filings[0]["id"]) + dummy_embedding = [0.1] * 768 + + chunks = [ + { + "filing_id": filing_id, + "section": "MD&A", + "chunk_index": 1, + "text": "Revenue increased 8 percent year over year.", + "token_count": 9, + "embedding": dummy_embedding, + "metadata": {"cik": "0000320193"}, + }, + { + "filing_id": filing_id, + "section": "MD&A", + "chunk_index": 2, + "text": "Gross margin expanded to 44 percent driven by services growth.", + "token_count": 11, + "embedding": dummy_embedding, + "metadata": {"cik": "0000320193"}, + }, + ] + count = pg.insert_chunks_batch(chunks) + assert count == 2 + _LOG.info("Batch inserted %d chunks", count) + + +def test_semantic_search(pg): + dummy_embedding = [0.1] * 768 + + results = pg.semantic_search( + embedding=dummy_embedding, + limit=5, + cik="0000320193", + ) + assert len(results) >= 1 + assert "text" in results[0] + assert "score" in results[0] + assert "section" in results[0] + _LOG.info( + "Top result: '%s' (score=%.4f)", + results[0]["text"][:60], + results[0]["score"], + ) + + +def test_semantic_search_with_section_filter(pg): + dummy_embedding = [0.1] * 768 + + results = pg.semantic_search( + embedding=dummy_embedding, + limit=5, + section="Risk Factors", + cik="0000320193", + ) + assert all(r["section"] == "Risk Factors" for r in results) + _LOG.info("Section filtered results: %d", len(results)) + + +# ── xbrl facts ──────────────────────────────────────────────────────────────── + + +def test_upsert_xbrl_facts(pg): + count = pg.upsert_xbrl_facts( + [ + { + "cik": "0000320193", + "concept": "Revenues", + "period_end": "2023-09-30", + "value": 383285000000, + "unit": "USD", + "form_type": "10-K", + "accession": "0000320193-24-000001", + }, + { + "cik": "0000320193", + "concept": "Assets", + "period_end": "2023-09-30", + "value": 352583000000, + "unit": "USD", + "form_type": "10-K", + "accession": "0000320193-24-000001", + }, + ] + ) + assert count == 2 + + +def test_get_xbrl_facts(pg): + facts = pg.get_xbrl_facts("0000320193", concept="Revenues") + assert len(facts) >= 1 + assert float(facts[0]["value"]) == 383285000000.0 + _LOG.info("Revenues: %s %s", facts[0]["value"], facts[0]["unit"]) + + +def test_get_all_xbrl_facts_for_company(pg): + facts = pg.get_xbrl_facts("0000320193") + assert len(facts) >= 2 + concepts = {f["concept"] for f in facts} + assert "Revenues" in concepts + assert "Assets" in concepts + + +# ── articles ────────────────────────────────────────────────────────────────── + + +def test_upsert_article(pg): + article_id = pg.upsert_article( + source="reuters", + url="https://reuters.com/test-article-pgvector-001", + title="Apple reports record revenue", + published_at="2024-01-15 10:00:00", + body_text="Apple Inc. reported record quarterly revenue.", + sentiment="positive", + tickers_mentioned=["AAPL"], + ) + assert article_id is not None + _LOG.info("Article id: %s", article_id) + + +def test_upsert_article_duplicate(pg): + # second insert of same url should return None (ON CONFLICT DO NOTHING) + duplicate = pg.upsert_article( + source="reuters", + url="https://reuters.com/test-article-pgvector-001", + title="Apple reports record revenue", + published_at="2024-01-15 10:00:00", + ) + assert duplicate is None + + +# ── audit log ───────────────────────────────────────────────────────────────── + + +def test_collection_run_lifecycle(pg): + run_id = pg.start_collection_run("test_pgvector_collector") + assert run_id is not None + + pg.finish_collection_run( + run_id=run_id, + records_written=42, + status="success", + ) + _LOG.info("Run %s completed", run_id) + + +def test_collection_run_failure(pg): + run_id = pg.start_collection_run("test_pgvector_collector_fail") + assert run_id is not None + + pg.finish_collection_run( + run_id=run_id, + records_written=0, + status="failed", + error_msg="Simulated failure for test", + ) + _LOG.info("Failed run %s logged", run_id) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/__init__.py new file mode 100644 index 000000000..3c628ba52 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/__init__.py @@ -0,0 +1 @@ +# Streamlit UI components diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/chat.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/chat.py new file mode 100644 index 000000000..598a153e1 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/chat.py @@ -0,0 +1,107 @@ +""" +Chat UI component for Streamlit. + +Provides a Q&A interface backed by the orchestrator agent. +Supports streaming responses with source citations. +""" + +import streamlit as st + +from app.agents.orchestrator import run as run_orchestrator + + +def render() -> None: + """Render the research chat interface.""" + st.header("Research Chat") + st.caption("Ask questions about companies, markets, or investments") + + # Initialize chat history in session state + if "messages" not in st.session_state: + st.session_state.messages = [] + + # Display chat history + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + if message.get("sources"): + _render_sources(message["sources"]) + + # Chat input + if prompt := st.chat_input("Ask about a company or market..."): + # Add user message to history + st.session_state.messages.append({"role": "user", "content": prompt}) + with st.chat_message("user"): + st.markdown(prompt) + + # Generate assistant response + with st.chat_message("assistant"): + with st.spinner("Researching..."): + response = _generate_response(prompt) + + st.markdown(response["content"]) + if response.get("sources"): + _render_sources(response["sources"]) + + # Add assistant response to history + st.session_state.messages.append(response) + + +def _generate_response(prompt: str) -> dict: + """ + Generate a response using the orchestrator agent. + + Args: + prompt: User's question + + Returns: + Dict with content and sources + """ + # Extract ticker from prompt if present + ticker = _extract_ticker(prompt) + + # Run orchestrator + result = run_orchestrator(prompt, context={"ticker": ticker} if ticker else {}) + + return { + "role": "assistant", + "content": result.get("response", "I couldn't generate a response."), + "sources": result.get("sources", []), + "agents_used": result.get("agents_used", []), + } + + +def _extract_ticker(prompt: str) -> str | None: + """ + Extract stock ticker from user prompt. + + Simple heuristic: look for uppercase 1-5 letter words. + """ + import re + + # Common ticker patterns + patterns = [ + r"\$([A-Z]{1,5})\b", # $AAPL + r"\b([A-Z]{1,5})\b", # AAPL (standalone uppercase) + ] + + for pattern in patterns: + matches = re.findall(pattern, prompt) + if matches: + return matches[0] + + return None + + +def _render_sources(sources: list[str]) -> None: + """Render source citations.""" + if not sources: + return + + with st.expander("Sources", expanded=False): + for i, source in enumerate(sources, 1): + st.markdown(f"{i}. {source}") + + +def clear_chat() -> None: + """Clear chat history.""" + st.session_state.messages = [] diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/dashboard.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/dashboard.py new file mode 100644 index 000000000..dabf7a431 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/dashboard.py @@ -0,0 +1,89 @@ +""" +Dashboard UI component for Streamlit. + +Renders: +- Risk flags (from due diligence agent) +- Earnings KPI table (from earnings agent) +""" + +from typing import Any + +import streamlit as st + +from app.agents.diligence import run as run_diligence +from app.agents.earnings import run as run_earnings + + +def render(ticker: str) -> None: + """ + Render the dashboard for a given ticker. + + :param ticker: stock ticker symbol (e.g., ``"AAPL"``) + """ + st.header(f"Market Research Dashboard: {ticker}") + # Run the agents that have data sources wired up. + with st.spinner(f"Analyzing {ticker}..."): + diligence_result = run_diligence( + f"What are the key risks for {ticker}?", + context={"ticker": ticker}, + ) + earnings_result = run_earnings( + f"What are the latest earnings for {ticker}?", + context={"ticker": ticker}, + ) + # Risk flags occupy the top of the page. + st.subheader("Risk Flags") + _render_risk_flags(diligence_result) + # Earnings KPI table follows. + st.subheader("Earnings KPIs") + _render_earnings_table(earnings_result) + + +def _render_risk_flags(result: dict[str, Any]) -> None: + """ + Render risk flags from due diligence analysis grouped by severity. + """ + risk_flags = result.get("risk_flags", []) + if not risk_flags: + st.info("No significant risks identified") + return + high_risks = [r for r in risk_flags if r.get("severity") == "high"] + medium_risks = [r for r in risk_flags if r.get("severity") == "medium"] + low_risks = [r for r in risk_flags if r.get("severity") == "low"] + if high_risks: + st.error(f"**High Risk ({len(high_risks)})**") + for risk in high_risks[:3]: + st.write(f"- {risk.get('description', 'N/A')[:150]}") + if medium_risks: + st.warning(f"**Medium Risk ({len(medium_risks)})**") + for risk in medium_risks[:3]: + st.write(f"- {risk.get('description', 'N/A')[:150]}") + if low_risks: + st.success(f"**Low Risk ({len(low_risks)})**") + for risk in low_risks[:3]: + st.write(f"- {risk.get('description', 'N/A')[:150]}") + + +def _render_earnings_table(result: dict[str, Any]) -> None: + """ + Render the earnings KPI table. + """ + kpi_trends = result.get("kpi_trends", []) + earnings_summary = result.get("earnings_summary", "") + if not kpi_trends: + st.info(earnings_summary or "No earnings data available") + return + rows = [] + for kpi in kpi_trends: + rows.append( + { + "Metric": kpi.get("metric", "N/A"), + "Current": kpi.get("current_value", "N/A"), + "Prior": kpi.get("prior_value", "N/A"), + "Change": kpi.get("change", "N/A"), + } + ) + st.table(rows) + tone = result.get("management_tone", "") + if tone: + st.caption(f"Management tone: {tone[:200]}") diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/research.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/research.py new file mode 100644 index 000000000..c76785f62 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/app/ui/research.py @@ -0,0 +1,223 @@ +""" +Streamlit UI for the agentic research API. + +While the agent runs, a "thinking" panel shows each step live (route, +per-agent retrieval, synthesis). When the answer arrives, the thinking +panel collapses into a small expander and the clean answer + sources are +rendered prominently. + +Run: + streamlit run app/ui/research.py +""" + +import json +import os +import time + +import httpx +import streamlit as st + +API_URL = os.getenv("API_URL", "http://localhost:8000") + +# ############################################################################# +# Page setup +# ############################################################################# + +st.set_page_config( + page_title="Market Research Agent", + page_icon="🔎", + layout="wide", +) +st.title("Market Research Agent") +st.caption( + "Ask a question about any of our 67 tracked tickers. The agent picks " + "sub-agents, retrieves evidence, and writes a cited answer." +) + +# ############################################################################# +# Sidebar: status + examples +# ############################################################################# + +with st.sidebar: + st.subheader("Status") + try: + ping = httpx.get(f"{API_URL}/", timeout=2.0) + if ping.status_code == 200: + st.success(f"API: connected\n\n`{API_URL}`") + else: + st.error(f"API: HTTP {ping.status_code}") + except Exception as e: + st.error(f"API unreachable: {e}") + st.divider() + st.subheader("Examples") + examples = [ + "What does Apple disclose as risk factors?", + "Recent NVDA news sentiment?", + "How does JPMorgan describe regulatory risk?", + "What are analysts saying about Tesla?", + "Summarize Microsoft 8-K disclosures", + ] + for ex in examples: + if st.button(ex, key=f"ex_{ex[:18]}"): + st.session_state["query_text"] = ex + +# ############################################################################# +# Query input +# ############################################################################# + +query = st.text_input( + "Your question", + key="query_text", + placeholder="e.g. What are AAPL's main risk factors?", +) +go = st.button("Research", type="primary") + +# ############################################################################# +# Pipeline runner +# ############################################################################# + + +def _format_thinking_line(step: str, payload: dict) -> str: + """ + Render one streaming event as a single line in the live "thinking" + transcript. + + :param step: event step name from the API + :param payload: event payload from the API + :return: human-readable line; empty string if the event isn't worth + showing in the live transcript + """ + elapsed = payload.get("elapsed_ms") + elapsed_str = f"`{elapsed:.0f}ms`" if elapsed is not None else "" + if step == "route" and "agents" in payload: + ticker = payload.get("ticker") or "—" + agents = ", ".join(payload.get("agents", [])) + return ( + f"**Routed** {elapsed_str} → ticker=`{ticker}`, agents=`{agents}`. " + f"_{payload.get('reason', '')}_" + ) + if step == "retrieve" and payload.get("status") == "running": + return f"**Retrieving** {elapsed_str} from `{payload.get('agent')}` agent…" + if step == "retrieve" and payload.get("status") == "complete": + return ( + f"**Retrieved** {elapsed_str} from `{payload.get('agent')}` " + f"agent → {payload.get('count', 0)} chunk(s) in " + f"`{payload.get('step_ms', 0):.0f}ms`" + ) + if step == "synthesize" and payload.get("status") == "running": + return f"**Synthesizing** {elapsed_str}…" + if step == "synthesize" and payload.get("status") == "complete": + mode = "LLM" if payload.get("used_llm") else "extractive" + return ( + f"**Synthesized** {elapsed_str} ({mode}) in " + f"`{payload.get('step_ms', 0):.0f}ms`" + ) + return "" + + +def _run(query: str) -> None: + """ + Hit ``/research/stream`` and render each event live in a "thinking" + panel; collapse it once the final answer arrives. + + :param query: user query (already non-empty) + """ + thinking_holder = st.empty() + answer_holder = st.empty() + thinking_lines: list[str] = [] + final_payload: dict | None = None + final_timings: dict | None = None + wall_t0 = time.perf_counter() + try: + with httpx.stream( + "POST", + f"{API_URL}/research/stream", + json={"query": query}, + timeout=120.0, + ) as response: + response.raise_for_status() + for raw in response.iter_lines(): + line = raw.decode() if isinstance(raw, bytes) else raw + if not line or not line.startswith("data:"): + continue + event = json.loads(line[5:].strip()) + step = event["step"] + payload = event["payload"] + # Keep the live transcript updated. + tl = _format_thinking_line(step, payload) + if tl: + thinking_lines.append(tl) + with thinking_holder.container(): + st.markdown("**Thinking…**") + for ln in thinking_lines: + st.markdown(f"- {ln}") + if step == "synthesize" and payload.get("status") == "complete": + final_payload = payload + elif step == "done": + final_timings = payload.get("timings", {}) + elif step == "error": + st.error(payload.get("message", "server-side error")) + return + except httpx.HTTPError as e: + st.error(f"Network error talking to {API_URL}: {e}") + return + except json.JSONDecodeError as e: + st.error(f"Could not parse server response: {e}") + return + if final_payload is None: + st.warning("Stream ended before the agent produced an answer.") + return + # Collapse the live thinking transcript into a small expander. + wall_ms = (time.perf_counter() - wall_t0) * 1000.0 + with thinking_holder.container(): + with st.expander( + f"Show agent trace ({len(thinking_lines)} steps, " + f"{wall_ms:.0f}ms wall time)", + expanded=False, + ): + for ln in thinking_lines: + st.markdown(f"- {ln}") + if final_timings: + st.markdown("**Timing breakdown**") + cols = st.columns(len(final_timings)) + for col, (k, v) in zip(cols, final_timings.items()): + col.metric(k, f"{v:.0f}ms") + # Render the clean answer + sources. + with answer_holder.container(): + st.markdown("### Answer") + if final_payload.get("used_llm"): + st.caption("Generated with LLM synthesis") + else: + st.caption( + "Extractive synthesis — set LLM_BASE_URL / LLM_API_KEY / " + "LLM_MODEL on the API server to enable LLM prose generation" + ) + st.markdown(final_payload.get("answer", "")) + sources = final_payload.get("sources", []) + if sources: + st.markdown("### Sources") + for s in sources: + ticker = s.get("ticker") or "?" + src = s.get("source") or "?" + ftype = s.get("filing_type") or "" + date = (s.get("filing_date") or "")[:10] + score = s.get("score") or 0.0 + idx = s.get("id") + header_bits = [f"[{idx}]", ticker, src] + if ftype: + header_bits.append(ftype) + if date: + header_bits.append(date) + header_bits.append(f"score={score:.3f}") + with st.expander(" · ".join(header_bits)): + st.write(s.get("snippet", "")) + if s.get("accession_number"): + st.caption(f"Accession: {s['accession_number']}") + if s.get("url"): + st.markdown(f"[Open source]({s['url']})") + + +if go and query and query.strip(): + _run(query.strip()) +elif go and not query.strip(): + st.warning("Type a question first.") diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/bashrc b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/bashrc new file mode 100644 index 000000000..4b7ff4c49 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/bashrc @@ -0,0 +1 @@ +set -o vi diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/copy_docker_files.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/copy_docker_files.py new file mode 100644 index 000000000..0e97c194c --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/copy_docker_files.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python + +""" +Copy Docker-related files from the source directory to a destination directory. + +This script copies all Docker configuration and utility files from +class_project/project_template/ to a specified destination directory. + +Usage examples: + # Copy all files to a target directory. + > ./copy_docker_files.py --dst_dir /path/to/destination + + # Copy with verbose logging. + > ./copy_docker_files.py --dst_dir /path/to/destination -v DEBUG + +Import as: + +import class_project.project_template.copy_docker_files as cpdccodo +""" + +import argparse +import logging +import os +from typing import List + +import helpers.hdbg as hdbg +import helpers.hio as hio +import helpers.hparser as hparser +import helpers.hsystem as hsystem + +_LOG = logging.getLogger(__name__) + +# ############################################################################# +# Constants +# ############################################################################# + +# List of files to copy from the source directory. +_FILES_TO_COPY = [ + "bashrc", + "docker_bash.sh", + "docker_build.sh", + "docker_clean.sh", + "docker_cmd.sh", + "docker_exec.sh", + "docker_jupyter.sh", + "docker_name.sh", + "docker_push.sh", + "etc_sudoers", + "install_jupyter_extensions.sh", + "run_jupyter.sh" + "version.sh", +] + + +# ############################################################################# +# Helper functions +# ############################################################################# + + +def _get_source_dir() -> str: + """ + Get the absolute path to the source directory containing Docker files. + + :return: absolute path to class_project/project_template/ + """ + # Get the directory where this script is located. + script_dir = os.path.dirname(os.path.abspath(__file__)) + _LOG.debug("Script directory='%s'", script_dir) + return script_dir + + +def _copy_files( + *, + src_dir: str, + dst_dir: str, + files: List[str], +) -> None: + """ + Copy specified files from source directory to destination directory. + + :param src_dir: source directory path + :param dst_dir: destination directory path + :param files: list of filenames to copy + """ + # Verify source directory exists. + hdbg.dassert_dir_exists(src_dir, "Source directory does not exist:", src_dir) + # Create destination directory if it doesn't exist. + hio.create_dir(dst_dir, incremental=True) + _LOG.info("Copying %d files from '%s' to '%s'", len(files), src_dir, dst_dir) + # Copy each file. + copied_count = 0 + for filename in files: + src_path = os.path.join(src_dir, filename) + dst_path = os.path.join(dst_dir, filename) + # Verify source file exists. + hdbg.dassert_path_exists( + src_path, "Source file does not exist:", src_path + ) + # Copy the file using cp -a to preserve all permissions and attributes. + _LOG.debug("Copying '%s' -> '%s'", src_path, dst_path) + cmd = f"cp -a {src_path} {dst_path}" + hsystem.system(cmd) + copied_count += 1 + # + _LOG.info("Successfully copied %d files", copied_count) + + +# ############################################################################# + + +def _parse() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--dst_dir", + action="store", + required=True, + help="Destination directory where files will be copied", + ) + hparser.add_verbosity_arg(parser) + return parser + + +def _main(parser: argparse.ArgumentParser) -> None: + args = parser.parse_args() + hdbg.init_logger(verbosity=args.log_level, use_exec_path=True) + # Get source directory. + src_dir = _get_source_dir() + # Copy files to destination. + _copy_files( + src_dir=src_dir, + dst_dir=args.dst_dir, + files=_FILES_TO_COPY, + ) + + +if __name__ == "__main__": + _main(_parse()) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker-compose.yml b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker-compose.yml new file mode 100644 index 000000000..ff7e5ed28 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker-compose.yml @@ -0,0 +1,115 @@ +services: + + keydb: + image: eqalpha/keydb:latest + restart: "no" + command: > + keydb-server + --server-threads 2 + --maxmemory 1gb + --maxmemory-policy allkeys-lru + --save "" + --protected-mode no + ports: + - "6379:6379" + volumes: + - keydbdata:/data + healthcheck: + test: ["CMD", "keydb-cli", "ping"] + interval: 5s + timeout: 3s + retries: 5 + + postgres: + image: pgvector/pgvector:pg16 + restart: "no" + environment: + POSTGRES_DB: financial_kb + POSTGRES_USER: fin + POSTGRES_PASSWORD: fin_local + ports: + - "5432:5432" + volumes: + - pgdata:/var/lib/postgresql/data + - ./sql/init.sql:/docker-entrypoint-initdb.d/init.sql + healthcheck: + test: ["CMD-SHELL", "pg_isready -U fin -d financial_kb"] + interval: 5s + timeout: 3s + retries: 10 + + minio: + image: minio/minio:latest + restart: "no" + command: server /data --console-address ":9001" + environment: + MINIO_ROOT_USER: minioadmin + MINIO_ROOT_PASSWORD: minioadmin + ports: + - "9000:9000" + - "9001:9001" + volumes: + - miniodata:/data + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 10s + timeout: 5s + retries: 5 + + api: + build: . + image: txtai-market-research:latest + restart: "no" + command: ["uvicorn", "app.api.server:app", "--host", "0.0.0.0", "--port", "8000"] + environment: + KEYDB_HOST: keydb + KEYDB_PORT: 6379 + POSTGRES_HOST: postgres + POSTGRES_PORT: 5432 + POSTGRES_DB: financial_kb + POSTGRES_USER: fin + POSTGRES_PASSWORD: fin_local + MINIO_ENDPOINT: minio:9000 + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin + MINIO_SECURE: "false" + TXTAI_DATA_DIR: /app/data + env_file: + - .env + ports: + - "8000:8000" + volumes: + - ./data:/app/data + depends_on: + keydb: + condition: service_healthy + postgres: + condition: service_healthy + minio: + condition: service_healthy + healthcheck: + test: ["CMD-SHELL", "curl -f http://localhost:8000/ || exit 1"] + interval: 15s + timeout: 5s + retries: 5 + + ui: + image: txtai-market-research:latest + restart: "no" + command: ["streamlit", "run", "app/ui/research.py", "--server.port=8501", "--server.address=0.0.0.0"] + environment: + API_URL: http://api:8000 + env_file: + - .env + ports: + - "8501:8501" + volumes: + - ./data:/app/data + depends_on: + api: + condition: service_healthy + +volumes: + keydbdata: {} + pgdata: {} + miniodata: {} diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_bash.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_bash.sh new file mode 100755 index 000000000..0025e81f4 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_bash.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# """ +# This script launches a Docker container with an interactive bash shell for +# development. +# """ + +# Exit immediately if any command exits with a non-zero status. +set -e + +# Import the utility functions from the project template. +GIT_ROOT=$(git rev-parse --show-toplevel) +source $GIT_ROOT/class_project/project_template/utils.sh + +# Parse default args (-h, -v) and enable set -x if -v is passed. +parse_default_args "$@" + +# Load Docker configuration variables for this script. +get_docker_vars_script ${BASH_SOURCE[0]} +source $DOCKER_NAME +print_docker_vars + +# List the available Docker images matching the expected image name. +run "docker image ls $FULL_IMAGE_NAME" + +# Configure and run the Docker container with interactive bash shell. +# - Container is removed automatically on exit (--rm) +# - Interactive mode with TTY allocation (-ti) +# - Port forwarding for Jupyter or other services +# - Git root mounted to /git_root inside container +CONTAINER_NAME=${IMAGE_NAME}_bash +PORT= +DOCKER_CMD=$(get_docker_bash_command) +DOCKER_CMD_OPTS=$(get_docker_bash_options $CONTAINER_NAME $PORT) +run "$DOCKER_CMD $DOCKER_CMD_OPTS $FULL_IMAGE_NAME" diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_build.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_build.sh new file mode 100755 index 000000000..5b0957a99 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_build.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# """ +# Build a Docker container image for the project. +# +# This script sets up the build environment with error handling and command +# tracing, loads Docker configuration from docker_name.sh, and builds the +# Docker image using the build_container_image utility function. It supports +# both single-architecture and multi-architecture builds via the +# DOCKER_BUILD_MULTI_ARCH environment variable. +# """ + +# Exit immediately if any command exits with a non-zero status. +set -e + +# Import the utility functions. +GIT_ROOT=$(git rev-parse --show-toplevel) +source $GIT_ROOT/class_project/project_template/utils.sh + +# Parse default args (-h, -v) and enable set -x if -v is passed. +# Shift processed option flags so remaining args are passed to the build. +parse_default_args "$@" +shift $((OPTIND-1)) + +# Load Docker configuration variables (REPO_NAME, IMAGE_NAME, FULL_IMAGE_NAME). +get_docker_vars_script ${BASH_SOURCE[0]} +source $DOCKER_NAME +print_docker_vars + +# Configure Docker build settings. +# Enable BuildKit for improved build performance and features. +export DOCKER_BUILDKIT=1 +#export DOCKER_BUILDKIT=0 + +# Configure single-architecture build (set to 1 for multi-arch build). +#export DOCKER_BUILD_MULTI_ARCH=1 +export DOCKER_BUILD_MULTI_ARCH=0 + +# Build the container image. +# Pass extra arguments (e.g., --no-cache) via command line after -v. +build_container_image "$@" diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_clean.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_clean.sh new file mode 100755 index 000000000..7e40839ae --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_clean.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# """ +# Remove Docker container image for the project. +# +# This script cleans up Docker images by removing the container image +# matching the project configuration. Useful for freeing disk space or +# ensuring a fresh build. +# """ + +# Exit immediately if any command exits with a non-zero status. +set -e + +# Import the utility functions. +GIT_ROOT=$(git rev-parse --show-toplevel) +source $GIT_ROOT/class_project/project_template/utils.sh + +# Parse default args (-h, -v) and enable set -x if -v is passed. +parse_default_args "$@" + +# Load Docker configuration variables for this script. +get_docker_vars_script ${BASH_SOURCE[0]} +source $DOCKER_NAME +print_docker_vars + +# Remove the container image. +remove_container_image diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_cmd.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_cmd.sh new file mode 100755 index 000000000..f3616fc27 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_cmd.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# """ +# Execute a command in a Docker container. +# +# This script runs a specified command inside a new Docker container instance. +# The container is removed automatically after the command completes. The +# current directory is mounted to /data inside the container. +# """ + +# Exit immediately if any command exits with a non-zero status. +set -e + +# Import the utility functions. +GIT_ROOT=$(git rev-parse --show-toplevel) +source $GIT_ROOT/class_project/project_template/utils.sh + +# Parse default args (-h, -v) and enable set -x if -v is passed. +# Shift processed option flags so remaining args form the command. +parse_default_args "$@" +shift $((OPTIND-1)) + +# Capture the command to execute from remaining arguments. +CMD="$@" +echo "Executing: '$CMD'" + +# Load Docker configuration variables for this script. +get_docker_vars_script ${BASH_SOURCE[0]} +source $DOCKER_NAME +print_docker_vars + +# List available Docker images matching the expected image name. +run "docker image ls $FULL_IMAGE_NAME" +#(docker manifest inspect $FULL_IMAGE_NAME | grep arch) || true + +# Configure and run the Docker container with the specified command. +CONTAINER_NAME=$IMAGE_NAME +DOCKER_CMD=$(get_docker_cmd_command) +PORT="" +DOCKER_RUN_OPTS="" +DOCKER_CMD_OPTS=$(get_docker_bash_options $CONTAINER_NAME $PORT $DOCKER_RUN_OPTS) +run "$DOCKER_CMD $DOCKER_CMD_OPTS $FULL_IMAGE_NAME bash -c '$CMD'" diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_exec.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_exec.sh new file mode 100755 index 000000000..24f8e401a --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_exec.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# """ +# Execute a bash shell in a running Docker container. +# +# This script connects to an already running Docker container and opens an +# interactive bash session for debugging or inspection purposes. +# """ + +# Exit immediately if any command exits with a non-zero status. +set -e + +# Import the utility functions. +GIT_ROOT=$(git rev-parse --show-toplevel) +source $GIT_ROOT/class_project/project_template/utils.sh + +# Parse default args (-h, -v) and enable set -x if -v is passed. +parse_default_args "$@" + +# Load Docker configuration variables for this script. +get_docker_vars_script ${BASH_SOURCE[0]} +source $DOCKER_NAME +print_docker_vars + +# Execute bash shell in the running container. +exec_container diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_jupyter.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_jupyter.sh new file mode 100755 index 000000000..6c7d09b13 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_jupyter.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# """ +# Execute Jupyter Lab in a Docker container. +# +# This script launches a Docker container running Jupyter Lab with +# configurable port, directory mounting, and vim bindings. It passes +# command-line options to the run_jupyter.sh script inside the container. +# +# Usage: +# > docker_jupyter.sh [options] +# """ + +# Exit immediately if any command exits with a non-zero status. +set -e + +# Import the utility functions. +GIT_ROOT=$(git rev-parse --show-toplevel) +source $GIT_ROOT/class_project/project_template/utils.sh + +# Parse command-line options and set Jupyter configuration variables. +parse_docker_jupyter_args "$@" + +# Load Docker configuration variables for this script. +get_docker_vars_script ${BASH_SOURCE[0]} +source $DOCKER_NAME +print_docker_vars + +# List available Docker images and inspect architecture. +run "docker image ls $FULL_IMAGE_NAME" +(docker manifest inspect $FULL_IMAGE_NAME | grep arch) || true + +# Run the Docker container with Jupyter Lab. +CMD=$(get_run_jupyter_cmd "${BASH_SOURCE[0]}" "$OLD_CMD_OPTS") +CONTAINER_NAME=$IMAGE_NAME +DOCKER_CMD=$(get_docker_jupyter_command) +DOCKER_CMD_OPTS=$(get_docker_jupyter_options $CONTAINER_NAME $JUPYTER_HOST_PORT $JUPYTER_USE_VIM) +run "$DOCKER_CMD $DOCKER_CMD_OPTS $FULL_IMAGE_NAME $CMD" diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_name.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_name.sh new file mode 100755 index 000000000..498ea0139 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_name.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# """ +# Docker image naming configuration. +# +# This file defines the repository name, image name, and full image name +# variables used by all docker_*.sh scripts in the project template. +# """ + +REPO_NAME=gpsaggese +# The file should be all lower case. +IMAGE_NAME=umd_causal_success_analysis +FULL_IMAGE_NAME=$REPO_NAME/$IMAGE_NAME diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_push.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_push.sh new file mode 100755 index 000000000..d4e4a8e0a --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docker_push.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# """ +# Push Docker container image to Docker Hub or registry. +# +# This script authenticates with the Docker registry using credentials from +# ~/.docker/passwd.$REPO_NAME.txt and pushes the locally built container +# image to the remote repository. +# """ + +# Exit immediately if any command exits with a non-zero status. +set -e + +# Import the utility functions. +GIT_ROOT=$(git rev-parse --show-toplevel) +source $GIT_ROOT/class_project/project_template/utils.sh + +# Load Docker image naming configuration. +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source $SCRIPT_DIR/docker_name.sh + +# Push the container image to the registry. +push_container_image diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docs/architecture.excalidraw b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docs/architecture.excalidraw new file mode 100644 index 000000000..e4b3f1608 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/docs/architecture.excalidraw @@ -0,0 +1,1456 @@ +{ + "type": "excalidraw", + "version": 2, + "source": "https://excalidraw.com", + "elements": [ + { + "id": "user", + "type": "rectangle", + "x": 360, + "y": 0, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffffff", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 1373158607, + "version": 1, + "versionNonce": 239081664, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "user-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "user-t", + "type": "text", + "x": 364, + "y": 18.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 53710185, + "version": 1, + "versionNonce": 1592467582, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "User", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "user", + "originalText": "User", + "lineHeight": 1.25 + }, + { + "id": "ui", + "type": "rectangle", + "x": 360, + "y": 100, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#e7f5ff", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 590620972, + "version": 1, + "versionNonce": 525901257, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "ui-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "ui-t", + "type": "text", + "x": 364, + "y": 118.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 479341424, + "version": 1, + "versionNonce": 299655413, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "Streamlit UI (8501)", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "ui", + "originalText": "Streamlit UI (8501)", + "lineHeight": 1.25 + }, + { + "id": "api", + "type": "rectangle", + "x": 360, + "y": 200, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#e7f5ff", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 1581559893, + "version": 1, + "versionNonce": 220106708, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "api-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "api-t", + "type": "text", + "x": 364, + "y": 218.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 1453201079, + "version": 1, + "versionNonce": 1590571866, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "FastAPI / /research /research/stream", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "api", + "originalText": "FastAPI / /research /research/stream", + "lineHeight": 1.25 + }, + { + "id": "router", + "type": "rectangle", + "x": 360, + "y": 320, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#fff3bf", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 1915941033, + "version": 1, + "versionNonce": 1171165723, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "router-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "router-t", + "type": "text", + "x": 364, + "y": 338.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 186699714, + "version": 1, + "versionNonce": 1268073013, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "research_agent: _route()", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "router", + "originalText": "research_agent: _route()", + "lineHeight": 1.25 + }, + { + "id": "sec_agent", + "type": "rectangle", + "x": 120, + "y": 440, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#fff3bf", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 906070221, + "version": 1, + "versionNonce": 68252794, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "sec_agent-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "sec_agent-t", + "type": "text", + "x": 124, + "y": 458.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 63989048, + "version": 1, + "versionNonce": 201209006, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "SEC sub-agent (tags='sec')", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "sec_agent", + "originalText": "SEC sub-agent (tags='sec')", + "lineHeight": 1.25 + }, + { + "id": "news_agent", + "type": "rectangle", + "x": 600, + "y": 440, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#fff3bf", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 469521478, + "version": 1, + "versionNonce": 499635469, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "news_agent-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "news_agent-t", + "type": "text", + "x": 604, + "y": 458.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 1085242217, + "version": 1, + "versionNonce": 1292825379, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "News sub-agent (tags='news')", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "news_agent", + "originalText": "News sub-agent (tags='news')", + "lineHeight": 1.25 + }, + { + "id": "synth", + "type": "rectangle", + "x": 360, + "y": 560, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#fff3bf", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 56985562, + "version": 1, + "versionNonce": 1205264596, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "synth-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "synth-t", + "type": "text", + "x": 364, + "y": 578.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 427000597, + "version": 1, + "versionNonce": 1537640409, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "_synthesize (LLM or extractive)", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "synth", + "originalText": "_synthesize (LLM or extractive)", + "lineHeight": 1.25 + }, + { + "id": "txtai", + "type": "rectangle", + "x": 360, + "y": 700, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#d3f9d8", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 1395616197, + "version": 1, + "versionNonce": 1506083911, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "txtai-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "txtai-t", + "type": "text", + "x": 364, + "y": 718.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 1170252924, + "version": 1, + "versionNonce": 900911955, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "txtai.Embeddings (SQLite + ANN)", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "txtai", + "originalText": "txtai.Embeddings (SQLite + ANN)", + "lineHeight": 1.25 + }, + { + "id": "hot", + "type": "rectangle", + "x": -20, + "y": 820, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffe3e3", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 473392625, + "version": 1, + "versionNonce": 964669078, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "hot-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "hot-t", + "type": "text", + "x": -16, + "y": 838.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 1265438423, + "version": 1, + "versionNonce": 597409993, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "Hot: KeyDB (cache, prices)", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "hot", + "originalText": "Hot: KeyDB (cache, prices)", + "lineHeight": 1.25 + }, + { + "id": "warm", + "type": "rectangle", + "x": 240, + "y": 820, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffe3e3", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 1738238662, + "version": 1, + "versionNonce": 1866808230, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "warm-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "warm-t", + "type": "text", + "x": 244, + "y": 838.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 13955984, + "version": 1, + "versionNonce": 1629526406, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "Warm: Postgres + pgvector", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "warm", + "originalText": "Warm: Postgres + pgvector", + "lineHeight": 1.25 + }, + { + "id": "cold", + "type": "rectangle", + "x": 500, + "y": 820, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#ffe3e3", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 1730483679, + "version": 1, + "versionNonce": 342865763, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "cold-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "cold-t", + "type": "text", + "x": 504, + "y": 838.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 1499242942, + "version": 1, + "versionNonce": 907557513, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "Cold: MinIO (raw filings)", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "cold", + "originalText": "Cold: MinIO (raw filings)", + "lineHeight": 1.25 + }, + { + "id": "ext", + "type": "rectangle", + "x": 760, + "y": 820, + "width": 240, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "#f3f0ff", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 3 + }, + "seed": 730682428, + "version": 1, + "versionNonce": 596724165, + "isDeleted": false, + "boundElements": [ + { + "type": "text", + "id": "ext-t" + } + ], + "updated": 1, + "link": null, + "locked": false + }, + { + "id": "ext-t", + "type": "text", + "x": 764, + "y": 838.0, + "width": 232, + "height": 24, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": null, + "seed": 333889689, + "version": 1, + "versionNonce": 462382782, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "fontSize": 18, + "fontFamily": 3, + "text": "External: SEC EDGAR / NewsAPI", + "textAlign": "center", + "verticalAlign": "middle", + "containerId": "ext", + "originalText": "External: SEC EDGAR / NewsAPI", + "lineHeight": 1.25 + }, + { + "id": "a1", + "type": "arrow", + "x": 480.0, + "y": 60, + "width": 0.0, + "height": 40, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 2055599410, + "version": 1, + "versionNonce": 1639591160, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 0.0, + 40 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a2", + "type": "arrow", + "x": 480.0, + "y": 160, + "width": 0.0, + "height": 40, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 722831293, + "version": 1, + "versionNonce": 219494903, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 0.0, + 40 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a3", + "type": "arrow", + "x": 480.0, + "y": 260, + "width": 0.0, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 199170185, + "version": 1, + "versionNonce": 815887679, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 0.0, + 60 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a4l", + "type": "arrow", + "x": 390, + "y": 380, + "width": -150, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 207696844, + "version": 1, + "versionNonce": 770902344, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + -150, + 60 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a4r", + "type": "arrow", + "x": 570, + "y": 380, + "width": 150, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 1819980298, + "version": 1, + "versionNonce": 738639289, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 150, + 60 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a5l", + "type": "arrow", + "x": 240, + "y": 500, + "width": 150, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 1296491778, + "version": 1, + "versionNonce": 568054228, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 150, + 60 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a5r", + "type": "arrow", + "x": 720, + "y": 500, + "width": -150, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 1733294784, + "version": 1, + "versionNonce": 93309106, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + -150, + 60 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a6", + "type": "arrow", + "x": 480.0, + "y": 620, + "width": 0.0, + "height": 80, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 1567087081, + "version": 1, + "versionNonce": 986607412, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 0.0, + 80 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a7l", + "type": "arrow", + "x": 240, + "y": 500, + "width": 180, + "height": 200, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 1151541059, + "version": 1, + "versionNonce": 268062141, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 180, + 200 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a7r", + "type": "arrow", + "x": 720, + "y": 500, + "width": -180, + "height": 200, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 2089750183, + "version": 1, + "versionNonce": 1980614225, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + -180, + 200 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a8a", + "type": "arrow", + "x": 390, + "y": 760, + "width": -290, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 812896394, + "version": 1, + "versionNonce": 169222133, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + -290, + 60 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a8b", + "type": "arrow", + "x": 480.0, + "y": 760, + "width": -120.0, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 1185498233, + "version": 1, + "versionNonce": 629595553, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + -120.0, + 60 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a8c", + "type": "arrow", + "x": 570, + "y": 760, + "width": 50, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 1781132954, + "version": 1, + "versionNonce": 1349993688, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + 50, + 60 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + }, + { + "id": "a9", + "type": "arrow", + "x": 880, + "y": 820, + "width": -420, + "height": 60, + "angle": 0, + "strokeColor": "#1e1e1e", + "backgroundColor": "transparent", + "fillStyle": "solid", + "strokeWidth": 2, + "strokeStyle": "solid", + "roughness": 1, + "opacity": 100, + "groupIds": [], + "frameId": null, + "roundness": { + "type": 2 + }, + "seed": 1328261054, + "version": 1, + "versionNonce": 1901493144, + "isDeleted": false, + "boundElements": null, + "updated": 1, + "link": null, + "locked": false, + "points": [ + [ + 0, + 0 + ], + [ + -420, + 60 + ] + ], + "lastCommittedPoint": null, + "startBinding": null, + "endBinding": null, + "startArrowhead": null, + "endArrowhead": "arrow" + } + ], + "appState": { + "gridSize": null, + "viewBackgroundColor": "#ffffff" + }, + "files": {} +} \ No newline at end of file diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/etc_sudoers b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/etc_sudoers new file mode 100644 index 000000000..ee0816a15 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/etc_sudoers @@ -0,0 +1,31 @@ +# +# This file MUST be edited with the 'visudo' command as root. +# +# Please consider adding local content in /etc/sudoers.d/ instead of +# directly modifying this file. +# +# See the man page for details on how to write a sudoers file. +# +Defaults env_reset +Defaults mail_badpass +Defaults secure_path="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/snap/bin" + +# Host alias specification + +# User alias specification + +# Cmnd alias specification + +# User privilege specification +root ALL=(ALL:ALL) ALL + +# Members of the admin group may gain root privileges +%admin ALL=(ALL) ALL + +# Allow members of group sudo to execute any command +%sudo ALL=(ALL:ALL) ALL + +# See sudoers(5) for more information on "#include" directives: +postgres ALL=(ALL) NOPASSWD:ALL + +#includedir /etc/sudoers.d diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.API.ipynb b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.API.ipynb new file mode 100644 index 000000000..d1827484e --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.API.ipynb @@ -0,0 +1,343 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1e41f3f0", + "metadata": {}, + "source": [ + "# txtai API Tour\n", + "\n", + "Standalone walkthrough of the txtai primitives we use in this project. Each\n", + "cell exercises one concept in isolation so that someone new to txtai can\n", + "read the cells top-to-bottom and understand what the building blocks do\n", + "before seeing them composed in `txtai.example.ipynb`.\n", + "\n", + "## txtai endpoints exercised in this notebook\n", + "\n", + "The notebook is organized around the small surface of the `txtai` library\n", + "that the project depends on. Each section below maps to one cell.\n", + "\n", + "| # | Endpoint | What it does | Cell |\n", + "|---|-----------------------------------------|---------------------------------------------------------|------|\n", + "| 1 | `txtai.embeddings.Embeddings(config)` | Build a vector index plus content store | 1 |\n", + "| 2 | `Embeddings.index(rows)` | Bulk-load `(id, text, tags)` tuples into the index | 2 |\n", + "| 3 | `Embeddings.search(query, limit)` | Pure semantic top-k | 3 |\n", + "| 4 | `Embeddings.search(sql, parameters, \u2026)` | SQL-style filter (`WHERE tags = 'sec'`) over the index | 4 |\n", + "| 5 | `Embeddings.save(path)` / `.load(path)` | Persist and reload an index from disk | 5 |\n", + "| 5 | `Embeddings.count()` | Number of rows currently in the index | 5 |\n", + "| 6 | `txtai.LLM(model)` | Optional OpenAI-compatible LLM wrapper | 6 |\n", + "\n", + "These are the only `txtai` calls the project relies on \u2014 everything else\n", + "(routing, agents, retrieval) is built on top of them in\n", + "`app/pipeline/embeddings.py` and `app/agents/research_agent.py`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "6e006aa6", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# System libraries.\n", + "import logging\n", + "\n", + "# Third party libraries.\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2a8fe60f", + "metadata": {}, + "outputs": [], + "source": [ + "# Local utility for notebook logging setup.\n", + "_LOG = logging.getLogger(__name__)\n", + "logging.basicConfig(level=logging.INFO, format=\"%(levelname)s %(name)s: %(message)s\")" + ] + }, + { + "cell_type": "markdown", + "id": "4ee4be5b", + "metadata": {}, + "source": [ + "## 1. Embeddings index\n", + "\n", + "`txtai.Embeddings` is the workhorse: a vector index plus a content store.\n", + "We pass `content=True` so the original text is stored alongside the\n", + "vectors, which lets us return it from search results without a separate\n", + "document store." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "548b675c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/gprakash/src/umd_classes1/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/.venv/lib/python3.14/site-packages/tika/__init__.py:20: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " __import__('pkg_resources').declare_namespace(__name__)\n", + "INFO faiss.loader: Loading faiss.\n", + "INFO faiss.loader: Successfully loaded faiss.\n", + "'(ReadTimeoutError(\"HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)\"), '(Request ID: db40ad22-bc00-4991-a315-d11afd10374c)')' thrown while requesting HEAD https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/config.json\n", + "WARNING huggingface_hub.utils._http: '(ReadTimeoutError(\"HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)\"), '(Request ID: db40ad22-bc00-4991-a315-d11afd10374c)')' thrown while requesting HEAD https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/config.json\n", + "Retrying in 1s [Retry 1/5].\n", + "WARNING huggingface_hub.utils._http: Retrying in 1s [Retry 1/5].\n", + "INFO __main__: Created Embeddings instance with model=sentence-transformers/all-MiniLM-L6-v2\n" + ] + } + ], + "source": [ + "from txtai.embeddings import Embeddings\n", + "\n", + "# Create a fresh in-memory index using a small sentence-transformer.\n", + "embeddings = Embeddings(\n", + " {\n", + " \"path\": \"sentence-transformers/all-MiniLM-L6-v2\",\n", + " \"content\": True,\n", + " }\n", + ")\n", + "_LOG.info(\"Created Embeddings instance with model=%s\", embeddings.config[\"path\"])" + ] + }, + { + "cell_type": "markdown", + "id": "f336b9cb", + "metadata": {}, + "source": [ + "## 2. Index a few documents\n", + "\n", + "`index()` accepts an iterable of `(id, text, tags)` tuples. The `tags`\n", + "field is what enables SQL-style metadata filtering later." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "92b64099", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO __main__: Indexed 6 documents\n" + ] + } + ], + "source": [ + "# Build a tiny corpus: three SEC-flavored snippets and three news-flavored ones.\n", + "docs = [\n", + " (1, \"Risk factors include macro uncertainty and supply chain disruption.\", \"sec\"),\n", + " (2, \"Management discussed gross margin trends in the Q4 10-K filing.\", \"sec\"),\n", + " (3, \"The 8-K disclosed a material acquisition closing next quarter.\", \"sec\"),\n", + " (4, \"Analysts upgraded the stock to buy citing strong services growth.\", \"news\"),\n", + " (5, \"Press release announces a partnership with a major chip maker.\", \"news\"),\n", + " (6, \"Bearish commentary appeared after the latest earnings call.\", \"news\"),\n", + "]\n", + "embeddings.index(docs)\n", + "_LOG.info(\"Indexed %d documents\", len(docs))" + ] + }, + { + "cell_type": "markdown", + "id": "f33a7a83", + "metadata": {}, + "source": [ + "## 3. Plain semantic search\n", + "\n", + "`search(query, limit)` returns the top-k by cosine similarity, with the\n", + "stored text inlined." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d8f0c645", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO __main__: score=0.414 text=Risk factors include macro uncertainty and supply chain disruption.\n", + "INFO __main__: score=0.232 text=Management discussed gross margin trends in the Q4 10-K filing.\n", + "INFO __main__: score=0.203 text=The 8-K disclosed a material acquisition closing next quarter.\n" + ] + } + ], + "source": [ + "# Run a generic semantic query.\n", + "hits = embeddings.search(\"regulatory filings and risk\", limit=3)\n", + "for h in hits:\n", + " _LOG.info(\"score=%.3f text=%s\", h[\"score\"], h[\"text\"][:80])" + ] + }, + { + "cell_type": "markdown", + "id": "d2c57f5c", + "metadata": {}, + "source": [ + "## 4. SQL-style metadata filter\n", + "\n", + "Because `content=True` and we passed tags, txtai exposes a SQL surface.\n", + "We can WHERE-clause on the `tags` column to scope a search to a single\n", + "source \u2014 this is exactly how `app/agents/research_agent.py` keeps SEC and\n", + "News retrievals separated." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "eba61453", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO __main__: score=0.414 tag=sec text=Risk factors include macro uncertainty and supply chain disruption.\n", + "INFO __main__: score=0.232 tag=sec text=Management discussed gross margin trends in the Q4 10-K filing.\n", + "INFO __main__: score=0.203 tag=sec text=The 8-K disclosed a material acquisition closing next quarter.\n" + ] + } + ], + "source": [ + "# Restrict the same query to the news partition only.\n", + "sec_only = embeddings.search(\n", + " \"select id, text, score from txtai where similar(:q) and tags = 'sec'\",\n", + " parameters={\"q\": \"regulatory filings and risk\"},\n", + " limit=3,\n", + ")\n", + "for h in sec_only:\n", + " _LOG.info(\"score=%.3f tag=sec text=%s\", h[\"score\"], h[\"text\"][:80])" + ] + }, + { + "cell_type": "markdown", + "id": "7b8ccce9", + "metadata": {}, + "source": [ + "## 5. Persist and reload\n", + "\n", + "The shared production index lives on disk so multiple processes (the\n", + "collector, the API, the eval script) can hit it without rebuilding." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c0992b2e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO __main__: Reloaded index has 6 rows\n" + ] + } + ], + "source": [ + "# Save and reload the index round-trip.\n", + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "with tempfile.TemporaryDirectory() as tmp:\n", + " path = Path(tmp) / \"demo_index\"\n", + " embeddings.save(str(path))\n", + " reloaded = Embeddings()\n", + " reloaded.load(str(path))\n", + " _LOG.info(\"Reloaded index has %d rows\", reloaded.count())" + ] + }, + { + "cell_type": "markdown", + "id": "4b3adc59", + "metadata": {}, + "source": [ + "## 6. LLM wrapper (optional)\n", + "\n", + "`txtai.LLM` is a thin abstraction over OpenAI-compatible endpoints (or\n", + "local Ollama). The synthesizer in `research_agent.py` uses it when\n", + "`LLM_BASE_URL`/`LLM_API_KEY`/`LLM_MODEL` are set, and falls back to an\n", + "extractive template when they are not." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "81efb963", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO __main__: LLM credentials not set; skipping. The pipeline still runs extractively.\n" + ] + } + ], + "source": [ + "# This cell is illustrative only \u2014 actual instantiation is gated on env vars.\n", + "import os\n", + "\n", + "if os.getenv(\"LLM_API_KEY\") and os.getenv(\"LLM_BASE_URL\"):\n", + " from txtai import LLM\n", + "\n", + " llm = LLM(model=os.getenv(\"LLM_MODEL\", \"gpt-4o-mini\"))\n", + " _LOG.info(\"LLM wrapper ready: %s\", llm)\n", + "else:\n", + " _LOG.info(\"LLM credentials not set; skipping. The pipeline still runs extractively.\")" + ] + }, + { + "cell_type": "markdown", + "id": "a7e92a48", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "- `Embeddings(content=True)` gives us vectors plus a SQL-queryable\n", + " content store backed by SQLite\n", + "- Tagged documents enable per-source filtering with `WHERE tags = ...`\n", + "- `save()` / `load()` round-trip an index so the API server reads the\n", + " same artifact the collectors wrote\n", + "- `LLM` is optional; the system degrades to extractive answers without it" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,py:percent" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.API.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.API.py new file mode 100644 index 000000000..c610bc7a1 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.API.py @@ -0,0 +1,173 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.19.1 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% [markdown] +# # txtai API Tour +# +# Standalone walkthrough of the txtai primitives we use in this project. Each +# cell exercises one concept in isolation so that someone new to txtai can +# read the cells top-to-bottom and understand what the building blocks do +# before seeing them composed in `txtai.example.ipynb`. +# +# ## txtai endpoints exercised in this notebook +# +# The notebook is organized around the small surface of the `txtai` library +# that the project depends on. Each section below maps to one cell. +# +# | # | Endpoint | What it does | Cell | +# |---|-----------------------------------------|---------------------------------------------------------|------| +# | 1 | `txtai.embeddings.Embeddings(config)` | Build a vector index plus content store | 1 | +# | 2 | `Embeddings.index(rows)` | Bulk-load `(id, text, tags)` tuples into the index | 2 | +# | 3 | `Embeddings.search(query, limit)` | Pure semantic top-k | 3 | +# | 4 | `Embeddings.search(sql, parameters, …)` | SQL-style filter (`WHERE tags = 'sec'`) over the index | 4 | +# | 5 | `Embeddings.save(path)` / `.load(path)` | Persist and reload an index from disk | 5 | +# | 5 | `Embeddings.count()` | Number of rows currently in the index | 5 | +# | 6 | `txtai.LLM(model)` | Optional OpenAI-compatible LLM wrapper | 6 | +# +# These are the only `txtai` calls the project relies on — everything else +# (routing, agents, retrieval) is built on top of them in +# `app/pipeline/embeddings.py` and `app/agents/research_agent.py`. + +# %% +# %load_ext autoreload +# %autoreload 2 + +# System libraries. +import logging + +# Third party libraries. +import numpy as np + +# %% +# Local utility for notebook logging setup. +_LOG = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") + +# %% [markdown] +# ## 1. Embeddings index +# +# `txtai.Embeddings` is the workhorse: a vector index plus a content store. +# We pass `content=True` so the original text is stored alongside the +# vectors, which lets us return it from search results without a separate +# document store. + +# %% +from txtai.embeddings import Embeddings + +# Create a fresh in-memory index using a small sentence-transformer. +embeddings = Embeddings( + { + "path": "sentence-transformers/all-MiniLM-L6-v2", + "content": True, + } +) +_LOG.info("Created Embeddings instance with model=%s", embeddings.config["path"]) + +# %% [markdown] +# ## 2. Index a few documents +# +# `index()` accepts an iterable of `(id, text, tags)` tuples. The `tags` +# field is what enables SQL-style metadata filtering later. + +# %% +# Build a tiny corpus: three SEC-flavored snippets and three news-flavored ones. +docs = [ + (1, "Risk factors include macro uncertainty and supply chain disruption.", "sec"), + (2, "Management discussed gross margin trends in the Q4 10-K filing.", "sec"), + (3, "The 8-K disclosed a material acquisition closing next quarter.", "sec"), + (4, "Analysts upgraded the stock to buy citing strong services growth.", "news"), + (5, "Press release announces a partnership with a major chip maker.", "news"), + (6, "Bearish commentary appeared after the latest earnings call.", "news"), +] +embeddings.index(docs) +_LOG.info("Indexed %d documents", len(docs)) + +# %% [markdown] +# ## 3. Plain semantic search +# +# `search(query, limit)` returns the top-k by cosine similarity, with the +# stored text inlined. + +# %% +# Run a generic semantic query. +hits = embeddings.search("regulatory filings and risk", limit=3) +for h in hits: + _LOG.info("score=%.3f text=%s", h["score"], h["text"][:80]) + +# %% [markdown] +# ## 4. SQL-style metadata filter +# +# Because `content=True` and we passed tags, txtai exposes a SQL surface. +# We can WHERE-clause on the `tags` column to scope a search to a single +# source — this is exactly how `app/agents/research_agent.py` keeps SEC and +# News retrievals separated. + +# %% +# Restrict the same query to the news partition only. +sec_only = embeddings.search( + "select id, text, score from txtai where similar(:q) and tags = 'sec'", + parameters={"q": "regulatory filings and risk"}, + limit=3, +) +for h in sec_only: + _LOG.info("score=%.3f tag=sec text=%s", h["score"], h["text"][:80]) + +# %% [markdown] +# ## 5. Persist and reload +# +# The shared production index lives on disk so multiple processes (the +# collector, the API, the eval script) can hit it without rebuilding. + +# %% +# Save and reload the index round-trip. +import tempfile +from pathlib import Path + +with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "demo_index" + embeddings.save(str(path)) + reloaded = Embeddings() + reloaded.load(str(path)) + _LOG.info("Reloaded index has %d rows", reloaded.count()) + +# %% [markdown] +# ## 6. LLM wrapper (optional) +# +# `txtai.LLM` is a thin abstraction over OpenAI-compatible endpoints (or +# local Ollama). The synthesizer in `research_agent.py` uses it when +# `LLM_BASE_URL`/`LLM_API_KEY`/`LLM_MODEL` are set, and falls back to an +# extractive template when they are not. + +# %% +# This cell is illustrative only — actual instantiation is gated on env vars. +import os + +if os.getenv("LLM_API_KEY") and os.getenv("LLM_BASE_URL"): + from txtai import LLM + + llm = LLM(model=os.getenv("LLM_MODEL", "gpt-4o-mini")) + _LOG.info("LLM wrapper ready: %s", llm) +else: + _LOG.info("LLM credentials not set; skipping. The pipeline still runs extractively.") + +# %% [markdown] +# ## Summary +# +# - `Embeddings(content=True)` gives us vectors plus a SQL-queryable +# content store backed by SQLite +# - Tagged documents enable per-source filtering with `WHERE tags = ...` +# - `save()` / `load()` round-trip an index so the API server reads the +# same artifact the collectors wrote +# - `LLM` is optional; the system degrades to extractive answers without it diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.example.ipynb b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.example.ipynb new file mode 100644 index 000000000..04fff6670 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.example.ipynb @@ -0,0 +1,417 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "50742d7e", + "metadata": {}, + "source": [ + "# txtai Market Research \u2014 End-to-End Example\n", + "\n", + "Walks through the full project pipeline against AAPL:\n", + "\n", + "1. Ingest a small batch of SEC filings + news articles\n", + "2. Confirm the documents land in the txtai index\n", + "3. Drive the agentic research pipeline (router -> retrievers -> synthesizer)\n", + "4. Inspect routing, retrievals, and the synthesized answer\n", + "\n", + "Prerequisites: `docker-compose up -d` must be running so KeyDB,\n", + "PostgreSQL+pgvector, and MinIO are reachable. Secrets must be present in\n", + "`.env` (`SEC_USER_AGENT`, `NEWSAPI_KEY`, `ALPHAVANTAGE_API_KEY`,\n", + "`OPENAI_API_KEY`).\n", + "\n", + "## Endpoints exercised in this notebook\n", + "\n", + "The notebook touches three layers of the project. Knowing which call\n", + "lives where makes it easy to swap in a different driver (CLI, HTTP,\n", + "Streamlit) for the same pipeline.\n", + "\n", + "### Python entry points (used directly by this notebook)\n", + "\n", + "| Symbol | Source | Purpose | Cell |\n", + "|----------------------------------------------|-------------------------------------|----------------------------------------------------------|------|\n", + "| `app.collectors.SECCollector.collect(...)` | `app/collectors/sec_collector.py` | Pull filings into MinIO + Postgres + txtai | 1 |\n", + "| `app.pipeline.embeddings.get_embeddings()` | `app/pipeline/embeddings.py` | Singleton `txtai.Embeddings` index used by every agent | 2 |\n", + "| `app.pipeline.embeddings.search(query, k)` | `app/pipeline/embeddings.py` | Thin wrapper around `Embeddings.search` | 2 |\n", + "| `app.agents.research_agent.run_research_sync(query)` | `app/agents/research_agent.py` | Sync route -> retrieve -> synthesize, returns one dict | 3 |\n", + "| `app.agents.research_agent.run_research(query)` | `app/agents/research_agent.py` | Generator form, yields one event per pipeline stage | 4 |\n", + "\n", + "### HTTP endpoints (the same pipeline, exposed for UIs)\n", + "\n", + "Defined in `app/api/server.py` \u2014 not invoked from this notebook, but the\n", + "request bodies match the function signatures above.\n", + "\n", + "| Verb / Path | Body | Behavior |\n", + "|-------------------------|-----------------------|-------------------------------------------------------------------------|\n", + "| `GET /` | - | Health probe, returns the registered endpoint map |\n", + "| `POST /research` | `{\"query\": \"...\"}` | Calls `run_research_sync`, returns a JSON dict with the full trace |\n", + "| `POST /research/stream` | `{\"query\": \"...\"}` | SSE stream of `route` -> `retrieve` -> `synthesize` -> `done` events |\n", + "\n", + "### txtai primitives under the hood\n", + "\n", + "Every retrieval cell ultimately resolves to a `txtai.Embeddings` call:\n", + "\n", + "- `Embeddings.search(query, limit)` for plain semantic top-k\n", + "- `Embeddings.search(\"... where tags = 'sec'\", parameters=..., limit=...)`\n", + " for per-source scoping inside the SEC and News sub-agents\n", + "- `Embeddings.save(path)` / `.load(path)` for the on-disk index in `data/`\n", + "- Optional `txtai.LLM(model)` in the synthesizer when LLM credentials are set\n", + "\n", + "See `notebooks/txtai.API.ipynb` for each of those primitives in isolation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "65f5c84c", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# System libraries.\n", + "import logging\n", + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "\n", + "# Third party libraries.\n", + "from dotenv import load_dotenv" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "36fe335e", + "metadata": {}, + "outputs": [], + "source": [ + "# Make the project root importable when running the notebook from notebooks/.\n", + "project_root = Path.cwd().parent\n", + "if str(project_root) not in sys.path:\n", + " sys.path.insert(0, str(project_root))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "409191d6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO __main__: OPENAI_API_KEY present: False\n", + "INFO __main__: SEC_USER_AGENT present: True\n" + ] + } + ], + "source": [ + "# Load secrets from .env and configure notebook logging.\n", + "load_dotenv(project_root / \".env\")\n", + "_LOG = logging.getLogger(__name__)\n", + "logging.basicConfig(level=logging.INFO, format=\"%(levelname)s %(name)s: %(message)s\")\n", + "_LOG.info(\"OPENAI_API_KEY present: %s\", bool(os.getenv(\"OPENAI_API_KEY\")))\n", + "_LOG.info(\"SEC_USER_AGENT present: %s\", bool(os.getenv(\"SEC_USER_AGENT\")))" + ] + }, + { + "cell_type": "markdown", + "id": "6cbb19dc", + "metadata": {}, + "source": [ + "## 1. Ingest a small batch\n", + "\n", + "Pull a handful of recent SEC filings for AAPL into MinIO + PostgreSQL +\n", + "the txtai index. Limit kept small so the cell finishes in a couple of\n", + "minutes." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f879966b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/gprakash/src/umd_classes1/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/.venv/lib/python3.14/site-packages/tika/__init__.py:20: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " __import__('pkg_resources').declare_namespace(__name__)\n", + "INFO faiss.loader: Loading faiss.\n", + "INFO faiss.loader: Successfully loaded faiss.\n", + "INFO app.storage.cold_storage.minio_client: MinIO client created for endpoint: localhost:9000\n", + "INFO app.storage.cold_storage.minio_client: Connected to MinIO at localhost:9000\n", + "INFO app.storage.warm_storage.pgvector_client: PostgreSQL connection pool created: min=1, max=10\n", + "INFO app.storage.warm_storage.pgvector_client: Connected to PostgreSQL at localhost:5432/financial_kb\n", + "INFO app.storage.hot_storage.keydb_client: Connected to KeyDB at localhost:6379\n", + "INFO app.storage.cache_manager: CacheManager initialized\n", + "INFO app.collectors.base_collector: Starting collection for sec ticker=AAPL\n", + "INFO app.collectors.sec_collector: Resolved AAPL \u2192 CIK 0000320193\n", + "INFO app.collectors.sec_collector: Fetching submission history for AAPL (CIK 0000320193), max_filings=10000\u2026\n", + "INFO app.collectors.sec_collector: CIK 0000320193: found 1 extra submission files, fetching all...\n", + "INFO app.collectors.sec_collector: Submission history: 2229 filings on record\n", + "INFO app.collectors.sec_collector: [10-K] matched 32 filings (requested 100)\n", + "INFO app.collectors.sec_collector: [8-K] matched 100 filings (requested 100)\n", + "INFO app.collectors.sec_collector: Total filings to fetch: 132 (types: 10-K, 8-K)\n", + "INFO app.collectors.sec_collector: Fetching batch 1-100 of 132 filings (concurrency=8)\u2026\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019394000016/000032019394000016-index.htm \u2014 retrying in 2s (attempt 1/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000091205799010244/000091205799010244-index.htm \u2014 retrying in 2s (attempt 1/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000104746998001822/000104746998001822-index.htm \u2014 retrying in 2s (attempt 1/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000104746997006960/000104746997006960-index.htm \u2014 retrying in 2s (attempt 1/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019395000016/000032019395000016-index.htm \u2014 retrying in 2s (attempt 1/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019396000023/000032019396000023-index.htm \u2014 retrying in 2s (attempt 1/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019394000016/000032019394000016-index.htm \u2014 retrying in 4s (attempt 2/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000091205799010244/000091205799010244-index.htm \u2014 retrying in 4s (attempt 2/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000104746998001822/000104746998001822-index.htm \u2014 retrying in 4s (attempt 2/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000104746997006960/000104746997006960-index.htm \u2014 retrying in 4s (attempt 2/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019395000016/000032019395000016-index.htm \u2014 retrying in 4s (attempt 2/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019396000023/000032019396000023-index.htm \u2014 retrying in 4s (attempt 2/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019394000016/000032019394000016-index.htm \u2014 retrying in 6s (attempt 3/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000091205799010244/000091205799010244-index.htm \u2014 retrying in 6s (attempt 3/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000104746998001822/000104746998001822-index.htm \u2014 retrying in 6s (attempt 3/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019395000016/000032019395000016-index.htm \u2014 retrying in 6s (attempt 3/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000104746997006960/000104746997006960-index.htm \u2014 retrying in 6s (attempt 3/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019396000023/000032019396000023-index.htm \u2014 retrying in 6s (attempt 3/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019394000016/000032019394000016-index.htm \u2014 retrying in 8s (attempt 4/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000104746998001822/000104746998001822-index.htm \u2014 retrying in 8s (attempt 4/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019395000016/000032019395000016-index.htm \u2014 retrying in 8s (attempt 4/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000091205799010244/000091205799010244-index.htm \u2014 retrying in 8s (attempt 4/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000104746997006960/000104746997006960-index.htm \u2014 retrying in 8s (attempt 4/4)\n", + "WARNING app.collectors.sec_collector: HTTP 503 on https://www.sec.gov/Archives/edgar/data/320193/000032019396000023/000032019396000023-index.htm \u2014 retrying in 8s (attempt 4/4)\n", + "INFO app.collectors.sec_collector: Fetching batch 101-132 of 132 filings (concurrency=8)\u2026\n", + "INFO app.collectors.sec_collector: Done: 2 filings fetched in 30.3s (15.15s avg/filing)\n", + "INFO app.collectors.base_collector: Fetched 2 documents\n", + "INFO app.collectors.base_collector: Inserted 2 filing rows\n", + "INFO app.collectors.base_collector: Inserted 7 chunks\n", + "INFO app.collectors.base_collector: Inserted 63 document_metadata rows\n", + "INFO app.collectors.base_collector: Collection complete: fetched=2 cold=2 warm=7 indexed=0\n", + "INFO __main__: SEC collection summary: {'fetched': 2, 'stored_cold': 2, 'stored_warm': 7, 'indexed': 0}\n" + ] + } + ], + "source": [ + "# Run the SEC collector for one ticker.\n", + "import importlib, app.collectors.sec_collector as sc \n", + "importlib.reload(sc) \n", + "from app.collectors import SECCollector\n", + "\n", + "sec = SECCollector()\n", + "sec_summary = sec.collect(\n", + " ticker=\"AAPL\",\n", + " filing_types=[\"10-K\", \"8-K\"],\n", + " limit=2,\n", + ")\n", + "_LOG.info(\"SEC collection summary: %s\", sec_summary)" + ] + }, + { + "cell_type": "markdown", + "id": "ae854b2e", + "metadata": {}, + "source": [ + "## 2. Confirm the index has content\n", + "\n", + "`get_embeddings()` returns the singleton index used by every agent." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1b50f0d8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO __main__: Index row count: 7\n", + "INFO __main__: score=0.516 text=The information contained in this Current Report shall not be deemed \u201cfiled\u201d for purposes of Section 18 of the Securitie\n", + "INFO __main__: score=0.498 text=aapl-20260430 false 0000320193 0000320193 2026-04-30 2026-04-30 0000320193 us-gaap:CommonStockMember 2026-04-30 2026-04-\n", + "INFO __main__: score=0.495 text=Date: April 20, 2026 \n", + " \n", + "\n", + " \n", + " Apple Inc. \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \u00a0 \n", + "\n", + " \u00a0 \n", + "\n", + " \u00a0 \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \u00a0 \n", + "\n", + " \n", + " By: \n", + " \n", + "\n", + " \n", + " /s/ Jennifer Newstead \n", + " \n", + " \n", + "\n", + " \n", + "\n", + "\n" + ] + } + ], + "source": [ + "# Spot-check the index by issuing one semantic query.\n", + "from app.pipeline.embeddings import get_embeddings, search\n", + "\n", + "embeddings = get_embeddings()\n", + "_LOG.info(\"Index row count: %d\", embeddings.count())\n", + "hits = search(\"Apple revenue trend\", limit=3)\n", + "for h in hits:\n", + " _LOG.info(\"score=%.3f text=%s\", h[\"score\"], (h.get(\"text\") or \"\")[:120])" + ] + }, + { + "cell_type": "markdown", + "id": "67044bdd", + "metadata": {}, + "source": [ + "## 3. Run the agentic pipeline synchronously\n", + "\n", + "`run_research_sync` is the same entry point the FastAPI `/research`\n", + "handler calls. It returns a single dict with the full trace." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ae2d4070", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO app.agents.research_agent: SEC agent query=What are the key risks discussed in Apple's latest 10-K? ticker=AAPL\n", + "INFO __main__: Route: {'query': \"What are the key risks discussed in Apple's latest 10-K?\", 'ticker': 'AAPL', 'agents': ['sec'], 'reason': 'Question mentions SEC / filings keywords; routing to SEC agent only.', 'elapsed_ms': 0.07258300320245326, 'step_ms': 0.06799999391660094}\n", + "INFO __main__: Used LLM: False\n", + "INFO __main__: Chunk count: 5\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u2610 Item 5.02 Departure of Directors or Certain Officers; Election of Directors; Appointment of Certain Officers; Compensatory Arrangements of Certain Officers . On April 20, 2026, Apple Inc. (\u201cApple\u201d) announced that Tim Cook will transition from his role as Chief Executive Officer to Executive Chair of Apple\u2019s Board of Directors (the \u201cBoard\u201d), effective September 1, 2026 (the \u201cTransition Date\u201d). [1] The information contained in this Current Report shall not be deemed \u201cfiled\u201d for purposes of Section 18 of the Securities Exchange Act of 1934, as amended (the \u201cExchange Act\u201d), or incorporated by reference in any filing under the Securities Act of 1933, as amended, or the Exchange Act, except as shall be expressly set forth by specific reference in such a filing. Item 9.01 Financial Statements and Exhibits. (d) Exhibits. [2] aapl-20260430 false 0000320193 0000320193 2026-04-30 2026-04-30 0000320193 us-gaap:CommonStockMember 2026-04-30 2026-04-30 0000320193 aapl:A1.625NotesDue2026Member 2026-04-30 2026-04-30 0000320193 aapl:A2.000NotesDue2027Member 2026-04-30 2026-04-30 0000320193 aapl:A1.375NotesDue2029Member 2026-04-30 2026-04-30 0000320193 aapl:A3.050NotesDue2029Member 2026-04-30 2026-04-30 0000320193 aapl:A0.500Notesdue2031Member 2026-04-30 2026-04-30 0000320193 aapl:A3.600NotesDue2042Member 2026-04-30 2026-04-30 UNITED STATES SECURITIES AND EXCHANGE COMMISSION Washington, D.C. 20549 FORM 8-K CURRENT REPORT Pursuant to Section 13 OR 15(d) of The Securities Exchange Act of 1934 April 30, 2026 Date of Report (Date of earliest event reported) Apple Inc. (Exact name of Registrant as specified in its charter) California 001-36743 94-2404110 (State or other jurisdiction of incorporation) (Commission File Number) (I.R.S. Employer Identification No.) One Apple Park Way Cupertino , California 95014 (Address of principal executive offices) (Zip Code) ( 408 ) 996-1010 (Registrant\u2019s telephone number, including area code) Not applicable (Former name or former address, if changed since last report.) [3]\n" + ] + } + ], + "source": [ + "# Drive the full route -> retrieve -> synthesize pipeline.\n", + "from app.agents.research_agent import run_research_sync\n", + "\n", + "result = run_research_sync(\"What are the key risks discussed in Apple's latest 10-K?\")\n", + "_LOG.info(\"Route: %s\", result.get(\"route\"))\n", + "_LOG.info(\"Used LLM: %s\", result.get(\"used_llm\"))\n", + "_LOG.info(\"Chunk count: %d\", result.get(\"chunk_count\", 0))\n", + "print(result.get(\"answer\", \"\"))" + ] + }, + { + "cell_type": "markdown", + "id": "653d324a", + "metadata": {}, + "source": [ + "## 4. Stream the same query\n", + "\n", + "`run_research` is the generator form. The FastAPI SSE endpoint and the\n", + "Streamlit UI consume this so they can show each stage as it happens." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "17d8bfcb", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO __main__: [route] payload keys=['query', 'status', 'elapsed_ms']\n", + "INFO __main__: [route] payload keys=['query', 'ticker', 'agents', 'reason', 'elapsed_ms', 'step_ms']\n", + "INFO __main__: [retrieve] agent=news chunks=0\n", + "INFO app.agents.research_agent: News agent query=Any analyst upgrades for AAPL recently? ticker=AAPL\n", + "INFO __main__: [retrieve] agent=news chunks=0\n", + "INFO __main__: [synthesize] answer=\n", + "INFO __main__: [synthesize] answer=I couldn't find relevant documents about **AAPL**. Searched the news index. Try rephrasing or asking about a different ticker (we have data for AAPL, MSFT, NVDA\n", + "INFO __main__: [done] payload keys=['chunk_count', 'timings']\n" + ] + } + ], + "source": [ + "# Iterate over the streaming events and log each one as it arrives.\n", + "from app.agents.research_agent import run_research\n", + "\n", + "for event in run_research(\"Any analyst upgrades for AAPL recently?\"):\n", + " step = event[\"step\"]\n", + " if step == \"synthesize\":\n", + " _LOG.info(\"[%s] answer=%s\", step, event[\"payload\"].get(\"answer\", \"\")[:160])\n", + " elif step == \"retrieve\":\n", + " chunks = event[\"payload\"].get(\"chunks\", [])\n", + " _LOG.info(\"[%s] agent=%s chunks=%d\", step, event[\"payload\"].get(\"agent\"), len(chunks))\n", + " else:\n", + " _LOG.info(\"[%s] payload keys=%s\", step, list(event[\"payload\"].keys()))" + ] + }, + { + "cell_type": "markdown", + "id": "47cf58aa", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "- The full ingest -> index -> search -> agent pipeline runs end-to-end\n", + " from a notebook with no API server required\n", + "- The same `run_research_sync` function powers the FastAPI handler and\n", + " the Streamlit UI; this notebook is the primary smoke test for the\n", + " research pipeline\n", + "- To extend: add another collector (`app/collectors/my_collector.py`),\n", + " re-run the ingest cell, and the agent picks up the new chunks\n", + " automatically because they share the txtai index" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,py:percent" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.example.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.example.py new file mode 100644 index 000000000..c992c4115 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/notebooks/txtai.example.py @@ -0,0 +1,179 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.19.1 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% [markdown] +# # txtai Market Research — End-to-End Example +# +# Walks through the full project pipeline against AAPL: +# +# 1. Ingest a small batch of SEC filings + news articles +# 2. Confirm the documents land in the txtai index +# 3. Drive the agentic research pipeline (router -> retrievers -> synthesizer) +# 4. Inspect routing, retrievals, and the synthesized answer +# +# Prerequisites: `docker-compose up -d` must be running so KeyDB, +# PostgreSQL+pgvector, and MinIO are reachable. Secrets must be present in +# `.env` (`SEC_USER_AGENT`, `NEWSAPI_KEY`, `ALPHAVANTAGE_API_KEY`, +# `OPENAI_API_KEY`). +# +# ## Endpoints exercised in this notebook +# +# The notebook touches three layers of the project. Knowing which call +# lives where makes it easy to swap in a different driver (CLI, HTTP, +# Streamlit) for the same pipeline. +# +# ### Python entry points (used directly by this notebook) +# +# | Symbol | Source | Purpose | Cell | +# |----------------------------------------------|-------------------------------------|----------------------------------------------------------|------| +# | `app.collectors.SECCollector.collect(...)` | `app/collectors/sec_collector.py` | Pull filings into MinIO + Postgres + txtai | 1 | +# | `app.pipeline.embeddings.get_embeddings()` | `app/pipeline/embeddings.py` | Singleton `txtai.Embeddings` index used by every agent | 2 | +# | `app.pipeline.embeddings.search(query, k)` | `app/pipeline/embeddings.py` | Thin wrapper around `Embeddings.search` | 2 | +# | `app.agents.research_agent.run_research_sync(query)` | `app/agents/research_agent.py` | Sync route -> retrieve -> synthesize, returns one dict | 3 | +# | `app.agents.research_agent.run_research(query)` | `app/agents/research_agent.py` | Generator form, yields one event per pipeline stage | 4 | +# +# ### HTTP endpoints (the same pipeline, exposed for UIs) +# +# Defined in `app/api/server.py` — not invoked from this notebook, but the +# request bodies match the function signatures above. +# +# | Verb / Path | Body | Behavior | +# |-------------------------|-----------------------|-------------------------------------------------------------------------| +# | `GET /` | - | Health probe, returns the registered endpoint map | +# | `POST /research` | `{"query": "..."}` | Calls `run_research_sync`, returns a JSON dict with the full trace | +# | `POST /research/stream` | `{"query": "..."}` | SSE stream of `route` -> `retrieve` -> `synthesize` -> `done` events | +# +# ### txtai primitives under the hood +# +# Every retrieval cell ultimately resolves to a `txtai.Embeddings` call: +# +# - `Embeddings.search(query, limit)` for plain semantic top-k +# - `Embeddings.search("... where tags = 'sec'", parameters=..., limit=...)` +# for per-source scoping inside the SEC and News sub-agents +# - `Embeddings.save(path)` / `.load(path)` for the on-disk index in `data/` +# - Optional `txtai.LLM(model)` in the synthesizer when LLM credentials are set +# +# See `notebooks/txtai.API.ipynb` for each of those primitives in isolation. + +# %% +# %load_ext autoreload +# %autoreload 2 + +# System libraries. +import logging +import os +import sys +from pathlib import Path + +# Third party libraries. +from dotenv import load_dotenv + +# %% +# Make the project root importable when running the notebook from notebooks/. +project_root = Path.cwd().parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# %% +# Load secrets from .env and configure notebook logging. +load_dotenv(project_root / ".env") +_LOG = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") +_LOG.info("OPENAI_API_KEY present: %s", bool(os.getenv("OPENAI_API_KEY"))) +_LOG.info("SEC_USER_AGENT present: %s", bool(os.getenv("SEC_USER_AGENT"))) + +# %% [markdown] +# ## 1. Ingest a small batch +# +# Pull a handful of recent SEC filings for AAPL into MinIO + PostgreSQL + +# the txtai index. Limit kept small so the cell finishes in a couple of +# minutes. + +# %% +# Run the SEC collector for one ticker. +import importlib, app.collectors.sec_collector as sc +importlib.reload(sc) +from app.collectors import SECCollector + +sec = SECCollector() +sec_summary = sec.collect( + ticker="AAPL", + filing_types=["10-K", "8-K"], + limit=2, +) +_LOG.info("SEC collection summary: %s", sec_summary) + +# %% [markdown] +# ## 2. Confirm the index has content +# +# `get_embeddings()` returns the singleton index used by every agent. + +# %% +# Spot-check the index by issuing one semantic query. +from app.pipeline.embeddings import get_embeddings, search + +embeddings = get_embeddings() +_LOG.info("Index row count: %d", embeddings.count()) +hits = search("Apple revenue trend", limit=3) +for h in hits: + _LOG.info("score=%.3f text=%s", h["score"], (h.get("text") or "")[:120]) + +# %% [markdown] +# ## 3. Run the agentic pipeline synchronously +# +# `run_research_sync` is the same entry point the FastAPI `/research` +# handler calls. It returns a single dict with the full trace. + +# %% +# Drive the full route -> retrieve -> synthesize pipeline. +from app.agents.research_agent import run_research_sync + +result = run_research_sync("What are the key risks discussed in Apple's latest 10-K?") +_LOG.info("Route: %s", result.get("route")) +_LOG.info("Used LLM: %s", result.get("used_llm")) +_LOG.info("Chunk count: %d", result.get("chunk_count", 0)) +print(result.get("answer", "")) + +# %% [markdown] +# ## 4. Stream the same query +# +# `run_research` is the generator form. The FastAPI SSE endpoint and the +# Streamlit UI consume this so they can show each stage as it happens. + +# %% +# Iterate over the streaming events and log each one as it arrives. +from app.agents.research_agent import run_research + +for event in run_research("Any analyst upgrades for AAPL recently?"): + step = event["step"] + if step == "synthesize": + _LOG.info("[%s] answer=%s", step, event["payload"].get("answer", "")[:160]) + elif step == "retrieve": + chunks = event["payload"].get("chunks", []) + _LOG.info("[%s] agent=%s chunks=%d", step, event["payload"].get("agent"), len(chunks)) + else: + _LOG.info("[%s] payload keys=%s", step, list(event["payload"].keys())) + +# %% [markdown] +# ## Summary +# +# - The full ingest -> index -> search -> agent pipeline runs end-to-end +# from a notebook with no API server required +# - The same `run_research_sync` function powers the FastAPI handler and +# the Streamlit UI; this notebook is the primary smoke test for the +# research pipeline +# - To extend: add another collector (`app/collectors/my_collector.py`), +# re-run the ingest cell, and the agent picks up the new chunks +# automatically because they share the txtai index diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/requirements.txt b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/requirements.txt new file mode 100644 index 000000000..5f0b20f83 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/requirements.txt @@ -0,0 +1,38 @@ +# txtai Market Research Platform - Dependencies + +# Core txtai with all extras (includes sentence-transformers, transformers, torch) +txtai[all]>=7.0.0 + +# Streamlit for UI +streamlit>=1.28.0 + +# FastAPI and uvicorn (for potential API deployment) +fastapi>=0.104.0 +uvicorn>=0.24.0 + +# HTTP client for API calls +httpx>=0.25.0 + +# Web scraping +beautifulsoup4>=4.12.0 + +# Reddit API (optional, for social media fetching) +praw>=7.7.0 + +# Charting library +plotly>=5.18.0 + +# Environment variable management +python-dotenv>=1.0.0 + +# OpenAI API client (txtai wrapper handles actual calls) +openai>=1.3.0 + +# Redis client for KeyDB (hot tier cache) +redis>=5.0.0 + +# PostgreSQL client with connection pooling (warm tier) +psycopg2-binary>=2.9.9 + +# MinIO client for cold tier object storage +minio>=7.2.0 diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/run_jupyter.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/run_jupyter.sh new file mode 100755 index 000000000..4f0bbe316 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/run_jupyter.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# """ +# Launch Jupyter Lab server. +# +# This script starts Jupyter Lab on port 8888 with the following configuration: +# - No browser auto-launch (useful for Docker containers) +# - Accessible from any IP address (0.0.0.0) +# - Root user allowed (required for Docker environments) +# - No authentication token or password (for development convenience) +# - Vim keybindings can be enabled via JUPYTER_USE_VIM environment variable +# """ + +# Exit immediately if any command exits with a non-zero status. +set -e + +# Print each command to stdout before executing it. +#set -x + +# Import the utility functions from the project template. +GIT_ROOT=/git_root +source $GIT_ROOT/class_project/project_template/utils.sh + +# Load Docker configuration variables for this script. +get_docker_vars_script ${BASH_SOURCE[0]} +source $DOCKER_NAME +print_docker_vars + +# Configure vim keybindings and notifications. +configure_jupyter_vim_keybindings +configure_jupyter_notifications + +# Initialize Jupyter Lab command with base configuration. +JUPYTER_ARGS=$(get_jupyter_args) + +# Start Jupyter Lab with development-friendly settings. +run "jupyter lab $JUPYTER_ARGS" diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/__init__.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/__init__.py new file mode 100644 index 000000000..33ddf1eff --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/__init__.py @@ -0,0 +1,3 @@ +""" +Scripts for running collectors and other operations. +""" diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/backfill_txtai_from_chunks.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/backfill_txtai_from_chunks.py new file mode 100644 index 000000000..2f2f66c66 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/backfill_txtai_from_chunks.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 +""" +Backfill txtai search index from PostgreSQL chunks table. + +The warm tier already contains ~14,863 SEC chunks (with text + filing +metadata). The txtai index is empty, which leaves all search-backed +agents starved. This script reads chunks in batches and indexes them +into txtai with source tag "sec" and per-doc metadata. + +Re-embeds with sentence-transformers/all-mpnet-base-v2 on CPU; expect +roughly 10-30 minutes for the full 14k-chunk backfill. + +Usage: + python -m scripts.backfill_txtai_from_chunks + python -m scripts.backfill_txtai_from_chunks --batch-size 500 --limit 1000 +""" + +import argparse +import logging +import sys +import time +from pathlib import Path + +from dotenv import load_dotenv + +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from app.pipeline.embeddings import get_data_dir, get_embeddings, upsert +from app.storage import get_postgres_client + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) +_LOG = logging.getLogger(__name__) + + +_FETCH_QUERY = """ + SELECT + c.id, + c.text, + c.section, + f.ticker, + f.filing_type, + f.filing_date, + f.accession_number + FROM chunks c + JOIN filings f ON c.filing_id = f.id + WHERE c.text IS NOT NULL AND LENGTH(c.text) > 0 + ORDER BY c.id + LIMIT %s OFFSET %s +""" + +# SEC-issued form types — anything else collected via the news pipeline is +# tagged as "news" so the txtai source filter still works. +_SEC_FILING_TYPES = { + "10-K", + "10-K/A", + "10-Q", + "10-Q/A", + "8-K", + "8-K/A", + "DEF 14A", + "DEFA14A", + "DEF 14C", + "S-1", + "S-3", + "S-4", + "S-8", + "20-F", + "20-F/A", + "40-F", + "6-K", + "13F-HR", + "13F-HR/A", + "SC 13D", + "SC 13D/A", + "SC 13G", + "SC 13G/A", + "SD", +} + + +def _classify_source(filing_type: str) -> str: + """Return the txtai source tag for a row's filing_type.""" + if filing_type in _SEC_FILING_TYPES: + return "sec" + return "news" + + +def _count_chunks(client) -> int: + """Count chunks eligible for indexing.""" + with client.get_cursor() as cur: + cur.execute( + "SELECT COUNT(*) AS n FROM chunks " + "WHERE text IS NOT NULL AND LENGTH(text) > 0" + ) + return cur.fetchone()["n"] + + +def _fetch_batch(client, batch_size: int, offset: int) -> list[dict]: + """Fetch one batch of chunks joined with filing metadata.""" + with client.get_cursor() as cur: + cur.execute(_FETCH_QUERY, (batch_size, offset)) + rows = cur.fetchall() + documents = [] + for row in rows: + filing_date = row["filing_date"] + filing_type = row["filing_type"] or "" + source = _classify_source(filing_type) + documents.append( + { + "id": row["id"], + "text": row["text"], + "tags": source, + "metadata": { + "ticker": row["ticker"], + "filing_type": filing_type, + "filing_date": str(filing_date) if filing_date else None, + "section": row["section"], + "accession_number": row["accession_number"], + "source": source, + }, + } + ) + return documents + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Backfill txtai index from chunks table" + ) + parser.add_argument( + "--batch-size", + type=int, + default=500, + help="Chunks per upsert call (default: 500)", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Cap total chunks indexed (default: all)", + ) + parser.add_argument( + "--from-scratch", + action="store_true", + help="Delete existing txtai index files before backfilling", + ) + return parser.parse_args() + + +def main() -> int: + load_dotenv() + args = _parse_args() + if args.from_scratch: + data_dir = get_data_dir() + for f in data_dir.iterdir(): + if f.is_file(): + _LOG.info("Removing %s", f) + f.unlink() + pg_client = get_postgres_client() + total_eligible = _count_chunks(pg_client) + target = min(total_eligible, args.limit) if args.limit else total_eligible + _LOG.info("Eligible chunks in warm tier: %d", total_eligible) + _LOG.info("Will index: %d (batch=%d)", target, args.batch_size) + # Warm up embeddings model so the first batch isn't artificially slow. + embeddings = get_embeddings() + _LOG.info("Embeddings model loaded") + indexed = 0 + offset = 0 + started = time.time() + while indexed < target: + remaining = target - indexed + this_batch = min(args.batch_size, remaining) + batch = _fetch_batch(pg_client, this_batch, offset) + if not batch: + _LOG.info("No more rows at offset=%d", offset) + break + t0 = time.time() + upsert(batch, save=False) + indexed += len(batch) + offset += len(batch) + elapsed = time.time() - t0 + rate = len(batch) / elapsed if elapsed > 0 else 0 + eta_seconds = (target - indexed) / rate if rate > 0 else 0 + _LOG.info( + "Indexed %d/%d (%.1f%%) batch_time=%.1fs rate=%.1f/s eta=%.0fs", + indexed, + target, + 100.0 * indexed / target, + elapsed, + rate, + eta_seconds, + ) + # Persist ANN index files once at the end. + _LOG.info("Saving txtai index to %s", get_data_dir()) + embeddings.save(str(get_data_dir())) + total_elapsed = time.time() - started + _LOG.info( + "Backfill complete: %d chunks in %.1fs (%.1f/s)", + indexed, + total_elapsed, + indexed / total_elapsed if total_elapsed > 0 else 0, + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/check_storage_status.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/check_storage_status.py new file mode 100644 index 000000000..b4dd85592 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/check_storage_status.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 +""" +Storage Status Check Script + +Displays content statistics across all storage tiers: +- Hot Tier: KeyDB (cache, sessions, prices) +- Warm Tier: PostgreSQL (filings, chunks, XBRL facts) +- Cold Tier: MinIO (raw document archive) +- Search Tier: txtai (semantic search index) + +Usage: + python -m scripts.check_storage_status +""" + +import json +import logging +import sys +from pathlib import Path + +from dotenv import load_dotenv + +# Add project root to path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from app.storage import get_keydb_client, get_postgres_client, get_minio_client +from app.pipeline.embeddings import get_embeddings + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +_LOG = logging.getLogger(__name__) + + +def check_hot_tier_keydb() -> dict: + """Check KeyDB (hot tier) content.""" + result = { + "status": "unknown", + "keys": {}, + "total_keys": 0, + } + + try: + client = get_keydb_client() + if not client.ping(): + result["status"] = "disconnected" + return result + + result["status"] = "connected" + + # Count keys by pattern + patterns = { + "prices:*": "Price Cache", + "cache:*": "Semantic Cache", + "session:*": "Sessions", + "fetch:*": "Fetch Cache", + } + + for pattern, label in patterns.items(): + keys = list(client.scan_iter(pattern)) + result["keys"][pattern] = len(keys) + result["total_keys"] += len(keys) + + # Get some sample keys + all_keys = list(client.scan_iter("*", count=100)) + result["sample_keys"] = all_keys[:10] + + # Get KeyDB info + info = client._get_client().info() + result["memory_used"] = info.get("used_memory_human", "unknown") + result["connected_clients"] = info.get("connected_clients", 0) + + except Exception as e: + result["status"] = f"error: {e}" + + return result + + +def check_warm_tier_postgres() -> dict: + """Check PostgreSQL (warm tier) content.""" + result = { + "status": "unknown", + "tables": {}, + "total_records": 0, + } + + try: + client = get_postgres_client() + if not client.ping(): + result["status"] = "disconnected" + return result + + result["status"] = "connected" + + # Get stats from each table + queries = { + "companies": "SELECT COUNT(*) FROM companies", + "filings": "SELECT COUNT(*) FROM filings", + "chunks": "SELECT COUNT(*) FROM chunks", + "document_metadata": "SELECT COUNT(*) FROM document_metadata", + "xbrl_facts": "SELECT COUNT(*) FROM xbrl_facts", + "collection_runs": "SELECT COUNT(*) FROM collection_runs", + } + + with client.get_cursor() as cur: + for table, query in queries.items(): + try: + cur.execute(query) + count = cur.fetchone()[0] + result["tables"][table] = count + result["total_records"] += count + except Exception as e: + result["tables"][table] = f"error: {e}" + + # Get filings by ticker + try: + cur.execute(""" + SELECT ticker, COUNT(*) as count, + STRING_AGG(DISTINCT filing_type, ', ') as types + FROM filings + GROUP BY ticker + ORDER BY count DESC + LIMIT 10 + """) + result["filings_by_ticker"] = [ + {"ticker": row[0], "count": row[1], "types": row[2]} + for row in cur.fetchall() + ] + except Exception: + pass + + # Get chunks with embeddings + try: + cur.execute("SELECT COUNT(*) FROM chunks WHERE embedding IS NOT NULL") + result["chunks_with_embeddings"] = cur.fetchone()[0] + except Exception: + pass + + # Get recent collection runs + try: + cur.execute(""" + SELECT collector, ticker, status, records_written, started_at + FROM collection_runs + ORDER BY started_at DESC + LIMIT 5 + """) + result["recent_runs"] = [ + { + "collector": row[0], + "ticker": row[1], + "status": row[2], + "records": row[3], + "started_at": str(row[4]), + } + for row in cur.fetchall() + ] + except Exception: + pass + + except Exception as e: + result["status"] = f"error: {e}" + + return result + + +def check_cold_tier_minio() -> dict: + """Check MinIO (cold tier) content.""" + result = { + "status": "unknown", + "buckets": {}, + "total_objects": 0, + } + + try: + client = get_minio_client() + if not client.ping(): + result["status"] = "disconnected" + return result + + result["status"] = "connected" + + minio_client = client._get_client() + + # Check each expected bucket + bucket_names = ["filings", "articles", "raw_docs"] + + for bucket_name in bucket_names: + try: + # Check if bucket exists + exists = minio_client.bucket_exists(bucket_name) + if not exists: + result["buckets"][bucket_name] = {"exists": False, "objects": 0} + continue + + # Count objects + objects = list(minio_client.list_objects(bucket_name, recursive=True)) + result["buckets"][bucket_name] = { + "exists": True, + "objects": len(objects), + } + result["total_objects"] += len(objects) + + except Exception as e: + result["buckets"][bucket_name] = {"exists": False, "error": str(e)} + + # Get bucket with most objects + if result["buckets"]: + max_bucket = max( + [(k, v.get("objects", 0)) for k, v in result["buckets"].items()], + key=lambda x: x[1], + default=(None, 0), + ) + result["largest_bucket"] = max_bucket[0] + result["largest_bucket_count"] = max_bucket[1] + + except Exception as e: + result["status"] = f"error: {e}" + + return result + + +def check_search_tier_txtai() -> dict: + """Check txtai (search tier) content.""" + result = { + "status": "unknown", + "index_count": 0, + } + + try: + embeddings = get_embeddings() + + # Use txtai's public count() API. Note: ``embeddings.index`` is a + # method that builds the index, not the loaded index object, so + # introspecting it gives a misleading 0. + try: + result["index_count"] = embeddings.count() + except Exception: + pass + + # Check data directory + from app.pipeline.embeddings import get_data_dir + + data_dir = get_data_dir() + index_file = data_dir / "index.db" + + result["data_dir"] = str(data_dir) + result["index_file_exists"] = index_file.exists() + result["index_file_size"] = ( + index_file.stat().st_size if index_file.exists() else 0 + ) + + result["status"] = "connected" + + except Exception as e: + result["status"] = f"error: {e}" + + return result + + +def print_summary(hot, warm, cold, search) -> None: + """Print formatted summary.""" + print("\n" + "=" * 70) + print("STORAGE TIER STATUS SUMMARY") + print("=" * 70) + + # Hot Tier + print("\n🔥 HOT TIER (KeyDB)") + print(f" Status: {hot['status'].upper()}") + if hot["status"] == "connected": + print(f" Total Keys: {hot['total_keys']}") + print(f" Memory Used: {hot.get('memory_used', 'N/A')}") + print(f" Connected Clients: {hot.get('connected_clients', 'N/A')}") + for pattern, count in hot.get("keys", {}).items(): + print(f" - {pattern}: {count}") + + # Warm Tier + print("\n📊 WARM TIER (PostgreSQL)") + print(f" Status: {warm['status'].upper()}") + if warm["status"] == "connected": + print(f" Total Records: {warm['total_records']}") + for table, count in warm.get("tables", {}).items(): + print(f" - {table}: {count}") + if warm.get("chunks_with_embeddings"): + print(f" Chunks with Embeddings: {warm['chunks_with_embeddings']}") + if warm.get("filings_by_ticker"): + print("\n Filings by Ticker:") + for item in warm["filings_by_ticker"]: + print(f" - {item['ticker']}: {item['count']} ({item['types']})") + if warm.get("recent_runs"): + print("\n Recent Collection Runs:") + for run in warm["recent_runs"]: + print( + f" - {run['collector']} ({run['ticker']}): {run['status']} - {run['records']} records" + ) + + # Cold Tier + print("\n❄️ COLD TIER (MinIO)") + print(f" Status: {cold['status'].upper()}") + if cold["status"] == "connected": + print(f" Total Objects: {cold['total_objects']}") + for bucket, info in cold.get("buckets", {}).items(): + if info.get("exists"): + print(f" - {bucket}: {info['objects']} objects") + else: + print(f" - {bucket}: (not created)") + + # Search Tier + print("\n🔍 SEARCH TIER (txtai)") + print(f" Status: {search['status'].upper()}") + if search["status"] == "connected": + print(f" Index Count: {search['index_count']}") + print(f" Data Dir: {search['data_dir']}") + print( + f" Index File: {search['index_file_exists']} ({search['index_file_size']} bytes)" + ) + + print("\n" + "=" * 70) + + +def main() -> int: + """Main entry point.""" + load_dotenv() + + _LOG.info("Checking storage tiers...") + + # Check all tiers + hot_status = check_hot_tier_keydb() + warm_status = check_warm_tier_postgres() + cold_status = check_cold_tier_minio() + search_status = check_search_tier_txtai() + + # Print summary + print_summary(hot_status, warm_status, cold_status, search_status) + + # Save detailed status + detailed_status = { + "timestamp": Path().stat().st_mtime, + "hot_tier": hot_status, + "warm_tier": warm_status, + "cold_tier": cold_status, + "search_tier": search_status, + } + + status_file = Path("storage_status.json") + with open(status_file, "w") as f: + json.dump(detailed_status, f, indent=2, default=str) + _LOG.info(f"Detailed status saved to {status_file}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/eval_research.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/eval_research.py new file mode 100644 index 000000000..3dda8947f --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/eval_research.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +""" +Eval harness for the agentic research pipeline. + +Runs a fixed set of benchmark queries, repeats each ``--repeats`` times to +average out noise, then reports: + + - p50 / p95 / p99 total latency + - Per-step latency breakdown (route / retrieve / synthesize) + - Routing accuracy (extracted ticker matches expected; selected agents + match expected set) + - Retrieval health (% queries returning >= 1 chunk, mean top-1 score) + - Answer length distribution + +Usage: + python -m scripts.eval_research # full benchmark + python -m scripts.eval_research --repeats 5 # 5 runs/query + python -m scripts.eval_research --json out.json # also dump JSON +""" + +import argparse +import json +import logging +import statistics +import sys +import time +from pathlib import Path + +from dotenv import load_dotenv + +# Make ``app`` importable when run from the repo root. +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from app.agents.research_agent import run_research_sync + +logging.basicConfig(level=logging.WARNING, format="%(message)s") +_LOG = logging.getLogger(__name__) + +# ############################################################################# +# Benchmark set +# ############################################################################# + +# (query, expected_ticker, expected_agents_subset) +# ``expected_agents_subset`` is a set of agents the router *must* fire; extra +# agents are OK (e.g. if the router defaults to both, that's fine when only +# {"sec"} was expected). +_BENCHMARK = [ + ( + "What does Apple disclose as risk factors in its 10-K?", + "AAPL", + {"sec"}, + ), + ( + "What is the recent NVDA news sentiment?", + "NVDA", + {"news"}, + ), + ( + "How does JPMorgan describe regulatory risk?", + "JPM", + {"sec"}, + ), + ( + "What are analysts saying about Tesla?", + "TSLA", + {"news"}, + ), + ( + "Summarize Microsoft 8-K disclosures", + "MSFT", + {"sec"}, + ), + ( + "Tell me about NVIDIA", + "NVDA", + {"sec", "news"}, + ), + ( + "Recent news around $AMD", + "AMD", + {"news"}, + ), + ( + "What does Goldman Sachs say in its filings about market risk?", + "GS", + {"sec"}, + ), + ( + "Wells Fargo earnings outlook", + "WFC", + {"news"}, + ), + ( + "Pfizer 10-K risk factors", + "PFE", + {"sec"}, + ), +] + + +# ############################################################################# +# Stat helpers +# ############################################################################# + + +def _percentile(values: list[float], p: float) -> float: + """ + Compute a percentile without numpy. + + :param values: sample values + :param p: percentile in [0, 100] + :return: percentile value (0.0 if the input is empty) + """ + if not values: + return 0.0 + s = sorted(values) + k = (len(s) - 1) * (p / 100.0) + lo = int(k) + hi = min(lo + 1, len(s) - 1) + frac = k - lo + return s[lo] * (1 - frac) + s[hi] * frac + + +def _summary(label: str, values: list[float], unit: str = "ms") -> str: + """ + Format a one-line summary of a numeric distribution. + + :param label: column label printed at the start of the line + :param values: sample values + :param unit: unit suffix (e.g. "ms", "chunks") + :return: tab-aligned summary line + """ + if not values: + return f"{label:<24} (no data)" + mean = statistics.fmean(values) + p50 = _percentile(values, 50) + p95 = _percentile(values, 95) + p99 = _percentile(values, 99) + return ( + f"{label:<24} mean={mean:7.1f}{unit} " + f"p50={p50:7.1f}{unit} " + f"p95={p95:7.1f}{unit} " + f"p99={p99:7.1f}{unit} " + f"n={len(values)}" + ) + + +# ############################################################################# +# Main +# ############################################################################# + + +def _parse_args() -> argparse.Namespace: + """ + Parse CLI args. + """ + parser = argparse.ArgumentParser(description="Eval the agentic research pipeline.") + parser.add_argument( + "--repeats", + type=int, + default=3, + help="Number of runs per query (default: 3)", + ) + parser.add_argument( + "--json", + type=str, + default=None, + help="Optional path to dump the full per-run JSON", + ) + parser.add_argument( + "--warmup", + action="store_true", + help="Run one extra warmup pass before timing", + ) + return parser.parse_args() + + +def _main() -> int: + """ + Run the benchmark and print the report. + """ + load_dotenv() + args = _parse_args() + if args.warmup: + print("Warming up…") + run_research_sync(_BENCHMARK[0][0]) + print( + f"Running {len(_BENCHMARK)} queries x {args.repeats} repeats = " + f"{len(_BENCHMARK) * args.repeats} total runs" + ) + print("─" * 80) + runs: list[dict] = [] + bench_t0 = time.perf_counter() + for query, expected_ticker, expected_agents in _BENCHMARK: + print(f"\n>>> {query}") + for r in range(args.repeats): + t0 = time.perf_counter() + result = run_research_sync(query) + wall_ms = (time.perf_counter() - t0) * 1000.0 + timings = result.get("timings", {}) + route = result.get("route", {}) + actual_ticker = route.get("ticker") + actual_agents = set(route.get("agents", [])) + ticker_ok = actual_ticker == expected_ticker + agents_ok = expected_agents.issubset(actual_agents) + top_score = result["sources"][0]["score"] if result.get("sources") else 0.0 + runs.append( + { + "query": query, + "repeat": r, + "wall_ms": wall_ms, + "timings": timings, + "expected_ticker": expected_ticker, + "actual_ticker": actual_ticker, + "ticker_ok": ticker_ok, + "expected_agents": sorted(expected_agents), + "actual_agents": sorted(actual_agents), + "agents_ok": agents_ok, + "chunk_count": result.get("chunk_count", 0), + "top_score": top_score, + "answer_len": len(result.get("answer", "")), + "used_llm": result.get("used_llm", False), + } + ) + ok = "✓" if (ticker_ok and agents_ok) else "✗" + print( + f" [{r + 1}/{args.repeats}] {ok} ticker={actual_ticker} " + f"agents={sorted(actual_agents)} chunks={result.get('chunk_count', 0)} " + f"top={top_score:.3f} wall={wall_ms:.0f}ms" + ) + bench_total_s = time.perf_counter() - bench_t0 + print("\n" + "═" * 80) + print(f"REPORT ({len(runs)} runs in {bench_total_s:.1f}s)") + print("═" * 80) + # Latency. + print("\n— Latency —") + print(_summary("total wall", [r["wall_ms"] for r in runs])) + for key in ["route_ms", "retrieve_sec_ms", "retrieve_news_ms", "synthesize_ms"]: + vals = [r["timings"][key] for r in runs if key in r["timings"]] + print(_summary(key, vals)) + # Routing accuracy. + print("\n— Routing accuracy —") + ticker_acc = sum(r["ticker_ok"] for r in runs) / len(runs) + agents_acc = sum(r["agents_ok"] for r in runs) / len(runs) + both_acc = sum(r["ticker_ok"] and r["agents_ok"] for r in runs) / len(runs) + print(f" ticker correct {ticker_acc:6.1%}") + print(f" agents superset {agents_acc:6.1%}") + print(f" both correct {both_acc:6.1%}") + misses = [r for r in runs if not (r["ticker_ok"] and r["agents_ok"])] + if misses: + print(" failures:") + seen = set() + for r in misses: + key = (r["query"], r["actual_ticker"], tuple(r["actual_agents"])) + if key in seen: + continue + seen.add(key) + print( + f" - {r['query'][:55]:<55} " + f"got ticker={r['actual_ticker']} agents={r['actual_agents']}" + ) + # Retrieval. + print("\n— Retrieval —") + nonempty = [r for r in runs if r["chunk_count"] > 0] + print(f" queries w/ chunks {len(nonempty) / len(runs):6.1%}") + print(_summary("chunks/query", [r["chunk_count"] for r in runs], unit=" ")) + print(_summary("top-1 score", [r["top_score"] for r in nonempty], unit="")) + # Answer. + print("\n— Answer —") + print(_summary("answer length (chars)", [r["answer_len"] for r in runs], unit="")) + llm_pct = sum(r["used_llm"] for r in runs) / len(runs) + print(f" used LLM synthesis {llm_pct:6.1%}") + print() + if args.json: + Path(args.json).write_text(json.dumps(runs, indent=2)) + print(f"Wrote per-run JSON to {args.json}") + return 0 + + +if __name__ == "__main__": + sys.exit(_main()) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_all_collectors.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_all_collectors.py new file mode 100644 index 000000000..2dd057310 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_all_collectors.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +""" +Comprehensive Market Research Data Collection Script + +Runs collectors to gather data from: +- SEC EDGAR (10-K, 10-Q, 8-K, DEF 14A filings) +- News articles (NewsAPI, Alpha Vantage) + +Usage: + python -m scripts.run_all_collectors --tickers AAPL,MSFT,TSLA --sec-limit 500 + python -m scripts.run_all_collectors --sector tech --sec-limit 1000 +""" + +import argparse +import json +import logging +import sys +import time +from datetime import datetime +from pathlib import Path + +from dotenv import load_dotenv + +# Add project root to path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from app.collectors import SECCollector, NewsCollector + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +_LOG = logging.getLogger(__name__) + +# Sector definitions +SECTORS = { + "tech": ["AAPL", "MSFT", "GOOGL", "META", "NVDA", "ADBE", "CRM", "ORCL"], + "finance": ["JPM", "BAC", "WFC", "GS", "MS", "C", "BLK", "V"], + "healthcare": ["JNJ", "UNH", "PFE", "MRK", "ABBV", "TMO", "ABT", "LLY"], + "consumer": ["AMZN", "WMT", "HD", "PG", "KO", "PEP", "COST", "MCD"], + "ev": ["TSLA", "RIVN", "LCID", "NIO", "XPEV", "LI"], +} + +# Default filing types for SEC collection +DEFAULT_FILING_TYPES = ["10-K", "10-Q", "8-K", "DEF 14A"] + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Comprehensive Market Research Data Collection" + ) + + # Ticker selection + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--tickers", + type=str, + default="AAPL", + help="Comma-separated tickers (default: AAPL)", + ) + group.add_argument( + "--sector", + type=str, + choices=list(SECTORS.keys()), + help="Collect all tickers from a sector", + ) + + # Collection limits + parser.add_argument( + "--sec-limit", + type=int, + default=500, + help="Max SEC filings per ticker (default: 500)", + ) + parser.add_argument( + "--news-limit", + type=int, + default=50, + help="Max news articles per ticker (default: 50)", + ) + parser.add_argument( + "--days-back", type=int, default=30, help="Days back for news (default: 30)" + ) + + # Selective collection + parser.add_argument( + "--sec-only", action="store_true", help="Only run SEC collector" + ) + parser.add_argument("--skip-sec", action="store_true", help="Skip SEC collector") + parser.add_argument("--skip-news", action="store_true", help="Skip news collector") + + # Storage options + parser.add_argument( + "--no-cold", action="store_true", help="Skip cold storage (MinIO)" + ) + parser.add_argument( + "--no-warm", action="store_true", help="Skip warm storage (PostgreSQL)" + ) + parser.add_argument( + "--no-search", action="store_true", help="Skip search index (txtai)" + ) + + # Output + parser.add_argument( + "--output-stats", + type=str, + default="collection_stats.json", + help="Output file for collection statistics", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable debug logging" + ) + + return parser.parse_args() + + +def get_ticker_list(args) -> list[str]: + """Get list of tickers to collect.""" + if args.sector: + return SECTORS[args.sector] + if args.tickers: + return [t.strip() for t in args.tickers.split(",")] + return ["AAPL"] + + +def run_sec_collector(ticker: str, args) -> dict: + """Run SEC EDGAR collector.""" + _LOG.info(f" [{ticker}] SEC EDGAR: Fetching up to {args.sec_limit} filings...") + + collector = SECCollector() + results = collector.collect( + ticker=ticker, + filing_types=DEFAULT_FILING_TYPES, + limit=args.sec_limit, + max_filings=args.sec_limit + 500, # Fetch extra to account for filtering + store_cold=not args.no_cold, + store_warm=not args.no_warm, + store_search=not args.no_search, + ) + + return { + "source": "sec", + "fetched": results.get("fetched", 0), + "stored_cold": results.get("stored_cold", 0), + "stored_warm": results.get("stored_warm", 0), + "indexed": results.get("indexed", 0), + } + + +def run_news_collector(ticker: str, args) -> dict: + """Run News collector.""" + _LOG.info(f" [{ticker}] News: Fetching up to {args.news_limit} articles...") + + collector = NewsCollector() + results = collector.collect( + ticker=ticker, + days_back=args.days_back, + limit=args.news_limit, + store_cold=not args.no_cold, + store_warm=not args.no_warm, + store_search=not args.no_search, + ) + + return { + "source": "news", + "fetched": results.get("fetched", 0), + "stored_cold": results.get("stored_cold", 0), + "stored_warm": results.get("stored_warm", 0), + "indexed": results.get("indexed", 0), + } + + +def main() -> int: + """Main entry point.""" + args = parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + load_dotenv() + + # Get ticker list + tickers = get_ticker_list(args) + + _LOG.info("=" * 70) + _LOG.info("COMPREHENSIVE MARKET RESEARCH DATA COLLECTION") + _LOG.info("=" * 70) + _LOG.info( + "Tickers: %d (%s)", + len(tickers), + ", ".join(tickers[:5]) + ("..." if len(tickers) > 5 else ""), + ) + _LOG.info("SEC Limit: %d filings/ticker", args.sec_limit) + _LOG.info("News Limit: %d articles/ticker", args.news_limit) + _LOG.info("Days Back: %d", args.days_back) + _LOG.info("=" * 70) + + # Determine which collectors to run + run_sec = not args.skip_sec or args.sec_only + run_news = not args.skip_news and not args.sec_only + + if args.sec_only: + _LOG.info("Running: SEC ONLY") + else: + collectors = [] + if run_sec: + collectors.append("SEC") + if run_news: + collectors.append("News") + _LOG.info("Running collectors: %s", ", ".join(collectors)) + + _LOG.info("=" * 70) + + # Initialize collectors + all_stats = {} + start_time = time.time() + + for i, ticker in enumerate(tickers): + _LOG.info("\n[%d/%d] Processing ticker: %s", i + 1, len(tickers), ticker) + _LOG.info("-" * 50) + + ticker_stats = {} + + try: + # SEC Collector + if run_sec: + stats = run_sec_collector(ticker, args) + ticker_stats["sec"] = stats + _LOG.info( + f" -> SEC: {stats['fetched']} fetched, {stats['stored_warm']} stored warm" + ) + + # News Collector + if run_news: + stats = run_news_collector(ticker, args) + ticker_stats["news"] = stats + _LOG.info( + f" -> News: {stats['fetched']} fetched, {stats['stored_warm']} stored warm" + ) + + except Exception as e: + _LOG.error(f"Error collecting for {ticker}: {e}") + ticker_stats["error"] = str(e) + + all_stats[ticker] = ticker_stats + + # Delay between tickers (rate limiting) + if i < len(tickers) - 1: + delay = 3.0 # seconds + _LOG.info(f" Waiting {delay}s before next ticker...") + time.sleep(delay) + + # Calculate totals + elapsed = time.time() - start_time + + _LOG.info("\n" + "=" * 70) + _LOG.info("COLLECTION COMPLETE") + _LOG.info("=" * 70) + + # Aggregate stats by source + totals_by_source = {} + for ticker, stats in all_stats.items(): + for source, source_stats in stats.items(): + if source == "error": + continue + if source not in totals_by_source: + totals_by_source[source] = { + "fetched": 0, + "stored_warm": 0, + "indexed": 0, + } + for key in ["fetched", "stored_warm", "indexed"]: + totals_by_source[source][key] += source_stats.get(key, 0) + + _LOG.info("\nTotals by Source:") + grand_total = 0 + for source, totals in totals_by_source.items(): + _LOG.info( + f" {source.upper()}: {totals['fetched']} fetched, {totals['stored_warm']} stored warm, {totals['indexed']} indexed" + ) + grand_total += totals["fetched"] + + _LOG.info(f"\nGrand Total: {grand_total} documents collected") + _LOG.info(f"Total Time: {elapsed:.1f} seconds ({elapsed / 60:.1f} minutes)") + _LOG.info("=" * 70) + + # Save statistics + output_data = { + "timestamp": datetime.now().isoformat(), + "tickers": tickers, + "results": all_stats, + "totals_by_source": totals_by_source, + "grand_total": grand_total, + "elapsed_seconds": elapsed, + } + + with open(args.output_stats, "w") as f: + json.dump(output_data, f, indent=2) + _LOG.info(f"Statistics saved to {args.output_stats}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_earnings_collector.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_earnings_collector.py new file mode 100644 index 000000000..562b872b8 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_earnings_collector.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +""" +Earnings Call Transcript Collector Script. + +Fetches earnings transcripts from Alpha Vantage and stores them across all +storage tiers. + +Usage: + python -m scripts.run_earnings_collector --ticker AAPL --quarters 4 + python -m scripts.run_earnings_collector -t MSFT --year 2024 --quarter 1 + +Environment Variables Required: + - ALPHAVANTAGE_API_KEY: free tier supports EARNINGS_CALL_TRANSCRIPT (25/day) + - POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DB, POSTGRES_USER, POSTGRES_PASSWORD + - MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY + - OPENAI_API_KEY or OLLAMA_HOST (one of, for embeddings) +""" + +import argparse +import logging +import os +import sys +from pathlib import Path + +from dotenv import load_dotenv + +# Add project root to path. +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from app.collectors import EarningsCollector + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +_LOG = logging.getLogger(__name__) + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Earnings Call Transcript Collector for txtai Market Research" + ) + parser.add_argument( + "-t", "--ticker", type=str, default="AAPL", help="Stock ticker (default: AAPL)" + ) + parser.add_argument( + "-q", + "--quarters", + type=int, + default=4, + help="Number of trailing quarters to fetch (default: 4)", + ) + parser.add_argument( + "--year", + type=int, + default=None, + help="Explicit fiscal year (use with --quarter)", + ) + parser.add_argument( + "--quarter", + type=int, + choices=[1, 2, 3, 4], + default=None, + help="Explicit fiscal quarter 1-4 (use with --year)", + ) + parser.add_argument( + "--no-cold", action="store_true", help="Skip cold storage (MinIO)" + ) + parser.add_argument( + "--no-warm", action="store_true", help="Skip warm storage (PostgreSQL)" + ) + parser.add_argument( + "--no-search", action="store_true", help="Skip search index (txtai)" + ) + parser.add_argument( + "--use-cache", action="store_true", help="Use cached results if available" + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable debug logging" + ) + return parser.parse_args() + + +def validate_environment() -> bool: + """Validate required environment variables.""" + required = [ + "ALPHAVANTAGE_API_KEY", + "POSTGRES_HOST", + "POSTGRES_DB", + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "MINIO_ENDPOINT", + "MINIO_ACCESS_KEY", + "MINIO_SECRET_KEY", + ] + has_openai = bool(os.getenv("OPENAI_API_KEY")) + has_ollama = bool(os.getenv("OLLAMA_HOST")) + missing = [v for v in required if not os.getenv(v)] + if not has_openai and not has_ollama: + missing.append("OPENAI_API_KEY or OLLAMA_HOST (need one for embeddings)") + if missing: + _LOG.error("Missing required environment variables: %s", ", ".join(missing)) + _LOG.info("Copy .env.example to .env and fill in the values") + return False + return True + + +def main() -> int: + """Main entry point.""" + args = parse_args() + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + _LOG.info("=" * 60) + _LOG.info("Earnings Transcript Collector") + _LOG.info("=" * 60) + load_dotenv() + if not validate_environment(): + return 1 + # year+quarter is all-or-nothing. + if (args.year is None) ^ (args.quarter is None): + _LOG.error("--year and --quarter must be provided together") + return 1 + _LOG.info("Configuration:") + _LOG.info(" Ticker: %s", args.ticker) + if args.year is not None: + _LOG.info(" Quarter: %sQ%s", args.year, args.quarter) + else: + _LOG.info(" Trailing Qs: %d", args.quarters) + _LOG.info(" Cold Storage: %s", "disabled" if args.no_cold else "enabled") + _LOG.info(" Warm Storage: %s", "disabled" if args.no_warm else "enabled") + _LOG.info(" Search Index: %s", "disabled" if args.no_search else "enabled") + _LOG.info(" Use Cache: %s", "yes" if args.use_cache else "no") + _LOG.info("=" * 60) + collector = EarningsCollector() + try: + results = collector.collect( + ticker=args.ticker, + quarters=args.quarters, + year=args.year, + quarter=args.quarter, + store_cold=not args.no_cold, + store_warm=not args.no_warm, + store_search=not args.no_search, + use_cache=args.use_cache, + ) + _LOG.info("=" * 60) + _LOG.info("Collection Results:") + _LOG.info(" Transcripts Fetched: %d", results.get("fetched", 0)) + _LOG.info(" Stored in Cold: %d", results.get("stored_cold", 0)) + _LOG.info(" Stored in Warm: %d", results.get("stored_warm", 0)) + _LOG.info(" Indexed for Search: %d", results.get("indexed", 0)) + _LOG.info("=" * 60) + if results.get("fetched", 0) == 0: + _LOG.warning( + "No transcripts were fetched. Possible causes: " + "Alpha Vantage rate limit (25/day on free tier), no transcript " + "available for the requested quarter, or invalid ticker." + ) + return 0 + _LOG.info("Earnings collection completed successfully!") + return 0 + except Exception as e: + _LOG.exception("Collection failed: %s", e) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_sec_bulk.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_sec_bulk.py new file mode 100644 index 000000000..93d4a5e8a --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_sec_bulk.py @@ -0,0 +1,560 @@ +#!/usr/bin/env python3 +""" +Bulk SEC EDGAR Collector + +Scrape SEC filings for multiple companies with filtering options. + +Usage: + # Scrape predefined groups + python -m scripts.run_sec_bulk --group faang + python -m scripts.run_sec_bulk --group banks + python -m scripts.run_sec_bulk --group all + + # Scrape a custom list of tickers + python -m scripts.run_sec_bulk --tickers AAPL MSFT GOOGL AMZN + + # Filter by filing type and date range + python -m scripts.run_sec_bulk --group faang --filing-types 10-K + python -m scripts.run_sec_bulk --tickers TSLA --filing-types 10-K,8-K --after 2022-01-01 + python -m scripts.run_sec_bulk --tickers AAPL --filing-types 10-K --before 2023-12-31 + + # Control concurrency and limits + python -m scripts.run_sec_bulk --group faang --limit 5 --workers 3 + + # Dry run — shows what would be scraped without storing + python -m scripts.run_sec_bulk --group banks --dry-run + + # Skip already-scraped tickers (checks filings table first) + python -m scripts.run_sec_bulk --group all --skip-existing + +Environment Variables Required: + POSTGRES_HOST, POSTGRES_DB, POSTGRES_USER, POSTGRES_PASSWORD + MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY + OPENAI_API_KEY or OLLAMA_HOST +""" + +import argparse +import logging +import os +import sys +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from dotenv import load_dotenv + +# Add project root to path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from app.collectors import SECCollector +from app.storage import get_postgres_client + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +_LOG = logging.getLogger(__name__) + +# ── Predefined ticker groups ────────────────────────────────────────────────── + +TICKER_GROUPS = { + "faang": ["META", "AAPL", "AMZN", "NFLX", "GOOGL"], + "big_tech": [ + "AAPL", + "MSFT", + "GOOGL", + "AMZN", + "META", + "NVDA", + "TSLA", + "ORCL", + "IBM", + "INTC", + ], + "banks": ["JPM", "BAC", "WFC", "GS", "MS", "C", "USB", "PNC", "TFC", "COF"], + "healthcare": [ + "JNJ", + "UNH", + "PFE", + "ABBV", + "MRK", + "TMO", + "ABT", + "DHR", + "BMY", + "AMGN", + ], + "energy": ["XOM", "CVX", "COP", "EOG", "SLB", "MPC", "PSX", "VLO", "OXY", "HAL"], + "retail": ["WMT", "AMZN", "COST", "TGT", "HD", "LOW", "TJX", "ROST", "DG", "DLTR"], + "ev": ["TSLA", "RIVN", "LCID", "NIO", "XPEV", "F", "GM", "STLA"], + "semiconductors": [ + "NVDA", + "AMD", + "INTC", + "QCOM", + "AVGO", + "TXN", + "MU", + "AMAT", + "KLAC", + "LRCX", + ], + "sp500_sample": [ + "AAPL", + "MSFT", + "AMZN", + "NVDA", + "GOOGL", + "META", + "BRK-B", + "LLY", + "AVGO", + "JPM", + "TSLA", + "UNH", + "XOM", + "V", + "MA", + "JNJ", + "PG", + "HD", + "COST", + "MRK", + ], + # All groups combined (deduplicated) + "all": [], +} + +# Populate "all" from all other groups +_all_tickers: list[str] = [] +for _group, _tickers in TICKER_GROUPS.items(): + if _group != "all": + _all_tickers.extend(_tickers) +TICKER_GROUPS["all"] = list(dict.fromkeys(_all_tickers)) # preserve order, dedupe + + +# ── Argument parsing ────────────────────────────────────────────────────────── + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Bulk SEC EDGAR collector — scrape multiple companies", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # ── Ticker selection + ticker_group = parser.add_mutually_exclusive_group(required=True) + ticker_group.add_argument( + "--tickers", + "-t", + nargs="+", + metavar="TICKER", + help="One or more ticker symbols e.g. AAPL MSFT GOOGL", + ) + ticker_group.add_argument( + "--group", + "-g", + choices=list(TICKER_GROUPS.keys()), + help=f"Predefined group. Available: {', '.join(TICKER_GROUPS.keys())}", + ) + + # ── Filing filters + parser.add_argument( + "--filing-types", + "-f", + type=str, + default="10-K,8-K,DEF 14A", + help="Comma-separated filing types (default: 10-K,8-K,DEF 14A)", + ) + parser.add_argument( + "--after", + type=str, + default=None, + metavar="YYYY-MM-DD", + help="Only fetch filings on or after this date", + ) + parser.add_argument( + "--before", + type=str, + default=None, + metavar="YYYY-MM-DD", + help="Only fetch filings on or before this date", + ) + parser.add_argument( + "--limit", + "-l", + type=int, + default=10, + help="Max filings per ticker per type (default: 10)", + ) + + # ── Storage flags + parser.add_argument( + "--no-cold", action="store_true", help="Skip MinIO cold storage" + ) + parser.add_argument( + "--no-warm", action="store_true", help="Skip PostgreSQL warm storage" + ) + parser.add_argument( + "--no-search", action="store_true", help="Skip txtai search index" + ) + + # ── Run behaviour + parser.add_argument( + "--workers", + "-w", + type=int, + default=2, + help="Parallel workers for different tickers (default: 2). " + "Keep low — each worker makes concurrent async requests to EDGAR.", + ) + parser.add_argument( + "--delay", + type=float, + default=2.0, + help="Seconds to wait between tickers (default: 2.0)", + ) + parser.add_argument( + "--skip-existing", + action="store_true", + help="Skip tickers that already have filings in the database", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print what would be scraped without actually storing anything", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable debug logging", + ) + + return parser.parse_args() + + +# ── Environment validation ───────────────────────────────────────────────────── + + +def validate_environment() -> bool: + required = [ + "POSTGRES_HOST", + "POSTGRES_DB", + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "MINIO_ENDPOINT", + "MINIO_ACCESS_KEY", + "MINIO_SECRET_KEY", + ] + missing = [v for v in required if not os.getenv(v)] + + if not os.getenv("OPENAI_API_KEY") and not os.getenv("OLLAMA_HOST"): + missing.append("OPENAI_API_KEY or OLLAMA_HOST") + + if missing: + _LOG.error("Missing environment variables: %s", ", ".join(missing)) + return False + return True + + +# ── Existing ticker check ───────────────────────────────────────────────────── + + +def get_existing_tickers() -> set[str]: + """Return tickers that already have at least one filing in the DB.""" + try: + pg = get_postgres_client() + with pg.get_cursor() as cur: + cur.execute("SELECT DISTINCT ticker FROM filings") + return {row["ticker"] for row in cur.fetchall()} + except Exception as e: + _LOG.warning("Could not query existing tickers: %s", e) + return set() + + +# ── Date filtering ──────────────────────────────────────────────────────────── + + +def apply_date_filter( + collector: SECCollector, + ticker: str, + filing_types: list[str], + limit: int, + after: str | None, + before: str | None, + store_cold: bool, + store_warm: bool, + store_search: bool, +) -> dict[str, int]: + """ + Collect filings and post-filter by date range. + + SEC submissions API doesn't support date range params directly, + so we fetch then filter in-memory before storing. + """ + # Temporarily monkey-patch _fetch_data to apply date filter + original_fetch = collector._fetch_data + + def filtered_fetch(t, **kwargs): + docs = original_fetch(t, **kwargs) + filtered = [] + for doc in docs: + filing_date = doc.get("metadata", {}).get("filing_date", "") + if after and filing_date and filing_date < after: + continue + if before and filing_date and filing_date > before: + continue + filtered.append(doc) + if after or before: + _LOG.info( + "[%s] Date filter (%s → %s): %d/%d filings kept", + t, + after or "any", + before or "any", + len(filtered), + len(docs), + ) + return filtered + + collector._fetch_data = filtered_fetch + + result = collector.collect( + ticker=ticker, + filing_types=filing_types, + limit=limit, + store_cold=store_cold, + store_warm=store_warm, + store_search=store_search, + ) + + # Restore original method + collector._fetch_data = original_fetch + return result + + +# ── Per-ticker scrape ───────────────────────────────────────────────────────── + + +def scrape_ticker( + ticker: str, + filing_types: list[str], + limit: int, + after: str | None, + before: str | None, + store_cold: bool, + store_warm: bool, + store_search: bool, +) -> dict: + """Run the full collection pipeline for a single ticker.""" + _LOG.info("━" * 50) + _LOG.info("Scraping %s | types=%s | limit=%d", ticker, filing_types, limit) + + t0 = time.perf_counter() + try: + collector = SECCollector() + results = apply_date_filter( + collector=collector, + ticker=ticker, + filing_types=filing_types, + limit=limit, + after=after, + before=before, + store_cold=store_cold, + store_warm=store_warm, + store_search=store_search, + ) + elapsed = time.perf_counter() - t0 + results["ticker"] = ticker + results["elapsed"] = round(elapsed, 1) + results["status"] = "ok" + _LOG.info( + "✓ %s done in %.1fs — fetched=%d cold=%d warm=%d indexed=%d", + ticker, + elapsed, + results.get("fetched", 0), + results.get("stored_cold", 0), + results.get("stored_warm", 0), + results.get("indexed", 0), + ) + except Exception as e: + elapsed = time.perf_counter() - t0 + _LOG.exception("✗ %s failed after %.1fs: %s", ticker, elapsed, e) + results = { + "ticker": ticker, + "elapsed": round(elapsed, 1), + "status": "error", + "error": str(e), + "fetched": 0, + "stored_cold": 0, + "stored_warm": 0, + "indexed": 0, + } + + return results + + +# ── Summary table ───────────────────────────────────────────────────────────── + + +def print_summary(all_results: list[dict], total_elapsed: float) -> None: + _LOG.info("") + _LOG.info("=" * 65) + _LOG.info("BULK COLLECTION SUMMARY") + _LOG.info("=" * 65) + _LOG.info( + "%-8s %-8s %-6s %-6s %-8s %-8s %s", + "TICKER", + "STATUS", + "FETCHED", + "COLD", + "WARM", + "INDEXED", + "TIME(s)", + ) + _LOG.info("-" * 65) + + totals = {"fetched": 0, "stored_cold": 0, "stored_warm": 0, "indexed": 0} + + for r in all_results: + status_icon = "✓" if r["status"] == "ok" else "✗" + _LOG.info( + "%-8s %-8s %-6d %-6d %-8d %-8d %.1f", + r["ticker"], + f"{status_icon} {r['status']}", + r.get("fetched", 0), + r.get("stored_cold", 0), + r.get("stored_warm", 0), + r.get("indexed", 0), + r.get("elapsed", 0), + ) + for key in totals: + totals[key] += r.get(key, 0) + + _LOG.info("-" * 65) + _LOG.info( + "%-8s %-8s %-6d %-6d %-8d %-8d %.1f", + "TOTAL", + "", + totals["fetched"], + totals["stored_cold"], + totals["stored_warm"], + totals["indexed"], + total_elapsed, + ) + _LOG.info("=" * 65) + + errors = [r for r in all_results if r["status"] == "error"] + if errors: + _LOG.warning("Failed tickers: %s", ", ".join(r["ticker"] for r in errors)) + + +# ── Main ────────────────────────────────────────────────────────────────────── + + +def main() -> int: + args = parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + load_dotenv() + + if not args.dry_run and not validate_environment(): + return 1 + + # Resolve ticker list + tickers: list[str] = args.tickers if args.tickers else TICKER_GROUPS[args.group] + filing_types = [ft.strip() for ft in args.filing_types.split(",")] + + # Skip existing tickers if requested + if args.skip_existing and not args.dry_run: + existing = get_existing_tickers() + before_count = len(tickers) + tickers = [t for t in tickers if t not in existing] + skipped = before_count - len(tickers) + if skipped: + _LOG.info("Skipping %d already-scraped tickers", skipped) + + if not tickers: + _LOG.info("No tickers to scrape (all already exist or list is empty).") + return 0 + + # ── Dry run ── + if args.dry_run: + _LOG.info("DRY RUN — nothing will be stored") + _LOG.info("Tickers (%d): %s", len(tickers), ", ".join(tickers)) + _LOG.info("Filing types: %s", ", ".join(filing_types)) + _LOG.info("Limit per ticker: %d", args.limit) + _LOG.info("Date range: %s → %s", args.after or "any", args.before or "any") + _LOG.info("Workers: %d", args.workers) + return 0 + + # ── Live run ── + _LOG.info("=" * 65) + _LOG.info("BULK SEC COLLECTION") + _LOG.info("=" * 65) + _LOG.info("Tickers : %s", ", ".join(tickers)) + _LOG.info("Filing types: %s", ", ".join(filing_types)) + _LOG.info("Limit : %d per ticker", args.limit) + _LOG.info("Date range : %s → %s", args.after or "any", args.before or "any") + _LOG.info("Workers : %d", args.workers) + _LOG.info("Cold storage: %s", "disabled" if args.no_cold else "enabled") + _LOG.info("Warm storage: %s", "disabled" if args.no_warm else "enabled") + _LOG.info("Search index: %s", "disabled" if args.no_search else "enabled") + _LOG.info("=" * 65) + + all_results: list[dict] = [] + t0_total = time.perf_counter() + + if args.workers > 1: + # Parallel scraping across tickers + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = { + executor.submit( + scrape_ticker, + ticker, + filing_types, + args.limit, + args.after, + args.before, + not args.no_cold, + not args.no_warm, + not args.no_search, + ): ticker + for ticker in tickers + } + for future in as_completed(futures): + result = future.result() + all_results.append(result) + # Brief pause between completions to stay polite to EDGAR + time.sleep(args.delay) + else: + # Sequential — easier to read logs, safer for EDGAR rate limits + for i, ticker in enumerate(tickers): + result = scrape_ticker( + ticker=ticker, + filing_types=filing_types, + limit=args.limit, + after=args.after, + before=args.before, + store_cold=not args.no_cold, + store_warm=not args.no_warm, + store_search=not args.no_search, + ) + all_results.append(result) + + # Delay between tickers (except after the last one) + if i < len(tickers) - 1: + _LOG.info("Waiting %.1fs before next ticker…", args.delay) + time.sleep(args.delay) + + total_elapsed = time.perf_counter() - t0_total + print_summary(all_results, total_elapsed) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_sec_collector.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_sec_collector.py new file mode 100644 index 000000000..3aea505cf --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/run_sec_collector.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +""" +SEC EDGAR Collector Script + +Run the SEC collector to fetch and store filings for a given ticker. + +Usage: + python -m scripts.run_sec_collector --ticker AAPL --filing-types 10-K,8-K --limit 10 + python -m scripts.run_sec_collector -t TSLA -f 10-K -l 5 + +Environment Variables Required: + - OPENAI_API_KEY: For embeddings + - POSTGRES_HOST, POSTGRES_PORT, POSTGRES_DB, POSTGRES_USER, POSTGRES_PASSWORD + - MINIO_ENDPOINT, MINIO_ACCESS_KEY, MINIO_SECRET_KEY + - KEYDB_HOST, KEYDB_PORT (or use defaults) +""" + +import argparse +import logging +import os +import sys +from pathlib import Path + +from dotenv import load_dotenv + +# Add project root to path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from app.collectors import SECCollector + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +_LOG = logging.getLogger(__name__) + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="SEC EDGAR Collector for txtai Market Research Platform" + ) + parser.add_argument( + "-t", + "--ticker", + type=str, + default="AAPL", + help="Stock ticker symbol (default: AAPL)", + ) + parser.add_argument( + "-f", + "--filing-types", + type=str, + default="10-K,8-K,DEF 14A", + help="Comma-separated filing types (default: 10-K,8-K,DEF 14A)", + ) + parser.add_argument( + "-l", + "--limit", + type=int, + default=5000, + help="Maximum number of filings to fetch (default: 5000, supports up to 10000+)", + ) + parser.add_argument( + "--max-filings", + type=int, + default=10000, + help="Maximum filings to fetch from SEC API (default: 10000)", + ) + parser.add_argument( + "--no-cold", action="store_true", help="Skip cold storage (MinIO)" + ) + parser.add_argument( + "--no-warm", action="store_true", help="Skip warm storage (PostgreSQL)" + ) + parser.add_argument( + "--no-search", action="store_true", help="Skip search index (txtai)" + ) + parser.add_argument( + "--use-cache", action="store_true", help="Use cached results if available" + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable debug logging" + ) + + return parser.parse_args() + + +def validate_environment() -> bool: + """Validate required environment variables.""" + # Required for all deployments + required = [ + "POSTGRES_HOST", + "POSTGRES_DB", + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "MINIO_ENDPOINT", + "MINIO_ACCESS_KEY", + "MINIO_SECRET_KEY", + ] + + # Check for embedding provider (either OpenAI or Ollama) + has_openai = bool(os.getenv("OPENAI_API_KEY")) + has_ollama = bool(os.getenv("OLLAMA_HOST")) + + missing = [] + for var in required: + if not os.getenv(var): + missing.append(var) + + if not has_openai and not has_ollama: + missing.append("OPENAI_API_KEY or OLLAMA_HOST (need one for embeddings)") + + if missing: + _LOG.error("Missing required environment variables: %s", ", ".join(missing)) + _LOG.info("Copy .env.example to .env and fill in the values") + return False + + return True + + +def main() -> int: + """Main entry point.""" + args = parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + _LOG.info("=" * 60) + _LOG.info("SEC EDGAR Collector") + _LOG.info("=" * 60) + + # Load environment variables + load_dotenv() + + # Validate environment + if not validate_environment(): + return 1 + + # Parse filing types + filing_types = [ft.strip() for ft in args.filing_types.split(",")] + + _LOG.info("Configuration:") + _LOG.info(" Ticker: %s", args.ticker) + _LOG.info(" Filing Types: %s", ", ".join(filing_types)) + _LOG.info(" Limit: %d total filings", args.limit) + _LOG.info(" Max Filings: %d from SEC API", args.max_filings) + _LOG.info(" Cold Storage: %s", "disabled" if args.no_cold else "enabled") + _LOG.info(" Warm Storage: %s", "disabled" if args.no_warm else "enabled") + _LOG.info(" Search Index: %s", "disabled" if args.no_search else "enabled") + _LOG.info(" Use Cache: %s", "yes" if args.use_cache else "no") + _LOG.info("=" * 60) + + if args.limit > 1000: + _LOG.info( + "LARGE-SCALE COLLECTION: Fetching %d filings - this may take 5-15 minutes", + args.limit, + ) + + # Initialize collector + _LOG.info("Initializing SEC collector...") + collector = SECCollector() + + # Run collection + _LOG.info("Starting collection for %s...", args.ticker) + try: + results = collector.collect( + ticker=args.ticker, + filing_types=filing_types, + limit=args.limit, + store_cold=not args.no_cold, + store_warm=not args.no_warm, + store_search=not args.no_search, + use_cache=args.use_cache, + ) + + _LOG.info("=" * 60) + _LOG.info("Collection Results:") + _LOG.info(" Documents Fetched: %d", results.get("fetched", 0)) + _LOG.info(" Stored in Cold: %d", results.get("stored_cold", 0)) + _LOG.info(" Stored in Warm: %d", results.get("stored_warm", 0)) + _LOG.info(" Indexed for Search: %d", results.get("indexed", 0)) + _LOG.info("=" * 60) + + if results.get("fetched", 0) == 0: + _LOG.warning( + "No filings were fetched. Check the SEC API or try a different ticker." + ) + return 0 + + _LOG.info("SEC collection completed successfully!") + return 0 + + except Exception as e: + _LOG.exception("Collection failed: %s", e) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/smoke_test.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/smoke_test.sh new file mode 100755 index 000000000..2b2c1e4f3 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/scripts/smoke_test.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +# End-to-end smoke test: +# 1. assume `docker-compose up -d` is already running (api + storage tiers) +# 2. ingest one ticker with --limit 1 +# 3. backfill the txtai index from chunks +# 4. POST one query to /research and assert the answer is non-empty +# +# Usage: +# > ./scripts/smoke_test.sh +# > TICKER=MSFT ./scripts/smoke_test.sh + +set -euo pipefail + +TICKER="${TICKER:-AAPL}" +API_URL="${API_URL:-http://localhost:8000}" +QUERY="${QUERY:-What are the key risks discussed in the latest 10-K?}" + +echo "[smoke] ensuring docker-compose stack is up..." +docker-compose ps --services --filter status=running | grep -q api \ + || (echo "[smoke] api service is not running. Start it with 'docker-compose up -d'." >&2; exit 1) + +echo "[smoke] ingesting one filing for ${TICKER}..." +docker-compose exec -T api python -m scripts.run_sec_collector \ + --ticker "${TICKER}" \ + --filing-types 10-K \ + --limit 1 + +echo "[smoke] backfilling txtai index from chunks..." +docker-compose exec -T api python -m scripts.backfill_txtai_from_chunks --from-scratch + +echo "[smoke] hitting ${API_URL}/research..." +RESPONSE=$(curl -sS -X POST "${API_URL}/research" \ + -H 'Content-Type: application/json' \ + -d "{\"query\": \"${QUERY}\"}") + +echo "${RESPONSE}" | head -c 800 +echo "" + +# Pull the answer field via python (no jq dependency). +ANSWER=$(printf '%s' "${RESPONSE}" | python3 -c 'import json,sys; print(json.load(sys.stdin).get("answer",""))') + +if [ -z "${ANSWER}" ]; then + echo "[smoke] FAIL: empty answer returned" >&2 + exit 1 +fi + +echo "[smoke] OK — answer length=${#ANSWER}" diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/sql/init.sql b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/sql/init.sql new file mode 100644 index 000000000..4e357d1d0 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/sql/init.sql @@ -0,0 +1,112 @@ +-- txtai Market Research Platform - Database Schema +-- Tiers: Warm (PostgreSQL + pgvector), Graph (Kuzu - planned) +-- Note: Cold tier (MinIO) and Hot tier (KeyDB) are separate + +CREATE EXTENSION IF NOT EXISTS vector; +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; + +-- ############################################################################# +-- Companies Table +-- ############################################################################# +CREATE TABLE IF NOT EXISTS companies ( + cik VARCHAR(10) PRIMARY KEY, + ticker VARCHAR(10) UNIQUE, + name TEXT NOT NULL, + sic_code VARCHAR(4), + sector TEXT, + sub_industry TEXT, + exchange VARCHAR(10), + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_companies_ticker ON companies(ticker); + +-- ############################################################################# +-- Filings Table - Stores SEC filing metadata +-- ############################################################################# +CREATE TABLE IF NOT EXISTS filings ( + id VARCHAR(64) PRIMARY KEY, -- SHA256 hash of ticker:accession:form_type + ticker VARCHAR(10) NOT NULL, + company_name TEXT, + filing_type VARCHAR(20) NOT NULL, -- e.g., "10-K", "8-K", "DEF 14A" + cik VARCHAR(10), + accession_number VARCHAR(25) UNIQUE NOT NULL, + filing_date DATE, + period_of_report DATE, + document_url TEXT, + file_size_bytes INTEGER, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_filings_ticker ON filings(ticker); +CREATE INDEX IF NOT EXISTS idx_filings_type ON filings(filing_type); +CREATE INDEX IF NOT EXISTS idx_filings_date ON filings(filing_date DESC); +CREATE INDEX IF NOT EXISTS idx_filings_cik ON filings(cik); + +-- ############################################################################# +-- Chunks Table - Document chunks with embeddings for semantic search +-- ############################################################################# +CREATE TABLE IF NOT EXISTS chunks ( + id VARCHAR(64) PRIMARY KEY, -- SHA256 hash of source:ticker:chunk_text[:100] + filing_id VARCHAR(64) REFERENCES filings(id) ON DELETE CASCADE, + chunk_index INTEGER NOT NULL, + text TEXT NOT NULL, + section TEXT, + embedding vector(768), -- sentence-transformers/all-mpnet-base-v2 + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_chunks_embedding + ON chunks USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100); +CREATE INDEX IF NOT EXISTS idx_chunks_filing ON chunks(filing_id); +CREATE INDEX IF NOT EXISTS idx_chunks_section ON chunks(section); + +-- ############################################################################# +-- Document Metadata Table - Key/value metadata for chunks +-- ############################################################################# +CREATE TABLE IF NOT EXISTS document_metadata ( + id VARCHAR(64) PRIMARY KEY, -- SHA256 hash of chunk_id:key + chunk_id VARCHAR(64) REFERENCES chunks(id) ON DELETE CASCADE, + key VARCHAR(255) NOT NULL, + value TEXT, + created_at TIMESTAMPTZ DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_document_metadata_chunk ON document_metadata(chunk_id); +CREATE INDEX IF NOT EXISTS idx_document_metadata_key ON document_metadata(key); + +-- ############################################################################# +-- XBRL Facts Table - Structured financial data from SEC filings +-- ############################################################################# +CREATE TABLE IF NOT EXISTS xbrl_facts ( + id VARCHAR(64) PRIMARY KEY, + filing_id VARCHAR(64) REFERENCES filings(id) ON DELETE CASCADE, + concept_name VARCHAR(100) NOT NULL, + value TEXT, + value_numeric NUMERIC, + unit VARCHAR(20), + period_start DATE, + period_end DATE, + instant_date DATE, + axis VARCHAR(100), + member TEXT, + created_at TIMESTAMPTZ DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_xbrl_facts_filing ON xbrl_facts(filing_id); +CREATE INDEX IF NOT EXISTS idx_xbrl_concept ON xbrl_facts(concept_name); +CREATE INDEX IF NOT EXISTS idx_xbrl_period ON xbrl_facts(period_end DESC); + +-- ############################################################################# +-- Collection Runs Table - Track data collection job history +-- ############################################################################# +CREATE TABLE IF NOT EXISTS collection_runs ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + collector VARCHAR(50) NOT NULL, + ticker VARCHAR(10), + started_at TIMESTAMPTZ DEFAULT NOW(), + finished_at TIMESTAMPTZ, + records_written INTEGER DEFAULT 0, + status VARCHAR(20) DEFAULT 'running', + error_msg TEXT +); +CREATE INDEX IF NOT EXISTS idx_collection_runs_collector ON collection_runs(collector); +CREATE INDEX IF NOT EXISTS idx_collection_runs_ticker ON collection_runs(ticker); \ No newline at end of file diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/template_utils.py b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/template_utils.py new file mode 100644 index 000000000..37dda5fb9 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/template_utils.py @@ -0,0 +1,703 @@ +""" +Causal Success Analysis - Simulation and Inference Utilities. + +Import as: + +import research.A_Causal_Analysis_of_Success_in_Modern_Society.causal_success_utils as racaosimscsu +""" + +from typing import List, Optional, Dict, Any + +import numpy as np +import pandas as pd + +import helpers.hdbg as hdbg + +# Optional Bayesian dependencies (simulation works without these). +# try: +import pymc as pm # type: ignore +import arviz as az # type: ignore +# except Exception: # pragma: no cover - optional import. +# pm = None +# az = None + +# __all__ = [ +# "Agent", +# "create_population", +# "calculate_gini", +# "get_results_dataframe", +# "generate_summary_statistics", +# "validate_simulation_results", +# "run_simulation", +# "run_policy_simulation", +# "fit_bayesian_luck_model", +# "summarize_bayesian_fit", +# "posterior_predictive_check", +# ] + + +# ############################################################################# +# Agent +# ############################################################################# + + +class Agent: + """ + Agent representing an individual in the simulation. + + Each agent has four characteristics that define their position in the system: + + 1. Intensity (0-1): Activity level and effort. + - How active the agent is in seeking opportunities and experiences. + - Higher intensity → higher probability of encountering events (both good and bad). + - Think of it as "surface area for luck": more active people encounter more events. + - Influences event exposure probability via sigmoid function. + + 2. IQ (0-1): Ability to capitalize on opportunities. + - When a lucky event occurs, IQ determines if the agent successfully exploits it. + - Does NOT create opportunities, only gates whether they can be converted to gains. + - Unlucky events always apply (no IQ gate). + - Used as probability of capitalizing on beneficial events. + + 3. Networking (0-1): Social connectivity and spillover. + - Represents social connections and access to network effects. + - When an agent benefits from a lucky event, there's a chance (10%) that + a connected agent also benefits (at reduced impact: 50% of original). + - Spillover amount weighted by networking score. + + 4. Initial Capital: Starting wealth. + - Set to 1.0 for all agents in baseline simulation. + - This ensures inequality EMERGES from dynamics, not inherited advantages. + - Minimum enforced: 0.01 (prevent collapse to zero). + """ + + def __init__( + self, + agent_id: int, + intensity: float, + iq: float, + networking: float, + *, + initial_capital: float = 1.0, + ): + """ + Initialize the Agent with talents and initial capital. + + :param agent_id: Unique agent identifier (typically agent's index) + :param intensity: Intensity talent, affects event exposure (0-1) + :param iq: IQ talent, affects ability to capitalize on luck (0-1) + :param networking: Networking talent, affects spillover effects (0-1) + :param initial_capital: Starting wealth level (default 1.0) + """ + self.id = int(agent_id) + # Enforce bounds and safe floor for capital. + self.talent = { + "intensity": float(np.clip(intensity, 0.0, 1.0)), + "iq": float(np.clip(iq, 0.0, 1.0)), + "networking": float(np.clip(networking, 0.0, 1.0)), + "initial_capital": float(max(0.01, initial_capital)), + } + self.capital = float(self.talent["initial_capital"]) + self.capital_history: List[float] = [self.capital] + self.lucky_events: int = 0 + self.unlucky_events: int = 0 + + @property + def talent_norm(self) -> float: + """ + Euclidean norm of the 4D talent vector. + + :return: L2 norm of talent dimensions + """ + values = np.array( + [ + self.talent["intensity"], + self.talent["iq"], + self.talent["networking"], + self.talent["initial_capital"], + ], + dtype=float, + ) + return float(np.linalg.norm(values)) + + def get_event_probability(self) -> float: + """ + Probability of encountering an event based on intensity. + + Uses a sigmoid centered at 0.5. Higher intensity = higher exposure. + + :return: Event probability in [0, 1] + """ + alpha = 2.0 + return float( + 1.0 / (1.0 + np.exp(-alpha * (self.talent["intensity"] - 0.5))) + ) + + def apply_event(self, event_type: str, impact: float) -> None: + """ + Apply an event to capital using multiplicative dynamics. + + :param event_type: "lucky" or "unlucky" + :param impact: magnitude as a decimal (e.g., 0.25 = 25%) + """ + impact = float(abs(impact)) + if event_type == "lucky": + self.capital *= 1.0 + impact + self.lucky_events += 1 + elif event_type == "unlucky": + self.capital *= 1.0 - impact + self.capital = max(0.01, self.capital) + self.unlucky_events += 1 + else: + raise ValueError(f"Unknown event type: {event_type}") + self.capital_history.append(self.capital) + + +# ############################################################################# + + +def create_population(n_agents: int = 100, *, seed: int = 42) -> List[Agent]: + """ + Create a population of agents with normally distributed talents. + + :param n_agents: number of agents to create (default 100) + :param seed: RNG seed for reproducibility (default 42) + :return: List of Agent objects, each with random talents and capital=1.0 + """ + hdbg.dassert_lt(0, n_agents, "n_agents must be positive") + rng = np.random.default_rng(seed) + agents: List[Agent] = [] + for i in range(n_agents): + intensity = float(np.clip(rng.normal(0.5, 0.15), 0.0, 1.0)) + iq = float(np.clip(rng.normal(0.5, 0.15), 0.0, 1.0)) + networking = float(np.clip(rng.normal(0.5, 0.15), 0.0, 1.0)) + agents.append(Agent(i, intensity, iq, networking, initial_capital=1.0)) + return agents + + +def calculate_gini(values: np.ndarray) -> float: + """ + Compute the Gini coefficient for non-negative values. + + The Gini coefficient measures inequality in a distribution (e.g., wealth). + + :param values: 1D array of non-negative values (e.g., capital amounts) + :return: Gini coefficient in [0, 1] + """ + x = np.asarray(values, dtype=float) + hdbg.dassert_lt( + 0, x.size, "Cannot calculate Gini coefficient for empty array" + ) + hdbg.dassert( + not np.any(x < 0), + "Gini coefficient requires non-negative values", + ) + if np.all(x == 0): + return 0.0 + x_sorted = np.sort(x) + n = x_sorted.size + index = np.arange(1, n + 1, dtype=float) + gini = (2.0 * np.sum(index * x_sorted)) / (n * np.sum(x_sorted)) - ( + n + 1.0 + ) / n + return float(np.clip(gini, 0.0, 1.0)) + + +def get_results_dataframe(agents: List[Agent]) -> pd.DataFrame: + """ + Convert a list of agents to a DataFrame for analysis. + + :param agents: List of Agent objects + :return: DataFrame with agent attributes + """ + if not agents: + return pd.DataFrame() + rows: List[Dict[str, Any]] = [] + for a in agents: + rows.append( + { + "id": a.id, + "talent_intensity": a.talent["intensity"], + "talent_iq": a.talent["iq"], + "talent_networking": a.talent["networking"], + "initial_capital": a.talent["initial_capital"], + "talent_norm": a.talent_norm, + "capital": a.capital, + "lucky_events": a.lucky_events, + "unlucky_events": a.unlucky_events, + "net_events": a.lucky_events - a.unlucky_events, + } + ) + return pd.DataFrame(rows) + + +def generate_summary_statistics(agents: List[Agent]) -> Dict[str, float]: + """ + Generate comprehensive summary statistics for the simulation output. + + :param agents: List of Agent objects (after simulation) + :return: Dictionary mapping metric names to float values. Example output: + { + 'n_agents': 100.0, + 'mean_capital': 2.15, + 'median_capital': 1.85, + 'std_capital': 1.42, + 'min_capital': 0.01, + 'max_capital': 8.50, + 'capital_range': 850.0, + 'gini_coefficient': 0.38, + 'top_10_pct_share': 0.35, + 'top_20_pct_share': 0.52, + 'bottom_50_pct_share': 0.15, + 'mean_lucky_events': 4.2, + 'mean_unlucky_events': 4.1, + 'mean_talent_norm': 1.95, + } + """ + df = get_results_dataframe(agents) + if df.empty: + return {"n_agents": 0} + capital = df["capital"].to_numpy(dtype=float) + gini = calculate_gini(capital) + min_cap = float(np.min(capital)) + max_cap = float(np.max(capital)) + total_cap = float(np.sum(capital)) + n = len(df) + # Guard against division by zero (should not happen due to floor). + cap_range = max_cap / max(min_cap, 1e-12) + top_10_n = max(1, n // 10) + top_20_n = max(1, n // 5) + bottom_50_n = max(1, n // 2) + return { + "n_agents": float(n), + "mean_capital": float(np.mean(capital)), + "median_capital": float(np.median(capital)), + "std_capital": float(np.std(capital)), + "min_capital": min_cap, + "max_capital": max_cap, + "capital_range": float(cap_range), + "gini_coefficient": float(gini), + "top_10_pct_share": float( + df.nlargest(top_10_n, "capital")["capital"].sum() / total_cap + ), + "top_20_pct_share": float( + df.nlargest(top_20_n, "capital")["capital"].sum() / total_cap + ), + "bottom_50_pct_share": float( + df.nsmallest(bottom_50_n, "capital")["capital"].sum() / total_cap + ), + "mean_lucky_events": float(df["lucky_events"].mean()), + "mean_unlucky_events": float(df["unlucky_events"].mean()), + "mean_talent_norm": float(df["talent_norm"].mean()), + } + + +def validate_simulation_results(agents: List[Agent]) -> bool: + """ + Validate simulation results for basic correctness. + + Raises AssertionError if anything looks inconsistent. + + :param agents: List of Agent objects to validate + :return: True if validation passes + """ + df = get_results_dataframe(agents) + hdbg.dassert(not df.empty, "No agents provided to validate") + hdbg.dassert(not (df["capital"] < 0).any(), "Negative capital detected") + hdbg.dassert(not df.isnull().any().any(), "NaN values detected") + hdbg.dassert( + not ((df["lucky_events"] < 0).any() or (df["unlucky_events"] < 0).any()), + "Negative event counts detected", + ) + for a in agents: + expected = 1 + a.lucky_events + a.unlucky_events + hdbg.dassert_eq( + len(a.capital_history), + expected, + "Agent has inconsistent capital history length (expected, got):", + expected, + len(a.capital_history), + ) + return True + + +def run_simulation( + agents: List[Agent], + *, + n_periods: int = 80, + n_lucky_events_per_period: int = 5, + n_unlucky_events_per_period: int = 5, + lucky_mean: float = 0.25, + lucky_std: float = 0.08, + unlucky_mean: float = 0.15, + unlucky_std: float = 0.05, + seed: Optional[int] = 42, + verbose: bool = False, +) -> List[Agent]: + """ + Execute the agent-based simulation over multiple periods. + + :param agents: List of Agent objects to simulate (modified in-place) + :param n_periods: Number of time periods to simulate (default 80) + :param n_lucky_events_per_period: Lucky events per period (default 5) + :param n_unlucky_events_per_period: Unlucky events per period (default 5) + :param lucky_mean: Mean impact of lucky events (default 0.25 = 25%) + :param lucky_std: Std dev of lucky event impacts (default 0.08) + :param unlucky_mean: Mean impact of unlucky events (default 0.15 = 15%) + :param unlucky_std: Std dev of unlucky event impacts (default 0.05) + :param seed: RNG seed for reproducibility (default 42) + :param verbose: Show progress bar if True (requires tqdm, default False) + :return: Same agents list with updated capital and event histories + """ + hdbg.dassert_lt(0, n_periods, "n_periods must be positive") + hdbg.dassert(agents, "agents list cannot be empty") + hdbg.dassert( + n_lucky_events_per_period >= 0 and n_unlucky_events_per_period >= 0, + "event counts per period must be non-negative", + ) + + rng = np.random.default_rng(seed) + n_agents = len(agents) + if verbose: + try: + from tqdm import tqdm # type: ignore + + periods_iter = tqdm( + range(n_periods), desc="Running simulation", unit="period" + ) + except Exception: + periods_iter = range(n_periods) + else: + periods_iter = range(n_periods) + for _ in periods_iter: + # Lucky events. + for _ in range(n_lucky_events_per_period): + exposure = np.array( + [a.get_event_probability() for a in agents], dtype=float + ) + exposure = ( + exposure / exposure.sum() + if exposure.sum() > 0 + else np.ones(n_agents) / n_agents + ) + selected_idx = int(rng.choice(n_agents, p=exposure)) + selected = agents[selected_idx] + impact = float( + np.clip(rng.normal(lucky_mean, lucky_std), 0.05, 0.50) + ) + # IQ gates whether a lucky event can be capitalized on. + if rng.random() < selected.talent["iq"]: + selected.apply_event("lucky", impact) + # Networking spillover (10%). + if rng.random() < 0.1: + net = np.array( + [a.talent["networking"] for a in agents], dtype=float + ) + if net.sum() > 0: + net = net / net.sum() + inherited_idx = int(rng.choice(n_agents, p=net)) + if ( + inherited_idx != selected_idx + and rng.random() < agents[inherited_idx].talent["iq"] + ): + agents[inherited_idx].apply_event("lucky", impact * 0.5) + # Unlucky events. + for _ in range(n_unlucky_events_per_period): + exposure = np.array( + [a.get_event_probability() for a in agents], dtype=float + ) + exposure = ( + exposure / exposure.sum() + if exposure.sum() > 0 + else np.ones(n_agents) / n_agents + ) + selected_idx = int(rng.choice(n_agents, p=exposure)) + selected = agents[selected_idx] + impact = float( + np.clip(rng.normal(unlucky_mean, unlucky_std), 0.05, 0.30) + ) + selected.apply_event("unlucky", impact) + return agents + + +def run_policy_simulation( + agents: List[Agent], + *, + policy: str = "egalitarian", + resource_amount: float = 100.0, + cate_values: Optional[np.ndarray] = None, + **simulation_kwargs, +) -> List[Agent]: + """ + Allocate initial resources under a policy, then run the standard simulation. + + 1. "egalitarian" + - Every agent gets: resource_amount / n_agents + - Rationale: Reduce initial inequality, give everyone equal chance + - Typical outcome: Lowest final Gini (most equitable) + - Typical outcome: Moderate total welfare + + 2. "meritocratic" + - Allocation ∝ talent_norm (total ability) + - Rationale: Reward potentially capable people + - Typical outcome: Moderate final Gini + - Typical outcome: High total welfare (resources go to productive people) + + 3. "performance" + - Allocation ∝ current capital (rich get richer) + - Rationale: Compound success (controversial, tested for comparison) + - Typical outcome: Highest final Gini (most unequal) + - Typical outcome: Lowest total welfare (resources wasted on already-rich) + + 4. "random" + - One randomly chosen agent gets ALL resources + - Rationale: Extreme luck-based allocation + - Typical outcome: Very high Gini + - Typical outcome: Highest possible total welfare (concentrated resources) + + 5. "cate_optimal" + - Allocation ∝ CATE estimates (heterogeneous treatment effects) + - Rationale: Give resources to agents who benefit most from them + - Requires: cate_values array with one value per agent + - Typical outcome: High total welfare, moderate Gini + - Note: Only allocates to agents with non-negative CATE + + :param agents: List of Agent objects (capital modified in-place) + :param policy: Allocation rule: "egalitarian", "meritocratic", "performance", + "random", or "cate_optimal" (default "egalitarian") + :param resource_amount: Total budget to distribute at t=0 (default 100.0) + :param cate_values: 1D array of CATE estimates, one per agent. + Required if policy="cate_optimal", ignored otherwise. + :param simulation_kwargs: Additional arguments forwarded to run_simulation() + (e.g., n_periods=80, seed=42, verbose=True) + :return: Same agents list after resource allocation and full simulation + """ + hdbg.dassert(agents, "agents list cannot be empty") + hdbg.dassert_lt( + -0.0001, resource_amount, "resource_amount must be non-negative" + ) + n = len(agents) + rng = np.random.default_rng(simulation_kwargs.get("seed", None)) + # Handle random policy separately (single winner). + if policy == "random": + winner_idx = int(rng.integers(n)) + agents[winner_idx].capital += resource_amount + agents[winner_idx].capital_history[0] = agents[winner_idx].capital + return run_simulation(agents, **simulation_kwargs) + # For all other policies, we compute weights and allocate proportionally. + weights = np.zeros(n, dtype=float) + if policy == "egalitarian": + weights[:] = 1.0 + elif policy == "meritocratic": + weights = np.array([a.talent_norm for a in agents], dtype=float) + elif policy == "performance": + weights = np.array([a.capital for a in agents], dtype=float) + elif policy == "cate_optimal": + hdbg.dassert_is_not( + cate_values, + None, + "cate_values must be provided when policy='cate_optimal'.", + ) + cate_array = np.asarray(cate_values, dtype=float) + hdbg.dassert_eq( + cate_array.shape[0], + n, + "cate_values must have length (expected, got):", + n, + cate_array.shape[0], + ) + # Use only non-negative CATEs; negative values are clamped to zero. + weights = np.maximum(cate_array, 0.0) + else: + raise ValueError( + f"Unknown policy: {policy}. Must be one of: " + f"egalitarian, meritocratic, performance, random, cate_optimal" + ) + total_weight = float(weights.sum()) + if total_weight <= 0.0: + # Fallback: if everything is zero, allocate equally. + weights = np.ones(n, dtype=float) + total_weight = float(n) + shares = weights / total_weight + allocations = shares * float(resource_amount) + for a, alloc in zip(agents, allocations): + a.capital += float(alloc) + # Keep history consistent at t=0. + a.capital_history[0] = a.capital + return run_simulation(agents, **simulation_kwargs) + + +# ############################################################################# +# Bayesian model +# ############################################################################# +# ------------------------------------------------------------------- + + +def fit_bayesian_luck_model( + df: pd.DataFrame, + *, + draws: int = 1000, + tune: int = 1000, + target_accept: float = 0.9, + random_seed: int = 42, +): + """ + Fit a Bayesian regression model to estimate causal effect of luck on capital. + + :param df: DataFrame from get_results_dataframe(agents), must have: + 'capital', 'lucky_events', 'talent_intensity', 'talent_iq', 'talent_networking' + :param draws: Number of posterior draws per chain (default 1000) + Higher = more accurate posterior, longer runtime + :param tune: Number of NUTS tuning/burn-in iterations (default 1000) + These are discarded and not used in inference + :param target_accept: NUTS sampler target acceptance rate (default 0.9) + Valid range: (0.5, 1.0), higher = slower but more stable + :param random_seed: RNG seed for reproducibility (default 42) + + :return: Tuple (model, idata): + - model: PyMC Model object (for diagnostics, re-sampling, etc.) + - idata: ArviZ InferenceData object containing posterior samples + Use with summarize_bayesian_fit() or posterior_predictive_check() + """ + required_cols = [ + "capital", + "lucky_events", + "talent_intensity", + "talent_iq", + "talent_networking", + ] + missing = [c for c in required_cols if c not in df.columns] + hdbg.dassert(not missing, "DataFrame is missing required columns:", missing) + capital = df["capital"].to_numpy(dtype=float) + y = np.log(capital) # log-capital is more stable and closer to normal. + lucky = df["lucky_events"].to_numpy(dtype=float) + intensity = df["talent_intensity"].to_numpy(dtype=float) + iq = df["talent_iq"].to_numpy(dtype=float) + networking = df["talent_networking"].to_numpy(dtype=float) + # THE QUESTION + # ============ + # Does luck causally affect outcomes, even after controlling for talent? + # This model answers that by regressing log(capital) on both luck and talent. + + # THE MODEL + # ========= + # Linear Bayesian regression: + + # log(capital_i) = alpha + # + beta_luck * lucky_events_i + # + beta_intensity * talent_intensity_i + # + beta_iq * talent_iq_i + # + beta_networking * talent_networking_i + # + epsilon_i + + # Where: + # - log(capital) is the outcome (log-scale for stability) + # - lucky_events is the treatment (how many beneficial events occurred) + # - talent_* are confounders (control for inherent ability) + # - epsilon ~ N(0, sigma) is residual error + + # PRIORS + # ====== + # All coefficients use weakly informative N(0, 1) priors (centered at 0). + # This allows the data to dominate the inference without strong prior beliefs. + with pm.Model() as model: + # Priors: fairly weakly informative, centered at 0. + alpha = pm.Normal("alpha", mu=0.0, sigma=1.0) + beta_luck = pm.Normal("beta_luck", mu=0.0, sigma=1.0) + beta_intensity = pm.Normal("beta_intensity", mu=0.0, sigma=1.0) + beta_iq = pm.Normal("beta_iq", mu=0.0, sigma=1.0) + beta_networking = pm.Normal("beta_networking", mu=0.0, sigma=1.0) + sigma = pm.HalfNormal("sigma", sigma=1.0) + mu = ( + alpha + + beta_luck * lucky + + beta_intensity * intensity + + beta_iq * iq + + beta_networking * networking + ) + pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y) + idata = pm.sample( + draws=draws, + tune=tune, + target_accept=target_accept, + random_seed=random_seed, + return_inferencedata=True, + progressbar=True, + ) + return model, idata + + +def summarize_bayesian_fit( + idata, *, var_names: Optional[List[str]] = None +) -> pd.DataFrame: + """ + Return a tidy summary table (posterior mean, sd, and credible intervals). + + For the Bayesian model parameters. + + :param idata: ArviZ InferenceData returned by fit_bayesian_luck_model + :param var_names: optional subset of parameter names to summarize + :return: pandas DataFrame with summary statistics (mean, sd, hdi, etc.) + """ + if var_names is None: + # By default, summarize the main coefficients and sigma. + var_names = [ + "alpha", + "beta_luck", + "beta_intensity", + "beta_iq", + "beta_networking", + "sigma", + ] + summary = az.summary(idata, var_names=var_names) + return summary + + +def posterior_predictive_check( + model, + idata, + df: pd.DataFrame, + *, + random_seed: int = 123, +) -> Dict[str, np.ndarray]: + """ + Simple posterior predictive check (PPC). + + This function draws from the posterior predictive distribution and compares + simulated log-capital to the observed log-capital. + + :param model: PyMC model returned by fit_bayesian_luck_model + :param idata: ArviZ InferenceData with posterior draws + :param df: same DataFrame used for fitting + :param random_seed: RNG seed for reproducibility + :return: dict with: + - "y_obs": observed log-capital + - "y_pred_mean": posterior predictive mean log-capital per agent + - "y_pred_std": posterior predictive std-dev per agent + """ + capital = df["capital"].to_numpy(dtype=float) + y_obs = np.log(capital) + with model: + ppc = pm.sample_posterior_predictive( + idata, + var_names=["y_obs"], + random_seed=random_seed, + progressbar=False, + ) + # ppc["y_obs"] has shape (chains, draws, n) or (draws, n) depending on PyMC version. + y_sim = np.asarray(ppc["y_obs"]) + if y_sim.ndim == 3: + # (chains, draws, n) -> (chains * draws, n). + y_sim = y_sim.reshape(-1, y_sim.shape[-1]) + elif y_sim.ndim == 2: + # (draws, n) -> OK. + pass + else: + raise ValueError(f"Unexpected PPC shape for y_obs: {y_sim.shape}") + y_pred_mean = y_sim.mean(axis=0) + y_pred_std = y_sim.std(axis=0) + return { + "y_obs": y_obs, + "y_pred_mean": y_pred_mean, + "y_pred_std": y_pred_std, + } diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/utils.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/utils.sh new file mode 100644 index 000000000..6dc059628 --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/utils.sh @@ -0,0 +1,22 @@ +DEBIAN_FRONTEND=noninteractive + + +print_vars() { + echo "AM_CONTAINER_VERSION=$AM_CONTAINER_VERSION" + echo "APP_DIR=$APP_DIR" + echo "CLEAN_UP_INSTALLATION=$CLEAN_UP_INSTALLATION" + echo "ENV_NAME=$ENV_NAME" + echo "HOME=$HOME" + echo "INSTALL_DIND=$INSTALL_DIND" + echo "POETRY_MODE=$POETRY_MODE" +} + + +report_disk_usage() { + du -h --max-depth=1 / --exclude=/proc | sort -hr + # Print dirs with size larger than 1MB. + DIRS="/usr /var" + du -h --max-depth=1 $DIRS 2>/dev/null | \ + awk '$1 ~ /[0-9\.]+M/ || $1 ~ /[0-9\.]+G/ || $1 ~ /[0-9\.]+T/' | \ + sort -hr +} \ No newline at end of file diff --git a/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/version.sh b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/version.sh new file mode 100755 index 000000000..c46ed254c --- /dev/null +++ b/class_project/data605/Spring2026/projects/UmdTask430_DATA605_Spring2026_txtai_for_market_research/version.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# """ +# Display versions of installed tools and packages. +# +# This script prints version information for Python, pip, Jupyter, and all +# installed Python packages. Used for debugging and documentation purposes +# to verify the Docker container environment setup. +# """ + +# Display Python 3 version. +echo "# Python3" +python3 --version + +# Display pip version. +echo "# pip3" +pip3 --version + +# Display Jupyter version. +echo "# jupyter" +jupyter --version + +# List all installed Python packages and their versions. +echo "# Python packages" +pip3 list + +# Template for adding additional tool versions. +# echo "# mongo" +# mongod --version