|
| 1 | +import asyncio |
| 2 | +import os |
| 3 | +import time |
| 4 | +from typing import Any, Callable, Dict, List, Optional |
| 5 | + |
| 6 | +import requests |
| 7 | +from datetime import datetime, timezone, timedelta |
| 8 | +from eval_protocol.models import EvaluationRow, Status |
| 9 | +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader |
| 10 | +from eval_protocol.types.remote_rollout_processor import DataLoaderConfig |
| 11 | + |
| 12 | +from .rollout_processor import RolloutProcessor |
| 13 | +from .types import RolloutProcessorConfig |
| 14 | +from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace |
| 15 | + |
| 16 | + |
| 17 | +class GithubActionRolloutProcessor(RolloutProcessor): |
| 18 | + """ |
| 19 | + Rollout processor that dispatches and monitors a GitHub Actions workflow per evaluation row. |
| 20 | +
|
| 21 | + Expected GitHub Actions workflow: |
| 22 | + - Workflow dispatch with inputs: model, metadata (JSON), model_base_url |
| 23 | + - Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy) |
| 24 | + - Traces are fetched later via output_data_loader using rollout_id tags |
| 25 | +
|
| 26 | + NOTE: GHA has a rate limit of 5000 requests per hour. |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + *, |
| 32 | + owner: str, |
| 33 | + repo: str, |
| 34 | + workflow_id: str, |
| 35 | + ref: str = "main", |
| 36 | + model_base_url: str = "https://tracing.fireworks.ai", |
| 37 | + poll_interval: float = 10.0, |
| 38 | + timeout_seconds: float = 1800.0, |
| 39 | + max_find_workflow_retries: int = 5, |
| 40 | + github_token: Optional[str] = None, |
| 41 | + output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None, |
| 42 | + ): |
| 43 | + self.owner = owner |
| 44 | + self.repo = repo |
| 45 | + self.workflow_id = workflow_id |
| 46 | + self.ref = ref |
| 47 | + self.model_base_url = model_base_url |
| 48 | + _ep_model_base_url = os.getenv("EP_MODEL_BASE_URL") |
| 49 | + if _ep_model_base_url: |
| 50 | + self.model_base_url = _ep_model_base_url |
| 51 | + self.poll_interval = poll_interval |
| 52 | + self.timeout_seconds = timeout_seconds |
| 53 | + self.max_find_workflow_retries = max_find_workflow_retries |
| 54 | + self.github_token = github_token |
| 55 | + self._output_data_loader = output_data_loader or default_fireworks_output_data_loader |
| 56 | + |
| 57 | + def _headers(self) -> Dict[str, str]: |
| 58 | + headers = {"Accept": "application/vnd.github+json"} |
| 59 | + token = self.github_token or os.getenv("GITHUB_TOKEN") |
| 60 | + if not token: |
| 61 | + raise ValueError( |
| 62 | + "GitHub token is required. Provide it via github_token parameter or GITHUB_TOKEN environment variable" |
| 63 | + ) |
| 64 | + headers["Authorization"] = f"Bearer {token}" |
| 65 | + return headers |
| 66 | + |
| 67 | + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: |
| 68 | + # Calculate max_pages based on number of rows we're processing |
| 69 | + num_rows = len(rows) |
| 70 | + max_pages = (num_rows + 99) // 100 # Round up pages |
| 71 | + |
| 72 | + async def _process_row(row: EvaluationRow) -> EvaluationRow: |
| 73 | + start_time = time.perf_counter() |
| 74 | + |
| 75 | + if row.execution_metadata.invocation_id is None: |
| 76 | + raise ValueError("Invocation ID is required in GithubActionRolloutProcessor") |
| 77 | + if row.execution_metadata.experiment_id is None: |
| 78 | + raise ValueError("Experiment ID is required in GithubActionRolloutProcessor") |
| 79 | + if row.execution_metadata.rollout_id is None: |
| 80 | + raise ValueError("Rollout ID is required in GithubActionRolloutProcessor") |
| 81 | + if row.execution_metadata.run_id is None: |
| 82 | + raise ValueError("Run ID is required in GithubActionRolloutProcessor") |
| 83 | + if row.input_metadata.row_id is None: |
| 84 | + raise ValueError("Row ID is required in GithubActionRolloutProcessor") |
| 85 | + |
| 86 | + init_request = build_init_request(row, config, self.model_base_url) |
| 87 | + |
| 88 | + def _dispatch_workflow(): |
| 89 | + url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/dispatches" |
| 90 | + payload = { |
| 91 | + "ref": self.ref, |
| 92 | + "inputs": { |
| 93 | + "model": init_request.model, |
| 94 | + "metadata": init_request.metadata.model_dump_json(), |
| 95 | + "model_base_url": init_request.model_base_url, |
| 96 | + }, |
| 97 | + } |
| 98 | + r = requests.post(url, json=payload, headers=self._headers(), timeout=30) |
| 99 | + r.raise_for_status() |
| 100 | + |
| 101 | + await asyncio.to_thread(_dispatch_workflow) |
| 102 | + |
| 103 | + run = None |
| 104 | + target_name = f"rollout:{row.execution_metadata.rollout_id}" |
| 105 | + |
| 106 | + # Look for runs created in the last 15 minutes (we just dispatched it) |
| 107 | + cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=15) |
| 108 | + cutoff_iso = cutoff_time.isoformat() |
| 109 | + |
| 110 | + for attempt in range(self.max_find_workflow_retries): |
| 111 | + try: |
| 112 | + page = 1 |
| 113 | + while page <= max_pages: |
| 114 | + |
| 115 | + def _list_runs(): |
| 116 | + url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/runs" |
| 117 | + params = { |
| 118 | + "event": "workflow_dispatch", |
| 119 | + "branch": self.ref, |
| 120 | + "per_page": 100, # Max per_page is 100, minimize total number of pages |
| 121 | + "page": page, |
| 122 | + "created": f">={cutoff_iso}", # Only look at recent runs |
| 123 | + } |
| 124 | + |
| 125 | + r = requests.get(url, params=params, headers=self._headers(), timeout=30) |
| 126 | + r.raise_for_status() |
| 127 | + return r.json() |
| 128 | + |
| 129 | + runs_data = await asyncio.to_thread(_list_runs) |
| 130 | + |
| 131 | + # Search for our target run in this page |
| 132 | + for candidate_run in runs_data.get("workflow_runs", []): |
| 133 | + if candidate_run.get("name") == target_name: |
| 134 | + run = candidate_run |
| 135 | + |
| 136 | + # If we got fewer results than 100, we've reached the end, since we paginate in chunks of 100 |
| 137 | + if len(runs_data.get("workflow_runs", [])) < 100: |
| 138 | + break |
| 139 | + |
| 140 | + page += 1 |
| 141 | + |
| 142 | + # If no run found, GHA might still be populating it, retry |
| 143 | + if attempt < self.max_find_workflow_retries - 1: |
| 144 | + delay = 2**attempt # Exponential backoff |
| 145 | + await asyncio.sleep(delay) |
| 146 | + |
| 147 | + except requests.exceptions.HTTPError as e: |
| 148 | + # Retry on rate limits (HTTP 429) |
| 149 | + if e.response and e.response.status_code == 429: |
| 150 | + if attempt < self.max_find_workflow_retries - 1: |
| 151 | + delay = 2**attempt # Exponential backoff |
| 152 | + await asyncio.sleep(delay) |
| 153 | + else: |
| 154 | + # Give up after max attempts |
| 155 | + raise e |
| 156 | + else: |
| 157 | + raise e |
| 158 | + |
| 159 | + if not run: |
| 160 | + row.rollout_status = Status.rollout_error( |
| 161 | + f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}" |
| 162 | + ) |
| 163 | + row.execution_metadata.duration_seconds = time.perf_counter() - start_time |
| 164 | + return row |
| 165 | + |
| 166 | + run_id = run.get("id") |
| 167 | + if not run_id: |
| 168 | + row.rollout_status = Status.rollout_error( |
| 169 | + f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}" |
| 170 | + ) |
| 171 | + row.execution_metadata.duration_seconds = time.perf_counter() - start_time |
| 172 | + return row |
| 173 | + |
| 174 | + # Poll the specific run until completion |
| 175 | + deadline = time.time() + self.timeout_seconds |
| 176 | + |
| 177 | + def _get_run() -> Dict[str, Any]: |
| 178 | + """Get status of a specific workflow run.""" |
| 179 | + url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/runs/{run_id}" |
| 180 | + r = requests.get(url, headers=self._headers(), timeout=30) |
| 181 | + r.raise_for_status() |
| 182 | + return r.json() |
| 183 | + |
| 184 | + while time.time() < deadline: |
| 185 | + run_data = await asyncio.to_thread(_get_run) |
| 186 | + |
| 187 | + if run_data.get("status") == "completed": |
| 188 | + break |
| 189 | + |
| 190 | + await asyncio.sleep(self.poll_interval) |
| 191 | + else: |
| 192 | + row.rollout_status = Status.rollout_error( |
| 193 | + f"GitHub Actions run timed out after {self.timeout_seconds} seconds" |
| 194 | + ) |
| 195 | + row.execution_metadata.duration_seconds = time.perf_counter() - start_time |
| 196 | + return row |
| 197 | + |
| 198 | + row.execution_metadata.duration_seconds = time.perf_counter() - start_time |
| 199 | + |
| 200 | + def _update_with_trace() -> None: |
| 201 | + return update_row_with_remote_trace(row, self._output_data_loader, self.model_base_url) |
| 202 | + |
| 203 | + await asyncio.to_thread(_update_with_trace) |
| 204 | + |
| 205 | + # Add GitHub Actions run URL to session data |
| 206 | + if run_id: |
| 207 | + github_run_url = f"https://github.com/{self.owner}/{self.repo}/actions/runs/{run_id}" |
| 208 | + if not row.input_metadata.session_data: |
| 209 | + row.input_metadata.session_data = {} |
| 210 | + row.input_metadata.session_data["github_actions_run_url"] = github_run_url |
| 211 | + |
| 212 | + return row |
| 213 | + |
| 214 | + semaphore = config.semaphore |
| 215 | + |
| 216 | + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: |
| 217 | + async with semaphore: |
| 218 | + return await _process_row(r) |
| 219 | + |
| 220 | + return [asyncio.create_task(_sem_wrapper(row)) for row in rows] |
| 221 | + |
| 222 | + def cleanup(self) -> None: |
| 223 | + return None |
0 commit comments