Skip to content

Commit d9ab3d4

Browse files
committed
add
1 parent 7aa064f commit d9ab3d4

File tree

3 files changed

+334
-0
lines changed

3 files changed

+334
-0
lines changed

eval_protocol/pytest/buffer.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import asyncio
2+
import os
3+
from collections import defaultdict
4+
from typing import List, Dict
5+
6+
from eval_protocol.models import EvaluationRow
7+
8+
class MiniBatchDataBuffer:
9+
"""
10+
Buffers evaluation results and writes them to disk in minibatches.
11+
Waits for all runs of a sample to complete before considering it ready and flush to disk.
12+
"""
13+
def __init__(self, num_runs: int, minibatch_size: int, output_path_template: str):
14+
self.num_runs = num_runs
15+
self.minibatch_size = minibatch_size
16+
self.output_path_template = output_path_template
17+
self.pending_samples: Dict[str, List[EvaluationRow]] = defaultdict(list) # row_id -> list[EvaluationRow]
18+
self.completed_samples_buffer: List[List[EvaluationRow]] = [] # List[List[EvaluationRow]]
19+
self.batch_index = 0
20+
self.lock = asyncio.Lock()
21+
22+
async def add_result(self, row: EvaluationRow):
23+
"""
24+
Add a single evaluation result.
25+
Thread-safe/Coroutine-safe.
26+
"""
27+
async with self.lock:
28+
row_id = row.input_metadata.row_id
29+
if not row_id:
30+
# Should not happen in valid EP workflow, unique row_id is required to group things together properly
31+
return
32+
33+
self.pending_samples[row_id].append(row)
34+
35+
if len(self.pending_samples[row_id]) >= self.num_runs:
36+
# Sample completed (all runs finished)
37+
completed_rows = self.pending_samples.pop(row_id)
38+
self.completed_samples_buffer.append(completed_rows)
39+
40+
if len(self.completed_samples_buffer) >= self.minibatch_size:
41+
await self._flush_unsafe()
42+
43+
async def _flush_unsafe(self):
44+
"""
45+
not thread safe, assumes lock is held by called
46+
"""
47+
if not self.completed_samples_buffer:
48+
return
49+
50+
if "{index}" in self.output_path_template:
51+
output_path = self.output_path_template.format(index=self.batch_index)
52+
mode = "w"
53+
else:
54+
output_path = self.output_path_template
55+
mode = "a" # Append if no index placeholder
56+
57+
# Ensure directory exists
58+
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
59+
60+
# Write flattened rows
61+
with open(output_path, mode) as f:
62+
for sample_rows in self.completed_samples_buffer:
63+
for row in sample_rows:
64+
f.write(row.model_dump_json() + "\n")
65+
66+
self.completed_samples_buffer = []
67+
self.batch_index += 1
68+
69+
async def close(self):
70+
"""
71+
Flush any remaining samples in the buffer.
72+
"""
73+
async with self.lock:
74+
if self.completed_samples_buffer:
75+
await self._flush_unsafe()
76+

