Skip to content

Commit c0efe5e

Browse files
committed
pr review
1 parent 8fff014 commit c0efe5e

1 file changed

Lines changed: 67 additions & 75 deletions

File tree

  • py/src/braintrust/integrations/agentscope

py/src/braintrust/integrations/agentscope/tracing.py

Lines changed: 67 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -150,52 +150,70 @@ def _tool_name(tool_call: Any) -> str:
150150
return str(getattr(tool_call, "name", "unknown_tool"))
151151

152152

153-
async def _agent_call_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: dict[str, Any]) -> Any:
154-
with start_span(
155-
name=f"{_agent_name(instance)}.reply",
156-
type=SpanTypeAttribute.TASK,
157-
input=_args_kwargs_input(args, kwargs),
158-
metadata=_clean({"agent_class": instance.__class__.__name__}),
159-
) as span:
160-
try:
161-
result = await wrapped(*args, **kwargs)
162-
span.log(output=result)
163-
return result
164-
except Exception as exc:
165-
span.log(error=str(exc))
166-
raise
153+
def _make_task_wrapper(
154+
*,
155+
name_fn: Any,
156+
metadata_fn: Any,
157+
input_fn: Any = _args_kwargs_input,
158+
) -> Any:
159+
"""Build a simple async wrapper that creates a TASK span and logs the result."""
160+
161+
async def _wrapper(wrapped: Any, instance: Any, args: Any, kwargs: dict[str, Any]) -> Any:
162+
with start_span(
163+
name=name_fn(instance, args, kwargs),
164+
type=SpanTypeAttribute.TASK,
165+
input=input_fn(args, kwargs),
166+
metadata=metadata_fn(instance, args, kwargs),
167+
) as span:
168+
try:
169+
result = await wrapped(*args, **kwargs)
170+
span.log(output=result)
171+
return result
172+
except Exception as exc:
173+
span.log(error=str(exc))
174+
raise
175+
176+
return _wrapper
177+
178+
179+
_agent_call_wrapper = _make_task_wrapper(
180+
name_fn=lambda instance, _a, _k: f"{_agent_name(instance)}.reply",
181+
metadata_fn=lambda instance, _a, _k: _clean({"agent_class": instance.__class__.__name__}),
182+
)
183+
184+
_sequential_pipeline_wrapper = _make_task_wrapper(
185+
name_fn=lambda _i, _a, _k: "sequential_pipeline.run",
186+
metadata_fn=lambda _i, args, kwargs: _pipeline_metadata(args, kwargs),
187+
)
188+
189+
_fanout_pipeline_wrapper = _make_task_wrapper(
190+
name_fn=lambda _i, _a, _k: "fanout_pipeline.run",
191+
metadata_fn=lambda _i, args, kwargs: _pipeline_metadata(args, kwargs),
192+
)
167193

168194

169-
async def _sequential_pipeline_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: dict[str, Any]) -> Any:
170-
with start_span(
171-
name="sequential_pipeline.run",
172-
type=SpanTypeAttribute.TASK,
173-
input=_args_kwargs_input(args, kwargs),
174-
metadata=_pipeline_metadata(args, kwargs),
175-
) as span:
176-
try:
177-
result = await wrapped(*args, **kwargs)
178-
span.log(output=result)
179-
return result
180-
except Exception as exc:
181-
span.log(error=str(exc))
182-
raise
195+
def _is_async_iterator(value: Any) -> bool:
196+
try:
197+
return getattr(value, "__aiter__", None) is not None and getattr(value, "__anext__", None) is not None
198+
except Exception:
199+
return False
183200

184201

185-
async def _fanout_pipeline_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: dict[str, Any]) -> Any:
186-
with start_span(
187-
name="fanout_pipeline.run",
188-
type=SpanTypeAttribute.TASK,
189-
input=_args_kwargs_input(args, kwargs),
190-
metadata=_pipeline_metadata(args, kwargs),
191-
) as span:
192-
try:
193-
result = await wrapped(*args, **kwargs)
194-
span.log(output=result)
195-
return result
196-
except Exception as exc:
197-
span.log(error=str(exc))
198-
raise
202+
def _deferred_stream_trace(result: Any, span: Any, stack: contextlib.ExitStack, log_fn: Any) -> Any:
203+
"""Wrap an async iterator so the span stays open until the stream is consumed."""
204+
deferred = stack.pop_all()
205+
206+
async def _trace():
207+
with deferred:
208+
last_chunk = None
209+
async with aclosing(result) as agen:
210+
async for chunk in agen:
211+
last_chunk = chunk
212+
yield chunk
213+
if last_chunk is not None:
214+
log_fn(span, last_chunk)
215+
216+
return _trace()
199217

200218

201219
async def _toolkit_call_tool_function_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: dict[str, Any]) -> Any:
@@ -218,19 +236,7 @@ async def _toolkit_call_tool_function_wrapper(wrapped: Any, instance: Any, args:
218236
try:
219237
result = await wrapped(*args, **kwargs)
220238
if _is_async_iterator(result):
221-
deferred = stack.pop_all()
222-
223-
async def _trace():
224-
with deferred:
225-
last_chunk = None
226-
async with aclosing(result) as agen:
227-
async for chunk in agen:
228-
last_chunk = chunk
229-
yield chunk
230-
if last_chunk is not None:
231-
span.log(output=last_chunk)
232-
233-
return _trace()
239+
return _deferred_stream_trace(result, span, stack, lambda s, chunk: s.log(output=chunk))
234240

235241
span.log(output=result)
236242
return result
@@ -239,13 +245,6 @@ async def _trace():
239245
raise
240246

241247

242-
def _is_async_iterator(value: Any) -> bool:
243-
try:
244-
return getattr(value, "__aiter__", None) is not None and getattr(value, "__anext__", None) is not None
245-
except Exception:
246-
return False
247-
248-
249248
async def _model_call_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: dict[str, Any]) -> Any:
250249
with contextlib.ExitStack() as stack:
251250
span = stack.enter_context(
@@ -259,19 +258,12 @@ async def _model_call_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: di
259258
try:
260259
result = await wrapped(*args, **kwargs)
261260
if _is_async_iterator(result):
262-
deferred = stack.pop_all()
263-
264-
async def _trace():
265-
with deferred:
266-
last_chunk = None
267-
async with aclosing(result) as agen:
268-
async for chunk in agen:
269-
last_chunk = chunk
270-
yield chunk
271-
if last_chunk is not None:
272-
span.log(output=_model_call_output(last_chunk), metrics=_extract_metrics(last_chunk))
273-
274-
return _trace()
261+
return _deferred_stream_trace(
262+
result,
263+
span,
264+
stack,
265+
lambda s, chunk: s.log(output=_model_call_output(chunk), metrics=_extract_metrics(chunk)),
266+
)
275267

276268
span.log(output=_model_call_output(result), metrics=_extract_metrics(result))
277269
return result

0 commit comments

Comments
 (0)