Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion py/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@
LITELLM_VERSIONS = (LATEST, "1.74.0")
# CLI bundling started in 0.1.10 - older versions require external Claude Code installation
CLAUDE_AGENT_SDK_VERSIONS = (LATEST, "0.1.10")
AGNO_VERSIONS = (LATEST, "2.1.0")
# Keep LATEST for newest API coverage, and pin 2.4.0 to cover the 2.4 -> 2.5 breaking change
# to internals we leverage for instrumentation.
AGNO_VERSIONS = (LATEST, "2.4.0", "2.1.0")
# pydantic_ai 1.x requires Python >= 3.10
# Two test suites with different version requirements:
# 1. wrap_openai approach: works with older versions (0.1.9+)
Expand Down
28 changes: 14 additions & 14 deletions py/src/braintrust/wrappers/agno/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from braintrust.span_types import SpanTypeAttribute
from wrapt import wrap_function_wrapper

from .run_helpers import arun_public_dispatch_wrapper, run_public_dispatch_wrapper
from .utils import (
_aggregate_agent_chunks,
extract_metadata,
Expand Down Expand Up @@ -46,19 +47,18 @@ def _run_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: Any):
return _create_run_span(wrapped, instance, args, kwargs, input_data)

def _run_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
"""Entry point for public run(input)."""
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
input_data = {"input": input_arg}
return _create_run_span(wrapped, instance, args, kwargs, input_data)
return run_public_dispatch_wrapper(
wrapped, instance, args, kwargs, default_name="Agent", metadata_component="agent"
)

# Wrap private method if it exists, otherwise wrap public method
if hasattr(Agent, "_run"):
wrap_function_wrapper(Agent, "_run", _run_wrapper_private)
elif hasattr(Agent, "run"):
wrap_function_wrapper(Agent, "run", _run_wrapper_public)

async def _create_arun_span(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict):
"""Shared logic to create span and execute arun method."""
async def _create_arun_span_private(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict):
"""Shared logic to create span and execute async private _arun method."""
agent_name = getattr(instance, "name", None) or "Agent"
span_name = f"{agent_name}.arun"

Expand All @@ -80,19 +80,16 @@ async def _arun_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs:
run_response = args[0] if len(args) > 0 else kwargs.get("run_response")
input_arg = args[1] if len(args) > 1 else kwargs.get("input")
input_data = {"run_response": run_response, "input": input_arg}
return await _create_arun_span(wrapped, instance, args, kwargs, input_data)
return await _create_arun_span_private(wrapped, instance, args, kwargs, input_data)

async def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
"""Entry point for public arun(input)."""
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
input_data = {"input": input_arg}
return await _create_arun_span(wrapped, instance, args, kwargs, input_data)
def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
return arun_public_dispatch_wrapper(
wrapped, instance, args, kwargs, default_name="Agent", metadata_component="agent"
)

# Wrap private method if it exists, otherwise wrap public method
if hasattr(Agent, "_arun"):
wrap_function_wrapper(Agent, "_arun", _arun_wrapper_private)
elif hasattr(Agent, "arun"):
wrap_function_wrapper(Agent, "arun", _arun_wrapper_public)

def run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
agent_name = getattr(instance, "name", None) or "Agent"
Expand Down Expand Up @@ -211,6 +208,9 @@ async def _trace_stream():

if hasattr(Agent, "_arun_stream"):
wrap_function_wrapper(Agent, "_arun_stream", arun_stream_wrapper)
elif not hasattr(Agent, "_arun") and hasattr(Agent, "arun"):
# Agno >= 2.5 routes through public arun(..., stream=...)
wrap_function_wrapper(Agent, "arun", _arun_wrapper_public)

mark_patched(Agent)
return Agent
139 changes: 139 additions & 0 deletions py/src/braintrust/wrappers/agno/run_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import time
from inspect import isawaitable
from typing import Any

from braintrust.logger import start_span
from braintrust.span_types import SpanTypeAttribute

from .utils import (
extract_metadata,
extract_metrics,
is_async_iterator,
is_sync_iterator,
omit,
trace_async_stream_result,
trace_sync_stream_result,
)


def run_public_dispatch_wrapper(
wrapped: Any,
instance: Any,
args: Any,
kwargs: Any,
*,
default_name: str,
metadata_component: str,
) -> Any:
"""Trace a public synchronous `run(...)` dispatch method.

Handles both non-streaming return values and synchronous streaming iterators.
For iterator results, span lifecycle is delegated to `trace_sync_stream_result`.
"""
component_name = getattr(instance, "name", None) or default_name
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
input_data = {"input": input_arg}
metadata = {**omit(kwargs, ["input"]), **extract_metadata(instance, metadata_component)}

