diff --git a/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v1_tool_call_spans.yaml b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v1_tool_call_spans.yaml new file mode 100644 index 00000000..9b0c8ff9 --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v1_tool_call_spans.yaml @@ -0,0 +1,83 @@ +interactions: +- request: + body: '{"message":"Use the get_weather tool for Paris.","model":"command-a-03-2025","max_tokens":64,"tools":[{"name":"get_weather","description":"Get + the weather for a city.","parameter_definitions":{"city":{"description":"City + name","type":"str","required":true}}}],"force_single_step":true,"stream":false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '300' + Host: + - api.cohere.com + User-Agent: + - cohere/6.1.0 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.12.12 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 6.1.0 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v1/chat + response: + body: + string: '{"response_id":"490af82c-d74b-4774-96aa-a48abc77b39d","text":"I will + use one or more of the available tools to find the answer","generation_id":"465b7c4c-823c-4932-bd72-cc91b3aa154d","chat_history":[{"role":"USER","message":"Use + the get_weather tool for Paris."},{"role":"CHATBOT","message":"I will use + one or more of the available tools to find the answer","tool_calls":[{"name":"get_weather","parameters":{"city":"Paris"}}]}],"finish_reason":"COMPLETE","meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":42,"output_tokens":22},"tokens":{"input_tokens":1462,"output_tokens":33},"cached_tokens":0},"tool_calls":[{"name":"get_weather","parameters":{"city":"Paris"}}]}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-length: + - '684' + content-type: + - application/json + date: + - Fri, 01 May 2026 17:35:39 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + num_chars: + - '6894' + num_tokens: + - '64' + pragma: + - no-cache + server: + - envoy + vary: + - Origin,Accept-Encoding + x-accel-expires: + - '0' + x-debug-trace-id: + - 9a84e7af81536f5a0ae6409b072713da + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '988' + x-trial-endpoint-call-limit: + - '20' + x-trial-endpoint-call-remaining: + - '17' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v2_tool_call_spans.yaml b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v2_tool_call_spans.yaml new file mode 100644 index 00000000..6ef0ea00 --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_v2_tool_call_spans.yaml @@ -0,0 +1,81 @@ +interactions: +- request: + body: '{"model":"command-a-03-2025","messages":[{"role":"user","content":"Use + the get_weather tool for Paris."}],"tools":[{"type":"function","function":{"name":"get_weather","description":"Get + the weather for a city.","parameters":{"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}}}],"max_tokens":64,"tool_choice":"REQUIRED","stream":false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '361' + Host: + - api.cohere.com + User-Agent: + - cohere/6.1.0 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.12.12 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 6.1.0 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v2/chat + response: + body: + string: '{"id":"74529407-6112-4e14-9cb3-41ab95fe7725","message":{"role":"assistant","tool_plan":"I + will use one or more of the available tools to find the answer","tool_calls":[{"id":"get_weather_p1927913dgpp","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"Paris\"}"}}]},"finish_reason":"TOOL_CALL","usage":{"billed_units":{"input_tokens":37,"output_tokens":22},"tokens":{"input_tokens":1455,"output_tokens":33},"cached_tokens":0}}' + headers: + Alt-Svc: + - h3=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-length: + - '451' + content-type: + - application/json + date: + - Fri, 01 May 2026 17:34:51 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + num_chars: + - '6866' + num_tokens: + - '59' + pragma: + - no-cache + server: + - envoy + vary: + - Origin,Accept-Encoding + x-accel-expires: + - '0' + x-debug-trace-id: + - b3c262976868c73da50611172c4163a1 + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '913' + x-trial-endpoint-call-limit: + - '20' + x-trial-endpoint-call-remaining: + - '18' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/test_cohere.py b/py/src/braintrust/integrations/cohere/test_cohere.py index 7181964c..98900ac4 100644 --- a/py/src/braintrust/integrations/cohere/test_cohere.py +++ b/py/src/braintrust/integrations/cohere/test_cohere.py @@ -30,7 +30,8 @@ V2RerankPatcher, ) from braintrust.integrations.test_utils import assert_metrics_are_valid, verify_autoinstrument_script -from braintrust.test_helpers import init_test_logger +from braintrust.span_types import SpanTypeAttribute +from braintrust.test_helpers import find_spans_by_type, init_test_logger pytest.importorskip("cohere") @@ -260,6 +261,91 @@ def test_wrap_cohere_chat_v2_sync(memory_logger): assert_metrics_are_valid(span["metrics"], start, end) +@pytest.mark.vcr +def test_wrap_cohere_chat_v2_tool_call_spans(memory_logger): + if os.environ.get("BRAINTRUST_TEST_PACKAGE_VERSION") != "latest": + pytest.skip("v2 tool-call cassette is recorded for the latest Cohere SDK") + + assert not memory_logger.pop() + client = wrap_cohere(_v2_client(require_methods=("chat",))) + + response = client.chat( + model=CHAT_MODEL, + messages=[{"role": "user", "content": "Use the get_weather tool for Paris."}], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a city.", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ], + tool_choice="REQUIRED", + max_tokens=64, + ) + + tool_calls = response.message.tool_calls + assert tool_calls + assert tool_calls[0].function.name == "get_weather" + + spans = memory_logger.pop() + llm_spans = find_spans_by_type(spans, SpanTypeAttribute.LLM) + tool_spans = find_spans_by_type(spans, SpanTypeAttribute.TOOL) + + assert len(llm_spans) == 1 + assert len(tool_spans) == 1 + tool_span = tool_spans[0] + assert tool_span["span_attributes"]["name"] == "tool: get_weather" + assert tool_span["span_parents"] == [llm_spans[0]["span_id"]] + assert tool_span["metadata"]["tool_call_id"] == tool_calls[0].id + assert tool_span["metadata"]["tool_type"] == "function" + assert "Paris" in str(tool_span["input"]) + + +@pytest.mark.vcr +def test_wrap_cohere_chat_v1_tool_call_spans(memory_logger): + if os.environ.get("BRAINTRUST_TEST_PACKAGE_VERSION") != "latest": + pytest.skip("v1 tool-call cassette is recorded for the latest Cohere SDK") + + assert not memory_logger.pop() + client = wrap_cohere(_v1_client()) + + response = client.chat( + model=CHAT_MODEL, + message="Use the get_weather tool for Paris.", + tools=[ + { + "name": "get_weather", + "description": "Get the weather for a city.", + "parameter_definitions": {"city": {"description": "City name", "type": "str", "required": True}}, + } + ], + force_single_step=True, + max_tokens=64, + ) + + tool_calls = response.tool_calls + assert tool_calls + assert tool_calls[0].name == "get_weather" + + spans = memory_logger.pop() + llm_spans = find_spans_by_type(spans, SpanTypeAttribute.LLM) + tool_spans = find_spans_by_type(spans, SpanTypeAttribute.TOOL) + + assert len(llm_spans) == 1 + assert len(tool_spans) == 1 + tool_span = tool_spans[0] + assert tool_span["span_attributes"]["name"] == "tool: get_weather" + assert tool_span["span_parents"] == [llm_spans[0]["span_id"]] + assert "Paris" in str(tool_span["input"]) + + @pytest.mark.vcr def test_wrap_cohere_chat_v1_sync(memory_logger): assert not memory_logger.pop() diff --git a/py/src/braintrust/integrations/cohere/tracing.py b/py/src/braintrust/integrations/cohere/tracing.py index 938b7d22..82d9aec5 100644 --- a/py/src/braintrust/integrations/cohere/tracing.py +++ b/py/src/braintrust/integrations/cohere/tracing.py @@ -92,8 +92,16 @@ # --------------------------------------------------------------------------- -def _get(obj: Any, key: str) -> Any: - """Read *key* from *obj*, supporting both dicts and Pydantic/attribute objects.""" +def _get_field(obj: Any, key: str) -> Any: + """Return a field from either a mapping or a Cohere SDK model object. + + Cohere responses and stream events are not represented by one uniform type + across SDK versions and transports: some are plain dictionaries, while + others are Pydantic-style objects with attributes. Tracing code only needs + read-only access to named fields, so this helper centralizes that small + compatibility layer instead of scattering ``isinstance(..., dict)`` checks + throughout the normalization logic. + """ if obj is None: return None if isinstance(obj, dict): @@ -118,12 +126,12 @@ def _extract_response_metadata(result: Any) -> dict[str, Any]: metadata: dict[str, Any] = {} for key in _RESPONSE_METADATA_KEYS: - value = _get(result, key) + value = _get_field(result, key) if value is not None: metadata[key] = value - api_version = _get(_get(result, "meta"), "api_version") - version_value = _get(api_version, "version") + api_version = _get_field(_get_field(result, "meta"), "api_version") + version_value = _get_field(api_version, "version") if version_value is not None: metadata["api_version"] = version_value return metadata @@ -141,23 +149,23 @@ def _merge_usage_metrics(metrics: dict[str, float], usage: Any) -> None: ("total_tokens", "tokens"), ("cached_tokens", "prompt_cached_tokens"), ): - value = _get(usage, key) + value = _get_field(usage, key) if is_numeric(value): metrics[metric] = float(value) - tokens_block = _get(usage, "tokens") + tokens_block = _get_field(usage, "tokens") for src_key, metric in ( ("input_tokens", "prompt_tokens"), ("output_tokens", "completion_tokens"), ("total_tokens", "tokens"), ): - value = _get(tokens_block, src_key) + value = _get_field(tokens_block, src_key) if is_numeric(value): metrics[metric] = float(value) # billed_units is Cohere's authoritative counter for billing — it intentionally # overrides values from the top-level or ``tokens`` block when both are present. - billed = _get(usage, "billed_units") + billed = _get_field(usage, "billed_units") for src_key, metric in ( ("input_tokens", "prompt_tokens"), ("output_tokens", "completion_tokens"), @@ -166,7 +174,7 @@ def _merge_usage_metrics(metrics: dict[str, float], usage: Any) -> None: ("images", "images"), ("image_tokens", "image_tokens"), ): - value = _get(billed, src_key) + value = _get_field(billed, src_key) if is_numeric(value): metrics[metric] = float(value) @@ -178,8 +186,8 @@ def _parse_usage_metrics(result: Any) -> dict[str, float]: metrics: dict[str, float] = {} _merge_usage_metrics(metrics, result) - _merge_usage_metrics(metrics, _get(result, "usage")) - _merge_usage_metrics(metrics, _get(result, "meta")) + _merge_usage_metrics(metrics, _get_field(result, "usage")) + _merge_usage_metrics(metrics, _get_field(result, "meta")) if "tokens" not in metrics and "prompt_tokens" in metrics and "completion_tokens" in metrics: metrics["tokens"] = metrics["prompt_tokens"] + metrics["completion_tokens"] @@ -243,7 +251,7 @@ def _audio_transcription_output(result: Any) -> str | None: """Return the transcribed text string for a transcription response.""" if result is None: return None - text = _get(result, "text") + text = _get_field(result, "text") return text if isinstance(text, str) else None @@ -262,12 +270,12 @@ def _chat_output(result: Any) -> Any: if result is None: return None - message = _get(result, "message") + message = _get_field(result, "message") if message is not None: return message - text = _get(result, "text") - tool_calls = _get(result, "tool_calls") + text = _get_field(result, "text") + tool_calls = _get_field(result, "tool_calls") if tool_calls: return { "role": "assistant", @@ -302,7 +310,7 @@ def _iter_embedding_lists(embeddings: Any): def _embed_output(result: Any) -> dict[str, Any] | None: """Return ``{embedding_count, embedding_length}`` summary for an embed response.""" - embeddings = _get(result, "embeddings") + embeddings = _get_field(result, "embeddings") for entry in _iter_embedding_lists(embeddings): if isinstance(entry, list) and entry and isinstance(entry[0], list): return { @@ -318,7 +326,7 @@ def _rerank_output(result: Any) -> list[dict[str, Any]] | None: Each result is summarized to ``{index, relevance_score}`` — the document payload (if present via ``return_documents=True``) is dropped on purpose. """ - results = _get(result, "results") + results = _get_field(result, "results") if not isinstance(results, list): return None @@ -328,8 +336,8 @@ def _rerank_output(result: Any) -> list[dict[str, Any]] | None: continue out.append( { - "index": _get(item, "index"), - "relevance_score": _get(item, "relevance_score"), + "index": _get_field(item, "index"), + "relevance_score": _get_field(item, "relevance_score"), } ) return out @@ -363,10 +371,10 @@ def _merge_tool_call(existing: dict[str, Any] | None, incoming: dict[str, Any]) def _v2_delta_text(chunk: Any) -> str | None: - content = _get(_get(_get(chunk, "delta"), "message"), "content") + content = _get_field(_get_field(_get_field(chunk, "delta"), "message"), "content") if isinstance(content, str): return content - text = _get(content, "text") + text = _get_field(content, "text") return text if isinstance(text, str) else None @@ -376,7 +384,7 @@ def _delta_message_tool_calls(chunk: Any) -> list[Any]: The SDK may emit a single tool-call object or a list; this normalizes to a list so callers can iterate uniformly. """ - tool_calls = _get(_get(_get(chunk, "delta"), "message"), "tool_calls") + tool_calls = _get_field(_get_field(_get_field(chunk, "delta"), "message"), "tool_calls") if tool_calls is None: return [] if isinstance(tool_calls, list): @@ -451,36 +459,36 @@ def _aggregate_chat_stream(chunks: list[Any]) -> tuple[Any, dict[str, float], di for chunk in chunks: if chunk is None: continue - event_type = _get(chunk, "event_type") or _get(chunk, "type") + event_type = _get_field(chunk, "event_type") or _get_field(chunk, "type") # -- v1 shape --------------------------------------------------------- if event_type == "text-generation": - text = _get(chunk, "text") + text = _get_field(chunk, "text") if isinstance(text, str): text_parts.append(text) continue if event_type == "stream-end": - response = _get(chunk, "response") + response = _get_field(chunk, "response") if response is not None: terminal_response = response metrics.update(_parse_usage_metrics(response)) metadata.update(_extract_response_metadata(response)) - fr = _get(response, "finish_reason") + fr = _get_field(response, "finish_reason") if isinstance(fr, str): finish_reason = fr continue if event_type == "tool-calls-generation": - tool_calls = _get(chunk, "tool_calls") + tool_calls = _get_field(chunk, "tool_calls") if isinstance(tool_calls, list): _upsert_tool_calls(tool_calls_by_index, tool_call_order, tool_calls) continue # -- v2 shape --------------------------------------------------------- if event_type == "message-start": - msg_id = _get(chunk, "id") + msg_id = _get_field(chunk, "id") if isinstance(msg_id, str): metadata["id"] = msg_id - role_value = _get(_get(_get(chunk, "delta"), "message"), "role") + role_value = _get_field(_get_field(_get_field(chunk, "delta"), "message"), "role") if isinstance(role_value, str): role = role_value continue @@ -490,7 +498,7 @@ def _aggregate_chat_stream(chunks: list[Any]) -> tuple[Any, dict[str, float], di text_parts.append(text) continue if event_type == "tool-call-start": - chunk_index = _get(chunk, "index") + chunk_index = _get_field(chunk, "index") _upsert_tool_calls( tool_calls_by_index, tool_call_order, @@ -499,7 +507,7 @@ def _aggregate_chat_stream(chunks: list[Any]) -> tuple[Any, dict[str, float], di ) continue if event_type == "tool-call-delta": - chunk_index = _get(chunk, "index") + chunk_index = _get_field(chunk, "index") _upsert_tool_calls( tool_calls_by_index, tool_call_order, @@ -509,11 +517,11 @@ def _aggregate_chat_stream(chunks: list[Any]) -> tuple[Any, dict[str, float], di ) continue if event_type == "message-end": - delta = _get(chunk, "delta") - fr = _get(delta, "finish_reason") + delta = _get_field(chunk, "delta") + fr = _get_field(delta, "finish_reason") if isinstance(fr, str): finish_reason = fr - usage = _get(delta, "usage") + usage = _get_field(delta, "usage") if usage is not None: _merge_usage_metrics(metrics, usage) if "tokens" not in metrics and "prompt_tokens" in metrics and "completion_tokens" in metrics: @@ -556,15 +564,77 @@ def _start_span(name: str, span_input: Any, metadata: dict[str, Any]): ) +def _tool_call_function(tool_call: Any) -> Any: + return _get_field(tool_call, "function") + + +def _tool_call_name(tool_call: Any) -> str | None: + function = _tool_call_function(tool_call) + name = _get_field(function, "name") or _get_field(tool_call, "name") + return name if isinstance(name, str) and name else None + + +def _tool_call_input(tool_call: Any) -> Any: + function = _tool_call_function(tool_call) + arguments = _get_field(function, "arguments") + if arguments is not None: + return arguments + parameters = _get_field(function, "parameters") + if parameters is not None: + return parameters + return _get_field(tool_call, "parameters") + + +def _tool_call_metadata(tool_call: Any) -> dict[str, Any] | None: + metadata = { + "tool_call_id": _get_field(tool_call, "id") or _get_field(tool_call, "call_id"), + "tool_type": _get_field(tool_call, "type"), + } + return {k: v for k, v in metadata.items() if v is not None} or None + + +def _iter_tool_calls(output: Any): + if output is None: + return + output_dict = output if isinstance(output, dict) else _try_to_dict(output) + if not isinstance(output_dict, dict): + return + tool_calls = output_dict.get("tool_calls") + if not isinstance(tool_calls, list): + return + for tool_call in tool_calls: + if tool_call is not None: + yield tool_call + + +def _log_tool_call_spans(output: Any, *, parent_export: str | None) -> None: + for tool_call in _iter_tool_calls(output): + name = _tool_call_name(tool_call) + if name is None: + continue + span_args = { + "name": f"tool: {name}", + "type": SpanTypeAttribute.TOOL, + "input": _tool_call_input(tool_call), + "metadata": _tool_call_metadata(tool_call), + } + if parent_export is not None: + span_args["parent"] = parent_export + with start_span(**span_args): + pass + + def _log_call_result(span, output_fn, start_time: float, result: Any) -> None: """Log output/metrics/metadata for *result* and end *span*.""" metrics = { **_timing_metrics(start_time, time.time()), **_parse_usage_metrics(result), } + output = output_fn(result) + _log_tool_call_spans(output, parent_export=span.export()) _log_and_end_span( span, - output=output_fn(result), + output=output, metrics=metrics, metadata=_extract_response_metadata(result) or None, ) @@ -734,6 +804,7 @@ def _finish(self, error: BaseException | None = None) -> None: **_timing_metrics(self._start_time, time.time(), self._first_token_time), **usage_metrics, } + _log_tool_call_spans(output, parent_export=self._span.export()) _log_and_end_span( self._span, output=output,