diff --git a/.changeset/fine-cameras-allow.md b/.changeset/fine-cameras-allow.md new file mode 100644 index 000000000..67d8d4082 --- /dev/null +++ b/.changeset/fine-cameras-allow.md @@ -0,0 +1,5 @@ +--- +"braintrust": patch +--- + +fix(cohere): Wrap v2 subclient diff --git a/e2e/scenarios/cohere-instrumentation/__snapshots__/cohere-v7-14-0-wrapped.span-events.json b/e2e/scenarios/cohere-instrumentation/__snapshots__/cohere-v7-14-0-wrapped.span-events.json new file mode 100644 index 000000000..cf928fca7 --- /dev/null +++ b/e2e/scenarios/cohere-instrumentation/__snapshots__/cohere-v7-14-0-wrapped.span-events.json @@ -0,0 +1,152 @@ +[ + { + "has_input": false, + "has_output": false, + "metadata": { + "scenario": "cohere-instrumentation" + }, + "metric_keys": [], + "name": "cohere-instrumentation-root", + "root_span_id": "", + "span_id": "", + "span_parents": [], + "type": "task" + }, + { + "has_input": false, + "has_output": false, + "metadata": { + "operation": "chat" + }, + "metric_keys": [], + "name": "cohere-v2-chat-operation", + "root_span_id": "", + "span_id": "", + "span_parents": [ + "" + ], + "type": null + }, + { + "has_input": true, + "has_output": true, + "metadata": { + "model": "command-a-03-2025", + "provider": "cohere" + }, + "metric_keys": [ + "completion_tokens", + "prompt_cached_tokens", + "prompt_tokens", + "time_to_first_token", + "tokens" + ], + "name": "cohere.chat", + "root_span_id": "", + "span_id": "", + "span_parents": [ + "" + ], + "type": "llm" + }, + { + "has_input": false, + "has_output": false, + "metadata": { + "operation": "chat-stream" + }, + "metric_keys": [], + "name": "cohere-v2-chat-stream-operation", + "root_span_id": "", + "span_id": "", + "span_parents": [ + "" + ], + "type": null + }, + { + "has_input": true, + "has_output": false, + "metadata": { + "model": "command-a-03-2025", + "provider": "cohere" + }, + "metric_keys": [], + "name": "cohere.chatStream", + "root_span_id": "", + "span_id": "", + "span_parents": [ + "" + ], + "type": "llm" + }, + { + "has_input": false, + "has_output": false, + "metadata": { + "operation": "embed" + }, + "metric_keys": [], + "name": "cohere-v2-embed-operation", + "root_span_id": "", + "span_id": "", + "span_parents": [ + "" + ], + "type": null + }, + { + "has_input": true, + "has_output": true, + "metadata": { + "inputType": "search_document", + "model": "embed-english-v3.0", + "provider": "cohere" + }, + "metric_keys": [ + "prompt_tokens" + ], + "name": "cohere.embed", + "root_span_id": "", + "span_id": "", + "span_parents": [ + "" + ], + "type": "llm" + }, + { + "has_input": false, + "has_output": false, + "metadata": { + "operation": "rerank" + }, + "metric_keys": [], + "name": "cohere-v2-rerank-operation", + "root_span_id": "", + "span_id": "", + "span_parents": [ + "" + ], + "type": null + }, + { + "has_input": true, + "has_output": true, + "metadata": { + "document_count": 3, + "model": "rerank-english-v3.0", + "provider": "cohere", + "topN": 2 + }, + "metric_keys": [ + "search_units" + ], + "name": "cohere.rerank", + "root_span_id": "", + "span_id": "", + "span_parents": [ + "" + ], + "type": "llm" + } +] diff --git a/e2e/scenarios/cohere-instrumentation/__snapshots__/cohere-v7-14-0.span-events.json b/e2e/scenarios/cohere-instrumentation/__snapshots__/cohere-v7-14-0.span-events.json index f31d4de70..1ab342dc3 100644 --- a/e2e/scenarios/cohere-instrumentation/__snapshots__/cohere-v7-14-0.span-events.json +++ b/e2e/scenarios/cohere-instrumentation/__snapshots__/cohere-v7-14-0.span-events.json @@ -19,7 +19,7 @@ "operation": "chat" }, "metric_keys": [], - "name": "cohere-chat-operation", + "name": "cohere-v2-chat-operation", "root_span_id": "", "span_id": "", "span_parents": [ @@ -56,7 +56,7 @@ "operation": "chat-stream" }, "metric_keys": [], - "name": "cohere-chat-stream-operation", + "name": "cohere-v2-chat-stream-operation", "root_span_id": "", "span_id": "", "span_parents": [ @@ -93,7 +93,7 @@ "operation": "embed" }, "metric_keys": [], - "name": "cohere-embed-operation", + "name": "cohere-v2-embed-operation", "root_span_id": "", "span_id": "", "span_parents": [ @@ -127,7 +127,7 @@ "operation": "rerank" }, "metric_keys": [], - "name": "cohere-rerank-operation", + "name": "cohere-v2-rerank-operation", "root_span_id": "", "span_id": "", "span_parents": [ diff --git a/e2e/scenarios/cohere-instrumentation/assertions.ts b/e2e/scenarios/cohere-instrumentation/assertions.ts index 5e21c458d..bd0a6941a 100644 --- a/e2e/scenarios/cohere-instrumentation/assertions.ts +++ b/e2e/scenarios/cohere-instrumentation/assertions.ts @@ -42,21 +42,40 @@ function isCohereProviderLimitError(error: unknown): boolean { return message.includes("TooManyRequestsError") || message.includes("429"); } +function getOperationName( + baseName: string, + { useV2Namespace }: { useV2Namespace: boolean }, +) { + return useV2Namespace + ? `cohere-v2-${baseName}-operation` + : `cohere-${baseName}-operation`; +} + function buildSpanSummary( events: CapturedLogEvent[], supportsThinking: boolean, + useV2Namespace: boolean, ): Json { - const chatOperation = findLatestSpan(events, "cohere-chat-operation"); + const chatOperation = findLatestSpan( + events, + getOperationName("chat", { useV2Namespace }), + ); const chatStreamOperation = findLatestSpan( events, - "cohere-chat-stream-operation", + getOperationName("chat-stream", { useV2Namespace }), ); const chatStreamThinkingOperation = findLatestSpan( events, - "cohere-chat-stream-thinking-operation", + getOperationName("chat-stream-thinking", { useV2Namespace }), + ); + const embedOperation = findLatestSpan( + events, + getOperationName("embed", { useV2Namespace }), + ); + const rerankOperation = findLatestSpan( + events, + getOperationName("rerank", { useV2Namespace }), ); - const embedOperation = findLatestSpan(events, "cohere-embed-operation"); - const rerankOperation = findLatestSpan(events, "cohere-rerank-operation"); const summaryEvents = [ findLatestSpan(events, ROOT_NAME), @@ -104,6 +123,8 @@ export function defineCohereInstrumentationAssertions(options: { supportsThinking: boolean; testFileUrl: string; timeoutMs: number; + requireChatStreamOutput?: boolean; + useV2Namespace?: boolean; }): void { const spanSnapshotPath = resolveFileSnapshotPath( options.testFileUrl, @@ -151,7 +172,13 @@ export function defineCohereInstrumentationAssertions(options: { context.skip(); } - const chatOperation = findLatestSpan(events, "cohere-chat-operation"); + const chatOperationName = getOperationName("chat", { + useV2Namespace: options.useV2Namespace ?? false, + }); + const chatStreamOperationName = getOperationName("chat-stream", { + useV2Namespace: options.useV2Namespace ?? false, + }); + const chatOperation = findLatestSpan(events, chatOperationName); const chatSpan = findCohereSpan( events, chatOperation?.span.id, @@ -159,7 +186,7 @@ export function defineCohereInstrumentationAssertions(options: { ); const chatStreamOperation = findLatestSpan( events, - "cohere-chat-stream-operation", + chatStreamOperationName, ); const chatStreamSpan = findCohereSpan( events, @@ -179,7 +206,9 @@ export function defineCohereInstrumentationAssertions(options: { expect(chatStreamSpan?.row.metadata).toMatchObject({ provider: "cohere", }); - expect(chatStreamSpan?.output).toBeDefined(); + if (options.requireChatStreamOutput ?? true) { + expect(chatStreamSpan?.output).toBeDefined(); + } }); if (options.supportsThinking) { @@ -194,7 +223,9 @@ export function defineCohereInstrumentationAssertions(options: { const root = findLatestSpan(events, ROOT_NAME); const operation = findLatestSpan( events, - "cohere-chat-stream-thinking-operation", + getOperationName("chat-stream-thinking", { + useV2Namespace: options.useV2Namespace ?? false, + }), ); const span = findCohereSpan( events, @@ -250,7 +281,12 @@ export function defineCohereInstrumentationAssertions(options: { context.skip(); } - const operation = findLatestSpan(events, "cohere-embed-operation"); + const operation = findLatestSpan( + events, + getOperationName("embed", { + useV2Namespace: options.useV2Namespace ?? false, + }), + ); const span = findCohereSpan(events, operation?.span.id, "cohere.embed"); const output = span?.output as { embedding_length?: number } | undefined; @@ -268,7 +304,12 @@ export function defineCohereInstrumentationAssertions(options: { context.skip(); } - const operation = findLatestSpan(events, "cohere-rerank-operation"); + const operation = findLatestSpan( + events, + getOperationName("rerank", { + useV2Namespace: options.useV2Namespace ?? false, + }), + ); const span = findCohereSpan(events, operation?.span.id, "cohere.rerank"); expect(operation).toBeDefined(); @@ -290,7 +331,11 @@ export function defineCohereInstrumentationAssertions(options: { await expect( formatJsonFileSnapshot( - buildSpanSummary(events, options.supportsThinking), + buildSpanSummary( + events, + options.supportsThinking, + options.useV2Namespace ?? false, + ), ), ).toMatchFileSnapshot(spanSnapshotPath); }); diff --git a/e2e/scenarios/cohere-instrumentation/scenario.cohere-v7.mjs b/e2e/scenarios/cohere-instrumentation/scenario.cohere-v7.mjs index 6b0f4d5bd..5a7c67e68 100644 --- a/e2e/scenarios/cohere-instrumentation/scenario.cohere-v7.mjs +++ b/e2e/scenarios/cohere-instrumentation/scenario.cohere-v7.mjs @@ -1,16 +1,13 @@ -import { - CohereClient as CohereClientV7, - CohereClientV2 as CohereClientV7V2, -} from "cohere-sdk-v7"; import { runMain } from "../../helpers/provider-runtime.mjs"; import { runAutoCohereInstrumentation } from "./scenario.impl.mjs"; -runMain(async () => - runAutoCohereInstrumentation(CohereClientV7, { +runMain(async () => { + const cohere = await import( + process.env.COHERE_PACKAGE_NAME ?? "cohere-sdk-v7" + ); + + await runAutoCohereInstrumentation(cohere.CohereClient, { apiVersion: "v7", - ThinkingCohereClient: - process.env.COHERE_SUPPORTS_THINKING === "1" - ? CohereClientV7V2 - : undefined, - }), -); + useV2Namespace: true, + }); +}); diff --git a/e2e/scenarios/cohere-instrumentation/scenario.cohere-v7.ts b/e2e/scenarios/cohere-instrumentation/scenario.cohere-v7.ts index 3a01391a8..fdb082a8a 100644 --- a/e2e/scenarios/cohere-instrumentation/scenario.cohere-v7.ts +++ b/e2e/scenarios/cohere-instrumentation/scenario.cohere-v7.ts @@ -1,18 +1,15 @@ import { wrapCohere } from "braintrust"; -import { - CohereClient as CohereClientV7, - CohereClientV2 as CohereClientV7V2, -} from "cohere-sdk-v7"; import { runMain } from "../../helpers/scenario-runtime"; import { runWrappedCohereInstrumentation } from "./scenario.impl.mjs"; -runMain(async () => - runWrappedCohereInstrumentation(CohereClientV7, { +runMain(async () => { + const cohere = await import( + process.env.COHERE_PACKAGE_NAME ?? "cohere-sdk-v7" + ); + + await runWrappedCohereInstrumentation(cohere.CohereClient, { apiVersion: "v7", decorateClient: wrapCohere, - ThinkingCohereClient: - process.env.COHERE_SUPPORTS_THINKING === "1" - ? CohereClientV7V2 - : undefined, - }), -); + useV2Namespace: true, + }); +}); diff --git a/e2e/scenarios/cohere-instrumentation/scenario.impl.mjs b/e2e/scenarios/cohere-instrumentation/scenario.impl.mjs index b1e74a8a6..9d94c5e25 100644 --- a/e2e/scenarios/cohere-instrumentation/scenario.impl.mjs +++ b/e2e/scenarios/cohere-instrumentation/scenario.impl.mjs @@ -24,6 +24,7 @@ export const COHERE_SCENARIO_SPECS = [ dependencyName: "cohere-sdk-v7-14-0", snapshotName: "cohere-v7-14-0", supportsThinking: false, + useV2Namespace: true, wrapperEntry: "scenario.cohere-v7.ts", }, { @@ -31,6 +32,7 @@ export const COHERE_SCENARIO_SPECS = [ autoEntry: "scenario.cohere-v7.mjs", dependencyName: "cohere-sdk-v7-20-0", snapshotName: "cohere-v7-20-0", + useV2Namespace: true, wrapperEntry: "scenario.cohere-v7.ts", }, { @@ -38,6 +40,7 @@ export const COHERE_SCENARIO_SPECS = [ autoEntry: "scenario.cohere-v7.mjs", dependencyName: "cohere-sdk-v7-21-0", snapshotName: "cohere-v7-21-0", + useV2Namespace: true, wrapperEntry: "scenario.cohere-v7.ts", }, { @@ -45,6 +48,7 @@ export const COHERE_SCENARIO_SPECS = [ autoEntry: "scenario.cohere-v7.mjs", dependencyName: "cohere-sdk-v7", snapshotName: "cohere-v7", + useV2Namespace: true, wrapperEntry: "scenario.cohere-v7.ts", }, { @@ -60,8 +64,8 @@ function getApiKey() { return process.env.COHERE_API_KEY || process.env.CO_API_KEY; } -function getChatRequest(apiVersion, { stream = false } = {}) { - if (apiVersion === "v8") { +function getChatRequest(apiVersion, { stream = false, useV2Api = false } = {}) { + if (apiVersion === "v8" || useV2Api) { return { model: CHAT_MODEL_V8, messages: [ @@ -157,9 +161,27 @@ function getRerankRequest(apiVersion) { }; } +function getOperationName(baseName, { useV2Namespace = false } = {}) { + return useV2Namespace + ? `cohere-v2-${baseName}-operation` + : `cohere-${baseName}-operation`; +} + +function getOperationClient(client, { useV2Namespace = false } = {}) { + if (!useV2Namespace) { + return client; + } + + if (!client.v2) { + throw new Error("Expected Cohere client to expose a v2 namespace"); + } + + return client.v2; +} + async function runCohereInstrumentationScenario( CohereClient, - { apiVersion, decorateClient, ThinkingCohereClient } = {}, + { apiVersion, decorateClient, ThinkingCohereClient, useV2Namespace } = {}, ) { const apiKey = getApiKey(); if (!apiKey) { @@ -170,6 +192,7 @@ async function runCohereInstrumentationScenario( token: apiKey, }); const client = decorateClient ? decorateClient(baseClient) : baseClient; + const operationClient = getOperationClient(client, { useV2Namespace }); const thinkingClientClass = ThinkingCohereClient ?? CohereClient; const thinkingBaseClient = thinkingClientClass === CohereClient @@ -180,19 +203,31 @@ async function runCohereInstrumentationScenario( const thinkingClient = decorateClient ? decorateClient(thinkingBaseClient) : thinkingBaseClient; + const thinkingOperationClient = getOperationClient(thinkingClient, { + useV2Namespace: useV2Namespace && thinkingClientClass === CohereClient, + }); await runTracedScenario({ callback: async () => { - await runOperation("cohere-chat-operation", "chat", async () => { - await client.chat(getChatRequest(apiVersion)); - }); + await runOperation( + getOperationName("chat", { useV2Namespace }), + "chat", + async () => { + await operationClient.chat( + getChatRequest(apiVersion, { useV2Api: useV2Namespace }), + ); + }, + ); await runOperation( - "cohere-chat-stream-operation", + getOperationName("chat-stream", { useV2Namespace }), "chat-stream", async () => { - const stream = await client.chatStream( - getChatRequest(apiVersion, { stream: true }), + const stream = await operationClient.chatStream( + getChatRequest(apiVersion, { + stream: true, + useV2Api: useV2Namespace, + }), ); await collectAsync(stream); }, @@ -200,10 +235,10 @@ async function runCohereInstrumentationScenario( if (shouldRunThinkingScenario(apiVersion)) { await runOperation( - "cohere-chat-stream-thinking-operation", + getOperationName("chat-stream-thinking", { useV2Namespace }), "chat-stream-thinking", async () => { - const stream = await thinkingClient.chatStream( + const stream = await thinkingOperationClient.chatStream( getThinkingChatRequest(), ); await collectAsync(stream); @@ -211,13 +246,21 @@ async function runCohereInstrumentationScenario( ); } - await runOperation("cohere-embed-operation", "embed", async () => { - await client.embed(getEmbedRequest(apiVersion)); - }); + await runOperation( + getOperationName("embed", { useV2Namespace }), + "embed", + async () => { + await operationClient.embed(getEmbedRequest(apiVersion)); + }, + ); - await runOperation("cohere-rerank-operation", "rerank", async () => { - await client.rerank(getRerankRequest(apiVersion)); - }); + await runOperation( + getOperationName("rerank", { useV2Namespace }), + "rerank", + async () => { + await operationClient.rerank(getRerankRequest(apiVersion)); + }, + ); }, metadata: { scenario: SCENARIO_NAME, diff --git a/e2e/scenarios/cohere-instrumentation/scenario.test.ts b/e2e/scenarios/cohere-instrumentation/scenario.test.ts index 70c6ad6ea..05c2d14db 100644 --- a/e2e/scenarios/cohere-instrumentation/scenario.test.ts +++ b/e2e/scenarios/cohere-instrumentation/scenario.test.ts @@ -34,6 +34,7 @@ for (const scenario of cohereScenarios) { await runScenarioDir({ entry: scenario.wrapperEntry, env: { + COHERE_PACKAGE_NAME: scenario.dependencyName, COHERE_SUPPORTS_THINKING: supportsThinking ? "1" : "0", }, runContext: { variantKey: scenario.snapshotName }, @@ -41,10 +42,15 @@ for (const scenario of cohereScenarios) { timeoutMs: COHERE_SCENARIO_TIMEOUT_MS, }); }, - snapshotName: scenario.snapshotName, + requireChatStreamOutput: scenario.snapshotName !== "cohere-v7-14-0", + snapshotName: + scenario.snapshotName === "cohere-v7-14-0" + ? "cohere-v7-14-0-wrapped" + : scenario.snapshotName, supportsThinking, testFileUrl: import.meta.url, timeoutMs: COHERE_SCENARIO_TIMEOUT_MS, + useV2Namespace: scenario.useV2Namespace ?? false, }); defineCohereInstrumentationAssertions({ @@ -53,6 +59,7 @@ for (const scenario of cohereScenarios) { await runNodeScenarioDir({ entry: scenario.autoEntry, env: { + COHERE_PACKAGE_NAME: scenario.dependencyName, COHERE_SUPPORTS_THINKING: supportsThinking ? "1" : "0", }, nodeArgs: ["--import", "braintrust/hook.mjs"], @@ -65,6 +72,7 @@ for (const scenario of cohereScenarios) { supportsThinking, testFileUrl: import.meta.url, timeoutMs: COHERE_SCENARIO_TIMEOUT_MS, + useV2Namespace: scenario.useV2Namespace ?? false, }); }); } diff --git a/js/src/auto-instrumentations/configs/cohere.ts b/js/src/auto-instrumentations/configs/cohere.ts index 955038c84..8810fa8c6 100644 --- a/js/src/auto-instrumentations/configs/cohere.ts +++ b/js/src/auto-instrumentations/configs/cohere.ts @@ -19,7 +19,20 @@ export const cohereConfigs: InstrumentationConfig[] = [ channelName: cohereChannels.chat.channelName, module: { name: "cohere-ai", - versionRange: ">=7.20.0 <8.0.0", + versionRange: ">=7.0.0 <7.21.0", + filePath: "api/resources/v2/client/Client.js", + }, + functionQuery: { + className: "V2", + methodName: "chat", + kind: "Async", + }, + }, + { + channelName: cohereChannels.chat.channelName, + module: { + name: "cohere-ai", + versionRange: ">=7.21.0 <8.0.0", filePath: "api/resources/v2/client/Client.js", }, functionQuery: { @@ -71,7 +84,20 @@ export const cohereConfigs: InstrumentationConfig[] = [ channelName: cohereChannels.chatStream.channelName, module: { name: "cohere-ai", - versionRange: ">=7.20.0 <8.0.0", + versionRange: ">=7.0.0 <7.21.0", + filePath: "api/resources/v2/client/Client.js", + }, + functionQuery: { + className: "V2", + methodName: "chatStream", + kind: "Async", + }, + }, + { + channelName: cohereChannels.chatStream.channelName, + module: { + name: "cohere-ai", + versionRange: ">=7.21.0 <8.0.0", filePath: "api/resources/v2/client/Client.js", }, functionQuery: { @@ -123,7 +149,20 @@ export const cohereConfigs: InstrumentationConfig[] = [ channelName: cohereChannels.embed.channelName, module: { name: "cohere-ai", - versionRange: ">=7.20.0 <8.0.0", + versionRange: ">=7.0.0 <7.21.0", + filePath: "api/resources/v2/client/Client.js", + }, + functionQuery: { + className: "V2", + methodName: "embed", + kind: "Async", + }, + }, + { + channelName: cohereChannels.embed.channelName, + module: { + name: "cohere-ai", + versionRange: ">=7.21.0 <8.0.0", filePath: "api/resources/v2/client/Client.js", }, functionQuery: { @@ -175,7 +214,20 @@ export const cohereConfigs: InstrumentationConfig[] = [ channelName: cohereChannels.rerank.channelName, module: { name: "cohere-ai", - versionRange: ">=7.20.0 <8.0.0", + versionRange: ">=7.0.0 <7.21.0", + filePath: "api/resources/v2/client/Client.js", + }, + functionQuery: { + className: "V2", + methodName: "rerank", + kind: "Async", + }, + }, + { + channelName: cohereChannels.rerank.channelName, + module: { + name: "cohere-ai", + versionRange: ">=7.21.0 <8.0.0", filePath: "api/resources/v2/client/Client.js", }, functionQuery: { diff --git a/js/src/vendor-sdk-types/cohere.ts b/js/src/vendor-sdk-types/cohere.ts index 920c01f3b..0e4158842 100644 --- a/js/src/vendor-sdk-types/cohere.ts +++ b/js/src/vendor-sdk-types/cohere.ts @@ -188,5 +188,6 @@ export type CohereClient = { request: CohereRerankRequest, options?: unknown, ) => Promise; + v2?: CohereClient; [key: string]: unknown; }; diff --git a/js/src/wrappers/cohere.test.ts b/js/src/wrappers/cohere.test.ts index d193404ce..b5a473a70 100644 --- a/js/src/wrappers/cohere.test.ts +++ b/js/src/wrappers/cohere.test.ts @@ -143,6 +143,56 @@ describe("cohere wrapper", () => { }); }); + test("preserves chatStream promise subclass helpers", async () => { + class MockResponsePromise extends Promise { + withRawResponse() { + return "raw"; + } + } + + async function* stream() { + yield { + eventType: "stream-end", + response: { + finishReason: "COMPLETE", + id: "resp_stream", + meta: { + tokens: { + inputTokens: 1, + outputTokens: 1, + }, + }, + text: "OK", + }, + }; + } + + const rawPromise = new MockResponsePromise>( + (resolve) => { + resolve(stream()); + }, + ); + const client = wrapCohere({ + chatStream: vi.fn(() => rawPromise), + }); + + const resultPromise = client.chatStream({ + message: "Say OK", + model: "command-r", + }); + + expect((resultPromise as any).withRawResponse()).toBe("raw"); + const result = await resultPromise; + const chunks: unknown[] = []; + for await (const chunk of result) { + chunks.push(chunk); + } + expect(chunks).toHaveLength(1); + + const spans = await backgroundLogger.drain(); + expect(spans).toHaveLength(1); + }); + test("wraps embed and rerank", async () => { const client = wrapCohere({ embed: vi.fn(async () => ({ @@ -219,4 +269,50 @@ describe("cohere wrapper", () => { }, ]); }); + + test("wraps methods on v2 namespace clients", async () => { + const rawV2Client = { + chat: vi.fn(async () => ({ + finishReason: "COMPLETE", + id: "resp_v2_chat", + meta: { + tokens: { + inputTokens: 6, + outputTokens: 2, + }, + }, + message: { + content: "OK", + role: "assistant", + }, + })), + }; + const client = wrapCohere({ + chat: vi.fn(), + v2: rawV2Client, + }); + + expect(client.v2).toBe(client.v2); + + await client.v2.chat({ + messages: [{ content: "Reply with exactly OK.", role: "user" }], + model: "command-a-03-2025", + temperature: 0, + }); + + expect(rawV2Client.chat).toHaveBeenCalledTimes(1); + + const spans = await backgroundLogger.drain(); + expect(spans).toHaveLength(1); + const span = spans[0] as Record; + expect(span.span_attributes).toMatchObject({ + name: "cohere.chat", + type: "llm", + }); + expect(span.metadata).toMatchObject({ + provider: "cohere", + model: "command-a-03-2025", + temperature: 0, + }); + }); }); diff --git a/js/src/wrappers/cohere.ts b/js/src/wrappers/cohere.ts index 72a8d6554..aa6038030 100644 --- a/js/src/wrappers/cohere.ts +++ b/js/src/wrappers/cohere.ts @@ -24,6 +24,8 @@ export function wrapCohere(cohere: T): T { return cohere; } +const cohereProxyCache = new WeakMap(); + function isRecord(value: unknown): value is Record { return typeof value === "object" && value !== null; } @@ -50,7 +52,12 @@ function isSupportedCohereClient(value: unknown): value is CohereClient { } function cohereProxy(cohere: CohereClient): CohereClient { - return new Proxy(cohere, { + const cached = cohereProxyCache.get(cohere); + if (cached) { + return cached; + } + + const proxy = new Proxy(cohere, { get(target, prop, receiver) { switch (prop) { case "chat": @@ -69,11 +76,16 @@ function cohereProxy(cohere: CohereClient): CohereClient { return typeof target.rerank === "function" ? wrapRerank(target.rerank.bind(target)) : target.rerank; - default: - return Reflect.get(target, prop, receiver); + default: { + const value = Reflect.get(target, prop, receiver); + return isSupportedCohereClient(value) ? cohereProxy(value) : value; + } } }, }); + + cohereProxyCache.set(cohere, proxy); + return proxy; } function wrapChat(