Skip to content

Commit a4cb45b

Browse files
Fix critical issues identified in PR review
This commit addresses all critical and important issues found during the comprehensive PR review by three specialized agents. ## Critical Fixes 1. **Fixed RobotState immutability conflict** (multi_robot_coordinator.py) - Removed numpy array immutability that conflicted with update_robot_state() - Arrays are now mutable to allow state updates as intended - Updated docstring to clarify arrays remain mutable 2. **Fixed CoordinatedTask field definition** (multi_robot_coordinator.py) - Moved completion_callback field before __post_init__ method - Was incorrectly defined after __post_init__, causing syntax error 3. **Fixed state extraction in RL environment** (rl_integration.py) - Changed from response.get("state", {}).get("qpos") to response.get("qpos") - Server returns qpos/qvel directly in response, not nested under "state" - Added comment explaining the correct structure 4. **Added missing logger** (rl_integration.py) - Added module-level logger = logging.getLogger(__name__) - Fixes NameError when logger.error() was called at line 692 5. **Added error handling to ModelViewer.__init__** (mujoco_viewer_server.py) - Wrapped model loading in try/except with specific error types - Added context-rich error messages for debugging - Handles FileNotFoundError, generic model loading errors separately - Added error handling for MjData creation and viewer launch 6. **Replaced dangerous BaseException suppression** (mujoco_viewer_server.py) - Replaced contextlib.suppress(BaseException) with specific exception types - Never suppresses KeyboardInterrupt or SystemExit - Added proper error logging for cleanup failures - Thread timeout warning when simulation thread doesn't terminate ## Important Fixes 7. **Added thread safety to _handle_ping** (mujoco_viewer_server.py) - Acquire viewer_lock before accessing current_viewer and current_model_id - Prevents race conditions with concurrent model loading 8. **Improved exception handling in handle_command** (mujoco_viewer_server.py) - Distinguish between expected errors (KeyError, ValueError, TypeError) - Handle RuntimeError separately (expected runtime failures) - Log unexpected exceptions with full stack traces - Better error messages for users vs. bugs 9. **Fixed connection state in viewer_client** (viewer_client.py) - Mark connection as failed for JSONDecodeError and UnicodeDecodeError - Previously only OSError marked connection as failed - Prevents continued attempts to use corrupted connections - Updated docstring to reflect ValueError instead of JSONDecodeError ## Impact These fixes address: - 2 critical bugs that would cause runtime failures - 1 syntax error in dataclass definition - 1 missing import causing NameError - 2 dangerous exception handling patterns - 3 thread safety and error handling improvements All changes preserve functionality while significantly improving: - Error handling robustness - Thread safety - Error message clarity - Connection state consistency Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent eff6240 commit a4cb45b

4 files changed

Lines changed: 69 additions & 31 deletions

File tree

mujoco_viewer_server.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,34 @@ def __init__(self, model_id: str, model_source: str):
4646
self.created_time = time.time()
4747

4848
# Load model - supports file path or XML string
49-
if os.path.exists(model_source):
50-
# If it's a file path, use from_xml_path to load
51-
# (so relative paths are resolved correctly)
52-
self.model = mujoco.MjModel.from_xml_path(model_source)
53-
else:
54-
# Otherwise assume it's an XML string
55-
self.model = mujoco.MjModel.from_xml_string(model_source)
49+
# Paths are resolved relative to the XML file's directory
50+
try:
51+
if os.path.exists(model_source):
52+
logger.info(f"Loading model {model_id} from file: {model_source}")
53+
self.model = mujoco.MjModel.from_xml_path(model_source)
54+
else:
55+
logger.info(f"Loading model {model_id} from XML string")
56+
self.model = mujoco.MjModel.from_xml_string(model_source)
57+
except FileNotFoundError as e:
58+
logger.error(f"Model file not found for {model_id}: {model_source}")
59+
raise RuntimeError(f"Failed to load model {model_id}: file not found at {model_source}") from e
60+
except Exception as e:
61+
logger.error(f"Failed to load MuJoCo model {model_id}: {e}")
62+
raise RuntimeError(f"Failed to load model {model_id}: {e}") from e
5663

