diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1525020 --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.pyc +*.pyo +.Python + +# Virtual environments +.venv/ +venv/ +env/ + +# Distribution / packaging +dist/ +build/ +*.egg-info/ + +# Node / Wrangler +node_modules/ +.wrangler/ +.dev.vars + +# Test / coverage +.pytest_cache/ +.coverage +htmlcov/ + +# Environment variables +.env diff --git a/README.md b/README.md index 04caa50..b4a10ae 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,234 @@ -# scholarai +# ScholarAI + AI-powered research assistant designed to help students and scientists navigate large volumes of academic literature. Supports paper discovery, summarization, citation exploration, and question answering across research papers. Helps organize knowledge, identify trends, and accelerate literature review workflows. + +Built as a **Cloudflare Python Worker** — using **only the `workers` module** (no third-party frameworks) — powered by **Cloudflare Workers AI** (`@cf/meta/llama-3.1-8b-instruct`). + +--- + +## Features + +| Feature | Endpoint | Description | +|---------|----------|-------------| +| Paper discovery | `POST /api/discover` | Find relevant papers for a query | +| Summarization | `POST /api/summarize` | Structured summary of a paper | +| Citation exploration | `POST /api/citations` | Explore a paper's citation network | +| Question answering | `POST /api/qa` | Answer research questions from literature | +| Knowledge organization | `POST /api/organize` | Cluster and map a reading list | +| Trend identification | `POST /api/trends` | Spot emerging and declining research trends | +| Literature review | `POST /api/review` | Generate a full literature review section | + +--- + +## Quick Start (local development) + +### Prerequisites + +- [Node.js](https://nodejs.org/) ≥ 18 +- [uv](https://github.com/astral-sh/uv) (Python package manager) + +```bash +# Install Node dependencies (Wrangler CLI) +npm install + +# Authenticate with Cloudflare (required for the Workers AI binding) +npx wrangler login + +# Start the local development server +npm run dev +``` + +The local server starts at `http://localhost:8787`. + +--- + +## API Reference + +### `GET /` + +Returns API metadata and a usage guide for all endpoints. + +--- + +### `POST /api/discover` + +Discover relevant academic papers for a research query. + +**Request body:** +```json +{ + "query": "transformer models in NLP", + "fields": ["machine learning", "NLP"], + "limit": 10 +} +``` + +**Response:** +```json +{ + "query": "transformer models in NLP", + "results": { + "papers": [...], + "research_directions": [...], + "key_concepts": [...], + "related_queries": [...] + } +} +``` + +--- + +### `POST /api/summarize` + +Generate a structured summary of a research paper. + +**Request body** (at least one field required): +```json +{ + "title": "Attention Is All You Need", + "abstract": "We propose a new simple network architecture...", + "content": "Full paper text (optional, truncated to 4 000 chars)" +} +``` + +--- + +### `POST /api/citations` + +Explore a paper's citation network. + +**Request body:** +```json +{ + "paper": "Attention Is All You Need", + "type": "related" +} +``` + +`type` options: `forward`, `backward`, `related` (default: `related`). + +--- + +### `POST /api/qa` + +Answer a research question using the AI's knowledge and optional context. + +**Request body:** +```json +{ + "question": "What are the main advantages of self-attention over RNNs?", + "context": "Optional background text", + "papers": [{"title": "...", "year": 2017}] +} +``` + +--- + +### `POST /api/organize` + +Organize a reading list into thematic clusters and a knowledge map. + +**Request body:** +```json +{ + "papers": [ + {"title": "Paper A", "year": 2021}, + {"title": "Paper B", "year": 2022} + ], + "organize_by": "topic" +} +``` + +`organize_by` options: `topic`, `year`, `author`, `methodology` (default: `topic`). + +--- + +### `POST /api/trends` + +Identify research trends in a field or from a set of papers. + +**Request body** (at least one of `field` or `papers` required): +```json +{ + "field": "computer vision", + "time_range": "2018-2024", + "papers": [...] +} +``` + +--- + +### `POST /api/review` + +Generate a structured literature review. + +**Request body:** +```json +{ + "topic": "graph neural networks", + "papers": [...], + "style": "comprehensive", + "audience": "graduate students" +} +``` + +`style` options: `comprehensive`, `brief`, `systematic` (default: `comprehensive`). + +--- + +## Deployment + +```bash +# Deploy to Cloudflare Workers +npm run deploy +``` + +--- + +## Development + +### Running tests + +```bash +# Install dev dependencies +uv sync --group dev + +# Run tests +uv run pytest tests/ -v +``` + +### Project structure + +``` +scholarai/ +├── src/ +│ └── entry.py # Handler functions + WorkerEntrypoint router +├── tests/ +│ └── test_entry.py # Async pytest tests with mocked AI binding +├── wrangler.toml # Cloudflare Workers configuration +├── pyproject.toml # Python project + dependencies +├── package.json # npm scripts for Wrangler CLI +├── conftest.py # Workers SDK stub for local testing +└── README.md +``` + +--- + +## Architecture + +``` +HTTP Request + │ + ▼ +Default.fetch() ← Cloudflare Workers runtime (WorkerEntrypoint) + │ manual URL routing + ▼ +handle_*() functions ← pure async business logic + │ + ▼ +env.AI.run(model, params) ← Cloudflare Workers AI (@cf/meta/llama-3.1-8b-instruct) + │ + ▼ +Response(json, status) ← workers.Response +``` + diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..5b34ab6 --- /dev/null +++ b/conftest.py @@ -0,0 +1,39 @@ +""" +Pytest configuration: stubs out the Cloudflare Workers SDK so that +src/entry.py can be imported in a plain Python test environment. +""" +import json +import sys +from types import ModuleType + + +def _make_workers_stub() -> ModuleType: + """Return a minimal 'workers' module stub.""" + mod = ModuleType("workers") + + class Response: + """Stub for the Cloudflare Workers Response class.""" + + def __init__(self, body="", status=200, headers=None): + self.body = body + self.status = status + self.headers = headers or {} + + def json_body(self) -> dict: + """Convenience helper used in tests to decode the JSON body.""" + return json.loads(self.body) + + class WorkerEntrypoint: + """Stub for the Cloudflare Workers WorkerEntrypoint class.""" + + async def fetch(self, request): # pragma: no cover + raise NotImplementedError("Use the Cloudflare Workers runtime") + + mod.Response = Response + mod.WorkerEntrypoint = WorkerEntrypoint + return mod + + +# Install stub before any test module imports src.entry +if "workers" not in sys.modules: + sys.modules["workers"] = _make_workers_stub() diff --git a/package.json b/package.json new file mode 100644 index 0000000..569cf19 --- /dev/null +++ b/package.json @@ -0,0 +1,14 @@ +{ + "name": "scholarai", + "version": "1.0.0", + "description": "AI-powered research assistant using Cloudflare Workers AI", + "private": true, + "scripts": { + "deploy": "uv run pywrangler deploy", + "dev": "uv run pywrangler dev", + "start": "uv run pywrangler dev" + }, + "devDependencies": { + "wrangler": "^4.46.0" + } +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8b2a37f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[project] +name = "scholarai" +version = "1.0.0" +description = "AI-powered research assistant using Cloudflare Workers AI" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "webtypy>=0.1.7", +] + +[dependency-groups] +dev = [ + "workers-py", + "workers-runtime-sdk", + "pytest", + "pytest-asyncio", +] + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/src/entry.py b/src/entry.py new file mode 100644 index 0000000..41c5769 --- /dev/null +++ b/src/entry.py @@ -0,0 +1,428 @@ +import json +from urllib.parse import urlparse + +from workers import Response, WorkerEntrypoint + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +AI_MODEL = "@cf/meta/llama-3.1-8b-instruct" + +# --------------------------------------------------------------------------- +# HTTP helpers +# --------------------------------------------------------------------------- + + +def _json_response(data: dict, status: int = 200) -> Response: + """Return a JSON Response with the correct Content-Type header.""" + return Response( + json.dumps(data), + status=status, + headers={"Content-Type": "application/json"}, + ) + + +def _error(message: str, status: int) -> Response: + """Return a JSON error response.""" + return _json_response({"error": message}, status) + + +# --------------------------------------------------------------------------- +# Shared AI helper +# --------------------------------------------------------------------------- + + +def _try_parse_json(text: str): + """Attempt to extract and parse a JSON object from a text response.""" + start = text.find("{") + end = text.rfind("}") + 1 + if start >= 0 and end > start: + try: + return json.loads(text[start:end]) + except json.JSONDecodeError: + pass + return text + + +async def run_ai(env, system_prompt: str, user_prompt: str) -> str: + """Run inference via the Cloudflare Workers AI binding.""" + result = await env.AI.run( + AI_MODEL, + { + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + }, + ) + if isinstance(result, dict): + return result.get("response", "") + return getattr(result, "response", str(result)) + + +# --------------------------------------------------------------------------- +# Business-logic handlers (pure async functions, easy to unit-test) +# --------------------------------------------------------------------------- + + +def handle_info() -> dict: + """Return API metadata.""" + return { + "name": "ScholarAI", + "version": "1.0.0", + "description": ( + "AI-powered research assistant for academic literature, " + "powered by Cloudflare Workers AI." + ), + "endpoints": { + "GET /": "API information", + "POST /api/discover": "Discover relevant academic papers", + "POST /api/summarize": "Summarize a research paper", + "POST /api/citations": "Explore paper citations and related work", + "POST /api/qa": "Question answering about research topics", + "POST /api/organize": "Organize a collection of papers", + "POST /api/trends": "Identify research trends in a field", + "POST /api/review": "Generate a literature review", + }, + "usage": { + "discover": { + "method": "POST", + "body": { + "query": "string (required)", + "fields": "list[string] (optional)", + "limit": "int (optional, default: 10)", + }, + }, + "summarize": { + "method": "POST", + "body": { + "title": "string (optional)", + "abstract": "string (optional)", + "content": "string (optional)", + }, + }, + "citations": { + "method": "POST", + "body": { + "paper": "string (required)", + "type": "string (optional: forward | backward | related)", + }, + }, + "qa": { + "method": "POST", + "body": { + "question": "string (required)", + "context": "string (optional)", + "papers": "list[dict] (optional)", + }, + }, + "organize": { + "method": "POST", + "body": { + "papers": "list[dict] (required)", + "organize_by": "string (optional: topic | year | author | methodology)", + }, + }, + "trends": { + "method": "POST", + "body": { + "field": "string (optional)", + "papers": "list[dict] (optional)", + "time_range": "string (optional)", + }, + }, + "review": { + "method": "POST", + "body": { + "topic": "string (required)", + "papers": "list[dict] (optional)", + "style": "string (optional: comprehensive | brief | systematic)", + "audience": "string (optional)", + }, + }, + }, + } + + +async def handle_discover(body: dict, env) -> tuple[dict, int]: + """Discover relevant academic papers for a research query.""" + query = body.get("query", "").strip() + if not query: + return {"error": "query is required"}, 400 + + fields = body.get("fields", []) + limit = int(body.get("limit", 10)) + field_context = f" in the fields of {', '.join(fields)}" if fields else "" + + system_prompt = ( + "You are ScholarAI, an expert academic research assistant. " + "You help researchers discover relevant academic papers and literature. " + "When given a research query, provide structured information about relevant " + "papers, key concepts, and research directions. " + "Format your response as a valid JSON object." + ) + user_prompt = ( + f"Discover academic papers for the following research query{field_context}:\n\n" + f"Query: {query}\n" + f"Requested limit: {limit} papers\n\n" + "Return a JSON object with keys: papers (list with title, authors, year, venue, " + "abstract_summary, relevance_score, key_topics), research_directions (list), " + "key_concepts (list), related_queries (list)." + ) + response_text = await run_ai(env, system_prompt, user_prompt) + return {"query": query, "results": _try_parse_json(response_text)}, 200 + + +async def handle_summarize(body: dict, env) -> tuple[dict, int]: + """Generate a structured summary of a research paper.""" + title = body.get("title", "").strip() + abstract = body.get("abstract", "").strip() + content = body.get("content", "").strip() + + if not (title or abstract or content): + return {"error": "At least one of title, abstract, or content is required"}, 400 + + parts = [] + if title: + parts.append(f"Title: {title}") + if abstract: + parts.append(f"Abstract: {abstract}") + if content: + parts.append(f"Content:\n{content[:4000]}") + + system_prompt = ( + "You are ScholarAI, an expert academic research assistant specialized in " + "summarizing research papers. Provide clear, concise, and accurate summaries " + "that capture the key contributions, methodology, results, and implications." + ) + user_prompt = ( + f"Summarize the following research paper:\n\n{chr(10).join(parts)}\n\n" + "Structure your summary with these sections:\n" + "1. Main contribution\n" + "2. Problem being solved\n" + "3. Methodology\n" + "4. Key findings / results\n" + "5. Limitations\n" + "6. Future work directions\n" + "7. Impact and significance" + ) + summary = await run_ai(env, system_prompt, user_prompt) + return {"title": title, "summary": summary}, 200 + + +async def handle_citations(body: dict, env) -> tuple[dict, int]: + """Explore the citation network around a paper.""" + paper = body.get("paper", "").strip() + if not paper: + return {"error": "paper is required"}, 400 + + citation_type = body.get("type", "related") + + system_prompt = ( + "You are ScholarAI, an expert academic research assistant specializing in " + "citation analysis and research genealogy. Help researchers understand the " + "citation network around papers, including foundational work, recent advances, " + "and related research." + ) + user_prompt = ( + f"Explore citations and related work for the following paper:\n\n" + f"Paper: {paper}\n" + f"Citation type requested: {citation_type}\n\n" + "Please provide:\n" + "1. Key foundational papers (seminal works this paper builds upon)\n" + "2. Papers that cite this work (if applicable)\n" + "3. Related papers in the same research area\n" + "4. Research lineage and evolution of ideas\n" + "5. Key authors and research groups in this area\n" + "6. Connections to adjacent research fields" + ) + result = await run_ai(env, system_prompt, user_prompt) + return {"paper": paper, "citation_type": citation_type, "exploration": result}, 200 + + +async def handle_qa(body: dict, env) -> tuple[dict, int]: + """Answer a research question using the AI.""" + question = body.get("question", "").strip() + if not question: + return {"error": "question is required"}, 400 + + context = body.get("context", "") + papers = body.get("papers", []) + + context_parts = [] + if context: + context_parts.append(f"Context:\n{context}") + if papers: + context_parts.append(f"Available papers:\n{json.dumps(papers[:5], indent=2)}") + + system_prompt = ( + "You are ScholarAI, an expert academic research assistant with deep knowledge " + "across multiple scientific disciplines. Answer research questions accurately, " + "citing relevant papers and providing evidence-based responses. " + "Be precise, thorough, and acknowledge uncertainty when appropriate." + ) + extra = ("\n\n" + "\n\n".join(context_parts)) if context_parts else "" + user_prompt = ( + f"Please answer the following research question:\n\n" + f"Question: {question}{extra}\n\n" + "Provide:\n" + "1. A comprehensive answer\n" + "2. Key evidence and findings from the literature\n" + "3. Relevant papers and references\n" + "4. Important caveats or limitations\n" + "5. Areas of ongoing debate or uncertainty" + ) + answer = await run_ai(env, system_prompt, user_prompt) + return {"question": question, "answer": answer}, 200 + + +async def handle_organize(body: dict, env) -> tuple[dict, int]: + """Organize a collection of papers into thematic clusters.""" + papers = body.get("papers") + if not papers: + return {"error": "papers is required"}, 400 + + organize_by = body.get("organize_by", "topic") + + system_prompt = ( + "You are ScholarAI, an expert academic research assistant specialized in " + "knowledge organization and taxonomy. Help researchers organize and structure " + "their literature collection for better understanding and navigation." + ) + papers_json = json.dumps(papers[:20], indent=2) + user_prompt = ( + f"Organize the following research papers by {organize_by}:\n\n" + f"Papers:\n{papers_json}\n\n" + "Please provide:\n" + "1. Organized groupings / clusters\n" + "2. Key themes and relationships between papers\n" + "3. A knowledge map showing connections\n" + "4. Recommended reading order\n" + "5. Research gaps identified from this collection\n" + "6. Cross-cutting themes and methodologies" + ) + result = await run_ai(env, system_prompt, user_prompt) + return {"organize_by": organize_by, "organization": result}, 200 + + +async def handle_trends(body: dict, env) -> tuple[dict, int]: + """Identify research trends in a field or from a set of papers.""" + field = body.get("field", "").strip() + papers = body.get("papers", []) + + if not field and not papers: + return {"error": "At least one of field or papers is required"}, 400 + + time_range = body.get("time_range", "") + context_parts = [] + if field: + context_parts.append(f"Research field: {field}") + if time_range: + context_parts.append(f"Time range: {time_range}") + if papers: + context_parts.append(f"Papers:\n{json.dumps(papers[:20], indent=2)}") + + system_prompt = ( + "You are ScholarAI, an expert academic research assistant specialized in " + "research trend analysis and bibliometrics. Identify emerging trends, " + "declining areas, hot topics, and research trajectories in academic literature." + ) + user_prompt = ( + f"Identify research trends from the following information:\n\n" + f"{chr(10).join(context_parts)}\n\n" + "Please provide:\n" + "1. Emerging research trends and hot topics\n" + "2. Declining or saturated research areas\n" + "3. Methodological trends and shifts\n" + "4. Key breakthroughs and milestone papers\n" + "5. Future research directions\n" + "6. Cross-disciplinary connections and emerging intersections\n" + "7. Technology / tool adoption trends" + ) + result = await run_ai(env, system_prompt, user_prompt) + return {"field": field, "trends": result}, 200 + + +async def handle_review(body: dict, env) -> tuple[dict, int]: + """Generate a structured literature review.""" + topic = body.get("topic", "").strip() + if not topic: + return {"error": "topic is required"}, 400 + + papers = body.get("papers", []) + style = body.get("style", "comprehensive") + audience = body.get("audience", "researchers") + + context_parts = [ + f"Topic: {topic}", + f"Review style: {style}", + f"Target audience: {audience}", + ] + if papers: + context_parts.append(f"Papers to include:\n{json.dumps(papers[:15], indent=2)}") + + system_prompt = ( + "You are ScholarAI, an expert academic research assistant specialized in " + "writing high-quality literature reviews. Generate well-structured, " + "comprehensive, and academically rigorous literature review sections that " + "synthesize research findings." + ) + user_prompt = ( + f"Generate a literature review for the following:\n\n" + f"{chr(10).join(context_parts)}\n\n" + f"Write a {style} literature review for {audience} that includes:\n" + "1. Introduction to the research area\n" + "2. Historical background and development\n" + "3. Current state of the art\n" + "4. Key methodologies and approaches\n" + "5. Major findings and contributions\n" + "6. Ongoing challenges and open problems\n" + "7. Future research directions\n" + "8. Conclusion" + ) + review = await run_ai(env, system_prompt, user_prompt) + return {"topic": topic, "style": style, "review": review}, 200 + + +# --------------------------------------------------------------------------- +# Route table +# --------------------------------------------------------------------------- + +_POST_ROUTES = { + "/api/discover": handle_discover, + "/api/summarize": handle_summarize, + "/api/citations": handle_citations, + "/api/qa": handle_qa, + "/api/organize": handle_organize, + "/api/trends": handle_trends, + "/api/review": handle_review, +} + +# --------------------------------------------------------------------------- +# Cloudflare Workers entrypoint +# --------------------------------------------------------------------------- + + +class Default(WorkerEntrypoint): + async def fetch(self, request): + path = urlparse(request.url).path.rstrip("/") or "/" + method = request.method.upper() + + if method == "GET" and path == "/": + return _json_response(handle_info()) + + if method == "POST": + handler = _POST_ROUTES.get(path) + if handler is None: + return _error("Not found", 404) + + try: + body = await request.json() + except Exception: + return _error("Invalid JSON body", 400) + + data, status = await handler(body, self.env) + return _json_response(data, status) + + return _error("Not found", 404) diff --git a/tests/test_entry.py b/tests/test_entry.py new file mode 100644 index 0000000..ffd1f7c --- /dev/null +++ b/tests/test_entry.py @@ -0,0 +1,361 @@ +""" +Unit tests for ScholarAI. + +Handler functions from src.entry are called directly with a mocked Cloudflare +env so that tests run without a live Workers runtime. +""" +import json +import pytest + +# --------------------------------------------------------------------------- +# Mock Cloudflare AI binding +# --------------------------------------------------------------------------- + + +class MockAI: + """Minimal stand-in for the Cloudflare Workers AI binding.""" + + DEFAULT_RESPONSE = "This is a mocked AI response for testing purposes." + + async def run(self, model: str, params: dict) -> dict: # noqa: ARG002 + return {"response": self.DEFAULT_RESPONSE} + + +class MockEnv: + """Minimal stand-in for the Cloudflare Workers env object.""" + + AI = MockAI() + + +ENV = MockEnv() + +# --------------------------------------------------------------------------- +# Tests: handle_info +# --------------------------------------------------------------------------- + + +def test_handle_info_name(): + from src.entry import handle_info + + data = handle_info() + assert data["name"] == "ScholarAI" + assert data["version"] == "1.0.0" + + +def test_handle_info_endpoints_listed(): + from src.entry import handle_info + + data = handle_info() + expected = { + "GET /", + "POST /api/discover", + "POST /api/summarize", + "POST /api/citations", + "POST /api/qa", + "POST /api/organize", + "POST /api/trends", + "POST /api/review", + } + assert expected == set(data["endpoints"].keys()) + + +# --------------------------------------------------------------------------- +# Tests: handle_discover +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_discover_success(): + from src.entry import handle_discover + + data, status = await handle_discover({"query": "transformer models"}, ENV) + assert status == 200 + assert data["query"] == "transformer models" + assert "results" in data + + +@pytest.mark.asyncio +async def test_discover_with_fields_and_limit(): + from src.entry import handle_discover + + data, status = await handle_discover( + {"query": "protein folding", "fields": ["biology", "ML"], "limit": 5}, ENV + ) + assert status == 200 + assert data["query"] == "protein folding" + + +@pytest.mark.asyncio +async def test_discover_missing_query(): + from src.entry import handle_discover + + data, status = await handle_discover({}, ENV) + assert status == 400 + assert "error" in data + + +# --------------------------------------------------------------------------- +# Tests: handle_summarize +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_summarize_with_abstract(): + from src.entry import handle_summarize + + data, status = await handle_summarize( + {"title": "Test Paper", "abstract": "This paper proposes a new method."}, ENV + ) + assert status == 200 + assert data["title"] == "Test Paper" + assert "summary" in data + assert len(data["summary"]) > 0 + + +@pytest.mark.asyncio +async def test_summarize_with_content_only(): + from src.entry import handle_summarize + + data, status = await handle_summarize( + {"content": "Full paper content goes here..."}, ENV + ) + assert status == 200 + + +@pytest.mark.asyncio +async def test_summarize_no_input(): + from src.entry import handle_summarize + + data, status = await handle_summarize({}, ENV) + assert status == 400 + assert "error" in data + + +# --------------------------------------------------------------------------- +# Tests: handle_citations +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_citations_success(): + from src.entry import handle_citations + + data, status = await handle_citations( + {"paper": "Attention Is All You Need", "type": "related"}, ENV + ) + assert status == 200 + assert data["paper"] == "Attention Is All You Need" + assert data["citation_type"] == "related" + assert "exploration" in data + + +@pytest.mark.asyncio +async def test_citations_missing_paper(): + from src.entry import handle_citations + + data, status = await handle_citations({}, ENV) + assert status == 400 + assert "error" in data + + +@pytest.mark.asyncio +async def test_citations_forward_type(): + from src.entry import handle_citations + + data, status = await handle_citations( + {"paper": "BERT: Pre-training of Deep Bidirectional Transformers", "type": "forward"}, + ENV, + ) + assert status == 200 + assert data["citation_type"] == "forward" + + +# --------------------------------------------------------------------------- +# Tests: handle_qa +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_qa_success(): + from src.entry import handle_qa + + data, status = await handle_qa({"question": "What is transfer learning?"}, ENV) + assert status == 200 + assert data["question"] == "What is transfer learning?" + assert "answer" in data + assert len(data["answer"]) > 0 + + +@pytest.mark.asyncio +async def test_qa_with_context_and_papers(): + from src.entry import handle_qa + + data, status = await handle_qa( + { + "question": "How does BERT work?", + "context": "BERT is a language model.", + "papers": [{"title": "BERT", "year": 2018}], + }, + ENV, + ) + assert status == 200 + + +@pytest.mark.asyncio +async def test_qa_missing_question(): + from src.entry import handle_qa + + data, status = await handle_qa({}, ENV) + assert status == 400 + assert "error" in data + + +# --------------------------------------------------------------------------- +# Tests: handle_organize +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_organize_success(): + from src.entry import handle_organize + + papers = [{"title": "Paper A", "year": 2020}, {"title": "Paper B", "year": 2021}] + data, status = await handle_organize({"papers": papers, "organize_by": "topic"}, ENV) + assert status == 200 + assert data["organize_by"] == "topic" + assert "organization" in data + + +@pytest.mark.asyncio +async def test_organize_missing_papers(): + from src.entry import handle_organize + + data, status = await handle_organize({}, ENV) + assert status == 400 + assert "error" in data + + +@pytest.mark.asyncio +async def test_organize_by_year(): + from src.entry import handle_organize + + data, status = await handle_organize( + {"papers": [{"title": "Paper A", "year": 2019}], "organize_by": "year"}, ENV + ) + assert status == 200 + assert data["organize_by"] == "year" + + +# --------------------------------------------------------------------------- +# Tests: handle_trends +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_trends_with_field(): + from src.entry import handle_trends + + data, status = await handle_trends({"field": "machine learning"}, ENV) + assert status == 200 + assert data["field"] == "machine learning" + assert "trends" in data + + +@pytest.mark.asyncio +async def test_trends_with_papers(): + from src.entry import handle_trends + + data, status = await handle_trends( + {"papers": [{"title": "Paper A", "year": 2023}]}, ENV + ) + assert status == 200 + + +@pytest.mark.asyncio +async def test_trends_no_field_or_papers(): + from src.entry import handle_trends + + data, status = await handle_trends({}, ENV) + assert status == 400 + assert "error" in data + + +@pytest.mark.asyncio +async def test_trends_with_time_range(): + from src.entry import handle_trends + + data, status = await handle_trends({"field": "NLP", "time_range": "2018-2024"}, ENV) + assert status == 200 + + +# --------------------------------------------------------------------------- +# Tests: handle_review +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_review_success(): + from src.entry import handle_review + + data, status = await handle_review({"topic": "deep learning for NLP"}, ENV) + assert status == 200 + assert data["topic"] == "deep learning for NLP" + assert data["style"] == "comprehensive" + assert "review" in data + assert len(data["review"]) > 0 + + +@pytest.mark.asyncio +async def test_review_with_papers_and_style(): + from src.entry import handle_review + + data, status = await handle_review( + { + "topic": "graph neural networks", + "papers": [{"title": "GNN Survey", "year": 2022}], + "style": "brief", + "audience": "graduate students", + }, + ENV, + ) + assert status == 200 + assert data["style"] == "brief" + + +@pytest.mark.asyncio +async def test_review_missing_topic(): + from src.entry import handle_review + + data, status = await handle_review({}, ENV) + assert status == 400 + assert "error" in data + + +# --------------------------------------------------------------------------- +# Tests: _json_response and _error helpers +# --------------------------------------------------------------------------- + + +def test_json_response_content_type(): + from src.entry import _json_response + + resp = _json_response({"key": "value"}) + assert resp.headers.get("Content-Type") == "application/json" + assert resp.status == 200 + assert json.loads(resp.body) == {"key": "value"} + + +def test_json_response_custom_status(): + from src.entry import _json_response + + resp = _json_response({"error": "bad"}, status=400) + assert resp.status == 400 + + +def test_error_helper(): + from src.entry import _error + + resp = _error("Not found", 404) + assert resp.status == 404 + assert json.loads(resp.body) == {"error": "Not found"} + diff --git a/wrangler.toml b/wrangler.toml new file mode 100644 index 0000000..d2a0606 --- /dev/null +++ b/wrangler.toml @@ -0,0 +1,10 @@ +name = "scholarai" +main = "src/entry.py" +compatibility_date = "2025-11-02" +compatibility_flags = ["python_workers"] + +[ai] +binding = "AI" + +[observability] +enabled = true