diff --git a/api/db/repositories.py b/api/db/repositories.py
index 6608718..7686510 100644
--- a/api/db/repositories.py
+++ b/api/db/repositories.py
@@ -16,4 +16,10 @@ def create_form(session: Session, form: FormSubmission) -> FormSubmission:
session.add(form)
session.commit()
session.refresh(form)
- return form
\ No newline at end of file
+ return form
+
+def get_all_templates(session: Session, limit: int = 100, offset: int = 0) -> list[Template]:
+ return session.exec(select(Template).offset(offset).limit(limit)).all()
+
+def get_form(session: Session, submission_id: int) -> FormSubmission | None:
+ return session.get(FormSubmission, submission_id)
\ No newline at end of file
diff --git a/api/main.py b/api/main.py
index d0b8c79..0a7d8e7 100644
--- a/api/main.py
+++ b/api/main.py
@@ -1,7 +1,25 @@
-from fastapi import FastAPI
+from fastapi import FastAPI, Request
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse
from api.routes import templates, forms
+from api.errors.base import AppError
+from typing import Union
app = FastAPI()
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+@app.exception_handler(AppError)
+def app_error_handler(request: Request, exc: AppError):
+ return JSONResponse(
+ status_code=exc.status_code,
+ content={"detail": exc.message}
+ )
+
app.include_router(templates.router)
app.include_router(forms.router)
\ No newline at end of file
diff --git a/api/routes/forms.py b/api/routes/forms.py
index f3430ed..3491d4e 100644
--- a/api/routes/forms.py
+++ b/api/routes/forms.py
@@ -1,25 +1,82 @@
+import os
from fastapi import APIRouter, Depends
+from fastapi.responses import FileResponse
from sqlmodel import Session
from api.deps import get_db
from api.schemas.forms import FormFill, FormFillResponse
-from api.db.repositories import create_form, get_template
+from api.db.repositories import create_form, get_template, get_form
from api.db.models import FormSubmission
from api.errors.base import AppError
from src.controller import Controller
router = APIRouter(prefix="/forms", tags=["forms"])
+
@router.post("/fill", response_model=FormFillResponse)
def fill_form(form: FormFill, db: Session = Depends(get_db)):
- if not get_template(db, form.template_id):
+ # Single DB query (fixes issue #149 - redundant query)
+ template = get_template(db, form.template_id)
+ if not template:
raise AppError("Template not found", status_code=404)
- fetched_template = get_template(db, form.template_id)
+ try:
+ controller = Controller()
+ # FileManipulator.fill_form expects fields as a list of key strings
+ fields_list = list(template.fields.keys()) if isinstance(template.fields, dict) else template.fields
+ path = controller.fill_form(
+ user_input=form.input_text,
+ fields=fields_list,
+ pdf_form_path=template.pdf_path
+ )
+ except ConnectionError:
+ raise AppError(
+ "Could not connect to Ollama. Make sure ollama serve is running.",
+ status_code=503
+ )
+ except Exception as e:
+ raise AppError(f"PDF filling failed: {str(e)}", status_code=500)
+
+ # Guard: controller returned None instead of a file path
+ if not path:
+ raise AppError(
+ "PDF generation failed — no output file was produced. "
+ "Check that the PDF template is a valid fillable form and Ollama is running.",
+ status_code=500
+ )
- controller = Controller()
- path = controller.fill_form(user_input=form.input_text, fields=fetched_template.fields, pdf_form_path=fetched_template.pdf_path)
+ if not os.path.exists(path):
+ raise AppError(
+ f"PDF was generated but file not found at: {path}",
+ status_code=500
+ )
- submission = FormSubmission(**form.model_dump(), output_pdf_path=path)
+ submission = FormSubmission(
+ **form.model_dump(),
+ output_pdf_path=path
+ )
return create_form(db, submission)
+@router.get("/{submission_id}", response_model=FormFillResponse)
+def get_submission(submission_id: int, db: Session = Depends(get_db)):
+ submission = get_form(db, submission_id)
+ if not submission:
+ raise AppError("Submission not found", status_code=404)
+ return submission
+
+
+@router.get("/download/{submission_id}")
+def download_filled_pdf(submission_id: int, db: Session = Depends(get_db)):
+ submission = get_form(db, submission_id)
+ if not submission:
+ raise AppError("Submission not found", status_code=404)
+
+ file_path = submission.output_pdf_path
+ if not os.path.exists(file_path):
+ raise AppError("PDF file not found on server", status_code=404)
+
+ return FileResponse(
+ path=file_path,
+ media_type="application/pdf",
+ filename=os.path.basename(file_path)
+ )
\ No newline at end of file
diff --git a/api/routes/templates.py b/api/routes/templates.py
index 5c2281b..9419ae6 100644
--- a/api/routes/templates.py
+++ b/api/routes/templates.py
@@ -1,16 +1,89 @@
-from fastapi import APIRouter, Depends
+import os
+import shutil
+import uuid
+from fastapi import APIRouter, Depends, UploadFile, File, Form
from sqlmodel import Session
from api.deps import get_db
-from api.schemas.templates import TemplateCreate, TemplateResponse
-from api.db.repositories import create_template
+from api.schemas.templates import TemplateResponse
+from api.db.repositories import create_template, get_all_templates
from api.db.models import Template
-from src.controller import Controller
+from api.errors.base import AppError
router = APIRouter(prefix="/templates", tags=["templates"])
+# Save directly into src/inputs/ — stable location, won't get wiped
+TEMPLATES_DIR = os.path.join("src", "inputs")
+os.makedirs(TEMPLATES_DIR, exist_ok=True)
+
+
@router.post("/create", response_model=TemplateResponse)
-def create(template: TemplateCreate, db: Session = Depends(get_db)):
- controller = Controller()
- template_path = controller.create_template(template.pdf_path)
- tpl = Template(**template.model_dump(exclude={"pdf_path"}), pdf_path=template_path)
- return create_template(db, tpl)
\ No newline at end of file
+async def create(
+ name: str = Form(...),
+ file: UploadFile = File(...),
+ db: Session = Depends(get_db)
+):
+ # Validate PDF
+ if not file.filename.endswith(".pdf"):
+ raise AppError("Only PDF files are allowed", status_code=400)
+
+ # Save uploaded file with unique name into src/inputs/
+ unique_name = f"{uuid.uuid4().hex}_{file.filename}"
+ save_path = os.path.join(TEMPLATES_DIR, unique_name)
+
+ with open(save_path, "wb") as f:
+ shutil.copyfileobj(file.file, f)
+
+ # Extract fields using commonforms + pypdf
+ # Store as simple list of field name strings — what Filler expects
+ try:
+ from commonforms import prepare_form
+ from pypdf import PdfReader
+
+ # Read real field names directly from original PDF
+ # Use /T (internal name) as both key and label
+ # Real names like "JobTitle", "Phone Number" are already human-readable
+ reader = PdfReader(save_path)
+ raw_fields = reader.get_fields() or {}
+
+ fields = {}
+ for internal_name, field_data in raw_fields.items():
+ # Use /TU tooltip if available, otherwise prettify /T name
+ label = None
+ if isinstance(field_data, dict):
+ label = field_data.get("/TU")
+ if not label:
+ # Prettify: "JobTitle" → "Job Title", "DATE7_af_date" → "Date"
+ import re
+ label = re.sub(r'([a-z])([A-Z])', r'\1 \2', internal_name)
+ label = re.sub(r'_af_.*$', '', label) # strip "_af_date" suffix
+ label = label.replace('_', ' ').strip().title()
+ fields[internal_name] = label
+
+ except Exception as e:
+ print(f"Field extraction failed: {e}")
+ fields = []
+
+ # Save to DB
+ tpl = Template(name=name, pdf_path=save_path, fields=fields)
+ return create_template(db, tpl)
+
+
+@router.get("", response_model=list[TemplateResponse])
+def list_templates(
+ limit: int = 100,
+ offset: int = 0,
+ db: Session = Depends(get_db)
+):
+ return get_all_templates(db, limit=limit, offset=offset)
+
+
+@router.get("/{template_id}", response_model=TemplateResponse)
+def get_template_by_id(
+ template_id: int,
+ db: Session = Depends(get_db)
+):
+ from api.db.repositories import get_template
+ tpl = get_template(db, template_id)
+ if not tpl:
+ raise AppError("Template not found", status_code=404)
+ return tpl
\ No newline at end of file
diff --git a/frontend/index.html b/frontend/index.html
new file mode 100644
index 0000000..a3b0083
--- /dev/null
+++ b/frontend/index.html
@@ -0,0 +1,467 @@
+
+
+
+
+
+FireForm — Report Once, File Everywhere
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
UN Digital Public Good · GSoC 2026
+
REPORTONCE.
+
Describe any incident in plain language. FireForm uses a locally-running AI to extract every relevant detail and auto-fill all required agency forms — instantly and privately.
+
+
+
+
+
1
+
Upload Template
Any fillable PDF form
+
+
+
2
+
Select Template
Choose from saved forms
+
+
+
3
+
Describe Incident
Plain language report
+
+
+
4
+
Download PDF
All fields auto-filled
+
+
+
+
+
+
+
+
+
No submissions yet this session.
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/llm.py b/src/llm.py
index 70937f9..2463e0f 100644
--- a/src/llm.py
+++ b/src/llm.py
@@ -1,14 +1,19 @@
import json
import os
+import time
import requests
class LLM:
def __init__(self, transcript_text=None, target_fields=None, json=None):
+ """
+ target_fields: dict or list containing the template field names to extract
+ (dict format: {"field_name": "human_label"}, list format: ["field_name1", "field_name2"])
+ """
if json is None:
json = {}
self._transcript_text = transcript_text # str
- self._target_fields = target_fields # List, contains the template field.
+ self._target_fields = target_fields # dict or list
self._json = json # dictionary
def type_check_all(self):
@@ -17,64 +22,204 @@ def type_check_all(self):
f"ERROR in LLM() attributes ->\
Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}"
)
- elif type(self._target_fields) is not list:
+ if not isinstance(self._target_fields, (list, dict)):
raise TypeError(
f"ERROR in LLM() attributes ->\
- Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}"
+ Target fields must be a list or dict. Input:\n\ttarget_fields: {self._target_fields}"
+ )
+
+ def build_batch_prompt(self) -> str:
+ """
+ Build a single prompt that extracts ALL fields at once.
+ Sends human-readable labels as context so Mistral understands
+ what each internal field name means.
+ Fixes Issue #196 — reduces N Ollama calls to 1.
+ """
+ if isinstance(self._target_fields, dict):
+ fields_lines = "\n".join(
+ f' "{k}": null // {v if v and v != k else k}'
+ for k, v in self._target_fields.items()
+ )
+ else:
+ fields_lines = "\n".join(
+ f' "{f}": null'
+ for f in self._target_fields
)
- def build_prompt(self, current_field):
+ prompt = f"""You are filling out an official form. Extract values from the transcript below.
+
+FORM FIELDS (each line: "internal_key": null // visible label on form):
+{{
+{fields_lines}
+}}
+
+RULES:
+1. Return ONLY a valid JSON object — no explanation, no markdown, no extra text
+2. Use the visible label (after //) to understand what each field means
+3. Fill each key with the matching value from the transcript
+4. If a value is not found in the transcript, use null
+5. Never invent or guess values not present in the transcript
+6. For multiple values (e.g. multiple victims), use a semicolon-separated string: "Name1; Name2"
+7. Distinguish roles carefully: Officer/Employee is NOT the same as Victim or Suspect
+
+TRANSCRIPT:
+{self._transcript_text}
+
+JSON:"""
+
+ return prompt
+
+ def build_prompt(self, current_field: str) -> str:
"""
- This method is in charge of the prompt engineering. It creates a specific prompt for each target field.
- @params: current_field -> represents the current element of the json that is being prompted.
+ Legacy single-field prompt — kept for backward compatibility.
+ Used as fallback if batch parsing fails.
"""
- prompt = f"""
- SYSTEM PROMPT:
- You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings.
- You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return
- only a single string containing the identified value for the JSON field.
- If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";".
- If you don't identify the value in the provided text, return "-1".
- ---
- DATA:
- Target JSON field to find in text: {current_field}
-
- TEXT: {self._transcript_text}
- """
+ field_lower = current_field.lower()
+ is_plural = current_field.endswith('s') and not current_field.lower().endswith('ss')
+
+ if any(w in field_lower for w in ['officer', 'employee', 'dispatcher', 'caller', 'reporting', 'supervisor']):
+ role_guidance = """
+ROLE: Extract the PRIMARY OFFICER/EMPLOYEE/DISPATCHER
+- This is typically the person speaking or reporting the incident
+- DO NOT extract victims, witnesses, or members of the public
+- Example: "Officer Smith reporting... victims are John and Jane" → extract "Smith"
+"""
+ elif any(w in field_lower for w in ['victim', 'injured', 'affected', 'casualty', 'patient']):
+ role_guidance = f"""
+ROLE: Extract VICTIM/AFFECTED PERSON(S)
+- Focus on people who experienced harm
+- Ignore officers, dispatchers, and witnesses
+{'- Return ALL names separated by ";"' if is_plural else '- Return the FIRST/PRIMARY victim'}
+"""
+ elif any(w in field_lower for w in ['location', 'address', 'street', 'place', 'where']):
+ role_guidance = """
+ROLE: Extract LOCATION/ADDRESS
+- Extract WHERE the incident occurred
+- Return only the incident location, not other addresses mentioned
+"""
+ elif any(w in field_lower for w in ['date', 'time', 'when', 'occurred', 'reported']):
+ role_guidance = """
+ROLE: Extract DATE/TIME
+- Extract WHEN the incident occurred
+- Return in the format it appears in the text
+"""
+ elif any(w in field_lower for w in ['phone', 'number', 'contact', 'tel']):
+ role_guidance = "ROLE: Extract PHONE NUMBER — return exactly as it appears in text"
+ elif any(w in field_lower for w in ['email', 'mail']):
+ role_guidance = "ROLE: Extract EMAIL ADDRESS"
+ elif any(w in field_lower for w in ['department', 'unit', 'division']):
+ role_guidance = "ROLE: Extract DEPARTMENT/UNIT name"
+ elif any(w in field_lower for w in ['title', 'job', 'role', 'rank', 'position']):
+ role_guidance = "ROLE: Extract JOB TITLE or RANK"
+ elif any(w in field_lower for w in ['id', 'badge', 'identifier']):
+ role_guidance = "ROLE: Extract ID or BADGE NUMBER"
+ elif any(w in field_lower for w in ['description', 'incident', 'detail', 'nature', 'summary']):
+ role_guidance = "ROLE: Extract a brief INCIDENT DESCRIPTION"
+ else:
+ role_guidance = f"""
+ROLE: Generic extraction for field "{current_field}"
+{'- Return MULTIPLE values separated by ";" if applicable' if is_plural else '- Return the PRIMARY matching value'}
+"""
+
+ prompt = f"""
+SYSTEM: You are extracting specific information from an incident report transcript.
+
+FIELD TO EXTRACT: {current_field}
+{'[SINGULAR - Extract ONE value]' if not is_plural else '[PLURAL - Extract MULTIPLE values separated by semicolon]'}
+
+EXTRACTION RULES:
+{role_guidance}
+
+CRITICAL RULES:
+1. Read the ENTIRE text before answering
+2. Extract ONLY what belongs to this specific field
+3. Return values exactly as they appear in the text
+4. If not found, return: -1
+
+TRANSCRIPT:
+{self._transcript_text}
+
+ANSWER: Return ONLY the extracted value(s), nothing else."""
return prompt
def main_loop(self):
- # self.type_check_all()
- for field in self._target_fields.keys():
- prompt = self.build_prompt(field)
- # print(prompt)
- # ollama_url = "http://localhost:11434/api/generate"
- ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
- ollama_url = f"{ollama_host}/api/generate"
-
- payload = {
- "model": "mistral",
- "prompt": prompt,
- "stream": False, # don't really know why --> look into this later.
- }
+ """
+ Single batch Ollama call — extracts ALL fields in one request.
+ Falls back to per-field extraction if JSON parsing fails.
+ Fixes Issue #196 (O(N) → O(1) LLM calls).
+ """
+ ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
+ ollama_url = f"{ollama_host}/api/generate"
- try:
- response = requests.post(ollama_url, json=payload)
- response.raise_for_status()
- except requests.exceptions.ConnectionError:
- raise ConnectionError(
- f"Could not connect to Ollama at {ollama_url}. "
- "Please ensure Ollama is running and accessible."
- )
- except requests.exceptions.HTTPError as e:
- raise RuntimeError(f"Ollama returned an error: {e}")
-
- # parse response
- json_data = response.json()
- parsed_response = json_data["response"]
- # print(parsed_response)
- self.add_response_to_json(field, parsed_response)
+ # Get field keys for result mapping
+ if isinstance(self._target_fields, dict):
+ field_keys = list(self._target_fields.keys())
+ else:
+ field_keys = list(self._target_fields)
+
+ # ── Single batch call ─────────────────────────────────────
+ prompt = self.build_batch_prompt()
+ payload = {"model": "mistral", "prompt": prompt, "stream": False}
+
+ # Progress logging (#132)
+ if isinstance(self._target_fields, dict):
+ field_count = len(self._target_fields)
+ field_names = list(self._target_fields.values())
+ else:
+ field_count = len(self._target_fields)
+ field_names = list(self._target_fields)
+
+ print(f"[LOG] Starting batch extraction for {field_count} field(s)...")
+ for i, name in enumerate(field_names, 1):
+ print(f"[LOG] Queuing field {i}/{field_count} -> '{name}'")
+ print(f"[LOG] Sending single batch request to Ollama (model: mistral)...")
+ _start = time.time()
+
+ try:
+ timeout = int(os.getenv("OLLAMA_TIMEOUT", "120"))
+ response = requests.post(ollama_url, json=payload, timeout=timeout)
+ response.raise_for_status()
+ _elapsed = time.time() - _start
+ print(f"[LOG] Ollama responded in {_elapsed:.2f}s")
+ except requests.exceptions.ConnectionError:
+ raise ConnectionError(
+ f"Could not connect to Ollama at {ollama_url}. "
+ "Please ensure Ollama is running and accessible."
+ )
+ except requests.exceptions.Timeout:
+ raise RuntimeError(
+ f"Ollama timed out after {timeout}s. "
+ "Try increasing the OLLAMA_TIMEOUT environment variable."
+ )
+ except requests.exceptions.HTTPError as e:
+ raise RuntimeError(f"Ollama returned an error: {e}")
+
+ raw = response.json()["response"].strip()
+
+ # Strip markdown code fences if Mistral wraps in ```json ... ```
+ raw = raw.replace("```json", "").replace("```", "").strip()
+
+ print("----------------------------------")
+ print("\t[LOG] Raw Mistral batch response:")
+ print(raw)
+
+ # ── Parse JSON response ───────────────────────────────────
+ try:
+ extracted = json.loads(raw)
+ for key in field_keys:
+ val = extracted.get(key)
+ if val and str(val).lower() not in ("null", "none", ""):
+ self._json[key] = val
+ else:
+ self._json[key] = None
+
+ print("\t[LOG] Batch extraction successful.")
+
+ except json.JSONDecodeError:
+ print("\t[WARN] Batch JSON parse failed — falling back to per-field extraction")
+ self._json = {}
+ self._fallback_per_field(ollama_url, field_keys)
print("----------------------------------")
print("\t[LOG] Resulting JSON created from the input text:")
@@ -83,10 +228,38 @@ def main_loop(self):
return self
+ def _fallback_per_field(self, ollama_url: str, field_keys: list):
+ """
+ Legacy per-field extraction — used only when batch JSON parse fails.
+ """
+ print("\t[LOG] Running fallback per-field extraction...")
+
+ total = len(field_keys)
+ for i, field in enumerate(field_keys, 1):
+ print(f"[LOG] Extracting field {i}/{total} -> '{field}'")
+ if isinstance(self._target_fields, dict):
+ label = self._target_fields.get(field, field)
+ if not label or label == field:
+ label = field
+ else:
+ label = field
+
+ prompt = self.build_prompt(label)
+ payload = {"model": "mistral", "prompt": prompt, "stream": False}
+
+ try:
+ response = requests.post(ollama_url, json=payload)
+ response.raise_for_status()
+ parsed_response = response.json()["response"]
+ self.add_response_to_json(field, parsed_response)
+ except Exception as e:
+ print(f"\t[WARN] Failed to extract field '{field}': {e}")
+ self._json[field] = None
+
def add_response_to_json(self, field, value):
"""
- this method adds the following value under the specified field,
- or under a new field if the field doesn't exist, to the json dict
+ Add extracted value under field name.
+ Handles plural (semicolon-separated) values.
"""
value = value.strip().replace('"', "")
parsed_value = None
@@ -94,42 +267,35 @@ def add_response_to_json(self, field, value):
if value != "-1":
parsed_value = value
- if ";" in value:
- parsed_value = self.handle_plural_values(value)
+ if parsed_value and ";" in parsed_value:
+ parsed_value = self.handle_plural_values(parsed_value)
- if field in self._json.keys():
- self._json[field].append(parsed_value)
+ if field in self._json:
+ existing = self._json[field]
+ if isinstance(existing, list):
+ if isinstance(parsed_value, list):
+ existing.extend(parsed_value)
+ else:
+ existing.append(parsed_value)
+ else:
+ self._json[field] = [existing, parsed_value]
else:
self._json[field] = parsed_value
- return
-
def handle_plural_values(self, plural_value):
"""
- This method handles plural values.
- Takes in strings of the form 'value1; value2; value3; ...; valueN'
- returns a list with the respective values -> [value1, value2, value3, ..., valueN]
+ Split semicolon-separated values into a list.
+ "Mark Smith; Jane Doe" → ["Mark Smith", "Jane Doe"]
"""
if ";" not in plural_value:
raise ValueError(
f"Value is not plural, doesn't have ; separator, Value: {plural_value}"
)
- print(
- f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..."
- )
- values = plural_value.split(";")
-
- # Remove trailing leading whitespace
- for i in range(len(values)):
- current = i + 1
- if current < len(values):
- clean_value = values[current].lstrip()
- values[current] = clean_value
-
+ print(f"\t[LOG]: Formatting plural values for JSON, [For input {plural_value}]...")
+ values = [v.strip() for v in plural_value.split(";") if v.strip()]
print(f"\t[LOG]: Resulting formatted list of values: {values}")
-
return values
def get_data(self):
- return self._json
+ return self._json
\ No newline at end of file
diff --git a/src/main.py b/src/main.py
index 5bb632b..e07578b 100644
--- a/src/main.py
+++ b/src/main.py
@@ -1,5 +1,6 @@
import os
# from backend import Fill
+from typing import Union
from commonforms import prepare_form
from pypdf import PdfReader
from controller import Controller
diff --git a/tests/conftest.py b/tests/conftest.py
index 7cb4db3..ff92c19 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,12 +3,10 @@
from sqlalchemy.pool import StaticPool
import pytest
-
from api.main import app
from api.deps import get_db
from api.db.models import Template, FormSubmission
-# In-memory SQLite database for tests
TEST_DATABASE_URL = "sqlite://"
engine = create_engine(
@@ -23,12 +21,12 @@ def override_get_db():
yield session
-# Apply dependency override
app.dependency_overrides[get_db] = override_get_db
-@pytest.fixture(scope="session", autouse=True)
-def create_test_db():
+@pytest.fixture(autouse=True)
+def reset_db():
+ SQLModel.metadata.drop_all(engine)
SQLModel.metadata.create_all(engine)
yield
SQLModel.metadata.drop_all(engine)
@@ -37,3 +35,10 @@ def create_test_db():
@pytest.fixture
def client():
return TestClient(app)
+
+
+@pytest.fixture
+def db_session():
+ """Direct DB session for test setup."""
+ with Session(engine) as session:
+ yield session
diff --git a/tests/test_forms.py b/tests/test_forms.py
index 8f432bf..5e32755 100644
--- a/tests/test_forms.py
+++ b/tests/test_forms.py
@@ -1,25 +1,107 @@
-def test_submit_form(client):
- pass
- # First create a template
- # form_payload = {
- # "template_id": 3,
- # "input_text": "Hi. The employee's name is John Doe. His job title is managing director. His department supervisor is Jane Doe. His phone number is 123456. His email is jdoe@ucsc.edu. The signature is , and the date is 01/02/2005",
- # }
-
- # template_res = client.post("/templates/", json=template_payload)
- # template_id = template_res.json()["id"]
-
- # # Submit a form
- # form_payload = {
- # "template_id": template_id,
- # "data": {"rating": 5, "comment": "Great service"},
- # }
-
- # response = client.post("/forms/", json=form_payload)
-
- # assert response.status_code == 200
-
- # data = response.json()
- # assert data["id"] is not None
- # assert data["template_id"] == template_id
- # assert data["data"] == form_payload["data"]
+"""
+Tests for /forms endpoints.
+Closes #165, #205, #163
+"""
+
+import pytest
+from unittest.mock import patch
+from api.db.models import Template, FormSubmission
+from datetime import datetime
+
+
+# ── helpers ───────────────────────────────────────────────────────────────────
+
+def make_template(db_session):
+ t = Template(
+ name="Test Form",
+ fields={"JobTitle": "Job Title"},
+ pdf_path="/tmp/test.pdf",
+ created_at=datetime.utcnow(),
+ )
+ db_session.add(t)
+ db_session.commit()
+ db_session.refresh(t)
+ return t.id
+
+
+def make_submission(db_session, template_id, output_path="/tmp/filled.pdf"):
+ s = FormSubmission(
+ template_id=template_id,
+ input_text="John Smith is a firefighter.",
+ output_pdf_path=output_path,
+ created_at=datetime.utcnow(),
+ )
+ db_session.add(s)
+ db_session.commit()
+ db_session.refresh(s)
+ return s.id
+
+
+# ── POST /forms/fill ──────────────────────────────────────────────────────────
+
+class TestFillForm:
+
+ def test_fill_form_template_not_found(self, client):
+ """Returns 404 when template_id does not exist."""
+ response = client.post("/forms/fill", json={
+ "template_id": 999999,
+ "input_text": "John Smith is a firefighter.",
+ })
+ assert response.status_code == 404
+
+ def test_fill_form_missing_fields_returns_422(self, client):
+ """Returns 422 when required fields are missing."""
+ response = client.post("/forms/fill", json={
+ "template_id": 1,
+ })
+ assert response.status_code == 422
+
+ def test_fill_form_ollama_down_returns_503(self, client, db_session):
+ """Returns 503 when Ollama is not reachable."""
+ template_id = make_template(db_session)
+
+ with patch("src.controller.Controller.fill_form",
+ side_effect=ConnectionError("Ollama not running")):
+ response = client.post("/forms/fill", json={
+ "template_id": template_id,
+ "input_text": "John Smith is a firefighter.",
+ })
+
+ assert response.status_code == 503
+
+
+# ── GET /forms/{submission_id} ────────────────────────────────────────────────
+
+class TestGetSubmission:
+
+ def test_get_submission_not_found(self, client):
+ """Returns 404 for non-existent submission ID."""
+ response = client.get("/forms/999999")
+ assert response.status_code == 404
+
+ def test_get_submission_invalid_id(self, client):
+ """Returns 422 for non-integer submission ID."""
+ response = client.get("/forms/not-an-id")
+ assert response.status_code == 422
+
+
+# ── GET /forms/download/{submission_id} ───────────────────────────────────────
+
+class TestDownloadSubmission:
+
+ def test_download_not_found_submission(self, client):
+ """Returns 404 when submission does not exist."""
+ response = client.get("/forms/download/999999")
+ assert response.status_code == 404
+
+ def test_download_file_missing_on_disk(self, client, db_session):
+ """Returns 404 when submission exists but PDF missing on disk."""
+ template_id = make_template(db_session)
+ submission_id = make_submission(
+ db_session, template_id, "/nonexistent/filled.pdf"
+ )
+
+ with patch("os.path.exists", return_value=False):
+ response = client.get(f"/forms/download/{submission_id}")
+
+ assert response.status_code == 404
diff --git a/tests/test_llm.py b/tests/test_llm.py
new file mode 100644
index 0000000..cfe483b
--- /dev/null
+++ b/tests/test_llm.py
@@ -0,0 +1,278 @@
+"""
+Unit tests for src/llm.py — LLM class.
+
+Closes: #186 (Unit tests for LLM class methods)
+Covers: batch prompt, per-field prompt, add_response_to_json,
+ handle_plural_values, type_check_all, main_loop (mocked)
+"""
+
+import json
+import pytest
+from unittest.mock import patch, MagicMock
+from src.llm import LLM
+
+
+# ── Fixtures ─────────────────────────────────────────────────────────────────
+
+@pytest.fixture
+def dict_fields():
+ """Realistic dict fields: {internal_name: human_label}"""
+ return {
+ "NAME/SID": "Employee Or Student Name",
+ "JobTitle": "Job Title",
+ "Department": "Department",
+ "Phone Number": "Phone Number",
+ "email": "Email",
+ }
+
+@pytest.fixture
+def list_fields():
+ """Legacy list fields: [internal_name, ...]"""
+ return ["officer_name", "location", "incident_date"]
+
+@pytest.fixture
+def transcript():
+ return (
+ "Employee name is John Smith. Employee ID is EMP-2024-789. "
+ "Job title is Firefighter Paramedic. Department is Emergency Medical Services. "
+ "Phone number is 916-555-0147."
+ )
+
+@pytest.fixture
+def llm_dict(dict_fields, transcript):
+ return LLM(transcript_text=transcript, target_fields=dict_fields)
+
+@pytest.fixture
+def llm_list(list_fields, transcript):
+ return LLM(transcript_text=transcript, target_fields=list_fields)
+
+
+# ── type_check_all ────────────────────────────────────────────────────────────
+
+class TestTypeCheckAll:
+
+ def test_raises_on_non_string_transcript(self, dict_fields):
+ llm = LLM(transcript_text=12345, target_fields=dict_fields)
+ with pytest.raises(TypeError, match="Transcript must be text"):
+ llm.type_check_all()
+
+ def test_raises_on_none_transcript(self, dict_fields):
+ llm = LLM(transcript_text=None, target_fields=dict_fields)
+ with pytest.raises(TypeError):
+ llm.type_check_all()
+
+ def test_raises_on_invalid_fields_type(self, transcript):
+ llm = LLM(transcript_text=transcript, target_fields="not_a_list_or_dict")
+ with pytest.raises(TypeError, match="list or dict"):
+ llm.type_check_all()
+
+ def test_passes_with_dict_fields(self, llm_dict):
+ # Should not raise
+ llm_dict.type_check_all()
+
+ def test_passes_with_list_fields(self, llm_list):
+ # Should not raise
+ llm_list.type_check_all()
+
+
+# ── build_batch_prompt ────────────────────────────────────────────────────────
+
+class TestBuildBatchPrompt:
+
+ def test_contains_all_field_keys(self, llm_dict, dict_fields):
+ prompt = llm_dict.build_batch_prompt()
+ for key in dict_fields.keys():
+ assert key in prompt, f"Field key '{key}' missing from batch prompt"
+
+ def test_contains_human_labels(self, llm_dict, dict_fields):
+ prompt = llm_dict.build_batch_prompt()
+ for label in dict_fields.values():
+ assert label in prompt, f"Label '{label}' missing from batch prompt"
+
+ def test_contains_transcript(self, llm_dict, transcript):
+ prompt = llm_dict.build_batch_prompt()
+ assert transcript in prompt
+
+ def test_contains_json_instruction(self, llm_dict):
+ prompt = llm_dict.build_batch_prompt()
+ assert "JSON" in prompt
+
+ def test_list_fields_batch_prompt(self, llm_list, list_fields):
+ prompt = llm_list.build_batch_prompt()
+ for field in list_fields:
+ assert field in prompt
+
+ def test_labels_used_as_comments(self, llm_dict):
+ """Human labels should appear after // in the prompt"""
+ prompt = llm_dict.build_batch_prompt()
+ assert "//" in prompt
+
+
+# ── build_prompt (legacy per-field) ──────────────────────────────────────────
+
+class TestBuildPrompt:
+
+ def test_officer_field_gets_officer_guidance(self, llm_dict):
+ prompt = llm_dict.build_prompt("officer_name")
+ assert "OFFICER" in prompt.upper() or "EMPLOYEE" in prompt.upper()
+
+ def test_location_field_gets_location_guidance(self, llm_dict):
+ prompt = llm_dict.build_prompt("incident_location")
+ assert "LOCATION" in prompt.upper() or "ADDRESS" in prompt.upper()
+
+ def test_victim_field_gets_victim_guidance(self, llm_dict):
+ prompt = llm_dict.build_prompt("victim_name")
+ assert "VICTIM" in prompt.upper()
+
+ def test_phone_field_gets_phone_guidance(self, llm_dict):
+ prompt = llm_dict.build_prompt("Phone Number")
+ assert "PHONE" in prompt.upper()
+
+ def test_prompt_contains_transcript(self, llm_dict, transcript):
+ prompt = llm_dict.build_prompt("some_field")
+ assert transcript in prompt
+
+ def test_generic_field_still_builds_prompt(self, llm_dict):
+ prompt = llm_dict.build_prompt("textbox_0_0")
+ assert len(prompt) > 50
+
+
+# ── handle_plural_values ──────────────────────────────────────────────────────
+
+class TestHandlePluralValues:
+
+ def test_splits_on_semicolon(self, llm_dict):
+ result = llm_dict.handle_plural_values("Mark Smith;Jane Doe")
+ assert "Mark Smith" in result
+ assert "Jane Doe" in result
+
+ def test_strips_whitespace(self, llm_dict):
+ result = llm_dict.handle_plural_values("Mark Smith; Jane Doe; Bob")
+ assert all(v == v.strip() for v in result)
+
+ def test_returns_list(self, llm_dict):
+ result = llm_dict.handle_plural_values("A;B;C")
+ assert isinstance(result, list)
+
+ def test_raises_without_semicolon(self, llm_dict):
+ with pytest.raises(ValueError, match="separator"):
+ llm_dict.handle_plural_values("no semicolon here")
+
+ def test_three_values(self, llm_dict):
+ result = llm_dict.handle_plural_values("Alice;Bob;Charlie")
+ assert len(result) == 3
+
+
+# ── add_response_to_json ──────────────────────────────────────────────────────
+
+class TestAddResponseToJson:
+
+ def test_stores_value_under_field(self, llm_dict):
+ llm_dict.add_response_to_json("NAME/SID", "John Smith")
+ assert llm_dict._json["NAME/SID"] == "John Smith"
+
+ def test_ignores_minus_one(self, llm_dict):
+ llm_dict.add_response_to_json("email", "-1")
+ assert llm_dict._json["email"] is None
+
+ def test_strips_quotes(self, llm_dict):
+ llm_dict.add_response_to_json("JobTitle", '"Firefighter"')
+ assert llm_dict._json["JobTitle"] == "Firefighter"
+
+ def test_strips_whitespace(self, llm_dict):
+ llm_dict.add_response_to_json("Department", " EMS ")
+ assert llm_dict._json["Department"] == "EMS"
+
+ def test_plural_value_becomes_list(self, llm_dict):
+ llm_dict.add_response_to_json("victims", "Mark Smith;Jane Doe")
+ assert isinstance(llm_dict._json["victims"], list)
+
+ def test_existing_field_becomes_list(self, llm_dict):
+ """Adding to existing field should not overwrite silently."""
+ llm_dict._json["NAME/SID"] = "John"
+ llm_dict.add_response_to_json("NAME/SID", "Jane")
+ assert isinstance(llm_dict._json["NAME/SID"], list)
+
+
+# ── get_data ──────────────────────────────────────────────────────────────────
+
+class TestGetData:
+
+ def test_returns_dict(self, llm_dict):
+ assert isinstance(llm_dict.get_data(), dict)
+
+ def test_returns_same_reference_as_internal_json(self, llm_dict):
+ llm_dict._json["test_key"] = "test_value"
+ assert llm_dict.get_data()["test_key"] == "test_value"
+
+
+# ── main_loop (mocked Ollama) ─────────────────────────────────────────────────
+
+class TestMainLoop:
+
+ def _mock_response(self, json_body: dict):
+ """Build a mock requests.Response returning a valid Mistral JSON reply."""
+ mock_resp = MagicMock()
+ mock_resp.raise_for_status = MagicMock()
+ mock_resp.json.return_value = {
+ "response": json.dumps(json_body)
+ }
+ return mock_resp
+
+ def test_batch_success_fills_all_fields(self, llm_dict, dict_fields):
+ expected = {
+ "NAME/SID": "John Smith",
+ "JobTitle": "Firefighter Paramedic",
+ "Department": "Emergency Medical Services",
+ "Phone Number": "916-555-0147",
+ "email": None,
+ }
+ with patch("requests.post", return_value=self._mock_response(expected)):
+ llm_dict.main_loop()
+
+ result = llm_dict.get_data()
+ assert result["NAME/SID"] == "John Smith"
+ assert result["JobTitle"] == "Firefighter Paramedic"
+ assert result["Department"] == "Emergency Medical Services"
+ assert result["Phone Number"] == "916-555-0147"
+
+ def test_batch_makes_exactly_one_ollama_call(self, llm_dict, dict_fields):
+ """Core performance requirement — O(1) not O(N)."""
+ expected = {k: "value" for k in dict_fields.keys()}
+ with patch("requests.post", return_value=self._mock_response(expected)) as mock_post:
+ llm_dict.main_loop()
+
+ assert mock_post.call_count == 1, (
+ f"Expected 1 Ollama call, got {mock_post.call_count}. "
+ "main_loop() must use batch extraction, not per-field."
+ )
+
+ def test_fallback_on_invalid_json(self, llm_dict, dict_fields):
+ """If Mistral returns non-JSON, fallback per-field runs without crash."""
+ bad_response = MagicMock()
+ bad_response.raise_for_status = MagicMock()
+ bad_response.json.return_value = {"response": "This is not JSON at all."}
+
+ good_response = MagicMock()
+ good_response.raise_for_status = MagicMock()
+ good_response.json.return_value = {"response": "John Smith"}
+
+ # First call returns bad JSON, rest return single values
+ with patch("requests.post", side_effect=[bad_response] + [good_response] * len(dict_fields)):
+ llm_dict.main_loop() # should not raise
+
+ def test_connection_error_raises_connection_error(self, llm_dict):
+ import requests as req
+ with patch("requests.post", side_effect=req.exceptions.ConnectionError):
+ with pytest.raises(ConnectionError, match="Ollama"):
+ llm_dict.main_loop()
+
+ def test_null_values_stored_as_none(self, llm_dict, dict_fields):
+ """Mistral returning null should be stored as None, not the string 'null'."""
+ response_with_nulls = {k: None for k in dict_fields.keys()}
+ with patch("requests.post", return_value=self._mock_response(response_with_nulls)):
+ llm_dict.main_loop()
+
+ result = llm_dict.get_data()
+ for key in dict_fields.keys():
+ assert result[key] is None, f"Expected None for '{key}', got {result[key]!r}"
diff --git a/tests/test_templates.py b/tests/test_templates.py
index bbced2b..9b7cf8e 100644
--- a/tests/test_templates.py
+++ b/tests/test_templates.py
@@ -1,18 +1,126 @@
-def test_create_template(client):
- payload = {
- "name": "Template 1",
- "pdf_path": "src/inputs/file.pdf",
- "fields": {
- "Employee's name": "string",
- "Employee's job title": "string",
- "Employee's department supervisor": "string",
- "Employee's phone number": "string",
- "Employee's email": "string",
- "Signature": "string",
- "Date": "string",
- },
- }
-
- response = client.post("/templates/create", json=payload)
-
- assert response.status_code == 200
+"""
+Tests for /templates endpoints.
+Closes #162, #160, #163
+"""
+
+import io
+import pytest
+from unittest.mock import patch, MagicMock
+from api.db.models import Template
+from datetime import datetime
+
+
+# ── POST /templates/create ────────────────────────────────────────────────────
+
+class TestCreateTemplate:
+
+ def test_create_template_success(self, client):
+ """Uploading a valid PDF returns 200 with template data."""
+ pdf_bytes = (
+ b"%PDF-1.4\n1 0 obj<>endobj\n"
+ b"2 0 obj<>endobj\n"
+ b"3 0 obj<>endobj\n"
+ b"xref\n0 4\n0000000000 65535 f\n"
+ b"trailer<>\nstartxref\n0\n%%EOF"
+ )
+
+ mock_fields = {
+ "JobTitle": {"/T": "JobTitle", "/FT": "/Tx"},
+ "Department": {"/T": "Department", "/FT": "/Tx"},
+ }
+
+ with patch("commonforms.prepare_form"), \
+ patch("pypdf.PdfReader") as mock_reader, \
+ patch("shutil.copyfileobj"), \
+ patch("builtins.open", MagicMock()), \
+ patch("os.path.exists", return_value=True), \
+ patch("os.remove"):
+
+ mock_reader.return_value.get_fields.return_value = mock_fields
+
+ response = client.post(
+ "/templates/create",
+ files={"file": ("form.pdf", io.BytesIO(pdf_bytes), "application/pdf")},
+ data={"name": "Vaccine Form"},
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data["name"] == "Vaccine Form"
+ assert "id" in data
+ assert "fields" in data
+
+ def test_create_template_without_file_returns_422(self, client):
+ """Missing file field returns 422 Unprocessable Entity."""
+ response = client.post(
+ "/templates/create",
+ data={"name": "No File"},
+ )
+ assert response.status_code == 422
+
+ def test_create_template_non_pdf_returns_400(self, client):
+ """Uploading a non-PDF returns 400."""
+ with patch("shutil.copyfileobj"), \
+ patch("builtins.open", MagicMock()):
+ response = client.post(
+ "/templates/create",
+ files={"file": ("data.csv", io.BytesIO(b"a,b,c"), "text/csv")},
+ data={"name": "CSV attempt"},
+ )
+ assert response.status_code == 400
+
+
+# ── GET /templates ────────────────────────────────────────────────────────────
+
+class TestListTemplates:
+
+ def test_list_templates_returns_200(self, client):
+ """GET /templates returns 200."""
+ response = client.get("/templates")
+ assert response.status_code == 200
+
+ def test_list_templates_returns_list(self, client):
+ """Response is always a list."""
+ response = client.get("/templates")
+ assert isinstance(response.json(), list)
+
+ def test_list_templates_empty_on_fresh_db(self, client):
+ """Fresh DB returns empty list."""
+ response = client.get("/templates")
+ assert response.json() == []
+
+ def test_list_templates_pagination_accepted(self, client):
+ """Pagination params accepted without error."""
+ response = client.get("/templates?limit=5&offset=0")
+ assert response.status_code == 200
+
+
+# ── GET /templates/{template_id} ──────────────────────────────────────────────
+
+class TestGetTemplate:
+
+ def test_get_template_not_found(self, client):
+ """Returns 404 for non-existent ID."""
+ response = client.get("/templates/999999")
+ assert response.status_code == 404
+
+ def test_get_template_invalid_id_type(self, client):
+ """Returns 422 for non-integer ID."""
+ response = client.get("/templates/not-an-id")
+ assert response.status_code == 422
+
+ def test_get_template_by_id(self, client, db_session):
+ """Returns correct template for valid ID."""
+ t = Template(
+ name="Cal Fire Form",
+ fields={"officer_name": "Officer Name"},
+ pdf_path="/tmp/cal_fire.pdf",
+ created_at=datetime.utcnow(),
+ )
+ db_session.add(t)
+ db_session.commit()
+ db_session.refresh(t)
+
+ response = client.get(f"/templates/{t.id}")
+ assert response.status_code == 200
+ assert response.json()["name"] == "Cal Fire Form"