1313"""
1414
1515import asyncio
16+ import dataclasses
1617import hashlib
1718import inspect
1819import json
1920import logging
2021import os
2122import threading
2223from abc import ABC , abstractmethod
24+ from concurrent .futures import ThreadPoolExecutor
25+ from datetime import date , datetime
26+ from enum import Enum
2327from typing import Any , Callable , Dict , Optional , Tuple
2428
2529import uvicorn
2630from mcp .server .fastmcp import Context , FastMCP
31+ from pydantic import BaseModel
2732from starlette .requests import Request
2833from starlette .responses import JSONResponse
2934from uvicorn .middleware .proxy_headers import ProxyHeadersMiddleware
@@ -75,14 +80,23 @@ class McpGym(ABC):
7580 - Environment Implementation: Single-process MCP server per environment
7681 """
7782
78- def __init__ (self , server_name : str , adapter : EnvironmentAdapter , seed : Optional [int ] = None ):
83+ def __init__ (
84+ self ,
85+ server_name : str ,
86+ adapter : EnvironmentAdapter ,
87+ seed : Optional [int ] = None ,
88+ max_workers : Optional [int ] = None ,
89+ ):
7990 """
8091 Initialize the MCP-Gym environment.
8192
8293 Args:
8394 server_name: Name for the MCP server
8495 adapter: Environment adapter instance
8596 seed: Optional seed for reproducible environments
97+ max_workers: Optional maximum number of worker threads for ThreadPoolExecutor.
98+ If None, uses ThreadPoolExecutor default (min(32, (os.cpu_count() or 1) + 4))
99+
86100 """
87101 self .adapter = adapter
88102
@@ -110,6 +124,8 @@ def __init__(self, server_name: str, adapter: EnvironmentAdapter, seed: Optional
110124 "total_reward" : 0.0 ,
111125 }
112126
127+ self .pool = ThreadPoolExecutor (max_workers = max_workers )
128+
113129 # Reset with seed if provided
114130 self .env , self .obs , _info = self ._new_env (seed = seed )
115131
@@ -189,49 +205,7 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
189205 """
190206 session_id = self ._get_session_id (ctx )
191207 print (f"🔍 _get_or_create_session: session_id: { session_id } " )
192-
193- with self .session_lock :
194- if session_id not in self .sessions :
195- print (f"🔍 _get_or_create_session: Creating new session for { session_id } " )
196- # Extract seed from context using proper FastMCP pattern
197- seed = None
198- config = self ._get_default_config ()
199- print (f"🔍 _get_or_create_session: default_config: { config } " )
200-
201- if hasattr (ctx , "session" ) and hasattr (ctx .session , "client_params" ):
202- client_params = ctx .session .client_params
203- if hasattr (client_params , "clientInfo" ):
204- client_info = client_params .clientInfo
205- if client_info and hasattr (client_info , "_extra" ):
206- extra_data = client_info ._extra
207- print (f"🔍 _get_or_create_session: extra_data in session creation: { extra_data } " )
208- if extra_data and isinstance (extra_data , dict ):
209- # Extract seed from client info
210- seed = extra_data .get ("seed" )
211- print (f"🌱 Extracted seed from client_info: { seed } (type: { type (seed )} )" )
212- # Update config with any additional options
213- if "config" in extra_data :
214- config .update (extra_data ["config" ])
215- print (f"🔍 _get_or_create_session: updated config: { config } " )
216-
217- print (f"🔍 _get_or_create_session: About to create environment with seed: { seed } " )
218-
219- env , obs , info = self ._new_env (seed = seed )
220- print (f"🔍 _get_or_create_session: environment created with obs: { obs } , info: { info } " )
221-
222- # Initialize session state
223- self .sessions [session_id ] = {
224- "env" : env ,
225- "obs" : obs ,
226- "session_data" : {}, # Subclasses can store additional data here
227- "session_id" : session_id ,
228- }
229-
230- print (f"🎮 Created new session { session_id [:16 ]} ... with seed { seed } , initial obs: { obs } " )
231- else :
232- print (f"🔍 _get_or_create_session: Returning existing session { session_id } " )
233-
234- return self .sessions [session_id ]
208+ return self .sessions [session_id ]
235209
236210 def _register_session_reset_endpoint (self ):
237211
@@ -243,16 +217,17 @@ async def reset_session_endpoint(request: Request) -> JSONResponse:
243217 print (f"🔍 _register_session_reset_endpoint: Resetting session, session_id: { session_id } , seed: { seed } " )
244218 if not session_id :
245219 return JSONResponse ({"error" : "Missing mcp-session-id header" }, status_code = 400 )
246- with self .session_lock :
247- if session_id in self .sessions :
248- env , obs , _ = self ._new_env (seed = seed )
220+ if session_id in self .sessions :
221+ loop = asyncio .get_running_loop ()
222+ env , obs , info = await loop .run_in_executor (self .pool , self ._new_env , seed )
223+ with self .session_lock :
249224 self .sessions [session_id ] = {
250225 "env" : env ,
251226 "obs" : obs ,
252227 "session_data" : {},
253228 "session_id" : session_id ,
254229 }
255- print (f"🔍 _register_session_reset_endpoint: Finished reset session, session_id: { session_id } " )
230+ print (f"🔍 _register_session_reset_endpoint: Finished reset session, session_id: { session_id } " )
256231 return JSONResponse ({"message" : "Session reset successfully" })
257232
258233 def _discover_and_register_control_plane_endpoints (self ):
@@ -286,29 +261,27 @@ async def endpoint_handler(request: Request) -> JSONResponse:
286261 )
287262
288263 # Get or create session data
264+ session_data = self .sessions .get (session_id )
265+ if not session_data :
266+ if func .__name__ != "get_initial_state_endpoint" :
267+ return JSONResponse (
268+ {"error" : f"Session { session_id } not found" },
269+ status_code = 404 ,
270+ )
271+
272+ loop = asyncio .get_running_loop ()
273+ env , obs , info = await loop .run_in_executor (self .pool , self ._new_env , None )
274+
275+ # Initialize session state with extracted seed from session ID
276+ session_data = {
277+ "env" : env ,
278+ "obs" : obs ,
279+ "session_data" : {}, # Subclasses can store additional data here
280+ "session_id" : session_id ,
281+ }
289282 with self .session_lock :
290- session_data = self .sessions .get (session_id )
291- if not session_data :
292- # For initial state endpoint, we need to create the session
293- # based on the session ID and available information
294- if func .__name__ == "get_initial_state_endpoint" :
295- env , obs , info = self ._new_env (seed = None )
296- # Initialize session state with extracted seed from session ID
297- session_data = {
298- "env" : env ,
299- "obs" : obs ,
300- "session_data" : {}, # Subclasses can store additional data here
301- "session_id" : session_id ,
302- }
303- # Store the session
304- self .sessions [session_id ] = session_data
305- else :
306- return JSONResponse (
307- {"error" : f"Session { session_id } not found" },
308- status_code = 404 ,
309- )
310-
311- # Call the endpoint function with session data
283+ self .sessions [session_id ] = session_data
284+
312285 if inspect .iscoroutinefunction (func ):
313286 result = await func (session_data = session_data )
314287 else :
@@ -356,22 +329,21 @@ def _update_control_plane(self, reward: float, terminated: bool, truncated: bool
356329
357330 def _get_or_create_session_control_plane (self , session_id : str ) -> Dict [str , Any ]:
358331 """Get or create control plane state for a specific session."""
359- with self .session_lock :
360- if session_id not in self .sessions :
361- return {}
362-
363- session_data = self .sessions [session_id ]
364- if "control_plane" not in session_data ["session_data" ]:
365- session_data ["session_data" ]["control_plane" ] = {
366- "reward" : 0.0 ,
367- "terminated" : False ,
368- "truncated" : False ,
369- "info" : {},
370- "step_count" : 0 ,
371- "total_reward" : 0.0 ,
372- }
332+ if session_id not in self .sessions :
333+ raise Exception (f"Session { session_id } not found" )
334+
335+ session_data = self .sessions [session_id ]
336+ if "control_plane" not in session_data ["session_data" ]:
337+ session_data ["session_data" ]["control_plane" ] = {
338+ "reward" : 0.0 ,
339+ "terminated" : False ,
340+ "truncated" : False ,
341+ "info" : {},
342+ "step_count" : 0 ,
343+ "total_reward" : 0.0 ,
344+ }
373345
374- return session_data ["session_data" ]["control_plane" ]
346+ return session_data ["session_data" ]["control_plane" ]
375347
376348 def _update_session_control_plane (
377349 self ,
@@ -396,13 +368,6 @@ def _update_session_control_plane(
396368 f"🎛️ Session { session_id [:16 ]} ... control plane: reward={ reward } , terminated={ terminated } , step={ control_plane ['step_count' ]} , total_reward={ control_plane ['total_reward' ]} "
397369 )
398370
399- def get_control_plane_state (self , session_id : str ) -> Optional [Dict [str , Any ]]:
400- """Get control plane state for a specific session (for rollout system)."""
401- with self .session_lock :
402- if session_id in self .sessions :
403- return self ._get_or_create_session_control_plane (session_id ).copy ()
404- return None
405-
406371 def _execute_environment_step (self , action_int : int ) -> Dict [str , Any ]:
407372 """
408373 Execute environment step and update control plane (single session).
@@ -510,11 +475,11 @@ def get_info_endpoint(self, session_data: Dict[str, Any]) -> Dict[str, Any]:
510475 return control_plane .get ("info" , {})
511476
512477 @control_plane_endpoint ("/control/initial_state" )
513- def get_initial_state_endpoint (self , session_data : Dict [str , Any ]) -> Dict [str , Any ]:
478+ async def get_initial_state_endpoint (self , session_data : Dict [str , Any ]) -> Dict [str , Any ]:
514479 """Get initial state for this session."""
480+ session_id = session_data .get ("session_id" , "unknown" )
515481 env = session_data .get ("env" )
516482 obs = session_data .get ("obs" )
517-
518483 if env and obs is not None :
519484 try :
520485 formatted_obs = self .format_observation (obs , env )
@@ -604,8 +569,8 @@ async def run_with_high_concurrency():
604569 proxy_headers = True ,
605570 forwarded_allow_ips = "*" ,
606571 # HIGH CONCURRENCY SETTINGS
607- limit_concurrency = 200 , # Increase for HTTP endpoints + MCP
608- limit_max_requests = 100000 , # Higher request limit
572+ limit_concurrency = None , # Increase for HTTP endpoints + MCP
573+ limit_max_requests = None , # Higher request limit
609574 timeout_keep_alive = 120 , # Longer keep-alive for control plane
610575 timeout_notify = 180 ,
611576 h11_max_incomplete_event_size = 4 * 1024 * 1024 , # Handle larger events
@@ -624,11 +589,6 @@ def _to_json_serializable(self, obj: Any) -> Any:
624589 Handles Pydantic models, dataclasses, lists, dicts, and primitive types.
625590 This is a utility method that can be used by format_observation implementations.
626591 """
627- import dataclasses
628- from datetime import date , datetime
629- from enum import Enum
630-
631- from pydantic import BaseModel
632592
633593 # Handle None and primitive types
634594 if obj is None or isinstance (obj , (str , int , float , bool )):
0 commit comments