From 3a4a8c6d095c90eaea358d2cce3583565fca4d2a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 17:47:00 +0000 Subject: [PATCH 01/14] Initial plan From 8e83006ccd02623b632ed8f0e91a5a8f0332aeb7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 17:57:04 +0000 Subject: [PATCH 02/14] Add document chat agent using Microsoft Agent Framework with Azure AI Search Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- .env.sample | 10 ++ config/config.yaml | 7 ++ src/agent/__init__.py | 3 + src/agent/agent.py | 171 ++++++++++++++++++++++++++++++++++ src/agent/requirements.txt | 5 + tests/agent/__init__.py | 1 + tests/agent/test_agent.py | 182 +++++++++++++++++++++++++++++++++++++ 7 files changed, 379 insertions(+) create mode 100644 src/agent/__init__.py create mode 100644 src/agent/agent.py create mode 100644 src/agent/requirements.txt create mode 100644 tests/agent/__init__.py create mode 100644 tests/agent/test_agent.py diff --git a/.env.sample b/.env.sample index 2a369f9..77684f1 100644 --- a/.env.sample +++ b/.env.sample @@ -8,3 +8,13 @@ AI_FOUNDRY_PROJECT_URI="https://${AI_FOUNDRY_NAME}.services.ai.azure.com/api/pro MANAGED_IDENTITY_CLIENT_ID= MANAGED_IDENTITY_NAME= FUNCTION_APP_NAME= + +# Agent environment variables +# Azure AI Foundry project endpoint (used by the document chat agent) +AZURE_AI_AGENT_ENDPOINT="https://${AI_FOUNDRY_NAME}.services.ai.azure.com/api/projects/${PROJECT_NAME}" +# Model deployment name for the agent (e.g. gpt-4o) +AZURE_AI_AGENT_MODEL_DEPLOYMENT_NAME= +# Azure AI Foundry connection ID for the Azure AI Search service +AZURE_AI_SEARCH_CONNECTION_ID= +# Name of the Azure AI Search index the agent will query +AZURE_AI_SEARCH_INDEX_NAME= diff --git a/config/config.yaml b/config/config.yaml index dc1e47c..46fcd64 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -36,3 +36,10 @@ data_pr: data_dev: local_folder: data storage_container: fulldataset + +# Agent configuration +agent_config: + agent_endpoint: ${AZURE_AI_AGENT_ENDPOINT} + agent_model_deployment: ${AZURE_AI_AGENT_MODEL_DEPLOYMENT_NAME} + ai_search_connection_id: ${AZURE_AI_SEARCH_CONNECTION_ID} + ai_search_index_name: ${AZURE_AI_SEARCH_INDEX_NAME} diff --git a/src/agent/__init__.py b/src/agent/__init__.py new file mode 100644 index 0000000..a4bad3b --- /dev/null +++ b/src/agent/__init__.py @@ -0,0 +1,3 @@ +"""Agent module for chatting with indexed documents using Azure AI Search.""" + +from src.agent import agent # noqa: F401 diff --git a/src/agent/agent.py b/src/agent/agent.py new file mode 100644 index 0000000..22fec64 --- /dev/null +++ b/src/agent/agent.py @@ -0,0 +1,171 @@ +"""Agent for chatting with documents indexed in Azure AI Search.""" + +import asyncio +import argparse + +from azure.identity.aio import DefaultAzureCredential +from azure.ai.agents.models import AzureAISearchTool, AzureAISearchQueryType +from semantic_kernel.agents import AzureAIAgent, AzureAIAgentSettings +from semantic_kernel.agents import AzureAIAgentThread + + +AGENT_NAME = "DocumentChatAgent" +AGENT_INSTRUCTIONS = ( + "You are a helpful assistant that answers questions about documents " + "stored in an Azure AI Search index. Use the search tool to find relevant " + "information and provide accurate, concise answers based on the indexed content. " + "If the answer is not found in the indexed documents, say so clearly." +) + + +async def create_agent( + ai_search_connection_id: str, + ai_search_index_name: str, + model_deployment_name: str, + endpoint: str, +) -> AzureAIAgent: + """ + Create an AzureAIAgent configured with an Azure AI Search tool. + + Args: + ai_search_connection_id (str): The AI Foundry connection ID for Azure AI Search. + ai_search_index_name (str): The name of the Azure AI Search index to query. + model_deployment_name (str): The model deployment name to use for the agent. + endpoint (str): The Azure AI Foundry project endpoint. + + Returns: + AzureAIAgent: Configured agent instance. + """ + ai_search_tool = AzureAISearchTool( + index_connection_id=ai_search_connection_id, + index_name=ai_search_index_name, + query_type=AzureAISearchQueryType.VECTOR_SEMANTIC_HYBRID, + top_k=5, + ) + + credential = DefaultAzureCredential() + client = AzureAIAgent.create_client(credential=credential, endpoint=endpoint) + agent_definition = await client.agents.create_agent( + model=model_deployment_name, + name=AGENT_NAME, + instructions=AGENT_INSTRUCTIONS, + tools=ai_search_tool.definitions, + tool_resources=ai_search_tool.resources, + ) + + return AzureAIAgent(client=client, definition=agent_definition) + + +async def run_agent_conversation(agent: AzureAIAgent, user_message: str) -> str: + """ + Send a single message to the agent and return the response. + + Args: + agent (AzureAIAgent): The configured agent instance. + user_message (str): The user's query message. + + Returns: + str: The agent's response text. + """ + thread: AzureAIAgentThread | None = None + try: + thread = AzureAIAgentThread(client=agent.client) + response_parts = [] + async for response in agent.invoke( + messages=user_message, + thread=thread, + ): + response_parts.append(str(response.content)) + return "".join(response_parts) + finally: + if thread is not None: + await thread.delete() + + +async def run_local_chat( + ai_search_connection_id: str, + ai_search_index_name: str, + model_deployment_name: str, + endpoint: str, +) -> None: + """ + Run an interactive local chat session with the document agent. + + Args: + ai_search_connection_id (str): The AI Foundry connection ID for Azure AI Search. + ai_search_index_name (str): The name of the Azure AI Search index to query. + model_deployment_name (str): The model deployment name to use for the agent. + endpoint (str): The Azure AI Foundry project endpoint. + """ + print("Initializing Document Chat Agent...") + agent = await create_agent( + ai_search_connection_id=ai_search_connection_id, + ai_search_index_name=ai_search_index_name, + model_deployment_name=model_deployment_name, + endpoint=endpoint, + ) + print(f"Agent '{AGENT_NAME}' is ready. Type 'exit' or 'quit' to stop.\n") + + try: + while True: + user_input = input("You: ").strip() + if not user_input: + continue + if user_input.lower() in {"exit", "quit"}: + print("Goodbye!") + break + response = await run_agent_conversation(agent, user_input) + print(f"Agent: {response}\n") + finally: + await agent.client.agents.delete_agent(agent.id) + + +def main(): + """Run the document chat agent locally using command-line arguments or environment variables.""" + parser = argparse.ArgumentParser( + description="Run an interactive chat session with indexed documents." + ) + parser.add_argument( + "--connection-id", + required=True, + help="Azure AI Foundry connection ID for the Azure AI Search service.", + ) + parser.add_argument( + "--index-name", + required=True, + help="Name of the Azure AI Search index to query.", + ) + parser.add_argument( + "--model", + default=None, + help=( + "Model deployment name. Defaults to AZURE_AI_AGENT_MODEL_DEPLOYMENT_NAME " + "environment variable if not specified." + ), + ) + parser.add_argument( + "--endpoint", + default=None, + help=( + "Azure AI Foundry project endpoint. Defaults to AZURE_AI_AGENT_ENDPOINT " + "environment variable if not specified." + ), + ) + args = parser.parse_args() + + settings = AzureAIAgentSettings() + model = args.model or settings.model_deployment_name + endpoint = args.endpoint or settings.endpoint + + asyncio.run( + run_local_chat( + ai_search_connection_id=args.connection_id, + ai_search_index_name=args.index_name, + model_deployment_name=model, + endpoint=endpoint, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/src/agent/requirements.txt b/src/agent/requirements.txt new file mode 100644 index 0000000..abd2cea --- /dev/null +++ b/src/agent/requirements.txt @@ -0,0 +1,5 @@ +semantic-kernel>=1.0.0 +azure-identity>=1.16.1 +azure-ai-projects>=1.0.0 +azure-ai-agents>=1.0.0 +python-dotenv>=0.10.3 diff --git a/tests/agent/__init__.py b/tests/agent/__init__.py new file mode 100644 index 0000000..49cc545 --- /dev/null +++ b/tests/agent/__init__.py @@ -0,0 +1 @@ +"""Agent tests package.""" diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py new file mode 100644 index 0000000..a3dbb76 --- /dev/null +++ b/tests/agent/test_agent.py @@ -0,0 +1,182 @@ +"""Unit tests for the document chat agent.""" + +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + + +class TestCreateAgent(unittest.IsolatedAsyncioTestCase): + """Tests for the create_agent function.""" + + @patch("src.agent.agent.AzureAIAgent") + @patch("src.agent.agent.DefaultAzureCredential") + async def test_create_agent_builds_correct_tool( + self, mock_credential_cls, mock_agent_cls + ): + """Test that create_agent configures the AI Search tool correctly.""" + from src.agent.agent import create_agent, AGENT_NAME, AGENT_INSTRUCTIONS + + mock_credential = AsyncMock() + mock_credential_cls.return_value = mock_credential + + mock_client = MagicMock() + mock_agent_cls.create_client.return_value = mock_client + + mock_agent_definition = MagicMock() + mock_agent_definition.name = AGENT_NAME + mock_agent_definition.id = "agent-123" + mock_agent_definition.description = None + mock_agent_definition.instructions = AGENT_INSTRUCTIONS + mock_client.agents.create_agent = AsyncMock(return_value=mock_agent_definition) + + mock_agent = MagicMock() + mock_agent_cls.return_value = mock_agent + + agent = await create_agent( + ai_search_connection_id="test-connection-id", + ai_search_index_name="test-index", + model_deployment_name="gpt-4o", + endpoint="https://test.services.ai.azure.com/api/projects/test-project", + ) + + mock_client.agents.create_agent.assert_awaited_once() + call_kwargs = mock_client.agents.create_agent.call_args.kwargs + self.assertEqual(call_kwargs["model"], "gpt-4o") + self.assertEqual(call_kwargs["name"], AGENT_NAME) + self.assertEqual(call_kwargs["instructions"], AGENT_INSTRUCTIONS) + self.assertIsNotNone(call_kwargs["tools"]) + self.assertIsNotNone(call_kwargs["tool_resources"]) + + self.assertEqual(agent, mock_agent) + + @patch("src.agent.agent.AzureAIAgent") + @patch("src.agent.agent.DefaultAzureCredential") + async def test_create_agent_uses_provided_endpoint( + self, mock_credential_cls, mock_agent_cls + ): + """Test that create_agent passes the provided endpoint to create_client.""" + from src.agent.agent import create_agent, AGENT_INSTRUCTIONS, AGENT_NAME + + expected_endpoint = "https://my-project.services.ai.azure.com/api/projects/proj" + mock_credential = AsyncMock() + mock_credential_cls.return_value = mock_credential + + mock_client = MagicMock() + mock_agent_cls.create_client.return_value = mock_client + + mock_agent_definition = MagicMock() + mock_agent_definition.name = AGENT_NAME + mock_agent_definition.id = "agent-456" + mock_agent_definition.description = None + mock_agent_definition.instructions = AGENT_INSTRUCTIONS + mock_client.agents.create_agent = AsyncMock(return_value=mock_agent_definition) + + mock_agent_cls.return_value = MagicMock() + + await create_agent( + ai_search_connection_id="conn-id", + ai_search_index_name="idx", + model_deployment_name="gpt-4o", + endpoint=expected_endpoint, + ) + + mock_agent_cls.create_client.assert_called_once_with( + credential=mock_credential, endpoint=expected_endpoint + ) + + +class TestRunAgentConversation(unittest.IsolatedAsyncioTestCase): + """Tests for the run_agent_conversation function.""" + + @patch("src.agent.agent.AzureAIAgentThread") + async def test_run_agent_conversation_returns_response(self, mock_thread_cls): + """Test that run_agent_conversation returns the agent response text.""" + from src.agent.agent import run_agent_conversation + + mock_thread = MagicMock() + mock_thread.delete = AsyncMock() + mock_thread_cls.return_value = mock_thread + + mock_response = MagicMock() + mock_response.content = "This is the answer from indexed documents." + + async def fake_invoke(**kwargs): + yield mock_response + + mock_agent = MagicMock() + mock_agent.invoke = fake_invoke + mock_agent.client = MagicMock() + + result = await run_agent_conversation(mock_agent, "What is this document about?") + + self.assertEqual(result, "This is the answer from indexed documents.") + mock_thread.delete.assert_awaited_once() + + @patch("src.agent.agent.AzureAIAgentThread") + async def test_run_agent_conversation_deletes_thread_on_error(self, mock_thread_cls): + """Test that run_agent_conversation cleans up thread even when an error occurs.""" + from src.agent.agent import run_agent_conversation + + mock_thread = MagicMock() + mock_thread.delete = AsyncMock() + mock_thread_cls.return_value = mock_thread + + async def failing_invoke(**kwargs): + if False: + yield + raise RuntimeError("Search failed") + + mock_agent = MagicMock() + mock_agent.invoke = failing_invoke + mock_agent.client = MagicMock() + + with self.assertRaises(RuntimeError): + await run_agent_conversation(mock_agent, "What is this document about?") + + mock_thread.delete.assert_awaited_once() + + @patch("src.agent.agent.AzureAIAgentThread") + async def test_run_agent_conversation_concatenates_multiple_responses( + self, mock_thread_cls + ): + """Test that multiple response chunks are concatenated.""" + from src.agent.agent import run_agent_conversation + + mock_thread = MagicMock() + mock_thread.delete = AsyncMock() + mock_thread_cls.return_value = mock_thread + + mock_resp1 = MagicMock() + mock_resp1.content = "Part one. " + mock_resp2 = MagicMock() + mock_resp2.content = "Part two." + + async def multi_invoke(**kwargs): + yield mock_resp1 + yield mock_resp2 + + mock_agent = MagicMock() + mock_agent.invoke = multi_invoke + mock_agent.client = MagicMock() + + result = await run_agent_conversation(mock_agent, "Tell me more.") + + self.assertEqual(result, "Part one. Part two.") + mock_thread.delete.assert_awaited_once() + + +class TestAgentConstants(unittest.TestCase): + """Tests for agent module constants and defaults.""" + + def test_agent_name_is_defined(self): + """Test that AGENT_NAME constant is defined.""" + from src.agent.agent import AGENT_NAME + + self.assertIsInstance(AGENT_NAME, str) + self.assertTrue(len(AGENT_NAME) > 0) + + def test_agent_instructions_mention_search(self): + """Test that AGENT_INSTRUCTIONS reference AI Search usage.""" + from src.agent.agent import AGENT_INSTRUCTIONS + + self.assertIsInstance(AGENT_INSTRUCTIONS, str) + self.assertIn("search", AGENT_INSTRUCTIONS.lower()) From 85570b9776061a9ac6350742bcdc98af5345ce6b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 18:21:55 +0000 Subject: [PATCH 03/14] Refactor agent: auto-resolve connection ID and index name from config Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- .env.sample | 8 +--- config/config.yaml | 4 +- src/agent/agent.py | 90 +++++++++++++++++++++++++++------------ tests/agent/test_agent.py | 83 ++++++++++++++++++++++++++++++++++++ 4 files changed, 148 insertions(+), 37 deletions(-) diff --git a/.env.sample b/.env.sample index 77684f1..abfa254 100644 --- a/.env.sample +++ b/.env.sample @@ -10,11 +10,5 @@ MANAGED_IDENTITY_NAME= FUNCTION_APP_NAME= # Agent environment variables -# Azure AI Foundry project endpoint (used by the document chat agent) -AZURE_AI_AGENT_ENDPOINT="https://${AI_FOUNDRY_NAME}.services.ai.azure.com/api/projects/${PROJECT_NAME}" -# Model deployment name for the agent (e.g. gpt-4o) +# Model deployment name for the document chat agent (e.g. gpt-4o) AZURE_AI_AGENT_MODEL_DEPLOYMENT_NAME= -# Azure AI Foundry connection ID for the Azure AI Search service -AZURE_AI_SEARCH_CONNECTION_ID= -# Name of the Azure AI Search index the agent will query -AZURE_AI_SEARCH_INDEX_NAME= diff --git a/config/config.yaml b/config/config.yaml index 46fcd64..7e5ba03 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -39,7 +39,5 @@ data_dev: # Agent configuration agent_config: - agent_endpoint: ${AZURE_AI_AGENT_ENDPOINT} + agent_endpoint: ${AI_FOUNDRY_PROJECT_URI} agent_model_deployment: ${AZURE_AI_AGENT_MODEL_DEPLOYMENT_NAME} - ai_search_connection_id: ${AZURE_AI_SEARCH_CONNECTION_ID} - ai_search_index_name: ${AZURE_AI_SEARCH_INDEX_NAME} diff --git a/src/agent/agent.py b/src/agent/agent.py index 22fec64..8547390 100644 --- a/src/agent/agent.py +++ b/src/agent/agent.py @@ -3,11 +3,17 @@ import asyncio import argparse +from azure.identity import DefaultAzureCredential as SyncDefaultAzureCredential from azure.identity.aio import DefaultAzureCredential +from azure.ai.projects import AIProjectClient as SyncAIProjectClient +from azure.ai.projects.models import ConnectionType from azure.ai.agents.models import AzureAISearchTool, AzureAISearchQueryType -from semantic_kernel.agents import AzureAIAgent, AzureAIAgentSettings +from semantic_kernel.agents import AzureAIAgent from semantic_kernel.agents import AzureAIAgentThread +from mlops.common.config_utils import MLOpsConfig +from mlops.common.naming_utils import generate_index_name + AGENT_NAME = "DocumentChatAgent" AGENT_INSTRUCTIONS = ( @@ -18,6 +24,44 @@ ) +def get_ai_search_connection_id(endpoint: str, acs_service_name: str) -> str: + """ + Retrieve the AI Foundry connection ID for the given Azure AI Search service. + + Lists all Azure AI Search connections in the AI Foundry project and returns + the ID of the connection whose target URL contains the specified service name. + Falls back to the default AI Search connection if no name match is found. + + Args: + endpoint (str): The Azure AI Foundry project endpoint. + acs_service_name (str): The Azure AI Search service name (e.g. 'my-search'). + + Returns: + str: The connection ID to use with the AI Search tool. + + Raises: + ValueError: If no Azure AI Search connection is found in the project. + """ + credential = SyncDefaultAzureCredential() + client = SyncAIProjectClient(endpoint=endpoint, credential=credential) + connections = list(client.connections.list(connection_type=ConnectionType.AZURE_AI_SEARCH)) + if not connections: + raise ValueError( + "No Azure AI Search connection found in the AI Foundry project. " + "Please add a connection to your Azure AI Search service in AI Foundry." + ) + # Prefer the connection whose target URL contains the configured service name + matched = next( + (c for c in connections if acs_service_name.lower() in c.target.lower()), + None, + ) + if matched: + return matched.id + # Fall back to the default connection, or the first available one + default_conn = next((c for c in connections if c.is_default), connections[0]) + return default_conn.id + + async def create_agent( ai_search_connection_id: str, ai_search_index_name: str, @@ -121,46 +165,38 @@ async def run_local_chat( def main(): - """Run the document chat agent locally using command-line arguments or environment variables.""" + """Run the document chat agent locally using configuration from config.yaml.""" parser = argparse.ArgumentParser( description="Run an interactive chat session with indexed documents." ) parser.add_argument( - "--connection-id", - required=True, - help="Azure AI Foundry connection ID for the Azure AI Search service.", - ) - parser.add_argument( - "--index-name", - required=True, - help="Name of the Azure AI Search index to query.", + "--stage", + default="pr", + help="Stage to find parameters (pr, dev). Defaults to 'pr'.", ) parser.add_argument( "--model", default=None, - help=( - "Model deployment name. Defaults to AZURE_AI_AGENT_MODEL_DEPLOYMENT_NAME " - "environment variable if not specified." - ), - ) - parser.add_argument( - "--endpoint", - default=None, - help=( - "Azure AI Foundry project endpoint. Defaults to AZURE_AI_AGENT_ENDPOINT " - "environment variable if not specified." - ), + help="Model deployment name. Overrides agent_config.agent_model_deployment in config.yaml.", ) args = parser.parse_args() - settings = AzureAIAgentSettings() - model = args.model or settings.model_deployment_name - endpoint = args.endpoint or settings.endpoint + config = MLOpsConfig(environment=args.stage) + agent_config = config.agent_config + acs_config = config.acs_config + + endpoint = agent_config["agent_endpoint"] + model = args.model or agent_config["agent_model_deployment"] + acs_service_name = acs_config["acs_service_name"] + index_name = generate_index_name() + + print(f"Looking up AI Search connection for service '{acs_service_name}'...") + connection_id = get_ai_search_connection_id(endpoint, acs_service_name) asyncio.run( run_local_chat( - ai_search_connection_id=args.connection_id, - ai_search_index_name=args.index_name, + ai_search_connection_id=connection_id, + ai_search_index_name=index_name, model_deployment_name=model, endpoint=endpoint, ) diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index a3dbb76..03e40cc 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -3,6 +3,89 @@ import unittest from unittest.mock import AsyncMock, MagicMock, patch +from azure.ai.projects.models import ConnectionType + + +class TestGetAISearchConnectionId(unittest.TestCase): + """Tests for the get_ai_search_connection_id function.""" + + @patch("src.agent.agent.SyncAIProjectClient") + @patch("src.agent.agent.SyncDefaultAzureCredential") + def test_returns_matching_connection_by_service_name( + self, mock_cred_cls, mock_client_cls + ): + """Test that the connection whose target contains acs_service_name is returned.""" + from src.agent.agent import get_ai_search_connection_id + + conn1 = MagicMock() + conn1.id = "/connections/other-search" + conn1.target = "https://other-search.search.windows.net" + conn1.is_default = False + + conn2 = MagicMock() + conn2.id = "/connections/my-search" + conn2.target = "https://my-search.search.windows.net" + conn2.is_default = False + + mock_client = MagicMock() + mock_client.connections.list.return_value = [conn1, conn2] + mock_client_cls.return_value = mock_client + + result = get_ai_search_connection_id( + endpoint="https://test.services.ai.azure.com/api/projects/p", + acs_service_name="my-search", + ) + + self.assertEqual(result, "/connections/my-search") + mock_client.connections.list.assert_called_once_with( + connection_type=ConnectionType.AZURE_AI_SEARCH + ) + + @patch("src.agent.agent.SyncAIProjectClient") + @patch("src.agent.agent.SyncDefaultAzureCredential") + def test_falls_back_to_default_when_no_name_match( + self, mock_cred_cls, mock_client_cls + ): + """Test fallback to default connection when service name has no match.""" + from src.agent.agent import get_ai_search_connection_id + + conn1 = MagicMock() + conn1.id = "/connections/first" + conn1.target = "https://first.search.windows.net" + conn1.is_default = False + + conn2 = MagicMock() + conn2.id = "/connections/default" + conn2.target = "https://default.search.windows.net" + conn2.is_default = True + + mock_client = MagicMock() + mock_client.connections.list.return_value = [conn1, conn2] + mock_client_cls.return_value = mock_client + + result = get_ai_search_connection_id( + endpoint="https://test.services.ai.azure.com/api/projects/p", + acs_service_name="no-match-here", + ) + + self.assertEqual(result, "/connections/default") + + @patch("src.agent.agent.SyncAIProjectClient") + @patch("src.agent.agent.SyncDefaultAzureCredential") + def test_raises_when_no_connections_found(self, mock_cred_cls, mock_client_cls): + """Test that ValueError is raised when no AI Search connections exist.""" + from src.agent.agent import get_ai_search_connection_id + + mock_client = MagicMock() + mock_client.connections.list.return_value = [] + mock_client_cls.return_value = mock_client + + with self.assertRaises(ValueError): + get_ai_search_connection_id( + endpoint="https://test.services.ai.azure.com/api/projects/p", + acs_service_name="my-search", + ) + class TestCreateAgent(unittest.IsolatedAsyncioTestCase): """Tests for the create_agent function.""" From 8b0f8ace09b4d05db1c505b4e11d3c5a0e9f7da6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 19:22:21 +0000 Subject: [PATCH 04/14] Remove --stage/--model CLI args; create AI Search connection if missing Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- src/agent/agent.py | 98 +++++++++++++++++++------------ src/agent/requirements.txt | 1 + tests/agent/test_agent.py | 117 ++++++++++++++++++++++++++----------- 3 files changed, 145 insertions(+), 71 deletions(-) diff --git a/src/agent/agent.py b/src/agent/agent.py index 8547390..b98e4d4 100644 --- a/src/agent/agent.py +++ b/src/agent/agent.py @@ -1,13 +1,14 @@ """Agent for chatting with documents indexed in Azure AI Search.""" import asyncio -import argparse from azure.identity import DefaultAzureCredential as SyncDefaultAzureCredential from azure.identity.aio import DefaultAzureCredential from azure.ai.projects import AIProjectClient as SyncAIProjectClient from azure.ai.projects.models import ConnectionType from azure.ai.agents.models import AzureAISearchTool, AzureAISearchQueryType +from azure.ai.ml import MLClient +from azure.ai.ml.entities import AzureAISearchConnection from semantic_kernel.agents import AzureAIAgent from semantic_kernel.agents import AzureAIAgentThread @@ -24,42 +25,74 @@ ) -def get_ai_search_connection_id(endpoint: str, acs_service_name: str) -> str: +def _extract_project_name(endpoint: str) -> str: """ - Retrieve the AI Foundry connection ID for the given Azure AI Search service. + Extract the AI Foundry project name from the project endpoint URL. + + Args: + endpoint (str): URL in the form + ``https://.services.ai.azure.com/api/projects/``. + + Returns: + str: The project name (last path segment of the URL). + """ + return endpoint.rstrip("/").split("/")[-1] - Lists all Azure AI Search connections in the AI Foundry project and returns - the ID of the connection whose target URL contains the specified service name. - Falls back to the default AI Search connection if no name match is found. + +def ensure_ai_search_connection_id( + endpoint: str, + acs_service_name: str, + subscription_id: str, + resource_group_name: str, +) -> str: + """ + Return the AI Foundry connection ID for the given Azure AI Search service, + creating the connection if it does not already exist. + + First checks whether a connection whose target URL contains + ``acs_service_name`` is already registered in the project. If no such + connection exists, one is created via the Azure AI ML management SDK using + AAD/managed-identity authentication (no API key required). Args: endpoint (str): The Azure AI Foundry project endpoint. acs_service_name (str): The Azure AI Search service name (e.g. 'my-search'). + subscription_id (str): Azure subscription ID. + resource_group_name (str): Azure resource group name. Returns: str: The connection ID to use with the AI Search tool. - - Raises: - ValueError: If no Azure AI Search connection is found in the project. """ credential = SyncDefaultAzureCredential() client = SyncAIProjectClient(endpoint=endpoint, credential=credential) connections = list(client.connections.list(connection_type=ConnectionType.AZURE_AI_SEARCH)) - if not connections: - raise ValueError( - "No Azure AI Search connection found in the AI Foundry project. " - "Please add a connection to your Azure AI Search service in AI Foundry." - ) - # Prefer the connection whose target URL contains the configured service name + + # Return the connection whose target URL matches the configured service name matched = next( (c for c in connections if acs_service_name.lower() in c.target.lower()), None, ) if matched: return matched.id - # Fall back to the default connection, or the first available one - default_conn = next((c for c in connections if c.is_default), connections[0]) - return default_conn.id + + # No matching connection found — create one using the management SDK + print( + f"No AI Search connection found for '{acs_service_name}'. " + "Creating connection in AI Foundry..." + ) + project_name = _extract_project_name(endpoint) + ml_client = MLClient( + credential=credential, + subscription_id=subscription_id, + resource_group_name=resource_group_name, + workspace_name=project_name, + ) + new_connection = AzureAISearchConnection( + name=acs_service_name, + endpoint=f"https://{acs_service_name}.search.windows.net", + ) + created = ml_client.connections.create_or_update(new_connection) + return created.id async def create_agent( @@ -166,32 +199,25 @@ async def run_local_chat( def main(): """Run the document chat agent locally using configuration from config.yaml.""" - parser = argparse.ArgumentParser( - description="Run an interactive chat session with indexed documents." - ) - parser.add_argument( - "--stage", - default="pr", - help="Stage to find parameters (pr, dev). Defaults to 'pr'.", - ) - parser.add_argument( - "--model", - default=None, - help="Model deployment name. Overrides agent_config.agent_model_deployment in config.yaml.", - ) - args = parser.parse_args() - - config = MLOpsConfig(environment=args.stage) + config = MLOpsConfig() agent_config = config.agent_config acs_config = config.acs_config + sub_config = config.sub_config endpoint = agent_config["agent_endpoint"] - model = args.model or agent_config["agent_model_deployment"] + model = agent_config["agent_model_deployment"] acs_service_name = acs_config["acs_service_name"] + subscription_id = sub_config["subscription_id"] + resource_group_name = sub_config["resource_group_name"] index_name = generate_index_name() print(f"Looking up AI Search connection for service '{acs_service_name}'...") - connection_id = get_ai_search_connection_id(endpoint, acs_service_name) + connection_id = ensure_ai_search_connection_id( + endpoint=endpoint, + acs_service_name=acs_service_name, + subscription_id=subscription_id, + resource_group_name=resource_group_name, + ) asyncio.run( run_local_chat( diff --git a/src/agent/requirements.txt b/src/agent/requirements.txt index abd2cea..c6c85f3 100644 --- a/src/agent/requirements.txt +++ b/src/agent/requirements.txt @@ -2,4 +2,5 @@ semantic-kernel>=1.0.0 azure-identity>=1.16.1 azure-ai-projects>=1.0.0 azure-ai-agents>=1.0.0 +azure-ai-ml>=1.0.0 python-dotenv>=0.10.3 diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index 03e40cc..56d7986 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -6,8 +6,30 @@ from azure.ai.projects.models import ConnectionType -class TestGetAISearchConnectionId(unittest.TestCase): - """Tests for the get_ai_search_connection_id function.""" +class TestExtractProjectName(unittest.TestCase): + """Tests for the _extract_project_name helper.""" + + def test_extracts_last_path_segment(self): + """Test that the project name is parsed from the endpoint URL.""" + from src.agent.agent import _extract_project_name + + result = _extract_project_name( + "https://myhub.services.ai.azure.com/api/projects/myproject" + ) + self.assertEqual(result, "myproject") + + def test_handles_trailing_slash(self): + """Test that a trailing slash is ignored.""" + from src.agent.agent import _extract_project_name + + result = _extract_project_name( + "https://myhub.services.ai.azure.com/api/projects/myproject/" + ) + self.assertEqual(result, "myproject") + + +class TestEnsureAISearchConnectionId(unittest.TestCase): + """Tests for the ensure_ai_search_connection_id function.""" @patch("src.agent.agent.SyncAIProjectClient") @patch("src.agent.agent.SyncDefaultAzureCredential") @@ -15,25 +37,25 @@ def test_returns_matching_connection_by_service_name( self, mock_cred_cls, mock_client_cls ): """Test that the connection whose target contains acs_service_name is returned.""" - from src.agent.agent import get_ai_search_connection_id + from src.agent.agent import ensure_ai_search_connection_id conn1 = MagicMock() conn1.id = "/connections/other-search" conn1.target = "https://other-search.search.windows.net" - conn1.is_default = False conn2 = MagicMock() conn2.id = "/connections/my-search" conn2.target = "https://my-search.search.windows.net" - conn2.is_default = False mock_client = MagicMock() mock_client.connections.list.return_value = [conn1, conn2] mock_client_cls.return_value = mock_client - result = get_ai_search_connection_id( + result = ensure_ai_search_connection_id( endpoint="https://test.services.ai.azure.com/api/projects/p", acs_service_name="my-search", + subscription_id="sub-123", + resource_group_name="rg-test", ) self.assertEqual(result, "/connections/my-search") @@ -41,50 +63,75 @@ def test_returns_matching_connection_by_service_name( connection_type=ConnectionType.AZURE_AI_SEARCH ) + @patch("src.agent.agent.MLClient") @patch("src.agent.agent.SyncAIProjectClient") @patch("src.agent.agent.SyncDefaultAzureCredential") - def test_falls_back_to_default_when_no_name_match( - self, mock_cred_cls, mock_client_cls + def test_creates_connection_when_not_found( + self, mock_cred_cls, mock_client_cls, mock_ml_client_cls ): - """Test fallback to default connection when service name has no match.""" - from src.agent.agent import get_ai_search_connection_id - - conn1 = MagicMock() - conn1.id = "/connections/first" - conn1.target = "https://first.search.windows.net" - conn1.is_default = False - - conn2 = MagicMock() - conn2.id = "/connections/default" - conn2.target = "https://default.search.windows.net" - conn2.is_default = True + """Test that a new connection is created when no matching connection exists.""" + from src.agent.agent import ensure_ai_search_connection_id mock_client = MagicMock() - mock_client.connections.list.return_value = [conn1, conn2] + mock_client.connections.list.return_value = [] mock_client_cls.return_value = mock_client - result = get_ai_search_connection_id( - endpoint="https://test.services.ai.azure.com/api/projects/p", - acs_service_name="no-match-here", - ) + mock_ml_client = MagicMock() + created_conn = MagicMock() + created_conn.id = "/connections/my-search" + mock_ml_client.connections.create_or_update.return_value = created_conn + mock_ml_client_cls.return_value = mock_ml_client - self.assertEqual(result, "/connections/default") + result = ensure_ai_search_connection_id( + endpoint="https://test.services.ai.azure.com/api/projects/myproject", + acs_service_name="my-search", + subscription_id="sub-123", + resource_group_name="rg-test", + ) + self.assertEqual(result, "/connections/my-search") + mock_ml_client_cls.assert_called_once() + call_kwargs = mock_ml_client_cls.call_args.kwargs + self.assertEqual(call_kwargs["subscription_id"], "sub-123") + self.assertEqual(call_kwargs["resource_group_name"], "rg-test") + self.assertEqual(call_kwargs["workspace_name"], "myproject") + mock_ml_client.connections.create_or_update.assert_called_once() + created_connection = mock_ml_client.connections.create_or_update.call_args.args[0] + self.assertEqual(created_connection.name, "my-search") + self.assertIn("my-search", created_connection.endpoint) + + @patch("src.agent.agent.MLClient") @patch("src.agent.agent.SyncAIProjectClient") @patch("src.agent.agent.SyncDefaultAzureCredential") - def test_raises_when_no_connections_found(self, mock_cred_cls, mock_client_cls): - """Test that ValueError is raised when no AI Search connections exist.""" - from src.agent.agent import get_ai_search_connection_id + def test_creates_connection_when_no_name_match( + self, mock_cred_cls, mock_client_cls, mock_ml_client_cls + ): + """Test that a new connection is created when connections exist but none match.""" + from src.agent.agent import ensure_ai_search_connection_id + + conn1 = MagicMock() + conn1.id = "/connections/other-search" + conn1.target = "https://other-search.search.windows.net" mock_client = MagicMock() - mock_client.connections.list.return_value = [] + mock_client.connections.list.return_value = [conn1] mock_client_cls.return_value = mock_client - with self.assertRaises(ValueError): - get_ai_search_connection_id( - endpoint="https://test.services.ai.azure.com/api/projects/p", - acs_service_name="my-search", - ) + mock_ml_client = MagicMock() + created_conn = MagicMock() + created_conn.id = "/connections/my-search" + mock_ml_client.connections.create_or_update.return_value = created_conn + mock_ml_client_cls.return_value = mock_ml_client + + result = ensure_ai_search_connection_id( + endpoint="https://test.services.ai.azure.com/api/projects/myproject", + acs_service_name="my-search", + subscription_id="sub-123", + resource_group_name="rg-test", + ) + + self.assertEqual(result, "/connections/my-search") + mock_ml_client.connections.create_or_update.assert_called_once() class TestCreateAgent(unittest.IsolatedAsyncioTestCase): From 50945bb6e3d996eefdd21fa0faa88d4ead86314d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 19:32:05 +0000 Subject: [PATCH 05/14] Fix pydocstyle D205/D400 in ensure_ai_search_connection_id docstring Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- src/agent/agent.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/agent/agent.py b/src/agent/agent.py index b98e4d4..e7d4464 100644 --- a/src/agent/agent.py +++ b/src/agent/agent.py @@ -45,9 +45,7 @@ def ensure_ai_search_connection_id( subscription_id: str, resource_group_name: str, ) -> str: - """ - Return the AI Foundry connection ID for the given Azure AI Search service, - creating the connection if it does not already exist. + """Return the AI Foundry connection ID for the given Azure AI Search service. First checks whether a connection whose target URL contains ``acs_service_name`` is already registered in the project. If no such From 0184368c4bcf431933775aad4946343667b54ea2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Mar 2026 19:51:38 +0000 Subject: [PATCH 06/14] Add agent dependencies to root requirements.txt to fix test import error Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- requirements.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/requirements.txt b/requirements.txt index f707f73..efc73fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,7 @@ azure-mgmt-search azure-mgmt-storage azure-search-documents azure-ai-evaluation +semantic-kernel>=1.0.0 +azure-ai-projects>=1.0.0 +azure-ai-agents>=1.0.0 +azure-ai-ml>=1.0.0 From 7411ee2c5981a375c7bf25d52a963b4b959fb51b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 19:20:53 +0000 Subject: [PATCH 07/14] Add src/__init__.py to make src a Python package, fixing ModuleNotFoundError Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- src/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 src/__init__.py diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..066f155 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1 @@ +"""Source package for AI Search MLOps components.""" From 8563f0fde2c5d0e8fb8ce9b36eaf77e52951d933 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 19:34:32 +0000 Subject: [PATCH 08/14] Add [tool:pytest] pythonpath=. to setup.cfg to fix ModuleNotFoundError: No module named 'src' in CI Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- setup.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index 9e43cff..fd0eba7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,3 +5,6 @@ max-complexity = 10 max-line-length = 120 count = True statistics = True + +[tool:pytest] +pythonpath = . From 8f3cd5a0c2ff5c4ca336421fffd83a0cb1513fe4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 19:51:47 +0000 Subject: [PATCH 09/14] Fix 471 pytest warnings: lazy-import azure-ai-ml and filter marshmallow deprecation warnings Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- setup.cfg | 5 +++++ src/agent/agent.py | 4 ++-- tests/agent/test_agent.py | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index fd0eba7..7f5f5ec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,3 +8,8 @@ statistics = True [tool:pytest] pythonpath = . +filterwarnings = + ignore::DeprecationWarning:marshmallow + ignore::PendingDeprecationWarning:marshmallow + ignore::DeprecationWarning:azure.ai.ml + ignore::PendingDeprecationWarning:azure.ai.ml diff --git a/src/agent/agent.py b/src/agent/agent.py index e7d4464..59c8a83 100644 --- a/src/agent/agent.py +++ b/src/agent/agent.py @@ -7,8 +7,6 @@ from azure.ai.projects import AIProjectClient as SyncAIProjectClient from azure.ai.projects.models import ConnectionType from azure.ai.agents.models import AzureAISearchTool, AzureAISearchQueryType -from azure.ai.ml import MLClient -from azure.ai.ml.entities import AzureAISearchConnection from semantic_kernel.agents import AzureAIAgent from semantic_kernel.agents import AzureAIAgentThread @@ -78,6 +76,8 @@ def ensure_ai_search_connection_id( f"No AI Search connection found for '{acs_service_name}'. " "Creating connection in AI Foundry..." ) + from azure.ai.ml import MLClient # noqa: PLC0415 + from azure.ai.ml.entities import AzureAISearchConnection # noqa: PLC0415 project_name = _extract_project_name(endpoint) ml_client = MLClient( credential=credential, diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index 56d7986..53cb45e 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -63,7 +63,7 @@ def test_returns_matching_connection_by_service_name( connection_type=ConnectionType.AZURE_AI_SEARCH ) - @patch("src.agent.agent.MLClient") + @patch("azure.ai.ml.MLClient") @patch("src.agent.agent.SyncAIProjectClient") @patch("src.agent.agent.SyncDefaultAzureCredential") def test_creates_connection_when_not_found( @@ -100,7 +100,7 @@ def test_creates_connection_when_not_found( self.assertEqual(created_connection.name, "my-search") self.assertIn("my-search", created_connection.endpoint) - @patch("src.agent.agent.MLClient") + @patch("azure.ai.ml.MLClient") @patch("src.agent.agent.SyncAIProjectClient") @patch("src.agent.agent.SyncDefaultAzureCredential") def test_creates_connection_when_no_name_match( From 4675a7cccd6f8b2edef645890d2bfd3242c01b5c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 20:03:35 +0000 Subject: [PATCH 10/14] Revert 8f3cd5a: restore lazy imports, test patches and filterwarnings changes Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- setup.cfg | 5 ----- src/agent/agent.py | 4 ++-- tests/agent/test_agent.py | 4 ++-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7f5f5ec..fd0eba7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,8 +8,3 @@ statistics = True [tool:pytest] pythonpath = . -filterwarnings = - ignore::DeprecationWarning:marshmallow - ignore::PendingDeprecationWarning:marshmallow - ignore::DeprecationWarning:azure.ai.ml - ignore::PendingDeprecationWarning:azure.ai.ml diff --git a/src/agent/agent.py b/src/agent/agent.py index 59c8a83..e7d4464 100644 --- a/src/agent/agent.py +++ b/src/agent/agent.py @@ -7,6 +7,8 @@ from azure.ai.projects import AIProjectClient as SyncAIProjectClient from azure.ai.projects.models import ConnectionType from azure.ai.agents.models import AzureAISearchTool, AzureAISearchQueryType +from azure.ai.ml import MLClient +from azure.ai.ml.entities import AzureAISearchConnection from semantic_kernel.agents import AzureAIAgent from semantic_kernel.agents import AzureAIAgentThread @@ -76,8 +78,6 @@ def ensure_ai_search_connection_id( f"No AI Search connection found for '{acs_service_name}'. " "Creating connection in AI Foundry..." ) - from azure.ai.ml import MLClient # noqa: PLC0415 - from azure.ai.ml.entities import AzureAISearchConnection # noqa: PLC0415 project_name = _extract_project_name(endpoint) ml_client = MLClient( credential=credential, diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index 53cb45e..56d7986 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -63,7 +63,7 @@ def test_returns_matching_connection_by_service_name( connection_type=ConnectionType.AZURE_AI_SEARCH ) - @patch("azure.ai.ml.MLClient") + @patch("src.agent.agent.MLClient") @patch("src.agent.agent.SyncAIProjectClient") @patch("src.agent.agent.SyncDefaultAzureCredential") def test_creates_connection_when_not_found( @@ -100,7 +100,7 @@ def test_creates_connection_when_not_found( self.assertEqual(created_connection.name, "my-search") self.assertIn("my-search", created_connection.endpoint) - @patch("azure.ai.ml.MLClient") + @patch("src.agent.agent.MLClient") @patch("src.agent.agent.SyncAIProjectClient") @patch("src.agent.agent.SyncDefaultAzureCredential") def test_creates_connection_when_no_name_match( From e91ac6dfc12bfae991d4a962dea8896c1885ce19 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 20:37:58 +0000 Subject: [PATCH 11/14] Add azure-search-documents version print to buildcontainer Dockerfile Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- .buildcontainer/Dockerfile | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.buildcontainer/Dockerfile b/.buildcontainer/Dockerfile index 56d9242..6df655d 100644 --- a/.buildcontainer/Dockerfile +++ b/.buildcontainer/Dockerfile @@ -26,4 +26,8 @@ RUN echo "source /home/$USERNAME/llm-env/bin/activate" >> /home/$USERNAME/.bashr # Add venv to PATH ENV PATH="/home/$USERNAME/llm-env/bin:$PATH" +# Print installed azure-search-documents version to verify it is present +RUN /home/$USERNAME/llm-env/bin/python -c \ + "import importlib.metadata; print('azure-search-documents version:', importlib.metadata.version('azure-search-documents'))" + CMD ["python", "--version"] \ No newline at end of file From b1f0077c501001f8313d98305ab0fcf1ae339b3c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 21:31:05 +0000 Subject: [PATCH 12/14] Fix AI Search connection creation fallback to hub workspace Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- src/agent/agent.py | 50 +++++++++++++++++++++++-------- tests/agent/test_agent.py | 63 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 13 deletions(-) diff --git a/src/agent/agent.py b/src/agent/agent.py index e7d4464..5c5fbf3 100644 --- a/src/agent/agent.py +++ b/src/agent/agent.py @@ -1,6 +1,7 @@ """Agent for chatting with documents indexed in Azure AI Search.""" import asyncio +from urllib.parse import urlparse from azure.identity import DefaultAzureCredential as SyncDefaultAzureCredential from azure.identity.aio import DefaultAzureCredential @@ -9,6 +10,7 @@ from azure.ai.agents.models import AzureAISearchTool, AzureAISearchQueryType from azure.ai.ml import MLClient from azure.ai.ml.entities import AzureAISearchConnection +from azure.core.exceptions import ResourceNotFoundError from semantic_kernel.agents import AzureAIAgent from semantic_kernel.agents import AzureAIAgentThread @@ -39,6 +41,14 @@ def _extract_project_name(endpoint: str) -> str: return endpoint.rstrip("/").split("/")[-1] +def _extract_hub_name(endpoint: str) -> str: + """Extract the AI Foundry hub name from the endpoint host.""" + hostname = urlparse(endpoint).hostname + if not hostname: + return "" + return hostname.split(".")[0] + + def ensure_ai_search_connection_id( endpoint: str, acs_service_name: str, @@ -78,19 +88,33 @@ def ensure_ai_search_connection_id( f"No AI Search connection found for '{acs_service_name}'. " "Creating connection in AI Foundry..." ) - project_name = _extract_project_name(endpoint) - ml_client = MLClient( - credential=credential, - subscription_id=subscription_id, - resource_group_name=resource_group_name, - workspace_name=project_name, - ) - new_connection = AzureAISearchConnection( - name=acs_service_name, - endpoint=f"https://{acs_service_name}.search.windows.net", - ) - created = ml_client.connections.create_or_update(new_connection) - return created.id + workspace_candidates = [_extract_project_name(endpoint), _extract_hub_name(endpoint)] + # Preserve order but remove duplicates when project and hub names are identical. + workspace_candidates = [name for name in dict.fromkeys(workspace_candidates) if name] + last_error = None + for workspace_name in workspace_candidates: + try: + ml_client = MLClient( + credential=credential, + subscription_id=subscription_id, + resource_group_name=resource_group_name, + workspace_name=workspace_name, + ) + new_connection = AzureAISearchConnection( + name=acs_service_name, + endpoint=f"https://{acs_service_name}.search.windows.net", + ) + created = ml_client.connections.create_or_update(new_connection) + return created.id + except ResourceNotFoundError as error: + last_error = error + + assert last_error is not None + attempted_workspaces = ", ".join(workspace_candidates) + raise RuntimeError( + "Unable to create Azure AI Search connection. " + f"Attempted ML workspaces: {attempted_workspaces}." + ) from last_error async def create_agent( diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index 56d7986..40c4202 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -3,6 +3,7 @@ import unittest from unittest.mock import AsyncMock, MagicMock, patch +from azure.core.exceptions import ResourceNotFoundError from azure.ai.projects.models import ConnectionType @@ -28,6 +29,31 @@ def test_handles_trailing_slash(self): self.assertEqual(result, "myproject") +class TestExtractHubName(unittest.TestCase): + """Tests for the _extract_hub_name helper.""" + + def test_extracts_hub_name_from_endpoint(self): + """Test that the hub name is parsed from endpoint hostname.""" + from src.agent.agent import _extract_hub_name + + result = _extract_hub_name( + "https://myhub.services.ai.azure.com/api/projects/myproject" + ) + self.assertEqual(result, "myhub") + + def test_returns_empty_hub_name_for_malformed_endpoint(self): + """Test malformed endpoint handling without host name.""" + from src.agent.agent import _extract_hub_name + + self.assertEqual(_extract_hub_name("not-a-url"), "") + + def test_handles_non_foundry_hostname(self): + """Test generic hostname handling.""" + from src.agent.agent import _extract_hub_name + + self.assertEqual(_extract_hub_name("https://incomplete"), "incomplete") + + class TestEnsureAISearchConnectionId(unittest.TestCase): """Tests for the ensure_ai_search_connection_id function.""" @@ -133,6 +159,43 @@ def test_creates_connection_when_no_name_match( self.assertEqual(result, "/connections/my-search") mock_ml_client.connections.create_or_update.assert_called_once() + @patch("src.agent.agent.MLClient") + @patch("src.agent.agent.SyncAIProjectClient") + @patch("src.agent.agent.SyncDefaultAzureCredential") + def test_retries_with_hub_name_when_project_workspace_not_found( + self, mock_cred_cls, mock_client_cls, mock_ml_client_cls + ): + """Test that creation retries with hub name if project workspace does not exist.""" + from src.agent.agent import ensure_ai_search_connection_id + + mock_client = MagicMock() + mock_client.connections.list.return_value = [] + mock_client_cls.return_value = mock_client + + project_ml_client = MagicMock() + project_ml_client.connections.create_or_update.side_effect = ResourceNotFoundError( + message="project workspace not found" + ) + hub_ml_client = MagicMock() + created_conn = MagicMock() + created_conn.id = "/connections/my-search" + hub_ml_client.connections.create_or_update.return_value = created_conn + mock_ml_client_cls.side_effect = [project_ml_client, hub_ml_client] + + result = ensure_ai_search_connection_id( + endpoint="https://myhub.services.ai.azure.com/api/projects/myproject", + acs_service_name="my-search", + subscription_id="sub-123", + resource_group_name="rg-test", + ) + + self.assertEqual(result, "/connections/my-search") + self.assertEqual(mock_ml_client_cls.call_count, 2) + first_call = mock_ml_client_cls.call_args_list[0].kwargs + second_call = mock_ml_client_cls.call_args_list[1].kwargs + self.assertEqual(first_call["workspace_name"], "myproject") + self.assertEqual(second_call["workspace_name"], "myhub") + class TestCreateAgent(unittest.IsolatedAsyncioTestCase): """Tests for the create_agent function.""" From 5299a8fd103bebb394229ceb2c69999b22147956 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 21:40:52 +0000 Subject: [PATCH 13/14] Revert hub-based workspace fallback for AI Search connection Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- src/agent/agent.py | 50 ++++++++----------------------- tests/agent/test_agent.py | 63 --------------------------------------- 2 files changed, 13 insertions(+), 100 deletions(-) diff --git a/src/agent/agent.py b/src/agent/agent.py index 5c5fbf3..e7d4464 100644 --- a/src/agent/agent.py +++ b/src/agent/agent.py @@ -1,7 +1,6 @@ """Agent for chatting with documents indexed in Azure AI Search.""" import asyncio -from urllib.parse import urlparse from azure.identity import DefaultAzureCredential as SyncDefaultAzureCredential from azure.identity.aio import DefaultAzureCredential @@ -10,7 +9,6 @@ from azure.ai.agents.models import AzureAISearchTool, AzureAISearchQueryType from azure.ai.ml import MLClient from azure.ai.ml.entities import AzureAISearchConnection -from azure.core.exceptions import ResourceNotFoundError from semantic_kernel.agents import AzureAIAgent from semantic_kernel.agents import AzureAIAgentThread @@ -41,14 +39,6 @@ def _extract_project_name(endpoint: str) -> str: return endpoint.rstrip("/").split("/")[-1] -def _extract_hub_name(endpoint: str) -> str: - """Extract the AI Foundry hub name from the endpoint host.""" - hostname = urlparse(endpoint).hostname - if not hostname: - return "" - return hostname.split(".")[0] - - def ensure_ai_search_connection_id( endpoint: str, acs_service_name: str, @@ -88,33 +78,19 @@ def ensure_ai_search_connection_id( f"No AI Search connection found for '{acs_service_name}'. " "Creating connection in AI Foundry..." ) - workspace_candidates = [_extract_project_name(endpoint), _extract_hub_name(endpoint)] - # Preserve order but remove duplicates when project and hub names are identical. - workspace_candidates = [name for name in dict.fromkeys(workspace_candidates) if name] - last_error = None - for workspace_name in workspace_candidates: - try: - ml_client = MLClient( - credential=credential, - subscription_id=subscription_id, - resource_group_name=resource_group_name, - workspace_name=workspace_name, - ) - new_connection = AzureAISearchConnection( - name=acs_service_name, - endpoint=f"https://{acs_service_name}.search.windows.net", - ) - created = ml_client.connections.create_or_update(new_connection) - return created.id - except ResourceNotFoundError as error: - last_error = error - - assert last_error is not None - attempted_workspaces = ", ".join(workspace_candidates) - raise RuntimeError( - "Unable to create Azure AI Search connection. " - f"Attempted ML workspaces: {attempted_workspaces}." - ) from last_error + project_name = _extract_project_name(endpoint) + ml_client = MLClient( + credential=credential, + subscription_id=subscription_id, + resource_group_name=resource_group_name, + workspace_name=project_name, + ) + new_connection = AzureAISearchConnection( + name=acs_service_name, + endpoint=f"https://{acs_service_name}.search.windows.net", + ) + created = ml_client.connections.create_or_update(new_connection) + return created.id async def create_agent( diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index 40c4202..56d7986 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -3,7 +3,6 @@ import unittest from unittest.mock import AsyncMock, MagicMock, patch -from azure.core.exceptions import ResourceNotFoundError from azure.ai.projects.models import ConnectionType @@ -29,31 +28,6 @@ def test_handles_trailing_slash(self): self.assertEqual(result, "myproject") -class TestExtractHubName(unittest.TestCase): - """Tests for the _extract_hub_name helper.""" - - def test_extracts_hub_name_from_endpoint(self): - """Test that the hub name is parsed from endpoint hostname.""" - from src.agent.agent import _extract_hub_name - - result = _extract_hub_name( - "https://myhub.services.ai.azure.com/api/projects/myproject" - ) - self.assertEqual(result, "myhub") - - def test_returns_empty_hub_name_for_malformed_endpoint(self): - """Test malformed endpoint handling without host name.""" - from src.agent.agent import _extract_hub_name - - self.assertEqual(_extract_hub_name("not-a-url"), "") - - def test_handles_non_foundry_hostname(self): - """Test generic hostname handling.""" - from src.agent.agent import _extract_hub_name - - self.assertEqual(_extract_hub_name("https://incomplete"), "incomplete") - - class TestEnsureAISearchConnectionId(unittest.TestCase): """Tests for the ensure_ai_search_connection_id function.""" @@ -159,43 +133,6 @@ def test_creates_connection_when_no_name_match( self.assertEqual(result, "/connections/my-search") mock_ml_client.connections.create_or_update.assert_called_once() - @patch("src.agent.agent.MLClient") - @patch("src.agent.agent.SyncAIProjectClient") - @patch("src.agent.agent.SyncDefaultAzureCredential") - def test_retries_with_hub_name_when_project_workspace_not_found( - self, mock_cred_cls, mock_client_cls, mock_ml_client_cls - ): - """Test that creation retries with hub name if project workspace does not exist.""" - from src.agent.agent import ensure_ai_search_connection_id - - mock_client = MagicMock() - mock_client.connections.list.return_value = [] - mock_client_cls.return_value = mock_client - - project_ml_client = MagicMock() - project_ml_client.connections.create_or_update.side_effect = ResourceNotFoundError( - message="project workspace not found" - ) - hub_ml_client = MagicMock() - created_conn = MagicMock() - created_conn.id = "/connections/my-search" - hub_ml_client.connections.create_or_update.return_value = created_conn - mock_ml_client_cls.side_effect = [project_ml_client, hub_ml_client] - - result = ensure_ai_search_connection_id( - endpoint="https://myhub.services.ai.azure.com/api/projects/myproject", - acs_service_name="my-search", - subscription_id="sub-123", - resource_group_name="rg-test", - ) - - self.assertEqual(result, "/connections/my-search") - self.assertEqual(mock_ml_client_cls.call_count, 2) - first_call = mock_ml_client_cls.call_args_list[0].kwargs - second_call = mock_ml_client_cls.call_args_list[1].kwargs - self.assertEqual(first_call["workspace_name"], "myproject") - self.assertEqual(second_call["workspace_name"], "myhub") - class TestCreateAgent(unittest.IsolatedAsyncioTestCase): """Tests for the create_agent function.""" From 89a59f3be90c99bd6cc2ffe1c0b0f36fabf313b3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:13:48 +0000 Subject: [PATCH 14/14] Fix workspace name resolution: use Foundry hostname instead of project path Co-authored-by: sbaidachni <10055252+sbaidachni@users.noreply.github.com> --- src/agent/agent.py | 22 ++++++++++++++-------- tests/agent/test_agent.py | 22 +++++++++++----------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/src/agent/agent.py b/src/agent/agent.py index e7d4464..fe30d3c 100644 --- a/src/agent/agent.py +++ b/src/agent/agent.py @@ -25,18 +25,24 @@ ) -def _extract_project_name(endpoint: str) -> str: - """ - Extract the AI Foundry project name from the project endpoint URL. +def _extract_workspace_name(endpoint: str) -> str: + """Extract the AI Foundry workspace name from the project endpoint URL. + + In Azure AI Foundry the ARM workspace resource is the Foundry resource + whose name appears as the hostname prefix (before + ``.services.ai.azure.com``), **not** the project name in the URL path. Args: endpoint (str): URL in the form - ``https://.services.ai.azure.com/api/projects/``. + ``https://.services.ai.azure.com/api/projects/``. Returns: - str: The project name (last path segment of the URL). + str: The workspace / Foundry resource name. """ - return endpoint.rstrip("/").split("/")[-1] + from urllib.parse import urlparse + + hostname = urlparse(endpoint).hostname or "" + return hostname.split(".")[0] def ensure_ai_search_connection_id( @@ -78,12 +84,12 @@ def ensure_ai_search_connection_id( f"No AI Search connection found for '{acs_service_name}'. " "Creating connection in AI Foundry..." ) - project_name = _extract_project_name(endpoint) + workspace_name = _extract_workspace_name(endpoint) ml_client = MLClient( credential=credential, subscription_id=subscription_id, resource_group_name=resource_group_name, - workspace_name=project_name, + workspace_name=workspace_name, ) new_connection = AzureAISearchConnection( name=acs_service_name, diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index 56d7986..a6ff615 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -6,26 +6,26 @@ from azure.ai.projects.models import ConnectionType -class TestExtractProjectName(unittest.TestCase): - """Tests for the _extract_project_name helper.""" +class TestExtractWorkspaceName(unittest.TestCase): + """Tests for the _extract_workspace_name helper.""" - def test_extracts_last_path_segment(self): - """Test that the project name is parsed from the endpoint URL.""" - from src.agent.agent import _extract_project_name + def test_extracts_foundry_name_from_hostname(self): + """Test that the workspace name is parsed from the endpoint hostname.""" + from src.agent.agent import _extract_workspace_name - result = _extract_project_name( + result = _extract_workspace_name( "https://myhub.services.ai.azure.com/api/projects/myproject" ) - self.assertEqual(result, "myproject") + self.assertEqual(result, "myhub") def test_handles_trailing_slash(self): """Test that a trailing slash is ignored.""" - from src.agent.agent import _extract_project_name + from src.agent.agent import _extract_workspace_name - result = _extract_project_name( + result = _extract_workspace_name( "https://myhub.services.ai.azure.com/api/projects/myproject/" ) - self.assertEqual(result, "myproject") + self.assertEqual(result, "myhub") class TestEnsureAISearchConnectionId(unittest.TestCase): @@ -94,7 +94,7 @@ def test_creates_connection_when_not_found( call_kwargs = mock_ml_client_cls.call_args.kwargs self.assertEqual(call_kwargs["subscription_id"], "sub-123") self.assertEqual(call_kwargs["resource_group_name"], "rg-test") - self.assertEqual(call_kwargs["workspace_name"], "myproject") + self.assertEqual(call_kwargs["workspace_name"], "test") mock_ml_client.connections.create_or_update.assert_called_once() created_connection = mock_ml_client.connections.create_or_update.call_args.args[0] self.assertEqual(created_connection.name, "my-search")