Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ OPENROUTER_API_KEY="<your_openrouter_api_key>"
API_HOST=0.0.0.0
API_PORT=8000
API_RATE_LIMIT=60
# Optional shared header auth for internal callers; leave empty to keep API open.
OST_LINKER_SERVICE_TOKEN=

# --- dbt ---
# Target profile: "local" (port 5433, default) or "docker" (port 5432, host "db").
Expand Down
19 changes: 19 additions & 0 deletions src/services/api/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
import secrets

from fastapi import Header, HTTPException


def require_service_token(
x_service_token: str | None = Header(default=None),
) -> None:
"""Require X-Service-Token when service-token auth is enabled."""
expected = os.environ.get("OST_LINKER_SERVICE_TOKEN")
if not expected:
return

if not secrets.compare_digest(x_service_token or "", expected):
raise HTTPException(
status_code=401,
detail="Invalid or missing service token",
)
1 change: 1 addition & 0 deletions src/services/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ class APIConfig(BaseSettings):
host: str = Field(default="0.0.0.0", alias="API_HOST")
port: int = Field(default=8000, alias="API_PORT")
rate_limit: int = Field(default=60, alias="API_RATE_LIMIT")
service_token: str | None = Field(default=None, alias="OST_LINKER_SERVICE_TOKEN")

model_config = {"populate_by_name": True}
18 changes: 14 additions & 4 deletions src/services/api/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager

from fastapi import FastAPI, Request, Response
from fastapi import Depends, FastAPI, Request, Response
from fastapi.responses import JSONResponse
from slowapi.errors import RateLimitExceeded

from src.services.api.auth import require_service_token
from src.services.api.config import APIConfig
from src.services.api.dependencies import close_pool, init_pool
from src.services.api.rate_limit import limiter
Expand Down Expand Up @@ -48,6 +49,15 @@ def _rate_limit_handler(request: Request, exc: RateLimitExceeded) -> Response:
# backend, not by browsers. Add CORSMiddleware if browser access is needed later.

app.include_router(health.router)
app.include_router(references.router)
app.include_router(projects.router)
app.include_router(recommendations.router)
app.include_router(
references.router,
dependencies=[Depends(require_service_token)],
)
app.include_router(
projects.router,
dependencies=[Depends(require_service_token)],
)
app.include_router(
recommendations.router,
dependencies=[Depends(require_service_token)],
)
144 changes: 144 additions & 0 deletions tests/api/test_service_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from unittest.mock import MagicMock

import pytest
from fastapi.testclient import TestClient

from src.services.api.dependencies import get_pool
from src.services.api.main import app

pytestmark = pytest.mark.api


def _make_pool(
*,
fetchall_rows: list[dict] | None = None,
fetchone_rows: list[dict | None] | None = None,
) -> MagicMock:
"""Create a mock pool whose cursor returns the given rows."""
mock_cursor = MagicMock()
mock_cursor.fetchall.return_value = fetchall_rows or []
if fetchone_rows is not None:
mock_cursor.fetchone.side_effect = fetchone_rows
else:
mock_cursor.fetchone.return_value = None

mock_pool = MagicMock()
mock_pool.get_cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_pool.get_cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_pool


class TestServiceTokenOpen:
def test_open_mode_allows_requests_without_header(
self, client: TestClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Protected endpoints stay open when no service token is configured."""
monkeypatch.delenv("OST_LINKER_SERVICE_TOKEN", raising=False)
pool = _make_pool(
fetchall_rows=[
{
"id": "1",
"title": "React App",
"description": "A react app",
"repo_url": "https://github.com/org/react-app",
"published": True,
"trending": False,
"logo_url": None,
}
]
)
app.dependency_overrides[get_pool] = lambda: pool
try:
response = client.get("/projects/search?q=foo")
finally:
app.dependency_overrides.pop(get_pool, None)

assert response.status_code == 200


class TestServiceTokenEnforced:
@pytest.mark.parametrize(
"path",
[
"/projects/search?q=foo",
"/projects/550e8400-e29b-41d4-a716-446655440000",
"/projects/550e8400-e29b-41d4-a716-446655440000/similar",
"/recommendations/trending",
"/categories",
"/domains",
"/techstacks",
],
)
def test_missing_header(
self, client: TestClient, monkeypatch: pytest.MonkeyPatch, path: str
) -> None:
"""Protected endpoints return 401 without a service token header."""
monkeypatch.setenv("OST_LINKER_SERVICE_TOKEN", "expected-token")

response = client.get(path)

assert response.status_code == 401
assert response.json() == {"detail": "Invalid or missing service token"}

def test_mismatched_header(
self, client: TestClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Protected endpoints return 401 for the wrong service token."""
monkeypatch.setenv("OST_LINKER_SERVICE_TOKEN", "expected-token")

response = client.get(
"/projects/search?q=foo",
headers={"X-Service-Token": "wrong-token"},
)

assert response.status_code == 401
assert response.json() == {"detail": "Invalid or missing service token"}

def test_matching_header(
self, client: TestClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Protected endpoints allow requests with the matching service token."""
monkeypatch.setenv("OST_LINKER_SERVICE_TOKEN", "expected-token")
pool = _make_pool(
fetchall_rows=[
{
"id": "1",
"title": "React App",
"description": "A react app",
"repo_url": "https://github.com/org/react-app",
"published": True,
"trending": False,
"logo_url": None,
}
]
)
app.dependency_overrides[get_pool] = lambda: pool
try:
response = client.get(
"/projects/search?q=foo",
headers={"X-Service-Token": "expected-token"},
)
finally:
app.dependency_overrides.pop(get_pool, None)

assert response.status_code == 200


class TestHealthOpen:
@pytest.mark.parametrize("service_token", [None, "expected-token"])
def test_health_stays_open(
self,
client: TestClient,
monkeypatch: pytest.MonkeyPatch,
service_token: str | None,
) -> None:
"""Health remains open whether service-token auth is enabled or not."""
if service_token is None:
monkeypatch.delenv("OST_LINKER_SERVICE_TOKEN", raising=False)
else:
monkeypatch.setenv("OST_LINKER_SERVICE_TOKEN", service_token)

response = client.get("/health")

assert response.status_code == 200
assert response.json() == {"status": "ok"}
Loading