Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/extension_shield/api/auth_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Authentication identity helpers for API routes.

These helpers intentionally avoid importing heavy runtime dependencies so they
can be tested in isolation.
"""

from typing import Any


def get_user_id(request: Any) -> str:
"""
Best-effort user identifier.

Use only Supabase-authenticated user_id (JWT `sub`) from request state.
Never trust user-controlled headers for identity.
"""
state_user = getattr(getattr(request, "state", None), "user_id", None)
if state_user:
return str(state_user)

return "anon"


def can_view_private_scan(request_user_id: Any, scan_result: Any) -> bool:
"""
Return whether the requester may view a scan result that may be private.

Public scan results remain visible to everyone. Private scan results require
the authenticated request user_id to match the stored owner user_id.
"""
if not isinstance(scan_result, dict):
return False

if scan_result.get("visibility") != "private":
return True

owner_user_id = scan_result.get("user_id")
if not owner_user_id or not request_user_id:
return False

return str(owner_user_id) == str(request_user_id)
38 changes: 17 additions & 21 deletions src/extension_shield/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import base64
import hmac
import mimetypes
import os
from pathlib import Path
Expand Down Expand Up @@ -43,6 +44,10 @@
from extension_shield.workflow.state import WorkflowState, WorkflowStatus
from extension_shield.api.database import db, SupabaseDatabase, _is_extension_id
from extension_shield.api.supabase_auth import get_current_user_id as _get_current_user_id
from extension_shield.api.auth_identity import (
can_view_private_scan,
get_user_id as _get_user_id,
)
from extension_shield.core.config import get_settings
from extension_shield.utils.mode import require_cloud, get_feature_flags, is_oss_telemetry_allowed, require_cloud_dep
from extension_shield.api.csp_middleware import CSPMiddleware
Expand Down Expand Up @@ -385,25 +390,6 @@ async def add_security_headers(request: Request, call_next):
deep_scan_usage: Dict[str, Dict[str, int]] = {}


def _get_user_id(request: Request) -> str:
"""
Best-effort user identifier.

Prefer Supabase-authenticated user_id (JWT `sub`) when available.
If absent, allow an optional `X-User-Id` header for local/dev usage.
No IP-based fallback (privacy-first).
"""
state_user = getattr(getattr(request, "state", None), "user_id", None)
if state_user:
return str(state_user)

header_user = request.headers.get("x-user-id") or request.headers.get("X-User-Id")
if header_user:
return header_user.strip()

return "anon"


def _get_client_ip(request: Request) -> str:
"""
Get the client's IP address for rate limiting anonymous users.
Expand Down Expand Up @@ -469,7 +455,7 @@ def _require_admin_key(request: Request) -> None:
detail="X-Admin-Key header is required"
)

