Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/_ravnar/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,16 @@ class DatabaseConfig(BaseModel, RenderableConfigMixin):

class URLDataSourceConfig(BaseModel, RenderableConfigMixin):
enabled: bool = False
allowlist: Allowlist = Field(default_factory=list)
allowed_hostnames: Allowlist = Field(default_factory=list)
timeout: timedelta = timedelta(seconds=30)

@field_validator("allowlist", mode="after")
@field_validator("allowed_hostnames", mode="after")
@classmethod
def _normalize_allowlist_entries(cls, allowlist: list[str]) -> list[str]:
def _normalize_hostnames(cls, allowlist: list[str]) -> list[str]:
if "*" in allowlist:
return allowlist

return [normalize_hostname(entry) for entry in allowlist]
return [normalize_hostname(hostname) for hostname in allowlist]


class FileStorageConfig(BaseModel, RenderableConfigMixin):
Expand Down
2 changes: 1 addition & 1 deletion src/_ravnar/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def _template_render_validation_handler(request: Request, exc: RequestVali
error=str(original.__cause__),
)
return JSONResponse(
status_code=400,
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": original.message},
)
return await request_validation_exception_handler(request, exc)
Expand Down
12 changes: 6 additions & 6 deletions src/_ravnar/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ async def _extract_url(self, file_input_content: FileInputContent) -> _FileData:
response = await self._fetch_url(
file_input_content.source.value,
timeout=self._config.url_data_source.timeout,
allowlist=self._config.url_data_source.allowlist,
allowed_hostnames=self._config.url_data_source.allowed_hostnames,
)

url = str(response.request.url)
Expand All @@ -189,7 +189,7 @@ async def _fetch_url(
url: str,
*,
timeout: timedelta, # noqa: ASYNC109
allowlist: list[str],
allowed_hostnames: list[str],
max_redirects: int = 20,
) -> httpx.Response:
redirect_chain: list[str] = []
Expand All @@ -198,7 +198,7 @@ async def _fetch_url(
)
async with httpx.AsyncClient(follow_redirects=False, timeout=timeout.total_seconds()) as client:
for _ in range(max_redirects):
response = await client.get(FileHandler._validate_url(url, allowlist=allowlist))
response = await client.get(FileHandler._validate_url(url, allowed_hostnames=allowed_hostnames))
next_request = response.next_request
if next_request is not None:
url = str(next_request.url)
Expand All @@ -217,7 +217,7 @@ async def _fetch_url(
raise failure_exception

@staticmethod
def _validate_url(url: str, *, allowlist: list[str]) -> str:
def _validate_url(url: str, *, allowed_hostnames: list[str]) -> str:
failure_exception = HTTPException(status.HTTP_400_BAD_REQUEST, detail="URL fetch not allowed")

parts = urllib.parse.urlsplit(url)
Expand All @@ -229,10 +229,10 @@ def _validate_url(url: str, *, allowlist: list[str]) -> str:
except Exception as exc:
raise failure_exception from exc

if "*" in allowlist:
if "*" in allowed_hostnames:
return url

for entry in allowlist:
for entry in allowed_hostnames:
if normalized_hostname == entry or normalized_hostname.endswith("." + entry):
return url

Expand Down
2 changes: 1 addition & 1 deletion tests/api/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def url_app_client(self, httpserver, request):
"files": {
"url_data_source": {
"enabled": True,
"allowlist": [hostname],
"allowed_hostnames": [hostname],
},
},
},
Expand Down
29 changes: 28 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
import yaml
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, YamlConfigSettingsSource

from _ravnar.config import AgentConfig, BaseConfig, Config, DynamicAgentConfig, ImportStringWithParams
from _ravnar.config import (
AgentConfig,
BaseConfig,
Config,
DynamicAgentConfig,
ImportStringWithParams,
URLDataSourceConfig,
)
from tests.utils import MockAgent


Expand Down Expand Up @@ -286,3 +293,23 @@ def test_wildcard_only(self):
def test_wildcard_with_other_entries_raises(self, matches="Wildcard"):
with pytest.raises(pydantic.ValidationError):
DynamicAgentConfig(enabled=True, allowed_env_vars=["*", "HOME"])


class TestURLDataSourceConfig:
def test_allowlist_normalization(self) -> None:
config = URLDataSourceConfig(allowed_hostnames=["GITHUB.COM", "München.example.com"])
assert config.allowed_hostnames == ["github.com", "xn--mnchen-3ya.example.com"]

def test_wildcard_preserved(self) -> None:
config = URLDataSourceConfig(allowed_hostnames=["*"])
assert config.allowed_hostnames == ["*"]

def test_wildcard_with_others_blocked(self) -> None:
with pytest.raises(pydantic.ValidationError) as exc_info:
URLDataSourceConfig(allowed_hostnames=["*", "example.com"])
assert "must be the sole" in str(exc_info.value)

