From 2bbef0d8bedac31e4e3a4efbc2a87924ccf07c34 Mon Sep 17 00:00:00 2001 From: cedric Date: Mon, 9 Mar 2026 16:24:59 -0700 Subject: [PATCH 1/7] fix: Multi turn evaluation causes JSON serialization error --- py/src/braintrust/functions/invoke.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py/src/braintrust/functions/invoke.py b/py/src/braintrust/functions/invoke.py index 5c566c3f..d8b759ba 100644 --- a/py/src/braintrust/functions/invoke.py +++ b/py/src/braintrust/functions/invoke.py @@ -3,6 +3,7 @@ from sseclient import SSEClient from .._generated_types import FunctionTypeEnum +from ..bt_json import bt_dumps from ..logger import Exportable, _internal_get_global_state, get_span_parent_object, login, proxy_conn from ..util import response_raise_for_status from .constants import INVOKE_API_VERSION @@ -201,7 +202,8 @@ def invoke( if org_name is not None: headers["x-bt-org-name"] = org_name - resp = proxy_conn().post("function/invoke", json=request, headers=headers, stream=stream) + request = bt_dumps(request) + resp = proxy_conn().post("function/invoke", data=request, headers=headers, stream=stream) if resp.status_code == 500: raise BraintrustInvokeError(resp.text) From 4e65eb9c8fa6bd746216b76c2c19f89f82f2b21d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ViaD=C3=A9zo1er?= Date: Tue, 10 Mar 2026 09:49:15 -0700 Subject: [PATCH 2/7] Rename json request variable Co-authored-by: Abhijeet Prasad --- py/src/braintrust/functions/invoke.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/src/braintrust/functions/invoke.py b/py/src/braintrust/functions/invoke.py index d8b759ba..f0b1c3c0 100644 --- a/py/src/braintrust/functions/invoke.py +++ b/py/src/braintrust/functions/invoke.py @@ -202,8 +202,8 @@ def invoke( if org_name is not None: headers["x-bt-org-name"] = org_name - request = bt_dumps(request) - resp = proxy_conn().post("function/invoke", data=request, headers=headers, stream=stream) + request_json = bt_dumps(request) + resp = proxy_conn().post("function/invoke", data=request_json, headers=headers, stream=stream) if resp.status_code == 500: raise BraintrustInvokeError(resp.text) From 9c510c553e661fdbeb914a9382e7918aa080a88d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Tue, 10 Mar 2026 18:20:14 -0700 Subject: [PATCH 3/7] Chore: regression test for deserialization of LLM messages Currently tests OopenAI, Anthropic, Google --- py/src/braintrust/functions/test_invoke.py | 51 ++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/py/src/braintrust/functions/test_invoke.py b/py/src/braintrust/functions/test_invoke.py index c38e2e10..72425702 100644 --- a/py/src/braintrust/functions/test_invoke.py +++ b/py/src/braintrust/functions/test_invoke.py @@ -1,6 +1,9 @@ """Tests for the invoke module, particularly init_function.""" +import json +import pytest +from braintrust.bt_json import bt_dumps from braintrust.functions.invoke import init_function from braintrust.logger import _internal_get_global_state, _internal_reset_global_state @@ -59,3 +62,51 @@ def test_init_function_permanently_disables_cache(self): # Try to start again - should still be disabled because of explicit disable state.span_cache.start() assert state.span_cache.disabled is True + + +class TestInvokeSerializationRegression: + """Regression tests for JSON serialization in invoke (GitHub issue #38).""" + + def test_llm_provider_messages_are_serializable(self): + provider_messages = [] + + try: + from openai.types.chat import ChatCompletionMessage + + provider_messages.append(ChatCompletionMessage(role="assistant", content="The answer is X.")) + except ImportError: + print("OpenAI not imported") + + try: + from anthropic.types import Message, TextBlock, Usage + + provider_messages.append( + Message( + id="msg_123", + type="message", + role="assistant", + content=[TextBlock(type="text", text="The answer is X.")], + model="claude-3-5-sonnet-20241022", + stop_reason="end_turn", + stop_sequence=None, + usage=Usage(input_tokens=10, output_tokens=20), + ) + ) + except ImportError: + print("Anthropic not imported") + + try: + from google.genai.types import Content, Part + + provider_messages.append(Content(role="model", parts=[Part(text="The answer is X.")])) + except ImportError: + print("Google GenAI not imported") + + if not provider_messages: + pytest.skip("no supported LLM provider packages available") + + for msg in provider_messages: + result = bt_dumps(msg) + assert isinstance(result, str) + # Verify the output is valid JSON and serialization didn't silently fail + json.loads(result) From b9f29d341fe00d403e4a2ac0b72beecde3e3621a Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Mon, 9 Mar 2026 18:12:30 -0400 Subject: [PATCH 4/7] fix: Make sure cross context cleanup doesn't raise an error (#58) --- AGENTS.md | 3 + Makefile | 9 +- py/Makefile | 6 +- py/src/braintrust/context.py | 5 +- py/src/braintrust/test_context.py | 21 ++++ .../wrappers/claude_agent_sdk/test_wrapper.py | 114 ++++++++++++++++++ 6 files changed, 153 insertions(+), 5 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index ab9d11e8..becb51bf 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -58,6 +58,8 @@ make test-core nox -l ``` +For larger or cross-cutting changes, also run `make pylint` from `py/` before handing work off. + Targeted wrapper/session runs: ```bash @@ -77,6 +79,7 @@ Key facts: - `test_core` runs without optional vendor packages. - wrapper coverage is split across dedicated nox sessions by provider/version. - `pylint` installs the broad dependency surface before checking files. +- `cd py && make pylint` runs only `pylint`; `cd py && make lint` runs pre-commit hooks first and then `pylint`. - `test-wheel` is a wheel sanity check and requires a built wheel first. When changing behavior, run the narrowest affected session first, then expand only if needed. diff --git a/Makefile b/Makefile index 9185d785..ff5c5edd 100644 --- a/Makefile +++ b/Makefile @@ -24,9 +24,12 @@ test-core: test-wheel: mise exec -- $(MAKE) -C py test-wheel -lint pylint: +lint: mise exec -- $(MAKE) -C py lint +pylint: + mise exec -- $(MAKE) -C py pylint + nox: test help: @@ -35,8 +38,8 @@ help: @echo " fixup - Run pre-commit hooks across the repo" @echo " install-deps - Install Python SDK dependencies via py/Makefile" @echo " install-dev - Install pinned tools and create/update the repo env via mise" - @echo " lint - Run Python SDK lint checks via py/Makefile" - @echo " pylint - Alias for lint" + @echo " lint - Run pre-commit hooks plus Python SDK pylint via py/Makefile" + @echo " pylint - Run Python SDK pylint only via py/Makefile" @echo " nox - Alias for test" @echo " test - Run the Python SDK nox matrix via py/Makefile" @echo " test-core - Run Python SDK core tests via py/Makefile" diff --git a/py/Makefile b/py/Makefile index bfdae330..f692d4e0 100644 --- a/py/Makefile +++ b/py/Makefile @@ -2,7 +2,7 @@ PYTHON ?= python UV := $(PYTHON) -m uv UV_VERSION := $(shell awk '$$1=="uv" { print $$2 }' ../.tool-versions) -.PHONY: lint test test-wheel _template-version clean fixup build verify-build verify help install-build-deps install-dev install-optional test-core _check-git-clean +.PHONY: lint pylint test test-wheel _template-version clean fixup build verify-build verify help install-build-deps install-dev install-optional test-core _check-git-clean clean: rm -rf build dist @@ -14,6 +14,9 @@ fixup: lint: fixup nox -s pylint +pylint: + nox -s pylint + test: nox -x @@ -69,6 +72,7 @@ help: @echo " install-build-deps - Install build dependencies for CI" @echo " install-dev - Install package in development mode with all dependencies" @echo " lint - Run pylint checks" + @echo " pylint - Run pylint without pre-commit hooks" @echo " test - Run all tests" @echo " test-core - Run core tests only" @echo " test-wheel - Run tests against built wheel" diff --git a/py/src/braintrust/context.py b/py/src/braintrust/context.py index 0018f1a0..051864c1 100644 --- a/py/src/braintrust/context.py +++ b/py/src/braintrust/context.py @@ -103,7 +103,10 @@ def set_current_span(self, span_object: Any) -> Any: def unset_current_span(self, context_token: Any = None) -> None: """Unset the current active span.""" if context_token: - self._current_span.reset(context_token) + try: + self._current_span.reset(context_token) + except ValueError: + self._current_span.set(None) else: self._current_span.set(None) diff --git a/py/src/braintrust/test_context.py b/py/src/braintrust/test_context.py index 5a1d9ece..6c3c1fbd 100644 --- a/py/src/braintrust/test_context.py +++ b/py/src/braintrust/test_context.py @@ -753,6 +753,27 @@ async def task_work(): ) +@pytest.mark.asyncio +async def test_unset_current_span_with_cross_context_token_falls_back_to_clear(): + """Cross-context cleanup should not raise if the token can't be reset.""" + from braintrust.context import BraintrustContextManager + + context_manager = BraintrustContextManager() + token = context_manager.set_current_span("parent") + result = {} + + async def other_task(): + try: + context_manager.unset_current_span(token) + result["outcome"] = "ok" + except Exception as e: + result["outcome"] = f"{type(e).__name__}: {e}" + + await asyncio.create_task(other_task()) + + assert result["outcome"] == "ok" + + @pytest.mark.asyncio async def test_async_generator_early_break_context_token(test_logger, with_memory_logger): """ diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py b/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py index db2fd729..e769fd6d 100644 --- a/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py +++ b/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py @@ -5,6 +5,12 @@ the actual Claude Agent SDK. """ +import asyncio +import gc +import sys +import types +from typing import Type + import pytest # Try to import the Claude Agent SDK - skip tests if not available @@ -19,6 +25,7 @@ from braintrust import logger from braintrust.span_types import SpanTypeAttribute from braintrust.test_helpers import init_test_logger +from braintrust.wrappers.claude_agent_sdk import setup_claude_agent_sdk from braintrust.wrappers.claude_agent_sdk._wrapper import ( _create_client_wrapper_class, _create_tool_wrapper_class, @@ -292,3 +299,110 @@ class TestAutoInstrumentClaudeAgentSDK: def test_auto_instrument_claude_agent_sdk(self): """Test auto_instrument patches Claude Agent SDK and creates spans.""" verify_autoinstrument_script("test_auto_claude_agent_sdk.py") + + +class _FakeClaudeAgentOptions: + def __init__(self, model, permission_mode=None): + self.model = model + self.permission_mode = permission_mode + + +class _FakeMessage: + def __init__(self, content): + self.content = content + + +class _FakeResultMessage: + def __init__(self): + self.usage = types.SimpleNamespace(input_tokens=1, output_tokens=1, cache_creation_input_tokens=0) + self.num_turns = 1 + self.session_id = "session-123" + + +class _FakeClaudeSDKClient: + def __init__(self, options): + self.options = options + self._prompt = None + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return None + + async def query(self, prompt): + self._prompt = prompt + + async def receive_response(self): + yield _FakeMessage("Hello") + await asyncio.sleep(0) + yield _FakeResultMessage() + + +class _FakeClaudeSdkModule(types.ModuleType): + ClaudeSDKClient: Type[_FakeClaudeSDKClient] + ClaudeAgentOptions: Type[_FakeClaudeAgentOptions] + SdkMcpTool = None + tool = None + + +class _FakeConsumerModule(types.ModuleType): + ClaudeSDKClient: Type[_FakeClaudeSDKClient] + ClaudeAgentOptions: Type[_FakeClaudeAgentOptions] + + +def _install_fake_claude_sdk(monkeypatch): + fake_module = _FakeClaudeSdkModule("claude_agent_sdk") + fake_module.ClaudeSDKClient = _FakeClaudeSDKClient + fake_module.ClaudeAgentOptions = _FakeClaudeAgentOptions + monkeypatch.setitem(sys.modules, "claude_agent_sdk", fake_module) + return fake_module + + +@pytest.mark.asyncio +async def test_setup_claude_agent_sdk_repro_import_before_setup(memory_logger, monkeypatch): + """Regression test for https://github.com/braintrustdata/braintrust-sdk-python/issues/7.""" + assert not memory_logger.pop() + + fake_sdk = _install_fake_claude_sdk(monkeypatch) + consumer_module_name = "test_issue7_repro_module" + consumer_module = _FakeConsumerModule(consumer_module_name) + consumer_module.ClaudeSDKClient = fake_sdk.ClaudeSDKClient + consumer_module.ClaudeAgentOptions = fake_sdk.ClaudeAgentOptions + monkeypatch.setitem(sys.modules, consumer_module_name, consumer_module) + + # Mirror the reported import pattern: + # from claude_agent_sdk import ClaudeSDKClient, ClaudeAgentOptions + assert setup_claude_agent_sdk(project=PROJECT_NAME, api_key=logger.TEST_API_KEY) + assert consumer_module.ClaudeSDKClient is not _FakeClaudeSDKClient + + loop_errors = [] + received_types = [] + + async def main(): + loop = asyncio.get_running_loop() + loop.set_exception_handler(lambda loop, ctx: loop_errors.append(ctx.get("exception") or ctx.get("message"))) + + options = consumer_module.ClaudeAgentOptions( + model="claude-sonnet-4-20250514", + permission_mode="bypassPermissions", + ) + async with consumer_module.ClaudeSDKClient(options=options) as client: + await client.query("Hello") + async for message in client.receive_response(): + received_types.append(type(message).__name__) + + await asyncio.sleep(0) + gc.collect() + await asyncio.sleep(0.01) + + await main() + + assert loop_errors == [] + assert received_types == ["_FakeMessage", "_FakeResultMessage"] + + spans = memory_logger.pop() + task_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TASK] + assert len(task_spans) == 1 + assert task_spans[0]["span_attributes"]["name"] == "Claude Agent" + assert task_spans[0]["input"] == "Hello" From 8259398de8e1722001c779cf4aa2e4fe046c9f80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Wed, 11 Mar 2026 17:08:08 -0700 Subject: [PATCH 5/7] feat: tracing for raw_response in wrap_openai openai.Client().responses.with_raw_response now has tracing Test added, very similar to the tests for the non raw responses Deals with the create and parse methods --- py/src/braintrust/oai.py | 61 +++++++--- py/src/braintrust/wrappers/test_openai.py | 138 ++++++++++++++++++++++ 2 files changed, 180 insertions(+), 19 deletions(-) diff --git a/py/src/braintrust/oai.py b/py/src/braintrust/oai.py index df848f46..c717f5db 100644 --- a/py/src/braintrust/oai.py +++ b/py/src/braintrust/oai.py @@ -350,18 +350,23 @@ def _postprocess_streaming_results(cls, all_results: list[dict[str, Any]]) -> di class ResponseWrapper: - def __init__(self, create_fn: Callable[..., Any] | None, acreate_fn: Callable[..., Any] | None, name: str = "openai.responses.create"): + def __init__( + self, + create_fn: Callable[..., Any] | None, + acreate_fn: Callable[..., Any] | None, + name: str = "openai.responses.create", + return_raw: bool = False, + ): self.create_fn = create_fn self.acreate_fn = acreate_fn self.name = name + self.return_raw = return_raw def create(self, *args: Any, **kwargs: Any) -> Any: params = self._parse_params(kwargs) stream = kwargs.get("stream", False) - span = start_span( - **merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params) - ) + span = start_span(**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params)) should_end = True try: @@ -373,6 +378,7 @@ def create(self, *args: Any, **kwargs: Any) -> Any: else: raw_response = create_response if stream: + def gen(): try: first = True @@ -401,7 +407,7 @@ def gen(): event_data["metrics"] = {} event_data["metrics"]["time_to_first_token"] = time.time() - start span.log(**event_data) - return raw_response + return create_response if (self.return_raw and hasattr(create_response, "parse")) else raw_response finally: if should_end: span.end() @@ -410,9 +416,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any: params = self._parse_params(kwargs) stream = kwargs.get("stream", False) - span = start_span( - **merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params) - ) + span = start_span(**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params)) should_end = True try: @@ -424,6 +428,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any: else: raw_response = create_response if stream: + async def gen(): try: first = True @@ -453,7 +458,7 @@ async def gen(): event_data["metrics"] = {} event_data["metrics"]["time_to_first_token"] = time.time() - start span.log(**event_data) - return raw_response + return create_response if (self.return_raw and hasattr(create_response, "parse")) else raw_response finally: if should_end: span.end() @@ -506,7 +511,12 @@ def _postprocess_streaming_results(cls, all_results: list[Any]) -> dict[str, Any for result in all_results: usage = getattr(result, "usage", None) - if not usage and hasattr(result, "type") and result.type == "response.completed" and hasattr(result, "response"): + if ( + not usage + and hasattr(result, "type") + and result.type == "response.completed" + and hasattr(result, "response") + ): # Handle summaries from completed response if present if hasattr(result.response, "output") and result.response.output: for output_item in result.response.output: @@ -787,29 +797,43 @@ def __init__(self, chat: Any): class ResponsesV1Wrapper(NamedWrapper): - def __init__(self, responses: Any): + def __init__(self, responses: Any, return_raw: bool = False) -> None: self.__responses = responses + self.__return_raw = return_raw + if not return_raw: + self.with_raw_response = ResponsesV1Wrapper(responses, return_raw=True) super().__init__(responses) def create(self, *args: Any, **kwargs: Any) -> Any: - return ResponseWrapper(self.__responses.with_raw_response.create, None).create(*args, **kwargs) + return ResponseWrapper(self.__responses.with_raw_response.create, None, return_raw=self.__return_raw).create( + *args, **kwargs + ) def parse(self, *args: Any, **kwargs: Any) -> Any: - return ResponseWrapper(self.__responses.with_raw_response.parse, None, "openai.responses.parse").create(*args, **kwargs) + return ResponseWrapper( + self.__responses.with_raw_response.parse, None, "openai.responses.parse", return_raw=self.__return_raw + ).create(*args, **kwargs) class AsyncResponsesV1Wrapper(NamedWrapper): - def __init__(self, responses: Any): + def __init__(self, responses: Any, return_raw: bool = False) -> None: self.__responses = responses + self.__return_raw = return_raw + if not return_raw: + self.with_raw_response = AsyncResponsesV1Wrapper(responses, return_raw=True) super().__init__(responses) async def create(self, *args: Any, **kwargs: Any) -> Any: - response = await ResponseWrapper(None, self.__responses.with_raw_response.create).acreate(*args, **kwargs) - return AsyncResponseWrapper(response) + response = await ResponseWrapper( + None, self.__responses.with_raw_response.create, return_raw=self.__return_raw + ).acreate(*args, **kwargs) + return response if self.__return_raw else AsyncResponseWrapper(response) async def parse(self, *args: Any, **kwargs: Any) -> Any: - response = await ResponseWrapper(None, self.__responses.with_raw_response.parse, "openai.responses.parse").acreate(*args, **kwargs) - return AsyncResponseWrapper(response) + response = await ResponseWrapper( + None, self.__responses.with_raw_response.parse, "openai.responses.parse", return_raw=self.__return_raw + ).acreate(*args, **kwargs) + return response if self.__return_raw else AsyncResponseWrapper(response) class BetaCompletionsV1Wrapper(NamedWrapper): @@ -938,7 +962,6 @@ def _parse_metrics_from_usage(usage: Any) -> dict[str, Any]: return metrics - def prettify_params(params: dict[str, Any]) -> dict[str, Any]: # Filter out NOT_GIVEN parameters # https://linear.app/braintrustdata/issue/BRA-2467 diff --git a/py/src/braintrust/wrappers/test_openai.py b/py/src/braintrust/wrappers/test_openai.py index d763cf22..91e0bce4 100644 --- a/py/src/braintrust/wrappers/test_openai.py +++ b/py/src/braintrust/wrappers/test_openai.py @@ -378,6 +378,7 @@ def __init__(self, id="test_id", type="message"): # No spans should be generated from this unit test assert not memory_logger.pop() + @pytest.mark.vcr def test_openai_embeddings(memory_logger): assert not memory_logger.pop() @@ -1210,6 +1211,142 @@ class NumberAnswer(BaseModel): assert span["output"][0]["content"][0]["parsed"]["reasoning"] +@pytest.mark.vcr +def test_openai_responses_with_raw_response_create(memory_logger): + """Test that with_raw_response.create returns HTTP response headers AND generates a tracing span.""" + assert not memory_logger.pop() + + # Unwrapped client: with_raw_response should work but produce no spans. + unwrapped_client = openai.OpenAI() + raw = unwrapped_client.responses.with_raw_response.create( + model=TEST_MODEL, + input=TEST_PROMPT, + instructions="Just the number please", + ) + assert raw.headers # HTTP response headers are accessible + response = raw.parse() + assert response.output + content = response.output[0].content[0].text + assert "24" in content or "twenty-four" in content.lower() + assert not memory_logger.pop() + + # Wrapped client: with_raw_response should ALSO generate a span. + client = wrap_openai(openai.OpenAI()) + start = time.time() + raw = client.responses.with_raw_response.create( + model=TEST_MODEL, + input=TEST_PROMPT, + instructions="Just the number please", + ) + end = time.time() + + # The raw HTTP response (with headers) must be returned to the caller. + assert raw.headers + response = raw.parse() + assert response.output + content = response.output[0].content[0].text + assert "24" in content or "twenty-four" in content.lower() + + # A span must have been recorded with correct metrics and metadata. + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + metrics = span["metrics"] + assert_metrics_are_valid(metrics, start, end) + assert TEST_MODEL in span["metadata"]["model"] + assert span["metadata"]["provider"] == "openai" + assert TEST_PROMPT in str(span["input"]) + assert len(span["output"]) > 0 + span_content = span["output"][0]["content"][0]["text"] + assert "24" in span_content or "twenty-four" in span_content.lower() + + +@pytest.mark.vcr +def test_openai_responses_with_raw_response_parse(memory_logger): + """Test that with_raw_response.parse returns HTTP response headers AND generates a tracing span.""" + assert not memory_logger.pop() + + class NumberAnswer(BaseModel): + value: int + reasoning: str + + unwrapped_client = openai.OpenAI() + raw_parse = unwrapped_client.responses.with_raw_response.parse( + model=TEST_MODEL, input=TEST_PROMPT, text_format=NumberAnswer + ) + assert raw_parse.headers + parse_response = raw_parse.parse() + assert parse_response.output_parsed + assert parse_response.output_parsed.value == 24 + assert not memory_logger.pop() + + client = wrap_openai(openai.OpenAI()) + start = time.time() + raw_parse = client.responses.with_raw_response.parse(model=TEST_MODEL, input=TEST_PROMPT, text_format=NumberAnswer) + end = time.time() + + assert raw_parse.headers + parse_response = raw_parse.parse() + assert parse_response.output_parsed + assert parse_response.output_parsed.value == 24 + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + metrics = span["metrics"] + assert_metrics_are_valid(metrics, start, end) + assert TEST_MODEL in span["metadata"]["model"] + assert span["metadata"]["provider"] == "openai" + assert TEST_PROMPT in str(span["input"]) + assert span["output"][0]["content"][0]["parsed"]["value"] == 24 + + +@pytest.mark.asyncio +@pytest.mark.vcr +async def test_openai_responses_with_raw_response_async(memory_logger): + """Async version of test_openai_responses_with_raw_response.""" + assert not memory_logger.pop() + + unwrapped_client = AsyncOpenAI() + raw = await unwrapped_client.responses.with_raw_response.create( + model=TEST_MODEL, + input=TEST_PROMPT, + instructions="Just the number please", + ) + assert raw.headers + response = raw.parse() + assert response.output + content = response.output[0].content[0].text + assert "24" in content or "twenty-four" in content.lower() + assert not memory_logger.pop() + + client = wrap_openai(AsyncOpenAI()) + start = time.time() + raw = await client.responses.with_raw_response.create( + model=TEST_MODEL, + input=TEST_PROMPT, + instructions="Just the number please", + ) + end = time.time() + + assert raw.headers + response = raw.parse() + assert response.output + content = response.output[0].content[0].text + assert "24" in content or "twenty-four" in content.lower() + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + metrics = span["metrics"] + assert_metrics_are_valid(metrics, start, end) + assert TEST_MODEL in span["metadata"]["model"] + assert TEST_PROMPT in str(span["input"]) + assert len(span["output"]) > 0 + span_content = span["output"][0]["content"][0]["text"] + assert "24" in span_content or "twenty-four" in span_content.lower() + + @pytest.mark.vcr def test_openai_parallel_tool_calls(memory_logger): """Test parallel tool calls with both streaming and non-streaming modes.""" @@ -1935,6 +2072,7 @@ def test_auto_instrument_openai(self): """Test auto_instrument patches OpenAI, creates spans, and uninstrument works.""" verify_autoinstrument_script("test_auto_openai.py") + class TestZAICompatibleOpenAI: """Tests for validating some ZAI compatibility with OpenAI wrapper.""" From 64a73f01c836ef973e18e1eb5ae35de7796c33dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Wed, 11 Mar 2026 17:12:39 -0700 Subject: [PATCH 6/7] Wrong branch to push --- AGENTS.md | 3 - Makefile | 9 +- py/Makefile | 6 +- py/src/braintrust/context.py | 5 +- py/src/braintrust/oai.py | 61 +++----- py/src/braintrust/test_context.py | 21 --- .../wrappers/claude_agent_sdk/test_wrapper.py | 114 --------------- py/src/braintrust/wrappers/test_openai.py | 138 ------------------ 8 files changed, 24 insertions(+), 333 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index becb51bf..ab9d11e8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -58,8 +58,6 @@ make test-core nox -l ``` -For larger or cross-cutting changes, also run `make pylint` from `py/` before handing work off. - Targeted wrapper/session runs: ```bash @@ -79,7 +77,6 @@ Key facts: - `test_core` runs without optional vendor packages. - wrapper coverage is split across dedicated nox sessions by provider/version. - `pylint` installs the broad dependency surface before checking files. -- `cd py && make pylint` runs only `pylint`; `cd py && make lint` runs pre-commit hooks first and then `pylint`. - `test-wheel` is a wheel sanity check and requires a built wheel first. When changing behavior, run the narrowest affected session first, then expand only if needed. diff --git a/Makefile b/Makefile index ff5c5edd..9185d785 100644 --- a/Makefile +++ b/Makefile @@ -24,12 +24,9 @@ test-core: test-wheel: mise exec -- $(MAKE) -C py test-wheel -lint: +lint pylint: mise exec -- $(MAKE) -C py lint -pylint: - mise exec -- $(MAKE) -C py pylint - nox: test help: @@ -38,8 +35,8 @@ help: @echo " fixup - Run pre-commit hooks across the repo" @echo " install-deps - Install Python SDK dependencies via py/Makefile" @echo " install-dev - Install pinned tools and create/update the repo env via mise" - @echo " lint - Run pre-commit hooks plus Python SDK pylint via py/Makefile" - @echo " pylint - Run Python SDK pylint only via py/Makefile" + @echo " lint - Run Python SDK lint checks via py/Makefile" + @echo " pylint - Alias for lint" @echo " nox - Alias for test" @echo " test - Run the Python SDK nox matrix via py/Makefile" @echo " test-core - Run Python SDK core tests via py/Makefile" diff --git a/py/Makefile b/py/Makefile index f692d4e0..bfdae330 100644 --- a/py/Makefile +++ b/py/Makefile @@ -2,7 +2,7 @@ PYTHON ?= python UV := $(PYTHON) -m uv UV_VERSION := $(shell awk '$$1=="uv" { print $$2 }' ../.tool-versions) -.PHONY: lint pylint test test-wheel _template-version clean fixup build verify-build verify help install-build-deps install-dev install-optional test-core _check-git-clean +.PHONY: lint test test-wheel _template-version clean fixup build verify-build verify help install-build-deps install-dev install-optional test-core _check-git-clean clean: rm -rf build dist @@ -14,9 +14,6 @@ fixup: lint: fixup nox -s pylint -pylint: - nox -s pylint - test: nox -x @@ -72,7 +69,6 @@ help: @echo " install-build-deps - Install build dependencies for CI" @echo " install-dev - Install package in development mode with all dependencies" @echo " lint - Run pylint checks" - @echo " pylint - Run pylint without pre-commit hooks" @echo " test - Run all tests" @echo " test-core - Run core tests only" @echo " test-wheel - Run tests against built wheel" diff --git a/py/src/braintrust/context.py b/py/src/braintrust/context.py index 051864c1..0018f1a0 100644 --- a/py/src/braintrust/context.py +++ b/py/src/braintrust/context.py @@ -103,10 +103,7 @@ def set_current_span(self, span_object: Any) -> Any: def unset_current_span(self, context_token: Any = None) -> None: """Unset the current active span.""" if context_token: - try: - self._current_span.reset(context_token) - except ValueError: - self._current_span.set(None) + self._current_span.reset(context_token) else: self._current_span.set(None) diff --git a/py/src/braintrust/oai.py b/py/src/braintrust/oai.py index c717f5db..df848f46 100644 --- a/py/src/braintrust/oai.py +++ b/py/src/braintrust/oai.py @@ -350,23 +350,18 @@ def _postprocess_streaming_results(cls, all_results: list[dict[str, Any]]) -> di class ResponseWrapper: - def __init__( - self, - create_fn: Callable[..., Any] | None, - acreate_fn: Callable[..., Any] | None, - name: str = "openai.responses.create", - return_raw: bool = False, - ): + def __init__(self, create_fn: Callable[..., Any] | None, acreate_fn: Callable[..., Any] | None, name: str = "openai.responses.create"): self.create_fn = create_fn self.acreate_fn = acreate_fn self.name = name - self.return_raw = return_raw def create(self, *args: Any, **kwargs: Any) -> Any: params = self._parse_params(kwargs) stream = kwargs.get("stream", False) - span = start_span(**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params)) + span = start_span( + **merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params) + ) should_end = True try: @@ -378,7 +373,6 @@ def create(self, *args: Any, **kwargs: Any) -> Any: else: raw_response = create_response if stream: - def gen(): try: first = True @@ -407,7 +401,7 @@ def gen(): event_data["metrics"] = {} event_data["metrics"]["time_to_first_token"] = time.time() - start span.log(**event_data) - return create_response if (self.return_raw and hasattr(create_response, "parse")) else raw_response + return raw_response finally: if should_end: span.end() @@ -416,7 +410,9 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any: params = self._parse_params(kwargs) stream = kwargs.get("stream", False) - span = start_span(**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params)) + span = start_span( + **merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params) + ) should_end = True try: @@ -428,7 +424,6 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any: else: raw_response = create_response if stream: - async def gen(): try: first = True @@ -458,7 +453,7 @@ async def gen(): event_data["metrics"] = {} event_data["metrics"]["time_to_first_token"] = time.time() - start span.log(**event_data) - return create_response if (self.return_raw and hasattr(create_response, "parse")) else raw_response + return raw_response finally: if should_end: span.end() @@ -511,12 +506,7 @@ def _postprocess_streaming_results(cls, all_results: list[Any]) -> dict[str, Any for result in all_results: usage = getattr(result, "usage", None) - if ( - not usage - and hasattr(result, "type") - and result.type == "response.completed" - and hasattr(result, "response") - ): + if not usage and hasattr(result, "type") and result.type == "response.completed" and hasattr(result, "response"): # Handle summaries from completed response if present if hasattr(result.response, "output") and result.response.output: for output_item in result.response.output: @@ -797,43 +787,29 @@ def __init__(self, chat: Any): class ResponsesV1Wrapper(NamedWrapper): - def __init__(self, responses: Any, return_raw: bool = False) -> None: + def __init__(self, responses: Any): self.__responses = responses - self.__return_raw = return_raw - if not return_raw: - self.with_raw_response = ResponsesV1Wrapper(responses, return_raw=True) super().__init__(responses) def create(self, *args: Any, **kwargs: Any) -> Any: - return ResponseWrapper(self.__responses.with_raw_response.create, None, return_raw=self.__return_raw).create( - *args, **kwargs - ) + return ResponseWrapper(self.__responses.with_raw_response.create, None).create(*args, **kwargs) def parse(self, *args: Any, **kwargs: Any) -> Any: - return ResponseWrapper( - self.__responses.with_raw_response.parse, None, "openai.responses.parse", return_raw=self.__return_raw - ).create(*args, **kwargs) + return ResponseWrapper(self.__responses.with_raw_response.parse, None, "openai.responses.parse").create(*args, **kwargs) class AsyncResponsesV1Wrapper(NamedWrapper): - def __init__(self, responses: Any, return_raw: bool = False) -> None: + def __init__(self, responses: Any): self.__responses = responses - self.__return_raw = return_raw - if not return_raw: - self.with_raw_response = AsyncResponsesV1Wrapper(responses, return_raw=True) super().__init__(responses) async def create(self, *args: Any, **kwargs: Any) -> Any: - response = await ResponseWrapper( - None, self.__responses.with_raw_response.create, return_raw=self.__return_raw - ).acreate(*args, **kwargs) - return response if self.__return_raw else AsyncResponseWrapper(response) + response = await ResponseWrapper(None, self.__responses.with_raw_response.create).acreate(*args, **kwargs) + return AsyncResponseWrapper(response) async def parse(self, *args: Any, **kwargs: Any) -> Any: - response = await ResponseWrapper( - None, self.__responses.with_raw_response.parse, "openai.responses.parse", return_raw=self.__return_raw - ).acreate(*args, **kwargs) - return response if self.__return_raw else AsyncResponseWrapper(response) + response = await ResponseWrapper(None, self.__responses.with_raw_response.parse, "openai.responses.parse").acreate(*args, **kwargs) + return AsyncResponseWrapper(response) class BetaCompletionsV1Wrapper(NamedWrapper): @@ -962,6 +938,7 @@ def _parse_metrics_from_usage(usage: Any) -> dict[str, Any]: return metrics + def prettify_params(params: dict[str, Any]) -> dict[str, Any]: # Filter out NOT_GIVEN parameters # https://linear.app/braintrustdata/issue/BRA-2467 diff --git a/py/src/braintrust/test_context.py b/py/src/braintrust/test_context.py index 6c3c1fbd..5a1d9ece 100644 --- a/py/src/braintrust/test_context.py +++ b/py/src/braintrust/test_context.py @@ -753,27 +753,6 @@ async def task_work(): ) -@pytest.mark.asyncio -async def test_unset_current_span_with_cross_context_token_falls_back_to_clear(): - """Cross-context cleanup should not raise if the token can't be reset.""" - from braintrust.context import BraintrustContextManager - - context_manager = BraintrustContextManager() - token = context_manager.set_current_span("parent") - result = {} - - async def other_task(): - try: - context_manager.unset_current_span(token) - result["outcome"] = "ok" - except Exception as e: - result["outcome"] = f"{type(e).__name__}: {e}" - - await asyncio.create_task(other_task()) - - assert result["outcome"] == "ok" - - @pytest.mark.asyncio async def test_async_generator_early_break_context_token(test_logger, with_memory_logger): """ diff --git a/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py b/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py index e769fd6d..db2fd729 100644 --- a/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py +++ b/py/src/braintrust/wrappers/claude_agent_sdk/test_wrapper.py @@ -5,12 +5,6 @@ the actual Claude Agent SDK. """ -import asyncio -import gc -import sys -import types -from typing import Type - import pytest # Try to import the Claude Agent SDK - skip tests if not available @@ -25,7 +19,6 @@ from braintrust import logger from braintrust.span_types import SpanTypeAttribute from braintrust.test_helpers import init_test_logger -from braintrust.wrappers.claude_agent_sdk import setup_claude_agent_sdk from braintrust.wrappers.claude_agent_sdk._wrapper import ( _create_client_wrapper_class, _create_tool_wrapper_class, @@ -299,110 +292,3 @@ class TestAutoInstrumentClaudeAgentSDK: def test_auto_instrument_claude_agent_sdk(self): """Test auto_instrument patches Claude Agent SDK and creates spans.""" verify_autoinstrument_script("test_auto_claude_agent_sdk.py") - - -class _FakeClaudeAgentOptions: - def __init__(self, model, permission_mode=None): - self.model = model - self.permission_mode = permission_mode - - -class _FakeMessage: - def __init__(self, content): - self.content = content - - -class _FakeResultMessage: - def __init__(self): - self.usage = types.SimpleNamespace(input_tokens=1, output_tokens=1, cache_creation_input_tokens=0) - self.num_turns = 1 - self.session_id = "session-123" - - -class _FakeClaudeSDKClient: - def __init__(self, options): - self.options = options - self._prompt = None - - async def __aenter__(self): - return self - - async def __aexit__(self, *args): - return None - - async def query(self, prompt): - self._prompt = prompt - - async def receive_response(self): - yield _FakeMessage("Hello") - await asyncio.sleep(0) - yield _FakeResultMessage() - - -class _FakeClaudeSdkModule(types.ModuleType): - ClaudeSDKClient: Type[_FakeClaudeSDKClient] - ClaudeAgentOptions: Type[_FakeClaudeAgentOptions] - SdkMcpTool = None - tool = None - - -class _FakeConsumerModule(types.ModuleType): - ClaudeSDKClient: Type[_FakeClaudeSDKClient] - ClaudeAgentOptions: Type[_FakeClaudeAgentOptions] - - -def _install_fake_claude_sdk(monkeypatch): - fake_module = _FakeClaudeSdkModule("claude_agent_sdk") - fake_module.ClaudeSDKClient = _FakeClaudeSDKClient - fake_module.ClaudeAgentOptions = _FakeClaudeAgentOptions - monkeypatch.setitem(sys.modules, "claude_agent_sdk", fake_module) - return fake_module - - -@pytest.mark.asyncio -async def test_setup_claude_agent_sdk_repro_import_before_setup(memory_logger, monkeypatch): - """Regression test for https://github.com/braintrustdata/braintrust-sdk-python/issues/7.""" - assert not memory_logger.pop() - - fake_sdk = _install_fake_claude_sdk(monkeypatch) - consumer_module_name = "test_issue7_repro_module" - consumer_module = _FakeConsumerModule(consumer_module_name) - consumer_module.ClaudeSDKClient = fake_sdk.ClaudeSDKClient - consumer_module.ClaudeAgentOptions = fake_sdk.ClaudeAgentOptions - monkeypatch.setitem(sys.modules, consumer_module_name, consumer_module) - - # Mirror the reported import pattern: - # from claude_agent_sdk import ClaudeSDKClient, ClaudeAgentOptions - assert setup_claude_agent_sdk(project=PROJECT_NAME, api_key=logger.TEST_API_KEY) - assert consumer_module.ClaudeSDKClient is not _FakeClaudeSDKClient - - loop_errors = [] - received_types = [] - - async def main(): - loop = asyncio.get_running_loop() - loop.set_exception_handler(lambda loop, ctx: loop_errors.append(ctx.get("exception") or ctx.get("message"))) - - options = consumer_module.ClaudeAgentOptions( - model="claude-sonnet-4-20250514", - permission_mode="bypassPermissions", - ) - async with consumer_module.ClaudeSDKClient(options=options) as client: - await client.query("Hello") - async for message in client.receive_response(): - received_types.append(type(message).__name__) - - await asyncio.sleep(0) - gc.collect() - await asyncio.sleep(0.01) - - await main() - - assert loop_errors == [] - assert received_types == ["_FakeMessage", "_FakeResultMessage"] - - spans = memory_logger.pop() - task_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TASK] - assert len(task_spans) == 1 - assert task_spans[0]["span_attributes"]["name"] == "Claude Agent" - assert task_spans[0]["input"] == "Hello" diff --git a/py/src/braintrust/wrappers/test_openai.py b/py/src/braintrust/wrappers/test_openai.py index 91e0bce4..d763cf22 100644 --- a/py/src/braintrust/wrappers/test_openai.py +++ b/py/src/braintrust/wrappers/test_openai.py @@ -378,7 +378,6 @@ def __init__(self, id="test_id", type="message"): # No spans should be generated from this unit test assert not memory_logger.pop() - @pytest.mark.vcr def test_openai_embeddings(memory_logger): assert not memory_logger.pop() @@ -1211,142 +1210,6 @@ class NumberAnswer(BaseModel): assert span["output"][0]["content"][0]["parsed"]["reasoning"] -@pytest.mark.vcr -def test_openai_responses_with_raw_response_create(memory_logger): - """Test that with_raw_response.create returns HTTP response headers AND generates a tracing span.""" - assert not memory_logger.pop() - - # Unwrapped client: with_raw_response should work but produce no spans. - unwrapped_client = openai.OpenAI() - raw = unwrapped_client.responses.with_raw_response.create( - model=TEST_MODEL, - input=TEST_PROMPT, - instructions="Just the number please", - ) - assert raw.headers # HTTP response headers are accessible - response = raw.parse() - assert response.output - content = response.output[0].content[0].text - assert "24" in content or "twenty-four" in content.lower() - assert not memory_logger.pop() - - # Wrapped client: with_raw_response should ALSO generate a span. - client = wrap_openai(openai.OpenAI()) - start = time.time() - raw = client.responses.with_raw_response.create( - model=TEST_MODEL, - input=TEST_PROMPT, - instructions="Just the number please", - ) - end = time.time() - - # The raw HTTP response (with headers) must be returned to the caller. - assert raw.headers - response = raw.parse() - assert response.output - content = response.output[0].content[0].text - assert "24" in content or "twenty-four" in content.lower() - - # A span must have been recorded with correct metrics and metadata. - spans = memory_logger.pop() - assert len(spans) == 1 - span = spans[0] - metrics = span["metrics"] - assert_metrics_are_valid(metrics, start, end) - assert TEST_MODEL in span["metadata"]["model"] - assert span["metadata"]["provider"] == "openai" - assert TEST_PROMPT in str(span["input"]) - assert len(span["output"]) > 0 - span_content = span["output"][0]["content"][0]["text"] - assert "24" in span_content or "twenty-four" in span_content.lower() - - -@pytest.mark.vcr -def test_openai_responses_with_raw_response_parse(memory_logger): - """Test that with_raw_response.parse returns HTTP response headers AND generates a tracing span.""" - assert not memory_logger.pop() - - class NumberAnswer(BaseModel): - value: int - reasoning: str - - unwrapped_client = openai.OpenAI() - raw_parse = unwrapped_client.responses.with_raw_response.parse( - model=TEST_MODEL, input=TEST_PROMPT, text_format=NumberAnswer - ) - assert raw_parse.headers - parse_response = raw_parse.parse() - assert parse_response.output_parsed - assert parse_response.output_parsed.value == 24 - assert not memory_logger.pop() - - client = wrap_openai(openai.OpenAI()) - start = time.time() - raw_parse = client.responses.with_raw_response.parse(model=TEST_MODEL, input=TEST_PROMPT, text_format=NumberAnswer) - end = time.time() - - assert raw_parse.headers - parse_response = raw_parse.parse() - assert parse_response.output_parsed - assert parse_response.output_parsed.value == 24 - - spans = memory_logger.pop() - assert len(spans) == 1 - span = spans[0] - metrics = span["metrics"] - assert_metrics_are_valid(metrics, start, end) - assert TEST_MODEL in span["metadata"]["model"] - assert span["metadata"]["provider"] == "openai" - assert TEST_PROMPT in str(span["input"]) - assert span["output"][0]["content"][0]["parsed"]["value"] == 24 - - -@pytest.mark.asyncio -@pytest.mark.vcr -async def test_openai_responses_with_raw_response_async(memory_logger): - """Async version of test_openai_responses_with_raw_response.""" - assert not memory_logger.pop() - - unwrapped_client = AsyncOpenAI() - raw = await unwrapped_client.responses.with_raw_response.create( - model=TEST_MODEL, - input=TEST_PROMPT, - instructions="Just the number please", - ) - assert raw.headers - response = raw.parse() - assert response.output - content = response.output[0].content[0].text - assert "24" in content or "twenty-four" in content.lower() - assert not memory_logger.pop() - - client = wrap_openai(AsyncOpenAI()) - start = time.time() - raw = await client.responses.with_raw_response.create( - model=TEST_MODEL, - input=TEST_PROMPT, - instructions="Just the number please", - ) - end = time.time() - - assert raw.headers - response = raw.parse() - assert response.output - content = response.output[0].content[0].text - assert "24" in content or "twenty-four" in content.lower() - - spans = memory_logger.pop() - assert len(spans) == 1 - span = spans[0] - metrics = span["metrics"] - assert_metrics_are_valid(metrics, start, end) - assert TEST_MODEL in span["metadata"]["model"] - assert TEST_PROMPT in str(span["input"]) - assert len(span["output"]) > 0 - span_content = span["output"][0]["content"][0]["text"] - assert "24" in span_content or "twenty-four" in span_content.lower() - - @pytest.mark.vcr def test_openai_parallel_tool_calls(memory_logger): """Test parallel tool calls with both streaming and non-streaming modes.""" @@ -2072,7 +1935,6 @@ def test_auto_instrument_openai(self): """Test auto_instrument patches OpenAI, creates spans, and uninstrument works.""" verify_autoinstrument_script("test_auto_openai.py") - class TestZAICompatibleOpenAI: """Tests for validating some ZAI compatibility with OpenAI wrapper.""" From 38504fa1fa12bdb69c3712a45cb75059a718c4aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Halber?= Date: Wed, 11 Mar 2026 17:38:53 -0700 Subject: [PATCH 7/7] chore: improve regression test about serializing LLM message --- py/src/braintrust/functions/test_invoke.py | 100 +++++++++++---------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/py/src/braintrust/functions/test_invoke.py b/py/src/braintrust/functions/test_invoke.py index 72425702..264217e2 100644 --- a/py/src/braintrust/functions/test_invoke.py +++ b/py/src/braintrust/functions/test_invoke.py @@ -1,10 +1,10 @@ """Tests for the invoke module, particularly init_function.""" import json +from unittest.mock import MagicMock, patch import pytest -from braintrust.bt_json import bt_dumps -from braintrust.functions.invoke import init_function +from braintrust.functions.invoke import init_function, invoke from braintrust.logger import _internal_get_global_state, _internal_reset_global_state @@ -64,49 +64,53 @@ def test_init_function_permanently_disables_cache(self): assert state.span_cache.disabled is True -class TestInvokeSerializationRegression: - """Regression tests for JSON serialization in invoke (GitHub issue #38).""" - - def test_llm_provider_messages_are_serializable(self): - provider_messages = [] - - try: - from openai.types.chat import ChatCompletionMessage - - provider_messages.append(ChatCompletionMessage(role="assistant", content="The answer is X.")) - except ImportError: - print("OpenAI not imported") - - try: - from anthropic.types import Message, TextBlock, Usage - - provider_messages.append( - Message( - id="msg_123", - type="message", - role="assistant", - content=[TextBlock(type="text", text="The answer is X.")], - model="claude-3-5-sonnet-20241022", - stop_reason="end_turn", - stop_sequence=None, - usage=Usage(input_tokens=10, output_tokens=20), - ) - ) - except ImportError: - print("Anthropic not imported") - - try: - from google.genai.types import Content, Part - - provider_messages.append(Content(role="model", parts=[Part(text="The answer is X.")])) - except ImportError: - print("Google GenAI not imported") - - if not provider_messages: - pytest.skip("no supported LLM provider packages available") - - for msg in provider_messages: - result = bt_dumps(msg) - assert isinstance(result, str) - # Verify the output is valid JSON and serialization didn't silently fail - json.loads(result) +def _invoke_with_messages(messages): + """Call invoke() with mocked proxy_conn; return the parsed request body.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {} + mock_conn = MagicMock() + mock_conn.post.return_value = mock_resp + + with ( + patch("braintrust.functions.invoke.login"), + patch("braintrust.functions.invoke.get_span_parent_object") as mock_parent, + patch("braintrust.functions.invoke.proxy_conn", return_value=mock_conn), + ): + mock_parent.return_value.export.return_value = "span-export" + invoke(project_name="test-project", slug="test-fn", messages=messages) + + kwargs = mock_conn.post.call_args.kwargs + assert "data" in kwargs, "invoke must use data= (bt_dumps) not json= (json.dumps) (see issue 38)" + assert "json" not in kwargs + return json.loads(kwargs["data"]) + + +def test_invoke_serializes_openai_messages(): + openai_chat = pytest.importorskip("openai.types.chat") + msg = openai_chat.ChatCompletionMessage(role="assistant", content="The answer is X.") + parsed = _invoke_with_messages([msg]) + assert isinstance(parsed, dict) and parsed + + +def test_invoke_serializes_anthropic_messages(): + anthropic_types = pytest.importorskip("anthropic.types") + msg = anthropic_types.Message( + id="msg_123", + type="message", + role="assistant", + content=[anthropic_types.TextBlock(type="text", text="The answer is X.")], + model="claude-3-5-sonnet-20241022", + stop_reason="end_turn", + stop_sequence=None, + usage=anthropic_types.Usage(input_tokens=10, output_tokens=20), + ) + parsed = _invoke_with_messages([msg]) + assert isinstance(parsed, dict) and parsed + + +def test_invoke_serializes_google_messages(): + google_types = pytest.importorskip("google.genai.types") + msg = google_types.Content(role="model", parts=[google_types.Part(text="The answer is X.")]) + parsed = _invoke_with_messages([msg]) + assert isinstance(parsed, dict) and parsed