|
| 1 | +import asyncio |
| 2 | +import os |
| 3 | +from collections import defaultdict |
| 4 | +from dataclasses import dataclass, field |
| 5 | +from typing import Any, Callable, List, Dict, Optional, Union, Awaitable |
| 6 | + |
| 7 | +from eval_protocol.models import EvaluationRow, Status |
| 8 | +from eval_protocol.pytest.types import RolloutProcessorConfig |
| 9 | +from eval_protocol.pytest.rollout_processor import RolloutProcessor |
| 10 | +from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry |
| 11 | +from eval_protocol.pytest.buffer import MiniBatchDataBuffer |
| 12 | +from eval_protocol.dataset_logger.dataset_logger import DatasetLogger |
| 13 | +from eval_protocol.human_id import generate_id |
| 14 | + |
| 15 | +@dataclass(order=True) |
| 16 | +class RolloutTask: |
| 17 | + """ |
| 18 | + Represents a single unit of work for the worker pool. |
| 19 | + Priority tuple structure: (status, row_index) |
| 20 | + - status: 0 = High Priority (e.g., subsequent micro-batches of an already started sample) |
| 21 | + 1 = Low Priority (e.g., starting a new sample) |
| 22 | + - row_index: Used to maintain dataset order for initial scheduling |
| 23 | + """ |
| 24 | + priority: tuple[int, int] |
| 25 | + |
| 26 | + # Payload (excluded from comparison) |
| 27 | + row: EvaluationRow = field(compare=False) |
| 28 | + run_indices: List[int] = field(compare=False) # Which runs to execute in this task |
| 29 | + config: RolloutProcessorConfig = field(compare=False) |
| 30 | + row_index: int = field(compare=False) # To track which sample this belongs to |
| 31 | + |
| 32 | + # History for speculation (injected from previous micro-batches) |
| 33 | + history: List[str] = field(compare=False, default_factory=list) |
| 34 | + |
| 35 | +class PriorityRolloutScheduler: |
| 36 | + """ |
| 37 | + Manages a priority queue of rollout tasks and a pool of workers. |
| 38 | + Ensures that once a sample starts processing, its subsequent micro-batches |
| 39 | + are prioritized to complete the sample as quickly as possible. |
| 40 | + """ |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + rollout_processor: RolloutProcessor, |
| 44 | + max_concurrent_rollouts: int, |
| 45 | + active_logger: DatasetLogger, |
| 46 | + eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], # Callback to run evaluation |
| 47 | + mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None, |
| 48 | + ): |
| 49 | + self.rollout_processor = rollout_processor |
| 50 | + self.max_concurrent_rollouts = max_concurrent_rollouts |
| 51 | + self.active_logger = active_logger |
| 52 | + self.eval_executor = eval_executor |
| 53 | + self.mini_batch_data_buffer = mini_batch_data_buffer |
| 54 | + |
| 55 | + # Priority Queue: Stores RolloutTask |
| 56 | + self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue() |
| 57 | + |
| 58 | + self.num_runs = 0 |
| 59 | + self.micro_batch_size = 0 |
| 60 | + |
| 61 | + async def schedule_dataset( |
| 62 | + self, |
| 63 | + dataset: List[EvaluationRow], |
| 64 | + base_config: RolloutProcessorConfig, |
| 65 | + ): |
| 66 | + """ |
| 67 | + Populates the queue with initial tasks (the first micro-batch for each sample). |
| 68 | + """ |
| 69 | + for i, row in enumerate(dataset): |
| 70 | + # Calculate ranges for the first micro-batch |
| 71 | + batch_start = 0 |
| 72 | + # Ensure micro_batch_size is at least 1 to avoid infinite loop or stuck tasks |
| 73 | + safe_batch_size = self.micro_batch_size if self.micro_batch_size > 0 else self.num_runs |
| 74 | + batch_end = min(safe_batch_size, self.num_runs) |
| 75 | + run_indices = list(range(batch_start, batch_end)) |
| 76 | + |
| 77 | + # Initial priority: Low (1), ordered by dataset index |
| 78 | + priority = (1, i) |
| 79 | + |
| 80 | + task = RolloutTask( |
| 81 | + priority=priority, |
| 82 | + row=row, |
| 83 | + run_indices=run_indices, |
| 84 | + config=base_config, |
| 85 | + row_index=i, |
| 86 | + history=[] # Initial batch has no history |
| 87 | + ) |
| 88 | + self.queue.put_nowait(task) |
| 89 | + |
| 90 | + async def worker(self): |
| 91 | + """ |
| 92 | + Worker loop: fetch task -> execute micro-batch -> schedule next batch (if any). |
| 93 | + """ |
| 94 | + while True: |
| 95 | + try: |
| 96 | + # Get a task from the priority queue |
| 97 | + task: RolloutTask = await self.queue.get() |
| 98 | + except asyncio.QueueEmpty: |
| 99 | + break |
| 100 | + |
| 101 | + try: |
| 102 | + await self._process_task(task) |
| 103 | + except Exception as e: |
| 104 | + print(f"Error processing task for row {task.row.input_metadata.row_id}: {e}") |
| 105 | + finally: |
| 106 | + self.queue.task_done() |
| 107 | + |
| 108 | + async def _process_task(self, task: RolloutTask): |
| 109 | + """ |
| 110 | + Executes a single micro-batch task. |
| 111 | + """ |
| 112 | + # 1. Prepare Config & Row for this micro-batch |
| 113 | + current_batch_rows = [] |
| 114 | + for run_idx in task.run_indices: |
| 115 | + row_copy = task.row.model_copy(deep=True) |
| 116 | + |
| 117 | + row_copy.execution_metadata.run_id = generate_id() |
| 118 | + row_copy.execution_metadata.rollout_id = generate_id() |
| 119 | + |
| 120 | + # Inject Speculation History |
| 121 | + if task.history: |
| 122 | + cp = row_copy.input_metadata.completion_params |
| 123 | + # Ensure safe dict access |
| 124 | + if not isinstance(cp, dict): |
| 125 | + cp = {} |
| 126 | + # Need to check and initialize nested dicts |
| 127 | + extra_body = cp.get("extra_body") |
| 128 | + if extra_body is None or not isinstance(extra_body, dict): |
| 129 | + extra_body = {} |
| 130 | + |
| 131 | + extra_body["prediction"] = task.history |
| 132 | + cp["extra_body"] = extra_body |
| 133 | + row_copy.input_metadata.completion_params = cp |
| 134 | + |
| 135 | + current_batch_rows.append(row_copy) |
| 136 | + self.active_logger.log(row_copy) |
| 137 | + |
| 138 | + # 2. Execute Rollout |
| 139 | + batch_results: List[EvaluationRow] = [] |
| 140 | + if task.run_indices: |
| 141 | + representative_run_idx = task.run_indices[0] |
| 142 | + |
| 143 | + async for result_row in rollout_processor_with_retry( |
| 144 | + self.rollout_processor, current_batch_rows, task.config, representative_run_idx |
| 145 | + ): |
| 146 | + batch_results.append(result_row) |
| 147 | + |
| 148 | + # 3. Evaluate and Collect History |
| 149 | + current_batch_history_updates = [] |
| 150 | + |
| 151 | + for res in batch_results: |
| 152 | + # Run Evaluation |
| 153 | + eval_res = await self.eval_executor(res) |
| 154 | + |
| 155 | + # Depending on the execution mode, eval_executor might return a single row or a list |
| 156 | + # For pointwise, it's a single row. For groupwise, it's a list. |
| 157 | + # Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back |
| 158 | + # But to be safe and type-correct, we handle both. |
| 159 | + |
| 160 | + if isinstance(eval_res, list): |
| 161 | + # Should not happen in pointwise mode which is typically used with this scheduler |
| 162 | + # But if it does, we process each result |
| 163 | + for r in eval_res: |
| 164 | + if self.mini_batch_data_buffer: |
| 165 | + await self.mini_batch_data_buffer.add_result(r) |
| 166 | + |
| 167 | + last_msg = r.last_assistant_message() |
| 168 | + if last_msg and last_msg.content: |
| 169 | + content = last_msg.content |
| 170 | + if isinstance(content, list): |
| 171 | + text_parts = [p["text"] for p in content if p["type"] == "text"] |
| 172 | + current_batch_history_updates.append("".join(text_parts)) |
| 173 | + else: |
| 174 | + current_batch_history_updates.append(str(content)) |
| 175 | + else: |
| 176 | + current_batch_history_updates.append("") |
| 177 | + else: |
| 178 | + if self.mini_batch_data_buffer: |
| 179 | + await self.mini_batch_data_buffer.add_result(eval_res) |
| 180 | + |
| 181 | + # Extract prediction for history |
| 182 | + last_msg = eval_res.last_assistant_message() |
| 183 | + if last_msg and last_msg.content: |
| 184 | + content = last_msg.content |
| 185 | + if isinstance(content, list): |
| 186 | + text_parts = [p["text"] for p in content if p["type"] == "text"] |
| 187 | + current_batch_history_updates.append("".join(text_parts)) |
| 188 | + else: |
| 189 | + current_batch_history_updates.append(str(content)) |
| 190 | + else: |
| 191 | + current_batch_history_updates.append("") # Empty string for failed turns |
| 192 | + |
| 193 | + # 4. Schedule Next Micro-batch (High Priority) |
| 194 | + last_run_idx = task.run_indices[-1] |
| 195 | + next_start = last_run_idx + 1 |
| 196 | + |
| 197 | + if next_start < self.num_runs: |
| 198 | + next_end = min(next_start + self.micro_batch_size, self.num_runs) |
| 199 | + next_indices = list(range(next_start, next_end)) |
| 200 | + new_history = task.history + current_batch_history_updates |
| 201 | + |
| 202 | + # Priority 0 (High) to ensure we finish this sample ASAP |
| 203 | + new_priority = (0, task.row_index) |
| 204 | + |
| 205 | + new_task = RolloutTask( |
| 206 | + priority=new_priority, |
| 207 | + row=task.row, |
| 208 | + run_indices=next_indices, |
| 209 | + config=task.config, |
| 210 | + row_index=task.row_index, |
| 211 | + history=new_history |
| 212 | + ) |
| 213 | + self.queue.put_nowait(new_task) |
| 214 | + |
| 215 | + async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_size: int, base_config: RolloutProcessorConfig): |
| 216 | + self.num_runs = num_runs |
| 217 | + self.micro_batch_size = micro_batch_size |
| 218 | + |
| 219 | + # 1. Schedule initial tasks |
| 220 | + await self.schedule_dataset(dataset, base_config) |
| 221 | + |
| 222 | + # 2. Start Workers |
| 223 | + workers = [asyncio.create_task(self.worker()) for _ in range(self.max_concurrent_rollouts)] |
| 224 | + |
| 225 | + # 3. Wait for completion |
| 226 | + await self.queue.join() |
| 227 | + |
| 228 | + # 4. Cleanup |
| 229 | + for w in workers: |
| 230 | + w.cancel() |
| 231 | + |
| 232 | + # Ensure cancellation is complete |
| 233 | + if workers: |
| 234 | + await asyncio.gather(*workers, return_exceptions=True) |
| 235 | + |
| 236 | + # Return empty dict as we rely on side effects (streaming buffer) |
| 237 | + return {} |
| 238 | + |
| 239 | +async def execute_priority_rollouts( |
| 240 | + dataset: List[EvaluationRow], |
| 241 | + num_runs: int, |
| 242 | + micro_batch_size: int, |
| 243 | + rollout_processor: RolloutProcessor, |
| 244 | + config: RolloutProcessorConfig, |
| 245 | + max_concurrent_rollouts: int, |
| 246 | + active_logger: DatasetLogger, |
| 247 | + eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], |
| 248 | + mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None, |
| 249 | +): |
| 250 | + scheduler = PriorityRolloutScheduler( |
| 251 | + rollout_processor=rollout_processor, |
| 252 | + max_concurrent_rollouts=max_concurrent_rollouts, |
| 253 | + active_logger=active_logger, |
| 254 | + eval_executor=eval_executor, |
| 255 | + mini_batch_data_buffer=mini_batch_data_buffer |
| 256 | + ) |
| 257 | + return await scheduler.run(dataset, num_runs, micro_batch_size, config) |
0 commit comments