From cbe91aa0b28ef1e2ef6054d12aab27a69a204094 Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Fri, 1 May 2026 13:55:46 -0400 Subject: [PATCH] fix(cohere): add tool spans for chat tool calls Emit child TOOL spans for Cohere chat responses that include tool calls, covering both v1 and v2 response shapes. Rename the shared Cohere field accessor to document why tracing handles both dict and SDK model objects.\n\nAdd VCR-backed regression coverage for v1 and v2 tool-call responses.\n\nTest: nox -s "test_cohere(latest)" -- -k "tool_call_spans" --- ...t_wrap_cohere_chat_v1_tool_call_spans.yaml | 83 ++++++++++ ...t_wrap_cohere_chat_v2_tool_call_spans.yaml | 81 ++++++++++ .../integrations/cohere/test_cohere.py | 88 ++++++++++- .../braintrust/integrations/cohere/tracing.py | 143 +++++++++++++----- 4 files changed, 358 insertions(+), 37 deletions(-) create mode 100644 py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v1_tool_call_spans.yaml create mode 100644 py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v2_tool_call_spans.yaml diff --git a/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v1_tool_call_spans.yaml b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v1_tool_call_spans.yaml new file mode 100644 index 00000000..9b0c8ff9 --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v1_tool_call_spans.yaml @@ -0,0 +1,83 @@ +interactions: +- request: + body: '{"message":"Use the get_weather tool for Paris.","model":"command-a-03-2025","max_tokens":64,"tools":[{"name":"get_weather","description":"Get + the weather for a city.","parameter_definitions":{"city":{"description":"City + name","type":"str","required":true}}}],"force_single_step":true,"stream":false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '300' + Host: + - api.cohere.com + User-Agent: + - cohere/6.1.0 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.12.12 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 6.1.0 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v1/chat + response: + body: + string: '{"response_id":"490af82c-d74b-4774-96aa-a48abc77b39d","text":"I will + use one or more of the available tools to find the answer","generation_id":"465b7c4c-823c-4932-bd72-cc91b3aa154d","chat_history":[{"role":"USER","message":"Use + the get_weather tool for Paris."},{"role":"CHATBOT","message":"I will use + one or more of the available tools to find the answer","tool_calls":[{"name":"get_weather","parameters":{"city":"Paris"}}]}],"finish_reason":"COMPLETE","meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":42,"output_tokens":22},"tokens":{"input_tokens":1462,"output_tokens":33},"cached_tokens":0},"tool_calls":[{"name":"get_weather","parameters":{"city":"Paris"}}]}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-length: + - '684' + content-type: + - application/json + date: + - Fri, 01 May 2026 17:35:39 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + num_chars: + - '6894' + num_tokens: + - '64' + pragma: + - no-cache + server: + - envoy + vary: + - Origin,Accept-Encoding + x-accel-expires: + - '0' + x-debug-trace-id: + - 9a84e7af81536f5a0ae6409b072713da + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '988' + x-trial-endpoint-call-limit: + - '20' + x-trial-endpoint-call-remaining: + - '17' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v2_tool_call_spans.yaml b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v2_tool_call_spans.yaml new file mode 100644 index 00000000..6ef0ea00 --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v2_tool_call_spans.yaml @@ -0,0 +1,81 @@ +interactions: +- request: + body: '{"model":"command-a-03-2025","messages":[{"role":"user","content":"Use + the get_weather tool for Paris."}],"tools":[{"type":"function","function":{"name":"get_weather","description":"Get + the weather for a city.","parameters":{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}}}],"max_tokens":64,"tool_choice":"REQUIRED","stream":false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '361' + Host: + - api.cohere.com + User-Agent: + - cohere/6.1.0 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.12.12 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 6.1.0 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v2/chat + response: + body: + string: '{"id":"74529407-6112-4e14-9cb3-41ab95fe7725","message":{"role":"assistant","tool_plan":"I + will use one or more of the available tools to find the answer","tool_calls":[{"id":"get_weather_p1927913dgpp","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"Paris\"}"}}]},"finish_reason":"TOOL_CALL","usage":{"billed_units":{"input_tokens":37,"output_tokens":22},"tokens":{"input_tokens":1455,"output_tokens":33},"cached_tokens":0}}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-length: + - '451' + content-type: + - application/json + date: + - Fri, 01 May 2026 17:34:51 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + num_chars: + - '6866' + num_tokens: + - '59' + pragma: + - no-cache + server: + - envoy + vary: + - Origin,Accept-Encoding + x-accel-expires: + - '0' + x-debug-trace-id: + - b3c262976868c73da50611172c4163a1 + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '913' + x-trial-endpoint-call-limit: + - '20' + x-trial-endpoint-call-remaining: + - '18' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/test_cohere.py b/py/src/braintrust/integrations/cohere/test_cohere.py index 7181964c..98900ac4 100644 --- a/py/src/braintrust/integrations/cohere/test_cohere.py +++ b/py/src/braintrust/integrations/cohere/test_cohere.py @@ -30,7 +30,8 @@ V2RerankPatcher, ) from braintrust.integrations.test_utils import assert_metrics_are_valid, verify_autoinstrument_script -from braintrust.test_helpers import init_test_logger +from braintrust.span_types import SpanTypeAttribute +from braintrust.test_helpers import find_spans_by_type, init_test_logger pytest.importorskip("cohere") @@ -260,6 +261,91 @@ def test_wrap_cohere_chat_v2_sync(memory_logger): assert_metrics_are_valid(span["metrics"], start, end) +@pytest.mark.vcr +def test_wrap_cohere_chat_v2_tool_call_spans(memory_logger): + if os.environ.get("BRAINTRUST_TEST_PACKAGE_VERSION") != "latest": + pytest.skip("v2 tool-call cassette is recorded for the latest Cohere SDK") + + assert not memory_logger.pop() + client = wrap_cohere(_v2_client(require_methods=("chat",))) + + response = client.chat( + model=CHAT_MODEL, + messages=[{"role": "user", "content": "Use the get_weather tool for Paris."}], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a city.", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ], + tool_choice="REQUIRED", + max_tokens=64, + ) + + tool_calls = response.message.tool_calls + assert tool_calls + assert tool_calls[0].function.name == "get_weather" + + spans = memory_logger.pop() + llm_spans = find_spans_by_type(spans, SpanTypeAttribute.LLM) + tool_spans = find_spans_by_type(spans, SpanTypeAttribute.TOOL) + + assert len(llm_spans) == 1 + assert len(tool_spans) == 1 + tool_span = tool_spans[0] + assert tool_span["span_attributes"]["name"] == "tool: get_weather" + assert tool_span["span_parents"] == [llm_spans[0]["span_id"]] + assert tool_span["metadata"]["tool_call_id"] == tool_calls[0].id + assert tool_span["metadata"]["tool_type"] == "function" + assert "Paris" in str(tool_span["input"]) + + +@pytest.mark.vcr +def test_wrap_cohere_chat_v1_tool_call_spans(memory_logger): + if os.environ.get("BRAINTRUST_TEST_PACKAGE_VERSION") != "latest": + pytest.skip("v1 tool-call cassette is recorded for the latest Cohere SDK") + + assert not memory_logger.pop() + client = wrap_cohere(_v1_client()) + + response = client.chat( + model=CHAT_MODEL, + message="Use the get_weather tool for Paris.", + tools=[ + { + "name": "get_weather", + "description": "Get the weather for a city.", + "parameter_definitions": {"city": {"description": "City name", "type": "str", "required": True}}, + } + ], + force_single_step=True, + max_tokens=64, + ) + + tool_calls = response.tool_calls + assert tool_calls + assert tool_calls[0].name == "get_weather" + + spans = memory_logger.pop() + llm_spans = find_spans_by_type(spans, SpanTypeAttribute.LLM) + tool_spans = find_spans_by_type(spans, SpanTypeAttribute.TOOL) + + assert len(llm_spans) == 1 + assert len(tool_spans) == 1 + tool_span = tool_spans[0] + assert tool_span["span_attributes"]["name"] == "tool: get_weather" + assert tool_span["span_parents"] == [llm_spans[0]["span_id"]] + assert "Paris" in str(tool_span["input"]) + + @pytest.mark.vcr def test_wrap_cohere_chat_v1_sync(memory_logger): assert not memory_logger.pop() diff --git a/py/src/braintrust/integrations/cohere/tracing.py b/py/src/braintrust/integrations/cohere/tracing.py index 938b7d22..82d9aec5 100644 --- a/py/src/braintrust/integrations/cohere/tracing.py +++ b/py/src/braintrust/integrations/cohere/tracing.py @@ -92,8 +92,16 @@ # --------------------------------------------------------------------------- -def _get(obj: Any, key: str) -> Any: - """Read *key* from *obj*, supporting both dicts and Pydantic/attribute objects.""" +def _get_field(obj: Any, key: str) -> Any: + """Return a field from either a mapping or a Cohere SDK model object. + + Cohere responses and stream events are not represented by one uniform type + across SDK versions and transports: some are plain dictionaries, while + others are Pydantic-style objects with attributes. Tracing code only needs + read-only access to named fields, so this helper centralizes that small + compatibility layer instead of scattering ``isinstance(..., dict)`` checks + throughout the normalization logic. + """ if obj is None: return None if isinstance(obj, dict): @@ -118,12 +126,12 @@ def _extract_response_metadata(result: Any) -> dict[str, Any]: metadata: dict[str, Any] = {} for key in _RESPONSE_METADATA_KEYS: - value = _get(result, key) + value = _get_field(result, key) if value is not None: metadata[key] = value - api_version = _get(_get(result, "meta"), "api_version") - version_value = _get(api_version, "version") + api_version = _get_field(_get_field(result, "meta"), "api_version") + version_value = _get_field(api_version, "version") if version_value is not None: metadata["api_version"] = version_value return metadata @@ -141,23 +149,23 @@ def _merge_usage_metrics(metrics: dict[str, float], usage: Any) -> None: ("total_tokens", "tokens"), ("cached_tokens", "prompt_cached_tokens"), ): - value = _get(usage, key) + value = _get_field(usage, key) if is_numeric(value): metrics[metric] = float(value) - tokens_block = _get(usage, "tokens") + tokens_block = _get_field(usage, "tokens") for src_key, metric in ( ("input_tokens", "prompt_tokens"), ("output_tokens", "completion_tokens"), ("total_tokens", "tokens"), ): - value = _get(tokens_block, src_key) + value = _get_field(tokens_block, src_key) if is_numeric(value): metrics[metric] = float(value) # billed_units is Cohere's authoritative counter for billing — it intentionally # overrides values from the top-level or ``tokens`` block when both are present. - billed = _get(usage, "billed_units") + billed = _get_field(usage, "billed_units") for src_key, metric in ( ("input_tokens", "prompt_tokens"), ("output_tokens", "completion_tokens"), @@ -166,7 +174,7 @@ def _merge_usage_metrics(metrics: dict[str, float], usage: Any) -> None: ("images", "images"), ("image_tokens", "image_tokens"), ): - value = _get(billed, src_key) + value = _get_field(billed, src_key) if is_numeric(value): metrics[metric] = float(value) @@ -178,8 +186,8 @@ def _parse_usage_metrics(result: Any) -> dict[str, float]: metrics: dict[str, float] = {} _merge_usage_metrics(metrics, result) - _merge_usage_metrics(metrics, _get(result, "usage")) - _merge_usage_metrics(metrics, _get(result, "meta")) + _merge_usage_metrics(metrics, _get_field(result, "usage")) + _merge_usage_metrics(metrics, _get_field(result, "meta")) if "tokens" not in metrics and "prompt_tokens" in metrics and "completion_tokens" in metrics: metrics["tokens"] = metrics["prompt_tokens"] + metrics["completion_tokens"] @@ -243,7 +251,7 @@ def _audio_transcription_output(result: Any) -> str | None: """Return the transcribed text string for a transcription response.""" if result is None: return None - text = _get(result, "text") + text = _get_field(result, "text") return text if isinstance(text, str) else None @@ -262,12 +270,12 @@ def _chat_output(result: Any) -> Any: if result is None: return None - message = _get(result, "message") + message = _get_field(result, "message") if message is not None: return message - text = _get(result, "text") - tool_calls = _get(result, "tool_calls") + text = _get_field(result, "text") + tool_calls = _get_field(result, "tool_calls") if tool_calls: return { "role": "assistant", @@ -302,7 +310,7 @@ def _iter_embedding_lists(embeddings: Any): def _embed_output(result: Any) -> dict[str, Any] | None: """Return ``{embedding_count, embedding_length}`` summary for an embed response.""" - embeddings = _get(result, "embeddings") + embeddings = _get_field(result, "embeddings") for entry in _iter_embedding_lists(embeddings): if isinstance(entry, list) and entry and isinstance(entry[0], list): return { @@ -318,7 +326,7 @@ def _rerank_output(result: Any) -> list[dict[str, Any]] | None: Each result is summarized to ``{index, relevance_score}`` — the document payload (if present via ``return_documents=True``) is dropped on purpose. """ - results = _get(result, "results") + results = _get_field(result, "results") if not isinstance(results, list): return None @@ -328,8 +336,8 @@ def _rerank_output(result: Any) -> list[dict[str, Any]] | None: continue out.append( { - "index": _get(item, "index"), - "relevance_score": _get(item, "relevance_score"), + "index": _get_field(item, "index"), + "relevance_score": _get_field(item, "relevance_score"), } ) return out @@ -363,10 +371,10 @@ def _merge_tool_call(existing: dict[str, Any] | None, incoming: dict[str, Any]) def _v2_delta_text(chunk: Any) -> str | None: - content = _get(_get(_get(chunk, "delta"), "message"), "content") + content = _get_field(_get_field(_get_field(chunk, "delta"), "message"), "content") if isinstance(content, str): return content - text = _get(content, "text") + text = _get_field(content, "text") return text if isinstance(text, str) else None @@ -376,7 +384,7 @@ def _delta_message_tool_calls(chunk: Any) -> list[Any]: The SDK may emit a single tool-call object or a list; this normalizes to a list so callers can iterate uniformly. """ - tool_calls = _get(_get(_get(chunk, "delta"), "message"), "tool_calls") + tool_calls = _get_field(_get_field(_get_field(chunk, "delta"), "message"), "tool_calls") if tool_calls is None: return [] if isinstance(tool_calls, list): @@ -451,36 +459,36 @@ def _aggregate_chat_stream(chunks: list[Any]) -> tuple[Any, dict[str, float], di for chunk in chunks: if chunk is None: continue - event_type = _get(chunk, "event_type") or _get(chunk, "type") + event_type = _get_field(chunk, "event_type") or _get_field(chunk, "type") # -- v1 shape --------------------------------------------------------- if event_type == "text-generation": - text = _get(chunk, "text") + text = _get_field(chunk, "text") if isinstance(text, str): text_parts.append(text) continue if event_type == "stream-end": - response = _get(chunk, "response") + response = _get_field(chunk, "response") if response is not None: terminal_response = response metrics.update(_parse_usage_metrics(response)) metadata.update(_extract_response_metadata(response)) - fr = _get(response, "finish_reason") + fr = _get_field(response, "finish_reason") if isinstance(fr, str): finish_reason = fr continue if event_type == "tool-calls-generation": - tool_calls = _get(chunk, "tool_calls") + tool_calls = _get_field(chunk, "tool_calls") if isinstance(tool_calls, list): _upsert_tool_calls(tool_calls_by_index, tool_call_order, tool_calls) continue # -- v2 shape --------------------------------------------------------- if event_type == "message-start": - msg_id = _get(chunk, "id") + msg_id = _get_field(chunk, "id") if isinstance(msg_id, str): metadata["id"] = msg_id - role_value = _get(_get(_get(chunk, "delta"), "message"), "role") + role_value = _get_field(_get_field(_get_field(chunk, "delta"), "message"), "role") if isinstance(role_value, str): role = role_value continue @@ -490,7 +498,7 @@ def _aggregate_chat_stream(chunks: list[Any]) -> tuple[Any, dict[str, float], di text_parts.append(text) continue if event_type == "tool-call-start": - chunk_index = _get(chunk, "index") + chunk_index = _get_field(chunk, "index") _upsert_tool_calls( tool_calls_by_index, tool_call_order, @@ -499,7 +507,7 @@ def _aggregate_chat_stream(chunks: list[Any]) -> tuple[Any, dict[str, float], di ) continue if event_type == "tool-call-delta": - chunk_index = _get(chunk, "index") + chunk_index = _get_field(chunk, "index") _upsert_tool_calls( tool_calls_by_index, tool_call_order, @@ -509,11 +517,11 @@ def _aggregate_chat_stream(chunks: list[Any]) -> tuple[Any, dict[str, float], di ) continue if event_type == "message-end": - delta = _get(chunk, "delta") - fr = _get(delta, "finish_reason") + delta = _get_field(chunk, "delta") + fr = _get_field(delta, "finish_reason") if isinstance(fr, str): finish_reason = fr - usage = _get(delta, "usage") + usage = _get_field(delta, "usage") if usage is not None: _merge_usage_metrics(metrics, usage) if "tokens" not in metrics and "prompt_tokens" in metrics and "completion_tokens" in metrics: @@ -556,15 +564,77 @@ def _start_span(name: str, span_input: Any, metadata: dict[str, Any]): ) +def _tool_call_function(tool_call: Any) -> Any: + return _get_field(tool_call, "function") + + +def _tool_call_name(tool_call: Any) -> str | None: + function = _tool_call_function(tool_call) + name = _get_field(function, "name") or _get_field(tool_call, "name") + return name if isinstance(name, str) and name else None + + +def _tool_call_input(tool_call: Any) -> Any: + function = _tool_call_function(tool_call) + arguments = _get_field(function, "arguments") + if arguments is not None: + return arguments + parameters = _get_field(function, "parameters") + if parameters is not None: + return parameters + return _get_field(tool_call, "parameters") + + +def _tool_call_metadata(tool_call: Any) -> dict[str, Any] | None: + metadata = { + "tool_call_id": _get_field(tool_call, "id") or _get_field(tool_call, "call_id"), + "tool_type": _get_field(tool_call, "type"), + } + return {k: v for k, v in metadata.items() if v is not None} or None + + +def _iter_tool_calls(output: Any): + if output is None: + return + output_dict = output if isinstance(output, dict) else _try_to_dict(output) + if not isinstance(output_dict, dict): + return + tool_calls = output_dict.get("tool_calls") + if not isinstance(tool_calls, list): + return + for tool_call in tool_calls: + if tool_call is not None: + yield tool_call + + +def _log_tool_call_spans(output: Any, *, parent_export: str | None) -> None: + for tool_call in _iter_tool_calls(output): + name = _tool_call_name(tool_call) + if name is None: + continue + span_args = { + "name": f"tool: {name}", + "type": SpanTypeAttribute.TOOL, + "input": _tool_call_input(tool_call), + "metadata": _tool_call_metadata(tool_call), + } + if parent_export is not None: + span_args["parent"] = parent_export + with start_span(**span_args): + pass + + def _log_call_result(span, output_fn, start_time: float, result: Any) -> None: """Log output/metrics/metadata for *result* and end *span*.""" metrics = { **_timing_metrics(start_time, time.time()), **_parse_usage_metrics(result), } + output = output_fn(result) + _log_tool_call_spans(output, parent_export=span.export()) _log_and_end_span( span, - output=output_fn(result), + output=output, metrics=metrics, metadata=_extract_response_metadata(result) or None, ) @@ -734,6 +804,7 @@ def _finish(self, error: BaseException | None = None) -> None: **_timing_metrics(self._start_time, time.time(), self._first_token_time), **usage_metrics, } + _log_tool_call_spans(output, parent_export=self._span.export()) _log_and_end_span( self._span, output=output,