1111from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1212
1313from ...types import DatasetRow , MCPSession , MCPToolCall
14- from ..execution . manager import ExecutionManager
14+ from ..client . connection import MCPConnectionManager
1515
1616logger = logging .getLogger (__name__ )
1717
@@ -44,7 +44,7 @@ def __init__(
4444 self .user_prompt_formatter = user_prompt_formatter or self ._default_formatter
4545 self .n = len (sessions )
4646 self .tool_schemas = [] # Discovered from MCP servers
47- self .execution_manager = ExecutionManager ()
47+ self .connection_manager = MCPConnectionManager ()
4848 self .usage_stats = {} # llm usage stats for monitoring
4949
5050 if len (sessions ) != len (dataset_rows ):
@@ -58,17 +58,14 @@ async def reset(self, session: MCPSession) -> Tuple[Any, List[Dict]]:
5858
5959 This is thread-safe and can be called from worker threads.
6060 """
61- # Establish a persistent session for each environment.
62- await self .execution_manager .connection_manager .initialize_session (session )
63-
6461 # Get available tools from MCP server
65- tool_schemas = await self .execution_manager . connection_manager .discover_tools (session )
62+ tool_schemas = await self .connection_manager .discover_tools (session )
6663
6764 if not self .tool_schemas :
6865 self .tool_schemas = tool_schemas
6966
7067 # PROPER MCP PATTERN: Get initial state from resources during session establishment
71- initial_observation = await self .execution_manager . connection_manager .get_initial_state (session )
68+ initial_observation = await self .connection_manager .get_initial_state (session )
7269
7370 # Update session state
7471 session .terminated = False
@@ -119,7 +116,7 @@ async def step(self, env_index: int, tool_call: MCPToolCall) -> Tuple[Any, float
119116 )
120117
121118 # Execute the tool call via MCP protocol
122- observation , reward , done , info = await self .execution_manager . connection_manager .call_tool (
119+ observation , reward , done , info = await self .connection_manager .call_tool (
123120 session , tool_call .tool_name , tool_call .arguments
124121 )
125122
@@ -223,5 +220,6 @@ def _default_formatter(self, template: str, obs: Any, context: Dict) -> Union[st
223220 async def close (self ):
224221 """Closes all MCP sessions."""
225222 print (f"🧹 Closing { self .n } MCP sessions..." )
226- await self .execution_manager .close_sessions (self .sessions )
223+ tasks = [self .connection_manager .close_session (session ) for session in self .sessions ]
224+ await asyncio .gather (* tasks )
227225 print (f"✅ All MCP sessions closed." )
0 commit comments