57-
self.data = mujoco.MjData(self.model)
64+
# Create simulation data
65+
try:
66+
self.data = mujoco.MjData(self.model)
67+
except Exception as e:
68+
logger.error(f"Failed to create MjData for model {model_id}: {e}")
69+
raise RuntimeError(f"Failed to initialize simulation data for {model_id}: {e}") from e
5870

5971
# Start viewer
60-
self.viewer = mujoco.viewer.launch_passive(self.model, self.data)
72+
try:
73+
self.viewer = mujoco.viewer.launch_passive(self.model, self.data)
74+
except Exception as e:
75+
logger.error(f"Failed to launch viewer for model {model_id}: {e}")
76+
raise RuntimeError(f"Failed to start viewer for {model_id}: {e}") from e
6177

6278
# Start simulation loop
6379
self.simulation_running = True
@@ -107,13 +123,25 @@ def close(self):
107123
self.viewer.close()
108124
elif hasattr(self.viewer, "_window") and self.viewer._window:
109125
# For older MuJoCo versions, try to close the window directly
110-
with contextlib.suppress(builtins.BaseException):
126+
try:
111127
self.viewer._window.close()
128+
except (AttributeError, RuntimeError) as e:
129+
logger.debug(f"Failed to close viewer window for {self.model_id}: {e}")
130+
112131
# Wait for simulation thread to finish
113132
if hasattr(self, "sim_thread") and self.sim_thread.is_alive():
114133
self.sim_thread.join(timeout=2.0)
115-
except Exception as e:
134+
if self.sim_thread.is_alive():
135+
logger.warning(f"Simulation thread for {self.model_id} did not terminate within timeout")
136+
except KeyboardInterrupt:
137+
# Never suppress user interrupts
138+
raise
139+
except (AttributeError, RuntimeError, OSError) as e:
140+
# Expected errors during cleanup
116141
logger.warning(f"Error closing viewer for {self.model_id}: {e}")
142+
except Exception as e:
143+
# Unexpected errors should be logged as errors
144+
logger.error(f"Unexpected error closing viewer for {self.model_id}: {e}")
117145
finally:
118146
self.viewer = None
119147
logger.info(f"Closed ModelViewer for {self.model_id}")
@@ -160,11 +188,21 @@ def handle_command(self, command: Dict[str, Any]) -> Dict[str, Any]:
160188
handler = self._command_handlers.get(cmd_type)
161189
if handler:
162190
return handler(command)
191+
logger.warning(f"Unknown command type received: {cmd_type}")
163192
return {"success": False, "error": f"Unknown command: {cmd_type}"}
164193

165-
except Exception as e:
166-
logger.exception(f"Error handling command {cmd_type}: {e}")
194+
except (KeyError, ValueError, TypeError) as e:
195+
# Expected parameter validation errors
196+
logger.warning(f"Invalid parameters for command {cmd_type}: {e}")
197+
return {"success": False, "error": f"Invalid parameters: {e}"}
198+
except RuntimeError as e:
199+
# Expected runtime errors (model loading failures, etc.)
200+
logger.error(f"Runtime error handling command {cmd_type}: {e}")
167201
return {"success": False, "error": str(e)}
202+
except Exception as e:
203+
# Unexpected errors - these indicate bugs
204+
logger.exception(f"Unexpected error handling command {cmd_type}: {e}")
205+
return {"success": False, "error": f"Internal server error: {str(e)}"}
168206

169207
def _check_viewer_available(self, model_id: str | None) -> Dict[str, Any] | None:
170208
"""Check if viewer is available for the given model. Returns error dict or None if OK."""
@@ -286,12 +324,14 @@ def _handle_list_models(self, command: Dict[str, Any]) -> Dict[str, Any]:
286324