eval_protocol/pytest/evaluation_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dataclasses import dataclass, field
12
import asyncio
23
import inspect
34
import os
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import asyncio
2+
import os
3+
from collections import defaultdict
4+
from dataclasses import dataclass, field
5+
from typing import Any, Callable, List, Dict, Optional, Union, Awaitable
6+
7+
from eval_protocol.models import EvaluationRow, Status
8+
from eval_protocol.pytest.types import RolloutProcessorConfig
9+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
10+
from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry
11+
from eval_protocol.pytest.buffer import MiniBatchDataBuffer
12+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
13+
from eval_protocol.human_id import generate_id
14+
15+
@dataclass(order=True)
16+
class RolloutTask:
17+
"""
18+
Represents a single unit of work for the worker pool.
19+
Priority tuple structure: (status, row_index)
20+
- status: 0 = High Priority (e.g., subsequent micro-batches of an already started sample)
21+
1 = Low Priority (e.g., starting a new sample)
22+
- row_index: Used to maintain dataset order for initial scheduling
23+
"""
24+
priority: tuple[int, int]
25+
26+
# Payload (excluded from comparison)
27+
row: EvaluationRow = field(compare=False)
28+
run_indices: List[int] = field(compare=False) # Which runs to execute in this task
29+
config: RolloutProcessorConfig = field(compare=False)
30+
row_index: int = field(compare=False) # To track which sample this belongs to
31+
32+
# History for speculation (injected from previous micro-batches)
33+
history: List[str] = field(compare=False, default_factory=list)
34+
35+
class PriorityRolloutScheduler:
36+
"""
37+
Manages a priority queue of rollout tasks and a pool of workers.
38+
Ensures that once a sample starts processing, its subsequent micro-batches
39+
are prioritized to complete the sample as quickly as possible.
40+
"""
41+
def __init__(
42+
self,
43+
rollout_processor: RolloutProcessor,
44+
max_concurrent_rollouts: int,
45+
active_logger: DatasetLogger,
46+
eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]], # Callback to run evaluation
47+
mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None,
48+
):
49+
self.rollout_processor = rollout_processor
50+
self.max_concurrent_rollouts = max_concurrent_rollouts
51+
self.active_logger = active_logger
52+
self.eval_executor = eval_executor
53+
self.mini_batch_data_buffer = mini_batch_data_buffer
54+
55+
# Priority Queue: Stores RolloutTask
56+
self.queue: asyncio.PriorityQueue[RolloutTask] = asyncio.PriorityQueue()
57+
58+
self.num_runs = 0
59+
self.micro_batch_size = 0
60+
61+
async def schedule_dataset(
62+
self,
63+
dataset: List[EvaluationRow],
64+
base_config: RolloutProcessorConfig,
65+
):
66+
"""
67+
Populates the queue with initial tasks (the first micro-batch for each sample).
68+
"""
69+
for i, row in enumerate(dataset):
70+
# Calculate ranges for the first micro-batch
71+
batch_start = 0
72+
# Ensure micro_batch_size is at least 1 to avoid infinite loop or stuck tasks
73+
safe_batch_size = self.micro_batch_size if self.micro_batch_size > 0 else self.num_runs
74+
batch_end = min(safe_batch_size, self.num_runs)
75+
run_indices = list(range(batch_start, batch_end))
76+
77+
# Initial priority: Low (1), ordered by dataset index
78+
priority = (1, i)
79+
80+
task = RolloutTask(
81+
priority=priority,
82+
row=row,
83+
run_indices=run_indices,
84+
config=base_config,
85+
row_index=i,
86+
history=[] # Initial batch has no history
87+
)
88+
self.queue.put_nowait(task)
89+
90+
async def worker(self):
91+
"""
92+
Worker loop: fetch task -> execute micro-batch -> schedule next batch (if any).
93+
"""
94+
while True:
95+
try:
96+
# Get a task from the priority queue
97+
task: RolloutTask = await self.queue.get()
98+
except asyncio.QueueEmpty:
99+
break
100+
101+
try:
102+
await self._process_task(task)
103+
except Exception as e:
104+
print(f"Error processing task for row {task.row.input_metadata.row_id}: {e}")
105+
finally:
106+
self.queue.task_done()
107+
108+
async def _process_task(self, task: RolloutTask):
109+
"""
110+
Executes a single micro-batch task.
111+
"""
112+
# 1. Prepare Config & Row for this micro-batch
113+
current_batch_rows = []
114+
for run_idx in task.run_indices:
115+
row_copy = task.row.model_copy(deep=True)
116+
117+
row_copy.execution_metadata.run_id = generate_id()
118+
row_copy.execution_metadata.rollout_id = generate_id()
119+
120+
# Inject Speculation History
121+
if task.history:
122+
cp = row_copy.input_metadata.completion_params
123+
# Ensure safe dict access
124+
if not isinstance(cp, dict):
125+
cp = {}
126+
# Need to check and initialize nested dicts
127+
extra_body = cp.get("extra_body")
128+
if extra_body is None or not isinstance(extra_body, dict):
129+
extra_body = {}
130+
131+
extra_body["prediction"] = task.history
132+
cp["extra_body"] = extra_body
133+
row_copy.input_metadata.completion_params = cp
134+
135+
current_batch_rows.append(row_copy)
136+
self.active_logger.log(row_copy)
137+
138+
# 2. Execute Rollout
139+
batch_results: List[EvaluationRow] = []
140+
if task.run_indices:
141+
representative_run_idx = task.run_indices[0]
142+
143+
async for result_row in rollout_processor_with_retry(
144+
self.rollout_processor, current_batch_rows, task.config, representative_run_idx
145+
):
146+
batch_results.append(result_row)
147+
148+
# 3. Evaluate and Collect History
149+
current_batch_history_updates = []
150+
151+
for res in batch_results:
152+
# Run Evaluation
153+
eval_res = await self.eval_executor(res)
154+
155+
# Depending on the execution mode, eval_executor might return a single row or a list
156+
# For pointwise, it's a single row. For groupwise, it's a list.
157+
# Since PriorityScheduler processes a batch of single-turn rollouts, we expect single rows back
158+
# But to be safe and type-correct, we handle both.
159+
160+
if isinstance(eval_res, list):
161+
# Should not happen in pointwise mode which is typically used with this scheduler
162+
# But if it does, we process each result
163+
for r in eval_res:
164+
if self.mini_batch_data_buffer:
165+
await self.mini_batch_data_buffer.add_result(r)
166+
167+
last_msg = r.last_assistant_message()
168+
if last_msg and last_msg.content:
169+
content = last_msg.content
170+
if isinstance(content, list):
171+
text_parts = [p["text"] for p in content if p["type"] == "text"]
172+
current_batch_history_updates.append("".join(text_parts))
173+
else:
174+
current_batch_history_updates.append(str(content))
175+
else:
176+
current_batch_history_updates.append("")
177+
else:
178+
if self.mini_batch_data_buffer:
179+
await self.mini_batch_data_buffer.add_result(eval_res)
180+
181+
# Extract prediction for history
182+
last_msg = eval_res.last_assistant_message()
183+
if last_msg and last_msg.content:
184+
content = last_msg.content
185+
if isinstance(content, list):
186+
text_parts = [p["text"] for p in content if p["type"] == "text"]
187+
current_batch_history_updates.append("".join(text_parts))
188+
else:
189+
current_batch_history_updates.append(str(content))
190+
else:
191+
current_batch_history_updates.append("") # Empty string for failed turns
192+
193+
# 4. Schedule Next Micro-batch (High Priority)
194+
last_run_idx = task.run_indices[-1]
195+
next_start = last_run_idx + 1
196+
197+
if next_start < self.num_runs:
198+
next_end = min(next_start + self.micro_batch_size, self.num_runs)
199+
next_indices = list(range(next_start, next_end))
200+
new_history = task.history + current_batch_history_updates
201+
202+
# Priority 0 (High) to ensure we finish this sample ASAP
203+
new_priority = (0, task.row_index)
204+
205+
new_task = RolloutTask(
206+
priority=new_priority,
207+
row=task.row,
208+
run_indices=next_indices,
209+
config=task.config,
210+
row_index=task.row_index,
211+
history=new_history
212+
)
213+
self.queue.put_nowait(new_task)
214+
215+
async def run(self, dataset: List[EvaluationRow], num_runs: int, micro_batch_size: int, base_config: RolloutProcessorConfig):
216+
self.num_runs = num_runs
217+
self.micro_batch_size = micro_batch_size
218+
219+
# 1. Schedule initial tasks
220+
await self.schedule_dataset(dataset, base_config)
221+
222+
# 2. Start Workers
223+
workers = [asyncio.create_task(self.worker()) for _ in range(self.max_concurrent_rollouts)]
224+
225+
# 3. Wait for completion
226+
await self.queue.join()
227+
228+
# 4. Cleanup
229+
for w in workers:
230+
w.cancel()
231+
232+
# Ensure cancellation is complete
233+
if workers:
234+
await asyncio.gather(*workers, return_exceptions=True)
235+
236+
# Return empty dict as we rely on side effects (streaming buffer)
237+
return {}
238+
239+
async def execute_priority_rollouts(
240+
dataset: List[EvaluationRow],
241+
num_runs: int,
242+
micro_batch_size: int,
243+
rollout_processor: RolloutProcessor,
244+
config: RolloutProcessorConfig,
245+
max_concurrent_rollouts: int,
246+
active_logger: DatasetLogger,
247+
eval_executor: Callable[[Union[EvaluationRow, List[EvaluationRow]]], Awaitable[Union[EvaluationRow, List[EvaluationRow]]]],
248+
mini_batch_data_buffer: Optional[MiniBatchDataBuffer] = None,
249+
):
250+
scheduler = PriorityRolloutScheduler(
251+
rollout_processor=rollout_processor,
252+
max_concurrent_rollouts=max_concurrent_rollouts,
253+
active_logger=active_logger,
254+
eval_executor=eval_executor,
255+
mini_batch_data_buffer=mini_batch_data_buffer
256+
)
257+
return await scheduler.run(dataset, num_runs, micro_batch_size, config)

0 commit comments

Comments
 (0)