Skip to content

Commit 661f950

Browse files
authored
Merge pull request #9 from reward-protocol/derekx/refactor-environment
Environment Creation moved from GymProductionServer to McpGym
2 parents 009e3b8 + 8ba3a1e commit 661f950

4 files changed

Lines changed: 24 additions & 45 deletions

File tree

examples/lunar_lander_mcp/mcp_server/lunar_lander_mcp_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from reward_kit.mcp import GymProductionServer
3636

37+
# TODO: FAST FOLLOW. refactor this entire file to use McpGym, leaving logic below incorrect for now.
3738

3839
class LunarLanderProdServer(GymProductionServer):
3940
"""LunarLander production server with visual rendering support."""
@@ -188,4 +189,4 @@ def main():
188189

189190

190191
if __name__ == "__main__":
191-
main()
192+
main()

examples/taxi_mcp_complete/mcp_server/taxi_mcp_server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
from mcp.server.fastmcp import Context
2626
from taxi_adapter import TaxiAdapter
2727

28-
from reward_protocol.mcp import GymProductionServer
28+
from reward_kit.mcp import GymProductionServer
2929

30+
# TODO: FAST FOLLOW. refactor this entire file to use McpGym, leaving logic below incorrect for now.
3031

3132
class TaxiProdServer(GymProductionServer):
3233
"""Taxi production server using unified framework."""
@@ -53,6 +54,7 @@ def taxi_move(action: str, ctx: Context) -> Dict[str, Any]:
5354
"""
5455
# Extract seed from client info and reinitialize if needed
5556
self.extract_seed_from_context(ctx)
57+
# note for later, should be: session_data = self._get_or_create_session(ctx)
5658

5759
# Validate action
5860
if not action or not isinstance(action, str):
@@ -169,4 +171,4 @@ def main():
169171

170172

171173
if __name__ == "__main__":
172-
main()
174+
main()

reward_kit/mcp/gym_production_server.py

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ def __init__(self, name: str, adapter: EnvironmentAdapter):
5454
"""
5555
self.adapter = adapter
5656

57-
# For backward compatibility, keep single-session support
58-
self.env, self.obs, _info = self._new_env()
59-
6057
# Multi-session support
6158
self.sessions = (
6259
{}
@@ -74,20 +71,6 @@ def __init__(self, name: str, adapter: EnvironmentAdapter):
7471
self._register_resources()
7572
self._register_tools()
7673

77-
def _new_env(self, seed: Optional[int] = None) -> Tuple[Any, Any, Dict]:
78-
"""Create new environment and return initial state."""
79-
if hasattr(self.adapter, "create_environment_with_seed"):
80-
env, obs, info = self.adapter.create_environment_with_seed(
81-
self.adapter.get_default_config(), seed=seed
82-
)
83-
else:
84-
env = self.adapter.create_environment(self.adapter.get_default_config())
85-
obs, info = self.adapter.reset_environment(env, seed=seed)
86-
return env, obs, info
87-
88-
def _render(self, obs) -> Dict[str, Any]:
89-
"""Format observation using subclass implementation."""
90-
return self.format_observation(obs, self.env)
9174

9275
def _register_resources(self):
9376
"""Register standard MCP resources."""
@@ -250,28 +233,6 @@ def _get_or_create_session(self, ctx: Context) -> Dict[str, Any]:
250233

251234
return self.sessions[session_id]
252235

253-
def extract_seed_from_context(self, ctx: Context) -> Optional[int]:
254-
"""
255-
Extract seed from MCP client info if available.
256-
257-
NOTE: This method is kept for backward compatibility. New code should use
258-
_get_or_create_session() which handles seed extraction automatically.
259-
"""
260-
if hasattr(ctx, "session") and hasattr(ctx.session, "client_params"):
261-
client_params = ctx.session.client_params
262-
if hasattr(client_params, "clientInfo"):
263-
client_info = client_params.clientInfo
264-
if client_info and hasattr(client_info, "_extra"):
265-
extra_data = client_info._extra
266-
if extra_data and isinstance(extra_data, dict):
267-
seed = extra_data.get("seed")
268-
if seed is not None:
269-
print(f"🌱 Reinitializing with seed from client: {seed}")
270-
self.env, self.obs, _info = self._new_env(seed=seed)
271-
return seed
272-
273-
return None
274-
275236
# Abstract methods that subclasses must implement
276237

277238
@abstractmethod
@@ -293,4 +254,4 @@ def run(self, transport: str = "streamable-http", **kwargs):
293254
print("🔗 Initial state resource: game://initial_state")
294255

295256
# Run the server
296-
self.mcp.run(transport=transport, **kwargs)
257+
self.mcp.run(transport=transport, **kwargs)

reward_kit/mcp/mcpgym.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ def __init__(
9898
}
9999

100100
# Reset with seed if provided
101-
if seed is not None:
102-
self.env, self.obs, _info = self._new_env(seed=seed)
101+
self.env, self.obs, _info = self._new_env(seed=seed)
103102

104103
# Discover and register control plane endpoints
105104
self._discover_and_register_control_plane_endpoints()
@@ -434,3 +433,19 @@ def format_observation(obs: Any, env: Any) -> Dict[str, Any]:
434433
Formatted observation dictionary (DATA PLANE ONLY)
435434
"""
436435
pass
436+
437+
def _new_env(self, seed: Optional[int] = None) -> Tuple[Any, Any, Dict]:
438+
"""Create new environment and return initial state."""
439+
config = self.adapter.get_default_config()
440+
441+
try:
442+
env, obs, info = self.adapter.create_environment_with_seed(config, seed=seed)
443+
except AttributeError:
444+
env = self.adapter.create_environment(config)
445+
obs, info = self.adapter.reset_environment(env, seed=seed)
446+
447+
return env, obs, info
448+
449+
def _render(self, obs) -> Dict[str, Any]:
450+
"""Format observation using subclass implementation."""
451+
return self.format_observation(obs, self.env)

0 commit comments

Comments
 (0)