Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 83 additions & 4 deletions nodejs/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ import type {
ExitPlanModeResult,
ForegroundSessionInfo,
GetAuthStatusResponse,
GetBearerToken,
GetStatusResponse,
InternalRuntimeConnection,
LargeToolOutputConfig,
MCPServerConfig,
ModelInfo,
NamedProviderConfig,
ProviderConfig,
ResumeSessionConfig,
SectionTransformFn,
SessionConfig,
Expand Down Expand Up @@ -150,6 +153,62 @@ function toJsonSchema(parameters: Tool["parameters"]): Record<string, unknown> |
return parameters;
}

/** Implicit provider name for the singular, whole-session {@link ProviderConfig}. */
const DEFAULT_PROVIDER_NAME = "default";

/** Wire-safe singular provider config carrying the `bearerTokenProvider` flag. */
type WireProviderConfig = Omit<ProviderConfig, "getBearerToken"> & { bearerTokenProvider?: boolean };

/** Wire-safe named provider config carrying the `bearerTokenProvider` flag. */
type WireNamedProviderConfig = Omit<NamedProviderConfig, "getBearerToken"> & {
bearerTokenProvider?: boolean;
};

/**
* Strips the non-serializable {@link GetBearerToken} callbacks from the singular
* and named provider configs before they cross the RPC boundary, replacing each
* with a `bearerTokenProvider: true` wire flag. Any configured
* {@link ProviderConfig.bearerTokenScope} is forwarded verbatim (the bearer-token
* surface is provider-agnostic, so the SDK never substitutes a default scope).
* Returns wire-safe provider configs alongside a map of provider name → callback
* for session-side registration.
*/
function extractBearerTokenProviders(
provider: ProviderConfig | undefined,
providers: NamedProviderConfig[] | undefined
): {
wireProvider: WireProviderConfig | undefined;
wireProviders: WireNamedProviderConfig[] | undefined;
callbacks: Map<string, GetBearerToken>;
} {
const callbacks = new Map<string, GetBearerToken>();

let wireProvider: WireProviderConfig | undefined = provider;
if (provider?.getBearerToken) {
const { getBearerToken, ...rest } = provider;
callbacks.set(DEFAULT_PROVIDER_NAME, getBearerToken);
wireProvider = {
...rest,
bearerTokenProvider: true,
};
}

let wireProviders: WireNamedProviderConfig[] | undefined = providers;
if (providers?.some((p) => p.getBearerToken)) {
wireProviders = providers.map((p) => {
if (!p.getBearerToken) return p;
const { getBearerToken, ...rest } = p;
callbacks.set(p.name, getBearerToken);
return {
...rest,
bearerTokenProvider: true,
};
});
}

return { wireProvider, wireProviders, callbacks };
}

