diff --git a/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_stream_v2_rag_citations.yaml b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_stream_v2_rag_citations.yaml new file mode 100644 index 00000000..2730276d --- /dev/null +++ b/py/src/braintrust/integrations/cohere/cassettes/latest/test_wrap_cohere_chat_stream_v2_rag_citations.yaml @@ -0,0 +1,201 @@ +interactions: +- request: + body: '{"model":"command-a-03-2025","messages":[{"role":"user","content":"What + is Braintrust? Cite the provided document."}],"documents":[{"data":{"title":"Braintrust + overview","snippet":"Braintrust is a platform for evaluating, logging, and improving + AI applications."}}],"citation_options":{"mode":"fast"},"max_tokens":80,"stream":true}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '331' + Host: + - api.cohere.com + User-Agent: + - cohere/6.1.0 + X-Fern-Language: + - Python + X-Fern-Platform: + - darwin/25.2.0 + X-Fern-Runtime: + - python/3.12.12 + X-Fern-SDK-Name: + - cohere + X-Fern-SDK-Version: + - 6.1.0 + content-type: + - application/json + method: POST + uri: https://api.cohere.com/v2/chat + response: + body: + string: 'event: message-start + + data: {"id":"eb4d8291-1dde-42d1-a10b-e15931a1891f","type":"message-start","delta":{"message":{"role":"assistant","content":[],"tool_plan":"","tool_calls":[],"citations":[]}}} + + + event: content-start + + data: {"type":"content-start","index":0,"delta":{"message":{"content":{"type":"text","text":""}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"Brain"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"trust"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":" + is"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":" + a"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":" + "}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"platform"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":" + for"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":" + evaluating"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":","}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":" + logging"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":","}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":" + and"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":" + improving"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":" + AI"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":" + applications"}}}} + + + event: content-delta + + data: {"type":"content-delta","index":0,"delta":{"message":{"content":{"text":"."}}}} + + + event: citation-start + + data: {"type":"citation-start","index":0,"delta":{"message":{"citations":{"start":16,"end":80,"text":"platform + for evaluating, logging, and improving AI applications.","sources":[{"type":"document","id":"doc:0","document":{"id":"doc:0","snippet":"Braintrust + is a platform for evaluating, logging, and improving AI applications.","title":"Braintrust + overview"}}],"type":"TEXT_CONTENT"}}}} + + + event: citation-end + + data: {"type":"citation-end","index":0} + + + event: content-end + + data: {"type":"content-end","index":0} + + + event: message-end + + data: {"type":"message-end","delta":{"finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":34,"output_tokens":15},"tokens":{"input_tokens":1677,"output_tokens":27},"cached_tokens":0}}} + + + data: [DONE] + + + ' + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Transfer-Encoding: + - chunked + Via: + - 1.1 google + access-control-expose-headers: + - X-Debug-Trace-ID + cache-control: + - no-cache, no-store, no-transform, must-revalidate, private, max-age=0 + content-type: + - text/event-stream + date: + - Mon, 04 May 2026 17:51:26 GMT + expires: + - Thu, 01 Jan 1970 00:00:00 GMT + pragma: + - no-cache + server: + - envoy + vary: + - Origin + x-accel-expires: + - '0' + x-debug-trace-id: + - e47c41547aa42c0050d9ff38acdbe9c4 + x-endpoint-monthly-call-limit: + - '1000' + x-envoy-upstream-service-time: + - '19' + x-trial-endpoint-call-limit: + - '20' + x-trial-endpoint-call-remaining: + - '18' + status: + code: 200 + message: OK +version: 1 diff --git a/py/src/braintrust/integrations/cohere/test_cohere.py b/py/src/braintrust/integrations/cohere/test_cohere.py index 98900ac4..0bbeb0bd 100644 --- a/py/src/braintrust/integrations/cohere/test_cohere.py +++ b/py/src/braintrust/integrations/cohere/test_cohere.py @@ -567,6 +567,53 @@ def test_wrap_cohere_chat_stream_v2_sync(memory_logger): assert metrics.get("completion_tokens", 0) > 0 +@pytest.mark.vcr +def test_wrap_cohere_chat_stream_v2_rag_citations(memory_logger): + if os.environ.get("BRAINTRUST_TEST_PACKAGE_VERSION") != "latest": + pytest.skip("v2 RAG citation cassette is recorded for the latest Cohere SDK") + + assert not memory_logger.pop() + client = wrap_cohere(_v2_client(require_methods=("chat_stream",))) + documents = [ + { + "data": { + "title": "Braintrust overview", + "snippet": "Braintrust is a platform for evaluating, logging, and improving AI applications.", + } + } + ] + citation_options = {"mode": "fast"} + + events = list( + client.chat_stream( + model=CHAT_MODEL, + messages=[{"role": "user", "content": "What is Braintrust? Cite the provided document."}], + documents=documents, + citation_options=citation_options, + max_tokens=80, + ) + ) + + assert events + event_types = [getattr(event, "type", None) or getattr(event, "event_type", None) for event in events] + assert "citation-start" in event_types + + spans = memory_logger.pop() + assert len(spans) == 1 + span = spans[0] + + assert span["metadata"]["documents"] == documents + assert span["metadata"]["citation_options"] == citation_options + output = span["output"] + assert isinstance(output, dict) + citations = output.get("citations") + assert isinstance(citations, list) and citations + assert citations[0].get("start") is not None + assert citations[0].get("end") is not None + assert citations[0].get("text") + assert citations[0].get("sources") + + @pytest.mark.vcr def test_wrap_cohere_chat_v1_async(memory_logger): assert not memory_logger.pop() diff --git a/py/src/braintrust/integrations/cohere/tracing.py b/py/src/braintrust/integrations/cohere/tracing.py index 82d9aec5..dfc89fd2 100644 --- a/py/src/braintrust/integrations/cohere/tracing.py +++ b/py/src/braintrust/integrations/cohere/tracing.py @@ -50,6 +50,8 @@ "presence_penalty", "raw_prompting", "search_queries_only", + "documents", + "citation_options", "strict_tools", "tool_choice", "tools", @@ -393,7 +395,7 @@ def _delta_message_tool_calls(chunk: Any) -> list[Any]: def _as_merge_dict(value: Any) -> dict[str, Any] | None: - """Coerce a tool-call delta to a plain dict for merging, or return ``None``.""" + """Coerce a provider object to a plain dict for merging, or return ``None``.""" if isinstance(value, dict): return value converted = _try_to_dict(value) @@ -455,6 +457,8 @@ def _aggregate_chat_stream(chunks: list[Any]) -> tuple[Any, dict[str, float], di finish_reason: str | None = None metadata: dict[str, Any] = {} metrics: dict[str, float] = {} + citations_by_index: dict[int, dict[str, Any]] = {} + citation_order: list[int] = [] for chunk in chunks: if chunk is None: @@ -516,6 +520,16 @@ def _aggregate_chat_stream(chunks: list[Any]) -> tuple[Any, dict[str, float], di merge=True, ) continue + if event_type == "citation-start": + citation = _get_field(_get_field(_get_field(chunk, "delta"), "message"), "citations") + citation_dict = _as_merge_dict(citation) + chunk_index = _get_field(chunk, "index") + idx = chunk_index if isinstance(chunk_index, int) else len(citation_order) + if citation_dict is not None: + if idx not in citation_order: + citation_order.append(idx) + citations_by_index[idx] = citation_dict + continue if event_type == "message-end": delta = _get_field(chunk, "delta") fr = _get_field(delta, "finish_reason") @@ -531,18 +545,28 @@ def _aggregate_chat_stream(chunks: list[Any]) -> tuple[Any, dict[str, float], di merged_tool_calls = [ tool_calls_by_index[i] for i in sorted(tool_call_order) if isinstance(tool_calls_by_index.get(i), dict) ] + merged_citations = [ + citations_by_index[i] for i in sorted(citation_order) if isinstance(citations_by_index.get(i), dict) + ] output: Any = _chat_output(terminal_response) if terminal_response is not None else None if output is None: + output_dict = {} merged_text = "".join(text_parts) - if merged_tool_calls or role or merged_text: - output = {} - if role: - output["role"] = role - if merged_text: - output["content"] = merged_text - if merged_tool_calls: - output["tool_calls"] = merged_tool_calls + if role: + output_dict["role"] = role + if merged_text: + output_dict["content"] = merged_text + if merged_tool_calls: + output_dict["tool_calls"] = merged_tool_calls + if merged_citations: + output_dict["citations"] = merged_citations + output = output_dict or None + elif merged_citations: + output_dict = output if isinstance(output, dict) else _try_to_dict(output) + if isinstance(output_dict, dict): + output_dict.setdefault("citations", merged_citations) + output = output_dict if finish_reason is not None: metadata["finish_reason"] = finish_reason