diff --git a/py/src/braintrust/wrappers/pydantic_ai.py b/py/src/braintrust/wrappers/pydantic_ai.py index c932c91a..9ed61462 100644 --- a/py/src/braintrust/wrappers/pydantic_ai.py +++ b/py/src/braintrust/wrappers/pydantic_ai.py @@ -127,6 +127,26 @@ def agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): wrap_function_wrapper(Agent, "run_sync", agent_run_sync_wrapper) + def agent_to_cli_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + _ensure_model_wrapped(instance) + input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance) + + with start_span( + name=f"agent_to_cli_sync [{instance.name}]" + if hasattr(instance, "name") and instance.name + else "agent_to_cli_sync", + type=SpanTypeAttribute.LLM, + input=input_data if input_data else None, + metadata=metadata, + ) as agent_span: + start_time = time.time() + result = wrapped(*args, **kwargs) + end_time = time.time() + agent_span.log(metrics={"start": start_time, "end": end_time, "duration": end_time - start_time}) + return result + + wrap_function_wrapper(Agent, "to_cli_sync", agent_to_cli_sync_wrapper) + def agent_run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): _ensure_model_wrapped(instance) input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance) diff --git a/py/src/braintrust/wrappers/test_pydantic_ai_integration.py b/py/src/braintrust/wrappers/test_pydantic_ai_integration.py index 4a6ad2ae..14088f74 100644 --- a/py/src/braintrust/wrappers/test_pydantic_ai_integration.py +++ b/py/src/braintrust/wrappers/test_pydantic_ai_integration.py @@ -3,6 +3,7 @@ # pyright: reportUnknownParameterType=false # pyright: reportPrivateUsage=false import asyncio +import inspect import time import pytest @@ -13,6 +14,7 @@ from pydantic import BaseModel from pydantic_ai import Agent, ModelSettings from pydantic_ai.messages import ModelRequest, UserPromptPart +from pydantic_ai.usage import UsageLimits PROJECT_NAME = "test-pydantic-ai-integration" MODEL = "openai:gpt-4o-mini" # Use cheaper model for tests @@ -168,6 +170,71 @@ def is_descendant(child_span, ancestor_id): assert "completion_tokens" in agent_sync_span["metrics"] +def test_agent_to_cli_sync(memory_logger, monkeypatch): + """Test Agent.to_cli_sync() records a CLI session span.""" + assert not memory_logger.pop() + + cli_signature = inspect.signature(Agent.to_cli_sync) + message_history = [ModelRequest(parts=[UserPromptPart(content="Previous question")])] + agent = Agent(MODEL, name="cli-agent", model_settings=ModelSettings(max_tokens=50)) + captured = {} + + async def fake_run_chat( + *, + stream, + agent, + deps, + console, + code_theme, + prog_name, + message_history, + model_settings=None, + usage_limits=None, + ): + assert stream is True + assert prog_name == "braintrust-cli" + assert message_history is not None + captured["model_settings"] = model_settings + captured["usage_limits"] = usage_limits + return 0 + + monkeypatch.setattr("pydantic_ai._cli.run_chat", fake_run_chat) + + cli_kwargs = { + "prog_name": "braintrust-cli", + "message_history": message_history, + } + # pydantic_ai 1.10.0 exposes a smaller to_cli_sync API; newer versions add + # model_settings and usage_limits, so assert those fields only when present. + if "model_settings" in cli_signature.parameters: + cli_kwargs["model_settings"] = ModelSettings(max_tokens=20, temperature=0.2) + if "usage_limits" in cli_signature.parameters: + cli_kwargs["usage_limits"] = UsageLimits(request_limit=3) + + start = time.time() + agent.to_cli_sync(**cli_kwargs) + end = time.time() + + spans = memory_logger.pop() + assert len(spans) == 1, f"Expected 1 CLI span, got {len(spans)}" + + cli_span = spans[0] + assert cli_span["span_attributes"]["type"] == SpanTypeAttribute.LLM + assert cli_span["span_attributes"]["name"] == "agent_to_cli_sync [cli-agent]" + assert cli_span["metadata"]["model"] == "gpt-4o-mini" + assert cli_span["metadata"]["provider"] == "openai" + assert cli_span["input"]["prog_name"] == "braintrust-cli" + assert "message_history" in cli_span["input"] + if "model_settings" in cli_signature.parameters: + assert captured["model_settings"] is not None + assert cli_span["input"]["model_settings"]["max_tokens"] == 20 + assert cli_span["input"]["model_settings"]["temperature"] == 0.2 + if "usage_limits" in cli_signature.parameters: + assert captured["usage_limits"] is not None + assert cli_span["input"]["usage_limits"]["request_limit"] == 3 + _assert_metrics_are_valid(cli_span["metrics"], start, end) + + @pytest.mark.vcr @pytest.mark.asyncio async def test_multiple_identical_sequential_streams(memory_logger):