/**
* Convert MCP server configs from public API format (workingDirectory) to
* wire format (cwd) expected by the runtime.
Expand Down Expand Up @@ -1161,6 +1220,15 @@ export class CopilotClient {
const useServerGeneratedId = config.cloud != null && callerSessionId == null;
const localSessionId = useServerGeneratedId ? undefined : (callerSessionId ?? randomUUID());

// Strip non-serializable getBearerToken callbacks from provider configs,
// replacing them with a wire flag; keep the callbacks for session-side
// registration so the runtime can call back to acquire tokens.
const {
wireProvider: bearerWireProvider,
wireProviders: bearerWireProviders,
callbacks: bearerTokenCallbacks,
} = extractBearerTokenProviders(config.provider, config.providers);

// Extract transform callbacks from system message config before serialization.
const { wirePayload: wireSystemMessage, transformCallbacks } = extractTransformCallbacks(
config.systemMessage
Expand All @@ -1178,6 +1246,9 @@ export class CopilotClient {
s.registerTools(config.tools);
s.registerCanvases(config.canvases);
s.registerCommands(config.commands);
if (bearerTokenCallbacks.size > 0) {
s.registerBearerTokenProviders(bearerTokenCallbacks);
}
s.registerPermissionHandler(config.onPermissionRequest);
if (config.onUserInputRequest) {
s.registerUserInputHandler(config.onUserInputRequest);
Expand Down Expand Up @@ -1249,8 +1320,8 @@ export class CopilotClient {
availableTools: toolFilterOptions.availableTools,
excludedTools: toolFilterOptions.excludedTools,
toolFilterPrecedence: toolFilterOptions.toolFilterPrecedence,
provider: config.provider,
providers: config.providers,
provider: bearerWireProvider,
providers: bearerWireProviders,
models: config.models,
enableSessionTelemetry: config.enableSessionTelemetry,
modelCapabilities: config.modelCapabilities,
Expand Down Expand Up @@ -1369,6 +1440,14 @@ export class CopilotClient {
session.registerTools(config.tools);
session.registerCanvases(config.canvases);
session.registerCommands(config.commands);
const {
wireProvider: bearerWireProvider,
wireProviders: bearerWireProviders,
callbacks: bearerTokenCallbacks,
} = extractBearerTokenProviders(config.provider, config.providers);
if (bearerTokenCallbacks.size > 0) {
session.registerBearerTokenProviders(bearerTokenCallbacks);
}
session.registerPermissionHandler(config.onPermissionRequest);
if (config.onUserInputRequest) {
session.registerUserInputHandler(config.onUserInputRequest);
Expand Down Expand Up @@ -1435,8 +1514,8 @@ export class CopilotClient {
name: cmd.name,
description: cmd.description,
})),
provider: config.provider,
providers: config.providers,
provider: bearerWireProvider,
providers: bearerWireProviders,
models: config.models,
modelCapabilities: config.modelCapabilities,
largeOutput: toWireLargeOutput(config.largeOutput),
Expand Down
73 changes: 73 additions & 0 deletions nodejs/src/generated/rpc.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 48 additions & 0 deletions nodejs/src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import type {
ExitPlanModeHandler,
ExitPlanModeRequest,
ExitPlanModeResult,
GetBearerToken,
UiInputOptions,
MessageOptions,
PermissionHandler,
Expand Down Expand Up @@ -122,6 +123,7 @@ export class CopilotSession {
new Map();
private toolHandlers: Map<string, ToolHandler> = new Map();
private canvases: Map<string, Canvas> = new Map();
private bearerTokenProviders: Map<string, GetBearerToken> = new Map();
private commandHandlers: Map<string, CommandHandler> = new Map();
private permissionHandler?: PermissionHandler;
private userInputHandler?: UserInputHandler;
Expand Down Expand Up @@ -759,6 +761,52 @@ export class CopilotSession {
};
}

/**
* Registers per-provider {@link GetBearerToken} callbacks for BYOK providers
* configured with managed-identity / on-demand bearer-token auth.
*
* The runtime never receives the callback itself; the SDK strips it from the
* provider config and instead sends `bearerTokenProvider: true`. When the
* runtime needs a token it issues a session-scoped `providerToken.acquire`
* request, which this handler routes to the matching per-provider callback.
*
* @param providers - Map of provider name → callback, or undefined/empty to clear.
* @internal This method is called internally when creating/resuming a session.
*/
registerBearerTokenProviders(providers?: Map<string, GetBearerToken>): void {
this.bearerTokenProviders.clear();
if (!providers || providers.size === 0) {
delete this.clientSessionApis.providerToken;
return;
}
for (const [name, callback] of providers) {
this.bearerTokenProviders.set(name, callback);
}

const self = this;
this.clientSessionApis.providerToken = {
async acquire(params) {
const callback = self.bearerTokenProviders.get(params.providerName);
if (!callback) {
throw new Error(
`No bearer-token provider registered for provider "${params.providerName}"`
);
}
const result = await callback({
providerName: params.providerName,
scope: params.scope,
});
if (typeof result === "string") {
return { token: result };
}
return {
token: result.token,
expiresOnTimestamp: result.expiresOnTimestamp,
};
},
};
}

/**
* Registers command handlers for this session.
*
Expand Down
Loading
Loading