Skip to content

Commit 23ba2b3

Browse files
author
Shrey Modi
committed
updates
1 parent a1a973e commit 23ba2b3

5 files changed

Lines changed: 48 additions & 77 deletions

File tree

eval_protocol/pytest/integrations/openenv_trl_vllm.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -387,22 +387,11 @@ async def _run_all() -> List[EvaluationRow]:
387387
if isinstance(extra, dict):
388388
prompt_ids = list(extra.get("prompt_ids", []) or [])
389389
completion_ids = list(extra.get("completion_ids", []) or [])
390+
rewards = [float(r) for r in (extra.get("step_rewards", []) or [])]
390391
except Exception:
391392
prompt_ids = []
392393
completion_ids = []
393-
394-
# Extract step rewards from the sentinel system message
395-
for msg in row.messages:
396-
if msg.role == "system":
397-
try:
398-
content = msg.content or ""
399-
if isinstance(content, str) and content.startswith("__ep_step_rewards__:"):
400-
import json
401-
402-
payload = content.split(":", 1)[1]
403-
rewards = json.loads(payload) or []
404-
except Exception:
405-
pass
394+
rewards = []
406395

407396
# Append accumulated tokens for this episode
408397
episode_prompt_ids.append(prompt_ids if prompt_ids else [0])

eval_protocol/pytest/openenv_rollout_processor.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import time
1818
from itertools import count
1919
from typing import List, Any, Dict, Callable, Generic, TypeVar, Optional, Type
20-
import json
2120

2221
from openai.types import CompletionUsage
2322

@@ -414,26 +413,22 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
414413
)
415414
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
416415

417-
# Store per-step rewards in a sentinel system message so
418-
# evaluation tests and downstream integrations can reconstruct
419-
# episode rewards.
420-
sentinel = "__ep_step_rewards__:" + json.dumps(step_rewards)
421-
messages.append(Message(role="system", content=sentinel))
422-
423-
# Attach accumulated token IDs to execution_metadata.extra for
424-
# training integrations (e.g., TRL GRPO) instead of encoding
425-
# them into synthetic system messages.
426-
if all_prompt_ids or all_completion_ids:
427-
try:
428-
extra = getattr(row.execution_metadata, "extra", None)
429-
if not isinstance(extra, dict):
430-
extra = {}
416+
# Attach per-step rewards and accumulated token IDs to
417+
# execution_metadata.extra for downstream integrations
418+
# (for example, TRL GRPO) instead of encoding them into
419+
# synthetic system messages.
420+
try:
421+
extra = getattr(row.execution_metadata, "extra", None)
422+
if not isinstance(extra, dict):
423+
extra = {}
424+
extra["step_rewards"] = list(step_rewards)
425+
if all_prompt_ids or all_completion_ids:
431426
extra["prompt_ids"] = list(all_prompt_ids)
432427
extra["completion_ids"] = list(all_completion_ids)
433-
row.execution_metadata.extra = extra # type: ignore[attr-defined]
434-
except Exception:
435-
# Non-fatal: training integrations can fall back if tokens are missing
436-
pass
428+
row.execution_metadata.extra = extra # type: ignore[attr-defined]
429+
except Exception:
430+
# Non-fatal: callers can fall back if metadata is missing
431+
pass
437432

438433
total_reward = sum(step_rewards)
439434
logger.info("[OpenEnvRolloutProcessor] ✅ ROLLOUT COMPLETE")

