Skip to content

Commit 9942429

Browse files
committed
take out output dataloader
1 parent cee95a9 commit 9942429

7 files changed

Lines changed: 8 additions & 123 deletions

File tree

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import asyncio
22
import os
33
import time
4-
from typing import Any, Callable, Dict, List, Optional
4+
from typing import Any, Dict, List, Optional
55
import json
66
import requests
77
from datetime import datetime, timezone, timedelta
88
from eval_protocol.models import EvaluationRow, Status
99
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
10-
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
1110

1211
from .rollout_processor import RolloutProcessor
1312
from .types import RolloutProcessorConfig
@@ -21,7 +20,7 @@ class GithubActionRolloutProcessor(RolloutProcessor):
2120
Expected GitHub Actions workflow:
2221
- Workflow dispatch with inputs: completion_params, metadata (JSON), model_base_url, api_key
2322
- Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy)
24-
- Traces are fetched later via output_data_loader using rollout_id tags
23+
- Traces are fetched later via Fireworks tracing proxy using rollout_id tags
2524
2625
NOTE: GHA has a rate limit of 5000 requests per hour.
2726
"""
@@ -38,7 +37,6 @@ def __init__(
3837
timeout_seconds: float = 1800.0,
3938
max_find_workflow_retries: int = 5,
4039
github_token: Optional[str] = None,
41-
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
4240
):
4341
self.owner = owner
4442
self.repo = repo
@@ -52,7 +50,7 @@ def __init__(
5250
self.timeout_seconds = timeout_seconds
5351
self.max_find_workflow_retries = max_find_workflow_retries
5452
self.github_token = github_token
55-
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
53+
self._output_data_loader = default_fireworks_output_data_loader
5654

5755
def _headers(self) -> Dict[str, str]:
5856
headers = {"Accept": "application/vnd.github+json"}

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import time
3-
from typing import Any, Dict, List, Optional, Callable
3+
from typing import Any, Dict, List, Optional
44

55
import requests
66

@@ -26,8 +26,7 @@ class RemoteRolloutProcessor(RolloutProcessor):
2626
"""
2727
Rollout processor that triggers a remote HTTP server to perform the rollout.
2828
29-
By default, fetches traces from the Fireworks tracing proxy using rollout_id tags.
30-
You can provide a custom output_data_loader for different tracing backends.
29+
Fetches traces from the Fireworks tracing proxy using rollout_id tags.
3130
3231
See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation.
3332
"""
@@ -39,7 +38,6 @@ def __init__(
3938
model_base_url: str = "https://tracing.fireworks.ai",
4039
poll_interval: float = 1.0,
4140
timeout_seconds: float = 120.0,
42-
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
4341
):
4442
# Prefer constructor-provided configuration. These can be overridden via
4543
# config.kwargs at call time for backward compatibility.
@@ -52,7 +50,7 @@ def __init__(
5250
self._model_base_url = _ep_model_base_url
5351
self._poll_interval = poll_interval
5452
self._timeout_seconds = timeout_seconds
55-
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
53+
self._output_data_loader = default_fireworks_output_data_loader
5654
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)
5755

5856
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:

tests/github_actions/test_github_actions_rollout.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
from eval_protocol.models import EvaluationRow, InputMetadata
1313
from eval_protocol.pytest import evaluation_test
1414
from eval_protocol.pytest.github_action_rollout_processor import GithubActionRolloutProcessor
15-
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
16-
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
17-
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
1815

1916
ROLLOUT_IDS = set()
2017

@@ -29,21 +26,6 @@ def check_rollout_coverage():
2926
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
3027

3128

32-
def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
33-
global ROLLOUT_IDS # Track all rollout_ids we've seen
34-
ROLLOUT_IDS.add(config.rollout_id)
35-
36-
base_url = config.model_base_url or "https://tracing.fireworks.ai"
37-
adapter = FireworksTracingAdapter(base_url=base_url)
38-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
39-
40-
41-
def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
42-
return DynamicDataLoader(
43-
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
44-
)
45-
46-
4729
def rows() -> List[EvaluationRow]:
4830
return [
4931
EvaluationRow(input_metadata=InputMetadata(row_id=str(i)))
@@ -68,7 +50,6 @@ def rows() -> List[EvaluationRow]:
6850
ref=os.getenv("GITHUB_REF", "main"),
6951
poll_interval=3.0, # For multi-turn, you'll likely want higher poll interval
7052
timeout_seconds=300,
71-
output_data_loader=fireworks_output_data_loader,
7253
),
7354
)
7455
async def test_github_actions_rollout(row: EvaluationRow) -> EvaluationRow:

tests/remote_server/test_remote_fireworks.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# AUTO SERVER STARTUP: Server is automatically started and stopped by the test
22

3-
import os
43
import subprocess
54
import socket
65
import time
@@ -13,9 +12,6 @@
1312
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
1413
from eval_protocol.pytest import evaluation_test
1514
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
16-
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
17-
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
18-
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
1915

2016
ROLLOUT_IDS = set()
2117

@@ -78,21 +74,6 @@ def check_rollout_coverage():
7874
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
7975

8076

81-
def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
82-
global ROLLOUT_IDS # Track all rollout_ids we've seen
83-
ROLLOUT_IDS.add(config.rollout_id)
84-
85-
base_url = config.model_base_url or "https://tracing.fireworks.ai"
86-
adapter = FireworksTracingAdapter(base_url=base_url)
87-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7)
88-
89-
90-
def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
91-
return DynamicDataLoader(
92-
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
93-
)
94-
95-
9677
def rows() -> List[EvaluationRow]:
9778
"""Generate local rows with rich input_metadata to verify it survives remote traces."""
9879
base_dataset_info = {
@@ -118,7 +99,6 @@ def rows() -> List[EvaluationRow]:
11899
rollout_processor=RemoteRolloutProcessor(
119100
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
120101
timeout_seconds=180,
121-
output_data_loader=fireworks_output_data_loader,
122102
),
123103
)
124104
async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow:
@@ -129,13 +109,11 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
129109
- fetch traces from Langfuse via Fireworks tracing proxy filtered by metadata via output_data_loader; FAIL if none found
130110
"""
131111
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
112+
ROLLOUT_IDS.add(row.execution_metadata.rollout_id)
132113

