Skip to content

Commit 7657edb

Browse files
authored
GithubActionRolloutProcessor (#273)
* gha test * switch to prompt * fix rollout worker * test naming * test * update test * working example now * remove unused code * fix test * merged wrong rollout yml * remove some uneeded imports * make better comment * add rate limit retry logic * addressed comments
1 parent e8dde79 commit 7657edb

26 files changed

+748
-329
lines changed

.github/workflows/rollout.yml

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,26 @@
11
name: Eval Protocol Rollout
22

3-
run-name: rollout:${{ inputs.rollout_id }}
3+
run-name: rollout:${{ fromJSON(inputs.metadata).rollout_id }}
44

55
on:
66
workflow_dispatch:
77
inputs:
88
model:
9-
description: 'Model to use for the rollout'
9+
description: 'Model to use'
1010
required: true
1111
type: string
12-
rollout_id:
13-
description: 'Rollout ID for tracking'
12+
metadata:
13+
description: 'JSON serialized metadata object'
1414
required: true
1515
type: string
16-
prompt:
17-
description: 'User prompt for the rollout'
16+
model_base_url:
17+
description: 'Base URL for the model API'
1818
required: true
1919
type: string
2020

2121
jobs:
2222
rollout:
2323
runs-on: ubuntu-latest
24-
name: rollout-${{ inputs.rollout_id }}
2524

2625
steps:
2726
- name: Checkout code
@@ -43,13 +42,5 @@ jobs:
4342
run: |
4443
python tests/github_actions/rollout_worker.py \
4544
--model "${{ inputs.model }}" \
46-
--rollout-id "${{ inputs.rollout_id }}" \
47-
--prompt "${{ inputs.prompt }}"
48-
49-
- name: Upload rollout trace
50-
uses: actions/upload-artifact@v4
51-
if: always() # Upload even if the rollout failed
52-
with:
53-
name: rollout-trace-${{ inputs.rollout_id }}
54-
path: rollout_trace_${{ inputs.rollout_id }}.json
55-
retention-days: 7
45+
--metadata '${{ inputs.metadata }}' \
46+
--model-base-url "${{ inputs.model_base_url }}"

eval_protocol/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from .reward_function import RewardFunction
3131
from .typed_interface import reward_function
3232
from .quickstart import aha_judge, multi_turn_assistant_to_ground_truth, assistant_to_ground_truth
33-
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor
33+
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor, GithubActionRolloutProcessor
3434
from .pytest.remote_rollout_processor import create_elasticsearch_config_from_env
3535
from .pytest.parameterize import DefaultParameterIdGenerator
3636
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
@@ -93,6 +93,7 @@
9393
"DataLoaderConfig",
9494
"Status",
9595
"RemoteRolloutProcessor",
96+
"GithubActionRolloutProcessor",
9697
"InputMetadata",
9798
"EvaluationRow",
9899
"DefaultParameterIdGenerator",

eval_protocol/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,9 @@ class EvaluationRow(BaseModel):
598598
model_config = ConfigDict(extra="allow")
599599

600600
# Core OpenAI ChatCompletion compatible conversation data
601-
messages: List[Message] = Field(description="List of messages in the conversation. Also known as a trajectory.")
601+
messages: List[Message] = Field(
602+
default_factory=list, description="List of messages in the conversation. Also known as a trajectory."
603+
)
602604

603605
# Tool and function call information
604606
tools: Optional[List[Dict[str, Any]]] = Field(

eval_protocol/pytest/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .default_no_op_rollout_processor import NoOpRolloutProcessor
55
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
66
from .remote_rollout_processor import RemoteRolloutProcessor
7+
from .github_action_rollout_processor import GithubActionRolloutProcessor
78
from .evaluation_test import evaluation_test
89
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
910
from .rollout_processor import RolloutProcessor
@@ -33,6 +34,7 @@
3334
"RolloutProcessor",
3435
"SingleTurnRolloutProcessor",
3536
"RemoteRolloutProcessor",
37+
"GithubActionRolloutProcessor",
3638
"NoOpRolloutProcessor",
3739
"default_dataset_adapter",
3840
"RolloutProcessorConfig",

eval_protocol/pytest/evaluation_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
)
4848

4949

50-
from eval_protocol.pytest.utils import (
50+
from eval_protocol.pytest.evaluation_test_utils import (
5151
AggregationMethod,
5252
add_cost_metrics,
5353
log_eval_status_and_rows,

eval_protocol/pytest/evaluation_test_postprocess.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from eval_protocol.models import CompletionParams, EvaluationRow, EvaluationThreshold, Status
1111
from eval_protocol.pytest.handle_persist_flow import handle_persist_flow
1212
from eval_protocol.pytest.types import EvaluationTestMode
13-
from eval_protocol.pytest.utils import AggregationMethod, aggregate, extract_effort_tag, sanitize_filename
13+
from eval_protocol.pytest.evaluation_test_utils import (
14+
AggregationMethod,
15+
aggregate,
16+
extract_effort_tag,
17+
sanitize_filename,
18+
)
1419
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci
1520

1621

File renamed without changes.

eval_protocol/pytest/generate_parameter_combinations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from eval_protocol.data_loader.models import EvaluationDataLoader
33
from eval_protocol.models import CompletionParams, EvaluationRow
44
from eval_protocol.pytest.types import Dataset, DatasetPathParam, EvaluationInputParam, InputMessagesParam
5-
from eval_protocol.pytest.utils import parse_ep_max_rows
5+
from eval_protocol.pytest.evaluation_test_utils import parse_ep_max_rows
66
from collections.abc import Sequence
77

88

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import asyncio
2+
import os
3+
import time
4+
from typing import Any, Callable, Dict, List, Optional
5+
6+
import requests
7+
from datetime import datetime, timezone, timedelta
8+
from eval_protocol.models import EvaluationRow, Status
9+
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
10+
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
11+
12+
from .rollout_processor import RolloutProcessor
13+
from .types import RolloutProcessorConfig
14+
from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace
15+
16+
17+
class GithubActionRolloutProcessor(RolloutProcessor):
18+
"""
19+
Rollout processor that dispatches and monitors a GitHub Actions workflow per evaluation row.
20+
21+
Expected GitHub Actions workflow:
22+
- Workflow dispatch with inputs: model, metadata (JSON), model_base_url
23+
- Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy)
24+
- Traces are fetched later via output_data_loader using rollout_id tags
25+
26+
NOTE: GHA has a rate limit of 5000 requests per hour.
27+
"""
28+
29+
def __init__(
30+
self,
31+
*,
32+
owner: str,
33+
repo: str,
34+
workflow_id: str,
35+
ref: str = "main",
36+
model_base_url: str = "https://tracing.fireworks.ai",
37+
poll_interval: float = 10.0,
38+
timeout_seconds: float = 1800.0,
39+
max_find_workflow_retries: int = 5,
40+
github_token: Optional[str] = None,
41+
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
42+
):
43+
self.owner = owner
44+
self.repo = repo
45+
self.workflow_id = workflow_id
46+
self.ref = ref
47+
self.model_base_url = model_base_url
48+
_ep_model_base_url = os.getenv("EP_MODEL_BASE_URL")
49+
if _ep_model_base_url:
50+
self.model_base_url = _ep_model_base_url
51+
self.poll_interval = poll_interval
52+
self.timeout_seconds = timeout_seconds
53+
self.max_find_workflow_retries = max_find_workflow_retries
54+
self.github_token = github_token
55+
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
56+
57+
def _headers(self) -> Dict[str, str]:
58+
headers = {"Accept": "application/vnd.github+json"}
59+
token = self.github_token or os.getenv("GITHUB_TOKEN")
60+
if not token:
61+
raise ValueError(
62+
"GitHub token is required. Provide it via github_token parameter or GITHUB_TOKEN environment variable"
63+
)
64+
headers["Authorization"] = f"Bearer {token}"
65+
return headers
66+
67+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
68+
# Calculate max_pages based on number of rows we're processing
69+
num_rows = len(rows)
70+
max_pages = (num_rows + 99) // 100 # Round up pages
71+
72+
async def _process_row(row: EvaluationRow) -> EvaluationRow:
73+
start_time = time.perf_counter()
74+
75+
if row.execution_metadata.invocation_id is None:
76+
raise ValueError("Invocation ID is required in GithubActionRolloutProcessor")
77+
if row.execution_metadata.experiment_id is None:
78+
raise ValueError("Experiment ID is required in GithubActionRolloutProcessor")
79+
if row.execution_metadata.rollout_id is None:
80+
raise ValueError("Rollout ID is required in GithubActionRolloutProcessor")
81+
if row.execution_metadata.run_id is None:
82+
raise ValueError("Run ID is required in GithubActionRolloutProcessor")
83+
if row.input_metadata.row_id is None:
84+
raise ValueError("Row ID is required in GithubActionRolloutProcessor")
85+
86+
init_request = build_init_request(row, config, self.model_base_url)
87+
88+
def _dispatch_workflow():
89+
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/dispatches"
90+
payload = {
91+
"ref": self.ref,
92+
"inputs": {
93+
"model": init_request.model,
94+
"metadata": init_request.metadata.model_dump_json(),
95+
"model_base_url": init_request.model_base_url,
96+
},
97+
}
98+
r = requests.post(url, json=payload, headers=self._headers(), timeout=30)
99+
r.raise_for_status()
100+
101+
await asyncio.to_thread(_dispatch_workflow)
102+
103+
run = None
104+
target_name = f"rollout:{row.execution_metadata.rollout_id}"
105+
106+
# Look for runs created in the last 15 minutes (we just dispatched it)
107+
cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=15)
108+
cutoff_iso = cutoff_time.isoformat()
109+
110+
for attempt in range(self.max_find_workflow_retries):
111+
try:
112+
page = 1
113+
while page <= max_pages:
114+
115+
def _list_runs():
116+
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/runs"
117+
params = {
118+
"event": "workflow_dispatch",
119+
"branch": self.ref,
120+
"per_page": 100, # Max per_page is 100, minimize total number of pages
121+
"page": page,
122+
"created": f">={cutoff_iso}", # Only look at recent runs
123+
}
124+
125+
r = requests.get(url, params=params, headers=self._headers(), timeout=30)
126+
r.raise_for_status()
127+
return r.json()
128+
129+
runs_data = await asyncio.to_thread(_list_runs)
130+
131+
# Search for our target run in this page
132+
for candidate_run in runs_data.get("workflow_runs", []):
133+
if candidate_run.get("name") == target_name:
134+
run = candidate_run
135+
136+
# If we got fewer results than 100, we've reached the end, since we paginate in chunks of 100
137+
if len(runs_data.get("workflow_runs", [])) < 100:
138+
break
139+
140+
page += 1
141+
142+
# If no run found, GHA might still be populating it, retry
143+
if attempt < self.max_find_workflow_retries - 1:
144+
delay = 2**attempt # Exponential backoff
145+
await asyncio.sleep(delay)
146+
147+
except requests.exceptions.HTTPError as e:
148+
# Retry on rate limits (HTTP 429)
149+
if e.response and e.response.status_code == 429:
150+
if attempt < self.max_find_workflow_retries - 1:
151+
delay = 2**attempt # Exponential backoff
152+
await asyncio.sleep(delay)
153+
else:
154+
# Give up after max attempts
155+
raise e
156+
else:
157+
raise e
158+
159+
if not run:
160+
row.rollout_status = Status.rollout_error(
161+
f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}"
162+
)
163+
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
164+
return row
165+
166+
run_id = run.get("id")
167+
if not run_id:
168+
row.rollout_status = Status.rollout_error(
169+
f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}"
170+
)
171+
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
172+
return row
173+
174+
# Poll the specific run until completion
175+
deadline = time.time() + self.timeout_seconds
176+
177+
def _get_run() -> Dict[str, Any]:
178+
"""Get status of a specific workflow run."""
179+
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/runs/{run_id}"
180+
r = requests.get(url, headers=self._headers(), timeout=30)
181+
r.raise_for_status()
182+
return r.json()
183+
184+
while time.time() < deadline:
185+
run_data = await asyncio.to_thread(_get_run)
186+
187+
if run_data.get("status") == "completed":
188+
break
189+
190+
await asyncio.sleep(self.poll_interval)
191+
else:
192+
row.rollout_status = Status.rollout_error(
193+
f"GitHub Actions run timed out after {self.timeout_seconds} seconds"
194+
)
195+
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
196+
return row
197+
198+
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
199+
200+
def _update_with_trace() -> None:
201+
return update_row_with_remote_trace(row, self._output_data_loader, self.model_base_url)
202+
203+
await asyncio.to_thread(_update_with_trace)
204+
205+
# Add GitHub Actions run URL to session data
206+
if run_id:
207+
github_run_url = f"https://github.com/{self.owner}/{self.repo}/actions/runs/{run_id}"
208+
if not row.input_metadata.session_data:
209+
row.input_metadata.session_data = {}
210+
row.input_metadata.session_data["github_actions_run_url"] = github_run_url
211+
212+
return row
213+
214+
semaphore = config.semaphore
215+
216+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
217+
async with semaphore:
218+
return await _process_row(r)
219+
220+
return [asyncio.create_task(_sem_wrapper(row)) for row in rows]
221+
222+
def cleanup(self) -> None:
223+
return None

0 commit comments

Comments
 (0)