Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions py/src/braintrust/wrappers/claude_agent_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,39 +66,32 @@ def setup_claude_agent_sdk(

import claude_agent_sdk

# Store original classes before patching
original_client = claude_agent_sdk.ClaudeSDKClient if hasattr(claude_agent_sdk, "ClaudeSDKClient") else None
original_tool_class = claude_agent_sdk.SdkMcpTool if hasattr(claude_agent_sdk, "SdkMcpTool") else None
original_tool_fn = claude_agent_sdk.tool if hasattr(claude_agent_sdk, "tool") else None

# Patch ClaudeSDKClient
if original_client:
wrapped_client = _create_client_wrapper_class(original_client)
claude_agent_sdk.ClaudeSDKClient = wrapped_client

# Update all modules that already imported ClaudeSDKClient
for module in list(sys.modules.values()):
if module and hasattr(module, "ClaudeSDKClient"):
if getattr(module, "ClaudeSDKClient", None) is original_client:
setattr(module, "ClaudeSDKClient", wrapped_client)

# Patch SdkMcpTool
if original_tool_class:
wrapped_tool_class = _create_tool_wrapper_class(original_tool_class)
claude_agent_sdk.SdkMcpTool = wrapped_tool_class

# Update all modules that already imported SdkMcpTool
for module in list(sys.modules.values()):
if module and hasattr(module, "SdkMcpTool"):
if getattr(module, "SdkMcpTool", None) is original_tool_class:
setattr(module, "SdkMcpTool", wrapped_tool_class)

# Patch tool() decorator
if original_tool_fn:
wrapped_tool_fn = _wrap_tool_factory(original_tool_fn)
claude_agent_sdk.tool = wrapped_tool_fn

# Update all modules that already imported tool
for module in list(sys.modules.values()):
if module and hasattr(module, "tool"):
if getattr(module, "tool", None) is original_tool_fn:
Expand Down
71 changes: 71 additions & 0 deletions py/src/braintrust/wrappers/claude_agent_sdk/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from dataclasses import dataclass
from enum import Enum
from types import MappingProxyType
from typing import Final, Mapping


class MessageClassName(str, Enum):
ASSISTANT = "AssistantMessage"
USER = "UserMessage"
RESULT = "ResultMessage"
SYSTEM = "SystemMessage"
TASK_STARTED = "TaskStartedMessage"
TASK_PROGRESS = "TaskProgressMessage"
TASK_NOTIFICATION = "TaskNotificationMessage"


class BlockClassName(str, Enum):
TEXT = "TextBlock"
TOOL_USE = "ToolUseBlock"
TOOL_RESULT = "ToolResultBlock"


class SerializedContentType(str, Enum):
TEXT = "text"
TOOL_USE = "tool_use"
TOOL_RESULT = "tool_result"


@dataclass(frozen=True)
class ToolMetadataKeys:
tool_name: str = "gen_ai.tool.name"
tool_call_id: str = "gen_ai.tool.call.id"
raw_tool_name: str = "raw_tool_name"
operation_name: str = "gen_ai.operation.name"
mcp_method_name: str = "mcp.method.name"
mcp_server: str = "mcp.server"


@dataclass(frozen=True)
class MCPToolMetadataValues:
operation_name: str = "execute_tool"
method_name: str = "tools/call"


DEFAULT_TOOL_NAME: Final[str] = "unknown"

CLAUDE_AGENT_TASK_SPAN_NAME: Final[str] = "Claude Agent"
ANTHROPIC_MESSAGES_CREATE_SPAN_NAME: Final[str] = "anthropic.messages.create"

MCP_TOOL_PREFIX: Final[str] = "mcp__"
MCP_TOOL_NAME_DELIMITER: Final[str] = "__"

TOOL_METADATA: Final[ToolMetadataKeys] = ToolMetadataKeys()
MCP_TOOL_METADATA: Final[MCPToolMetadataValues] = MCPToolMetadataValues()

SERIALIZED_CONTENT_TYPE_BY_BLOCK_CLASS: Final[Mapping[str, SerializedContentType]] = MappingProxyType(
{
BlockClassName.TEXT: SerializedContentType.TEXT,
BlockClassName.TOOL_USE: SerializedContentType.TOOL_USE,
BlockClassName.TOOL_RESULT: SerializedContentType.TOOL_RESULT,
}
)

SYSTEM_MESSAGE_TYPES: Final[frozenset[MessageClassName]] = frozenset(
{
MessageClassName.SYSTEM,
MessageClassName.TASK_STARTED,
MessageClassName.TASK_PROGRESS,
MessageClassName.TASK_NOTIFICATION,
}
)
43 changes: 37 additions & 6 deletions py/src/braintrust/wrappers/claude_agent_sdk/_test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ def get_record_mode() -> str:
return "once"


def _version_suffix() -> str:
version = getattr(claude_agent_sdk, "__version__", "unknown")
return version.replace(".", "_")


def _require_sdk() -> None:
if _CLAUDE_AGENT_SDK_IMPORT_ERROR is not None:
raise ImportError(
Expand All @@ -57,7 +52,7 @@ def _require_sdk() -> None:


def cassette_path(name: str) -> Path:
return CASSETTES_DIR / f"{name}__sdk_{_version_suffix()}.json"
return CASSETTES_DIR / f"{name}.json"


def _normalize_write(data: str, *, sanitize: bool = False) -> dict[str, Any]:
Expand Down Expand Up @@ -91,12 +86,48 @@ def _sanitize_json_for_storage(value: Any) -> Any:
if isinstance(value, list):
return [_sanitize_json_for_storage(item) for item in value]
if isinstance(value, dict):
value = _compact_initialize_message_for_storage(value)
return {key: _sanitize_field_for_storage(key, item) for key, item in value.items()}
if isinstance(value, str):
return _sanitize_string_for_storage(value)
return value


def _compact_initialize_message_for_storage(value: dict[str, Any]) -> dict[str, Any]:
if value.get("type") != "control_response":
return value

response = value.get("response")
if not isinstance(response, dict) or response.get("subtype") != "success":
return value

result = response.get("response")
if not isinstance(result, dict) or not _looks_like_initialize_response(result):
return value

compact_result: dict[str, Any] = {}
if "account" in result:
compact_result["account"] = result["account"]

for key in ("available_output_styles", "commands", "models", "agents"):
if key in result:
compact_result[key] = []

return {
**value,
"response": {
**response,
"response": compact_result,
},
}


def _looks_like_initialize_response(value: dict[str, Any]) -> bool:
return "account" in value and any(
key in value for key in ("available_output_styles", "commands", "models", "agents")
)


def _sanitize_field_for_storage(key: str, value: Any) -> Any:
if not isinstance(value, str):
return _sanitize_json_for_storage(value)
Expand Down
Loading
Loading