Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/_ravnar/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def _make_stateful_router(
from _ravnar.file_storage import FileHandler
from _ravnar.mixin import SetupTeardownMixin

database = Database(url=str(storage_config.database_dsn))
file_handler = FileHandler(root=storage_config.file_storage_path, database=database)
database = Database(config=storage_config.database)
file_handler = FileHandler(config=storage_config.files, database=database)

router = schema.APIRouter(
tags=["Stateful"],
Expand Down
30 changes: 27 additions & 3 deletions src/_ravnar/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import sys
from datetime import timedelta
from pathlib import Path
from typing import Annotated, Any, Self, TypeVar

Expand All @@ -10,7 +11,7 @@
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, YamlConfigSettingsSource
from upath import UPath

from _ravnar.utils import ImportStringWithParams, render_template
from _ravnar.utils import ImportStringWithParams, normalize_hostname, render_template

from .agents import Agent, DefaultAgent
from .authenticators import Authenticator
Expand Down Expand Up @@ -77,10 +78,33 @@ def _local_storage() -> Path:
return p


class DatabaseConfig(BaseModel, RenderableConfigMixin):
dsn: str = Field(default_factory=lambda: f"sqlite:///{_local_storage() / 'state.db'}")


class URLDataSourceConfig(BaseModel, RenderableConfigMixin):
enabled: bool = False
allowlist: Allowlist = Field(default_factory=list)
timeout: timedelta = timedelta(seconds=30)

@field_validator("allowlist", mode="after")
@classmethod
def _normalize_allowlist_entries(cls, allowlist: list[str]) -> list[str]:
if "*" in allowlist:
return allowlist

return [normalize_hostname(entry) for entry in allowlist]


class FileStorageConfig(BaseModel, RenderableConfigMixin):
path: UPath = Field(default_factory=lambda: UPath(_local_storage() / "files"))
url_data_source: URLDataSourceConfig = Field(default_factory=URLDataSourceConfig)


class StorageConfig(BaseModel, RenderableConfigMixin):
enabled: bool = True
database_dsn: str = Field(default_factory=lambda: f"sqlite:///{_local_storage() / 'state.db'}")
file_storage_path: UPath = Field(default_factory=lambda: UPath(_local_storage() / "files"))
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
files: FileStorageConfig = Field(default_factory=FileStorageConfig)
Comment thread
pmeier marked this conversation as resolved.


class DynamicAgentConfig(BaseModel, RenderableConfigMixin):
Expand Down
9 changes: 6 additions & 3 deletions src/_ravnar/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import AsyncIterator, Awaitable, Callable, Collection
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from math import ceil
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast

from fastapi import HTTPException, status
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
Expand All @@ -23,14 +23,17 @@
from .observability import traced
from .utils import as_async_context_manager, as_awaitable, now

if TYPE_CHECKING:
from _ravnar.config import DatabaseConfig


class SessionFactoryParams(TypedDict):
expire_on_commit: bool


class Database(SetupTeardownMixin):
def __init__(self, url: str) -> None:
url = make_url(url)
def __init__(self, config: DatabaseConfig) -> None:
url = make_url(config.dsn)

if url.drivername.startswith("sqlite") and (url.database is None or url.database == ":memory:"):
# See https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#using-a-memory-database-in-multiple-threads
Expand Down
116 changes: 84 additions & 32 deletions src/_ravnar/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import base64
import dataclasses
import mimetypes
import urllib.parse
import uuid
from datetime import datetime
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Annotated, Any, Self

import ag_ui.core
Expand All @@ -16,9 +17,10 @@

from _ravnar import orm, schema
from _ravnar.observability import traced
from _ravnar.utils import as_awaitable
from _ravnar.utils import as_awaitable, normalize_hostname

if TYPE_CHECKING:
from _ravnar.config import FileStorageConfig
from _ravnar.database import Database


Expand Down Expand Up @@ -98,23 +100,25 @@ class WrappedMetadata(schema.BaseModel):


class FileHandler:
def __init__(self, *, root: UPath, database: Database) -> None:
self._storage = _Storage(root)
def __init__(self, *, config: FileStorageConfig, database: Database) -> None:
self._config = config
self._storage = _Storage(config.path)
self._database = database

self._extractors = {
"data": self._extract_data,
"url": self._extract_url,
"custom": self._extract_custom,
}

@traced
async def add(self, file_input_content: FileInputContent, *, user_id: str) -> tuple[orm.File, bytes]:
source_type = file_input_content.source.type
if source_type not in self._extractors:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unsupported file source type")

data = await self._extractors[source_type](file_input_content)
try:
extractor = {
"data": self._extract_data,
"url": self._extract_url,
}[source_type]
except KeyError:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unsupported file source type"
) from None

data = await extractor(file_input_content)
file = orm.File(
user_id=user_id,
type=file_input_content.type,
Expand Down Expand Up @@ -152,24 +156,23 @@ async def _extract_data(file_input_content: FileInputContent) -> _FileData:
mime_type=file_input_content.source.mime_type,
)

@staticmethod
async def _extract_url(file_input_content: FileInputContent) -> _FileData:
async def _extract_url(self, file_input_content: FileInputContent) -> _FileData:
assert isinstance(file_input_content.source, ag_ui.core.InputContentUrlSource)

url = file_input_content.source.value
if not self._config.url_data_source.enabled:
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL file source is not enabled")

mime_type = file_input_content.source.mime_type
tracer = trace.get_tracer(__name__)
with tracer.start_as_current_span("FileHandler.fetch_url"):
async with httpx.AsyncClient(follow_redirects=True) as client:
response = await client.get(url)
if not response.is_success:
span = trace.get_current_span()
exc = HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch file from URL")
span.record_exception(exc)
span.set_status(trace.StatusCode.ERROR, description="Failed to fetch file from URL")
raise exc
content = response.content
content_type = response.headers.get("Content-Type", "").split(";", 1)[0].strip().lower()

response = await self._fetch_url(
file_input_content.source.value,
timeout=self._config.url_data_source.timeout,
allowlist=self._config.url_data_source.allowlist,
)

url = str(response.request.url)
content = response.content
content_type = response.headers.get("Content-Type", "").split(";", 1)[0].strip().lower()

if not mime_type:
mime_type = content_type
Expand All @@ -181,10 +184,59 @@ async def _extract_url(file_input_content: FileInputContent) -> _FileData:
return _FileData(content=content, mime_type=mime_type, source_data={"url": url})

@staticmethod
async def _extract_custom(file_input_content: FileInputContent) -> _FileData:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Custom file source type is not supported"
@traced(name="FileHandler.fetch_url")
async def _fetch_url(
url: str,
*,
timeout: timedelta, # noqa: ASYNC109
allowlist: list[str],
max_redirects: int = 20,
) -> httpx.Response:
redirect_chain: list[str] = []
failure_exception = HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch file from URL"
)
async with httpx.AsyncClient(follow_redirects=False, timeout=timeout.total_seconds()) as client:
for _ in range(max_redirects):
response = await client.get(FileHandler._validate_url(url, allowlist=allowlist))
next_request = response.next_request
if next_request is not None:
url = str(next_request.url)
redirect_chain.append(url)
continue