def test_invalid_allowlist_entry_raises(self) -> None:
# Double dot creates an empty label which is invalid in IDNA
with pytest.raises(pydantic.ValidationError):
URLDataSourceConfig(allowed_hostnames=["example..com"])
195 changes: 68 additions & 127 deletions tests/test_ssrf.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,91 @@
from __future__ import annotations

import urllib.parse
from datetime import timedelta

import pydantic
import pytest
from fastapi import HTTPException, status

from _ravnar.config import URLDataSourceConfig, normalize_hostname
from _ravnar.file_storage import FileHandler
from _ravnar.utils import normalize_hostname


class TestNormalizeHostname:
def test_ascii_lowercasing(self) -> None:
assert normalize_hostname("GITHUB.COM") == "github.com"
assert normalize_hostname("Example.COM") == "example.com"

def test_unicode_to_punycode(self) -> None:
assert normalize_hostname("München.example.com") == "xn--mnchen-3ya.example.com"

def test_punycode_idempotent(self) -> None:
result = normalize_hostname("xn--mnchen-3ya.example.com")
assert result == "xn--mnchen-3ya.example.com"
@pytest.mark.parametrize(
("hostname", "expected"),
[
("GITHUB.COM", "github.com"),
("Example.COM", "example.com"),
("München.example.com", "xn--mnchen-3ya.example.com"),
("xn--mnchen-3ya.example.com", "xn--mnchen-3ya.example.com"),
("93.184.216.34", "93.184.216.34"),
("github.com.", "github.com."),
],
)
def test_normalize(self, hostname: str, expected: str) -> None:
assert normalize_hostname(hostname) == expected

def test_invalid_idna_raises_value_error(self) -> None:
# Double dot creates an empty label which is invalid in IDNA
with pytest.raises(ValueError):
normalize_hostname("example..com")

def test_ipv4_passthrough(self) -> None:
assert normalize_hostname("93.184.216.34") == "93.184.216.34"

def test_trailing_dot_preserved(self) -> None:
assert normalize_hostname("github.com.") == "github.com."


class TestURLDataSourceConfig:
def test_allowlist_normalization(self) -> None:
config = URLDataSourceConfig(allowlist=["GITHUB.COM", "München.example.com"])
assert config.allowlist == ["github.com", "xn--mnchen-3ya.example.com"]

def test_wildcard_preserved(self) -> None:
config = URLDataSourceConfig(allowlist=["*"])
assert config.allowlist == ["*"]

def test_wildcard_with_others_blocked(self) -> None:
with pytest.raises(pydantic.ValidationError) as exc_info:
URLDataSourceConfig(allowlist=["*", "example.com"])
assert "must be the sole" in str(exc_info.value)

def test_invalid_allowlist_entry_raises(self) -> None:
# Double dot creates an empty label which is invalid in IDNA
with pytest.raises(pydantic.ValidationError):
URLDataSourceConfig(allowlist=["example..com"])

def test_timeout_default(self) -> None:
config = URLDataSourceConfig()
assert config.timeout == timedelta(seconds=30)

def test_enabled_default(self) -> None:
config = URLDataSourceConfig()
assert config.enabled is False


class TestValidateURL:
"""Tests for FileHandler._validate_url (a sync @staticmethod)."""

def test_not_enabled(self) -> None:
"""URL source not enabled is checked in _extract_url, not _validate_url.

_validate_url only validates the hostname against the allowlist.
When allowlist is empty, all URLs are rejected.
"""
allowlist: list[str] = []
with pytest.raises(HTTPException) as exc_info:
FileHandler._validate_url("http://example.com/file", allowlist=allowlist)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST

def test_exact_match(self) -> None:
result = FileHandler._validate_url("http://example.com/file", allowlist=["example.com"])
assert result == "http://example.com/file"

def test_subdomain_match(self) -> None:
result = FileHandler._validate_url("http://sub.example.com/file", allowlist=["example.com"])
assert result == "http://sub.example.com/file"

def test_case_insensitive_match(self) -> None:
# Allowlist entry is already normalized (lowercased) at config load time
result = FileHandler._validate_url("http://example.com/file", allowlist=["example.com"])
assert result == "http://example.com/file"

def test_non_match(self) -> None:
@pytest.mark.parametrize(
("url", "allowed_hostnames"),
[
pytest.param("http://example.com/file", ["example.com"], id="exact_match"),
pytest.param("http://sub.example.com/file", ["example.com"], id="subdomain_match"),
pytest.param("http://example.com/file", ["example.com"], id="case_insensitive_match"),
pytest.param("http://evil.com/file", ["*"], id="wildcard_allows_all"),
pytest.param("http://169.254.169.254/latest/meta-data/", ["*"], id="wildcard_allows_internal"),
pytest.param("http://user:pass@example.com/file", ["example.com"], id="url_with_userinfo"),
pytest.param("http://example.com:8080/file", ["example.com"], id="url_with_non_standard_port"),
pytest.param("http://93.184.216.34/file", ["93.184.216.34"], id="ip_literal_in_allowlist"),
pytest.param(
"http://MÜNCHEN.example.com/file",
["xn--mnchen-3ya.example.com"],
id="idn_hostname_matching_idn_entry",
),
pytest.param(
"http://MÜNCHEN.example.com/file",
["xn--mnchen-3ya.example.com"],
id="idn_hostname_matching_punycode_entry",
),
],
)
def test_allowed(self, url: str, allowed_hostnames: list[str]) -> None:
assert FileHandler._validate_url(url, allowed_hostnames=allowed_hostnames) == url