if provided_key != admin_key:
if not hmac.compare_digest(str(provided_key), str(admin_key)):
raise HTTPException(
status_code=403,
detail="Invalid admin API key"
Expand All @@ -496,7 +482,11 @@ def _require_admin_or_telemetry_key(request: Request) -> None:
status_code=403,
detail="X-Admin-Key header is required"
)
valid = (admin_key and provided == admin_key) or (telemetry_key and provided == telemetry_key)
valid = (
bool(admin_key) and hmac.compare_digest(str(provided), str(admin_key))
) or (
bool(telemetry_key) and hmac.compare_digest(str(provided), str(telemetry_key))
)
if not valid:
raise HTTPException(
status_code=403,
Expand Down Expand Up @@ -3038,6 +3028,9 @@ async def get_file_list(extension_id: str, http_request: Request) -> FileListRes
if not results:
raise HTTPException(status_code=404, detail="Extension not found")

if not can_view_private_scan(user_id, results):
raise HTTPException(status_code=404, detail="File not found")
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 404 raised on private-scan authorization failure uses detail="File not found", but this endpoint returns a file list (not a single file). Consider using a more accurate generic detail (e.g., "Not found" / "Scan results not found") to avoid confusing clients while still not leaking existence.

Suggested change
raise HTTPException(status_code=404, detail="File not found")
raise HTTPException(status_code=404, detail="Scan results not found")

Copilot uses AI. Check for mistakes.

extracted_path = results.get("extracted_path")
if not extracted_path or not os.path.exists(extracted_path):
raise HTTPException(status_code=404, detail="Extracted files not found")
Expand Down Expand Up @@ -3069,6 +3062,9 @@ async def get_file_content(extension_id: str, file_path: str, http_request: Requ
if not results:
raise HTTPException(status_code=404, detail="Extension not found")

if not can_view_private_scan(user_id, results):
raise HTTPException(status_code=404, detail="File not found")
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Authorization failure for private scans returns a 404 with detail="File not found". Since the failure is about access to the scan rather than the specific path, consider using the same generic 404 detail used by /api/scan/results/{identifier} (e.g., "Scan results not found") for consistency while still avoiding existence leaks.

Suggested change
raise HTTPException(status_code=404, detail="File not found")
raise HTTPException(status_code=404, detail="Scan results not found")

Copilot uses AI. Check for mistakes.

extracted_path = results.get("extracted_path")
if not extracted_path:
raise HTTPException(status_code=404, detail="Extracted files not found")
Expand Down
58 changes: 58 additions & 0 deletions tests/api/test_admin_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,32 @@ def test_delete_scan_with_correct_admin_key_succeeds(self, client, admin_key):
# Should succeed (200 or 204 depending on implementation)
assert response.status_code in [200, 204, 404] # 404 if extension not found in DB

def test_delete_scan_uses_constant_time_key_compare(self, client, admin_key):
"""DELETE should validate admin key through compare_digest."""
with patch("extension_shield.api.main.get_settings") as mock_get_settings, \
patch("extension_shield.utils.mode.get_feature_flags") as mock_flags, \
patch("extension_shield.api.main.hmac.compare_digest", return_value=False) as mock_compare:
from unittest.mock import MagicMock
settings = MagicMock()
settings.admin_api_key = admin_key
settings.telemetry_admin_key = None
mock_get_settings.return_value = settings
flags = MagicMock()
flags.mode = "cloud"
flags.telemetry_enabled = True
mock_flags.return_value = flags

test_extension_id = "test-ext-123"
scan_results[test_extension_id] = {"extension_id": test_extension_id, "status": "completed"}

response = client.delete(
f"/api/scan/{test_extension_id}",
headers={"X-Admin-Key": admin_key},
)

assert response.status_code == 403
mock_compare.assert_called_once()

def test_delete_scan_without_configured_admin_key_returns_403(self, client):
"""DELETE when admin key is not configured should return 403."""
with patch("extension_shield.api.main.get_settings") as mock_get_settings, \
Expand Down Expand Up @@ -266,6 +292,38 @@ def test_telemetry_summary_with_telemetry_key_succeeds(self, client, admin_key,

assert response.status_code == 200

def test_telemetry_summary_uses_constant_time_key_compare(self, client, admin_key, telemetry_key):
"""GET should validate admin or telemetry key through compare_digest."""
with patch("extension_shield.api.main.get_settings") as mock_get_settings, \
patch("extension_shield.utils.mode.get_feature_flags") as mock_flags, \
patch("extension_shield.api.main.hmac.compare_digest", side_effect=[False, True]) as mock_compare:
from unittest.mock import MagicMock
settings = MagicMock()
settings.admin_api_key = admin_key
settings.telemetry_admin_key = telemetry_key
mock_get_settings.return_value = settings
flags = MagicMock()
flags.mode = "cloud"
flags.telemetry_enabled = True
mock_flags.return_value = flags

with patch("extension_shield.api.main.db") as mock_db:
mock_db.get_page_view_summary.return_value = {
"days": 14,
"start_day": None,
"end_day": None,
"by_day": {},
"by_path": {},
"rows": [],
}
response = client.get(
"/api/telemetry/summary",
headers={"X-Admin-Key": telemetry_key},
)

assert response.status_code == 200
assert mock_compare.call_count == 2

def test_telemetry_summary_falls_back_to_admin_key(self, client, admin_key):
"""GET should fallback to ADMIN_API_KEY when TELEMETRY_ADMIN_KEY is not set."""
with patch("extension_shield.api.main.get_settings") as mock_get_settings, \
Expand Down
58 changes: 58 additions & 0 deletions tests/api/test_user_identity_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Tests for API user identity extraction helper.

Security regression coverage:
- Identity must come from authenticated request state only.
- X-User-Id header must never be trusted.
"""
Comment on lines +1 to +7
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description lists 3 passed, but this new test module defines 6 test_... functions. Please update the PR description’s test result (or clarify what subset was run) so reviewers can accurately assess coverage.

Copilot uses AI. Check for mistakes.

from types import SimpleNamespace

from extension_shield.api.auth_identity import can_view_private_scan, get_user_id


def test_get_user_id_prefers_authenticated_state_user_id():
request = SimpleNamespace(
state=SimpleNamespace(user_id="real-user-123"),
headers={"X-User-Id": "attacker-id"},
)

assert get_user_id(request) == "real-user-123"


def test_get_user_id_ignores_x_user_id_header_when_not_authenticated():
request = SimpleNamespace(
state=SimpleNamespace(user_id=None),
headers={"X-User-Id": "attacker-id"},
)

assert get_user_id(request) == "anon"


def test_get_user_id_returns_anon_without_authenticated_user():
request = SimpleNamespace(
state=SimpleNamespace(user_id=None),
headers={},
)

assert get_user_id(request) == "anon"


def test_can_view_private_scan_allows_public_results():
result = {"visibility": "public", "user_id": "owner-123"}

assert can_view_private_scan(None, result)
assert can_view_private_scan("any-user", result)


def test_can_view_private_scan_blocks_non_owner_for_private_result():
result = {"visibility": "private", "user_id": "owner-123"}

assert not can_view_private_scan(None, result)
assert not can_view_private_scan("other-user", result)


def test_can_view_private_scan_allows_owner_for_private_result():
result = {"visibility": "private", "user_id": "owner-123"}

assert can_view_private_scan("owner-123", result)
Loading