Skip to content

Commit 1165ff1

Browse files
committed
MINIMAL REPRO
1 parent 8f4557b commit 1165ff1

File tree

2 files changed

+220
-0
lines changed

2 files changed

+220
-0
lines changed

test_burst_client.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Burst Client Test - Simulates 50 threads calling envs.reset() -> get_initial_state
4+
Exact pattern: _execute_rollout() -> envs.reset() -> get_initial_state -> client.get()
5+
"""
6+
7+
import asyncio
8+
import threading
9+
import time
10+
from typing import Any, Dict, List
11+
12+
import httpx
13+
14+
15+
class EnvResetClient:
16+
"""
17+
Simulates the exact pattern from your code:
18+
50 threads -> _execute_rollout() -> envs.reset() -> get_initial_state -> client.get()
19+
"""
20+
21+
def __init__(self, base_url: str = "http://localhost:8000"):
22+
self.base_url = base_url
23+
self.initial_state_url = f"{base_url}/control/initial_state"
24+
25+
async def get_initial_state(self, thread_id: int) -> Dict[str, Any]:
26+
"""
27+
Simulates the get_initial_state call from your McpGym code.
28+
This is the slow HTTP call that happens during envs.reset().
29+
"""
30+
headers = {"Content-Type": "application/json", "Accept": "application/json"}
31+
32+
start_time = time.time()
33+
34+
try:
35+
async with httpx.AsyncClient(timeout=30.0) as client:
36+
# This is the exact pattern from your code
37+
initial_state_response = await client.get(
38+
self.initial_state_url,
39+
headers=headers,
40+
timeout=30.0,
41+
)
42+
initial_state_response.raise_for_status()
43+
result = initial_state_response.json()
44+
45+
end_time = time.time()
46+
duration = end_time - start_time
47+
48+
return {"thread_id": thread_id, "success": True, "duration": duration, "initial_state": result}
49+
50+
except Exception as e:
51+
end_time = time.time()
52+
duration = end_time - start_time
53+
return {"thread_id": thread_id, "success": False, "duration": duration, "error": str(e)}
54+
55+
async def envs_reset(self, thread_id: int) -> Dict[str, Any]:
56+
"""
57+
Simulates envs.reset() which internally calls get_initial_state.
58+
This is what gets called from _execute_rollout().
59+
"""
60+
print(f"🔄 Thread {thread_id}: envs.reset() called")
61+
62+
# This simulates the envs.reset() -> get_initial_state call chain
63+
return await self.get_initial_state(thread_id)
64+
65+
66+
async def _execute_rollout(thread_id: int, client: EnvResetClient) -> Dict[str, Any]:
67+
"""
68+
Simulates _execute_rollout() function that calls envs.reset().
69+
This runs concurrently using asyncio, matching your actual pattern.
70+
"""
71+
print(f"🚀 Rollout {thread_id}: _execute_rollout() started")
72+
73+
# This is where envs.reset() gets called
74+
result = await client.envs_reset(thread_id)
75+
return result
76+
77+
78+
async def run_burst_test(num_clients: int = 50, server_url: str = "http://localhost:8000"):
79+
"""
80+
Run burst test simulating 50 concurrent _execute_rollout() calls.
81+
Each one calls envs.reset() -> get_initial_state -> client.get()
82+
"""
83+
print(f"🚀 Starting burst test with {num_clients} concurrent rollouts")
84+
print(f"🎯 Target server: {server_url}")
85+
print(f"📋 Pattern: _execute_rollout() -> envs.reset() -> get_initial_state -> client.get()")
86+
87+
client = EnvResetClient(server_url)
88+
89+
# Create tasks for concurrent rollouts (simulating your threading pattern)
90+
start_time = time.time()
91+
tasks = [_execute_rollout(i, client) for i in range(num_clients)]
92+
93+
# Run all rollouts concurrently
94+
results = await asyncio.gather(*tasks, return_exceptions=True)
95+
96+
end_time = time.time()
97+
total_duration = end_time - start_time
98+
99+
# Analyze results
100+
successful = [r for r in results if isinstance(r, dict) and r.get("success")]
101+
failed = [r for r in results if isinstance(r, dict) and not r.get("success")]
102+
exceptions = [r for r in results if not isinstance(r, dict)]
103+
104+
print(f"\n📊 BURST TEST RESULTS:")
105+
print(f" Total rollouts: {num_clients}")
106+
print(f" Total time: {total_duration:.3f}s")
107+
print(f" Successful: {len(successful)}")
108+
print(f" Failed: {len(failed)}")
109+
print(f" Exceptions: {len(exceptions)}")
110+
111+
if successful:
112+
avg_duration = sum(r["duration"] for r in successful) / len(successful)
113+
min_duration = min(r["duration"] for r in successful)
114+
max_duration = max(r["duration"] for r in successful)
115+
116+
print(f" Average rollout duration: {avg_duration:.3f}s")
117+
print(f" Min rollout duration: {min_duration:.3f}s")
118+
print(f" Max rollout duration: {max_duration:.3f}s")
119+
120+
# Show sample successful result
121+
sample = successful[0]
122+
print(f"\n✅ Sample successful rollout:")
123+
print(f" Thread ID: {sample['thread_id']}")
124+
print(f" Initial state: {sample['initial_state']['observation']}")
125+
print(f" Timestamp: {sample['initial_state']['timestamp']}")
126+
127+
if failed:
128+
print(f"\n❌ Sample failed rollouts:")
129+
for fail in failed[:3]: # Show first 3 failures
130+
print(f" Thread {fail['thread_id']}: {fail['error']}")
131+
132+
if exceptions:
133+
print(f"\n💥 Sample exceptions:")
134+
for exc in exceptions[:3]: # Show first 3 exceptions
135+
print(f" {type(exc).__name__}: {exc}")
136+
137+
# Key test: If concurrent, should take ~1 second. If sequential, ~50 seconds.
138+
if total_duration < 5: # Allow some overhead
139+
print(f"\n🎉 CONCURRENCY WORKING! Total time {total_duration:.3f}s (expected ~1s for concurrent)")
140+
else:
141+
print(f"\n⚠️ POSSIBLE SEQUENTIAL EXECUTION! Total time {total_duration:.3f}s (expected ~1s for concurrent)")
142+
143+
return len(successful) == num_clients
144+
145+
146+
def main():
147+
"""Run the burst test."""
148+
import argparse
149+
150+
parser = argparse.ArgumentParser(
151+
description="Envs Reset Burst Test - Simulates 50 rollouts calling get_initial_state"
152+
)
153+
parser.add_argument("--rollouts", type=int, default=50, help="Number of concurrent rollouts")
154+
parser.add_argument("--server", default="http://localhost:8000", help="Server URL")
155+
156+
args = parser.parse_args()
157+
158+
success = asyncio.run(run_burst_test(args.rollouts, args.server))
159+
160+
if success:
161+
print(f"\n🎉 ALL {args.rollouts} ROLLOUTS SUCCESSFUL!")
162+
exit(0)
163+
else:
164+
print(f"\n💥 SOME ROLLOUTS FAILED!")
165+
exit(1)
166+
167+
168+
if __name__ == "__main__":
169+
main()

test_simple_mcp_server.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Simple MCP Server for Testing get_initial_state Concurrency
4+
Simulates the exact pattern: envs.reset() -> get_initial_state -> slow HTTP endpoint
5+
"""
6+
7+
import asyncio
8+
import os
9+
import time
10+
11+
from fastmcp import FastMCP
12+
from starlette.requests import Request
13+
from starlette.responses import JSONResponse
14+
15+
# Create a simple MCP server
16+
mcp = FastMCP(name="TestServer")
17+
18+
19+
@mcp.custom_route("/control/initial_state", methods=["GET"])
20+
async def get_initial_state_endpoint(request: Request) -> JSONResponse:
21+
"""
22+
Simulate the get_initial_state endpoint that's slow.
23+
This mimics the pattern in your McpGym code.
24+
"""
25+
print(f"🔍 get_initial_state called at {time.time()}")
26+
27+
# Simulate the slow operation (like environment initialization)
28+
time.sleep(1) # 1 second delay to test concurrency
29+
30+
# Return a dummy initial state
31+
return JSONResponse({"observation": "dummy_initial_state", "session_id": "test_session", "timestamp": time.time()})
32+
33+
34+
@mcp.tool
35+
def dummy_tool() -> str:
36+
"""Dummy tool for MCP compatibility."""
37+
return "dummy"
38+
39+
40+
def main():
41+
"""Run the test server."""
42+
port = int(os.environ.get("PORT", 8000))
43+
print(f"🚀 Starting get_initial_state test server on port {port}")
44+
print(f"📡 Endpoint: http://localhost:{port}/control/initial_state")
45+
46+
# Use FastMCP 2.0 run method with streamable-http transport
47+
mcp.run(transport="http", host="0.0.0.0", port=port)
48+
49+
50+
if __name__ == "__main__":
51+
main()

0 commit comments

Comments
 (0)