diff --git a/dotnet/src/BearerTokenProvider.cs b/dotnet/src/BearerTokenProvider.cs new file mode 100644 index 000000000..2c59da09b --- /dev/null +++ b/dotnet/src/BearerTokenProvider.cs @@ -0,0 +1,32 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Diagnostics.CodeAnalysis; + +namespace GitHub.Copilot; + +/// +/// Arguments passed to a bearer-token callback (the GetBearerToken property +/// on / ) when the +/// runtime needs a fresh bearer token for a BYOK provider. +/// +/// +/// Part of the experimental managed-identity / bearer-token-provider surface and +/// may change or be removed in future SDK or CLI releases. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class ProviderTokenArgs +{ + /// + /// Name of the BYOK provider needing a token. For the singular, whole-session + /// this is the implicit provider name + /// ("default"); for entries it is + /// . + /// + /// + /// The callback closes over its own token scope/audience; the runtime is + /// provider-agnostic and forwards only the provider name. + /// + public required string ProviderName { get; init; } +} diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 9859dec90..4c81af392 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -652,6 +652,7 @@ private CopilotSession InitializeSession( } ConfigureSessionFsHandlers(session, config.CreateSessionFsProvider); session.SetCanvasHandler(config.CanvasHandler); + session.RegisterBearerTokenProviders(BuildBearerTokenCallbacks(config)); RegisterSession(session); session.StartProcessingEvents(); LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, @@ -664,6 +665,37 @@ private CopilotSession InitializeSession( return session; } + /// + /// Implicit provider name for the singular, whole-session . + /// + private const string DefaultBearerTokenProviderName = "default"; + + /// + /// Collects the per-provider GetBearerToken callbacks keyed by + /// provider name for session-side registration. The singular, whole-session + /// uses the implicit + /// . + /// + private static Dictionary>> BuildBearerTokenCallbacks(SessionConfigBase config) + { + var callbacks = new Dictionary>>(StringComparer.Ordinal); + if (config.Provider?.GetBearerToken is { } singular) + { + callbacks[DefaultBearerTokenProviderName] = singular; + } + if (config.Providers != null) + { + foreach (var provider in config.Providers) + { + if (provider.GetBearerToken is { } callback) + { + callbacks[provider.Name] = callback; + } + } + } + return callbacks; + } + /// /// Catches misuse of / /// at the SDK boundary so @@ -839,7 +871,6 @@ private async Task UpdateSessionOptionsForModeAsync(CopilotSession session, Sess try { -#pragma warning disable GHCP001 await session.Rpc.Options.UpdateAsync( skipCustomInstructions: skipCustomInstructions, customAgentsLocalOnly: customAgentsLocalOnly, @@ -847,7 +878,6 @@ await session.Rpc.Options.UpdateAsync( manageScheduleEnabled: manageScheduleEnabled, installedPlugins: installedPlugins, cancellationToken: cancellationToken).ConfigureAwait(false); -#pragma warning restore GHCP001 } catch { @@ -2436,7 +2466,6 @@ internal record CreateSessionRequest( IList? PluginDirectories = null, LargeToolOutputConfig? LargeOutput = null, MemoryConfiguration? Memory = null, -#pragma warning disable GHCP001 IList? Canvases = null, bool? RequestCanvasRenderer = null, bool? RequestExtensions = null, @@ -2445,7 +2474,6 @@ internal record CreateSessionRequest( IList? Providers = null, IList? Models = null, OptionsUpdateToolFilterPrecedence? ToolFilterPrecedence = null); -#pragma warning restore GHCP001 internal record ToolDefinition( string Name, @@ -2471,9 +2499,7 @@ internal record CreateSessionResponse( string SessionId, string? WorkspacePath, SessionCapabilities? Capabilities = null, -#pragma warning disable GHCP001 IList? OpenCanvases = null); -#pragma warning restore GHCP001 internal record ResumeSessionRequest( string SessionId, @@ -2530,7 +2556,6 @@ internal record ResumeSessionRequest( IList? PluginDirectories = null, LargeToolOutputConfig? LargeOutput = null, MemoryConfiguration? Memory = null, -#pragma warning disable GHCP001 IList? Canvases = null, bool? RequestCanvasRenderer = null, bool? RequestExtensions = null, @@ -2540,15 +2565,12 @@ internal record ResumeSessionRequest( IList? Providers = null, IList? Models = null, OptionsUpdateToolFilterPrecedence? ToolFilterPrecedence = null); -#pragma warning restore GHCP001 internal record ResumeSessionResponse( string SessionId, string? WorkspacePath, SessionCapabilities? Capabilities = null, -#pragma warning disable GHCP001 IList? OpenCanvases = null); -#pragma warning restore GHCP001 internal record CommandWireDefinition( string Name, diff --git a/dotnet/src/Generated/Rpc.cs b/dotnet/src/Generated/Rpc.cs index 0b4202024..c82f6cbc2 100644 --- a/dotnet/src/Generated/Rpc.cs +++ b/dotnet/src/Generated/Rpc.cs @@ -6215,6 +6215,10 @@ public sealed class NamedProviderConfig [JsonPropertyName("bearerToken")] public string? BearerToken { get; set; } + /// When true, the SDK client supplies bearer tokens on demand: the runtime calls the client-session `providerToken.getToken` callback before each request and uses the returned token as the Authorization header. The token-acquiring function itself stays on the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive with `apiKey`/`bearerToken`. + [JsonPropertyName("hasBearerTokenProvider")] + public bool? HasBearerTokenProvider { get; set; } + /// Custom HTTP headers to include in all outbound requests to the provider. [JsonPropertyName("headers")] public IDictionary? Headers { get; set; } @@ -6362,6 +6366,10 @@ public sealed class ProviderConfig [JsonPropertyName("bearerToken")] public string? BearerToken { get; set; } + /// When true, the SDK client supplies bearer tokens on demand: the runtime calls the client-session `providerToken.getToken` callback before each request and uses the returned token as the Authorization header. The token-acquiring function itself stays on the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive with `apiKey`/`bearerToken`. + [JsonPropertyName("hasBearerTokenProvider")] + public bool? HasBearerTokenProvider { get; set; } + /// Custom HTTP headers to include in all outbound requests to the provider. [JsonPropertyName("headers")] public IDictionary? Headers { get; set; } @@ -10743,6 +10751,27 @@ public sealed class CanvasProviderInvokeActionRequest public string SessionId { get; set; } = string.Empty; } +/// A bearer token supplied by the SDK client for a BYOK provider. The runtime sets it as `Authorization: Bearer <token>` on the outbound request and does no caching; the SDK consumer owns token caching and refresh. +[Experimental(Diagnostics.Experimental)] +public sealed class ProviderTokenAcquireResult +{ + /// The bearer token value (without the `Bearer ` prefix). + [JsonPropertyName("token")] + public string Token { get; set; } = string.Empty; +} + +/// Asks the SDK client to acquire a bearer token for a BYOK provider whose config set `hasBearerTokenProvider: true`. Issued by the runtime before each outbound model request; the runtime does no caching, so this is sent once per request. +public sealed class ProviderTokenAcquireRequest +{ + /// Name of the BYOK provider needing a token. For the legacy whole-session `provider` this is the implicit provider name; for named providers it is `NamedProviderConfig.name`. + [JsonPropertyName("providerName")] + public string ProviderName { get; set; } = string.Empty; + + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; +} + /// Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. [Experimental(Diagnostics.Experimental)] public sealed class LlmInferenceHttpRequestStartResult @@ -20553,6 +20582,17 @@ public interface ICanvasHandler Task InvokeAsync(CanvasProviderInvokeActionRequest request, CancellationToken cancellationToken = default); } +/// Handles `providerToken` client session API methods. +[Experimental(Diagnostics.Experimental)] +public interface IProviderTokenHandler +{ + /// Asks the SDK client to get a bearer token for a BYOK provider whose config set `hasBearerTokenProvider: true`. Session-scoped: the runtime calls it back on the connection that created the session, passing the provider name, and uses the returned token as the Authorization header for the outbound model request. The runtime does no caching — it calls this once per outbound request; the SDK consumer owns token acquisition, caching, and refresh. + /// Asks the SDK client to acquire a bearer token for a BYOK provider whose config set `hasBearerTokenProvider: true`. Issued by the runtime before each outbound model request; the runtime does no caching, so this is sent once per request. + /// The to monitor for cancellation requests. The default is . + /// A bearer token supplied by the SDK client for a BYOK provider. The runtime sets it as `Authorization: Bearer <token>` on the outbound request and does no caching; the SDK consumer owns token caching and refresh. + Task GetTokenAsync(ProviderTokenAcquireRequest request, CancellationToken cancellationToken = default); +} + /// Provides all client session API handler groups for a session. public sealed class ClientSessionApiHandlers { @@ -20561,6 +20601,9 @@ public sealed class ClientSessionApiHandlers /// Optional handler for Canvas client session API methods. public ICanvasHandler? Canvas { get; set; } + + /// Optional handler for ProviderToken client session API methods. + public IProviderTokenHandler? ProviderToken { get; set; } } /// Registers client session API handlers on a JSON-RPC connection. @@ -20663,6 +20706,12 @@ public static void RegisterClientSessionApiHandlers(JsonRpc rpc, Func>)(async (request, cancellationToken) => + { + var handler = getHandlers(request.SessionId).ProviderToken; + if (handler is null) throw new InvalidOperationException($"No providerToken handler registered for session: {request.SessionId}"); + return await handler.GetTokenAsync(request, cancellationToken); + }), singleObjectParam: true); } } @@ -21310,6 +21359,8 @@ public static void RegisterClientGlobalApiHandlers(JsonRpc rpc, ClientGlobalApiH [JsonSerializable(typeof(ProviderGetEndpointRequestWithSession))] [JsonSerializable(typeof(ProviderModelConfig))] [JsonSerializable(typeof(ProviderSessionToken))] +[JsonSerializable(typeof(ProviderTokenAcquireRequest))] +[JsonSerializable(typeof(ProviderTokenAcquireResult))] [JsonSerializable(typeof(PushAttachment))] [JsonSerializable(typeof(PushAttachmentFileLineRange))] [JsonSerializable(typeof(PushAttachmentSelectionDetails))] diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 095c1abf7..8ba3c807a 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -58,6 +58,7 @@ public sealed partial class CopilotSession : IAsyncDisposable { private readonly Dictionary _toolHandlers = []; private readonly Dictionary> _commandHandlers = []; + private readonly Dictionary>> _bearerTokenProviders = new(StringComparer.Ordinal); private readonly ILogger _logger; private readonly CopilotClient _parentClient; @@ -76,9 +77,7 @@ private sealed record EventSubscription(Type EventType, Action Han private Dictionary>>? _transformCallbacks; private readonly SemaphoreSlim _transformCallbacksLock = new(1, 1); -#pragma warning disable GHCP001 private IReadOnlyList _openCanvases = Array.Empty(); -#pragma warning restore GHCP001 private int _isDisposed; @@ -126,7 +125,6 @@ public SessionCapabilities Capabilities private set; } -#pragma warning disable GHCP001 /// /// Canvas instances currently known to be open for this session. /// @@ -136,7 +134,6 @@ public SessionCapabilities Capabilities /// [Experimental(Diagnostics.Experimental)] public IReadOnlyList OpenCanvases => _openCanvases; -#pragma warning restore GHCP001 /// /// Gets the UI API for eliciting information from the user during this session. @@ -873,6 +870,51 @@ internal void RegisterAutoModeSwitchHandler(Func + /// Registers per-provider GetBearerToken callbacks for BYOK + /// providers configured with managed-identity / on-demand bearer-token auth. + /// + /// + /// The runtime never receives the callback itself; the SDK strips it from the + /// provider config and instead sends hasBearerTokenProvider: true. When + /// the runtime needs a token it issues a session-scoped + /// providerToken.getToken request, which this handler routes to the + /// matching per-provider callback. + /// + /// Map of provider name to callback, or null/empty to clear. + internal void RegisterBearerTokenProviders(IReadOnlyDictionary>>? providers) + { + _bearerTokenProviders.Clear(); + if (providers is null || providers.Count == 0) + { + ClientSessionApis.ProviderToken = null; + return; + } + foreach (var (name, callback) in providers) + { + _bearerTokenProviders[name] = callback; + } + ClientSessionApis.ProviderToken = new BearerTokenProviderHandler(this); + } + + /// + /// Routes runtime providerToken.getToken requests to the matching + /// per-provider GetBearerToken callback registered on the session. + /// + private sealed class BearerTokenProviderHandler(CopilotSession session) : IProviderTokenHandler + { + public async Task GetTokenAsync(ProviderTokenAcquireRequest request, CancellationToken cancellationToken = default) + { + if (!session._bearerTokenProviders.TryGetValue(request.ProviderName, out var callback)) + { + throw new InvalidOperationException( + $"No bearer-token provider registered for provider \"{request.ProviderName}\""); + } + var token = await callback(new ProviderTokenArgs { ProviderName = request.ProviderName }).ConfigureAwait(false); + return new ProviderTokenAcquireResult { Token = token }; + } + } + /// /// Sets the capabilities reported by the host for this session. /// @@ -882,7 +924,6 @@ internal void SetCapabilities(SessionCapabilities? capabilities) Capabilities = capabilities ?? new SessionCapabilities(); } -#pragma warning disable GHCP001 internal void SetOpenCanvases(IList? canvases) { _openCanvases = canvases is { Count: > 0 } @@ -962,7 +1003,6 @@ private static JsonElement SerializeActionResult(object? value) var element = CopilotClient.ToJsonElementForWire(value); return element ?? NullJsonElement; } -#pragma warning restore GHCP001 private sealed class CanvasHandlerAdapter(ICanvasHandler handler) : Rpc.ICanvasHandler { diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 5f9d8f861..11b3e7348 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -5,12 +5,14 @@ using GitHub.Copilot.Rpc; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; +using System; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using System.Threading.Tasks; namespace GitHub.Copilot; @@ -2032,6 +2034,28 @@ public sealed class ProviderConfig [JsonPropertyName("bearerToken")] public string? BearerToken { get; set; } + /// + /// Wire-only flag, emitted automatically when is set, that tells + /// the runtime to request a token over the session-scoped providerToken.getToken RPC + /// before each outbound request to this provider. Derived from ; + /// internal and never part of the public API. + /// + [JsonInclude] + [JsonPropertyName("hasBearerTokenProvider")] + internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null; + + /// + /// Per-request callback that resolves a bearer token on demand for this BYOK provider (for + /// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a + /// callback backed by your own identity library. Never serialized — setting it makes the SDK send + /// hasBearerTokenProvider: true on the wire and answer the runtime's + /// providerToken.getToken requests. Mutually exclusive with and + /// . + /// + [JsonIgnore] + [Experimental(Diagnostics.Experimental)] + public Func>? GetBearerToken { get; set; } + /// /// Azure-specific configuration options. /// @@ -2164,6 +2188,28 @@ public sealed class NamedProviderConfig [JsonPropertyName("bearerToken")] public string? BearerToken { get; set; } + /// + /// Wire-only flag, emitted automatically when is set, that tells + /// the runtime to request a token over the session-scoped providerToken.getToken RPC + /// before each outbound request to this provider. Derived from ; + /// internal and never part of the public API. + /// + [JsonInclude] + [JsonPropertyName("hasBearerTokenProvider")] + internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null; + + /// + /// Per-request callback that resolves a bearer token on demand for this BYOK provider (for + /// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a + /// callback backed by your own identity library. Never serialized — setting it makes the SDK send + /// hasBearerTokenProvider: true on the wire and answer the runtime's + /// providerToken.getToken requests. Mutually exclusive with and + /// . + /// + [JsonIgnore] + [Experimental(Diagnostics.Experimental)] + public Func>? GetBearerToken { get; set; } + /// /// Azure-specific configuration options. /// @@ -2679,14 +2725,12 @@ protected SessionConfigBase(SessionConfigBase? other) CreateSessionFsProvider = other.CreateSessionFsProvider; GitHubToken = other.GitHubToken; RemoteSession = other.RemoteSession; -#pragma warning disable GHCP001 Canvases = other.Canvases is not null ? [.. other.Canvases] : null; RequestCanvasRenderer = other.RequestCanvasRenderer; RequestExtensions = other.RequestExtensions; ExtensionSdkPath = other.ExtensionSdkPath; ExtensionInfo = other.ExtensionInfo; CanvasHandler = other.CanvasHandler; -#pragma warning restore GHCP001 SkillDirectories = other.SkillDirectories is not null ? [.. other.SkillDirectories] : null; PluginDirectories = other.PluginDirectories is not null ? [.. other.PluginDirectories] : null; InstructionDirectories = other.InstructionDirectories is not null ? [.. other.InstructionDirectories] : null; @@ -3055,7 +3099,6 @@ protected SessionConfigBase(SessionConfigBase? other) /// public RemoteSessionMode? RemoteSession { get; set; } -#pragma warning disable GHCP001 /// /// Canvas declarations advertised by this connection. The runtime forwards /// these to the agent and routes inbound canvas.* requests for any @@ -3104,7 +3147,6 @@ protected SessionConfigBase(SessionConfigBase? other) [Experimental(Diagnostics.Experimental)] [JsonIgnore] public ICanvasHandler? CanvasHandler { get; set; } -#pragma warning restore GHCP001 } /// @@ -3190,7 +3232,6 @@ private ResumeSessionConfig(ResumeSessionConfig? other) : base(other) /// public bool? ContinuePendingWork { get; set; } -#pragma warning disable GHCP001 /// /// Snapshot of canvases that were already open when the session was suspended. /// When provided on resume, the runtime can rehydrate canvas state so consumers @@ -3198,7 +3239,6 @@ private ResumeSessionConfig(ResumeSessionConfig? other) : base(other) /// [Experimental(Diagnostics.Experimental)] public IList? OpenCanvases { get; set; } -#pragma warning restore GHCP001 /// /// Creates a shallow clone of this instance. @@ -3767,10 +3807,8 @@ public sealed class SystemMessageTransformRpcResponse [JsonSerializable(typeof(object))] [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(string[]))] -#pragma warning disable GHCP001 [JsonSerializable(typeof(CanvasDeclaration))] [JsonSerializable(typeof(CanvasProviderOpenResult))] [JsonSerializable(typeof(CanvasHostContext))] [JsonSerializable(typeof(ExtensionInfo))] -#pragma warning restore GHCP001 internal partial class TypesJsonContext : JsonSerializerContext; diff --git a/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs b/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs new file mode 100644 index 000000000..3f869a437 --- /dev/null +++ b/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs @@ -0,0 +1,287 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Collections.Concurrent; +using System.Net; +using System.Net.Http; +using GitHub.Copilot.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.Test.E2E; + +/// +/// End-to-end coverage for the experimental BYOK bearer-token-provider surface +/// (GetBearerToken on a provider config). The callback stays entirely on +/// the SDK/client side: the SDK strips it from the wire config, sets the +/// hasBearerTokenProvider flag, and the runtime calls back over the +/// session-scoped providerToken.getToken RPC before each outbound model +/// request, applying the returned token as the Authorization header. +/// +/// +/// +/// These tests mirror the Node SDK's byok_bearer_token_provider.e2e.test.ts. +/// Rather than standing up a real HTTP listener, each test installs a +/// that intercepts the runtime's outbound +/// model request in-process, captures the Authorization header, and +/// returns a synthetic response — so nothing touches the network and there is no +/// CAPI proxy acting as the inference endpoint. They validate, against a real +/// runtime: +/// +/// +/// the callback's token reaches the model request as Authorization: Bearer <token>; +/// the runtime re-acquires a token per request (no runtime-side caching); +/// per-provider dispatch routes each provider's turn to its own callback, +/// and the resulting token reaches that provider's endpoint. +/// +/// +public class ByokBearerTokenProviderE2ETests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "byok_bearer_token_provider", output) +{ + // Fake BYOK provider hosts. These are never actually dialed: the request + // handler fully answers any request aimed at a `.invalid` host, so they only + // need to be syntactically valid, non-resolving URLs. Distinct hosts let the + // per-provider test assert routing by host. + private const string PrimaryHost = "byok-endpoint.invalid"; + private const string PrimaryBaseUrl = $"https://{PrimaryHost}/v1"; + private const string RedHost = "byok-red.invalid"; + private const string RedBaseUrl = $"https://{RedHost}/v1"; + private const string BlueHost = "byok-blue.invalid"; + private const string BlueBaseUrl = $"https://{BlueHost}/v1"; + + private CopilotClient CreateClientWith(CapturingRequestHandler handler) => + Ctx.CreateClient(options: new CopilotClientOptions + { + Connection = RuntimeConnection.ForStdio(), + RequestHandler = handler, + }); + + /// + /// Drives one BYOK turn against the given providers/models. The capturing + /// handler 404s the BYOK request, which errors the turn after the runtime has + /// already applied the (token-bearing) Authorization header — which is + /// all these tests assert on. The resulting error is swallowed. + /// + private static async Task RunTurnAsync( + CopilotClient client, + IList providers, + IList models, + string selectionId, + string prompt) + { + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Model = selectionId, + Providers = providers, + Models = models, + }); + try + { + await session.SendAndWaitAsync(new MessageOptions { Prompt = prompt }); + } + catch + { + // The handler always 404s the BYOK endpoint, so the turn errors after + // the token-bearing request was already captured. Expected. + } + finally + { + await session.DisposeAsync(); + } + } + + [Fact] + public async Task Applies_The_Callbacks_Token_As_The_Authorization_Header() + { + const string sentinel = "sentinel-bearer-token-abc123"; + var calls = 0; + + var handler = new CapturingRequestHandler(); + await using var client = CreateClientWith(handler); + await client.StartAsync(); + + var providers = new List + { + new() + { + Name = "mi", + Type = "openai", + WireApi = "completions", + BaseUrl = PrimaryBaseUrl, + GetBearerToken = _ => + { + Interlocked.Increment(ref calls); + return Task.FromResult(sentinel); + }, + }, + }; + var models = new List + { + new() { Id = "default", Provider = "mi", WireModel = "byok-gpt-4o" }, + }; + + await RunTurnAsync(client, providers, models, "mi/default", "What is 5+5?"); + + // The runtime acquired a token via the callback and applied it verbatim as + // the bearer credential on the outbound model request. + Assert.Contains($"Bearer {sentinel}", handler.AuthHeaders()); + Assert.True(calls >= 1, "Expected the bearer-token callback to be invoked at least once."); + } + + [Fact] + public async Task Re_Acquires_A_Fresh_Token_For_Each_Request() + { + var calls = 0; + + var handler = new CapturingRequestHandler(); + await using var client = CreateClientWith(handler); + await client.StartAsync(); + + var providers = new List + { + new() + { + Name = "mi", + Type = "openai", + WireApi = "completions", + BaseUrl = PrimaryBaseUrl, + // A distinct token per acquisition proves the runtime re-invokes + // the callback per request rather than caching a previous token. + GetBearerToken = _ => + { + var n = Interlocked.Increment(ref calls); + return Task.FromResult($"rotating-token-{n}"); + }, + }, + }; + var models = new List + { + new() { Id = "default", Provider = "mi", WireModel = "byok-gpt-4o" }, + }; + + await RunTurnAsync(client, providers, models, "mi/default", "What is 1+1?"); + await RunTurnAsync(client, providers, models, "mi/default", "What is 2+2?"); + + // Each outbound request carries a freshly-acquired, distinct token. + var auths = handler.AuthHeaders(); + Assert.True(auths.Count >= 2, $"Expected at least 2 captured Authorization headers, saw {auths.Count}."); + Assert.Matches(@"^Bearer rotating-token-\d+$", auths[0]); + Assert.Matches(@"^Bearer rotating-token-\d+$", auths[1]); + Assert.NotEqual(auths[0], auths[1]); + Assert.True(calls >= 2, "Expected the bearer-token callback to be invoked at least twice."); + } + + [Fact] + public async Task Dispatches_Token_Acquisition_Per_Provider() + { + var tokenByProvider = new Dictionary + { + ["red"] = "token-for-red", + ["blue"] = "token-for-blue", + }; + var acquiredFor = new ConcurrentBag(); + + Func> MakeCallback(string providerName) => + args => + { + // The runtime forwards the requesting provider's name so the client + // can dispatch to the right credential. + Assert.Equal(providerName, args.ProviderName); + acquiredFor.Add(providerName); + return Task.FromResult(tokenByProvider[providerName]); + }; + + var handler = new CapturingRequestHandler(); + await using var client = CreateClientWith(handler); + await client.StartAsync(); + + var providers = new List + { + new() + { + Name = "red", + Type = "openai", + WireApi = "completions", + BaseUrl = RedBaseUrl, + GetBearerToken = MakeCallback("red"), + }, + new() + { + Name = "blue", + Type = "openai", + WireApi = "completions", + BaseUrl = BlueBaseUrl, + GetBearerToken = MakeCallback("blue"), + }, + }; + var models = new List + { + new() { Id = "default", Provider = "red", WireModel = "byok-gpt-4o" }, + new() { Id = "default", Provider = "blue", WireModel = "byok-gpt-4o" }, + }; + + await RunTurnAsync(client, providers, models, "red/default", "What is 3+3?"); + await RunTurnAsync(client, providers, models, "blue/default", "What is 4+4?"); + + // Each provider's turn was authenticated with its own token AND that token + // was delivered to that provider's endpoint, proving per-provider dispatch + // (not a single session-global credential). + Assert.Equal($"Bearer {tokenByProvider["red"]}", handler.AuthHeaderForHost(RedHost)); + Assert.Equal($"Bearer {tokenByProvider["blue"]}", handler.AuthHeaderForHost(BlueHost)); + Assert.Contains("red", acquiredFor); + Assert.Contains("blue", acquiredFor); + } +} + +/// +/// A used in place of a real HTTP listener. +/// The runtime invokes for every model-layer HTTP +/// request. Requests aimed at a fake BYOK host (*.invalid) are captured — +/// recording the Authorization header the runtime applied after calling +/// the provider's GetBearerToken callback over the session-scoped +/// providerToken.getToken RPC — and answered with a synthetic 404 +/// (a non-retryable status, so each outbound model request yields exactly one +/// capture). Every other request (CAPI bootstrap: model catalog, policy, …) is +/// served a synthetic well-formed response so the bootstrap never touches the +/// network. +/// +internal sealed class CapturingRequestHandler : CopilotRequestHandler +{ + private readonly ConcurrentQueue _captures = new(); + + protected override Task SendRequestAsync(HttpRequestMessage request, CopilotRequestContext ctx) + { + var uri = request.RequestUri!; + if (uri.Host.EndsWith(".invalid", StringComparison.Ordinal)) + { + _captures.Enqueue(new CapturedRequest( + uri.Host, + request.Headers.TryGetValues("Authorization", out var values) + ? string.Join(", ", values) + : null)); + + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound) + { + Content = new StringContent( + "{\"error\":{\"message\":\"fake byok endpoint\"}}", + System.Text.Encoding.UTF8, + "application/json"), + }); + } + + // CAPI bootstrap (model catalog, policy, …) — answered off-network. + return Task.FromResult(RecordingRequestHandler.BuildNonInferenceResponse(uri.ToString())); + } + + /// The Authorization headers captured across BYOK requests, in arrival order. + public IReadOnlyList AuthHeaders() => + [.. _captures.Select(c => c.Authorization).Where(v => v is not null).Cast()]; + + /// The Authorization header captured for requests aimed at , if any. + public string? AuthHeaderForHost(string host) => + _captures.FirstOrDefault(c => string.Equals(c.Host, host, StringComparison.Ordinal))?.Authorization; + + private sealed record CapturedRequest(string Host, string? Authorization); +} diff --git a/go/client.go b/go/client.go index 5dc44e027..8387d7e3e 100644 --- a/go/client.go +++ b/go/client.go @@ -53,6 +53,30 @@ import ( "github.com/github/copilot-sdk/go/rpc" ) +// defaultBearerTokenProviderName is the implicit provider name for the singular, +// whole-session [ProviderConfig]. Named providers are keyed by their own Name. +const defaultBearerTokenProviderName = "default" + +// collectBearerTokenProviders gathers the per-provider [GetBearerToken] callbacks +// from the singular provider and any named providers, keyed by provider name. The +// singular provider uses the implicit name "default"; named providers use their +// own Name. Returns nil when no callbacks are configured. +func collectBearerTokenProviders(provider *ProviderConfig, providers []NamedProviderConfig) map[string]GetBearerToken { + callbacks := make(map[string]GetBearerToken) + if provider != nil && provider.GetBearerToken != nil { + callbacks[defaultBearerTokenProviderName] = provider.GetBearerToken + } + for i := range providers { + if providers[i].GetBearerToken != nil { + callbacks[providers[i].Name] = providers[i].GetBearerToken + } + } + if len(callbacks) == 0 { + return nil + } + return callbacks +} + func validateSessionFSConfig(config *SessionFSConfig) error { if config == nil { return nil @@ -808,6 +832,9 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses if config.CanvasHandler != nil { s.registerCanvasHandler(config.CanvasHandler) } + if bearerTokenProviders := collectBearerTokenProviders(config.Provider, config.Providers); bearerTokenProviders != nil { + s.registerBearerTokenProviders(bearerTokenProviders) + } c.sessionsMux.Lock() c.sessions[sessionID] = s @@ -1104,6 +1131,9 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, if config.CanvasHandler != nil { session.registerCanvasHandler(config.CanvasHandler) } + if bearerTokenProviders := collectBearerTokenProviders(config.Provider, config.Providers); bearerTokenProviders != nil { + session.registerBearerTokenProviders(bearerTokenProviders) + } c.sessionsMux.Lock() c.sessions[sessionID] = session diff --git a/go/internal/e2e/byok_bearer_token_provider_e2e_test.go b/go/internal/e2e/byok_bearer_token_provider_e2e_test.go new file mode 100644 index 000000000..6a6e5cbc2 --- /dev/null +++ b/go/internal/e2e/byok_bearer_token_provider_e2e_test.go @@ -0,0 +1,284 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "net/http" + "strconv" + "strings" + "sync" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +// Fake BYOK provider base URLs. These hosts are never actually dialed: the +// capturing RoundTripper fully answers any request aimed at a `.invalid` host, +// so they only need to be syntactically valid, non-resolving URLs. Distinct +// hosts let the per-provider test assert routing by host. +const ( + byokPrimaryHost = "byok-endpoint.invalid" + byokPrimaryBaseURL = "https://" + byokPrimaryHost + "/v1" + byokRedHost = "byok-red.invalid" + byokRedBaseURL = "https://" + byokRedHost + "/v1" + byokBlueHost = "byok-blue.invalid" + byokBlueBaseURL = "https://" + byokBlueHost + "/v1" +) + +// capturedBYOKRequest records the host and Authorization header of one outbound +// HTTP request the runtime aimed at a fake BYOK provider endpoint. +type capturedBYOKRequest struct { + host string + authorization string +} + +// byokCapturingRoundTripper stands in for a real HTTP upstream. It records the +// `Authorization` header the runtime applied (after calling the provider's +// GetBearerToken callback over the session-scoped `providerToken.getToken` RPC) +// for every request aimed at a fake `.invalid` BYOK host, answering them with a +// synthetic 404 (a non-retryable status, so each outbound model request yields +// exactly one capture). Every other request (CAPI bootstrap: model catalog, +// policy, session) is fabricated locally so the test never touches the network. +type byokCapturingRoundTripper struct { + mu sync.Mutex + captures []capturedBYOKRequest +} + +func (rt *byokCapturingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if strings.HasSuffix(req.URL.Hostname(), ".invalid") { + rt.mu.Lock() + rt.captures = append(rt.captures, capturedBYOKRequest{ + host: req.URL.Host, + authorization: req.Header.Get("Authorization"), + }) + rt.mu.Unlock() + if req.Body != nil { + _ = req.Body.Close() + } + return buildJSONResponse(http.StatusNotFound, `{"error":{"message":"fake byok endpoint"}}`), nil + } + return buildNonInferenceResponse(req.URL.String()), nil +} + +// authHeaders returns the captured Authorization headers in arrival order. +func (rt *byokCapturingRoundTripper) authHeaders() []string { + rt.mu.Lock() + defer rt.mu.Unlock() + headers := make([]string, 0, len(rt.captures)) + for _, c := range rt.captures { + if c.authorization != "" { + headers = append(headers, c.authorization) + } + } + return headers +} + +// authHeaderForHost returns the Authorization header captured for requests aimed +// at host, if any. +func (rt *byokCapturingRoundTripper) authHeaderForHost(host string) string { + rt.mu.Lock() + defer rt.mu.Unlock() + for _, c := range rt.captures { + if c.host == host { + return c.authorization + } + } + return "" +} + +func (rt *byokCapturingRoundTripper) reset() { + rt.mu.Lock() + defer rt.mu.Unlock() + rt.captures = nil +} + +// TestBYOKBearerTokenProvider is end-to-end coverage for the experimental BYOK +// bearer-token-provider surface (GetBearerToken on a provider config). The +// callback stays entirely on the SDK/client side: the SDK strips it from the +// wire config, sets the `hasBearerTokenProvider` flag, and the runtime calls +// back over the session-scoped `providerToken.getToken` RPC before each outbound +// model request, applying the returned token as the `Authorization` header. +// +// Rather than standing up a real HTTP listener, the test installs a capturing +// RoundTripper that intercepts the runtime's outbound model request in-process, +// captures the `Authorization` header, and returns a synthetic response. It +// validates, against a real runtime: +// 1. the callback's token reaches the model request as `Authorization: Bearer `; +// 2. the runtime re-acquires a token per request (no runtime-side caching); +// 3. per-provider dispatch routes each provider's turn to its own callback, and +// the resulting token reaches that provider's endpoint. +func TestBYOKBearerTokenProvider(t *testing.T) { + ctx := testharness.NewTestContext(t) + rt := &byokCapturingRoundTripper{} + handler := &copilot.CopilotRequestHandler{Transport: rt} + + client := newCopilotRequestClient(ctx, handler) + t.Cleanup(func() { client.ForceStop() }) + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // runTurn drives one BYOK turn; the synthetic 404 errors the turn after the + // runtime has already sent the token-bearing request, which is all the test + // asserts on, so the resulting error is expected and swallowed. + runTurn := func(providers []copilot.NamedProviderConfig, models []copilot.ProviderModelConfig, selectionID, prompt string) { + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Model: selectionID, + Providers: providers, + Models: models, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + _, _ = session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: prompt}) + _ = session.Disconnect() + } + + t.Run("applies the callback's token as the Authorization header", func(t *testing.T) { + rt.reset() + const sentinel = "sentinel-bearer-token-abc123" + var mu sync.Mutex + calls := 0 + getBearerToken := func(args copilot.ProviderTokenArgs) (string, error) { + mu.Lock() + calls++ + mu.Unlock() + return sentinel, nil + } + + providers := []copilot.NamedProviderConfig{{ + Name: "mi", + Type: "openai", + WireAPI: "completions", + BaseURL: byokPrimaryBaseURL, + GetBearerToken: getBearerToken, + }} + models := []copilot.ProviderModelConfig{{ID: "default", Provider: "mi", WireModel: "byok-gpt-4o"}} + + runTurn(providers, models, "mi/default", "What is 5+5?") + + // The runtime acquired a token via the callback and applied it verbatim + // as the bearer credential on the outbound model request. + if !containsString(rt.authHeaders(), "Bearer "+sentinel) { + t.Fatalf("Expected captured Authorization headers to contain %q, got %v", "Bearer "+sentinel, rt.authHeaders()) + } + mu.Lock() + gotCalls := calls + mu.Unlock() + if gotCalls < 1 { + t.Fatalf("Expected the callback to be invoked at least once, got %d", gotCalls) + } + }) + + t.Run("re-acquires a fresh token for each request (no runtime caching)", func(t *testing.T) { + rt.reset() + var mu sync.Mutex + calls := 0 + getBearerToken := func(args copilot.ProviderTokenArgs) (string, error) { + mu.Lock() + calls++ + token := "rotating-token-" + strconv.Itoa(calls) + mu.Unlock() + // A distinct token per acquisition proves the runtime re-invokes the + // callback per request rather than caching a previous token. + return token, nil + } + + providers := []copilot.NamedProviderConfig{{ + Name: "mi", + Type: "openai", + WireAPI: "completions", + BaseURL: byokPrimaryBaseURL, + GetBearerToken: getBearerToken, + }} + models := []copilot.ProviderModelConfig{{ID: "default", Provider: "mi", WireModel: "byok-gpt-4o"}} + + runTurn(providers, models, "mi/default", "What is 1+1?") + runTurn(providers, models, "mi/default", "What is 2+2?") + + // Each outbound request carries a freshly-acquired, distinct token. + auths := rt.authHeaders() + if len(auths) < 2 { + t.Fatalf("Expected at least 2 captured Authorization headers, got %d: %v", len(auths), auths) + } + if !strings.HasPrefix(auths[0], "Bearer rotating-token-") || !strings.HasPrefix(auths[1], "Bearer rotating-token-") { + t.Fatalf("Expected rotating-token bearer headers, got %v", auths) + } + if auths[0] == auths[1] { + t.Fatalf("Expected distinct tokens per request, both were %q", auths[0]) + } + mu.Lock() + gotCalls := calls + mu.Unlock() + if gotCalls < 2 { + t.Fatalf("Expected the callback to be invoked at least twice, got %d", gotCalls) + } + }) + + t.Run("dispatches token acquisition per provider", func(t *testing.T) { + rt.reset() + tokenByProvider := map[string]string{ + "red": "token-for-red", + "blue": "token-for-blue", + } + var mu sync.Mutex + var acquiredFor []string + makeCallback := func(providerName string) copilot.GetBearerToken { + return func(args copilot.ProviderTokenArgs) (string, error) { + // The runtime forwards the requesting provider's name so the + // client can dispatch to the right credential. + if args.ProviderName != providerName { + t.Errorf("Expected providerName %q, got %q", providerName, args.ProviderName) + } + mu.Lock() + acquiredFor = append(acquiredFor, providerName) + mu.Unlock() + return tokenByProvider[providerName], nil + } + } + + providers := []copilot.NamedProviderConfig{ + { + Name: "red", + Type: "openai", + WireAPI: "completions", + BaseURL: byokRedBaseURL, + GetBearerToken: makeCallback("red"), + }, + { + Name: "blue", + Type: "openai", + WireAPI: "completions", + BaseURL: byokBlueBaseURL, + GetBearerToken: makeCallback("blue"), + }, + } + models := []copilot.ProviderModelConfig{ + {ID: "default", Provider: "red", WireModel: "byok-gpt-4o"}, + {ID: "default", Provider: "blue", WireModel: "byok-gpt-4o"}, + } + + runTurn(providers, models, "red/default", "What is 3+3?") + runTurn(providers, models, "blue/default", "What is 4+4?") + + // Each provider's turn was authenticated with its own token AND that + // token was delivered to that provider's endpoint, proving per-provider + // dispatch (not a single session-global credential). + if got := rt.authHeaderForHost(byokRedHost); got != "Bearer "+tokenByProvider["red"] { + t.Fatalf("Expected red host to receive %q, got %q", "Bearer "+tokenByProvider["red"], got) + } + if got := rt.authHeaderForHost(byokBlueHost); got != "Bearer "+tokenByProvider["blue"] { + t.Fatalf("Expected blue host to receive %q, got %q", "Bearer "+tokenByProvider["blue"], got) + } + mu.Lock() + got := append([]string(nil), acquiredFor...) + mu.Unlock() + if !containsString(got, "red") || !containsString(got, "blue") { + t.Fatalf("Expected both providers to acquire tokens, got %v", got) + } + }) +} diff --git a/go/rpc/zrpc.go b/go/rpc/zrpc.go index a105a712c..6d37ef9af 100644 --- a/go/rpc/zrpc.go +++ b/go/rpc/zrpc.go @@ -3434,6 +3434,12 @@ type NamedProviderConfig struct { // Bearer token for authentication. Sets the Authorization header directly. Takes precedence // over apiKey when both are set. BearerToken *string `json:"bearerToken,omitempty"` + // When true, the SDK client supplies bearer tokens on demand: the runtime calls the + // client-session `providerToken.getToken` callback before each request and uses the + // returned token as the Authorization header. The token-acquiring function itself stays on + // the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive + // with `apiKey`/`bearerToken`. + HasBearerTokenProvider *bool `json:"hasBearerTokenProvider,omitempty"` // Custom HTTP headers to include in all outbound requests to the provider. Headers map[string]string `json:"headers,omitzero"` // Stable identifier referenced by BYOK model definitions. Must not contain '/'. @@ -4949,6 +4955,12 @@ type ProviderConfig struct { // Bearer token for authentication. Sets the Authorization header directly. Takes precedence // over apiKey when both are set. BearerToken *string `json:"bearerToken,omitempty"` + // When true, the SDK client supplies bearer tokens on demand: the runtime calls the + // client-session `providerToken.getToken` callback before each request and uses the + // returned token as the Authorization header. The token-acquiring function itself stays on + // the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive + // with `apiKey`/`bearerToken`. + HasBearerTokenProvider *bool `json:"hasBearerTokenProvider,omitempty"` // Custom HTTP headers to include in all outbound requests to the provider. Headers map[string]string `json:"headers,omitzero"` // Maximum context window tokens for the model. @@ -5050,6 +5062,29 @@ type ProviderSessionToken struct { Token string `json:"token"` } +// Asks the SDK client to acquire a bearer token for a BYOK provider whose config set +// `hasBearerTokenProvider: true`. Issued by the runtime before each outbound model request; +// the runtime does no caching, so this is sent once per request. +// Experimental: ProviderTokenAcquireRequest is part of an experimental API and may change +// or be removed. +type ProviderTokenAcquireRequest struct { + // Name of the BYOK provider needing a token. For the legacy whole-session `provider` this + // is the implicit provider name; for named providers it is `NamedProviderConfig.name`. + ProviderName string `json:"providerName"` + // Target session identifier + SessionID string `json:"sessionId"` +} + +// A bearer token supplied by the SDK client for a BYOK provider. The runtime sets it as +// `Authorization: Bearer ` on the outbound request and does no caching; the SDK +// consumer owns token caching and refresh. +// Experimental: ProviderTokenAcquireResult is part of an experimental API and may change or +// be removed. +type ProviderTokenAcquireResult struct { + // The bearer token value (without the `Bearer ` prefix). + Token string `json:"token"` +} + // Schema for the `PushAttachment` type. // Experimental: PushAttachment is part of an experimental API and may change or be removed. type PushAttachment interface { @@ -16857,6 +16892,28 @@ type CanvasHandler interface { Open(request *CanvasProviderOpenRequest) (*CanvasProviderOpenResult, error) } +// Experimental: ProviderTokenHandler contains experimental APIs that may change or be +// removed. +type ProviderTokenHandler interface { + // GetToken asks the SDK client to get a bearer token for a BYOK provider whose config set + // `hasBearerTokenProvider: true`. Session-scoped: the runtime calls it back on the + // connection that created the session, passing the provider name, and uses the returned + // token as the Authorization header for the outbound model request. The runtime does no + // caching — it calls this once per outbound request; the SDK consumer owns token + // acquisition, caching, and refresh. + // + // RPC method: providerToken.getToken. + // + // Parameters: Asks the SDK client to acquire a bearer token for a BYOK provider whose + // config set `hasBearerTokenProvider: true`. Issued by the runtime before each outbound + // model request; the runtime does no caching, so this is sent once per request. + // + // Returns: A bearer token supplied by the SDK client for a BYOK provider. The runtime sets + // it as `Authorization: Bearer ` on the outbound request and does no caching; the + // SDK consumer owns token caching and refresh. + GetToken(request *ProviderTokenAcquireRequest) (*ProviderTokenAcquireResult, error) +} + // Experimental: SessionFSHandler contains experimental APIs that may change or be removed. type SessionFSHandler interface { // AppendFile appends content to a file in the client-provided session filesystem. @@ -16975,8 +17032,9 @@ type SessionFSHandler interface { // ClientSessionAPIHandlers provides all client session API handler groups for a session. type ClientSessionAPIHandlers struct { - Canvas CanvasHandler - SessionFS SessionFSHandler + Canvas CanvasHandler + ProviderToken ProviderTokenHandler + SessionFS SessionFSHandler } func clientSessionHandlerError(err error) *jsonrpc2.Error { @@ -17050,6 +17108,25 @@ func RegisterClientSessionAPIHandlers(client *jsonrpc2.Client, getHandlers func( } return raw, nil }) + client.SetRequestHandler("providerToken.getToken", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request ProviderTokenAcquireRequest + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + handlers := getHandlers(request.SessionID) + if handlers == nil || handlers.ProviderToken == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No providerToken handler registered for session: %s", request.SessionID)} + } + result, err := handlers.ProviderToken.GetToken(&request) + if err != nil { + return nil, clientSessionHandlerError(err) + } + raw, err := json.Marshal(result) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)} + } + return raw, nil + }) client.SetRequestHandler("sessionFs.appendFile", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { var request SessionFSAppendFileRequest if err := json.Unmarshal(params, &request); err != nil { diff --git a/go/session.go b/go/session.go index acd698677..2f7d23d30 100644 --- a/go/session.go +++ b/go/session.go @@ -77,6 +77,8 @@ type Session struct { elicitationMu sync.RWMutex canvasHandler CanvasHandler canvasMu sync.RWMutex + bearerTokenProviders map[string]GetBearerToken + bearerTokenMu sync.RWMutex openCanvases []rpc.OpenCanvasInstance openCanvasesMu sync.RWMutex capabilities SessionCapabilities @@ -183,6 +185,66 @@ func (s *Session) getCanvasHandler() CanvasHandler { return s.canvasHandler } +// registerBearerTokenProviders installs per-provider [GetBearerToken] callbacks +// for BYOK providers configured with managed-identity / on-demand bearer-token +// auth, keyed by provider name. +// +// The runtime never receives the callback itself; the SDK strips it from the +// provider config and instead sends `hasBearerTokenProvider: true`. When the +// runtime needs a token it issues a session-scoped `providerToken.getToken` +// request, which the session's provider-token adapter routes to the matching +// per-provider callback. +func (s *Session) registerBearerTokenProviders(providers map[string]GetBearerToken) { + s.bearerTokenMu.Lock() + defer s.bearerTokenMu.Unlock() + s.bearerTokenProviders = make(map[string]GetBearerToken, len(providers)) + for name, callback := range providers { + if callback == nil { + continue + } + s.bearerTokenProviders[name] = callback + } +} + +func (s *Session) getBearerTokenProvider(providerName string) GetBearerToken { + s.bearerTokenMu.RLock() + defer s.bearerTokenMu.RUnlock() + return s.bearerTokenProviders[providerName] +} + +type providerTokenClientSessionAdapter struct { + session *Session +} + +func newProviderTokenClientSessionAdapter(session *Session) rpc.ProviderTokenHandler { + return &providerTokenClientSessionAdapter{session: session} +} + +func (a *providerTokenClientSessionAdapter) GetToken(request *rpc.ProviderTokenAcquireRequest) (*rpc.ProviderTokenAcquireResult, error) { + if request == nil { + return nil, providerTokenJSONRPCError("missing provider token request") + } + if a.session == nil || a.session.SessionID != request.SessionID { + return nil, providerTokenJSONRPCError(fmt.Sprintf("unknown session %s", request.SessionID)) + } + callback := a.session.getBearerTokenProvider(request.ProviderName) + if callback == nil { + return nil, providerTokenJSONRPCError(fmt.Sprintf("No bearer-token provider registered for provider %q", request.ProviderName)) + } + token, err := callback(ProviderTokenArgs{ProviderName: request.ProviderName}) + if err != nil { + return nil, providerTokenJSONRPCError(err.Error()) + } + return &rpc.ProviderTokenAcquireResult{Token: token}, nil +} + +func providerTokenJSONRPCError(message string) *jsonrpc2.Error { + return &jsonrpc2.Error{ + Code: -32603, + Message: message, + } +} + type canvasClientSessionAdapter struct { session *Session } @@ -309,6 +371,7 @@ func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) RPC: rpc.NewSessionRPC(client, sessionID), } s.clientSessionAPIs.Canvas = newCanvasClientSessionAdapter(s) + s.clientSessionAPIs.ProviderToken = newProviderTokenClientSessionAdapter(s) go s.processEvents() return s } diff --git a/go/types.go b/go/types.go index 3ff6d0f9c..4d0db2b19 100644 --- a/go/types.go +++ b/go/types.go @@ -1536,6 +1536,37 @@ type ResumeSessionConfig struct { // ExtensionInfo identifies the stable extension providing this session's canvases. ExtensionInfo *ExtensionInfo } + +// ProviderTokenArgs carries the context passed to a [GetBearerToken] callback +// when the runtime needs a fresh bearer token for a BYOK provider. +// +// Experimental: ProviderTokenArgs is part of the experimental managed-identity / +// bearer-token-provider surface and may change or be removed in future SDK or CLI +// releases. +type ProviderTokenArgs struct { + // ProviderName is the name of the BYOK provider needing a token. For the + // singular, whole-session [ProviderConfig] this is the implicit provider name + // ("default"); for [NamedProviderConfig] entries it is + // [NamedProviderConfig.Name]. + // + // The callback closes over its own token scope/audience; the runtime is + // provider-agnostic and forwards only the provider name. + ProviderName string +} + +// GetBearerToken is a per-provider callback that resolves a bearer token on +// demand, returning the raw token string (without the "Bearer " prefix). The +// Copilot SDK itself takes no Azure dependency: the consumer supplies this +// callback backed by their own identity library (for example azidentity's +// DefaultAzureCredential.GetToken), and the runtime calls it once before each +// outbound model request. The runtime does no caching of its own, so the callback +// (or the identity library it wraps) owns token caching and refresh. +// +// Experimental: GetBearerToken is part of the experimental managed-identity / +// bearer-token-provider surface and may change or be removed in future SDK or CLI +// releases. +type GetBearerToken func(args ProviderTokenArgs) (string, error) + type ProviderConfig struct { // Type is the provider type: "openai", "azure", or "anthropic". Defaults to "openai". Type string `json:"type,omitempty"` @@ -1576,6 +1607,33 @@ type ProviderConfig struct { // tokens. When hit, the model stops generating and returns a truncated // response. MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + // GetBearerToken resolves a bearer token on demand for this provider + // (managed-identity / on-demand auth). When set, the SDK strips the callback + // from the wire config and instead sends `hasBearerTokenProvider: true`; the + // runtime calls back over the session-scoped `providerToken.getToken` RPC + // before each outbound model request and applies the returned token as the + // Authorization header. Never serialized. + // + // Experimental: part of the experimental managed-identity / bearer-token-provider + // surface and may change or be removed in future SDK or CLI releases. + GetBearerToken GetBearerToken `json:"-"` +} + +// MarshalJSON serializes the provider config, deriving the wire-only +// `hasBearerTokenProvider` flag from the presence of [ProviderConfig.GetBearerToken]. +// The non-serializable callback never crosses the RPC boundary; the runtime only +// learns that a token provider exists and forwards the provider name back when it +// needs a token. +func (p ProviderConfig) MarshalJSON() ([]byte, error) { + type wire ProviderConfig + aux := struct { + wire + HasBearerTokenProvider *bool `json:"hasBearerTokenProvider,omitempty"` + }{wire: wire(p)} + if p.GetBearerToken != nil { + aux.HasBearerTokenProvider = Bool(true) + } + return json.Marshal(aux) } // CapiSessionOptions configures provider-scoped Copilot API (CAPI) session behavior. @@ -1630,6 +1688,33 @@ type NamedProviderConfig struct { Azure *AzureProviderOptions `json:"azure,omitempty"` // Headers are custom HTTP headers included in all outbound provider requests. Headers map[string]string `json:"headers,omitempty"` + // GetBearerToken resolves a bearer token on demand for this provider + // (managed-identity / on-demand auth). When set, the SDK strips the callback + // from the wire config and instead sends `hasBearerTokenProvider: true`; the + // runtime calls back over the session-scoped `providerToken.getToken` RPC + // before each outbound model request and applies the returned token as the + // Authorization header. Never serialized. + // + // Experimental: part of the experimental managed-identity / bearer-token-provider + // surface and may change or be removed in future SDK or CLI releases. + GetBearerToken GetBearerToken `json:"-"` +} + +// MarshalJSON serializes the named provider config, deriving the wire-only +// `hasBearerTokenProvider` flag from the presence of +// [NamedProviderConfig.GetBearerToken]. The non-serializable callback never +// crosses the RPC boundary; the runtime only learns that a token provider exists +// and forwards the provider name back when it needs a token. +func (p NamedProviderConfig) MarshalJSON() ([]byte, error) { + type wire NamedProviderConfig + aux := struct { + wire + HasBearerTokenProvider *bool `json:"hasBearerTokenProvider,omitempty"` + }{wire: wire(p)} + if p.GetBearerToken != nil { + aux.HasBearerTokenProvider = Bool(true) + } + return json.Marshal(aux) } // ProviderModelConfig is a BYOK model definition that references a diff --git a/java/src/main/java/com/github/copilot/CopilotSession.java b/java/src/main/java/com/github/copilot/CopilotSession.java index adfeac013..df4491bf1 100644 --- a/java/src/main/java/com/github/copilot/CopilotSession.java +++ b/java/src/main/java/com/github/copilot/CopilotSession.java @@ -75,6 +75,7 @@ import com.github.copilot.rpc.ExitPlanModeRequest; import com.github.copilot.rpc.ExitPlanModeResult; import com.github.copilot.rpc.ElicitationSchema; +import com.github.copilot.rpc.GetBearerToken; import com.github.copilot.rpc.GetMessagesResponse; import com.github.copilot.rpc.HookInvocation; import com.github.copilot.rpc.InputOptions; @@ -169,6 +170,7 @@ public final class CopilotSession implements AutoCloseable { private final Set> eventHandlers = ConcurrentHashMap.newKeySet(); private final Map toolHandlers = new ConcurrentHashMap<>(); private final Map commandHandlers = new ConcurrentHashMap<>(); + private final Map bearerTokenProviders = new ConcurrentHashMap<>(); private final AtomicReference permissionHandler = new AtomicReference<>(); private final AtomicReference userInputHandler = new AtomicReference<>(); private final AtomicReference elicitationHandler = new AtomicReference<>(); @@ -1348,6 +1350,33 @@ void registerElicitationHandler(ElicitationHandler handler) { elicitationHandler.set(handler); } + /** + * Registers bearer-token provider callbacks for this session. + *