span = start_span(
name=f"{component_name}.run",
type=SpanTypeAttribute.TASK,
input=input_data,
metadata=metadata,
)
span.set_current()
start = time.time()
try:
result = wrapped(*args, **kwargs)
if is_sync_iterator(result):
return trace_sync_stream_result(result, span, start)
span.log(
output=result,
metrics=extract_metrics(result),
)
span.unset_current()
span.end()
return result
except Exception as e:
span.log(error=str(e))
span.unset_current()
span.end()
raise


def arun_public_dispatch_wrapper(
wrapped: Any,
instance: Any,
args: Any,
kwargs: Any,
*,
default_name: str,
metadata_component: str,
) -> Any:
"""Trace a public `arun(...)` dispatch method across async return contracts.

Supports all observed `arun` dispatcher behaviors:
- immediate return value
- awaitable returning a value
- direct async iterator
- awaitable returning an async iterator

If an async iterator is returned (directly or after await), span lifecycle is
delegated to `trace_async_stream_result` so the span remains open until stream
consumption completes.
"""
component_name = getattr(instance, "name", None) or default_name
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
input_data = {"input": input_arg}
metadata = {**omit(kwargs, ["input"]), **extract_metadata(instance, metadata_component)}

span = start_span(
name=f"{component_name}.arun",
type=SpanTypeAttribute.TASK,
input=input_data,
metadata=metadata,
)
span.set_current()
start = time.time()
try:
result = wrapped(*args, **kwargs)

if isawaitable(result):

async def _trace_awaitable():
should_end_span = True
try:
awaited = await result
if is_async_iterator(awaited):
should_end_span = False
return trace_async_stream_result(awaited, span, start)
span.log(
output=awaited,
metrics=extract_metrics(awaited),
)
return awaited
except Exception as e:
span.log(error=str(e))
raise
finally:
if should_end_span:
span.unset_current()
span.end()

return _trace_awaitable()

if is_async_iterator(result):
return trace_async_stream_result(result, span, start)

span.log(
output=result,
metrics=extract_metrics(result),
)
span.unset_current()
span.end()
return result
except Exception as e:
span.log(error=str(e))
span.unset_current()
span.end()
raise
28 changes: 14 additions & 14 deletions py/src/braintrust/wrappers/agno/team.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from braintrust.span_types import SpanTypeAttribute
from wrapt import wrap_function_wrapper

from .run_helpers import arun_public_dispatch_wrapper, run_public_dispatch_wrapper
from .utils import (
_aggregate_agent_chunks,
extract_metadata,
Expand Down Expand Up @@ -46,19 +47,18 @@ def _run_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: Any):
return _create_run_span(wrapped, instance, args, kwargs, input_data)

def _run_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
"""Entry point for public run(input)."""
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
input_data = {"input": input_arg}
return _create_run_span(wrapped, instance, args, kwargs, input_data)
return run_public_dispatch_wrapper(
wrapped, instance, args, kwargs, default_name="Team", metadata_component="team"
)

# Wrap private method if it exists, otherwise wrap public method
if hasattr(Team, "_run"):
wrap_function_wrapper(Team, "_run", _run_wrapper_private)
elif hasattr(Team, "run"):
wrap_function_wrapper(Team, "run", _run_wrapper_public)

async def _create_arun_span(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict):
"""Shared logic to create span and execute arun method."""
async def _create_arun_span_private(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict):
"""Shared logic to create span and execute async private _arun method."""
agent_name = getattr(instance, "name", None) or "Team"
span_name = f"{agent_name}.arun"

Expand All @@ -80,19 +80,16 @@ async def _arun_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs:
run_response = args[0] if len(args) > 0 else kwargs.get("run_response")
input_arg = args[1] if len(args) > 1 else kwargs.get("input")
input_data = {"run_response": run_response, "input": input_arg}
return await _create_arun_span(wrapped, instance, args, kwargs, input_data)
return await _create_arun_span_private(wrapped, instance, args, kwargs, input_data)

async def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
"""Entry point for public arun(input)."""
input_arg = args[0] if len(args) > 0 else kwargs.get("input")
input_data = {"input": input_arg}
return await _create_arun_span(wrapped, instance, args, kwargs, input_data)
def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any):
return arun_public_dispatch_wrapper(
wrapped, instance, args, kwargs, default_name="Team", metadata_component="team"
)

# Wrap private method if it exists, otherwise wrap public method
if hasattr(Team, "_arun"):
wrap_function_wrapper(Team, "_arun", _arun_wrapper_private)
elif hasattr(Team, "arun"):
wrap_function_wrapper(Team, "arun", _arun_wrapper_public)

def run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
agent_name = getattr(instance, "name", None) or "Team"
Expand Down Expand Up @@ -211,6 +208,9 @@ async def _trace_stream():

if hasattr(Team, "_arun_stream"):
wrap_function_wrapper(Team, "_arun_stream", arun_stream_wrapper)
elif not hasattr(Team, "_arun") and hasattr(Team, "arun"):
# Agno >= 2.5 routes through public arun(..., stream=...)
wrap_function_wrapper(Team, "arun", _arun_wrapper_public)

mark_patched(Team)
return Team
Loading
Loading