tests/pytest/test_openenv_browsergym_eval.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -275,20 +275,13 @@ def test_openenv_browsergym_eval(row: EvaluationRow) -> EvaluationRow:
275275
"""
276276
if not _HAS_BG:
277277
pytest.skip("OpenEnv (envs.browsergym_env) is not installed; skipping BrowserGym test.")
278-
# Extract step rewards from the sentinel system message injected by the rollout processor
278+
# Extract step rewards from execution metadata (set by OpenEnvRolloutProcessor)
279279
step_rewards: List[float] = []
280280
try:
281-
for msg in row.messages or []:
282-
if (
283-
msg.role == "system"
284-
and isinstance(msg.content, str)
285-
and msg.content.startswith("__ep_step_rewards__:")
286-
):
287-
import json as _json
288-
289-
payload = msg.content.split(":", 1)[1]
290-
step_rewards = _json.loads(payload) or []
291-
break
281+
extra = getattr(row.execution_metadata, "extra", None)
282+
if isinstance(extra, dict):
283+
raw = extra.get("step_rewards") or []
284+
step_rewards = [float(r) for r in raw]
292285
except Exception:
293286
step_rewards = []
294287

tests/pytest/test_openenv_echo_hub.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@
88
from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor
99
import pytest
1010

11+
try:
12+
# Preferred import when using the monolithic `openenv` package
13+
from openenv.envs.echo_env import EchoEnv # type: ignore
14+
15+
_HAS_ECHO = True
16+
except Exception:
17+
_HAS_ECHO = False
18+
1119
# Skip these integration-heavy tests on CI runners by default
1220
pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skip OpenEnv integration tests on CI")
1321

@@ -35,20 +43,20 @@ def action_parser(response_text: str):
3543
Convert raw model response to EchoAction.
3644
"""
3745
try:
38-
from envs.echo_env import EchoAction # type: ignore
46+
from openenv.envs.echo_env import EchoAction # type: ignore
3947
except Exception:
40-
pytest.skip("OpenEnv (envs.echo_env) is not installed; skipping Echo hub test.")
48+
pytest.skip("OpenEnv (openenv.envs.echo_env) is not installed; skipping Echo hub test.")
4149
raise
4250
text = response_text.strip() if isinstance(response_text, str) else ""
4351
return EchoAction(message=text or "hello")
4452

4553

46-
try:
47-
from envs.echo_env import EchoEnv # type: ignore
54+
# try:
55+
# from envs.echo_env import EchoEnv # type: ignore
4856

49-
_HAS_ECHO = True
50-
except Exception:
51-
_HAS_ECHO = False
57+
# _HAS_ECHO = True
58+
# except Exception:
59+
# _HAS_ECHO = False
5260

5361

5462
# Inline test data
@@ -93,23 +101,15 @@ def test_openenv_echo_hub(row: EvaluationRow) -> EvaluationRow:
93101
Extracts env rewards (from rollout policy extras) and sets evaluation_result.
94102
"""
95103
if not _HAS_ECHO:
96-
pytest.skip("OpenEnv (envs.echo_env) is not installed; skipping Echo hub test.")
97-
# Try to read rewards/usage left in execution metadata extra or system messages.
104+
pytest.skip("OpenEnv (openenv.envs.echo_env) is not installed; skipping Echo hub test.")
105+
# Try to read rewards/usage left in execution metadata extra.
98106
total_reward = 0.0
99107
try:
100-
# Preferred path: system sentinel "__ep_step_rewards__"
108+
extra = getattr(row.execution_metadata, "extra", None)
101109
step_rewards: List[float] = []
102-
for msg in row.messages or []:
103-
if (
104-
msg.role == "system"
105-
and isinstance(msg.content, str)
106-
and msg.content.startswith("__ep_step_rewards__:")
107-
):
108-
import json as _json
109-
110-
payload = msg.content.split(":", 1)[1]
111-
step_rewards = _json.loads(payload) or []
112-
break
110+
if isinstance(extra, dict):
111+
raw = extra.get("step_rewards") or []
112+
step_rewards = [float(r) for r in raw]
113113
total_reward = float(sum(step_rewards)) if step_rewards else 0.0
114114
except Exception:
115115
total_reward = 0.0

tests/pytest/test_openenv_textarena_docker.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,25 +144,19 @@ def test_openenv_textarena_docker(row: EvaluationRow) -> EvaluationRow:
144144
# Extract step rewards and compute score
145145
total_reward = 0.0
146146
try:
147+
extra = getattr(row.execution_metadata, "extra", None)
147148
step_rewards: List[float] = []
148-
for msg in row.messages or []:
149-
if (
150-
msg.role == "system"
151-
and isinstance(msg.content, str)
152-
and msg.content.startswith("__ep_step_rewards__:")
153-
):
154-
import json
155-
156-
payload = msg.content.split(":", 1)[1]
157-
step_rewards = json.loads(payload) or []
158-
break
149+
if isinstance(extra, dict):
150+
raw = extra.get("step_rewards") or []
151+
step_rewards = [float(r) for r in raw]
159152
total_reward = float(sum(step_rewards)) if step_rewards else 0.0
160153
except Exception:
161154
total_reward = 0.0
162155

163156
score = max(0.0, min(1.0, total_reward))
157+
steps = len(step_rewards) if "step_rewards" in locals() else 0
164158
row.evaluation_result = EvaluateResult(
165159
score=score,
166-
reason=f"TextArena total reward={total_reward:.2f} over {len(step_rewards) if 'step_rewards' in locals() else 0} steps",
160+
reason=f"TextArena total reward={total_reward:.2f} over {steps} steps",
167161
)
168162
return row

0 commit comments

Comments
 (0)