Skip to content

Commit 260f721

Browse files
committed
add dataloaderconfig
1 parent 712a37d commit 260f721

4 files changed

Lines changed: 35 additions & 17 deletions

File tree

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from eval_protocol.log_utils.elasticsearch_client import ElasticsearchClient
88
from eval_protocol.models import EvaluationRow, Status
99
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
10-
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig, InitRequest, RolloutMetadata
10+
from eval_protocol.types.remote_rollout_processor import (
11+
DataLoaderConfig,
12+
ElasticsearchConfig,
13+
InitRequest,
14+
RolloutMetadata,
15+
)
1116
from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter
1217
from eval_protocol.quickstart.utils import filter_longest_conversation
1318
from .rollout_processor import RolloutProcessor
@@ -20,19 +25,20 @@
2025
logger = logging.getLogger(__name__)
2126

2227

23-
def _default_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader:
28+
def _default_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
2429
"""Default output data loader that fetches traces from Fireworks tracing proxy.
2530
2631
Args:
27-
rollout_id: The rollout ID to filter traces by
32+
config: Configuration containing rollout_id and optional model_base_url
2833
2934
Returns:
3035
DynamicDataLoader configured to fetch and process traces
3136
"""
3237

3338
def fetch_traces() -> List[EvaluationRow]:
39+
base_url = config.model_base_url or "https://tracing.fireworks.ai"
3440
adapter = create_fireworks_tracing_adapter(base_url=base_url)
35-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5)
41+
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
3642

3743
return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation)
3844

@@ -54,7 +60,7 @@ def __init__(
5460
model_base_url: str = "https://tracing.fireworks.ai",
5561
poll_interval: float = 1.0,
5662
timeout_seconds: float = 120.0,
57-
output_data_loader: Optional[Callable[[str, str], DynamicDataLoader]] = None,
63+
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
5864
disable_elastic_search: bool = False,
5965
elastic_search_config: Optional[ElasticsearchConfig] = None,
6066
):
@@ -64,7 +70,6 @@ def __init__(
6470
self._model_base_url = model_base_url
6571
if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"):
6672
self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL")
67-
self._model_base_url = model_base_url
6873
_ep_model_base_url = os.getenv("EP_MODEL_BASE_URL")
6974
if _ep_model_base_url:
7075
self._model_base_url = _ep_model_base_url
@@ -268,7 +273,10 @@ def _get_status() -> Dict[str, Any]:
268273
if row.execution_metadata.rollout_id is None:
269274
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
270275

271-
data_loader = self._output_data_loader(row.execution_metadata.rollout_id, model_base_url)
276+
loader_config = DataLoaderConfig(
277+
rollout_id=row.execution_metadata.rollout_id, model_base_url=model_base_url
278+
)
279+
data_loader = self._output_data_loader(loader_config)
272280

273281
def _load_data():
274282
return data_loader.load()

eval_protocol/types/remote_rollout_processor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ class RolloutMetadata(BaseModel):
3434
row_id: str
3535

3636

37+
class DataLoaderConfig(BaseModel):
38+
"""Configuration passed to output_data_loader functions."""
39+
40+
rollout_id: str
41+
model_base_url: Optional[str] = None
42+
43+
3744
class InitRequest(BaseModel):
3845
"""Request model for POST /init endpoint."""
3946

tests/remote_server/test_remote_fireworks.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
2323
from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter
2424
from eval_protocol.quickstart.utils import filter_longest_conversation
25+
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
2526

2627
ROLLOUT_IDS = set()
2728

@@ -36,17 +37,18 @@ def check_rollout_coverage():
3637
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
3738

3839

39-
def fetch_fireworks_traces(rollout_id: str, base_url: str) -> List[EvaluationRow]:
40+
def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
4041
global ROLLOUT_IDS # Track all rollout_ids we've seen
41-
ROLLOUT_IDS.add(rollout_id)
42+
ROLLOUT_IDS.add(config.rollout_id)
4243

44+
base_url = config.model_base_url or "https://tracing.fireworks.ai"
4345
adapter = create_fireworks_tracing_adapter(base_url=base_url)
44-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5)
46+
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
4547

4648

47-
def fireworks_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader:
49+
def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
4850
return DynamicDataLoader(
49-
generators=[lambda: fetch_fireworks_traces(rollout_id, base_url)], preprocess_fn=filter_longest_conversation
51+
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
5052
)
5153

5254

tests/remote_server/test_remote_langfuse.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
2323
from eval_protocol.adapters.langfuse import create_langfuse_adapter
2424
from eval_protocol.quickstart.utils import filter_longest_conversation
25+
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
2526

2627
ROLLOUT_IDS = set()
2728

@@ -36,17 +37,17 @@ def check_rollout_coverage():
3637
assert len(ROLLOUT_IDS) == 3, f"Expected to see {ROLLOUT_IDS} rollout_ids, but only saw {ROLLOUT_IDS}"
3738

3839

39-
def fetch_langfuse_traces(rollout_id: str) -> List[EvaluationRow]:
40+
def fetch_langfuse_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
4041
global ROLLOUT_IDS # Track all rollout_ids we've seen
41-
ROLLOUT_IDS.add(rollout_id)
42+
ROLLOUT_IDS.add(config.rollout_id)
4243

4344
adapter = create_langfuse_adapter()
44-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5)
45+
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
4546

4647

47-
def langfuse_output_data_loader(rollout_id: str) -> DynamicDataLoader:
48+
def langfuse_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
4849
return DynamicDataLoader(
49-
generators=[lambda: fetch_langfuse_traces(rollout_id)], preprocess_fn=filter_longest_conversation
50+
generators=[lambda: fetch_langfuse_traces(config)], preprocess_fn=filter_longest_conversation
5051
)
5152

5253

0 commit comments

Comments
 (0)