Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/_ravnar/api/agents.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Callable
from collections.abc import AsyncIterator, Callable
from typing import TYPE_CHECKING, Annotated, Any

import ag_ui.core
Expand All @@ -9,6 +9,7 @@

from _ravnar import schema
from _ravnar.security import User
from _ravnar.utils import ImportStringWithParams

if TYPE_CHECKING:
from _ravnar.core import AgentHandler
Expand Down Expand Up @@ -49,7 +50,13 @@ def _make_dynamic_agents_router(
"Can be checked with [`GET /api/config`](#/API/get_config_api_config_get)."
)

@router.post("", description=description)
async def set_dynamic_agent_render_template_context() -> AsyncIterator[None]:
with ImportStringWithParams.explicit_render_template_context(
agent_handler.get_dynamic_render_template_context()
):
yield

@router.post("", description=description, dependencies=[Depends(set_dynamic_agent_render_template_context)])
async def register_agent(
data: schema.RegisterAgentData,
user: User = Depends(authorized_user_with("agents:write")), # noqa: B008
Expand Down
45 changes: 25 additions & 20 deletions src/_ravnar/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,10 @@
import os
import sys
from pathlib import Path
from typing import Any, Self, TypeVar
from typing import Annotated, Any, Self, TypeVar

import l2sl
from pydantic import (
BaseModel,
Field,
field_validator,
model_validator,
)
from pydantic import AfterValidator, BaseModel, Field, field_validator, model_validator
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, YamlConfigSettingsSource
from upath import UPath

Expand All @@ -27,24 +22,33 @@ def interactive_session() -> bool:
return sys.stdout.isatty()


class RenderableMixin:
def _validate_allowlist_wildcard(allowlist: list[str]) -> list[str]:
if "*" in allowlist and len(allowlist) > 1:
raise ValueError('Wildcard "*" must be the sole allowlist entry. It cannot be combined with other entries.')
return allowlist


Allowlist = Annotated[list[str], AfterValidator(_validate_allowlist_wildcard)]


class RenderableConfigMixin:
@field_validator("*", mode="before")
@classmethod
def _render_templates(cls, data: Any) -> Any:
return render_template(data)
return render_template(data, context=dict(os.environ))


class LoggingConfig(BaseModel, RenderableMixin):
class LoggingConfig(BaseModel, RenderableConfigMixin):
level: l2sl.LogLevel = l2sl.LogLevel("info")
as_json: bool = Field(default_factory=lambda: not interactive_session())


class TracingConfig(BaseModel, RenderableMixin):
class TracingConfig(BaseModel, RenderableConfigMixin):
endpoint: str | None = None
as_logs: bool = Field(default_factory=lambda values: interactive_session() and values["endpoint"] is None)


class ServerConfig(BaseModel, RenderableMixin):
class ServerConfig(BaseModel, RenderableConfigMixin):
hostname: str = "127.0.0.1"
port: int = 8000
proxy_headers: bool = False
Expand All @@ -54,12 +58,12 @@ class ServerConfig(BaseModel, RenderableMixin):
tracing: TracingConfig = Field(default_factory=TracingConfig)


class CORSConfig(BaseModel, RenderableMixin):
allowed_origins: list[str] = Field(default_factory=lambda: ["*"])
allowed_headers: list[str] = Field(default_factory=list)
class CORSConfig(BaseModel, RenderableConfigMixin):
allowed_origins: Allowlist = Field(default_factory=lambda: ["*"])
allowed_headers: Allowlist = Field(default_factory=list)


class SecurityConfig(BaseModel, RenderableMixin):
class SecurityConfig(BaseModel, RenderableConfigMixin):
authenticator: ImportStringWithParams[Authenticator] | None = None
cors: CORSConfig = Field(default_factory=CORSConfig)

Expand All @@ -73,17 +77,18 @@ def _local_storage() -> Path:
return p


class StorageConfig(BaseModel, RenderableMixin):
class StorageConfig(BaseModel, RenderableConfigMixin):
enabled: bool = True
database_dsn: str = Field(default_factory=lambda: f"sqlite:///{_local_storage() / 'state.db'}")
file_storage_path: UPath = Field(default_factory=lambda: UPath(_local_storage() / "files"))


class DynamicAgentConfig(BaseModel, RenderableMixin):
class DynamicAgentConfig(BaseModel, RenderableConfigMixin):
enabled: bool = False
allowed_env_vars: Allowlist = Field(default_factory=list)


class AgentConfig(BaseModel, RenderableMixin):
class AgentConfig(BaseModel, RenderableConfigMixin):
static: dict[str, ImportStringWithParams[Agent]] = Field(
default_factory=lambda: { # type: ignore[arg-type]
"default": ImportStringWithParams(cls_or_fn=DefaultAgent),
Expand All @@ -98,7 +103,7 @@ def _ensure_not_agentless(self) -> Self:
return self


class BaseConfig(BaseSettings, RenderableMixin):
class BaseConfig(BaseSettings, RenderableConfigMixin):
server: ServerConfig = Field(default_factory=ServerConfig)
security: SecurityConfig = Field(default_factory=SecurityConfig)
storage: StorageConfig = Field(default_factory=StorageConfig)
Expand Down
36 changes: 33 additions & 3 deletions src/_ravnar/core.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from __future__ import annotations

import asyncio
import os
from collections.abc import AsyncIterator, Awaitable, Callable
from typing import TYPE_CHECKING, cast

import ag_ui.core
import ag_ui.encoder
import fastsse
from fastapi import FastAPI, HTTPException, status
import structlog
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.exception_handlers import request_validation_exception_handler
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, Response
from fastapi.responses import JSONResponse, RedirectResponse, Response
from opentelemetry import trace
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor

Expand All @@ -18,7 +22,7 @@
from _ravnar.mixin import SetupTeardownMixin
from _ravnar.observability import configure_logging, configure_tracing
from _ravnar.security import SecurityHeadersMiddleware, make_authorized_user_factory
from _ravnar.utils import as_awaitable
from _ravnar.utils import TemplateRenderError, as_awaitable

from .api import make_router as make_api_router
from .config import AgentConfig, BaseConfig, Config
Expand Down Expand Up @@ -74,6 +78,24 @@ async def health() -> Response:
async def version() -> str:
return __version__

@app.exception_handler(RequestValidationError)
async def _template_render_validation_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
for error in exc.errors():
if error.get("type") == "value_error":
original = error.get("ctx", {}).get("error")
if isinstance(original, TemplateRenderError):
structlog.get_logger().warning(
"Template rendering blocked",
template=original.template,
reason=original.reason,
error=str(original.__cause__),
)
return JSONResponse(
status_code=400,
content={"detail": original.message},
)
return await request_validation_exception_handler(request, exc)

app.include_router(
make_api_router(
storage_config=config.storage,
Expand Down Expand Up @@ -111,6 +133,14 @@ def __init__(self, agent_config: AgentConfig) -> None:
self._dynamic_agents: dict[str, Agent] = {}
self._event_encoder = ag_ui.encoder.EventEncoder()
self._dynamic_enabled = agent_config.dynamic.enabled
self._dynamic_allowed_env_vars = agent_config.dynamic.allowed_env_vars

def get_dynamic_render_template_context(self) -> dict[str, str]:
ctx = dict(os.environ)
if "*" in self._dynamic_allowed_env_vars:
return ctx

return {k: v for k, v in ctx.items() if k in self._dynamic_allowed_env_vars}

@staticmethod
async def _setup_agent(agent: Agent) -> None:
Expand Down
54 changes: 47 additions & 7 deletions src/_ravnar/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import contextlib
import contextvars
import functools
import inspect
import json
import os
import re
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator
from datetime import UTC, datetime
from typing import Any, Generic, TypeVar, cast, get_type_hints
from typing import Any, ClassVar, Generic, TypeVar, cast, get_type_hints

import jinja2
import jinja2.sandbox
from pydantic import (
BaseModel,
Field,
Expand Down Expand Up @@ -107,17 +109,40 @@ def now() -> datetime:
return datetime.now(tz=UTC)


def render_template(s: Any) -> Any:
class TemplateRenderError(ValueError):
def __init__(self, *, template: str, reason: str, message: str) -> None:
self.template = template
self.reason = reason
self.message = message
super().__init__(message)


def render_template(s: Any, context: dict[str, str]) -> Any:
if isinstance(s, str):
return jinja2.Environment().from_string(s).render(**os.environ)
env = jinja2.sandbox.SandboxedEnvironment(undefined=jinja2.StrictUndefined)
return env.from_string(s).render(**context)
if isinstance(s, dict):
return {render_template(k): render_template(v) for k, v in s.items()}
return {render_template(k, context): render_template(v, context) for k, v in s.items()}
if isinstance(s, list):
return [render_template(v) for v in s]
return [render_template(v, context) for v in s]
return s


class ImportStringWithParams(BaseModel, Generic[T]):
_render_template_context: ClassVar[contextvars.ContextVar[dict[str, str] | None]] = contextvars.ContextVar(
"render_template_context", default=None
)

@classmethod
def explicit_render_template_context(cls, ctx: dict[str, str]) -> contextlib.AbstractContextManager[dict[str, str]]:
@contextlib.contextmanager
def cm() -> Iterator[dict[str, str]]:
cls._render_template_context.set(ctx)
yield ctx
cls._render_template_context.set(None)

return cm()

cls_or_fn: ImportString[type[T] | Callable[..., T]]
params: dict[str, Any] = Field(default_factory=dict)

Expand Down Expand Up @@ -172,14 +197,29 @@ def validate(v: Any, loc: tuple[str | int, ...]) -> Any:
@classmethod
def _render_field_templates(cls, f: Any) -> Any:
if isinstance(f, str):
return render_template(f)
return cls._render_template(f)

return f

@field_validator("params", mode="after")
@classmethod
def _render_param_items(cls, params: dict[str, Any]) -> dict[str, Any]:
return {render_template(k): render_template(v) for k, v in params.items()}
return {cls._render_template(k): cls._render_template(v) for k, v in params.items()}

@classmethod
def _render_template(cls, s: Any) -> Any:
explicit_ctx = cls._render_template_context.get()
ctx = explicit_ctx if explicit_ctx is not None else dict(os.environ)
try:
return render_template(s, ctx)
except (jinja2.exceptions.SecurityError, jinja2.exceptions.UndefinedError) as exc:
if explicit_ctx is not None:
raise TemplateRenderError(
template=str(exc),
reason=type(exc).__name__,
message="Invalid configuration",
) from exc
raise

@model_serializer(mode="wrap")
def _serialize(self, nxt: SerializerFunctionWrapHandler) -> Any:
Expand Down
80 changes: 79 additions & 1 deletion tests/api/test_agents.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import compyre
import pydantic
import pytest
Expand All @@ -6,7 +8,7 @@
import ravnar.agents
from _ravnar import schema
from _ravnar.config import BaseConfig
from tests.utils import HeaderAuthenticator, make_app_client
from tests.utils import HeaderAuthenticator, MockAgent, make_app_client


def make_config(*, dynamic_enabled=False):
Expand Down Expand Up @@ -203,3 +205,79 @@ def test_unregister_twice_returns_404(self, client):

response = client.delete("/api/agents/delete-twice")
assert response.status_code == status.HTTP_404_NOT_FOUND

def test_register_agent_with_env_var_default_deny(self, client):
response = client.post(
"/api/agents",
json={
"id": "env-agent",
"agent": {
"cls_or_fn": f"{MockAgent.__module__}.{MockAgent.__name__}",
"params": {"param": "{{ HOME }}"},
},
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["detail"] == "Invalid configuration"

def test_register_agent_with_sandbox_escape(self, client):
response = client.post(
"/api/agents",
json={
"id": "sandbox-agent",
"agent": {
"cls_or_fn": f"{MockAgent.__module__}.{MockAgent.__name__}",
"params": {"param": "{{ ''.__class__ }}"},
},
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["detail"] == "Invalid configuration"


class TestDynamicAgentsWithAllowedEnvVars:
@pytest.fixture
def client(self, mocker):
mocker.patch.dict(os.environ, {"ALLOWED_VAR": "allowed_value", "DENIED_VAR": "denied_value"})
with make_app_client(
BaseConfig.model_validate(
{
"agents": {
"static": {
"default": {"cls_or_fn": "ravnar.agents.DefaultAgent"},
},
"dynamic": {"enabled": True, "allowed_env_vars": ["ALLOWED_VAR"]},
},
}
)
) as c:
yield c

def test_register_agent_with_allowed_env_var(self, client):
response = client.post(
"/api/agents",
json={
"id": "allowed-env-agent",
"agent": {
"cls_or_fn": f"{MockAgent.__module__}.{MockAgent.__name__}",
"params": {"param": "{{ ALLOWED_VAR }}"},
},
},
)
assert response.status_code == status.HTTP_200_OK
info = schema.AgentInfo.model_validate_json(response.content)
assert info.id == "allowed-env-agent"

def test_register_agent_with_denied_env_var(self, client):
response = client.post(
"/api/agents",
json={
"id": "denied-env-agent",
"agent": {
"cls_or_fn": f"{MockAgent.__module__}.{MockAgent.__name__}",
"params": {"param": "{{ DENIED_VAR }}"},
},
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["detail"] == "Invalid configuration"
Loading
Loading