diff --git a/packages/agent/src/routes/ai/ai-proxy.ts b/packages/agent/src/routes/ai/ai-proxy.ts index 85eb69521..581216a29 100644 --- a/packages/agent/src/routes/ai/ai-proxy.ts +++ b/packages/agent/src/routes/ai/ai-proxy.ts @@ -8,6 +8,7 @@ import { AIError, AINotFoundError, Router as AiProxyRouter, + injectOauthTokens, } from '@forestadmin/ai-proxy'; import { BadRequestError, @@ -40,11 +41,25 @@ export default class AiProxyRoute extends BaseRoute { private async handleAiProxy(context: Context): Promise { try { + const mcpOauthTokensHeader = context.request.headers['x-mcp-oauth-tokens'] as string; + let mcpOAuthTokens: Record | undefined; + + if (mcpOauthTokensHeader) { + try { + mcpOAuthTokens = JSON.parse(mcpOauthTokensHeader); + } catch { + throw new BadRequestError('Invalid JSON in x-mcp-oauth-tokens header'); + } + } + + const mcpConfigs = + await this.options.forestAdminClient.mcpServerConfigService.getConfiguration(); + context.response.body = await this.aiProxyRouter.route({ route: context.params.route, body: context.request.body, query: context.query, - mcpConfigs: await this.options.forestAdminClient.mcpServerConfigService.getConfiguration(), + mcpConfigs: injectOauthTokens(mcpConfigs, mcpOAuthTokens), }); context.response.status = HttpCode.Ok; } catch (error) { diff --git a/packages/agent/test/routes/ai/ai-proxy.test.ts b/packages/agent/test/routes/ai/ai-proxy.test.ts index 0034914d1..0a7747e8b 100644 --- a/packages/agent/test/routes/ai/ai-proxy.test.ts +++ b/packages/agent/test/routes/ai/ai-proxy.test.ts @@ -85,7 +85,7 @@ describe('AiProxyRoute', () => { expect(context.response.body).toEqual(expectedResponse); }); - test('should pass route, body, query and mcpConfigs to router', async () => { + test('should pass route, body, query, mcpConfigs and mcpOAuthTokens to router', async () => { const route = new AiProxyRoute(services, options, aiConfigurations); mockRoute.mockResolvedValueOnce({}); @@ -108,6 +108,71 @@ describe('AiProxyRoute', () => { }); }); + test('should inject oauth tokens into mcpConfigs when header is provided', async () => { + const route = new AiProxyRoute(services, options, aiConfigurations); + mockRoute.mockResolvedValueOnce({}); + + // Mock mcpServerConfigService to return actual configs + const mcpConfigs = { + configs: { + server1: { type: 'http' as const, url: 'https://server1.com' }, + server2: { type: 'http' as const, url: 'https://server2.com' }, + }, + }; + options.forestAdminClient.mcpServerConfigService.getConfiguration = jest + .fn() + .mockResolvedValue(mcpConfigs); + + const tokens = { server1: 'token1', server2: 'token2' }; + const context = createMockContext({ + customProperties: { + params: { route: 'ai-query' }, + }, + requestBody: { messages: [] }, + headers: { 'x-mcp-oauth-tokens': JSON.stringify(tokens) }, + }); + context.query = {}; + + await (route as any).handleAiProxy(context); + + expect(mockRoute).toHaveBeenCalledWith( + expect.objectContaining({ + mcpConfigs: { + configs: { + server1: { + type: 'http', + url: 'https://server1.com', + headers: { Authorization: 'token1' }, + }, + server2: { + type: 'http', + url: 'https://server2.com', + headers: { Authorization: 'token2' }, + }, + }, + }, + }), + ); + }); + + test('should throw BadRequestError when x-mcp-oauth-tokens header contains invalid JSON', async () => { + const route = new AiProxyRoute(services, options, aiConfigurations); + + const context = createMockContext({ + customProperties: { + params: { route: 'ai-query' }, + }, + requestBody: { messages: [] }, + headers: { 'x-mcp-oauth-tokens': '{ invalid json }' }, + }); + context.query = {}; + + await expect((route as any).handleAiProxy(context)).rejects.toThrow(BadRequestError); + await expect((route as any).handleAiProxy(context)).rejects.toThrow( + 'Invalid JSON in x-mcp-oauth-tokens header', + ); + }); + describe('error handling', () => { test('should convert AINotConfiguredError to UnprocessableError', async () => { const route = new AiProxyRoute(services, options, aiConfigurations); @@ -212,7 +277,6 @@ describe('AiProxyRoute', () => { }); context.query = {}; - // eslint-disable-next-line @typescript-eslint/no-explicit-any const promise = (route as any).handleAiProxy(context); await expect(promise).rejects.toBe(unknownError); diff --git a/packages/ai-proxy/src/mcp-client.ts b/packages/ai-proxy/src/mcp-client.ts index b30212a8f..b0211eea7 100644 --- a/packages/ai-proxy/src/mcp-client.ts +++ b/packages/ai-proxy/src/mcp-client.ts @@ -5,10 +5,59 @@ import { MultiServerMCPClient } from '@langchain/mcp-adapters'; import { McpConnectionError } from './types/errors'; import McpServerRemoteTool from './types/mcp-server-remote-tool'; +export type McpAuthenticationType = 'none' | 'bearer' | 'oauth2'; + +export type McpServerConfig = MultiServerMCPClient['config']['mcpServers'][string]; + export type McpConfiguration = { configs: MultiServerMCPClient['config']['mcpServers']; + authenticationTypes?: Record; } & Omit; +/** + * Injects the OAuth token as Authorization header into HTTP-based transport configurations. + * For stdio transports, returns the config unchanged. + */ +export function injectOauthToken(serverConfig: McpServerConfig, token?: string): McpServerConfig { + if (!token) return serverConfig; + + // Only inject token for HTTP-based transports (sse, http) + if (serverConfig.type === 'http' || serverConfig.type === 'sse') { + const { oauthConfig, ...headers } = serverConfig.headers || {}; + + return { + ...serverConfig, + headers: { + ...headers, + Authorization: token, + }, + }; + } + + return serverConfig; +} + +/** + * Injects OAuth tokens into all server configurations. + * Returns a new McpConfiguration with tokens injected, or undefined if no configs provided. + */ +export function injectOauthTokens( + mcpConfigs: McpConfiguration | undefined, + mcpOAuthTokens: Record | undefined, +): McpConfiguration | undefined { + if (!mcpConfigs) return undefined; + if (!mcpOAuthTokens) return mcpConfigs; + + const configsWithTokens = Object.fromEntries( + Object.entries(mcpConfigs.configs).map(([name, serverConfig]) => [ + name, + injectOauthToken(serverConfig, mcpOAuthTokens[name]), + ]), + ); + + return { ...mcpConfigs, configs: configsWithTokens }; +} + export default class McpClient { private readonly mcpClients: Record = {}; private readonly logger?: Logger; diff --git a/packages/ai-proxy/test/mcp-client.test.ts b/packages/ai-proxy/test/mcp-client.test.ts index caa09eaf2..bc70fbb73 100644 --- a/packages/ai-proxy/test/mcp-client.test.ts +++ b/packages/ai-proxy/test/mcp-client.test.ts @@ -3,7 +3,7 @@ import type { McpConfiguration } from '../src'; import { tool } from '@langchain/core/tools'; import { McpConnectionError } from '../src'; -import McpClient from '../src/mcp-client'; +import McpClient, { injectOauthToken, injectOauthTokens } from '../src/mcp-client'; import McpServerRemoteTool from '../src/types/mcp-server-remote-tool'; const getToolsMock = jest.fn(); @@ -19,6 +19,13 @@ jest.mock('@langchain/mcp-adapters', () => { }; }); +// eslint-disable-next-line import/first +import { MultiServerMCPClient } from '@langchain/mcp-adapters'; + +const MockedMultiServerMCPClient = MultiServerMCPClient as jest.MockedClass< + typeof MultiServerMCPClient +>; + describe('McpClient', () => { beforeEach(() => { jest.clearAllMocks(); @@ -236,4 +243,167 @@ describe('McpClient', () => { }); }); }); + + describe('injectOauthToken', () => { + it('should inject OAuth token as Authorization header into HTTP type transport', () => { + const serverConfig = { + type: 'http' as const, + url: 'https://example.com/mcp', + }; + + const result = injectOauthToken(serverConfig, 'my-oauth-token'); + + expect(result).toEqual({ + type: 'http', + url: 'https://example.com/mcp', + headers: { + Authorization: 'my-oauth-token', + }, + }); + }); + + it('should inject OAuth token as Authorization header into SSE type transport', () => { + const serverConfig = { + type: 'sse' as const, + url: 'https://example.com/mcp', + }; + + const result = injectOauthToken(serverConfig, 'my-oauth-token'); + + expect(result).toEqual({ + type: 'sse', + url: 'https://example.com/mcp', + headers: { + Authorization: 'my-oauth-token', + }, + }); + }); + + it('should merge Authorization header with existing headers and strip oauthConfig', () => { + const serverConfig = { + type: 'http' as const, + url: 'https://example.com/mcp', + headers: { + 'x-custom-header': 'custom-value', + oauthConfig: { clientId: 'test' } as unknown as string, + }, + }; + + const result = injectOauthToken(serverConfig, 'my-oauth-token'); + + expect(result).toEqual({ + type: 'http', + url: 'https://example.com/mcp', + headers: { + 'x-custom-header': 'custom-value', + Authorization: 'my-oauth-token', + }, + }); + }); + + it('should not inject OAuth token into stdio transport even if token provided', () => { + const serverConfig = { + transport: 'stdio' as const, + command: 'npx', + args: ['-y', '@modelcontextprotocol/server-slack'], + env: {}, + }; + + const result = injectOauthToken(serverConfig, 'my-oauth-token'); + + expect(result).toEqual(serverConfig); + }); + + it('should not modify config when no token is provided', () => { + const serverConfig = { + type: 'http' as const, + url: 'https://example.com/mcp', + }; + + const result = injectOauthToken(serverConfig); + + expect(result).toEqual(serverConfig); + }); + + it('should return same reference when no token is provided', () => { + const serverConfig = { + type: 'http' as const, + url: 'https://example.com/mcp', + }; + + const result = injectOauthToken(serverConfig); + + expect(result).toBe(serverConfig); + }); + }); + + describe('injectOauthTokens', () => { + it('should inject tokens into all matching server configs', () => { + const mcpConfigs = { + configs: { + server1: { type: 'http' as const, url: 'https://server1.com' }, + server2: { type: 'http' as const, url: 'https://server2.com' }, + }, + }; + const tokens = { server1: 'token1', server2: 'token2' }; + + const result = injectOauthTokens(mcpConfigs, tokens); + + expect(result).toEqual({ + configs: { + server1: { + type: 'http', + url: 'https://server1.com', + headers: { Authorization: 'token1' }, + }, + server2: { + type: 'http', + url: 'https://server2.com', + headers: { Authorization: 'token2' }, + }, + }, + }); + }); + + it('should only inject tokens for servers that have matching tokens', () => { + const mcpConfigs = { + configs: { + server1: { type: 'http' as const, url: 'https://server1.com' }, + server2: { type: 'http' as const, url: 'https://server2.com' }, + }, + }; + const tokens = { server1: 'token1' }; + + const result = injectOauthTokens(mcpConfigs, tokens); + + expect(result).toEqual({ + configs: { + server1: { + type: 'http', + url: 'https://server1.com', + headers: { Authorization: 'token1' }, + }, + server2: { type: 'http', url: 'https://server2.com' }, + }, + }); + }); + + it('should return undefined when mcpConfigs is undefined', () => { + const result = injectOauthTokens(undefined, { server1: 'token1' }); + + expect(result).toBeUndefined(); + }); + + it('should return mcpConfigs unchanged when tokens is undefined', () => { + const mcpConfigs = { + configs: { + server1: { type: 'http' as const, url: 'https://server1.com' }, + }, + }; + + const result = injectOauthTokens(mcpConfigs, undefined); + + expect(result).toBe(mcpConfigs); + }); + }); });