diff --git a/python/semantic_kernel/agents/azure_ai/agent_thread_actions.py b/python/semantic_kernel/agents/azure_ai/agent_thread_actions.py index 62fb798bb11b..54b75e928fa1 100644 --- a/python/semantic_kernel/agents/azure_ai/agent_thread_actions.py +++ b/python/semantic_kernel/agents/azure_ai/agent_thread_actions.py @@ -67,6 +67,8 @@ from semantic_kernel.agents.open_ai.assistant_content_generation import merge_streaming_function_results from semantic_kernel.agents.open_ai.function_action_result import FunctionActionResult from semantic_kernel.agents.open_ai.run_polling_options import RunPollingOptions +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior +from semantic_kernel.connectors.ai.function_choice_type import FunctionChoiceType from semantic_kernel.connectors.ai.function_calling_utils import kernel_function_metadata_to_function_call_format from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent @@ -124,6 +126,7 @@ async def invoke( parallel_tool_calls: bool | None = None, metadata: dict[str, str] | None = None, polling_options: RunPollingOptions | None = None, + function_choice_behavior: FunctionChoiceBehavior | None = None, **kwargs: Any, ) -> AsyncIterable[tuple[bool, "ChatMessageContent"]]: """Invoke the message in the thread. @@ -139,7 +142,9 @@ async def invoke( additional_messages: The additional messages to add to the thread. Only supports messages with role = User or Assistant. https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-additional_messages - tools: The tools. + tools: The SDK-level tools (e.g. CodeInterpreter, FileSearch, AzureAISearch). When provided, + overrides the tools from the agent definition. Does not affect kernel function availability; + use function_choice_behavior for that. temperature: The temperature. top_p: The top p. max_prompt_tokens: The max prompt tokens. @@ -150,6 +155,9 @@ async def invoke( metadata: The metadata. polling_options: The polling options defined at the run-level. These will override the agent-level polling options. + function_choice_behavior: Controls which kernel functions are allowed to execute during this run. + Use FunctionChoiceBehavior.Auto(filters={"included_functions": [...]}) to restrict to specific + functions. Only Auto is supported; other types will raise an error. kwargs: Additional keyword arguments. Returns: @@ -158,7 +166,10 @@ async def invoke( arguments = KernelArguments() if arguments is None else KernelArguments(**arguments, **kwargs) kernel = kernel or agent.kernel - tools = cls._get_tools(agent=agent, kernel=kernel) # type: ignore + cls._validate_function_choice_behavior(function_choice_behavior) + + tools = cls._get_tools(agent=agent, kernel=kernel, tools_override=tools, + function_choice_behavior=function_choice_behavior) # type: ignore base_instructions = await agent.format_instructions(kernel=kernel, arguments=arguments) @@ -232,7 +243,11 @@ async def invoke( chat_history = ChatHistory() if kwargs.get("chat_history") is None else kwargs["chat_history"] _ = await cls._invoke_function_calls( - kernel=kernel, fccs=fccs, chat_history=chat_history, arguments=arguments + kernel=kernel, + fccs=fccs, + chat_history=chat_history, + arguments=arguments, + function_choice_behavior=function_choice_behavior, ) tool_outputs = cls._format_tool_outputs(fccs, chat_history) @@ -467,6 +482,7 @@ async def invoke_stream( temperature: float | None = None, top_p: float | None = None, truncation_strategy: TruncationObject | None = None, + function_choice_behavior: FunctionChoiceBehavior | None = None, **kwargs: Any, ) -> AsyncIterable["StreamingChatMessageContent"]: """Invoke the agent stream and yield ChatMessageContent continuously. @@ -489,10 +505,15 @@ async def invoke_stream( formed from the streamed chunks. parallel_tool_calls: Whether to configure parallel tool calls. response_format: The response format. - tools: The tools. + tools: The SDK-level tools (e.g. CodeInterpreter, FileSearch, AzureAISearch). When provided, + overrides the tools from the agent definition. Does not affect kernel function availability; + use function_choice_behavior for that. temperature: The temperature. top_p: The top p. truncation_strategy: The truncation strategy. + function_choice_behavior: Controls which kernel functions are allowed to execute during this run. + Use FunctionChoiceBehavior.Auto(filters={"included_functions": [...]}) to restrict to specific + functions. Only Auto is supported; other types will raise an error. kwargs: Additional keyword arguments. Returns: @@ -502,7 +523,10 @@ async def invoke_stream( kernel = kernel or agent.kernel arguments = agent._merge_arguments(arguments) - tools = cls._get_tools(agent=agent, kernel=kernel) # type: ignore + cls._validate_function_choice_behavior(function_choice_behavior) + + tools = cls._get_tools(agent=agent, kernel=kernel, tools_override=tools, + function_choice_behavior=function_choice_behavior) # type: ignore base_instructions = await agent.format_instructions(kernel=kernel, arguments=arguments) @@ -549,6 +573,7 @@ async def invoke_stream( arguments=arguments, function_steps=function_steps, active_messages=active_messages, + function_choice_behavior=function_choice_behavior, ): if content: yield content @@ -564,6 +589,7 @@ async def _process_stream_events( function_steps: dict[str, FunctionCallContent], active_messages: dict[str, RunStep], output_messages: "list[ChatMessageContent] | None" = None, + function_choice_behavior: FunctionChoiceBehavior | None = None, ) -> AsyncIterable["StreamingChatMessageContent"]: """Process events from the main stream and delegate tool output handling as needed.""" thread_msg_id = None @@ -671,6 +697,7 @@ async def _process_stream_events( run=run, function_steps=function_steps, arguments=arguments, + function_choice_behavior=function_choice_behavior, ) if action_result is None: raise RuntimeError( @@ -959,11 +986,51 @@ def _deduplicate_tools(existing_tools: list[dict], new_tools: list[dict]) -> lis } return [tool for tool in new_tools if tool.get("function", {}).get("name") not in existing_names] + @staticmethod + def _validate_function_choice_behavior( + function_choice_behavior: FunctionChoiceBehavior | None, + ) -> None: + """Validate the function choice behavior is compatible with agent invocations.""" + if function_choice_behavior is None: + return + if function_choice_behavior.type_ != FunctionChoiceType.AUTO: + raise AgentInvokeException( + f"FunctionChoiceBehavior with type '{function_choice_behavior.type_}' is not supported for agent " + "invocations. Use FunctionChoiceBehavior.Auto(filters=...) to control which kernel functions " + "are available." + ) + if not function_choice_behavior.auto_invoke_kernel_functions: + raise AgentInvokeException( + "FunctionChoiceBehavior.Auto(auto_invoke=False) is not supported for agent invocations. " + "The agent run loop manages tool invocation; disabling auto_invoke is not compatible." + ) + @classmethod - def _get_tools(cls: type[_T], agent: "AzureAIAgent", kernel: "Kernel") -> list[dict[str, Any] | ToolDefinition]: - """Get the tools for the agent.""" - tools: list[Any] = list(agent.definition.tools) - funcs = kernel.get_full_list_of_function_metadata() + def _get_tools( + cls: type[_T], + agent: "AzureAIAgent", + kernel: "Kernel", + tools_override: list[ToolDefinition] | None = None, + function_choice_behavior: FunctionChoiceBehavior | None = None, + ) -> list[dict[str, Any] | ToolDefinition]: + """Get the tools for the agent. + + Args: + agent: The agent instance. + kernel: The kernel to use for function metadata. + tools_override: When provided, overrides agent.definition.tools (SDK-level tools only). + function_choice_behavior: When provided, filters which kernel functions are included. + """ + tools: list[Any] = list(tools_override) if tools_override is not None else list(agent.definition.tools) + + # Determine kernel function metadata based on function_choice_behavior + if function_choice_behavior is not None and not function_choice_behavior.enable_kernel_functions: + funcs = [] + elif function_choice_behavior is not None and function_choice_behavior.filters: + funcs = kernel.get_list_of_function_metadata(function_choice_behavior.filters) + else: + funcs = kernel.get_full_list_of_function_metadata() + cls._validate_function_tools_registered(tools, funcs) dict_defs = [kernel_function_metadata_to_function_call_format(f) for f in funcs] deduped_defs = cls._deduplicate_tools(tools, dict_defs) @@ -1071,6 +1138,7 @@ async def _invoke_function_calls( fccs: list["FunctionCallContent"], chat_history: "ChatHistory", arguments: KernelArguments, + function_choice_behavior: FunctionChoiceBehavior | None = None, ) -> list["AutoFunctionInvocationContext | None"]: """Invoke the function calls.""" return await asyncio.gather( @@ -1079,6 +1147,7 @@ async def _invoke_function_calls( function_call=function_call, chat_history=chat_history, arguments=arguments, + function_behavior=function_choice_behavior, ) for function_call in fccs ], @@ -1111,6 +1180,7 @@ async def _handle_streaming_requires_action( run: ThreadRun, function_steps: dict[str, "FunctionCallContent"], arguments: KernelArguments, + function_choice_behavior: FunctionChoiceBehavior | None = None, **kwargs: Any, ) -> FunctionActionResult | None: """Handle the requires action event for a streaming run.""" @@ -1121,7 +1191,8 @@ async def _handle_streaming_requires_action( chat_history = ChatHistory() if kwargs.get("chat_history") is None else kwargs["chat_history"] results = await cls._invoke_function_calls( - kernel=kernel, fccs=fccs, chat_history=chat_history, arguments=arguments + kernel=kernel, fccs=fccs, chat_history=chat_history, arguments=arguments, + function_choice_behavior=function_choice_behavior, ) function_result_streaming_content = merge_streaming_function_results( diff --git a/python/semantic_kernel/agents/azure_ai/azure_ai_agent.py b/python/semantic_kernel/agents/azure_ai/azure_ai_agent.py index 44af0f22fee8..74b64ed534be 100644 --- a/python/semantic_kernel/agents/azure_ai/azure_ai_agent.py +++ b/python/semantic_kernel/agents/azure_ai/azure_ai_agent.py @@ -40,6 +40,7 @@ from semantic_kernel.agents.azure_ai.azure_ai_channel import AzureAIChannel from semantic_kernel.agents.channels.agent_channel import AgentChannel from semantic_kernel.agents.open_ai.run_polling_options import RunPollingOptions +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior from semantic_kernel.connectors.ai.function_calling_utils import kernel_function_metadata_to_function_call_format from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.utils.author_role import AuthorRole @@ -647,6 +648,7 @@ async def get_response( parallel_tool_calls: bool | None = None, metadata: dict[str, str] | None = None, polling_options: RunPollingOptions | None = None, + function_choice_behavior: FunctionChoiceBehavior | None = None, **kwargs: Any, ) -> AgentResponseItem[ChatMessageContent]: """Get a response from the agent on a thread. @@ -671,6 +673,8 @@ async def get_response( parallel_tool_calls: Whether to allow parallel tool calls. metadata: Metadata for the agent. polling_options: The polling options for the agent. + function_choice_behavior: The function choice behavior to control which kernel + functions are available. Only Auto is supported; other types will raise an error. **kwargs: Additional keyword arguments. Returns: @@ -716,6 +720,7 @@ async def get_response( thread_id=thread.id, kernel=kernel, arguments=arguments, + function_choice_behavior=function_choice_behavior, **run_level_params, # type: ignore ): if is_visible and response.metadata.get("code") is not True: @@ -752,6 +757,7 @@ async def invoke( parallel_tool_calls: bool | None = None, metadata: dict[str, str] | None = None, polling_options: RunPollingOptions | None = None, + function_choice_behavior: FunctionChoiceBehavior | None = None, **kwargs: Any, ) -> AsyncIterable[AgentResponseItem[ChatMessageContent]]: """Invoke the agent on the specified thread. @@ -777,6 +783,8 @@ async def invoke( parallel_tool_calls: Whether to allow parallel tool calls. polling_options: The polling options for the agent. metadata: Metadata for the agent. + function_choice_behavior: The function choice behavior to control which kernel + functions are available. Only Auto is supported; other types will raise an error. **kwargs: Additional keyword arguments. Yields: @@ -821,6 +829,7 @@ async def invoke( thread_id=thread.id, kernel=kernel, arguments=arguments, + function_choice_behavior=function_choice_behavior, **run_level_params, # type: ignore ): message.metadata["thread_id"] = thread.id @@ -856,6 +865,7 @@ async def invoke_stream( response_format: AgentsApiResponseFormatOption | None = None, parallel_tool_calls: bool | None = None, metadata: dict[str, str] | None = None, + function_choice_behavior: FunctionChoiceBehavior | None = None, **kwargs: Any, ) -> AsyncIterable[AgentResponseItem["StreamingChatMessageContent"]]: """Invoke the agent on the specified thread with a stream of messages. @@ -881,6 +891,8 @@ async def invoke_stream( response_format: Response format for the agent. parallel_tool_calls: Whether to allow parallel tool calls. metadata: Metadata for the agent. + function_choice_behavior: The function choice behavior to control which kernel + functions are available. Only Auto is supported; other types will raise an error. **kwargs: Additional keyword arguments. Yields: @@ -928,6 +940,7 @@ async def invoke_stream( output_messages=collected_messages, kernel=kernel, arguments=arguments, + function_choice_behavior=function_choice_behavior, **run_level_params, # type: ignore ): # Before yielding the current streamed message, emit any new full messages first diff --git a/python/semantic_kernel/agents/open_ai/assistant_thread_actions.py b/python/semantic_kernel/agents/open_ai/assistant_thread_actions.py index 3a6679df643f..1d2146232a04 100644 --- a/python/semantic_kernel/agents/open_ai/assistant_thread_actions.py +++ b/python/semantic_kernel/agents/open_ai/assistant_thread_actions.py @@ -33,6 +33,8 @@ ) from semantic_kernel.agents.open_ai.function_action_result import FunctionActionResult from semantic_kernel.agents.open_ai.run_polling_options import RunPollingOptions +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior +from semantic_kernel.connectors.ai.function_choice_type import FunctionChoiceType from semantic_kernel.connectors.ai.function_calling_utils import kernel_function_metadata_to_function_call_format from semantic_kernel.contents.file_reference_content import FileReferenceContent from semantic_kernel.contents.function_call_content import FunctionCallContent @@ -154,6 +156,7 @@ async def invoke( top_p: float | None = None, truncation_strategy: "TruncationStrategy | None" = None, polling_options: RunPollingOptions | None = None, + function_choice_behavior: FunctionChoiceBehavior | None = None, **kwargs: Any, ) -> AsyncIterable[tuple[bool, "ChatMessageContent"]]: """Invoke the assistant. @@ -173,12 +176,17 @@ async def invoke( parallel_tool_calls: The parallel tool calls. reasoning_effort: The reasoning effort. response_format: The response format. - tools: The tools. + tools: The SDK-level tools (e.g. CodeInterpreter, FileSearch). When provided, + overrides the tools from the agent definition. Does not affect kernel function availability; + use function_choice_behavior for that. temperature: The temperature. top_p: The top p. truncation_strategy: The truncation strategy. polling_options: The polling options defined at the run-level. These will override the agent-level polling options. + function_choice_behavior: Controls which kernel functions are allowed to execute during this run. + Use FunctionChoiceBehavior.Auto(filters={"included_functions": [...]}) to restrict to specific + functions. Only Auto is supported; other types will raise an error. kwargs: Additional keyword arguments. Returns: @@ -187,7 +195,10 @@ async def invoke( arguments = KernelArguments() if arguments is None else KernelArguments(**arguments, **kwargs) kernel = kernel or agent.kernel - tools = cls._get_tools(agent=agent, kernel=kernel) # type: ignore + cls._validate_function_choice_behavior(function_choice_behavior) + + tools = cls._get_tools(agent=agent, kernel=kernel, tools_override=tools, + function_choice_behavior=function_choice_behavior) # type: ignore base_instructions = await agent.format_instructions(kernel=kernel, arguments=arguments) @@ -260,7 +271,8 @@ async def invoke( chat_history = ChatHistory() _ = await cls._invoke_function_calls( - kernel=kernel, fccs=fccs, chat_history=chat_history, arguments=arguments + kernel=kernel, fccs=fccs, chat_history=chat_history, arguments=arguments, + function_choice_behavior=function_choice_behavior, ) tool_outputs = cls._format_tool_outputs(fccs, chat_history) @@ -374,6 +386,7 @@ async def invoke_stream( temperature: float | None = None, top_p: float | None = None, truncation_strategy: "TruncationStrategy | None" = None, + function_choice_behavior: FunctionChoiceBehavior | None = None, **kwargs: Any, ) -> AsyncIterable["StreamingChatMessageContent"]: """Invoke the assistant. @@ -396,10 +409,15 @@ async def invoke_stream( parallel_tool_calls: The parallel tool calls. reasoning_effort: The reasoning effort. response_format: The response format. - tools: The tools. + tools: The SDK-level tools (e.g. CodeInterpreter, FileSearch). When provided, + overrides the tools from the agent definition. Does not affect kernel function availability; + use function_choice_behavior for that. temperature: The temperature. top_p: The top p. truncation_strategy: The truncation strategy. + function_choice_behavior: Controls which kernel functions are allowed to execute during this run. + Use FunctionChoiceBehavior.Auto(filters={"included_functions": [...]}) to restrict to specific + functions. Only Auto is supported; other types will raise an error. kwargs: Additional keyword arguments. Returns: @@ -408,7 +426,10 @@ async def invoke_stream( arguments = KernelArguments() if arguments is None else KernelArguments(**arguments, **kwargs) kernel = kernel or agent.kernel - tools = cls._get_tools(agent=agent, kernel=kernel) # type: ignore + cls._validate_function_choice_behavior(function_choice_behavior) + + tools = cls._get_tools(agent=agent, kernel=kernel, tools_override=tools, + function_choice_behavior=function_choice_behavior) # type: ignore base_instructions = await agent.format_instructions(kernel=kernel, arguments=arguments) @@ -496,6 +517,7 @@ async def invoke_stream( run, function_steps, arguments, + function_choice_behavior=function_choice_behavior, ) if action_result is None: raise AgentInvokeException( @@ -553,6 +575,7 @@ async def _handle_streaming_requires_action( run: "Run", function_steps: dict[str, "FunctionCallContent"], arguments: KernelArguments, + function_choice_behavior: FunctionChoiceBehavior | None = None, **kwargs: Any, ) -> FunctionActionResult | None: """Handle the requires action event for a streaming run.""" @@ -563,7 +586,8 @@ async def _handle_streaming_requires_action( chat_history = ChatHistory() if kwargs.get("chat_history") is None else kwargs["chat_history"] results = await cls._invoke_function_calls( - kernel=kernel, fccs=fccs, chat_history=chat_history, arguments=arguments + kernel=kernel, fccs=fccs, chat_history=chat_history, arguments=arguments, + function_choice_behavior=function_choice_behavior, ) function_result_streaming_content = merge_streaming_function_results( @@ -658,6 +682,7 @@ async def _invoke_function_calls( fccs: list["FunctionCallContent"], chat_history: "ChatHistory", arguments: KernelArguments, + function_choice_behavior: FunctionChoiceBehavior | None = None, ) -> list["AutoFunctionInvocationContext | None"]: """Invoke the function calls.""" return await asyncio.gather( @@ -666,6 +691,7 @@ async def _invoke_function_calls( function_call=function_call, chat_history=chat_history, arguments=arguments, + function_behavior=function_choice_behavior, ) for function_call in fccs ], @@ -837,22 +863,61 @@ def _get_tool_definition(cls: type[_T], tools: list[Any]) -> Iterable["Additiona if tool_definition := cls.tool_metadata.get(tool): yield from tool_definition + @staticmethod + def _validate_function_choice_behavior( + function_choice_behavior: FunctionChoiceBehavior | None, + ) -> None: + """Validate the function choice behavior is compatible with agent invocations.""" + if function_choice_behavior is None: + return + if function_choice_behavior.type_ != FunctionChoiceType.AUTO: + raise AgentInvokeException( + f"FunctionChoiceBehavior with type '{function_choice_behavior.type_}' is not supported for agent " + "invocations. Use FunctionChoiceBehavior.Auto(filters=...) to control which kernel functions " + "are available." + ) + if not function_choice_behavior.auto_invoke_kernel_functions: + raise AgentInvokeException( + "FunctionChoiceBehavior.Auto(auto_invoke=False) is not supported for agent invocations. " + "The agent run loop manages tool invocation; disabling auto_invoke is not compatible." + ) + @classmethod - def _get_tools(cls: type[_T], agent: "OpenAIAssistantAgent", kernel: "Kernel") -> list[dict[str, str]]: + def _get_tools( + cls: type[_T], + agent: "OpenAIAssistantAgent", + kernel: "Kernel", + tools_override: "list[AssistantToolParam] | None" = None, + function_choice_behavior: FunctionChoiceBehavior | None = None, + ) -> list[dict[str, str]]: """Get the list of tools for the assistant. + Args: + agent: The assistant agent. + kernel: The kernel to use for function metadata. + tools_override: When provided, overrides agent.definition.tools (SDK-level tools only). + function_choice_behavior: When provided, filters which kernel functions are included. + Returns: The list of tools. """ tools: list[Any] = [] - for tool in agent.definition.tools: + source_tools = tools_override if tools_override is not None else agent.definition.tools + for tool in source_tools: if isinstance(tool, CodeInterpreterTool): tools.append({"type": "code_interpreter"}) elif isinstance(tool, FileSearchTool): tools.append({"type": "file_search"}) - funcs = agent.kernel.get_full_list_of_function_metadata() + # Determine kernel function metadata based on function_choice_behavior + if function_choice_behavior is not None and not function_choice_behavior.enable_kernel_functions: + funcs = [] + elif function_choice_behavior is not None and function_choice_behavior.filters: + funcs = kernel.get_list_of_function_metadata(function_choice_behavior.filters) + else: + funcs = kernel.get_full_list_of_function_metadata() + tools.extend([kernel_function_metadata_to_function_call_format(f) for f in funcs]) return tools diff --git a/python/semantic_kernel/agents/open_ai/openai_assistant_agent.py b/python/semantic_kernel/agents/open_ai/openai_assistant_agent.py index a1daaa75f5be..123f19aaf448 100644 --- a/python/semantic_kernel/agents/open_ai/openai_assistant_agent.py +++ b/python/semantic_kernel/agents/open_ai/openai_assistant_agent.py @@ -35,6 +35,7 @@ from semantic_kernel.agents.channels.open_ai_assistant_channel import OpenAIAssistantChannel from semantic_kernel.agents.open_ai.assistant_thread_actions import AssistantThreadActions from semantic_kernel.agents.open_ai.run_polling_options import RunPollingOptions +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior from semantic_kernel.connectors.ai.open_ai.settings.open_ai_settings import OpenAISettings from semantic_kernel.connectors.utils.structured_output_schema import generate_structured_output_response_format_schema from semantic_kernel.contents.chat_message_content import ChatMessageContent @@ -758,6 +759,7 @@ async def get_response( top_p: float | None = None, truncation_strategy: "TruncationStrategy | None" = None, polling_options: RunPollingOptions | None = None, + function_choice_behavior: "FunctionChoiceBehavior | None" = None, **kwargs: Any, ) -> AgentResponseItem[ChatMessageContent]: """Get a response from the agent on a thread. @@ -783,6 +785,8 @@ async def get_response( top_p: The top p. truncation_strategy: The truncation strategy. polling_options: The polling options at the run-level. + function_choice_behavior: The function choice behavior to control which kernel + functions are available. Only Auto is supported; other types will raise an error. kwargs: Additional keyword arguments. Returns: @@ -829,6 +833,7 @@ async def get_response( thread_id=thread.id, kernel=kernel, arguments=arguments, + function_choice_behavior=function_choice_behavior, **run_level_params, # type: ignore ): if is_visible and response.metadata.get("code") is not True: @@ -866,6 +871,7 @@ async def invoke( top_p: float | None = None, truncation_strategy: "TruncationStrategy | None" = None, polling_options: RunPollingOptions | None = None, + function_choice_behavior: "FunctionChoiceBehavior | None" = None, **kwargs: Any, ) -> AsyncIterable[AgentResponseItem[ChatMessageContent]]: """Invoke the agent. @@ -892,6 +898,8 @@ async def invoke( top_p: The top p. truncation_strategy: The truncation strategy. polling_options: The polling options at the run-level. + function_choice_behavior: The function choice behavior to control which kernel + functions are available. Only Auto is supported; other types will raise an error. kwargs: Additional keyword arguments. Yields: @@ -937,6 +945,7 @@ async def invoke( thread_id=thread.id, kernel=kernel, arguments=arguments, + function_choice_behavior=function_choice_behavior, **run_level_params, # type: ignore ): message.metadata["thread_id"] = thread.id @@ -973,6 +982,7 @@ async def invoke_stream( temperature: float | None = None, top_p: float | None = None, truncation_strategy: "TruncationStrategy | None" = None, + function_choice_behavior: "FunctionChoiceBehavior | None" = None, **kwargs: Any, ) -> AsyncIterable[AgentResponseItem[StreamingChatMessageContent]]: """Invoke the agent. @@ -999,6 +1009,8 @@ async def invoke_stream( temperature: The temperature. top_p: The top p. truncation_strategy: The truncation strategy. + function_choice_behavior: The function choice behavior to control which kernel + functions are available. Only Auto is supported; other types will raise an error. kwargs: Additional keyword arguments. Yields: @@ -1047,6 +1059,7 @@ async def invoke_stream( output_messages=collected_messages, kernel=kernel, arguments=arguments, + function_choice_behavior=function_choice_behavior, **run_level_params, # type: ignore ): # Before yielding the current streamed message, emit any new full messages first diff --git a/python/tests/unit/agents/azure_ai_agent/test_agent_thread_actions.py b/python/tests/unit/agents/azure_ai_agent/test_agent_thread_actions.py index 000491d09021..a83bcdef0c4e 100644 --- a/python/tests/unit/agents/azure_ai_agent/test_agent_thread_actions.py +++ b/python/tests/unit/agents/azure_ai_agent/test_agent_thread_actions.py @@ -3,6 +3,8 @@ from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch +import pytest + from azure.ai.agents.models import ( MessageTextContent, MessageTextDetails, @@ -26,9 +28,12 @@ from semantic_kernel.agents.azure_ai.agent_thread_actions import AgentThreadActions from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior from semantic_kernel.contents import FunctionCallContent, FunctionResultContent, TextContent from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.utils.author_role import AuthorRole +from semantic_kernel.exceptions.agent_exceptions import AgentInvokeException +from semantic_kernel.functions.kernel_arguments import KernelArguments from semantic_kernel.kernel import Kernel @@ -354,3 +359,161 @@ async def test_agent_thread_actions_invoke_stream(ai_project_client, ai_agent_de collected_messages.append(content) assert isinstance(content, ChatMessageContent) assert content.metadata.get("message_id") == "msg_1" + + +# region Security tests for tools override and function_choice_behavior + + +async def test_validate_function_choice_behavior_rejects_required(): + """Required FCB is not supported for agent invocations.""" + with pytest.raises(AgentInvokeException, match="not supported"): + AgentThreadActions._validate_function_choice_behavior( + FunctionChoiceBehavior.Required() + ) + + +async def test_validate_function_choice_behavior_accepts_auto(): + """Auto FCB should be accepted without error.""" + AgentThreadActions._validate_function_choice_behavior( + FunctionChoiceBehavior.Auto() + ) + + +async def test_validate_function_choice_behavior_rejects_none_invoke(): + """NoneInvoke FCB is not supported for agent invocations.""" + with pytest.raises(AgentInvokeException, match="not supported"): + AgentThreadActions._validate_function_choice_behavior( + FunctionChoiceBehavior.NoneInvoke() + ) + + +async def test_validate_function_choice_behavior_accepts_none(): + """None (no FCB) should be accepted.""" + AgentThreadActions._validate_function_choice_behavior(None) + + +async def test_validate_function_choice_behavior_rejects_auto_invoke_false(): + """Auto with auto_invoke=False is not supported for agent invocations.""" + with pytest.raises(AgentInvokeException, match="auto_invoke"): + AgentThreadActions._validate_function_choice_behavior( + FunctionChoiceBehavior.Auto(auto_invoke=False) + ) + + +async def test_get_tools_with_tools_override(ai_project_client, ai_agent_definition): + """When tools_override is provided, it should replace agent.definition.tools.""" + from azure.ai.agents.models import CodeInterpreterToolDefinition + + agent = AzureAIAgent(client=ai_project_client, definition=ai_agent_definition) + kernel = MagicMock(spec=Kernel) + kernel.get_full_list_of_function_metadata.return_value = [] + + override_tool = CodeInterpreterToolDefinition() + tools = AgentThreadActions._get_tools( + agent=agent, kernel=kernel, tools_override=[override_tool] + ) + # Should contain the override tool, not agent.definition.tools + assert any( + (isinstance(t, CodeInterpreterToolDefinition) or (isinstance(t, dict) and t.get("type") == "code_interpreter")) + for t in tools + ) + + +async def test_get_tools_with_fcb_filters(ai_project_client, ai_agent_definition): + """When function_choice_behavior has filters, only matching functions should be included.""" + agent = AzureAIAgent(client=ai_project_client, definition=ai_agent_definition) + kernel = MagicMock(spec=Kernel) + + # Simulate filtered metadata + mock_metadata = MagicMock() + mock_metadata.fully_qualified_name = "Plugin-AllowedFunc" + mock_metadata.name = "AllowedFunc" + mock_metadata.plugin_name = "Plugin" + mock_metadata.description = "An allowed function" + mock_metadata.parameters = [] + mock_metadata.is_prompt = False + mock_metadata.return_parameter = MagicMock() + mock_metadata.return_parameter.description = "" + mock_metadata.return_parameter.type_ = "str" + mock_metadata.additional_properties = {} + + kernel.get_list_of_function_metadata.return_value = [mock_metadata] + kernel.get_full_list_of_function_metadata.return_value = [] + + fcb = FunctionChoiceBehavior.Auto( + filters={"included_functions": ["Plugin-AllowedFunc"]} + ) + AgentThreadActions._get_tools( + agent=agent, kernel=kernel, function_choice_behavior=fcb + ) + # Should have called get_list_of_function_metadata with the filters + kernel.get_list_of_function_metadata.assert_called_once_with(fcb.filters) + + +async def test_get_tools_with_fcb_disable_kernel_functions(ai_project_client, ai_agent_definition): + """When enable_kernel_functions=False, no kernel functions should be included.""" + agent = AzureAIAgent(client=ai_project_client, definition=ai_agent_definition) + kernel = MagicMock(spec=Kernel) + + fcb = FunctionChoiceBehavior.Auto(enable_kernel_functions=False) + AgentThreadActions._get_tools( + agent=agent, kernel=kernel, function_choice_behavior=fcb + ) + # Should NOT have called any function metadata methods + kernel.get_full_list_of_function_metadata.assert_not_called() + kernel.get_list_of_function_metadata.assert_not_called() + + +async def test_invoke_function_calls_passes_function_behavior(): + """_invoke_function_calls should pass function_behavior to kernel.invoke_function_call.""" + mock_kernel = AsyncMock(spec=Kernel) + mock_kernel.invoke_function_call.return_value = None + + fcc = FunctionCallContent(name="Plugin-Func", arguments={}, id="call1") + from semantic_kernel.contents.chat_history import ChatHistory + + chat_history = ChatHistory() + fcb = FunctionChoiceBehavior.Auto( + filters={"included_functions": ["Plugin-Func"]} + ) + + await AgentThreadActions._invoke_function_calls( + kernel=mock_kernel, + fccs=[fcc], + chat_history=chat_history, + arguments=KernelArguments(), + function_choice_behavior=fcb, + ) + + mock_kernel.invoke_function_call.assert_awaited_once() + call_kwargs = mock_kernel.invoke_function_call.call_args + assert call_kwargs.kwargs.get("function_behavior") is fcb + + +async def test_invoke_function_calls_passes_disabled_kernel_functions(): + """_invoke_function_calls should pass enable_kernel_functions=False FCB to kernel.""" + mock_kernel = AsyncMock(spec=Kernel) + mock_kernel.invoke_function_call.return_value = None + + fcc = FunctionCallContent(name="Plugin-Func", arguments={}, id="call1") + from semantic_kernel.contents.chat_history import ChatHistory + + chat_history = ChatHistory() + fcb = FunctionChoiceBehavior.Auto(enable_kernel_functions=False) + + await AgentThreadActions._invoke_function_calls( + kernel=mock_kernel, + fccs=[fcc], + chat_history=chat_history, + arguments=KernelArguments(), + function_choice_behavior=fcb, + ) + + mock_kernel.invoke_function_call.assert_awaited_once() + call_kwargs = mock_kernel.invoke_function_call.call_args + passed_behavior = call_kwargs.kwargs.get("function_behavior") + assert passed_behavior is fcb + assert not passed_behavior.enable_kernel_functions + + +# endregion diff --git a/python/tests/unit/agents/azure_ai_agent/test_azure_ai_agent.py b/python/tests/unit/agents/azure_ai_agent/test_azure_ai_agent.py index b5dc1178b6a1..dff8f210f063 100644 --- a/python/tests/unit/agents/azure_ai_agent/test_azure_ai_agent.py +++ b/python/tests/unit/agents/azure_ai_agent/test_azure_ai_agent.py @@ -9,6 +9,7 @@ from semantic_kernel.agents.agent import AgentResponseItem from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent, AzureAIAgentThread from semantic_kernel.agents.channels.agent_channel import AgentChannel +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent @@ -398,3 +399,80 @@ def test_create_client_raises_if_no_endpoint(): assert "Azure AI endpoint" in str(e) else: assert False, "Expected AgentInitializationException to be raised" + + +async def test_azure_ai_agent_get_response_passes_function_choice_behavior(ai_project_client, ai_agent_definition): + agent = AzureAIAgent(client=ai_project_client, definition=ai_agent_definition) + thread = AsyncMock(spec=AzureAIAgentThread) + fcb = FunctionChoiceBehavior.Auto() + captured_kwargs = {} + + async def fake_invoke(*args, **kwargs): + captured_kwargs.update(kwargs) + yield True, ChatMessageContent(role=AuthorRole.ASSISTANT, content="content") + + with patch( + "semantic_kernel.agents.azure_ai.agent_thread_actions.AgentThreadActions.invoke", + side_effect=fake_invoke, + ): + await agent.get_response(messages="message", thread=thread, function_choice_behavior=fcb) + + assert captured_kwargs.get("function_choice_behavior") is fcb + + +async def test_azure_ai_agent_invoke_passes_function_choice_behavior(ai_project_client, ai_agent_definition): + agent = AzureAIAgent(client=ai_project_client, definition=ai_agent_definition) + thread = AsyncMock(spec=AzureAIAgentThread) + fcb = FunctionChoiceBehavior.Auto() + captured_kwargs = {} + + async def fake_invoke(*args, **kwargs): + captured_kwargs.update(kwargs) + yield True, ChatMessageContent(role=AuthorRole.ASSISTANT, content="content") + + with patch( + "semantic_kernel.agents.azure_ai.agent_thread_actions.AgentThreadActions.invoke", + side_effect=fake_invoke, + ): + async for _ in agent.invoke(messages="message", thread=thread, function_choice_behavior=fcb): + pass + + assert captured_kwargs.get("function_choice_behavior") is fcb + + +async def test_azure_ai_agent_invoke_stream_passes_function_choice_behavior(ai_project_client, ai_agent_definition): + agent = AzureAIAgent(client=ai_project_client, definition=ai_agent_definition) + thread = AsyncMock(spec=AzureAIAgentThread) + fcb = FunctionChoiceBehavior.Auto() + captured_kwargs = {} + + async def fake_invoke(*args, **kwargs): + captured_kwargs.update(kwargs) + yield ChatMessageContent(role=AuthorRole.ASSISTANT, content="content") + + with patch( + "semantic_kernel.agents.azure_ai.agent_thread_actions.AgentThreadActions.invoke_stream", + side_effect=fake_invoke, + ): + async for _ in agent.invoke_stream(messages="message", thread=thread, function_choice_behavior=fcb): + pass + + assert captured_kwargs.get("function_choice_behavior") is fcb + + +async def test_azure_ai_agent_get_response_no_fcb_passes_none(ai_project_client, ai_agent_definition): + agent = AzureAIAgent(client=ai_project_client, definition=ai_agent_definition) + thread = AsyncMock(spec=AzureAIAgentThread) + captured_kwargs = {} + + async def fake_invoke(*args, **kwargs): + captured_kwargs.update(kwargs) + yield True, ChatMessageContent(role=AuthorRole.ASSISTANT, content="content") + + with patch( + "semantic_kernel.agents.azure_ai.agent_thread_actions.AgentThreadActions.invoke", + side_effect=fake_invoke, + ): + await agent.get_response(messages="message", thread=thread) + + assert captured_kwargs.get("function_choice_behavior") is None diff --git a/python/tests/unit/agents/openai_assistant/test_assistant_thread_actions.py b/python/tests/unit/agents/openai_assistant/test_assistant_thread_actions.py index 1bb688bb42c0..c9b3436deac0 100644 --- a/python/tests/unit/agents/openai_assistant/test_assistant_thread_actions.py +++ b/python/tests/unit/agents/openai_assistant/test_assistant_thread_actions.py @@ -55,6 +55,7 @@ from semantic_kernel.agents.open_ai.function_action_result import FunctionActionResult from semantic_kernel.agents.open_ai.openai_assistant_agent import OpenAIAssistantAgent from semantic_kernel.agents.open_ai.run_polling_options import RunPollingOptions +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.file_reference_content import FileReferenceContent from semantic_kernel.contents.function_call_content import FunctionCallContent @@ -853,3 +854,185 @@ async def test_handle_streaming_requires_action_returns_none(): dummy_args, ) assert result is None + + +# region Security tests for tools override and function_choice_behavior + + +async def test_validate_function_choice_behavior_rejects_required(): + """Required FCB is not supported for agent invocations.""" + with pytest.raises(AgentInvokeException, match="not supported"): + AssistantThreadActions._validate_function_choice_behavior( + FunctionChoiceBehavior.Required() + ) + + +async def test_validate_function_choice_behavior_accepts_auto(): + """Auto FCB should be accepted without error.""" + AssistantThreadActions._validate_function_choice_behavior( + FunctionChoiceBehavior.Auto() + ) + + +async def test_validate_function_choice_behavior_rejects_none_invoke(): + """NoneInvoke FCB is not supported for agent invocations.""" + with pytest.raises(AgentInvokeException, match="not supported"): + AssistantThreadActions._validate_function_choice_behavior( + FunctionChoiceBehavior.NoneInvoke() + ) + + +async def test_validate_function_choice_behavior_accepts_none(): + """None (no FCB) should be accepted.""" + AssistantThreadActions._validate_function_choice_behavior(None) + + +async def test_validate_function_choice_behavior_rejects_auto_invoke_false(): + """Auto with auto_invoke=False is not supported for agent invocations.""" + with pytest.raises(AgentInvokeException, match="auto_invoke"): + AssistantThreadActions._validate_function_choice_behavior( + FunctionChoiceBehavior.Auto(auto_invoke=False) + ) + + +async def test_get_tools_with_tools_override(): + """When tools_override is provided, it should replace agent.definition.tools.""" + agent = MagicMock(spec=OpenAIAssistantAgent) + agent.definition = MagicMock() + agent.definition.tools = [CodeInterpreterTool(type="code_interpreter")] + agent.kernel = MagicMock(spec=Kernel) + + kernel = MagicMock(spec=Kernel) + kernel.get_full_list_of_function_metadata.return_value = [] + + # Override with file_search only + override_tools = [FileSearchTool(type="file_search")] + tools = AssistantThreadActions._get_tools( + agent=agent, kernel=kernel, tools_override=override_tools + ) + # Should contain file_search from override, not code_interpreter from agent + tool_types = [t.get("type") if isinstance(t, dict) else None for t in tools] + assert "file_search" in tool_types + # Agent's code_interpreter should NOT be in the result + assert "code_interpreter" not in tool_types + + +async def test_get_tools_with_fcb_filters(): + """When function_choice_behavior has filters, only matching functions should be included.""" + agent = MagicMock(spec=OpenAIAssistantAgent) + agent.definition = MagicMock() + agent.definition.tools = [] + agent.kernel = MagicMock(spec=Kernel) + + kernel = MagicMock(spec=Kernel) + + mock_metadata = MagicMock() + mock_metadata.fully_qualified_name = "Plugin-AllowedFunc" + mock_metadata.name = "AllowedFunc" + mock_metadata.plugin_name = "Plugin" + mock_metadata.description = "An allowed function" + mock_metadata.parameters = [] + mock_metadata.is_prompt = False + mock_metadata.return_parameter = MagicMock() + mock_metadata.return_parameter.description = "" + mock_metadata.return_parameter.type_ = "str" + mock_metadata.additional_properties = {} + + kernel.get_list_of_function_metadata.return_value = [mock_metadata] + + fcb = FunctionChoiceBehavior.Auto( + filters={"included_functions": ["Plugin-AllowedFunc"]} + ) + AssistantThreadActions._get_tools( + agent=agent, kernel=kernel, function_choice_behavior=fcb + ) + kernel.get_list_of_function_metadata.assert_called_once_with(fcb.filters) + + +async def test_get_tools_with_fcb_disable_kernel_functions(): + """When enable_kernel_functions=False, no kernel functions should be included.""" + agent = MagicMock(spec=OpenAIAssistantAgent) + agent.definition = MagicMock() + agent.definition.tools = [] + agent.kernel = MagicMock(spec=Kernel) + + kernel = MagicMock(spec=Kernel) + + fcb = FunctionChoiceBehavior.Auto(enable_kernel_functions=False) + AssistantThreadActions._get_tools( + agent=agent, kernel=kernel, function_choice_behavior=fcb + ) + kernel.get_full_list_of_function_metadata.assert_not_called() + kernel.get_list_of_function_metadata.assert_not_called() + + +async def test_invoke_function_calls_passes_function_behavior(): + """_invoke_function_calls should pass function_behavior to kernel.invoke_function_call.""" + mock_kernel = AsyncMock(spec=Kernel) + mock_kernel.invoke_function_call.return_value = None + + fcc = FunctionCallContent(name="Plugin-Func", arguments={}, id="call1") + from semantic_kernel.contents.chat_history import ChatHistory + + chat_history = ChatHistory() + fcb = FunctionChoiceBehavior.Auto( + filters={"included_functions": ["Plugin-Func"]} + ) + + await AssistantThreadActions._invoke_function_calls( + kernel=mock_kernel, + fccs=[fcc], + chat_history=chat_history, + arguments=KernelArguments(), + function_choice_behavior=fcb, + ) + + mock_kernel.invoke_function_call.assert_awaited_once() + call_kwargs = mock_kernel.invoke_function_call.call_args + assert call_kwargs.kwargs.get("function_behavior") is fcb + + +async def test_invoke_function_calls_passes_disabled_kernel_functions(): + """_invoke_function_calls should pass enable_kernel_functions=False FCB to kernel.""" + mock_kernel = AsyncMock(spec=Kernel) + mock_kernel.invoke_function_call.return_value = None + + fcc = FunctionCallContent(name="Plugin-Func", arguments={}, id="call1") + from semantic_kernel.contents.chat_history import ChatHistory + + chat_history = ChatHistory() + fcb = FunctionChoiceBehavior.Auto(enable_kernel_functions=False) + + await AssistantThreadActions._invoke_function_calls( + kernel=mock_kernel, + fccs=[fcc], + chat_history=chat_history, + arguments=KernelArguments(), + function_choice_behavior=fcb, + ) + + mock_kernel.invoke_function_call.assert_awaited_once() + call_kwargs = mock_kernel.invoke_function_call.call_args + passed_behavior = call_kwargs.kwargs.get("function_behavior") + assert passed_behavior is fcb + assert not passed_behavior.enable_kernel_functions + + +async def test_get_tools_uses_passed_kernel_not_agent_kernel(): + """_get_tools should use the passed kernel parameter, not agent.kernel.""" + agent = MagicMock(spec=OpenAIAssistantAgent) + agent.definition = MagicMock() + agent.definition.tools = [] + agent.kernel = MagicMock(spec=Kernel) + agent.kernel.get_full_list_of_function_metadata.return_value = ["should_not_be_used"] + + kernel = MagicMock(spec=Kernel) + kernel.get_full_list_of_function_metadata.return_value = [] + + AssistantThreadActions._get_tools(agent=agent, kernel=kernel) + # Should call the passed kernel, not agent.kernel + kernel.get_full_list_of_function_metadata.assert_called_once() + agent.kernel.get_full_list_of_function_metadata.assert_not_called() + + +# endregion diff --git a/python/tests/unit/agents/openai_assistant/test_openai_assistant_agent.py b/python/tests/unit/agents/openai_assistant/test_openai_assistant_agent.py index 6423ebf39b74..88262c22b925 100644 --- a/python/tests/unit/agents/openai_assistant/test_openai_assistant_agent.py +++ b/python/tests/unit/agents/openai_assistant/test_openai_assistant_agent.py @@ -10,6 +10,7 @@ from semantic_kernel.agents import AgentRegistry, AgentResponseItem, OpenAIAssistantAgent from semantic_kernel.agents.open_ai.openai_assistant_agent import AssistantAgentThread from semantic_kernel.agents.open_ai.run_polling_options import RunPollingOptions +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent @@ -497,3 +498,84 @@ async def test_openai_assistant_agent_from_yaml_invalid_type(): """ with pytest.raises(AgentInitializationException, match="not registered"): await AgentRegistry.create_from_yaml(spec) + + +async def test_openai_assistant_agent_get_response_passes_function_choice_behavior( + openai_client, assistant_definition +): + agent = OpenAIAssistantAgent(client=openai_client, definition=assistant_definition) + thread = AsyncMock(spec=AssistantAgentThread) + fcb = FunctionChoiceBehavior.Auto() + captured_kwargs = {} + + async def fake_invoke(*args, **kwargs): + captured_kwargs.update(kwargs) + yield True, ChatMessageContent(role=AuthorRole.ASSISTANT, content="content") + + with patch( + "semantic_kernel.agents.open_ai.assistant_thread_actions.AssistantThreadActions.invoke", + side_effect=fake_invoke, + ): + await agent.get_response(messages="message", thread=thread, function_choice_behavior=fcb) + + assert captured_kwargs.get("function_choice_behavior") is fcb + + +async def test_openai_assistant_agent_invoke_passes_function_choice_behavior(openai_client, assistant_definition): + agent = OpenAIAssistantAgent(client=openai_client, definition=assistant_definition) + thread = AsyncMock(spec=AssistantAgentThread) + fcb = FunctionChoiceBehavior.Auto() + captured_kwargs = {} + + async def fake_invoke(*args, **kwargs): + captured_kwargs.update(kwargs) + yield True, ChatMessageContent(role=AuthorRole.ASSISTANT, content="content") + + with patch( + "semantic_kernel.agents.open_ai.assistant_thread_actions.AssistantThreadActions.invoke", + side_effect=fake_invoke, + ): + async for _ in agent.invoke(messages="message", thread=thread, function_choice_behavior=fcb): + pass + + assert captured_kwargs.get("function_choice_behavior") is fcb + + +async def test_openai_assistant_agent_invoke_stream_passes_function_choice_behavior( + openai_client, assistant_definition +): + agent = OpenAIAssistantAgent(client=openai_client, definition=assistant_definition) + thread = AsyncMock(spec=AssistantAgentThread) + fcb = FunctionChoiceBehavior.Auto() + captured_kwargs = {} + + async def fake_invoke(*args, **kwargs): + captured_kwargs.update(kwargs) + yield ChatMessageContent(role=AuthorRole.ASSISTANT, content="content") + + with patch( + "semantic_kernel.agents.open_ai.assistant_thread_actions.AssistantThreadActions.invoke_stream", + side_effect=fake_invoke, + ): + async for _ in agent.invoke_stream(messages="message", thread=thread, function_choice_behavior=fcb): + pass + + assert captured_kwargs.get("function_choice_behavior") is fcb + + +async def test_openai_assistant_agent_get_response_no_fcb_passes_none(openai_client, assistant_definition): + agent = OpenAIAssistantAgent(client=openai_client, definition=assistant_definition) + thread = AsyncMock(spec=AssistantAgentThread) + captured_kwargs = {} + + async def fake_invoke(*args, **kwargs): + captured_kwargs.update(kwargs) + yield True, ChatMessageContent(role=AuthorRole.ASSISTANT, content="content") + + with patch( + "semantic_kernel.agents.open_ai.assistant_thread_actions.AssistantThreadActions.invoke", + side_effect=fake_invoke, + ): + await agent.get_response(messages="message", thread=thread) + + assert captured_kwargs.get("function_choice_behavior") is None