Skip to content

Commit df192c5

Browse files
committed
feat: Add support for hooks instrumentation to claude agent sdk
1 parent 28875eb commit df192c5

3 files changed

Lines changed: 1069 additions & 1 deletion

File tree

py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import dataclasses
2+
import inspect
23
import logging
34
import threading
45
import time
5-
from collections.abc import AsyncGenerator, AsyncIterable
6+
from collections.abc import AsyncGenerator, AsyncIterable, Mapping
67
from typing import Any
78

89
from braintrust.logger import start_span
@@ -80,6 +81,7 @@ def release(self) -> None:
8081
def _log_tracing_warning(exc: Exception) -> None:
8182
log.warning("Error in tracing code", exc_info=exc)
8283

84+
8385
def _parse_tool_name(tool_name: Any) -> ParsedToolName:
8486
raw_name = str(tool_name) if tool_name is not None else DEFAULT_TOOL_NAME
8587

@@ -135,6 +137,7 @@ def _serialize_tool_result_output(tool_result_block: Any) -> dict[str, Any]:
135137

136138
return output
137139

140+
138141
def _serialize_system_message(message: Any) -> dict[str, Any]:
139142
serialized = {"subtype": getattr(message, "subtype", None)}
140143

@@ -163,6 +166,137 @@ def _serialize_system_message(message: Any) -> dict[str, Any]:
163166
return serialized
164167

165168

169+
def _serialize_hook_value(value: Any) -> Any:
170+
if value is None or isinstance(value, (bool, int, float, str)):
171+
return value
172+
173+
if dataclasses.is_dataclass(value):
174+
return _serialize_hook_value(dataclasses.asdict(value))
175+
176+
if isinstance(value, Mapping):
177+
return {str(key): _serialize_hook_value(item) for key, item in value.items()}
178+
179+
if isinstance(value, (list, tuple)):
180+
return [_serialize_hook_value(item) for item in value]
181+
182+
if hasattr(value, "__dict__"):
183+
return {
184+
key: _serialize_hook_value(item)
185+
for key, item in vars(value).items()
186+
if not key.startswith("_") and not callable(item)
187+
}
188+
189+
return str(value)
190+
191+
192+
def _serialize_hook_context(context: Any) -> dict[str, Any] | None:
193+
serialized = _serialize_hook_value(context)
194+
if not isinstance(serialized, dict):
195+
return None
196+
197+
serialized.pop("signal", None)
198+
return serialized or None
199+
200+
201+
def _callback_name(callback: Any) -> str:
202+
return getattr(callback, "__qualname__", None) or getattr(callback, "__name__", None) or type(callback).__name__
203+
204+
205+
def _resolve_hook_parent(tool_use_id: Any) -> str | None:
206+
tool_span_tracker = getattr(_thread_local, "tool_span_tracker", None)
207+
if tool_span_tracker is not None:
208+
tool_span_export = tool_span_tracker.get_span_export(tool_use_id)
209+
if tool_span_export is not None:
210+
return tool_span_export
211+
212+
return getattr(_thread_local, "claude_agent_task_span_export", None)
213+
214+
215+
def _wrap_hook_callback(callback: Any, *, event_name: str, matcher: Any) -> Any:
216+
if not callable(callback) or hasattr(callback, "_braintrust_wrapped_claude_hook"):
217+
return callback
218+
219+
callback_name = _callback_name(callback)
220+
serialized_matcher = None if matcher is None else str(matcher)
221+
222+
async def wrapped_hook(*args: Any, **kwargs: Any) -> Any:
223+
hook_input = args[0] if args else kwargs.get("input")
224+
tool_use_id = kwargs.get("tool_use_id") if "tool_use_id" in kwargs else (args[1] if len(args) > 1 else None)
225+
context = kwargs.get("context") if "context" in kwargs else (args[2] if len(args) > 2 else None)
226+
227+
span_input = {"input": _serialize_hook_value(hook_input)}
228+
if tool_use_id is not None:
229+
span_input["tool_use_id"] = str(tool_use_id)
230+
231+
context_payload = _serialize_hook_context(context)
232+
if context_payload:
233+
span_input["context"] = context_payload
234+
235+
metadata = {
236+
"hook.event": event_name,
237+
"hook.callback": callback_name,
238+
}
239+
if serialized_matcher:
240+
metadata["hook.matcher"] = serialized_matcher
241+
242+
with start_span(
243+
name=f"{event_name} hook: {callback_name}",
244+
span_attributes={"type": SpanTypeAttribute.FUNCTION},
245+
input=span_input,
246+
metadata=metadata,
247+
parent=_resolve_hook_parent(tool_use_id),
248+
) as span:
249+
result = callback(*args, **kwargs)
250+
if inspect.isawaitable(result):
251+
result = await result
252+
span.log(output=_serialize_hook_value(result))
253+
return result
254+
255+
wrapped_hook._braintrust_wrapped_claude_hook = True # type: ignore[attr-defined]
256+
return wrapped_hook
257+
258+
259+
def _wrap_hook_matcher(matcher: Any, *, event_name: str) -> Any:
260+
hooks = getattr(matcher, "hooks", None)
261+
if not isinstance(hooks, list):
262+
return _wrap_hook_callback(matcher, event_name=event_name, matcher=None)
263+
264+
wrapped_hooks = [
265+
_wrap_hook_callback(callback, event_name=event_name, matcher=getattr(matcher, "matcher", None))
266+
for callback in hooks
267+
]
268+
if hooks == wrapped_hooks:
269+
return matcher
270+
271+
try:
272+
setattr(matcher, "hooks", wrapped_hooks)
273+
return matcher
274+
except Exception:
275+
if dataclasses.is_dataclass(matcher):
276+
return dataclasses.replace(matcher, hooks=wrapped_hooks)
277+
return matcher
278+
279+
280+
def _wrap_client_hooks(client: Any) -> None:
281+
options = getattr(client, "options", None)
282+
hooks_by_event = getattr(options, "hooks", None)
283+
if not isinstance(hooks_by_event, dict):
284+
return
285+
286+
for event_name, matchers in list(hooks_by_event.items()):
287+
if not isinstance(matchers, list):
288+
continue
289+
290+
wrapped_matchers = [_wrap_hook_matcher(matcher, event_name=str(event_name)) for matcher in matchers]
291+
if wrapped_matchers == matchers:
292+
continue
293+
294+
try:
295+
hooks_by_event[event_name] = wrapped_matchers
296+
except Exception:
297+
continue
298+
299+
166300
def _create_tool_wrapper_class(original_tool_class: Any) -> Any:
167301
"""Creates a wrapper class for SdkMcpTool that re-enters active TOOL spans."""
168302

