From 761f7759aa876a3e8e7b6a03be8afcd596497f45 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 8 Jun 2026 07:30:24 +0000 Subject: [PATCH 1/4] redesign with config restructuring and SSRF guard details --- DESIGN.md | 252 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 DESIGN.md diff --git a/DESIGN.md b/DESIGN.md new file mode 100644 index 0000000..9b23c73 --- /dev/null +++ b/DESIGN.md @@ -0,0 +1,252 @@ +# Design: SSRF Guard for URL-Based File Sources + +## Summary + +Add controls to the file upload endpoint's URL-fetching capability so that ravnar cannot be used as a proxy to reach internal infrastructure or cloud metadata endpoints. URL fetching will be disabled by default, and when enabled, restricted to an explicit allowlist of permitted domains. The implementation adds a new configuration sub-object to the storage config and a validation layer in `FileHandler`. + +## Goals + +- Prevent ravnar from being coerced into making HTTP requests to internal or private IP ranges. +- Prevent data exfiltration via attacker-controlled URLs. +- Prevent DNS-rebinding-based bypasses through defense-in-depth design (per-hop redirect validation, hostname normalization). +- Maintain existing functionality for legitimate use cases (fetching files from known external domains). + +## Non-Goals + +- Adding a full HTTP proxy or egress filtering system — that is the deployment environment's responsibility. +- Rate-limiting or connection-pool sizing — these are gateway concerns (per established architecture boundaries). +- Adding support for authenticated URL fetches (e.g., Bearer tokens, cookies) — the URL source is intended for public resources only. +- Replacing the underlying HTTP client library or adding a custom DNS resolver. +- DNS-level rebinding protection (see Tradeoffs & Risks). +- Backwards compatibility for config structure — ravnar is in alpha. + +## Background / Motivation + +Services that fetch content from user-supplied URLs are a well-known vector for Server-Side Request Forgery (SSRF). An attacker who can supply a URL to the server can probe internal services, reach cloud metadata endpoints (e.g., `169.254.169.254`), or exfiltrate data to an attacker-controlled endpoint via query parameters, path segments, or DNS lookups. + +The current implementation in `FileHandler._extract_url` creates an `httpx.AsyncClient` with `follow_redirects=True` and no transport-level restrictions. Any authenticated user with `files:write` can supply any URL. This means: + +- An attacker can reach any internal service on the host or network that ravnar can reach. +- An attacker can use redirect chains (allowed → internal) to bypass naive hostname filtering. +- An attacker can use IPv6 literal addresses, DNS rebinding, or alternative representations of internal IPs. + +The fix follows a deny-by-default model: URL fetching is opt-in, and when enabled, only explicitly listed domains are permitted. + +## Design + +### 1. Configuration Model + +Restructure the `StorageConfig` hierarchy to split the monolithic config into sub-objects. Add a new `URLDataSourceConfig` for the SSRF guard settings. + +#### Current structure (before) + +```python +class StorageConfig(BaseModel): + enabled: bool = True + database_dsn: str = ... + file_storage_path: UPath = ... +``` + +#### New structure (after) + +```python +class DatabaseConfig(BaseModel): + dsn: str = ... + +class URLDataSourceConfig(BaseModel): + enabled: bool = False + allowlist: list[str] = [] + timeout_seconds: int = 30 + +class FileStorageConfig(BaseModel): + path: UPath = ... + url_data_source: URLDataSourceConfig = Field(default_factory=URLDataSourceConfig) + +class StorageConfig(BaseModel): + enabled: bool = True + database: DatabaseConfig = Field(default_factory=DatabaseConfig) + files: FileStorageConfig = Field(default_factory=FileStorageConfig) +``` + +- `storage.enabled` remains a top-level toggle that disables all stateful routes (database, files, threads). +- `Database` takes a `DatabaseConfig` instead of a raw DSN string. +- `FileHandler` takes a `FileStorageConfig` instead of raw `root` and `database` (plus other relevant params). + +##### `url_data_source` properties + +| Property | Type | Default | Description | +|---|---|---|---| +| `enabled` | `bool` | `false` | When false, any file source with `type: url` returns a 400 error. | +| `allowlist` | `list[str]` | `[]` | Case-insensitive domain list. Only URLs whose hostname matches an entry (or is a subdomain of an entry) are permitted. When empty and `enabled` is true, all URLs are rejected. Each entry is normalized via the IDNA punycode encoder before storage. | +| `timeout_seconds` | `int` | `30` | Per-request timeout for DNS + connect + read of the URL fetch. | + +##### Example YAML + +```yaml +storage: + enabled: true + database: + dsn: sqlite:///data/state.db + files: + path: /data/files + url_data_source: + enabled: true + allowlist: + - "raw.githubusercontent.com" + - "github.com" + timeout_seconds: 30 +``` + +### 2. IDNA / Punycode Normalization Helper + +Write a small shared helper function to normalize hostnames before comparison. It is called both: + +- At config load time, to normalize each entry in `url_fetch_allowlist`. +- At request time, to normalize the extracted hostname from the user-supplied URL. + +```python +def normalize_hostname(host: str) -> str: + """Normalize a hostname to lowercase ASCII (punycode form). + + Handles internationalized domain names by encoding them to + their IDNA2003 ASCII-compatible form. Pure-ASCII inputs are + lowercased and returned as-is. + + Raises ValueError if the hostname is not valid IDNA. + """ + return host.encode("idna").decode("ascii").lower() +``` + +Using `str.encode("idna")` and then decoding back to ASCII ensures that: + +- `"München.example.com"` → `"xn--mnchen-3ya.example.com"` +- `"GITHUB.COM"` → `"github.com"` +- `"xn--mnchen-3ya.example.com"` → `"xn--mnchen-3ya.example.com"` (idempotent) + +### 3. Validation Logic in FileHandler + +Add a validation method `_validate_url` in `FileHandler`, called at the start of `_extract_url` and after each redirect hop. + +```python +async def _validate_url(self, url: str) -> str: + """Validate a URL against the SSRF guard config. + + Returns the validated URL string on success. + Raises HTTPException(400) on failure. + """ + config = self._file_storage_config.url_data_source + + if not config.enabled: + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL file source is not enabled") + + parsed = urllib.parse.urlparse(url) + hostname = parsed.hostname + + if not hostname: + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") + + normalized = normalize_hostname(hostname) + + # Allowlist check + if not config.allowlist: + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") + + allowed = False + for entry in config.allowlist: + entry_norm = normalize_hostname(entry) + if normalized == entry_norm or normalized.endswith("." + entry_norm): + allowed = True + break + + if not allowed: + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") + + return url +``` + +Key points: + +- **IP-literal hostnames** (IPv4 dotted, IPv6 colon-hex, bracketed IPv6) are rejected unless the operator explicitly adds the IP string to the allowlist. This is a natural consequence of the allowlist model — `"93.184.216.34"` would match `normalize_hostname("93.184.216.34")` → `"93.184.216.34"`. +- **Error messages are generic** — never include the blocked hostname, the allowlist entry checked, or any internal details. All diagnostic information goes to the OpenTelemetry trace. +- **Port numbers** are handled naturally by `urllib.parse.urlparse` — `urlparse` separates hostname from port, so `github.com:8080` extracts hostname `"github.com"`. +- **Trailing dots** in hostnames (valid DNS root references like `github.com.`) are **not** normalized by this code. If `github.com.` reaches the validator, its hostname is `"github.com."` which won't match `"github.com"`. This is acceptable — operators should not use trailing dots. No normalization is attempted. + +### 4. Redirect Handling + +Set `follow_redirects=False` on the `httpx.AsyncClient` and implement a manual redirect loop within `_extract_url`: + +1. Issue the initial GET request with `follow_redirects=False`. +2. If the response is a redirect (3xx with a `Location` header), extract the redirect target URL. +3. Validate the target URL through `_validate_url`. +4. Issue a new GET request to the validated target. +5. Repeat up to a maximum of 20 redirects (httpx default). +6. On success, proceed with content extraction as before. + +This approach ensures **every hop** in a redirect chain is validated against the allowlist. An allowlisted domain cannot redirect to an internal IP without being caught. + +**No config option to disable redirects.** Redirects are always allowed, subject to per-hop allowlist validation. A redirects-enabled toggle would create an operator footgun without meaningful security benefit. + +### 5. Error Responses + +All blocked URL fetch requests return a `400 Bad Request` with a generic `detail` string, using FastAPI's standard `HTTPException`: + +```python +raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL file source is not enabled") +# or +raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") +``` + +These match the existing pattern used by `_extract_custom` for unsupported source types (`422` for invalid source type, `400` for blocked URL). + +### 6. Tracing + +The existing `"FileHandler.fetch_url"` OpenTelemetry span is extended with attributes for diagnostics: + +- `ssrf.blocked_reason` — `"not_enabled"` or `"not_allowed"` (when the request is rejected) +- `ssrf.hostname` — the normalized hostname that was checked +- `ssrf.allowlist_entry` — the allowlist entry that matched (if applicable) +- `ssrf.redirect_chain` — list of URLs visited in the redirect chain +- `ssrf.redirect_count` — number of redirect hops followed + +No dedicated structured log lines are emitted beyond tracing. The trace provides full detail for debugging; logs do not need to duplicate it. + +## Tradeoffs & Risks + +- **Usability vs. security:** Disabling URL fetching by default breaks any workflow that depends on it until the operator explicitly configures it. This is intentional — SSRF is a critical-class vulnerability and should require deliberate enabling. The error message directs operators to the configuration option indirectly (generic "not enabled" message). +- **Allowlist granularity:** Domain-level allowlisting is coarse. An attacker who controls a subdomain of an allowlisted domain (e.g., `evil.github.io` if `github.io` is allowlisted) could still abuse it. More granular approaches (path-based, content-type-based) add complexity. The allowlist is documented as a security boundary that operators must configure carefully. +- **DNS rebinding:** An attacker who controls a domain and its authoritative DNS server can return different IPs for successive queries from the same client. If the first query (during allowlist validation) returns a public IP and the second query (during the actual HTTP request) returns an internal IP, the request reaches an internal target despite the allowlist check. Full protection requires a custom transport layer that pins DNS resolution — out of scope for this design. The per-hop redirect validation and IP-literal rejection mitigate simpler bypass variants. This is an accepted risk for the initial implementation. +- **IDNA2003 vs IDNA2008:** Python's `encode("idna")` implements IDNA2003. Some Unicode characters handled by IDNA2008 (e.g., `ß` → `"ss"`) may produce unexpected results. This is acceptable for an alpha-stage project. If edge cases arise, the normalization helper can be swapped for an IDNA2008 library. +- **Performance:** URL validation is cheap (string comparison). The HTTP timeout prevents resource exhaustion from a slow peer. +- **No logging beyond tracing:** Detailed diagnostic data is stored in OpenTelemetry spans, not in structured logs. This keeps log volume low for normal operation. Debugging a blocked request requires accessing trace data. + +## Testing Strategy + +- **Unit tests for `normalize_hostname`:** ASCII lowercasing, Unicode → punycode, already-punycode idempotency, invalid IDNA raises `ValueError`. +- **Unit tests for `_validate_url`:** + - `url_fetch_enabled = false` → 400. + - `url_fetch_allowlist = []` with enabled → 400. + - Exact match, subdomain match, case-insensitive match. + - Non-match → 400. + - IP literal hostname (not in allowlist) → 400. + - IP literal hostname (in allowlist as string) → allowed. + - IDN hostname matching IDN allowlist entry. + - IDN hostname matching punycode allowlist entry. + - Hostname with trailing dot (not matching) → 400. + - URL with userinfo (`user:pass@host`). + - URL with non-standard port. +- **Unit tests for redirect loop:** + - Single redirect to allowlisted domain → success. + - Single redirect to non-allowlisted domain → 400. + - Chain of consecutive redirects staying within allowlist → success. + - Chain that eventually leaves allowlist → 400. + - Exceeding max redirect count → error. +- **Integration tests:** + - Start ravnar with URL fetching enabled and a known allowlist. + - Upload file from allowlisted URL → success. + - Upload file from non-allowlisted URL → 400. + - Upload file with `type: url` when fetching is disabled → 400. +- **No e2e tests needed** beyond the integration coverage. + +## Open Questions + +*(none — all design decisions are resolved)* From 2276459328855f79240d1b6c6b185e0af97c1bb7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 8 Jun 2026 08:08:24 +0000 Subject: [PATCH 2/4] update design with wildcard sentinel, timedelta timeout, and per-request scope --- DESIGN.md | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/DESIGN.md b/DESIGN.md index 9b23c73..137a733 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -56,7 +56,7 @@ class DatabaseConfig(BaseModel): class URLDataSourceConfig(BaseModel): enabled: bool = False allowlist: list[str] = [] - timeout_seconds: int = 30 + timeout: timedelta = timedelta(seconds=30) class FileStorageConfig(BaseModel): path: UPath = ... @@ -77,8 +77,8 @@ class StorageConfig(BaseModel): | Property | Type | Default | Description | |---|---|---|---| | `enabled` | `bool` | `false` | When false, any file source with `type: url` returns a 400 error. | -| `allowlist` | `list[str]` | `[]` | Case-insensitive domain list. Only URLs whose hostname matches an entry (or is a subdomain of an entry) are permitted. When empty and `enabled` is true, all URLs are rejected. Each entry is normalized via the IDNA punycode encoder before storage. | -| `timeout_seconds` | `int` | `30` | Per-request timeout for DNS + connect + read of the URL fetch. | +| `allowlist` | `list[str]` | `[]` | Case-insensitive domain list. Only URLs whose hostname matches an entry (or is a subdomain of an entry) are permitted. When empty and `enabled` is true, all URLs are rejected. The sentinel value `"*"` (as the sole entry) allows all hostnames — same pattern as Starlette's `CORSMiddleware.allowed_origins`. Each entry is normalized via the IDNA punycode encoder before comparison. | +| `timeout` | `timedelta` | `30s` | Per-request timeout applied to each individual HTTP request in the redirect chain. Pydantic accepts `int` (seconds), `"30s"`, `"0.5m"`, etc. The total worst-case time for a chain is `timeout × max_redirects` (default 20). | ##### Example YAML @@ -94,7 +94,7 @@ storage: allowlist: - "raw.githubusercontent.com" - "github.com" - timeout_seconds: 30 + timeout: 30 ``` ### 2. IDNA / Punycode Normalization Helper @@ -151,6 +151,10 @@ async def _validate_url(self, url: str) -> str: if not config.allowlist: raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") + # Wildcard sentinel — same pattern as Starlette CORSMiddleware + if "*" in config.allowlist: + return url + allowed = False for entry in config.allowlist: entry_norm = normalize_hostname(entry) @@ -179,9 +183,11 @@ Set `follow_redirects=False` on the `httpx.AsyncClient` and implement a manual r 2. If the response is a redirect (3xx with a `Location` header), extract the redirect target URL. 3. Validate the target URL through `_validate_url`. 4. Issue a new GET request to the validated target. -5. Repeat up to a maximum of 20 redirects (httpx default). +5. Repeat up to a maximum of 20 redirects. 6. On success, proceed with content extraction as before. +Each request in the loop gets its own full `timeout` budget (per-request, not cumulative across the chain). The max redirect count bounds the total worst-case time to `timeout × 20`. + This approach ensures **every hop** in a redirect chain is validated against the allowlist. An allowlisted domain cannot redirect to an internal IP without being caught. **No config option to disable redirects.** Redirects are always allowed, subject to per-hop allowlist validation. A redirects-enabled toggle would create an operator footgun without meaningful security benefit. From b65f0b69b496cf37a4f8808f0f121e83035ebf6d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 8 Jun 2026 08:15:40 +0000 Subject: [PATCH 3/4] Add SSRF guard for URL-based file sources --- src/_ravnar/api/__init__.py | 4 +- src/_ravnar/config.py | 40 +++++- src/_ravnar/file_storage.py | 116 +++++++++++++++-- tests/api/test_files.py | 35 +++++- tests/test_ssrf.py | 241 ++++++++++++++++++++++++++++++++++++ 5 files changed, 415 insertions(+), 21 deletions(-) create mode 100644 tests/test_ssrf.py diff --git a/src/_ravnar/api/__init__.py b/src/_ravnar/api/__init__.py index 1ed47f6..85c0e29 100644 --- a/src/_ravnar/api/__init__.py +++ b/src/_ravnar/api/__init__.py @@ -70,8 +70,8 @@ def _make_stateful_router( from _ravnar.file_storage import FileHandler from _ravnar.mixin import SetupTeardownMixin - database = Database(url=str(storage_config.database_dsn)) - file_handler = FileHandler(root=storage_config.file_storage_path, database=database) + database = Database(url=str(storage_config.database.dsn)) + file_handler = FileHandler(file_storage_config=storage_config.files, database=database) router = schema.APIRouter( tags=["Stateful"], diff --git a/src/_ravnar/config.py b/src/_ravnar/config.py index 19813ad..89ccfaa 100644 --- a/src/_ravnar/config.py +++ b/src/_ravnar/config.py @@ -2,6 +2,7 @@ import os import sys +from datetime import timedelta from pathlib import Path from typing import Any, Self, TypeVar @@ -23,6 +24,18 @@ T = TypeVar("T") +def normalize_hostname(host: str) -> str: + """Normalize a hostname to lowercase ASCII (punycode form). + + Handles internationalized domain names by encoding them to + their IDNA2003 ASCII-compatible form. Pure-ASCII inputs are + lowercased and returned as-is. + + Raises ValueError if the hostname is not valid IDNA. + """ + return host.encode("idna").decode("ascii").lower() + + def interactive_session() -> bool: return sys.stdout.isatty() @@ -73,10 +86,33 @@ def _local_storage() -> Path: return p +class DatabaseConfig(BaseModel, RenderableMixin): + dsn: str = Field(default_factory=lambda: f"sqlite:///{_local_storage() / 'state.db'}") + + +class URLDataSourceConfig(BaseModel, RenderableMixin): + enabled: bool = False + allowlist: list[str] = [] + timeout: timedelta = timedelta(seconds=30) + + @field_validator("allowlist", mode="after") + @classmethod + def _normalize_allowlist_entries(cls, v: list[str]) -> list[str]: + return [ + entry if entry == "*" else normalize_hostname(entry) + for entry in v + ] + + +class FileStorageConfig(BaseModel, RenderableMixin): + path: UPath = Field(default_factory=lambda: UPath(_local_storage() / "files")) + url_data_source: URLDataSourceConfig = Field(default_factory=URLDataSourceConfig) + + class StorageConfig(BaseModel, RenderableMixin): enabled: bool = True - database_dsn: str = Field(default_factory=lambda: f"sqlite:///{_local_storage() / 'state.db'}") - file_storage_path: UPath = Field(default_factory=lambda: UPath(_local_storage() / "files")) + database: DatabaseConfig = Field(default_factory=DatabaseConfig) + files: FileStorageConfig = Field(default_factory=FileStorageConfig) class DynamicAgentConfig(BaseModel, RenderableMixin): diff --git a/src/_ravnar/file_storage.py b/src/_ravnar/file_storage.py index 25ac57c..0b46690 100644 --- a/src/_ravnar/file_storage.py +++ b/src/_ravnar/file_storage.py @@ -3,6 +3,7 @@ import base64 import dataclasses import mimetypes +import urllib.parse import uuid from datetime import datetime from typing import TYPE_CHECKING, Annotated, Any, Self @@ -15,6 +16,7 @@ from upath import UPath from _ravnar import orm, schema +from _ravnar.config import FileStorageConfig, normalize_hostname from _ravnar.observability import traced from _ravnar.utils import as_awaitable @@ -98,8 +100,9 @@ class WrappedMetadata(schema.BaseModel): class FileHandler: - def __init__(self, *, root: UPath, database: Database) -> None: - self._storage = _Storage(root) + def __init__(self, *, file_storage_config: "FileStorageConfig", database: Database) -> None: + self._file_storage_config = file_storage_config + self._storage = _Storage(file_storage_config.path) self._database = database self._extractors = { @@ -108,6 +111,52 @@ def __init__(self, *, root: UPath, database: Database) -> None: "custom": self._extract_custom, } + def _validate_url(self, url: str) -> str: + """Validate a URL against the SSRF guard config. + + Returns the validated URL string on success. + Raises HTTPException(400) on failure. + """ + config = self._file_storage_config.url_data_source + + if not config.enabled: + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL file source is not enabled") + + parsed = urllib.parse.urlparse(url) + hostname = parsed.hostname + + if not hostname: + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") + + try: + normalized = normalize_hostname(hostname) + except Exception: + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") + + span = trace.get_current_span() + span.set_attribute("ssrf.hostname", normalized) + + if not config.allowlist: + span.set_attribute("ssrf.blocked_reason", "not_allowed") + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") + + # Wildcard sentinel — same pattern as Starlette CORSMiddleware + if "*" in config.allowlist: + return url + + allowed = False + for entry in config.allowlist: + if normalized == entry or normalized.endswith("." + entry): + span.set_attribute("ssrf.allowlist_entry", entry) + allowed = True + break + + if not allowed: + span.set_attribute("ssrf.blocked_reason", "not_allowed") + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") + + return url + @traced async def add(self, file_input_content: FileInputContent, *, user_id: str) -> tuple[orm.File, bytes]: source_type = file_input_content.source.type @@ -152,33 +201,74 @@ async def _extract_data(file_input_content: FileInputContent) -> _FileData: mime_type=file_input_content.source.mime_type, ) - @staticmethod - async def _extract_url(file_input_content: FileInputContent) -> _FileData: + async def _extract_url(self, file_input_content: FileInputContent) -> _FileData: assert isinstance(file_input_content.source, ag_ui.core.InputContentUrlSource) url = file_input_content.source.value mime_type = file_input_content.source.mime_type + max_redirects = 20 + timeout = self._file_storage_config.url_data_source.timeout tracer = trace.get_tracer(__name__) with tracer.start_as_current_span("FileHandler.fetch_url"): - async with httpx.AsyncClient(follow_redirects=True) as client: - response = await client.get(url) - if not response.is_success: + self._validate_url(url) + + config = httpx.Timeout(timeout.total_seconds()) + async with httpx.AsyncClient(follow_redirects=False, timeout=config) as client: + redirect_chain: list[str] = [] + current_url = url + for _ in range(max_redirects): + response = await client.get(current_url) + + if response.is_redirect: + location = response.headers.get("Location") + if not location: + span = trace.get_current_span() + exc = HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Failed to fetch file from URL", + ) + span.record_exception(exc) + span.set_status(trace.StatusCode.ERROR, description="Redirect missing Location header") + raise exc + redirect_chain.append(location) + self._validate_url(location) + current_url = location + continue + + if not response.is_success: + span = trace.get_current_span() + exc = HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch file from URL" + ) + span.record_exception(exc) + span.set_status(trace.StatusCode.ERROR, description="Failed to fetch file from URL") + raise exc + + content = response.content + content_type = response.headers.get("Content-Type", "").split(";", 1)[0].strip().lower() + break + else: span = trace.get_current_span() - exc = HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch file from URL") + exc = HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch file from URL" + ) span.record_exception(exc) - span.set_status(trace.StatusCode.ERROR, description="Failed to fetch file from URL") + span.set_status(trace.StatusCode.ERROR, description="Too many redirects") raise exc - content = response.content - content_type = response.headers.get("Content-Type", "").split(";", 1)[0].strip().lower() + + span = trace.get_current_span() + if redirect_chain: + span.set_attribute("ssrf.redirect_chain", redirect_chain) + span.set_attribute("ssrf.redirect_count", len(redirect_chain)) if not mime_type: mime_type = content_type if not mime_type: - mime_type, _ = mimetypes.guess_type(url, strict=False) + mime_type, _ = mimetypes.guess_type(current_url, strict=False) if not mime_type: mime_type = "application/octet-stream" - return _FileData(content=content, mime_type=mime_type, source_data={"url": url}) + return _FileData(content=content, mime_type=mime_type, source_data={"url": current_url}) @staticmethod async def _extract_custom(file_input_content: FileInputContent) -> _FileData: diff --git a/tests/api/test_files.py b/tests/api/test_files.py index 69d9787..b8f1482 100644 --- a/tests/api/test_files.py +++ b/tests/api/test_files.py @@ -1,5 +1,6 @@ import base64 import mimetypes +from urllib.parse import urlparse import ag_ui.core import compyre @@ -7,6 +8,7 @@ import pytest import pytest_httpserver.httpserver +from _ravnar.config import BaseConfig from _ravnar.file_storage import MIME_TYPE, DataSourceValue, FileInputContent @@ -44,11 +46,36 @@ def test_e2e_data_source(self, app_client, mime_type, metadata): assert response.content == content assert response.headers.get("Content-Type") == mime_type + @pytest.fixture + def url_app_client(self, httpserver, request): + """Create a test client with URL source enabled and the test server's hostname allowlisted.""" + from tests.utils import TestClient + + parsed = urlparse(httpserver.url_for("/")) + hostname = parsed.hostname or "localhost" + config = BaseConfig.model_validate( + { + "security": { + "authenticator": "tests.utils.HeaderAuthenticator", + }, + "storage": { + "files": { + "url_data_source": { + "enabled": True, + "allowlist": [hostname], + }, + }, + }, + } + ) + with TestClient.from_config(config) as client: + yield client + @pytest.mark.parametrize("mime_type", [None, "image/jpeg", "application/octet-stream"]) @pytest.mark.parametrize("source_content_type", [None, "image/png"]) @pytest.mark.parametrize("metadata", [None, "metadata", {"foo": "bar"}]) @pytest.mark.parametrize("endpoint", ["/image.jpg", "/file"]) - def test_e2e_url_source(self, app_client, httpserver, mime_type, source_content_type, metadata, endpoint): + def test_e2e_url_source(self, url_app_client, httpserver, mime_type, source_content_type, metadata, endpoint): content = b"content" response_cls = pytest_httpserver.httpserver.Response @@ -62,7 +89,7 @@ def test_e2e_url_source(self, app_client, httpserver, mime_type, source_content_ mime_type or source_content_type or mimetypes.guess_type(url, strict=False)[0] or "application/octet-stream" ) - response = app_client.post( + response = url_app_client.post( "/api/files", json=ag_ui.core.ImageInputContent( source=ag_ui.core.InputContentUrlSource(value=url, mime_type=mime_type), metadata=metadata @@ -81,10 +108,10 @@ def test_e2e_url_source(self, app_client, httpserver, mime_type, source_content_ file_id = value.file_id expected = file_input_content - response = app_client.get(f"/api/files/{file_id}").raise_for_status() + response = url_app_client.get(f"/api/files/{file_id}").raise_for_status() actual = pydantic.TypeAdapter(FileInputContent).validate_json(response.content) compyre.assert_equal(actual, expected) - response = app_client.get(f"/api/files/{file_id}/content").raise_for_status() + response = url_app_client.get(f"/api/files/{file_id}/content").raise_for_status() assert response.content == content assert response.headers.get("Content-Type") == expected_mime_type diff --git a/tests/test_ssrf.py b/tests/test_ssrf.py new file mode 100644 index 0000000..842c595 --- /dev/null +++ b/tests/test_ssrf.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import urllib.parse +from datetime import timedelta + +import httpx +import pytest +import pytest_httpserver +from fastapi import HTTPException, status + +import pydantic +from _ravnar.config import URLDataSourceConfig, FileStorageConfig, normalize_hostname + + +class TestNormalizeHostname: + def test_ascii_lowercasing(self) -> None: + assert normalize_hostname("GITHUB.COM") == "github.com" + assert normalize_hostname("Example.COM") == "example.com" + + def test_unicode_to_punycode(self) -> None: + assert normalize_hostname("München.example.com") == "xn--mnchen-3ya.example.com" + + def test_punycode_idempotent(self) -> None: + result = normalize_hostname("xn--mnchen-3ya.example.com") + assert result == "xn--mnchen-3ya.example.com" + + def test_invalid_idna_raises_value_error(self) -> None: + # Double dot creates an empty label which is invalid in IDNA + with pytest.raises(ValueError): + normalize_hostname("example..com") + + def test_ipv4_passthrough(self) -> None: + assert normalize_hostname("93.184.216.34") == "93.184.216.34" + + def test_trailing_dot_preserved(self) -> None: + assert normalize_hostname("github.com.") == "github.com." + + +class TestURLDataSourceConfig: + def test_allowlist_normalization(self) -> None: + config = URLDataSourceConfig(allowlist=["GITHUB.COM", "München.example.com"]) + assert config.allowlist == ["github.com", "xn--mnchen-3ya.example.com"] + + def test_wildcard_preserved(self) -> None: + config = URLDataSourceConfig(allowlist=["*"]) + assert config.allowlist == ["*"] + + def test_invalid_allowlist_entry_raises(self) -> None: + # Double dot creates an empty label which is invalid in IDNA + with pytest.raises(pydantic.ValidationError): + URLDataSourceConfig(allowlist=["example..com"]) + + def test_timeout_default(self) -> None: + config = URLDataSourceConfig() + assert config.timeout == timedelta(seconds=30) + + def test_enabled_default(self) -> None: + config = URLDataSourceConfig() + assert config.enabled is False + + +def _make_handler(url_data_source_config: URLDataSourceConfig | None = None): + """Create a FileHandler with the given URL data source config for testing.""" + from _ravnar.file_storage import FileHandler + + file_storage_config = FileStorageConfig( + url_data_source=url_data_source_config or URLDataSourceConfig(enabled=True, allowlist=["example.com"]) + ) + handler = FileHandler.__new__(FileHandler) + handler._file_storage_config = file_storage_config + return handler + + +class TestValidateURL: + def test_not_enabled(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=False)) + with pytest.raises(HTTPException) as exc_info: + handler._validate_url("http://example.com/file") + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert "URL file source is not enabled" in exc_info.value.detail + + def test_empty_allowlist_blocks_all(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=[])) + with pytest.raises(HTTPException) as exc_info: + handler._validate_url("http://example.com/file") + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + + def test_exact_match(self) -> None: + handler = _make_handler() + result = handler._validate_url("http://example.com/file") + assert result == "http://example.com/file" + + def test_subdomain_match(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["example.com"])) + result = handler._validate_url("http://sub.example.com/file") + assert result == "http://sub.example.com/file" + + def test_case_insensitive_match(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["EXAMPLE.COM"])) + result = handler._validate_url("http://example.com/file") + assert result == "http://example.com/file" + + def test_non_match(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["example.com"])) + with pytest.raises(HTTPException) as exc_info: + handler._validate_url("http://evil.com/file") + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + + def test_wildcard_allows_all(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["*"])) + result = handler._validate_url("http://evil.com/file") + assert result == "http://evil.com/file" + + def test_wildcard_allows_internal(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["*"])) + result = handler._validate_url("http://169.254.169.254/latest/meta-data/") + assert result == "http://169.254.169.254/latest/meta-data/" + + def test_url_with_userinfo(self) -> None: + handler = _make_handler() + result = handler._validate_url("http://user:pass@example.com/file") + assert result == "http://user:pass@example.com/file" + + def test_url_with_non_standard_port(self) -> None: + handler = _make_handler() + result = handler._validate_url("http://example.com:8080/file") + assert result == "http://example.com:8080/file" + + def test_hostname_trailing_dot_not_matching(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["example.com"])) + with pytest.raises(HTTPException) as exc_info: + handler._validate_url("http://example.com./file") + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + + def test_ip_literal_not_in_allowlist(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["example.com"])) + with pytest.raises(HTTPException) as exc_info: + handler._validate_url("http://93.184.216.34/file") + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + + def test_ip_literal_in_allowlist(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["93.184.216.34"])) + result = handler._validate_url("http://93.184.216.34/file") + assert result == "http://93.184.216.34/file" + + def test_url_with_no_hostname(self) -> None: + handler = _make_handler() + with pytest.raises(HTTPException) as exc_info: + handler._validate_url("file:///etc/passwd") + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + + def test_idn_hostname_matching_idn_entry(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["münchen.example.com"])) + result = handler._validate_url("http://MÜNCHEN.example.com/file") + assert result == "http://MÜNCHEN.example.com/file" + + def test_idn_hostname_matching_punycode_entry(self) -> None: + handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["xn--mnchen-3ya.example.com"])) + result = handler._validate_url("http://MÜNCHEN.example.com/file") + assert result == "http://MÜNCHEN.example.com/file" + + +class TestValidateURLIntegration: + """Integration tests via HTTP client against a running ravnar instance.""" + + @pytest.fixture + def app_client(self, httpserver): + from tests.utils import TestClient + from _ravnar.config import BaseConfig + + hostname = urllib.parse.urlparse(httpserver.url_for("/")).hostname or "localhost" + config = BaseConfig.model_validate( + { + "security": { + "authenticator": "tests.utils.HeaderAuthenticator", + }, + "storage": { + "files": { + "url_data_source": { + "enabled": True, + "allowlist": [hostname], + }, + }, + }, + } + ) + with TestClient.from_config(config) as client: + yield client + + def test_upload_from_allowlisted_url_succeeds(self, app_client, httpserver): + httpserver.expect_request("/file.txt").respond_with_data(b"hello") + url = httpserver.url_for("/file.txt") + + response = app_client.post( + "/api/files", + json={ + "type": "document", + "source": {"type": "url", "value": url}, + }, + ) + assert response.status_code == 200 + + def test_upload_from_non_allowlisted_url_fails(self, app_client, httpserver): + # Point to a URL on a non-allowlisted host + response = app_client.post( + "/api/files", + json={ + "type": "document", + "source": {"type": "url", "value": "http://evil.com/malware"}, + }, + ) + assert response.status_code == 400 + + def test_url_source_disabled_fails(self, app_client, httpserver): + """Override fixture to disable URL source.""" + from tests.utils import TestClient + from _ravnar.config import BaseConfig + + config = BaseConfig.model_validate( + { + "security": { + "authenticator": "tests.utils.HeaderAuthenticator", + }, + "storage": { + "files": { + "url_data_source": { + "enabled": False, + }, + }, + }, + } + ) + with TestClient.from_config(config) as client: + response = client.post( + "/api/files", + json={ + "type": "document", + "source": {"type": "url", "value": "http://example.com/file"}, + }, + ) + assert response.status_code == 400 From b14d51f822c23f8d6c581524e18b8c3e431bee43 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 8 Jun 2026 11:28:29 +0200 Subject: [PATCH 4/4] cleanup --- DESIGN.md | 258 ------------------------------------ src/_ravnar/api/__init__.py | 4 +- src/_ravnar/config.py | 33 ++--- src/_ravnar/database.py | 9 +- src/_ravnar/file_storage.py | 206 ++++++++++++---------------- src/_ravnar/utils.py | 4 + tests/test_ssrf.py | 96 ++++++-------- 7 files changed, 146 insertions(+), 464 deletions(-) delete mode 100644 DESIGN.md diff --git a/DESIGN.md b/DESIGN.md deleted file mode 100644 index 137a733..0000000 --- a/DESIGN.md +++ /dev/null @@ -1,258 +0,0 @@ -# Design: SSRF Guard for URL-Based File Sources - -## Summary - -Add controls to the file upload endpoint's URL-fetching capability so that ravnar cannot be used as a proxy to reach internal infrastructure or cloud metadata endpoints. URL fetching will be disabled by default, and when enabled, restricted to an explicit allowlist of permitted domains. The implementation adds a new configuration sub-object to the storage config and a validation layer in `FileHandler`. - -## Goals - -- Prevent ravnar from being coerced into making HTTP requests to internal or private IP ranges. -- Prevent data exfiltration via attacker-controlled URLs. -- Prevent DNS-rebinding-based bypasses through defense-in-depth design (per-hop redirect validation, hostname normalization). -- Maintain existing functionality for legitimate use cases (fetching files from known external domains). - -## Non-Goals - -- Adding a full HTTP proxy or egress filtering system — that is the deployment environment's responsibility. -- Rate-limiting or connection-pool sizing — these are gateway concerns (per established architecture boundaries). -- Adding support for authenticated URL fetches (e.g., Bearer tokens, cookies) — the URL source is intended for public resources only. -- Replacing the underlying HTTP client library or adding a custom DNS resolver. -- DNS-level rebinding protection (see Tradeoffs & Risks). -- Backwards compatibility for config structure — ravnar is in alpha. - -## Background / Motivation - -Services that fetch content from user-supplied URLs are a well-known vector for Server-Side Request Forgery (SSRF). An attacker who can supply a URL to the server can probe internal services, reach cloud metadata endpoints (e.g., `169.254.169.254`), or exfiltrate data to an attacker-controlled endpoint via query parameters, path segments, or DNS lookups. - -The current implementation in `FileHandler._extract_url` creates an `httpx.AsyncClient` with `follow_redirects=True` and no transport-level restrictions. Any authenticated user with `files:write` can supply any URL. This means: - -- An attacker can reach any internal service on the host or network that ravnar can reach. -- An attacker can use redirect chains (allowed → internal) to bypass naive hostname filtering. -- An attacker can use IPv6 literal addresses, DNS rebinding, or alternative representations of internal IPs. - -The fix follows a deny-by-default model: URL fetching is opt-in, and when enabled, only explicitly listed domains are permitted. - -## Design - -### 1. Configuration Model - -Restructure the `StorageConfig` hierarchy to split the monolithic config into sub-objects. Add a new `URLDataSourceConfig` for the SSRF guard settings. - -#### Current structure (before) - -```python -class StorageConfig(BaseModel): - enabled: bool = True - database_dsn: str = ... - file_storage_path: UPath = ... -``` - -#### New structure (after) - -```python -class DatabaseConfig(BaseModel): - dsn: str = ... - -class URLDataSourceConfig(BaseModel): - enabled: bool = False - allowlist: list[str] = [] - timeout: timedelta = timedelta(seconds=30) - -class FileStorageConfig(BaseModel): - path: UPath = ... - url_data_source: URLDataSourceConfig = Field(default_factory=URLDataSourceConfig) - -class StorageConfig(BaseModel): - enabled: bool = True - database: DatabaseConfig = Field(default_factory=DatabaseConfig) - files: FileStorageConfig = Field(default_factory=FileStorageConfig) -``` - -- `storage.enabled` remains a top-level toggle that disables all stateful routes (database, files, threads). -- `Database` takes a `DatabaseConfig` instead of a raw DSN string. -- `FileHandler` takes a `FileStorageConfig` instead of raw `root` and `database` (plus other relevant params). - -##### `url_data_source` properties - -| Property | Type | Default | Description | -|---|---|---|---| -| `enabled` | `bool` | `false` | When false, any file source with `type: url` returns a 400 error. | -| `allowlist` | `list[str]` | `[]` | Case-insensitive domain list. Only URLs whose hostname matches an entry (or is a subdomain of an entry) are permitted. When empty and `enabled` is true, all URLs are rejected. The sentinel value `"*"` (as the sole entry) allows all hostnames — same pattern as Starlette's `CORSMiddleware.allowed_origins`. Each entry is normalized via the IDNA punycode encoder before comparison. | -| `timeout` | `timedelta` | `30s` | Per-request timeout applied to each individual HTTP request in the redirect chain. Pydantic accepts `int` (seconds), `"30s"`, `"0.5m"`, etc. The total worst-case time for a chain is `timeout × max_redirects` (default 20). | - -##### Example YAML - -```yaml -storage: - enabled: true - database: - dsn: sqlite:///data/state.db - files: - path: /data/files - url_data_source: - enabled: true - allowlist: - - "raw.githubusercontent.com" - - "github.com" - timeout: 30 -``` - -### 2. IDNA / Punycode Normalization Helper - -Write a small shared helper function to normalize hostnames before comparison. It is called both: - -- At config load time, to normalize each entry in `url_fetch_allowlist`. -- At request time, to normalize the extracted hostname from the user-supplied URL. - -```python -def normalize_hostname(host: str) -> str: - """Normalize a hostname to lowercase ASCII (punycode form). - - Handles internationalized domain names by encoding them to - their IDNA2003 ASCII-compatible form. Pure-ASCII inputs are - lowercased and returned as-is. - - Raises ValueError if the hostname is not valid IDNA. - """ - return host.encode("idna").decode("ascii").lower() -``` - -Using `str.encode("idna")` and then decoding back to ASCII ensures that: - -- `"München.example.com"` → `"xn--mnchen-3ya.example.com"` -- `"GITHUB.COM"` → `"github.com"` -- `"xn--mnchen-3ya.example.com"` → `"xn--mnchen-3ya.example.com"` (idempotent) - -### 3. Validation Logic in FileHandler - -Add a validation method `_validate_url` in `FileHandler`, called at the start of `_extract_url` and after each redirect hop. - -```python -async def _validate_url(self, url: str) -> str: - """Validate a URL against the SSRF guard config. - - Returns the validated URL string on success. - Raises HTTPException(400) on failure. - """ - config = self._file_storage_config.url_data_source - - if not config.enabled: - raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL file source is not enabled") - - parsed = urllib.parse.urlparse(url) - hostname = parsed.hostname - - if not hostname: - raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") - - normalized = normalize_hostname(hostname) - - # Allowlist check - if not config.allowlist: - raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") - - # Wildcard sentinel — same pattern as Starlette CORSMiddleware - if "*" in config.allowlist: - return url - - allowed = False - for entry in config.allowlist: - entry_norm = normalize_hostname(entry) - if normalized == entry_norm or normalized.endswith("." + entry_norm): - allowed = True - break - - if not allowed: - raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") - - return url -``` - -Key points: - -- **IP-literal hostnames** (IPv4 dotted, IPv6 colon-hex, bracketed IPv6) are rejected unless the operator explicitly adds the IP string to the allowlist. This is a natural consequence of the allowlist model — `"93.184.216.34"` would match `normalize_hostname("93.184.216.34")` → `"93.184.216.34"`. -- **Error messages are generic** — never include the blocked hostname, the allowlist entry checked, or any internal details. All diagnostic information goes to the OpenTelemetry trace. -- **Port numbers** are handled naturally by `urllib.parse.urlparse` — `urlparse` separates hostname from port, so `github.com:8080` extracts hostname `"github.com"`. -- **Trailing dots** in hostnames (valid DNS root references like `github.com.`) are **not** normalized by this code. If `github.com.` reaches the validator, its hostname is `"github.com."` which won't match `"github.com"`. This is acceptable — operators should not use trailing dots. No normalization is attempted. - -### 4. Redirect Handling - -Set `follow_redirects=False` on the `httpx.AsyncClient` and implement a manual redirect loop within `_extract_url`: - -1. Issue the initial GET request with `follow_redirects=False`. -2. If the response is a redirect (3xx with a `Location` header), extract the redirect target URL. -3. Validate the target URL through `_validate_url`. -4. Issue a new GET request to the validated target. -5. Repeat up to a maximum of 20 redirects. -6. On success, proceed with content extraction as before. - -Each request in the loop gets its own full `timeout` budget (per-request, not cumulative across the chain). The max redirect count bounds the total worst-case time to `timeout × 20`. - -This approach ensures **every hop** in a redirect chain is validated against the allowlist. An allowlisted domain cannot redirect to an internal IP without being caught. - -**No config option to disable redirects.** Redirects are always allowed, subject to per-hop allowlist validation. A redirects-enabled toggle would create an operator footgun without meaningful security benefit. - -### 5. Error Responses - -All blocked URL fetch requests return a `400 Bad Request` with a generic `detail` string, using FastAPI's standard `HTTPException`: - -```python -raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL file source is not enabled") -# or -raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") -``` - -These match the existing pattern used by `_extract_custom` for unsupported source types (`422` for invalid source type, `400` for blocked URL). - -### 6. Tracing - -The existing `"FileHandler.fetch_url"` OpenTelemetry span is extended with attributes for diagnostics: - -- `ssrf.blocked_reason` — `"not_enabled"` or `"not_allowed"` (when the request is rejected) -- `ssrf.hostname` — the normalized hostname that was checked -- `ssrf.allowlist_entry` — the allowlist entry that matched (if applicable) -- `ssrf.redirect_chain` — list of URLs visited in the redirect chain -- `ssrf.redirect_count` — number of redirect hops followed - -No dedicated structured log lines are emitted beyond tracing. The trace provides full detail for debugging; logs do not need to duplicate it. - -## Tradeoffs & Risks - -- **Usability vs. security:** Disabling URL fetching by default breaks any workflow that depends on it until the operator explicitly configures it. This is intentional — SSRF is a critical-class vulnerability and should require deliberate enabling. The error message directs operators to the configuration option indirectly (generic "not enabled" message). -- **Allowlist granularity:** Domain-level allowlisting is coarse. An attacker who controls a subdomain of an allowlisted domain (e.g., `evil.github.io` if `github.io` is allowlisted) could still abuse it. More granular approaches (path-based, content-type-based) add complexity. The allowlist is documented as a security boundary that operators must configure carefully. -- **DNS rebinding:** An attacker who controls a domain and its authoritative DNS server can return different IPs for successive queries from the same client. If the first query (during allowlist validation) returns a public IP and the second query (during the actual HTTP request) returns an internal IP, the request reaches an internal target despite the allowlist check. Full protection requires a custom transport layer that pins DNS resolution — out of scope for this design. The per-hop redirect validation and IP-literal rejection mitigate simpler bypass variants. This is an accepted risk for the initial implementation. -- **IDNA2003 vs IDNA2008:** Python's `encode("idna")` implements IDNA2003. Some Unicode characters handled by IDNA2008 (e.g., `ß` → `"ss"`) may produce unexpected results. This is acceptable for an alpha-stage project. If edge cases arise, the normalization helper can be swapped for an IDNA2008 library. -- **Performance:** URL validation is cheap (string comparison). The HTTP timeout prevents resource exhaustion from a slow peer. -- **No logging beyond tracing:** Detailed diagnostic data is stored in OpenTelemetry spans, not in structured logs. This keeps log volume low for normal operation. Debugging a blocked request requires accessing trace data. - -## Testing Strategy - -- **Unit tests for `normalize_hostname`:** ASCII lowercasing, Unicode → punycode, already-punycode idempotency, invalid IDNA raises `ValueError`. -- **Unit tests for `_validate_url`:** - - `url_fetch_enabled = false` → 400. - - `url_fetch_allowlist = []` with enabled → 400. - - Exact match, subdomain match, case-insensitive match. - - Non-match → 400. - - IP literal hostname (not in allowlist) → 400. - - IP literal hostname (in allowlist as string) → allowed. - - IDN hostname matching IDN allowlist entry. - - IDN hostname matching punycode allowlist entry. - - Hostname with trailing dot (not matching) → 400. - - URL with userinfo (`user:pass@host`). - - URL with non-standard port. -- **Unit tests for redirect loop:** - - Single redirect to allowlisted domain → success. - - Single redirect to non-allowlisted domain → 400. - - Chain of consecutive redirects staying within allowlist → success. - - Chain that eventually leaves allowlist → 400. - - Exceeding max redirect count → error. -- **Integration tests:** - - Start ravnar with URL fetching enabled and a known allowlist. - - Upload file from allowlisted URL → success. - - Upload file from non-allowlisted URL → 400. - - Upload file with `type: url` when fetching is disabled → 400. -- **No e2e tests needed** beyond the integration coverage. - -## Open Questions - -*(none — all design decisions are resolved)* diff --git a/src/_ravnar/api/__init__.py b/src/_ravnar/api/__init__.py index 85c0e29..b629c31 100644 --- a/src/_ravnar/api/__init__.py +++ b/src/_ravnar/api/__init__.py @@ -70,8 +70,8 @@ def _make_stateful_router( from _ravnar.file_storage import FileHandler from _ravnar.mixin import SetupTeardownMixin - database = Database(url=str(storage_config.database.dsn)) - file_handler = FileHandler(file_storage_config=storage_config.files, database=database) + database = Database(config=storage_config.database) + file_handler = FileHandler(config=storage_config.files, database=database) router = schema.APIRouter( tags=["Stateful"], diff --git a/src/_ravnar/config.py b/src/_ravnar/config.py index 89ccfaa..90c149e 100644 --- a/src/_ravnar/config.py +++ b/src/_ravnar/config.py @@ -7,16 +7,11 @@ from typing import Any, Self, TypeVar import l2sl -from pydantic import ( - BaseModel, - Field, - field_validator, - model_validator, -) +from pydantic import BaseModel, Field, field_validator, model_validator from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, YamlConfigSettingsSource from upath import UPath -from _ravnar.utils import ImportStringWithParams, render_template +from _ravnar.utils import ImportStringWithParams, normalize_hostname, render_template from .agents import Agent, DefaultAgent from .authenticators import Authenticator @@ -24,18 +19,6 @@ T = TypeVar("T") -def normalize_hostname(host: str) -> str: - """Normalize a hostname to lowercase ASCII (punycode form). - - Handles internationalized domain names by encoding them to - their IDNA2003 ASCII-compatible form. Pure-ASCII inputs are - lowercased and returned as-is. - - Raises ValueError if the hostname is not valid IDNA. - """ - return host.encode("idna").decode("ascii").lower() - - def interactive_session() -> bool: return sys.stdout.isatty() @@ -98,10 +81,14 @@ class URLDataSourceConfig(BaseModel, RenderableMixin): @field_validator("allowlist", mode="after") @classmethod def _normalize_allowlist_entries(cls, v: list[str]) -> list[str]: - return [ - entry if entry == "*" else normalize_hostname(entry) - for entry in v - ] + if "*" in v: + if len(v) > 1: + raise ValueError( + 'Wildcard "*" must be the sole allowlist entry. It cannot be combined with specific domains.' + ) + else: + v = [normalize_hostname(entry) for entry in v] + return v class FileStorageConfig(BaseModel, RenderableMixin): diff --git a/src/_ravnar/database.py b/src/_ravnar/database.py index ea90b56..e9bc014 100644 --- a/src/_ravnar/database.py +++ b/src/_ravnar/database.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Awaitable, Callable, Collection from contextlib import AbstractAsyncContextManager, AbstractContextManager from math import ceil -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from fastapi import HTTPException, status from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor @@ -23,14 +23,17 @@ from .observability import traced from .utils import as_async_context_manager, as_awaitable, now +if TYPE_CHECKING: + from _ravnar.config import DatabaseConfig + class SessionFactoryParams(TypedDict): expire_on_commit: bool class Database(SetupTeardownMixin): - def __init__(self, url: str) -> None: - url = make_url(url) + def __init__(self, config: DatabaseConfig) -> None: + url = make_url(config.dsn) if url.drivername.startswith("sqlite") and (url.database is None or url.database == ":memory:"): # See https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#using-a-memory-database-in-multiple-threads diff --git a/src/_ravnar/file_storage.py b/src/_ravnar/file_storage.py index 0b46690..8b4ede4 100644 --- a/src/_ravnar/file_storage.py +++ b/src/_ravnar/file_storage.py @@ -5,7 +5,7 @@ import mimetypes import urllib.parse import uuid -from datetime import datetime +from datetime import datetime, timedelta from typing import TYPE_CHECKING, Annotated, Any, Self import ag_ui.core @@ -16,11 +16,11 @@ from upath import UPath from _ravnar import orm, schema -from _ravnar.config import FileStorageConfig, normalize_hostname from _ravnar.observability import traced -from _ravnar.utils import as_awaitable +from _ravnar.utils import as_awaitable, normalize_hostname if TYPE_CHECKING: + from _ravnar.config import FileStorageConfig from _ravnar.database import Database @@ -100,70 +100,25 @@ class WrappedMetadata(schema.BaseModel): class FileHandler: - def __init__(self, *, file_storage_config: "FileStorageConfig", database: Database) -> None: - self._file_storage_config = file_storage_config - self._storage = _Storage(file_storage_config.path) + def __init__(self, *, config: FileStorageConfig, database: Database) -> None: + self._config = config + self._storage = _Storage(config.path) self._database = database - self._extractors = { - "data": self._extract_data, - "url": self._extract_url, - "custom": self._extract_custom, - } - - def _validate_url(self, url: str) -> str: - """Validate a URL against the SSRF guard config. - - Returns the validated URL string on success. - Raises HTTPException(400) on failure. - """ - config = self._file_storage_config.url_data_source - - if not config.enabled: - raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL file source is not enabled") - - parsed = urllib.parse.urlparse(url) - hostname = parsed.hostname - - if not hostname: - raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") - - try: - normalized = normalize_hostname(hostname) - except Exception: - raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") - - span = trace.get_current_span() - span.set_attribute("ssrf.hostname", normalized) - - if not config.allowlist: - span.set_attribute("ssrf.blocked_reason", "not_allowed") - raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") - - # Wildcard sentinel — same pattern as Starlette CORSMiddleware - if "*" in config.allowlist: - return url - - allowed = False - for entry in config.allowlist: - if normalized == entry or normalized.endswith("." + entry): - span.set_attribute("ssrf.allowlist_entry", entry) - allowed = True - break - - if not allowed: - span.set_attribute("ssrf.blocked_reason", "not_allowed") - raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") - - return url - @traced async def add(self, file_input_content: FileInputContent, *, user_id: str) -> tuple[orm.File, bytes]: source_type = file_input_content.source.type - if source_type not in self._extractors: - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unsupported file source type") - - data = await self._extractors[source_type](file_input_content) + try: + extractor = { + "data": self._extract_data, + "url": self._extract_url, + }[source_type] + except KeyError: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unsupported file source type" + ) from None + + data = await extractor(file_input_content) file = orm.File( user_id=user_id, type=file_input_content.type, @@ -204,77 +159,84 @@ async def _extract_data(file_input_content: FileInputContent) -> _FileData: async def _extract_url(self, file_input_content: FileInputContent) -> _FileData: assert isinstance(file_input_content.source, ag_ui.core.InputContentUrlSource) - url = file_input_content.source.value + if not self._config.url_data_source.enabled: + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL file source is not enabled") + mime_type = file_input_content.source.mime_type - max_redirects = 20 - timeout = self._file_storage_config.url_data_source.timeout - tracer = trace.get_tracer(__name__) - with tracer.start_as_current_span("FileHandler.fetch_url"): - self._validate_url(url) - - config = httpx.Timeout(timeout.total_seconds()) - async with httpx.AsyncClient(follow_redirects=False, timeout=config) as client: - redirect_chain: list[str] = [] - current_url = url - for _ in range(max_redirects): - response = await client.get(current_url) - - if response.is_redirect: - location = response.headers.get("Location") - if not location: - span = trace.get_current_span() - exc = HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, - detail="Failed to fetch file from URL", - ) - span.record_exception(exc) - span.set_status(trace.StatusCode.ERROR, description="Redirect missing Location header") - raise exc - redirect_chain.append(location) - self._validate_url(location) - current_url = location - continue - - if not response.is_success: - span = trace.get_current_span() - exc = HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch file from URL" - ) - span.record_exception(exc) - span.set_status(trace.StatusCode.ERROR, description="Failed to fetch file from URL") - raise exc - - content = response.content - content_type = response.headers.get("Content-Type", "").split(";", 1)[0].strip().lower() - break - else: - span = trace.get_current_span() - exc = HTTPException( - status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch file from URL" - ) - span.record_exception(exc) - span.set_status(trace.StatusCode.ERROR, description="Too many redirects") - raise exc - - span = trace.get_current_span() - if redirect_chain: - span.set_attribute("ssrf.redirect_chain", redirect_chain) - span.set_attribute("ssrf.redirect_count", len(redirect_chain)) + + response = await self._fetch_url( + file_input_content.source.value, + timeout=self._config.url_data_source.timeout, + allowlist=self._config.url_data_source.allowlist, + ) + + url = str(response.request.url) + content = response.content + content_type = response.headers.get("Content-Type", "").split(";", 1)[0].strip().lower() if not mime_type: mime_type = content_type if not mime_type: - mime_type, _ = mimetypes.guess_type(current_url, strict=False) + mime_type, _ = mimetypes.guess_type(url, strict=False) if not mime_type: mime_type = "application/octet-stream" - return _FileData(content=content, mime_type=mime_type, source_data={"url": current_url}) + return _FileData(content=content, mime_type=mime_type, source_data={"url": url}) @staticmethod - async def _extract_custom(file_input_content: FileInputContent) -> _FileData: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Custom file source type is not supported" + @traced(name="FileHandler.fetch_url") + async def _fetch_url( + url: str, + *, + timeout: timedelta, # noqa: ASYNC109 + allowlist: list[str], + max_redirects: int = 20, + ) -> httpx.Response: + redirect_chain: list[str] = [] + failure_exception = HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch file from URL" ) + async with httpx.AsyncClient(follow_redirects=False, timeout=timeout.total_seconds()) as client: + for _ in range(max_redirects): + response = await client.get(FileHandler._validate_url(url, allowlist=allowlist)) + next_request = response.next_request + if next_request is not None: + url = str(next_request.url) + redirect_chain.append(url) + continue + + if not response.is_success: + raise failure_exception + + span = trace.get_current_span() + span.set_attribute("ssrf.redirect_chain", redirect_chain) + span.set_attribute("ssrf.redirect_count", len(redirect_chain)) + + return response + + raise failure_exception + + @staticmethod + def _validate_url(url: str, *, allowlist: list[str]) -> str: + failure_exception = HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") + + parts = urllib.parse.urlsplit(url) + if not parts.hostname: + raise failure_exception + + try: + normalized_hostname = normalize_hostname(parts.hostname) + except Exception as exc: + raise failure_exception from exc + + if "*" in allowlist: + return url + + for entry in allowlist: + if normalized_hostname == entry or normalized_hostname.endswith("." + entry): + return url + + raise failure_exception @traced async def get(self, id: uuid.UUID, *, user_id: str) -> orm.File: diff --git a/src/_ravnar/utils.py b/src/_ravnar/utils.py index 17ca2be..b306a58 100644 --- a/src/_ravnar/utils.py +++ b/src/_ravnar/utils.py @@ -201,3 +201,7 @@ def call(v: Any) -> Any: return v return self.cls_or_fn(**{k: call(v) for k, v in self.params.items()}) + + +def normalize_hostname(host: str) -> str: + return host.encode("idna").decode("ascii").lower() diff --git a/tests/test_ssrf.py b/tests/test_ssrf.py index 842c595..e0ffadf 100644 --- a/tests/test_ssrf.py +++ b/tests/test_ssrf.py @@ -3,13 +3,12 @@ import urllib.parse from datetime import timedelta -import httpx +import pydantic import pytest -import pytest_httpserver from fastapi import HTTPException, status -import pydantic -from _ravnar.config import URLDataSourceConfig, FileStorageConfig, normalize_hostname +from _ravnar.config import URLDataSourceConfig, normalize_hostname +from _ravnar.file_storage import FileHandler class TestNormalizeHostname: @@ -45,6 +44,11 @@ def test_wildcard_preserved(self) -> None: config = URLDataSourceConfig(allowlist=["*"]) assert config.allowlist == ["*"] + def test_wildcard_with_others_blocked(self) -> None: + with pytest.raises(pydantic.ValidationError) as exc_info: + URLDataSourceConfig(allowlist=["*", "example.com"]) + assert "must be the sole" in str(exc_info.value) + def test_invalid_allowlist_entry_raises(self) -> None: # Double dot creates an empty label which is invalid in IDNA with pytest.raises(pydantic.ValidationError): @@ -59,104 +63,86 @@ def test_enabled_default(self) -> None: assert config.enabled is False -def _make_handler(url_data_source_config: URLDataSourceConfig | None = None): - """Create a FileHandler with the given URL data source config for testing.""" - from _ravnar.file_storage import FileHandler - - file_storage_config = FileStorageConfig( - url_data_source=url_data_source_config or URLDataSourceConfig(enabled=True, allowlist=["example.com"]) - ) - handler = FileHandler.__new__(FileHandler) - handler._file_storage_config = file_storage_config - return handler - - class TestValidateURL: + """Tests for FileHandler._validate_url (a sync @staticmethod).""" + def test_not_enabled(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=False)) - with pytest.raises(HTTPException) as exc_info: - handler._validate_url("http://example.com/file") - assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST - assert "URL file source is not enabled" in exc_info.value.detail + """URL source not enabled is checked in _extract_url, not _validate_url. - def test_empty_allowlist_blocks_all(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=[])) + _validate_url only validates the hostname against the allowlist. + When allowlist is empty, all URLs are rejected. + """ + allowlist: list[str] = [] with pytest.raises(HTTPException) as exc_info: - handler._validate_url("http://example.com/file") + FileHandler._validate_url("http://example.com/file", allowlist=allowlist) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST def test_exact_match(self) -> None: - handler = _make_handler() - result = handler._validate_url("http://example.com/file") + result = FileHandler._validate_url("http://example.com/file", allowlist=["example.com"]) assert result == "http://example.com/file" def test_subdomain_match(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["example.com"])) - result = handler._validate_url("http://sub.example.com/file") + result = FileHandler._validate_url("http://sub.example.com/file", allowlist=["example.com"]) assert result == "http://sub.example.com/file" def test_case_insensitive_match(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["EXAMPLE.COM"])) - result = handler._validate_url("http://example.com/file") + # Allowlist entry is already normalized (lowercased) at config load time + result = FileHandler._validate_url("http://example.com/file", allowlist=["example.com"]) assert result == "http://example.com/file" def test_non_match(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["example.com"])) with pytest.raises(HTTPException) as exc_info: - handler._validate_url("http://evil.com/file") + FileHandler._validate_url("http://evil.com/file", allowlist=["example.com"]) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST def test_wildcard_allows_all(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["*"])) - result = handler._validate_url("http://evil.com/file") + result = FileHandler._validate_url("http://evil.com/file", allowlist=["*"]) assert result == "http://evil.com/file" def test_wildcard_allows_internal(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["*"])) - result = handler._validate_url("http://169.254.169.254/latest/meta-data/") + result = FileHandler._validate_url("http://169.254.169.254/latest/meta-data/", allowlist=["*"]) assert result == "http://169.254.169.254/latest/meta-data/" def test_url_with_userinfo(self) -> None: - handler = _make_handler() - result = handler._validate_url("http://user:pass@example.com/file") + result = FileHandler._validate_url("http://user:pass@example.com/file", allowlist=["example.com"]) assert result == "http://user:pass@example.com/file" def test_url_with_non_standard_port(self) -> None: - handler = _make_handler() - result = handler._validate_url("http://example.com:8080/file") + result = FileHandler._validate_url("http://example.com:8080/file", allowlist=["example.com"]) assert result == "http://example.com:8080/file" def test_hostname_trailing_dot_not_matching(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["example.com"])) with pytest.raises(HTTPException) as exc_info: - handler._validate_url("http://example.com./file") + FileHandler._validate_url("http://example.com./file", allowlist=["example.com"]) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST def test_ip_literal_not_in_allowlist(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["example.com"])) with pytest.raises(HTTPException) as exc_info: - handler._validate_url("http://93.184.216.34/file") + FileHandler._validate_url("http://93.184.216.34/file", allowlist=["example.com"]) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST def test_ip_literal_in_allowlist(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["93.184.216.34"])) - result = handler._validate_url("http://93.184.216.34/file") + result = FileHandler._validate_url("http://93.184.216.34/file", allowlist=["93.184.216.34"]) assert result == "http://93.184.216.34/file" def test_url_with_no_hostname(self) -> None: - handler = _make_handler() with pytest.raises(HTTPException) as exc_info: - handler._validate_url("file:///etc/passwd") + FileHandler._validate_url("file:///etc/passwd", allowlist=["example.com"]) assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST def test_idn_hostname_matching_idn_entry(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["münchen.example.com"])) - result = handler._validate_url("http://MÜNCHEN.example.com/file") + # Config normalizes the IDN allowlist entry at load time + result = FileHandler._validate_url( + "http://MÜNCHEN.example.com/file", + allowlist=["xn--mnchen-3ya.example.com"], + ) assert result == "http://MÜNCHEN.example.com/file" def test_idn_hostname_matching_punycode_entry(self) -> None: - handler = _make_handler(URLDataSourceConfig(enabled=True, allowlist=["xn--mnchen-3ya.example.com"])) - result = handler._validate_url("http://MÜNCHEN.example.com/file") + result = FileHandler._validate_url( + "http://MÜNCHEN.example.com/file", + allowlist=["xn--mnchen-3ya.example.com"], + ) assert result == "http://MÜNCHEN.example.com/file" @@ -165,8 +151,8 @@ class TestValidateURLIntegration: @pytest.fixture def app_client(self, httpserver): - from tests.utils import TestClient from _ravnar.config import BaseConfig + from tests.utils import TestClient hostname = urllib.parse.urlparse(httpserver.url_for("/")).hostname or "localhost" config = BaseConfig.model_validate( @@ -201,7 +187,6 @@ def test_upload_from_allowlisted_url_succeeds(self, app_client, httpserver): assert response.status_code == 200 def test_upload_from_non_allowlisted_url_fails(self, app_client, httpserver): - # Point to a URL on a non-allowlisted host response = app_client.post( "/api/files", json={ @@ -212,9 +197,8 @@ def test_upload_from_non_allowlisted_url_fails(self, app_client, httpserver): assert response.status_code == 400 def test_url_source_disabled_fails(self, app_client, httpserver): - """Override fixture to disable URL source.""" - from tests.utils import TestClient from _ravnar.config import BaseConfig + from tests.utils import TestClient config = BaseConfig.model_validate( {