diff --git a/py/noxfile.py b/py/noxfile.py index 78cd11b8..3ce32ed4 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -61,7 +61,9 @@ LITELLM_VERSIONS = (LATEST, "1.74.0") # CLI bundling started in 0.1.10 - older versions require external Claude Code installation CLAUDE_AGENT_SDK_VERSIONS = (LATEST, "0.1.10") -AGNO_VERSIONS = (LATEST, "2.1.0") +# Keep LATEST for newest API coverage, and pin 2.4.0 to cover the 2.4 -> 2.5 breaking change +# to internals we leverage for instrumentation. +AGNO_VERSIONS = (LATEST, "2.4.0", "2.1.0") # pydantic_ai 1.x requires Python >= 3.10 # Two test suites with different version requirements: # 1. wrap_openai approach: works with older versions (0.1.9+) diff --git a/py/src/braintrust/wrappers/agno/agent.py b/py/src/braintrust/wrappers/agno/agent.py index 31b020b6..cb63cc49 100644 --- a/py/src/braintrust/wrappers/agno/agent.py +++ b/py/src/braintrust/wrappers/agno/agent.py @@ -5,6 +5,7 @@ from braintrust.span_types import SpanTypeAttribute from wrapt import wrap_function_wrapper +from .run_helpers import arun_public_dispatch_wrapper, run_public_dispatch_wrapper from .utils import ( _aggregate_agent_chunks, extract_metadata, @@ -46,10 +47,9 @@ def _run_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: Any): return _create_run_span(wrapped, instance, args, kwargs, input_data) def _run_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): - """Entry point for public run(input).""" - input_arg = args[0] if len(args) > 0 else kwargs.get("input") - input_data = {"input": input_arg} - return _create_run_span(wrapped, instance, args, kwargs, input_data) + return run_public_dispatch_wrapper( + wrapped, instance, args, kwargs, default_name="Agent", metadata_component="agent" + ) # Wrap private method if it exists, otherwise wrap public method if hasattr(Agent, "_run"): @@ -57,8 +57,8 @@ def _run_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): elif hasattr(Agent, "run"): wrap_function_wrapper(Agent, "run", _run_wrapper_public) - async def _create_arun_span(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict): - """Shared logic to create span and execute arun method.""" + async def _create_arun_span_private(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict): + """Shared logic to create span and execute async private _arun method.""" agent_name = getattr(instance, "name", None) or "Agent" span_name = f"{agent_name}.arun" @@ -80,19 +80,16 @@ async def _arun_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: run_response = args[0] if len(args) > 0 else kwargs.get("run_response") input_arg = args[1] if len(args) > 1 else kwargs.get("input") input_data = {"run_response": run_response, "input": input_arg} - return await _create_arun_span(wrapped, instance, args, kwargs, input_data) + return await _create_arun_span_private(wrapped, instance, args, kwargs, input_data) - async def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): - """Entry point for public arun(input).""" - input_arg = args[0] if len(args) > 0 else kwargs.get("input") - input_data = {"input": input_arg} - return await _create_arun_span(wrapped, instance, args, kwargs, input_data) + def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): + return arun_public_dispatch_wrapper( + wrapped, instance, args, kwargs, default_name="Agent", metadata_component="agent" + ) # Wrap private method if it exists, otherwise wrap public method if hasattr(Agent, "_arun"): wrap_function_wrapper(Agent, "_arun", _arun_wrapper_private) - elif hasattr(Agent, "arun"): - wrap_function_wrapper(Agent, "arun", _arun_wrapper_public) def run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): agent_name = getattr(instance, "name", None) or "Agent" @@ -211,6 +208,9 @@ async def _trace_stream(): if hasattr(Agent, "_arun_stream"): wrap_function_wrapper(Agent, "_arun_stream", arun_stream_wrapper) + elif not hasattr(Agent, "_arun") and hasattr(Agent, "arun"): + # Agno >= 2.5 routes through public arun(..., stream=...) + wrap_function_wrapper(Agent, "arun", _arun_wrapper_public) mark_patched(Agent) return Agent diff --git a/py/src/braintrust/wrappers/agno/run_helpers.py b/py/src/braintrust/wrappers/agno/run_helpers.py new file mode 100644 index 00000000..3be7f587 --- /dev/null +++ b/py/src/braintrust/wrappers/agno/run_helpers.py @@ -0,0 +1,139 @@ +import time +from inspect import isawaitable +from typing import Any + +from braintrust.logger import start_span +from braintrust.span_types import SpanTypeAttribute + +from .utils import ( + extract_metadata, + extract_metrics, + is_async_iterator, + is_sync_iterator, + omit, + trace_async_stream_result, + trace_sync_stream_result, +) + + +def run_public_dispatch_wrapper( + wrapped: Any, + instance: Any, + args: Any, + kwargs: Any, + *, + default_name: str, + metadata_component: str, +) -> Any: + """Trace a public synchronous `run(...)` dispatch method. + + Handles both non-streaming return values and synchronous streaming iterators. + For iterator results, span lifecycle is delegated to `trace_sync_stream_result`. + """ + component_name = getattr(instance, "name", None) or default_name + input_arg = args[0] if len(args) > 0 else kwargs.get("input") + input_data = {"input": input_arg} + metadata = {**omit(kwargs, ["input"]), **extract_metadata(instance, metadata_component)} + + span = start_span( + name=f"{component_name}.run", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata=metadata, + ) + span.set_current() + start = time.time() + try: + result = wrapped(*args, **kwargs) + if is_sync_iterator(result): + return trace_sync_stream_result(result, span, start) + span.log( + output=result, + metrics=extract_metrics(result), + ) + span.unset_current() + span.end() + return result + except Exception as e: + span.log(error=str(e)) + span.unset_current() + span.end() + raise + + +def arun_public_dispatch_wrapper( + wrapped: Any, + instance: Any, + args: Any, + kwargs: Any, + *, + default_name: str, + metadata_component: str, +) -> Any: + """Trace a public `arun(...)` dispatch method across async return contracts. + + Supports all observed `arun` dispatcher behaviors: + - immediate return value + - awaitable returning a value + - direct async iterator + - awaitable returning an async iterator + + If an async iterator is returned (directly or after await), span lifecycle is + delegated to `trace_async_stream_result` so the span remains open until stream + consumption completes. + """ + component_name = getattr(instance, "name", None) or default_name + input_arg = args[0] if len(args) > 0 else kwargs.get("input") + input_data = {"input": input_arg} + metadata = {**omit(kwargs, ["input"]), **extract_metadata(instance, metadata_component)} + + span = start_span( + name=f"{component_name}.arun", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata=metadata, + ) + span.set_current() + start = time.time() + try: + result = wrapped(*args, **kwargs) + + if isawaitable(result): + + async def _trace_awaitable(): + should_end_span = True + try: + awaited = await result + if is_async_iterator(awaited): + should_end_span = False + return trace_async_stream_result(awaited, span, start) + span.log( + output=awaited, + metrics=extract_metrics(awaited), + ) + return awaited + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_end_span: + span.unset_current() + span.end() + + return _trace_awaitable() + + if is_async_iterator(result): + return trace_async_stream_result(result, span, start) + + span.log( + output=result, + metrics=extract_metrics(result), + ) + span.unset_current() + span.end() + return result + except Exception as e: + span.log(error=str(e)) + span.unset_current() + span.end() + raise diff --git a/py/src/braintrust/wrappers/agno/team.py b/py/src/braintrust/wrappers/agno/team.py index 294fbc08..f82fc9b5 100644 --- a/py/src/braintrust/wrappers/agno/team.py +++ b/py/src/braintrust/wrappers/agno/team.py @@ -5,6 +5,7 @@ from braintrust.span_types import SpanTypeAttribute from wrapt import wrap_function_wrapper +from .run_helpers import arun_public_dispatch_wrapper, run_public_dispatch_wrapper from .utils import ( _aggregate_agent_chunks, extract_metadata, @@ -46,10 +47,9 @@ def _run_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: Any): return _create_run_span(wrapped, instance, args, kwargs, input_data) def _run_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): - """Entry point for public run(input).""" - input_arg = args[0] if len(args) > 0 else kwargs.get("input") - input_data = {"input": input_arg} - return _create_run_span(wrapped, instance, args, kwargs, input_data) + return run_public_dispatch_wrapper( + wrapped, instance, args, kwargs, default_name="Team", metadata_component="team" + ) # Wrap private method if it exists, otherwise wrap public method if hasattr(Team, "_run"): @@ -57,8 +57,8 @@ def _run_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): elif hasattr(Team, "run"): wrap_function_wrapper(Team, "run", _run_wrapper_public) - async def _create_arun_span(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict): - """Shared logic to create span and execute arun method.""" + async def _create_arun_span_private(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict): + """Shared logic to create span and execute async private _arun method.""" agent_name = getattr(instance, "name", None) or "Team" span_name = f"{agent_name}.arun" @@ -80,19 +80,16 @@ async def _arun_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: run_response = args[0] if len(args) > 0 else kwargs.get("run_response") input_arg = args[1] if len(args) > 1 else kwargs.get("input") input_data = {"run_response": run_response, "input": input_arg} - return await _create_arun_span(wrapped, instance, args, kwargs, input_data) + return await _create_arun_span_private(wrapped, instance, args, kwargs, input_data) - async def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): - """Entry point for public arun(input).""" - input_arg = args[0] if len(args) > 0 else kwargs.get("input") - input_data = {"input": input_arg} - return await _create_arun_span(wrapped, instance, args, kwargs, input_data) + def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): + return arun_public_dispatch_wrapper( + wrapped, instance, args, kwargs, default_name="Team", metadata_component="team" + ) # Wrap private method if it exists, otherwise wrap public method if hasattr(Team, "_arun"): wrap_function_wrapper(Team, "_arun", _arun_wrapper_private) - elif hasattr(Team, "arun"): - wrap_function_wrapper(Team, "arun", _arun_wrapper_public) def run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): agent_name = getattr(instance, "name", None) or "Team" @@ -211,6 +208,9 @@ async def _trace_stream(): if hasattr(Team, "_arun_stream"): wrap_function_wrapper(Team, "_arun_stream", arun_stream_wrapper) + elif not hasattr(Team, "_arun") and hasattr(Team, "arun"): + # Agno >= 2.5 routes through public arun(..., stream=...) + wrap_function_wrapper(Team, "arun", _arun_wrapper_public) mark_patched(Team) return Team diff --git a/py/src/braintrust/wrappers/agno/utils.py b/py/src/braintrust/wrappers/agno/utils.py index 8aabe989..072fbfee 100644 --- a/py/src/braintrust/wrappers/agno/utils.py +++ b/py/src/braintrust/wrappers/agno/utils.py @@ -134,21 +134,7 @@ def extract_metrics(result: Any, messages: list | None = None) -> dict[str, Any] # For agent/team responses with metrics if hasattr(result, "metrics") and result.metrics: - agno_metrics = result.metrics - metrics = {} - - # Direct field mapping for agent/team metrics - if hasattr(agno_metrics, "input_tokens") and agno_metrics.input_tokens: - metrics["prompt_tokens"] = agno_metrics.input_tokens - if hasattr(agno_metrics, "output_tokens") and agno_metrics.output_tokens: - metrics["completion_tokens"] = agno_metrics.output_tokens - if hasattr(agno_metrics, "total_tokens") and agno_metrics.total_tokens: - metrics["total_tokens"] = agno_metrics.total_tokens - if hasattr(agno_metrics, "duration") and agno_metrics.duration: - metrics["duration"] = agno_metrics.duration - if hasattr(agno_metrics, "time_to_first_token") and agno_metrics.time_to_first_token: - metrics["time_to_first_token"] = agno_metrics.time_to_first_token - + metrics = parse_metrics_from_agno(result.metrics) return metrics if metrics else None # If no metrics found and we have messages, look for metrics in assistant messages (model-specific) @@ -165,14 +151,16 @@ def extract_streaming_metrics(aggregated: dict[str, Any], start_time: float) -> """Extract metrics from aggregated streaming response.""" metrics = {} - # Add duration - metrics["duration"] = time.time() - start_time - # Extract metrics from aggregated data # The metrics are already in Braintrust format from _aggregate_model_chunks if aggregated.get("metrics") and isinstance(aggregated["metrics"], dict): # Merge the aggregated metrics metrics.update(aggregated["metrics"]) + # Handle object-like metrics payloads (e.g. RunCompletedEvent.metrics) + elif aggregated.get("metrics"): + parsed_metrics = parse_metrics_from_agno(aggregated["metrics"]) + if parsed_metrics: + metrics.update(parsed_metrics) # Also check response_usage for backward compatibility elif aggregated.get("response_usage"): response_metrics = parse_metrics_from_agno(aggregated["response_usage"]) @@ -357,15 +345,15 @@ def _aggregate_agent_chunks(chunks: list[Any]) -> dict[str, Any]: } for chunk in chunks: - # Handle RunStartedEvent - if hasattr(chunk, "event") and chunk.event == "RunStarted": + event = getattr(chunk, "event", None) + + if event == "RunStarted": if hasattr(chunk, "model"): aggregated["model"] = chunk.model if hasattr(chunk, "model_provider"): aggregated["model_provider"] = chunk.model_provider - # Handle RunContentEvent - elif hasattr(chunk, "event") and chunk.event == "RunContent": + elif event == "RunContent": if hasattr(chunk, "content") and chunk.content: aggregated["content"] += str(chunk.content) # type: ignore if hasattr(chunk, "reasoning_content") and chunk.reasoning_content: @@ -375,18 +363,16 @@ def _aggregate_agent_chunks(chunks: list[Any]) -> dict[str, Any]: if hasattr(chunk, "references"): aggregated["references"] = chunk.references - # Handle RunCompletedEvent - elif hasattr(chunk, "event") and chunk.event == "RunCompleted": + elif event == "RunCompleted": if hasattr(chunk, "metrics"): - aggregated["metrics"] = chunk.metrics + parsed_metrics = parse_metrics_from_agno(chunk.metrics) + aggregated["metrics"] = parsed_metrics if parsed_metrics else chunk.metrics aggregated["finish_reason"] = "stop" - # Handle RunError - elif hasattr(chunk, "event") and chunk.event == "RunError": + elif event == "RunError": aggregated["finish_reason"] = "error" - # Handle tool calls - elif hasattr(chunk, "event") and chunk.event == "ToolCallStarted": + elif event == "ToolCallStarted": if hasattr(chunk, "tool_call"): aggregated["tool_calls"].append( # type:ignore { @@ -402,6 +388,78 @@ def _aggregate_agent_chunks(chunks: list[Any]) -> dict[str, Any]: return {k: v for k, v in aggregated.items() if v not in (None, "")} +def is_sync_iterator(result: Any) -> bool: + return hasattr(result, "__iter__") and hasattr(result, "__next__") + + +def is_async_iterator(result: Any) -> bool: + return hasattr(result, "__aiter__") and hasattr(result, "__anext__") + + +def trace_sync_stream_result(result: Any, span: Any, start: float): + def _trace_stream(): + should_unset = True + try: + first = True + all_chunks = [] + for chunk in result: + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + + aggregated = _aggregate_agent_chunks(all_chunks) + span.log( + output=aggregated, + metrics=extract_streaming_metrics(aggregated, start), + ) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _trace_stream() + + +def trace_async_stream_result(result: Any, span: Any, start: float): + async def _trace_astream(): + should_unset = True + try: + first = True + all_chunks = [] + async for chunk in result: + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + + aggregated = _aggregate_agent_chunks(all_chunks) + span.log( + output=aggregated, + metrics=extract_streaming_metrics(aggregated, start), + ) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _trace_astream() + + # Legacy aliases for backward compatibility _extract_run_metrics = extract_metrics _extract_streaming_metrics = extract_streaming_metrics diff --git a/py/src/braintrust/wrappers/test_agno.py b/py/src/braintrust/wrappers/test_agno.py index 4fd3d845..6516dcdd 100644 --- a/py/src/braintrust/wrappers/test_agno.py +++ b/py/src/braintrust/wrappers/test_agno.py @@ -4,10 +4,18 @@ # pyright: reportUnknownParameterType=false # pyright: reportUnknownVariableType=false # pyright: reportUnknownArgumentType=false +from inspect import isawaitable + import pytest from braintrust import logger +from braintrust.logger import start_span from braintrust.test_helpers import init_test_logger +from braintrust.wrappers.agno import agent as agno_agent_module +from braintrust.wrappers.agno import run_helpers as agno_run_helpers_module from braintrust.wrappers.agno import setup_agno +from braintrust.wrappers.agno import team as agno_team_module +from braintrust.wrappers.agno.agent import wrap_agent +from braintrust.wrappers.agno.team import wrap_team from braintrust.wrappers.test_utils import verify_autoinstrument_script TEST_ORG_ID = "test-org-123" @@ -103,3 +111,456 @@ class TestAutoInstrumentAgno: def test_auto_instrument_agno(self): """Test auto_instrument patches Agno and creates spans.""" verify_autoinstrument_script("test_auto_agno.py") + + +class _FakeMetrics: + def __init__(self): + self.input_tokens = 1 + self.output_tokens = 2 + self.total_tokens = 3 + self.duration = 0.1 + self.time_to_first_token = 0.01 + + +class _FakeRunOutput: + def __init__(self, content: str): + self.content = content + self.status = "COMPLETED" + self.model = "fake-model" + self.model_provider = "FakeProvider" + self.metrics = _FakeMetrics() + + +class _FakeEvent: + def __init__(self, event: str, **kwargs): + self.event = event + for k, v in kwargs.items(): + setattr(self, k, v) + + +def _make_fake_component(name: str): + class FakeComponent: + def __init__(self): + self.name = name + + def run(self, input, stream=False, **kwargs): + if stream: + def _stream(): + yield _FakeEvent("RunStarted", model="fake-model", model_provider="FakeProvider") + yield _FakeEvent("RunContent", content=f"{input}-sync") + yield _FakeEvent("RunCompleted", metrics=_FakeMetrics()) + + return _stream() + return _FakeRunOutput(f"{input}-sync") + + def arun(self, input, stream=False, **kwargs): + if stream: + async def _astream(): + yield _FakeEvent("RunStarted", model="fake-model", model_provider="FakeProvider") + yield _FakeEvent("RunContent", content=f"{input}-async") + yield _FakeEvent("RunCompleted", metrics=_FakeMetrics()) + + return _astream() + + async def _result(): + return _FakeRunOutput(f"{input}-async") + + return _result() + + return FakeComponent + + +def _make_fake_async_dispatch_component(name: str): + class FakeComponent: + def __init__(self): + self.name = name + + async def arun(self, input, stream=False, **kwargs): + if stream: + async def _astream(): + yield _FakeEvent("RunStarted", model="fake-model", model_provider="FakeProvider") + yield _FakeEvent("RunContent", content=f"{input}-awaited-async") + yield _FakeEvent("RunCompleted", metrics=_FakeMetrics()) + + return _astream() + return {"content": f"{input}-awaited-async"} + + return FakeComponent + + +def _make_fake_error_component(name: str): + class FakeComponent: + def __init__(self): + self.name = name + + def run(self, input, stream=False, **kwargs): + if stream: + def _stream(): + yield _FakeEvent("RunStarted", model="fake-model", model_provider="FakeProvider") + raise RuntimeError("sync-stream-error") + + return _stream() + return _FakeRunOutput(f"{input}-sync") + + def arun(self, input, stream=False, **kwargs): + if stream: + async def _astream(): + yield _FakeEvent("RunStarted", model="fake-model", model_provider="FakeProvider") + raise RuntimeError("async-stream-error") + + return _astream() + + async def _result(): + return _FakeRunOutput(f"{input}-async") + + return _result() + + return FakeComponent + + +def _make_fake_private_public_component(name: str): + class FakeComponent: + def __init__(self): + self.name = name + self.calls = [] + + def _run(self, run_response=None, run_messages=None, **kwargs): + self.calls.append("_run") + return _FakeRunOutput("private-run") + + def run(self, input, **kwargs): + self.calls.append("run") + return _FakeRunOutput("public-run") + + async def _arun(self, run_response=None, input=None, **kwargs): + self.calls.append("_arun") + return _FakeRunOutput("private-arun") + + def arun(self, input, **kwargs): + self.calls.append("arun") + + async def _result(): + return _FakeRunOutput("public-arun") + + return _result() + + return FakeComponent + + +@pytest.mark.parametrize( + "wrapper,name", + [ + (wrap_agent, "CompatAgent"), + (wrap_team, "CompatTeam"), + ], +) +def test_agno_public_run_stream_dispatcher_compat(memory_logger, wrapper, name): + """Ensures public run(stream=True) dispatchers are traced as a single streamed task span.""" + Component = wrapper(_make_fake_component(name)) + instance = Component() + + chunks = list(instance.run("hello", stream=True)) + assert len(chunks) == 3 + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["span_attributes"]["name"] == f"{name}.run" + assert span["output"]["content"] == "hello-sync" + assert span["metrics"]["prompt_tokens"] == 1 + assert span["metrics"]["completion_tokens"] == 2 + assert span["metrics"]["duration"] >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "wrapper,name", + [ + (wrap_agent, "CompatAgentAsync"), + (wrap_team, "CompatTeamAsync"), + ], +) +async def test_agno_public_arun_stream_dispatcher_compat(memory_logger, wrapper, name): + """Covers async streaming when arun returns an async iterator directly.""" + Component = wrapper(_make_fake_component(name)) + instance = Component() + + stream = instance.arun("hello", stream=True) + if isawaitable(stream): + stream = await stream + chunks = [chunk async for chunk in stream] + assert len(chunks) == 3 + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["span_attributes"]["name"] == f"{name}.arun" + assert span["output"]["content"] == "hello-async" + assert span["metrics"]["prompt_tokens"] == 1 + assert span["metrics"]["completion_tokens"] == 2 + assert span["metrics"]["duration"] >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "wrapper,name", + [ + (wrap_agent, "CompatAgentAwaitedAsync"), + (wrap_team, "CompatTeamAwaitedAsync"), + ], +) +async def test_agno_public_arun_awaited_async_iterator_compat(memory_logger, wrapper, name): + """Covers async streaming when arun must be awaited before yielding an async iterator.""" + Component = wrapper(_make_fake_async_dispatch_component(name)) + instance = Component() + + stream = await instance.arun("hello", stream=True) + chunks = [chunk async for chunk in stream] + assert len(chunks) == 3 + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["span_attributes"]["name"] == f"{name}.arun" + assert span["output"]["content"] == "hello-awaited-async" + assert span["metrics"]["prompt_tokens"] == 1 + assert span["metrics"]["completion_tokens"] == 2 + + +class _StrictSpan: + def __init__(self): + self.ended = False + + def set_current(self): + return None + + def unset_current(self): + return None + + def log(self, **kwargs): + if self.ended: + raise AssertionError("log called after span.end()") + + def end(self): + self.ended = True + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "module,wrapper,name", + [ + (agno_agent_module, wrap_agent, "StrictAgentAwaitedAsync"), + (agno_team_module, wrap_team, "StrictTeamAwaitedAsync"), + ], +) +async def test_agno_public_arun_awaited_async_iterator_span_lifecycle(monkeypatch, module, wrapper, name): + """Guards against ending the span before awaited async-stream consumption completes.""" + strict_span = _StrictSpan() + monkeypatch.setattr(module, "start_span", lambda **kwargs: strict_span) + monkeypatch.setattr(agno_run_helpers_module, "start_span", lambda **kwargs: strict_span) + + Component = wrapper(_make_fake_async_dispatch_component(name)) + instance = Component() + + stream = await instance.arun("hello", stream=True) + # Span must remain open until the async stream is consumed. + assert strict_span.ended is False + + chunks = [chunk async for chunk in stream] + assert len(chunks) == 3 + assert strict_span.ended is True + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "wrapper,name", + [ + (wrap_agent, "CompatAgentAsyncNonStream"), + (wrap_team, "CompatTeamAsyncNonStream"), + ], +) +async def test_agno_public_arun_non_stream_awaitable_compat(memory_logger, wrapper, name): + """Validates non-streaming async dispatcher path logs output without stream-specific handling.""" + Component = wrapper(_make_fake_component(name)) + instance = Component() + + result = instance.arun("hello", stream=False) + if isawaitable(result): + result = await result + + assert result.content == "hello-async" + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + assert span["span_attributes"]["name"] == f"{name}.arun" + assert span["output"] + + +@pytest.mark.parametrize( + "wrapper,name", + [ + (wrap_agent, "CompatAgentSyncError"), + (wrap_team, "CompatTeamSyncError"), + ], +) +def test_agno_public_run_stream_error_path(memory_logger, wrapper, name): + """Ensures sync stream exceptions are surfaced and recorded on the task span.""" + Component = wrapper(_make_fake_error_component(name)) + instance = Component() + + with pytest.raises(RuntimeError, match="sync-stream-error"): + list(instance.run("boom", stream=True)) + + spans = memory_logger.pop() + assert len(spans) == 1 + assert "sync-stream-error" in spans[0]["error"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "wrapper,name", + [ + (wrap_agent, "CompatAgentAsyncError"), + (wrap_team, "CompatTeamAsyncError"), + ], +) +async def test_agno_public_arun_stream_error_path(memory_logger, wrapper, name): + """Ensures async stream exceptions are surfaced and recorded on the task span.""" + Component = wrapper(_make_fake_error_component(name)) + instance = Component() + + stream = instance.arun("boom", stream=True) + if isawaitable(stream): + stream = await stream + + with pytest.raises(RuntimeError, match="async-stream-error"): + async for _ in stream: + pass + + spans = memory_logger.pop() + assert len(spans) == 1 + assert "async-stream-error" in spans[0]["error"] + + +@pytest.mark.parametrize( + "wrapper,name", + [ + (wrap_agent, "CompatAgentSyncEarlyBreak"), + (wrap_team, "CompatTeamSyncEarlyBreak"), + ], +) +def test_agno_public_run_stream_early_break(memory_logger, wrapper, name): + """Covers early consumer break from sync stream without span lifecycle regressions.""" + Component = wrapper(_make_fake_component(name)) + instance = Component() + + for _ in instance.run("hello", stream=True): + break + + spans = memory_logger.pop() + assert len(spans) == 1 + assert spans[0]["span_attributes"]["name"] == f"{name}.run" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "wrapper,name", + [ + (wrap_agent, "CompatAgentAsyncEarlyBreak"), + (wrap_team, "CompatTeamAsyncEarlyBreak"), + ], +) +async def test_agno_public_arun_stream_early_break(memory_logger, wrapper, name): + """Covers early consumer break from async stream without span lifecycle regressions.""" + Component = wrapper(_make_fake_component(name)) + instance = Component() + + stream = instance.arun("hello", stream=True) + if isawaitable(stream): + stream = await stream + + async for _ in stream: + break + + spans = memory_logger.pop() + assert len(spans) == 1 + assert spans[0]["span_attributes"]["name"] == f"{name}.arun" + + +@pytest.mark.parametrize( + "wrapper,name", + [ + (wrap_agent, "CompatAgentParentSync"), + (wrap_team, "CompatTeamParentSync"), + ], +) +def test_agno_public_run_parent_span_nesting(memory_logger, wrapper, name): + """Confirms public run spans nest under an already-active parent span.""" + Component = wrapper(_make_fake_component(name)) + instance = Component() + + with start_span(name="outer_sync_parent", type="task"): + instance.run("hello") + + spans = memory_logger.pop() + by_name = {s["span_attributes"]["name"]: s for s in spans} + outer = by_name["outer_sync_parent"] + child = by_name[f"{name}.run"] + assert child["span_parents"] == [outer["span_id"]] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "wrapper,name", + [ + (wrap_agent, "CompatAgentParentAsync"), + (wrap_team, "CompatTeamParentAsync"), + ], +) +async def test_agno_public_arun_parent_span_nesting(memory_logger, wrapper, name): + """Confirms public arun streaming spans nest under an already-active parent span.""" + Component = wrapper(_make_fake_component(name)) + instance = Component() + + with start_span(name="outer_async_parent", type="task"): + stream = instance.arun("hello", stream=True) + if isawaitable(stream): + stream = await stream + async for _ in stream: + pass + + spans = memory_logger.pop() + by_name = {s["span_attributes"]["name"]: s for s in spans} + outer = by_name["outer_async_parent"] + child = by_name[f"{name}.arun"] + assert child["span_parents"] == [outer["span_id"]] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "wrapper,name", + [ + (wrap_agent, "CompatAgentPrivatePrecedence"), + (wrap_team, "CompatTeamPrivatePrecedence"), + ], +) +async def test_agno_private_method_precedence_over_public(memory_logger, wrapper, name): + """Ensures classes from older Agno versions that expose private run methods still trace those paths.""" + Component = wrapper(_make_fake_private_public_component(name)) + instance = Component() + + _ = instance.run("hello") + _ = await instance.arun("hello") + _ = instance._run("rr", "rm") + _ = await instance._arun("rr", "hello") + + spans = memory_logger.pop() + span_names = {s["span_attributes"]["name"] for s in spans} + + # Calling public methods should not trigger tracing when private wrappers are present. + assert instance.calls == ["run", "arun", "_run", "_arun"] + # Private methods are traced, and they use the same span names as public run/arun. + assert f"{name}.run" in span_names + assert f"{name}.arun" in span_names + assert len(spans) == 2