287325
def _handle_ping(self, command: Dict[str, Any]) -> Dict[str, Any]:
288326
"""Ping the server."""
289-
models_count = 1 if self.current_viewer else 0
327+
with self.viewer_lock:
328+
models_count = 1 if self.current_viewer else 0
329+
current_model = self.current_model_id
290330
return {
291331
"success": True,
292332
"pong": True,
293333
"models_count": models_count,
294-
"current_model": self.current_model_id,
334+
"current_model": current_model,
295335
"server_running": self.running,
296336
"server_info": {
297337
"version": "0.7.4",

src/mujoco_mcp/multi_robot_coordinator.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,16 @@ class RobotState:
5959
last_update: float = field(default_factory=time.time)
6060

6161
def __post_init__(self):
62-
"""Validate robot state dimensions and make arrays immutable."""
62+
"""Validate robot state dimensions.
63+
64+
Note: Arrays are kept mutable to allow state updates via update_robot_state().
65+
"""
6366
if len(self.joint_positions) != len(self.joint_velocities):
6467
raise ValueError(
6568
f"joint_positions length ({len(self.joint_positions)}) must match "
6669
f"joint_velocities length ({len(self.joint_velocities)})"
6770
)
6871

69-
# Make numpy arrays immutable
70-
self.joint_positions.flags.writeable = False
71-
self.joint_velocities.flags.writeable = False
72-
if self.end_effector_pos is not None:
73-
self.end_effector_pos.flags.writeable = False
74-
if self.end_effector_vel is not None:
75-
self.end_effector_vel.flags.writeable = False
76-
7772
def is_stale(self, timeout: float = 1.0) -> bool:
7873
"""Check if state is stale"""
7974
return time.time() - self.last_update > timeout
@@ -91,14 +86,14 @@ class CoordinatedTask:
9186
timeout: float = 30.0
9287
status: TaskStatus = TaskStatus.PENDING
9388
start_time: float | None = None
89+
completion_callback: Callable | None = None
9490

9591
def __post_init__(self):
9692
"""Validate coordinated task parameters."""
9793
if not self.robots:
9894
raise ValueError("robots list cannot be empty")
9995
if self.timeout <= 0:
10096
raise ValueError(f"timeout must be positive, got {self.timeout}")
101-
completion_callback: Callable | None = None
10297

10398

10499
class CollisionChecker:

src/mujoco_mcp/rl_integration.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from .viewer_client import MuJoCoViewerClient
2020
from .sensor_feedback import SensorManager
2121

22+
logger = logging.getLogger(__name__)
23+
2224

2325
class ActionSpaceType(Enum):
2426
"""Types of action spaces for RL environments."""
@@ -669,9 +671,9 @@ def _get_observation(self) -> np.ndarray:
669671
response = self.viewer_client.send_command({"type": "get_state", "model_id": self.model_id})
670672

671673
if response.get("success"):
672-
state = response.get("state", {})
673-
qpos = np.array(state.get("qpos", []))
674-
qvel = np.array(state.get("qvel", []))
674+
# Extract qpos and qvel directly from response (not nested under "state")
675+
qpos = np.array(response.get("qpos", []))
676+
qvel = np.array(response.get("qvel", []))
675677

676678
# Combine position and velocity
677679
observation = np.concatenate([qpos, qvel])

src/mujoco_mcp/viewer_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ def send_command(self, command: Dict[str, Any]) -> Dict[str, Any]:
9393
9494
Raises:
9595
ConnectionError: If not connected to viewer server.
96-
ValueError: If response is too large (>1MB).
97-
json.JSONDecodeError: If response is not valid JSON.
96+
ValueError: If response is too large (>1MB), not valid JSON, or cannot be decoded as UTF-8.
9897
OSError: If socket communication fails.
9998
"""
10099
if not self.connected or not self.socket:
@@ -134,9 +133,11 @@ def send_command(self, command: Dict[str, Any]) -> Dict[str, Any]:
134133
raise OSError(f"Failed to communicate with viewer server: {e}") from e
135134
except json.JSONDecodeError as e:
136135
logger.exception(f"Invalid JSON response: {e}")
137-
raise
136+
self.connected = False # Connection is likely corrupted
137+
raise ValueError(f"Server returned invalid JSON: {e}") from e
138138
except UnicodeDecodeError as e:
139139
logger.exception(f"Response decode error: {e}")
140+
self.connected = False # Connection is likely corrupted
140141
raise ValueError(f"Failed to decode server response as UTF-8: {e}") from e
141142

142143
def ping(self) -> bool:

0 commit comments

Comments
 (0)