Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion app/api/analyze.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""POST /analyze — extract drugs from OCR text."""

import re

from fastapi import APIRouter

from app.api.schemas import AnalyzeDataSources, AnalyzeRequest, AnalyzeResponse, DrugResult
Expand All @@ -8,12 +10,14 @@

router = APIRouter()

_HTML_TAG = re.compile(r"<[^>]+>")


@router.post("/analyze", response_model=AnalyzeResponse)
async def analyze(request: AnalyzeRequest):
drugs = await drug_analyzer.analyze(request.text)
return AnalyzeResponse(
drugs=[DrugResult(**d) for d in drugs],
raw_text=request.text,
raw_text=_HTML_TAG.sub("", request.text),
data_sources=AnalyzeDataSources(ner_model=ner_model.MODEL_ID),
)
10 changes: 7 additions & 3 deletions app/api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Pydantic request/response models for the PillChecker API."""

from pydantic import BaseModel, Field
from typing import Annotated

from pydantic import BaseModel, Field, StringConstraints


# --- POST /analyze ---

class AnalyzeRequest(BaseModel):
text: str = Field(..., min_length=1, examples=["BRUFEN Ibuprofen 400 mg Film-Coated Tablets"])
text: str = Field(..., min_length=1, max_length=5000, examples=["BRUFEN Ibuprofen 400 mg Film-Coated Tablets"])


class DrugResult(BaseModel):
Expand All @@ -32,7 +34,9 @@ class AnalyzeResponse(BaseModel):
# --- POST /interactions ---

class InteractionsRequest(BaseModel):
drugs: list[str] = Field(..., min_length=2, examples=[["ibuprofen", "warfarin"]])
drugs: list[Annotated[str, StringConstraints(min_length=1, max_length=200, strip_whitespace=True)]] = Field(
..., min_length=2, examples=[["ibuprofen", "warfarin"]]
)


class DrugRef(BaseModel):
Expand Down
13 changes: 12 additions & 1 deletion app/services/drug_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import logging
import re

from app.clients import rxnorm_client
from app.middleware.audit_log import get_audit_context
Expand All @@ -17,6 +18,14 @@

logger = logging.getLogger(__name__)

_PUNCTUATION_ONLY = re.compile(r"^[^\w]+$")


def _is_valid_entity_name(name: str) -> bool:
"""Reject empty, single-char, or punctuation-only entity names."""
stripped = name.strip()
return len(stripped) > 1 and not _PUNCTUATION_ONLY.match(stripped)


async def analyze(text: str) -> list[dict]:
"""Analyze OCR text and return enriched drug profiles.
Expand Down Expand Up @@ -48,7 +57,9 @@ async def analyze(text: str) -> list[dict]:

drug_entities = [
e for e in entities
if e.label in ("CHEM", "Chemical", "CHEMICAL") and not e.text.isdigit()
if e.label in ("CHEM", "Chemical", "CHEMICAL")
and not e.text.isdigit()
and _is_valid_entity_name(e.text)
]

if drug_entities:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,59 @@ def client(mock_drugbank, mock_severity, mock_severity_parser):
return TestClient(app)


class TestAnalyzeValidation:
def test_analyze_rejects_oversized_text(self, client):
"""Text over 5000 chars must be rejected with 422."""
resp = client.post(
"/analyze",
json={"text": "Metformin 500mg " * 500},
headers={"X-API-Key": "test-key"},
)
assert resp.status_code == 422

def test_analyze_strips_html_from_raw_text(self, client):
"""HTML tags must be stripped from raw_text to prevent XSS."""
with patch("app.services.drug_analyzer.analyze", new=AsyncMock(return_value=[])):
resp = client.post(
"/analyze",
json={"text": '<script>alert(1)</script>Metformin 500mg'},
headers={"X-API-Key": "test-key"},
)
assert resp.status_code == 200
data = resp.json()
assert "<script>" not in data["raw_text"]
assert "alert(1)" in data["raw_text"]


class TestInteractionsValidation:
def test_interactions_rejects_empty_string_drug(self, client):
"""Empty strings in drugs list must be rejected with 422."""
resp = client.post(
"/interactions",
json={"drugs": ["metformin", "", "lisinopril"]},
headers={"X-API-Key": "test-key"},
)
assert resp.status_code == 422

def test_interactions_rejects_whitespace_only_drug(self, client):
"""Whitespace-only strings must be rejected after stripping."""
resp = client.post(
"/interactions",
json={"drugs": [" ", "metformin"]},
headers={"X-API-Key": "test-key"},
)
assert resp.status_code == 422

def test_interactions_rejects_long_drug_name(self, client):
"""Drug names over 200 chars must be rejected."""
resp = client.post(
"/interactions",
json={"drugs": ["a" * 201, "metformin"]},
headers={"X-API-Key": "test-key"},
)
assert resp.status_code == 422


class TestInteractionsEndpoint:
def test_known_interaction(self, client, mock_drugbank):
mock_drugbank.get_interactions.side_effect = [
Expand Down
42 changes: 42 additions & 0 deletions tests/test_drug_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,48 @@ async def test_low_confidence_ner_needs_confirmation():
assert results[0]["needs_confirmation"] is True


@pytest.mark.asyncio
async def test_single_char_ner_entity_filtered():
"""Single-character NER entities (e.g. '-') must be filtered out."""
entities = [
ner_model.Entity(text="-", label="CHEM", score=0.95, start=0, end=1),
ner_model.Entity(text="Metformin", label="CHEM", score=0.93, start=5, end=14),
]

with (
patch("app.services.drug_analyzer.ner_model.predict", return_value=entities),
patch(
"app.services.drug_analyzer.rxnorm_client.get_rxcui",
new=AsyncMock(return_value="6809"),
),
):
results = await drug_analyzer.analyze("- Metformin 500mg")

assert len(results) == 1
assert results[0]["name"] == "Metformin"


@pytest.mark.asyncio
async def test_punctuation_only_ner_entity_filtered():
"""Entities that are pure punctuation must be filtered out."""
entities = [
ner_model.Entity(text="...", label="CHEM", score=0.90, start=0, end=3),
ner_model.Entity(text="Lisinopril", label="CHEM", score=0.92, start=5, end=15),
]

with (
patch("app.services.drug_analyzer.ner_model.predict", return_value=entities),
patch(
"app.services.drug_analyzer.rxnorm_client.get_rxcui",
new=AsyncMock(return_value="29046"),
),
):
results = await drug_analyzer.analyze("... Lisinopril 10mg")

assert len(results) == 1
assert results[0]["name"] == "Lisinopril"


@pytest.mark.asyncio
async def test_ner_results_sorted_by_confidence_descending():
"""Results must be sorted by confidence, highest first."""
Expand Down