2424from concurrent .futures import ThreadPoolExecutor
2525from datetime import date , datetime
2626from enum import Enum
27- from typing import Any , Callable , Dict , Optional , Tuple
27+ from typing import Any , Callable , Dict , Optional , Tuple , Literal , cast
2828
2929import uvicorn
3030from mcp .server .fastmcp import Context , FastMCP
@@ -146,22 +146,20 @@ def _get_session_id(self, ctx: Context) -> str:
146146 print (f"🔍 _get_session_id: hasattr(ctx, 'session'): { hasattr (ctx , 'session' )} " )
147147
148148 # Use stable session ID based on client info (following simulation_server.py pattern)
149- if hasattr (ctx , "session" ) and hasattr ( ctx . session , "client_params" ) :
150- client_params = ctx .session . client_params
149+ if hasattr (ctx , "session" ):
150+ client_params = getattr ( ctx .session , " client_params" , None )
151151 print (f"🔍 _get_session_id: client_params type: { type (client_params )} " )
152- print (f"🔍 _get_session_id: hasattr(client_params, 'clientInfo'): { hasattr (client_params , 'clientInfo' )} " )
153-
154- if hasattr (client_params , "clientInfo" ):
155- client_info = client_params .clientInfo
152+ if client_params is not None and hasattr (client_params , "clientInfo" ):
153+ client_info = getattr (client_params , "clientInfo" , None )
156154 print (f"🔍 _get_session_id: client_info: { client_info } " )
157- print (f"🔍 _get_session_id: hasattr(client_info, '_extra'): { hasattr (client_info , '_extra' )} " )
158155
159- if client_info and hasattr (client_info , "_extra" ):
160- extra_data = client_info ._extra
156+ if client_info is not None :
157+ # Access private _extra with a cast to satisfy type checker
158+ extra_data = cast (Any , getattr (client_info , "_extra" , None ))
161159 print (f"🔍 _get_session_id: extra_data: { extra_data } " )
162160 print (f"🔍 _get_session_id: extra_data type: { type (extra_data )} " )
163161
164- if extra_data and isinstance (extra_data , dict ):
162+ if isinstance (extra_data , dict ):
165163 # use the client generated session id
166164 if "session_id" in extra_data :
167165 print (f"🔍 _get_session_id: using client generated session_id: { extra_data ['session_id' ]} " )
@@ -181,8 +179,8 @@ def _get_session_id(self, ctx: Context) -> str:
181179 "config" : config_value ,
182180 "dataset_row_id" : dataset_row_id_value ,
183181 "model_id" : model_id_value ,
184- "name" : client_info . name ,
185- "version" : client_info . version ,
182+ "name" : getattr ( client_info , " name" , None ) ,
183+ "version" : getattr ( client_info , " version" , None ) ,
186184 }
187185
188186 print (f"🔍 _get_session_id: stable_data: { stable_data } " )
@@ -205,6 +203,15 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
205203 """
206204 session_id = self ._get_session_id (ctx )
207205 print (f"🔍 _get_or_create_session: session_id: { session_id } " )
206+ if session_id not in self .sessions :
207+ env , obs , info = self ._new_env (seed = None )
208+ with self .session_lock :
209+ self .sessions [session_id ] = {
210+ "env" : env ,
211+ "obs" : obs ,
212+ "session_data" : {},
213+ "session_id" : session_id ,
214+ }
208215 return self .sessions [session_id ]
209216
210217 def _register_session_reset_endpoint (self ):
@@ -400,6 +407,15 @@ def _execute_session_environment_step(self, session_id: str, action: Any) -> Dic
400407 Returns:
401408 Data plane response (observation only, no rewards)
402409 """
410+ if session_id not in self .sessions :
411+ env , obs , info = self ._new_env (seed = None )
412+ with self .session_lock :
413+ self .sessions [session_id ] = {
414+ "env" : env ,
415+ "obs" : obs ,
416+ "session_data" : {},
417+ "session_id" : session_id ,
418+ }
403419 session_data = self .sessions [session_id ]
404420 env = session_data ["env" ]
405421
@@ -558,7 +574,8 @@ async def run_with_high_concurrency():
558574 if not kwargs .get ("redirect_slashes" , True ) and hasattr (starlette_app , "router" ):
559575 starlette_app .router .redirect_slashes = False
560576
561- starlette_app .add_middleware (ProxyHeadersMiddleware , trusted_hosts = "*" )
577+ # Add middleware with proper type cast to satisfy basedpyright
578+ starlette_app .add_middleware (cast (Any , ProxyHeadersMiddleware ), trusted_hosts = "*" )
562579
563580 config = uvicorn .Config (
564581 starlette_app ,
@@ -580,7 +597,15 @@ async def run_with_high_concurrency():
580597 asyncio .run (run_with_high_concurrency ())
581598 else :
582599 # Use default FastMCP run for other transports
583- self .mcp .run (transport = transport , ** kwargs )
600+ # Constrain transport to the allowed literal values for type checker
601+ allowed_transport : Literal ["stdio" , "sse" , "streamable-http" ]
602+ if transport in ("stdio" , "sse" , "streamable-http" ):
603+ allowed_transport = cast (Literal ["stdio" , "sse" , "streamable-http" ], transport )
604+ else :
605+ # Default to streamable-http if unknown
606+ allowed_transport = cast (Literal ["stdio" , "sse" , "streamable-http" ], "streamable-http" )
607+
608+ self .mcp .run (transport = allowed_transport , ** kwargs )
584609
585610 def _to_json_serializable (self , obj : Any ) -> Any :
586611 """Convert any object to JSON-serializable format.
@@ -607,7 +632,8 @@ def _to_json_serializable(self, obj: Any) -> Any:
607632
608633 # Handle dataclasses
609634 elif dataclasses .is_dataclass (obj ):
610- return dataclasses .asdict (obj )
635+ # Cast for type checker because protocol uses ClassVar on __dataclass_fields__
636+ return dataclasses .asdict (cast (Any , obj ))
611637
612638 # Handle dictionaries
613639 elif isinstance (obj , dict ):
0 commit comments