diff --git a/src/extension_shield/api/main.py b/src/extension_shield/api/main.py index 3ebbc4c..5f48608 100644 --- a/src/extension_shield/api/main.py +++ b/src/extension_shield/api/main.py @@ -62,6 +62,14 @@ # Initialize logger logger = logging.getLogger(__name__) + +def _parse_trusted_proxy_hosts() -> list[str]: + """Return the explicit proxy hosts allowed to send forwarded headers.""" + raw_hosts = os.getenv("TRUSTED_PROXY_HOSTS", "").strip() + if raw_hosts: + return [host.strip() for host in raw_hosts.split(",") if host.strip()] + return ["127.0.0.1", "localhost", "::1"] + # Import safe JSON utilities from shared module from extension_shield.utils.json_encoder import ( safe_json_dumps, @@ -361,8 +369,9 @@ async def add_security_headers(request: Request, call_next): print(f"✅ CSP: Production mode detected (STATIC_DIR={STATIC_DIR}, index.html exists)") app.add_middleware(CSPMiddleware, is_dev=_is_dev) -# Trust X-Forwarded-Proto / X-Forwarded-For from Railway/Cloudflare so request.url.scheme is correct -app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") +# Trust forwarded headers only from explicitly allowed proxy hosts. +# Set TRUSTED_PROXY_HOSTS to your actual reverse proxy / CDN hop(s). +app.add_middleware(ProxyHeadersMiddleware, trusted_hosts=_parse_trusted_proxy_hosts()) # In-memory state lives in shared.py; import references here so existing # code in this file (and tests) can continue using module-level names. @@ -408,20 +417,9 @@ def _get_client_ip(request: Request) -> str: """ Get the client's IP address for rate limiting anonymous users. - Handles proxied requests via X-Forwarded-For and X-Real-IP headers. - Falls back to client host if no headers present. + Relies on ProxyHeadersMiddleware to rewrite request.client only when the + request came from a trusted proxy host. """ - # Check X-Forwarded-For header (from reverse proxy/load balancer) - x_forwarded_for = request.headers.get("x-forwarded-for") - if x_forwarded_for: - # Take the first IP (original client) - return x_forwarded_for.split(",")[0].strip() - - # Check X-Real-IP header (from nginx) - x_real_ip = request.headers.get("x-real-ip") - if x_real_ip: - return x_real_ip.strip() - # Fall back to direct client IP if request.client: return request.client.host @@ -504,6 +502,27 @@ def _require_admin_or_telemetry_key(request: Request) -> None: ) +def _require_private_scan_artifact_access( + request: Request, + extension_id: str, + payload: Optional[Dict[str, Any]] = None, +) -> None: + """Block access to private scan artifacts unless the requester owns the scan.""" + requester_id = getattr(getattr(request, "state", None), "user_id", None) + if isinstance(payload, dict): + is_private = payload.get("visibility") == "private" or payload.get("source") == "upload" + owner_id = payload.get("user_id") or scan_user_ids.get(extension_id) + else: + is_private = scan_source.get(extension_id) == "upload" + owner_id = scan_user_ids.get(extension_id) + + if not is_private: + return + + if not requester_id or not owner_id or requester_id != owner_id: + raise HTTPException(status_code=404, detail="Scan results not found") + + def _deep_scan_limit_status(rate_limit_key: str) -> Dict[str, Any]: """Get deep scan limit status. Returns unlimited in local/dev environments. Anonymous (IP-based) users get 1 scan per day; authenticated users get 3. @@ -1086,6 +1105,38 @@ def _extension_icon_file_response(icon_file_path: str) -> FileResponse: ) +def _resolve_extracted_root_path(extracted_path: Optional[str]) -> Optional[Path]: + """Resolve extracted root path safely, including relative storage paths.""" + if not extracted_path: + return None + root = Path(extracted_path) + if not root.is_absolute(): + root = Path(get_settings().extension_storage_path) / root + try: + root = root.resolve(strict=True) + except Exception: + return None + return root if root.is_dir() else None + + +def _resolve_icon_candidate_path(extracted_root: Path, icon_path: str) -> Optional[Path]: + """Resolve icon candidate and ensure it remains within extracted root.""" + candidate = Path(icon_path) + if not candidate.is_absolute(): + candidate = extracted_root / candidate + try: + resolved = candidate.resolve(strict=True) + except Exception: + return None + try: + in_root = os.path.commonpath([str(extracted_root), str(resolved)]) == str(extracted_root) + except Exception: + in_root = False + if not in_root or not resolved.is_file(): + return None + return resolved + + def _load_icon_record_from_db(extension_id: str) -> Dict[str, Optional[str]]: """ Load icon-related fields for an extension from DB. @@ -1138,21 +1189,16 @@ def _extract_icon_blob_for_storage( if not icon_path or not extracted_path: return None, None try: - abs_extracted_path = os.path.abspath(extracted_path) - candidate_path = ( - os.path.abspath(icon_path) - if os.path.isabs(icon_path) - else os.path.abspath(os.path.join(extracted_path, icon_path)) - ) + extracted_root = _resolve_extracted_root_path(extracted_path) + if not extracted_root: + return None, None - # Security check: icon must stay inside extracted extension dir. - if os.path.commonpath([abs_extracted_path, candidate_path]) != abs_extracted_path: + candidate = _resolve_icon_candidate_path(extracted_root, icon_path) + if not candidate: logger.warning("[ICON] Refusing out-of-bounds icon path for persistence: %s", icon_path) return None, None - if not os.path.isfile(candidate_path): - return None, None - with open(candidate_path, "rb") as icon_file: + with open(candidate, "rb") as icon_file: icon_bytes = icon_file.read() if not icon_bytes: return None, None @@ -1160,12 +1206,12 @@ def _extract_icon_blob_for_storage( logger.warning( "[ICON] Skipping icon persistence for oversized icon (%s bytes): %s", len(icon_bytes), - candidate_path, + candidate, ) return None, None icon_b64 = base64.b64encode(icon_bytes).decode("ascii") - guessed_media_type, _ = mimetypes.guess_type(candidate_path) + guessed_media_type, _ = mimetypes.guess_type(str(candidate)) media_type = _normalize_image_media_type(guessed_media_type) return icon_b64, media_type except Exception as exc: @@ -2888,7 +2934,7 @@ async def batch_scan_status(req: BatchStatusRequest, request: Request): @app.get("/api/scan/enforcement_bundle/{extension_id}") -async def get_enforcement_bundle(extension_id: str): +async def get_enforcement_bundle(extension_id: str, http_request: Request): """ Get the governance enforcement bundle for an analyzed extension. @@ -2926,6 +2972,8 @@ async def get_enforcement_bundle(extension_id: str): if not results: raise HTTPException(status_code=404, detail="Scan results not found") + + _require_private_scan_artifact_access(http_request, extension_id, results) # Check if governance analysis was run governance_bundle = results.get("governance_bundle") @@ -2959,7 +3007,7 @@ async def get_enforcement_bundle(extension_id: str): @app.get("/api/scan/report/{extension_id}") -async def generate_pdf_report(extension_id: str) -> Response: +async def generate_pdf_report(extension_id: str, http_request: Request) -> Response: """ Generate a PDF security report for an analyzed extension. @@ -2989,6 +3037,8 @@ async def generate_pdf_report(extension_id: str) -> Response: if not results: raise HTTPException(status_code=404, detail="Scan results not found") + _require_private_scan_artifact_access(http_request, extension_id, results) + # Generate PDF report try: report_generator = ReportGenerator() @@ -3038,6 +3088,8 @@ async def get_file_list(extension_id: str, http_request: Request) -> FileListRes if not results: raise HTTPException(status_code=404, detail="Extension not found") + _require_private_scan_artifact_access(http_request, extension_id, results) + extracted_path = results.get("extracted_path") if not extracted_path or not os.path.exists(extracted_path): raise HTTPException(status_code=404, detail="Extracted files not found") @@ -3069,22 +3121,43 @@ async def get_file_content(extension_id: str, file_path: str, http_request: Requ if not results: raise HTTPException(status_code=404, detail="Extension not found") + _require_private_scan_artifact_access(http_request, extension_id, results) + extracted_path = results.get("extracted_path") if not extracted_path: raise HTTPException(status_code=404, detail="Extracted files not found") - # Construct full file path - full_path = os.path.join(extracted_path, file_path) + # Resolve extracted root path (handles relative paths stored in DB). + extracted_root = Path(extracted_path) + if not extracted_root.is_absolute(): + extracted_root = Path(get_settings().extension_storage_path) / extracted_root + + try: + extracted_root = extracted_root.resolve(strict=True) + except Exception: + raise HTTPException(status_code=404, detail="Extracted files not found") + + if not extracted_root.is_dir(): + raise HTTPException(status_code=404, detail="Extracted files not found") + + # Resolve requested path and ensure it stays inside extracted root. + try: + candidate_path = (extracted_root / file_path).resolve(strict=True) + except Exception: + raise HTTPException(status_code=404, detail="File not found") - # Security check: ensure path is within extracted directory - if not os.path.abspath(full_path).startswith(os.path.abspath(extracted_path)): + try: + in_root = os.path.commonpath([str(extracted_root), str(candidate_path)]) == str(extracted_root) + except Exception: + in_root = False + if not in_root: raise HTTPException(status_code=403, detail="Access denied") - if not os.path.exists(full_path): + if not candidate_path.is_file(): raise HTTPException(status_code=404, detail="File not found") try: - with open(full_path, "r", encoding="utf-8") as f: + with open(candidate_path, "r", encoding="utf-8") as f: content = f.read() return FileContentResponse(content=content, file_path=file_path) except UnicodeDecodeError as exc: @@ -3824,7 +3897,7 @@ async def database_health_check(request: Request): @app.get("/api/scan/icon/{extension_id}") -async def get_extension_icon(extension_id: str): +async def get_extension_icon(extension_id: str, http_request: Request): """ Get extension icon from the extracted extension folder. Uses icon_path from storage when available, and falls back to persisted icon bytes. @@ -3850,6 +3923,9 @@ async def get_extension_icon(extension_id: str): icon_media_type = results.get("icon_media_type") else: db_icon_record = _load_icon_record_from_db(extension_id) + results = db.get_scan_result(extension_id) + if results: + scan_results[extension_id] = results extracted_path = db_icon_record.get("extracted_path") icon_path = db_icon_record.get("icon_path") icon_base64 = db_icon_record.get("icon_base64") @@ -3893,6 +3969,7 @@ async def get_extension_icon(extension_id: str): # Best practice: if we have a persisted icon blob, serve it immediately. # This avoids relying on filesystem state (ephemeral/persistent) and prevents slow fallbacks. + _require_private_scan_artifact_access(http_request, extension_id, results) persisted = _extension_icon_response_from_base64(icon_base64, icon_media_type) if persisted: return persisted @@ -3934,30 +4011,19 @@ def _persisted_icon_response() -> Optional[Response]: logger.debug(f"[ICON] No extracted_path for {extension_id}, returning placeholder") return _extension_icon_placeholder_response() - # Convert to absolute path if it's relative - # extracted_path is relative to extension_storage_path, not RESULTS_DIR - if not os.path.isabs(extracted_path): - settings = get_settings() - storage_path = Path(settings.extension_storage_path) - # If extracted_path is just a directory name, join with storage_path - if os.path.basename(extracted_path) == extracted_path: - extracted_path = os.path.join(str(storage_path), extracted_path) - else: - # Already has path components, resolve relative to storage_path - extracted_path = os.path.join(str(storage_path), extracted_path) - - # Verify the path exists - if not os.path.exists(extracted_path): + extracted_root = _resolve_extracted_root_path(extracted_path) + if not extracted_root: logger.warning(f"Extracted path does not exist: {extracted_path}") # Try alternative: search in storage_path for matching directory settings = get_settings() storage_path = Path(settings.extension_storage_path) if storage_path.exists(): # Look for directory matching the basename - basename = os.path.basename(extracted_path) + basename = os.path.basename(extracted_path) if extracted_path else "" for item in storage_path.iterdir(): if item.is_dir() and (item.name == basename or item.name.startswith(basename)): extracted_path = str(item) + extracted_root = _resolve_extracted_root_path(extracted_path) logger.debug(f"Found extracted extension at: {extracted_path}") break else: @@ -3974,26 +4040,30 @@ def _persisted_icon_response() -> Optional[Response]: return persisted_response logger.debug(f"[ICON] Storage path missing for {extension_id}, returning placeholder") return _extension_icon_placeholder_response() + + if not extracted_root: + persisted_response = _persisted_icon_response() + if persisted_response: + logger.debug("[ICON] Served persisted icon blob for %s", extension_id) + return persisted_response + logger.debug(f"[ICON] Extracted root unresolved for {extension_id}, returning placeholder") + return _extension_icon_placeholder_response() - logger.debug(f"[ICON] extracted_path={extracted_path}, icon_path={icon_path}") + logger.debug(f"[ICON] extracted_path={extracted_root}, icon_path={icon_path}") # First, try using icon_path from database if available if icon_path: - full_icon_path = os.path.join(extracted_path, icon_path) - # Security check: ensure icon_path is within extracted_path - abs_icon_path = os.path.abspath(full_icon_path) - abs_extracted_path = os.path.abspath(extracted_path) - - logger.debug(f"[ICON] Trying stored icon_path: {full_icon_path}") - if abs_icon_path.startswith(abs_extracted_path) and os.path.exists(full_icon_path): - logger.info(f"[ICON] Found icon using stored icon_path: {full_icon_path}") - return _extension_icon_file_response(full_icon_path) + candidate = _resolve_icon_candidate_path(extracted_root, icon_path) + logger.debug(f"[ICON] Trying stored icon_path: {icon_path}") + if candidate: + logger.info(f"[ICON] Found icon using stored icon_path: {candidate}") + return _extension_icon_file_response(str(candidate)) else: - logger.warning(f"[ICON] Stored icon_path {icon_path} not found at {full_icon_path}, falling back to search") + logger.warning(f"[ICON] Stored icon_path {icon_path} is invalid or out of bounds, falling back to search") # Fallback: Try common icon sizes in order of preference icon_sizes = ["128", "64", "48", "32", "16", "96", "256"] - icons_dir = os.path.join(extracted_path, "icons") + icons_dir = os.path.join(str(extracted_root), "icons") # First try icons directory if os.path.exists(icons_dir): @@ -4005,18 +4075,18 @@ def _persisted_icon_response() -> Optional[Response]: # Try root directory for size in icon_sizes: - test_icon_path = os.path.join(extracted_path, f"icon{size}.png") + test_icon_path = os.path.join(str(extracted_root), f"icon{size}.png") if os.path.exists(test_icon_path): logger.debug(f"Found icon at: {test_icon_path}") return _extension_icon_file_response(test_icon_path) - test_icon_path = os.path.join(extracted_path, f"{size}.png") + test_icon_path = os.path.join(str(extracted_root), f"{size}.png") if os.path.exists(test_icon_path): logger.debug(f"Found icon at: {test_icon_path}") return _extension_icon_file_response(test_icon_path) # Try images directory (common for many extensions) - images_dir = os.path.join(extracted_path, "images") + images_dir = os.path.join(str(extracted_root), "images") if os.path.exists(images_dir): # Look for icon files in images directory for icon_name in ["icon128.png", "icon.png", "icon64.png", "icon48.png", "icon32.png", "icon16.png", "logo.png"]: @@ -4026,7 +4096,7 @@ def _persisted_icon_response() -> Optional[Response]: return _extension_icon_file_response(icon_path) # Try checking manifest for icon paths - manifest_path = os.path.join(extracted_path, "manifest.json") + manifest_path = os.path.join(str(extracted_root), "manifest.json") if os.path.exists(manifest_path): try: with open(manifest_path, "r", encoding="utf-8") as f: @@ -4038,16 +4108,11 @@ def _persisted_icon_response() -> Optional[Response]: # Get the largest icon largest_size = max(manifest_icons.keys(), key=lambda x: int(x)) icon_rel_path = manifest_icons[largest_size] - manifest_icon_path = os.path.join(extracted_path, icon_rel_path) - - # Security check - abs_icon_path = os.path.abspath(manifest_icon_path) - abs_extracted_path = os.path.abspath(extracted_path) - - if abs_icon_path.startswith(abs_extracted_path): - if os.path.exists(manifest_icon_path): + if isinstance(icon_rel_path, str): + manifest_icon_path = _resolve_icon_candidate_path(extracted_root, icon_rel_path) + if manifest_icon_path: logger.debug(f"Found icon from manifest at: {manifest_icon_path}") - return _extension_icon_file_response(manifest_icon_path) + return _extension_icon_file_response(str(manifest_icon_path)) except Exception as e: logger.warning(f"Failed to read manifest for icons: {e}") diff --git a/tests/api/test_enforcement_bundle.py b/tests/api/test_enforcement_bundle.py index 0b1f90a..174e090 100644 --- a/tests/api/test_enforcement_bundle.py +++ b/tests/api/test_enforcement_bundle.py @@ -118,6 +118,38 @@ def test_get_enforcement_bundle_not_found(self, client): assert response.status_code == 404 assert "not found" in response.json()["detail"].lower() + + def test_private_scan_artifact_requires_owner(self, client, tmp_path): + """Private scan artifacts should not be readable without the owning user.""" + ext_id = "privateartifact1234567890123456" + extracted = tmp_path / ext_id + extracted.mkdir() + + scan_results[ext_id] = { + "extension_id": ext_id, + "extension_name": "Private Extension", + "status": "completed", + "visibility": "private", + "user_id": "owner-user-1", + "governance_bundle": {"decision": {"verdict": "ALLOW"}}, + "extracted_path": str(extracted), + } + + try: + endpoints = [ + f"/api/scan/enforcement_bundle/{ext_id}", + f"/api/scan/report/{ext_id}", + f"/api/scan/files/{ext_id}", + f"/api/scan/file/{ext_id}/manifest.json", + f"/api/scan/icon/{ext_id}", + ] + + for endpoint in endpoints: + response = client.get(endpoint) + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + finally: + scan_results.pop(ext_id, None) def test_get_enforcement_bundle_no_governance_data(self, client): """Test 404 when governance bundle not available.""" diff --git a/tests/api/test_scan_results_endpoint.py b/tests/api/test_scan_results_endpoint.py index 192e6a4..15233c8 100644 --- a/tests/api/test_scan_results_endpoint.py +++ b/tests/api/test_scan_results_endpoint.py @@ -95,3 +95,66 @@ def test_legacy_payload_is_upgraded_with_consumer_insights(self, client: TestCli assert isinstance(rvm["consumer_insights"], dict) +def test_scan_file_blocks_path_traversal(client: TestClient, tmp_path) -> None: + """/api/scan/file should block traversal outside extracted root.""" + extension_id = "pathtraversaltest1234567890123456" + + extracted_dir = tmp_path / "extracted" + extracted_dir.mkdir() + inside_file = extracted_dir / "manifest.json" + inside_file.write_text('{"name": "safe"}', encoding="utf-8") + + outside_file = tmp_path / "outside.txt" + outside_file.write_text("secret", encoding="utf-8") + + scan_results[extension_id] = { + "extension_id": extension_id, + "status": "completed", + "visibility": "public", + "extracted_path": str(extracted_dir), + } + + try: + ok = client.get(f"/api/scan/file/{extension_id}/manifest.json") + assert ok.status_code == 200 + assert "safe" in ok.json()["content"] + + blocked = client.get(f"/api/scan/file/{extension_id}/../outside.txt") + assert blocked.status_code == 403 + finally: + scan_results.pop(extension_id, None) + + +def test_scan_icon_blocks_manifest_path_escape(client: TestClient, tmp_path) -> None: + """/api/scan/icon should not serve out-of-root files via manifest icon path.""" + extension_id = "iconpathtest1234567890123456789012" + + extracted_dir = tmp_path / "icon-extracted" + extracted_dir.mkdir() + (extracted_dir / "manifest.json").write_text( + '{"icons": {"128": "../../outside-secret.png"}}', + encoding="utf-8", + ) + + outside_file = tmp_path / "outside-secret.png" + outside_file.write_bytes(b"OUTSIDE_SECRET_BYTES") + + scan_results[extension_id] = { + "extension_id": extension_id, + "status": "completed", + "visibility": "public", + "extracted_path": str(extracted_dir), + "icon_path": "../../outside-secret.png", + "icon_base64": None, + "icon_media_type": None, + } + + try: + response = client.get(f"/api/scan/icon/{extension_id}") + assert response.status_code == 200 + assert response.headers.get("X-Extension-Icon-Source") != "filesystem" + assert response.content != b"OUTSIDE_SECRET_BYTES" + finally: + scan_results.pop(extension_id, None) + + diff --git a/uv.lock b/uv.lock index 9169c81..e5c6b19 100644 --- a/uv.lock +++ b/uv.lock @@ -779,6 +779,7 @@ dependencies = [ { name = "click" }, { name = "fastapi" }, { name = "fastmcp" }, + { name = "jinja2" }, { name = "langchain" }, { name = "langchain-ibm" }, { name = "langchain-ollama" }, @@ -817,6 +818,7 @@ requires-dist = [ { name = "click", specifier = ">=8.1.0" }, { name = "fastapi", specifier = ">=0.115.0" }, { name = "fastmcp", specifier = ">=1.0" }, + { name = "jinja2", specifier = ">=3.1.0" }, { name = "langchain", specifier = ">=0.3.27" }, { name = "langchain-ibm", specifier = ">=0.3.18" }, { name = "langchain-ollama", specifier = ">=1.0.0" },