diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 85fb8bd34..5a5d34dcd 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -1009,6 +1009,8 @@ public async Task CreateSessionAsync(SessionConfig config, Cance RequestExtensions: config.RequestExtensions, ExtensionSdkPath: config.ExtensionSdkPath, ExtensionInfo: config.ExtensionInfo, + Providers: config.Providers, + Models: config.Models, ToolFilterPrecedence: toolFilter.ToolFilterPrecedence); var rpcTimestamp = Stopwatch.GetTimestamp(); @@ -1207,6 +1209,8 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes ExtensionSdkPath: config.ExtensionSdkPath, ExtensionInfo: config.ExtensionInfo, OpenCanvases: config.OpenCanvases, + Providers: config.Providers, + Models: config.Models, ToolFilterPrecedence: toolFilter.ToolFilterPrecedence); var rpcTimestamp = Stopwatch.GetTimestamp(); @@ -2402,6 +2406,8 @@ internal record CreateSessionRequest( bool? RequestExtensions = null, string? ExtensionSdkPath = null, ExtensionInfo? ExtensionInfo = null, + IList? Providers = null, + IList? Models = null, OptionsUpdateToolFilterPrecedence? ToolFilterPrecedence = null); #pragma warning restore GHCP001 @@ -2494,6 +2500,8 @@ internal record ResumeSessionRequest( string? ExtensionSdkPath = null, ExtensionInfo? ExtensionInfo = null, IList? OpenCanvases = null, + IList? Providers = null, + IList? Models = null, OptionsUpdateToolFilterPrecedence? ToolFilterPrecedence = null); #pragma warning restore GHCP001 @@ -2569,6 +2577,8 @@ internal record HooksInvokeResponse( [JsonSerializable(typeof(EmbeddingCacheStorageMode))] [JsonSerializable(typeof(ModelCapabilitiesOverride))] [JsonSerializable(typeof(ProviderConfig))] + [JsonSerializable(typeof(NamedProviderConfig))] + [JsonSerializable(typeof(ProviderModelConfig))] [JsonSerializable(typeof(ResumeSessionRequest))] [JsonSerializable(typeof(ResumeSessionResponse))] [JsonSerializable(typeof(SessionCapabilities))] diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 706a1ec6b..d7b326afb 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -2070,6 +2070,135 @@ public sealed class AzureOptions public string? ApiVersion { get; set; } } +/// +/// A named BYOK provider connection (transport + credentials only), referenced by +/// entries via . +/// +/// Unlike the singular, whole-session — which bypasses +/// Copilot API authentication — named providers are additive and coexist with Copilot +/// API auth, so models from CAPI and one or more BYOK providers can be mixed within a +/// single session and across sub-agents. Combining named providers/models with +/// is rejected. +/// +/// +[Experimental(Diagnostics.Experimental)] +public sealed class NamedProviderConfig +{ + /// + /// Stable identifier referenced by . + /// Must not contain '/'. + /// + [JsonPropertyName("name")] + public string Name { get; set; } = string.Empty; + + /// + /// Provider type. Defaults to "openai" for generic OpenAI-compatible APIs. + /// + [JsonPropertyName("type")] + public string? Type { get; set; } + + /// + /// Wire API format (openai/azure only). Defaults to "completions". + /// + [JsonPropertyName("wireApi")] + public string? WireApi { get; set; } + + /// + /// API endpoint URL. + /// + [JsonPropertyName("baseUrl")] + public string BaseUrl { get; set; } = string.Empty; + + /// + /// API key. Optional for local providers like Ollama. + /// + [JsonPropertyName("apiKey")] + public string? ApiKey { get; set; } + + /// + /// Bearer token for authentication. Sets the Authorization header directly. + /// Takes precedence over when both are set. + /// + [JsonPropertyName("bearerToken")] + public string? BearerToken { get; set; } + + /// + /// Azure-specific configuration options. + /// + [JsonPropertyName("azure")] + public AzureOptions? Azure { get; set; } + + /// + /// Custom HTTP headers to include in all outbound requests to the provider. + /// + [JsonPropertyName("headers")] + public IDictionary? Headers { get; set; } +} + +/// +/// A BYOK model definition that references a by name +/// and is added to the session's selectable model list. The session-wide selection id +/// (shown in the model list and passed to model switching) is the provider-qualified +/// provider/id, so BYOK ids never collide with bare CAPI ids. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class ProviderModelConfig +{ + /// + /// Provider-local model id, unique within its provider. + /// + [JsonPropertyName("id")] + public string Id { get; set; } = string.Empty; + + /// + /// Name of the that serves this model. + /// + [JsonPropertyName("provider")] + public string Provider { get; set; } = string.Empty; + + /// + /// The model name sent to the provider API for inference. Defaults to . + /// + [JsonPropertyName("wireModel")] + public string? WireModel { get; set; } + + /// + /// Well-known base model id used for behavior/capability/config lookup. Defaults to . + /// + [JsonPropertyName("modelId")] + public string? ModelId { get; set; } + + /// + /// Display name for model pickers. Defaults to the provider-qualified selection id. + /// + [JsonPropertyName("name")] + public string? Name { get; set; } + + /// + /// Maximum prompt/input tokens for the model. + /// + [JsonPropertyName("maxPromptTokens")] + public int? MaxPromptTokens { get; set; } + + /// + /// Maximum context window tokens for the model. + /// + [JsonPropertyName("maxContextWindowTokens")] + public int? MaxContextWindowTokens { get; set; } + + /// + /// Maximum output tokens for the model. + /// + [JsonPropertyName("maxOutputTokens")] + public int? MaxOutputTokens { get; set; } + + /// + /// Optional capability overrides (vision, tool_calls, reasoning, etc.) for the synthesized model. + /// + [JsonPropertyName("capabilities")] + public ModelCapabilitiesOverride? Capabilities { get; set; } +} + // ============================================================================ // MCP Server Configuration Types // ============================================================================ @@ -2494,6 +2623,8 @@ protected SessionConfigBase(SessionConfigBase? other) OnPermissionRequest = other.OnPermissionRequest; OnUserInputRequest = other.OnUserInputRequest; Provider = other.Provider; + Providers = other.Providers is not null ? [.. other.Providers] : null; + Models = other.Models is not null ? [.. other.Models] : null; EnableSessionTelemetry = other.EnableSessionTelemetry; SkipCustomInstructions = other.SkipCustomInstructions; CustomAgentsLocalOnly = other.CustomAgentsLocalOnly; @@ -2649,6 +2780,21 @@ protected SessionConfigBase(SessionConfigBase? other) /// Custom model provider configuration for the session. public ProviderConfig? Provider { get; set; } + /// + /// Named BYOK provider connections (transport + credentials). Additive to Copilot + /// API authentication (unlike ); combine with . + /// Cannot be combined with . + /// + [Experimental(Diagnostics.Experimental)] + public IList? Providers { get; set; } + + /// + /// BYOK model definitions added to the session's selectable model list, each + /// referencing a entry by name. + /// + [Experimental(Diagnostics.Experimental)] + public IList? Models { get; set; } + /// /// Enables or disables internal session telemetry for this session. /// When false, disables session telemetry. When null (the default) or true, diff --git a/dotnet/test/E2E/MultiProviderRegistryE2ETests.cs b/dotnet/test/E2E/MultiProviderRegistryE2ETests.cs new file mode 100644 index 000000000..8a75cc3c1 --- /dev/null +++ b/dotnet/test/E2E/MultiProviderRegistryE2ETests.cs @@ -0,0 +1,208 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.Rpc; +using GitHub.Copilot.Test.Harness; +using System.Text.Json; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.Test.E2E; + +/// +/// End-to-end coverage for the experimental multi-provider BYOK registry +/// ( / ). +/// Validates that several named providers, several models per provider, and +/// custom agents bound to those provider-qualified models can coexist in one +/// session, be launched, and route inference to the configured provider with +/// the configured wire model and headers. +/// +public class MultiProviderRegistryE2ETests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "multi_provider_registry", output) +{ + /// + /// Builds a heterogeneous registry: two providers of different types, with + /// multiple models each. Provider-qualified selection ids are + /// alpha/sonnet, alpha/haiku, beta/opus, beta/haiku. + /// + private static IList RegistryProviders() => + [ + new() + { + Name = "alpha", + Type = "openai", + WireApi = "completions", + BaseUrl = "https://alpha.example.test/v1", + ApiKey = "alpha-secret", + Headers = new Dictionary { ["X-Provider"] = "alpha" }, + }, + new() + { + Name = "beta", + Type = "anthropic", + BaseUrl = "https://beta.example.test", + BearerToken = "beta-bearer", + Headers = new Dictionary { ["X-Provider"] = "beta" }, + }, + ]; + + private static IList RegistryModels() => + [ + new() { Id = "sonnet", Provider = "alpha", WireModel = "byok-gpt-4o", MaxPromptTokens = 111111 }, + new() { Id = "haiku", Provider = "alpha", WireModel = "byok-gpt-4o-mini" }, + new() { Id = "opus", Provider = "beta", WireModel = "byok-claude-3-opus" }, + new() { Id = "haiku", Provider = "beta", WireModel = "byok-claude-3-haiku" }, + ]; + + private static IList RegistryAgents() => + [ + new() { Name = "orchestrator", DisplayName = "Orchestrator", Description = "Top-level planner.", Prompt = "Plan and delegate.", Model = "alpha/sonnet" }, + new() { Name = "researcher", DisplayName = "Researcher", Description = "Deep research subagent.", Prompt = "Research thoroughly.", Model = "beta/opus" }, + new() { Name = "fast-helper", DisplayName = "Fast Helper", Description = "Quick subagent.", Prompt = "Answer quickly.", Model = "alpha/haiku" }, + new() { Name = "summarizer", DisplayName = "Summarizer", Description = "Summarizing subagent.", Prompt = "Summarize.", Model = "beta/haiku" }, + ]; + + [Fact] + public async Task Should_Register_Multiple_Providers_With_Custom_Agents_Bound_To_Their_Models() + { + var session = await CreateSessionAsync(new SessionConfig + { + Providers = RegistryProviders(), + Models = RegistryModels(), + CustomAgents = RegistryAgents(), + }); + + var agents = (await session.Rpc.Agent.ListAsync()).Agents; + + // All four custom agents coexist in a single session. + Assert.Equal(4, agents.Count); + + // Each agent is bound to its configured provider-qualified BYOK model. + AssertAgentModel(agents, "orchestrator", "alpha/sonnet", "Orchestrator", "Top-level planner."); + AssertAgentModel(agents, "researcher", "beta/opus", "Researcher", "Deep research subagent."); + AssertAgentModel(agents, "fast-helper", "alpha/haiku", "Fast Helper", "Quick subagent."); + AssertAgentModel(agents, "summarizer", "beta/haiku", "Summarizer", "Summarizing subagent."); + + // Models from BOTH providers are represented, proving the two providers + // and their models coexist within the same session. + var boundModels = agents.Select(a => a.Model).ToHashSet(); + Assert.Contains(boundModels, m => m!.StartsWith("alpha/", StringComparison.Ordinal)); + Assert.Contains(boundModels, m => m!.StartsWith("beta/", StringComparison.Ordinal)); + } + + [Fact] + public async Task Should_Route_Alpha_Sonnet_Turn_To_Its_Provider_And_Wire_Model() + => await AssertRoutingAsync("alpha/sonnet", "byok-gpt-4o", "alpha"); + + [Fact] + public async Task Should_Route_Alpha_Haiku_Turn_To_Its_Provider_And_Wire_Model() + => await AssertRoutingAsync("alpha/haiku", "byok-gpt-4o-mini", "alpha"); + + [Fact] + public async Task Should_Route_Delta_Turbo_Turn_To_Its_Provider_And_Wire_Model() + => await AssertRoutingAsync("delta/turbo", "byok-gpt-4-turbo", "delta"); + + /// + /// Selects in a session whose registry holds + /// two OpenAI-compatible providers (each pointed at the replay proxy), runs a + /// turn, and asserts the captured request used the model's configured wire + /// model and carried the owning provider's header and credential. + /// + private async Task AssertRoutingAsync(string selectionId, string expectedWireModel, string expectedProviderHeader) + { + // Two OpenAI-compatible providers, both pointed at the replay proxy so + // their /chat/completions traffic is captured. They are distinguished on + // the wire by their per-provider X-Provider header. "alpha" carries two + // models (multiple models per provider); "delta" carries one. + var providers = new List + { + new() + { + Name = "alpha", + Type = "openai", + WireApi = "completions", + BaseUrl = Ctx.ProxyUrl, + ApiKey = "alpha-secret", + Headers = new Dictionary { ["X-Provider"] = "alpha" }, + }, + new() + { + Name = "delta", + Type = "openai", + WireApi = "completions", + BaseUrl = Ctx.ProxyUrl, + ApiKey = "delta-secret", + Headers = new Dictionary { ["X-Provider"] = "delta" }, + }, + }; + var models = new List + { + new() { Id = "sonnet", Provider = "alpha", WireModel = "byok-gpt-4o" }, + new() { Id = "haiku", Provider = "alpha", WireModel = "byok-gpt-4o-mini" }, + new() { Id = "turbo", Provider = "delta", WireModel = "byok-gpt-4-turbo" }, + }; + + var session = await CreateSessionAsync(new SessionConfig + { + Model = selectionId, + Providers = providers, + Models = models, + }); + + var exchanges = await SendAndWaitForExchangesAsync( + session, + new MessageOptions { Prompt = "What is 5+5?" }); + + var exchange = Assert.Single(exchanges); + + // The wire model sent to the provider is the selected model's WireModel, + // not its provider-qualified selection id. + Assert.Equal(expectedWireModel, exchange.Request.Model); + + // The request carried the owning provider's custom header, proving the + // turn was dispatched against the correct provider connection. + Assert.Equal(expectedProviderHeader, GetHeaderValue(exchange, "X-Provider")); + + // The provider's API key was applied as an Authorization header. + Assert.False(string.IsNullOrEmpty(GetHeaderValue(exchange, "Authorization"))); + } + + private static void AssertAgentModel( + IEnumerable agents, + string name, + string expectedModel, + string expectedDisplayName, + string expectedDescription) + { + var agent = Assert.Single(agents, a => string.Equals(a.Name, name, StringComparison.Ordinal)); + Assert.Equal(expectedModel, agent.Model); + Assert.Equal(expectedDisplayName, agent.DisplayName); + Assert.Equal(expectedDescription, agent.Description); + } + + private static string? GetHeaderValue(ParsedHttpExchange exchange, string name) + { + if (exchange.RequestHeaders == null) + { + return null; + } + + foreach (var kv in exchange.RequestHeaders) + { + if (!string.Equals(kv.Key, name, StringComparison.OrdinalIgnoreCase)) + { + continue; + } + + return kv.Value.ValueKind switch + { + JsonValueKind.String => kv.Value.GetString(), + JsonValueKind.Array when kv.Value.GetArrayLength() > 0 => kv.Value[0].GetString(), + _ => kv.Value.ToString(), + }; + } + + return null; + } +} diff --git a/go/client.go b/go/client.go index ad330e5a0..af9044ad9 100644 --- a/go/client.go +++ b/go/client.go @@ -681,6 +681,8 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses req.ExcludedTools = excludedTools req.ToolFilterPrecedence = precedence req.Provider = config.Provider + req.Providers = config.Providers + req.Models = config.Models req.EnableSessionTelemetry = config.EnableSessionTelemetry req.SkipCustomInstructions = config.SkipCustomInstructions req.CustomAgentsLocalOnly = config.CustomAgentsLocalOnly @@ -976,6 +978,8 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, req.SystemMessage = wireSystemMessage req.Tools = config.Tools req.Provider = config.Provider + req.Providers = config.Providers + req.Models = config.Models req.EnableSessionTelemetry = config.EnableSessionTelemetry req.SkipCustomInstructions = config.SkipCustomInstructions req.CustomAgentsLocalOnly = config.CustomAgentsLocalOnly diff --git a/go/internal/e2e/multi_provider_registry_e2e_test.go b/go/internal/e2e/multi_provider_registry_e2e_test.go new file mode 100644 index 000000000..7bec13414 --- /dev/null +++ b/go/internal/e2e/multi_provider_registry_e2e_test.go @@ -0,0 +1,195 @@ +package e2e + +import ( + "strings" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +// TestMultiProviderRegistryE2E exercises the experimental multi-provider BYOK +// registry (Providers / Models on the session config). It validates that +// several named providers, several models per provider, and custom agents +// bound to those provider-qualified models can coexist in one session, be +// launched, and route inference to the configured provider with the configured +// wire model and headers. +func TestMultiProviderRegistryE2E(t *testing.T) { + ctx := testharness.NewTestContext(t) + client := ctx.NewClient() + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + t.Run("should register multiple providers with custom agents bound to their models", func(t *testing.T) { + ctx.ConfigureForTest(t) + + // A heterogeneous registry: two providers of different types, with + // multiple models each. Provider-qualified selection ids are + // alpha/sonnet, alpha/haiku, beta/opus, beta/haiku. + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Providers: []copilot.NamedProviderConfig{ + { + Name: "alpha", + Type: "openai", + WireAPI: "completions", + BaseURL: "https://alpha.example.test/v1", + APIKey: "alpha-secret", + Headers: map[string]string{"X-Provider": "alpha"}, + }, + { + Name: "beta", + Type: "anthropic", + BaseURL: "https://beta.example.test", + BearerToken: "beta-bearer", + Headers: map[string]string{"X-Provider": "beta"}, + }, + }, + Models: []copilot.ProviderModelConfig{ + {ID: "sonnet", Provider: "alpha", WireModel: "byok-gpt-4o", MaxPromptTokens: 111111}, + {ID: "haiku", Provider: "alpha", WireModel: "byok-gpt-4o-mini"}, + {ID: "opus", Provider: "beta", WireModel: "byok-claude-3-opus"}, + {ID: "haiku", Provider: "beta", WireModel: "byok-claude-3-haiku"}, + }, + CustomAgents: []copilot.CustomAgentConfig{ + {Name: "orchestrator", DisplayName: "Orchestrator", Description: "Top-level planner.", Prompt: "Plan and delegate.", Model: "alpha/sonnet"}, + {Name: "researcher", DisplayName: "Researcher", Description: "Deep research subagent.", Prompt: "Research thoroughly.", Model: "beta/opus"}, + {Name: "fast-helper", DisplayName: "Fast Helper", Description: "Quick subagent.", Prompt: "Answer quickly.", Model: "alpha/haiku"}, + {Name: "summarizer", DisplayName: "Summarizer", Description: "Summarizing subagent.", Prompt: "Summarize.", Model: "beta/haiku"}, + }, + }) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + + result, err := session.RPC.Agent.List(t.Context()) + if err != nil { + t.Fatalf("Agent.List failed: %v", err) + } + + // All four custom agents coexist in a single session. + if len(result.Agents) != 4 { + t.Fatalf("Expected 4 agents, got %d", len(result.Agents)) + } + + // Each agent is bound to its configured provider-qualified BYOK model. + boundModels := map[string]string{} + for _, agent := range result.Agents { + model := "" + if agent.Model != nil { + model = *agent.Model + } + boundModels[agent.Name] = model + } + expected := map[string]string{ + "orchestrator": "alpha/sonnet", + "researcher": "beta/opus", + "fast-helper": "alpha/haiku", + "summarizer": "beta/haiku", + } + for name, want := range expected { + if got := boundModels[name]; got != want { + t.Errorf("Expected agent %q bound to model %q, got %q", name, want, got) + } + } + + // Models from BOTH providers are represented, proving the two providers + // and their models coexist within the same session. + var hasAlpha, hasBeta bool + for _, model := range boundModels { + if strings.HasPrefix(model, "alpha/") { + hasAlpha = true + } + if strings.HasPrefix(model, "beta/") { + hasBeta = true + } + } + if !hasAlpha || !hasBeta { + t.Errorf("Expected both providers represented; hasAlpha=%v hasBeta=%v", hasAlpha, hasBeta) + } + }) + + assertRouting := func(t *testing.T, selectionID, expectedWireModel, expectedProviderHeader string) { + ctx.ConfigureForTest(t) + + // Two OpenAI-compatible providers, both pointed at the replay proxy so + // their /chat/completions traffic is captured. They are distinguished + // on the wire by their per-provider X-Provider header. "alpha" carries + // two models (multiple models per provider); "delta" carries one. + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Model: selectionID, + Providers: []copilot.NamedProviderConfig{ + { + Name: "alpha", + Type: "openai", + WireAPI: "completions", + BaseURL: ctx.ProxyURL, + APIKey: "alpha-secret", + Headers: map[string]string{"X-Provider": "alpha"}, + }, + { + Name: "delta", + Type: "openai", + WireAPI: "completions", + BaseURL: ctx.ProxyURL, + APIKey: "delta-secret", + Headers: map[string]string{"X-Provider": "delta"}, + }, + }, + Models: []copilot.ProviderModelConfig{ + {ID: "sonnet", Provider: "alpha", WireModel: "byok-gpt-4o"}, + {ID: "haiku", Provider: "alpha", WireModel: "byok-gpt-4o-mini"}, + {ID: "turbo", Provider: "delta", WireModel: "byok-gpt-4-turbo"}, + }, + }) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + + if _, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 5+5?"}); err != nil { + t.Fatalf("SendAndWait failed: %v", err) + } + + exchanges, err := ctx.GetExchanges() + if err != nil { + t.Fatalf("GetExchanges failed: %v", err) + } + if len(exchanges) != 1 { + t.Fatalf("Expected exactly 1 exchange, got %d", len(exchanges)) + } + exchange := exchanges[0] + + // The wire model sent to the provider is the selected model's WireModel, + // not its provider-qualified selection id. + if exchange.Request.Model != expectedWireModel { + t.Errorf("Expected request model %q, got %q", expectedWireModel, exchange.Request.Model) + } + + // The request carried the owning provider's custom header, proving the + // turn was dispatched against the correct provider connection. + if !exchangeHasHeader(exchange, "X-Provider", expectedProviderHeader) { + t.Errorf("Expected X-Provider header %q to be present", expectedProviderHeader) + } + + // The provider's API key was applied as an Authorization header. + if !exchangeHasHeader(exchange, "Authorization", "Bearer") { + t.Error("Expected an Authorization header on the dispatched request") + } + } + + t.Run("should route alpha sonnet turn to its provider and wire model", func(t *testing.T) { + assertRouting(t, "alpha/sonnet", "byok-gpt-4o", "alpha") + }) + + t.Run("should route alpha haiku turn to its provider and wire model", func(t *testing.T) { + assertRouting(t, "alpha/haiku", "byok-gpt-4o-mini", "alpha") + }) + + t.Run("should route delta turbo turn to its provider and wire model", func(t *testing.T) { + assertRouting(t, "delta/turbo", "byok-gpt-4-turbo", "delta") + }) +} diff --git a/go/types.go b/go/types.go index 5ed0b6931..ba83c6b6d 100644 --- a/go/types.go +++ b/go/types.go @@ -983,6 +983,18 @@ type SessionConfig struct { IncludeSubAgentStreamingEvents *bool // Provider configures a custom model provider (BYOK) Provider *ProviderConfig + // Providers configures named BYOK provider connections. Additive to Copilot + // API auth (unlike Provider); combine with Models. Cannot be combined with Provider. + // + // Experimental: Providers is part of an experimental multi-provider BYOK + // surface and may change or be removed in future SDK or CLI releases. + Providers []NamedProviderConfig + // Models adds BYOK model definitions to the session's selectable model list, + // each referencing a Providers entry by name. + // + // Experimental: Models is part of an experimental multi-provider BYOK + // surface and may change or be removed in future SDK or CLI releases. + Models []ProviderModelConfig // EnableSessionTelemetry enables or disables internal session telemetry for this session. // When false, disables session telemetry. When nil (the default) or true, // telemetry is enabled for GitHub-authenticated sessions. When a custom @@ -1316,6 +1328,18 @@ type ResumeSessionConfig struct { ExcludedTools []string // Provider configures a custom model provider Provider *ProviderConfig + // Providers configures named BYOK provider connections. Additive to Copilot + // API auth (unlike Provider); combine with Models. Cannot be combined with Provider. + // + // Experimental: Providers is part of an experimental multi-provider BYOK + // surface and may change or be removed in future SDK or CLI releases. + Providers []NamedProviderConfig + // Models adds BYOK model definitions to the session's selectable model list, + // each referencing a Providers entry by name. + // + // Experimental: Models is part of an experimental multi-provider BYOK + // surface and may change or be removed in future SDK or CLI releases. + Models []ProviderModelConfig // EnableSessionTelemetry enables or disables internal session telemetry for this session. // When false, disables session telemetry. When nil (the default) or true, // telemetry is enabled for GitHub-authenticated sessions. When a custom @@ -1546,6 +1570,65 @@ type AzureProviderOptions struct { APIVersion string `json:"apiVersion,omitempty"` } +// NamedProviderConfig is a named BYOK provider connection (transport + +// credentials), referenced by ProviderModelConfig entries via Name. +// +// Unlike the singular Provider (which makes the whole session BYOK and bypasses +// Copilot API authentication), named providers are additive: they coexist with +// Copilot API auth so models from CAPI and one or more BYOK providers can be +// mixed within a single session and across sub-agents. Combining Providers and +// Models with Provider is rejected. +// +// Experimental: NamedProviderConfig is part of an experimental multi-provider +// BYOK surface and may change or be removed in future SDK or CLI releases. +type NamedProviderConfig struct { + // Name is the stable identifier referenced by ProviderModelConfig.Provider. + // Must not contain "/". + Name string `json:"name"` + // Type is the provider type: "openai", "azure", or "anthropic". Defaults to "openai". + Type string `json:"type,omitempty"` + // WireAPI is the API format (openai/azure only): "completions" or "responses". Defaults to "completions". + WireAPI string `json:"wireApi,omitempty"` + // BaseURL is the API endpoint URL. + BaseURL string `json:"baseUrl"` + // APIKey is the API key. Optional for local providers like Ollama. + APIKey string `json:"apiKey,omitempty"` + // BearerToken for authentication. Sets the Authorization header directly. + // Takes precedence over APIKey when both are set. + BearerToken string `json:"bearerToken,omitempty"` + // Azure contains Azure-specific options. + Azure *AzureProviderOptions `json:"azure,omitempty"` + // Headers are custom HTTP headers included in all outbound provider requests. + Headers map[string]string `json:"headers,omitempty"` +} + +// ProviderModelConfig is a BYOK model definition that references a +// NamedProviderConfig by name and is added to the session's selectable model +// list. The session-wide selection id is the provider-qualified "provider/id". +// +// Experimental: ProviderModelConfig is part of an experimental multi-provider +// BYOK surface and may change or be removed in future SDK or CLI releases. +type ProviderModelConfig struct { + // ID is the provider-local model id, unique within its provider. + ID string `json:"id"` + // Provider is the name of the NamedProviderConfig that serves this model. + Provider string `json:"provider"` + // WireModel is the model name sent to the provider API for inference. Defaults to ID. + WireModel string `json:"wireModel,omitempty"` + // ModelID is the well-known base model id used for behavior/capability/config lookup. Defaults to ID. + ModelID string `json:"modelId,omitempty"` + // Name is the display name for model pickers. Defaults to the provider-qualified selection id. + Name string `json:"name,omitempty"` + // MaxPromptTokens is the maximum prompt/input tokens for the model. + MaxPromptTokens int `json:"maxPromptTokens,omitempty"` + // MaxContextWindowTokens is the maximum context window tokens for the model. + MaxContextWindowTokens int `json:"maxContextWindowTokens,omitempty"` + // MaxOutputTokens is the maximum output tokens for the model. + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + // Capabilities holds optional capability overrides for the synthesized model. + Capabilities *rpc.ModelCapabilitiesOverride `json:"capabilities,omitempty"` +} + // ToolBinaryResult represents binary payloads returned by tools. type ToolBinaryResult struct { Data string `json:"data"` @@ -1721,6 +1804,8 @@ type createSessionRequest struct { ExcludedTools []string `json:"excludedTools,omitempty"` ToolFilterPrecedence *rpc.OptionsUpdateToolFilterPrecedence `json:"toolFilterPrecedence,omitempty"` Provider *ProviderConfig `json:"provider,omitempty"` + Providers []NamedProviderConfig `json:"providers,omitempty"` + Models []ProviderModelConfig `json:"models,omitempty"` EnableSessionTelemetry *bool `json:"enableSessionTelemetry,omitempty"` SkipCustomInstructions *bool `json:"skipCustomInstructions,omitempty"` CustomAgentsLocalOnly *bool `json:"customAgentsLocalOnly,omitempty"` @@ -1800,6 +1885,8 @@ type resumeSessionRequest struct { ExcludedTools []string `json:"excludedTools,omitempty"` ToolFilterPrecedence *rpc.OptionsUpdateToolFilterPrecedence `json:"toolFilterPrecedence,omitempty"` Provider *ProviderConfig `json:"provider,omitempty"` + Providers []NamedProviderConfig `json:"providers,omitempty"` + Models []ProviderModelConfig `json:"models,omitempty"` EnableSessionTelemetry *bool `json:"enableSessionTelemetry,omitempty"` SkipCustomInstructions *bool `json:"skipCustomInstructions,omitempty"` CustomAgentsLocalOnly *bool `json:"customAgentsLocalOnly,omitempty"` diff --git a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java index 5697c7060..66d3e4344 100644 --- a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java +++ b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java @@ -113,6 +113,8 @@ static CreateSessionRequest buildCreateRequest(SessionConfig config, String sess request.setAvailableTools(config.getAvailableTools()); request.setExcludedTools(config.getExcludedTools()); request.setProvider(config.getProvider()); + request.setProviders(config.getProviders()); + request.setModels(config.getModels()); config.getEnableSessionTelemetry().ifPresent(request::setEnableSessionTelemetry); if (config.getOnUserInputRequest() != null) { request.setRequestUserInput(true); @@ -225,6 +227,8 @@ static ResumeSessionRequest buildResumeRequest(String sessionId, ResumeSessionCo request.setAvailableTools(config.getAvailableTools()); request.setExcludedTools(config.getExcludedTools()); request.setProvider(config.getProvider()); + request.setProviders(config.getProviders()); + request.setModels(config.getModels()); config.getEnableSessionTelemetry().ifPresent(request::setEnableSessionTelemetry); if (config.getOnUserInputRequest() != null) { request.setRequestUserInput(true); diff --git a/java/src/main/java/com/github/copilot/rpc/AgentInfo.java b/java/src/main/java/com/github/copilot/rpc/AgentInfo.java index 84e512644..1f6f1688f 100644 --- a/java/src/main/java/com/github/copilot/rpc/AgentInfo.java +++ b/java/src/main/java/com/github/copilot/rpc/AgentInfo.java @@ -24,6 +24,9 @@ public class AgentInfo { @JsonProperty("description") private String description; + @JsonProperty("model") + private String model; + /** * Gets the unique identifier of the agent. * @@ -86,4 +89,26 @@ public AgentInfo setDescription(String description) { this.description = description; return this; } + + /** + * Gets the preferred model id for this agent. When omitted, the agent inherits + * the outer agent's model. + * + * @return the preferred model id, or {@code null} if unset + */ + public String getModel() { + return model; + } + + /** + * Sets the preferred model id for this agent. + * + * @param model + * the preferred model id + * @return this instance for chaining + */ + public AgentInfo setModel(String model) { + this.model = model; + return this; + } } diff --git a/java/src/main/java/com/github/copilot/rpc/CreateSessionRequest.java b/java/src/main/java/com/github/copilot/rpc/CreateSessionRequest.java index 7211cc36c..42a431f49 100644 --- a/java/src/main/java/com/github/copilot/rpc/CreateSessionRequest.java +++ b/java/src/main/java/com/github/copilot/rpc/CreateSessionRequest.java @@ -11,6 +11,8 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.github.copilot.CopilotExperimental; + /** * Internal request object for creating a new session. *

