Skip to content

Commit 5d7e5cb

Browse files
author
Dylan Huang
authored
RemoteRolloutProcessor / evaluation_test improvements (#237)
* Add row_ids parameter to evaluation_test function for filtering evaluations - Introduced row_ids as an optional parameter to allow filtering of evaluation rows based on specified identifiers. - Updated documentation to reflect the new parameter and its usage in the evaluation process. * Handle timeout in RemoteRolloutProcessor by updating rollout status - Added logic to set the rollout status to an error when the polling loop completes without a successful break, indicating a timeout. - Enhanced error handling to provide clearer feedback on rollout timeouts. * Add optional status field to StatusResponse model in remote rollout processor - Introduced an optional status indicator in the StatusResponse model to differentiate between successful and failed rollouts. - Updated documentation to clarify the purpose of the new status field for better understanding in the eval-protocol context. * Rename row_ids parameter to filtered_row_ids in evaluation_test function for clarity - Updated the parameter name from row_ids to filtered_row_ids to better reflect its purpose in filtering evaluation rows. - Adjusted related documentation to ensure consistency and clarity regarding the new parameter name.
1 parent 2a8ace1 commit 5d7e5cb

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def evaluation_test(
7979
aggregation_method: AggregationMethod = "mean",
8080
passed_threshold: EvaluationThreshold | float | EvaluationThresholdDict | None = None,
8181
num_runs: int = 1,
82+
filtered_row_ids: Sequence[str] | None = None,
8283
max_dataset_rows: int | None = None,
8384
mcp_config_path: str | None = None,
8485
max_concurrent_rollouts: int = 8,
@@ -146,6 +147,7 @@ def evaluation_test(
146147
Success rate must be above success, and if set, standard error must be below standard_error.
147148
Success rate +/- one standard_error is equivalent to 68% confidence interval.
148149
num_runs: Number of times to repeat the rollout and evaluations.
150+
filtered_row_ids: List of row_ids to filter for the evaluation. If provided, only the rows with the given row_ids will be evaluated.
149151
max_dataset_rows: Limit dataset to the first N rows.
150152
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
151153
max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel.
@@ -286,6 +288,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
286288
else:
287289
raise ValueError("No input dataset, input messages, or input rows provided")
288290

291+
if filtered_row_ids is not None:
292+
data = [row for row in data if row.input_metadata.row_id in filtered_row_ids]
293+
289294
"""
290295
data_loaders handles preprocess_fn internally so we want
291296
to specially handle data_loaders here so we don't double

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,13 @@ def _get_status() -> Dict[str, Any]:
155155
except Exception:
156156
# transient errors; continue polling
157157
pass
158+
158159
await asyncio.sleep(poll_interval)
160+
else:
161+
# Loop completed without breaking, which means we timed out
162+
row.rollout_status = Status.rollout_error(
163+
f"Rollout {row.execution_metadata.rollout_id} timed out after {timeout_seconds} seconds"
164+
)
159165

160166
# Update duration, regardless of termination
161167
row.execution_metadata.duration_seconds = time.perf_counter() - start_time

eval_protocol/types/remote_rollout_processor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import Any, Dict, List, Optional
66
from pydantic import BaseModel, Field
7-
from eval_protocol.models import Message
7+
from eval_protocol.models import Message, Status
88

99

1010
class RolloutMetadata(BaseModel):
@@ -40,6 +40,12 @@ class StatusResponse(BaseModel):
4040
terminated: bool
4141
info: Optional[Dict[str, Any]] = None
4242

43+
status: Optional[Status] = None
44+
"""
45+
Optional status indicator for the rollout to be used by eval-protocol. This
46+
is useful to distinguish between successful and failed rollouts.
47+
"""
48+
4349

4450
def create_langfuse_config_tags(init_request: InitRequest) -> List[str]:
4551
"""Create Langfuse tags from InitRequest metadata."""

0 commit comments

Comments
 (0)