From 2e51c2d30714b7409691f5108b56a381a3b48e93 Mon Sep 17 00:00:00 2001 From: Yuchen Zhang Date: Thu, 14 May 2026 15:14:28 -0700 Subject: [PATCH] ensure connection lifecycle is owned by one dedicated lifecycle task Signed-off-by: Yuchen Zhang --- .../src/nat/plugins/mcp/client/client_base.py | 307 ++++++++++++------ .../tests/client/test_mcp_client_base.py | 153 ++++++--- 2 files changed, 316 insertions(+), 144 deletions(-) diff --git a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client/client_base.py b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client/client_base.py index d158c5eb66..3d55ee3ac2 100644 --- a/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client/client_base.py +++ b/packages/nvidia_nat_mcp/src/nat/plugins/mcp/client/client_base.py @@ -118,6 +118,7 @@ async def _get_auth_headers(self, # Build headers from credentials from nat.data_models.authentication import BearerTokenCred from nat.data_models.authentication import HeaderCred + headers = {} for cred in auth_result.credentials: @@ -150,30 +151,34 @@ class MCPBaseClient(ABC): reconnect_max_backoff (float): Maximum backoff delay in seconds for reconnection attempts """ - def __init__(self, - transport: str = 'streamable-http', - auth_provider: AuthProviderBase | None = None, - user_id: str | None = None, - tool_call_timeout: timedelta = timedelta(seconds=60), - auth_flow_timeout: timedelta = timedelta(seconds=300), - reconnect_enabled: bool = True, - reconnect_max_attempts: int = 2, - reconnect_initial_backoff: float = 0.5, - reconnect_max_backoff: float = 50.0): + def __init__( + self, + transport: str = "streamable-http", + auth_provider: AuthProviderBase | None = None, + user_id: str | None = None, + tool_call_timeout: timedelta = timedelta(seconds=60), + auth_flow_timeout: timedelta = timedelta(seconds=300), + reconnect_enabled: bool = True, + reconnect_max_attempts: int = 2, + reconnect_initial_backoff: float = 0.5, + reconnect_max_backoff: float = 50.0, + ): self._tools = None self._transport = transport.lower() - if self._transport not in ['sse', 'stdio', 'streamable-http']: + if self._transport not in ["sse", "stdio", "streamable-http"]: raise ValueError("transport must be either 'sse', 'stdio' or 'streamable-http'") self._exit_stack: AsyncExitStack | None = None self._session: ClientSession | None = None # Main session self._connection_established = False self._initial_connection = False + self._lifecycle_task: asyncio.Task | None = None + self._lifecycle_commands: asyncio.Queue[tuple[str, asyncio.Future[None]]] | None = None # Convert auth provider to AuthAdapter self._auth_provider = auth_provider # Use provided user_id or fall back to auth provider's default_user_id (if available) - effective_user_id = user_id or (getattr(auth_provider.config, 'default_user_id', None) + effective_user_id = user_id or (getattr(auth_provider.config, "default_user_id", None) if auth_provider else None) self._httpx_auth = AuthAdapter(auth_provider, effective_user_id) if auth_provider else None @@ -196,27 +201,39 @@ def transport(self) -> str: return self._transport async def __aenter__(self): - if self._exit_stack: + if self._lifecycle_task and not self._lifecycle_task.done(): raise RuntimeError("MCPBaseClient already initialized. Use async with to initialize.") - self._exit_stack = AsyncExitStack() - - # Establish connection with httpx.Auth - self._session = await self._exit_stack.enter_async_context(self.connect_to_server()) - - self._initial_connection = True - self._connection_established = True + self._lifecycle_commands = asyncio.Queue() + self._lifecycle_task = asyncio.create_task(self._lifecycle_worker(), name=f"mcp-client-{self.server_name}") + try: + await self._run_lifecycle_command("connect") + except Exception: + self._lifecycle_task.cancel() + try: + await self._lifecycle_task + except asyncio.CancelledError: + pass + self._lifecycle_task = None + self._lifecycle_commands = None + raise return self async def __aexit__(self, exc_type, exc_value, traceback): - if self._exit_stack: - # Close session - await self._exit_stack.aclose() - self._session = None - self._exit_stack = None + lifecycle_task = self._lifecycle_task + if lifecycle_task and not lifecycle_task.done(): + try: + await self._run_lifecycle_command("close") + finally: + await lifecycle_task + + self._lifecycle_task = None + self._lifecycle_commands = None self._connection_established = False + self._session = None + self._exit_stack = None self._tools = None @property @@ -251,15 +268,7 @@ async def _reconnect(self): while attempt in range(0, self._reconnect_max_attempts): attempt += 1 try: - # Close the existing stack and ClientSession - if self._exit_stack: - await self._exit_stack.aclose() - # Create a fresh stack and session - self._exit_stack = AsyncExitStack() - self._session = await self._exit_stack.enter_async_context(self.connect_to_server()) - - self._connection_established = True - self._tools = None + await self._run_lifecycle_command("reconnect") logger.info("Reconnected to MCP server (%s) on attempt %d", self.server_name, attempt) return @@ -275,6 +284,80 @@ async def _reconnect(self): if last_error: raise last_error + async def _run_lifecycle_command(self, command: str) -> None: + """Run a connection lifecycle command in the task that owns the transport stack.""" + if self._lifecycle_commands is None or self._lifecycle_task is None or self._lifecycle_task.done(): + raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.") + + future: asyncio.Future[None] = asyncio.get_running_loop().create_future() + await self._lifecycle_commands.put((command, future)) + await future + + async def _lifecycle_worker(self) -> None: + """Own MCP transport context entry and exit to keep AnyIO cancel scopes task-local.""" + if self._lifecycle_commands is None: + raise RuntimeError("MCPBaseClient lifecycle command queue is not initialized") + + while True: + command, future = await self._lifecycle_commands.get() + + if command == "close": + try: + await self._close_connection() + except Exception as e: + if not future.done(): + future.set_exception(e) + else: + if not future.done(): + future.set_result(None) + return + + try: + if command == "connect": + await self._connect_connection() + elif command == "reconnect": + await self._close_connection() + await self._connect_connection() + else: + raise RuntimeError(f"Unsupported MCP client lifecycle command: {command}") + except Exception as e: + self._connection_established = False + self._session = None + self._tools = None + if not future.done(): + future.set_exception(e) + else: + if not future.done(): + future.set_result(None) + + async def _connect_connection(self) -> None: + """Enter the MCP transport context in the lifecycle task.""" + stack = AsyncExitStack() + try: + session = await stack.enter_async_context(self.connect_to_server()) + except Exception: + await stack.aclose() + self._exit_stack = None + self._session = None + self._connection_established = False + raise + + self._exit_stack = stack + self._session = session + self._initial_connection = True + self._connection_established = True + + async def _close_connection(self) -> None: + """Exit the MCP transport context in the same task that entered it.""" + stack = self._exit_stack + self._exit_stack = None + self._session = None + self._connection_established = False + self._tools = None + + if stack is not None: + await stack.aclose() + async def _with_reconnect(self, coro): """ Execute an awaited operation, reconnecting once on errors. @@ -317,9 +400,9 @@ async def _has_cached_auth_token(self) -> bool: try: # Check if OAuth2 provider has tokens cached - if hasattr(self._auth_provider, '_auth_code_provider'): + if hasattr(self._auth_provider, "_auth_code_provider"): provider = self._auth_provider._auth_code_provider - if provider and hasattr(provider, '_authenticated_tokens'): + if provider and hasattr(provider, "_authenticated_tokens"): # Check if we have at least one non-expired token for auth_result in provider._authenticated_tokens.values(): if not auth_result.is_expired(): @@ -362,6 +445,7 @@ async def _get_tools(): tools = await session.list_tools() except TimeoutError as e: from nat.plugins.mcp.exceptions import MCPTimeoutError + raise MCPTimeoutError(self.server_name, e) return tools @@ -374,11 +458,13 @@ async def _get_tools(): return { tool.name: - MCPToolClient(session=self._session, - tool_name=tool.name, - tool_description=tool.description, - tool_input_schema=tool.inputSchema, - parent_client=self) + MCPToolClient( + session=self._session, + tool_name=tool.name, + tool_description=tool.description, + tool_input_schema=tool.inputSchema, + parent_client=self, + ) for tool in response.tools } @@ -431,21 +517,25 @@ class MCPSSEClient(MCPBaseClient): url (str): The url of the MCP server """ - def __init__(self, - url: str, - tool_call_timeout: timedelta = timedelta(seconds=60), - auth_flow_timeout: timedelta = timedelta(seconds=300), - reconnect_enabled: bool = True, - reconnect_max_attempts: int = 2, - reconnect_initial_backoff: float = 0.5, - reconnect_max_backoff: float = 50.0): - super().__init__("sse", - tool_call_timeout=tool_call_timeout, - auth_flow_timeout=auth_flow_timeout, - reconnect_enabled=reconnect_enabled, - reconnect_max_attempts=reconnect_max_attempts, - reconnect_initial_backoff=reconnect_initial_backoff, - reconnect_max_backoff=reconnect_max_backoff) + def __init__( + self, + url: str, + tool_call_timeout: timedelta = timedelta(seconds=60), + auth_flow_timeout: timedelta = timedelta(seconds=300), + reconnect_enabled: bool = True, + reconnect_max_attempts: int = 2, + reconnect_initial_backoff: float = 0.5, + reconnect_max_backoff: float = 50.0, + ): + super().__init__( + "sse", + tool_call_timeout=tool_call_timeout, + auth_flow_timeout=auth_flow_timeout, + reconnect_enabled=reconnect_enabled, + reconnect_max_attempts=reconnect_max_attempts, + reconnect_initial_backoff=reconnect_initial_backoff, + reconnect_max_backoff=reconnect_max_backoff, + ) self._url = url @property @@ -480,23 +570,27 @@ class MCPStdioClient(MCPBaseClient): env (dict[str, str] | None): Environment variables to set for the process """ - def __init__(self, - command: str, - args: list[str] | None = None, - env: dict[str, str] | None = None, - tool_call_timeout: timedelta = timedelta(seconds=60), - auth_flow_timeout: timedelta = timedelta(seconds=300), - reconnect_enabled: bool = True, - reconnect_max_attempts: int = 2, - reconnect_initial_backoff: float = 0.5, - reconnect_max_backoff: float = 50.0): - super().__init__("stdio", - tool_call_timeout=tool_call_timeout, - auth_flow_timeout=auth_flow_timeout, - reconnect_enabled=reconnect_enabled, - reconnect_max_attempts=reconnect_max_attempts, - reconnect_initial_backoff=reconnect_initial_backoff, - reconnect_max_backoff=reconnect_max_backoff) + def __init__( + self, + command: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + tool_call_timeout: timedelta = timedelta(seconds=60), + auth_flow_timeout: timedelta = timedelta(seconds=300), + reconnect_enabled: bool = True, + reconnect_max_attempts: int = 2, + reconnect_initial_backoff: float = 0.5, + reconnect_max_backoff: float = 50.0, + ): + super().__init__( + "stdio", + tool_call_timeout=tool_call_timeout, + auth_flow_timeout=auth_flow_timeout, + reconnect_enabled=reconnect_enabled, + reconnect_max_attempts=reconnect_max_attempts, + reconnect_initial_backoff=reconnect_initial_backoff, + reconnect_max_backoff=reconnect_max_backoff, + ) self._command = command self._args = args self._env = env @@ -548,26 +642,30 @@ class MCPStreamableHTTPClient(MCPBaseClient): reconnect_max_backoff (float): Maximum backoff delay in seconds """ - def __init__(self, - url: str, - auth_provider: AuthProviderBase | None = None, - user_id: str | None = None, - custom_headers: dict[str, str] | None = None, - tool_call_timeout: timedelta = timedelta(seconds=60), - auth_flow_timeout: timedelta = timedelta(seconds=300), - reconnect_enabled: bool = True, - reconnect_max_attempts: int = 2, - reconnect_initial_backoff: float = 0.5, - reconnect_max_backoff: float = 50.0): - super().__init__("streamable-http", - auth_provider=auth_provider, - user_id=user_id, - tool_call_timeout=tool_call_timeout, - auth_flow_timeout=auth_flow_timeout, - reconnect_enabled=reconnect_enabled, - reconnect_max_attempts=reconnect_max_attempts, - reconnect_initial_backoff=reconnect_initial_backoff, - reconnect_max_backoff=reconnect_max_backoff) + def __init__( + self, + url: str, + auth_provider: AuthProviderBase | None = None, + user_id: str | None = None, + custom_headers: dict[str, str] | None = None, + tool_call_timeout: timedelta = timedelta(seconds=60), + auth_flow_timeout: timedelta = timedelta(seconds=300), + reconnect_enabled: bool = True, + reconnect_max_attempts: int = 2, + reconnect_initial_backoff: float = 0.5, + reconnect_max_backoff: float = 50.0, + ): + super().__init__( + "streamable-http", + auth_provider=auth_provider, + user_id=user_id, + tool_call_timeout=tool_call_timeout, + auth_flow_timeout=auth_flow_timeout, + reconnect_enabled=reconnect_enabled, + reconnect_max_attempts=reconnect_max_attempts, + reconnect_initial_backoff=reconnect_initial_backoff, + reconnect_max_backoff=reconnect_max_backoff, + ) self._url = url self._custom_headers = custom_headers or {} # Callback to retrieve MCP session ID from the transport layer @@ -633,8 +731,11 @@ async def connect_to_server(self): try: async with http_client: - async with streamable_http_client(url=self._url, - http_client=http_client) as (read, write, get_session_id): + async with streamable_http_client(url=self._url, http_client=http_client) as ( + read, + write, + get_session_id, + ): # Store the session ID callback for later retrieval self._get_mcp_session_id = get_session_id async with ClientSession(read, write) as session: @@ -658,16 +759,18 @@ class MCPToolClient: parent_client (MCPBaseClient): The parent MCP client for auth management. """ - def __init__(self, - session: ClientSession, - parent_client: MCPBaseClient, - tool_name: str, - tool_description: str | None, - tool_input_schema: dict | None = None): + def __init__( + self, + session: ClientSession, + parent_client: MCPBaseClient, + tool_name: str, + tool_description: str | None, + tool_input_schema: dict | None = None, + ): self._session = session self._tool_name = tool_name self._tool_description = tool_description - self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None) + self._input_schema = model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None self._parent_client = parent_client if self._parent_client is None: diff --git a/packages/nvidia_nat_mcp/tests/client/test_mcp_client_base.py b/packages/nvidia_nat_mcp/tests/client/test_mcp_client_base.py index 9def908723..67cfcc9f26 100644 --- a/packages/nvidia_nat_mcp/tests/client/test_mcp_client_base.py +++ b/packages/nvidia_nat_mcp/tests/client/test_mcp_client_base.py @@ -136,7 +136,6 @@ async def mcp_client_fixture(request: pytest.FixtureRequest, unused_tcp_port_fac async def test_mcp_client_base_methods(mcp_client: MCPBaseClient): async with mcp_client: - # Test get_tools tools = await mcp_client.get_tools() assert len(tools) == 2 @@ -158,7 +157,6 @@ async def test_mcp_client_base_methods(mcp_client: MCPBaseClient): @pytest.mark.skip(reason="Temporarily disabled while debugging MCP server hang") async def test_error_handling(mcp_client: MCPBaseClient): async with mcp_client: - tool = await mcp_client.get_tool("throw_error") with pytest.raises(RuntimeError) as e: @@ -199,9 +197,8 @@ def __init__(self, client): async def __aenter__(self): self.client.connect_call_count += 1 # Only fail during reconnect attempts, not initial connection for most tests - if (self.client.connect_should_fail and self.client.connect_call_count > 1 - and # Allow first connection to succeed - self.client.connect_call_count <= self.client.connect_failure_count + 1): + if (self.client.connect_should_fail and self.client.connect_call_count > 1 # Allow first connection to succeed + and self.client.connect_call_count <= self.client.connect_failure_count + 1): raise ConnectionError(f"Mock connection failure #{self.client.connect_call_count}") # Return a mock session @@ -224,13 +221,58 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): pass +class TaskBoundMockMCPClient(MCPBaseClient): + """Mock client whose transport context must be exited by the task that entered it.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.connect_call_count = 0 + self.list_tools_call_count = 0 + self.contexts: list[TaskBoundAsyncContextManager] = [] + + def connect_to_server(self): # type: ignore + return TaskBoundAsyncContextManager(self) + + +class TaskBoundAsyncContextManager: + """Context manager that mimics AnyIO task-bound transport cleanup.""" + + def __init__(self, client: TaskBoundMockMCPClient): + self.client = client + self.enter_task: asyncio.Task | None = None + self.exit_task: asyncio.Task | None = None + + async def __aenter__(self): + self.enter_task = asyncio.current_task() + self.client.connect_call_count += 1 + self.client.contexts.append(self) + + mock_session = AsyncMock(spec=ClientSession) + + async def mock_list_tools(): + self.client.list_tools_call_count += 1 + if self.client.list_tools_call_count == 1: + raise ConnectionError("Connection lost") + return MagicMock(tools=[]) + + mock_session.list_tools.side_effect = mock_list_tools + return mock_session + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_task = asyncio.current_task() + if self.exit_task is not self.enter_task: + raise RuntimeError("Attempted to exit a cancel scope that isn't the current task's current cancel scope") + + async def test_reconnect_configuration(): """Test that reconnect configuration parameters are properly set.""" - client = MockMCPClient(transport="streamable-http", - reconnect_enabled=False, - reconnect_max_attempts=5, - reconnect_initial_backoff=1.0, - reconnect_max_backoff=100.0) + client = MockMCPClient( + transport="streamable-http", + reconnect_enabled=False, + reconnect_max_attempts=5, + reconnect_initial_backoff=1.0, + reconnect_max_backoff=100.0, + ) assert client._reconnect_enabled is False assert client._reconnect_max_attempts == 5 assert client._reconnect_initial_backoff == 1.0 @@ -271,7 +313,8 @@ async def test_reconnect_success_after_failure(): reconnect_enabled=True, reconnect_max_attempts=2, reconnect_initial_backoff=0.01, # Fast for testing - reconnect_max_backoff=0.02) + reconnect_max_backoff=0.02, + ) # Mock the session to fail once, then succeed call_count = 0 @@ -294,13 +337,34 @@ async def mock_list_tools(): assert call_count == 2 +async def test_reconnect_exits_transport_stack_in_lifecycle_task(): + """Reconnect from a request task should not exit task-bound transports in that request task.""" + client = TaskBoundMockMCPClient( + transport="stdio", + reconnect_enabled=True, + reconnect_max_attempts=2, + reconnect_initial_backoff=0.01, + reconnect_max_backoff=0.02, + ) + + async with client: + result = await asyncio.create_task(client.get_tools()) + + assert result == {} + assert client.connect_call_count == 2 + assert client.list_tools_call_count == 2 + assert all(context.exit_task is context.enter_task for context in client.contexts) + + async def test_reconnect_max_attempts_exceeded(): """Test that reconnect gives up after max attempts.""" - client = MockMCPClient(transport="streamable-http", - reconnect_enabled=True, - reconnect_max_attempts=2, - reconnect_initial_backoff=0.01, - reconnect_max_backoff=0.02) + client = MockMCPClient( + transport="streamable-http", + reconnect_enabled=True, + reconnect_max_attempts=2, + reconnect_initial_backoff=0.01, + reconnect_max_backoff=0.02, + ) # Configure client to fail connection attempts during reconnect client.connect_should_fail = True @@ -320,11 +384,13 @@ async def always_fail(): @pytest.mark.skip(reason="This test might fail in CI due to race conditions") async def test_reconnect_backoff_timing(): """Test that reconnect backoff timing works correctly.""" - client = MockMCPClient(transport="streamable-http", - reconnect_enabled=True, - reconnect_max_attempts=3, - reconnect_initial_backoff=0.1, - reconnect_max_backoff=0.5) + client = MockMCPClient( + transport="streamable-http", + reconnect_enabled=True, + reconnect_max_attempts=3, + reconnect_initial_backoff=0.1, + reconnect_max_backoff=0.5, + ) # Track timing of reconnect attempts attempt_times = [] @@ -350,7 +416,7 @@ async def mock_list_tools(): client.list_tools_side_effect = mock_list_tools - with patch('asyncio.sleep', mock_sleep): + with patch("asyncio.sleep", mock_sleep): async with client: # Should eventually succeed await client.get_tools() @@ -369,7 +435,7 @@ async def test_reconnect_max_backoff_limit(): reconnect_enabled=True, reconnect_max_attempts=4, reconnect_initial_backoff=0.2, - reconnect_max_backoff=0.3 # Low max for testing + reconnect_max_backoff=0.3, # Low max for testing ) attempt_times = [] @@ -388,7 +454,7 @@ async def always_fail(): client.list_tools_side_effect = always_fail - with patch('asyncio.sleep', mock_sleep): + with patch("asyncio.sleep", mock_sleep): async with client: with pytest.raises(MCPConnectionError): await client.get_tools() @@ -598,11 +664,13 @@ def test_tool_client_with_input_schema(self): input_schema = {"type": "object", "properties": {"arg1": {"type": "string"}, "arg2": {"type": "number"}}} # Create MCPToolClient instance - tool_client = MCPToolClient(session=mock_session, - parent_client=mock_parent_client, - tool_name="test_tool", - tool_description="Test tool", - tool_input_schema=input_schema) + tool_client = MCPToolClient( + session=mock_session, + parent_client=mock_parent_client, + tool_name="test_tool", + tool_description="Test tool", + tool_input_schema=input_schema, + ) # Verify input schema is processed assert tool_client.input_schema is not None @@ -616,10 +684,12 @@ def test_tool_client_description_override(self): mock_parent_client = MagicMock() # Create MCPToolClient instance - tool_client = MCPToolClient(session=mock_session, - parent_client=mock_parent_client, - tool_name="test_tool", - tool_description="Original description") + tool_client = MCPToolClient( + session=mock_session, + parent_client=mock_parent_client, + tool_name="test_tool", + tool_description="Original description", + ) # Override description tool_client.set_description("New description") @@ -712,8 +782,8 @@ async def test_connect_to_server_sets_session_id_callback(self): async def mock_streamable_client(*args, **kwargs): yield (AsyncMock(), AsyncMock(), mock_session_id_callback) - with patch('nat.plugins.mcp.client.client_base.streamable_http_client', mock_streamable_client): - with patch('nat.plugins.mcp.client.client_base.ClientSession') as MockClientSession: + with patch("nat.plugins.mcp.client.client_base.streamable_http_client", mock_streamable_client): + with patch("nat.plugins.mcp.client.client_base.ClientSession") as MockClientSession: mock_session_cm = AsyncMock() mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session) mock_session_cm.__aexit__ = AsyncMock(return_value=None) @@ -739,14 +809,14 @@ async def test_connect_to_server_passes_custom_headers(self): @asynccontextmanager async def mock_streamable_client(*args, **kwargs): nonlocal captured_http_client - captured_http_client = kwargs.get('http_client') + captured_http_client = kwargs.get("http_client") yield (AsyncMock(), AsyncMock(), MagicMock(return_value=None)) mock_session = AsyncMock() mock_session.initialize = AsyncMock() - with patch('nat.plugins.mcp.client.client_base.streamable_http_client', mock_streamable_client): - with patch('nat.plugins.mcp.client.client_base.ClientSession') as MockClientSession: + with patch("nat.plugins.mcp.client.client_base.streamable_http_client", mock_streamable_client): + with patch("nat.plugins.mcp.client.client_base.ClientSession") as MockClientSession: mock_session_cm = AsyncMock() mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session) mock_session_cm.__aexit__ = AsyncMock(return_value=None) @@ -770,14 +840,14 @@ async def test_connect_to_server_no_headers_when_empty(self): @asynccontextmanager async def mock_streamable_client(*args, **kwargs): nonlocal captured_http_client - captured_http_client = kwargs.get('http_client') + captured_http_client = kwargs.get("http_client") yield (AsyncMock(), AsyncMock(), MagicMock(return_value=None)) mock_session = AsyncMock() mock_session.initialize = AsyncMock() - with patch('nat.plugins.mcp.client.client_base.streamable_http_client', mock_streamable_client): - with patch('nat.plugins.mcp.client.client_base.ClientSession') as MockClientSession: + with patch("nat.plugins.mcp.client.client_base.streamable_http_client", mock_streamable_client): + with patch("nat.plugins.mcp.client.client_base.ClientSession") as MockClientSession: mock_session_cm = AsyncMock() mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session) mock_session_cm.__aexit__ = AsyncMock(return_value=None) @@ -835,7 +905,6 @@ def test_transport_defaults_to_streamable_http(self): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="MCP Server") parser.add_argument("--transport", type=str, default="stdio", help="Transport to use for the server")