From 893319f7c28e65df41539e94824611a9623a9006 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Mon, 22 Jun 2026 17:16:32 +0100 Subject: [PATCH] Add getBearerToken callback for BYOK providers (Managed Identity v1) Adds a per-provider getBearerToken callback on ProviderConfig / NamedProviderConfig so SDK consumers can resolve bearer tokens (e.g. via @azure/identity) on the client side. The SDK strips the non-serializable function, sends a bearerTokenProvider wire flag, and answers the runtime's session-scoped providerToken.acquire RPC by dispatching to the matching per-provider callback. The token surface is provider-agnostic: bearerTokenScope is forwarded verbatim with no default. Includes e2e coverage (callback token reaches the model as the Authorization header, token refresh on expiry, and per-provider dispatch) with hand-authored replay snapshots. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- nodejs/src/client.ts | 87 +++++- nodejs/src/generated/rpc.ts | 73 +++++ nodejs/src/session.ts | 48 ++++ nodejs/src/types.ts | 106 ++++++++ .../byok_bearer_token_provider.e2e.test.ts | 250 ++++++++++++++++++ ...k_s_token_as_the_authorization_header.yaml | 10 + ...atches_token_acquisition_per_provider.yaml | 17 ++ ...ken_when_the_previous_one_has_expired.yaml | 14 + 8 files changed, 601 insertions(+), 4 deletions(-) create mode 100644 nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts create mode 100644 test/snapshots/byok_bearer_token_provider/applies_the_callback_s_token_as_the_authorization_header.yaml create mode 100644 test/snapshots/byok_bearer_token_provider/dispatches_token_acquisition_per_provider.yaml create mode 100644 test/snapshots/byok_bearer_token_provider/re_acquires_a_fresh_token_when_the_previous_one_has_expired.yaml diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index a6efb061a..b0810a5c9 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -47,11 +47,14 @@ import type { ExitPlanModeResult, ForegroundSessionInfo, GetAuthStatusResponse, + GetBearerToken, GetStatusResponse, InternalRuntimeConnection, LargeToolOutputConfig, MCPServerConfig, ModelInfo, + NamedProviderConfig, + ProviderConfig, ResumeSessionConfig, SectionTransformFn, SessionConfig, @@ -150,6 +153,62 @@ function toJsonSchema(parameters: Tool["parameters"]): Record | return parameters; } +/** Implicit provider name for the singular, whole-session {@link ProviderConfig}. */ +const DEFAULT_PROVIDER_NAME = "default"; + +/** Wire-safe singular provider config carrying the `bearerTokenProvider` flag. */ +type WireProviderConfig = Omit & { bearerTokenProvider?: boolean }; + +/** Wire-safe named provider config carrying the `bearerTokenProvider` flag. */ +type WireNamedProviderConfig = Omit & { + bearerTokenProvider?: boolean; +}; + +/** + * Strips the non-serializable {@link GetBearerToken} callbacks from the singular + * and named provider configs before they cross the RPC boundary, replacing each + * with a `bearerTokenProvider: true` wire flag. Any configured + * {@link ProviderConfig.bearerTokenScope} is forwarded verbatim (the bearer-token + * surface is provider-agnostic, so the SDK never substitutes a default scope). + * Returns wire-safe provider configs alongside a map of provider name → callback + * for session-side registration. + */ +function extractBearerTokenProviders( + provider: ProviderConfig | undefined, + providers: NamedProviderConfig[] | undefined +): { + wireProvider: WireProviderConfig | undefined; + wireProviders: WireNamedProviderConfig[] | undefined; + callbacks: Map; +} { + const callbacks = new Map(); + + let wireProvider: WireProviderConfig | undefined = provider; + if (provider?.getBearerToken) { + const { getBearerToken, ...rest } = provider; + callbacks.set(DEFAULT_PROVIDER_NAME, getBearerToken); + wireProvider = { + ...rest, + bearerTokenProvider: true, + }; + } + + let wireProviders: WireNamedProviderConfig[] | undefined = providers; + if (providers?.some((p) => p.getBearerToken)) { + wireProviders = providers.map((p) => { + if (!p.getBearerToken) return p; + const { getBearerToken, ...rest } = p; + callbacks.set(p.name, getBearerToken); + return { + ...rest, + bearerTokenProvider: true, + }; + }); + } + + return { wireProvider, wireProviders, callbacks }; +} + /** * Convert MCP server configs from public API format (workingDirectory) to * wire format (cwd) expected by the runtime. @@ -1161,6 +1220,15 @@ export class CopilotClient { const useServerGeneratedId = config.cloud != null && callerSessionId == null; const localSessionId = useServerGeneratedId ? undefined : (callerSessionId ?? randomUUID()); + // Strip non-serializable getBearerToken callbacks from provider configs, + // replacing them with a wire flag; keep the callbacks for session-side + // registration so the runtime can call back to acquire tokens. + const { + wireProvider: bearerWireProvider, + wireProviders: bearerWireProviders, + callbacks: bearerTokenCallbacks, + } = extractBearerTokenProviders(config.provider, config.providers); + // Extract transform callbacks from system message config before serialization. const { wirePayload: wireSystemMessage, transformCallbacks } = extractTransformCallbacks( config.systemMessage @@ -1178,6 +1246,9 @@ export class CopilotClient { s.registerTools(config.tools); s.registerCanvases(config.canvases); s.registerCommands(config.commands); + if (bearerTokenCallbacks.size > 0) { + s.registerBearerTokenProviders(bearerTokenCallbacks); + } s.registerPermissionHandler(config.onPermissionRequest); if (config.onUserInputRequest) { s.registerUserInputHandler(config.onUserInputRequest); @@ -1249,8 +1320,8 @@ export class CopilotClient { availableTools: toolFilterOptions.availableTools, excludedTools: toolFilterOptions.excludedTools, toolFilterPrecedence: toolFilterOptions.toolFilterPrecedence, - provider: config.provider, - providers: config.providers, + provider: bearerWireProvider, + providers: bearerWireProviders, models: config.models, enableSessionTelemetry: config.enableSessionTelemetry, modelCapabilities: config.modelCapabilities, @@ -1369,6 +1440,14 @@ export class CopilotClient { session.registerTools(config.tools); session.registerCanvases(config.canvases); session.registerCommands(config.commands); + const { + wireProvider: bearerWireProvider, + wireProviders: bearerWireProviders, + callbacks: bearerTokenCallbacks, + } = extractBearerTokenProviders(config.provider, config.providers); + if (bearerTokenCallbacks.size > 0) { + session.registerBearerTokenProviders(bearerTokenCallbacks); + } session.registerPermissionHandler(config.onPermissionRequest); if (config.onUserInputRequest) { session.registerUserInputHandler(config.onUserInputRequest); @@ -1435,8 +1514,8 @@ export class CopilotClient { name: cmd.name, description: cmd.description, })), - provider: config.provider, - providers: config.providers, + provider: bearerWireProvider, + providers: bearerWireProviders, models: config.models, modelCapabilities: config.modelCapabilities, largeOutput: toWireLargeOutput(config.largeOutput), diff --git a/nodejs/src/generated/rpc.ts b/nodejs/src/generated/rpc.ts index 6303f9db2..7752a5a3b 100644 --- a/nodejs/src/generated/rpc.ts +++ b/nodejs/src/generated/rpc.ts @@ -6516,6 +6516,14 @@ export interface NamedProviderConfig { headers?: { [k: string]: string | undefined; }; + /** + * When true, the SDK client supplies bearer tokens on demand: the runtime calls the client-session `providerToken.acquire` callback (with this provider's `name`) before each request and uses the returned token as the Authorization header. The token-acquiring function itself stays on the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive with `apiKey`/`bearerToken`. + */ + bearerTokenProvider?: boolean; + /** + * Token scope forwarded to the `providerToken.acquire` callback when `bearerTokenProvider` is set. Optional and provider-agnostic: when omitted, an empty scope is forwarded and the callback is responsible for supplying the correct scope to its identity library. + */ + bearerTokenScope?: string; } /** * Azure-specific provider options. @@ -8422,6 +8430,14 @@ export interface ProviderConfig { headers?: { [k: string]: string | undefined; }; + /** + * When true, the SDK client supplies bearer tokens on demand: the runtime calls the client-session `providerToken.acquire` callback before each request and uses the returned token as the Authorization header. The token-acquiring function itself stays on the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive with `apiKey`/`bearerToken`. + */ + bearerTokenProvider?: boolean; + /** + * Token scope forwarded to the `providerToken.acquire` callback when `bearerTokenProvider` is set. Optional and provider-agnostic: when omitted, an empty scope is forwarded and the callback is responsible for supplying the correct scope to its identity library. + */ + bearerTokenScope?: string; } /** * A snapshot of the provider endpoint the session is currently configured to talk to. @@ -8487,6 +8503,44 @@ export interface ProviderGetEndpointRequest { */ modelId?: string; } +/** + * Asks the SDK client to acquire a bearer token for a BYOK provider whose config set `bearerTokenProvider: true`. Issued by the runtime before an outbound model request when no fresh cached token is available. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "ProviderTokenAcquireRequest". + */ +/** @experimental */ +export interface ProviderTokenAcquireRequest { + /** + * Target session identifier + */ + sessionId: string; + /** + * Name of the BYOK provider needing a token. For the legacy whole-session `provider` this is the implicit provider name; for named providers it is `NamedProviderConfig.name`. + */ + providerName: string; + /** + * Token scope to request, mirroring the provider's `bearerTokenScope`. Empty when the provider configured no scope, in which case the callback is responsible for supplying its own scope. Provider-agnostic: no default scope is assumed. + */ + scope: string; +} +/** + * A bearer token supplied by the SDK client for a BYOK provider. The runtime sets it as `Authorization: Bearer ` and caches it until shortly before `expiresOnTimestamp`. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "ProviderTokenAcquireResult". + */ +/** @experimental */ +export interface ProviderTokenAcquireResult { + /** + * The bearer token value (without the `Bearer ` prefix). + */ + token: string; + /** + * Unix epoch time in milliseconds at which the token expires. When omitted, the runtime treats the token as single-use and re-acquires on the next request. + */ + expiresOnTimestamp?: number; +} /** * A BYOK model definition referencing a named provider. * @@ -15632,6 +15686,19 @@ export function createInternalSessionRpc(connection: MessageConnection, sessionI }; } +/** Handler for `providerToken` client session API methods. */ +/** @experimental */ +export interface ProviderTokenHandler { + /** + * Asks the SDK client to acquire a bearer token for a BYOK provider whose config set `bearerTokenProvider: true`. Session-scoped: the runtime calls it back on the connection that created the session, passing the provider name and scope, and uses the returned token as the Authorization header for outbound model requests (cached until shortly before its expiry). + * + * @param params Asks the SDK client to acquire a bearer token for a BYOK provider whose config set `bearerTokenProvider: true`. Issued by the runtime before an outbound model request when no fresh cached token is available. + * + * @returns A bearer token supplied by the SDK client for a BYOK provider. The runtime sets it as `Authorization: Bearer ` and caches it until shortly before `expiresOnTimestamp`. + */ + acquire(params: ProviderTokenAcquireRequest): Promise; +} + /** Handler for `sessionFs` client session API methods. */ /** @experimental */ export interface SessionFsHandler { @@ -15762,6 +15829,7 @@ export interface CanvasHandler { /** All client session API handler groups. */ export interface ClientSessionApiHandlers { + providerToken?: ProviderTokenHandler; sessionFs?: SessionFsHandler; canvas?: CanvasHandler; } @@ -15776,6 +15844,11 @@ export function registerClientSessionApiHandlers( connection: MessageConnection, getHandlers: (sessionId: string) => ClientSessionApiHandlers, ): void { + connection.onRequest("providerToken.acquire", async (params: ProviderTokenAcquireRequest) => { + const handler = getHandlers(params.sessionId).providerToken; + if (!handler) throw new Error(`No providerToken handler registered for session: ${params.sessionId}`); + return handler.acquire(params); + }); connection.onRequest("sessionFs.readFile", async (params: SessionFsReadFileRequest) => { const handler = getHandlers(params.sessionId).sessionFs; if (!handler) throw new Error(`No sessionFs handler registered for session: ${params.sessionId}`); diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index 8ae19755a..d7f5ea36e 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -26,6 +26,7 @@ import type { ExitPlanModeHandler, ExitPlanModeRequest, ExitPlanModeResult, + GetBearerToken, UiInputOptions, MessageOptions, PermissionHandler, @@ -122,6 +123,7 @@ export class CopilotSession { new Map(); private toolHandlers: Map = new Map(); private canvases: Map = new Map(); + private bearerTokenProviders: Map = new Map(); private commandHandlers: Map = new Map(); private permissionHandler?: PermissionHandler; private userInputHandler?: UserInputHandler; @@ -759,6 +761,52 @@ export class CopilotSession { }; } + /** + * Registers per-provider {@link GetBearerToken} callbacks for BYOK providers + * configured with managed-identity / on-demand bearer-token auth. + * + * The runtime never receives the callback itself; the SDK strips it from the + * provider config and instead sends `bearerTokenProvider: true`. When the + * runtime needs a token it issues a session-scoped `providerToken.acquire` + * request, which this handler routes to the matching per-provider callback. + * + * @param providers - Map of provider name → callback, or undefined/empty to clear. + * @internal This method is called internally when creating/resuming a session. + */ + registerBearerTokenProviders(providers?: Map): void { + this.bearerTokenProviders.clear(); + if (!providers || providers.size === 0) { + delete this.clientSessionApis.providerToken; + return; + } + for (const [name, callback] of providers) { + this.bearerTokenProviders.set(name, callback); + } + + const self = this; + this.clientSessionApis.providerToken = { + async acquire(params) { + const callback = self.bearerTokenProviders.get(params.providerName); + if (!callback) { + throw new Error( + `No bearer-token provider registered for provider "${params.providerName}"` + ); + } + const result = await callback({ + providerName: params.providerName, + scope: params.scope, + }); + if (typeof result === "string") { + return { token: result }; + } + return { + token: result.token, + expiresOnTimestamp: result.expiresOnTimestamp, + }; + }, + }; + } + /** * Registers command handlers for this session. * diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index f198a88b3..2824467d7 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -2127,6 +2127,70 @@ export interface ResumeSessionConfig extends SessionConfigBase { openCanvases?: OpenCanvasInstance[]; } +/** + * Arguments passed to a {@link GetBearerToken} callback when the runtime needs a + * fresh bearer token for a BYOK provider. + * + * @experimental Part of the experimental managed-identity / bearer-token-provider + * surface and may change or be removed in future SDK or CLI releases. + */ +export interface ProviderTokenArgs { + /** + * Name of the BYOK provider needing a token. For the singular, whole-session + * {@link ProviderConfig} this is the implicit provider name (`"default"`); for + * {@link NamedProviderConfig} entries it is {@link NamedProviderConfig.name}. + */ + providerName: string; + + /** + * Token scope to request, exactly as configured in the provider's + * {@link ProviderConfig.bearerTokenScope}. Empty when the provider did not + * configure a scope — in that case the callback is responsible for supplying + * the correct scope to its identity library. The bearer-token surface is + * provider-agnostic; no scope default is assumed. + */ + scope: string; +} + +/** + * A bearer token (and optional expiry) returned from a {@link GetBearerToken} + * callback. The shape mirrors the Azure Identity `AccessToken` so the result of + * `credential.getToken(scope)` can be returned directly. + * + * @experimental Part of the experimental managed-identity / bearer-token-provider + * surface and may change or be removed in future SDK or CLI releases. + */ +export interface ProviderBearerToken { + /** + * The bearer token value (without the `Bearer ` prefix). + */ + token: string; + + /** + * Unix epoch time in milliseconds at which the token expires. When provided, + * the runtime caches the token and only re-invokes the callback shortly + * before expiry. When omitted, the runtime re-invokes the callback on every + * request. + */ + expiresOnTimestamp?: number; +} + +/** + * Per-provider callback that resolves a bearer token on demand. The Copilot SDK + * itself takes no Azure dependency: the consumer supplies this callback backed by + * their own identity library (for example `@azure/identity`'s + * `DefaultAzureCredential.getToken(scope)`), and the runtime calls it before + * outbound model requests, caching and refreshing automatically based on the + * returned {@link ProviderBearerToken.expiresOnTimestamp}. + * + * Returning a bare string is equivalent to returning `{ token }` with no expiry + * (re-invoked every request). + * + * @experimental Part of the experimental managed-identity / bearer-token-provider + * surface and may change or be removed in future SDK or CLI releases. + */ +export type GetBearerToken = (args: ProviderTokenArgs) => Promise; + /** * Configuration for a custom API provider. */ @@ -2158,6 +2222,27 @@ export interface ProviderConfig { */ bearerToken?: string; + /** + * Per-request bearer-token provider for managed-identity / on-demand auth. + * When set, the SDK keeps this function client-side (it is never serialized) + * and the runtime calls back into this client to acquire a token before each + * outbound request, caching and refreshing automatically. Mutually exclusive + * with {@link apiKey} / {@link bearerToken}. + * + * @experimental + */ + getBearerToken?: GetBearerToken; + + /** + * Token scope forwarded to {@link getBearerToken} as + * {@link ProviderTokenArgs.scope}. Optional and provider-agnostic: when + * omitted, the callback receives an empty scope and is responsible for + * supplying the correct scope to its identity library. + * + * @experimental + */ + bearerTokenScope?: string; + /** * Azure-specific options */ @@ -2249,6 +2334,27 @@ export interface NamedProviderConfig { */ bearerToken?: string; + /** + * Per-request bearer-token provider for managed-identity / on-demand auth. + * When set, the SDK keeps this function client-side (it is never serialized) + * and the runtime calls back into this client to acquire a token before each + * outbound request, caching and refreshing automatically. Mutually exclusive + * with {@link apiKey} / {@link bearerToken}. + * + * @experimental + */ + getBearerToken?: GetBearerToken; + + /** + * Token scope forwarded to {@link getBearerToken} as + * {@link ProviderTokenArgs.scope}. Optional and provider-agnostic: when + * omitted, the callback receives an empty scope and is responsible for + * supplying the correct scope to its identity library. + * + * @experimental + */ + bearerTokenScope?: string; + /** * Azure-specific options. */ diff --git a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts new file mode 100644 index 000000000..ceb9d9c47 --- /dev/null +++ b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts @@ -0,0 +1,250 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll } from "../../src/index.js"; +import type { + GetBearerToken, + NamedProviderConfig, + ProviderModelConfig, +} from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; +import { retry } from "./harness/sdkTestHelper.js"; +import type { ParsedHttpExchange } from "../../../test/harness/replayingCapiProxy"; + +/** + * End-to-end coverage for the experimental BYOK bearer-token-provider surface + * (`getBearerToken` on a provider config). The callback stays entirely on the + * SDK/client side: the SDK strips it from the wire config, sets the + * `bearerTokenProvider` flag, and the runtime calls back over the session-scoped + * `providerToken.acquire` RPC before each outbound model request, applying the + * returned token as the `Authorization` header. + * + * These tests validate, against a real runtime + replaying model proxy: + * 1. the callback's token reaches the model as `Authorization: Bearer `; + * 2. tokens are refreshed (re-acquired) when they expire; + * 3. per-provider dispatch routes each provider's turn to its own callback. + */ +describe("BYOK bearer-token provider", async () => { + const { copilotClient: client, openAiEndpoint } = await createSdkTestContext(); + + async function waitForExchanges(minimumCount = 1): Promise { + await retry( + `capture ${minimumCount} chat completion request(s)`, + async () => { + const exchanges = await openAiEndpoint.getExchanges(); + expect(exchanges.length).toBeGreaterThanOrEqual(minimumCount); + }, + 1_200 + ); + return openAiEndpoint.getExchanges(); + } + + function getHeader(exchange: ParsedHttpExchange, name: string): string | undefined { + const headers = exchange.requestHeaders ?? {}; + const key = Object.keys(headers).find((k) => k.toLowerCase() === name.toLowerCase()); + if (key === undefined) { + return undefined; + } + const value = headers[key]; + return Array.isArray(value) ? value[0] : value; + } + + it("applies the callback's token as the Authorization header", async () => { + const SENTINEL = "sentinel-bearer-token-abc123"; + let calls = 0; + const getBearerToken: GetBearerToken = async () => { + calls += 1; + // Far-future expiry: the runtime caches it, so a single turn needs + // only one acquisition. + return { token: SENTINEL, expiresOnTimestamp: Date.now() + 60 * 60 * 1000 }; + }; + + const providers: NamedProviderConfig[] = [ + { + name: "mi", + type: "openai", + wireApi: "completions", + baseUrl: openAiEndpoint.url, + getBearerToken, + }, + ]; + const models: ProviderModelConfig[] = [ + { id: "default", provider: "mi", wireModel: "byok-gpt-4o" }, + ]; + + const session = await client.createSession({ + onPermissionRequest: approveAll, + model: "mi/default", + providers, + models, + }); + + try { + const reply = await session.sendAndWait({ prompt: "What is 5+5?" }); + const exchanges = await waitForExchanges(); + expect(exchanges.length).toBe(1); + + // The runtime acquired a token via the callback and applied it + // verbatim as the bearer credential on the outbound model request. + expect(getHeader(exchanges[0], "Authorization")).toBe(`Bearer ${SENTINEL}`); + // The far-future expiry means the token is cached, so the single + // turn needs only one acquisition (it is never re-fetched mid-turn). + expect(calls).toBe(1); + + // Validate the final assistant response arrived (guards against + // truncated captures). + expect(reply?.data.content).toContain("10"); + } finally { + try { + await session.disconnect(); + } catch { + // ignore disconnect errors for the fake BYOK endpoint + } + } + }); + + it("re-acquires a fresh token when the previous one has expired", async () => { + let calls = 0; + const getBearerToken: GetBearerToken = async () => { + calls += 1; + // Already-expired expiry forces the runtime to re-acquire on the next + // request rather than reuse the cached token. + return { token: `rotating-token-${calls}`, expiresOnTimestamp: Date.now() - 1 }; + }; + + const providers: NamedProviderConfig[] = [ + { + name: "mi", + type: "openai", + wireApi: "completions", + baseUrl: openAiEndpoint.url, + getBearerToken, + }, + ]; + const models: ProviderModelConfig[] = [ + { id: "default", provider: "mi", wireModel: "byok-gpt-4o" }, + ]; + + const session = await client.createSession({ + onPermissionRequest: approveAll, + model: "mi/default", + providers, + models, + }); + + try { + const reply1 = await session.sendAndWait({ prompt: "What is 1+1?" }); + const afterTurn1 = await waitForExchanges(1); + const callsAfterTurn1 = calls; + + const reply2 = await session.sendAndWait({ prompt: "What is 2+2?" }); + const exchanges = await waitForExchanges(2); + + // Each outbound request carries a freshly-acquired, distinct token, + // proving the runtime refreshed rather than reusing the expired one. + const auth1 = getHeader(exchanges[0], "Authorization"); + const auth2 = getHeader(exchanges[1], "Authorization"); + expect(auth1).toMatch(/^Bearer rotating-token-\d+$/); + expect(auth2).toMatch(/^Bearer rotating-token-\d+$/); + expect(auth1).not.toBe(auth2); + + // The second turn triggered at least one additional acquisition. + expect(calls).toBeGreaterThan(callsAfterTurn1); + + // Validate the final assistant responses arrived (guards against + // truncated captures). + expect(reply1?.data.content).toContain("2"); + expect(reply2?.data.content).toContain("4"); + void afterTurn1; + } finally { + try { + await session.disconnect(); + } catch { + // ignore disconnect errors for the fake BYOK endpoint + } + } + }); + + it("dispatches token acquisition per provider", async () => { + const tokenByProvider: Record = { + red: "token-for-red", + blue: "token-for-blue", + }; + const acquiredFor: string[] = []; + const makeCallback = + (providerName: string): GetBearerToken => + async (args) => { + // The runtime forwards the requesting provider's name so the + // client can dispatch to the right credential. + expect(args.providerName).toBe(providerName); + acquiredFor.push(providerName); + return { + token: tokenByProvider[providerName], + expiresOnTimestamp: Date.now() + 60 * 60 * 1000, + }; + }; + + const providers: NamedProviderConfig[] = [ + { + name: "red", + type: "openai", + wireApi: "completions", + baseUrl: openAiEndpoint.url, + getBearerToken: makeCallback("red"), + }, + { + name: "blue", + type: "openai", + wireApi: "completions", + baseUrl: openAiEndpoint.url, + getBearerToken: makeCallback("blue"), + }, + ]; + const models: ProviderModelConfig[] = [ + { id: "default", provider: "red", wireModel: "byok-gpt-4o" }, + { id: "default", provider: "blue", wireModel: "byok-gpt-4o" }, + ]; + + async function runTurn(selectionId: string, prompt: string): Promise { + const session = await client.createSession({ + onPermissionRequest: approveAll, + model: selectionId, + providers, + models, + }); + try { + const reply = await session.sendAndWait({ prompt }); + return reply?.data.content; + } finally { + try { + await session.disconnect(); + } catch { + // ignore disconnect errors for the fake BYOK endpoint + } + } + } + + const replyRed = await runTurn("red/default", "What is 3+3?"); + const afterRed = await waitForExchanges(1); + expect(getHeader(afterRed[0], "Authorization")).toBe(`Bearer ${tokenByProvider.red}`); + + const replyBlue = await runTurn("blue/default", "What is 4+4?"); + const exchanges = await waitForExchanges(2); + + // The two turns were authenticated with their respective providers' + // tokens, proving per-provider dispatch (not a single session-global + // credential). + const authValues = exchanges.map((e) => getHeader(e, "Authorization")); + expect(authValues).toContain(`Bearer ${tokenByProvider.red}`); + expect(authValues).toContain(`Bearer ${tokenByProvider.blue}`); + expect(acquiredFor).toContain("red"); + expect(acquiredFor).toContain("blue"); + + // Validate the final assistant responses arrived (guards against + // truncated captures). + expect(replyRed).toContain("6"); + expect(replyBlue).toContain("8"); + }); +}); diff --git a/test/snapshots/byok_bearer_token_provider/applies_the_callback_s_token_as_the_authorization_header.yaml b/test/snapshots/byok_bearer_token_provider/applies_the_callback_s_token_as_the_authorization_header.yaml new file mode 100644 index 000000000..faa2379e8 --- /dev/null +++ b/test/snapshots/byok_bearer_token_provider/applies_the_callback_s_token_as_the_authorization_header.yaml @@ -0,0 +1,10 @@ +models: + - byok-gpt-4o +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 5+5? + - role: assistant + content: 5 + 5 = 10 diff --git a/test/snapshots/byok_bearer_token_provider/dispatches_token_acquisition_per_provider.yaml b/test/snapshots/byok_bearer_token_provider/dispatches_token_acquisition_per_provider.yaml new file mode 100644 index 000000000..4297653bb --- /dev/null +++ b/test/snapshots/byok_bearer_token_provider/dispatches_token_acquisition_per_provider.yaml @@ -0,0 +1,17 @@ +models: + - byok-gpt-4o +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 3+3? + - role: assistant + content: 3 + 3 = 6 + - messages: + - role: system + content: ${system} + - role: user + content: What is 4+4? + - role: assistant + content: 4 + 4 = 8 diff --git a/test/snapshots/byok_bearer_token_provider/re_acquires_a_fresh_token_when_the_previous_one_has_expired.yaml b/test/snapshots/byok_bearer_token_provider/re_acquires_a_fresh_token_when_the_previous_one_has_expired.yaml new file mode 100644 index 000000000..01e073f98 --- /dev/null +++ b/test/snapshots/byok_bearer_token_provider/re_acquires_a_fresh_token_when_the_previous_one_has_expired.yaml @@ -0,0 +1,14 @@ +models: + - byok-gpt-4o +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 1+1? + - role: assistant + content: 1 + 1 = 2 + - role: user + content: What is 2+2? + - role: assistant + content: 2 + 2 = 4