diff --git a/tests/test_rubric_group.py b/tests/test_rubric_group.py index ca380e1fb..f84ef3e5c 100644 --- a/tests/test_rubric_group.py +++ b/tests/test_rubric_group.py @@ -417,3 +417,79 @@ def reward_func(completion, parser, answer, **_): assert state["reward"] == 1.0 assert recorded_parsers == [xml_parser] + + @pytest.mark.asyncio + async def test_rubric_group_score_rollout_timing(self): + """Test that generation_ms + scoring_ms == total_ms after score_rollout.""" + + def func1(completion, **kwargs): + return 1.0 + + def func2(completion, **kwargs): + return 0.5 + + rubric1 = Rubric(funcs=[func1], weights=[1.0]) + rubric2 = Rubric(funcs=[func2], weights=[1.0]) + + group = RubricGroup(rubrics=[rubric1, rubric2]) + + state = State( + input=RolloutInput( + prompt=[{"role": "user", "content": "test"}], + answer="test", + task="default", + example_id=0, + ) + ) + state["completion"] = [{"role": "assistant", "content": "test"}] + state["trajectory"] = [] + state["timing"] = RolloutTiming( + generation_ms=100.0, + scoring_ms=0.0, + total_ms=100.0, + start_time=0.0, + ) + + await group.score_rollout(state) + + assert state["timing"]["generation_ms"] == 100.0 + assert state["timing"]["scoring_ms"] > 0.0 + assert state["timing"]["total_ms"] == 100.0 + state["timing"]["scoring_ms"] + + @pytest.mark.asyncio + async def test_rubric_group_score_group_timing(self): + """Test that generation_ms + scoring_ms == total_ms after score_group.""" + + def func1(completion, **kwargs): + return 1.0 + + def func2(completion, **kwargs): + return 0.5 + + rubric1 = Rubric(funcs=[func1], weights=[1.0]) + rubric2 = Rubric(funcs=[func2], weights=[1.0]) + + group = RubricGroup(rubrics=[rubric1, rubric2]) + + state = State( + input=RolloutInput( + prompt=[{"role": "user", "content": "test"}], + answer="test", + task="default", + example_id=0, + ) + ) + state["completion"] = [{"role": "assistant", "content": "test"}] + state["trajectory"] = [] + state["timing"] = RolloutTiming( + generation_ms=100.0, + scoring_ms=0.0, + total_ms=100.0, + start_time=0.0, + ) + + await group.score_group([state]) + + assert state["timing"]["generation_ms"] == 100.0 + assert state["timing"]["scoring_ms"] > 0.0 + assert state["timing"]["total_ms"] == 100.0 + state["timing"]["scoring_ms"] diff --git a/verifiers/rubrics/rubric_group.py b/verifiers/rubrics/rubric_group.py index b0d8bbb71..b80ad60fd 100644 --- a/verifiers/rubrics/rubric_group.py +++ b/verifiers/rubrics/rubric_group.py @@ -1,3 +1,4 @@ +import time from typing import Any from verifiers.rubrics.rubric import Rubric @@ -56,12 +57,14 @@ async def score_rollout(self, state: State): """ Evaluate all reward functions in-place for a single rollout. """ + start_time = time.time() total_reward = 0.0 aggregated_metrics: dict[str, float] = {} original_reward = state.get("reward", 0.0) original_metrics = ( state.get("metrics", {}).copy() if state.get("metrics") else {} ) + original_timing = state["timing"].copy() for rubric in self.rubrics: await rubric.score_rollout(state) rubric_reward = state.get("reward", 0.0) @@ -74,13 +77,19 @@ async def score_rollout(self, state: State): # restore original values for next rubric state["reward"] = original_reward state["metrics"] = original_metrics.copy() + state["timing"] = original_timing.copy() state["reward"] = total_reward state["metrics"] = aggregated_metrics + end_time = time.time() + scoring_ms = (end_time - start_time) * 1000 + state["timing"]["scoring_ms"] = scoring_ms + state["timing"]["total_ms"] += scoring_ms async def score_group(self, states: list[State]): """ Evaluate all reward functions in-place for a group of rollouts. """ + start_time = time.time() aggregated_rewards = [0.0] * len(states) aggregated_metrics: dict[str, list[float]] = {} original_rewards = [state.get("reward", 0.0) for state in states] @@ -88,6 +97,7 @@ async def score_group(self, states: list[State]): state.get("metrics", {}).copy() if state.get("metrics") else {} for state in states ] + original_timings = [state["timing"].copy() for state in states] for rubric in self.rubrics: await rubric.score_group(states) for i, state in enumerate(states): @@ -102,6 +112,9 @@ async def score_group(self, states: list[State]): aggregated_metrics[key][i] += value state["reward"] = original_rewards[i] state["metrics"] = original_metrics[i].copy() + state["timing"] = original_timings[i].copy() + end_time = time.time() + scoring_ms = (end_time - start_time) * 1000 for i, state in enumerate(states): state["reward"] = aggregated_rewards[i] if aggregated_metrics: @@ -109,3 +122,5 @@ async def score_group(self, states: list[State]): state["metrics"] = {} for key, values in aggregated_metrics.items(): state["metrics"][key] = values[i] + state["timing"]["scoring_ms"] = scoring_ms + state["timing"]["total_ms"] += scoring_ms