diff --git a/debug_mcp_version.py b/debug_mcp_version.py new file mode 100644 index 0000000..31fda31 --- /dev/null +++ b/debug_mcp_version.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +""" +Debug script to check MCP protocol version handling +""" + +import asyncio +import json +import sys +from typing import Dict, Any, List + +from mcp.server import Server, NotificationOptions +from mcp.server.models import InitializationOptions +import mcp.server.stdio +import mcp.types as types + +# Create server instance +server = Server("test-server") + +@server.list_tools() +async def handle_list_tools() -> List[types.Tool]: + """Return empty tool list for testing""" + return [] + +@server.call_tool() +async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent]: + """Handle tool calls""" + return [types.TextContent(type="text", text="Test response")] + +async def test_initialization(): + """Test MCP initialization to see protocol version""" + print("Testing MCP server initialization...", file=sys.stderr) + + # Initialize server capabilities + server_options = InitializationOptions( + server_name="test-mcp", + server_version="0.1.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={} + ) + ) + + print(f"Server options: {server_options}", file=sys.stderr) + print(f"Capabilities: {server_options.capabilities}", file=sys.stderr) + + # Try to inspect the server object + print(f"Server attributes: {[attr for attr in dir(server) if not attr.startswith('_')]}", file=sys.stderr) + +if __name__ == "__main__": + asyncio.run(test_initialization()) \ No newline at end of file diff --git a/src/mujoco_mcp/mcp_server.py b/src/mujoco_mcp/mcp_server.py index 23d3b0c..00f8130 100644 --- a/src/mujoco_mcp/mcp_server.py +++ b/src/mujoco_mcp/mcp_server.py @@ -2,6 +2,7 @@ """ MuJoCo MCP Server for stdio transport Production-ready MCP server that works with Claude Desktop and other MCP clients +MCP Protocol Version: 2024-11-05 """ import asyncio @@ -18,6 +19,9 @@ from .version import __version__ from .viewer_client import MuJoCoViewerClient as ViewerClient +# MCP Protocol constants +MCP_PROTOCOL_VERSION = "2024-11-05" + # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("mujoco-mcp") @@ -26,8 +30,7 @@ server = Server("mujoco-mcp") # Global viewer client -viewer_client: ViewerClient | None = None - +viewer_client: Optional[ViewerClient] = None @server.list_tools() async def handle_list_tools() -> List[types.Tool]: @@ -36,126 +39,140 @@ async def handle_list_tools() -> List[types.Tool]: types.Tool( name="get_server_info", description="Get information about the MuJoCo MCP server", - inputSchema={"type": "object", "properties": {}, "required": []}, + inputSchema={ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {}, + "required": [], + "additionalProperties": False + } ), types.Tool( name="create_scene", description="Create a physics simulation scene", inputSchema={ + "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "properties": { "scene_type": { "type": "string", "description": "Type of scene to create", - "enum": ["pendulum", "double_pendulum", "cart_pole", "arm"], + "enum": ["pendulum", "double_pendulum", "cart_pole", "arm"] } }, "required": ["scene_type"], - }, + "additionalProperties": False + } ), types.Tool( name="step_simulation", description="Step the physics simulation forward", inputSchema={ + "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "properties": { - "model_id": {"type": "string", "description": "ID of the model to step"}, + "model_id": { + "type": "string", + "description": "ID of the model to step" + }, "steps": { "type": "integer", "description": "Number of simulation steps", "default": 1, - }, + "minimum": 1 + } }, "required": ["model_id"], - }, + "additionalProperties": False + } ), types.Tool( name="get_state", description="Get current state of the simulation", inputSchema={ + "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "properties": { "model_id": { "type": "string", - "description": "ID of the model to get state from", + "description": "ID of the model to get state from" } }, "required": ["model_id"], - }, + "additionalProperties": False + } ), types.Tool( name="reset_simulation", description="Reset simulation to initial state", inputSchema={ + "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "properties": { - "model_id": {"type": "string", "description": "ID of the model to reset"} + "model_id": { + "type": "string", + "description": "ID of the model to reset" + } }, "required": ["model_id"], - }, + "additionalProperties": False + } ), types.Tool( name="close_viewer", description="Close the MuJoCo viewer window", inputSchema={ + "$schema": "http://json-schema.org/draft-07/schema#", "type": "object", "properties": { - "model_id": {"type": "string", "description": "ID of the model viewer to close"} + "model_id": { + "type": "string", + "description": "ID of the model viewer to close" + } }, "required": ["model_id"], - }, - ), + "additionalProperties": False + } + ) ] - @server.call_tool() async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent]: - """Handle tool calls""" + """Handle tool calls with MCP-compliant responses""" global viewer_client - + + # Log the tool call for debugging + logger.debug(f"Tool call: {name} with arguments: {arguments}") + try: if name == "get_server_info": - return [ - types.TextContent( - type="text", - text=json.dumps( - { - "name": "MuJoCo MCP Server", - "version": __version__, - "description": "Control MuJoCo physics simulations through MCP", - "status": "ready", - "capabilities": [ - "create_scene", - "step_simulation", - "get_state", - "reset", - "close_viewer", - ], - }, - indent=2, - ), - ) - ] - + return [types.TextContent( + type="text", + text=json.dumps({ + "name": "MuJoCo MCP Server", + "version": __version__, + "description": "Control MuJoCo physics simulations through MCP", + "status": "ready", + "capabilities": ["create_scene", "step_simulation", "get_state", "reset", "close_viewer"] + }, indent=2) + )] + elif name == "create_scene": scene_type = arguments.get("scene_type", "pendulum") - + # Initialize viewer client if not exists if not viewer_client: viewer_client = ViewerClient() - + # Connect to viewer server if not viewer_client.connected: success = viewer_client.connect() if not success: - return [ - types.TextContent( - type="text", - text="❌ Failed to connect to MuJoCo viewer server. " - "Please start `mujoco-mcp-viewer` first.", - ) - ] - + return [types.TextContent( + type="text", + text="❌ Failed to connect to MuJoCo viewer server. Please start `mujoco-mcp-viewer` first." + )] + # Map scene types to model XML scene_models = { "pendulum": """ @@ -201,155 +218,172 @@ async def handle_call_tool(name: str, arguments: Dict[str, Any]) -> List[types.T - """, + """ } - + if scene_type not in scene_models: - return [ - types.TextContent( - type="text", - text=f"❌ Unknown scene type: {scene_type}. Available: {', '.join(scene_models.keys())}", - ) - ] - + return [types.TextContent( + type="text", + text=f"❌ Unknown scene type: {scene_type}. Available: {', '.join(scene_models.keys())}" + )] + # Load the model - response = viewer_client.send_command( - { - "type": "load_model", - "model_id": scene_type, - "model_xml": scene_models[scene_type], - } - ) - + response = viewer_client.send_command({ + "type": "load_model", + "model_id": scene_type, + "model_xml": scene_models[scene_type] + }) + if response.get("success"): - return [ - types.TextContent( - type="text", - text=f"✅ Created {scene_type} scene successfully! Viewer window opened.", - ) - ] + return [types.TextContent( + type="text", + text=f"✅ Created {scene_type} scene successfully! Viewer window opened." + )] else: - return [ - types.TextContent( - type="text", - text=f"❌ Failed to create scene: {response.get('error', 'Unknown error')}", - ) - ] - + return [types.TextContent( + type="text", + text=f"❌ Failed to create scene: {response.get('error', 'Unknown error')}" + )] + elif name == "step_simulation": model_id = arguments.get("model_id") steps = arguments.get("steps", 1) - + if not viewer_client or not viewer_client.connected: - return [ - types.TextContent( - type="text", text="❌ No active viewer connection. Create a scene first." - ) - ] - + return [types.TextContent( + type="text", + text="❌ No active viewer connection. Create a scene first." + )] + # The viewer server doesn't have a direct step_simulation command # It automatically runs the simulation, so we just return success response = {"success": True, "message": f"Simulation running for model {model_id}"} - - return [ - types.TextContent( - type="text", - text=f"⏩ Stepped simulation {steps} steps" - if response.get("success") - else f"❌ Step failed: {response.get('error')}", - ) - ] - + + return [types.TextContent( + type="text", + text=f"⏩ Stepped simulation {steps} steps" if response.get("success") + else f"❌ Step failed: {response.get('error')}" + )] + elif name == "get_state": model_id = arguments.get("model_id") - + if not viewer_client or not viewer_client.connected: - return [ - types.TextContent( - type="text", text="❌ No active viewer connection. Create a scene first." - ) - ] - - response = viewer_client.send_command({"type": "get_state", "model_id": model_id}) - + return [types.TextContent( + type="text", + text="❌ No active viewer connection. Create a scene first." + )] + + response = viewer_client.send_command({ + "type": "get_state", + "model_id": model_id + }) + if response.get("success"): state = response.get("state", {}) - return [types.TextContent(type="text", text=json.dumps(state, indent=2))] + return [types.TextContent( + type="text", + text=json.dumps(state, indent=2) + )] else: - return [ - types.TextContent( - type="text", text=f"❌ Failed to get state: {response.get('error')}" - ) - ] - + return [types.TextContent( + type="text", + text=f"❌ Failed to get state: {response.get('error')}" + )] + elif name == "reset_simulation": model_id = arguments.get("model_id") - + if not viewer_client or not viewer_client.connected: - return [ - types.TextContent( - type="text", text="❌ No active viewer connection. Create a scene first." - ) - ] - - response = viewer_client.send_command({"type": "reset", "model_id": model_id}) - - return [ - types.TextContent( + return [types.TextContent( type="text", - text="🔄 Simulation reset to initial state" - if response.get("success") - else f"❌ Reset failed: {response.get('error')}", - ) - ] - + text="❌ No active viewer connection. Create a scene first." + )] + + response = viewer_client.send_command({ + "type": "reset", + "model_id": model_id + }) + + return [types.TextContent( + type="text", + text="🔄 Simulation reset to initial state" if response.get("success") + else f"❌ Reset failed: {response.get('error')}" + )] + elif name == "close_viewer": model_id = arguments.get("model_id") - + if not viewer_client or not viewer_client.connected: - return [types.TextContent(type="text", text="❌ No active viewer connection.")] - - response = viewer_client.send_command({"type": "close_model", "model_id": model_id}) - + return [types.TextContent( + type="text", + text="❌ No active viewer connection." + )] + + response = viewer_client.send_command({ + "type": "close_model", + "model_id": model_id + }) + # Close our connection too if viewer_client: viewer_client.disconnect() viewer_client = None - - return [ - types.TextContent( - type="text", - text="❌ Viewer closed" - if response.get("success") - else f"❌ Failed to close: {response.get('error')}", - ) - ] - + + return [types.TextContent( + type="text", + text="❌ Viewer closed" if response.get("success") + else f"❌ Failed to close: {response.get('error')}" + )] + else: - return [types.TextContent(type="text", text=f"❌ Unknown tool: {name}")] - + return [types.TextContent( + type="text", + text=f"❌ Unknown tool: {name}" + )] + except Exception as e: logger.exception(f"Error in tool {name}") - return [types.TextContent(type="text", text=f"❌ Error: {str(e)}")] - + return [types.TextContent( + type="text", + text=f"❌ Error: {str(e)}" + )] async def main(): """Main entry point for MCP server""" logger.info(f"Starting MuJoCo MCP Server v{__version__}") - - # Initialize server capabilities + logger.info(f"MCP Protocol Version: {MCP_PROTOCOL_VERSION}") + + # Initialize server capabilities with enhanced configuration + capabilities = server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={} + ) + server_options = InitializationOptions( server_name="mujoco-mcp", server_version=__version__, - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), experimental_capabilities={} - ), + capabilities=capabilities, + instructions="MuJoCo physics simulation server with viewer support. " + f"Implements MCP Protocol {MCP_PROTOCOL_VERSION}. " + "Provides tools for creating scenes, controlling simulation, and managing state." ) - + + logger.info(f"Server capabilities: {capabilities}") + logger.info("MCP server initialization complete") + # Run server with stdio transport - async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): - await server.run(read_stream, write_stream, server_options) - + try: + async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): + logger.info("Starting MCP server stdio transport") + await server.run( + read_stream, + write_stream, + server_options + ) + except Exception as e: + logger.error(f"MCP server error: {e}") + raise if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(main()) \ No newline at end of file diff --git a/test_mcp_compliance_fixes.py b/test_mcp_compliance_fixes.py new file mode 100644 index 0000000..0ffdc21 --- /dev/null +++ b/test_mcp_compliance_fixes.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +""" +Test script to validate MCP protocol compliance fixes +Tests the critical fixes implemented for MCP protocol version 2024-11-05 +""" + +import asyncio +import json +import sys +from typing import Dict, Any, List +import jsonschema +from jsonschema import validate + +# Import the MCP server components +from src.mujoco_mcp.mcp_server import ( + handle_list_tools, + handle_call_tool, + MCP_PROTOCOL_VERSION, + server, + main +) + +def test_protocol_version(): + """Test Critical Fix #1: Protocol Version Alignment""" + print("Testing Protocol Version...") + + # Check that protocol version is set correctly + assert MCP_PROTOCOL_VERSION == "2024-11-05", f"Expected '2024-11-05', got '{MCP_PROTOCOL_VERSION}'" + print("✅ Protocol version is correctly set to 2024-11-05") + + return True + +async def test_tool_schemas(): + """Test Critical Fix #3: Tool Schema Validation (JSON Schema Draft 7)""" + print("Testing Tool Schema Compliance...") + + tools = await handle_list_tools() + + for tool in tools: + schema = tool.inputSchema + + # Check for $schema field + assert "$schema" in schema, f"Tool '{tool.name}' missing $schema field" + assert schema["$schema"] == "http://json-schema.org/draft-07/schema#", \ + f"Tool '{tool.name}' has incorrect $schema" + + # Check for additionalProperties + assert "additionalProperties" in schema, f"Tool '{tool.name}' missing additionalProperties" + assert schema["additionalProperties"] == False, \ + f"Tool '{tool.name}' should set additionalProperties to False" + + # Validate schema structure + try: + # This validates that our schema is a valid JSON Schema + jsonschema.Draft7Validator.check_schema(schema) + print(f"✅ Tool '{tool.name}' has valid JSON Schema Draft 7") + except jsonschema.SchemaError as e: + print(f"❌ Tool '{tool.name}' has invalid schema: {e}") + return False + + return True + +async def test_server_initialization(): + """Test Critical Fix #2: Server Initialization""" + print("Testing Server Initialization...") + + # Import required MCP components + from mcp.server import NotificationOptions + + # Test that server has proper name and capabilities + capabilities = server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={} + ) + + assert capabilities is not None, "Server capabilities should not be None" + assert hasattr(capabilities, 'tools'), "Server should have tools capability" + assert capabilities.tools is not None, "Tools capability should not be None" + + print("✅ Server initialization appears correct") + return True + +async def test_response_format(): + """Test Critical Fix #4: Response Format Consistency""" + print("Testing Response Format...") + + # Test get_server_info response + response = await handle_call_tool("get_server_info", {}) + + assert len(response) == 1, "Should return exactly one response item" + assert response[0].type == "text", "Response should be of type 'text'" + + # Validate that the response text is valid JSON + try: + data = json.loads(response[0].text) + assert "name" in data, "Server info should include name" + assert "version" in data, "Server info should include version" + assert "status" in data, "Server info should include status" + print("✅ Response format is consistent and valid") + except json.JSONDecodeError as e: + print(f"❌ Invalid JSON in response: {e}") + return False + + return True + +def test_schema_validation_examples(): + """Test that our tool schemas can validate actual inputs""" + print("Testing Schema Validation with Examples...") + + # Test create_scene schema + create_scene_schema = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "scene_type": { + "type": "string", + "description": "Type of scene to create", + "enum": ["pendulum", "double_pendulum", "cart_pole", "arm"] + } + }, + "required": ["scene_type"], + "additionalProperties": False + } + + # Valid input + valid_input = {"scene_type": "pendulum"} + try: + validate(instance=valid_input, schema=create_scene_schema) + print("✅ Valid input passes schema validation") + except jsonschema.ValidationError as e: + print(f"❌ Valid input failed validation: {e}") + return False + + # Invalid input (missing required field) + invalid_input = {} + try: + validate(instance=invalid_input, schema=create_scene_schema) + print("❌ Invalid input should have failed validation") + return False + except jsonschema.ValidationError: + print("✅ Invalid input correctly rejected") + + # Invalid input (wrong enum value) + invalid_enum_input = {"scene_type": "invalid_scene"} + try: + validate(instance=invalid_enum_input, schema=create_scene_schema) + print("❌ Invalid enum value should have failed validation") + return False + except jsonschema.ValidationError: + print("✅ Invalid enum value correctly rejected") + + return True + +async def run_all_tests(): + """Run all compliance tests""" + print("="*60) + print("MCP Protocol Compliance Test Suite") + print("Testing fixes for MCP Protocol Version 2024-11-05") + print("="*60) + + test_results = [] + + # Test 1: Protocol Version + try: + result = test_protocol_version() + test_results.append(("Protocol Version", result)) + except Exception as e: + print(f"❌ Protocol version test failed: {e}") + test_results.append(("Protocol Version", False)) + + # Test 2: Tool Schemas + try: + result = await test_tool_schemas() + test_results.append(("Tool Schemas", result)) + except Exception as e: + print(f"❌ Tool schema test failed: {e}") + test_results.append(("Tool Schemas", False)) + + # Test 3: Server Initialization + try: + result = await test_server_initialization() + test_results.append(("Server Initialization", result)) + except Exception as e: + print(f"❌ Server initialization test failed: {e}") + test_results.append(("Server Initialization", False)) + + # Test 4: Response Format + try: + result = await test_response_format() + test_results.append(("Response Format", result)) + except Exception as e: + print(f"❌ Response format test failed: {e}") + test_results.append(("Response Format", False)) + + # Test 5: Schema Validation Examples + try: + result = test_schema_validation_examples() + test_results.append(("Schema Validation", result)) + except Exception as e: + print(f"❌ Schema validation test failed: {e}") + test_results.append(("Schema Validation", False)) + + # Summary + print("\n" + "="*60) + print("TEST RESULTS SUMMARY") + print("="*60) + + passed = 0 + total = len(test_results) + + for test_name, result in test_results: + status = "✅ PASS" if result else "❌ FAIL" + print(f"{test_name:20} : {status}") + if result: + passed += 1 + + print(f"\nPassed: {passed}/{total}") + + if passed == total: + print("🎉 All tests passed! MCP compliance fixes are working.") + return True + else: + print("⚠️ Some tests failed. Please review the fixes.") + return False + +if __name__ == "__main__": + success = asyncio.run(run_all_tests()) + sys.exit(0 if success else 1) \ No newline at end of file