Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 81 additions & 10 deletions python/semantic_kernel/agents/azure_ai/agent_thread_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
from semantic_kernel.agents.open_ai.assistant_content_generation import merge_streaming_function_results
from semantic_kernel.agents.open_ai.function_action_result import FunctionActionResult
from semantic_kernel.agents.open_ai.run_polling_options import RunPollingOptions
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.connectors.ai.function_choice_type import FunctionChoiceType
from semantic_kernel.connectors.ai.function_calling_utils import kernel_function_metadata_to_function_call_format
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
Expand Down Expand Up @@ -124,6 +126,7 @@ async def invoke(
parallel_tool_calls: bool | None = None,
metadata: dict[str, str] | None = None,
polling_options: RunPollingOptions | None = None,
function_choice_behavior: FunctionChoiceBehavior | None = None,
**kwargs: Any,
) -> AsyncIterable[tuple[bool, "ChatMessageContent"]]:
"""Invoke the message in the thread.
Expand All @@ -139,7 +142,9 @@ async def invoke(
additional_messages: The additional messages to add to the thread. Only supports messages with
role = User or Assistant.
https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-additional_messages
tools: The tools.
tools: The SDK-level tools (e.g. CodeInterpreter, FileSearch, AzureAISearch). When provided,
overrides the tools from the agent definition. Does not affect kernel function availability;
use function_choice_behavior for that.
temperature: The temperature.
top_p: The top p.
max_prompt_tokens: The max prompt tokens.
Expand All @@ -150,6 +155,9 @@ async def invoke(
metadata: The metadata.
polling_options: The polling options defined at the run-level. These will override the agent-level
polling options.
function_choice_behavior: Controls which kernel functions are allowed to execute during this run.
Use FunctionChoiceBehavior.Auto(filters={"included_functions": [...]}) to restrict to specific
functions. Only Auto is supported; other types will raise an error.
kwargs: Additional keyword arguments.

Returns:
Expand All @@ -158,7 +166,10 @@ async def invoke(
arguments = KernelArguments() if arguments is None else KernelArguments(**arguments, **kwargs)
kernel = kernel or agent.kernel

tools = cls._get_tools(agent=agent, kernel=kernel) # type: ignore
cls._validate_function_choice_behavior(function_choice_behavior)

tools = cls._get_tools(agent=agent, kernel=kernel, tools_override=tools,
function_choice_behavior=function_choice_behavior) # type: ignore

base_instructions = await agent.format_instructions(kernel=kernel, arguments=arguments)

Expand Down Expand Up @@ -232,7 +243,11 @@ async def invoke(

chat_history = ChatHistory() if kwargs.get("chat_history") is None else kwargs["chat_history"]
_ = await cls._invoke_function_calls(
kernel=kernel, fccs=fccs, chat_history=chat_history, arguments=arguments
kernel=kernel,
fccs=fccs,
chat_history=chat_history,
arguments=arguments,
function_choice_behavior=function_choice_behavior,
)

tool_outputs = cls._format_tool_outputs(fccs, chat_history)
Expand Down Expand Up @@ -467,6 +482,7 @@ async def invoke_stream(
temperature: float | None = None,
top_p: float | None = None,
truncation_strategy: TruncationObject | None = None,
function_choice_behavior: FunctionChoiceBehavior | None = None,
**kwargs: Any,
) -> AsyncIterable["StreamingChatMessageContent"]:
"""Invoke the agent stream and yield ChatMessageContent continuously.
Expand All @@ -489,10 +505,15 @@ async def invoke_stream(
formed from the streamed chunks.
parallel_tool_calls: Whether to configure parallel tool calls.
response_format: The response format.
tools: The tools.
tools: The SDK-level tools (e.g. CodeInterpreter, FileSearch, AzureAISearch). When provided,
overrides the tools from the agent definition. Does not affect kernel function availability;
use function_choice_behavior for that.
temperature: The temperature.
top_p: The top p.
truncation_strategy: The truncation strategy.
function_choice_behavior: Controls which kernel functions are allowed to execute during this run.
Use FunctionChoiceBehavior.Auto(filters={"included_functions": [...]}) to restrict to specific
functions. Only Auto is supported; other types will raise an error.
kwargs: Additional keyword arguments.

Returns:
Expand All @@ -502,7 +523,10 @@ async def invoke_stream(
kernel = kernel or agent.kernel
arguments = agent._merge_arguments(arguments)

tools = cls._get_tools(agent=agent, kernel=kernel) # type: ignore
cls._validate_function_choice_behavior(function_choice_behavior)

tools = cls._get_tools(agent=agent, kernel=kernel, tools_override=tools,
function_choice_behavior=function_choice_behavior) # type: ignore

base_instructions = await agent.format_instructions(kernel=kernel, arguments=arguments)

Expand Down Expand Up @@ -549,6 +573,7 @@ async def invoke_stream(
arguments=arguments,
function_steps=function_steps,
active_messages=active_messages,
function_choice_behavior=function_choice_behavior,
):
if content:
yield content
Expand All @@ -564,6 +589,7 @@ async def _process_stream_events(
function_steps: dict[str, FunctionCallContent],
active_messages: dict[str, RunStep],
output_messages: "list[ChatMessageContent] | None" = None,
function_choice_behavior: FunctionChoiceBehavior | None = None,
) -> AsyncIterable["StreamingChatMessageContent"]:
"""Process events from the main stream and delegate tool output handling as needed."""
thread_msg_id = None
Expand Down Expand Up @@ -671,6 +697,7 @@ async def _process_stream_events(
run=run,
function_steps=function_steps,
arguments=arguments,
function_choice_behavior=function_choice_behavior,
)
if action_result is None:
raise RuntimeError(
Expand Down Expand Up @@ -959,11 +986,51 @@ def _deduplicate_tools(existing_tools: list[dict], new_tools: list[dict]) -> lis
}
return [tool for tool in new_tools if tool.get("function", {}).get("name") not in existing_names]

@staticmethod
def _validate_function_choice_behavior(
function_choice_behavior: FunctionChoiceBehavior | None,
) -> None:
"""Validate the function choice behavior is compatible with agent invocations."""
if function_choice_behavior is None:
return
if function_choice_behavior.type_ != FunctionChoiceType.AUTO:
raise AgentInvokeException(
f"FunctionChoiceBehavior with type '{function_choice_behavior.type_}' is not supported for agent "
"invocations. Use FunctionChoiceBehavior.Auto(filters=...) to control which kernel functions "
"are available."
)
if not function_choice_behavior.auto_invoke_kernel_functions:
raise AgentInvokeException(
"FunctionChoiceBehavior.Auto(auto_invoke=False) is not supported for agent invocations. "
"The agent run loop manages tool invocation; disabling auto_invoke is not compatible."
)

@classmethod
def _get_tools(cls: type[_T], agent: "AzureAIAgent", kernel: "Kernel") -> list[dict[str, Any] | ToolDefinition]:
"""Get the tools for the agent."""
tools: list[Any] = list(agent.definition.tools)
funcs = kernel.get_full_list_of_function_metadata()
def _get_tools(
cls: type[_T],
agent: "AzureAIAgent",
kernel: "Kernel",
tools_override: list[ToolDefinition] | None = None,
function_choice_behavior: FunctionChoiceBehavior | None = None,
) -> list[dict[str, Any] | ToolDefinition]:
"""Get the tools for the agent.

Args:
agent: The agent instance.
kernel: The kernel to use for function metadata.
tools_override: When provided, overrides agent.definition.tools (SDK-level tools only).
function_choice_behavior: When provided, filters which kernel functions are included.
"""
tools: list[Any] = list(tools_override) if tools_override is not None else list(agent.definition.tools)

# Determine kernel function metadata based on function_choice_behavior
if function_choice_behavior is not None and not function_choice_behavior.enable_kernel_functions:
funcs = []
elif function_choice_behavior is not None and function_choice_behavior.filters:
funcs = kernel.get_list_of_function_metadata(function_choice_behavior.filters)
else:
funcs = kernel.get_full_list_of_function_metadata()

cls._validate_function_tools_registered(tools, funcs)
dict_defs = [kernel_function_metadata_to_function_call_format(f) for f in funcs]
deduped_defs = cls._deduplicate_tools(tools, dict_defs)
Expand Down Expand Up @@ -1071,6 +1138,7 @@ async def _invoke_function_calls(
fccs: list["FunctionCallContent"],
chat_history: "ChatHistory",
arguments: KernelArguments,
function_choice_behavior: FunctionChoiceBehavior | None = None,
) -> list["AutoFunctionInvocationContext | None"]:
"""Invoke the function calls."""
return await asyncio.gather(
Expand All @@ -1079,6 +1147,7 @@ async def _invoke_function_calls(
function_call=function_call,
chat_history=chat_history,
arguments=arguments,
function_behavior=function_choice_behavior,
)
for function_call in fccs
],
Expand Down Expand Up @@ -1111,6 +1180,7 @@ async def _handle_streaming_requires_action(
run: ThreadRun,
function_steps: dict[str, "FunctionCallContent"],
arguments: KernelArguments,
function_choice_behavior: FunctionChoiceBehavior | None = None,
**kwargs: Any,
) -> FunctionActionResult | None:
"""Handle the requires action event for a streaming run."""
Expand All @@ -1121,7 +1191,8 @@ async def _handle_streaming_requires_action(

chat_history = ChatHistory() if kwargs.get("chat_history") is None else kwargs["chat_history"]
results = await cls._invoke_function_calls(
kernel=kernel, fccs=fccs, chat_history=chat_history, arguments=arguments
kernel=kernel, fccs=fccs, chat_history=chat_history, arguments=arguments,
function_choice_behavior=function_choice_behavior,
)

function_result_streaming_content = merge_streaming_function_results(
Expand Down
13 changes: 13 additions & 0 deletions python/semantic_kernel/agents/azure_ai/azure_ai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from semantic_kernel.agents.azure_ai.azure_ai_channel import AzureAIChannel
from semantic_kernel.agents.channels.agent_channel import AgentChannel
from semantic_kernel.agents.open_ai.run_polling_options import RunPollingOptions
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.connectors.ai.function_calling_utils import kernel_function_metadata_to_function_call_format
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.utils.author_role import AuthorRole
Expand Down Expand Up @@ -647,6 +648,7 @@ async def get_response(
parallel_tool_calls: bool | None = None,
metadata: dict[str, str] | None = None,
polling_options: RunPollingOptions | None = None,
function_choice_behavior: FunctionChoiceBehavior | None = None,
**kwargs: Any,
) -> AgentResponseItem[ChatMessageContent]:
"""Get a response from the agent on a thread.
Expand All @@ -671,6 +673,8 @@ async def get_response(
parallel_tool_calls: Whether to allow parallel tool calls.
metadata: Metadata for the agent.
polling_options: The polling options for the agent.
function_choice_behavior: The function choice behavior to control which kernel
functions are available. Only Auto is supported; other types will raise an error.
**kwargs: Additional keyword arguments.

Returns:
Expand Down Expand Up @@ -716,6 +720,7 @@ async def get_response(
thread_id=thread.id,
kernel=kernel,
arguments=arguments,
function_choice_behavior=function_choice_behavior,
**run_level_params, # type: ignore
):
if is_visible and response.metadata.get("code") is not True:
Expand Down Expand Up @@ -752,6 +757,7 @@ async def invoke(
parallel_tool_calls: bool | None = None,
metadata: dict[str, str] | None = None,
polling_options: RunPollingOptions | None = None,
function_choice_behavior: FunctionChoiceBehavior | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentResponseItem[ChatMessageContent]]:
"""Invoke the agent on the specified thread.
Expand All @@ -777,6 +783,8 @@ async def invoke(
parallel_tool_calls: Whether to allow parallel tool calls.
polling_options: The polling options for the agent.
metadata: Metadata for the agent.
function_choice_behavior: The function choice behavior to control which kernel
functions are available. Only Auto is supported; other types will raise an error.
**kwargs: Additional keyword arguments.

Yields:
Expand Down Expand Up @@ -821,6 +829,7 @@ async def invoke(
thread_id=thread.id,
kernel=kernel,
arguments=arguments,
function_choice_behavior=function_choice_behavior,
**run_level_params, # type: ignore
):
message.metadata["thread_id"] = thread.id
Expand Down Expand Up @@ -856,6 +865,7 @@ async def invoke_stream(
response_format: AgentsApiResponseFormatOption | None = None,
parallel_tool_calls: bool | None = None,
metadata: dict[str, str] | None = None,
function_choice_behavior: FunctionChoiceBehavior | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentResponseItem["StreamingChatMessageContent"]]:
"""Invoke the agent on the specified thread with a stream of messages.
Expand All @@ -881,6 +891,8 @@ async def invoke_stream(
response_format: Response format for the agent.
parallel_tool_calls: Whether to allow parallel tool calls.
metadata: Metadata for the agent.
function_choice_behavior: The function choice behavior to control which kernel
functions are available. Only Auto is supported; other types will raise an error.
**kwargs: Additional keyword arguments.

Yields:
Expand Down Expand Up @@ -928,6 +940,7 @@ async def invoke_stream(
output_messages=collected_messages,
kernel=kernel,
arguments=arguments,
function_choice_behavior=function_choice_behavior,
**run_level_params, # type: ignore
):
# Before yielding the current streamed message, emit any new full messages first
Expand Down
Loading
Loading