+ * Called internally when creating or resuming a session with BYOK providers + * that use managed-identity token callbacks. + * + * @param providers + * the callbacks keyed by provider name + */ + void registerBearerTokenProviders(Map providers) { + bearerTokenProviders.clear(); + if (providers != null) { + bearerTokenProviders.putAll(providers); + } + } + + /** + * Gets the bearer-token provider callback for the given provider name. + * + * @param providerName + * the provider name + * @return the registered callback, or {@code null} if none is registered + */ + GetBearerToken getBearerTokenProvider(String providerName) { + return bearerTokenProviders.get(providerName); + } + /** * Registers an exit-plan-mode handler for this session. *

diff --git a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java index 391f270db..b62e8c582 100644 --- a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java +++ b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java @@ -19,6 +19,8 @@ import com.github.copilot.generated.SessionEvent; import com.github.copilot.rpc.AutoModeSwitchRequest; import com.github.copilot.rpc.ExitPlanModeRequest; +import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.ProviderTokenArgs; import com.github.copilot.rpc.PermissionRequestResult; import com.github.copilot.rpc.PermissionRequestResultKind; import com.github.copilot.rpc.SessionLifecycleEvent; @@ -88,6 +90,8 @@ void registerHandlers(JsonRpcClient rpc) { rpc.registerMethodHandler("hooks.invoke", (requestId, params) -> handleHooksInvoke(rpc, requestId, params)); rpc.registerMethodHandler("systemMessage.transform", (requestId, params) -> handleSystemMessageTransform(rpc, requestId, params)); + rpc.registerMethodHandler("providerToken.getToken", + (requestId, params) -> handleProviderTokenGetToken(rpc, requestId, params)); } private void handleSessionEvent(JsonNode params) { @@ -300,6 +304,68 @@ private void handleUserInputRequest(JsonRpcClient rpc, String requestId, JsonNod }); } + private void handleProviderTokenGetToken(JsonRpcClient rpc, String requestId, JsonNode params) { + LOG.fine("Received providerToken.getToken: " + params); + runAsync(() -> { + final long requestIdLong = parseRequestId(requestId, "providerToken.getToken"); + if (requestIdLong == -1) { + return; + } + try { + String sessionId = params.get("sessionId").asText(); + String providerName = params.get("providerName").asText(); + + CopilotSession session = sessions.get(sessionId); + if (session == null) { + rpc.sendErrorResponse(requestIdLong, -32602, "Unknown session " + sessionId); + return; + } + + GetBearerToken provider = session.getBearerTokenProvider(providerName); + if (provider == null) { + rpc.sendErrorResponse(requestIdLong, -32603, + "No bearer-token provider registered for provider " + providerName); + return; + } + + CompletableFuture tokenFuture = provider.getToken(new ProviderTokenArgs(providerName)); + if (tokenFuture == null) { + rpc.sendErrorResponse(requestIdLong, -32603, + "Bearer-token provider returned null future for provider " + providerName); + return; + } + + tokenFuture.thenAccept(token -> { + try { + if (token == null) { + rpc.sendErrorResponse(requestIdLong, -32603, + "Bearer-token provider returned null token for provider " + providerName); + return; + } + rpc.sendResponse(requestIdLong, Map.of("token", token)); + } catch (IOException e) { + LOG.log(Level.SEVERE, "Error sending provider token response", e); + } + }).exceptionally(ex -> { + LOG.log(Level.WARNING, "Bearer-token provider exception", ex); + try { + rpc.sendErrorResponse(requestIdLong, -32603, "Bearer-token provider error: " + ex.getMessage()); + } catch (IOException e) { + LOG.log(Level.SEVERE, "Error sending provider token error", e); + } + return null; + }); + } catch (Exception e) { + LOG.log(Level.SEVERE, "Error handling providerToken.getToken", e); + try { + rpc.sendErrorResponse(requestIdLong, -32603, "Provider token handler error: " + e.getMessage()); + } catch (IOException ioException) { + LOG.log(Level.SEVERE, "Error sending provider token handler error", ioException); + } + } + }); + } + private void handleExitPlanModeRequest(JsonRpcClient rpc, String requestId, JsonNode params) { runAsync(() -> { final long requestIdLong = parseRequestId(requestId, "exitPlanMode.request"); diff --git a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java index c26548a2f..073628945 100644 --- a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java +++ b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java @@ -5,11 +5,15 @@ package com.github.copilot; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.Function; import com.github.copilot.rpc.CreateSessionRequest; +import com.github.copilot.rpc.ProviderConfig; +import com.github.copilot.rpc.NamedProviderConfig; +import com.github.copilot.rpc.GetBearerToken; import com.github.copilot.rpc.CommandWireDefinition; import com.github.copilot.rpc.ResumeSessionConfig; import com.github.copilot.rpc.ResumeSessionRequest; @@ -329,6 +333,11 @@ static void configureSession(CopilotSession session, SessionConfig config) { if (config.getOnElicitationRequest() != null) { session.registerElicitationHandler(config.getOnElicitationRequest()); } + Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), + config.getProviders()); + if (!bearerTokenProviders.isEmpty()) { + session.registerBearerTokenProviders(bearerTokenProviders); + } if (config.getOnExitPlanMode() != null) { session.registerExitPlanModeHandler(config.getOnExitPlanMode()); } @@ -371,6 +380,11 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config) if (config.getOnElicitationRequest() != null) { session.registerElicitationHandler(config.getOnElicitationRequest()); } + Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), + config.getProviders()); + if (!bearerTokenProviders.isEmpty()) { + session.registerBearerTokenProviders(bearerTokenProviders); + } if (config.getOnExitPlanMode() != null) { session.registerExitPlanModeHandler(config.getOnExitPlanMode()); } @@ -381,4 +395,21 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config) session.on(config.getOnEvent()); } } + + private static Map collectBearerTokenProviders(ProviderConfig provider, + List providers) { + Map bearerTokenProviders = new HashMap<>(); + if (provider != null && provider.getGetBearerToken() != null) { + bearerTokenProviders.put("default", provider.getGetBearerToken()); + } + if (providers != null) { + for (NamedProviderConfig namedProvider : providers) { + if (namedProvider != null && namedProvider.getName() != null + && namedProvider.getGetBearerToken() != null) { + bearerTokenProviders.put(namedProvider.getName(), namedProvider.getGetBearerToken()); + } + } + } + return bearerTokenProviders; + } } diff --git a/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java b/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java new file mode 100644 index 000000000..27ec7f09c --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java @@ -0,0 +1,40 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import java.util.concurrent.CompletableFuture; + +import com.github.copilot.CopilotExperimental; + +/** + * Functional interface for supplying per-provider bearer tokens for BYOK + * provider requests. + *

+ * The callback returns the raw token without a {@code Bearer } prefix. The SDK + * keeps this callback client-side and the runtime requests a token via the + * session-scoped {@code providerToken.getToken} RPC before each outbound model + * request. + *

+ * Experimental. This managed-identity surface may change or be + * removed in future SDK or CLI releases. + * + * @see ProviderConfig#setGetBearerToken(GetBearerToken) + * @see NamedProviderConfig#setGetBearerToken(GetBearerToken) + * @since 1.0.0 + */ +@CopilotExperimental +@FunctionalInterface +public interface GetBearerToken { + + /** + * Gets a bearer token for the provider identified by {@code args}. + * + * @param args + * the provider token request arguments + * @return a future that completes with the raw token, without a {@code Bearer } + * prefix + */ + CompletableFuture getToken(ProviderTokenArgs args); +} diff --git a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java index dbc157739..2bdf2678f 100644 --- a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java @@ -7,6 +7,7 @@ import java.util.Collections; import java.util.Map; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; @@ -59,6 +60,9 @@ public class NamedProviderConfig { @JsonProperty("bearerToken") private String bearerToken; + @JsonIgnore + private GetBearerToken getBearerToken; + @JsonProperty("azure") private AzureOptions azure; @@ -212,6 +216,39 @@ public NamedProviderConfig setBearerToken(String bearerToken) { return this; } + /** + * Gets the bearer-token provider callback. + * + * @return the bearer-token provider callback, or {@code null} if not set + */ + public GetBearerToken getGetBearerToken() { + return getBearerToken; + } + + /** + * Sets a callback that supplies bearer tokens for outbound provider requests. + *

+ * Experimental. The callback stays SDK-side and is not + * serialized. Instead, the runtime receives a {@code hasBearerTokenProvider} + * flag and calls back over the session-scoped {@code providerToken.getToken} + * RPC before each model request. Return the raw token without a {@code Bearer } + * prefix. + * + * @param getBearerToken + * the bearer-token provider callback + * @return this config for method chaining + */ + public NamedProviderConfig setGetBearerToken(GetBearerToken getBearerToken) { + this.getBearerToken = getBearerToken; + return this; + } + + @JsonProperty("hasBearerTokenProvider") + @JsonInclude(JsonInclude.Include.NON_NULL) + Boolean hasBearerTokenProviderWireFlag() { + return getBearerToken != null ? Boolean.TRUE : null; + } + /** * Gets the Azure-specific options. * diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java index 8ba492ed9..ae59e7ead 100644 --- a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java @@ -56,6 +56,9 @@ public class ProviderConfig { @JsonProperty("bearerToken") private String bearerToken; + @JsonIgnore + private GetBearerToken getBearerToken; + @JsonProperty("azure") private AzureOptions azure; @@ -222,6 +225,39 @@ public ProviderConfig setBearerToken(String bearerToken) { return this; } + /** + * Gets the bearer-token provider callback. + * + * @return the bearer-token provider callback, or {@code null} if not set + */ + public GetBearerToken getGetBearerToken() { + return getBearerToken; + } + + /** + * Sets a callback that supplies bearer tokens for outbound provider requests. + *

+ * Experimental. The callback stays SDK-side and is not + * serialized. Instead, the runtime receives a {@code hasBearerTokenProvider} + * flag and calls back over the session-scoped {@code providerToken.getToken} + * RPC before each model request. Return the raw token without a {@code Bearer } + * prefix. + * + * @param getBearerToken + * the bearer-token provider callback + * @return this config for method chaining + */ + public ProviderConfig setGetBearerToken(GetBearerToken getBearerToken) { + this.getBearerToken = getBearerToken; + return this; + } + + @JsonProperty("hasBearerTokenProvider") + @JsonInclude(JsonInclude.Include.NON_NULL) + Boolean hasBearerTokenProviderWireFlag() { + return getBearerToken != null ? Boolean.TRUE : null; + } + /** * Gets the Azure-specific options. * diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java new file mode 100644 index 000000000..3866cc0ad --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java @@ -0,0 +1,63 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import com.github.copilot.CopilotExperimental; + +/** + * Arguments passed to a BYOK bearer-token provider callback. + *

+ * Experimental. This managed-identity surface may change or be + * removed in future SDK or CLI releases. + * + * @since 1.0.0 + */ +@CopilotExperimental +public class ProviderTokenArgs { + + private String providerName; + + /** + * Creates an empty argument object. + */ + public ProviderTokenArgs() { + } + + /** + * Creates argument object for the named provider. + * + * @param providerName + * the name of the BYOK provider needing a token; {@code "default"} + * for the singular whole-session provider, otherwise the named + * provider's {@code name} + */ + public ProviderTokenArgs(String providerName) { + this.providerName = providerName; + } + + /** + * Gets the name of the BYOK provider needing a token. + *

+ * The value is {@code "default"} for the singular whole-session provider, + * otherwise the named provider's {@code name}. + * + * @return the provider name + */ + public String getProviderName() { + return providerName; + } + + /** + * Sets the name of the BYOK provider needing a token. + * + * @param providerName + * the provider name + * @return this args instance for method chaining + */ + public ProviderTokenArgs setProviderName(String providerName) { + this.providerName = providerName; + return this; + } +} diff --git a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java new file mode 100644 index 000000000..253ce136c --- /dev/null +++ b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java @@ -0,0 +1,274 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static com.github.copilot.CopilotRequestTestSupport.buildNonInferenceResponse; +import static com.github.copilot.CopilotRequestTestSupport.newLlmClient; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.net.ssl.SSLSession; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.MessageOptions; +import com.github.copilot.rpc.NamedProviderConfig; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.ProviderModelConfig; +import com.github.copilot.rpc.SessionConfig; + +/** + * End-to-end coverage for the experimental BYOK bearer-token-provider surface + * ({@code getBearerToken} on a provider config). The callback stays entirely on + * the SDK/client side: the SDK keeps it off the wire, sends only the + * {@code hasBearerTokenProvider} flag, and the runtime calls back over the + * session-scoped {@code providerToken.getToken} RPC before each outbound model + * request. + */ +public class ByokBearerTokenProviderE2ETest { + + private static final String PRIMARY_HOST = "byok-endpoint.invalid"; + private static final String PRIMARY_BASE_URL = "https://" + PRIMARY_HOST + "/v1"; + private static final String RED_HOST = "byok-red.invalid"; + private static final String RED_BASE_URL = "https://" + RED_HOST + "/v1"; + private static final String BLUE_HOST = "byok-blue.invalid"; + private static final String BLUE_BASE_URL = "https://" + BLUE_HOST + "/v1"; + + private static E2ETestContext ctx; + private CapturingRequestHandler handler; + + @BeforeAll + static void setup() throws Exception { + ctx = E2ETestContext.create(); + } + + @AfterAll + static void teardown() throws Exception { + if (ctx != null) { + ctx.close(); + } + } + + @BeforeEach + void resetHandler() { + handler = new CapturingRequestHandler(); + } + + @Test + void appliesCallbackTokenAsAuthorizationHeader() throws Exception { + String sentinel = "sentinel-bearer-token-abc123"; + AtomicInteger calls = new AtomicInteger(); + GetBearerToken getBearerToken = args -> { + calls.incrementAndGet(); + return CompletableFuture.completedFuture(sentinel); + }; + + List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai") + .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken)); + List models = List + .of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o")); + + runTurn(providers, models, "mi/default", "What is 5+5?"); + + assertTrue(handler.authHeaders().contains("Bearer " + sentinel), + "Expected captured Authorization headers to contain the callback token: " + handler.authHeaders()); + assertTrue(calls.get() >= 1, "Expected the callback to be invoked at least once"); + } + + @Test + void reacquiresFreshTokenForEachRequest() throws Exception { + AtomicInteger calls = new AtomicInteger(); + GetBearerToken getBearerToken = args -> CompletableFuture + .completedFuture("rotating-token-" + calls.incrementAndGet()); + + List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai") + .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken)); + List models = List + .of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o")); + + runTurn(providers, models, "mi/default", "What is 1+1?"); + runTurn(providers, models, "mi/default", "What is 2+2?"); + + List auths = handler.authHeaders(); + assertTrue(auths.size() >= 2, "Expected at least two captured Authorization headers, got " + auths); + assertTrue(auths.get(0).startsWith("Bearer rotating-token-"), "Expected rotating token, got " + auths); + assertTrue(auths.get(1).startsWith("Bearer rotating-token-"), "Expected rotating token, got " + auths); + assertNotEquals(auths.get(0), auths.get(1), "Expected distinct tokens per request"); + assertTrue(calls.get() >= 2, "Expected the callback to be invoked at least twice"); + } + + @Test + void dispatchesTokenAcquisitionPerProvider() throws Exception { + List acquiredFor = new ArrayList<>(); + GetBearerToken redCallback = args -> { + assertEquals("red", args.getProviderName(), "Expected providerName to be forwarded"); + synchronized (acquiredFor) { + acquiredFor.add("red"); + } + return CompletableFuture.completedFuture("token-for-red"); + }; + GetBearerToken blueCallback = args -> { + assertEquals("blue", args.getProviderName(), "Expected providerName to be forwarded"); + synchronized (acquiredFor) { + acquiredFor.add("blue"); + } + return CompletableFuture.completedFuture("token-for-blue"); + }; + + List providers = List.of( + new NamedProviderConfig().setName("red").setType("openai").setWireApi("completions") + .setBaseUrl(RED_BASE_URL).setGetBearerToken(redCallback), + new NamedProviderConfig().setName("blue").setType("openai").setWireApi("completions") + .setBaseUrl(BLUE_BASE_URL).setGetBearerToken(blueCallback)); + List models = List.of( + new ProviderModelConfig().setId("default").setProvider("red").setWireModel("byok-gpt-4o"), + new ProviderModelConfig().setId("default").setProvider("blue").setWireModel("byok-gpt-4o")); + + runTurn(providers, models, "red/default", "What is 3+3?"); + runTurn(providers, models, "blue/default", "What is 4+4?"); + + assertEquals("Bearer token-for-red", handler.authHeaderForHost(RED_HOST)); + assertEquals("Bearer token-for-blue", handler.authHeaderForHost(BLUE_HOST)); + synchronized (acquiredFor) { + assertTrue(acquiredFor.contains("red"), "Expected red provider to acquire a token"); + assertTrue(acquiredFor.contains("blue"), "Expected blue provider to acquire a token"); + } + } + + private void runTurn(List providers, List models, String selectionId, + String prompt) throws Exception { + try (CopilotClient client = newLlmClient(ctx, handler)) { + CopilotSession session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setModel(selectionId).setProviders(providers).setModels(models)) + .get(60, TimeUnit.SECONDS); + try { + session.sendAndWait(new MessageOptions().setPrompt(prompt)).get(60, TimeUnit.SECONDS); + } catch (Exception ignored) { + // The fake BYOK endpoint returns 404 after capturing the token-bearing request. + } finally { + try { + session.close(); + } catch (Exception ignored) { + // Ignore disconnect errors for the fake BYOK endpoint. + } + } + } + } + + private static final class CapturingRequestHandler extends CopilotRequestHandler { + + private final ConcurrentLinkedQueue captures = new ConcurrentLinkedQueue<>(); + + @Override + protected HttpResponse sendRequest(HttpRequest request, CopilotRequestContext rctx) + throws Exception { + String host = request.uri().getHost(); + if (host != null && host.endsWith(".invalid")) { + captures.add(new CapturedRequest(request.uri().getHost(), + request.headers().firstValue("Authorization").orElse(null))); + return new StubHttpResponse(404, "{\"error\":{\"message\":\"fake byok endpoint\"}}"); + } + return buildNonInferenceResponse(request.uri().toString()); + } + + List authHeaders() { + List auths = new ArrayList<>(); + for (CapturedRequest capture : captures) { + if (capture.authorization() != null) { + auths.add(capture.authorization()); + } + } + return auths; + } + + String authHeaderForHost(String host) { + for (CapturedRequest capture : captures) { + if (host.equals(capture.host())) { + return capture.authorization(); + } + } + return null; + } + } + + private static final class StubHttpResponse implements HttpResponse { + + private final int status; + private final HttpHeaders headers; + private final byte[] body; + + StubHttpResponse(int status, String body) { + this.status = status; + this.body = body.getBytes(StandardCharsets.UTF_8); + this.headers = HttpHeaders.of(Map.of("content-type", List.of("application/json")), (k, v) -> true); + } + + @Override + public int statusCode() { + return status; + } + + @Override + public HttpRequest request() { + return null; + } + + @Override + public Optional> previousResponse() { + return Optional.empty(); + } + + @Override + public HttpHeaders headers() { + return headers; + } + + @Override + public InputStream body() { + return new ByteArrayInputStream(body); + } + + @Override + public Optional sslSession() { + return Optional.empty(); + } + + @Override + public URI uri() { + return null; + } + + @Override + public HttpClient.Version version() { + return HttpClient.Version.HTTP_1_1; + } + } + + private record CapturedRequest(String host, String authorization) { + } +} diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 96ac60842..1c29dd53c 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -51,11 +51,14 @@ import type { ExitPlanModeResult, ForegroundSessionInfo, GetAuthStatusResponse, + GetBearerToken, GetStatusResponse, InternalRuntimeConnection, LargeToolOutputConfig, MCPServerConfig, ModelInfo, + NamedProviderConfig, + ProviderConfig, ResumeSessionConfig, SectionTransformFn, SessionConfig, @@ -154,6 +157,62 @@ function toJsonSchema(parameters: Tool["parameters"]): Record | return parameters; } +/** Implicit provider name for the singular, whole-session {@link ProviderConfig}. */ +const DEFAULT_PROVIDER_NAME = "default"; + +/** Wire-safe singular provider config carrying the `hasBearerTokenProvider` flag. */ +type WireProviderConfig = Omit & { hasBearerTokenProvider?: boolean }; + +/** Wire-safe named provider config carrying the `hasBearerTokenProvider` flag. */ +type WireNamedProviderConfig = Omit & { + hasBearerTokenProvider?: boolean; +}; + +/** + * Strips the non-serializable {@link GetBearerToken} callbacks from the singular + * and named provider configs before they cross the RPC boundary, replacing each + * with a `hasBearerTokenProvider: true` wire flag. The callback closes over its + * own token scope/audience, so nothing scope-related crosses the wire — the + * runtime only forwards the provider name back when it needs a token. + * Returns wire-safe provider configs alongside a map of provider name → callback + * for session-side registration. + */ +function extractBearerTokenProviders( + provider: ProviderConfig | undefined, + providers: NamedProviderConfig[] | undefined +): { + wireProvider: WireProviderConfig | undefined; + wireProviders: WireNamedProviderConfig[] | undefined; + callbacks: Map; +} { + const callbacks = new Map(); + + let wireProvider: WireProviderConfig | undefined = provider; + if (provider?.getBearerToken) { + const { getBearerToken, ...rest } = provider; + callbacks.set(DEFAULT_PROVIDER_NAME, getBearerToken); + wireProvider = { + ...rest, + hasBearerTokenProvider: true, + }; + } + + let wireProviders: WireNamedProviderConfig[] | undefined = providers; + if (providers?.some((p) => p.getBearerToken)) { + wireProviders = providers.map((p) => { + if (!p.getBearerToken) return p; + const { getBearerToken, ...rest } = p; + callbacks.set(p.name, getBearerToken); + return { + ...rest, + hasBearerTokenProvider: true, + }; + }); + } + + return { wireProvider, wireProviders, callbacks }; +} + /** * Convert MCP server configs from public API format (workingDirectory) to * wire format (cwd) expected by the runtime. @@ -1237,6 +1296,15 @@ export class CopilotClient { const useServerGeneratedId = config.cloud != null && callerSessionId == null; const localSessionId = useServerGeneratedId ? undefined : (callerSessionId ?? randomUUID()); + // Strip non-serializable getBearerToken callbacks from provider configs, + // replacing them with a wire flag; keep the callbacks for session-side + // registration so the runtime can call back to acquire tokens. + const { + wireProvider: bearerWireProvider, + wireProviders: bearerWireProviders, + callbacks: bearerTokenCallbacks, + } = extractBearerTokenProviders(config.provider, config.providers); + // Extract transform callbacks from system message config before serialization. const { wirePayload: wireSystemMessage, transformCallbacks } = extractTransformCallbacks( config.systemMessage @@ -1254,6 +1322,9 @@ export class CopilotClient { s.registerTools(config.tools); s.registerCanvases(config.canvases); s.registerCommands(config.commands); + if (bearerTokenCallbacks.size > 0) { + s.registerBearerTokenProviders(bearerTokenCallbacks); + } s.registerPermissionHandler(config.onPermissionRequest); if (config.onUserInputRequest) { s.registerUserInputHandler(config.onUserInputRequest); @@ -1325,9 +1396,9 @@ export class CopilotClient { availableTools: toolFilterOptions.availableTools, excludedTools: toolFilterOptions.excludedTools, toolFilterPrecedence: toolFilterOptions.toolFilterPrecedence, - provider: config.provider, + provider: bearerWireProvider, capi: config.capi, - providers: config.providers, + providers: bearerWireProviders, models: config.models, enableSessionTelemetry: config.enableSessionTelemetry, modelCapabilities: config.modelCapabilities, @@ -1446,6 +1517,14 @@ export class CopilotClient { session.registerTools(config.tools); session.registerCanvases(config.canvases); session.registerCommands(config.commands); + const { + wireProvider: bearerWireProvider, + wireProviders: bearerWireProviders, + callbacks: bearerTokenCallbacks, + } = extractBearerTokenProviders(config.provider, config.providers); + if (bearerTokenCallbacks.size > 0) { + session.registerBearerTokenProviders(bearerTokenCallbacks); + } session.registerPermissionHandler(config.onPermissionRequest); if (config.onUserInputRequest) { session.registerUserInputHandler(config.onUserInputRequest); @@ -1512,9 +1591,9 @@ export class CopilotClient { name: cmd.name, description: cmd.description, })), - provider: config.provider, + provider: bearerWireProvider, capi: config.capi, - providers: config.providers, + providers: bearerWireProviders, models: config.models, modelCapabilities: config.modelCapabilities, largeOutput: toWireLargeOutput(config.largeOutput), diff --git a/nodejs/src/generated/rpc.ts b/nodejs/src/generated/rpc.ts index df44de84c..423ab4fe8 100644 --- a/nodejs/src/generated/rpc.ts +++ b/nodejs/src/generated/rpc.ts @@ -6597,6 +6597,10 @@ export interface NamedProviderConfig { headers?: { [k: string]: string | undefined; }; + /** + * When true, the SDK client supplies bearer tokens on demand: the runtime calls the client-session `providerToken.getToken` callback before each request and uses the returned token as the Authorization header. The token-acquiring function itself stays on the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive with `apiKey`/`bearerToken`. + */ + hasBearerTokenProvider?: boolean; } /** * Azure-specific provider options. @@ -8575,6 +8579,10 @@ export interface ProviderConfig { headers?: { [k: string]: string | undefined; }; + /** + * When true, the SDK client supplies bearer tokens on demand: the runtime calls the client-session `providerToken.getToken` callback before each request and uses the returned token as the Authorization header. The token-acquiring function itself stays on the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive with `apiKey`/`bearerToken`. + */ + hasBearerTokenProvider?: boolean; } /** * A snapshot of the provider endpoint the session is currently configured to talk to. @@ -13627,6 +13635,36 @@ export interface WorkspacesSaveLargePasteResult { sizeBytes: number; } | null; } +/** + * Asks the SDK client to acquire a bearer token for a BYOK provider whose config set `hasBearerTokenProvider: true`. Issued by the runtime before each outbound model request; the runtime does no caching, so this is sent once per request. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "ProviderTokenAcquireRequest". + */ +/** @experimental */ +export interface ProviderTokenAcquireRequest { + /** + * Target session identifier + */ + sessionId: string; + /** + * Name of the BYOK provider needing a token. For the legacy whole-session `provider` this is the implicit provider name; for named providers it is `NamedProviderConfig.name`. + */ + providerName: string; +} +/** + * A bearer token supplied by the SDK client for a BYOK provider. The runtime sets it as `Authorization: Bearer ` on the outbound request and does no caching; the SDK consumer owns token caching and refresh. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "ProviderTokenAcquireResult". + */ +/** @experimental */ +export interface ProviderTokenAcquireResult { + /** + * The bearer token value (without the `Bearer ` prefix). + */ + token: string; +} /** * Standard MCP CallToolResult * @@ -15920,10 +15958,24 @@ export interface CanvasHandler { invoke(params: CanvasProviderInvokeActionRequest): Promise; } +/** Handler for `providerToken` client session API methods. */ +/** @experimental */ +export interface ProviderTokenHandler { + /** + * Asks the SDK client to get a bearer token for a BYOK provider whose config set `hasBearerTokenProvider: true`. Session-scoped: the runtime calls it back on the connection that created the session, passing the provider name, and uses the returned token as the Authorization header for the outbound model request. The runtime does no caching — it calls this once per outbound request; the SDK consumer owns token acquisition, caching, and refresh. + * + * @param params Asks the SDK client to acquire a bearer token for a BYOK provider whose config set `hasBearerTokenProvider: true`. Issued by the runtime before each outbound model request; the runtime does no caching, so this is sent once per request. + * + * @returns A bearer token supplied by the SDK client for a BYOK provider. The runtime sets it as `Authorization: Bearer ` on the outbound request and does no caching; the SDK consumer owns token caching and refresh. + */ + getToken(params: ProviderTokenAcquireRequest): Promise; +} + /** All client session API handler groups. */ export interface ClientSessionApiHandlers { sessionFs?: SessionFsHandler; canvas?: CanvasHandler; + providerToken?: ProviderTokenHandler; } /** @@ -16011,6 +16063,11 @@ export function registerClientSessionApiHandlers( if (!handler) throw new Error(`No canvas handler registered for session: ${params.sessionId}`); return handler.invoke(params); }); + connection.onRequest("providerToken.getToken", async (params: ProviderTokenAcquireRequest) => { + const handler = getHandlers(params.sessionId).providerToken; + if (!handler) throw new Error(`No providerToken handler registered for session: ${params.sessionId}`); + return handler.getToken(params); + }); } /** Handler for `llmInference` client global API methods. */ diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts index 9bf02a32c..740a7bc89 100644 --- a/nodejs/src/index.ts +++ b/nodejs/src/index.ts @@ -84,6 +84,7 @@ export type { MCPHTTPServerConfig, MCPServerConfig, DefaultAgentConfig, + GetBearerToken, MessageOptions, ModelBilling, ModelBillingTokenPrices, @@ -99,6 +100,7 @@ export type { PermissionRequestResult, ProviderConfig, ProviderModelConfig, + ProviderTokenArgs, RemoteSessionMode, ResumeSessionConfig, SectionOverride, diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index 0ba42ab76..7ffc8d26f 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -26,6 +26,7 @@ import type { ExitPlanModeHandler, ExitPlanModeRequest, ExitPlanModeResult, + GetBearerToken, UiInputOptions, MessageOptions, PermissionHandler, @@ -122,6 +123,7 @@ export class CopilotSession { new Map(); private toolHandlers: Map = new Map(); private canvases: Map = new Map(); + private bearerTokenProviders: Map = new Map(); private commandHandlers: Map = new Map(); private permissionHandler?: PermissionHandler; private userInputHandler?: UserInputHandler; @@ -797,6 +799,45 @@ export class CopilotSession { }; } + /** + * Registers per-provider {@link GetBearerToken} callbacks for BYOK providers + * configured with managed-identity / on-demand bearer-token auth. + * + * The runtime never receives the callback itself; the SDK strips it from the + * provider config and instead sends `hasBearerTokenProvider: true`. When the + * runtime needs a token it issues a session-scoped `providerToken.getToken` + * request, which this handler routes to the matching per-provider callback. + * + * @param providers - Map of provider name → callback, or undefined/empty to clear. + * @internal This method is called internally when creating/resuming a session. + */ + registerBearerTokenProviders(providers?: Map): void { + this.bearerTokenProviders.clear(); + if (!providers || providers.size === 0) { + delete this.clientSessionApis.providerToken; + return; + } + for (const [name, callback] of providers) { + this.bearerTokenProviders.set(name, callback); + } + + const self = this; + this.clientSessionApis.providerToken = { + async getToken(params) { + const callback = self.bearerTokenProviders.get(params.providerName); + if (!callback) { + throw new Error( + `No bearer-token provider registered for provider "${params.providerName}"` + ); + } + const token = await callback({ + providerName: params.providerName, + }); + return { token }; + }, + }; + } + /** * Registers command handlers for this session. * diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index bdf02a7b0..9ed724d55 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -2192,6 +2192,39 @@ export interface ResumeSessionConfig extends SessionConfigBase { openCanvases?: OpenCanvasInstance[]; } +/** + * Arguments passed to a {@link GetBearerToken} callback when the runtime needs a + * fresh bearer token for a BYOK provider. + * + * @experimental Part of the experimental managed-identity / bearer-token-provider + * surface and may change or be removed in future SDK or CLI releases. + */ +export interface ProviderTokenArgs { + /** + * Name of the BYOK provider needing a token. For the singular, whole-session + * {@link ProviderConfig} this is the implicit provider name (`"default"`); for + * {@link NamedProviderConfig} entries it is {@link NamedProviderConfig.name}. + * + * The callback closes over its own token scope/audience; the runtime is + * provider-agnostic and forwards only the provider name. + */ + providerName: string; +} + +/** + * Per-provider callback that resolves a bearer token on demand, returning the + * raw token string (without the `Bearer ` prefix). The Copilot SDK itself takes + * no Azure dependency: the consumer supplies this callback backed by their own + * identity library (for example `@azure/identity`'s + * `DefaultAzureCredential.getToken(scope)`), and the runtime calls it once before + * each outbound model request. The runtime does no caching of its own, so the + * callback (or the identity library it wraps) owns token caching and refresh. + * + * @experimental Part of the experimental managed-identity / bearer-token-provider + * surface and may change or be removed in future SDK or CLI releases. + */ +export type GetBearerToken = (args: ProviderTokenArgs) => Promise; + /** * Configuration for a custom API provider. */ @@ -2234,6 +2267,18 @@ export interface ProviderConfig { */ bearerToken?: string; + /** + * Per-request bearer-token provider for managed-identity / on-demand auth. + * When set, the SDK keeps this function client-side (it is never serialized) + * and the runtime calls back into this client to acquire a token before each + * outbound request. The runtime does no caching of its own, so the callback + * owns token caching and refresh. Mutually exclusive with {@link apiKey} / + * {@link bearerToken}. + * + * @experimental + */ + getBearerToken?: GetBearerToken; + /** * Azure-specific options */ @@ -2325,6 +2370,18 @@ export interface NamedProviderConfig { */ bearerToken?: string; + /** + * Per-request bearer-token provider for managed-identity / on-demand auth. + * When set, the SDK keeps this function client-side (it is never serialized) + * and the runtime calls back into this client to acquire a token before each + * outbound request. The runtime does no caching of its own, so the callback + * owns token caching and refresh. Mutually exclusive with {@link apiKey} / + * {@link bearerToken}. + * + * @experimental + */ + getBearerToken?: GetBearerToken; + /** * Azure-specific options. */ diff --git a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts new file mode 100644 index 000000000..228b7a022 --- /dev/null +++ b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts @@ -0,0 +1,255 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { beforeEach, describe, expect, it } from "vitest"; +import { approveAll, CopilotRequestHandler } from "../../src/index.js"; +import type { + CopilotRequestContext, + GetBearerToken, + NamedProviderConfig, + ProviderModelConfig, +} from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +/** + * A captured outbound HTTP request the runtime aimed at a fake BYOK provider + * endpoint: just the host and the `Authorization` header, which is all these + * tests need to assert on. + */ +interface CapturedRequest { + host: string; + authorization?: string; +} + +// Fake BYOK provider base URLs. These hosts are never actually dialed: the +// client-global request interceptor fully answers any request aimed at a +// `.invalid` host, so they only need to be syntactically valid, non-resolving +// URLs. Distinct hosts let the per-provider test assert routing by host. +const PRIMARY_HOST = "byok-endpoint.invalid"; +const PRIMARY_BASE_URL = `https://${PRIMARY_HOST}/v1`; +const RED_HOST = "byok-red.invalid"; +const RED_BASE_URL = `https://${RED_HOST}/v1`; +const BLUE_HOST = "byok-blue.invalid"; +const BLUE_BASE_URL = `https://${BLUE_HOST}/v1`; + +/** + * Client-global HTTP request interceptor (from the SDK's `CopilotRequestHandler` + * surface) used in place of a real HTTP listener. + * + * The runtime invokes {@link sendRequest} for every model-layer HTTP request it + * would otherwise issue. We capture the ones aimed at a fake BYOK host — + * recording the `Authorization` header the runtime applied after calling the + * provider's `getBearerToken` callback over the session-scoped + * `providerToken.getToken` RPC — and answer them with a synthetic `404` (a + * non-retryable status, so each outbound model request yields exactly one + * capture). Every other request (CAPI bootstrap: model catalog, policy, …) is + * passed straight through to the real network via `super.sendRequest`. + * + * Because the handler is client-global (one per CLI process), it is installed + * once for the whole fixture and {@link reset} between tests. + */ +class CapturingRequestHandler extends CopilotRequestHandler { + public readonly captures: CapturedRequest[] = []; + + protected override async sendRequest( + request: Request, + ctx: CopilotRequestContext + ): Promise { + const url = new URL(request.url); + if (url.hostname.endsWith(".invalid")) { + this.captures.push({ + host: url.host, + authorization: request.headers.get("authorization") ?? undefined, + }); + return new Response(JSON.stringify({ error: { message: "fake byok endpoint" } }), { + status: 404, + headers: { "content-type": "application/json" }, + }); + } + return super.sendRequest(request, ctx); + } + + reset(): void { + this.captures.length = 0; + } + + /** The `Authorization` headers captured across BYOK requests, in arrival order. */ + authHeaders(): string[] { + return this.captures + .map((c) => c.authorization) + .filter((v): v is string => typeof v === "string"); + } + + /** The `Authorization` header captured for requests aimed at `host`, if any. */ + authHeaderForHost(host: string): string | undefined { + return this.captures.find((c) => c.host === host)?.authorization; + } +} + +/** + * End-to-end coverage for the experimental BYOK bearer-token-provider surface + * (`getBearerToken` on a provider config). The callback stays entirely on the + * SDK/client side: the SDK strips it from the wire config, sets the + * `hasBearerTokenProvider` flag, and the runtime calls back over the session-scoped + * `providerToken.getToken` RPC before each outbound model request, applying the + * returned token as the `Authorization` header. + * + * Rather than standing up a real HTTP listener, these tests install a + * client-global {@link CapturingRequestHandler} that intercepts the runtime's + * outbound model request in-process, captures the `Authorization` header, and + * returns a synthetic response. They validate, against a real runtime: + * 1. the callback's token reaches the model request as `Authorization: Bearer `; + * 2. the runtime re-acquires a token per request (no runtime-side caching); + * 3. per-provider dispatch routes each provider's turn to its own callback, + * and the resulting token reaches that provider's endpoint. + */ +describe("BYOK bearer-token provider", async () => { + const handler = new CapturingRequestHandler(); + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { requestHandler: handler }, + }); + + beforeEach(() => { + handler.reset(); + }); + + /** Drive one BYOK turn; the synthetic 404 errors the turn, which is expected. */ + async function runTurn( + providers: NamedProviderConfig[], + models: ProviderModelConfig[], + selectionId: string, + prompt: string + ): Promise { + const session = await client.createSession({ + onPermissionRequest: approveAll, + model: selectionId, + providers, + models, + }); + try { + // The interceptor always 404s, so the turn errors after the runtime + // has already sent the (token-bearing) request — which is all we + // assert on. Swallow the resulting error. + await session.sendAndWait({ prompt }).catch(() => undefined); + } finally { + try { + await session.disconnect(); + } catch { + // ignore disconnect errors for the fake BYOK endpoint + } + } + } + + it("applies the callback's token as the Authorization header", async () => { + const SENTINEL = "sentinel-bearer-token-abc123"; + let calls = 0; + const getBearerToken: GetBearerToken = async () => { + calls += 1; + return SENTINEL; + }; + + const providers: NamedProviderConfig[] = [ + { + name: "mi", + type: "openai", + wireApi: "completions", + baseUrl: PRIMARY_BASE_URL, + getBearerToken, + }, + ]; + const models: ProviderModelConfig[] = [ + { id: "default", provider: "mi", wireModel: "byok-gpt-4o" }, + ]; + + await runTurn(providers, models, "mi/default", "What is 5+5?"); + + // The runtime acquired a token via the callback and applied it verbatim as + // the bearer credential on the outbound model request. + expect(handler.authHeaders()).toContain(`Bearer ${SENTINEL}`); + expect(calls).toBeGreaterThanOrEqual(1); + }); + + it("re-acquires a fresh token for each request (no runtime caching)", async () => { + let calls = 0; + const getBearerToken: GetBearerToken = async () => { + calls += 1; + // A distinct token per acquisition proves the runtime re-invokes the + // callback per request rather than caching a previous token. + return `rotating-token-${calls}`; + }; + + const providers: NamedProviderConfig[] = [ + { + name: "mi", + type: "openai", + wireApi: "completions", + baseUrl: PRIMARY_BASE_URL, + getBearerToken, + }, + ]; + const models: ProviderModelConfig[] = [ + { id: "default", provider: "mi", wireModel: "byok-gpt-4o" }, + ]; + + await runTurn(providers, models, "mi/default", "What is 1+1?"); + await runTurn(providers, models, "mi/default", "What is 2+2?"); + + // Each outbound request carries a freshly-acquired, distinct token. + const auths = handler.authHeaders(); + expect(auths.length).toBeGreaterThanOrEqual(2); + expect(auths[0]).toMatch(/^Bearer rotating-token-\d+$/); + expect(auths[1]).toMatch(/^Bearer rotating-token-\d+$/); + expect(auths[0]).not.toBe(auths[1]); + expect(calls).toBeGreaterThanOrEqual(2); + }); + + it("dispatches token acquisition per provider", async () => { + const tokenByProvider: Record = { + red: "token-for-red", + blue: "token-for-blue", + }; + const acquiredFor: string[] = []; + const makeCallback = + (providerName: string): GetBearerToken => + async (args) => { + // The runtime forwards the requesting provider's name so the client + // can dispatch to the right credential. + expect(args.providerName).toBe(providerName); + acquiredFor.push(providerName); + return tokenByProvider[providerName]; + }; + + const providers: NamedProviderConfig[] = [ + { + name: "red", + type: "openai", + wireApi: "completions", + baseUrl: RED_BASE_URL, + getBearerToken: makeCallback("red"), + }, + { + name: "blue", + type: "openai", + wireApi: "completions", + baseUrl: BLUE_BASE_URL, + getBearerToken: makeCallback("blue"), + }, + ]; + const models: ProviderModelConfig[] = [ + { id: "default", provider: "red", wireModel: "byok-gpt-4o" }, + { id: "default", provider: "blue", wireModel: "byok-gpt-4o" }, + ]; + + await runTurn(providers, models, "red/default", "What is 3+3?"); + await runTurn(providers, models, "blue/default", "What is 4+4?"); + + // Each provider's turn was authenticated with its own token AND that token + // was delivered to that provider's endpoint, proving per-provider dispatch + // (not a single session-global credential). + expect(handler.authHeaderForHost(RED_HOST)).toBe(`Bearer ${tokenByProvider.red}`); + expect(handler.authHeaderForHost(BLUE_HOST)).toBe(`Bearer ${tokenByProvider.blue}`); + expect(acquiredFor).toContain("red"); + expect(acquiredFor).toContain("blue"); + }); +}); diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 06ecf4188..1e7a3afb1 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -100,6 +100,7 @@ ExitPlanModeHandler, ExitPlanModeRequest, ExitPlanModeResult, + GetBearerToken, InfiniteSessionConfig, InputOptions, LargeToolOutputConfig, @@ -128,6 +129,7 @@ PreToolUseHookOutput, ProviderConfig, ProviderModelConfig, + ProviderTokenArgs, ReasoningSummary, SessionCapabilities, SessionEndHandler, @@ -214,6 +216,7 @@ "ExtensionInfo", "CopilotWebSocketForwarder", "GetAuthStatusResponse", + "GetBearerToken", "GetStatusResponse", "InfiniteSessionConfig", "InputOptions", @@ -257,6 +260,7 @@ "PreToolUseHookOutput", "ProviderConfig", "ProviderModelConfig", + "ProviderTokenArgs", "ReasoningSummary", "RemoteSessionMode", "RuntimeConnection", diff --git a/python/copilot/client.py b/python/copilot/client.py index 5dc670903..cfe031532 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -89,6 +89,7 @@ DefaultAgentConfig, ElicitationHandler, ExitPlanModeHandler, + GetBearerToken, InfiniteSessionConfig, LargeToolOutputConfig, MCPServerConfig, @@ -171,6 +172,36 @@ def _capi_session_options_to_wire(options: CapiSessionOptions) -> dict[str, Any] return wire +# Implicit provider name for the singular, whole-session ``provider`` config. +# Named providers are keyed by their own ``name``. +_DEFAULT_BEARER_TOKEN_PROVIDER_NAME = "default" + + +def _collect_bearer_token_callbacks( + provider: ProviderConfig | None, + providers: list[NamedProviderConfig] | None, +) -> dict[str, GetBearerToken]: + """Collect per-provider ``get_bearer_token`` callbacks keyed by provider name. + + The singular, whole-session ``provider`` uses the implicit + ``_DEFAULT_BEARER_TOKEN_PROVIDER_NAME``; ``providers`` entries use their own + ``name``. The callbacks are never serialized — the wire conversion emits + ``hasBearerTokenProvider: true`` instead and the runtime calls back over + ``providerToken.getToken``. + """ + callbacks: dict[str, GetBearerToken] = {} + if provider is not None: + singular = provider.get("get_bearer_token") + if singular is not None: + callbacks[_DEFAULT_BEARER_TOKEN_PROVIDER_NAME] = singular + if providers: + for named in providers: + callback = named.get("get_bearer_token") + if callback is not None: + callbacks[named["name"]] = callback + return callbacks + + def _validate_session_fs_config(config: SessionFsConfig) -> None: if not config.get("initial_working_directory"): raise ValueError("session_fs.initial_working_directory is required") @@ -2112,6 +2143,9 @@ def _initialize_session(sid: str) -> CopilotSession: s._register_auto_mode_switch_handler(on_auto_mode_switch_request) if canvas_handler is not None: s._register_canvas_handler(canvas_handler) + s._register_bearer_token_providers( + _collect_bearer_token_callbacks(provider, providers) + ) if hooks: s._register_hooks(hooks) if transform_callbacks: @@ -2669,6 +2703,9 @@ async def resume_session( session._register_auto_mode_switch_handler(on_auto_mode_switch_request) if canvas_handler is not None: session._register_canvas_handler(canvas_handler) + session._register_bearer_token_providers( + _collect_bearer_token_callbacks(provider, providers) + ) if hooks: session._register_hooks(hooks) if transform_callbacks: @@ -3199,6 +3236,8 @@ def _convert_provider_to_wire_format( wire_provider["transport"] = provider["transport"] if "bearer_token" in provider: wire_provider["bearerToken"] = provider["bearer_token"] + if provider.get("get_bearer_token") is not None: + wire_provider["hasBearerTokenProvider"] = True if "headers" in provider: wire_provider["headers"] = provider["headers"] if "model_id" in provider: @@ -3235,6 +3274,8 @@ def _convert_named_provider_to_wire_format( wire["apiKey"] = provider["api_key"] if "bearer_token" in provider: wire["bearerToken"] = provider["bearer_token"] + if provider.get("get_bearer_token") is not None: + wire["hasBearerTokenProvider"] = True if "headers" in provider: wire["headers"] = provider["headers"] if "azure" in provider: diff --git a/python/copilot/generated/rpc.py b/python/copilot/generated/rpc.py index 6b46e465f..f02e59515 100644 --- a/python/copilot/generated/rpc.py +++ b/python/copilot/generated/rpc.py @@ -5150,6 +5150,27 @@ def to_dict(self) -> dict: result["model"] = from_union([from_str, from_none], self.model) return result +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class ProviderTokenAcquireResult: + """A bearer token supplied by the SDK client for a BYOK provider. The runtime sets it as + `Authorization: Bearer ` on the outbound request and does no caching; the SDK + consumer owns token caching and refresh. + """ + token: str + """The bearer token value (without the `Bearer ` prefix).""" + + @staticmethod + def from_dict(obj: Any) -> 'ProviderTokenAcquireResult': + assert isinstance(obj, dict) + token = from_str(obj.get("token")) + return ProviderTokenAcquireResult(token) + + def to_dict(self) -> dict: + result: dict = {} + result["token"] = from_str(self.token) + return result + # Experimental: this type is part of an experimental API and may change or be removed. @dataclass class PushAttachmentFileLineRange: @@ -11700,6 +11721,13 @@ class NamedProviderConfig: """Bearer token for authentication. Sets the Authorization header directly. Takes precedence over apiKey when both are set. """ + has_bearer_token_provider: bool | None = None + """When true, the SDK client supplies bearer tokens on demand: the runtime calls the + client-session `providerToken.getToken` callback before each request and uses the + returned token as the Authorization header. The token-acquiring function itself stays on + the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive + with `apiKey`/`bearerToken`. + """ headers: dict[str, str] | None = None """Custom HTTP headers to include in all outbound requests to the provider.""" @@ -11717,10 +11745,11 @@ def from_dict(obj: Any) -> 'NamedProviderConfig': api_key = from_union([from_str, from_none], obj.get("apiKey")) azure = from_union([ProviderConfigAzure.from_dict, from_none], obj.get("azure")) bearer_token = from_union([from_str, from_none], obj.get("bearerToken")) + has_bearer_token_provider = from_union([from_bool, from_none], obj.get("hasBearerTokenProvider")) headers = from_union([lambda x: from_dict(from_str, x), from_none], obj.get("headers")) type = from_union([ProviderType, from_none], obj.get("type")) wire_api = from_union([ProviderWireAPI, from_none], obj.get("wireApi")) - return NamedProviderConfig(base_url, name, api_key, azure, bearer_token, headers, type, wire_api) + return NamedProviderConfig(base_url, name, api_key, azure, bearer_token, has_bearer_token_provider, headers, type, wire_api) def to_dict(self) -> dict: result: dict = {} @@ -11732,6 +11761,8 @@ def to_dict(self) -> dict: result["azure"] = from_union([lambda x: to_class(ProviderConfigAzure, x), from_none], self.azure) if self.bearer_token is not None: result["bearerToken"] = from_union([from_str, from_none], self.bearer_token) + if self.has_bearer_token_provider is not None: + result["hasBearerTokenProvider"] = from_union([from_bool, from_none], self.has_bearer_token_provider) if self.headers is not None: result["headers"] = from_union([lambda x: from_dict(from_str, x), from_none], self.headers) if self.type is not None: @@ -11758,6 +11789,13 @@ class ProviderConfig: """Bearer token for authentication. Sets the Authorization header directly. Takes precedence over apiKey when both are set. """ + has_bearer_token_provider: bool | None = None + """When true, the SDK client supplies bearer tokens on demand: the runtime calls the + client-session `providerToken.getToken` callback before each request and uses the + returned token as the Authorization header. The token-acquiring function itself stays on + the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive + with `apiKey`/`bearerToken`. + """ headers: dict[str, str] | None = None """Custom HTTP headers to include in all outbound requests to the provider.""" @@ -11792,6 +11830,7 @@ def from_dict(obj: Any) -> 'ProviderConfig': api_key = from_union([from_str, from_none], obj.get("apiKey")) azure = from_union([ProviderConfigAzure.from_dict, from_none], obj.get("azure")) bearer_token = from_union([from_str, from_none], obj.get("bearerToken")) + has_bearer_token_provider = from_union([from_bool, from_none], obj.get("hasBearerTokenProvider")) headers = from_union([lambda x: from_dict(from_str, x), from_none], obj.get("headers")) max_context_window_tokens = from_union([from_float, from_none], obj.get("maxContextWindowTokens")) max_output_tokens = from_union([from_float, from_none], obj.get("maxOutputTokens")) @@ -11800,7 +11839,7 @@ def from_dict(obj: Any) -> 'ProviderConfig': type = from_union([ProviderType, from_none], obj.get("type")) wire_api = from_union([ProviderWireAPI, from_none], obj.get("wireApi")) wire_model = from_union([from_str, from_none], obj.get("wireModel")) - return ProviderConfig(base_url, api_key, azure, bearer_token, headers, max_context_window_tokens, max_output_tokens, max_prompt_tokens, model_id, type, wire_api, wire_model) + return ProviderConfig(base_url, api_key, azure, bearer_token, has_bearer_token_provider, headers, max_context_window_tokens, max_output_tokens, max_prompt_tokens, model_id, type, wire_api, wire_model) def to_dict(self) -> dict: result: dict = {} @@ -11811,6 +11850,8 @@ def to_dict(self) -> dict: result["azure"] = from_union([lambda x: to_class(ProviderConfigAzure, x), from_none], self.azure) if self.bearer_token is not None: result["bearerToken"] = from_union([from_str, from_none], self.bearer_token) + if self.has_bearer_token_provider is not None: + result["hasBearerTokenProvider"] = from_union([from_bool, from_none], self.has_bearer_token_provider) if self.headers is not None: result["headers"] = from_union([lambda x: from_dict(from_str, x), from_none], self.headers) if self.max_context_window_tokens is not None: @@ -16717,6 +16758,33 @@ def to_dict(self) -> dict: result["supports"] = from_union([lambda x: to_class(ModelCapabilitiesOverrideSupports, x), from_none], self.supports) return result +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class ProviderTokenAcquireRequest: + """Asks the SDK client to acquire a bearer token for a BYOK provider whose config set + `hasBearerTokenProvider: true`. Issued by the runtime before each outbound model request; + the runtime does no caching, so this is sent once per request. + """ + provider_name: str + """Name of the BYOK provider needing a token. For the legacy whole-session `provider` this + is the implicit provider name; for named providers it is `NamedProviderConfig.name`. + """ + session_id: str + """Target session identifier""" + + @staticmethod + def from_dict(obj: Any) -> 'ProviderTokenAcquireRequest': + assert isinstance(obj, dict) + provider_name = from_str(obj.get("providerName")) + session_id = from_str(obj.get("sessionId")) + return ProviderTokenAcquireRequest(provider_name, session_id) + + def to_dict(self) -> dict: + result: dict = {} + result["providerName"] = from_str(self.provider_name) + result["sessionId"] = from_str(self.session_id) + return result + # Experimental: this type is part of an experimental API and may change or be removed. @dataclass class OptionsUpdateAdditionalContentExclusionPolicy: @@ -21191,6 +21259,8 @@ class RPC: provider_get_endpoint_request: ProviderGetEndpointRequest provider_model_config: ProviderModelConfig provider_session_token: ProviderSessionToken + provider_token_acquire_request: ProviderTokenAcquireRequest + provider_token_acquire_result: ProviderTokenAcquireResult push_attachment: PushAttachment push_attachment_blob: PushAttachmentBlob push_attachment_directory: PushAttachmentDirectory @@ -21957,6 +22027,8 @@ def from_dict(obj: Any) -> 'RPC': provider_get_endpoint_request = ProviderGetEndpointRequest.from_dict(obj.get("ProviderGetEndpointRequest")) provider_model_config = ProviderModelConfig.from_dict(obj.get("ProviderModelConfig")) provider_session_token = ProviderSessionToken.from_dict(obj.get("ProviderSessionToken")) + provider_token_acquire_request = ProviderTokenAcquireRequest.from_dict(obj.get("ProviderTokenAcquireRequest")) + provider_token_acquire_result = ProviderTokenAcquireResult.from_dict(obj.get("ProviderTokenAcquireResult")) push_attachment = _load_PushAttachment(obj.get("PushAttachment")) push_attachment_blob = PushAttachmentBlob.from_dict(obj.get("PushAttachmentBlob")) push_attachment_directory = PushAttachmentDirectory.from_dict(obj.get("PushAttachmentDirectory")) @@ -22275,7 +22347,7 @@ def from_dict(obj: Any) -> 'RPC': subagent_settings = from_union([SubagentSettings.from_dict, from_none], obj.get("SubagentSettings")) task_progress = from_union([TaskProgress.from_dict, from_none], obj.get("TaskProgress")) workspace_summary = from_union([WorkspaceSummary.from_dict, from_none], obj.get("WorkspaceSummary")) - return RPC(abort_request, abort_result, account_all_users, account_get_all_users_result, account_get_current_auth_result, account_get_quota_request, account_get_quota_result, account_login_request, account_login_result, account_logout_request, account_logout_result, account_quota_snapshot, agent_discovery_path, agent_discovery_path_list, agent_discovery_path_scope, agent_get_current_result, agent_info, agent_info_source, agent_list, agent_registry_live_target_entry, agent_registry_live_target_entry_attention_kind, agent_registry_live_target_entry_kind, agent_registry_live_target_entry_last_terminal_event, agent_registry_live_target_entry_status, agent_registry_log_capture, agent_registry_log_capture_open_error_reason, agent_registry_spawn_error, agent_registry_spawn_permission_mode, agent_registry_spawn_registry_timeout, agent_registry_spawn_request, agent_registry_spawn_result, agent_registry_spawn_spawned, agent_registry_spawn_validation_error, agent_registry_spawn_validation_error_field, agent_registry_spawn_validation_error_reason, agent_reload_result, agents_discover_request, agent_select_request, agent_select_result, agents_get_discovery_paths_request, allow_all_permission_set_result, allow_all_permission_state, api_key_auth_info, auth_info, auth_info_type, cancel_user_requested_shell_command_result, canvas_action, canvas_action_invoke_request, canvas_action_invoke_result, canvas_close_request, canvas_host_context, canvas_host_context_capabilities, canvas_instance_availability, canvas_json_schema, canvas_list, canvas_list_open_result, canvas_open_request, canvas_provider_close_request, canvas_provider_invoke_action_request, canvas_provider_open_request, canvas_provider_open_result, canvas_session_context, command_list, commands_handle_pending_command_request, commands_handle_pending_command_result, commands_invoke_request, commands_list_request, commands_respond_to_queued_command_request, commands_respond_to_queued_command_result, configure_session_extensions_params, connected_remote_session_metadata, connected_remote_session_metadata_kind, connected_remote_session_metadata_repository, connect_remote_session_params, connect_request, connect_result, content_filter_mode, copilot_api_token_auth_info, copilot_user_response, copilot_user_response_endpoints, copilot_user_response_quota_snapshots, copilot_user_response_quota_snapshots_chat, copilot_user_response_quota_snapshots_completions, copilot_user_response_quota_snapshots_premium_interactions, current_model, current_tool_metadata, discovered_canvas, discovered_mcp_server, discovered_mcp_server_type, enqueue_command_params, enqueue_command_result, env_auth_info, event_log_read_request, event_log_release_interest_result, event_log_tail_result, event_log_types, events_agent_scope, events_cursor_status, events_read_result, execute_command_params, execute_command_result, extension, extension_context_push_input, extension_list, extensions_disable_request, extensions_enable_request, extension_source, extension_status, external_tool_result, external_tool_text_result_for_llm, external_tool_text_result_for_llm_binary_results_for_llm, external_tool_text_result_for_llm_binary_results_for_llm_type, external_tool_text_result_for_llm_content, external_tool_text_result_for_llm_content_audio, external_tool_text_result_for_llm_content_image, external_tool_text_result_for_llm_content_resource, external_tool_text_result_for_llm_content_resource_details, external_tool_text_result_for_llm_content_resource_link, external_tool_text_result_for_llm_content_resource_link_icon, external_tool_text_result_for_llm_content_resource_link_icon_theme, external_tool_text_result_for_llm_content_terminal, external_tool_text_result_for_llm_content_text, filter_mapping, fleet_start_request, fleet_start_result, folder_trust_add_params, folder_trust_check_params, folder_trust_check_result, gh_cli_auth_info, handle_pending_tool_call_request, handle_pending_tool_call_result, history_abort_manual_compaction_result, history_cancel_background_compaction_result, history_compact_context_window, history_compact_request, history_compact_result, history_summarize_for_handoff_result, history_truncate_request, history_truncate_result, hmac_auth_info, installed_plugin, installed_plugin_info, installed_plugin_source, installed_plugin_source_git_hub, installed_plugin_source_local, installed_plugin_source_url, instruction_discovery_path, instruction_discovery_path_kind, instruction_discovery_path_list, instruction_discovery_path_location, instructions_discover_request, instructions_get_discovery_paths_request, instructions_get_sources_result, instruction_source, instruction_source_location, instruction_source_type, llm_inference_headers, llm_inference_http_request_chunk_request, llm_inference_http_request_chunk_result, llm_inference_http_request_start_request, llm_inference_http_request_start_result, llm_inference_http_request_start_transport, llm_inference_http_response_chunk_error, llm_inference_http_response_chunk_request, llm_inference_http_response_chunk_result, llm_inference_http_response_start_request, llm_inference_http_response_start_result, llm_inference_set_provider_result, local_session_metadata_value, log_request, log_result, lsp_initialize_request, marketplace_add_result, marketplace_browse_result, marketplace_info, marketplace_list_result, marketplace_plugin_info, marketplace_refresh_entry, marketplace_refresh_result, marketplace_remove_result, mcp_allowed_server, mcp_apps_call_tool_request, mcp_apps_diagnose_capability, mcp_apps_diagnose_request, mcp_apps_diagnose_result, mcp_apps_diagnose_server, mcp_apps_host_context, mcp_apps_host_context_details, mcp_apps_host_context_details_available_display_mode, mcp_apps_host_context_details_display_mode, mcp_apps_host_context_details_platform, mcp_apps_host_context_details_theme, mcp_apps_list_tools_request, mcp_apps_list_tools_result, mcp_apps_read_resource_request, mcp_apps_read_resource_result, mcp_apps_resource_content, mcp_apps_set_host_context_details, mcp_apps_set_host_context_details_available_display_mode, mcp_apps_set_host_context_details_display_mode, mcp_apps_set_host_context_details_platform, mcp_apps_set_host_context_details_theme, mcp_apps_set_host_context_request, mcp_cancel_sampling_execution_params, mcp_cancel_sampling_execution_result, mcp_config_add_request, mcp_config_disable_request, mcp_config_enable_request, mcp_config_list, mcp_config_remove_request, mcp_config_update_request, mcp_configure_git_hub_request, mcp_configure_git_hub_result, mcp_disable_request, mcp_discover_request, mcp_discover_result, mcp_enable_request, mcp_execute_sampling_params, mcp_execute_sampling_request, mcp_execute_sampling_result, mcp_filtered_server, mcp_host_state, mcp_is_server_running_request, mcp_is_server_running_result, mcp_list_tools_request, mcp_list_tools_result, mcp_oauth_handle_pending_request, mcp_oauth_handle_pending_result, mcp_oauth_login_request, mcp_oauth_login_result, mcp_oauth_pending_request_response, mcp_oauth_respond_request, mcp_oauth_respond_result, mcp_register_external_client_request, mcp_reload_with_config_request, mcp_remove_git_hub_result, mcp_restart_server_request, mcp_sampling_execution_action, mcp_sampling_execution_result, mcp_server, mcp_server_auth_config, mcp_server_auth_config_redirect_port, mcp_server_config, mcp_server_config_defer_tools, mcp_server_config_http, mcp_server_config_http_oauth_grant_type, mcp_server_config_http_type, mcp_server_config_stdio, mcp_server_failure_info, mcp_server_list, mcp_server_needs_auth_info, mcp_set_env_value_mode_details, mcp_set_env_value_mode_params, mcp_set_env_value_mode_result, mcp_start_server_request, mcp_start_servers_result, mcp_stop_server_request, mcp_tools, mcp_unregister_external_client_request, memory_configuration, metadata_context_info_request, metadata_context_info_result, metadata_is_processing_result, metadata_recompute_context_tokens_request, metadata_recompute_context_tokens_result, metadata_record_context_change_request, metadata_record_context_change_result, metadata_set_working_directory_request, metadata_set_working_directory_result, metadata_snapshot_current_mode, metadata_snapshot_remote_metadata, metadata_snapshot_remote_metadata_repository, metadata_snapshot_remote_metadata_task_type, model, model_billing, model_billing_token_prices, model_billing_token_prices_long_context, model_capabilities, model_capabilities_limits, model_capabilities_limits_vision, model_capabilities_override, model_capabilities_override_limits, model_capabilities_override_limits_vision, model_capabilities_override_supports, model_capabilities_supports, model_list, model_list_request, model_picker_category, model_picker_price_category, model_policy, model_policy_state, model_set_reasoning_effort_request, model_set_reasoning_effort_result, models_list_request, model_switch_to_request, model_switch_to_result, mode_set_request, named_provider_config, name_get_result, name_set_auto_request, name_set_auto_result, name_set_request, open_canvas_instance, options_update_additional_content_exclusion_policy, options_update_additional_content_exclusion_policy_rule, options_update_additional_content_exclusion_policy_rule_source, options_update_additional_content_exclusion_policy_scope, options_update_context_tier, options_update_env_value_mode, options_update_reasoning_summary, options_update_tool_filter_precedence, pending_permission_request, pending_permission_request_list, permission_decision, permission_decision_approved, permission_decision_approved_for_location, permission_decision_approved_for_session, permission_decision_approve_for_location, permission_decision_approve_for_location_approval, permission_decision_approve_for_location_approval_commands, permission_decision_approve_for_location_approval_custom_tool, permission_decision_approve_for_location_approval_extension_management, permission_decision_approve_for_location_approval_extension_permission_access, permission_decision_approve_for_location_approval_mcp, permission_decision_approve_for_location_approval_mcp_sampling, permission_decision_approve_for_location_approval_memory, permission_decision_approve_for_location_approval_read, permission_decision_approve_for_location_approval_write, permission_decision_approve_for_session, permission_decision_approve_for_session_approval, permission_decision_approve_for_session_approval_commands, permission_decision_approve_for_session_approval_custom_tool, permission_decision_approve_for_session_approval_extension_management, permission_decision_approve_for_session_approval_extension_permission_access, permission_decision_approve_for_session_approval_mcp, permission_decision_approve_for_session_approval_mcp_sampling, permission_decision_approve_for_session_approval_memory, permission_decision_approve_for_session_approval_read, permission_decision_approve_for_session_approval_write, permission_decision_approve_once, permission_decision_approve_permanently, permission_decision_cancelled, permission_decision_denied_by_content_exclusion_policy, permission_decision_denied_by_permission_request_hook, permission_decision_denied_by_rules, permission_decision_denied_interactively_by_user, permission_decision_denied_no_approval_rule_and_could_not_request_from_user, permission_decision_reject, permission_decision_request, permission_decision_user_not_available, permission_location_add_tool_approval_params, permission_location_apply_params, permission_location_apply_result, permission_location_resolve_params, permission_location_resolve_result, permission_location_type, permission_paths_add_params, permission_paths_allowed_check_params, permission_paths_allowed_check_result, permission_paths_config, permission_paths_list, permission_paths_update_primary_params, permission_paths_workspace_check_params, permission_paths_workspace_check_result, permission_prompt_shown_notification, permission_request_result, permission_rules_set, permissions_configure_additional_content_exclusion_policy, permissions_configure_additional_content_exclusion_policy_rule, permissions_configure_additional_content_exclusion_policy_rule_source, permissions_configure_additional_content_exclusion_policy_scope, permissions_configure_params, permissions_configure_result, permissions_folder_trust_add_trusted_result, permissions_get_allow_all_request, permissions_locations_add_tool_approval_details, permissions_locations_add_tool_approval_details_commands, permissions_locations_add_tool_approval_details_custom_tool, permissions_locations_add_tool_approval_details_extension_management, permissions_locations_add_tool_approval_details_extension_permission_access, permissions_locations_add_tool_approval_details_mcp, permissions_locations_add_tool_approval_details_mcp_sampling, permissions_locations_add_tool_approval_details_memory, permissions_locations_add_tool_approval_details_read, permissions_locations_add_tool_approval_details_write, permissions_locations_add_tool_approval_result, permissions_modify_rules_params, permissions_modify_rules_result, permissions_modify_rules_scope, permissions_notify_prompt_shown_result, permissions_paths_add_result, permissions_paths_list_request, permissions_paths_update_primary_result, permissions_pending_requests_request, permissions_reset_session_approvals_request, permissions_reset_session_approvals_result, permissions_set_allow_all_request, permissions_set_allow_all_source, permissions_set_approve_all_request, permissions_set_approve_all_result, permissions_set_approve_all_source, permissions_set_required_request, permissions_set_required_result, permissions_urls_set_unrestricted_mode_result, permission_urls_config, permission_urls_set_unrestricted_mode_params, ping_request, ping_result, plan_read_result, plan_read_sql_todos_result, plan_read_sql_todos_with_dependencies_result, plan_sql_todo_dependency, plan_sql_todos_row, plan_update_request, plugin, plugin_install_result, plugin_list, plugin_list_result, plugins_disable_request, plugins_enable_request, plugins_install_request, plugins_marketplaces_add_request, plugins_marketplaces_browse_request, plugins_marketplaces_refresh_request, plugins_marketplaces_remove_request, plugins_reload_request, plugins_uninstall_request, plugins_update_request, plugin_update_all_entry, plugin_update_all_result, plugin_update_result, poll_spawned_sessions_result, provider_add_request, provider_add_result, provider_config, provider_config_azure, provider_config_type, provider_config_wire_api, provider_endpoint, provider_endpoint_type, provider_endpoint_wire_api, provider_get_endpoint_request, provider_model_config, provider_session_token, push_attachment, push_attachment_blob, push_attachment_directory, push_attachment_file, push_attachment_file_line_range, push_attachment_git_hub_reference, push_attachment_git_hub_reference_type, push_attachment_selection, push_attachment_selection_details, push_attachment_selection_details_end, push_attachment_selection_details_start, queued_command_handled, queued_command_not_handled, queued_command_result, queue_pending_items, queue_pending_items_kind, queue_pending_items_result, queue_remove_most_recent_result, register_event_interest_params, register_event_interest_result, register_extension_tools_params, register_extension_tools_result, release_event_interest_params, remote_control_config, remote_control_config_existing_mc_session, remote_control_status, remote_control_status_active, remote_control_status_connecting, remote_control_status_error, remote_control_status_off, remote_control_status_result, remote_control_stop_result, remote_control_transfer_result, remote_enable_request, remote_enable_result, remote_notify_steerable_changed_request, remote_notify_steerable_changed_result, remote_session_connection_result, remote_session_metadata_repository, remote_session_metadata_task_type, remote_session_metadata_value, remote_session_mode, remote_session_repository, sandbox_config, sandbox_config_user_policy, sandbox_config_user_policy_experimental, sandbox_config_user_policy_experimental_seatbelt, sandbox_config_user_policy_filesystem, sandbox_config_user_policy_network, schedule_entry, schedule_list, schedule_stop_request, schedule_stop_result, secrets_add_filter_values_request, secrets_add_filter_values_result, send_agent_mode, send_attachments_to_message_params, send_mode, send_request, send_result, server_agent_list, server_instruction_source_list, server_skill, server_skill_list, session_activity, session_auth_status, session_bulk_delete_result, session_capability, session_context, session_context_host_type, session_enrich_metadata_result, session_fs_append_file_request, session_fs_error, session_fs_error_code, session_fs_exists_request, session_fs_exists_result, session_fs_mkdir_request, session_fs_readdir_request, session_fs_readdir_result, session_fs_readdir_with_types_entry, session_fs_readdir_with_types_entry_type, session_fs_readdir_with_types_request, session_fs_readdir_with_types_result, session_fs_read_file_request, session_fs_read_file_result, session_fs_rename_request, session_fs_rm_request, session_fs_set_provider_capabilities, session_fs_set_provider_conventions, session_fs_set_provider_request, session_fs_set_provider_result, session_fs_sqlite_exists_request, session_fs_sqlite_exists_result, session_fs_sqlite_query_request, session_fs_sqlite_query_result, session_fs_sqlite_query_type, session_fs_stat_request, session_fs_stat_result, session_fs_write_file_request, session_installed_plugin, session_installed_plugin_source, session_installed_plugin_source_git_hub, session_installed_plugin_source_local, session_installed_plugin_source_url, session_list, session_list_entry, session_list_filter, session_load_deferred_repo_hooks_result, session_log_level, session_mcp_apps_call_tool_result, session_metadata_snapshot, session_mode, session_model_list, session_open_options, session_open_options_additional_content_exclusion_policy, session_open_options_additional_content_exclusion_policy_rule, session_open_options_additional_content_exclusion_policy_rule_source, session_open_options_additional_content_exclusion_policy_scope, session_open_options_env_value_mode, session_open_options_reasoning_summary, session_open_params, session_open_result, session_prune_result, sessions_bulk_delete_request, sessions_check_in_use_request, sessions_check_in_use_result, sessions_close_request, sessions_close_result, sessions_enrich_metadata_request, session_set_credentials_params, session_set_credentials_result, sessions_find_by_prefix_request, sessions_find_by_prefix_result, sessions_find_by_task_id_request, sessions_find_by_task_id_result, sessions_fork_request, sessions_fork_result, sessions_get_board_entry_count_request, sessions_get_board_entry_count_result, sessions_get_event_file_path_request, sessions_get_event_file_path_result, sessions_get_last_for_context_request, sessions_get_last_for_context_result, sessions_get_persisted_remote_steerable_request, sessions_get_persisted_remote_steerable_result, session_sizes, sessions_list_request, sessions_load_deferred_repo_hooks_request, sessions_open_attach, sessions_open_cloud, sessions_open_create, sessions_open_handoff, sessions_open_handoff_task_type, sessions_open_progress, sessions_open_progress_status, sessions_open_progress_step, sessions_open_remote, sessions_open_resume, sessions_open_resume_last, sessions_open_status, session_source, sessions_poll_spawned_sessions_event, sessions_poll_spawned_sessions_request, sessions_prune_old_request, sessions_register_extension_tools_on_session_options, sessions_release_lock_request, sessions_release_lock_result, sessions_reload_plugin_hooks_request, sessions_reload_plugin_hooks_result, sessions_save_request, sessions_save_result, sessions_set_additional_plugins_request, sessions_set_additional_plugins_result, sessions_set_remote_control_steering_request, sessions_start_remote_control_request, sessions_stop_remote_control_request, sessions_transfer_remote_control_request, session_telemetry_engagement, session_update_options_params, session_update_options_result, session_working_directory_context, session_working_directory_context_host_type, shell_cancel_user_requested_request, shell_exec_request, shell_exec_result, shell_execute_user_requested_request, shell_kill_request, shell_kill_result, shell_kill_signal, shutdown_request, skill, skill_discovery_path, skill_discovery_path_list, skill_discovery_scope, skill_list, skills_config_set_disabled_skills_request, skills_disable_request, skills_discover_request, skills_enable_request, skills_get_discovery_paths_request, skills_get_invoked_result, skills_invoked_skill, skills_load_diagnostics, slash_command_agent_prompt_result, slash_command_completed_result, slash_command_info, slash_command_input, slash_command_input_completion, slash_command_invocation_result, slash_command_kind, slash_command_select_subcommand_option, slash_command_select_subcommand_result, slash_command_text_result, subagent_settings_entry, subagent_settings_entry_context_tier, task_agent_info, task_agent_progress, task_execution_mode, task_info, task_list, task_progress_line, tasks_cancel_request, tasks_cancel_result, tasks_get_current_promotable_result, tasks_get_progress_request, tasks_get_progress_result, task_shell_info, task_shell_info_attachment_mode, task_shell_progress, tasks_promote_current_to_background_result, tasks_promote_to_background_request, tasks_promote_to_background_result, tasks_refresh_result, tasks_remove_request, tasks_remove_result, tasks_send_message_request, tasks_send_message_result, tasks_start_agent_request, tasks_start_agent_result, task_status, tasks_wait_for_pending_result, telemetry_set_feature_overrides_request, token_auth_info, tool, tool_list, tools_get_current_metadata_result, tools_initialize_and_validate_result, tools_list_request, tools_update_subagent_settings_result, ui_auto_mode_switch_response, ui_elicitation_array_any_of_field, ui_elicitation_array_any_of_field_items, ui_elicitation_array_any_of_field_items_any_of, ui_elicitation_array_enum_field, ui_elicitation_array_enum_field_items, ui_elicitation_field_value, ui_elicitation_request, ui_elicitation_response, ui_elicitation_response_action, ui_elicitation_response_content, ui_elicitation_result, ui_elicitation_schema, ui_elicitation_schema_property, ui_elicitation_schema_property_boolean, ui_elicitation_schema_property_number, ui_elicitation_schema_property_number_type, ui_elicitation_schema_property_string, ui_elicitation_schema_property_string_format, ui_elicitation_string_enum_field, ui_elicitation_string_one_of_field, ui_elicitation_string_one_of_field_one_of, ui_ephemeral_query_request, ui_ephemeral_query_result, ui_exit_plan_mode_action, ui_exit_plan_mode_response, ui_handle_pending_auto_mode_switch_request, ui_handle_pending_elicitation_request, ui_handle_pending_exit_plan_mode_request, ui_handle_pending_result, ui_handle_pending_sampling_request, ui_handle_pending_sampling_response, ui_handle_pending_user_input_request, ui_register_direct_auto_mode_switch_handler_result, ui_unregister_direct_auto_mode_switch_handler_request, ui_unregister_direct_auto_mode_switch_handler_result, ui_user_input_response, update_subagent_settings_request, usage_get_metrics_result, usage_metrics_code_changes, usage_metrics_model_metric, usage_metrics_model_metric_requests, usage_metrics_model_metric_token_detail, usage_metrics_model_metric_usage, usage_metrics_token_detail, user_auth_info, user_requested_shell_command_result, workspace_diff_file_change, workspace_diff_file_change_type, workspace_diff_mode, workspace_diff_result, workspaces_checkpoints, workspaces_create_file_request, workspaces_diff_request, workspaces_get_workspace_result, workspaces_list_checkpoints_result, workspaces_list_files_result, workspaces_read_checkpoint_request, workspaces_read_checkpoint_result, workspaces_read_file_request, workspaces_read_file_result, workspaces_save_large_paste_request, workspaces_save_large_paste_result, workspace_summary_host_type, workspaces_workspace_details_host_type, session_context_info, subagent_settings, task_progress, workspace_summary) + return RPC(abort_request, abort_result, account_all_users, account_get_all_users_result, account_get_current_auth_result, account_get_quota_request, account_get_quota_result, account_login_request, account_login_result, account_logout_request, account_logout_result, account_quota_snapshot, agent_discovery_path, agent_discovery_path_list, agent_discovery_path_scope, agent_get_current_result, agent_info, agent_info_source, agent_list, agent_registry_live_target_entry, agent_registry_live_target_entry_attention_kind, agent_registry_live_target_entry_kind, agent_registry_live_target_entry_last_terminal_event, agent_registry_live_target_entry_status, agent_registry_log_capture, agent_registry_log_capture_open_error_reason, agent_registry_spawn_error, agent_registry_spawn_permission_mode, agent_registry_spawn_registry_timeout, agent_registry_spawn_request, agent_registry_spawn_result, agent_registry_spawn_spawned, agent_registry_spawn_validation_error, agent_registry_spawn_validation_error_field, agent_registry_spawn_validation_error_reason, agent_reload_result, agents_discover_request, agent_select_request, agent_select_result, agents_get_discovery_paths_request, allow_all_permission_set_result, allow_all_permission_state, api_key_auth_info, auth_info, auth_info_type, cancel_user_requested_shell_command_result, canvas_action, canvas_action_invoke_request, canvas_action_invoke_result, canvas_close_request, canvas_host_context, canvas_host_context_capabilities, canvas_instance_availability, canvas_json_schema, canvas_list, canvas_list_open_result, canvas_open_request, canvas_provider_close_request, canvas_provider_invoke_action_request, canvas_provider_open_request, canvas_provider_open_result, canvas_session_context, command_list, commands_handle_pending_command_request, commands_handle_pending_command_result, commands_invoke_request, commands_list_request, commands_respond_to_queued_command_request, commands_respond_to_queued_command_result, configure_session_extensions_params, connected_remote_session_metadata, connected_remote_session_metadata_kind, connected_remote_session_metadata_repository, connect_remote_session_params, connect_request, connect_result, content_filter_mode, copilot_api_token_auth_info, copilot_user_response, copilot_user_response_endpoints, copilot_user_response_quota_snapshots, copilot_user_response_quota_snapshots_chat, copilot_user_response_quota_snapshots_completions, copilot_user_response_quota_snapshots_premium_interactions, current_model, current_tool_metadata, discovered_canvas, discovered_mcp_server, discovered_mcp_server_type, enqueue_command_params, enqueue_command_result, env_auth_info, event_log_read_request, event_log_release_interest_result, event_log_tail_result, event_log_types, events_agent_scope, events_cursor_status, events_read_result, execute_command_params, execute_command_result, extension, extension_context_push_input, extension_list, extensions_disable_request, extensions_enable_request, extension_source, extension_status, external_tool_result, external_tool_text_result_for_llm, external_tool_text_result_for_llm_binary_results_for_llm, external_tool_text_result_for_llm_binary_results_for_llm_type, external_tool_text_result_for_llm_content, external_tool_text_result_for_llm_content_audio, external_tool_text_result_for_llm_content_image, external_tool_text_result_for_llm_content_resource, external_tool_text_result_for_llm_content_resource_details, external_tool_text_result_for_llm_content_resource_link, external_tool_text_result_for_llm_content_resource_link_icon, external_tool_text_result_for_llm_content_resource_link_icon_theme, external_tool_text_result_for_llm_content_terminal, external_tool_text_result_for_llm_content_text, filter_mapping, fleet_start_request, fleet_start_result, folder_trust_add_params, folder_trust_check_params, folder_trust_check_result, gh_cli_auth_info, handle_pending_tool_call_request, handle_pending_tool_call_result, history_abort_manual_compaction_result, history_cancel_background_compaction_result, history_compact_context_window, history_compact_request, history_compact_result, history_summarize_for_handoff_result, history_truncate_request, history_truncate_result, hmac_auth_info, installed_plugin, installed_plugin_info, installed_plugin_source, installed_plugin_source_git_hub, installed_plugin_source_local, installed_plugin_source_url, instruction_discovery_path, instruction_discovery_path_kind, instruction_discovery_path_list, instruction_discovery_path_location, instructions_discover_request, instructions_get_discovery_paths_request, instructions_get_sources_result, instruction_source, instruction_source_location, instruction_source_type, llm_inference_headers, llm_inference_http_request_chunk_request, llm_inference_http_request_chunk_result, llm_inference_http_request_start_request, llm_inference_http_request_start_result, llm_inference_http_request_start_transport, llm_inference_http_response_chunk_error, llm_inference_http_response_chunk_request, llm_inference_http_response_chunk_result, llm_inference_http_response_start_request, llm_inference_http_response_start_result, llm_inference_set_provider_result, local_session_metadata_value, log_request, log_result, lsp_initialize_request, marketplace_add_result, marketplace_browse_result, marketplace_info, marketplace_list_result, marketplace_plugin_info, marketplace_refresh_entry, marketplace_refresh_result, marketplace_remove_result, mcp_allowed_server, mcp_apps_call_tool_request, mcp_apps_diagnose_capability, mcp_apps_diagnose_request, mcp_apps_diagnose_result, mcp_apps_diagnose_server, mcp_apps_host_context, mcp_apps_host_context_details, mcp_apps_host_context_details_available_display_mode, mcp_apps_host_context_details_display_mode, mcp_apps_host_context_details_platform, mcp_apps_host_context_details_theme, mcp_apps_list_tools_request, mcp_apps_list_tools_result, mcp_apps_read_resource_request, mcp_apps_read_resource_result, mcp_apps_resource_content, mcp_apps_set_host_context_details, mcp_apps_set_host_context_details_available_display_mode, mcp_apps_set_host_context_details_display_mode, mcp_apps_set_host_context_details_platform, mcp_apps_set_host_context_details_theme, mcp_apps_set_host_context_request, mcp_cancel_sampling_execution_params, mcp_cancel_sampling_execution_result, mcp_config_add_request, mcp_config_disable_request, mcp_config_enable_request, mcp_config_list, mcp_config_remove_request, mcp_config_update_request, mcp_configure_git_hub_request, mcp_configure_git_hub_result, mcp_disable_request, mcp_discover_request, mcp_discover_result, mcp_enable_request, mcp_execute_sampling_params, mcp_execute_sampling_request, mcp_execute_sampling_result, mcp_filtered_server, mcp_host_state, mcp_is_server_running_request, mcp_is_server_running_result, mcp_list_tools_request, mcp_list_tools_result, mcp_oauth_handle_pending_request, mcp_oauth_handle_pending_result, mcp_oauth_login_request, mcp_oauth_login_result, mcp_oauth_pending_request_response, mcp_oauth_respond_request, mcp_oauth_respond_result, mcp_register_external_client_request, mcp_reload_with_config_request, mcp_remove_git_hub_result, mcp_restart_server_request, mcp_sampling_execution_action, mcp_sampling_execution_result, mcp_server, mcp_server_auth_config, mcp_server_auth_config_redirect_port, mcp_server_config, mcp_server_config_defer_tools, mcp_server_config_http, mcp_server_config_http_oauth_grant_type, mcp_server_config_http_type, mcp_server_config_stdio, mcp_server_failure_info, mcp_server_list, mcp_server_needs_auth_info, mcp_set_env_value_mode_details, mcp_set_env_value_mode_params, mcp_set_env_value_mode_result, mcp_start_server_request, mcp_start_servers_result, mcp_stop_server_request, mcp_tools, mcp_unregister_external_client_request, memory_configuration, metadata_context_info_request, metadata_context_info_result, metadata_is_processing_result, metadata_recompute_context_tokens_request, metadata_recompute_context_tokens_result, metadata_record_context_change_request, metadata_record_context_change_result, metadata_set_working_directory_request, metadata_set_working_directory_result, metadata_snapshot_current_mode, metadata_snapshot_remote_metadata, metadata_snapshot_remote_metadata_repository, metadata_snapshot_remote_metadata_task_type, model, model_billing, model_billing_token_prices, model_billing_token_prices_long_context, model_capabilities, model_capabilities_limits, model_capabilities_limits_vision, model_capabilities_override, model_capabilities_override_limits, model_capabilities_override_limits_vision, model_capabilities_override_supports, model_capabilities_supports, model_list, model_list_request, model_picker_category, model_picker_price_category, model_policy, model_policy_state, model_set_reasoning_effort_request, model_set_reasoning_effort_result, models_list_request, model_switch_to_request, model_switch_to_result, mode_set_request, named_provider_config, name_get_result, name_set_auto_request, name_set_auto_result, name_set_request, open_canvas_instance, options_update_additional_content_exclusion_policy, options_update_additional_content_exclusion_policy_rule, options_update_additional_content_exclusion_policy_rule_source, options_update_additional_content_exclusion_policy_scope, options_update_context_tier, options_update_env_value_mode, options_update_reasoning_summary, options_update_tool_filter_precedence, pending_permission_request, pending_permission_request_list, permission_decision, permission_decision_approved, permission_decision_approved_for_location, permission_decision_approved_for_session, permission_decision_approve_for_location, permission_decision_approve_for_location_approval, permission_decision_approve_for_location_approval_commands, permission_decision_approve_for_location_approval_custom_tool, permission_decision_approve_for_location_approval_extension_management, permission_decision_approve_for_location_approval_extension_permission_access, permission_decision_approve_for_location_approval_mcp, permission_decision_approve_for_location_approval_mcp_sampling, permission_decision_approve_for_location_approval_memory, permission_decision_approve_for_location_approval_read, permission_decision_approve_for_location_approval_write, permission_decision_approve_for_session, permission_decision_approve_for_session_approval, permission_decision_approve_for_session_approval_commands, permission_decision_approve_for_session_approval_custom_tool, permission_decision_approve_for_session_approval_extension_management, permission_decision_approve_for_session_approval_extension_permission_access, permission_decision_approve_for_session_approval_mcp, permission_decision_approve_for_session_approval_mcp_sampling, permission_decision_approve_for_session_approval_memory, permission_decision_approve_for_session_approval_read, permission_decision_approve_for_session_approval_write, permission_decision_approve_once, permission_decision_approve_permanently, permission_decision_cancelled, permission_decision_denied_by_content_exclusion_policy, permission_decision_denied_by_permission_request_hook, permission_decision_denied_by_rules, permission_decision_denied_interactively_by_user, permission_decision_denied_no_approval_rule_and_could_not_request_from_user, permission_decision_reject, permission_decision_request, permission_decision_user_not_available, permission_location_add_tool_approval_params, permission_location_apply_params, permission_location_apply_result, permission_location_resolve_params, permission_location_resolve_result, permission_location_type, permission_paths_add_params, permission_paths_allowed_check_params, permission_paths_allowed_check_result, permission_paths_config, permission_paths_list, permission_paths_update_primary_params, permission_paths_workspace_check_params, permission_paths_workspace_check_result, permission_prompt_shown_notification, permission_request_result, permission_rules_set, permissions_configure_additional_content_exclusion_policy, permissions_configure_additional_content_exclusion_policy_rule, permissions_configure_additional_content_exclusion_policy_rule_source, permissions_configure_additional_content_exclusion_policy_scope, permissions_configure_params, permissions_configure_result, permissions_folder_trust_add_trusted_result, permissions_get_allow_all_request, permissions_locations_add_tool_approval_details, permissions_locations_add_tool_approval_details_commands, permissions_locations_add_tool_approval_details_custom_tool, permissions_locations_add_tool_approval_details_extension_management, permissions_locations_add_tool_approval_details_extension_permission_access, permissions_locations_add_tool_approval_details_mcp, permissions_locations_add_tool_approval_details_mcp_sampling, permissions_locations_add_tool_approval_details_memory, permissions_locations_add_tool_approval_details_read, permissions_locations_add_tool_approval_details_write, permissions_locations_add_tool_approval_result, permissions_modify_rules_params, permissions_modify_rules_result, permissions_modify_rules_scope, permissions_notify_prompt_shown_result, permissions_paths_add_result, permissions_paths_list_request, permissions_paths_update_primary_result, permissions_pending_requests_request, permissions_reset_session_approvals_request, permissions_reset_session_approvals_result, permissions_set_allow_all_request, permissions_set_allow_all_source, permissions_set_approve_all_request, permissions_set_approve_all_result, permissions_set_approve_all_source, permissions_set_required_request, permissions_set_required_result, permissions_urls_set_unrestricted_mode_result, permission_urls_config, permission_urls_set_unrestricted_mode_params, ping_request, ping_result, plan_read_result, plan_read_sql_todos_result, plan_read_sql_todos_with_dependencies_result, plan_sql_todo_dependency, plan_sql_todos_row, plan_update_request, plugin, plugin_install_result, plugin_list, plugin_list_result, plugins_disable_request, plugins_enable_request, plugins_install_request, plugins_marketplaces_add_request, plugins_marketplaces_browse_request, plugins_marketplaces_refresh_request, plugins_marketplaces_remove_request, plugins_reload_request, plugins_uninstall_request, plugins_update_request, plugin_update_all_entry, plugin_update_all_result, plugin_update_result, poll_spawned_sessions_result, provider_add_request, provider_add_result, provider_config, provider_config_azure, provider_config_type, provider_config_wire_api, provider_endpoint, provider_endpoint_type, provider_endpoint_wire_api, provider_get_endpoint_request, provider_model_config, provider_session_token, provider_token_acquire_request, provider_token_acquire_result, push_attachment, push_attachment_blob, push_attachment_directory, push_attachment_file, push_attachment_file_line_range, push_attachment_git_hub_reference, push_attachment_git_hub_reference_type, push_attachment_selection, push_attachment_selection_details, push_attachment_selection_details_end, push_attachment_selection_details_start, queued_command_handled, queued_command_not_handled, queued_command_result, queue_pending_items, queue_pending_items_kind, queue_pending_items_result, queue_remove_most_recent_result, register_event_interest_params, register_event_interest_result, register_extension_tools_params, register_extension_tools_result, release_event_interest_params, remote_control_config, remote_control_config_existing_mc_session, remote_control_status, remote_control_status_active, remote_control_status_connecting, remote_control_status_error, remote_control_status_off, remote_control_status_result, remote_control_stop_result, remote_control_transfer_result, remote_enable_request, remote_enable_result, remote_notify_steerable_changed_request, remote_notify_steerable_changed_result, remote_session_connection_result, remote_session_metadata_repository, remote_session_metadata_task_type, remote_session_metadata_value, remote_session_mode, remote_session_repository, sandbox_config, sandbox_config_user_policy, sandbox_config_user_policy_experimental, sandbox_config_user_policy_experimental_seatbelt, sandbox_config_user_policy_filesystem, sandbox_config_user_policy_network, schedule_entry, schedule_list, schedule_stop_request, schedule_stop_result, secrets_add_filter_values_request, secrets_add_filter_values_result, send_agent_mode, send_attachments_to_message_params, send_mode, send_request, send_result, server_agent_list, server_instruction_source_list, server_skill, server_skill_list, session_activity, session_auth_status, session_bulk_delete_result, session_capability, session_context, session_context_host_type, session_enrich_metadata_result, session_fs_append_file_request, session_fs_error, session_fs_error_code, session_fs_exists_request, session_fs_exists_result, session_fs_mkdir_request, session_fs_readdir_request, session_fs_readdir_result, session_fs_readdir_with_types_entry, session_fs_readdir_with_types_entry_type, session_fs_readdir_with_types_request, session_fs_readdir_with_types_result, session_fs_read_file_request, session_fs_read_file_result, session_fs_rename_request, session_fs_rm_request, session_fs_set_provider_capabilities, session_fs_set_provider_conventions, session_fs_set_provider_request, session_fs_set_provider_result, session_fs_sqlite_exists_request, session_fs_sqlite_exists_result, session_fs_sqlite_query_request, session_fs_sqlite_query_result, session_fs_sqlite_query_type, session_fs_stat_request, session_fs_stat_result, session_fs_write_file_request, session_installed_plugin, session_installed_plugin_source, session_installed_plugin_source_git_hub, session_installed_plugin_source_local, session_installed_plugin_source_url, session_list, session_list_entry, session_list_filter, session_load_deferred_repo_hooks_result, session_log_level, session_mcp_apps_call_tool_result, session_metadata_snapshot, session_mode, session_model_list, session_open_options, session_open_options_additional_content_exclusion_policy, session_open_options_additional_content_exclusion_policy_rule, session_open_options_additional_content_exclusion_policy_rule_source, session_open_options_additional_content_exclusion_policy_scope, session_open_options_env_value_mode, session_open_options_reasoning_summary, session_open_params, session_open_result, session_prune_result, sessions_bulk_delete_request, sessions_check_in_use_request, sessions_check_in_use_result, sessions_close_request, sessions_close_result, sessions_enrich_metadata_request, session_set_credentials_params, session_set_credentials_result, sessions_find_by_prefix_request, sessions_find_by_prefix_result, sessions_find_by_task_id_request, sessions_find_by_task_id_result, sessions_fork_request, sessions_fork_result, sessions_get_board_entry_count_request, sessions_get_board_entry_count_result, sessions_get_event_file_path_request, sessions_get_event_file_path_result, sessions_get_last_for_context_request, sessions_get_last_for_context_result, sessions_get_persisted_remote_steerable_request, sessions_get_persisted_remote_steerable_result, session_sizes, sessions_list_request, sessions_load_deferred_repo_hooks_request, sessions_open_attach, sessions_open_cloud, sessions_open_create, sessions_open_handoff, sessions_open_handoff_task_type, sessions_open_progress, sessions_open_progress_status, sessions_open_progress_step, sessions_open_remote, sessions_open_resume, sessions_open_resume_last, sessions_open_status, session_source, sessions_poll_spawned_sessions_event, sessions_poll_spawned_sessions_request, sessions_prune_old_request, sessions_register_extension_tools_on_session_options, sessions_release_lock_request, sessions_release_lock_result, sessions_reload_plugin_hooks_request, sessions_reload_plugin_hooks_result, sessions_save_request, sessions_save_result, sessions_set_additional_plugins_request, sessions_set_additional_plugins_result, sessions_set_remote_control_steering_request, sessions_start_remote_control_request, sessions_stop_remote_control_request, sessions_transfer_remote_control_request, session_telemetry_engagement, session_update_options_params, session_update_options_result, session_working_directory_context, session_working_directory_context_host_type, shell_cancel_user_requested_request, shell_exec_request, shell_exec_result, shell_execute_user_requested_request, shell_kill_request, shell_kill_result, shell_kill_signal, shutdown_request, skill, skill_discovery_path, skill_discovery_path_list, skill_discovery_scope, skill_list, skills_config_set_disabled_skills_request, skills_disable_request, skills_discover_request, skills_enable_request, skills_get_discovery_paths_request, skills_get_invoked_result, skills_invoked_skill, skills_load_diagnostics, slash_command_agent_prompt_result, slash_command_completed_result, slash_command_info, slash_command_input, slash_command_input_completion, slash_command_invocation_result, slash_command_kind, slash_command_select_subcommand_option, slash_command_select_subcommand_result, slash_command_text_result, subagent_settings_entry, subagent_settings_entry_context_tier, task_agent_info, task_agent_progress, task_execution_mode, task_info, task_list, task_progress_line, tasks_cancel_request, tasks_cancel_result, tasks_get_current_promotable_result, tasks_get_progress_request, tasks_get_progress_result, task_shell_info, task_shell_info_attachment_mode, task_shell_progress, tasks_promote_current_to_background_result, tasks_promote_to_background_request, tasks_promote_to_background_result, tasks_refresh_result, tasks_remove_request, tasks_remove_result, tasks_send_message_request, tasks_send_message_result, tasks_start_agent_request, tasks_start_agent_result, task_status, tasks_wait_for_pending_result, telemetry_set_feature_overrides_request, token_auth_info, tool, tool_list, tools_get_current_metadata_result, tools_initialize_and_validate_result, tools_list_request, tools_update_subagent_settings_result, ui_auto_mode_switch_response, ui_elicitation_array_any_of_field, ui_elicitation_array_any_of_field_items, ui_elicitation_array_any_of_field_items_any_of, ui_elicitation_array_enum_field, ui_elicitation_array_enum_field_items, ui_elicitation_field_value, ui_elicitation_request, ui_elicitation_response, ui_elicitation_response_action, ui_elicitation_response_content, ui_elicitation_result, ui_elicitation_schema, ui_elicitation_schema_property, ui_elicitation_schema_property_boolean, ui_elicitation_schema_property_number, ui_elicitation_schema_property_number_type, ui_elicitation_schema_property_string, ui_elicitation_schema_property_string_format, ui_elicitation_string_enum_field, ui_elicitation_string_one_of_field, ui_elicitation_string_one_of_field_one_of, ui_ephemeral_query_request, ui_ephemeral_query_result, ui_exit_plan_mode_action, ui_exit_plan_mode_response, ui_handle_pending_auto_mode_switch_request, ui_handle_pending_elicitation_request, ui_handle_pending_exit_plan_mode_request, ui_handle_pending_result, ui_handle_pending_sampling_request, ui_handle_pending_sampling_response, ui_handle_pending_user_input_request, ui_register_direct_auto_mode_switch_handler_result, ui_unregister_direct_auto_mode_switch_handler_request, ui_unregister_direct_auto_mode_switch_handler_result, ui_user_input_response, update_subagent_settings_request, usage_get_metrics_result, usage_metrics_code_changes, usage_metrics_model_metric, usage_metrics_model_metric_requests, usage_metrics_model_metric_token_detail, usage_metrics_model_metric_usage, usage_metrics_token_detail, user_auth_info, user_requested_shell_command_result, workspace_diff_file_change, workspace_diff_file_change_type, workspace_diff_mode, workspace_diff_result, workspaces_checkpoints, workspaces_create_file_request, workspaces_diff_request, workspaces_get_workspace_result, workspaces_list_checkpoints_result, workspaces_list_files_result, workspaces_read_checkpoint_request, workspaces_read_checkpoint_result, workspaces_read_file_request, workspaces_read_file_result, workspaces_save_large_paste_request, workspaces_save_large_paste_result, workspace_summary_host_type, workspaces_workspace_details_host_type, session_context_info, subagent_settings, task_progress, workspace_summary) def to_dict(self) -> dict: result: dict = {} @@ -22723,6 +22795,8 @@ def to_dict(self) -> dict: result["ProviderGetEndpointRequest"] = to_class(ProviderGetEndpointRequest, self.provider_get_endpoint_request) result["ProviderModelConfig"] = to_class(ProviderModelConfig, self.provider_model_config) result["ProviderSessionToken"] = to_class(ProviderSessionToken, self.provider_session_token) + result["ProviderTokenAcquireRequest"] = to_class(ProviderTokenAcquireRequest, self.provider_token_acquire_request) + result["ProviderTokenAcquireResult"] = to_class(ProviderTokenAcquireResult, self.provider_token_acquire_result) result["PushAttachment"] = (self.push_attachment).to_dict() result["PushAttachmentBlob"] = to_class(PushAttachmentBlob, self.push_attachment_blob) result["PushAttachmentDirectory"] = to_class(PushAttachmentDirectory, self.push_attachment_directory) @@ -25096,10 +25170,17 @@ async def invoke(self, params: CanvasProviderInvokeActionRequest) -> Any: "Invokes an action on an open canvas instance via the provider.\n\nArgs:\n params: Canvas action invocation parameters sent to the provider.\n\nReturns:\n Provider-supplied action result." pass +# Experimental: this API group is experimental and may change or be removed. +class ProviderTokenHandler(Protocol): + async def get_token(self, params: ProviderTokenAcquireRequest) -> ProviderTokenAcquireResult: + "Asks the SDK client to get a bearer token for a BYOK provider whose config set `hasBearerTokenProvider: true`. Session-scoped: the runtime calls it back on the connection that created the session, passing the provider name, and uses the returned token as the Authorization header for the outbound model request. The runtime does no caching — it calls this once per outbound request; the SDK consumer owns token acquisition, caching, and refresh.\n\nArgs:\n params: Asks the SDK client to acquire a bearer token for a BYOK provider whose config set `hasBearerTokenProvider: true`. Issued by the runtime before each outbound model request; the runtime does no caching, so this is sent once per request.\n\nReturns:\n A bearer token supplied by the SDK client for a BYOK provider. The runtime sets it as `Authorization: Bearer ` on the outbound request and does no caching; the SDK consumer owns token caching and refresh." + pass + @dataclass class ClientSessionApiHandlers: session_fs: SessionFsHandler | None = None canvas: CanvasHandler | None = None + provider_token: ProviderTokenHandler | None = None def register_client_session_api_handlers( client: "JsonRpcClient", @@ -25211,6 +25292,13 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: result = await handler.invoke(request) return result.value if hasattr(result, 'value') else result client.set_request_handler("canvas.action.invoke", handle_canvas_action_invoke) + async def handle_provider_token_get_token(params: dict) -> dict | None: + request = ProviderTokenAcquireRequest.from_dict(params) + handler = get_handlers(request.session_id).provider_token + if handler is None: raise RuntimeError(f"No provider_token handler registered for session: {request.session_id}") + result = await handler.get_token(request) + return result.to_dict() + client.set_request_handler("providerToken.getToken", handle_provider_token_get_token) # Experimental: this API group is experimental and may change or be removed. class LlmInferenceHandler(Protocol): @@ -25776,6 +25864,9 @@ async def handle_llm_inference_http_request_chunk(params: dict) -> dict | None: "ProviderGetEndpointRequest", "ProviderModelConfig", "ProviderSessionToken", + "ProviderTokenAcquireRequest", + "ProviderTokenAcquireResult", + "ProviderTokenHandler", "ProviderType", "ProviderWireAPI", "PurpleSource", diff --git a/python/copilot/session.py b/python/copilot/session.py index f15d6c7d3..c1bea85e7 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -44,6 +44,8 @@ PermissionDecisionApproveOnce, PermissionDecisionRequest, PermissionDecisionUserNotAvailable, + ProviderTokenAcquireRequest, + ProviderTokenAcquireResult, SessionLogLevel, SessionRpc, UIElicitationRequest, @@ -1066,6 +1068,29 @@ class AzureProviderOptions(TypedDict, total=False): api_version: str # Azure API version. Defaults to "2024-10-21". +class ProviderTokenArgs(TypedDict): + """Arguments passed to a :data:`GetBearerToken` callback when the runtime + needs a fresh bearer token for a BYOK provider. + + **Experimental.** Part of the bearer-token-provider surface and may change or + be removed in future SDK or CLI releases. + """ + + # Name of the BYOK provider needing a token. For the singular, whole-session + # ``provider`` this is the implicit provider name ("default"); for + # ``NamedProviderConfig`` entries it is ``NamedProviderConfig.name``. + provider_name: str + + +# Per-request callback that resolves a bearer token on demand for a BYOK +# provider (for example via Azure Managed Identity). The Copilot SDK takes no +# identity dependency: supply a callback backed by your own identity library. +# Never serialized — setting it makes the SDK send ``hasBearerTokenProvider`` on +# the wire and answer the runtime's ``providerToken.getToken`` requests. May be +# sync or async. +GetBearerToken = Callable[[ProviderTokenArgs], str | Awaitable[str]] + + class ProviderConfig(TypedDict, total=False): """Configuration for a custom API provider""" @@ -1102,6 +1127,12 @@ class ProviderConfig(TypedDict, total=False): # Overrides the resolved model's default max output tokens. When hit, the # model stops generating and returns a truncated response. max_output_tokens: int + # Per-request callback that resolves a bearer token on demand for this BYOK + # provider (for example via Azure Managed Identity). Never serialized — the + # SDK sends hasBearerTokenProvider: true on the wire and answers the + # runtime's providerToken.getToken requests with this callback's result. + # Mutually exclusive with api_key and bearer_token. + get_bearer_token: GetBearerToken class NamedProviderConfig(TypedDict, total=False): @@ -1128,6 +1159,11 @@ class NamedProviderConfig(TypedDict, total=False): bearer_token: str azure: AzureProviderOptions # Azure-specific options headers: dict[str, str] + # Per-request bearer-token callback for this named BYOK provider. Never + # serialized; the SDK sends hasBearerTokenProvider: true and answers the + # runtime's providerToken.getToken requests. Mutually exclusive with api_key + # and bearer_token. + get_bearer_token: GetBearerToken class ProviderModelConfig(TypedDict, total=False): @@ -1199,6 +1235,37 @@ def _canvas_handler_error(err: Exception) -> JsonRpcError: ) +class _BearerTokenProviderAdapter: + """Routes runtime ``providerToken.getToken`` requests to the matching + per-provider :data:`GetBearerToken` callback registered on the session. + + The runtime calls this once per outbound request for a BYOK provider that + declared ``hasBearerTokenProvider: true``; it does no caching, so the SDK + consumer's callback (typically backed by an identity library) owns + acquisition, caching, and refresh. + """ + + def __init__(self, session: CopilotSession) -> None: + self._session = session + + async def get_token( + self, params: ProviderTokenAcquireRequest + ) -> ProviderTokenAcquireResult: + provider_name = params.provider_name + with self._session._bearer_token_providers_lock: + callback = self._session._bearer_token_providers.get(provider_name) + if callback is None: + raise JsonRpcError( + -32603, + f"No bearer-token provider registered for provider: {provider_name!r}", + ) + args: ProviderTokenArgs = {"provider_name": provider_name} + result = callback(args) + if inspect.isawaitable(result): + result = await result + return ProviderTokenAcquireResult(token=cast(str, result)) + + class CopilotSession: """ Represents a single conversation session with the Copilot CLI. @@ -1264,6 +1331,8 @@ def __init__( self._transform_callbacks_lock = threading.Lock() self._command_handlers: dict[str, CommandHandler] = {} self._command_handlers_lock = threading.Lock() + self._bearer_token_providers: dict[str, GetBearerToken] = {} + self._bearer_token_providers_lock = threading.Lock() self._elicitation_handler: ElicitationHandler | None = None self._elicitation_handler_lock = threading.Lock() self._capabilities: SessionCapabilities = {} @@ -2009,6 +2078,28 @@ def _register_commands(self, commands: list[CommandDefinition] | None) -> None: for cmd in commands: self._command_handlers[cmd.name] = cmd.handler + def _register_bearer_token_providers( + self, providers: dict[str, GetBearerToken] | None + ) -> None: + """Register per-provider bearer-token callbacks for this session. + + The runtime never receives the callbacks themselves; the SDK strips them + from the provider config and instead sends ``hasBearerTokenProvider: + true``. When the runtime needs a token it issues a session-scoped + ``providerToken.getToken`` request, which the registered handler routes + to the matching per-provider callback. + + Args: + providers: Map of provider name -> callback, or None/empty to clear. + """ + with self._bearer_token_providers_lock: + self._bearer_token_providers.clear() + if not providers: + self._client_session_apis.provider_token = None + return + self._bearer_token_providers.update(providers) + self._client_session_apis.provider_token = _BearerTokenProviderAdapter(self) + def _register_elicitation_handler(self, handler: ElicitationHandler | None) -> None: """Register the elicitation handler for this session. diff --git a/python/e2e/test_byok_bearer_token_provider_e2e.py b/python/e2e/test_byok_bearer_token_provider_e2e.py new file mode 100644 index 000000000..44e238dd1 --- /dev/null +++ b/python/e2e/test_byok_bearer_token_provider_e2e.py @@ -0,0 +1,253 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# -------------------------------------------------------------------------------------------- + +"""E2E coverage for the experimental BYOK bearer-token-provider surface. + +Mirrors ``nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts``. A BYOK +provider config may carry a ``get_bearer_token`` callback; the callback stays +entirely on the SDK/client side. The SDK strips it from the wire config, sets +the ``hasBearerTokenProvider`` flag, and the runtime calls back over the +session-scoped ``providerToken.getToken`` RPC before each outbound model +request, applying the returned token as the ``Authorization`` header. + +Like the other ``copilot_request_*`` tests, this one installs a client-global +``CopilotRequestHandler`` instead of using the CAPI proxy: the handler +fabricates the bootstrap (catalog/policy) responses and intercepts the +runtime's outbound BYOK request in-process, capturing the ``Authorization`` +header and returning a synthetic ``404``. It validates, against a real runtime: + 1. the callback's token reaches the model request as ``Authorization: Bearer ``; + 2. the runtime re-acquires a token per request (no runtime-side caching); + 3. per-provider dispatch routes each provider's turn to its own callback, and + the resulting token reaches that provider's endpoint. +""" + +from __future__ import annotations + +import re + +import httpx +import pytest +import pytest_asyncio + +from copilot import CopilotRequestContext, CopilotRequestHandler +from copilot.session import GetBearerToken, PermissionHandler + +from ._copilot_request_helpers import build_isolated_client, build_non_inference_response +from .testharness import E2ETestContext + +pytestmark = pytest.mark.asyncio(loop_scope="module") + +# Fake BYOK provider base URLs. These hosts are never actually dialed: the +# client-global request interceptor fully answers any request aimed at a +# ``.invalid`` host, so they only need to be syntactically valid, non-resolving +# URLs. Distinct hosts let the per-provider test assert routing by host. +PRIMARY_HOST = "byok-endpoint.invalid" +PRIMARY_BASE_URL = f"https://{PRIMARY_HOST}/v1" +RED_HOST = "byok-red.invalid" +RED_BASE_URL = f"https://{RED_HOST}/v1" +BLUE_HOST = "byok-blue.invalid" +BLUE_BASE_URL = f"https://{BLUE_HOST}/v1" + + +class _CapturingRequestHandler(CopilotRequestHandler): + """Client-global HTTP interceptor used in place of a real BYOK listener. + + The runtime invokes :meth:`send_request` for every model-layer HTTP request. + Requests aimed at a fake BYOK host are captured — recording the + ``Authorization`` header the runtime applied after calling the provider's + ``get_bearer_token`` callback over ``providerToken.getToken`` — and answered + with a synthetic ``404`` (non-retryable, so each outbound model request + yields exactly one capture). Every other request (CAPI bootstrap: model + catalog, policy, …) is fabricated locally so no real network or CAPI proxy + is involved. + """ + + def __init__(self) -> None: + # (host, authorization) for each captured BYOK request, in arrival order. + self.captures: list[tuple[str, str | None]] = [] + + async def send_request( + self, request: httpx.Request, ctx: CopilotRequestContext + ) -> httpx.Response: + url = httpx.URL(request.url) + host = url.host + if host.endswith(".invalid"): + self.captures.append((host, request.headers.get("authorization"))) + return httpx.Response( + 404, + headers={"content-type": "application/json"}, + json={"error": {"message": "fake byok endpoint"}}, + request=request, + ) + return build_non_inference_response(str(request.url)) + + def reset(self) -> None: + self.captures.clear() + + def auth_headers(self) -> list[str]: + """The ``Authorization`` headers captured across BYOK requests, in order.""" + return [auth for (_host, auth) in self.captures if auth is not None] + + def auth_header_for_host(self, host: str) -> str | None: + """The ``Authorization`` header captured for requests aimed at ``host``.""" + for captured_host, auth in self.captures: + if captured_host == host: + return auth + return None + + +@pytest_asyncio.fixture(loop_scope="module") +async def bearer_fixture(ctx: E2ETestContext): + handler = _CapturingRequestHandler() + client = build_isolated_client(ctx, handler) + await client.start() + try: + yield client, handler + finally: + try: + await client.stop() + except Exception: + # Best-effort teardown during fixture cleanup. + pass + + +async def _run_turn(client, providers, models, selection_id: str, prompt: str) -> None: + """Drive one BYOK turn; the synthetic 404 errors the turn, which is expected.""" + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + model=selection_id, + providers=providers, + models=models, + ) + try: + # The interceptor always 404s, so the turn errors after the runtime has + # already sent the (token-bearing) request — which is all we assert on. + try: + await session.send_and_wait(prompt) + except Exception: + pass + finally: + try: + await session.disconnect() + except Exception: + # ignore disconnect errors for the fake BYOK endpoint + pass + + +class TestByokBearerTokenProvider: + async def test_applies_the_callbacks_token_as_the_authorization_header( + self, bearer_fixture + ): + client, handler = bearer_fixture + handler.reset() + + sentinel = "sentinel-bearer-token-abc123" + calls = 0 + + async def get_bearer_token(args) -> str: + nonlocal calls + calls += 1 + return sentinel + + providers = [ + { + "name": "mi", + "type": "openai", + "wire_api": "completions", + "base_url": PRIMARY_BASE_URL, + "get_bearer_token": get_bearer_token, + } + ] + models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}] + + await _run_turn(client, providers, models, "mi/default", "What is 5+5?") + + # The runtime acquired a token via the callback and applied it verbatim + # as the bearer credential on the outbound model request. + assert f"Bearer {sentinel}" in handler.auth_headers() + assert calls >= 1 + + async def test_reacquires_a_fresh_token_for_each_request(self, bearer_fixture): + client, handler = bearer_fixture + handler.reset() + + calls = 0 + + async def get_bearer_token(args) -> str: + nonlocal calls + calls += 1 + # A distinct token per acquisition proves the runtime re-invokes the + # callback per request rather than caching a previous token. + return f"rotating-token-{calls}" + + providers = [ + { + "name": "mi", + "type": "openai", + "wire_api": "completions", + "base_url": PRIMARY_BASE_URL, + "get_bearer_token": get_bearer_token, + } + ] + models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}] + + await _run_turn(client, providers, models, "mi/default", "What is 1+1?") + await _run_turn(client, providers, models, "mi/default", "What is 2+2?") + + # Each outbound request carries a freshly-acquired, distinct token. + auths = handler.auth_headers() + assert len(auths) >= 2 + assert re.match(r"^Bearer rotating-token-\d+$", auths[0]) + assert re.match(r"^Bearer rotating-token-\d+$", auths[1]) + assert auths[0] != auths[1] + assert calls >= 2 + + async def test_dispatches_token_acquisition_per_provider(self, bearer_fixture): + client, handler = bearer_fixture + handler.reset() + + token_by_provider = {"red": "token-for-red", "blue": "token-for-blue"} + acquired_for: list[str] = [] + + def make_callback(provider_name: str) -> GetBearerToken: + async def callback(args) -> str: + # The runtime forwards the requesting provider's name so the + # client can dispatch to the right credential. + assert args["provider_name"] == provider_name + acquired_for.append(provider_name) + return token_by_provider[provider_name] + + return callback + + providers = [ + { + "name": "red", + "type": "openai", + "wire_api": "completions", + "base_url": RED_BASE_URL, + "get_bearer_token": make_callback("red"), + }, + { + "name": "blue", + "type": "openai", + "wire_api": "completions", + "base_url": BLUE_BASE_URL, + "get_bearer_token": make_callback("blue"), + }, + ] + models = [ + {"id": "default", "provider": "red", "wire_model": "byok-gpt-4o"}, + {"id": "default", "provider": "blue", "wire_model": "byok-gpt-4o"}, + ] + + await _run_turn(client, providers, models, "red/default", "What is 3+3?") + await _run_turn(client, providers, models, "blue/default", "What is 4+4?") + + # Each provider's turn was authenticated with its own token AND that + # token was delivered to that provider's endpoint, proving per-provider + # dispatch (not a single session-global credential). + assert handler.auth_header_for_host(RED_HOST) == f"Bearer {token_by_provider['red']}" + assert handler.auth_header_for_host(BLUE_HOST) == f"Bearer {token_by_provider['blue']}" + assert "red" in acquired_for + assert "blue" in acquired_for diff --git a/rust/src/generated/api_types.rs b/rust/src/generated/api_types.rs index 522222b32..05e73dcd4 100644 --- a/rust/src/generated/api_types.rs +++ b/rust/src/generated/api_types.rs @@ -545,6 +545,8 @@ pub mod rpc_methods { pub const CANVAS_CLOSE: &str = "canvas.close"; /// `canvas.action.invoke` pub const CANVAS_ACTION_INVOKE: &str = "canvas.action.invoke"; + /// `providerToken.getToken` + pub const PROVIDERTOKEN_GETTOKEN: &str = "providerToken.getToken"; } /// Parameters for aborting the current turn @@ -5551,6 +5553,9 @@ pub struct NamedProviderConfig { /// Bearer token for authentication. Sets the Authorization header directly. Takes precedence over apiKey when both are set. #[serde(skip_serializing_if = "Option::is_none")] pub bearer_token: Option, + /// When true, the SDK client supplies bearer tokens on demand: the runtime calls the client-session `providerToken.getToken` callback before each request and uses the returned token as the Authorization header. The token-acquiring function itself stays on the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive with `apiKey`/`bearerToken`. + #[serde(skip_serializing_if = "Option::is_none")] + pub has_bearer_token_provider: Option, /// Custom HTTP headers to include in all outbound requests to the provider. #[serde(skip_serializing_if = "Option::is_none")] pub headers: Option>, @@ -7682,6 +7687,9 @@ pub struct ProviderConfig { /// Bearer token for authentication. Sets the Authorization header directly. Takes precedence over apiKey when both are set. #[serde(skip_serializing_if = "Option::is_none")] pub bearer_token: Option, + /// When true, the SDK client supplies bearer tokens on demand: the runtime calls the client-session `providerToken.getToken` callback before each request and uses the returned token as the Authorization header. The token-acquiring function itself stays on the SDK side and is never serialized; only this flag crosses the wire. Mutually exclusive with `apiKey`/`bearerToken`. + #[serde(skip_serializing_if = "Option::is_none")] + pub has_bearer_token_provider: Option, /// Custom HTTP headers to include in all outbound requests to the provider. #[serde(skip_serializing_if = "Option::is_none")] pub headers: Option>, @@ -13100,6 +13108,38 @@ pub struct WorkspaceSummary { pub user_named: Option, } +/// Asks the SDK client to acquire a bearer token for a BYOK provider whose config set `hasBearerTokenProvider: true`. Issued by the runtime before each outbound model request; the runtime does no caching, so this is sent once per request. +/// +///

+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ProviderTokenAcquireRequest { + /// Target session identifier + pub session_id: SessionId, + /// Name of the BYOK provider needing a token. For the legacy whole-session `provider` this is the implicit provider name; for named providers it is `NamedProviderConfig.name`. + pub provider_name: String, +} + +/// A bearer token supplied by the SDK client for a BYOK provider. The runtime sets it as `Authorization: Bearer ` on the outbound request and does no caching; the SDK consumer owns token caching and refresh. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ProviderTokenAcquireResult { + /// The bearer token value (without the `Bearer ` prefix). + pub token: String, +} + /// List of Copilot models available to the resolved user, including capabilities and billing metadata. #[derive(Debug, Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -16762,6 +16802,21 @@ pub struct CanvasOpenResult { pub url: Option, } +/// A bearer token supplied by the SDK client for a BYOK provider. The runtime sets it as `Authorization: Bearer ` on the outbound request and does no caching; the SDK consumer owns token caching and refresh. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ProviderTokenGetTokenResult { + /// The bearer token value (without the `Bearer ` prefix). + pub token: String, +} + /// HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. /// ///
diff --git a/rust/src/lib.rs b/rust/src/lib.rs index a0986182f..bd4988e86 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -22,6 +22,9 @@ pub mod hooks; mod jsonrpc; /// Permission-policy helpers that produce a [`handler::PermissionHandler`]. pub mod permission; +/// BYOK bearer-token provider callbacks. +pub mod provider_token; +mod provider_token_dispatch; /// GitHub Copilot CLI binary resolution (env var, embedded, dev cache). pub(crate) mod resolve; mod router; @@ -72,6 +75,7 @@ pub(crate) use jsonrpc::{ JsonRpcClient, JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes, }; pub use mode::{BUILTIN_TOOLS_ISOLATED, ClientMode, ToolSet}; +pub use provider_token::{BearerTokenError, BearerTokenProvider, ProviderTokenArgs}; /// Re-exported JSON-RPC internals for integration tests (requires `test-support` feature). #[cfg(feature = "test-support")] diff --git a/rust/src/provider_token.rs b/rust/src/provider_token.rs new file mode 100644 index 000000000..f92715006 --- /dev/null +++ b/rust/src/provider_token.rs @@ -0,0 +1,105 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +//! BYOK bearer-token provider callbacks. +//! +//!
+//! +//! **Experimental.** These types are part of an experimental wire-protocol +//! surface and may change or be removed in future SDK or CLI releases. +//! +//!
+ +use std::future::Future; + +use async_trait::async_trait; + +/// Arguments passed to a BYOK bearer-token provider callback. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol +/// surface and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ProviderTokenArgs { + /// Name of the BYOK provider needing a token. + /// + /// This is `"default"` for the singular whole-session provider, otherwise + /// the named provider's `name`. + pub provider_name: String, +} + +/// Error returned by a [`BearerTokenProvider`]. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol +/// surface and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BearerTokenError { + message: String, +} + +impl BearerTokenError { + /// Construct a bearer-token error with a human-readable message. + pub fn message(message: impl Into) -> Self { + Self { + message: message.into(), + } + } + + /// Return the human-readable error message. + pub fn as_str(&self) -> &str { + &self.message + } +} + +impl std::fmt::Display for BearerTokenError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.message) + } +} + +impl std::error::Error for BearerTokenError {} + +impl From for BearerTokenError { + fn from(message: String) -> Self { + Self::message(message) + } +} + +impl From<&str> for BearerTokenError { + fn from(message: &str) -> Self { + Self::message(message) + } +} + +/// Provider-side callback used to acquire bearer tokens for BYOK providers. +/// +///
+/// +/// **Experimental.** This trait is part of an experimental wire-protocol +/// surface and may change or be removed in future SDK or CLI releases. +/// +///
+#[async_trait] +pub trait BearerTokenProvider: Send + Sync { + /// Acquire a bearer token without the `Bearer ` prefix. + async fn get_token(&self, args: ProviderTokenArgs) -> Result; +} + +#[async_trait] +impl BearerTokenProvider for F +where + F: Fn(ProviderTokenArgs) -> Fut + Send + Sync, + Fut: Future> + Send, +{ + async fn get_token(&self, args: ProviderTokenArgs) -> Result { + (self)(args).await + } +} diff --git a/rust/src/provider_token_dispatch.rs b/rust/src/provider_token_dispatch.rs new file mode 100644 index 000000000..c100443cd --- /dev/null +++ b/rust/src/provider_token_dispatch.rs @@ -0,0 +1,157 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +//! Inbound `providerToken.*` JSON-RPC request dispatch helpers. + +use std::collections::HashMap; +use std::sync::Arc; + +use serde::Serialize; +use serde_json::Value; +use tracing::warn; + +use crate::generated::api_types::{ + ProviderTokenAcquireRequest, ProviderTokenAcquireResult, rpc_methods, +}; +use crate::provider_token::{BearerTokenError, BearerTokenProvider, ProviderTokenArgs}; +use crate::{Client, JsonRpcRequest, JsonRpcResponse, error_codes}; + +async fn respond(client: &Client, request_id: u64, result: T) { + let value = match serde_json::to_value(&result) { + Ok(value) => value, + Err(error) => { + warn!(error = %error, "failed to serialize provider token response"); + send_error( + client, + request_id, + error_codes::INTERNAL_ERROR, + "serialization failure", + ) + .await; + return; + } + }; + + let _ = client + .send_response(&JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request_id, + result: Some(value), + error: None, + }) + .await; +} + +async fn send_error(client: &Client, request_id: u64, code: i32, message: &str) { + let _ = client + .send_response(&JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request_id, + result: None, + error: Some(crate::JsonRpcError { + code, + message: message.to_string(), + data: None, + }), + }) + .await; +} + +async fn parse_params( + client: &Client, + request: &JsonRpcRequest, +) -> Option { + let params = request + .params + .as_ref() + .cloned() + .unwrap_or(Value::Object(serde_json::Map::new())); + match serde_json::from_value(params) { + Ok(params) => Some(params), + Err(error) => { + send_error( + client, + request.id, + error_codes::INVALID_PARAMS, + &format!("invalid params: {error}"), + ) + .await; + None + } + } +} + +fn token_provider_or_err( + providers: &HashMap>, + provider_name: &str, +) -> Result, BearerTokenError> { + providers.get(provider_name).cloned().ok_or_else(|| { + BearerTokenError::message(format!( + "No bearer-token provider installed for BYOK provider {provider_name:?}" + )) + }) +} + +async fn get_token( + client: &Client, + providers: &HashMap>, + request: JsonRpcRequest, +) { + let Some(params) = parse_params::(client, &request).await else { + return; + }; + + let token_provider = match token_provider_or_err(providers, ¶ms.provider_name) { + Ok(provider) => provider, + Err(error) => { + send_error( + client, + request.id, + error_codes::INTERNAL_ERROR, + &error.to_string(), + ) + .await; + return; + } + }; + + match token_provider + .get_token(ProviderTokenArgs { + provider_name: params.provider_name, + }) + .await + { + Ok(token) => respond(client, request.id, ProviderTokenAcquireResult { token }).await, + Err(error) => { + send_error( + client, + request.id, + error_codes::INTERNAL_ERROR, + &format!("Bearer-token provider failed: {error}"), + ) + .await; + } + } +} + +pub(crate) async fn dispatch( + client: &Client, + providers: &HashMap>, + request: JsonRpcRequest, +) { + let method = request.method.as_str(); + match method { + rpc_methods::PROVIDERTOKEN_GETTOKEN => get_token(client, providers, request).await, + _ => { + warn!(method = %method, "unknown providerToken.* method"); + send_error( + client, + request.id, + error_codes::METHOD_NOT_FOUND, + &format!("unknown method: {method}"), + ) + .await; + } + } +} diff --git a/rust/src/session.rs b/rust/src/session.rs index fed6705da..18b91b437 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -21,6 +21,7 @@ use crate::handler::{ PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, }; use crate::hooks::SessionHooks; +use crate::provider_token::BearerTokenProvider; use crate::session_fs::SessionFsProvider; use crate::trace_context::inject_trace_context; use crate::transforms::SystemMessageTransform; @@ -893,6 +894,7 @@ impl Client { let command_handlers = build_command_handler_map(runtime.commands.as_deref()); let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); + let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers); if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } @@ -1011,6 +1013,7 @@ impl Client { command_handlers, canvas_handler, session_fs_provider, + bearer_token_providers, channels, idle_waiter.clone(), capabilities.clone(), @@ -1149,6 +1152,7 @@ impl Client { let command_handlers = build_command_handler_map(runtime.commands.as_deref()); let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); + let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers); if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } @@ -1183,6 +1187,7 @@ impl Client { command_handlers, canvas_handler, session_fs_provider, + bearer_token_providers, channels, idle_waiter.clone(), capabilities.clone(), @@ -1391,6 +1396,7 @@ fn spawn_event_loop( command_handlers: Arc, canvas_handler: Option>, session_fs_provider: Option>, + bearer_token_providers: HashMap>, channels: crate::router::SessionChannels, idle_waiter: Arc>>, capabilities: Arc>, @@ -1432,6 +1438,7 @@ fn spawn_event_loop( transforms: transforms.as_deref(), canvas_handler: canvas_handler.as_ref(), session_fs_provider: session_fs_provider.as_ref(), + bearer_token_providers: &bearer_token_providers, }; handle_request(&session_id, ctx, request).await; } @@ -2010,6 +2017,7 @@ struct RequestDispatchContext<'a> { transforms: Option<&'a dyn SystemMessageTransform>, canvas_handler: Option<&'a Arc>, session_fs_provider: Option<&'a Arc>, + bearer_token_providers: &'a HashMap>, } /// Process a JSON-RPC request from the CLI. @@ -2025,6 +2033,7 @@ async fn handle_request( let transforms = ctx.transforms; let canvas_handler = ctx.canvas_handler; let session_fs_provider = ctx.session_fs_provider; + let bearer_token_providers = ctx.bearer_token_providers; if request.method.starts_with("sessionFs.") { crate::session_fs_dispatch::dispatch(client, session_fs_provider, request).await; @@ -2036,6 +2045,11 @@ async fn handle_request( return; } + if request.method == crate::generated::api_types::rpc_methods::PROVIDERTOKEN_GETTOKEN { + crate::provider_token_dispatch::dispatch(client, bearer_token_providers, request).await; + return; + } + match request.method.as_str() { "hooks.invoke" => { let params = request.params.as_ref(); diff --git a/rust/src/types.rs b/rust/src/types.rs index e743dbda1..0d46fb811 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -28,6 +28,7 @@ use crate::handler::{ UserInputHandler, }; use crate::hooks::SessionHooks; +use crate::provider_token::BearerTokenProvider; pub use crate::session_fs::{ DirEntry, DirEntryKind, FileInfo, FsError, SessionFsCapabilities, SessionFsConfig, SessionFsConventions, SessionFsProvider, SessionFsSqliteProvider, SessionFsSqliteQueryResult, @@ -1021,7 +1022,7 @@ pub struct McpHttpServerConfig { /// Routes session requests through an alternative model provider /// (OpenAI-compatible, Azure, Anthropic, or local) instead of GitHub /// Copilot's default routing. -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[non_exhaustive] pub struct ProviderConfig { @@ -1049,6 +1050,12 @@ pub struct ProviderConfig { /// API key. Takes precedence over `api_key` when both are set. #[serde(default, skip_serializing_if = "Option::is_none")] pub bearer_token: Option, + /// **Experimental.** Callback used to acquire a bearer token before each + /// outbound request to this provider. + #[serde(skip)] + pub get_bearer_token: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub(crate) has_bearer_token_provider: Option, /// Azure-specific options. #[serde(default, skip_serializing_if = "Option::is_none")] pub azure: Option, @@ -1080,6 +1087,30 @@ pub struct ProviderConfig { pub max_output_tokens: Option, } +impl std::fmt::Debug for ProviderConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ProviderConfig") + .field("provider_type", &self.provider_type) + .field("wire_api", &self.wire_api) + .field("transport", &self.transport) + .field("base_url", &self.base_url) + .field("api_key", &self.api_key) + .field("bearer_token", &self.bearer_token) + .field( + "get_bearer_token", + &self.get_bearer_token.as_ref().map(|_| ""), + ) + .field("has_bearer_token_provider", &self.has_bearer_token_provider) + .field("azure", &self.azure) + .field("headers", &self.headers) + .field("model_id", &self.model_id) + .field("wire_model", &self.wire_model) + .field("max_prompt_tokens", &self.max_prompt_tokens) + .field("max_output_tokens", &self.max_output_tokens) + .finish() + } +} + impl ProviderConfig { /// Construct a [`ProviderConfig`] with the required `base_url` set; /// all other fields default to unset. @@ -1122,6 +1153,16 @@ impl ProviderConfig { self } + /// Set the callback used to acquire a bearer token before each outbound + /// request to this provider. + /// + /// **Experimental.** This method is part of an experimental wire-protocol + /// surface and may change or be removed in a future release. + pub fn with_get_bearer_token(mut self, provider: Arc) -> Self { + self.get_bearer_token = Some(provider); + self + } + /// Set Azure-specific options. pub fn with_azure(mut self, azure: AzureProviderOptions) -> Self { self.azure = Some(azure); @@ -1223,7 +1264,7 @@ pub struct AzureProviderOptions { /// default Copilot routing and exposes these providers' models alongside /// it. Models are attached via [`ProviderModelConfig`], which references a /// provider by [`name`](Self::name). -#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[derive(Clone, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] #[non_exhaustive] pub struct NamedProviderConfig { @@ -1247,6 +1288,12 @@ pub struct NamedProviderConfig { /// directly. Takes precedence over `api_key` when both are set. #[serde(default, skip_serializing_if = "Option::is_none")] pub bearer_token: Option, + /// **Experimental.** Callback used to acquire a bearer token before each + /// outbound request to this provider. + #[serde(skip)] + pub get_bearer_token: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub(crate) has_bearer_token_provider: Option, /// Azure-specific options. #[serde(default, skip_serializing_if = "Option::is_none")] pub azure: Option, @@ -1255,6 +1302,26 @@ pub struct NamedProviderConfig { pub headers: Option>, } +impl std::fmt::Debug for NamedProviderConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NamedProviderConfig") + .field("name", &self.name) + .field("provider_type", &self.provider_type) + .field("wire_api", &self.wire_api) + .field("base_url", &self.base_url) + .field("api_key", &self.api_key) + .field("bearer_token", &self.bearer_token) + .field( + "get_bearer_token", + &self.get_bearer_token.as_ref().map(|_| ""), + ) + .field("has_bearer_token_provider", &self.has_bearer_token_provider) + .field("azure", &self.azure) + .field("headers", &self.headers) + .finish() + } +} + impl NamedProviderConfig { /// Construct a [`NamedProviderConfig`] with the required `name` and /// `base_url` set; all other fields default to unset. @@ -1291,6 +1358,16 @@ impl NamedProviderConfig { self } + /// Set the callback used to acquire a bearer token before each outbound + /// request to this provider. + /// + /// **Experimental.** This method is part of an experimental wire-protocol + /// surface and may change or be removed in a future release. + pub fn with_get_bearer_token(mut self, provider: Arc) -> Self { + self.get_bearer_token = Some(provider); + self + } + /// Set Azure-specific options. pub fn with_azure(mut self, azure: AzureProviderOptions) -> Self { self.azure = Some(azure); @@ -1304,6 +1381,31 @@ impl NamedProviderConfig { } } +fn prepare_bearer_token_providers( + provider: &mut Option, + providers: &mut Option>, +) -> HashMap> { + let mut bearer_token_providers = HashMap::new(); + + if let Some(provider) = provider.as_mut() + && let Some(token_provider) = provider.get_bearer_token.take() + { + provider.has_bearer_token_provider = Some(true); + bearer_token_providers.insert("default".to_string(), token_provider); + } + + if let Some(providers) = providers.as_mut() { + for provider in providers { + if let Some(token_provider) = provider.get_bearer_token.take() { + provider.has_bearer_token_provider = Some(true); + bearer_token_providers.insert(provider.name.clone(), token_provider); + } + } + } + + bearer_token_providers +} + /// A BYOK model definition in the multi-provider registry. /// /// **Experimental.** Multi-provider BYOK configuration is part of an @@ -1909,6 +2011,7 @@ pub(crate) struct SessionConfigRuntime { pub tool_handlers: HashMap>, pub canvas_handler: Option>, pub session_fs_provider: Option>, + pub bearer_token_providers: HashMap>, pub commands: Option>, } @@ -1960,6 +2063,8 @@ impl SessionConfig { }); let wire_canvases = self.canvases.clone(); let canvas_handler = self.canvas_handler.clone(); + let bearer_token_providers = + prepare_bearer_token_providers(&mut self.provider, &mut self.providers); let wire = crate::wire::SessionCreateWire { session_id, @@ -2035,6 +2140,7 @@ impl SessionConfig { tool_handlers, canvas_handler, session_fs_provider: self.session_fs_provider, + bearer_token_providers, commands: self.commands, }; @@ -2897,6 +3003,8 @@ impl ResumeSessionConfig { }); let wire_canvases = self.canvases.clone(); let canvas_handler = self.canvas_handler.clone(); + let bearer_token_providers = + prepare_bearer_token_providers(&mut self.provider, &mut self.providers); let wire = crate::wire::SessionResumeWire { session_id: self.session_id, @@ -2973,6 +3081,7 @@ impl ResumeSessionConfig { tool_handlers, canvas_handler, session_fs_provider: self.session_fs_provider, + bearer_token_providers, commands: self.commands, }; diff --git a/rust/tests/e2e.rs b/rust/tests/e2e.rs index c46630e69..59b83ab27 100644 --- a/rust/tests/e2e.rs +++ b/rust/tests/e2e.rs @@ -7,6 +7,8 @@ mod abort; mod ask_user; #[path = "e2e/builtin_tools.rs"] mod builtin_tools; +#[path = "e2e/byok_bearer_token_provider.rs"] +mod byok_bearer_token_provider; #[path = "e2e/canvas.rs"] mod canvas; #[path = "e2e/client.rs"] diff --git a/rust/tests/e2e/byok_bearer_token_provider.rs b/rust/tests/e2e/byok_bearer_token_provider.rs new file mode 100644 index 000000000..c3cd9ef4b --- /dev/null +++ b/rust/tests/e2e/byok_bearer_token_provider.rs @@ -0,0 +1,314 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use async_trait::async_trait; +use bytes::Bytes; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::{ + BearerTokenError, CopilotHttpRequest, CopilotHttpResponse, CopilotRequestContext, + CopilotRequestError, CopilotRequestHandler, MessageOptions, NamedProviderConfig, + ProviderModelConfig, ProviderTokenArgs, SessionConfig, +}; +use http::HeaderMap; + +use super::support::with_e2e_context_no_snapshot; + +const PRIMARY_BASE_URL: &str = "https://byok-endpoint.invalid/v1"; +const RED_HOST: &str = "byok-red.invalid"; +const RED_BASE_URL: &str = "https://byok-red.invalid/v1"; +const BLUE_HOST: &str = "byok-blue.invalid"; +const BLUE_BASE_URL: &str = "https://byok-blue.invalid/v1"; + +#[derive(Debug, Clone)] +struct CapturedRequest { + host: String, + authorization: Option, +} + +#[derive(Default)] +struct CapturingRequestHandler { + captures: std::sync::Mutex>, +} + +impl CapturingRequestHandler { + fn auth_headers(&self) -> Vec { + self.captures + .lock() + .unwrap() + .iter() + .filter_map(|capture| capture.authorization.clone()) + .collect() + } + + fn auth_header_for_host(&self, host: &str) -> Option { + self.captures + .lock() + .unwrap() + .iter() + .find(|capture| capture.host == host) + .and_then(|capture| capture.authorization.clone()) + } + + fn reset(&self) { + self.captures.lock().unwrap().clear(); + } +} + +#[async_trait] +impl CopilotRequestHandler for CapturingRequestHandler { + async fn send_request( + &self, + request: CopilotHttpRequest, + _ctx: &CopilotRequestContext, + ) -> Result { + let uri: http::Uri = request + .url + .parse() + .map_err(|error| CopilotRequestError::message(format!("invalid URL: {error}")))?; + if let Some(host) = uri.host() + && host.ends_with(".invalid") + { + let authorization = request + .headers + .get("authorization") + .and_then(|value| value.to_str().ok()) + .map(str::to_string); + self.captures.lock().unwrap().push(CapturedRequest { + host: host.to_string(), + authorization, + }); + return Ok(json_response( + 404, + br#"{"error":{"message":"fake byok endpoint"}}"#.to_vec(), + )); + } + + Ok(synth_non_inference_response(&request.url)) + } +} + +fn json_response(status: u16, body: Vec) -> CopilotHttpResponse { + let mut headers = HeaderMap::new(); + headers.insert( + "content-type", + http::HeaderValue::from_static("application/json"), + ); + let body = futures_util::stream::iter([Ok::(Bytes::from(body))]); + CopilotHttpResponse::new(status, None, headers, Box::pin(body)) +} + +fn synth_non_inference_response(url: &str) -> CopilotHttpResponse { + let lower = url.to_lowercase(); + if lower.ends_with("/models") { + return json_response( + 200, + br#"{"data":[{"id":"gpt-4o","name":"GPT-4o","object":"model","vendor":"OpenAI","version":"1","preview":false,"model_picker_enabled":true,"capabilities":{"type":"chat","family":"gpt-4o","tokenizer":"o200k_base","limits":{"max_context_window_tokens":128000,"max_output_tokens":4096},"supports":{"streaming":true,"tool_calls":true,"parallel_tool_calls":true}}}]}"# + .to_vec(), + ); + } + if lower.contains("/models/session") { + return json_response(200, b"{}".to_vec()); + } + if lower.contains("/policy") { + return json_response(200, br#"{"state":"enabled"}"#.to_vec()); + } + json_response(200, b"{}".to_vec()) +} + +async fn run_turn( + client: &github_copilot_sdk::Client, + providers: Vec, + models: Vec, + selection_id: &str, + prompt: &str, +) { + let session = client + .create_session( + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_model(selection_id) + .with_providers(providers) + .with_models(models), + ) + .await + .expect("create session"); + let _ = session.send_and_wait(MessageOptions::new(prompt)).await; + let _ = session.disconnect().await; +} + +#[tokio::test] +async fn callback_token_is_applied_as_authorization_header() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let handler = Arc::new(CapturingRequestHandler::default()); + let client = ctx.start_llm_client(handler.clone(), &[]).await; + handler.reset(); + + let calls = Arc::new(AtomicUsize::new(0)); + let callback_calls = calls.clone(); + let providers = vec![ + NamedProviderConfig::new("mi", PRIMARY_BASE_URL) + .with_provider_type("openai") + .with_wire_api("completions") + .with_get_bearer_token(Arc::new(move |_args: ProviderTokenArgs| { + let callback_calls = callback_calls.clone(); + async move { + callback_calls.fetch_add(1, Ordering::SeqCst); + Ok::<_, BearerTokenError>("sentinel-bearer-token-abc123".to_string()) + } + })), + ]; + let models = + vec![ProviderModelConfig::new("default", "mi").with_wire_model("byok-gpt-4o")]; + + run_turn(&client, providers, models, "mi/default", "What is 5+5?").await; + + assert!( + calls.load(Ordering::SeqCst) >= 1, + "expected callback to be invoked" + ); + // Validate the captured Authorization header is the final assertion. + assert!( + handler + .auth_headers() + .contains(&"Bearer sentinel-bearer-token-abc123".to_string()), + "expected captured Authorization headers to include the sentinel token, got {:?}", + handler.auth_headers() + ); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn reacquires_a_fresh_token_for_each_request() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let handler = Arc::new(CapturingRequestHandler::default()); + let client = ctx.start_llm_client(handler.clone(), &[]).await; + handler.reset(); + + let calls = Arc::new(AtomicUsize::new(0)); + let callback_calls = calls.clone(); + let providers = vec![ + NamedProviderConfig::new("mi", PRIMARY_BASE_URL) + .with_provider_type("openai") + .with_wire_api("completions") + .with_get_bearer_token(Arc::new(move |_args: ProviderTokenArgs| { + let callback_calls = callback_calls.clone(); + async move { + let call = callback_calls.fetch_add(1, Ordering::SeqCst) + 1; + Ok::<_, BearerTokenError>(format!("rotating-token-{call}")) + } + })), + ]; + let models = + vec![ProviderModelConfig::new("default", "mi").with_wire_model("byok-gpt-4o")]; + + run_turn( + &client, + providers.clone(), + models.clone(), + "mi/default", + "What is 1+1?", + ) + .await; + run_turn(&client, providers, models, "mi/default", "What is 2+2?").await; + + let auths = handler.auth_headers(); + assert!( + auths.len() >= 2, + "expected at least 2 captured Authorization headers, got {auths:?}" + ); + assert!( + auths[0].starts_with("Bearer rotating-token-") + && auths[1].starts_with("Bearer rotating-token-"), + "expected rotating-token bearer headers, got {auths:?}" + ); + assert!( + calls.load(Ordering::SeqCst) >= 2, + "expected callback to be invoked at least twice" + ); + // Validate the captured Authorization header is the final assertion. + assert_ne!( + auths[0], auths[1], + "expected distinct tokens per request, both were {:?}", + auths[0] + ); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn dispatches_token_acquisition_per_provider() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let handler = Arc::new(CapturingRequestHandler::default()); + let client = ctx.start_llm_client(handler.clone(), &[]).await; + handler.reset(); + + let acquired_for = Arc::new(std::sync::Mutex::new(Vec::new())); + let make_provider = + |name: &'static str, base_url: &'static str, token: &'static str| { + let acquired_for = acquired_for.clone(); + NamedProviderConfig::new(name, base_url) + .with_provider_type("openai") + .with_wire_api("completions") + .with_get_bearer_token(Arc::new(move |args: ProviderTokenArgs| { + let acquired_for = acquired_for.clone(); + async move { + assert_eq!(args.provider_name, name); + acquired_for.lock().unwrap().push(name.to_string()); + Ok::<_, BearerTokenError>(token.to_string()) + } + })) + }; + let providers = vec![ + make_provider("red", RED_BASE_URL, "token-for-red"), + make_provider("blue", BLUE_BASE_URL, "token-for-blue"), + ]; + let models = vec![ + ProviderModelConfig::new("default", "red").with_wire_model("byok-gpt-4o"), + ProviderModelConfig::new("default", "blue").with_wire_model("byok-gpt-4o"), + ]; + + run_turn( + &client, + providers.clone(), + models.clone(), + "red/default", + "What is 3+3?", + ) + .await; + run_turn(&client, providers, models, "blue/default", "What is 4+4?").await; + + let acquired = acquired_for.lock().unwrap().clone(); + assert!(acquired.contains(&"red".to_string())); + assert!(acquired.contains(&"blue".to_string())); + assert_eq!( + handler.auth_header_for_host(RED_HOST).as_deref(), + Some("Bearer token-for-red") + ); + // Validate the captured Authorization header is the final assertion. + assert_eq!( + handler.auth_header_for_host(BLUE_HOST).as_deref(), + Some("Bearer token-for-blue") + ); + + client.stop().await.expect("stop client"); + }) + }) + .await; +}