@@ -60,6 +62,12 @@ public final class CreateSessionRequest { @JsonProperty("provider") private ProviderConfig provider; + @JsonProperty("providers") + private List providers; + + @JsonProperty("models") + private List models; + @JsonProperty("enableSessionTelemetry") private Boolean enableSessionTelemetry; @@ -313,6 +321,30 @@ public void setProvider(ProviderConfig provider) { this.provider = provider; } + /** Gets the named provider connections. @return the named providers */ + @CopilotExperimental + public List getProviders() { + return providers; + } + + /** Sets the named provider connections. @param providers the named providers */ + @CopilotExperimental + public void setProviders(List providers) { + this.providers = providers; + } + + /** Gets the BYOK model definitions. @return the models */ + @CopilotExperimental + public List getModels() { + return models; + } + + /** Sets the BYOK model definitions. @param models the models */ + @CopilotExperimental + public void setModels(List models) { + this.models = models; + } + /** Gets enable session telemetry flag. @return the flag */ public Boolean getEnableSessionTelemetry() { return enableSessionTelemetry; diff --git a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java new file mode 100644 index 000000000..dbc157739 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java @@ -0,0 +1,257 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import java.util.Collections; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import com.github.copilot.CopilotExperimental; + +/** + * A named BYOK (Bring Your Own Key) provider connection in the multi-provider + * registry. + *

+ * Unlike {@link ProviderConfig}, which routes the entire session through a + * single provider, named providers are additive: the session keeps its default + * Copilot routing and exposes these providers' models alongside it. Models are + * attached via {@link ProviderModelConfig}, which references a provider by + * {@link #getName() name}. All setter methods return {@code this} for method + * chaining. + *

+ * Experimental. Multi-provider BYOK configuration is + * experimental and may change or be removed in future SDK or CLI releases. + * + *

Example Usage

+ * + *
{@code
+ * var provider = new NamedProviderConfig().setName("my-openai").setType("openai")
+ * 		.setBaseUrl("https://api.openai.com/v1").setApiKey("sk-...");
+ * }
+ * + * @see SessionConfig#setProviders(java.util.List) + * @see ProviderModelConfig + * @since 1.0.0 + */ +@CopilotExperimental +@JsonInclude(JsonInclude.Include.NON_NULL) +public class NamedProviderConfig { + + @JsonProperty("name") + private String name; + + @JsonProperty("type") + private String type; + + @JsonProperty("wireApi") + private String wireApi; + + @JsonProperty("baseUrl") + private String baseUrl; + + @JsonProperty("apiKey") + private String apiKey; + + @JsonProperty("bearerToken") + private String bearerToken; + + @JsonProperty("azure") + private AzureOptions azure; + + @JsonProperty("headers") + private Map headers; + + /** + * Gets the unique provider name. + * + * @return the provider name + */ + public String getName() { + return name; + } + + /** + * Sets the unique provider name. + *

+ * Referenced by {@link ProviderModelConfig#setProvider(String)} to attach + * models to this connection. + * + * @param name + * the provider name + * @return this config for method chaining + */ + public NamedProviderConfig setName(String name) { + this.name = name; + return this; + } + + /** + * Gets the provider type. + * + * @return the provider type (e.g., "openai", "azure", "anthropic") + */ + public String getType() { + return type; + } + + /** + * Sets the provider type. + *

+ * Supported types include: + *

    + *
  • "openai" - OpenAI API
  • + *
  • "azure" - Azure OpenAI Service
  • + *
  • "anthropic" - Anthropic API
  • + *
+ * + * @param type + * the provider type + * @return this config for method chaining + */ + public NamedProviderConfig setType(String type) { + this.type = type; + return this; + } + + /** + * Gets the wire API format. + * + * @return the wire API format + */ + public String getWireApi() { + return wireApi; + } + + /** + * Sets the wire API format (openai/azure only). + *

+ * Either "completions" or "responses". Defaults to "completions". + * + * @param wireApi + * the wire API format + * @return this config for method chaining + */ + public NamedProviderConfig setWireApi(String wireApi) { + this.wireApi = wireApi; + return this; + } + + /** + * Gets the base URL for the API. + * + * @return the API base URL + */ + public String getBaseUrl() { + return baseUrl; + } + + /** + * Sets the base URL for the API. + *

+ * For OpenAI, this is typically "https://api.openai.com/v1". + * + * @param baseUrl + * the API base URL + * @return this config for method chaining + */ + public NamedProviderConfig setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** + * Gets the API key. + * + * @return the API key + */ + public String getApiKey() { + return apiKey; + } + + /** + * Sets the API key for authentication. Optional for local providers like + * Ollama. + * + * @param apiKey + * the API key + * @return this config for method chaining + */ + public NamedProviderConfig setApiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** + * Gets the bearer token. + * + * @return the bearer token + */ + public String getBearerToken() { + return bearerToken; + } + + /** + * Sets a bearer token for authentication. + *

+ * Sets the {@code Authorization} header directly and takes precedence over + * {@link #setApiKey(String)} when both are set. + *

+ * Note: The bearer token is a static token + * string. The SDK does not refresh this token automatically. + * + * @param bearerToken + * the bearer token + * @return this config for method chaining + */ + public NamedProviderConfig setBearerToken(String bearerToken) { + this.bearerToken = bearerToken; + return this; + } + + /** + * Gets the Azure-specific options. + * + * @return the Azure options + */ + public AzureOptions getAzure() { + return azure; + } + + /** + * Sets Azure-specific options for Azure OpenAI Service. + * + * @param azure + * the Azure options + * @return this config for method chaining + * @see AzureOptions + */ + public NamedProviderConfig setAzure(AzureOptions azure) { + this.azure = azure; + return this; + } + + /** + * Gets the custom HTTP headers for outbound provider requests. + * + * @return the headers map, or {@code null} if not set + */ + public Map getHeaders() { + return headers == null ? null : Collections.unmodifiableMap(headers); + } + + /** + * Sets custom HTTP headers to include in outbound provider requests. + * + * @param headers + * the headers map + * @return this config for method chaining + */ + public NamedProviderConfig setHeaders(Map headers) { + this.headers = headers; + return this; + } +} diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderModelConfig.java b/java/src/main/java/com/github/copilot/rpc/ProviderModelConfig.java new file mode 100644 index 000000000..e191e32d9 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/ProviderModelConfig.java @@ -0,0 +1,298 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import java.util.OptionalInt; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import com.github.copilot.CopilotExperimental; + +/** + * A BYOK (Bring Your Own Key) model definition in the multi-provider registry. + *

+ * References a {@link NamedProviderConfig} by {@link #getProvider() provider} + * and becomes selectable under the provider-qualified id {@code provider/id}. + * All setter methods return {@code this} for method chaining. + *

+ * Experimental. Multi-provider BYOK configuration is + * experimental and may change or be removed in future SDK or CLI releases. + * + *

Example Usage

+ * + *
{@code
+ * var model = new ProviderModelConfig().setId("gpt-x").setProvider("my-openai").setWireModel("gpt-x-2025");
+ * }
+ * + * @see SessionConfig#setModels(java.util.List) + * @see NamedProviderConfig + * @since 1.0.0 + */ +@CopilotExperimental +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ProviderModelConfig { + + @JsonProperty("id") + private String id; + + @JsonProperty("provider") + private String provider; + + @JsonProperty("wireModel") + private String wireModel; + + @JsonProperty("modelId") + private String modelId; + + @JsonProperty("name") + private String name; + + @JsonProperty("maxPromptTokens") + private Integer maxPromptTokens; + + @JsonProperty("maxContextWindowTokens") + private Integer maxContextWindowTokens; + + @JsonProperty("maxOutputTokens") + private Integer maxOutputTokens; + + @JsonProperty("capabilities") + private ModelCapabilitiesOverride capabilities; + + /** + * Gets the model identifier. + * + * @return the model id + */ + public String getId() { + return id; + } + + /** + * Sets the model identifier, unique within its provider. + *

+ * Combined with {@link #getProvider() provider} to form the selection id + * {@code provider/id}. + * + * @param id + * the model id + * @return this config for method chaining + */ + public ProviderModelConfig setId(String id) { + this.id = id; + return this; + } + + /** + * Gets the name of the provider this model is served by. + * + * @return the provider name + */ + public String getProvider() { + return provider; + } + + /** + * Sets the name of the {@link NamedProviderConfig} this model is served by. + * + * @param provider + * the provider name + * @return this config for method chaining + */ + public ProviderModelConfig setProvider(String provider) { + this.provider = provider; + return this; + } + + /** + * Gets the model name sent to the provider API for inference. + * + * @return the wire model name, or {@code null} if not set + */ + public String getWireModel() { + return wireModel; + } + + /** + * Sets the model name sent to the provider API for inference. + *

+ * Use this when the provider's model name differs from {@link #getId() id}. + * + * @param wireModel + * the wire model name + * @return this config for method chaining + */ + public ProviderModelConfig setWireModel(String wireModel) { + this.wireModel = wireModel; + return this; + } + + /** + * Gets the well-known model ID used to look up agent config and default token + * limits. + * + * @return the model ID, or {@code null} if not set + */ + public String getModelId() { + return modelId; + } + + /** + * Sets the well-known model ID used to look up agent config and default token + * limits. + * + * @param modelId + * the model ID + * @return this config for method chaining + */ + public ProviderModelConfig setModelId(String modelId) { + this.modelId = modelId; + return this; + } + + /** + * Gets the human-readable display name. + * + * @return the display name, or {@code null} if not set + */ + public String getName() { + return name; + } + + /** + * Sets the human-readable display name. + * + * @param name + * the display name + * @return this config for method chaining + */ + public ProviderModelConfig setName(String name) { + this.name = name; + return this; + } + + /** + * Gets the maximum prompt token override. + * + * @return an {@link java.util.OptionalInt} containing the max prompt tokens, or + * {@link java.util.OptionalInt#empty()} if not set + */ + @JsonIgnore + public OptionalInt getMaxPromptTokens() { + return maxPromptTokens == null ? OptionalInt.empty() : OptionalInt.of(maxPromptTokens); + } + + /** + * Sets the maximum prompt tokens override. + * + * @param maxPromptTokens + * the max prompt tokens + * @return this config for method chaining + */ + public ProviderModelConfig setMaxPromptTokens(int maxPromptTokens) { + this.maxPromptTokens = maxPromptTokens; + return this; + } + + /** + * Clears the maxPromptTokens setting, reverting to the default behavior. + * + * @return this config for method chaining + */ + public ProviderModelConfig clearMaxPromptTokens() { + this.maxPromptTokens = null; + return this; + } + + /** + * Gets the maximum context window token override. + * + * @return an {@link java.util.OptionalInt} containing the max context window + * tokens, or {@link java.util.OptionalInt#empty()} if not set + */ + @JsonIgnore + public OptionalInt getMaxContextWindowTokens() { + return maxContextWindowTokens == null ? OptionalInt.empty() : OptionalInt.of(maxContextWindowTokens); + } + + /** + * Sets the maximum context window tokens override. + * + * @param maxContextWindowTokens + * the max context window tokens + * @return this config for method chaining + */ + public ProviderModelConfig setMaxContextWindowTokens(int maxContextWindowTokens) { + this.maxContextWindowTokens = maxContextWindowTokens; + return this; + } + + /** + * Clears the maxContextWindowTokens setting, reverting to the default behavior. + * + * @return this config for method chaining + */ + public ProviderModelConfig clearMaxContextWindowTokens() { + this.maxContextWindowTokens = null; + return this; + } + + /** + * Gets the maximum output token override. + * + * @return an {@link java.util.OptionalInt} containing the max output tokens, or + * {@link java.util.OptionalInt#empty()} if not set + */ + @JsonIgnore + public OptionalInt getMaxOutputTokens() { + return maxOutputTokens == null ? OptionalInt.empty() : OptionalInt.of(maxOutputTokens); + } + + /** + * Sets the maximum output tokens override. + * + * @param maxOutputTokens + * the max output tokens + * @return this config for method chaining + */ + public ProviderModelConfig setMaxOutputTokens(int maxOutputTokens) { + this.maxOutputTokens = maxOutputTokens; + return this; + } + + /** + * Clears the maxOutputTokens setting, reverting to the default behavior. + * + * @return this config for method chaining + */ + public ProviderModelConfig clearMaxOutputTokens() { + this.maxOutputTokens = null; + return this; + } + + /** + * Gets the per-property model capability overrides. + * + * @return the capabilities override, or {@code null} if not set + */ + public ModelCapabilitiesOverride getCapabilities() { + return capabilities; + } + + /** + * Sets per-property model capability overrides, deep-merged over runtime + * defaults. + * + * @param capabilities + * the capabilities override + * @return this config for method chaining + */ + public ProviderModelConfig setCapabilities(ModelCapabilitiesOverride capabilities) { + this.capabilities = capabilities; + return this; + } +} diff --git a/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java b/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java index 680d337ec..fa900aceb 100644 --- a/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java @@ -13,6 +13,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonIgnore; +import com.github.copilot.CopilotExperimental; import com.github.copilot.generated.SessionEvent; import java.util.Optional; @@ -45,6 +46,8 @@ public class ResumeSessionConfig { private List availableTools; private List excludedTools; private ProviderConfig provider; + private List providers; + private List models; private Boolean enableSessionTelemetry; private Boolean skipCustomInstructions; private Boolean customAgentsLocalOnly; @@ -254,6 +257,58 @@ public ResumeSessionConfig setProvider(ProviderConfig provider) { return this; } + /** + * Gets the named BYOK provider connections. + * + * @return the named provider connections, or {@code null} if not set + */ + @CopilotExperimental + public List getProviders() { + return providers; + } + + /** + * Re-supplies the named BYOK provider connections on resume (additive + * multi-provider registry). + *

+ * Attach models referencing these connections with {@link #setModels(List)}. + * + * @param providers + * the named provider connections + * @return this config instance for method chaining + * @see NamedProviderConfig + */ + @CopilotExperimental + public ResumeSessionConfig setProviders(List providers) { + this.providers = providers; + return this; + } + + /** + * Gets the BYOK model definitions. + * + * @return the model definitions, or {@code null} if not set + */ + @CopilotExperimental + public List getModels() { + return models; + } + + /** + * Re-supplies the BYOK model definitions on resume, each referencing a named + * provider supplied via {@link #setProviders(List)}. + * + * @param models + * the model definitions + * @return this config instance for method chaining + * @see ProviderModelConfig + */ + @CopilotExperimental + public ResumeSessionConfig setModels(List models) { + this.models = models; + return this; + } + /** * Enables or disables internal session telemetry for this session. When * {@code false}, disables session telemetry. When unset (the default) or @@ -1548,6 +1603,8 @@ public ResumeSessionConfig clone() { copy.availableTools = this.availableTools != null ? new ArrayList<>(this.availableTools) : null; copy.excludedTools = this.excludedTools != null ? new ArrayList<>(this.excludedTools) : null; copy.provider = this.provider; + copy.providers = this.providers != null ? new ArrayList<>(this.providers) : null; + copy.models = this.models != null ? new ArrayList<>(this.models) : null; copy.enableSessionTelemetry = this.enableSessionTelemetry; copy.reasoningEffort = this.reasoningEffort; copy.reasoningSummary = this.reasoningSummary; diff --git a/java/src/main/java/com/github/copilot/rpc/ResumeSessionRequest.java b/java/src/main/java/com/github/copilot/rpc/ResumeSessionRequest.java index e88be7a9f..2067a291c 100644 --- a/java/src/main/java/com/github/copilot/rpc/ResumeSessionRequest.java +++ b/java/src/main/java/com/github/copilot/rpc/ResumeSessionRequest.java @@ -11,6 +11,8 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.github.copilot.CopilotExperimental; + /** * Internal request object for resuming an existing session. *

@@ -62,6 +64,12 @@ public final class ResumeSessionRequest { @JsonProperty("provider") private ProviderConfig provider; + @JsonProperty("providers") + private List providers; + + @JsonProperty("models") + private List models; + @JsonProperty("enableSessionTelemetry") private Boolean enableSessionTelemetry; @@ -318,6 +326,30 @@ public void setProvider(ProviderConfig provider) { this.provider = provider; } + /** Gets the named provider connections. @return the named providers */ + @CopilotExperimental + public List getProviders() { + return providers; + } + + /** Sets the named provider connections. @param providers the named providers */ + @CopilotExperimental + public void setProviders(List providers) { + this.providers = providers; + } + + /** Gets the BYOK model definitions. @return the models */ + @CopilotExperimental + public List getModels() { + return models; + } + + /** Sets the BYOK model definitions. @param models the models */ + @CopilotExperimental + public void setModels(List models) { + this.models = models; + } + /** Gets enable session telemetry flag. @return the flag */ public Boolean getEnableSessionTelemetry() { return enableSessionTelemetry; diff --git a/java/src/main/java/com/github/copilot/rpc/SessionConfig.java b/java/src/main/java/com/github/copilot/rpc/SessionConfig.java index ded429867..ef483c410 100644 --- a/java/src/main/java/com/github/copilot/rpc/SessionConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/SessionConfig.java @@ -13,6 +13,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonIgnore; +import com.github.copilot.CopilotExperimental; import com.github.copilot.generated.SessionEvent; import java.util.Optional; @@ -49,6 +50,8 @@ public class SessionConfig { private List availableTools; private List excludedTools; private ProviderConfig provider; + private List providers; + private List models; private Boolean enableSessionTelemetry; private Boolean skipCustomInstructions; private Boolean customAgentsLocalOnly; @@ -355,6 +358,59 @@ public SessionConfig setProvider(ProviderConfig provider) { return this; } + /** + * Gets the named BYOK provider connections. + * + * @return the named provider connections, or {@code null} if not set + */ + @CopilotExperimental + public List getProviders() { + return providers; + } + + /** + * Sets the named BYOK provider connections (additive multi-provider registry). + *

+ * Unlike {@link #setProvider(ProviderConfig)}, these do not switch the whole + * session to BYOK; they are exposed alongside the default Copilot routing. + * Attach models referencing these connections with {@link #setModels(List)}. + * + * @param providers + * the named provider connections + * @return this config instance for method chaining + * @see NamedProviderConfig + */ + @CopilotExperimental + public SessionConfig setProviders(List providers) { + this.providers = providers; + return this; + } + + /** + * Gets the BYOK model definitions. + * + * @return the model definitions, or {@code null} if not set + */ + @CopilotExperimental + public List getModels() { + return models; + } + + /** + * Sets the BYOK model definitions, each referencing a named provider supplied + * via {@link #setProviders(List)}. + * + * @param models + * the model definitions + * @return this config instance for method chaining + * @see ProviderModelConfig + */ + @CopilotExperimental + public SessionConfig setModels(List models) { + this.models = models; + return this; + } + /** * Enables or disables internal session telemetry for this session. When * {@code false}, disables session telemetry. When unset (the default) or @@ -1671,6 +1727,8 @@ public SessionConfig clone() { copy.availableTools = this.availableTools != null ? new ArrayList<>(this.availableTools) : null; copy.excludedTools = this.excludedTools != null ? new ArrayList<>(this.excludedTools) : null; copy.provider = this.provider; + copy.providers = this.providers != null ? new ArrayList<>(this.providers) : null; + copy.models = this.models != null ? new ArrayList<>(this.models) : null; copy.enableSessionTelemetry = this.enableSessionTelemetry; copy.skipCustomInstructions = this.skipCustomInstructions; copy.customAgentsLocalOnly = this.customAgentsLocalOnly; diff --git a/java/src/test/java/com/github/copilot/AgentInfoTest.java b/java/src/test/java/com/github/copilot/AgentInfoTest.java index 3b15f5582..40654292f 100644 --- a/java/src/test/java/com/github/copilot/AgentInfoTest.java +++ b/java/src/test/java/com/github/copilot/AgentInfoTest.java @@ -21,6 +21,7 @@ void defaultValuesAreNull() { assertNull(agent.getName()); assertNull(agent.getDisplayName()); assertNull(agent.getDescription()); + assertNull(agent.getModel()); } @Test @@ -44,14 +45,22 @@ void descriptionGetterSetter() { assertEquals("Helps with coding tasks", agent.getDescription()); } + @Test + void modelGetterSetter() { + var agent = new AgentInfo(); + agent.setModel("alpha/sonnet"); + assertEquals("alpha/sonnet", agent.getModel()); + } + @Test void fluentChainingReturnsThis() { var agent = new AgentInfo().setName("coder").setDisplayName("Code Assistant") - .setDescription("Helps with coding tasks"); + .setDescription("Helps with coding tasks").setModel("alpha/sonnet"); assertEquals("coder", agent.getName()); assertEquals("Code Assistant", agent.getDisplayName()); assertEquals("Helps with coding tasks", agent.getDescription()); + assertEquals("alpha/sonnet", agent.getModel()); } @Test @@ -60,5 +69,6 @@ void fluentChainingReturnsSameInstance() { assertSame(agent, agent.setName("test")); assertSame(agent, agent.setDisplayName("Test")); assertSame(agent, agent.setDescription("A test agent")); + assertSame(agent, agent.setModel("alpha/sonnet")); } } diff --git a/java/src/test/java/com/github/copilot/MultiProviderConfigTest.java b/java/src/test/java/com/github/copilot/MultiProviderConfigTest.java new file mode 100644 index 000000000..171e525cf --- /dev/null +++ b/java/src/test/java/com/github/copilot/MultiProviderConfigTest.java @@ -0,0 +1,190 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import com.github.copilot.rpc.AzureOptions; +import com.github.copilot.rpc.NamedProviderConfig; +import com.github.copilot.rpc.ProviderModelConfig; +import com.github.copilot.rpc.ResumeSessionConfig; +import com.github.copilot.rpc.SessionConfig; + +/** + * Tests for the additive multi-provider BYOK registry: + * {@link NamedProviderConfig}, {@link ProviderModelConfig}, and their + * integration with {@link SessionConfig} and {@link ResumeSessionConfig}. + */ +public class MultiProviderConfigTest { + + private static final ObjectMapper MAPPER = JsonRpcClient.getObjectMapper(); + + @Test + void testNamedProviderConfigDefaultsAreNull() { + var provider = new NamedProviderConfig(); + + assertNull(provider.getName()); + assertNull(provider.getType()); + assertNull(provider.getWireApi()); + assertNull(provider.getBaseUrl()); + assertNull(provider.getApiKey()); + assertNull(provider.getBearerToken()); + assertNull(provider.getAzure()); + assertNull(provider.getHeaders()); + } + + @Test + void testNamedProviderConfigFluentSettersReturnSameInstance() { + var provider = new NamedProviderConfig(); + + NamedProviderConfig result = provider.setName("my-openai").setType("openai").setWireApi("responses") + .setBaseUrl("https://api.openai.com/v1").setApiKey("sk-test").setBearerToken("bearer") + .setAzure(new AzureOptions()).setHeaders(Map.of("X-Custom", "v")); + + assertEquals(provider, result); + } + + @Test + void testSerializeNamedProviderConfig() throws Exception { + var provider = new NamedProviderConfig().setName("my-openai").setType("openai").setWireApi("responses") + .setBaseUrl("https://api.openai.com/v1").setApiKey("sk-test"); + + JsonNode json = MAPPER.valueToTree(provider); + + assertEquals("my-openai", json.get("name").asText()); + assertEquals("openai", json.get("type").asText()); + assertEquals("responses", json.get("wireApi").asText()); + assertEquals("https://api.openai.com/v1", json.get("baseUrl").asText()); + assertEquals("sk-test", json.get("apiKey").asText()); + // Null fields must be omitted (NON_NULL) + assertTrue(json.path("bearerToken").isMissingNode()); + assertTrue(json.path("azure").isMissingNode()); + assertTrue(json.path("headers").isMissingNode()); + } + + @Test + void testProviderModelConfigDefaultsAreNull() { + var model = new ProviderModelConfig(); + + assertNull(model.getId()); + assertNull(model.getProvider()); + assertNull(model.getWireModel()); + assertNull(model.getModelId()); + assertNull(model.getName()); + assertTrue(model.getMaxPromptTokens().isEmpty()); + assertTrue(model.getMaxContextWindowTokens().isEmpty()); + assertTrue(model.getMaxOutputTokens().isEmpty()); + assertNull(model.getCapabilities()); + } + + @Test + void testSerializeProviderModelConfig() throws Exception { + var model = new ProviderModelConfig().setId("gpt-x").setProvider("my-openai").setWireModel("gpt-x-2025") + .setModelId("gpt-4o").setName("My GPT-X").setMaxPromptTokens(100_000).setMaxContextWindowTokens(128_000) + .setMaxOutputTokens(4096); + + JsonNode json = MAPPER.valueToTree(model); + + assertEquals("gpt-x", json.get("id").asText()); + assertEquals("my-openai", json.get("provider").asText()); + assertEquals("gpt-x-2025", json.get("wireModel").asText()); + assertEquals("gpt-4o", json.get("modelId").asText()); + assertEquals("My GPT-X", json.get("name").asText()); + assertEquals(100_000, json.get("maxPromptTokens").asInt()); + assertEquals(128_000, json.get("maxContextWindowTokens").asInt()); + assertEquals(4096, json.get("maxOutputTokens").asInt()); + assertTrue(json.path("capabilities").isMissingNode()); + + // Round-trip + ProviderModelConfig deserialized = MAPPER.readValue(MAPPER.writeValueAsString(model), + ProviderModelConfig.class); + assertEquals("gpt-x", deserialized.getId()); + assertEquals("my-openai", deserialized.getProvider()); + assertEquals(100_000, deserialized.getMaxPromptTokens().getAsInt()); + assertEquals(128_000, deserialized.getMaxContextWindowTokens().getAsInt()); + assertEquals(4096, deserialized.getMaxOutputTokens().getAsInt()); + } + + @Test + void testSessionConfigWithProvidersAndModels() throws Exception { + var config = new SessionConfig().setModel("gpt-4") + .setProviders(List.of(new NamedProviderConfig().setName("my-openai").setType("openai") + .setBaseUrl("https://api.openai.com/v1").setApiKey("sk-test"))) + .setModels(List.of(new ProviderModelConfig().setId("gpt-x").setProvider("my-openai"))); + + JsonNode json = MAPPER.valueToTree(config); + + assertNotNull(json.get("providers")); + assertEquals(1, json.get("providers").size()); + assertEquals("my-openai", json.get("providers").get(0).get("name").asText()); + assertNotNull(json.get("models")); + assertEquals("gpt-x", json.get("models").get(0).get("id").asText()); + assertEquals("my-openai", json.get("models").get(0).get("provider").asText()); + } + + @Test + void testSessionConfigWithoutProvidersOmitsFields() throws Exception { + var config = new SessionConfig().setModel("gpt-4"); + + JsonNode json = MAPPER.valueToTree(config); + + assertTrue(json.path("providers").isMissingNode()); + assertTrue(json.path("models").isMissingNode()); + } + + @Test + void testSessionConfigCopyPreservesProvidersAndModels() { + var config = new SessionConfig().setProviders(List.of(new NamedProviderConfig().setName("my-azure"))) + .setModels(List.of(new ProviderModelConfig().setId("deploy-1").setProvider("my-azure"))); + + SessionConfig copy = config.clone(); + + assertNotNull(copy.getProviders()); + assertEquals(1, copy.getProviders().size()); + assertEquals("my-azure", copy.getProviders().get(0).getName()); + assertNotNull(copy.getModels()); + assertEquals("deploy-1", copy.getModels().get(0).getId()); + } + + @Test + void testResumeSessionConfigWithProvidersAndModels() throws Exception { + var config = new ResumeSessionConfig() + .setProviders(List.of(new NamedProviderConfig().setName("my-azure").setType("azure") + .setBaseUrl("https://example.openai.azure.com") + .setAzure(new AzureOptions().setApiVersion("2024-10-21")))) + .setModels(List + .of(new ProviderModelConfig().setId("deploy-1").setProvider("my-azure").setModelId("gpt-4o"))); + + JsonNode json = MAPPER.valueToTree(config); + + assertNotNull(json.get("providers")); + assertEquals("my-azure", json.get("providers").get(0).get("name").asText()); + assertEquals("2024-10-21", json.get("providers").get(0).get("azure").get("apiVersion").asText()); + assertNotNull(json.get("models")); + assertEquals("deploy-1", json.get("models").get(0).get("id").asText()); + assertEquals("gpt-4o", json.get("models").get(0).get("modelId").asText()); + } + + @Test + void testResumeSessionConfigWithoutProvidersOmitsFields() throws Exception { + var config = new ResumeSessionConfig().setStreaming(true); + + JsonNode json = MAPPER.valueToTree(config); + + assertTrue(json.path("providers").isMissingNode()); + assertTrue(json.path("models").isMissingNode()); + } +} diff --git a/java/src/test/java/com/github/copilot/MultiProviderRegistryE2ETest.java b/java/src/test/java/com/github/copilot/MultiProviderRegistryE2ETest.java new file mode 100644 index 000000000..095543881 --- /dev/null +++ b/java/src/test/java/com/github/copilot/MultiProviderRegistryE2ETest.java @@ -0,0 +1,223 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import com.github.copilot.rpc.AgentInfo; +import com.github.copilot.rpc.CustomAgentConfig; +import com.github.copilot.rpc.MessageOptions; +import com.github.copilot.rpc.NamedProviderConfig; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.ProviderModelConfig; +import com.github.copilot.rpc.SessionConfig; + +/** + * End-to-end coverage for the experimental multi-provider BYOK registry + * ({@code SessionConfig.providers} / {@code SessionConfig.models}). Validates + * that several named providers, several models per provider, and custom agents + * bound to those provider-qualified models can coexist in one session, be + * launched, and route inference to the configured provider with the configured + * wire model and headers. + */ +public class MultiProviderRegistryE2ETest { + + private static E2ETestContext ctx; + + @BeforeAll + static void setup() throws Exception { + ctx = E2ETestContext.create(); + } + + @AfterAll + static void teardown() throws Exception { + if (ctx != null) { + ctx.close(); + } + } + + /** + * Builds a heterogeneous registry: two providers of different types, with + * multiple models each. Provider-qualified selection ids are + * {@code alpha/sonnet}, {@code alpha/haiku}, {@code beta/opus}, + * {@code beta/haiku}. + */ + private static List registryProviders() { + return List.of( + new NamedProviderConfig().setName("alpha").setType("openai").setWireApi("completions") + .setBaseUrl("https://alpha.example.test/v1").setApiKey("alpha-secret") + .setHeaders(Map.of("X-Provider", "alpha")), + new NamedProviderConfig().setName("beta").setType("anthropic").setBaseUrl("https://beta.example.test") + .setBearerToken("beta-bearer").setHeaders(Map.of("X-Provider", "beta"))); + } + + private static List registryModels() { + return List.of( + new ProviderModelConfig().setId("sonnet").setProvider("alpha").setWireModel("byok-gpt-4o") + .setMaxPromptTokens(111111), + new ProviderModelConfig().setId("haiku").setProvider("alpha").setWireModel("byok-gpt-4o-mini"), + new ProviderModelConfig().setId("opus").setProvider("beta").setWireModel("byok-claude-3-opus"), + new ProviderModelConfig().setId("haiku").setProvider("beta").setWireModel("byok-claude-3-haiku")); + } + + private static List registryAgents() { + return List.of( + new CustomAgentConfig().setName("orchestrator").setDisplayName("Orchestrator") + .setDescription("Top-level planner.").setPrompt("Plan and delegate.").setModel("alpha/sonnet"), + new CustomAgentConfig().setName("researcher").setDisplayName("Researcher") + .setDescription("Deep research subagent.").setPrompt("Research thoroughly.") + .setModel("beta/opus"), + new CustomAgentConfig().setName("fast-helper").setDisplayName("Fast Helper") + .setDescription("Quick subagent.").setPrompt("Answer quickly.").setModel("alpha/haiku"), + new CustomAgentConfig().setName("summarizer").setDisplayName("Summarizer") + .setDescription("Summarizing subagent.").setPrompt("Summarize.").setModel("beta/haiku")); + } + + @Test + void testShouldRegisterMultipleProvidersWithCustomAgentsBoundToTheirModels() throws Exception { + ctx.configureForTest("multi_provider_registry", + "should_register_multiple_providers_with_custom_agents_bound_to_their_models"); + + try (CopilotClient client = ctx.createClient()) { + CopilotSession session = client + .createSession(new SessionConfig().setProviders(registryProviders()).setModels(registryModels()) + .setCustomAgents(registryAgents()).setOnPermissionRequest(PermissionHandler.APPROVE_ALL)) + .get(); + + List agents = session.listAgents().get(30, TimeUnit.SECONDS); + + // All four custom agents coexist in a single session. + assertEquals(4, agents.size(), "Expected all four custom agents to coexist"); + + // Each agent is bound to its configured provider-qualified BYOK model. + assertAgentModel(agents, "orchestrator", "alpha/sonnet", "Orchestrator", "Top-level planner."); + assertAgentModel(agents, "researcher", "beta/opus", "Researcher", "Deep research subagent."); + assertAgentModel(agents, "fast-helper", "alpha/haiku", "Fast Helper", "Quick subagent."); + assertAgentModel(agents, "summarizer", "beta/haiku", "Summarizer", "Summarizing subagent."); + + // Models from BOTH providers are represented, proving the two + // providers and their models coexist within the same session. + Set boundModels = new HashSet<>(); + for (AgentInfo agent : agents) { + boundModels.add(agent.getModel()); + } + assertTrue(boundModels.stream().anyMatch(m -> m != null && m.startsWith("alpha/")), + "Expected a model from provider 'alpha' to be represented"); + assertTrue(boundModels.stream().anyMatch(m -> m != null && m.startsWith("beta/")), + "Expected a model from provider 'beta' to be represented"); + } + } + + @Test + void testShouldRouteAlphaSonnetTurnToItsProviderAndWireModel() throws Exception { + assertRouting("should_route_alpha_sonnet_turn_to_its_provider_and_wire_model", "alpha/sonnet", "byok-gpt-4o", + "alpha"); + } + + @Test + void testShouldRouteAlphaHaikuTurnToItsProviderAndWireModel() throws Exception { + assertRouting("should_route_alpha_haiku_turn_to_its_provider_and_wire_model", "alpha/haiku", "byok-gpt-4o-mini", + "alpha"); + } + + @Test + void testShouldRouteDeltaTurboTurnToItsProviderAndWireModel() throws Exception { + assertRouting("should_route_delta_turbo_turn_to_its_provider_and_wire_model", "delta/turbo", "byok-gpt-4-turbo", + "delta"); + } + + /** + * Selects {@code selectionId} in a session whose registry holds two + * OpenAI-compatible providers (each pointed at the replay proxy), runs a turn, + * and asserts the captured request used the model's configured wire model and + * carried the owning provider's header and credential. + */ + private void assertRouting(String snapshot, String selectionId, String expectedWireModel, + String expectedProviderHeader) throws Exception { + ctx.configureForTest("multi_provider_registry", snapshot); + + try (CopilotClient client = ctx.createClient()) { + // Two OpenAI-compatible providers, both pointed at the replay proxy + // so their /chat/completions traffic is captured. They are + // distinguished on the wire by their per-provider X-Provider header. + // "alpha" carries two models (multiple models per provider); + // "delta" carries one. + List providers = List.of( + new NamedProviderConfig().setName("alpha").setType("openai").setWireApi("completions") + .setBaseUrl(ctx.getProxyUrl()).setApiKey("alpha-secret") + .setHeaders(Map.of("X-Provider", "alpha")), + new NamedProviderConfig().setName("delta").setType("openai").setWireApi("completions") + .setBaseUrl(ctx.getProxyUrl()).setApiKey("delta-secret") + .setHeaders(Map.of("X-Provider", "delta"))); + List models = List.of( + new ProviderModelConfig().setId("sonnet").setProvider("alpha").setWireModel("byok-gpt-4o"), + new ProviderModelConfig().setId("haiku").setProvider("alpha").setWireModel("byok-gpt-4o-mini"), + new ProviderModelConfig().setId("turbo").setProvider("delta").setWireModel("byok-gpt-4-turbo")); + + CopilotSession session = client.createSession(new SessionConfig().setModel(selectionId) + .setProviders(providers).setModels(models).setOnPermissionRequest(PermissionHandler.APPROVE_ALL)) + .get(); + + session.sendAndWait(new MessageOptions().setPrompt("What is 5+5?")).get(30, TimeUnit.SECONDS); + + List> exchanges = ctx.getExchanges(); + assertEquals(1, exchanges.size(), "Expected exactly one captured /chat/completions exchange"); + Map exchange = exchanges.get(0); + + @SuppressWarnings("unchecked") + Map request = (Map) exchange.get("request"); + + // The wire model sent to the provider is the selected model's wire + // model, not its provider-qualified selection id. + assertEquals(expectedWireModel, request.get("model")); + + // The request carried the owning provider's custom header, proving + // the turn was dispatched against the correct provider connection. + assertEquals(expectedProviderHeader, getHeaderValue(exchange, "X-Provider")); + + // The provider's API key was applied as an Authorization header. + String authorization = getHeaderValue(exchange, "Authorization"); + assertNotNull(authorization, "Expected an Authorization header on the dispatched request"); + assertFalse(authorization.isEmpty(), "Expected a non-empty Authorization header"); + } + } + + private static void assertAgentModel(List agents, String name, String expectedModel, + String expectedDisplayName, String expectedDescription) { + AgentInfo agent = agents.stream().filter(a -> name.equals(a.getName())).findFirst() + .orElseThrow(() -> new AssertionError("Expected an agent named '" + name + "'")); + assertEquals(expectedModel, agent.getModel(), "Unexpected model binding for agent '" + name + "'"); + assertEquals(expectedDisplayName, agent.getDisplayName(), "Unexpected display name for agent '" + name + "'"); + assertEquals(expectedDescription, agent.getDescription(), "Unexpected description for agent '" + name + "'"); + } + + @SuppressWarnings("unchecked") + private static String getHeaderValue(Map exchange, String name) { + Object headersObj = exchange.get("requestHeaders"); + if (!(headersObj instanceof Map headers)) { + return null; + } + for (Map.Entry entry : headers.entrySet()) { + if (entry.getKey() != null && entry.getKey().toString().equalsIgnoreCase(name)) { + Object value = entry.getValue(); + if (value instanceof List list) { + return list.isEmpty() ? null : String.valueOf(list.get(0)); + } + return value != null ? value.toString() : null; + } + } + return null; + } +} diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 4deda08b4..6b4aca13e 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -1221,6 +1221,8 @@ export class CopilotClient { excludedTools: toolFilterOptions.excludedTools, toolFilterPrecedence: toolFilterOptions.toolFilterPrecedence, provider: config.provider, + providers: config.providers, + models: config.models, enableSessionTelemetry: config.enableSessionTelemetry, modelCapabilities: config.modelCapabilities, largeOutput: toWireLargeOutput(config.largeOutput), @@ -1405,6 +1407,8 @@ export class CopilotClient { description: cmd.description, })), provider: config.provider, + providers: config.providers, + models: config.models, modelCapabilities: config.modelCapabilities, largeOutput: toWireLargeOutput(config.largeOutput), requestPermission: diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts index a7ebbbde0..9b266fc9c 100644 --- a/nodejs/src/index.ts +++ b/nodejs/src/index.ts @@ -88,10 +88,12 @@ export type { ModelCapabilitiesOverride, ModelInfo, ModelPolicy, + NamedProviderConfig, PermissionHandler, PermissionRequest, PermissionRequestResult, ProviderConfig, + ProviderModelConfig, RemoteSessionMode, ResumeSessionConfig, SectionOverride, diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index bad1c33ad..f198a88b3 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -1745,6 +1745,32 @@ export interface SessionConfigBase { */ provider?: ProviderConfig; + /** + * Named BYOK provider connections (transport + credentials), referenced by + * {@link models} entries via {@link NamedProviderConfig.name}. + * + * Unlike the singular {@link provider} — which makes the entire session BYOK + * and bypasses Copilot API authentication — named providers are **additive**: + * they coexist with Copilot API auth so models from CAPI and one or more BYOK + * providers can be mixed within a single session and across sub-agents. + * Combining `providers`/`models` with {@link provider} is rejected. + * + * @experimental This is part of an experimental multi-provider BYOK surface + * and may change or be removed in future SDK or CLI releases. + */ + providers?: NamedProviderConfig[]; + + /** + * BYOK model definitions added to the session's selectable model list, each + * referencing a `providers[].name`. Each model surfaces under the + * provider-qualified selection id `providerName/id`, so BYOK ids never collide + * with — and cannot shadow — bare CAPI ids; duplicate selection ids are rejected. + * + * @experimental This is part of an experimental multi-provider BYOK surface + * and may change or be removed in future SDK or CLI releases. + */ + models?: ProviderModelConfig[]; + /** * Enables or disables internal session telemetry for this session. * When `false`, disables session telemetry. When omitted (the default) or `true`, @@ -2179,8 +2205,133 @@ export interface ProviderConfig { } /** - * Options for sending a message to a session + * A named BYOK provider connection (transport + credentials only), referenced by + * {@link ProviderModelConfig} entries via {@link NamedProviderConfig.name}. + * + * Unlike the singular, whole-session {@link ProviderConfig} — which bypasses + * Copilot API authentication — named providers are **additive** and coexist with + * Copilot API auth, so CAPI and BYOK models can be mixed within one session and + * across sub-agents. See {@link SessionConfigBase.providers}. + * + * @experimental This type is part of an experimental multi-provider BYOK surface + * and may change or be removed in future SDK or CLI releases. */ +export interface NamedProviderConfig { + /** + * Stable identifier referenced by {@link ProviderModelConfig.provider}. + * Must not contain `/`. + */ + name: string; + + /** + * Provider type. Defaults to "openai" for generic OpenAI-compatible APIs. + */ + type?: "openai" | "azure" | "anthropic"; + + /** + * Wire API format (openai/azure only). Defaults to "completions". + */ + wireApi?: "completions" | "responses"; + + /** + * API endpoint URL. + */ + baseUrl: string; + + /** + * API key. Optional for local providers like Ollama. + */ + apiKey?: string; + + /** + * Bearer token for authentication. Sets the Authorization header directly. + * Takes precedence over {@link apiKey} when both are set. + */ + bearerToken?: string; + + /** + * Azure-specific options. + */ + azure?: { + /** + * API version. When set, uses the versioned deployment route. When + * omitted, uses the GA versionless v1 route. + */ + apiVersion?: string; + }; + + /** + * Custom HTTP headers to include in all outbound requests to the provider. + */ + headers?: Record; +} + +/** + * A BYOK model definition that references a {@link NamedProviderConfig} by name + * and is added to the session's selectable model list. + * + * Each model has three identities: + * - {@link id}: the provider-local model id, unique within its provider. The + * session-wide selection id (shown in the model list and passed to model + * switching) is the provider-qualified `provider/id`. + * - {@link modelId}: the well-known behavior base model used for + * capability/config lookup. Defaults to {@link id}. + * - {@link wireModel}: the model name actually sent to the provider API for + * inference. Defaults to {@link id}. + * + * @experimental This type is part of an experimental multi-provider BYOK surface + * and may change or be removed in future SDK or CLI releases. + */ +export interface ProviderModelConfig { + /** + * Provider-local model id, unique within its provider. The session-wide + * selection id is the provider-qualified `provider/id`. + */ + id: string; + + /** + * Name of the {@link NamedProviderConfig} that serves this model. + */ + provider: string; + + /** + * The model name sent to the provider API for inference. Defaults to {@link id}. + */ + wireModel?: string; + + /** + * Well-known base model id used for behavior/capability/config lookup. + * Defaults to {@link id}. + */ + modelId?: string; + + /** + * Display name for model pickers. Defaults to the provider-qualified + * selection id (`provider/id`). + */ + name?: string; + + /** + * Maximum prompt/input tokens for the model. + */ + maxPromptTokens?: number; + + /** + * Maximum context window tokens for the model. + */ + maxContextWindowTokens?: number; + + /** + * Maximum output tokens for the model. + */ + maxOutputTokens?: number; + + /** + * Optional capability overrides (vision, tool_calls, reasoning, etc.) for + * the synthesized model. + */ + capabilities?: ModelCapabilitiesOverride; +} export interface MessageOptions { /** * The prompt/message to send diff --git a/nodejs/test/e2e/multi_provider_registry.e2e.test.ts b/nodejs/test/e2e/multi_provider_registry.e2e.test.ts new file mode 100644 index 000000000..cd0eb5316 --- /dev/null +++ b/nodejs/test/e2e/multi_provider_registry.e2e.test.ts @@ -0,0 +1,213 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll } from "../../src/index.js"; +import type { + CustomAgentConfig, + NamedProviderConfig, + ProviderModelConfig, +} from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; +import { retry } from "./harness/sdkTestHelper.js"; +import type { ParsedHttpExchange } from "../../../test/harness/replayingCapiProxy"; + +/** + * End-to-end coverage for the experimental multi-provider BYOK registry + * (`providers` / `models` on the session config). Validates that several named + * providers, several models per provider, and custom agents bound to those + * provider-qualified models can coexist in one session, be launched, and route + * inference to the configured provider with the configured wire model and + * headers. + */ +describe("Multi-provider BYOK registry", async () => { + const { copilotClient: client, openAiEndpoint } = await createSdkTestContext(); + + async function waitForExchanges(minimumCount = 1): Promise { + await retry( + `capture ${minimumCount} chat completion request(s)`, + async () => { + const exchanges = await openAiEndpoint.getExchanges(); + expect(exchanges.length).toBeGreaterThanOrEqual(minimumCount); + }, + 1_200 + ); + return openAiEndpoint.getExchanges(); + } + + function getHeader(exchange: ParsedHttpExchange, name: string): string | undefined { + const headers = exchange.requestHeaders ?? {}; + const key = Object.keys(headers).find((k) => k.toLowerCase() === name.toLowerCase()); + if (key === undefined) { + return undefined; + } + const value = headers[key]; + return Array.isArray(value) ? value[0] : value; + } + + // A heterogeneous registry: two providers of different types, with multiple + // models each. Provider-qualified selection ids are alpha/sonnet, + // alpha/haiku, beta/opus, beta/haiku. + const registryProviders: NamedProviderConfig[] = [ + { + name: "alpha", + type: "openai", + wireApi: "completions", + baseUrl: "https://alpha.example.test/v1", + apiKey: "alpha-secret", + headers: { "X-Provider": "alpha" }, + }, + { + name: "beta", + type: "anthropic", + baseUrl: "https://beta.example.test", + bearerToken: "beta-bearer", + headers: { "X-Provider": "beta" }, + }, + ]; + const registryModels: ProviderModelConfig[] = [ + { id: "sonnet", provider: "alpha", wireModel: "byok-gpt-4o", maxPromptTokens: 111111 }, + { id: "haiku", provider: "alpha", wireModel: "byok-gpt-4o-mini" }, + { id: "opus", provider: "beta", wireModel: "byok-claude-3-opus" }, + { id: "haiku", provider: "beta", wireModel: "byok-claude-3-haiku" }, + ]; + const registryAgents: CustomAgentConfig[] = [ + { + name: "orchestrator", + displayName: "Orchestrator", + description: "Top-level planner.", + prompt: "Plan and delegate.", + model: "alpha/sonnet", + }, + { + name: "researcher", + displayName: "Researcher", + description: "Deep research subagent.", + prompt: "Research thoroughly.", + model: "beta/opus", + }, + { + name: "fast-helper", + displayName: "Fast Helper", + description: "Quick subagent.", + prompt: "Answer quickly.", + model: "alpha/haiku", + }, + { + name: "summarizer", + displayName: "Summarizer", + description: "Summarizing subagent.", + prompt: "Summarize.", + model: "beta/haiku", + }, + ]; + + it("should register multiple providers with custom agents bound to their models", async () => { + const session = await client.createSession({ + onPermissionRequest: approveAll, + providers: registryProviders, + models: registryModels, + customAgents: registryAgents, + }); + + try { + const { agents } = await session.rpc.agent.list(); + + // All four custom agents coexist in a single session. + expect(agents.length).toBe(4); + + // Each agent is bound to its configured provider-qualified BYOK model. + const byName = new Map(agents.map((a) => [a.name, a])); + expect(byName.get("orchestrator")?.model).toBe("alpha/sonnet"); + expect(byName.get("researcher")?.model).toBe("beta/opus"); + expect(byName.get("fast-helper")?.model).toBe("alpha/haiku"); + expect(byName.get("summarizer")?.model).toBe("beta/haiku"); + + // Models from BOTH providers are represented, proving the two + // providers and their models coexist within the same session. + const boundModels = agents.map((a) => a.model ?? ""); + expect(boundModels.some((m) => m.startsWith("alpha/"))).toBe(true); + expect(boundModels.some((m) => m.startsWith("beta/"))).toBe(true); + } finally { + await session.disconnect(); + } + }); + + async function assertRouting( + selectionId: string, + expectedWireModel: string, + expectedProviderHeader: string + ): Promise { + // Two OpenAI-compatible providers, both pointed at the replay proxy so + // their /chat/completions traffic is captured. They are distinguished on + // the wire by their per-provider X-Provider header. "alpha" carries two + // models (multiple models per provider); "delta" carries one. + const providers: NamedProviderConfig[] = [ + { + name: "alpha", + type: "openai", + wireApi: "completions", + baseUrl: openAiEndpoint.url, + apiKey: "alpha-secret", + headers: { "X-Provider": "alpha" }, + }, + { + name: "delta", + type: "openai", + wireApi: "completions", + baseUrl: openAiEndpoint.url, + apiKey: "delta-secret", + headers: { "X-Provider": "delta" }, + }, + ]; + const models: ProviderModelConfig[] = [ + { id: "sonnet", provider: "alpha", wireModel: "byok-gpt-4o" }, + { id: "haiku", provider: "alpha", wireModel: "byok-gpt-4o-mini" }, + { id: "turbo", provider: "delta", wireModel: "byok-gpt-4-turbo" }, + ]; + + const session = await client.createSession({ + onPermissionRequest: approveAll, + model: selectionId, + providers, + models, + }); + + try { + await session.sendAndWait({ prompt: "What is 5+5?" }); + const exchanges = await waitForExchanges(); + expect(exchanges.length).toBe(1); + const exchange = exchanges[0]; + + // The wire model sent to the provider is the selected model's + // wireModel, not its provider-qualified selection id. + expect(exchange.request.model).toBe(expectedWireModel); + + // The request carried the owning provider's custom header, proving + // the turn was dispatched against the correct provider connection. + expect(getHeader(exchange, "X-Provider")).toBe(expectedProviderHeader); + + // The provider's API key was applied as an Authorization header. + expect(getHeader(exchange, "Authorization")).toBeTruthy(); + } finally { + try { + await session.disconnect(); + } catch { + // disconnect may fail since the BYOK provider URL is fake + } + } + } + + it("should route alpha sonnet turn to its provider and wire model", async () => { + await assertRouting("alpha/sonnet", "byok-gpt-4o", "alpha"); + }); + + it("should route alpha haiku turn to its provider and wire model", async () => { + await assertRouting("alpha/haiku", "byok-gpt-4o-mini", "alpha"); + }); + + it("should route delta turbo turn to its provider and wire model", async () => { + await assertRouting("delta/turbo", "byok-gpt-4-turbo", "delta"); + }); +}); diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index ff2562d68..1bda91072 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -101,6 +101,7 @@ ModelLimitsOverride, ModelSupportsOverride, ModelVisionLimitsOverride, + NamedProviderConfig, PermissionHandler, PermissionNoResult, PermissionRequestResult, @@ -117,6 +118,7 @@ PreToolUseHookInput, PreToolUseHookOutput, ProviderConfig, + ProviderModelConfig, ReasoningSummary, SessionCapabilities, SessionEndHandler, @@ -218,6 +220,7 @@ "ModelSupportsOverride", "ModelVisionLimits", "ModelVisionLimitsOverride", + "NamedProviderConfig", "OpenCanvasInstance", "PermissionHandler", "PermissionNoResult", @@ -237,6 +240,7 @@ "PreToolUseHookInput", "PreToolUseHookOutput", "ProviderConfig", + "ProviderModelConfig", "ReasoningSummary", "RemoteSessionMode", "RuntimeConnection", diff --git a/python/copilot/client.py b/python/copilot/client.py index 0dff0e5ab..2c407149c 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -92,7 +92,9 @@ MCPServerConfig, MemoryConfiguration, ModelCapabilitiesOverride, + NamedProviderConfig, ProviderConfig, + ProviderModelConfig, ReasoningEffort, ReasoningSummary, SectionTransformFn, @@ -1627,6 +1629,8 @@ async def create_session( hooks: SessionHooks | None = None, working_directory: str | None = None, provider: ProviderConfig | None = None, + providers: list[NamedProviderConfig] | None = None, + models: list[ProviderModelConfig] | None = None, enable_session_telemetry: bool | None = None, skip_custom_instructions: bool | None = None, custom_agents_local_only: bool | None = None, @@ -1709,6 +1713,11 @@ async def create_session( hooks: Lifecycle hooks for the session. working_directory: Working directory for the session. provider: Provider configuration for Azure or custom endpoints. + providers: Named BYOK provider connections. Additive to Copilot API + auth (unlike `provider`); combine with `models`. Cannot be + combined with `provider`. + models: BYOK model definitions added to the selectable model list, + each referencing a `providers` entry by name. enable_session_telemetry: Enables or disables internal session telemetry for this session. When False, disables session telemetry. When omitted or True, telemetry is enabled for GitHub-authenticated sessions. When @@ -1916,6 +1925,14 @@ async def create_session( if provider: payload["provider"] = self._convert_provider_to_wire_format(provider) + # Add additive BYOK provider/model registry if provided + if providers: + payload["providers"] = [ + self._convert_named_provider_to_wire_format(p) for p in providers + ] + if models: + payload["models"] = [self._convert_model_to_wire_format(m) for m in models] + if enable_session_telemetry is not None: payload["enableSessionTelemetry"] = enable_session_telemetry @@ -2204,6 +2221,8 @@ async def resume_session( hooks: SessionHooks | None = None, working_directory: str | None = None, provider: ProviderConfig | None = None, + providers: list[NamedProviderConfig] | None = None, + models: list[ProviderModelConfig] | None = None, enable_session_telemetry: bool | None = None, skip_custom_instructions: bool | None = None, custom_agents_local_only: bool | None = None, @@ -2287,6 +2306,11 @@ async def resume_session( hooks: Lifecycle hooks for the session. working_directory: Working directory for the session. provider: Provider configuration for Azure or custom endpoints. + providers: Named BYOK provider connections. Additive to Copilot API + auth (unlike `provider`); combine with `models`. Cannot be + combined with `provider`. + models: BYOK model definitions added to the selectable model list, + each referencing a `providers` entry by name. enable_session_telemetry: Enables or disables internal session telemetry for this session. When False, disables session telemetry. When omitted or True, telemetry is enabled for GitHub-authenticated sessions. When @@ -2437,6 +2461,12 @@ async def resume_session( payload["toolFilterPrecedence"] = "excluded" if provider: payload["provider"] = self._convert_provider_to_wire_format(provider) + if providers: + payload["providers"] = [ + self._convert_named_provider_to_wire_format(p) for p in providers + ] + if models: + payload["models"] = [self._convert_model_to_wire_format(m) for m in models] if enable_session_telemetry is not None: payload["enableSessionTelemetry"] = enable_session_telemetry if model_capabilities: @@ -3159,6 +3189,59 @@ def _convert_provider_to_wire_format( wire_provider["azure"] = wire_azure return wire_provider + def _convert_named_provider_to_wire_format( + self, provider: NamedProviderConfig | dict[str, Any] + ) -> dict[str, Any]: + """Convert a named BYOK provider from snake_case to camelCase wire format.""" + wire: dict[str, Any] = {} + if "name" in provider: + wire["name"] = provider["name"] + if "type" in provider: + wire["type"] = provider["type"] + if "wire_api" in provider: + wire["wireApi"] = provider["wire_api"] + if "base_url" in provider: + wire["baseUrl"] = provider["base_url"] + if "api_key" in provider: + wire["apiKey"] = provider["api_key"] + if "bearer_token" in provider: + wire["bearerToken"] = provider["bearer_token"] + if "headers" in provider: + wire["headers"] = provider["headers"] + if "azure" in provider: + azure = provider["azure"] + wire_azure: dict[str, Any] = {} + if "api_version" in azure: + wire_azure["apiVersion"] = azure["api_version"] + if wire_azure: + wire["azure"] = wire_azure + return wire + + def _convert_model_to_wire_format( + self, model: ProviderModelConfig | dict[str, Any] + ) -> dict[str, Any]: + """Convert a BYOK model definition from snake_case to camelCase wire format.""" + wire: dict[str, Any] = {} + if "id" in model: + wire["id"] = model["id"] + if "provider" in model: + wire["provider"] = model["provider"] + if "wire_model" in model: + wire["wireModel"] = model["wire_model"] + if "model_id" in model: + wire["modelId"] = model["model_id"] + if "name" in model: + wire["name"] = model["name"] + if "max_prompt_tokens" in model: + wire["maxPromptTokens"] = model["max_prompt_tokens"] + if "max_context_window_tokens" in model: + wire["maxContextWindowTokens"] = model["max_context_window_tokens"] + if "max_output_tokens" in model: + wire["maxOutputTokens"] = model["max_output_tokens"] + if "capabilities" in model: + wire["capabilities"] = _capabilities_to_dict(model["capabilities"]) + return wire + def _convert_custom_agent_to_wire_format( self, agent: CustomAgentConfig | dict[str, Any] ) -> dict[str, Any]: diff --git a/python/copilot/session.py b/python/copilot/session.py index 3720af05d..139376ca8 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -1099,6 +1099,61 @@ class ProviderConfig(TypedDict, total=False): max_output_tokens: int +class NamedProviderConfig(TypedDict, total=False): + """A named BYOK provider connection (transport + credentials). + + Referenced by :class:`ProviderModelConfig` entries via ``name``. Unlike the + singular :class:`ProviderConfig` (which makes the whole session BYOK and + bypasses Copilot API authentication), named providers are additive: they + coexist with Copilot API auth so models from CAPI and one or more BYOK + providers can be mixed within a single session and across sub-agents. + + **Experimental.** Multi-provider BYOK configuration is experimental and may + change or be removed in future SDK or CLI releases. + """ + + # Stable identifier referenced by ProviderModelConfig.provider. Must not contain "/". + name: str + type: Literal["openai", "azure", "anthropic"] + wire_api: Literal["completions", "responses"] + base_url: str + api_key: str + # Bearer token for authentication. Sets the Authorization header directly. + # Takes precedence over api_key when both are set. + bearer_token: str + azure: AzureProviderOptions # Azure-specific options + headers: dict[str, str] + + +class ProviderModelConfig(TypedDict, total=False): + """A BYOK model definition that references a :class:`NamedProviderConfig`. + + Added to the session's selectable model list. The session-wide selection id + (shown in the model list and passed to model switching) is the + provider-qualified ``provider/id``, so BYOK ids never collide with bare CAPI + ids. + + **Experimental.** Multi-provider BYOK configuration is experimental and may + change or be removed in future SDK or CLI releases. + """ + + # Provider-local model id, unique within its provider. + id: str + # Name of the NamedProviderConfig that serves this model. + provider: str + # Model name sent to the provider API for inference. Defaults to id. + wire_model: str + # Well-known base model id used for behavior/capability/config lookup. Defaults to id. + model_id: str + # Display name for model pickers. Defaults to the provider-qualified selection id. + name: str + max_prompt_tokens: int + max_context_window_tokens: int + max_output_tokens: int + # Optional capability overrides for the synthesized model. + capabilities: ModelCapabilitiesOverride + + SessionEventHandler = Callable[[SessionEvent], None] diff --git a/python/e2e/test_multi_provider_registry_e2e.py b/python/e2e/test_multi_provider_registry_e2e.py new file mode 100644 index 000000000..a20862455 --- /dev/null +++ b/python/e2e/test_multi_provider_registry_e2e.py @@ -0,0 +1,206 @@ +"""E2E tests for the experimental multi-provider BYOK registry. + +Validates that several named providers, several models per provider, and custom +agents bound to those provider-qualified models can coexist in one session, be +launched, and route inference to the configured provider with the configured +wire model and headers. +""" + +import pytest + +from copilot.session import PermissionHandler + +from .testharness import E2ETestContext + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +def _normalize_headers(headers) -> dict[str, str]: + if isinstance(headers, list): + flat: dict[str, str] = {} + for entry in headers: + if isinstance(entry, dict): + key = entry.get("name") or entry.get("key") + value = entry.get("value") + if key is not None: + flat[str(key).lower()] = str(value) + return flat + if isinstance(headers, dict): + flat = {} + for key, value in headers.items(): + if isinstance(value, list): + flat[str(key).lower()] = ", ".join(str(v) for v in value) + else: + flat[str(key).lower()] = str(value) + return flat + return {} + + +# A heterogeneous registry: two providers of different types, with multiple +# models each. Provider-qualified selection ids are alpha/sonnet, alpha/haiku, +# beta/opus, beta/haiku. +REGISTRY_PROVIDERS = [ + { + "name": "alpha", + "type": "openai", + "wire_api": "completions", + "base_url": "https://alpha.example.test/v1", + "api_key": "alpha-secret", + "headers": {"X-Provider": "alpha"}, + }, + { + "name": "beta", + "type": "anthropic", + "base_url": "https://beta.example.test", + "bearer_token": "beta-bearer", + "headers": {"X-Provider": "beta"}, + }, +] +REGISTRY_MODELS = [ + {"id": "sonnet", "provider": "alpha", "wire_model": "byok-gpt-4o", "max_prompt_tokens": 111111}, + {"id": "haiku", "provider": "alpha", "wire_model": "byok-gpt-4o-mini"}, + {"id": "opus", "provider": "beta", "wire_model": "byok-claude-3-opus"}, + {"id": "haiku", "provider": "beta", "wire_model": "byok-claude-3-haiku"}, +] +REGISTRY_AGENTS = [ + { + "name": "orchestrator", + "display_name": "Orchestrator", + "description": "Top-level planner.", + "prompt": "Plan and delegate.", + "model": "alpha/sonnet", + }, + { + "name": "researcher", + "display_name": "Researcher", + "description": "Deep research subagent.", + "prompt": "Research thoroughly.", + "model": "beta/opus", + }, + { + "name": "fast-helper", + "display_name": "Fast Helper", + "description": "Quick subagent.", + "prompt": "Answer quickly.", + "model": "alpha/haiku", + }, + { + "name": "summarizer", + "display_name": "Summarizer", + "description": "Summarizing subagent.", + "prompt": "Summarize.", + "model": "beta/haiku", + }, +] + + +class TestMultiProviderRegistry: + async def test_should_register_multiple_providers_with_custom_agents_bound_to_their_models( + self, ctx: E2ETestContext + ): + session = await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + providers=REGISTRY_PROVIDERS, + models=REGISTRY_MODELS, + custom_agents=REGISTRY_AGENTS, + ) + + try: + result = await session.rpc.agent.list() + + # All four custom agents coexist in a single session. + assert result.agents is not None + assert len(result.agents) == 4 + + # Each agent is bound to its configured provider-qualified BYOK model. + by_name = {agent.name: agent for agent in result.agents} + assert by_name["orchestrator"].model == "alpha/sonnet" + assert by_name["researcher"].model == "beta/opus" + assert by_name["fast-helper"].model == "alpha/haiku" + assert by_name["summarizer"].model == "beta/haiku" + + # Models from BOTH providers are represented, proving the two + # providers and their models coexist within the same session. + bound_models = [agent.model or "" for agent in result.agents] + assert any(m.startswith("alpha/") for m in bound_models) + assert any(m.startswith("beta/") for m in bound_models) + finally: + await session.disconnect() + + async def _assert_routing( + self, + ctx: E2ETestContext, + selection_id: str, + expected_wire_model: str, + expected_provider_header: str, + ): + # Two OpenAI-compatible providers, both pointed at the replay proxy so + # their /chat/completions traffic is captured. They are distinguished on + # the wire by their per-provider X-Provider header. "alpha" carries two + # models (multiple models per provider); "delta" carries one. + providers = [ + { + "name": "alpha", + "type": "openai", + "wire_api": "completions", + "base_url": ctx.proxy_url, + "api_key": "alpha-secret", + "headers": {"X-Provider": "alpha"}, + }, + { + "name": "delta", + "type": "openai", + "wire_api": "completions", + "base_url": ctx.proxy_url, + "api_key": "delta-secret", + "headers": {"X-Provider": "delta"}, + }, + ] + models = [ + {"id": "sonnet", "provider": "alpha", "wire_model": "byok-gpt-4o"}, + {"id": "haiku", "provider": "alpha", "wire_model": "byok-gpt-4o-mini"}, + {"id": "turbo", "provider": "delta", "wire_model": "byok-gpt-4-turbo"}, + ] + + session = await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + model=selection_id, + providers=providers, + models=models, + ) + + try: + await session.send_and_wait("What is 5+5?") + + exchanges = await ctx.get_exchanges() + assert len(exchanges) == 1 + exchange = exchanges[0] + + # The wire model sent to the provider is the selected model's + # wire_model, not its provider-qualified selection id. + assert exchange["request"]["model"] == expected_wire_model + + # The request carried the owning provider's custom header, proving + # the turn was dispatched against the correct provider connection. + headers = _normalize_headers(exchange.get("requestHeaders")) + assert headers.get("x-provider") == expected_provider_header + + # The provider's API key was applied as an Authorization header. + assert headers.get("authorization") + finally: + await session.disconnect() + + async def test_should_route_alpha_sonnet_turn_to_its_provider_and_wire_model( + self, ctx: E2ETestContext + ): + await self._assert_routing(ctx, "alpha/sonnet", "byok-gpt-4o", "alpha") + + async def test_should_route_alpha_haiku_turn_to_its_provider_and_wire_model( + self, ctx: E2ETestContext + ): + await self._assert_routing(ctx, "alpha/haiku", "byok-gpt-4o-mini", "alpha") + + async def test_should_route_delta_turbo_turn_to_its_provider_and_wire_model( + self, ctx: E2ETestContext + ): + await self._assert_routing(ctx, "delta/turbo", "byok-gpt-4-turbo", "delta") diff --git a/rust/src/types.rs b/rust/src/types.rs index c0643ec66..5c1c0ddf3 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1156,6 +1156,197 @@ pub struct AzureProviderOptions { pub api_version: Option, } +/// A named BYOK provider connection in the multi-provider registry. +/// +/// **Experimental.** Multi-provider BYOK configuration is part of an +/// experimental surface and may change or be removed in a future release. +/// +/// Unlike [`ProviderConfig`], which routes the whole session through a +/// single provider, named providers are additive: the session keeps its +/// default Copilot routing and exposes these providers' models alongside +/// it. Models are attached via [`ProviderModelConfig`], which references a +/// provider by [`name`](Self::name). +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct NamedProviderConfig { + /// Unique name used by [`ProviderModelConfig::provider`] to reference + /// this connection. + pub name: String, + /// Provider type: `"openai"`, `"azure"`, or `"anthropic"`. Defaults to + /// `"openai"` on the CLI. + #[serde(default, skip_serializing_if = "Option::is_none", rename = "type")] + pub provider_type: Option, + /// API format (openai/azure only): `"completions"` or `"responses"`. + /// Defaults to `"completions"`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub wire_api: Option, + /// API endpoint URL. + pub base_url: String, + /// API key. Optional for local providers like Ollama. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub api_key: Option, + /// Bearer token for authentication. Sets the `Authorization` header + /// directly. Takes precedence over `api_key` when both are set. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub bearer_token: Option, + /// Azure-specific options. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub azure: Option, + /// Custom HTTP headers included in outbound provider requests. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub headers: Option>, +} + +impl NamedProviderConfig { + /// Construct a [`NamedProviderConfig`] with the required `name` and + /// `base_url` set; all other fields default to unset. + pub fn new(name: impl Into, base_url: impl Into) -> Self { + Self { + name: name.into(), + base_url: base_url.into(), + ..Self::default() + } + } + + /// Set the provider type (`"openai"`, `"azure"`, or `"anthropic"`). + pub fn with_provider_type(mut self, provider_type: impl Into) -> Self { + self.provider_type = Some(provider_type.into()); + self + } + + /// Set the API format (`"completions"` or `"responses"`; openai/azure only). + pub fn with_wire_api(mut self, wire_api: impl Into) -> Self { + self.wire_api = Some(wire_api.into()); + self + } + + /// Set the API key. Optional for local providers like Ollama. + pub fn with_api_key(mut self, api_key: impl Into) -> Self { + self.api_key = Some(api_key.into()); + self + } + + /// Set the bearer token used to populate the `Authorization` header. + /// Takes precedence over `api_key` when both are set. + pub fn with_bearer_token(mut self, bearer_token: impl Into) -> Self { + self.bearer_token = Some(bearer_token.into()); + self + } + + /// Set Azure-specific options. + pub fn with_azure(mut self, azure: AzureProviderOptions) -> Self { + self.azure = Some(azure); + self + } + + /// Set the custom HTTP headers attached to outbound provider requests. + pub fn with_headers(mut self, headers: HashMap) -> Self { + self.headers = Some(headers); + self + } +} + +/// A BYOK model definition in the multi-provider registry. +/// +/// **Experimental.** Multi-provider BYOK configuration is part of an +/// experimental surface and may change or be removed in a future release. +/// +/// References a [`NamedProviderConfig`] by [`provider`](Self::provider) and +/// becomes selectable under the provider-qualified id `provider/id`. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct ProviderModelConfig { + /// Model identifier, unique within its provider. Combined with + /// [`provider`](Self::provider) to form the selection id `provider/id`. + pub id: String, + /// Name of the [`NamedProviderConfig`] this model is served by. + pub provider: String, + /// Model name sent to the provider API for inference. Use when the + /// provider's model name differs from [`id`](Self::id). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub wire_model: Option, + /// Well-known model ID used to look up agent config and default token + /// limits. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model_id: Option, + /// Human-readable display name. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub name: Option, + /// Overrides the resolved model's default max prompt tokens. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_prompt_tokens: Option, + /// Overrides the resolved model's default max context window tokens. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_context_window_tokens: Option, + /// Overrides the resolved model's default max output tokens. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + /// Per-property overrides for model capabilities, deep-merged over + /// runtime defaults. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub capabilities: Option, +} + +impl ProviderModelConfig { + /// Construct a [`ProviderModelConfig`] with the required `id` and + /// `provider` set; all other fields default to unset. + pub fn new(id: impl Into, provider: impl Into) -> Self { + Self { + id: id.into(), + provider: provider.into(), + ..Self::default() + } + } + + /// Set the model name sent to the provider API for inference. + pub fn with_wire_model(mut self, wire_model: impl Into) -> Self { + self.wire_model = Some(wire_model.into()); + self + } + + /// Set the well-known model ID used to look up agent config and default + /// token limits. + pub fn with_model_id(mut self, model_id: impl Into) -> Self { + self.model_id = Some(model_id.into()); + self + } + + /// Set the human-readable display name. + pub fn with_name(mut self, name: impl Into) -> Self { + self.name = Some(name.into()); + self + } + + /// Override the resolved model's default max prompt tokens. + pub fn with_max_prompt_tokens(mut self, max: i64) -> Self { + self.max_prompt_tokens = Some(max); + self + } + + /// Override the resolved model's default max context window tokens. + pub fn with_max_context_window_tokens(mut self, max: i64) -> Self { + self.max_context_window_tokens = Some(max); + self + } + + /// Override the resolved model's default max output tokens. + pub fn with_max_output_tokens(mut self, max: i64) -> Self { + self.max_output_tokens = Some(max); + self + } + + /// Set per-property model capability overrides. + pub fn with_capabilities( + mut self, + capabilities: crate::generated::api_types::ModelCapabilitiesOverride, + ) -> Self { + self.capabilities = Some(capabilities); + self + } +} + /// Configuration for creating a new session via the `session.create` RPC. /// /// All fields are optional — the CLI applies sensible defaults. @@ -1341,6 +1532,19 @@ pub struct SessionConfig { /// requests through this provider instead of the default Copilot /// routing. pub provider: Option, + /// **Experimental.** This field is part of an experimental multi-provider + /// BYOK surface and may change or be removed in a future release. + /// + /// Named BYOK provider connections. Additive to the default Copilot + /// routing — unlike [`provider`](Self::provider), these do not switch + /// the whole session to BYOK. Referenced by [`models`](Self::models). + pub providers: Option>, + /// **Experimental.** This field is part of an experimental multi-provider + /// BYOK surface and may change or be removed in a future release. + /// + /// BYOK model definitions, each referencing a [`providers`](Self::providers) + /// entry by name. Selectable under the id `provider/id`. + pub models: Option>, /// Enables or disables internal session telemetry for this session. /// /// When `Some(false)`, disables session telemetry. When `None` or @@ -1594,6 +1798,8 @@ impl Default for SessionConfig { agent: None, infinite_sessions: None, provider: None, + providers: None, + models: None, enable_session_telemetry: None, model_capabilities: None, memory: None, @@ -1737,6 +1943,8 @@ impl SessionConfig { agent: self.agent, infinite_sessions: self.infinite_sessions, provider: self.provider, + providers: self.providers, + models: self.models, enable_session_telemetry: self.enable_session_telemetry, model_capabilities: self.model_capabilities, memory: self.memory, @@ -2154,6 +2362,26 @@ impl SessionConfig { self } + /// **Experimental.** This method is part of an experimental multi-provider + /// BYOK surface and may change or be removed in a future release. + /// + /// Set the named BYOK provider connections (additive multi-provider + /// registry). Attach models referencing these with [`Self::with_models`]. + pub fn with_providers(mut self, providers: Vec) -> Self { + self.providers = Some(providers); + self + } + + /// **Experimental.** This method is part of an experimental multi-provider + /// BYOK surface and may change or be removed in a future release. + /// + /// Set the BYOK model definitions, each referencing a named provider + /// supplied via [`Self::with_providers`]. + pub fn with_models(mut self, models: Vec) -> Self { + self.models = Some(models); + self + } + /// Enable or disable internal session telemetry. /// /// See [`Self::enable_session_telemetry`] for default and BYOK behavior. @@ -2349,6 +2577,18 @@ pub struct ResumeSessionConfig { pub infinite_sessions: Option, /// Re-supply BYOK provider configuration on resume. pub provider: Option, + /// **Experimental.** This field is part of an experimental multi-provider + /// BYOK surface and may change or be removed in a future release. + /// + /// Re-supply named BYOK provider connections on resume. Additive to + /// the default Copilot routing. Referenced by [`models`](Self::models). + pub providers: Option>, + /// **Experimental.** This field is part of an experimental multi-provider + /// BYOK surface and may change or be removed in a future release. + /// + /// Re-supply BYOK model definitions on resume, each referencing a + /// [`providers`](Self::providers) entry by name. + pub models: Option>, /// Enables or disables internal session telemetry for this session. /// /// When `Some(false)`, disables session telemetry. When `None` or @@ -2626,6 +2866,8 @@ impl ResumeSessionConfig { agent: self.agent, infinite_sessions: self.infinite_sessions, provider: self.provider, + providers: self.providers, + models: self.models, enable_session_telemetry: self.enable_session_telemetry, model_capabilities: self.model_capabilities, memory: self.memory, @@ -2703,6 +2945,8 @@ impl ResumeSessionConfig { agent: None, infinite_sessions: None, provider: None, + providers: None, + models: None, enable_session_telemetry: None, model_capabilities: None, memory: None, @@ -3093,6 +3337,26 @@ impl ResumeSessionConfig { self } + /// **Experimental.** This method is part of an experimental multi-provider + /// BYOK surface and may change or be removed in a future release. + /// + /// Re-supply the named BYOK provider connections on resume. Attach + /// models referencing these with [`Self::with_models`]. + pub fn with_providers(mut self, providers: Vec) -> Self { + self.providers = Some(providers); + self + } + + /// **Experimental.** This method is part of an experimental multi-provider + /// BYOK surface and may change or be removed in a future release. + /// + /// Re-supply the BYOK model definitions on resume, each referencing a + /// named provider supplied via [`Self::with_providers`]. + pub fn with_models(mut self, models: Vec) -> Self { + self.models = Some(models); + self + } + /// Enable or disable internal session telemetry on resume. /// /// See [`Self::enable_session_telemetry`] for default and BYOK behavior. @@ -4356,9 +4620,10 @@ mod tests { use super::{ AgentMode, Attachment, AttachmentLineRange, AttachmentSelectionPosition, - AttachmentSelectionRange, ConnectionState, CustomAgentConfig, DeliveryMode, ExtensionInfo, - GitHubReferenceType, InfiniteSessionConfig, LargeToolOutputConfig, MemoryConfiguration, - ProviderConfig, ReasoningSummary, ResumeSessionConfig, SessionConfig, SessionEvent, + AttachmentSelectionRange, AzureProviderOptions, ConnectionState, CustomAgentConfig, + DeliveryMode, ExtensionInfo, GitHubReferenceType, InfiniteSessionConfig, + LargeToolOutputConfig, MemoryConfiguration, NamedProviderConfig, ProviderConfig, + ProviderModelConfig, ReasoningSummary, ResumeSessionConfig, SessionConfig, SessionEvent, SessionId, SystemMessageConfig, Tool, ToolBinaryResult, ToolResult, ToolResultExpanded, ToolResultResponse, ensure_attachment_display_names, }; @@ -4656,6 +4921,80 @@ mod tests { assert!(empty_json.get("cloud").is_none()); } + #[test] + fn session_config_into_wire_serializes_named_providers_and_models() { + let cfg = SessionConfig::default() + .with_providers(vec![ + NamedProviderConfig::new("my-openai", "https://api.example.com/v1") + .with_provider_type("openai") + .with_wire_api("responses") + .with_api_key("sk-test"), + ]) + .with_models(vec![ + ProviderModelConfig::new("gpt-x", "my-openai") + .with_wire_model("gpt-x-2025") + .with_max_output_tokens(2048), + ]); + + let (wire, _) = cfg + .into_wire(Some(SessionId::from("sess-providers"))) + .expect("no duplicate handlers"); + let wire_json = serde_json::to_value(&wire).unwrap(); + assert_eq!(wire_json["providers"][0]["name"], "my-openai"); + assert_eq!( + wire_json["providers"][0]["baseUrl"], + "https://api.example.com/v1" + ); + assert_eq!(wire_json["providers"][0]["type"], "openai"); + assert_eq!(wire_json["providers"][0]["wireApi"], "responses"); + assert_eq!(wire_json["providers"][0]["apiKey"], "sk-test"); + assert_eq!(wire_json["models"][0]["id"], "gpt-x"); + assert_eq!(wire_json["models"][0]["provider"], "my-openai"); + assert_eq!(wire_json["models"][0]["wireModel"], "gpt-x-2025"); + assert_eq!(wire_json["models"][0]["maxOutputTokens"], 2048); + + let (empty_wire, _) = SessionConfig::default() + .into_wire(Some(SessionId::from("empty"))) + .expect("default has no duplicate handlers"); + let empty_json = serde_json::to_value(&empty_wire).unwrap(); + assert!(empty_json.get("providers").is_none()); + assert!(empty_json.get("models").is_none()); + } + + #[test] + fn resume_config_into_wire_serializes_named_providers_and_models() { + let cfg = ResumeSessionConfig::new(SessionId::from("sess-resume")) + .with_providers(vec![ + NamedProviderConfig::new("my-azure", "https://example.openai.azure.com") + .with_provider_type("azure") + .with_azure(AzureProviderOptions { + api_version: Some("2024-10-21".to_string()), + }), + ]) + .with_models(vec![ + ProviderModelConfig::new("deploy-1", "my-azure").with_model_id("gpt-4o"), + ]); + + let (wire, _) = cfg.into_wire().expect("no duplicate handlers"); + let wire_json = serde_json::to_value(&wire).unwrap(); + assert_eq!(wire_json["providers"][0]["name"], "my-azure"); + assert_eq!(wire_json["providers"][0]["type"], "azure"); + assert_eq!( + wire_json["providers"][0]["azure"]["apiVersion"], + "2024-10-21" + ); + assert_eq!(wire_json["models"][0]["id"], "deploy-1"); + assert_eq!(wire_json["models"][0]["provider"], "my-azure"); + assert_eq!(wire_json["models"][0]["modelId"], "gpt-4o"); + + let (empty_wire, _) = ResumeSessionConfig::new(SessionId::from("empty")) + .into_wire() + .expect("default has no duplicate handlers"); + let empty_json = serde_json::to_value(&empty_wire).unwrap(); + assert!(empty_json.get("providers").is_none()); + assert!(empty_json.get("models").is_none()); + } + #[test] fn session_config_into_wire_serializes_plugin_directories_and_large_output() { use std::path::PathBuf; diff --git a/rust/src/wire.rs b/rust/src/wire.rs index 1b58abacd..cc9968100 100644 --- a/rust/src/wire.rs +++ b/rust/src/wire.rs @@ -26,7 +26,7 @@ use crate::generated::session_events::ReasoningSummary; use crate::types::{ CloudSessionOptions, CustomAgentConfig, DefaultAgentConfig, ExtensionInfo, InfiniteSessionConfig, LargeToolOutputConfig, McpServerConfig, MemoryConfiguration, - ProviderConfig, SessionId, SystemMessageConfig, Tool, + NamedProviderConfig, ProviderConfig, ProviderModelConfig, SessionId, SystemMessageConfig, Tool, }; /// Wire representation of a slash command (name + description only). The @@ -130,6 +130,10 @@ pub(crate) struct SessionCreateWire { #[serde(skip_serializing_if = "Option::is_none")] pub provider: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub providers: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub models: Option>, + #[serde(skip_serializing_if = "Option::is_none")] pub enable_session_telemetry: Option, #[serde(skip_serializing_if = "Option::is_none")] pub model_capabilities: Option, @@ -239,6 +243,10 @@ pub(crate) struct SessionResumeWire { #[serde(skip_serializing_if = "Option::is_none")] pub provider: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub providers: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub models: Option>, + #[serde(skip_serializing_if = "Option::is_none")] pub enable_session_telemetry: Option, #[serde(skip_serializing_if = "Option::is_none")] pub model_capabilities: Option, diff --git a/rust/tests/e2e.rs b/rust/tests/e2e.rs index 0df63e15e..04fe0b2ee 100644 --- a/rust/tests/e2e.rs +++ b/rust/tests/e2e.rs @@ -41,6 +41,8 @@ mod mode_handlers; mod multi_client; #[path = "e2e/multi_client_commands_elicitation.rs"] mod multi_client_commands_elicitation; +#[path = "e2e/multi_provider_registry.rs"] +mod multi_provider_registry; #[path = "e2e/multi_turn.rs"] mod multi_turn; #[path = "e2e/pending_work_resume.rs"] diff --git a/rust/tests/e2e/multi_provider_registry.rs b/rust/tests/e2e/multi_provider_registry.rs new file mode 100644 index 000000000..8c37deaa2 --- /dev/null +++ b/rust/tests/e2e/multi_provider_registry.rs @@ -0,0 +1,243 @@ +use std::collections::HashMap; + +use github_copilot_sdk::{ + CustomAgentConfig, MessageOptions, NamedProviderConfig, ProviderModelConfig, +}; +use serde_json::Value; + +use super::support::with_e2e_context; + +const CATEGORY: &str = "multi_provider_registry"; + +fn headers(provider: &str) -> HashMap { + let mut map = HashMap::new(); + map.insert("X-Provider".to_string(), provider.to_string()); + map +} + +#[tokio::test] +async fn should_register_multiple_providers_with_custom_agents_bound_to_their_models() { + with_e2e_context( + CATEGORY, + "should_register_multiple_providers_with_custom_agents_bound_to_their_models", + |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + + // A heterogeneous registry: two providers of different types, + // with multiple models each. Provider-qualified selection ids + // are alpha/sonnet, alpha/haiku, beta/opus, beta/haiku. + let session = client + .create_session( + ctx.approve_all_session_config() + .with_providers(vec![ + NamedProviderConfig::new("alpha", "https://alpha.example.test/v1") + .with_provider_type("openai") + .with_wire_api("completions") + .with_api_key("alpha-secret") + .with_headers(headers("alpha")), + NamedProviderConfig::new("beta", "https://beta.example.test") + .with_provider_type("anthropic") + .with_bearer_token("beta-bearer") + .with_headers(headers("beta")), + ]) + .with_models(vec![ + ProviderModelConfig::new("sonnet", "alpha") + .with_wire_model("byok-gpt-4o") + .with_max_prompt_tokens(111_111), + ProviderModelConfig::new("haiku", "alpha") + .with_wire_model("byok-gpt-4o-mini"), + ProviderModelConfig::new("opus", "beta") + .with_wire_model("byok-claude-3-opus"), + ProviderModelConfig::new("haiku", "beta") + .with_wire_model("byok-claude-3-haiku"), + ]) + .with_custom_agents([ + CustomAgentConfig::new("orchestrator", "Plan and delegate.") + .with_display_name("Orchestrator") + .with_description("Top-level planner.") + .with_model("alpha/sonnet"), + CustomAgentConfig::new("researcher", "Research thoroughly.") + .with_display_name("Researcher") + .with_description("Deep research subagent.") + .with_model("beta/opus"), + CustomAgentConfig::new("fast-helper", "Answer quickly.") + .with_display_name("Fast Helper") + .with_description("Quick subagent.") + .with_model("alpha/haiku"), + CustomAgentConfig::new("summarizer", "Summarize.") + .with_display_name("Summarizer") + .with_description("Summarizing subagent.") + .with_model("beta/haiku"), + ]), + ) + .await + .expect("create session"); + + let result = session.rpc().agent().list().await.expect("agent list"); + + // All four custom agents coexist in a single session. + assert_eq!(result.agents.len(), 4, "expected 4 custom agents"); + + // Each agent is bound to its configured provider-qualified model. + let bound = |name: &str| { + result + .agents + .iter() + .find(|agent| agent.name == name) + .and_then(|agent| agent.model.clone()) + .unwrap_or_default() + }; + assert_eq!(bound("orchestrator"), "alpha/sonnet"); + assert_eq!(bound("researcher"), "beta/opus"); + assert_eq!(bound("fast-helper"), "alpha/haiku"); + assert_eq!(bound("summarizer"), "beta/haiku"); + + // Models from BOTH providers are represented, proving the two + // providers and their models coexist within the same session. + let models: Vec = result + .agents + .iter() + .filter_map(|agent| agent.model.clone()) + .collect(); + assert!( + models.iter().any(|m| m.starts_with("alpha/")), + "expected an alpha-bound agent", + ); + assert!( + models.iter().any(|m| m.starts_with("beta/")), + "expected a beta-bound agent", + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }, + ) + .await; +} + +async fn assert_routing( + snapshot_name: &'static str, + selection_id: &'static str, + expected_wire_model: &'static str, + expected_provider_header: &'static str, +) { + with_e2e_context(CATEGORY, snapshot_name, move |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + + // Two OpenAI-compatible providers, both pointed at the replay proxy + // so their /chat/completions traffic is captured. They are + // distinguished on the wire by their per-provider X-Provider + // header. "alpha" carries two models (multiple models per + // provider); "delta" carries one. + let proxy_url = ctx.proxy_url().to_string(); + let session = client + .create_session( + ctx.approve_all_session_config() + .with_model(selection_id) + .with_providers(vec![ + NamedProviderConfig::new("alpha", proxy_url.clone()) + .with_provider_type("openai") + .with_wire_api("completions") + .with_api_key("alpha-secret") + .with_headers(headers("alpha")), + NamedProviderConfig::new("delta", proxy_url.clone()) + .with_provider_type("openai") + .with_wire_api("completions") + .with_api_key("delta-secret") + .with_headers(headers("delta")), + ]) + .with_models(vec![ + ProviderModelConfig::new("sonnet", "alpha") + .with_wire_model("byok-gpt-4o"), + ProviderModelConfig::new("haiku", "alpha") + .with_wire_model("byok-gpt-4o-mini"), + ProviderModelConfig::new("turbo", "delta") + .with_wire_model("byok-gpt-4-turbo"), + ]), + ) + .await + .expect("create session"); + + session + .send_and_wait(MessageOptions::new("What is 5+5?")) + .await + .expect("send"); + + let exchanges = ctx.exchanges(); + assert_eq!(exchanges.len(), 1, "expected exactly one captured exchange"); + let exchange = &exchanges[0]; + + // The wire model sent to the provider is the selected model's wire + // model, not its provider-qualified selection id. + let model = exchange + .get("request") + .and_then(|request| request.get("model")) + .and_then(Value::as_str) + .expect("request model"); + assert_eq!(model, expected_wire_model); + + let request_headers = exchange + .get("requestHeaders") + .and_then(Value::as_object) + .expect("request headers"); + + // The request carried the owning provider's custom header, proving + // the turn was dispatched against the correct provider connection. + let provider_header = request_headers + .iter() + .find(|(key, _)| key.eq_ignore_ascii_case("x-provider")) + .and_then(|(_, value)| value.as_str()) + .expect("x-provider header"); + assert_eq!(provider_header, expected_provider_header); + + // The provider's API key was applied as an Authorization header. + let has_authorization = request_headers + .iter() + .any(|(key, _)| key.eq_ignore_ascii_case("authorization")); + assert!(has_authorization, "expected an Authorization header"); + + // disconnect may fail since the BYOK provider URL is the proxy + let _ = session.disconnect().await; + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +#[tokio::test] +async fn should_route_alpha_sonnet_turn_to_its_provider_and_wire_model() { + assert_routing( + "should_route_alpha_sonnet_turn_to_its_provider_and_wire_model", + "alpha/sonnet", + "byok-gpt-4o", + "alpha", + ) + .await; +} + +#[tokio::test] +async fn should_route_alpha_haiku_turn_to_its_provider_and_wire_model() { + assert_routing( + "should_route_alpha_haiku_turn_to_its_provider_and_wire_model", + "alpha/haiku", + "byok-gpt-4o-mini", + "alpha", + ) + .await; +} + +#[tokio::test] +async fn should_route_delta_turbo_turn_to_its_provider_and_wire_model() { + assert_routing( + "should_route_delta_turbo_turn_to_its_provider_and_wire_model", + "delta/turbo", + "byok-gpt-4-turbo", + "delta", + ) + .await; +} diff --git a/test/snapshots/multi_provider_registry/should_register_multiple_providers_with_custom_agents_bound_to_their_models.yaml b/test/snapshots/multi_provider_registry/should_register_multiple_providers_with_custom_agents_bound_to_their_models.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/multi_provider_registry/should_register_multiple_providers_with_custom_agents_bound_to_their_models.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/multi_provider_registry/should_route_alpha_haiku_turn_to_its_provider_and_wire_model.yaml b/test/snapshots/multi_provider_registry/should_route_alpha_haiku_turn_to_its_provider_and_wire_model.yaml new file mode 100644 index 000000000..c669af9ad --- /dev/null +++ b/test/snapshots/multi_provider_registry/should_route_alpha_haiku_turn_to_its_provider_and_wire_model.yaml @@ -0,0 +1,10 @@ +models: + - byok-gpt-4o-mini +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 5+5? + - role: assistant + content: 5 + 5 = 10 diff --git a/test/snapshots/multi_provider_registry/should_route_alpha_sonnet_turn_to_its_provider_and_wire_model.yaml b/test/snapshots/multi_provider_registry/should_route_alpha_sonnet_turn_to_its_provider_and_wire_model.yaml new file mode 100644 index 000000000..faa2379e8 --- /dev/null +++ b/test/snapshots/multi_provider_registry/should_route_alpha_sonnet_turn_to_its_provider_and_wire_model.yaml @@ -0,0 +1,10 @@ +models: + - byok-gpt-4o +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 5+5? + - role: assistant + content: 5 + 5 = 10 diff --git a/test/snapshots/multi_provider_registry/should_route_delta_turbo_turn_to_its_provider_and_wire_model.yaml b/test/snapshots/multi_provider_registry/should_route_delta_turbo_turn_to_its_provider_and_wire_model.yaml new file mode 100644 index 000000000..f0dc69b50 --- /dev/null +++ b/test/snapshots/multi_provider_registry/should_route_delta_turbo_turn_to_its_provider_and_wire_model.yaml @@ -0,0 +1,10 @@ +models: + - byok-gpt-4-turbo +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 5+5? + - role: assistant + content: 5 + 5 = 10