@pytest.mark.parametrize(
("url", "allowed_hostnames"),
[
pytest.param(
"http://example.com/file",
[],
id="empty_allowlist",
),
pytest.param("http://evil.com/file", ["example.com"], id="non_match"),
pytest.param(
"http://example.com./file",
["example.com"],
id="hostname_trailing_dot_not_matching",
),
pytest.param(
"http://93.184.216.34/file",
["example.com"],
id="ip_literal_not_in_allowlist",
),
pytest.param("file:///etc/passwd", ["example.com"], id="url_with_no_hostname"),
],
)
def test_blocked(self, url: str, allowed_hostnames: list[str]) -> None:
with pytest.raises(HTTPException) as exc_info:
FileHandler._validate_url("http://evil.com/file", allowlist=["example.com"])
FileHandler._validate_url(url, allowed_hostnames=allowed_hostnames)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST

def test_wildcard_allows_all(self) -> None:
result = FileHandler._validate_url("http://evil.com/file", allowlist=["*"])
assert result == "http://evil.com/file"

def test_wildcard_allows_internal(self) -> None:
result = FileHandler._validate_url("http://169.254.169.254/latest/meta-data/", allowlist=["*"])
assert result == "http://169.254.169.254/latest/meta-data/"

def test_url_with_userinfo(self) -> None:
result = FileHandler._validate_url("http://user:pass@example.com/file", allowlist=["example.com"])
assert result == "http://user:pass@example.com/file"

def test_url_with_non_standard_port(self) -> None:
result = FileHandler._validate_url("http://example.com:8080/file", allowlist=["example.com"])
assert result == "http://example.com:8080/file"

def test_hostname_trailing_dot_not_matching(self) -> None:
with pytest.raises(HTTPException) as exc_info:
FileHandler._validate_url("http://example.com./file", allowlist=["example.com"])
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST

def test_ip_literal_not_in_allowlist(self) -> None:
with pytest.raises(HTTPException) as exc_info:
FileHandler._validate_url("http://93.184.216.34/file", allowlist=["example.com"])
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST

def test_ip_literal_in_allowlist(self) -> None:
result = FileHandler._validate_url("http://93.184.216.34/file", allowlist=["93.184.216.34"])
assert result == "http://93.184.216.34/file"

def test_url_with_no_hostname(self) -> None:
with pytest.raises(HTTPException) as exc_info:
FileHandler._validate_url("file:///etc/passwd", allowlist=["example.com"])
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST

def test_idn_hostname_matching_idn_entry(self) -> None:
# Config normalizes the IDN allowlist entry at load time
result = FileHandler._validate_url(
"http://MÜNCHEN.example.com/file",
allowlist=["xn--mnchen-3ya.example.com"],
)
assert result == "http://MÜNCHEN.example.com/file"

def test_idn_hostname_matching_punycode_entry(self) -> None:
result = FileHandler._validate_url(
"http://MÜNCHEN.example.com/file",
allowlist=["xn--mnchen-3ya.example.com"],
)
assert result == "http://MÜNCHEN.example.com/file"


class TestValidateURLIntegration:
"""Integration tests via HTTP client against a running ravnar instance."""
Expand All @@ -164,7 +105,7 @@ def app_client(self, httpserver):
"files": {
"url_data_source": {
"enabled": True,
"allowlist": [hostname],
"allowed_hostnames": [hostname],
},
},
},
Expand All @@ -184,7 +125,7 @@ def test_upload_from_allowlisted_url_succeeds(self, app_client, httpserver):
"source": {"type": "url", "value": url},
},
)
assert response.status_code == 200
assert response.status_code == status.HTTP_200_OK

def test_upload_from_non_allowlisted_url_fails(self, app_client, httpserver):
response = app_client.post(
Expand All @@ -194,7 +135,7 @@ def test_upload_from_non_allowlisted_url_fails(self, app_client, httpserver):
"source": {"type": "url", "value": "http://evil.com/malware"},
},
)
assert response.status_code == 400
assert response.status_code == status.HTTP_400_BAD_REQUEST

def test_url_source_disabled_fails(self, app_client, httpserver):
from _ravnar.config import BaseConfig
Expand Down Expand Up @@ -222,4 +163,4 @@ def test_url_source_disabled_fails(self, app_client, httpserver):
"source": {"type": "url", "value": "http://example.com/file"},
},
)
assert response.status_code == 400
assert response.status_code == status.HTTP_400_BAD_REQUEST
Loading