Skip to content

Commit 91f2378

Browse files
committed
various changes
1 parent fd204bb commit 91f2378

5 files changed

Lines changed: 86 additions & 14 deletions

File tree

eval_protocol/adapters/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,9 @@
2525
except ImportError:
2626
pass
2727

28-
try:
29-
from .fireworks_tracing import FireworksTracingAdapter, create_fireworks_tracing_adapter
28+
from .fireworks_tracing import FireworksTracingAdapter, create_fireworks_tracing_adapter
3029

31-
__all__.extend(["FireworksTracingAdapter", "create_fireworks_tracing_adapter"])
32-
except ImportError:
33-
pass
30+
__all__.extend(["FireworksTracingAdapter", "create_fireworks_tracing_adapter"])
3431

3532
try:
3633
from .huggingface import (

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,9 @@ def get_evaluation_rows(
347347

348348
# Make request to proxy
349349
if self.project_id:
350-
url = f"{self.base_url}/v1/project_id/{self.project_id}/langfuse/traces"
350+
url = f"{self.base_url}/v1/project_id/{self.project_id}/traces"
351351
else:
352-
url = f"{self.base_url}/v1/langfuse/traces"
352+
url = f"{self.base_url}/v1/traces"
353353

354354
try:
355355
response = requests.post(url, json=payload, timeout=self.timeout)

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from eval_protocol.models import EvaluationRow, Status
99
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
1010
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig, InitRequest, RolloutMetadata
11+
from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter
12+
from eval_protocol.quickstart.utils import filter_longest_conversation
1113
from .rollout_processor import RolloutProcessor
1214
from .types import RolloutProcessorConfig
1315
from .elasticsearch_setup import ElasticsearchSetup
@@ -18,10 +20,30 @@
1820
logger = logging.getLogger(__name__)
1921

2022

23+
def _default_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader:
24+
"""Default output data loader that fetches traces from Fireworks tracing proxy.
25+
26+
Args:
27+
rollout_id: The rollout ID to filter traces by
28+
29+
Returns:
30+
DynamicDataLoader configured to fetch and process traces
31+
"""
32+
33+
def fetch_traces() -> List[EvaluationRow]:
34+
adapter = create_fireworks_tracing_adapter(base_url=base_url)
35+
return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5)
36+
37+
return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation)
38+
39+
2140
class RemoteRolloutProcessor(RolloutProcessor):
2241
"""
2342
Rollout processor that triggers a remote HTTP server to perform the rollout.
2443
44+
By default, fetches traces from the Fireworks tracing proxy using rollout_id tags.
45+
You can provide a custom output_data_loader for different tracing backends.
46+
2547
See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation.
2648
"""
2749

@@ -32,7 +54,7 @@ def __init__(
3254
model_base_url: str = "https://tracing.fireworks.ai",
3355
poll_interval: float = 1.0,
3456
timeout_seconds: float = 120.0,
35-
output_data_loader: Callable[[str], DynamicDataLoader],
57+
output_data_loader: Optional[Callable[[str, str], DynamicDataLoader]] = None,
3658
disable_elastic_search: bool = False,
3759
elastic_search_config: Optional[ElasticsearchConfig] = None,
3860
):
@@ -44,7 +66,7 @@ def __init__(
4466
self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL")
4567
self._poll_interval = poll_interval
4668
self._timeout_seconds = timeout_seconds
47-
self._output_data_loader = output_data_loader
69+
self._output_data_loader = output_data_loader or _default_output_data_loader
4870
self._disable_elastic_search = disable_elastic_search
4971
self._elastic_search_config = elastic_search_config
5072

@@ -242,7 +264,7 @@ def _get_status() -> Dict[str, Any]:
242264
if row.execution_metadata.rollout_id is None:
243265
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
244266

245-
data_loader = self._output_data_loader(row.execution_metadata.rollout_id)
267+
data_loader = self._output_data_loader(row.execution_metadata.rollout_id, model_base_url)
246268

247269
def _load_data():
248270
return data_loader.load()

tests/remote_server/quickstart.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# MANUAL SERVER STARTUP REQUIRED:
2+
#
3+
# For Python server testing, start:
4+
# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000)
5+
#
6+
# For TypeScript server testing, start:
7+
# cd tests/remote_server/typescript-server
8+
# npm install
9+
# npm start
10+
#
11+
# The TypeScript server should be running on http://127.0.0.1:3000
12+
# You only need to start one of the servers!
13+
14+
import os
15+
from typing import List
16+
17+
import pytest
18+
19+
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
20+
from eval_protocol.models import EvaluationRow, Message
21+
from eval_protocol.pytest import evaluation_test
22+
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
23+
24+
25+
def rows() -> List[EvaluationRow]:
26+
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
27+
return [row, row, row]
28+
29+
30+
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
31+
@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
32+
@evaluation_test(
33+
data_loaders=DynamicDataLoader(
34+
generators=[rows],
35+
),
36+
rollout_processor=RemoteRolloutProcessor(
37+
remote_base_url="http://127.0.0.1:3000",
38+
timeout_seconds=30,
39+
),
40+
)
41+
async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow:
42+
"""
43+
End-to-end test:
44+
- REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server
45+
- trigger remote rollout via RemoteRolloutProcessor (calls init/status)
46+
- fetch traces from Langfuse via Fireworks tracing proxy (uses default FireworksTracingAdapter)
47+
- FAIL if no traces found or rollout_id missing
48+
"""
49+
assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content"
50+
assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row."
51+
assert row.execution_metadata.rollout_id, "Row should have a rollout_id from the remote rollout"
52+
53+
return row

tests/remote_server/test_remote_fireworks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@ def check_rollout_coverage():
3636
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
3737

3838

39-
def fetch_fireworks_traces(rollout_id: str) -> List[EvaluationRow]:
39+
def fetch_fireworks_traces(rollout_id: str, base_url: str) -> List[EvaluationRow]:
4040
global ROLLOUT_IDS # Track all rollout_ids we've seen
4141
ROLLOUT_IDS.add(rollout_id)
4242

43-
adapter = create_fireworks_tracing_adapter()
43+
adapter = create_fireworks_tracing_adapter(base_url=base_url)
4444
return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5)
4545

4646

47-
def fireworks_output_data_loader(rollout_id: str) -> DynamicDataLoader:
47+
def fireworks_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader:
4848
return DynamicDataLoader(
49-
generators=[lambda: fetch_fireworks_traces(rollout_id)], preprocess_fn=filter_longest_conversation
49+
generators=[lambda: fetch_fireworks_traces(rollout_id, base_url)], preprocess_fn=filter_longest_conversation
5050
)
5151

5252

0 commit comments

Comments
 (0)