11import dataclasses
2+ import inspect
23import logging
34import threading
45import time
5- from collections .abc import AsyncGenerator , AsyncIterable
6+ from collections .abc import AsyncGenerator , AsyncIterable , Mapping
67from typing import Any
78
89from braintrust .logger import start_span
@@ -80,6 +81,7 @@ def release(self) -> None:
8081def _log_tracing_warning (exc : Exception ) -> None :
8182 log .warning ("Error in tracing code" , exc_info = exc )
8283
84+
8385def _parse_tool_name (tool_name : Any ) -> ParsedToolName :
8486 raw_name = str (tool_name ) if tool_name is not None else DEFAULT_TOOL_NAME
8587
@@ -135,6 +137,7 @@ def _serialize_tool_result_output(tool_result_block: Any) -> dict[str, Any]:
135137
136138 return output
137139
140+
138141def _serialize_system_message (message : Any ) -> dict [str , Any ]:
139142 serialized = {"subtype" : getattr (message , "subtype" , None )}
140143
@@ -163,6 +166,137 @@ def _serialize_system_message(message: Any) -> dict[str, Any]:
163166 return serialized
164167
165168
169+ def _serialize_hook_value (value : Any ) -> Any :
170+ if value is None or isinstance (value , (bool , int , float , str )):
171+ return value
172+
173+ if dataclasses .is_dataclass (value ):
174+ return _serialize_hook_value (dataclasses .asdict (value ))
175+
176+ if isinstance (value , Mapping ):
177+ return {str (key ): _serialize_hook_value (item ) for key , item in value .items ()}
178+
179+ if isinstance (value , (list , tuple )):
180+ return [_serialize_hook_value (item ) for item in value ]
181+
182+ if hasattr (value , "__dict__" ):
183+ return {
184+ key : _serialize_hook_value (item )
185+ for key , item in vars (value ).items ()
186+ if not key .startswith ("_" ) and not callable (item )
187+ }
188+
189+ return str (value )
190+
191+
192+ def _serialize_hook_context (context : Any ) -> dict [str , Any ] | None :
193+ serialized = _serialize_hook_value (context )
194+ if not isinstance (serialized , dict ):
195+ return None
196+
197+ serialized .pop ("signal" , None )
198+ return serialized or None
199+
200+
201+ def _callback_name (callback : Any ) -> str :
202+ return getattr (callback , "__qualname__" , None ) or getattr (callback , "__name__" , None ) or type (callback ).__name__
203+
204+
205+ def _resolve_hook_parent (tool_use_id : Any ) -> str | None :
206+ tool_span_tracker = getattr (_thread_local , "tool_span_tracker" , None )
207+ if tool_span_tracker is not None :
208+ tool_span_export = tool_span_tracker .get_span_export (tool_use_id )
209+ if tool_span_export is not None :
210+ return tool_span_export
211+
212+ return getattr (_thread_local , "claude_agent_task_span_export" , None )
213+
214+
215+ def _wrap_hook_callback (callback : Any , * , event_name : str , matcher : Any ) -> Any :
216+ if not callable (callback ) or hasattr (callback , "_braintrust_wrapped_claude_hook" ):
217+ return callback
218+
219+ callback_name = _callback_name (callback )
220+ serialized_matcher = None if matcher is None else str (matcher )
221+
222+ async def wrapped_hook (* args : Any , ** kwargs : Any ) -> Any :
223+ hook_input = args [0 ] if args else kwargs .get ("input" )
224+ tool_use_id = kwargs .get ("tool_use_id" ) if "tool_use_id" in kwargs else (args [1 ] if len (args ) > 1 else None )
225+ context = kwargs .get ("context" ) if "context" in kwargs else (args [2 ] if len (args ) > 2 else None )
226+
227+ span_input = {"input" : _serialize_hook_value (hook_input )}
228+ if tool_use_id is not None :
229+ span_input ["tool_use_id" ] = str (tool_use_id )
230+
231+ context_payload = _serialize_hook_context (context )
232+ if context_payload :
233+ span_input ["context" ] = context_payload
234+
235+ metadata = {
236+ "hook.event" : event_name ,
237+ "hook.callback" : callback_name ,
238+ }
239+ if serialized_matcher :
240+ metadata ["hook.matcher" ] = serialized_matcher
241+
242+ with start_span (
243+ name = f"{ event_name } hook: { callback_name } " ,
244+ span_attributes = {"type" : SpanTypeAttribute .FUNCTION },
245+ input = span_input ,
246+ metadata = metadata ,
247+ parent = _resolve_hook_parent (tool_use_id ),
248+ ) as span :
249+ result = callback (* args , ** kwargs )
250+ if inspect .isawaitable (result ):
251+ result = await result
252+ span .log (output = _serialize_hook_value (result ))
253+ return result
254+
255+ wrapped_hook ._braintrust_wrapped_claude_hook = True # type: ignore[attr-defined]
256+ return wrapped_hook
257+
258+
259+ def _wrap_hook_matcher (matcher : Any , * , event_name : str ) -> Any :
260+ hooks = getattr (matcher , "hooks" , None )
261+ if not isinstance (hooks , list ):
262+ return _wrap_hook_callback (matcher , event_name = event_name , matcher = None )
263+
264+ wrapped_hooks = [
265+ _wrap_hook_callback (callback , event_name = event_name , matcher = getattr (matcher , "matcher" , None ))
266+ for callback in hooks
267+ ]
268+ if hooks == wrapped_hooks :
269+ return matcher
270+
271+ try :
272+ setattr (matcher , "hooks" , wrapped_hooks )
273+ return matcher
274+ except Exception :
275+ if dataclasses .is_dataclass (matcher ):
276+ return dataclasses .replace (matcher , hooks = wrapped_hooks )
277+ return matcher
278+
279+
280+ def _wrap_client_hooks (client : Any ) -> None :
281+ options = getattr (client , "options" , None )
282+ hooks_by_event = getattr (options , "hooks" , None )
283+ if not isinstance (hooks_by_event , dict ):
284+ return
285+
286+ for event_name , matchers in list (hooks_by_event .items ()):
287+ if not isinstance (matchers , list ):
288+ continue
289+
290+ wrapped_matchers = [_wrap_hook_matcher (matcher , event_name = str (event_name )) for matcher in matchers ]
291+ if wrapped_matchers == matchers :
292+ continue
293+
294+ try :
295+ hooks_by_event [event_name ] = wrapped_matchers
296+ except Exception :
297+ continue
298+
299+
166300def _create_tool_wrapper_class (original_tool_class : Any ) -> Any :
167301 """Creates a wrapper class for SdkMcpTool that re-enters active TOOL spans."""
168302
@@ -613,6 +747,7 @@ class WrappedClaudeSDKClient(Wrapper):
613747 def __init__ (self , * args : Any , ** kwargs : Any ):
614748 # Create the original client instance
615749 client = original_client_class (* args , ** kwargs )
750+ _wrap_client_hooks (client )
616751 super ().__init__ (client )
617752 self .__client = client
618753 self .__last_prompt : str | None = None
@@ -624,6 +759,7 @@ async def query(self, *args: Any, **kwargs: Any) -> Any:
624759 # Capture the time when query is called (when LLM call starts)
625760 self .__query_start_time = time .time ()
626761 self .__captured_messages = None
762+ _wrap_client_hooks (self .__client )
627763
628764 # Capture the prompt for use in receive_response
629765 prompt = args [0 ] if args else kwargs .get ("prompt" )
@@ -678,6 +814,7 @@ async def receive_response(self) -> AsyncGenerator[Any, None]:
678814 llm_tracker = LLMSpanTracker (query_start_time = self .__query_start_time )
679815 tool_tracker = ToolSpanTracker ()
680816 task_event_span_tracker = TaskEventSpanTracker (span .export (), tool_tracker )
817+ _thread_local .claude_agent_task_span_export = span .export ()
681818 _thread_local .tool_span_tracker = tool_tracker
682819
683820 try :
@@ -757,6 +894,8 @@ async def receive_response(self) -> AsyncGenerator[Any, None]:
757894 llm_tracker .cleanup ()
758895 if hasattr (_thread_local , "tool_span_tracker" ):
759896 delattr (_thread_local , "tool_span_tracker" )
897+ if hasattr (_thread_local , "claude_agent_task_span_export" ):
898+ delattr (_thread_local , "claude_agent_task_span_export" )
760899
761900 async def __aenter__ (self ) -> "WrappedClaudeSDKClient" :
762901 await self .__client .__aenter__ ()
0 commit comments