Skip to content

Commit b35c4bc

Browse files
committed
fix
1 parent ffc9351 commit b35c4bc

2 files changed

Lines changed: 38 additions & 2 deletions

File tree

eval_protocol/reward_function.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from .models import EvaluateResult, MetricResult
1313
from .typed_interface import reward_function
1414

15-
logging.basicConfig(level=logging.INFO)
1615
logger = logging.getLogger(__name__)
1716

1817
T = TypeVar("T", bound=Callable[..., EvaluateResult])

tests/remote_server/test_remote_fireworks.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# AUTO SERVER STARTUP: Server is automatically started and stopped by the test
22

3+
import logging
34
import subprocess
45
import socket
56
import time
@@ -19,10 +20,23 @@
1920
ROLLOUT_IDS = set()
2021

2122

23+
class StatusLogCaptureHandler(logging.Handler):
24+
"""Custom handler to capture status log messages."""
25+
26+
def __init__(self):
27+
super().__init__()
28+
self.status_100_messages: List[str] = []
29+
30+
def emit(self, record):
31+
msg = record.getMessage() # Use getMessage(), not .message attribute
32+
if "Found Fireworks log" in msg and "with status code 100" in msg:
33+
self.status_100_messages.append(msg)
34+
35+
2236
@pytest.fixture(autouse=True)
2337
def check_rollout_coverage(monkeypatch):
2438
"""
25-
Ensure we attempted to fetch remote traces for each rollout.
39+
Ensure we attempted to fetch remote traces for each rollout and received status logs.
2640
2741
This wraps the built-in default_fireworks_output_data_loader (without making it configurable)
2842
and tracks rollout_ids passed through its DataLoaderConfig.
@@ -37,9 +51,32 @@ def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader:
3751
return original_loader(config)
3852

3953
monkeypatch.setattr(remote_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader)
54+
55+
# Add custom handler to capture status logs
56+
status_handler = StatusLogCaptureHandler()
57+
status_handler.setLevel(logging.INFO)
58+
rrp_logger = logging.getLogger("eval_protocol.pytest.remote_rollout_processor")
59+
rrp_logger.addHandler(status_handler)
60+
# Ensure the logger level allows INFO messages through
61+
original_level = rrp_logger.level
62+
rrp_logger.setLevel(logging.INFO)
63+
4064
yield
65+
66+
# Cleanup handler and restore level
67+
rrp_logger.removeHandler(status_handler)
68+
rrp_logger.setLevel(original_level)
69+
70+
# After test completes, verify we saw status logs for all 3 rollouts
4171
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
4272

73+
# Check that we received "Found Fireworks log ... with status code 100" for each rollout
74+
assert len(status_handler.status_100_messages) == 3, (
75+
f"Expected 3 'Found Fireworks log ... with status code 100' messages, but found {len(status_handler.status_100_messages)}. "
76+
f"This means the status logs from the remote server were not received. "
77+
f"Messages captured: {status_handler.status_100_messages}"
78+
)
79+
4380

4481
def find_available_port() -> int:
4582
"""Find an available port on localhost"""

0 commit comments

Comments
 (0)