diff --git a/src/_ravnar/api/__init__.py b/src/_ravnar/api/__init__.py index 59989dd..86eb8e9 100644 --- a/src/_ravnar/api/__init__.py +++ b/src/_ravnar/api/__init__.py @@ -70,8 +70,8 @@ def _make_stateful_router( from _ravnar.file_storage import FileHandler from _ravnar.mixin import SetupTeardownMixin - database = Database(url=str(storage_config.database_dsn)) - file_handler = FileHandler(root=storage_config.file_storage_path, database=database) + database = Database(config=storage_config.database) + file_handler = FileHandler(config=storage_config.files, database=database) router = schema.APIRouter( tags=["Stateful"], diff --git a/src/_ravnar/config.py b/src/_ravnar/config.py index 37b8808..d118dfe 100644 --- a/src/_ravnar/config.py +++ b/src/_ravnar/config.py @@ -2,6 +2,7 @@ import os import sys +from datetime import timedelta from pathlib import Path from typing import Annotated, Any, Self, TypeVar @@ -10,7 +11,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, YamlConfigSettingsSource from upath import UPath -from _ravnar.utils import ImportStringWithParams, render_template +from _ravnar.utils import ImportStringWithParams, normalize_hostname, render_template from .agents import Agent, DefaultAgent from .authenticators import Authenticator @@ -77,10 +78,33 @@ def _local_storage() -> Path: return p +class DatabaseConfig(BaseModel, RenderableConfigMixin): + dsn: str = Field(default_factory=lambda: f"sqlite:///{_local_storage() / 'state.db'}") + + +class URLDataSourceConfig(BaseModel, RenderableConfigMixin): + enabled: bool = False + allowlist: Allowlist = Field(default_factory=list) + timeout: timedelta = timedelta(seconds=30) + + @field_validator("allowlist", mode="after") + @classmethod + def _normalize_allowlist_entries(cls, allowlist: list[str]) -> list[str]: + if "*" in allowlist: + return allowlist + + return [normalize_hostname(entry) for entry in allowlist] + + +class FileStorageConfig(BaseModel, RenderableConfigMixin): + path: UPath = Field(default_factory=lambda: UPath(_local_storage() / "files")) + url_data_source: URLDataSourceConfig = Field(default_factory=URLDataSourceConfig) + + class StorageConfig(BaseModel, RenderableConfigMixin): enabled: bool = True - database_dsn: str = Field(default_factory=lambda: f"sqlite:///{_local_storage() / 'state.db'}") - file_storage_path: UPath = Field(default_factory=lambda: UPath(_local_storage() / "files")) + database: DatabaseConfig = Field(default_factory=DatabaseConfig) + files: FileStorageConfig = Field(default_factory=FileStorageConfig) class DynamicAgentConfig(BaseModel, RenderableConfigMixin): diff --git a/src/_ravnar/database.py b/src/_ravnar/database.py index ea90b56..e9bc014 100644 --- a/src/_ravnar/database.py +++ b/src/_ravnar/database.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Awaitable, Callable, Collection from contextlib import AbstractAsyncContextManager, AbstractContextManager from math import ceil -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from fastapi import HTTPException, status from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor @@ -23,14 +23,17 @@ from .observability import traced from .utils import as_async_context_manager, as_awaitable, now +if TYPE_CHECKING: + from _ravnar.config import DatabaseConfig + class SessionFactoryParams(TypedDict): expire_on_commit: bool class Database(SetupTeardownMixin): - def __init__(self, url: str) -> None: - url = make_url(url) + def __init__(self, config: DatabaseConfig) -> None: + url = make_url(config.dsn) if url.drivername.startswith("sqlite") and (url.database is None or url.database == ":memory:"): # See https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#using-a-memory-database-in-multiple-threads diff --git a/src/_ravnar/file_storage.py b/src/_ravnar/file_storage.py index 25ac57c..8b4ede4 100644 --- a/src/_ravnar/file_storage.py +++ b/src/_ravnar/file_storage.py @@ -3,8 +3,9 @@ import base64 import dataclasses import mimetypes +import urllib.parse import uuid -from datetime import datetime +from datetime import datetime, timedelta from typing import TYPE_CHECKING, Annotated, Any, Self import ag_ui.core @@ -16,9 +17,10 @@ from _ravnar import orm, schema from _ravnar.observability import traced -from _ravnar.utils import as_awaitable +from _ravnar.utils import as_awaitable, normalize_hostname if TYPE_CHECKING: + from _ravnar.config import FileStorageConfig from _ravnar.database import Database @@ -98,23 +100,25 @@ class WrappedMetadata(schema.BaseModel): class FileHandler: - def __init__(self, *, root: UPath, database: Database) -> None: - self._storage = _Storage(root) + def __init__(self, *, config: FileStorageConfig, database: Database) -> None: + self._config = config + self._storage = _Storage(config.path) self._database = database - self._extractors = { - "data": self._extract_data, - "url": self._extract_url, - "custom": self._extract_custom, - } - @traced async def add(self, file_input_content: FileInputContent, *, user_id: str) -> tuple[orm.File, bytes]: source_type = file_input_content.source.type - if source_type not in self._extractors: - raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unsupported file source type") - - data = await self._extractors[source_type](file_input_content) + try: + extractor = { + "data": self._extract_data, + "url": self._extract_url, + }[source_type] + except KeyError: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unsupported file source type" + ) from None + + data = await extractor(file_input_content) file = orm.File( user_id=user_id, type=file_input_content.type, @@ -152,24 +156,23 @@ async def _extract_data(file_input_content: FileInputContent) -> _FileData: mime_type=file_input_content.source.mime_type, ) - @staticmethod - async def _extract_url(file_input_content: FileInputContent) -> _FileData: + async def _extract_url(self, file_input_content: FileInputContent) -> _FileData: assert isinstance(file_input_content.source, ag_ui.core.InputContentUrlSource) - url = file_input_content.source.value + if not self._config.url_data_source.enabled: + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL file source is not enabled") + mime_type = file_input_content.source.mime_type - tracer = trace.get_tracer(__name__) - with tracer.start_as_current_span("FileHandler.fetch_url"): - async with httpx.AsyncClient(follow_redirects=True) as client: - response = await client.get(url) - if not response.is_success: - span = trace.get_current_span() - exc = HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch file from URL") - span.record_exception(exc) - span.set_status(trace.StatusCode.ERROR, description="Failed to fetch file from URL") - raise exc - content = response.content - content_type = response.headers.get("Content-Type", "").split(";", 1)[0].strip().lower() + + response = await self._fetch_url( + file_input_content.source.value, + timeout=self._config.url_data_source.timeout, + allowlist=self._config.url_data_source.allowlist, + ) + + url = str(response.request.url) + content = response.content + content_type = response.headers.get("Content-Type", "").split(";", 1)[0].strip().lower() if not mime_type: mime_type = content_type @@ -181,10 +184,59 @@ async def _extract_url(file_input_content: FileInputContent) -> _FileData: return _FileData(content=content, mime_type=mime_type, source_data={"url": url}) @staticmethod - async def _extract_custom(file_input_content: FileInputContent) -> _FileData: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Custom file source type is not supported" + @traced(name="FileHandler.fetch_url") + async def _fetch_url( + url: str, + *, + timeout: timedelta, # noqa: ASYNC109 + allowlist: list[str], + max_redirects: int = 20, + ) -> httpx.Response: + redirect_chain: list[str] = [] + failure_exception = HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch file from URL" ) + async with httpx.AsyncClient(follow_redirects=False, timeout=timeout.total_seconds()) as client: + for _ in range(max_redirects): + response = await client.get(FileHandler._validate_url(url, allowlist=allowlist)) + next_request = response.next_request + if next_request is not None: + url = str(next_request.url) + redirect_chain.append(url) + continue + + if not response.is_success: + raise failure_exception + + span = trace.get_current_span() + span.set_attribute("ssrf.redirect_chain", redirect_chain) + span.set_attribute("ssrf.redirect_count", len(redirect_chain)) + + return response + + raise failure_exception + + @staticmethod + def _validate_url(url: str, *, allowlist: list[str]) -> str: + failure_exception = HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed") + + parts = urllib.parse.urlsplit(url) + if not parts.hostname: + raise failure_exception + + try: + normalized_hostname = normalize_hostname(parts.hostname) + except Exception as exc: + raise failure_exception from exc + + if "*" in allowlist: + return url + + for entry in allowlist: + if normalized_hostname == entry or normalized_hostname.endswith("." + entry): + return url + + raise failure_exception @traced async def get(self, id: uuid.UUID, *, user_id: str) -> orm.File: diff --git a/src/_ravnar/utils.py b/src/_ravnar/utils.py index b8c0f2a..5d1c4dc 100644 --- a/src/_ravnar/utils.py +++ b/src/_ravnar/utils.py @@ -241,3 +241,7 @@ def call(v: Any) -> Any: return v return self.cls_or_fn(**{k: call(v) for k, v in self.params.items()}) + + +def normalize_hostname(host: str) -> str: + return host.encode("idna").decode("ascii").lower() diff --git a/tests/api/test_files.py b/tests/api/test_files.py index 69d9787..b8f1482 100644 --- a/tests/api/test_files.py +++ b/tests/api/test_files.py @@ -1,5 +1,6 @@ import base64 import mimetypes +from urllib.parse import urlparse import ag_ui.core import compyre @@ -7,6 +8,7 @@ import pytest import pytest_httpserver.httpserver +from _ravnar.config import BaseConfig from _ravnar.file_storage import MIME_TYPE, DataSourceValue, FileInputContent @@ -44,11 +46,36 @@ def test_e2e_data_source(self, app_client, mime_type, metadata): assert response.content == content assert response.headers.get("Content-Type") == mime_type + @pytest.fixture + def url_app_client(self, httpserver, request): + """Create a test client with URL source enabled and the test server's hostname allowlisted.""" + from tests.utils import TestClient + + parsed = urlparse(httpserver.url_for("/")) + hostname = parsed.hostname or "localhost" + config = BaseConfig.model_validate( + { + "security": { + "authenticator": "tests.utils.HeaderAuthenticator", + }, + "storage": { + "files": { + "url_data_source": { + "enabled": True, + "allowlist": [hostname], + }, + }, + }, + } + ) + with TestClient.from_config(config) as client: + yield client + @pytest.mark.parametrize("mime_type", [None, "image/jpeg", "application/octet-stream"]) @pytest.mark.parametrize("source_content_type", [None, "image/png"]) @pytest.mark.parametrize("metadata", [None, "metadata", {"foo": "bar"}]) @pytest.mark.parametrize("endpoint", ["/image.jpg", "/file"]) - def test_e2e_url_source(self, app_client, httpserver, mime_type, source_content_type, metadata, endpoint): + def test_e2e_url_source(self, url_app_client, httpserver, mime_type, source_content_type, metadata, endpoint): content = b"content" response_cls = pytest_httpserver.httpserver.Response @@ -62,7 +89,7 @@ def test_e2e_url_source(self, app_client, httpserver, mime_type, source_content_ mime_type or source_content_type or mimetypes.guess_type(url, strict=False)[0] or "application/octet-stream" ) - response = app_client.post( + response = url_app_client.post( "/api/files", json=ag_ui.core.ImageInputContent( source=ag_ui.core.InputContentUrlSource(value=url, mime_type=mime_type), metadata=metadata @@ -81,10 +108,10 @@ def test_e2e_url_source(self, app_client, httpserver, mime_type, source_content_ file_id = value.file_id expected = file_input_content - response = app_client.get(f"/api/files/{file_id}").raise_for_status() + response = url_app_client.get(f"/api/files/{file_id}").raise_for_status() actual = pydantic.TypeAdapter(FileInputContent).validate_json(response.content) compyre.assert_equal(actual, expected) - response = app_client.get(f"/api/files/{file_id}/content").raise_for_status() + response = url_app_client.get(f"/api/files/{file_id}/content").raise_for_status() assert response.content == content assert response.headers.get("Content-Type") == expected_mime_type diff --git a/tests/test_ssrf.py b/tests/test_ssrf.py new file mode 100644 index 0000000..e0ffadf --- /dev/null +++ b/tests/test_ssrf.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import urllib.parse +from datetime import timedelta + +import pydantic +import pytest +from fastapi import HTTPException, status + +from _ravnar.config import URLDataSourceConfig, normalize_hostname +from _ravnar.file_storage import FileHandler + + +class TestNormalizeHostname: + def test_ascii_lowercasing(self) -> None: + assert normalize_hostname("GITHUB.COM") == "github.com" + assert normalize_hostname("Example.COM") == "example.com" + + def test_unicode_to_punycode(self) -> None: + assert normalize_hostname("München.example.com") == "xn--mnchen-3ya.example.com" + + def test_punycode_idempotent(self) -> None: + result = normalize_hostname("xn--mnchen-3ya.example.com") + assert result == "xn--mnchen-3ya.example.com" + + def test_invalid_idna_raises_value_error(self) -> None: + # Double dot creates an empty label which is invalid in IDNA + with pytest.raises(ValueError): + normalize_hostname("example..com") + + def test_ipv4_passthrough(self) -> None: + assert normalize_hostname("93.184.216.34") == "93.184.216.34" + + def test_trailing_dot_preserved(self) -> None: + assert normalize_hostname("github.com.") == "github.com." + + +class TestURLDataSourceConfig: + def test_allowlist_normalization(self) -> None: + config = URLDataSourceConfig(allowlist=["GITHUB.COM", "München.example.com"]) + assert config.allowlist == ["github.com", "xn--mnchen-3ya.example.com"] + + def test_wildcard_preserved(self) -> None: + config = URLDataSourceConfig(allowlist=["*"]) + assert config.allowlist == ["*"] + + def test_wildcard_with_others_blocked(self) -> None: + with pytest.raises(pydantic.ValidationError) as exc_info: + URLDataSourceConfig(allowlist=["*", "example.com"]) + assert "must be the sole" in str(exc_info.value) + + def test_invalid_allowlist_entry_raises(self) -> None: + # Double dot creates an empty label which is invalid in IDNA + with pytest.raises(pydantic.ValidationError): + URLDataSourceConfig(allowlist=["example..com"]) + + def test_timeout_default(self) -> None: + config = URLDataSourceConfig() + assert config.timeout == timedelta(seconds=30) + + def test_enabled_default(self) -> None: + config = URLDataSourceConfig() + assert config.enabled is False + + +class TestValidateURL: + """Tests for FileHandler._validate_url (a sync @staticmethod).""" + + def test_not_enabled(self) -> None: + """URL source not enabled is checked in _extract_url, not _validate_url. + + _validate_url only validates the hostname against the allowlist. + When allowlist is empty, all URLs are rejected. + """ + allowlist: list[str] = [] + with pytest.raises(HTTPException) as exc_info: + FileHandler._validate_url("http://example.com/file", allowlist=allowlist) + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + + def test_exact_match(self) -> None: + result = FileHandler._validate_url("http://example.com/file", allowlist=["example.com"]) + assert result == "http://example.com/file" + + def test_subdomain_match(self) -> None: + result = FileHandler._validate_url("http://sub.example.com/file", allowlist=["example.com"]) + assert result == "http://sub.example.com/file" + + def test_case_insensitive_match(self) -> None: + # Allowlist entry is already normalized (lowercased) at config load time + result = FileHandler._validate_url("http://example.com/file", allowlist=["example.com"]) + assert result == "http://example.com/file" + + def test_non_match(self) -> None: + with pytest.raises(HTTPException) as exc_info: + FileHandler._validate_url("http://evil.com/file", allowlist=["example.com"]) + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + + def test_wildcard_allows_all(self) -> None: + result = FileHandler._validate_url("http://evil.com/file", allowlist=["*"]) + assert result == "http://evil.com/file" + + def test_wildcard_allows_internal(self) -> None: + result = FileHandler._validate_url("http://169.254.169.254/latest/meta-data/", allowlist=["*"]) + assert result == "http://169.254.169.254/latest/meta-data/" + + def test_url_with_userinfo(self) -> None: + result = FileHandler._validate_url("http://user:pass@example.com/file", allowlist=["example.com"]) + assert result == "http://user:pass@example.com/file" + + def test_url_with_non_standard_port(self) -> None: + result = FileHandler._validate_url("http://example.com:8080/file", allowlist=["example.com"]) + assert result == "http://example.com:8080/file" + + def test_hostname_trailing_dot_not_matching(self) -> None: + with pytest.raises(HTTPException) as exc_info: + FileHandler._validate_url("http://example.com./file", allowlist=["example.com"]) + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + + def test_ip_literal_not_in_allowlist(self) -> None: + with pytest.raises(HTTPException) as exc_info: + FileHandler._validate_url("http://93.184.216.34/file", allowlist=["example.com"]) + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + + def test_ip_literal_in_allowlist(self) -> None: + result = FileHandler._validate_url("http://93.184.216.34/file", allowlist=["93.184.216.34"]) + assert result == "http://93.184.216.34/file" + + def test_url_with_no_hostname(self) -> None: + with pytest.raises(HTTPException) as exc_info: + FileHandler._validate_url("file:///etc/passwd", allowlist=["example.com"]) + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + + def test_idn_hostname_matching_idn_entry(self) -> None: + # Config normalizes the IDN allowlist entry at load time + result = FileHandler._validate_url( + "http://MÜNCHEN.example.com/file", + allowlist=["xn--mnchen-3ya.example.com"], + ) + assert result == "http://MÜNCHEN.example.com/file" + + def test_idn_hostname_matching_punycode_entry(self) -> None: + result = FileHandler._validate_url( + "http://MÜNCHEN.example.com/file", + allowlist=["xn--mnchen-3ya.example.com"], + ) + assert result == "http://MÜNCHEN.example.com/file" + + +class TestValidateURLIntegration: + """Integration tests via HTTP client against a running ravnar instance.""" + + @pytest.fixture + def app_client(self, httpserver): + from _ravnar.config import BaseConfig + from tests.utils import TestClient + + hostname = urllib.parse.urlparse(httpserver.url_for("/")).hostname or "localhost" + config = BaseConfig.model_validate( + { + "security": { + "authenticator": "tests.utils.HeaderAuthenticator", + }, + "storage": { + "files": { + "url_data_source": { + "enabled": True, + "allowlist": [hostname], + }, + }, + }, + } + ) + with TestClient.from_config(config) as client: + yield client + + def test_upload_from_allowlisted_url_succeeds(self, app_client, httpserver): + httpserver.expect_request("/file.txt").respond_with_data(b"hello") + url = httpserver.url_for("/file.txt") + + response = app_client.post( + "/api/files", + json={ + "type": "document", + "source": {"type": "url", "value": url}, + }, + ) + assert response.status_code == 200 + + def test_upload_from_non_allowlisted_url_fails(self, app_client, httpserver): + response = app_client.post( + "/api/files", + json={ + "type": "document", + "source": {"type": "url", "value": "http://evil.com/malware"}, + }, + ) + assert response.status_code == 400 + + def test_url_source_disabled_fails(self, app_client, httpserver): + from _ravnar.config import BaseConfig + from tests.utils import TestClient + + config = BaseConfig.model_validate( + { + "security": { + "authenticator": "tests.utils.HeaderAuthenticator", + }, + "storage": { + "files": { + "url_data_source": { + "enabled": False, + }, + }, + }, + } + ) + with TestClient.from_config(config) as client: + response = client.post( + "/api/files", + json={ + "type": "document", + "source": {"type": "url", "value": "http://example.com/file"}, + }, + ) + assert response.status_code == 400