diff --git a/packages/ai-proxy/src/provider-dispatcher.ts b/packages/ai-proxy/src/provider-dispatcher.ts index 5f04f3288..ca0df1096 100644 --- a/packages/ai-proxy/src/provider-dispatcher.ts +++ b/packages/ai-proxy/src/provider-dispatcher.ts @@ -77,6 +77,7 @@ export type DispatchBody = { messages: ChatCompletionMessage[]; tools?: ChatCompletionTool[]; tool_choice?: ChatCompletionToolChoice; + parallel_tool_calls?: boolean; }; export class ProviderDispatcher { @@ -101,10 +102,20 @@ export class ProviderDispatcher { throw new AINotConfiguredError(); } - const { tools, messages, tool_choice: toolChoice } = body; + const { + tools, + messages, + tool_choice: toolChoice, + parallel_tool_calls: parallelToolCalls, + } = body; const enrichedTools = this.enrichToolDefinitions(tools); - const model = this.bindToolsIfNeeded(this.chatModel, enrichedTools, toolChoice); + const model = enrichedTools?.length + ? this.chatModel.bindTools(enrichedTools, { + tool_choice: toolChoice, + parallel_tool_calls: parallelToolCalls, + }) + : this.chatModel; try { const response = await model.invoke(messages as BaseMessageLike[]); @@ -136,20 +147,6 @@ export class ProviderDispatcher { } } - private bindToolsIfNeeded( - chatModel: ChatOpenAI, - tools: ChatCompletionTool[] | undefined, - toolChoice?: ChatCompletionToolChoice, - ) { - if (!tools || tools.length === 0) { - return chatModel; - } - - return chatModel.bindTools(tools, { - tool_choice: toolChoice as 'auto' | 'none' | 'required' | undefined, - }); - } - private enrichToolDefinitions(tools?: ChatCompletionTool[]) { if (!tools || !Array.isArray(tools)) return tools; diff --git a/packages/ai-proxy/test/provider-dispatcher.test.ts b/packages/ai-proxy/test/provider-dispatcher.test.ts index c819fb1bc..cbe761325 100644 --- a/packages/ai-proxy/test/provider-dispatcher.test.ts +++ b/packages/ai-proxy/test/provider-dispatcher.test.ts @@ -218,6 +218,45 @@ describe('ProviderDispatcher', () => { }); }); + describe('when parallel_tool_calls is provided', () => { + it('should pass parallel_tool_calls to bindTools', async () => { + const dispatcher = new ProviderDispatcher( + { name: 'gpt4', provider: 'openai', apiKey: 'dev', model: 'gpt-4o' }, + new RemoteTools(apiKeys), + ); + + await dispatcher.dispatch({ + messages: [{ role: 'user', content: 'test' }], + tools: [{ type: 'function', function: { name: 'test', parameters: {} } }], + tool_choice: 'auto', + parallel_tool_calls: false, + } as unknown as DispatchBody); + + expect(bindToolsMock).toHaveBeenCalledWith( + [{ type: 'function', function: { name: 'test', parameters: {} } }], + { tool_choice: 'auto', parallel_tool_calls: false }, + ); + }); + + it('should pass parallel_tool_calls: true when explicitly set', async () => { + const dispatcher = new ProviderDispatcher( + { name: 'gpt4', provider: 'openai', apiKey: 'dev', model: 'gpt-4o' }, + new RemoteTools(apiKeys), + ); + + await dispatcher.dispatch({ + messages: [{ role: 'user', content: 'test' }], + tools: [{ type: 'function', function: { name: 'test', parameters: {} } }], + parallel_tool_calls: true, + } as unknown as DispatchBody); + + expect(bindToolsMock).toHaveBeenCalledWith( + expect.any(Array), + { tool_choice: undefined, parallel_tool_calls: true }, + ); + }); + }); + describe('when there is not remote tool', () => { it('should not enhance the remote tools definition', async () => { const remoteTools = new RemoteTools(apiKeys);