@@ -49,13 +49,15 @@ def __init__(
4949 self ._timeout_seconds = timeout_seconds
5050 self ._tracing_adapter = FireworksTracingAdapter (base_url = self ._model_base_url )
5151 self ._session : Optional [aiohttp .ClientSession ] = None
52+ self ._active_runs = 0
5253
5354 def _get_or_create_session (self ) -> aiohttp .ClientSession :
5455 if self ._session is None or self ._session .closed :
5556 self ._session = aiohttp .ClientSession ()
5657 return self ._session
5758
5859 def __call__ (self , rows : List [EvaluationRow ], config : RolloutProcessorConfig ) -> List [asyncio .Task [EvaluationRow ]]:
60+ self ._active_runs += 1
5961 tasks : List [asyncio .Task [EvaluationRow ]] = []
6062
6163 # Start with constructor values
@@ -112,19 +114,6 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
112114 raise RuntimeError (f"Remote /init failed (HTTP { resp .status } ): { body } " )
113115 resp .raise_for_status ()
114116 await resp .read () # Drain the response body and release the connection back to the pool
115- except asyncio .CancelledError :
116- # Distinguish intentional cancellation (Ctrl+C, test teardown) from
117- # aiohttp-internal cancellation caused by a poisoned DNS resolver
118- # after a server disconnect. Task.cancelling() returns the number
119- # of pending cancel() calls; > 0 means someone explicitly cancelled
120- # this task.
121- current = asyncio .current_task ()
122- if current is not None and current .cancelling () > 0 : # pyright: ignore[reportAttributeAccessIssue]
123- raise # Intentional cancellation — propagate immediately
124- # Network-level failure; discard the session so retries get a
125- # fresh connection pool.
126- self ._session = None
127- raise ConnectionError ("Remote server connection lost (request cancelled)" )
128117 except asyncio .TimeoutError :
129118 raise TimeoutError (
130119 f"The /init endpoint tried { init_url } with { init_payload .model_dump ()} but timed out after 300 seconds."
@@ -220,19 +209,25 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
220209 return tasks
221210
222211 async def acleanup (self ) -> None :
223- """Async cleanup - preferred when you can await."""
224- if self ._session and not self ._session .closed :
212+ """Async cleanup - only closes the session when the last run finishes.
213+
214+ rollout_processor_with_retry calls acleanup() per-run, but the session
215+ is shared across parallel runs. Closing it early would cancel in-flight
216+ requests in other runs.
217+ """
218+ self ._active_runs = max (0 , self ._active_runs - 1 )
219+ if self ._active_runs == 0 and self ._session and not self ._session .closed :
225220 await self ._session .close ()
226221
227222 def cleanup (self ) -> None :
228223 """Sync cleanup - best-effort, schedules close if event loop is running."""
224+ if self ._active_runs > 0 :
225+ return
229226 if self ._session and not self ._session .closed :
230227 try :
231228 loop = asyncio .get_running_loop ()
232229 loop .create_task (self ._session .close ())
233230 except RuntimeError :
234- # No running event loop - can't safely close the session.
235- # The session will be garbage collected eventually, but warn about it.
236231 logger .warning (
237232 "RemoteRolloutProcessor.cleanup() called outside of async context. "
238233 "Session may not be properly closed. Use `await processor.acleanup()` when possible."
0 commit comments