diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 9859dec90..871bdb85f 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -630,6 +630,7 @@ private CopilotSession InitializeSession( this); session.RegisterTools(config.Tools ?? []); session.RegisterPermissionHandler(config.OnPermissionRequest); + session.RegisterMcpAuthHandler(config.OnMcpAuthRequest); session.RegisterCommands(config.Commands); session.RegisterElicitationHandler(config.OnElicitationRequest); session.RegisterExitPlanModeHandler(config.OnExitPlanModeRequest); @@ -937,6 +938,10 @@ public async Task CreateSessionAsync(SessionConfig config, Cance transformCallbacks, hasHooks, "CopilotClient.CreateSessionAsync"); + if (config.OnMcpAuthRequest is not null) + { + await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken); + } } try { @@ -1050,6 +1055,12 @@ public async Task CreateSessionAsync(SessionConfig config, Cance $"session.create returned sessionId {response.SessionId} but the caller requested {localSessionId}."); } + // Local IDs registered before create; server-assigned IDs can only register now. + if (localSessionId is null && config.OnMcpAuthRequest is not null) + { + await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken); + } + session.WorkspacePath = response.WorkspacePath; session.SetCapabilities(response.Capabilities); session.SetOpenCanvases(response.OpenCanvases); @@ -1136,6 +1147,10 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes transformCallbacks, hasHooks, "CopilotClient.ResumeSessionAsync"); + if (config.OnMcpAuthRequest is not null) + { + await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken); + } try { diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 095c1abf7..fb501db9a 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -62,6 +62,7 @@ public sealed partial class CopilotSession : IAsyncDisposable private readonly CopilotClient _parentClient; private volatile Func>? _permissionHandler; + private volatile Func>? _mcpAuthHandler; private volatile Func>? _userInputHandler; private volatile Func>? _elicitationHandler; private volatile Func>? _exitPlanModeHandler; @@ -561,6 +562,11 @@ internal void RegisterPermissionHandler(Func>? handler) + { + _mcpAuthHandler = handler; + } + /// /// Handles a permission request from the Copilot CLI. /// @@ -636,6 +642,37 @@ private async Task HandleBroadcastEventAsync(SessionEvent sessionEvent) break; } + case McpOauthRequiredEvent authEvent: + { + var data = authEvent.Data; + if (string.IsNullOrEmpty(data.RequestId)) + return; + + var handler = _mcpAuthHandler; + if (handler is null) + { + if (_logger.IsEnabled(LogLevel.Warning)) + { + _logger.LogWarning( + "Received MCP OAuth request without a registered MCP auth handler. SessionId={SessionId}, RequestId={RequestId}", + SessionId, + data.RequestId); + } + return; + } + + await ExecuteMcpAuthAndRespondAsync(data.RequestId, new McpAuthContext + { + SessionId = SessionId, + ServerName = data.ServerName, + ServerUrl = data.ServerUrl, + WwwAuthenticateParams = data.WwwAuthenticateParams, + ResourceMetadata = data.ResourceMetadata, + StaticClientConfig = data.StaticClientConfig + }, handler); + break; + } + case CommandExecuteEvent cmdEvent: { var data = cmdEvent.Data; @@ -705,6 +742,40 @@ await HandleElicitationRequestAsync( } } + private async Task ExecuteMcpAuthAndRespondAsync( + string requestId, + McpAuthContext context, + Func> handler) + { + try + { + var result = await handler(context); + McpOauthPendingRequestResponse response = + result is { Cancelled: false, Token: { } token } + ? new McpOauthPendingRequestResponseToken + { + AccessToken = token.AccessToken, + TokenType = token.TokenType, + RefreshToken = token.RefreshToken, + ExpiresIn = token.ExpiresIn + } + : new McpOauthPendingRequestResponseCancelled(); + + await Rpc.Mcp.Oauth.HandlePendingRequestAsync(requestId, response); + } + catch (Exception) + { + try + { + await Rpc.Mcp.Oauth.HandlePendingRequestAsync(requestId, new McpOauthPendingRequestResponseCancelled()); + } + catch (Exception rpcEx) when (rpcEx is IOException or ObjectDisposedException) + { + // Connection lost or RPC error — nothing we can do. + } + } + } + /// /// Executes a tool handler and sends the result back via the HandlePendingToolCall RPC. /// diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 5f9d8f861..730ba8a45 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -1126,6 +1126,69 @@ public sealed class ElicitationContext public string? Url { get; set; } } +/// +/// Context for an MCP OAuth request callback. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class McpAuthContext +{ + /// Identifier of the session that triggered the MCP OAuth request. + public string SessionId { get; set; } = string.Empty; + + /// Display name of the MCP server that requires OAuth. + public string ServerName { get; set; } = string.Empty; + + /// URL of the MCP server that requires OAuth. + public string ServerUrl { get; set; } = string.Empty; + + /// Parsed WWW-Authenticate parameters from the MCP server, if available. + public McpOauthWWWAuthenticateParams? WwwAuthenticateParams { get; set; } + + /// Raw RFC 9728 protected-resource metadata JSON fetched by the runtime, if available. + public string? ResourceMetadata { get; set; } + + /// Static OAuth client configuration, if the server specifies one. + public McpOauthRequiredStaticClientConfig? StaticClientConfig { get; set; } +} + +/// +/// Host-provided OAuth token data for a pending MCP OAuth request. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class McpAuthToken +{ + /// Access token acquired by the SDK host. + public required string AccessToken { get; set; } + + /// OAuth token type. Defaults to Bearer when omitted. + public string? TokenType { get; set; } + + /// Refresh token supplied by the host, if available. + public string? RefreshToken { get; set; } + + /// Token lifetime in seconds, if known. + public long? ExpiresIn { get; set; } +} + +/// +/// Result returned by an MCP auth request handler. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class McpAuthResult +{ + /// Whether the request should be cancelled instead of resolved with a token. + public bool Cancelled { get; set; } + + /// Host-provided token data. Ignored when is true. + public McpAuthToken? Token { get; set; } + + /// Create a token result. + public static McpAuthResult FromToken(McpAuthToken token) => new() { Token = token }; + + /// Create a cancellation result. + public static McpAuthResult Cancel() => new() { Cancelled = true }; +} + // ============================================================================ // Session Capabilities // ============================================================================ @@ -2662,6 +2725,7 @@ protected SessionConfigBase(SessionConfigBase? other) OnElicitationRequest = other.OnElicitationRequest; OnEvent = other.OnEvent; OnExitPlanModeRequest = other.OnExitPlanModeRequest; + OnMcpAuthRequest = other.OnMcpAuthRequest; OnPermissionRequest = other.OnPermissionRequest; OnUserInputRequest = other.OnUserInputRequest; Provider = other.Provider; @@ -3105,6 +3169,14 @@ protected SessionConfigBase(SessionConfigBase? other) [JsonIgnore] public ICanvasHandler? CanvasHandler { get; set; } #pragma warning restore GHCP001 + + /// + /// Optional handler for MCP OAuth requests from MCP servers. + /// When provided, the SDK can satisfy MCP server OAuth requests with host-provided token data or cancellation. + /// + [Experimental(Diagnostics.Experimental)] + [JsonIgnore] + public Func>? OnMcpAuthRequest { get; set; } } /// diff --git a/dotnet/test/Unit/ClientSessionLifetimeTests.cs b/dotnet/test/Unit/ClientSessionLifetimeTests.cs index 2c11c7d6b..2d1dc4efc 100644 --- a/dotnet/test/Unit/ClientSessionLifetimeTests.cs +++ b/dotnet/test/Unit/ClientSessionLifetimeTests.cs @@ -16,6 +16,8 @@ namespace GitHub.Copilot.Test.Unit; public sealed class ClientSessionLifetimeTests { + private sealed record RpcRequestRecord(string Method, JsonElement Params); + [Fact] public async Task StopAsync_Requests_Runtime_Shutdown_For_Owned_Process() { @@ -188,6 +190,124 @@ public async Task ResumeSessionAsync_Throws_When_Same_Client_Already_Tracks_Sess AssertSessionCount(client, sessions: 1); } + [Fact] + public async Task CreateSessionAsync_Registers_McpAuth_Interest_Only_When_Handler_Configured() + { + await using var server = await FakeCopilotServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + await using var withoutAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnEvent = _ => { } + }); + + Assert.DoesNotContain(server.Requests, request => + request.Method == "session.eventLog.registerInterest" + && request.Params.GetProperty("eventType").GetString() == "mcp.oauth_required"); + Assert.Contains(server.Requests, request => + request.Method == "session.create" + && request.Params.GetProperty("requestPermission").GetBoolean()); + + server.ClearRequests(); + + await using var withAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = _ => Task.FromResult(McpAuthResult.Cancel()) + }); + + Assert.Collection( + server.Requests.Take(2), + request => + { + Assert.Equal("session.eventLog.registerInterest", request.Method); + Assert.Equal("mcp.oauth_required", request.Params.GetProperty("eventType").GetString()); + }, + request => Assert.Equal("session.create", request.Method)); + } + + [Fact] + public async Task CreateSessionAsync_Registers_McpAuth_Interest_After_Cloud_Create_When_Handler_Configured() + { + await using var server = await FakeCopilotServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + var cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository + { + Owner = "github", + Name = "copilot-sdk", + Branch = "main" + } + }; + + await using var withoutAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = cloud + }); + + Assert.DoesNotContain(server.Requests, request => + request.Method == "session.eventLog.registerInterest" + && request.Params.GetProperty("eventType").GetString() == "mcp.oauth_required"); + + server.ClearRequests(); + + await using var withAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = _ => Task.FromResult(McpAuthResult.Cancel()), + Cloud = cloud + }); + + Assert.Collection( + server.Requests.Take(2), + request => Assert.Equal("session.create", request.Method), + request => + { + Assert.Equal("session.eventLog.registerInterest", request.Method); + Assert.Equal("mcp.oauth_required", request.Params.GetProperty("eventType").GetString()); + }); + } + + [Fact] + public async Task ResumeSessionAsync_Registers_McpAuth_Interest_Only_When_Handler_Configured() + { + await using var server = await FakeCopilotServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + await using var withoutAuth = await client.ResumeSessionAsync("session-without-auth", new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnEvent = _ => { } + }); + + Assert.DoesNotContain(server.Requests, request => + request.Method == "session.eventLog.registerInterest" + && request.Params.GetProperty("eventType").GetString() == "mcp.oauth_required"); + Assert.Contains(server.Requests, request => + request.Method == "session.resume" + && request.Params.GetProperty("requestPermission").GetBoolean()); + + server.ClearRequests(); + + await using var withAuth = await client.ResumeSessionAsync("session-with-auth", new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = _ => Task.FromResult(McpAuthResult.Cancel()) + }); + + Assert.Collection( + server.Requests.Take(2), + request => + { + Assert.Equal("session.eventLog.registerInterest", request.Method); + Assert.Equal("mcp.oauth_required", request.Params.GetProperty("eventType").GetString()); + }, + request => Assert.Equal("session.resume", request.Method)); + } + [Fact] public async Task Generated_Session_Rpc_Throws_When_Session_Disposed() { @@ -277,6 +397,8 @@ private sealed class FakeCopilotServer : IAsyncDisposable private readonly TaskCompletionSource _destroyStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); private readonly TaskCompletionSource _allowDestroy = new(TaskCreationOptions.RunContinuationsAsynchronously); private readonly Task _serverTask; + private readonly List _requests = []; + private readonly object _requestsLock = new(); private string? _lastSessionId; private bool _delayDestroy; private bool _failRuntimeShutdown; @@ -307,6 +429,25 @@ public static Task StartAsync() public int RuntimeShutdownCount { get; private set; } + public IReadOnlyList Requests + { + get + { + lock (_requestsLock) + { + return _requests.ToArray(); + } + } + } + + public void ClearRequests() + { + lock (_requestsLock) + { + _requests.Clear(); + } + } + public void DelayDestroy() { _delayDestroy = true; @@ -382,6 +523,13 @@ private async Task HandleRequestAsync(Stream stream, JsonElement request, Cancel return; } + var paramsElement = request.TryGetProperty("params", out var rawParams) + ? rawParams.Clone() + : JsonDocument.Parse("{}").RootElement.Clone(); + lock (_requestsLock) + { + _requests.Add(new RpcRequestRecord(method!, paramsElement)); + } object? result = method switch { "connect" => new Dictionary @@ -392,6 +540,10 @@ private async Task HandleRequestAsync(Stream stream, JsonElement request, Cancel }, "session.create" => CreateSessionResult(request), "session.resume" => CreateSessionResult(request), + "session.eventLog.registerInterest" => new Dictionary + { + ["id"] = "interest-1" + }, "session.send" => new Dictionary { ["messageId"] = "message-1" diff --git a/dotnet/test/Unit/PublicDtoTests.cs b/dotnet/test/Unit/PublicDtoTests.cs index c81a8a7a6..d1918d2b9 100644 --- a/dotnet/test/Unit/PublicDtoTests.cs +++ b/dotnet/test/Unit/PublicDtoTests.cs @@ -20,6 +20,25 @@ namespace GitHub.Copilot.Test.Unit; /// public class PublicDtoTests { + [Fact] + public void McpAuth_Result_Factories_Represent_Token_And_Cancellation() + { + var token = new McpAuthToken + { + AccessToken = "host-token", + TokenType = "Bearer", + ExpiresIn = 3600, + }; + + var tokenResult = McpAuthResult.FromToken(token); + Assert.Same(token, tokenResult.Token); + Assert.False(tokenResult.Cancelled); + + var cancelled = McpAuthResult.Cancel(); + Assert.True(cancelled.Cancelled); + Assert.Null(cancelled.Token); + } + [Fact] public void Public_Dto_Properties_Can_Be_Set_And_Read() { diff --git a/dotnet/test/Unit/SessionEventSerializationTests.cs b/dotnet/test/Unit/SessionEventSerializationTests.cs index 47b4ac3f7..a0796b134 100644 --- a/dotnet/test/Unit/SessionEventSerializationTests.cs +++ b/dotnet/test/Unit/SessionEventSerializationTests.cs @@ -158,6 +158,11 @@ public class SessionEventSerializationTests GrantType = "client_credentials", PublicClient = false, }, + WwwAuthenticateParams = new McpOauthWWWAuthenticateParams + { + ResourceMetadataUrl = "https://example.com/.well-known/oauth-protected-resource", + }, + ResourceMetadata = """{"resource":"https://example.com/mcp"}""", }, }, "mcp.oauth_required" @@ -281,6 +286,11 @@ public void SessionEvent_ToJson_RoundTrips_JsonElementBackedPayloads(SessionEven .GetProperty("staticClientConfig") .GetProperty("grantType") .GetString()); + Assert.Equal( + """{"resource":"https://example.com/mcp"}""", + root.GetProperty("data") + .GetProperty("resourceMetadata") + .GetString()); break; case "assistant.message_start": @@ -297,4 +307,26 @@ public void SessionEvent_ToJson_RoundTrips_JsonElementBackedPayloads(SessionEven break; } } + + [Fact] + public void McpOauthRequiredData_Allows_Missing_Optional_Metadata() + { + const string json = """ + { + "id": "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + "timestamp": "2026-03-15T21:26:54.987Z", + "parentId": null, + "type": "mcp.oauth_required", + "data": { + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + } + } + """; + + var authEvent = Assert.IsType(SessionEvent.FromJson(json)); + Assert.Null(authEvent.Data.WwwAuthenticateParams); + Assert.Null(authEvent.Data.ResourceMetadata); + } } diff --git a/go/client.go b/go/client.go index 5dc44e027..2144d54de 100644 --- a/go/client.go +++ b/go/client.go @@ -781,6 +781,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses s.registerTools(config.Tools) s.registerPermissionHandler(config.OnPermissionRequest) + s.registerMCPAuthHandler(config.OnMCPAuthRequest) if config.OnUserInputRequest != nil { s.registerUserInputHandler(config.OnUserInputRequest) } @@ -848,6 +849,14 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses } session = s registeredSessionID = localSessionID + if config.OnMCPAuthRequest != nil { + if _, err := c.client.Request(ctx, "session.eventLog.registerInterest", map[string]any{ + "sessionId": localSessionID, + "eventType": "mcp.oauth_required", + }); err != nil { + return nil, err + } + } } // For the server-assigned (cloud) path, register the session @@ -909,6 +918,15 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses c.sessionsMux.Unlock() return nil, fmt.Errorf("session.create returned sessionId %s but the caller requested %s", response.SessionID, localSessionID) } + // Local IDs registered before create; server-assigned IDs can only register now. + if localSessionID == "" && config.OnMCPAuthRequest != nil { + if _, err := c.client.Request(ctx, "session.eventLog.registerInterest", map[string]any{ + "sessionId": session.SessionID, + "eventType": "mcp.oauth_required", + }); err != nil { + return nil, err + } + } session.workspacePath = response.WorkspacePath session.setCapabilities(response.Capabilities) @@ -1077,6 +1095,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, session.registerTools(config.Tools) session.registerPermissionHandler(config.OnPermissionRequest) + session.registerMCPAuthHandler(config.OnMCPAuthRequest) if config.OnUserInputRequest != nil { session.registerUserInputHandler(config.OnUserInputRequest) } @@ -1108,6 +1127,17 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, c.sessionsMux.Lock() c.sessions[sessionID] = session c.sessionsMux.Unlock() + if config.OnMCPAuthRequest != nil { + if _, err := c.client.Request(ctx, "session.eventLog.registerInterest", map[string]any{ + "sessionId": sessionID, + "eventType": "mcp.oauth_required", + }); err != nil { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + return nil, err + } + } if c.options.SessionFS != nil { if config.CreateSessionFSProvider == nil { diff --git a/go/client_test.go b/go/client_test.go index 383ee315d..fb33f654a 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -3,6 +3,8 @@ package copilot import ( "context" "encoding/json" + "fmt" + "io" "net" "os" "os/exec" @@ -1315,6 +1317,291 @@ func TestClient_StartStopRace(t *testing.T) { } } +func TestClient_MCPAuthInterestRegistration(t *testing.T) { + t.Run("create skips MCP OAuth interest without auth handler", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + session, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnEvent: func(SessionEvent) {}, + }) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + defer session.Disconnect() + + assertNoMCPAuthInterest(t, requests.snapshot()) + assertRequestMethod(t, requests.snapshot(), "session.create") + assertCreateRequestPermission(t, requests.snapshot()) + }) + + t.Run("create registers MCP OAuth interest before local session create when auth handler is configured", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + session, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnMCPAuthRequest: func(MCPAuthRequest, MCPAuthInvocation) (*MCPAuthResult, error) { + return &MCPAuthResult{Kind: "cancelled"}, nil + }, + }) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + defer session.Disconnect() + + snapshot := requests.snapshot() + assertRequestMethod(t, snapshot, "session.eventLog.registerInterest") + if snapshot[0].Method != "session.eventLog.registerInterest" { + t.Fatalf("expected MCP auth interest before session.create, got %s", snapshot[0].Method) + } + if snapshot[1].Method != "session.create" { + t.Fatalf("expected session.create after MCP auth interest, got %s", snapshot[1].Method) + } + assertMCPAuthInterest(t, snapshot[0]) + assertCreateRequestPermission(t, snapshot) + }) + + t.Run("cloud create registers MCP OAuth interest after server assigns id only when auth handler is configured", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + withoutAuth, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + Cloud: &CloudSessionOptions{ + Repository: &CloudSessionRepository{Owner: "github", Name: "copilot-sdk", Branch: "main"}, + }, + }) + if err != nil { + t.Fatalf("CreateSession without auth failed: %v", err) + } + defer withoutAuth.Disconnect() + + assertNoMCPAuthInterest(t, requests.snapshot()) + requests.clear() + + withAuth, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnMCPAuthRequest: func(MCPAuthRequest, MCPAuthInvocation) (*MCPAuthResult, error) { + return &MCPAuthResult{Kind: "cancelled"}, nil + }, + Cloud: &CloudSessionOptions{ + Repository: &CloudSessionRepository{Owner: "github", Name: "copilot-sdk", Branch: "main"}, + }, + }) + if err != nil { + t.Fatalf("CreateSession with auth failed: %v", err) + } + defer withAuth.Disconnect() + + snapshot := requests.snapshot() + if snapshot[0].Method != "session.create" { + t.Fatalf("expected cloud session.create before MCP auth interest, got %s", snapshot[0].Method) + } + if snapshot[1].Method != "session.eventLog.registerInterest" { + t.Fatalf("expected MCP auth interest after cloud session.create, got %s", snapshot[1].Method) + } + assertMCPAuthInterest(t, snapshot[1]) + }) + + t.Run("resume conditionally registers MCP OAuth interest before session resume", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + withoutAuth, err := client.ResumeSession(t.Context(), "session-without-auth", &ResumeSessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnEvent: func(SessionEvent) {}, + }) + if err != nil { + t.Fatalf("ResumeSession without auth failed: %v", err) + } + defer withoutAuth.Disconnect() + + assertNoMCPAuthInterest(t, requests.snapshot()) + assertRequestMethod(t, requests.snapshot(), "session.resume") + requests.clear() + + withAuth, err := client.ResumeSession(t.Context(), "session-with-auth", &ResumeSessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnMCPAuthRequest: func(MCPAuthRequest, MCPAuthInvocation) (*MCPAuthResult, error) { + return &MCPAuthResult{Kind: "cancelled"}, nil + }, + }) + if err != nil { + t.Fatalf("ResumeSession with auth failed: %v", err) + } + defer withAuth.Disconnect() + + snapshot := requests.snapshot() + if snapshot[0].Method != "session.eventLog.registerInterest" { + t.Fatalf("expected MCP auth interest before session.resume, got %s", snapshot[0].Method) + } + if snapshot[1].Method != "session.resume" { + t.Fatalf("expected session.resume after MCP auth interest, got %s", snapshot[1].Method) + } + assertMCPAuthInterest(t, snapshot[0]) + }) +} + +type recordedRequest struct { + Method string + Params map[string]any +} + +type requestRecorder struct { + mu sync.Mutex + requests []recordedRequest +} + +func (r *requestRecorder) append(request recordedRequest) { + r.mu.Lock() + defer r.mu.Unlock() + r.requests = append(r.requests, request) +} + +func (r *requestRecorder) snapshot() []recordedRequest { + r.mu.Lock() + defer r.mu.Unlock() + out := make([]recordedRequest, len(r.requests)) + copy(out, r.requests) + return out +} + +func (r *requestRecorder) clear() { + r.mu.Lock() + defer r.mu.Unlock() + r.requests = nil +} + +func newInMemoryClient(t *testing.T) (*Client, *requestRecorder, func()) { + t.Helper() + + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + rpcClient := jsonrpc2.NewClient(stdinW, stdoutR) + rpcClient.Start() + + client := NewClient(&ClientOptions{}) + client.client = rpcClient + client.RPC = rpc.NewServerRPC(rpcClient) + client.state = stateConnected + + requests := &requestRecorder{} + done := make(chan struct{}) + go serveInMemoryRuntime(t, stdinR, stdoutW, requests, done) + + cleanup := func() { + rpcClient.Stop() + stdinR.Close() + stdinW.Close() + stdoutR.Close() + stdoutW.Close() + <-done + } + return client, requests, cleanup +} + +func serveInMemoryRuntime(t *testing.T, stdinR *io.PipeReader, stdoutW *io.PipeWriter, requests *requestRecorder, done chan<- struct{}) { + t.Helper() + defer close(done) + + serverAssignedSessions := 0 + for { + frame, err := readTestJSONRPCFrame(stdinR) + if err != nil { + return + } + + var request struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + Params map[string]any `json:"params"` + } + if err := json.Unmarshal(frame, &request); err != nil { + t.Errorf("failed to unmarshal JSON-RPC request: %v", err) + return + } + requests.append(recordedRequest{Method: request.Method, Params: request.Params}) + + result := map[string]any{} + switch request.Method { + case "session.create", "session.resume": + sessionID, _ := request.Params["sessionId"].(string) + if sessionID == "" { + serverAssignedSessions++ + sessionID = fmt.Sprintf("server-assigned-session-%d", serverAssignedSessions) + } + result = map[string]any{"sessionId": sessionID, "workspacePath": nil} + case "session.eventLog.registerInterest": + result = map[string]any{"id": "interest-1"} + case "session.options.update": + result = map[string]any{"success": true} + case "session.skills.reload", "session.destroy": + result = map[string]any{} + default: + t.Errorf("unexpected JSON-RPC method %s", request.Method) + return + } + + response := map[string]any{ + "jsonrpc": "2.0", + "id": json.RawMessage(request.ID), + "result": result, + } + data, err := json.Marshal(response) + if err != nil { + t.Errorf("failed to marshal JSON-RPC response: %v", err) + return + } + if _, err := fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(data), data); err != nil { + return + } + } +} + +func assertRequestMethod(t *testing.T, requests []recordedRequest, method string) { + t.Helper() + for _, request := range requests { + if request.Method == method { + return + } + } + t.Fatalf("expected %s request in %+v", method, requests) +} + +func assertNoMCPAuthInterest(t *testing.T, requests []recordedRequest) { + t.Helper() + for _, request := range requests { + if request.Method == "session.eventLog.registerInterest" && request.Params["eventType"] == "mcp.oauth_required" { + t.Fatalf("did not expect MCP auth interest registration in %+v", requests) + } + } +} + +func assertMCPAuthInterest(t *testing.T, request recordedRequest) { + t.Helper() + if request.Method != "session.eventLog.registerInterest" { + t.Fatalf("expected registerInterest request, got %s", request.Method) + } + if request.Params["eventType"] != "mcp.oauth_required" { + t.Fatalf("expected mcp.oauth_required interest, got %v", request.Params["eventType"]) + } +} + +func assertCreateRequestPermission(t *testing.T, requests []recordedRequest) { + t.Helper() + for _, request := range requests { + if request.Method == "session.create" { + if request.Params["requestPermission"] != true { + t.Fatalf("expected create requestPermission=true, got %v", request.Params["requestPermission"]) + } + return + } + } + t.Fatalf("session.create request not found in %+v", requests) +} + func TestCreateSessionRequest_Commands(t *testing.T) { t.Run("forwards commands in session.create RPC", func(t *testing.T) { req := createSessionRequest{ diff --git a/go/session.go b/go/session.go index acd698677..db9af5450 100644 --- a/go/session.go +++ b/go/session.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "log" "sync" "time" @@ -61,6 +62,8 @@ type Session struct { toolHandlersM sync.RWMutex permissionHandler PermissionHandlerFunc permissionMux sync.RWMutex + mcpAuthHandler MCPAuthHandler + mcpAuthMu sync.RWMutex userInputHandler UserInputHandler userInputMux sync.RWMutex exitPlanModeHandler ExitPlanModeRequestHandler @@ -863,6 +866,46 @@ func (s *Session) getElicitationHandler() ElicitationHandler { return s.elicitationHandler } +func (s *Session) registerMCPAuthHandler(handler MCPAuthHandler) { + s.mcpAuthMu.Lock() + defer s.mcpAuthMu.Unlock() + s.mcpAuthHandler = handler +} + +func (s *Session) getMCPAuthHandler() MCPAuthHandler { + s.mcpAuthMu.RLock() + defer s.mcpAuthMu.RUnlock() + return s.mcpAuthHandler +} + +func (s *Session) handleMCPAuthRequest(request MCPAuthRequest) { + handler := s.getMCPAuthHandler() + if handler == nil { + return + } + + ctx := context.Background() + cancel := &rpc.MCPOauthPendingRequestResponseCancelled{} + result, err := handler(request, MCPAuthInvocation{SessionID: s.SessionID}) + if err != nil || result == nil || result.Kind == "cancelled" || result.Token == nil { + s.RPC.MCP.Oauth().HandlePendingRequest(ctx, &rpc.MCPOauthHandlePendingRequest{ + RequestID: request.RequestID, + Result: cancel, + }) + return + } + + s.RPC.MCP.Oauth().HandlePendingRequest(ctx, &rpc.MCPOauthHandlePendingRequest{ + RequestID: request.RequestID, + Result: &rpc.MCPOauthPendingRequestResponseToken{ + AccessToken: result.Token.AccessToken, + TokenType: result.Token.TokenType, + RefreshToken: result.Token.RefreshToken, + ExpiresIn: result.Token.ExpiresIn, + }, + }) +} + // handleElicitationRequest dispatches an elicitation.requested event to the registered handler // and sends the result back via the RPC layer. Auto-cancels on error. func (s *Session) handleElicitationRequest(elicitCtx ElicitationContext, requestID string) { @@ -1309,6 +1352,56 @@ func (s *Session) handleBroadcastEvent(event SessionEvent) { } s.executePermissionAndRespond(d.RequestID, d.PermissionRequest, handler) + case *MCPOauthRequiredData: + handler := s.getMCPAuthHandler() + if d.RequestID == "" { + return + } + if handler == nil { + log.Printf( + "Received MCP OAuth request without a registered MCP auth handler. SessionId=%s, RequestId=%s", + s.SessionID, + d.RequestID, + ) + return + } + var staticClientConfig *MCPAuthStaticClientConfig + if d.StaticClientConfig != nil { + var grantType string + if d.StaticClientConfig.GrantType != nil { + grantType = string(*d.StaticClientConfig.GrantType) + } + staticClientConfig = &MCPAuthStaticClientConfig{ + ClientID: d.StaticClientConfig.ClientID, + GrantType: grantType, + PublicClient: d.StaticClientConfig.PublicClient, + } + } + request := MCPAuthRequest{ + RequestID: d.RequestID, + ServerName: d.ServerName, + ServerURL: d.ServerURL, + StaticClientConfig: staticClientConfig, + } + if d.ResourceMetadata != nil { + request.ResourceMetadata = *d.ResourceMetadata + } + if d.WwwAuthenticateParams != nil { + var scope, oauthError string + if d.WwwAuthenticateParams.Scope != nil { + scope = *d.WwwAuthenticateParams.Scope + } + if d.WwwAuthenticateParams.Error != nil { + oauthError = *d.WwwAuthenticateParams.Error + } + request.WwwAuthenticateParams = &MCPAuthWwwAuthenticateParams{ + ResourceMetadataURL: d.WwwAuthenticateParams.ResourceMetadataURL, + Scope: scope, + Error: oauthError, + } + } + s.handleMCPAuthRequest(request) + case *CommandExecuteData: s.executeCommandAndRespond(d.RequestID, d.CommandName, d.Command, d.Args) diff --git a/go/session_test.go b/go/session_test.go index 15cfbcf57..b20918868 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -60,6 +60,170 @@ func TestSession_SetModelOmitsContextTierWhenUnset(t *testing.T) { } } +func TestSession_MCPAuthRequestSendsHostToken(t *testing.T) { + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + defer stdinR.Close() + defer stdinW.Close() + defer stdoutR.Close() + defer stdoutW.Close() + + client := jsonrpc2.NewClient(stdinW, stdoutR) + client.Start() + defer client.Stop() + + paramsCh := make(chan map[string]any, 1) + errCh := make(chan error, 1) + + go func() { + frame, err := readTestJSONRPCFrame(stdinR) + if err != nil { + errCh <- err + return + } + + var request struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + Params map[string]any `json:"params"` + } + if err := json.Unmarshal(frame, &request); err != nil { + errCh <- err + return + } + if request.Method != "session.mcp.oauth.handlePendingRequest" { + errCh <- fmt.Errorf("expected session.mcp.oauth.handlePendingRequest, got %s", request.Method) + return + } + + paramsCh <- request.Params + + response := map[string]any{ + "jsonrpc": "2.0", + "id": json.RawMessage(request.ID), + "result": map[string]any{"success": true}, + } + data, err := json.Marshal(response) + if err != nil { + errCh <- err + return + } + if _, err := fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(data), data); err != nil { + errCh <- err + } + }() + + session := &Session{ + SessionID: "session-1", + client: client, + RPC: rpc.NewSessionRPC(client, "session-1"), + } + var observedRequest MCPAuthRequest + session.registerMCPAuthHandler(func(request MCPAuthRequest, invocation MCPAuthInvocation) (*MCPAuthResult, error) { + observedRequest = request + if invocation.SessionID != "session-1" { + t.Fatalf("expected invocation session-1, got %s", invocation.SessionID) + } + if request.RequestID != "oauth-request" { + t.Fatalf("expected oauth-request, got %s", request.RequestID) + } + tokenType := "Bearer" + return &MCPAuthResult{ + Kind: "token", + Token: &MCPAuthToken{ + AccessToken: "host-token", + TokenType: &tokenType, + }, + }, nil + }) + session.handleMCPAuthRequest(MCPAuthRequest{ + RequestID: "oauth-request", + ResourceMetadata: `{"resource":"https://example.com/mcp"}`, + WwwAuthenticateParams: &MCPAuthWwwAuthenticateParams{ + ResourceMetadataURL: "https://example.com/.well-known/oauth-protected-resource", + }, + }) + if observedRequest.ResourceMetadata != `{"resource":"https://example.com/mcp"}` { + t.Fatalf("expected resource metadata to be propagated, got %q", observedRequest.ResourceMetadata) + } + if observedRequest.WwwAuthenticateParams == nil { + t.Fatal("expected WWW-Authenticate params to be propagated") + } + + select { + case params := <-paramsCh: + if params["sessionId"] != "session-1" { + t.Fatalf("expected sessionId session-1, got %v", params["sessionId"]) + } + if params["requestId"] != "oauth-request" { + t.Fatalf("expected requestId oauth-request, got %v", params["requestId"]) + } + result, ok := params["result"].(map[string]any) + if !ok { + t.Fatalf("expected result object, got %T", params["result"]) + } + if result["kind"] != "token" { + t.Fatalf("expected token kind, got %v", result["kind"]) + } + if result["accessToken"] != "host-token" { + t.Fatalf("expected accessToken host-token, got %v", result["accessToken"]) + } + if result["tokenType"] != "Bearer" { + t.Fatalf("expected tokenType Bearer, got %v", result["tokenType"]) + } + case err := <-errCh: + t.Fatal(err) + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for MCP OAuth request") + } +} + +func TestMCPAuthRequestAllowsMissingOptionalMetadata(t *testing.T) { + request := MCPAuthRequest{RequestID: "oauth-request"} + if request.ResourceMetadata != "" { + t.Fatalf("expected no resource metadata, got %q", request.ResourceMetadata) + } + if request.WwwAuthenticateParams != nil { + t.Fatalf("expected no WWW-Authenticate params, got %#v", request.WwwAuthenticateParams) + } +} + +func TestMCPOauthRequiredDataAllowsOptionalMetadata(t *testing.T) { + var withMetadata rpc.MCPOauthRequiredData + if err := json.Unmarshal([]byte(`{ + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp", + "wwwAuthenticateParams": { + "resourceMetadataUrl": "https://example.com/.well-known/oauth-protected-resource" + }, + "resourceMetadata": "{\"resource\":\"https://example.com/mcp\"}" + }`), &withMetadata); err != nil { + t.Fatal(err) + } + if withMetadata.ResourceMetadata == nil || *withMetadata.ResourceMetadata != `{"resource":"https://example.com/mcp"}` { + t.Fatalf("expected resource metadata, got %#v", withMetadata.ResourceMetadata) + } + if withMetadata.WwwAuthenticateParams == nil { + t.Fatal("expected WWW-Authenticate params") + } + + var withoutMetadata rpc.MCPOauthRequiredData + if err := json.Unmarshal([]byte(`{ + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + }`), &withoutMetadata); err != nil { + t.Fatal(err) + } + if withoutMetadata.ResourceMetadata != nil { + t.Fatalf("expected no resource metadata, got %#v", withoutMetadata.ResourceMetadata) + } + if withoutMetadata.WwwAuthenticateParams != nil { + t.Fatalf("expected no WWW-Authenticate params, got %#v", withoutMetadata.WwwAuthenticateParams) + } +} + func captureSetModelRequest(t *testing.T, opts *SetModelOptions) map[string]any { t.Helper() diff --git a/go/types.go b/go/types.go index 3ff6d0f9c..9a17427d7 100644 --- a/go/types.go +++ b/go/types.go @@ -319,6 +319,52 @@ type PermissionInvocation struct { SessionID string } +// MCPAuthWwwAuthenticateParams contains parsed parameters from an MCP server's WWW-Authenticate response. +type MCPAuthWwwAuthenticateParams struct { + ResourceMetadataURL string `json:"resourceMetadataUrl"` + Scope string `json:"scope,omitempty"` + Error string `json:"error,omitempty"` +} + +// MCPAuthStaticClientConfig is static OAuth client configuration supplied by an MCP server. +type MCPAuthStaticClientConfig struct { + ClientID string `json:"clientId"` + GrantType string `json:"grantType,omitempty"` + PublicClient *bool `json:"publicClient,omitempty"` +} + +// MCPAuthRequest describes an MCP OAuth request that the SDK host can satisfy with a token. +type MCPAuthRequest struct { + RequestID string `json:"requestId"` + ServerName string `json:"serverName"` + ServerURL string `json:"serverUrl"` + WwwAuthenticateParams *MCPAuthWwwAuthenticateParams `json:"wwwAuthenticateParams,omitempty"` + ResourceMetadata string `json:"resourceMetadata,omitempty"` + StaticClientConfig *MCPAuthStaticClientConfig `json:"staticClientConfig,omitempty"` +} + +// MCPAuthToken is host-provided OAuth token data for a pending MCP OAuth request. +type MCPAuthToken struct { + AccessToken string `json:"accessToken"` + TokenType *string `json:"tokenType,omitempty"` + RefreshToken *string `json:"refreshToken,omitempty"` + ExpiresIn *int64 `json:"expiresIn,omitempty"` +} + +// MCPAuthResult is the result returned by an MCP auth request handler. +type MCPAuthResult struct { + Kind string + Token *MCPAuthToken +} + +// MCPAuthInvocation provides context about an MCP auth handler invocation. +type MCPAuthInvocation struct { + SessionID string +} + +// MCPAuthHandler handles MCP OAuth requests from the runtime. +type MCPAuthHandler func(request MCPAuthRequest, invocation MCPAuthInvocation) (*MCPAuthResult, error) + // UserInputRequest represents a request for user input from the agent type UserInputRequest struct { Question string @@ -968,6 +1014,10 @@ type SessionConfig struct { // When nil, permission requests are surfaced as events and left pending for the // consumer to resolve via pending permission RPCs. OnPermissionRequest PermissionHandlerFunc + // OnMCPAuthRequest is an optional handler for MCP OAuth requests from MCP servers. + // When provided, the SDK can satisfy MCP server OAuth requests with host-provided + // token data or cancellation. + OnMCPAuthRequest MCPAuthHandler // OnUserInputRequest is a handler for user input requests from the agent (enables ask_user tool) OnUserInputRequest UserInputHandler // Hooks configures hook handlers for session lifecycle events @@ -1386,6 +1436,9 @@ type ResumeSessionConfig struct { // When nil, permission requests are surfaced as events and left pending for the // consumer to resolve via pending permission RPCs. OnPermissionRequest PermissionHandlerFunc + // OnMCPAuthRequest is an optional handler for MCP OAuth requests from MCP servers. + // See SessionConfig.OnMCPAuthRequest. + OnMCPAuthRequest MCPAuthHandler // OnUserInputRequest is a handler for user input requests from the agent (enables ask_user tool) OnUserInputRequest UserInputHandler // Hooks configures hook handlers for session lifecycle events diff --git a/java/src/main/java/com/github/copilot/CopilotClient.java b/java/src/main/java/com/github/copilot/CopilotClient.java index 63b70e2df..7015cf752 100644 --- a/java/src/main/java/com/github/copilot/CopilotClient.java +++ b/java/src/main/java/com/github/copilot/CopilotClient.java @@ -27,6 +27,7 @@ import com.github.copilot.generated.rpc.SessionInstalledPlugin; import com.github.copilot.generated.rpc.ConnectParams; import com.github.copilot.generated.rpc.ServerRpc; +import com.github.copilot.generated.rpc.SessionEventLogRegisterInterestParams; import com.github.copilot.rpc.DeleteSessionResponse; import com.github.copilot.rpc.GetAuthStatusResponse; import com.github.copilot.rpc.GetLastSessionIdResponse; @@ -564,6 +565,7 @@ public CompletableFuture createSession(SessionConfig config) { String[] registeredIdHolder = new String[1]; CopilotSession[] preRegisteredSessionHolder = new CopilotSession[1]; + CompletableFuture preCreateInterest = CompletableFuture.completedFuture(null); // Pre-register non-cloud sessions BEFORE issuing the RPC so any // session-scoped requests the CLI emits during session.create @@ -571,6 +573,10 @@ public CompletableFuture createSession(SessionConfig config) { if (localSessionId != null) { preRegisteredSessionHolder[0] = initializeSession.apply(localSessionId); registeredIdHolder[0] = localSessionId; + if (config.getOnMcpAuthRequest() != null) { + preCreateInterest = preRegisteredSessionHolder[0].getRpc().eventLog.registerInterest( + new SessionEventLogRegisterInterestParams(localSessionId, "mcp.oauth_required")); + } } var request = SessionRequestBuilder.buildCreateRequest(config, localSessionId); @@ -620,7 +626,9 @@ public CompletableFuture createSession(SessionConfig config) { } long rpcNanos = System.nanoTime(); - return connection.rpc.invoke("session.create", request, CreateSessionResponse.class) + return preCreateInterest + .thenCompose( + ignored -> connection.rpc.invoke("session.create", request, CreateSessionResponse.class)) .thenCompose(response -> { String returnedId = response.sessionId(); LoggingHelpers.logTiming(LOG, Level.FINE, @@ -638,14 +646,23 @@ public CompletableFuture createSession(SessionConfig config) { ? preRegisteredSessionHolder[0] : initializeSession.apply(returnedId); registeredIdHolder[0] = returnedId; + // Local IDs registered before create; server-assigned IDs can only register + // now. + CompletableFuture interest = config.getOnMcpAuthRequest() != null + && preRegisteredSessionHolder[0] == null + ? session.getRpc().eventLog + .registerInterest(new SessionEventLogRegisterInterestParams(returnedId, + "mcp.oauth_required")) + : CompletableFuture.completedFuture(null); session.setWorkspacePath(response.workspacePath()); session.setCapabilities(response.capabilities()); session.setOpenCanvases(response.openCanvases()); - return updateSessionOptionsForMode(session, config.getSkipCustomInstructions().orElse(null), + return interest.thenCompose(ignored -> updateSessionOptionsForMode(session, + config.getSkipCustomInstructions().orElse(null), config.getCustomAgentsLocalOnly().orElse(null), config.getCoauthorEnabled().orElse(null), - config.getManageScheduleEnabled().orElse(null)).thenApply(v -> { + config.getManageScheduleEnabled().orElse(null))).thenApply(v -> { LoggingHelpers.logTiming(LOG, Level.FINE, "CopilotClient.createSession complete. Elapsed={Elapsed}, SessionId=" + session.getSessionId(), @@ -714,6 +731,10 @@ public CompletableFuture resumeSession(String sessionId, ResumeS if (extracted.transformCallbacks() != null) { session.registerTransformCallbacks(extracted.transformCallbacks()); } + CompletableFuture interest = config.getOnMcpAuthRequest() != null + ? session.getRpc().eventLog.registerInterest( + new SessionEventLogRegisterInterestParams(sessionId, "mcp.oauth_required")) + : CompletableFuture.completedFuture(null); var request = SessionRequestBuilder.buildResumeRequest(sessionId, config); if (extracted.wireSystemMessage() != config.getSystemMessage()) { @@ -760,7 +781,9 @@ public CompletableFuture resumeSession(String sessionId, ResumeS } long rpcNanos = System.nanoTime(); - return connection.rpc.invoke("session.resume", request, ResumeSessionResponse.class) + return interest + .thenCompose( + ignored -> connection.rpc.invoke("session.resume", request, ResumeSessionResponse.class)) .thenCompose(response -> { LoggingHelpers.logTiming(LOG, Level.FINE, "CopilotClient.resumeSession session resume request completed. Elapsed={Elapsed}, SessionId=" diff --git a/java/src/main/java/com/github/copilot/CopilotSession.java b/java/src/main/java/com/github/copilot/CopilotSession.java index adfeac013..fd43d5555 100644 --- a/java/src/main/java/com/github/copilot/CopilotSession.java +++ b/java/src/main/java/com/github/copilot/CopilotSession.java @@ -33,6 +33,7 @@ import com.github.copilot.generated.rpc.SessionCommandsHandlePendingCommandParams; import com.github.copilot.generated.rpc.SessionLogParams; import com.github.copilot.generated.rpc.SessionLogLevel; +import com.github.copilot.generated.rpc.SessionMcpOauthHandlePendingRequestParams; import com.github.copilot.generated.rpc.ModelCapabilitiesOverride; import com.github.copilot.generated.rpc.ModelCapabilitiesOverrideLimits; import com.github.copilot.generated.rpc.ModelCapabilitiesOverrideSupports; @@ -49,6 +50,7 @@ import com.github.copilot.generated.CommandExecuteEvent; import com.github.copilot.generated.ElicitationRequestedEvent; import com.github.copilot.generated.ExternalToolRequestedEvent; +import com.github.copilot.generated.McpOauthRequiredEvent; import com.github.copilot.generated.PermissionRequestedEvent; import com.github.copilot.generated.SessionCanvasClosedEvent; import com.github.copilot.generated.SessionCanvasOpenedEvent; @@ -79,6 +81,9 @@ import com.github.copilot.rpc.HookInvocation; import com.github.copilot.rpc.InputOptions; import com.github.copilot.rpc.MessageOptions; +import com.github.copilot.rpc.McpAuthHandler; +import com.github.copilot.rpc.McpAuthRequest; +import com.github.copilot.rpc.McpAuthResult; import com.github.copilot.rpc.PermissionHandler; import com.github.copilot.rpc.PermissionInvocation; import com.github.copilot.rpc.PermissionRequest; @@ -170,6 +175,7 @@ public final class CopilotSession implements AutoCloseable { private final Map toolHandlers = new ConcurrentHashMap<>(); private final Map commandHandlers = new ConcurrentHashMap<>(); private final AtomicReference permissionHandler = new AtomicReference<>(); + private final AtomicReference mcpAuthHandler = new AtomicReference<>(); private final AtomicReference userInputHandler = new AtomicReference<>(); private final AtomicReference elicitationHandler = new AtomicReference<>(); private final AtomicReference exitPlanModeHandler = new AtomicReference<>(); @@ -838,6 +844,20 @@ private void handleBroadcastEventAsync(SessionEvent event) { } executePermissionAndRespondAsync(data.requestId(), MAPPER.convertValue(data.permissionRequest(), PermissionRequest.class), handler); + } else if (event instanceof McpOauthRequiredEvent authEvent) { + var data = authEvent.getData(); + if (data == null || data.requestId() == null) { + return; + } + McpAuthHandler handler = mcpAuthHandler.get(); + if (handler == null) { + LOG.warning(() -> "Received MCP OAuth request without a registered MCP auth handler. SessionId=" + + sessionId + ", RequestId=" + data.requestId()); + return; + } + executeMcpAuthAndRespondAsync(new McpAuthRequest(sessionId, data.requestId(), data.serverName(), + data.serverUrl(), data.wwwAuthenticateParams(), data.resourceMetadata(), data.staticClientConfig()), + handler); } else if (event instanceof CommandExecuteEvent cmdEvent) { var data = cmdEvent.getData(); if (data == null || data.requestId() == null || data.commandName() == null) { @@ -1005,6 +1025,60 @@ private void executePermissionAndRespondAsync(String requestId, PermissionReques } } + private void executeMcpAuthAndRespondAsync(McpAuthRequest request, McpAuthHandler handler) { + Runnable task = () -> { + try { + handler.handle(request).thenAccept(result -> sendMcpAuthResponse(request.requestId(), result)) + .exceptionally(ex -> { + sendMcpAuthResponse(request.requestId(), McpAuthResult.cancelled()); + return null; + }); + } catch (Exception e) { + LOG.log(Level.WARNING, "Error executing MCP auth handler for requestId=" + request.requestId(), e); + sendMcpAuthResponse(request.requestId(), McpAuthResult.cancelled()); + } + }; + try { + if (executor != null) { + CompletableFuture.runAsync(task, executor); + } else { + CompletableFuture.runAsync(task); + } + } catch (RejectedExecutionException e) { + LOG.log(Level.WARNING, + "Executor rejected MCP auth task for requestId=" + request.requestId() + "; running inline", e); + task.run(); + } + } + + private void sendMcpAuthResponse(String requestId, McpAuthResult result) { + try { + Object response; + if (result == null || result.isCancelled() || result.token() == null) { + response = Map.of("kind", "cancelled"); + } else { + var token = result.token(); + var tokenResponse = new java.util.HashMap(); + tokenResponse.put("kind", "token"); + tokenResponse.put("accessToken", token.accessToken()); + if (token.tokenType() != null) { + tokenResponse.put("tokenType", token.tokenType()); + } + if (token.refreshToken() != null) { + tokenResponse.put("refreshToken", token.refreshToken()); + } + if (token.expiresIn() != null) { + tokenResponse.put("expiresIn", token.expiresIn()); + } + response = tokenResponse; + } + getRpc().mcp.oauth.handlePendingRequest( + new SessionMcpOauthHandlePendingRequestParams(sessionId, requestId, response)); + } catch (Exception e) { + LOG.log(Level.WARNING, "Error sending MCP auth response for requestId=" + requestId, e); + } + } + /** * Registers custom tool handlers for this session. *

