77from eval_protocol .log_utils .elasticsearch_client import ElasticsearchClient
88from eval_protocol .models import EvaluationRow , Status
99from eval_protocol .data_loader .dynamic_data_loader import DynamicDataLoader
10- from eval_protocol .types .remote_rollout_processor import ElasticsearchConfig , InitRequest , RolloutMetadata
10+ from eval_protocol .types .remote_rollout_processor import (
11+ DataLoaderConfig ,
12+ ElasticsearchConfig ,
13+ InitRequest ,
14+ RolloutMetadata ,
15+ )
1116from eval_protocol .adapters .fireworks_tracing import create_fireworks_tracing_adapter
1217from eval_protocol .quickstart .utils import filter_longest_conversation
1318from .rollout_processor import RolloutProcessor
2025logger = logging .getLogger (__name__ )
2126
2227
23- def _default_output_data_loader (rollout_id : str , base_url : str ) -> DynamicDataLoader :
28+ def _default_output_data_loader (config : DataLoaderConfig ) -> DynamicDataLoader :
2429 """Default output data loader that fetches traces from Fireworks tracing proxy.
2530
2631 Args:
27- rollout_id: The rollout ID to filter traces by
32+ config: Configuration containing rollout_id and optional model_base_url
2833
2934 Returns:
3035 DynamicDataLoader configured to fetch and process traces
3136 """
3237
3338 def fetch_traces () -> List [EvaluationRow ]:
39+ base_url = config .model_base_url or "https://tracing.fireworks.ai"
3440 adapter = create_fireworks_tracing_adapter (base_url = base_url )
35- return adapter .get_evaluation_rows (tags = [f"rollout_id:{ rollout_id } " ], max_retries = 5 )
41+ return adapter .get_evaluation_rows (tags = [f"rollout_id:{ config . rollout_id } " ], max_retries = 5 )
3642
3743 return DynamicDataLoader (generators = [fetch_traces ], preprocess_fn = filter_longest_conversation )
3844
@@ -54,7 +60,7 @@ def __init__(
5460 model_base_url : str = "https://tracing.fireworks.ai" ,
5561 poll_interval : float = 1.0 ,
5662 timeout_seconds : float = 120.0 ,
57- output_data_loader : Optional [Callable [[str , str ], DynamicDataLoader ]] = None ,
63+ output_data_loader : Optional [Callable [[DataLoaderConfig ], DynamicDataLoader ]] = None ,
5864 disable_elastic_search : bool = False ,
5965 elastic_search_config : Optional [ElasticsearchConfig ] = None ,
6066 ):
@@ -64,7 +70,6 @@ def __init__(
6470 self ._model_base_url = model_base_url
6571 if os .getenv ("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL" ):
6672 self ._remote_base_url = os .getenv ("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL" )
67- self ._model_base_url = model_base_url
6873 _ep_model_base_url = os .getenv ("EP_MODEL_BASE_URL" )
6974 if _ep_model_base_url :
7075 self ._model_base_url = _ep_model_base_url
@@ -268,7 +273,10 @@ def _get_status() -> Dict[str, Any]:
268273 if row .execution_metadata .rollout_id is None :
269274 raise ValueError ("Rollout ID is required in RemoteRolloutProcessor" )
270275
271- data_loader = self ._output_data_loader (row .execution_metadata .rollout_id , model_base_url )
276+ loader_config = DataLoaderConfig (
277+ rollout_id = row .execution_metadata .rollout_id , model_base_url = model_base_url
278+ )
279+ data_loader = self ._output_data_loader (loader_config )
272280
273281 def _load_data ():
274282 return data_loader .load ()
0 commit comments