Skip to content

Commit 162f8c9

Browse files
committed
specialize duty for session and execution manager for further customization
1 parent 1c00878 commit 162f8c9

4 files changed

Lines changed: 33 additions & 52 deletions

File tree

eval_protocol/mcp/execution/manager.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from ...models import CompletionParams, EvaluationRow, InputMetadata, Message
2424
from ...types import MCPSession, MCPToolCall, TerminationReason, Trajectory
25-
from ..client.connection import MCPConnectionManager
2625

2726
if TYPE_CHECKING:
2827
from ..session.manager import GeneralMCPVectorEnv
@@ -33,43 +32,9 @@
3332

3433
class ExecutionManager:
3534
"""
36-
Unified manager that handles both MCP session lifecycle and rollout execution.
37-
38-
Combines the functionality of SessionManager and RolloutManager for better
39-
organization and reduced complexity.
35+
Manage rollout for MCP environments.
4036
"""
4137

42-
def __init__(self):
43-
"""Initialize the execution manager."""
44-
self.connection_manager = MCPConnectionManager()
45-
46-
async def initialize_sessions(self, sessions: List[MCPSession]) -> None:
47-
"""
48-
Initialize multiple MCP sessions in parallel.
49-
50-
Args:
51-
sessions: List of MCPSessions to initialize
52-
"""
53-
tasks = [self.connection_manager.initialize_session(session) for session in sessions]
54-
await asyncio.gather(*tasks)
55-
56-
async def close_sessions(self, sessions: List[MCPSession]) -> None:
57-
"""
58-
Close multiple MCP sessions in parallel.
59-
60-
Args:
61-
sessions: List of MCPSessions to close
62-
"""
63-
tasks = [asyncio.create_task(self.connection_manager.close_session(session)) for session in sessions]
64-
65-
if tasks:
66-
try:
67-
# Wait for all close operations to complete
68-
await asyncio.gather(*tasks, return_exceptions=True)
69-
except asyncio.CancelledError:
70-
# Handle cancellation gracefully (especially important for Python 3.12)
71-
logger.debug("Close operation was cancelled, but sessions are marked as closed")
72-
7338
async def execute_rollouts(
7439
self,
7540
envs: "GeneralMCPVectorEnv",
@@ -178,7 +143,7 @@ async def _execute_with_semaphore(idx):
178143
for msg in trajectory.conversation_history:
179144
# Create a copy to avoid modifying the original
180145
msg_dict = dict(msg)
181-
146+
182147
# Handle multimodal content (list of content blocks) by extracting text
183148
if isinstance(msg_dict.get("content"), list):
184149
text_content = None
@@ -187,7 +152,7 @@ async def _execute_with_semaphore(idx):
187152
text_content = content_block.get("text")
188153
break
189154
msg_dict["content"] = text_content or ""
190-
155+
191156
messages.append(Message.model_validate(msg_dict))
192157

193158
input_metadata = InputMetadata(

eval_protocol/mcp/mcpgym.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(self, server_name: str, adapter: EnvironmentAdapter, seed: Optional
115115
# Register tools and control plane endpoints
116116
self._register_tools()
117117
self._discover_and_register_control_plane_endpoints()
118+
self._register_session_reset_endpoint()
118119

119120
def _get_session_id(self, ctx: Context) -> str:
120121
"""
@@ -226,6 +227,19 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
226227

227228
return self.sessions[session_id]
228229

230+
def _register_session_reset_endpoint(self):
231+
232+
@self.mcp.custom_route("/control/reset_session", methods=["POST"])
233+
async def reset_session_endpoint(request: Request, ctx: Context) -> JSONResponse:
234+
session_id = request.headers.get("mcp-session-id")
235+
if not session_id:
236+
return JSONResponse({"error": "Missing mcp-session-id header"}, status_code=400)
237+
with self.session_lock:
238+
if session_id in self.sessions:
239+
del self.sessions[session_id]
240+
self.sessions[session_id] = self._get_or_create_session(ctx)
241+
return JSONResponse({"message": "Session reset successfully"})
242+
229243
def _discover_and_register_control_plane_endpoints(self):
230244
"""
231245
Discover and register control plane endpoints on the subclass instance.

eval_protocol/mcp/session/manager.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1212

1313
from ...types import DatasetRow, MCPSession, MCPToolCall
14-
from ..execution.manager import ExecutionManager
14+
from ..client.connection import MCPConnectionManager
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -44,7 +44,7 @@ def __init__(
4444
self.user_prompt_formatter = user_prompt_formatter or self._default_formatter
4545
self.n = len(sessions)
4646
self.tool_schemas = [] # Discovered from MCP servers
47-
self.execution_manager = ExecutionManager()
47+
self.connection_manager = MCPConnectionManager()
4848
self.usage_stats = {} # llm usage stats for monitoring
4949

5050
if len(sessions) != len(dataset_rows):
@@ -58,17 +58,14 @@ async def reset(self, session: MCPSession) -> Tuple[Any, List[Dict]]:
5858
5959
This is thread-safe and can be called from worker threads.
6060
"""
61-
# Establish a persistent session for each environment.
62-
await self.execution_manager.connection_manager.initialize_session(session)
63-
6461
# Get available tools from MCP server
65-
tool_schemas = await self.execution_manager.connection_manager.discover_tools(session)
62+
tool_schemas = await self.connection_manager.discover_tools(session)
6663

6764
if not self.tool_schemas:
6865
self.tool_schemas = tool_schemas
6966

7067
# PROPER MCP PATTERN: Get initial state from resources during session establishment
71-
initial_observation = await self.execution_manager.connection_manager.get_initial_state(session)
68+
initial_observation = await self.connection_manager.get_initial_state(session)
7269

7370
# Update session state
7471
session.terminated = False
@@ -119,7 +116,7 @@ async def step(self, env_index: int, tool_call: MCPToolCall) -> Tuple[Any, float
119116
)
120117

121118
# Execute the tool call via MCP protocol
122-
observation, reward, done, info = await self.execution_manager.connection_manager.call_tool(
119+
observation, reward, done, info = await self.connection_manager.call_tool(
123120
session, tool_call.tool_name, tool_call.arguments
124121
)
125122

@@ -223,5 +220,6 @@ def _default_formatter(self, template: str, obs: Any, context: Dict) -> Union[st
223220
async def close(self):
224221
"""Closes all MCP sessions."""
225222
print(f"🧹 Closing {self.n} MCP sessions...")
226-
await self.execution_manager.close_sessions(self.sessions)
223+
tasks = [self.connection_manager.close_session(session) for session in self.sessions]
224+
await asyncio.gather(*tasks)
227225
print(f"✅ All MCP sessions closed.")

eval_protocol/mcp_env.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@
5151
from .mcp.session.manager import GeneralMCPVectorEnv
5252
from .models import EvaluationRow
5353
from .types import DatasetRow, MCPSession, MCPToolCall
54+
import asyncio
5455

5556
logger = logging.getLogger(__name__)
5657

5758

58-
def make(
59+
async def make(
5960
env_spec: str,
6061
evaluation_rows: Optional[List[EvaluationRow]] = None,
6162
dataset: Optional[List[Dict]] = None,
@@ -104,17 +105,17 @@ def make(
104105
if evaluation_rows:
105106
for i, row in enumerate(evaluation_rows):
106107
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}
107-
108+
108109
system_message = row.get_system_message()
109110
system_prompt = system_message.content or ""
110-
111+
111112
dataset_entry = {
112113
"id": row.input_metadata.row_id if row.input_metadata and row.input_metadata.row_id else f"task_{i}",
113114
"system_prompt": system_prompt,
114115
"user_prompt_template": dataset_info.get("user_prompt_template", ""),
115116
"environment_context": dataset_info.get("environment_context", {}),
116117
"user_simulation": dataset_info.get("user_simulation", {}),
117-
"evaluation_criteria": dataset_info.get("evaluation_criteria", {})
118+
"evaluation_criteria": dataset_info.get("evaluation_criteria", {}),
118119
}
119120
internal_dataset.append(dataset_entry)
120121
elif dataset:
@@ -198,7 +199,10 @@ def make(
198199
)
199200
sessions.append(session)
200201

201-
return GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
202+
mcp_envs = GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
203+
tasks = [mcp_envs.connection_manager.initialize_session(session) for session in sessions]
204+
await asyncio.gather(*tasks)
205+
return mcp_envs
202206

203207

204208
async def rollout(

0 commit comments

Comments
 (0)