Skip to content

Commit 521e756

Browse files
committed
fix
1 parent 4240f64 commit 521e756

1 file changed

Lines changed: 12 additions & 17 deletions

File tree

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ def __init__(
4949
self._timeout_seconds = timeout_seconds
5050
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)
5151
self._session: Optional[aiohttp.ClientSession] = None
52+
self._active_runs = 0
5253

5354
def _get_or_create_session(self) -> aiohttp.ClientSession:
5455
if self._session is None or self._session.closed:
5556
self._session = aiohttp.ClientSession()
5657
return self._session
5758

5859
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
60+
self._active_runs += 1
5961
tasks: List[asyncio.Task[EvaluationRow]] = []
6062

6163
# Start with constructor values
@@ -112,19 +114,6 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
112114
raise RuntimeError(f"Remote /init failed (HTTP {resp.status}): {body}")
113115
resp.raise_for_status()
114116
await resp.read() # Drain the response body and release the connection back to the pool
115-
except asyncio.CancelledError:
116-
# Distinguish intentional cancellation (Ctrl+C, test teardown) from
117-
# aiohttp-internal cancellation caused by a poisoned DNS resolver
118-
# after a server disconnect. Task.cancelling() returns the number
119-
# of pending cancel() calls; > 0 means someone explicitly cancelled
120-
# this task.
121-
current = asyncio.current_task()
122-
if current is not None and current.cancelling() > 0: # pyright: ignore[reportAttributeAccessIssue]
123-
raise # Intentional cancellation — propagate immediately
124-
# Network-level failure; discard the session so retries get a
125-
# fresh connection pool.
126-
self._session = None
127-
raise ConnectionError("Remote server connection lost (request cancelled)")
128117
except asyncio.TimeoutError:
129118
raise TimeoutError(
130119
f"The /init endpoint tried {init_url} with {init_payload.model_dump()} but timed out after 300 seconds."
@@ -220,19 +209,25 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
220209
return tasks
221210

222211
async def acleanup(self) -> None:
223-
"""Async cleanup - preferred when you can await."""
224-
if self._session and not self._session.closed:
212+
"""Async cleanup - only closes the session when the last run finishes.
213+
214+
rollout_processor_with_retry calls acleanup() per-run, but the session
215+
is shared across parallel runs. Closing it early would cancel in-flight
216+
requests in other runs.
217+
"""
218+
self._active_runs = max(0, self._active_runs - 1)
219+
if self._active_runs == 0 and self._session and not self._session.closed:
225220
await self._session.close()
226221

227222
def cleanup(self) -> None:
228223
"""Sync cleanup - best-effort, schedules close if event loop is running."""
224+
if self._active_runs > 0:
225+
return
229226
if self._session and not self._session.closed:
230227
try:
231228
loop = asyncio.get_running_loop()
232229
loop.create_task(self._session.close())
233230
except RuntimeError:
234-
# No running event loop - can't safely close the session.
235-
# The session will be garbage collected eventually, but warn about it.
236231
logger.warning(
237232
"RemoteRolloutProcessor.cleanup() called outside of async context. "
238233
"Session may not be properly closed. Use `await processor.acleanup()` when possible."

0 commit comments

Comments
 (0)