@@ -1268,6 +1342,10 @@ void registerPermissionHandler(PermissionHandler handler) { permissionHandler.set(handler); } + void registerMcpAuthHandler(McpAuthHandler handler) { + mcpAuthHandler.set(handler); + } + /** * Handles a permission request from the Copilot CLI. *

diff --git a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java index c26548a2f..a67da9383 100644 --- a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java +++ b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java @@ -317,6 +317,9 @@ static void configureSession(CopilotSession session, SessionConfig config) { if (config.getOnPermissionRequest() != null) { session.registerPermissionHandler(config.getOnPermissionRequest()); } + if (config.getOnMcpAuthRequest() != null) { + session.registerMcpAuthHandler(config.getOnMcpAuthRequest()); + } if (config.getOnUserInputRequest() != null) { session.registerUserInputHandler(config.getOnUserInputRequest()); } @@ -359,6 +362,9 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config) if (config.getOnPermissionRequest() != null) { session.registerPermissionHandler(config.getOnPermissionRequest()); } + if (config.getOnMcpAuthRequest() != null) { + session.registerMcpAuthHandler(config.getOnMcpAuthRequest()); + } if (config.getOnUserInputRequest() != null) { session.registerUserInputHandler(config.getOnUserInputRequest()); } diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthHandler.java b/java/src/main/java/com/github/copilot/rpc/McpAuthHandler.java new file mode 100644 index 000000000..e92ef1dd4 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthHandler.java @@ -0,0 +1,24 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import java.util.concurrent.CompletableFuture; + +/** + * Handles MCP OAuth requests from the runtime. + * + * @since 1.0.0 + */ +@FunctionalInterface +public interface McpAuthHandler { + /** + * Handles an MCP OAuth request. + * + * @param request + * the MCP OAuth request context + * @return a future resolving to token data or cancellation + */ + CompletableFuture handle(McpAuthRequest request); +} diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java b/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java new file mode 100644 index 000000000..15a184702 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java @@ -0,0 +1,18 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import com.github.copilot.generated.McpOauthRequiredStaticClientConfig; +import com.github.copilot.generated.McpOauthWWWAuthenticateParams; + +/** + * MCP OAuth request that the SDK host can satisfy with a host-acquired token. + * + * @since 1.0.0 + */ +public record McpAuthRequest(String sessionId, String requestId, String serverName, String serverUrl, + McpOauthWWWAuthenticateParams wwwAuthenticateParams, String resourceMetadata, + McpOauthRequiredStaticClientConfig staticClientConfig) { +} diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java b/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java new file mode 100644 index 000000000..6b7fda34f --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java @@ -0,0 +1,32 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +/** + * Result returned by an MCP auth request handler. + * + * @since 1.0.0 + */ +public record McpAuthResult(boolean isCancelled, McpAuthToken token) { + /** + * Creates a token result. + * + * @param token + * the host-provided OAuth token data + * @return token result + */ + public static McpAuthResult token(McpAuthToken token) { + return new McpAuthResult(false, token); + } + + /** + * Creates a cancellation result. + * + * @return cancellation result + */ + public static McpAuthResult cancelled() { + return new McpAuthResult(true, null); + } +} diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthToken.java b/java/src/main/java/com/github/copilot/rpc/McpAuthToken.java new file mode 100644 index 000000000..5df1b33ff --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthToken.java @@ -0,0 +1,13 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +/** + * Host-provided OAuth token data for a pending MCP OAuth request. + * + * @since 1.0.0 + */ +public record McpAuthToken(String accessToken, String tokenType, String refreshToken, Long expiresIn) { +} diff --git a/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java b/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java index df600b0af..e6abf2e02 100644 --- a/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java @@ -59,6 +59,7 @@ public class ResumeSessionConfig { private String contextTier; private ModelCapabilitiesOverride modelCapabilities; private PermissionHandler onPermissionRequest; + private McpAuthHandler onMcpAuthRequest; private UserInputHandler onUserInputRequest; private SessionHooks hooks; private String workingDirectory; @@ -633,6 +634,28 @@ public ResumeSessionConfig setOnPermissionRequest(PermissionHandler onPermission return this; } + /** + * Gets the MCP OAuth request handler. + * + * @return the handler, or {@code null} if not set + */ + @JsonIgnore + public McpAuthHandler getOnMcpAuthRequest() { + return onMcpAuthRequest; + } + + /** + * Sets the MCP OAuth request handler. + * + * @param onMcpAuthRequest + * the handler + * @return this config instance for method chaining + */ + public ResumeSessionConfig setOnMcpAuthRequest(McpAuthHandler onMcpAuthRequest) { + this.onMcpAuthRequest = onMcpAuthRequest; + return this; + } + /** * Gets the user input request handler. * @@ -1671,6 +1694,7 @@ public ResumeSessionConfig clone() { copy.onEvent = this.onEvent; copy.commands = this.commands != null ? new ArrayList<>(this.commands) : null; copy.onElicitationRequest = this.onElicitationRequest; + copy.onMcpAuthRequest = this.onMcpAuthRequest; copy.onExitPlanMode = this.onExitPlanMode; copy.onAutoModeSwitch = this.onAutoModeSwitch; copy.enableMcpApps = this.enableMcpApps; diff --git a/java/src/main/java/com/github/copilot/rpc/SessionConfig.java b/java/src/main/java/com/github/copilot/rpc/SessionConfig.java index 5567d32ac..303034886 100644 --- a/java/src/main/java/com/github/copilot/rpc/SessionConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/SessionConfig.java @@ -59,6 +59,7 @@ public class SessionConfig { private Boolean coauthorEnabled; private Boolean manageScheduleEnabled; private PermissionHandler onPermissionRequest; + private McpAuthHandler onMcpAuthRequest; private UserInputHandler onUserInputRequest; private SessionHooks hooks; private String workingDirectory; @@ -676,6 +677,31 @@ public SessionConfig setOnPermissionRequest(PermissionHandler onPermissionReques return this; } + /** + * Gets the MCP OAuth request handler. + * + * @return the handler, or {@code null} if not set + */ + @JsonIgnore + public McpAuthHandler getOnMcpAuthRequest() { + return onMcpAuthRequest; + } + + /** + * Sets the MCP OAuth request handler. + *

+ * When provided, the SDK can satisfy MCP server OAuth requests with + * host-provided token data or cancellation. + * + * @param onMcpAuthRequest + * the handler + * @return this config instance for method chaining + */ + public SessionConfig setOnMcpAuthRequest(McpAuthHandler onMcpAuthRequest) { + this.onMcpAuthRequest = onMcpAuthRequest; + return this; + } + /** * Gets the user input request handler. * @@ -1795,6 +1821,7 @@ public SessionConfig clone() { copy.onEvent = this.onEvent; copy.commands = this.commands != null ? new ArrayList<>(this.commands) : null; copy.onElicitationRequest = this.onElicitationRequest; + copy.onMcpAuthRequest = this.onMcpAuthRequest; copy.onExitPlanMode = this.onExitPlanMode; copy.onAutoModeSwitch = this.onAutoModeSwitch; copy.enableMcpApps = this.enableMcpApps; diff --git a/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java b/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java new file mode 100644 index 000000000..79f7ffa85 --- /dev/null +++ b/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java @@ -0,0 +1,277 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.OutputStream; +import java.net.ServerSocket; +import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.copilot.generated.McpOauthRequiredEvent; +import com.github.copilot.rpc.CloudSessionOptions; +import com.github.copilot.rpc.CloudSessionRepository; +import com.github.copilot.rpc.CopilotClientOptions; +import com.github.copilot.rpc.McpAuthResult; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.ResumeSessionConfig; +import com.github.copilot.rpc.SessionConfig; + +class McpAuthInterestRegistrationTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + @Test + void mcpOauthRequiredEventExposesOptionalResourceMetadata() throws Exception { + var data = MAPPER.readValue(""" + { + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp", + "wwwAuthenticateParams": { + "resourceMetadataUrl": "https://example.com/.well-known/oauth-protected-resource" + }, + "resourceMetadata": "{\\"resource\\":\\"https://example.com/mcp\\"}" + } + """, McpOauthRequiredEvent.McpOauthRequiredEventData.class); + + assertEquals("{\"resource\":\"https://example.com/mcp\"}", data.resourceMetadata()); + assertNotNull(data.wwwAuthenticateParams()); + + var withoutMetadata = MAPPER.readValue(""" + { + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + } + """, McpOauthRequiredEvent.McpOauthRequiredEventData.class); + + assertNull(withoutMetadata.resourceMetadata()); + assertNull(withoutMetadata.wwwAuthenticateParams()); + } + + @Test + void createSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exception { + try (var server = new RecordingRuntime(); + var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { + try (var session = client.createSession( + new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL).setOnEvent(event -> { + })).get()) { + assertNotNull(session); + } + + assertNoMcpAuthInterest(server.requests()); + assertTrue(server.requests().stream().anyMatch(request -> "session.create".equals(request.method()) + && request.params().path("requestPermission").asBoolean())); + + server.clearRequests(); + + try (var session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()))) + .get()) { + assertNotNull(session); + } + + List requests = server.requests(); + assertEquals("session.eventLog.registerInterest", requests.get(0).method()); + assertEquals("mcp.oauth_required", requests.get(0).params().path("eventType").asText()); + assertEquals("session.create", requests.get(1).method()); + } + } + + @Test + void cloudCreateSessionRegistersMcpAuthInterestAfterCreateOnlyWhenHandlerConfigured() throws Exception { + try (var server = new RecordingRuntime(); + var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { + var cloud = new CloudSessionOptions().setRepository( + new CloudSessionRepository().setOwner("github").setName("copilot-sdk").setBranch("main")); + + try (var session = client + .createSession( + new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL).setCloud(cloud)) + .get()) { + assertNotNull(session); + } + + assertNoMcpAuthInterest(server.requests()); + server.clearRequests(); + + try (var session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setCloud(cloud).setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()))) + .get()) { + assertNotNull(session); + } + + List requests = server.requests(); + assertEquals("session.create", requests.get(0).method()); + assertEquals("session.eventLog.registerInterest", requests.get(1).method()); + assertEquals("mcp.oauth_required", requests.get(1).params().path("eventType").asText()); + } + } + + @Test + void resumeSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exception { + try (var server = new RecordingRuntime(); + var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { + try (var session = client.resumeSession("session-without-auth", new ResumeSessionConfig() + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL).setOnEvent(event -> { + })).get()) { + assertNotNull(session); + } + + assertNoMcpAuthInterest(server.requests()); + assertTrue(server.requests().stream().anyMatch(request -> "session.resume".equals(request.method()) + && request.params().path("requestPermission").asBoolean())); + + server.clearRequests(); + + try (var session = client.resumeSession("session-with-auth", + new ResumeSessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest(request -> java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()))) + .get()) { + assertNotNull(session); + } + + List requests = server.requests(); + assertEquals("session.eventLog.registerInterest", requests.get(0).method()); + assertEquals("mcp.oauth_required", requests.get(0).params().path("eventType").asText()); + assertEquals("session.resume", requests.get(1).method()); + } + } + + private static void assertNoMcpAuthInterest(List requests) { + assertFalse(requests.stream().anyMatch(request -> "session.eventLog.registerInterest".equals(request.method()) + && "mcp.oauth_required".equals(request.params().path("eventType").asText()))); + } + + private record RpcRequest(String method, JsonNode params) { + } + + private static final class RecordingRuntime implements AutoCloseable { + private final ServerSocket listener; + private final Thread thread; + private final List requests = new CopyOnWriteArrayList<>(); + private volatile boolean running = true; + + RecordingRuntime() throws Exception { + listener = new ServerSocket(0); + thread = new Thread(this::run, "mcp-auth-interest-test-runtime"); + thread.setDaemon(true); + thread.start(); + } + + String url() { + return "127.0.0.1:" + listener.getLocalPort(); + } + + List requests() { + return List.copyOf(requests); + } + + void clearRequests() { + requests.clear(); + } + + @Override + public void close() throws Exception { + running = false; + listener.close(); + thread.join(2000); + } + + private void run() { + try (Socket socket = listener.accept()) { + var in = socket.getInputStream(); + var out = socket.getOutputStream(); + while (running) { + JsonNode message = readMessage(in); + if (message == null) { + return; + } + String method = message.path("method").asText(); + requests.add(new RpcRequest(method, message.path("params").deepCopy())); + sendResponse(out, message.path("id").asLong(), resultFor(method, message.path("params"))); + } + } catch (Exception ex) { + if (running) { + throw new RuntimeException(ex); + } + } + } + + private static JsonNode resultFor(String method, JsonNode params) { + ObjectNode result = MAPPER.createObjectNode(); + switch (method) { + case "connect" -> { + result.put("ok", true); + result.put("protocolVersion", 3); + result.put("version", "test"); + } + case "session.create", "session.resume" -> { + String sessionId = params.path("sessionId").asText("server-assigned-session"); + if (sessionId.isEmpty()) { + sessionId = "server-assigned-session"; + } + result.put("sessionId", sessionId); + result.putNull("workspacePath"); + result.putNull("capabilities"); + } + case "session.eventLog.registerInterest" -> result.put("id", "interest-1"); + case "session.options.update" -> result.put("success", true); + case "session.skills.reload", "session.destroy" -> { + } + default -> throw new IllegalStateException("Unexpected RPC method " + method); + } + return result; + } + + private static JsonNode readMessage(java.io.InputStream in) throws Exception { + StringBuilder header = new StringBuilder(); + int b; + while ((b = in.read()) != -1) { + header.append((char) b); + if (header.toString().endsWith("\r\n\r\n")) { + break; + } + } + if (b == -1) { + return null; + } + int contentLength = 0; + for (String line : header.toString().split("\r\n")) { + int colon = line.indexOf(':'); + if (colon > 0 && "Content-Length".equals(line.substring(0, colon))) { + contentLength = Integer.parseInt(line.substring(colon + 1).trim()); + } + } + byte[] body = in.readNBytes(contentLength); + return MAPPER.readTree(body); + } + + private static void sendResponse(OutputStream out, long id, JsonNode result) throws Exception { + ObjectNode response = MAPPER.createObjectNode(); + response.put("jsonrpc", "2.0"); + response.put("id", id); + response.set("result", result); + byte[] body = MAPPER.writeValueAsBytes(response); + out.write(("Content-Length: " + body.length + "\r\n\r\n").getBytes(StandardCharsets.UTF_8)); + out.write(body); + out.flush(); + } + } +} diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 96ac60842..b764c8273 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -1249,7 +1249,8 @@ export class CopilotClient { sessionId, this.connection!, undefined, - this.onGetTraceContext + this.onGetTraceContext, + { mcpAuthHandler: config.onMcpAuthRequest } ); s.registerTools(config.tools); s.registerCanvases(config.canvases); @@ -1392,6 +1393,12 @@ export class CopilotClient { session = initializeSession(returnedSessionId); registeredId = returnedSessionId; } + if (config.onMcpAuthRequest) { + await this.connection!.sendRequest("session.eventLog.registerInterest", { + sessionId: returnedSessionId, + eventType: "mcp.oauth_required", + }); + } session["_workspacePath"] = workspacePath; session.setCapabilities(capabilities); @@ -1441,7 +1448,8 @@ export class CopilotClient { sessionId, this.connection!, undefined, - this.onGetTraceContext + this.onGetTraceContext, + { mcpAuthHandler: config.onMcpAuthRequest } ); session.registerTools(config.tools); session.registerCanvases(config.canvases); @@ -1478,6 +1486,12 @@ export class CopilotClient { } this.sessions.set(sessionId, session); this.setupSessionFs(session, config); + if (config.onMcpAuthRequest) { + await this.connection!.sendRequest("session.eventLog.registerInterest", { + sessionId, + eventType: "mcp.oauth_required", + }); + } const toolFilterOptions = this.resolveToolFilterOptions(config); diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index 0ba42ab76..d6c580c84 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -10,7 +10,11 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; import { ConnectionError, ErrorCodes, ResponseError } from "vscode-jsonrpc/node.js"; import { createSessionRpc } from "./generated/rpc.js"; -import type { ClientSessionApiHandlers, CanvasActionInvokeResult } from "./generated/rpc.js"; +import type { + ClientSessionApiHandlers, + CanvasActionInvokeResult, + McpOauthPendingRequestResponse, +} from "./generated/rpc.js"; import { type Canvas, CanvasError } from "./canvas.js"; import type { OpenCanvasInstance } from "./generated/rpc.js"; import { getTraceContext } from "./telemetry.js"; @@ -28,6 +32,8 @@ import type { ExitPlanModeResult, UiInputOptions, MessageOptions, + McpAuthHandler, + McpAuthRequest, PermissionHandler, PermissionRequest, ContextTier, @@ -124,6 +130,7 @@ export class CopilotSession { private canvases: Map = new Map(); private commandHandlers: Map = new Map(); private permissionHandler?: PermissionHandler; + private mcpAuthHandler?: McpAuthHandler; private userInputHandler?: UserInputHandler; private elicitationHandler?: ElicitationHandler; private exitPlanModeHandler?: ExitPlanModeHandler; @@ -152,9 +159,11 @@ export class CopilotSession { public readonly sessionId: string, private connection: MessageConnection, private _workspacePath?: string, - traceContextProvider?: TraceContextProvider + traceContextProvider?: TraceContextProvider, + options?: { mcpAuthHandler?: McpAuthHandler } ) { this.traceContextProvider = traceContextProvider; + this.mcpAuthHandler = options?.mcpAuthHandler; } /** @@ -499,6 +508,19 @@ export class CopilotSession { if (this.permissionHandler) { void this._executePermissionAndRespond(requestId, permissionRequest); } + } else if (event.type === "mcp.oauth_required") { + const data = event.data as McpAuthRequest | undefined; + if (!data?.requestId) { + return; + } + if (!this.mcpAuthHandler) { + console.warn( + "Received MCP OAuth request without a registered MCP auth handler. " + + `SessionId=${this.sessionId}, RequestId=${data.requestId}` + ); + return; + } + void this._executeMcpAuthAndRespond(data); } else if (event.type === "command.execute") { const { requestId, commandName, command, args } = event.data as { requestId: string; @@ -661,6 +683,35 @@ export class CopilotSession { } } + /** + * Executes an MCP auth handler and sends the result back via RPC. + * @internal + */ + private async _executeMcpAuthAndRespond(request: McpAuthRequest): Promise { + try { + const result = await this.mcpAuthHandler!(request, { sessionId: this.sessionId }); + const response: McpOauthPendingRequestResponse = + result && "accessToken" in result + ? { kind: "token", ...result } + : { kind: "cancelled" }; + await this.rpc.mcp.oauth.handlePendingRequest({ + requestId: request.requestId, + result: response, + }); + } catch (_error) { + try { + await this.rpc.mcp.oauth.handlePendingRequest({ + requestId: request.requestId, + result: { kind: "cancelled" }, + }); + } catch (rpcError) { + if (!(rpcError instanceof ConnectionError || rpcError instanceof ResponseError)) { + throw rpcError; + } + } + } + } + /** * Executes a command handler and sends the result back via RPC. * @internal diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index bdf02a7b0..b3c6af529 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -1607,6 +1607,74 @@ export type ReasoningEffort = "low" | "medium" | "high" | "xhigh"; */ export type ContextTier = "default" | "long_context"; +/** Parsed parameters from an MCP server's WWW-Authenticate response. */ +export interface McpAuthWwwAuthenticateParams { + /** Parsed resource_metadata URL used for protected-resource metadata discovery. */ + resourceMetadataUrl: string; + /** Parsed OAuth scope, if present. */ + scope?: string; + /** Parsed OAuth error, if present. */ + error?: string; +} + +/** Static OAuth client configuration supplied by the MCP server, if available. */ +export interface McpAuthStaticClientConfig { + /** OAuth client ID for the server. */ + clientId: string; + /** Optional non-default OAuth grant type. */ + grantType?: "client_credentials"; + /** Whether this is a public OAuth client. */ + publicClient?: boolean; +} + +/** MCP OAuth request that the SDK host can satisfy with a host-acquired token. */ +export interface McpAuthRequest { + /** Unique request identifier used by the SDK when responding. */ + requestId: string; + /** Display name of the MCP server that requires OAuth. */ + serverName: string; + /** URL of the MCP server that requires OAuth. */ + serverUrl: string; + /** Parsed WWW-Authenticate parameters from the MCP server. */ + wwwAuthenticateParams?: McpAuthWwwAuthenticateParams; + /** Raw RFC 9728 protected-resource metadata JSON fetched by the runtime, if available. */ + resourceMetadata?: string; + /** Static OAuth client configuration, if the server specifies one. */ + staticClientConfig?: McpAuthStaticClientConfig; +} + +/** Host-provided OAuth token data for a pending MCP OAuth request. */ +export interface McpAuthToken { + /** Access token acquired by the SDK host. */ + accessToken: string; + /** OAuth token type. Defaults to Bearer when omitted. */ + tokenType?: string; + /** Refresh token supplied by the host, if available. */ + refreshToken?: string; + /** Token lifetime in seconds, if known. */ + expiresIn?: number; +} + +/** + * Result returned by an MCP auth request handler. + * + * Return `null`/`undefined` or `{ kind: "cancelled" }` to cancel the pending + * OAuth request. Return `{ kind: "token", ... }` to provide host-acquired + * OAuth token data. + */ +export type McpAuthResult = ({ kind: "token" } & McpAuthToken) | { kind: "cancelled" }; + +/** Callback invoked when an MCP server requires OAuth and the SDK host opted in. */ +export type McpAuthHandler = ( + request: McpAuthRequest, + context: { sessionId: string } +) => + | McpAuthResult + | McpAuthToken + | null + | undefined + | Promise; + /** * Stable extension identity for session participants that provide canvases. */ @@ -1890,6 +1958,13 @@ export interface SessionConfigBase { */ onPermissionRequest?: PermissionHandler; + /** + * Optional handler for MCP OAuth requests from MCP servers. + * When provided, the SDK can satisfy MCP server OAuth requests with + * host-provided token data or cancellation. + */ + onMcpAuthRequest?: McpAuthHandler; + /** * Handler for user input requests from the agent. * When provided, enables the ask_user tool allowing the agent to ask questions. diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index ef804095f..134a0c862 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -24,6 +24,213 @@ describe("CopilotClient", () => { expect(spy).not.toHaveBeenCalled(); }); + it("responds to MCP OAuth requests with host token data", async () => { + const sendRequest = vi.fn(async () => ({ success: true })); + let observedRequest: any; + const session = new CopilotSession( + "session-1", + { sendRequest } as any, + undefined, + undefined, + { + mcpAuthHandler: async (request) => { + observedRequest = request; + return { + accessToken: "host-token", + tokenType: "Bearer", + expiresIn: 3600, + }; + }, + } + ); + + await (session as any)._executeMcpAuthAndRespond({ + requestId: "oauth-request", + serverName: "oauth-server", + serverUrl: "https://example.com/mcp", + wwwAuthenticateParams: { + resourceMetadataUrl: "https://example.com/.well-known/oauth-protected-resource", + }, + resourceMetadata: '{"resource":"https://example.com/mcp"}', + }); + + expect(observedRequest.resourceMetadata).toBe('{"resource":"https://example.com/mcp"}'); + expect(sendRequest).toHaveBeenCalledWith("session.mcp.oauth.handlePendingRequest", { + sessionId: "session-1", + requestId: "oauth-request", + result: { + kind: "token", + accessToken: "host-token", + tokenType: "Bearer", + expiresIn: 3600, + }, + }); + }); + + it("passes MCP OAuth requests through when optional metadata is absent", async () => { + let observedRequest: any; + const session = new CopilotSession( + "session-1", + { sendRequest: vi.fn(async () => ({ success: true })) } as any, + undefined, + undefined, + { + mcpAuthHandler: async (request) => { + observedRequest = request; + return { kind: "cancelled" }; + }, + } + ); + + await (session as any)._executeMcpAuthAndRespond({ + requestId: "oauth-request", + serverName: "oauth-server", + serverUrl: "https://example.com/mcp", + }); + + expect(observedRequest.resourceMetadata).toBeUndefined(); + expect(observedRequest.wwwAuthenticateParams).toBeUndefined(); + }); + + it("registers interest in MCP OAuth required events after create when an auth handler is configured", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.eventLog.registerInterest") { + return { id: "interest-1" }; + } + if (method === "session.create") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.createSession({ + onPermissionRequest: approveAll, + onMcpAuthRequest: () => ({ kind: "cancelled" }), + }); + + expect(spy.mock.calls[0][0]).toBe("session.create"); + expect(spy.mock.calls[1]).toEqual([ + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }), + ]); + expect(spy.mock.calls[1][1].sessionId).toBe(spy.mock.calls[0][1].sessionId); + }); + + it("does not register MCP OAuth interest without an auth handler", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.create") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.createSession({ + onPermissionRequest: approveAll, + onEvent: () => {}, + }); + + expect(spy).not.toHaveBeenCalledWith( + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }) + ); + expect(spy).toHaveBeenCalledWith( + "session.create", + expect.objectContaining({ requestPermission: true }) + ); + }); + + it("registers MCP OAuth interest after cloud create only when an auth handler is configured", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + let cloudCreateCount = 0; + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.eventLog.registerInterest") { + return { id: "interest-1" }; + } + if (method === "session.create") + return { sessionId: `server-assigned-session-${++cloudCreateCount}` }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.createSession({ + onPermissionRequest: approveAll, + cloud: { repository: { owner: "github", name: "copilot-sdk", branch: "main" } }, + }); + + expect(spy).not.toHaveBeenCalledWith( + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }) + ); + + spy.mockClear(); + await client.createSession({ + onPermissionRequest: approveAll, + onMcpAuthRequest: () => ({ kind: "cancelled" }), + cloud: { repository: { owner: "github", name: "copilot-sdk", branch: "main" } }, + }); + + expect(spy.mock.calls[0][0]).toBe("session.create"); + expect(spy.mock.calls[1]).toEqual([ + "session.eventLog.registerInterest", + { sessionId: "server-assigned-session-2", eventType: "mcp.oauth_required" }, + ]); + }); + + it("registers MCP OAuth interest before resuming only when an auth handler is configured", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.eventLog.registerInterest") { + return { id: "interest-1" }; + } + if (method === "session.resume") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.resumeSession("session-with-auth", { + onPermissionRequest: approveAll, + onMcpAuthRequest: () => ({ kind: "cancelled" }), + }); + + expect(spy.mock.calls[0]).toEqual([ + "session.eventLog.registerInterest", + { sessionId: "session-with-auth", eventType: "mcp.oauth_required" }, + ]); + expect(spy.mock.calls[1][0]).toBe("session.resume"); + expect(spy.mock.calls[1][1]).toEqual(expect.objectContaining({ requestPermission: true })); + + spy.mockClear(); + await client.resumeSession("session-without-auth", { + onPermissionRequest: approveAll, + onEvent: () => {}, + }); + + expect(spy).not.toHaveBeenCalledWith( + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }) + ); + expect(spy).toHaveBeenCalledWith( + "session.resume", + expect.objectContaining({ sessionId: "session-without-auth", requestPermission: true }) + ); + }); + it("forwards canvas declarations and request flags in session.create", async () => { const client = new CopilotClient(); await client.start(); diff --git a/nodejs/test/e2e/mcp_oauth.e2e.test.ts b/nodejs/test/e2e/mcp_oauth.e2e.test.ts new file mode 100644 index 000000000..52dc902ed --- /dev/null +++ b/nodejs/test/e2e/mcp_oauth.e2e.test.ts @@ -0,0 +1,172 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { spawn, type ChildProcessWithoutNullStreams } from "node:child_process"; +import { dirname, resolve } from "node:path"; +import { createInterface } from "node:readline"; +import { fileURLToPath } from "node:url"; +import { describe, expect, it, onTestFinished } from "vitest"; +import type { CopilotSession, MCPServerConfig, McpAuthRequest } from "../../src/index.js"; +import { approveAll } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; +import { waitForCondition } from "./harness/sdkTestHelper.js"; + +const __filename = fileURLToPath(import.meta.url); +const __dirname = dirname(__filename); +const TEST_MCP_OAUTH_SERVER = resolve(__dirname, "../../../test/harness/test-mcp-oauth-server.mjs"); +const EXPECTED_TOKEN = "sdk-host-token"; + +describe("MCP OAuth host auth", async () => { + const { copilotClient: client } = await createSdkTestContext(); + + it("should satisfy MCP OAuth using host-provided token", { timeout: 120_000 }, async () => { + const oauthServer = await startOAuthMcpServer(); + const serverName = "oauth-protected-mcp"; + let authRequest: McpAuthRequest | undefined; + + const session = await client.createSession({ + onPermissionRequest: approveAll, + onMcpAuthRequest: async (request) => { + authRequest = request; + return { + kind: "token", + accessToken: EXPECTED_TOKEN, + tokenType: "Bearer", + expiresIn: 3600, + }; + }, + mcpServers: { + [serverName]: { + type: "http", + url: `${oauthServer.url}/mcp`, + tools: ["*"], + oauthClientId: "sdk-e2e-client", + oauthPublicClient: true, + } as unknown as MCPServerConfig, + }, + }); + onTestFinished(() => disconnectSession(session)); + + await waitForMcpServerStatus(session, serverName); + + const tools = await session.rpc.mcp.listTools({ serverName }); + expect(tools.tools.map((tool) => tool.name)).toContain("whoami"); + + expect(authRequest).toMatchObject({ + requestId: expect.any(String), + serverName, + serverUrl: `${oauthServer.url}/mcp`, + wwwAuthenticateParams: { + resourceMetadataUrl: `${oauthServer.url}/.well-known/oauth-protected-resource`, + scope: "mcp.read", + error: "invalid_token", + }, + resourceMetadata: JSON.stringify({ + resource: `${oauthServer.url}/mcp`, + authorization_servers: [oauthServer.url], + scopes_supported: ["mcp.read"], + bearer_methods_supported: ["header"], + }), + }); + + const requests = await oauthServer.requests(); + expect(requests.some((request) => request.authorization === null)).toBe(true); + expect( + requests.some((request) => request.authorization === `Bearer ${EXPECTED_TOKEN}`) + ).toBe(true); + }); +}); + +async function waitForMcpServerStatus( + session: CopilotSession, + serverName: string, + expectedStatus = "connected" +): Promise { + let lastStatus = ""; + await waitForCondition( + async () => { + const result = await session.rpc.mcp.list(); + const server = result.servers.find((entry) => entry.name === serverName); + lastStatus = server?.status ?? ""; + return server?.status === expectedStatus; + }, + { + timeoutMs: 60_000, + intervalMs: 200, + timeoutMessage: `${serverName} did not reach ${expectedStatus}; last status was ${lastStatus}`, + } + ); +} + +async function startOAuthMcpServer(): Promise<{ + url: string; + requests: () => Promise>; +}> { + const child = spawn(process.execPath, [TEST_MCP_OAUTH_SERVER], { + env: { ...process.env, EXPECTED_TOKEN }, + stdio: ["ignore", "pipe", "pipe"], + }); + onTestFinished(() => stopChild(child)); + + const stderr: string[] = []; + child.stderr.on("data", (chunk) => stderr.push(String(chunk))); + + const url = await new Promise((resolvePromise, reject) => { + const rl = createInterface({ input: child.stdout }); + const timeout = setTimeout(() => { + rl.close(); + reject(new Error(`Timed out waiting for OAuth MCP server. ${stderr.join("")}`)); + }, 10_000); + + child.once("exit", (code, signal) => { + clearTimeout(timeout); + rl.close(); + reject( + new Error( + `OAuth MCP server exited before listening. code=${code} signal=${signal} ${stderr.join("")}` + ) + ); + }); + + rl.on("line", (line) => { + const match = /^Listening: (.+)$/.exec(line); + if (!match) { + return; + } + clearTimeout(timeout); + rl.close(); + resolvePromise(match[1]); + }); + }); + + return { + url, + requests: async () => { + const response = await fetch(`${url}/__requests`); + if (!response.ok) { + throw new Error(`Failed to fetch OAuth MCP requests: ${response.status}`); + } + return response.json(); + }, + }; +} + +async function disconnectSession(session: CopilotSession): Promise { + try { + await session.disconnect(); + } catch { + // Best-effort cleanup. + } +} + +function stopChild(child: ChildProcessWithoutNullStreams): Promise { + if (child.exitCode !== null || child.killed) { + return Promise.resolve(); + } + const exitPromise = new Promise((resolvePromise) => { + child.once("exit", () => resolvePromise()); + }); + child.kill("SIGTERM"); + return exitPromise; +} diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 06ecf4188..e4b105656 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -106,6 +106,12 @@ MCPHTTPServerConfig, MCPServerConfig, MCPStdioServerConfig, + McpAuthHandler, + McpAuthRequest, + McpAuthResult, + McpAuthStaticClientConfig, + McpAuthToken, + McpAuthWwwAuthenticateParams, ModelCapabilitiesOverride, ModelLimitsOverride, ModelSupportsOverride, @@ -223,6 +229,12 @@ "MCPHTTPServerConfig", "MCPServerConfig", "MCPStdioServerConfig", + "McpAuthHandler", + "McpAuthRequest", + "McpAuthResult", + "McpAuthStaticClientConfig", + "McpAuthToken", + "McpAuthWwwAuthenticateParams", "ModelBilling", "ModelBillingTokenPrices", "ModelBillingTokenPricesLongContext", diff --git a/python/copilot/client.py b/python/copilot/client.py index 5dc670903..790ed5b99 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -93,6 +93,7 @@ LargeToolOutputConfig, MCPServerConfig, MemoryConfiguration, + McpAuthHandler, ModelCapabilitiesOverride, NamedProviderConfig, ProviderConfig, @@ -1666,6 +1667,7 @@ async def create_session( on_event: Callable[[SessionEvent], None] | None = None, commands: list[CommandDefinition] | None = None, on_elicitation_request: ElicitationHandler | None = None, + on_mcp_auth_request: McpAuthHandler | None = None, enable_mcp_apps: bool = False, on_exit_plan_mode_request: ExitPlanModeHandler | None = None, on_auto_mode_switch_request: AutoModeSwitchHandler | None = None, @@ -2102,6 +2104,7 @@ def _initialize_session(sid: str) -> CopilotSession: s._register_tools(tools) s._register_commands(commands) s._register_permission_handler(on_permission_request) + s._register_mcp_auth_handler(on_mcp_auth_request) if on_user_input_request: s._register_user_input_handler(on_user_input_request) if on_elicitation_request: @@ -2142,6 +2145,11 @@ def _initialize_session(sid: str) -> CopilotSession: if local_session_id is not None: session = _initialize_session(local_session_id) registered_session_id = local_session_id + if on_mcp_auth_request is not None: + await self._client.request( + "session.eventLog.registerInterest", + {"sessionId": local_session_id, "eventType": "mcp.oauth_required"}, + ) try: rpc_start = time.perf_counter() @@ -2181,6 +2189,12 @@ def _register_inline(raw_response: Any) -> None: f"session.create returned sessionId {response.get('sessionId')} " f"but the caller requested {local_session_id}" ) + # Local IDs registered before create; server-assigned IDs can only register now. + if local_session_id is None and on_mcp_auth_request is not None: + await self._client.request( + "session.eventLog.registerInterest", + {"sessionId": session.session_id, "eventType": "mcp.oauth_required"}, + ) session._workspace_path = response.get("workspacePath") capabilities = response.get("capabilities") session._set_capabilities(capabilities) @@ -2271,6 +2285,7 @@ async def resume_session( on_event: Callable[[SessionEvent], None] | None = None, commands: list[CommandDefinition] | None = None, on_elicitation_request: ElicitationHandler | None = None, + on_mcp_auth_request: McpAuthHandler | None = None, enable_mcp_apps: bool = False, on_exit_plan_mode_request: ExitPlanModeHandler | None = None, on_auto_mode_switch_request: AutoModeSwitchHandler | None = None, @@ -2659,6 +2674,7 @@ async def resume_session( session._register_tools(tools) session._register_commands(commands) session._register_permission_handler(on_permission_request) + session._register_mcp_auth_handler(on_mcp_auth_request) if on_user_input_request: session._register_user_input_handler(on_user_input_request) if on_elicitation_request: @@ -2677,6 +2693,11 @@ async def resume_session( session.on(on_event) with self._sessions_lock: self._sessions[session_id] = session + if on_mcp_auth_request is not None: + await self._client.request( + "session.eventLog.registerInterest", + {"sessionId": session_id, "eventType": "mcp.oauth_required"}, + ) log_timing( logger, logging.DEBUG, diff --git a/python/copilot/session.py b/python/copilot/session.py index f15d6c7d3..fa004b263 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -39,6 +39,9 @@ ExternalToolTextResultForLlm, HandlePendingToolCallRequest, LogRequest, + MCPOauthHandlePendingRequest, + MCPOauthPendingRequestResponse, + MCPOauthPendingRequestResponseKind, ModelSwitchToRequest, PermissionDecision, PermissionDecisionApproveOnce, @@ -65,6 +68,7 @@ CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, + McpOauthRequiredData, PermissionRequest, PermissionRequestedData, SessionCanvasClosedData, @@ -354,6 +358,63 @@ def approve_all( return PermissionDecisionApproveOnce() +# ============================================================================ +# MCP Auth Types +# ============================================================================ + + +class McpAuthWwwAuthenticateParams(TypedDict, total=False): + """Parsed parameters from an MCP server's WWW-Authenticate response.""" + + resourceMetadataUrl: Required[str] + scope: str + error: str + + +class McpAuthStaticClientConfig(TypedDict, total=False): + """Static OAuth client configuration supplied by the MCP server, if available.""" + + clientId: Required[str] + grantType: Literal["client_credentials"] + publicClient: bool + + +class McpAuthRequest(TypedDict, total=False): + """MCP OAuth request that the SDK host can satisfy with a host-acquired token.""" + + requestId: Required[str] + serverName: Required[str] + serverUrl: Required[str] + wwwAuthenticateParams: McpAuthWwwAuthenticateParams + resourceMetadata: str + staticClientConfig: McpAuthStaticClientConfig + + +class McpAuthToken(TypedDict, total=False): + """Host-provided OAuth token data for a pending MCP OAuth request.""" + + accessToken: Required[str] + tokenType: str + refreshToken: str + expiresIn: int + + +class McpAuthResult(TypedDict, total=False): + """Result returned by an MCP auth request handler.""" + + kind: Required[Literal["token", "cancelled"]] + accessToken: str + tokenType: str + refreshToken: str + expiresIn: int + + +McpAuthHandler = Callable[ + [McpAuthRequest, dict[str, str]], + McpAuthResult | McpAuthToken | None | Awaitable[McpAuthResult | McpAuthToken | None], +] + + # ============================================================================ # User Input Request Types # ============================================================================ @@ -1252,6 +1313,8 @@ def __init__( self._tool_handlers_lock = threading.Lock() self._permission_handler: _PermissionHandlerFn | None = None self._permission_handler_lock = threading.Lock() + self._mcp_auth_handler: McpAuthHandler | None = None + self._mcp_auth_handler_lock = threading.Lock() self._user_input_handler: UserInputHandler | None = None self._user_input_handler_lock = threading.Lock() self._exit_plan_mode_handler: ExitPlanModeHandler | None = None @@ -1639,6 +1702,45 @@ def _handle_broadcast_event(self, event: SessionEvent) -> None: ) ) + case McpOauthRequiredData() as data: + with self._mcp_auth_handler_lock: + handler = self._mcp_auth_handler + if not data.request_id: + return + if not handler: + logger.warning( + "Received MCP OAuth request without a registered MCP auth handler. " + "SessionId=%s, RequestId=%s", + self.session_id, + data.request_id, + ) + return + request: McpAuthRequest = { + "requestId": data.request_id, + "serverName": data.server_name, + "serverUrl": data.server_url, + } + if data.www_authenticate_params is not None: + request["wwwAuthenticateParams"] = { + "resourceMetadataUrl": data.www_authenticate_params.resource_metadata_url, + } + if data.www_authenticate_params.scope is not None: + request["wwwAuthenticateParams"]["scope"] = data.www_authenticate_params.scope + if data.www_authenticate_params.error is not None: + request["wwwAuthenticateParams"]["error"] = data.www_authenticate_params.error + if data.resource_metadata is not None: + request["resourceMetadata"] = data.resource_metadata + if data.static_client_config is not None: + static_client_config: McpAuthStaticClientConfig = { + "clientId": data.static_client_config.client_id, + } + if data.static_client_config.grant_type is not None: + static_client_config["grantType"] = data.static_client_config.grant_type + if data.static_client_config.public_client is not None: + static_client_config["publicClient"] = data.static_client_config.public_client + request["staticClientConfig"] = static_client_config + asyncio.ensure_future(self._execute_mcp_auth_and_respond(request, handler)) + case CommandExecuteData() as data: request_id = data.request_id command_name = data.command_name @@ -1857,6 +1959,58 @@ async def _execute_permission_and_respond( except (JsonRpcError, ProcessExitedError, OSError): pass # Connection lost or RPC error — nothing we can do + async def _execute_mcp_auth_and_respond( + self, + request: McpAuthRequest, + handler: McpAuthHandler, + ) -> None: + """Execute an MCP auth handler and respond via RPC.""" + request_id = request["requestId"] + try: + handler_start = time.perf_counter() + result = handler(request, {"session_id": self.session_id}) + if inspect.isawaitable(result): + result = await result + log_timing( + logger, + logging.DEBUG, + "CopilotSession._execute_mcp_auth_and_respond dispatch", + handler_start, + session_id=self.session_id, + request_id=request_id, + ) + + if result and result.get("kind", "token") == "token": + rpc_result = MCPOauthPendingRequestResponse( + kind=MCPOauthPendingRequestResponseKind.TOKEN, + access_token=result["accessToken"], + expires_in=result.get("expiresIn"), + refresh_token=result.get("refreshToken"), + token_type=result.get("tokenType"), + ) + else: + rpc_result = MCPOauthPendingRequestResponse( + kind=MCPOauthPendingRequestResponseKind.CANCELLED + ) + await self.rpc.mcp.oauth.handle_pending_request( + MCPOauthHandlePendingRequest( + request_id=request_id, + result=rpc_result, + ) + ) + except Exception: + try: + await self.rpc.mcp.oauth.handle_pending_request( + MCPOauthHandlePendingRequest( + request_id=request_id, + result=MCPOauthPendingRequestResponse( + kind=MCPOauthPendingRequestResponseKind.CANCELLED + ), + ) + ) + except (JsonRpcError, ProcessExitedError, OSError): + pass # Connection lost or RPC error — nothing we can do + async def _execute_command_and_respond( self, request_id: str, @@ -2019,6 +2173,11 @@ def _register_elicitation_handler(self, handler: ElicitationHandler | None) -> N with self._elicitation_handler_lock: self._elicitation_handler = handler + def _register_mcp_auth_handler(self, handler: McpAuthHandler | None) -> None: + """Register the MCP auth handler for this session.""" + with self._mcp_auth_handler_lock: + self._mcp_auth_handler = handler + def _register_exit_plan_mode_handler(self, handler: ExitPlanModeHandler | None) -> None: """Register the exit-plan-mode handler for this session.""" with self._exit_plan_mode_handler_lock: diff --git a/python/test_client.py b/python/test_client.py index acb5c0733..eb7f37d1c 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -4,6 +4,7 @@ This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.py instead. """ +import asyncio from datetime import UTC, datetime from unittest.mock import AsyncMock, Mock, patch @@ -28,6 +29,12 @@ ModelSupports, ) from copilot.session import PermissionHandler +from copilot.session_events import ( + McpOauthRequiredData, + McpOauthWWWAuthenticateParams, + SessionEvent, + SessionEventType, +) from e2e.testharness import CLI_PATH @@ -139,6 +146,292 @@ async def test_resume_session_allows_none_permission_handler(self): class TestCreateSessionConfig: + @pytest.mark.asyncio + async def test_mcp_auth_handler_registers_interest_in_create_session(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + captured.append((method, params)) + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + if method == "session.create": + result = {"sessionId": params["sessionId"], "workspacePath": None} + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + return {} + + client._client.request = mock_request + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=lambda request: {"kind": "cancelled"}, + ) + + interest_method, interest_payload = captured[0] + create_method, create_payload = captured[1] + assert interest_method == "session.eventLog.registerInterest" + assert interest_payload["eventType"] == "mcp.oauth_required" + assert create_method == "session.create" + assert interest_payload["sessionId"] == create_payload["sessionId"] + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_interest_is_not_registered_without_handler(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + captured.append((method, params)) + if method == "session.create": + result = {"sessionId": params["sessionId"], "workspacePath": None} + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + if method == "session.resume": + return {"sessionId": params["sessionId"], "workspacePath": None} + return {} + + client._client.request = mock_request + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_event=lambda event: None, + ) + await client.resume_session( + "session-without-auth", + on_permission_request=PermissionHandler.approve_all, + on_event=lambda event: None, + ) + + assert session.session_id + assert not any( + method == "session.eventLog.registerInterest" + and params["eventType"] == "mcp.oauth_required" + for method, params in captured + ) + assert any( + method == "session.create" and params["requestPermission"] is True + for method, params in captured + ) + assert any( + method == "session.resume" and params["requestPermission"] is True + for method, params in captured + ) + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_handler_registers_interest_before_resume(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + captured.append((method, params)) + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + if method == "session.resume": + return {"sessionId": params["sessionId"], "workspacePath": None} + return {} + + client._client.request = mock_request + await client.resume_session( + "session-with-auth", + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=lambda request: {"kind": "cancelled"}, + ) + + interest_method, interest_payload = captured[0] + resume_method, resume_payload = captured[1] + assert interest_method == "session.eventLog.registerInterest" + assert interest_payload == { + "sessionId": "session-with-auth", + "eventType": "mcp.oauth_required", + } + assert resume_method == "session.resume" + assert resume_payload["requestPermission"] is True + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_handler_registers_interest_after_cloud_create_only_with_handler(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + create_count = 0 + + async def mock_request(method, params, **kwargs): + nonlocal create_count + captured.append((method, params)) + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + if method == "session.create": + create_count += 1 + result = { + "sessionId": f"server-assigned-session-{create_count}", + "workspacePath": None, + } + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + return {} + + cloud = CloudSessionOptions( + repository=CloudSessionRepository( + owner="github", + name="copilot-sdk", + branch="main", + ) + ) + + client._client.request = mock_request + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + cloud=cloud, + ) + + assert not any( + method == "session.eventLog.registerInterest" + and params["eventType"] == "mcp.oauth_required" + for method, params in captured + ) + + captured.clear() + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=lambda request: {"kind": "cancelled"}, + cloud=cloud, + ) + + create_method, _create_payload = captured[0] + interest_method, interest_payload = captured[1] + assert create_method == "session.create" + assert interest_method == "session.eventLog.registerInterest" + assert interest_payload == { + "sessionId": "server-assigned-session-2", + "eventType": "mcp.oauth_required", + } + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_required_event_sends_host_token(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + if method == "session.mcp.oauth.handlePendingRequest": + captured.append((method, params)) + return {"success": True} + if method == "session.create": + result = {"sessionId": params["sessionId"], "workspacePath": None} + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + return {} + + client._client.request = mock_request + observed_request = None + + def handle_mcp_auth_request(request, invocation): + nonlocal observed_request + observed_request = request + return { + "accessToken": "host-token", + "tokenType": "Bearer", + } + + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=handle_mcp_auth_request, + ) + + session._dispatch_event( + SessionEvent( + data=McpOauthRequiredData( + request_id="oauth-request", + server_name="oauth-server", + server_url="https://example.com/mcp", + www_authenticate_params=McpOauthWWWAuthenticateParams( + resource_metadata_url="https://example.com/.well-known/oauth-protected-resource" + ), + resource_metadata='{"resource":"https://example.com/mcp"}', + ), + id="evt-1", + timestamp="2026-01-01T00:00:00Z", + type=SessionEventType.MCP_OAUTH_REQUIRED, + ephemeral=True, + parent_id=None, + ) + ) + + for _ in range(200): + if captured: + break + await asyncio.sleep(0.005) + + assert observed_request is not None + assert observed_request["resourceMetadata"] == '{"resource":"https://example.com/mcp"}' + assert observed_request["wwwAuthenticateParams"]["resourceMetadataUrl"] == ( + "https://example.com/.well-known/oauth-protected-resource" + ) + assert captured == [ + ( + "session.mcp.oauth.handlePendingRequest", + { + "sessionId": session.session_id, + "requestId": "oauth-request", + "result": { + "kind": "token", + "accessToken": "host-token", + "tokenType": "Bearer", + }, + }, + ) + ] + + observed_request = None + session._dispatch_event( + SessionEvent( + data=McpOauthRequiredData( + request_id="oauth-request-without-metadata", + server_name="oauth-server", + server_url="https://example.com/mcp", + ), + id="evt-2", + timestamp="2026-01-01T00:00:00Z", + type=SessionEventType.MCP_OAUTH_REQUIRED, + ephemeral=True, + parent_id=None, + ) + ) + + for _ in range(200): + if observed_request is not None: + break + await asyncio.sleep(0.005) + + assert observed_request is not None + assert "resourceMetadata" not in observed_request + assert "wwwAuthenticateParams" not in observed_request + finally: + await client.force_stop() + @pytest.mark.asyncio async def test_create_session_forwards_cloud_options(self): client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) diff --git a/rust/src/handler.rs b/rust/src/handler.rs index dadd1706f..1d078bcb5 100644 --- a/rust/src/handler.rs +++ b/rust/src/handler.rs @@ -19,9 +19,12 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use crate::generated::api_types::{ - PermissionDecision, PermissionDecisionApproveOnce, PermissionDecisionReject, - PermissionDecisionUserNotAvailable, + McpOauthPendingRequestResponse, McpOauthPendingRequestResponseCancelled, + McpOauthPendingRequestResponseCancelledKind, McpOauthPendingRequestResponseToken, + McpOauthPendingRequestResponseTokenKind, PermissionDecision, PermissionDecisionApproveOnce, + PermissionDecisionReject, PermissionDecisionUserNotAvailable, }; +use crate::session_events::{McpOauthRequiredStaticClientConfig, McpOauthWWWAuthenticateParams}; use crate::types::{ ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId, SessionId, @@ -159,6 +162,75 @@ pub trait ElicitationHandler: Send + Sync + 'static { ) -> ElicitationResult; } +/// MCP OAuth request that the SDK host can satisfy with a host-acquired token. +#[derive(Debug, Clone)] +pub struct McpAuthRequest { + /// Display name of the MCP server that requires OAuth. + pub server_name: String, + /// URL of the MCP server that requires OAuth. + pub server_url: String, + /// Parsed WWW-Authenticate parameters from the MCP server, if available. + pub www_authenticate_params: Option, + /// Raw RFC 9728 protected-resource metadata JSON fetched by the runtime, if available. + pub resource_metadata: Option, + /// Static OAuth client configuration, if the server specifies one. + pub static_client_config: Option, +} + +/// Result returned by an MCP auth request handler. +#[derive(Debug, Clone)] +pub enum McpAuthResult { + /// Supplies host-acquired OAuth token data. + Token { + /// Access token acquired by the SDK host. + access_token: String, + /// OAuth token type. Defaults to Bearer when omitted. + token_type: Option, + /// Refresh token supplied by the host, if available. + refresh_token: Option, + /// Token lifetime in seconds, if known. + expires_in: Option, + }, + /// Declines or cancels the pending OAuth request. + Cancelled, +} + +impl McpAuthResult { + pub(crate) fn into_wire(self) -> McpOauthPendingRequestResponse { + match self { + Self::Token { + access_token, + token_type, + refresh_token, + expires_in, + } => McpOauthPendingRequestResponse::Token(McpOauthPendingRequestResponseToken { + access_token, + token_type, + refresh_token, + expires_in, + kind: McpOauthPendingRequestResponseTokenKind::Token, + }), + Self::Cancelled => { + McpOauthPendingRequestResponse::Cancelled(McpOauthPendingRequestResponseCancelled { + kind: McpOauthPendingRequestResponseCancelledKind::Cancelled, + }) + } + } + } +} + +/// Handler for MCP server OAuth requests. +#[async_trait] +pub trait McpAuthHandler: Send + Sync + 'static { + /// Resolve an MCP OAuth request with host token data or cancellation. + async fn handle( + &self, + session_id: SessionId, + request_id: RequestId, + request: McpAuthRequest, + ) -> McpAuthResult; +} + /// Handler for `user_input.requested` events from the `ask_user` tool. /// /// When unset, `requestUserInput: false` goes on the wire and the @@ -266,4 +338,24 @@ mod tests { PermissionResult::Decision(PermissionDecision::Reject(_)) )); } + + #[test] + fn mcp_auth_result_token_converts_to_wire_response() { + let wire = McpAuthResult::Token { + access_token: "host-token".to_string(), + token_type: Some("Bearer".to_string()), + refresh_token: None, + expires_in: Some(3600), + } + .into_wire(); + + match wire { + McpOauthPendingRequestResponse::Token(token) => { + assert_eq!(token.access_token, "host-token"); + assert_eq!(token.token_type.as_deref(), Some("Bearer")); + assert_eq!(token.expires_in, Some(3600)); + } + McpOauthPendingRequestResponse::Cancelled(_) => panic!("expected token response"), + } + } } diff --git a/rust/src/session.rs b/rust/src/session.rs index fed6705da..790291029 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -11,14 +11,17 @@ use tokio_util::sync::CancellationToken; use tracing::{Instrument, warn}; use crate::canvas::CanvasHandler; -use crate::generated::api_types::{LogRequest, ModelSwitchToRequest, OpenCanvasInstance}; +use crate::generated::api_types::{ + LogRequest, ModelSwitchToRequest, OpenCanvasInstance, RegisterEventInterestParams, rpc_methods, +}; use crate::generated::session_events::{ - CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, + CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, McpOauthRequiredData, SessionCanvasClosedData, SessionErrorData, SessionEventType, }; use crate::handler::{ AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, ExitPlanModeHandler, - PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, + McpAuthHandler, McpAuthRequest, McpAuthResult, PermissionHandler, PermissionResult, + UserInputHandler, UserInputResponse, }; use crate::hooks::SessionHooks; use crate::session_fs::SessionFsProvider; @@ -48,6 +51,7 @@ use crate::{ pub(crate) struct SessionHandlers { pub permission: Option>, pub elicitation: Option>, + pub mcp_auth: Option>, pub user_input: Option>, pub exit_plan_mode: Option>, pub auto_mode_switch: Option>, @@ -880,6 +884,7 @@ impl Client { let handlers = SessionHandlers { permission: permission_handler, elicitation: runtime.elicitation_handler.take(), + mcp_auth: runtime.mcp_auth_handler.take(), user_input: runtime.user_input_handler.take(), exit_plan_mode: runtime.exit_plan_mode_handler.take(), auto_mode_switch: runtime.auto_mode_switch_handler.take(), @@ -893,6 +898,7 @@ impl Client { let command_handlers = build_command_handler_map(runtime.commands.as_deref()); let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); + let has_mcp_auth_handler = handlers.mcp_auth.is_some(); if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } @@ -932,6 +938,9 @@ impl Client { { let channels = self.register_session(sid); *inline_stash.lock() = Some((sid.clone(), channels)); + if has_mcp_auth_handler { + register_mcp_auth_interest(self, sid).await?; + } None } else { let client = self.clone(); @@ -1027,6 +1036,10 @@ impl Client { "Client::create_session local setup complete" ); *capabilities.write() = create_result.capabilities.unwrap_or_default(); + // Local IDs registered before create; server-assigned IDs can only register now. + if has_mcp_auth_handler && local_session_id.is_none() { + register_mcp_auth_interest(self, &session_id).await?; + } tracing::debug!( elapsed_ms = total_start.elapsed().as_millis(), @@ -1136,6 +1149,7 @@ impl Client { let handlers = SessionHandlers { permission: permission_handler, elicitation: runtime.elicitation_handler.take(), + mcp_auth: runtime.mcp_auth_handler.take(), user_input: runtime.user_input_handler.take(), exit_plan_mode: runtime.exit_plan_mode_handler.take(), auto_mode_switch: runtime.auto_mode_switch_handler.take(), @@ -1149,6 +1163,7 @@ impl Client { let command_handlers = build_command_handler_map(runtime.commands.as_deref()); let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); + let has_mcp_auth_handler = handlers.mcp_auth.is_some(); if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } @@ -1166,6 +1181,9 @@ impl Client { let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); + if has_mcp_auth_handler { + register_mcp_auth_interest(self, &session_id).await?; + } let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default())); let setup_start = Instant::now(); @@ -1470,6 +1488,17 @@ fn notification_permission_payload(result: &PermissionResult) -> Option { } } +async fn register_mcp_auth_interest(client: &Client, session_id: &SessionId) -> Result<(), Error> { + let mut params = serde_json::to_value(RegisterEventInterestParams { + event_type: "mcp.oauth_required".to_string(), + })?; + params["sessionId"] = Value::String(session_id.to_string()); + client + .call(rpc_methods::SESSION_EVENTLOG_REGISTERINTEREST, Some(params)) + .await?; + Ok(()) +} + fn tool_failure_result(message: impl Into) -> ToolResult { let message = message.into(); ToolResult::Expanded(ToolResultExpanded { @@ -1937,6 +1966,89 @@ async fn handle_notification( .instrument(span), ); } + SessionEventType::McpOauthRequired => { + let Some(request_id) = extract_request_id(¬ification.event.data) else { + return; + }; + let Some(mcp_auth_handler) = handlers.mcp_auth.clone() else { + warn!( + session_id = %session_id, + request_id = %request_id, + "received MCP OAuth request without a registered MCP auth handler" + ); + return; + }; + let data: McpOauthRequiredData = + match serde_json::from_value(notification.event.data.clone()) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to deserialize MCP OAuth request"); + return; + } + }; + let request = McpAuthRequest { + server_name: data.server_name, + server_url: data.server_url, + www_authenticate_params: data.www_authenticate_params, + resource_metadata: data.resource_metadata, + static_client_config: data.static_client_config, + }; + let client = client.clone(); + let sid = session_id.clone(); + let span = tracing::error_span!( + "mcp_auth_request_handler", + session_id = %sid, + request_id = %request_id + ); + tokio::spawn( + async move { + let cancel = McpAuthResult::Cancelled; + let handler_task = tokio::spawn({ + let sid = sid.clone(); + let request_id = request_id.clone(); + let span = tracing::error_span!( + "mcp_auth_callback", + session_id = %sid, + request_id = %request_id + ); + async move { + let handler_start = Instant::now(); + let response = mcp_auth_handler + .handle(sid.clone(), request_id.clone(), request) + .await; + tracing::debug!( + elapsed_ms = handler_start.elapsed().as_millis(), + session_id = %sid, + request_id = %request_id, + "McpAuthHandler::handle dispatch" + ); + response + } + .instrument(span) + }); + let result = match handler_task.await { + Ok(result) => result, + Err(_) => cancel, + }; + let rpc_start = Instant::now(); + let _ = client + .call( + "session.mcp.oauth.handlePendingRequest", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "result": result.into_wire(), + })), + ) + .await; + tracing::debug!( + elapsed_ms = rpc_start.elapsed().as_millis(), + "Session::handle_notification MCP auth response sent" + ); + } + .instrument(span), + ); + } SessionEventType::CommandExecute => { let data: CommandExecuteData = match serde_json::from_value(notification.event.data.clone()) { diff --git a/rust/src/types.rs b/rust/src/types.rs index e743dbda1..3fb51861a 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -24,8 +24,8 @@ use crate::generated::api_types::OpenCanvasInstance; pub use crate::generated::session_events::ContextTier; use crate::generated::session_events::ReasoningSummary; use crate::handler::{ - AutoModeSwitchHandler, ElicitationHandler, ExitPlanModeHandler, PermissionHandler, - UserInputHandler, + AutoModeSwitchHandler, ElicitationHandler, ExitPlanModeHandler, McpAuthHandler, + PermissionHandler, UserInputHandler, }; use crate::hooks::SessionHooks; pub use crate::session_fs::{ @@ -1662,6 +1662,9 @@ pub struct SessionConfig { /// Optional elicitation-request handler. When `None`, /// `requestElicitation: false` goes on the wire. pub elicitation_handler: Option>, + /// Optional MCP OAuth request handler. When set, the SDK can satisfy MCP + /// server OAuth requests with host-acquired token data or cancellation. + pub mcp_auth_handler: Option>, /// Optional user-input handler. When `None`, /// `requestUserInput: false` goes on the wire and the `ask_user` /// tool is disabled. @@ -1790,6 +1793,14 @@ impl std::fmt::Debug for SessionConfig { "elicitation_handler", &self.elicitation_handler.as_ref().map(|_| ""), ) + .field( + "mcp_auth_handler", + &self.mcp_auth_handler.as_ref().map(|_| ""), + ) + .field( + "mcp_auth_handler", + &self.mcp_auth_handler.as_ref().map(|_| ""), + ) .field( "user_input_handler", &self.user_input_handler.as_ref().map(|_| ""), @@ -1878,6 +1889,7 @@ impl Default for SessionConfig { session_fs_provider: None, permission_handler: None, elicitation_handler: None, + mcp_auth_handler: None, user_input_handler: None, exit_plan_mode_handler: None, auto_mode_switch_handler: None, @@ -1901,6 +1913,7 @@ pub(crate) struct SessionConfigRuntime { pub permission_handler: Option>, pub permission_policy: Option, pub elicitation_handler: Option>, + pub mcp_auth_handler: Option>, pub user_input_handler: Option>, pub exit_plan_mode_handler: Option>, pub auto_mode_switch_handler: Option>, @@ -2027,6 +2040,7 @@ impl SessionConfig { permission_handler: self.permission_handler, permission_policy: self.permission_policy, elicitation_handler: self.elicitation_handler, + mcp_auth_handler: self.mcp_auth_handler, user_input_handler: self.user_input_handler, exit_plan_mode_handler: self.exit_plan_mode_handler, auto_mode_switch_handler: self.auto_mode_switch_handler, @@ -2056,6 +2070,12 @@ impl SessionConfig { self } + /// Install an [`McpAuthHandler`] for host-provided MCP OAuth tokens. + pub fn with_mcp_auth_handler(mut self, handler: Arc) -> Self { + self.mcp_auth_handler = Some(handler); + self + } + /// Install a [`UserInputHandler`]. Required for the `ask_user` tool /// to be enabled. pub fn with_user_input_handler(mut self, handler: Arc) -> Self { @@ -2717,6 +2737,8 @@ pub struct ResumeSessionConfig { /// Optional elicitation handler. See /// [`SessionConfig::elicitation_handler`]. pub elicitation_handler: Option>, + /// Optional MCP OAuth handler. See [`SessionConfig::mcp_auth_handler`]. + pub mcp_auth_handler: Option>, /// Optional user-input handler. See /// [`SessionConfig::user_input_handler`]. pub user_input_handler: Option>, @@ -2965,6 +2987,7 @@ impl ResumeSessionConfig { permission_handler: self.permission_handler, permission_policy: self.permission_policy, elicitation_handler: self.elicitation_handler, + mcp_auth_handler: self.mcp_auth_handler, user_input_handler: self.user_input_handler, exit_plan_mode_handler: self.exit_plan_mode_handler, auto_mode_switch_handler: self.auto_mode_switch_handler, @@ -3042,6 +3065,7 @@ impl ResumeSessionConfig { continue_pending_work: None, permission_handler: None, elicitation_handler: None, + mcp_auth_handler: None, user_input_handler: None, exit_plan_mode_handler: None, auto_mode_switch_handler: None, @@ -3067,6 +3091,12 @@ impl ResumeSessionConfig { self } + /// Install an [`McpAuthHandler`] for host-provided MCP OAuth tokens. + pub fn with_mcp_auth_handler(mut self, handler: Arc) -> Self { + self.mcp_auth_handler = Some(handler); + self + } + /// Install a [`UserInputHandler`] for the resumed session. pub fn with_user_input_handler(mut self, handler: Arc) -> Self { self.user_input_handler = Some(handler); diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 99d8c1391..a7a4b360e 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -9,17 +9,19 @@ use async_trait::async_trait; use github_copilot_sdk::canvas::{CanvasDeclaration, CanvasHandler, CanvasResult}; use github_copilot_sdk::handler::{ ApproveAllHandler, AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, - ExitPlanModeHandler, ExitPlanModeResult, UserInputHandler, UserInputResponse, + ExitPlanModeHandler, ExitPlanModeResult, McpAuthHandler, McpAuthRequest, McpAuthResult, + UserInputHandler, UserInputResponse, }; use github_copilot_sdk::rpc::{ CanvasInstanceAvailability, CanvasProviderInvokeActionRequest, CanvasProviderOpenRequest, CanvasProviderOpenResult, OpenCanvasInstance, }; -use github_copilot_sdk::session_events::ReasoningSummary; +use github_copilot_sdk::session_events::{McpOauthRequiredData, ReasoningSummary}; use github_copilot_sdk::types::{ - CommandContext, CommandDefinition, CommandHandler, DeliveryMode, ElicitationRequest, - ElicitationResult, ExitPlanModeData, ExtensionInfo, MessageOptions, RequestId, SessionConfig, - SessionId, SetModelOptions, Tool, ToolInvocation, ToolResult, + CloudSessionOptions, CloudSessionRepository, CommandContext, CommandDefinition, CommandHandler, + DeliveryMode, ElicitationRequest, ElicitationResult, ExitPlanModeData, ExtensionInfo, + MessageOptions, RequestId, SessionConfig, SessionId, SetModelOptions, Tool, ToolInvocation, + ToolResult, }; use github_copilot_sdk::{Client, ContextTier, tool}; use serde_json::Value; @@ -30,6 +32,20 @@ const TIMEOUT: Duration = Duration::from_secs(2); struct TestCanvasHandler; +struct CancelMcpAuthHandler; + +#[async_trait] +impl McpAuthHandler for CancelMcpAuthHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: McpAuthRequest, + ) -> McpAuthResult { + McpAuthResult::Cancelled + } +} + #[async_trait] impl CanvasHandler for TestCanvasHandler { async fn on_open( @@ -220,12 +236,279 @@ fn rand_id() -> u64 { COUNTER.fetch_add(1, Ordering::Relaxed) as u64 } +#[test] +fn mcp_oauth_required_data_allows_optional_metadata() { + let with_metadata: McpOauthRequiredData = serde_json::from_value(serde_json::json!({ + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp", + "wwwAuthenticateParams": { + "resourceMetadataUrl": "https://example.com/.well-known/oauth-protected-resource" + }, + "resourceMetadata": "{\"resource\":\"https://example.com/mcp\"}" + })) + .unwrap(); + assert_eq!( + with_metadata.resource_metadata.as_deref(), + Some("{\"resource\":\"https://example.com/mcp\"}") + ); + assert!(with_metadata.www_authenticate_params.is_some()); + + let without_metadata: McpOauthRequiredData = serde_json::from_value(serde_json::json!({ + "requestId": "oauth-request", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + })) + .unwrap(); + assert!(without_metadata.resource_metadata.is_none()); + assert!(without_metadata.www_authenticate_params.is_none()); +} + fn requested_session_id(request: &Value) -> &str { request["params"]["sessionId"] .as_str() .expect("session request should include sessionId") } +#[tokio::test] +async fn create_session_registers_mcp_auth_interest_only_with_handler() { + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default().with_permission_handler(Arc::new(ApproveAllHandler)), + ) + .await + .unwrap() + } + }); + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert_eq!(create_req["params"]["requestPermission"], true); + let session_id = requested_session_id(&create_req).to_string(); + server_respond_create(&mut server_write, &create_req, &session_id).await; + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + + let no_extra_request = timeout(Duration::from_millis(50), read_framed(&mut server_read)).await; + assert!(no_extra_request.is_err()); + drop(session); + + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_mcp_auth_handler(Arc::new(CancelMcpAuthHandler)), + ) + .await + .unwrap() + } + }); + + let interest_req = read_framed(&mut server_read).await; + assert_eq!(interest_req["method"], "session.eventLog.registerInterest"); + assert_eq!(interest_req["params"]["eventType"], "mcp.oauth_required"); + let id = interest_req["id"].as_u64().unwrap(); + write_framed( + &mut server_write, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "id": "interest-1" }, + })) + .unwrap(), + ) + .await; + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert_eq!(create_req["params"]["requestPermission"], true); + let session_id = requested_session_id(&create_req).to_string(); + server_respond_create(&mut server_write, &create_req, &session_id).await; + let _session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn cloud_create_session_registers_mcp_auth_interest_after_create_only_with_handler() { + let cloud = || { + CloudSessionOptions::with_repository( + CloudSessionRepository::new("github", "copilot-sdk").with_branch("main"), + ) + }; + + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_cloud(cloud()), + ) + .await + .unwrap() + } + }); + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert!(create_req["params"].get("sessionId").is_none()); + assert_eq!(create_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &create_req, "server-assigned-session-1").await; + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + let no_extra_request = timeout(Duration::from_millis(50), read_framed(&mut server_read)).await; + assert!(no_extra_request.is_err()); + drop(session); + + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_mcp_auth_handler(Arc::new(CancelMcpAuthHandler)) + .with_cloud(cloud()), + ) + .await + .unwrap() + } + }); + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert!(create_req["params"].get("sessionId").is_none()); + assert_eq!(create_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &create_req, "server-assigned-session-2").await; + + let interest_req = read_framed(&mut server_read).await; + assert_eq!(interest_req["method"], "session.eventLog.registerInterest"); + assert_eq!( + interest_req["params"]["sessionId"], + "server-assigned-session-2" + ); + assert_eq!(interest_req["params"]["eventType"], "mcp.oauth_required"); + let id = interest_req["id"].as_u64().unwrap(); + write_framed( + &mut server_write, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "id": "interest-1" }, + })) + .unwrap(), + ) + .await; + let _session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn resume_session_registers_mcp_auth_interest_only_with_handler() { + use github_copilot_sdk::types::ResumeSessionConfig; + + let (client, mut server_read, mut server_write) = make_client(); + let resume_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .resume_session( + ResumeSessionConfig::new(SessionId::from("session-without-auth")) + .with_permission_handler(Arc::new(ApproveAllHandler)), + ) + .await + .unwrap() + } + }); + + let resume_req = read_framed(&mut server_read).await; + assert_eq!(resume_req["method"], "session.resume"); + assert_eq!(resume_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &resume_req, "session-without-auth").await; + respond_to_reload(&mut server_read, &mut server_write).await; + let session = timeout(TIMEOUT, resume_handle).await.unwrap().unwrap(); + let no_extra_request = timeout(Duration::from_millis(50), read_framed(&mut server_read)).await; + assert!(no_extra_request.is_err()); + drop(session); + + let (client, mut server_read, mut server_write) = make_client(); + let resume_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .resume_session( + ResumeSessionConfig::new(SessionId::from("session-with-auth")) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_mcp_auth_handler(Arc::new(CancelMcpAuthHandler)), + ) + .await + .unwrap() + } + }); + + let interest_req = read_framed(&mut server_read).await; + assert_eq!(interest_req["method"], "session.eventLog.registerInterest"); + assert_eq!(interest_req["params"]["eventType"], "mcp.oauth_required"); + let id = interest_req["id"].as_u64().unwrap(); + write_framed( + &mut server_write, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "id": "interest-1" }, + })) + .unwrap(), + ) + .await; + + let resume_req = read_framed(&mut server_read).await; + assert_eq!(resume_req["method"], "session.resume"); + assert_eq!(resume_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &resume_req, "session-with-auth").await; + respond_to_reload(&mut server_read, &mut server_write).await; + let _session = timeout(TIMEOUT, resume_handle).await.unwrap().unwrap(); +} + +async fn server_respond_create( + writer: &mut (impl AsyncWrite + Unpin), + request: &Value, + session_id: &str, +) { + let id = request["id"].as_u64().unwrap(); + write_framed( + writer, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": session_id, "workspacePath": "/tmp/workspace" }, + })) + .unwrap(), + ) + .await; +} + +async fn respond_to_reload( + reader: &mut (impl tokio::io::AsyncRead + Unpin), + writer: &mut (impl AsyncWrite + Unpin), +) { + let reload = read_framed(reader).await; + assert_eq!(reload["method"], "session.skills.reload"); + let id = reload["id"].as_u64().unwrap(); + write_framed( + writer, + &serde_json::to_vec(&serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": {} })) + .unwrap(), + ) + .await; +} + #[tokio::test] async fn session_subscribe_yields_events_observe_only() { let (session, mut server) = create_session_pair().await; diff --git a/test/harness/test-mcp-oauth-server.mjs b/test/harness/test-mcp-oauth-server.mjs new file mode 100644 index 000000000..3a642b55a --- /dev/null +++ b/test/harness/test-mcp-oauth-server.mjs @@ -0,0 +1,216 @@ +#!/usr/bin/env node +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +/** + * Minimal OAuth-protected Streamable HTTP MCP server for SDK E2E tests. + * + * The `/mcp` endpoint returns a WWW-Authenticate challenge until requests include + * `Authorization: Bearer `, then serves enough JSON-RPC MCP + * methods for the runtime to initialize and list/call one tool. + */ + +import http from "node:http"; + +const DEFAULT_EXPECTED_TOKEN = "sdk-host-token"; +const PROTOCOL_VERSION = "2025-03-26"; + +export async function startOAuthMcpServer({ + expectedToken = DEFAULT_EXPECTED_TOKEN, + host = "127.0.0.1", + port = 0, +} = {}) { + const requests = []; + + const server = http.createServer(async (req, res) => { + const url = new URL( + req.url ?? "/", + `http://${req.headers.host ?? `${host}:${port}`}`, + ); + const baseUrl = `http://${req.headers.host}`; + + if (req.method === "GET" && url.pathname === "/__requests") { + respondJson(res, 200, requests); + return; + } + + if ( + req.method === "GET" && + url.pathname === "/.well-known/oauth-protected-resource" + ) { + respondJson(res, 200, { + resource: `${baseUrl}/mcp`, + authorization_servers: [baseUrl], + scopes_supported: ["mcp.read"], + bearer_methods_supported: ["header"], + }); + return; + } + + if ( + req.method === "GET" && + url.pathname === "/.well-known/oauth-authorization-server" + ) { + respondJson(res, 200, { + issuer: baseUrl, + authorization_endpoint: `${baseUrl}/authorize`, + token_endpoint: `${baseUrl}/token`, + response_types_supported: ["code"], + grant_types_supported: ["authorization_code"], + }); + return; + } + + if (url.pathname !== "/mcp") { + respondJson(res, 404, { error: "not_found" }); + return; + } + + const body = await readBody(req); + requests.push({ + method: req.method, + path: url.pathname, + authorization: req.headers.authorization ?? null, + body: body ? JSON.parse(body) : null, + }); + + if (req.headers.authorization !== `Bearer ${expectedToken}`) { + const resourceMetadataUrl = `${baseUrl}/.well-known/oauth-protected-resource`; + res.writeHead(401, { + "www-authenticate": `Bearer resource_metadata="${resourceMetadataUrl}", scope="mcp.read", error="invalid_token"`, + "content-type": "application/json", + }); + res.end(JSON.stringify({ error: "missing_or_invalid_token" })); + return; + } + + if (req.method !== "POST") { + respondJson(res, 405, { error: "method_not_allowed" }); + return; + } + + const message = body ? JSON.parse(body) : undefined; + const response = Array.isArray(message) + ? message.map(handleJsonRpcMessage).filter((item) => item !== undefined) + : handleJsonRpcMessage(message); + + if ( + response === undefined || + (Array.isArray(response) && response.length === 0) + ) { + res.writeHead(202, { "mcp-session-id": "oauth-test-session" }); + res.end(); + return; + } + + res.writeHead(200, { + "content-type": "application/json", + "mcp-session-id": "oauth-test-session", + }); + res.end(JSON.stringify(response)); + }); + + await new Promise((resolve, reject) => { + server.once("error", reject); + server.listen(port, host, () => { + server.off("error", reject); + resolve(); + }); + }); + + const address = server.address(); + if (!address || typeof address === "string") { + throw new Error("Expected TCP server address"); + } + + return { + url: `http://${host}:${address.port}`, + requests, + close: () => + new Promise((resolve, reject) => + server.close((err) => (err ? reject(err) : resolve())), + ), + }; +} + +function handleJsonRpcMessage(message) { + if (!message || typeof message !== "object" || !("id" in message)) { + return undefined; + } + + switch (message.method) { + case "initialize": + return { + jsonrpc: "2.0", + id: message.id, + result: { + protocolVersion: message.params?.protocolVersion ?? PROTOCOL_VERSION, + capabilities: { tools: {} }, + serverInfo: { name: "oauth-test-server", version: "1.0.0" }, + }, + }; + case "tools/list": + return { + jsonrpc: "2.0", + id: message.id, + result: { + tools: [ + { + name: "whoami", + description: "Returns the authenticated test principal.", + inputSchema: { + type: "object", + properties: {}, + additionalProperties: false, + }, + }, + ], + }, + }; + case "tools/call": + return { + jsonrpc: "2.0", + id: message.id, + result: { + content: [{ type: "text", text: "oauth-test-user" }], + isError: false, + }, + }; + default: + return { + jsonrpc: "2.0", + id: message.id, + error: { code: -32601, message: `Method not found: ${message.method}` }, + }; + } +} + +function readBody(req) { + return new Promise((resolve, reject) => { + const chunks = []; + req.on("data", (chunk) => chunks.push(chunk)); + req.on("error", reject); + req.on("end", () => resolve(Buffer.concat(chunks).toString("utf8"))); + }); +} + +function respondJson(res, statusCode, body) { + const data = JSON.stringify(body); + res.writeHead(statusCode, { + "content-type": "application/json", + "content-length": Buffer.byteLength(data), + }); + res.end(data); +} + +if (import.meta.url === `file://${process.argv[1]}`) { + const server = await startOAuthMcpServer({ + expectedToken: process.env.EXPECTED_TOKEN ?? DEFAULT_EXPECTED_TOKEN, + }); + console.log(`Listening: ${server.url}`); + process.on("SIGTERM", async () => { + await server.close(); + process.exit(0); + }); +}