diff --git a/src/_ravnar/agents.py b/src/_ravnar/agents.py index 4ad6caf..6ff526f 100644 --- a/src/_ravnar/agents.py +++ b/src/_ravnar/agents.py @@ -10,6 +10,8 @@ import ag_ui.core import pydantic +from _ravnar.security import User + from .mixin import SetupTeardownMixin if TYPE_CHECKING: @@ -24,7 +26,7 @@ class Agent(abc.ABC, SetupTeardownMixin): """Agent base class""" @abc.abstractmethod - def run(self, input: ag_ui.core.RunAgentInput) -> AsyncIterator[ag_ui.core.Event]: ... + def run(self, input: ag_ui.core.RunAgentInput, user: User) -> AsyncIterator[ag_ui.core.Event]: ... def get_capabilities(self) -> ag_ui.core.AgentCapabilities: """The capabilities of the agent.""" @@ -36,7 +38,7 @@ def get_quick_prompts(self) -> list[QuickPrompt]: class DefaultAgent(Agent): - async def run(self, input: ag_ui.core.RunAgentInput) -> AsyncIterator[ag_ui.core.Event]: + async def run(self, input: ag_ui.core.RunAgentInput, user: User) -> AsyncIterator[ag_ui.core.Event]: message_id = str(uuid.uuid4()) message = """ Hello, I'm ravnar's default agent. @@ -97,7 +99,7 @@ def __init__( super().__init__(capabilities=capabilities, quick_prompts=quick_prompts) - async def run(self, input: ag_ui.core.RunAgentInput) -> AsyncIterator[ag_ui.core.Event]: + async def run(self, input: ag_ui.core.RunAgentInput, user: User) -> AsyncIterator[ag_ui.core.Event]: import httpx import httpx_sse @@ -152,10 +154,10 @@ async def setup(self) -> None: self._capabilities = await self.extract_capabilities(self._agent) - def run(self, input: ag_ui.core.RunAgentInput) -> AsyncIterator[ag_ui.core.Event]: + def run(self, input: ag_ui.core.RunAgentInput, user: User) -> AsyncIterator[ag_ui.core.Event]: from pydantic_ai.ui.ag_ui import AGUIAdapter - return AGUIAdapter(agent=self._agent, run_input=input, accept="text/event-stream").run_stream() # type: ignore[return-value] + return AGUIAdapter(agent=self._agent, run_input=input, accept="text/event-stream").run_stream(deps=user) # type: ignore[return-value, arg-type] def get_capabilities(self) -> ag_ui.core.AgentCapabilities: """The capabilities of the agent.""" @@ -275,9 +277,24 @@ def __init__( super().__init__(capabilities=capabilities, quick_prompts=quick_prompts) - def run(self, input: ag_ui.core.RunAgentInput) -> AsyncIterator[ag_ui.core.Event]: + def run(self, input: ag_ui.core.RunAgentInput, user: User) -> AsyncIterator[ag_ui.core.Event]: from agno.os.interfaces.agui.router import run_agent + # Inject user_id into forwarded_props so agno's run_agent() picks it up natively and passes it to + # agent.arun(user_id=...) + forwarded_props = dict(input.forwarded_props or {}) + forwarded_props["user_id"] = user.id + + # Put the full serialized User into session_state so tools can access it via ctx.session_state + state: dict[str, Any] = dict(input.state or {}) + state["user"] = user.model_dump(mode="json") + + input = input.model_copy( + update={ + "forwarded_props": forwarded_props, + "state": state, + } + ) return run_agent(self._agent, input) # type: ignore[return-value] @staticmethod diff --git a/src/_ravnar/api/agents.py b/src/_ravnar/api/agents.py index 458c13b..5f529cd 100644 --- a/src/_ravnar/api/agents.py +++ b/src/_ravnar/api/agents.py @@ -31,7 +31,7 @@ async def create_stateless_run( run_agent_input: ag_ui.core.RunAgentInput, user: User = Depends(authorized_user_with("agents:read")), # noqa: B008 ) -> fastsse.Response: - return await agent_handler.run(agent_id, run_agent_input) + return await agent_handler.run(agent_id, run_agent_input, user=user) if agent_handler.dynamic_enabled: _make_dynamic_agents_router(router, agent_handler=agent_handler, authorized_user_with=authorized_user_with) diff --git a/src/_ravnar/api/threads.py b/src/_ravnar/api/threads.py index 95745b8..af076e8 100644 --- a/src/_ravnar/api/threads.py +++ b/src/_ravnar/api/threads.py @@ -142,7 +142,7 @@ async def callback(event_processor: EventProcessor) -> None: run = event_processor.extract(include_input_message_ids={m.id for m in data.messages}) await database.create_run(run) - return await agent_handler.run(thread.agent_id, run_agent_input, callback=callback) + return await agent_handler.run(thread.agent_id, run_agent_input, user=user, callback=callback) @traced(name="file-hydration") async def hydrate_files( diff --git a/src/_ravnar/core.py b/src/_ravnar/core.py index 58a6f72..b1e600b 100644 --- a/src/_ravnar/core.py +++ b/src/_ravnar/core.py @@ -21,7 +21,7 @@ from _ravnar.events import EventProcessor from _ravnar.mixin import SetupTeardownMixin from _ravnar.observability import configure_logging, configure_tracing -from _ravnar.security import SecurityHeadersMiddleware, make_authorized_user_factory +from _ravnar.security import SecurityHeadersMiddleware, User, make_authorized_user_factory from _ravnar.utils import TemplateRenderError, as_awaitable from .api import make_router as make_api_router @@ -207,6 +207,7 @@ async def run( agent_id: str, run_agent_input: ag_ui.core.RunAgentInput, *, + user: User, callback: Callable[[EventProcessor], Awaitable[None]] | None = None, ) -> fastsse.Response: agent = self._get_agent(agent_id) @@ -222,7 +223,7 @@ async def run( async def event_stream() -> AsyncIterator[ag_ui.core.Event]: try: - async for event in event_processor.process_event_stream(agent.run(run_agent_input)): + async for event in event_processor.process_event_stream(agent.run(run_agent_input, user=user)): yield event if callback is not None: diff --git a/tests/utils.py b/tests/utils.py index e30d166..7412020 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -116,6 +116,6 @@ class MockAgent(Agent): def __init__(self, param="unset"): self.param = param - async def run(self, input): + async def run(self, input, *, user): raise AssertionError yield