if not response.is_success:
raise failure_exception

span = trace.get_current_span()
span.set_attribute("ssrf.redirect_chain", redirect_chain)
span.set_attribute("ssrf.redirect_count", len(redirect_chain))

return response

raise failure_exception

@staticmethod
def _validate_url(url: str, *, allowlist: list[str]) -> str:
failure_exception = HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed")

parts = urllib.parse.urlsplit(url)
if not parts.hostname:
raise failure_exception

try:
normalized_hostname = normalize_hostname(parts.hostname)
except Exception as exc:
raise failure_exception from exc

if "*" in allowlist:
return url

for entry in allowlist:
if normalized_hostname == entry or normalized_hostname.endswith("." + entry):
return url

raise failure_exception

@traced
async def get(self, id: uuid.UUID, *, user_id: str) -> orm.File:
Expand Down
4 changes: 4 additions & 0 deletions src/_ravnar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,7 @@ def call(v: Any) -> Any:
return v

return self.cls_or_fn(**{k: call(v) for k, v in self.params.items()})


def normalize_hostname(host: str) -> str:
return host.encode("idna").decode("ascii").lower()
35 changes: 31 additions & 4 deletions tests/api/test_files.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import base64
import mimetypes
from urllib.parse import urlparse

import ag_ui.core
import compyre
import pydantic
import pytest
import pytest_httpserver.httpserver

from _ravnar.config import BaseConfig
from _ravnar.file_storage import MIME_TYPE, DataSourceValue, FileInputContent


Expand Down Expand Up @@ -44,11 +46,36 @@ def test_e2e_data_source(self, app_client, mime_type, metadata):
assert response.content == content
assert response.headers.get("Content-Type") == mime_type

@pytest.fixture
def url_app_client(self, httpserver, request):
"""Create a test client with URL source enabled and the test server's hostname allowlisted."""
from tests.utils import TestClient

parsed = urlparse(httpserver.url_for("/"))
hostname = parsed.hostname or "localhost"
config = BaseConfig.model_validate(
{
"security": {
"authenticator": "tests.utils.HeaderAuthenticator",
},
"storage": {
"files": {
"url_data_source": {
"enabled": True,
"allowlist": [hostname],
},
},
},
}
)
with TestClient.from_config(config) as client:
yield client

@pytest.mark.parametrize("mime_type", [None, "image/jpeg", "application/octet-stream"])
@pytest.mark.parametrize("source_content_type", [None, "image/png"])
@pytest.mark.parametrize("metadata", [None, "metadata", {"foo": "bar"}])
@pytest.mark.parametrize("endpoint", ["/image.jpg", "/file"])
def test_e2e_url_source(self, app_client, httpserver, mime_type, source_content_type, metadata, endpoint):
def test_e2e_url_source(self, url_app_client, httpserver, mime_type, source_content_type, metadata, endpoint):
content = b"content"

response_cls = pytest_httpserver.httpserver.Response
Expand All @@ -62,7 +89,7 @@ def test_e2e_url_source(self, app_client, httpserver, mime_type, source_content_
mime_type or source_content_type or mimetypes.guess_type(url, strict=False)[0] or "application/octet-stream"
)

response = app_client.post(
response = url_app_client.post(
"/api/files",
json=ag_ui.core.ImageInputContent(
source=ag_ui.core.InputContentUrlSource(value=url, mime_type=mime_type), metadata=metadata
Expand All @@ -81,10 +108,10 @@ def test_e2e_url_source(self, app_client, httpserver, mime_type, source_content_
file_id = value.file_id

expected = file_input_content
response = app_client.get(f"/api/files/{file_id}").raise_for_status()
response = url_app_client.get(f"/api/files/{file_id}").raise_for_status()
actual = pydantic.TypeAdapter(FileInputContent).validate_json(response.content)
compyre.assert_equal(actual, expected)

response = app_client.get(f"/api/files/{file_id}/content").raise_for_status()
response = url_app_client.get(f"/api/files/{file_id}/content").raise_for_status()
assert response.content == content
assert response.headers.get("Content-Type") == expected_mime_type
Loading
Loading