@@ -150,52 +150,70 @@ def _tool_name(tool_call: Any) -> str:
150150 return str (getattr (tool_call , "name" , "unknown_tool" ))
151151
152152
153- async def _agent_call_wrapper (wrapped : Any , instance : Any , args : Any , kwargs : dict [str , Any ]) -> Any :
154- with start_span (
155- name = f"{ _agent_name (instance )} .reply" ,
156- type = SpanTypeAttribute .TASK ,
157- input = _args_kwargs_input (args , kwargs ),
158- metadata = _clean ({"agent_class" : instance .__class__ .__name__ }),
159- ) as span :
160- try :
161- result = await wrapped (* args , ** kwargs )
162- span .log (output = result )
163- return result
164- except Exception as exc :
165- span .log (error = str (exc ))
166- raise
153+ def _make_task_wrapper (
154+ * ,
155+ name_fn : Any ,
156+ metadata_fn : Any ,
157+ input_fn : Any = _args_kwargs_input ,
158+ ) -> Any :
159+ """Build a simple async wrapper that creates a TASK span and logs the result."""
160+
161+ async def _wrapper (wrapped : Any , instance : Any , args : Any , kwargs : dict [str , Any ]) -> Any :
162+ with start_span (
163+ name = name_fn (instance , args , kwargs ),
164+ type = SpanTypeAttribute .TASK ,
165+ input = input_fn (args , kwargs ),
166+ metadata = metadata_fn (instance , args , kwargs ),
167+ ) as span :
168+ try :
169+ result = await wrapped (* args , ** kwargs )
170+ span .log (output = result )
171+ return result
172+ except Exception as exc :
173+ span .log (error = str (exc ))
174+ raise
175+
176+ return _wrapper
177+
178+
179+ _agent_call_wrapper = _make_task_wrapper (
180+ name_fn = lambda instance , _a , _k : f"{ _agent_name (instance )} .reply" ,
181+ metadata_fn = lambda instance , _a , _k : _clean ({"agent_class" : instance .__class__ .__name__ }),
182+ )
183+
184+ _sequential_pipeline_wrapper = _make_task_wrapper (
185+ name_fn = lambda _i , _a , _k : "sequential_pipeline.run" ,
186+ metadata_fn = lambda _i , args , kwargs : _pipeline_metadata (args , kwargs ),
187+ )
188+
189+ _fanout_pipeline_wrapper = _make_task_wrapper (
190+ name_fn = lambda _i , _a , _k : "fanout_pipeline.run" ,
191+ metadata_fn = lambda _i , args , kwargs : _pipeline_metadata (args , kwargs ),
192+ )
167193
168194
169- async def _sequential_pipeline_wrapper (wrapped : Any , instance : Any , args : Any , kwargs : dict [str , Any ]) -> Any :
170- with start_span (
171- name = "sequential_pipeline.run" ,
172- type = SpanTypeAttribute .TASK ,
173- input = _args_kwargs_input (args , kwargs ),
174- metadata = _pipeline_metadata (args , kwargs ),
175- ) as span :
176- try :
177- result = await wrapped (* args , ** kwargs )
178- span .log (output = result )
179- return result
180- except Exception as exc :
181- span .log (error = str (exc ))
182- raise
195+ def _is_async_iterator (value : Any ) -> bool :
196+ try :
197+ return getattr (value , "__aiter__" , None ) is not None and getattr (value , "__anext__" , None ) is not None
198+ except Exception :
199+ return False
183200
184201
185- async def _fanout_pipeline_wrapper (wrapped : Any , instance : Any , args : Any , kwargs : dict [str , Any ]) -> Any :
186- with start_span (
187- name = "fanout_pipeline.run" ,
188- type = SpanTypeAttribute .TASK ,
189- input = _args_kwargs_input (args , kwargs ),
190- metadata = _pipeline_metadata (args , kwargs ),
191- ) as span :
192- try :
193- result = await wrapped (* args , ** kwargs )
194- span .log (output = result )
195- return result
196- except Exception as exc :
197- span .log (error = str (exc ))
198- raise
202+ def _deferred_stream_trace (result : Any , span : Any , stack : contextlib .ExitStack , log_fn : Any ) -> Any :
203+ """Wrap an async iterator so the span stays open until the stream is consumed."""
204+ deferred = stack .pop_all ()
205+
206+ async def _trace ():
207+ with deferred :
208+ last_chunk = None
209+ async with aclosing (result ) as agen :
210+ async for chunk in agen :
211+ last_chunk = chunk
212+ yield chunk
213+ if last_chunk is not None :
214+ log_fn (span , last_chunk )
215+
216+ return _trace ()
199217
200218
201219async def _toolkit_call_tool_function_wrapper (wrapped : Any , instance : Any , args : Any , kwargs : dict [str , Any ]) -> Any :
@@ -218,19 +236,7 @@ async def _toolkit_call_tool_function_wrapper(wrapped: Any, instance: Any, args:
218236 try :
219237 result = await wrapped (* args , ** kwargs )
220238 if _is_async_iterator (result ):
221- deferred = stack .pop_all ()
222-
223- async def _trace ():
224- with deferred :
225- last_chunk = None
226- async with aclosing (result ) as agen :
227- async for chunk in agen :
228- last_chunk = chunk
229- yield chunk
230- if last_chunk is not None :
231- span .log (output = last_chunk )
232-
233- return _trace ()
239+ return _deferred_stream_trace (result , span , stack , lambda s , chunk : s .log (output = chunk ))
234240
235241 span .log (output = result )
236242 return result
@@ -239,13 +245,6 @@ async def _trace():
239245 raise
240246
241247
242- def _is_async_iterator (value : Any ) -> bool :
243- try :
244- return getattr (value , "__aiter__" , None ) is not None and getattr (value , "__anext__" , None ) is not None
245- except Exception :
246- return False
247-
248-
249248async def _model_call_wrapper (wrapped : Any , instance : Any , args : Any , kwargs : dict [str , Any ]) -> Any :
250249 with contextlib .ExitStack () as stack :
251250 span = stack .enter_context (
@@ -259,19 +258,12 @@ async def _model_call_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: di
259258 try :
260259 result = await wrapped (* args , ** kwargs )
261260 if _is_async_iterator (result ):
262- deferred = stack .pop_all ()
263-
264- async def _trace ():
265- with deferred :
266- last_chunk = None
267- async with aclosing (result ) as agen :
268- async for chunk in agen :
269- last_chunk = chunk
270- yield chunk
271- if last_chunk is not None :
272- span .log (output = _model_call_output (last_chunk ), metrics = _extract_metrics (last_chunk ))
273-
274- return _trace ()
261+ return _deferred_stream_trace (
262+ result ,
263+ span ,
264+ stack ,
265+ lambda s , chunk : s .log (output = _model_call_output (chunk ), metrics = _extract_metrics (chunk )),
266+ )
275267
276268 span .log (output = _model_call_output (result ), metrics = _extract_metrics (result ))
277269 return result
0 commit comments