diff --git a/src/_ravnar/api/agents.py b/src/_ravnar/api/agents.py index 50b5b28..458c13b 100644 --- a/src/_ravnar/api/agents.py +++ b/src/_ravnar/api/agents.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable from typing import TYPE_CHECKING, Annotated, Any import ag_ui.core @@ -9,6 +9,7 @@ from _ravnar import schema from _ravnar.security import User +from _ravnar.utils import ImportStringWithParams if TYPE_CHECKING: from _ravnar.core import AgentHandler @@ -49,7 +50,13 @@ def _make_dynamic_agents_router( "Can be checked with [`GET /api/config`](#/API/get_config_api_config_get)." ) - @router.post("", description=description) + async def set_dynamic_agent_render_template_context() -> AsyncIterator[None]: + with ImportStringWithParams.explicit_render_template_context( + agent_handler.get_dynamic_render_template_context() + ): + yield + + @router.post("", description=description, dependencies=[Depends(set_dynamic_agent_render_template_context)]) async def register_agent( data: schema.RegisterAgentData, user: User = Depends(authorized_user_with("agents:write")), # noqa: B008 diff --git a/src/_ravnar/config.py b/src/_ravnar/config.py index 19813ad..37b8808 100644 --- a/src/_ravnar/config.py +++ b/src/_ravnar/config.py @@ -3,15 +3,10 @@ import os import sys from pathlib import Path -from typing import Any, Self, TypeVar +from typing import Annotated, Any, Self, TypeVar import l2sl -from pydantic import ( - BaseModel, - Field, - field_validator, - model_validator, -) +from pydantic import AfterValidator, BaseModel, Field, field_validator, model_validator from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, YamlConfigSettingsSource from upath import UPath @@ -27,24 +22,33 @@ def interactive_session() -> bool: return sys.stdout.isatty() -class RenderableMixin: +def _validate_allowlist_wildcard(allowlist: list[str]) -> list[str]: + if "*" in allowlist and len(allowlist) > 1: + raise ValueError('Wildcard "*" must be the sole allowlist entry. It cannot be combined with other entries.') + return allowlist + + +Allowlist = Annotated[list[str], AfterValidator(_validate_allowlist_wildcard)] + + +class RenderableConfigMixin: @field_validator("*", mode="before") @classmethod def _render_templates(cls, data: Any) -> Any: - return render_template(data) + return render_template(data, context=dict(os.environ)) -class LoggingConfig(BaseModel, RenderableMixin): +class LoggingConfig(BaseModel, RenderableConfigMixin): level: l2sl.LogLevel = l2sl.LogLevel("info") as_json: bool = Field(default_factory=lambda: not interactive_session()) -class TracingConfig(BaseModel, RenderableMixin): +class TracingConfig(BaseModel, RenderableConfigMixin): endpoint: str | None = None as_logs: bool = Field(default_factory=lambda values: interactive_session() and values["endpoint"] is None) -class ServerConfig(BaseModel, RenderableMixin): +class ServerConfig(BaseModel, RenderableConfigMixin): hostname: str = "127.0.0.1" port: int = 8000 proxy_headers: bool = False @@ -54,12 +58,12 @@ class ServerConfig(BaseModel, RenderableMixin): tracing: TracingConfig = Field(default_factory=TracingConfig) -class CORSConfig(BaseModel, RenderableMixin): - allowed_origins: list[str] = Field(default_factory=lambda: ["*"]) - allowed_headers: list[str] = Field(default_factory=list) +class CORSConfig(BaseModel, RenderableConfigMixin): + allowed_origins: Allowlist = Field(default_factory=lambda: ["*"]) + allowed_headers: Allowlist = Field(default_factory=list) -class SecurityConfig(BaseModel, RenderableMixin): +class SecurityConfig(BaseModel, RenderableConfigMixin): authenticator: ImportStringWithParams[Authenticator] | None = None cors: CORSConfig = Field(default_factory=CORSConfig) @@ -73,17 +77,18 @@ def _local_storage() -> Path: return p -class StorageConfig(BaseModel, RenderableMixin): +class StorageConfig(BaseModel, RenderableConfigMixin): enabled: bool = True database_dsn: str = Field(default_factory=lambda: f"sqlite:///{_local_storage() / 'state.db'}") file_storage_path: UPath = Field(default_factory=lambda: UPath(_local_storage() / "files")) -class DynamicAgentConfig(BaseModel, RenderableMixin): +class DynamicAgentConfig(BaseModel, RenderableConfigMixin): enabled: bool = False + allowed_env_vars: Allowlist = Field(default_factory=list) -class AgentConfig(BaseModel, RenderableMixin): +class AgentConfig(BaseModel, RenderableConfigMixin): static: dict[str, ImportStringWithParams[Agent]] = Field( default_factory=lambda: { # type: ignore[arg-type] "default": ImportStringWithParams(cls_or_fn=DefaultAgent), @@ -98,7 +103,7 @@ def _ensure_not_agentless(self) -> Self: return self -class BaseConfig(BaseSettings, RenderableMixin): +class BaseConfig(BaseSettings, RenderableConfigMixin): server: ServerConfig = Field(default_factory=ServerConfig) security: SecurityConfig = Field(default_factory=SecurityConfig) storage: StorageConfig = Field(default_factory=StorageConfig) diff --git a/src/_ravnar/core.py b/src/_ravnar/core.py index 7ffb232..0855803 100644 --- a/src/_ravnar/core.py +++ b/src/_ravnar/core.py @@ -1,15 +1,19 @@ from __future__ import annotations import asyncio +import os from collections.abc import AsyncIterator, Awaitable, Callable from typing import TYPE_CHECKING, cast import ag_ui.core import ag_ui.encoder import fastsse -from fastapi import FastAPI, HTTPException, status +import structlog +from fastapi import FastAPI, HTTPException, Request, status +from fastapi.exception_handlers import request_validation_exception_handler +from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import RedirectResponse, Response +from fastapi.responses import JSONResponse, RedirectResponse, Response from opentelemetry import trace from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor @@ -18,7 +22,7 @@ from _ravnar.mixin import SetupTeardownMixin from _ravnar.observability import configure_logging, configure_tracing from _ravnar.security import SecurityHeadersMiddleware, make_authorized_user_factory -from _ravnar.utils import as_awaitable +from _ravnar.utils import TemplateRenderError, as_awaitable from .api import make_router as make_api_router from .config import AgentConfig, BaseConfig, Config @@ -74,6 +78,24 @@ async def health() -> Response: async def version() -> str: return __version__ + @app.exception_handler(RequestValidationError) + async def _template_render_validation_handler(request: Request, exc: RequestValidationError) -> JSONResponse: + for error in exc.errors(): + if error.get("type") == "value_error": + original = error.get("ctx", {}).get("error") + if isinstance(original, TemplateRenderError): + structlog.get_logger().warning( + "Template rendering blocked", + template=original.template, + reason=original.reason, + error=str(original.__cause__), + ) + return JSONResponse( + status_code=400, + content={"detail": original.message}, + ) + return await request_validation_exception_handler(request, exc) + app.include_router( make_api_router( storage_config=config.storage, @@ -111,6 +133,14 @@ def __init__(self, agent_config: AgentConfig) -> None: self._dynamic_agents: dict[str, Agent] = {} self._event_encoder = ag_ui.encoder.EventEncoder() self._dynamic_enabled = agent_config.dynamic.enabled + self._dynamic_allowed_env_vars = agent_config.dynamic.allowed_env_vars + + def get_dynamic_render_template_context(self) -> dict[str, str]: + ctx = dict(os.environ) + if "*" in self._dynamic_allowed_env_vars: + return ctx + + return {k: v for k, v in ctx.items() if k in self._dynamic_allowed_env_vars} @staticmethod async def _setup_agent(agent: Agent) -> None: diff --git a/src/_ravnar/utils.py b/src/_ravnar/utils.py index 17ca2be..b8c0f2a 100644 --- a/src/_ravnar/utils.py +++ b/src/_ravnar/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import contextvars import functools import inspect import json @@ -8,9 +9,10 @@ import re from collections.abc import AsyncIterator, Awaitable, Callable, Iterator from datetime import UTC, datetime -from typing import Any, Generic, TypeVar, cast, get_type_hints +from typing import Any, ClassVar, Generic, TypeVar, cast, get_type_hints import jinja2 +import jinja2.sandbox from pydantic import ( BaseModel, Field, @@ -107,17 +109,40 @@ def now() -> datetime: return datetime.now(tz=UTC) -def render_template(s: Any) -> Any: +class TemplateRenderError(ValueError): + def __init__(self, *, template: str, reason: str, message: str) -> None: + self.template = template + self.reason = reason + self.message = message + super().__init__(message) + + +def render_template(s: Any, context: dict[str, str]) -> Any: if isinstance(s, str): - return jinja2.Environment().from_string(s).render(**os.environ) + env = jinja2.sandbox.SandboxedEnvironment(undefined=jinja2.StrictUndefined) + return env.from_string(s).render(**context) if isinstance(s, dict): - return {render_template(k): render_template(v) for k, v in s.items()} + return {render_template(k, context): render_template(v, context) for k, v in s.items()} if isinstance(s, list): - return [render_template(v) for v in s] + return [render_template(v, context) for v in s] return s class ImportStringWithParams(BaseModel, Generic[T]): + _render_template_context: ClassVar[contextvars.ContextVar[dict[str, str] | None]] = contextvars.ContextVar( + "render_template_context", default=None + ) + + @classmethod + def explicit_render_template_context(cls, ctx: dict[str, str]) -> contextlib.AbstractContextManager[dict[str, str]]: + @contextlib.contextmanager + def cm() -> Iterator[dict[str, str]]: + cls._render_template_context.set(ctx) + yield ctx + cls._render_template_context.set(None) + + return cm() + cls_or_fn: ImportString[type[T] | Callable[..., T]] params: dict[str, Any] = Field(default_factory=dict) @@ -172,14 +197,29 @@ def validate(v: Any, loc: tuple[str | int, ...]) -> Any: @classmethod def _render_field_templates(cls, f: Any) -> Any: if isinstance(f, str): - return render_template(f) + return cls._render_template(f) return f @field_validator("params", mode="after") @classmethod def _render_param_items(cls, params: dict[str, Any]) -> dict[str, Any]: - return {render_template(k): render_template(v) for k, v in params.items()} + return {cls._render_template(k): cls._render_template(v) for k, v in params.items()} + + @classmethod + def _render_template(cls, s: Any) -> Any: + explicit_ctx = cls._render_template_context.get() + ctx = explicit_ctx if explicit_ctx is not None else dict(os.environ) + try: + return render_template(s, ctx) + except (jinja2.exceptions.SecurityError, jinja2.exceptions.UndefinedError) as exc: + if explicit_ctx is not None: + raise TemplateRenderError( + template=str(exc), + reason=type(exc).__name__, + message="Invalid configuration", + ) from exc + raise @model_serializer(mode="wrap") def _serialize(self, nxt: SerializerFunctionWrapHandler) -> Any: diff --git a/tests/api/test_agents.py b/tests/api/test_agents.py index 74ee4aa..cef4bb6 100644 --- a/tests/api/test_agents.py +++ b/tests/api/test_agents.py @@ -1,3 +1,5 @@ +import os + import compyre import pydantic import pytest @@ -6,7 +8,7 @@ import ravnar.agents from _ravnar import schema from _ravnar.config import BaseConfig -from tests.utils import HeaderAuthenticator, make_app_client +from tests.utils import HeaderAuthenticator, MockAgent, make_app_client def make_config(*, dynamic_enabled=False): @@ -203,3 +205,79 @@ def test_unregister_twice_returns_404(self, client): response = client.delete("/api/agents/delete-twice") assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_register_agent_with_env_var_default_deny(self, client): + response = client.post( + "/api/agents", + json={ + "id": "env-agent", + "agent": { + "cls_or_fn": f"{MockAgent.__module__}.{MockAgent.__name__}", + "params": {"param": "{{ HOME }}"}, + }, + }, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["detail"] == "Invalid configuration" + + def test_register_agent_with_sandbox_escape(self, client): + response = client.post( + "/api/agents", + json={ + "id": "sandbox-agent", + "agent": { + "cls_or_fn": f"{MockAgent.__module__}.{MockAgent.__name__}", + "params": {"param": "{{ ''.__class__ }}"}, + }, + }, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["detail"] == "Invalid configuration" + + +class TestDynamicAgentsWithAllowedEnvVars: + @pytest.fixture + def client(self, mocker): + mocker.patch.dict(os.environ, {"ALLOWED_VAR": "allowed_value", "DENIED_VAR": "denied_value"}) + with make_app_client( + BaseConfig.model_validate( + { + "agents": { + "static": { + "default": {"cls_or_fn": "ravnar.agents.DefaultAgent"}, + }, + "dynamic": {"enabled": True, "allowed_env_vars": ["ALLOWED_VAR"]}, + }, + } + ) + ) as c: + yield c + + def test_register_agent_with_allowed_env_var(self, client): + response = client.post( + "/api/agents", + json={ + "id": "allowed-env-agent", + "agent": { + "cls_or_fn": f"{MockAgent.__module__}.{MockAgent.__name__}", + "params": {"param": "{{ ALLOWED_VAR }}"}, + }, + }, + ) + assert response.status_code == status.HTTP_200_OK + info = schema.AgentInfo.model_validate_json(response.content) + assert info.id == "allowed-env-agent" + + def test_register_agent_with_denied_env_var(self, client): + response = client.post( + "/api/agents", + json={ + "id": "denied-env-agent", + "agent": { + "cls_or_fn": f"{MockAgent.__module__}.{MockAgent.__name__}", + "params": {"param": "{{ DENIED_VAR }}"}, + }, + }, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["detail"] == "Invalid configuration" diff --git a/tests/test_config.py b/tests/test_config.py index 2ce0d17..1b49eba 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,8 +7,8 @@ import yaml from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, YamlConfigSettingsSource -from _ravnar.agents import Agent from _ravnar.config import AgentConfig, BaseConfig, Config, DynamicAgentConfig, ImportStringWithParams +from tests.utils import MockAgent @pytest.fixture() @@ -167,19 +167,10 @@ def test_template_rendering_in_list(mocker, make_test_config, source): assert config.security.cors.allowed_origins == [f"https://{app_domain}"] -class MockAgent(Agent): - def __init__(self, param="unset"): - self.param = param - - async def run(self, input): - raise AssertionError - yield - - @pytest.mark.parametrize("source", ["file", "env", "env_json"]) @pytest.mark.parametrize("input_type", ["plain", "object"]) def test_import_string_with_params(make_test_config, source, input_type): - import_path = f"{__name__}.{MockAgent.__name__}" + import_path = f"{MockAgent.__module__}.{MockAgent.__name__}" default_param = "unset" explicit_param = "sentinel" @@ -211,7 +202,7 @@ def test_import_string_with_params_nested_error_localization(): with pytest.raises(pydantic.ValidationError) as exc_info: ImportStringWithParams.model_validate( { - "cls_or_fn": f"{__name__}.{MockAgent.__name__}", + "cls_or_fn": f"{MockAgent.__module__}.{MockAgent.__name__}", "params": {"param": {"cls_or_fn": "non_existing_module.NonExistingClass"}}, } ) @@ -223,3 +214,75 @@ def test_import_string_with_params_nested_error_localization(): details = ve.errors()[0] assert details["loc"] == ("params", "param", "cls_or_fn") assert "non_existing_module" in details["msg"] + + +class TestImportStringWithParamsRestrictedContext: + def test_allowed_env_var_renders(self, mocker): + var = "ALLOWED_VAR" + value = "allowed_value" + + mocker.patch.dict(os.environ, {var: value}) + with ImportStringWithParams.explicit_render_template_context({var: value}): + result = ImportStringWithParams.model_validate( + { + "cls_or_fn": MockAgent, + "params": {"param": "{{ ALLOWED_VAR }}"}, + } + ) + assert result.params["param"] == value + + def test_denied_env_var_raises_template_render_error(self, mocker): + var = "DENIED_VAR" + value = "denied_value" + + mocker.patch.dict(os.environ, {var: value}) + with ( + ImportStringWithParams.explicit_render_template_context({}), + pytest.raises(pydantic.ValidationError, match="Invalid configuration"), + ): + ImportStringWithParams.model_validate( + { + "cls_or_fn": MockAgent, + "params": {"param": "{{ DENIED_VAR }}"}, + } + ) + + def test_security_error_in_restricted_context_raises_template_render_error(self, mocker): + mocker.patch.dict(os.environ, {"SECRET": "secret_value"}) + with ( + ImportStringWithParams.explicit_render_template_context({"SECRET": "secret_value"}), + pytest.raises(pydantic.ValidationError, match="Invalid configuration"), + ): + ImportStringWithParams.model_validate( + { + "cls_or_fn": MockAgent, + "params": {"param": "{{ ''.__class__ }}"}, + } + ) + + def test_no_context_falls_back_to_full_environ(self, mocker): + var = "FULL_VAR" + value = "full_value" + + mocker.patch.dict(os.environ, {var: value}) + result = ImportStringWithParams.model_validate( + { + "cls_or_fn": MockAgent, + "params": {"param": "{{ FULL_VAR }}"}, + } + ) + assert result.params["param"] == value + + +class TestAllowlist: + def test_valid_entries(self): + config = DynamicAgentConfig(enabled=True, allowed_env_vars=["HOME", "USER"]) + assert config.allowed_env_vars == ["HOME", "USER"] + + def test_wildcard_only(self): + config = DynamicAgentConfig(enabled=True, allowed_env_vars=["*"]) + assert config.allowed_env_vars == ["*"] + + def test_wildcard_with_other_entries_raises(self, matches="Wildcard"): + with pytest.raises(pydantic.ValidationError): + DynamicAgentConfig(enabled=True, allowed_env_vars=["*", "HOME"]) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..b923566 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,52 @@ +import os + +import jinja2 +import pytest + +from _ravnar.utils import render_template + + +class TestRenderTemplate: + def test_basic_math(self): + assert render_template("{{ 7 * 7 }}", {}) == "49" + + def test_env_var_access(self, mocker): + mocker.patch.dict(os.environ, {"TEST_HOME": "/home/test"}) + assert render_template("{{ TEST_HOME }}", dict(os.environ)) == "/home/test" + + def test_dunder_access_raises_security_error(self): + with pytest.raises(jinja2.exceptions.SecurityError): + render_template("{{ config.__class__ }}", {"config": "value"}) + + def test_self_dunder_access_raises_security_error(self): + with pytest.raises(jinja2.exceptions.SecurityError): + render_template("{{ ''.__class__.__mro__ }}", {}) + + def test_strict_undefined_raises_on_missing_var(self): + with pytest.raises(jinja2.exceptions.UndefinedError): + render_template("{{ MISSING_VAR }}", {}) + + def test_strict_undefined_default_filter(self): + assert render_template('{{ MISSING_VAR | default("fallback") }}', {}) == "fallback" + + def test_strict_undefined_is_defined_test(self): + assert render_template("{% if MISSING_VAR is defined %}yes{% else %}no{% endif %}", {}) == "no" + + def test_strict_undefined_conditional_access_raises(self): + with pytest.raises(jinja2.exceptions.UndefinedError): + render_template("{{ VAR if VAR else 'x' }}", {}) + + def test_dict_key_rendering(self): + assert render_template("{{ KEY }}", {"KEY": "value"}) == "value" + + def test_nested_dict(self): + result = render_template({"key": "{{ VAL }}"}, {"VAL": "v"}) + assert result == {"key": "v"} + + def test_nested_list(self): + result = render_template(["{{ A }}", "{{ B }}"], {"A": "1", "B": "2"}) + assert result == ["1", "2"] + + def test_non_string_passed_through(self): + assert render_template(42, {}) == 42 + assert render_template(None, {}) is None diff --git a/tests/utils.py b/tests/utils.py index 601e99a..e30d166 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,7 @@ from fastapi.security import APIKeyHeader from fastapi.testclient import TestClient as _TestClient +from _ravnar.agents import Agent from _ravnar.config import BaseConfig from _ravnar.core import Ravnar from _ravnar.security import ALL_PERMISSIONS, User @@ -109,3 +110,12 @@ def safe_extract_response_content(response): decoded_content = content.decode() decoded_content = f"\n{json.dumps(json.loads(content), indent=2)}" return decoded_content + + +class MockAgent(Agent): + def __init__(self, param="unset"): + self.param = param + + async def run(self, input): + raise AssertionError + yield