Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions src/_ravnar/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import ag_ui.core
import pydantic

from _ravnar.security import User

from .mixin import SetupTeardownMixin

if TYPE_CHECKING:
Expand All @@ -24,7 +26,7 @@ class Agent(abc.ABC, SetupTeardownMixin):
"""Agent base class"""

@abc.abstractmethod
def run(self, input: ag_ui.core.RunAgentInput) -> AsyncIterator[ag_ui.core.Event]: ...
def run(self, input: ag_ui.core.RunAgentInput, user: User) -> AsyncIterator[ag_ui.core.Event]: ...

def get_capabilities(self) -> ag_ui.core.AgentCapabilities:
"""The capabilities of the agent."""
Expand All @@ -36,7 +38,7 @@ def get_quick_prompts(self) -> list[QuickPrompt]:


class DefaultAgent(Agent):
async def run(self, input: ag_ui.core.RunAgentInput) -> AsyncIterator[ag_ui.core.Event]:
async def run(self, input: ag_ui.core.RunAgentInput, user: User) -> AsyncIterator[ag_ui.core.Event]:
message_id = str(uuid.uuid4())
message = """
Hello, I'm ravnar's default agent.
Expand Down Expand Up @@ -97,7 +99,7 @@ def __init__(

super().__init__(capabilities=capabilities, quick_prompts=quick_prompts)

async def run(self, input: ag_ui.core.RunAgentInput) -> AsyncIterator[ag_ui.core.Event]:
async def run(self, input: ag_ui.core.RunAgentInput, user: User) -> AsyncIterator[ag_ui.core.Event]:
import httpx
import httpx_sse

Expand Down Expand Up @@ -152,10 +154,10 @@ async def setup(self) -> None:

self._capabilities = await self.extract_capabilities(self._agent)

def run(self, input: ag_ui.core.RunAgentInput) -> AsyncIterator[ag_ui.core.Event]:
def run(self, input: ag_ui.core.RunAgentInput, user: User) -> AsyncIterator[ag_ui.core.Event]:
from pydantic_ai.ui.ag_ui import AGUIAdapter

return AGUIAdapter(agent=self._agent, run_input=input, accept="text/event-stream").run_stream() # type: ignore[return-value]
return AGUIAdapter(agent=self._agent, run_input=input, accept="text/event-stream").run_stream(deps=user) # type: ignore[return-value, arg-type]

def get_capabilities(self) -> ag_ui.core.AgentCapabilities:
"""The capabilities of the agent."""
Expand Down Expand Up @@ -275,9 +277,24 @@ def __init__(

super().__init__(capabilities=capabilities, quick_prompts=quick_prompts)

def run(self, input: ag_ui.core.RunAgentInput) -> AsyncIterator[ag_ui.core.Event]:
def run(self, input: ag_ui.core.RunAgentInput, user: User) -> AsyncIterator[ag_ui.core.Event]:
from agno.os.interfaces.agui.router import run_agent

# Inject user_id into forwarded_props so agno's run_agent() picks it up natively and passes it to
# agent.arun(user_id=...)
forwarded_props = dict(input.forwarded_props or {})
forwarded_props["user_id"] = user.id

# Put the full serialized User into session_state so tools can access it via ctx.session_state
state: dict[str, Any] = dict(input.state or {})
state["user"] = user.model_dump(mode="json")

input = input.model_copy(
update={
"forwarded_props": forwarded_props,
"state": state,
}
)
return run_agent(self._agent, input) # type: ignore[return-value]

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion src/_ravnar/api/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def create_stateless_run(
run_agent_input: ag_ui.core.RunAgentInput,
user: User = Depends(authorized_user_with("agents:read")), # noqa: B008
) -> fastsse.Response:
return await agent_handler.run(agent_id, run_agent_input)
return await agent_handler.run(agent_id, run_agent_input, user=user)

if agent_handler.dynamic_enabled:
_make_dynamic_agents_router(router, agent_handler=agent_handler, authorized_user_with=authorized_user_with)
Expand Down
2 changes: 1 addition & 1 deletion src/_ravnar/api/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def callback(event_processor: EventProcessor) -> None:
run = event_processor.extract(include_input_message_ids={m.id for m in data.messages})
await database.create_run(run)

return await agent_handler.run(thread.agent_id, run_agent_input, callback=callback)
return await agent_handler.run(thread.agent_id, run_agent_input, user=user, callback=callback)

@traced(name="file-hydration")
async def hydrate_files(
Expand Down
5 changes: 3 additions & 2 deletions src/_ravnar/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from _ravnar.events import EventProcessor
from _ravnar.mixin import SetupTeardownMixin
from _ravnar.observability import configure_logging, configure_tracing
from _ravnar.security import SecurityHeadersMiddleware, make_authorized_user_factory
from _ravnar.security import SecurityHeadersMiddleware, User, make_authorized_user_factory
from _ravnar.utils import TemplateRenderError, as_awaitable

from .api import make_router as make_api_router
Expand Down Expand Up @@ -207,6 +207,7 @@ async def run(
agent_id: str,
run_agent_input: ag_ui.core.RunAgentInput,
*,
user: User,
callback: Callable[[EventProcessor], Awaitable[None]] | None = None,
) -> fastsse.Response:
agent = self._get_agent(agent_id)
Expand All @@ -222,7 +223,7 @@ async def run(

async def event_stream() -> AsyncIterator[ag_ui.core.Event]:
try:
async for event in event_processor.process_event_stream(agent.run(run_agent_input)):
async for event in event_processor.process_event_stream(agent.run(run_agent_input, user=user)):
yield event

if callback is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,6 @@ class MockAgent(Agent):
def __init__(self, param="unset"):
self.param = param

async def run(self, input):
async def run(self, input, *, user):
raise AssertionError
yield
Loading