Skip to content

Commit 680e719

Browse files
Integrate eval protocol to rllm trainer (#292)
* init * -import signal * commend out signal * print("model_id from eval_protocol: ", model_id) * new LiteLLMPolicy * - print("model_id from eval_protocol: ", model_id) * reverts * reformat --------- Co-authored-by: Derek Xu <xzrderek@gmail.com>
1 parent 87a99e0 commit 680e719

1 file changed

Lines changed: 25 additions & 27 deletions

File tree

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
import signal
55
import socket
66
import subprocess
7+
import threading
78
import time
89
from pathlib import Path
910
from typing import List, Optional
1011

1112
import eval_protocol as ep
13+
from eval_protocol.mcp.execution.manager import ExecutionManager
1214
from eval_protocol.models import EvaluationRow
1315
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1416
from eval_protocol.pytest.types import RolloutProcessorConfig
15-
from eval_protocol.mcp.execution.manager import ExecutionManager
1617

1718

1819
class MCPServerManager:
@@ -181,8 +182,9 @@ def _signal_handler(cls, signum, frame):
181182
def _register_cleanup_handlers(cls):
182183
"""Register cleanup handlers - called only once"""
183184
atexit.register(cls._cleanup_all_servers)
184-
signal.signal(signal.SIGINT, cls._signal_handler) # Ctrl+C
185-
signal.signal(signal.SIGTERM, cls._signal_handler) # Termination signal
185+
if threading.current_thread() is threading.main_thread():
186+
signal.signal(signal.SIGINT, cls._signal_handler) # Ctrl+C
187+
signal.signal(signal.SIGTERM, cls._signal_handler) # Termination signal
186188

187189
def __enter__(self):
188190
"""Context manager entry"""
@@ -223,28 +225,6 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
223225
try:
224226
self.server.start()
225227

226-
model_id = str(
227-
(config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini"
228-
)
229-
temperature = config.completion_params.get("temperature", 0.0)
230-
max_tokens = config.completion_params.get("max_tokens", 4096)
231-
232-
# Pass all other completion_params (e.g. stream=True) via kwargs
233-
other_params = {
234-
k: v
235-
for k, v in (config.completion_params or {}).items()
236-
if k not in ["model", "temperature", "max_tokens", "extra_body"]
237-
}
238-
extra_body = config.completion_params.get("extra_body", {}) or {}
239-
240-
self.policy = ep.LiteLLMPolicy(
241-
model_id=model_id,
242-
temperature=temperature,
243-
max_tokens=max_tokens,
244-
**extra_body,
245-
**other_params,
246-
)
247-
248228
except Exception as e:
249229
if self.server:
250230
self.server.stop()
@@ -254,13 +234,31 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
254234

255235
else:
256236
# Reuse existing MCP environments for retry
257-
if not self.server or not self.policy:
237+
if not self.server:
258238
raise RuntimeError(
259239
"Cannot retry without existing server/environments. Call with start_server=True first."
260240
)
261241

242+
model_id = str((config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini")
243+
temperature = config.completion_params.get("temperature", 0.0)
244+
max_tokens = config.completion_params.get("max_tokens", 4096)
245+
246+
# Pass all other completion_params (e.g. stream=True) via kwargs
247+
other_params = {
248+
k: v
249+
for k, v in (config.completion_params or {}).items()
250+
if k not in ["model", "temperature", "max_tokens", "extra_body"]
251+
}
252+
extra_body = config.completion_params.get("extra_body", {}) or {}
253+
254+
self.policy = ep.LiteLLMPolicy(
255+
model_id=model_id,
256+
temperature=temperature,
257+
max_tokens=max_tokens,
258+
**extra_body,
259+
**other_params,
260+
)
262261
# Create MCP environments directly from evaluation_rows
263-
assert self.policy is not None, "Policy must be initialized before rollout"
264262
envs = ep.make(
265263
"http://localhost:9700/mcp/",
266264
evaluation_rows=rows,

0 commit comments

Comments
 (0)