133114
assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content"
134115
assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row."
135116

136-
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
137-
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
138-
)
139117
assert row.input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"
140118
assert row.input_metadata.completion_params["temperature"] == 0.5, "Row should have temperature at top level"
141119

tests/remote_server/test_remote_fireworks_propagate_status.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
from eval_protocol.models import EvaluationRow, Message, Status, EvaluateResult
1313
from eval_protocol.pytest import evaluation_test
1414
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
15-
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
16-
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
17-
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
1815

1916

2017
def find_available_port() -> int:
@@ -67,18 +64,6 @@ def setup_remote_server():
6764
process.wait()
6865

6966

70-
def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
71-
base_url = config.model_base_url or "https://tracing.fireworks.ai"
72-
adapter = FireworksTracingAdapter(base_url=base_url)
73-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7)
74-
75-
76-
def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
77-
return DynamicDataLoader(
78-
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
79-
)
80-
81-
8267
def rows() -> List[EvaluationRow]:
8368
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
8469
return [row]
@@ -92,7 +77,6 @@ def rows() -> List[EvaluationRow]:
9277
rollout_processor=RemoteRolloutProcessor(
9378
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
9479
timeout_seconds=120,
95-
output_data_loader=fireworks_output_data_loader,
9680
),
9781
)
9882
async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow:
Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
# MANUAL SERVER STARTUP REQUIRED:
2-
#
3-
# For Python server testing, start:
4-
# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000)
5-
#
6-
# For TypeScript server testing, start:
7-
# cd tests/remote_server/typescript-server
8-
# npm install
9-
# npm start
10-
#
11-
# The TypeScript server should be running on http://127.0.0.1:3000
12-
# You only need to start one of the servers!
13-
141
import os
152
from typing import List
163

@@ -20,35 +7,6 @@
207
from eval_protocol.models import EvaluationRow, Message
218
from eval_protocol.pytest import evaluation_test
229
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
23-
from eval_protocol.adapters.langfuse import create_langfuse_adapter
24-
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
25-
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
26-
27-
ROLLOUT_IDS = set()
28-
29-
30-
@pytest.fixture(autouse=True)
31-
def check_rollout_coverage():
32-
"""Ensure we processed all expected rollout_ids"""
33-
global ROLLOUT_IDS
34-
ROLLOUT_IDS.clear()
35-
yield
36-
37-
assert len(ROLLOUT_IDS) == 3, f"Expected to see {ROLLOUT_IDS} rollout_ids, but only saw {ROLLOUT_IDS}"
38-
39-
40-
def fetch_langfuse_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
41-
global ROLLOUT_IDS # Track all rollout_ids we've seen
42-
ROLLOUT_IDS.add(config.rollout_id)
43-
44-
adapter = create_langfuse_adapter()
45-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
46-
47-
48-
def langfuse_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
49-
return DynamicDataLoader(
50-
generators=[lambda: fetch_langfuse_traces(config)], preprocess_fn=filter_longest_conversation
51-
)
5210

5311

5412
def rows() -> List[EvaluationRow]:
@@ -62,25 +20,14 @@ def rows() -> List[EvaluationRow]:
6220
data_loaders=DynamicDataLoader(
6321
generators=[rows],
6422
),
65-
rollout_processor=RemoteRolloutProcessor(
66-
remote_base_url="http://127.0.0.1:3000",
67-
timeout_seconds=30,
68-
output_data_loader=langfuse_output_data_loader,
69-
model_base_url="https://tracing.fireworks.ai/project_id/cmg5fd57b0006y107kuxkcrhk",
70-
),
23+
rollout_processor=RemoteRolloutProcessor(remote_base_url="http://127.0.0.1:3000", timeout_seconds=30),
7124
)
7225
async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow:
7326
"""
7427
End-to-end test:
75-
- REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server
7628
- trigger remote rollout via RemoteRolloutProcessor (calls init/status)
77-
- fetch traces from Langfuse filtered by metadata via output_data_loader; FAIL if none found
7829
"""
7930
assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content"
8031
assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row."
8132

82-
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
83-
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
84-
)
85-
8633
return row

tests/remote_server/typescript-server/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ from eval_protocol import (
120120
data_loaders=[InlineDataLoader(messages=[[Message(role="user", content="Hello")]])],
121121
rollout_processor=RemoteRolloutProcessor(
122122
remote_base_url="http://localhost:3000",
123-
output_data_loader=create_output_data_loader,
124123
)
125124
)
126125
def test_remote_http(row: EvaluationRow) -> EvaluationRow:

0 commit comments

Comments
 (0)