diff --git a/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_integration.py b/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_integration.py index 775b52c0..8dfffa97 100644 --- a/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_integration.py +++ b/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_integration.py @@ -2333,19 +2333,19 @@ def test_setup_pydantic_ai_is_idempotent_across_new_patch_points(): assert agent_graph_module.ToolManager.__dict__[tool_method_name] is tool_method -def test_serialize_content_part_with_binary_content(): - """Unit test to verify _serialize_content_part handles BinaryContent correctly. +def test_shape_content_part_with_binary_content(): + """Unit test to verify _shape_content_part handles BinaryContent correctly. - This tests the direct serialization of BinaryContent objects and verifies + This tests the direct shaping of BinaryContent objects and verifies they are converted to Braintrust Attachment objects. """ - from braintrust.integrations.pydantic_ai.tracing import _serialize_content_part + from braintrust.integrations.pydantic_ai.tracing import _shape_content_part from braintrust.logger import Attachment from pydantic_ai.models.function import BinaryContent # Test 1: Direct BinaryContent serialization binary = BinaryContent(data=b"test pdf data", media_type="application/pdf") - result = _serialize_content_part(binary) + result = _shape_content_part(binary) assert result is not None, "Should serialize BinaryContent" assert result["type"] == "binary", "Should have type 'binary'" @@ -2356,14 +2356,14 @@ def test_serialize_content_part_with_binary_content(): assert result["attachment"]._reference["content_type"] == "application/pdf" -def test_serialize_content_part_with_user_prompt_part(): - """Unit test to verify _serialize_content_part handles UserPromptPart with nested BinaryContent. +def test_shape_content_part_with_user_prompt_part(): + """Unit test to verify _shape_content_part handles UserPromptPart with nested BinaryContent. This is the critical test for the bug: when a UserPromptPart has a content list - containing BinaryContent, we need to recursively serialize the content items + containing BinaryContent, we need to recursively shape the content items so that BinaryContent is converted to Braintrust Attachment. """ - from braintrust.integrations.pydantic_ai.tracing import _serialize_content_part + from braintrust.integrations.pydantic_ai.tracing import _shape_content_part from braintrust.logger import Attachment from pydantic_ai.messages import UserPromptPart from pydantic_ai.models.function import BinaryContent @@ -2373,10 +2373,10 @@ def test_serialize_content_part_with_user_prompt_part(): binary = BinaryContent(data=pdf_data, media_type="application/pdf") user_prompt_part = UserPromptPart(content=[binary, "What is in this document?"]) - # Serialize the UserPromptPart - result = _serialize_content_part(user_prompt_part) + # Shape the UserPromptPart + result = _shape_content_part(user_prompt_part) - # Verify the result is a dict with serialized content + # Verify the result is a dict with shaped content assert isinstance(result, dict), f"Should return dict, got {type(result)}" assert "content" in result, f"Should have 'content' key. Keys: {result.keys()}" @@ -2384,7 +2384,7 @@ def test_serialize_content_part_with_user_prompt_part(): assert isinstance(content, list), f"Content should be a list, got {type(content)}" assert len(content) == 2, f"Should have 2 content items, got {len(content)}" - # CRITICAL: First item should be serialized BinaryContent with Attachment + # CRITICAL: First item should be shaped BinaryContent with Attachment binary_item = content[0] assert isinstance(binary_item, dict), f"Binary item should be dict, got {type(binary_item)}" assert binary_item.get("type") == "binary", f"Binary item should have type='binary'. Got: {binary_item}" @@ -2398,13 +2398,13 @@ def test_serialize_content_part_with_user_prompt_part(): assert content[1] == "What is in this document?" -def test_serialize_messages_with_binary_content(): - """Unit test to verify _serialize_messages handles ModelRequest with BinaryContent in parts. +def test_shape_messages_with_binary_content(): + """Unit test to verify _shape_messages handles ModelRequest with BinaryContent in parts. - This tests the full message serialization path that's used for the chat span, + This tests the full message shaping path that's used for the chat span, ensuring that nested BinaryContent in UserPromptPart is properly converted. """ - from braintrust.integrations.pydantic_ai.tracing import _serialize_messages + from braintrust.integrations.pydantic_ai.tracing import _shape_messages from braintrust.logger import Attachment from pydantic_ai.messages import ModelRequest, UserPromptPart from pydantic_ai.models.function import BinaryContent @@ -2415,9 +2415,9 @@ def test_serialize_messages_with_binary_content(): user_prompt_part = UserPromptPart(content=[binary, "What is in this document?"]) model_request = ModelRequest(parts=[user_prompt_part]) - # Serialize the messages + # Shape the messages messages = [model_request] - result = _serialize_messages(messages) + result = _shape_messages(messages) # Verify structure assert len(result) == 1, f"Should have 1 message, got {len(result)}" @@ -2435,7 +2435,7 @@ def test_serialize_messages_with_binary_content(): assert isinstance(content, list), f"Content should be list, got {type(content)}" assert len(content) == 2, f"Should have 2 content items, got {len(content)}" - # CRITICAL: First content item should be serialized BinaryContent with Attachment + # CRITICAL: First content item should be shaped BinaryContent with Attachment binary_item = content[0] assert isinstance(binary_item, dict), f"Binary item should be dict, got {type(binary_item)}" assert binary_item.get("type") == "binary", f"Binary item should have type='binary'. Got: {binary_item}" diff --git a/py/src/braintrust/integrations/pydantic_ai/tracing.py b/py/src/braintrust/integrations/pydantic_ai/tracing.py index b4fb6372..1f131457 100644 --- a/py/src/braintrust/integrations/pydantic_ai/tracing.py +++ b/py/src/braintrust/integrations/pydantic_ai/tracing.py @@ -3,10 +3,10 @@ import logging import sys import time +from collections.abc import Mapping from contextlib import AbstractAsyncContextManager from typing import Any -from braintrust.bt_json import bt_safe_deep_copy from braintrust.integrations.utils import _materialize_attachment from braintrust.logger import start_span from braintrust.span_types import SpanTypeAttribute @@ -82,7 +82,7 @@ async def _agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any _maybe_create_tool_spans_from_messages(result) - output = _serialize_result_output(result) + output = _shape_result_output(result) agent_span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time)) return result finally: @@ -106,7 +106,7 @@ def _agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) _maybe_create_tool_spans_from_messages(result) - output = _serialize_result_output(result) + output = _shape_result_output(result) agent_span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time)) return result finally: @@ -213,7 +213,7 @@ async def _agent_run_stream_events_wrapper(wrapped: Any, instance: Any, args: An } if final_result: - output = _serialize_result_output(final_result) + output = _shape_result_output(final_result) agent_span.log(output=output, metrics=metrics) finally: @@ -236,7 +236,7 @@ async def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): result = await wrapped(*args, **kwargs) end_time = time.time() - output = _serialize_model_response(result) + output = _shape_model_response(result) span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time)) return result @@ -259,7 +259,7 @@ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): result = wrapped(*args, **kwargs) end_time = time.time() - output = _serialize_model_response(result) + output = _shape_model_response(result) span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time)) return result @@ -313,7 +313,7 @@ async def wrapper(*args, **kwargs): result = await original_func(*args, **kwargs) end_time = time.time() - output = _serialize_model_response(result) + output = _shape_model_response(result) span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time)) return result @@ -334,7 +334,7 @@ def wrapper(*args, **kwargs): result = original_func(*args, **kwargs) end_time = time.time() - output = _serialize_model_response(result) + output = _shape_model_response(result) span.log(output=output, metrics=_wrapper_span_metrics(start_time, end_time)) return result @@ -382,11 +382,11 @@ def _build_model_class_input_and_metadata(instance: Any, args: Any, kwargs: Any) messages = args[0] if len(args) > 0 else kwargs.get("messages") model_settings = args[1] if len(args) > 1 else kwargs.get("model_settings") - serialized_messages = _serialize_messages(messages) + shaped_messages = _shape_messages(messages) - input_data = {"messages": serialized_messages} + input_data = {"messages": shaped_messages} if model_settings is not None: - input_data["model_settings"] = bt_safe_deep_copy(model_settings) + input_data["model_settings"] = model_settings metadata = _build_model_metadata(model_name, provider, model_settings=None) @@ -409,7 +409,7 @@ async def model_request_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: result = await wrapped(*args, **kwargs) end_time = time.time() - output = _serialize_model_response(result) + output = _shape_model_response(result) metrics = _extract_response_metrics(result, start_time, end_time) span.log(output=output, metrics=metrics) @@ -475,7 +475,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): _maybe_create_tool_spans_from_messages(self.stream_result) - output = _serialize_stream_output(self.stream_result) + output = _shape_stream_output(self.stream_result) self.span_cm.log( output=output, metrics=_wrapper_span_metrics(self.start_time, end_time, self._first_token_time), @@ -576,7 +576,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): try: final_response = self.stream.get() - output = _serialize_model_response(final_response) + output = _shape_model_response(final_response) if self.span_type == SpanTypeAttribute.LLM: metrics = _extract_response_metrics( final_response, self.start_time, end_time, self._first_token_time @@ -686,7 +686,7 @@ def _finalize(self): _maybe_create_tool_spans_from_messages(self._stream_result) - output = _serialize_stream_output(self._stream_result) + output = _shape_stream_output(self._stream_result) self._span.log( output=output, metrics=_wrapper_span_metrics(self._start_time, end_time, self._first_token_time), @@ -747,7 +747,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): try: final_response = self.stream.get() - output = _serialize_model_response(final_response) + output = _shape_model_response(final_response) self.span_cm.log( output=output, metrics=_wrapper_span_metrics(self.start_time, end_time, self._first_token_time), @@ -809,7 +809,7 @@ async def _trace_tool_execution(wrapped: Any, args: Any, kwargs: Any): try: input_data = call.args_as_dict() except Exception: - input_data = bt_safe_deep_copy(getattr(call, "args", None)) + input_data = getattr(call, "args", None) metadata = {"tool_call_id": tool_call_id} if tool_call_id else None @@ -819,7 +819,7 @@ async def _trace_tool_execution(wrapped: Any, args: Any, kwargs: Any): result = await wrapped(*args, **kwargs) end_time = time.time() tool_span.log( - output=bt_safe_deep_copy(result), + output=result, metrics={"start": start_time, "end": end_time, "duration": end_time - start_time}, ) return result @@ -875,14 +875,13 @@ def _create_tool_spans_from_messages_impl(result: Any) -> None: try: input_data = part.args_as_dict() except Exception: - input_data = bt_safe_deep_copy(getattr(part, "args", None)) + input_data = getattr(part, "args", None) output_data = None return_ts: float | None = None if tool_call_id and tool_call_id in returns_by_id: return_part, return_ts = returns_by_id[tool_call_id] - output_data = bt_safe_deep_copy(getattr(return_part, "content", None)) - + output_data = getattr(return_part, "content", None) metadata = {} if tool_call_id: metadata["tool_call_id"] = tool_call_id @@ -916,137 +915,146 @@ def _msg_timestamp(msg: Any) -> float | None: return None -def _serialize_user_prompt(user_prompt: Any) -> Any: - """Serialize user prompt, handling BinaryContent and other types.""" - if user_prompt is None: - return None +_MISSING = object() +_MESSAGE_FIELDS = ("kind", "role", "timestamp") +_PART_FIELDS = ("kind", "part_kind", "tool_name", "tool_call_id") +_RESPONSE_FIELDS = ("kind", "model_name", "timestamp", "usage", "provider_response_id", "provider_details") - if isinstance(user_prompt, str): + +def _shape_user_prompt(user_prompt: Any) -> Any: + """Shape user prompt, materializing BinaryContent where needed.""" + if user_prompt is None or isinstance(user_prompt, str): return user_prompt if isinstance(user_prompt, list): - return [_serialize_content_part(part) for part in user_prompt] + return [_shape_content_part(part) for part in user_prompt] - return _serialize_content_part(user_prompt) + return _shape_content_part(user_prompt) -def _serialize_content_part(part: Any) -> Any: - """Serialize a content part, handling BinaryContent specially. +def _shape_messages(messages: Any) -> Any: + """Shape messages, replacing binary content in message parts with Attachments.""" + if not messages: + return [] - This function handles: - - BinaryContent: converts to Braintrust Attachment - - Parts with nested content (UserPromptPart): recursively serializes content items - - Strings: passes through unchanged - - Other objects: converts to dict via model_dump - """ - if part is None: - return None + return [_shape_message(message) for message in messages] - if hasattr(part, "data") and hasattr(part, "media_type") and hasattr(part, "kind"): - if part.kind == "binary": - data = part.data - media_type = part.media_type - - resolved_attachment = _materialize_attachment(data, mime_type=media_type) - if resolved_attachment is not None: - return { - "type": "binary", - "attachment": resolved_attachment.attachment, - "media_type": resolved_attachment.mime_type, - } - - if hasattr(part, "content"): - content = part.content - if isinstance(content, list): - serialized_content = [_serialize_content_part(item) for item in content] - result = bt_safe_deep_copy(part) - if isinstance(result, dict): - result["content"] = serialized_content - return result - elif content is not None: - serialized_content = _serialize_content_part(content) - result = bt_safe_deep_copy(part) - if isinstance(result, dict): - result["content"] = serialized_content - return result - if isinstance(part, str): +def _shape_message(message: Any) -> Any: + parts = _field_value(message, "parts") + if not parts: + return message + + return _shape_object( + message, fields=_MESSAGE_FIELDS, overrides={"parts": [_shape_content_part(part) for part in parts]} + ) + + +def _shape_content_part(part: Any) -> Any: + """Shape a content part, materializing binary content into Braintrust Attachments.""" + if part is None or isinstance(part, str): return part - return bt_safe_deep_copy(part) + attachment_payload = _shape_binary_content(part) + if attachment_payload is not None: + return attachment_payload + content = _field_value(part, "content") + if content is not _MISSING: + shaped_content = ( + [_shape_content_part(item) for item in content] + if isinstance(content, list) + else _shape_content_part(content) + ) + return _shape_object(part, fields=_PART_FIELDS, overrides={"content": shaped_content}) -def _serialize_messages(messages: Any) -> Any: - """Serialize messages list.""" - if not messages: - return [] + return part - result = [] - for msg in messages: - if hasattr(msg, "parts") and msg.parts: - original_parts = msg.parts - serialized_parts = [_serialize_content_part(p) for p in original_parts] - # Use model_dump with exclude to avoid serializing parts field prematurely - if hasattr(msg, "model_dump"): - try: - serialized_msg = msg.model_dump(exclude={"parts"}, exclude_none=True) - except (TypeError, ValueError): - # If exclude parameter not supported, fall back to bt_safe_deep_copy - serialized_msg = bt_safe_deep_copy(msg) - else: - serialized_msg = bt_safe_deep_copy(msg) - - if isinstance(serialized_msg, dict): - serialized_msg["parts"] = serialized_parts - else: - serialized_msg = bt_safe_deep_copy(msg) +def _shape_binary_content(part: Any) -> dict[str, Any] | None: + if _field_value(part, "kind") != "binary": + return None + + data = _field_value(part, "data") + media_type = _field_value(part, "media_type") + if data is _MISSING or media_type is _MISSING: + return None + + resolved_attachment = _materialize_attachment(data, mime_type=media_type) + if resolved_attachment is None: + return None + + return { + "type": "binary", + "attachment": resolved_attachment.attachment, + "media_type": resolved_attachment.mime_type, + } + + +def _shape_object(value: Any, *, fields: tuple[str, ...], overrides: dict[str, Any]) -> dict[str, Any]: + """Return a shallow readable shape with selected fields and overrides. + + Braintrust handles final serialization. This helper only builds a small dict + when we need to replace nested binary content with Attachments. + """ + shaped = {} + for field in fields: + field_value = _field_value(value, field) + if field_value is not _MISSING: + shaped[field] = field_value + shaped.update(overrides) + return shaped - result.append(serialized_msg) - return result +def _field_value(value: Any, field: str) -> Any: + if isinstance(value, Mapping): + return value.get(field, _MISSING) + return getattr(value, field, _MISSING) -def _serialize_result_output(result: Any) -> Any: - """Serialize agent run result output.""" +def _shape_result_output(result: Any) -> Any: + """Shape agent run result output.""" if not result: return None output_dict = {} if hasattr(result, "output"): - output_dict["output"] = bt_safe_deep_copy(result.output) + output_dict["output"] = result.output if hasattr(result, "response"): - output_dict["response"] = _serialize_model_response(result.response) + output_dict["response"] = _shape_model_response(result.response) - return output_dict if output_dict else bt_safe_deep_copy(result) + return output_dict if output_dict else result -def _serialize_stream_output(stream_result: Any) -> Any: - """Serialize stream result output.""" +def _shape_stream_output(stream_result: Any) -> Any: + """Shape stream result output.""" if not stream_result: return None output_dict = {} if hasattr(stream_result, "response"): - output_dict["response"] = _serialize_model_response(stream_result.response) + output_dict["response"] = _shape_model_response(stream_result.response) return output_dict if output_dict else None -def _serialize_model_response(response: Any) -> Any: - """Serialize a model response.""" +def _shape_model_response(response: Any) -> Any: + """Shape a model response, replacing binary parts with Attachments when present.""" if not response: return None - response_dict = bt_safe_deep_copy(response) - - if hasattr(response, "parts") and isinstance(response_dict, dict): - response_dict["parts"] = [_serialize_content_part(p) for p in response.parts] + parts = _field_value(response, "parts") + if parts is not _MISSING: + return _shape_object( + response, + fields=_RESPONSE_FIELDS, + overrides={"parts": [_shape_content_part(part) for part in parts]}, + ) - return response_dict + return response def _extract_model_info_from_model_instance(model: Any) -> tuple[str | None, str | None]: @@ -1121,7 +1129,7 @@ def _build_model_metadata(model_name: str | None, provider: str | None, model_se if provider: metadata["provider"] = provider if model_settings: - metadata["model_settings"] = bt_safe_deep_copy(model_settings) + metadata["model_settings"] = model_settings return metadata @@ -1237,8 +1245,8 @@ def _context_wrapped_async_producer() -> None: return wrapper -def _serialize_type(obj: Any) -> Any: - """Serialize a type/class for logging, handling Pydantic models and other types. +def _shape_type(obj: Any) -> Any: + """Shape a type/class for logging, handling Pydantic models and other types. This is useful for output_type, toolsets, and similar type parameters. Returns full JSON schema for Pydantic models so engineers can see exactly @@ -1248,7 +1256,7 @@ def _serialize_type(obj: Any) -> Any: # For sequences of types (like Union types or list of models) if isinstance(obj, (list, tuple)): - return [_serialize_type(item) for item in obj] + return [_shape_type(item) for item in obj] # Handle Pydantic AI's output wrappers (ToolOutput, NativeOutput, PromptedOutput, TextOutput) if hasattr(obj, "output"): @@ -1258,7 +1266,7 @@ def _serialize_type(obj: Any) -> Any: wrapper_info["name"] = obj.name if hasattr(obj, "description") and obj.description: wrapper_info["description"] = obj.description - wrapper_info["output"] = _serialize_type(obj.output) + wrapper_info["output"] = _shape_type(obj.output) return wrapper_info # If it's a Pydantic model class, return its full JSON schema @@ -1279,8 +1287,7 @@ def _serialize_type(obj: Any) -> Any: if hasattr(obj, "__name__"): return obj.__name__ - # Try standard serialization - return bt_safe_deep_copy(obj) + return obj def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tuple[dict[str, Any], dict[str, Any]]: @@ -1293,21 +1300,18 @@ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tu user_prompt = args[0] if len(args) > 0 else kwargs.get("user_prompt") if user_prompt is not None: - input_data["user_prompt"] = _serialize_user_prompt(user_prompt) + input_data["user_prompt"] = _shape_user_prompt(user_prompt) for key, value in kwargs.items(): if key == "deps": continue elif key == "message_history": - input_data[key] = _serialize_messages(value) if value is not None else None + input_data[key] = _shape_messages(value) if value is not None else None elif key in ("output_type", "toolsets"): # These often contain types/classes, use special serialization - input_data[key] = _serialize_type(value) if value is not None else None - elif key == "model_settings": - # model_settings passed to run() goes in INPUT (it's a run() parameter) - input_data[key] = bt_safe_deep_copy(value) if value is not None else None + input_data[key] = _shape_type(value) if value is not None else None else: - input_data[key] = bt_safe_deep_copy(value) if value is not None else None + input_data[key] = value if "model" in kwargs: model_name, provider = _parse_model_string(kwargs["model"]) @@ -1333,7 +1337,7 @@ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tu # output_type can be a Pydantic model, str, or other types that get converted to JSON schema if "output_type" not in kwargs and hasattr(instance, "output_type") and instance.output_type is not None: try: - metadata["output_type"] = _serialize_type(instance.output_type) + metadata["output_type"] = _shape_type(instance.output_type) except Exception as e: logger.debug(f"Failed to extract output_type from agent: {e}") @@ -1344,7 +1348,7 @@ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tu toolsets = instance.toolsets if toolsets: # Convert toolsets to a list with FULL tool schemas for input - serialized_toolsets = [] + shaped_toolsets = [] for ts in toolsets: ts_info = { "id": getattr(ts, "id", str(type(ts).__name__)), @@ -1369,8 +1373,8 @@ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tu tool_dict["parameters"] = tool_obj.function_schema.json_schema tools_list.append(tool_dict) ts_info["tools"] = tools_list - serialized_toolsets.append(ts_info) - input_data["toolsets"] = serialized_toolsets + shaped_toolsets.append(ts_info) + input_data["toolsets"] = shaped_toolsets except Exception as e: logger.debug(f"Failed to extract toolsets from agent: {e}") @@ -1402,11 +1406,11 @@ def _build_direct_model_input_and_metadata(args: Any, kwargs: Any) -> tuple[dict messages = args[1] if len(args) > 1 else kwargs.get("messages", []) if messages: - input_data["messages"] = _serialize_messages(messages) + input_data["messages"] = _shape_messages(messages) for key, value in kwargs.items(): if key not in ["model", "messages"]: - input_data[key] = bt_safe_deep_copy(value) if value is not None else None + input_data[key] = value model_name, provider = _parse_model_string(model) metadata = _build_model_metadata(model_name, provider)