@@ -613,6 +747,7 @@ class WrappedClaudeSDKClient(Wrapper):
613747
def __init__(self, *args: Any, **kwargs: Any):
614748
# Create the original client instance
615749
client = original_client_class(*args, **kwargs)
750+
_wrap_client_hooks(client)
616751
super().__init__(client)
617752
self.__client = client
618753
self.__last_prompt: str | None = None
@@ -624,6 +759,7 @@ async def query(self, *args: Any, **kwargs: Any) -> Any:
624759
# Capture the time when query is called (when LLM call starts)
625760
self.__query_start_time = time.time()
626761
self.__captured_messages = None
762+
_wrap_client_hooks(self.__client)
627763

628764
# Capture the prompt for use in receive_response
629765
prompt = args[0] if args else kwargs.get("prompt")
@@ -678,6 +814,7 @@ async def receive_response(self) -> AsyncGenerator[Any, None]:
678814
llm_tracker = LLMSpanTracker(query_start_time=self.__query_start_time)
679815
tool_tracker = ToolSpanTracker()
680816
task_event_span_tracker = TaskEventSpanTracker(span.export(), tool_tracker)
817+
_thread_local.claude_agent_task_span_export = span.export()
681818
_thread_local.tool_span_tracker = tool_tracker
682819

683820
try:
@@ -757,6 +894,8 @@ async def receive_response(self) -> AsyncGenerator[Any, None]:
757894
llm_tracker.cleanup()
758895
if hasattr(_thread_local, "tool_span_tracker"):
759896
delattr(_thread_local, "tool_span_tracker")
897+
if hasattr(_thread_local, "claude_agent_task_span_export"):
898+
delattr(_thread_local, "claude_agent_task_span_export")
760899

761900
async def __aenter__(self) -> "WrappedClaudeSDKClient":
762901
await self.__client.__aenter__()

0 commit comments

Comments
 (0)