diff --git a/openhands-agent-server/openhands/agent_server/event_service.py b/openhands-agent-server/openhands/agent_server/event_service.py index 0a5a501008..3d5f6beb6b 100644 --- a/openhands-agent-server/openhands/agent_server/event_service.py +++ b/openhands-agent-server/openhands/agent_server/event_service.py @@ -18,8 +18,13 @@ ) from openhands.agent_server.pub_sub import PubSub, Subscriber from openhands.sdk import LLM, AgentBase, Event, Message, get_logger +from openhands.sdk.agent import ACPAgent from openhands.sdk.conversation.base import BaseConversation -from openhands.sdk.conversation.impl.local_conversation import LocalConversation +from openhands.sdk.conversation.impl.local_conversation import ( + ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID, + ACP_SUPERSEDE_INFLIGHT_PROMPT, + LocalConversation, +) from openhands.sdk.conversation.response_utils import get_agent_final_response from openhands.sdk.conversation.secret_registry import SecretValue from openhands.sdk.conversation.state import ( @@ -71,6 +76,15 @@ class EventService: # Set when a send_message(run=True) is rejected because a run is still # wrapping up; consumed by _run_and_publish to re-run the stranded message. _rerun_requested: bool = field(default=False, init=False) + # Set only for the internal ACP interrupt/restart path triggered by a new + # send_message(run=True). Explicit user pause/interrupt clears it so user + # stop intent wins over an earlier automatic restart request. + _acp_internal_rerun_requested: bool = field(default=False, init=False) + # Incremented for explicit user pause/interrupt requests. Internal ACP + # supersede restarts compare this generation after their interrupt drains + # so a later Stop/Pause cannot be overwritten by an automatic restart. + _explicit_interrupt_generation: int = field(default=0, init=False) + _closing: bool = field(default=False, init=False) _run_lock: asyncio.Lock = field(default_factory=asyncio.Lock, init=False) _callback_wrapper: AsyncCallbackWrapper | None = field(default=None, init=False) _lease: ConversationLease | None = field(default=None, init=False) @@ -419,11 +433,28 @@ async def batch_get_events(self, event_ids: list[str]) -> list[Event | None]: async def send_message(self, message: Message, run: bool = False): if not self._conversation: raise ValueError("inactive_service") + explicit_interrupt_generation = self._explicit_interrupt_generation loop = asyncio.get_running_loop() await loop.run_in_executor(None, self._conversation.send_message, message) if run: + if self._explicit_interrupt_generation != explicit_interrupt_generation: + return + ( + did_mark_acp_prompt_superseded, + active_acp_prompt_has_latest_message, + ) = await self._mark_running_acp_prompt_superseded() + interrupted_acp = False + if did_mark_acp_prompt_superseded: + self._acp_internal_rerun_requested = True + interrupted_acp = True + await self.interrupt(internal_acp_rerun=True) + if self._explicit_interrupt_generation != explicit_interrupt_generation: + return try: - await self.run() + await self.run( + acp_internal_rerun_generation=explicit_interrupt_generation + ) + self._acp_internal_rerun_requested = False except ValueError as e: # run() refused. If a run is still wrapping up (its # wait_for_pending tail), the message we just appended won't be @@ -433,8 +464,53 @@ async def send_message(self, message: Message, run: bool = False): # is what keeps a deliberate run=False append, or an IDLE reached # via another path, from triggering an unwanted run. # "inactive_service" is terminal and must not re-arm. - if str(e) == "conversation_already_running": + if ( + str(e) == "conversation_already_running" + and not active_acp_prompt_has_latest_message + ): self._rerun_requested = True + if interrupted_acp: + self._acp_internal_rerun_requested = True + + def _mark_running_acp_prompt_superseded_sync(self) -> tuple[bool, bool]: + """Mark the currently running ACP prompt superseded if needed. + + The tuple is ``(did_mark_superseded, active_prompt_has_latest_message)``. + If the running ACP prompt has already advanced to the newly appended + user message, interrupting it would cancel the replacement prompt and + strand that message behind the persisted cursor. + """ + if not self._conversation: + return (False, False) + if self._run_task is None: + return (False, False) + if not isinstance(self._conversation.agent, ACPAgent): + return (False, False) + with self._conversation._state as state: + if state.execution_status != ConversationExecutionStatus.RUNNING: + return (False, False) + inflight_prompt_user_message_id = state.agent_state.get( + ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID + ) + last_user_message_id = state.last_user_message_id + if inflight_prompt_user_message_id is None or last_user_message_id is None: + return (False, False) + active_prompt_has_latest_message = ( + inflight_prompt_user_message_id == last_user_message_id + ) + if active_prompt_has_latest_message: + return (False, True) + state.agent_state = { + **state.agent_state, + ACP_SUPERSEDE_INFLIGHT_PROMPT: True, + } + return (True, False) + + async def _mark_running_acp_prompt_superseded(self) -> tuple[bool, bool]: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, self._mark_running_acp_prompt_superseded_sync + ) async def subscribe_to_events(self, subscriber: Subscriber[Event]) -> UUID: subscriber_id = self._pub_sub.subscribe(subscriber) @@ -624,41 +700,53 @@ async def start(self): self._pub_sub, loop=asyncio.get_running_loop() ) - # Only wire token streaming if at least one LLM has stream=True. - # The LLM silently ignores on_token when stream is off, but skipping - # the wiring lets us log the decision so operators can tell from a - # log line whether deltas will flow. - streaming_enabled = any(llm.stream for llm in agent.get_all_llms()) + # Only wire token streaming for agents that can actually emit token + # callbacks. SDK LLM agents need stream=True, while ACP agents emit + # AgentMessageChunk text through their bridge without exposing an LLM. + streaming_enabled = isinstance(agent, ACPAgent) or any( + llm.stream for llm in agent.get_all_llms() + ) logger.debug( "Token streaming: %s", "enabled" if streaming_enabled else "disabled (no LLM has stream=True)", ) - def _token_streaming_callback(chunk: LLMStreamChunk) -> None: + def _publish_stream_delta( + content: str | None = None, + reasoning_content: str | None = None, + ) -> None: # Published directly to _pub_sub (not via _callback_wrapper) so # deltas reach subscribers but are NOT persisted to # ConversationState.events. See StreamingDeltaEvent docstring. if not self._main_loop or not self._main_loop.is_running(): return + # Use `is not None` rather than truthiness: some providers + # emit legitimate empty-string chunks at stream boundaries + # (e.g. after a tool call) that we still want to forward. + if content is None and reasoning_content is None: + return + event = StreamingDeltaEvent( + content=content, + reasoning_content=reasoning_content, + ) + with suppress(RuntimeError): # main loop already closed during teardown + asyncio.run_coroutine_threadsafe(self._pub_sub(event), self._main_loop) + + def _token_streaming_callback(chunk: LLMStreamChunk | str) -> None: + if isinstance(chunk, str): + _publish_stream_delta(content=chunk) + return + for choice in chunk.choices or (): delta = choice.delta if delta is None: continue content = getattr(delta, "content", None) reasoning = getattr(delta, "reasoning_content", None) - # Use `is not None` rather than truthiness: some providers - # emit legitimate empty-string chunks at stream boundaries - # (e.g. after a tool call) that we still want to forward. - if content is None and reasoning is None: - continue - event = StreamingDeltaEvent( + _publish_stream_delta( content=content if isinstance(content, str) else None, reasoning_content=reasoning if isinstance(reasoning, str) else None, ) - with suppress(RuntimeError): - asyncio.run_coroutine_threadsafe( - self._pub_sub(event), self._main_loop - ) conversation = LocalConversation( agent=agent, @@ -733,7 +821,7 @@ def _token_streaming_callback(chunk: LLMStreamChunk) -> None: # Publish initial state update await self._publish_state_update() - async def run(self): + async def run(self, acp_internal_rerun_generation: int | None = None): """Run the conversation asynchronously in the background. This method starts the conversation run in a background task and returns @@ -746,7 +834,7 @@ async def run(self): Raises: ValueError: If the service is inactive or conversation is already running. """ - if not self._conversation: + if not self._conversation or self._closing: raise ValueError("inactive_service") # Use lock to make check-and-set atomic, preventing race conditions @@ -756,6 +844,13 @@ async def run(self): == ConversationExecutionStatus.RUNNING ): raise ValueError("conversation_already_running") + if self._closing: + raise ValueError("inactive_service") + if ( + acp_internal_rerun_generation is not None + and self._explicit_interrupt_generation != acp_internal_rerun_generation + ): + return # Check if there's already a running task if self._run_task is not None and not self._run_task.done(): @@ -812,21 +907,53 @@ async def _run_and_publish(): # wrapping up. A send_message(run=True) that arrived during # the wait_for_pending() tail above had its run() rejected as # "conversation_already_running" and suppressed, setting - # _rerun_requested. Honor it only while the conversation is - # still IDLE — i.e. that message is genuinely pending. If the - # run loop was still alive it already absorbed the message - # (LocalConversation.run() keeps looping on FINISHED) and we - # are FINISHED here, so the IDLE guard avoids a redundant run. - # A deliberate run=False append, or an IDLE reached via - # another path, never sets the flag. - if self._rerun_requested: - self._rerun_requested = False - if ( - await self._get_execution_status() - == ConversationExecutionStatus.IDLE - ): - with suppress(ValueError): - await self.run() + # _rerun_requested. Honor it while the conversation is IDLE + # (pending input) or internally ACP-interrupted PAUSED (the + # old task finished its interrupt before the replacement run + # could start). Explicit user pause/interrupt clears the + # internal ACP flag, so user stop intent wins over an older + # automatic restart request. If the run loop was still alive + # it already absorbed the message and we are FINISHED here, + # so the guard avoids a redundant run. A deliberate + # run=False append, or an IDLE reached via another path, + # never sets the flag. + rerun_requested = self._rerun_requested + acp_internal_rerun_requested = self._acp_internal_rerun_requested + rerun_generation = self._explicit_interrupt_generation + self._rerun_requested = False + self._acp_internal_rerun_requested = False + if rerun_requested: + status = await self._get_execution_status() + rerun_generation_still_valid = ( + self._explicit_interrupt_generation == rerun_generation + ) + acp_internal_rerun_still_valid = ( + acp_internal_rerun_requested + and rerun_generation_still_valid + ) + should_restart = rerun_generation_still_valid and ( + status == ConversationExecutionStatus.IDLE + or ( + acp_internal_rerun_still_valid + and status == ConversationExecutionStatus.PAUSED + and isinstance(conversation.agent, ACPAgent) + ) + ) + if should_restart: + try: + await self.run( + acp_internal_rerun_generation=rerun_generation + if acp_internal_rerun_still_valid + else None + ) + except ValueError as e: + if str(e) == "conversation_already_running": + self._rerun_requested = True + self._acp_internal_rerun_requested = ( + acp_internal_rerun_requested + ) + else: + raise # Create task but don't await it - runs in background self._run_task = asyncio.create_task(_run_and_publish()) @@ -857,12 +984,15 @@ async def reject_pending_actions(self, reason: str): async def pause(self): if self._conversation: + self._explicit_interrupt_generation += 1 + self._rerun_requested = False + self._acp_internal_rerun_requested = False loop = asyncio.get_running_loop() await loop.run_in_executor(None, self._conversation.pause) # Publish state update after pause to ensure stats are updated await self._publish_state_update() - async def interrupt(self): + async def interrupt(self, *, internal_acp_rerun: bool = False): """Immediately cancel an in-flight async LLM call. Delegates to :meth:`LocalConversation.interrupt` which cancels the @@ -870,12 +1000,18 @@ async def interrupt(self): back to :meth:`pause`. """ if self._conversation: + if not internal_acp_rerun: + self._explicit_interrupt_generation += 1 + self._rerun_requested = False + self._acp_internal_rerun_requested = False self._conversation.interrupt() # Wait for the run task to finish so we can publish the final - # state update (PAUSED + InterruptEvent) cleanly. + # state update (PAUSED + InterruptEvent) cleanly. The shield keeps + # the 5s timeout from force-cancelling a cleanup that still needs + # to drain its ACP prompt/cancel handshake. if self._run_task is not None and not self._run_task.done(): with suppress(Exception): - await asyncio.wait_for(self._run_task, timeout=5.0) + await asyncio.wait_for(asyncio.shield(self._run_task), timeout=5.0) # Only clear _run_task if it actually finished; if # wait_for timed out the task may still be running and # clearing prematurely would allow a second run() to @@ -912,6 +1048,10 @@ async def set_security_analyzer( ) async def close(self): + self._closing = True + self._explicit_interrupt_generation += 1 + self._rerun_requested = False + self._acp_internal_rerun_requested = False if self._lease_task is not None: self._lease_task.cancel() with suppress(asyncio.CancelledError): diff --git a/openhands-sdk/openhands/sdk/agent/acp_agent.py b/openhands-sdk/openhands/sdk/agent/acp_agent.py index 5f9883954c..1391d4ac1d 100644 --- a/openhands-sdk/openhands/sdk/agent/acp_agent.py +++ b/openhands-sdk/openhands/sdk/agent/acp_agent.py @@ -17,14 +17,16 @@ from __future__ import annotations import asyncio +import inspect import json import os import threading import time import uuid from collections.abc import Generator +from concurrent.futures import Future from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, NamedTuple from acp.client.connection import ClientSideConnection from acp.exceptions import RequestError as ACPRequestError @@ -93,6 +95,13 @@ # These errors can occur when the connection drops mid-conversation but the # session state is still valid on the server side. _ACP_PROMPT_MAX_RETRIES: int = int(os.environ.get("ACP_PROMPT_MAX_RETRIES", "3")) + +# After a timeout/cancellation, wait briefly for the ACP prompt task to react +# to session/cancel before rewiring callbacks for the next turn. +_ACP_CANCEL_DRAIN_TIMEOUT: float = float( + os.environ.get("ACP_CANCEL_DRAIN_TIMEOUT", "2.0") +) + _ACP_PROMPT_RETRY_DELAYS: tuple[float, ...] = (5.0, 15.0, 30.0) # seconds # Exception types that indicate transient connection issues worth retrying @@ -146,6 +155,13 @@ _TERMINAL_TOOL_CALL_STATUSES: frozenset[str] = frozenset({"completed", "failed"}) +class _PromptDrainResult(NamedTuple): + drained: bool + completed: bool + response: PromptResponse | None + error: BaseException | None + + # Stable identifier stamped onto the sentinel LLM so downstream code # (e.g. title_utils) can detect "this LLM cannot be called" without # relying on the model name — which we overwrite with the real model @@ -759,6 +775,8 @@ def model_post_init(self, __context: object) -> None: # "installed" — already in subprocess history; skip further injection _suffix_install_state: str = PrivateAttr(default="unused") _installed_suffix: str | None = PrivateAttr(default=None) + _restart_session_on_next_turn: bool = PrivateAttr(default=False) + _resumed_existing_session: bool = PrivateAttr(default=False) # -- Helpers ----------------------------------------------------------- @@ -922,6 +940,10 @@ def init_state( # will re-inject the suffix on the first turn after upgrade, which # is benign — the suffix is additive LLM-context guidance. suffix_already_installed = bool(state.agent_state.get("acp_suffix_installed")) + # Tests that patch out _start_acp_server rely on this best-effort + # initial value. The real start path overwrites it with whether + # load_session actually succeeded. + self._resumed_existing_session = bool(state.agent_state.get("acp_session_id")) try: self._start_acp_server(state) @@ -942,17 +964,22 @@ def init_state( # in a different working directory would at best silently miss the # prior session and at worst load a different session that happens to # exist at the new cwd. - state.agent_state = { + updated_agent_state = { **state.agent_state, "acp_agent_name": self._agent_name, "acp_agent_version": self._agent_version, "acp_session_id": self._session_id, "acp_session_cwd": self._working_dir, } + if not self._resumed_existing_session: + updated_agent_state.pop("acp_suffix_installed", None) + state.agent_state = updated_agent_state if self._installed_suffix: self._suffix_install_state = ( - "installed" if suffix_already_installed else "pending_first_prompt" + "installed" + if suffix_already_installed and self._resumed_existing_session + else "pending_first_prompt" ) # Emit a placeholder system prompt so the visualizer shows a section @@ -1076,7 +1103,9 @@ def _start_acp_server(self, state: ConversationState) -> None: ) prior_session_id = None - async def _init() -> tuple[Any, Any, Any, str, str, str]: + self._resumed_existing_session = False + + async def _init() -> tuple[Any, Any, Any, str, str, str, bool]: # Spawn the subprocess directly so we can install a # filtering reader that skips non-JSON-RPC lines some # ACP servers (e.g. claude-code-acp v0.1.x) write to @@ -1159,6 +1188,7 @@ async def _init() -> tuple[Any, Any, Any, str, str, str]: # subprocess crash) propagate — there is no working connection to # fall back on, and the outer init_state handler cleans up. session_id: str | None = None + resumed_existing_session = False if prior_session_id is not None: try: await conn.load_session( @@ -1167,6 +1197,7 @@ async def _init() -> tuple[Any, Any, Any, str, str, str]: mcp_servers=[], ) session_id = prior_session_id + resumed_existing_session = True logger.info( "Resumed ACP session: %s (cwd=%s)", session_id, @@ -1205,7 +1236,15 @@ async def _init() -> tuple[Any, Any, Any, str, str, str]: logger.info("Setting ACP session mode: %s", mode_id) await conn.set_session_mode(mode_id=mode_id, session_id=session_id) - return conn, process, filtered_reader, session_id, agent_name, agent_version + return ( + conn, + process, + filtered_reader, + session_id, + agent_name, + agent_version, + resumed_existing_session, + ) result = self._executor.run_async(_init) ( @@ -1215,6 +1254,7 @@ async def _init() -> tuple[Any, Any, Any, str, str, str]: self._session_id, self._agent_name, self._agent_version, + self._resumed_existing_session, ) = result self._working_dir = working_dir @@ -1250,12 +1290,13 @@ def _cancel_inflight_tool_calls(self) -> None: spinning forever. This method closes those cards before we wipe the in-memory accumulator on retry / turn abort. - Uses the bridge's ``on_event`` directly (the same callback driving - live emissions); call this *before* ``_reset_client_for_turn`` so - the callback is still wired up. No-op if ``on_event`` was never - set (e.g. during tests exercising the bridge in isolation). + Captures the bridge's ``on_event`` callback, then unwires the bridge + before emitting synthetic terminal events so trailing updates from the + abandoned portal prompt cannot land after these failures. No-op if + ``on_event`` was never set (e.g. tests exercising the bridge alone). """ on_event = self._client.on_event + self._clear_turn_callbacks() if on_event is None: return for tc in self._client.accumulated_tool_calls: @@ -1282,6 +1323,116 @@ def _cancel_inflight_tool_calls(self) -> None: exc_info=True, ) + async def _arequest_session_cancel(self) -> None: + """Async variant of _request_session_cancel that waits for cancel send.""" + if self._conn is None or self._executor is None or self._session_id is None: + return + session_id = self._session_id + + async def _cancel() -> None: + result = self._conn.cancel(session_id) + if inspect.isawaitable(result): + await result + + try: + future = self._executor.portal.start_task_soon(_cancel) + await asyncio.wait_for( + asyncio.shield(asyncio.wrap_future(future)), + timeout=_ACP_CANCEL_DRAIN_TIMEOUT, + ) + except TimeoutError: + logger.warning( + "Timed out sending ACP session cancel; restarting ACP session" + ) + self._restart_session_on_next_turn = True + except Exception: + logger.warning("Failed to send ACP session cancel", exc_info=True) + + async def _drain_cancelled_prompt( + self, + future: Future[PromptResponse | None] | None, + ) -> _PromptDrainResult: + """Let a cancelled/timed-out portal prompt quiesce before rewiring.""" + if future is None: + return _PromptDrainResult( + drained=True, completed=False, response=None, error=None + ) + if future.cancelled(): + return _PromptDrainResult( + drained=True, completed=False, response=None, error=None + ) + if future.done(): + try: + return _PromptDrainResult( + drained=True, + completed=True, + response=future.result(), + error=None, + ) + except BaseException as exc: + return _PromptDrainResult( + drained=True, completed=True, response=None, error=exc + ) + try: + response = await asyncio.wait_for( + asyncio.shield(asyncio.wrap_future(future)), + timeout=_ACP_CANCEL_DRAIN_TIMEOUT, + ) + return _PromptDrainResult( + drained=True, completed=True, response=response, error=None + ) + except asyncio.CancelledError: + if future.cancelled(): + return _PromptDrainResult( + drained=False, completed=False, response=None, error=None + ) + raise + except TimeoutError: + logger.warning( + "Timed out waiting for cancelled ACP prompt to drain; " + "the ACP session will be restarted before the next turn" + ) + return _PromptDrainResult( + drained=False, completed=False, response=None, error=None + ) + except BaseException as exc: + return _PromptDrainResult( + drained=future.done(), completed=True, response=None, error=exc + ) + + def _restart_session_after_drain_timeout( + self, + state: ConversationState, + on_event: ConversationCallbackType, + ) -> None: + """Restart ACP after a prompt failed to quiesce post-cancel.""" + logger.warning("Restarting ACP session after cancelled prompt drain timeout") + self._clear_turn_callbacks() + self._cleanup() + self._initialized = False + # A local drain timeout means the cancelled prompt did not quiesce + # within our short grace window; it does not prove the ACP server lost + # its persisted session. Preserve the session id so the restarted + # subprocess can load_session() and retain conversation memory. + self.init_state(state, on_event=on_event) + self._restart_session_on_next_turn = False + + def _request_session_cancel(self) -> None: + """Ask the ACP server to cancel the active session prompt.""" + if self._conn is None or self._executor is None or self._session_id is None: + return + session_id = self._session_id + + async def _cancel() -> None: + result = self._conn.cancel(session_id) + if inspect.isawaitable(result): + await result + + try: + self._executor.portal.start_task_soon(_cancel) + except Exception: + logger.warning("Failed to send ACP session cancel", exc_info=True) + def _build_acp_prompt( self, event: MessageEvent ) -> list[TextContentBlock | ImageContentBlock] | None: @@ -1361,6 +1512,32 @@ async def _do_acp_prompt(self, prompt_blocks: list[Any]) -> PromptResponse | Non ) return response + async def _await_prompt_response_with_timeout( + self, + prompt_future: Future[PromptResponse | None], + ) -> PromptResponse | None: + """Await an ACP prompt with a hard turn deadline. + + The terminal tool reports hard command timeouts back to the agent + instead of waiting forever for active commands. ACP prompts follow the + same rule: activity heartbeats keep the server alive, but they do not + extend this prompt deadline. The timeout handler sends ``session/cancel`` + and closes any in-flight tool cards. + """ + try: + return await asyncio.wait_for( + asyncio.shield(asyncio.wrap_future(prompt_future)), + timeout=self.acp_prompt_timeout, + ) + except TimeoutError as exc: + raise TimeoutError( + f"ACP prompt timed out after {self.acp_prompt_timeout:.0f}s" + ) from exc + + @staticmethod + def _prompt_response_was_cancelled(response: PromptResponse | None) -> bool: + return response is not None and response.stop_reason == "cancelled" + def _finalize_successful_turn( self, response: PromptResponse | None, @@ -1494,12 +1671,40 @@ def _emit_turn_error( ) state.execution_status = ConversationExecutionStatus.ERROR + def _handle_cancelled_cleanup_interruption( + self, + prompt_future: Future[PromptResponse | None] | None, + elapsed: float, + state: ConversationState, + on_event: ConversationCallbackType, + ) -> None: + """Repair state when cancellation interrupts cancel/drain cleanup.""" + if prompt_future is not None and prompt_future.done(): + try: + response = prompt_future.result() + except BaseException: + self._cancel_inflight_tool_calls() + self._restart_session_on_next_turn = True + else: + if self._prompt_response_was_cancelled(response): + self._cancel_inflight_tool_calls() + self._restart_session_on_next_turn = True + else: + self._finalize_successful_turn(response, elapsed, state, on_event) + return + + self._cancel_inflight_tool_calls() + if prompt_future is not None: + self._restart_session_on_next_turn = True + def _clear_turn_callbacks(self) -> None: """Unwire per-turn bridge callbacks so trailing ``session_update`` between turns is a no-op (fires on the portal thread with no FIFOLock held by anyone — without unwiring, a stale ``on_event`` there would race with other threads mutating ``state.events``). """ + if self._client is None: + return self._client.on_event = None self._client.on_token = None self._client.on_activity = None @@ -1520,6 +1725,11 @@ def step( """ state = conversation.state + if self._restart_session_on_next_turn: + # If restart initialization fails, let the conversation transition + # to ERROR rather than reusing an ambiguous ACP session. + self._restart_session_after_drain_timeout(state, on_event) + # Conversation implementations already attach per-turn AgentContext # extensions to MessageEvent.extended_content; MessageEvent.to_llm_message() # merges those extensions with the user text. @@ -1607,6 +1817,7 @@ async def _prompt() -> PromptResponse | None: logger.info("ACP prompt returned in %.1fs", elapsed) self._finalize_successful_turn(response, elapsed, state, on_event) except TimeoutError: + self._request_session_cancel() self._emit_turn_timeout(time.monotonic() - t0, state, on_event) except Exception as e: self._emit_turn_error(e, state, on_event) @@ -1623,6 +1834,7 @@ async def astep( conversation: LocalConversation, on_event: ConversationCallbackType, on_token: ConversationTokenCallbackType | None = None, + prompt_message: MessageEvent | None = None, ) -> None: """Native-async variant of :meth:`step`. @@ -1634,36 +1846,39 @@ async def astep( ``on_event(observation)``, ``state.execution_status`` — runs entirely on the caller's thread. - Why this matters: ``LocalConversation.arun`` holds the - conversation state's reentrant ``FIFOLock`` on its loop thread - across ``await self.agent.astep(...)``. The default + Why this matters: ``LocalConversation.arun`` deliberately does + not hold the conversation state's reentrant ``FIFOLock`` across + long ACP prompt awaits, so remote user messages can be persisted + while the subprocess is still working. The default ``AgentBase.astep`` would wrap sync ``step`` in - ``loop.run_in_executor(None, self.step, ...)``, moving every - post-prompt callback to a worker thread. Any ``with state:`` - inside that chain (today: ``stats_callback``; tomorrow: any - callback added to LLM telemetry or the event pipeline) then - blocks on a lock owned by the loop thread that is itself - ``await``-ing ``astep`` to return. Keeping post-prompt work on - the caller's thread sidesteps the whole class of cross-thread - state-lock deadlocks. See #3348 / #3350 for the full diagnosis. + ``loop.run_in_executor(None, self.step, ...)``, moving post-prompt + callbacks and state updates to a worker thread. Keeping this path + native-async leaves finalization on the caller's loop task, where + ``LocalConversation`` can serialize each emitted event with a + short state-lock acquire and avoid the cross-thread deadlocks + diagnosed in #3348 / #3350. Bridge ``session_update`` notifications continue to fire on the - portal thread (no marshalling here) — they reach the user's - ``on_event`` chain via the agent-server's - ``_emit_event_from_thread`` queue, which already handles the - thread hop. Real-time mid-turn delivery of those events is a - separate concern (the queue waits for ``arun()`` to release the - state lock between iterations); it is not part of the deadlock - this fix removes. + portal thread (no marshalling here). The ``on_event`` callback + supplied by ``LocalConversation.arun`` is responsible for taking + the state lock around each individual event. """ state = conversation.state + if self._restart_session_on_next_turn: + # If restart initialization fails, let the conversation transition + # to ERROR rather than reusing an ambiguous ACP session. + self._restart_session_after_drain_timeout(state, on_event) + prompt_blocks: list[Any] | None = None - for event in reversed(list(state.events)): - if isinstance(event, MessageEvent) and event.source == "user": - prompt_blocks = self._build_acp_prompt(event) - if prompt_blocks: - break + if prompt_message is not None: + prompt_blocks = self._build_acp_prompt(prompt_message) + else: + for event in reversed(list(state.events)): + if isinstance(event, MessageEvent) and event.source == "user": + prompt_blocks = self._build_acp_prompt(event) + if prompt_blocks: + break if prompt_blocks is None: logger.warning("No user message found; finishing conversation") state.execution_status = ConversationExecutionStatus.FINISHED @@ -1672,6 +1887,7 @@ async def astep( self._reset_client_for_turn(on_token, on_event) t0 = time.monotonic() + prompt_future: Future[PromptResponse | None] | None = None try: logger.info( "Sending ACP prompt (timeout=%.0fs, blocks=%d, async)", @@ -1686,27 +1902,23 @@ async def astep( try: # Schedule the ACP prompt on the portal loop (where the # connection lives); await the future back on the caller - # loop. On timeout ``asyncio.wait_for`` cancels the - # caller-side asyncio future; the portal task may run to - # completion in the background (anyio starts it - # immediately on ``start_task_soon`` and - # ``concurrent.futures.Future.cancel()`` returns ``False`` - # for an already-running task), but - # ``_clear_turn_callbacks()`` in ``finally`` ensures any - # trailing ``session_update`` from that task is a no-op. - future = portal.start_task_soon(self._do_acp_prompt, prompt_blocks) - response = await asyncio.wait_for( - asyncio.wrap_future(future), - timeout=self.acp_prompt_timeout, + # loop. Shield the portal task from wait_for timeout so + # the timeout/cancellation handlers can send session/cancel + # and briefly drain the task before the next turn rewires + # callbacks. + current_prompt_future: Future[PromptResponse | None] = ( + portal.start_task_soon( + self._do_acp_prompt, + prompt_blocks, + ) + ) + prompt_future = current_prompt_future + response = await self._await_prompt_response_with_timeout( + current_prompt_future ) break - except TimeoutError as exc: - # ``asyncio.TimeoutError`` is ``TimeoutError`` on 3.11+. - # Re-raise as a clean TimeoutError so the outer handler - # branches the same way as the sync path. - raise TimeoutError( - f"ACP prompt timed out after {self.acp_prompt_timeout:.0f}s" - ) from exc + except TimeoutError: + raise except _RETRIABLE_CONNECTION_ERRORS as e: if attempt < max_retries: delay = _ACP_PROMPT_RETRY_DELAYS[ @@ -1750,7 +1962,12 @@ async def astep( elapsed = time.monotonic() - t0 logger.info("ACP prompt returned in %.1fs (async)", elapsed) - self._finalize_successful_turn(response, elapsed, state, on_event) + # ``on_event`` may be LocalConversation._on_event_with_state_lock, + # which re-acquires this same FIFOLock. This is safe because astep() + # finalization runs on the event-loop thread and FIFOLock is + # reentrant for the owning thread. + with state: + self._finalize_successful_turn(response, elapsed, state, on_event) except asyncio.CancelledError: # ``asyncio.CancelledError`` inherits from ``BaseException``, not # ``Exception`` — so it would otherwise bypass the generic handler @@ -1760,14 +1977,68 @@ async def astep( # before cancellation stays live in the event log forever # (``LocalConversation._emit_orphaned_action_errors`` only patches # ``ActionEvent``s, not ``ACPToolCallEvent``s). Cancel-emit on - # the caller thread while callbacks are still wired, then re-raise - # so ``arun()`` can transition to PAUSED. - self._cancel_inflight_tool_calls() + # the caller thread after the portal prompt has observed + # session/cancel, so late cancelled-turn updates cannot overwrite + # the terminal synthetic failures. + try: + await self._arequest_session_cancel() + drain_result = await self._drain_cancelled_prompt(prompt_future) + except asyncio.CancelledError: + with state: + elapsed = time.monotonic() - t0 + self._handle_cancelled_cleanup_interruption( + prompt_future, elapsed, state, on_event + ) + raise + with state: + elapsed = time.monotonic() - t0 + if drain_result.completed and drain_result.error is None: + if self._prompt_response_was_cancelled(drain_result.response): + self._cancel_inflight_tool_calls() + self._restart_session_on_next_turn = True + else: + self._finalize_successful_turn( + drain_result.response, elapsed, state, on_event + ) + raise + if drain_result.completed and drain_result.error is not None: + self._cancel_inflight_tool_calls() + self._restart_session_on_next_turn = True + raise + self._cancel_inflight_tool_calls() + if not drain_result.drained: + self._restart_session_on_next_turn = True raise except TimeoutError: - self._emit_turn_timeout(time.monotonic() - t0, state, on_event) + try: + await self._arequest_session_cancel() + drain_result = await self._drain_cancelled_prompt(prompt_future) + except asyncio.CancelledError: + with state: + elapsed = time.monotonic() - t0 + self._handle_cancelled_cleanup_interruption( + prompt_future, elapsed, state, on_event + ) + raise + with state: + elapsed = time.monotonic() - t0 + if drain_result.completed and drain_result.error is None: + if self._prompt_response_was_cancelled(drain_result.response): + self._emit_turn_timeout(elapsed, state, on_event) + self._restart_session_on_next_turn = True + else: + self._finalize_successful_turn( + drain_result.response, elapsed, state, on_event + ) + elif drain_result.completed and drain_result.error is not None: + self._emit_turn_error(drain_result.error, state, on_event) + self._restart_session_on_next_turn = True + else: + self._emit_turn_timeout(elapsed, state, on_event) + self._restart_session_on_next_turn = True except Exception as e: - self._emit_turn_error(e, state, on_event) + with state: + self._emit_turn_error(e, state, on_event) raise finally: self._clear_turn_callbacks() diff --git a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py index a6bb1ec60e..501650ebb8 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py @@ -5,6 +5,7 @@ import uuid from collections.abc import Mapping from pathlib import Path +from typing import TypeGuard from openhands.sdk.agent.acp_agent import ACPAgent from openhands.sdk.agent.base import AgentBase @@ -34,6 +35,7 @@ ActionEvent, AgentErrorEvent, CondensationRequest, + Event, InterruptEvent, MessageEvent, ObservationEvent, @@ -43,7 +45,7 @@ from openhands.sdk.event.conversation_error import ConversationErrorEvent from openhands.sdk.hooks import HookConfig, HookEventProcessor, create_hook_callback from openhands.sdk.io import LocalFileStore -from openhands.sdk.llm import LLM, Message, TextContent +from openhands.sdk.llm import LLM, Message, TextContent, content_to_str from openhands.sdk.llm.llm_profile_store import LLMProfileStore from openhands.sdk.llm.llm_registry import LLMRegistry from openhands.sdk.logger import get_logger @@ -71,6 +73,24 @@ logger = get_logger(__name__) +ACP_LAST_PROMPT_USER_MESSAGE_ID = "acp_last_prompt_user_message_id" +ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID = "acp_inflight_prompt_user_message_id" +ACP_SUPERSEDE_INFLIGHT_PROMPT = "acp_supersede_inflight_prompt" +ACP_STOP_HOOK_FEEDBACK_PREFIX = "[Stop hook feedback]" + + +def _is_acp_prompt_message(event: Event) -> TypeGuard[MessageEvent]: + if not isinstance(event, MessageEvent): + return False + if event.source == "user": + return True + if event.source != "environment" or event.llm_message.role != "user": + return False + return any( + part.startswith(ACP_STOP_HOOK_FEEDBACK_PREFIX) + for part in content_to_str(event.llm_message.content) + ) + class LocalConversation(BaseConversation): agent: AgentBase @@ -765,6 +785,11 @@ def send_message(self, message: str | Message, sender: str | None = None) -> Non ) self._on_event(user_msg_event) + def _on_event_with_state_lock(self, event: Event) -> None: + """Emit an event while holding the conversation state lock.""" + with self._state: + self._on_event(event) + @observe(name="conversation.run") def run(self) -> None: """Runs the conversation until the agent finishes. @@ -817,7 +842,9 @@ def run(self) -> None: if not should_stop: logger.info("Stop hook denied agent stopping") if feedback: - prefixed = f"[Stop hook feedback] {feedback}" + prefixed = ( + f"{ACP_STOP_HOOK_FEEDBACK_PREFIX} {feedback}" + ) feedback_msg = MessageEvent( source="environment", llm_message=Message( @@ -933,11 +960,26 @@ async def arun(self) -> None: observation is patched with a synthetic ``AgentErrorEvent`` so the LLM conversation history stays consistent. """ - self._ensure_agent_ready() self._arun_task = asyncio.current_task() self._cancel_token = CancellationToken() + self._ensure_agent_ready() with self._state: + if isinstance(self.agent, ACPAgent) and self._state.execution_status in ( + ConversationExecutionStatus.FINISHED, + ConversationExecutionStatus.IDLE, + ): + updated_agent_state = dict(self._state.agent_state) + inflight_prompt_user_message_id = updated_agent_state.get( + ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID + ) + if inflight_prompt_user_message_id is not None: + updated_agent_state[ACP_LAST_PROMPT_USER_MESSAGE_ID] = ( + inflight_prompt_user_message_id + ) + updated_agent_state.pop(ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID, None) + self._state.agent_state = updated_agent_state + if self._state.execution_status in [ ConversationExecutionStatus.IDLE, ConversationExecutionStatus.PAUSED, @@ -945,11 +987,16 @@ async def arun(self) -> None: ConversationExecutionStatus.STUCK, ]: self._state.execution_status = ConversationExecutionStatus.RUNNING + last_acp_prompt_user_message_id = self._state.agent_state.get( + ACP_LAST_PROMPT_USER_MESSAGE_ID + ) iteration = 0 try: while True: logger.debug(f"Conversation arun iteration {iteration}") + acp_step_user_message_id: str | None = None + acp_step_user_message: MessageEvent | None = None with self._state: if self._state.execution_status in [ ConversationExecutionStatus.PAUSED, @@ -968,7 +1015,9 @@ async def arun(self) -> None: if not should_stop: logger.info("Stop hook denied agent stopping") if feedback: - prefixed = f"[Stop hook feedback] {feedback}" + prefixed = ( + f"{ACP_STOP_HOOK_FEEDBACK_PREFIX} {feedback}" + ) feedback_msg = MessageEvent( source="environment", llm_message=Message( @@ -1000,12 +1049,226 @@ async def arun(self) -> None: ConversationExecutionStatus.RUNNING ) - await self.agent.astep( - self, - on_event=self._on_event, - on_token=self._on_token, - ) + if isinstance(self.agent, ACPAgent): + # Re-scan prompt messages under the lock each time we need + # the latest tail; the list is usually tiny, and correctness + # is more important than caching stale prompt snapshots. + + acp_prompt_messages = [ + event + for event in self._state.events + if _is_acp_prompt_message(event) + ] + if last_acp_prompt_user_message_id is None: + acp_step_user_message = ( + acp_prompt_messages[0] if acp_prompt_messages else None + ) + else: + last_prompt_index = next( + ( + index + for index, event in enumerate(acp_prompt_messages) + if event.id == last_acp_prompt_user_message_id + ), + None, + ) + if last_prompt_index is None: + logger.info( + "ACP prompt cursor %s no longer exists; " + "restarting from first available prompt", + last_acp_prompt_user_message_id, + ) + acp_step_user_message = ( + acp_prompt_messages[0] + if acp_prompt_messages + else None + ) + else: + acp_step_user_message = ( + acp_prompt_messages[last_prompt_index + 1] + if last_prompt_index + 1 < len(acp_prompt_messages) + else None + ) + acp_step_user_message_id = ( + acp_step_user_message.id + if acp_step_user_message is not None + else None + ) + else: + await self.agent.astep( + self, + on_event=self._on_event, + on_token=self._on_token, + ) + iteration += 1 + + if ( + self.state.execution_status + == ConversationExecutionStatus.WAITING_FOR_CONFIRMATION + ): + break + + if iteration >= self.max_iteration_per_run: + if ( + self._state.execution_status + == ConversationExecutionStatus.FINISHED + ): + break + error_msg = ( + f"Agent reached maximum iterations limit " + f"({self.max_iteration_per_run})." + ) + logger.error(error_msg) + self._state.execution_status = ( + ConversationExecutionStatus.ERROR + ) + self._on_event( + ConversationErrorEvent( + source="environment", + code="MaxIterationsReached", + detail=error_msg, + ) + ) + break + + continue + + # ACP prompt round-trips can run for minutes. Keep the state + # lock free while awaiting them so incoming user messages can + # be persisted immediately; event callbacks take the lock only + # for each individual mutation. + if acp_step_user_message is None: + with self._state: + acp_prompt_messages = [ + event + for event in self._state.events + if _is_acp_prompt_message(event) + ] + latest_acp_prompt_message_id = ( + acp_prompt_messages[-1].id if acp_prompt_messages else None + ) + acp_prompt_message_changed = ( + latest_acp_prompt_message_id is not None + and latest_acp_prompt_message_id + != last_acp_prompt_user_message_id + ) + if acp_prompt_message_changed: + if iteration >= self.max_iteration_per_run: + logger.info( + "User message arrived before ACP finish; " + "leaving conversation idle for a follow-up run" + ) + self._state.execution_status = ( + ConversationExecutionStatus.IDLE + ) + break + logger.info( + "User message arrived before ACP finish; continuing run" + ) + self._state.execution_status = ( + ConversationExecutionStatus.RUNNING + ) + continue + self._state.execution_status = ( + ConversationExecutionStatus.FINISHED + ) + break + + acp_step_start_event_count = 0 + with self._state: + if self._state.execution_status in ( + ConversationExecutionStatus.PAUSED, + ConversationExecutionStatus.STUCK, + ): + break + acp_step_start_event_count = len(self._state.events) + if acp_step_user_message_id is not None: + self._state.agent_state = { + **self._state.agent_state, + ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID: ( + acp_step_user_message_id + ), + } + + await self.agent.astep( + self, + on_event=self._on_event_with_state_lock, + on_token=self._on_token, + prompt_message=acp_step_user_message, + ) + with self._state: iteration += 1 + pause_requested_during_acp_step = any( + isinstance(event, PauseEvent) + for event in self._state.events[acp_step_start_event_count:] + ) + updated_agent_state = dict(self._state.agent_state) + if ( + updated_agent_state.get(ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID) + == acp_step_user_message_id + ): + updated_agent_state.pop( + ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID, None + ) + updated_agent_state.pop(ACP_SUPERSEDE_INFLIGHT_PROMPT, None) + if ( + acp_step_user_message_id is not None + and self._state.execution_status + not in ( + ConversationExecutionStatus.ERROR, + ConversationExecutionStatus.STUCK, + ConversationExecutionStatus.PAUSED, + ) + ): + last_acp_prompt_user_message_id = acp_step_user_message_id + updated_agent_state[ACP_LAST_PROMPT_USER_MESSAGE_ID] = ( + acp_step_user_message_id + ) + self._state.agent_state = updated_agent_state + + if self._state.execution_status in ( + ConversationExecutionStatus.ERROR, + ConversationExecutionStatus.STUCK, + ): + break + if pause_requested_during_acp_step: + self._state.execution_status = ( + ConversationExecutionStatus.PAUSED + ) + break + + acp_prompt_messages = [ + event + for event in self._state.events + if _is_acp_prompt_message(event) + ] + latest_acp_prompt_message_id = ( + acp_prompt_messages[-1].id if acp_prompt_messages else None + ) + acp_prompt_message_changed = ( + latest_acp_prompt_message_id is not None + and latest_acp_prompt_message_id + != last_acp_prompt_user_message_id + ) + if acp_prompt_message_changed and self._state.execution_status in ( + ConversationExecutionStatus.FINISHED, + ConversationExecutionStatus.IDLE, + ): + if iteration >= self.max_iteration_per_run: + logger.info( + "User message arrived during final ACP iteration; " + "leaving conversation idle for a follow-up run" + ) + self._state.execution_status = ( + ConversationExecutionStatus.IDLE + ) + break + logger.info( + "User message arrived during ACP step; continuing run" + ) + self._state.execution_status = ( + ConversationExecutionStatus.RUNNING + ) if ( self.state.execution_status @@ -1042,6 +1305,24 @@ async def arun(self) -> None: # PAUSED so the conversation can be resumed later. logger.info("arun() interrupted via task cancellation") with self._state: + updated_agent_state = dict(self._state.agent_state) + inflight_prompt_user_message_id = updated_agent_state.pop( + ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID, None + ) + superseded_by_new_message = bool( + updated_agent_state.pop(ACP_SUPERSEDE_INFLIGHT_PROMPT, False) + ) + completed_cancelled_prompt = ( + self._state.execution_status == ConversationExecutionStatus.FINISHED + ) + if ( + superseded_by_new_message or completed_cancelled_prompt + ) and inflight_prompt_user_message_id is not None: + updated_agent_state[ACP_LAST_PROMPT_USER_MESSAGE_ID] = ( + inflight_prompt_user_message_id + ) + self._state.agent_state = updated_agent_state + # Emit synthetic error observations for any ActionEvents # that were in-flight when the interrupt landed. Without # these the LLM history would contain tool-call requests @@ -1053,6 +1334,10 @@ async def arun(self) -> None: self._on_event(InterruptEvent()) except Exception as e: with self._state: + updated_agent_state = dict(self._state.agent_state) + updated_agent_state.pop(ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID, None) + updated_agent_state.pop(ACP_SUPERSEDE_INFLIGHT_PROMPT, None) + self._state.agent_state = updated_agent_state self._state.execution_status = ConversationExecutionStatus.ERROR self._on_event( ConversationErrorEvent( diff --git a/tests/agent_server/test_event_service.py b/tests/agent_server/test_event_service.py index 2058052c10..407778504e 100644 --- a/tests/agent_server/test_event_service.py +++ b/tests/agent_server/test_event_service.py @@ -23,7 +23,13 @@ ) from openhands.agent_server.pub_sub import Subscriber from openhands.sdk import LLM, Agent, Conversation, Message +from openhands.sdk.agent import ACPAgent from openhands.sdk.conversation.fifo_lock import FIFOLock +from openhands.sdk.conversation.impl.local_conversation import ( + ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID, + ACP_SUPERSEDE_INFLIGHT_PROMPT, + LocalConversation, +) from openhands.sdk.conversation.state import ( ConversationExecutionStatus, ConversationState, @@ -825,6 +831,287 @@ async def test_send_message_with_run_true_agent_idle(self, event_service): # Verify run was called since agent was idle conversation.run.assert_called_once() + @pytest.mark.asyncio + async def test_send_message_with_run_true_interrupts_running_acp_turn( + self, event_service, tmp_path + ): + """A new user message should interrupt an in-flight ACP prompt.""" + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=4, + stuck_detection=False, + ) + conversation.send_message("initial request") + event_service._conversation = conversation + event_service._publish_state_update = AsyncMock() + + first_step_started = asyncio.Event() + first_step_cancelled = asyncio.Event() + second_step_seen = asyncio.Event() + prompts_seen: list[str] = [] + + def user_text(event: MessageEvent | None) -> str: + assert event is not None + content = event.llm_message.content[0] + assert isinstance(content, TextContent) + return content.text + + async def blocking_astep( + self, # noqa: ARG001 + conv: LocalConversation, # noqa: ARG001 + on_event, # noqa: ARG001 + on_token=None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + prompts_seen.append(user_text(prompt_message)) + if len(prompts_seen) == 1: + first_step_started.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + first_step_cancelled.set() + raise + + second_step_seen.set() + conv.state.execution_status = ConversationExecutionStatus.FINISHED + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=blocking_astep), + ): + try: + await event_service.run() + await asyncio.wait_for(first_step_started.wait(), timeout=1.0) + + await event_service.send_message( + Message(role="user", content=[TextContent(text="intervening")]), + run=True, + ) + + await asyncio.wait_for(first_step_cancelled.wait(), timeout=1.0) + await asyncio.wait_for(second_step_seen.wait(), timeout=1.0) + finally: + if ( + event_service._run_task is not None + and not event_service._run_task.done() + ): + conversation.interrupt() + with suppress(asyncio.CancelledError, TimeoutError): + await asyncio.wait_for(event_service._run_task, timeout=1.0) + + assert prompts_seen == ["initial request", "intervening"] + + @pytest.mark.asyncio + async def test_send_message_with_run_true_does_not_interrupt_current_acp_prompt( + self, event_service, tmp_path + ): + """Do not cancel the ACP prompt if it already advanced to the new message.""" + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=4, + stuck_detection=False, + ) + conversation.send_message("initial request") + conversation.state.execution_status = ConversationExecutionStatus.RUNNING + event_service._conversation = conversation + event_service._publish_state_update = AsyncMock() + + release_run = asyncio.Event() + event_service._run_task = asyncio.create_task(release_run.wait()) + original_send_message = conversation.send_message + + def send_and_mark_active_prompt(message): + original_send_message(message) + conversation.state.execution_status = ConversationExecutionStatus.RUNNING + conversation.state.agent_state = { + **conversation.state.agent_state, + ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID: ( + conversation.state.last_user_message_id + ), + } + + conversation.send_message = send_and_mark_active_prompt # type: ignore[method-assign] + conversation.interrupt = MagicMock() # type: ignore[method-assign] + + try: + await event_service.send_message( + Message(role="user", content=[TextContent(text="intervening")]), + run=True, + ) + finally: + release_run.set() + await event_service._run_task + event_service._run_task = None + + conversation.interrupt.assert_not_called() + assert event_service._rerun_requested is False + + @pytest.mark.asyncio + async def test_acp_supersede_mark_rechecks_current_prompt( + self, event_service, tmp_path + ): + """Do not attach the supersede marker to a replacement ACP prompt.""" + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=4, + stuck_detection=False, + ) + conversation.send_message("initial request") + conversation.send_message("replacement request") + latest_user_message_id = conversation.state.last_user_message_id + assert latest_user_message_id is not None + conversation.state.execution_status = ConversationExecutionStatus.RUNNING + conversation.state.agent_state = { + **conversation.state.agent_state, + ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID: latest_user_message_id, + } + event_service._conversation = conversation + release_run = asyncio.Event() + event_service._run_task = asyncio.create_task(release_run.wait()) + + try: + ( + marked, + active_prompt_has_latest, + ) = await event_service._mark_running_acp_prompt_superseded() + finally: + release_run.set() + await event_service._run_task + event_service._run_task = None + + assert marked is False + assert active_prompt_has_latest is True + assert ACP_SUPERSEDE_INFLIGHT_PROMPT not in conversation.state.agent_state + + @pytest.mark.asyncio + async def test_explicit_interrupt_clears_internal_acp_rerun_request( + self, event_service + ): + """A later explicit stop should win over an earlier internal ACP rerun.""" + conversation = MagicMock() + event_service._conversation = conversation + event_service._publish_state_update = AsyncMock() + event_service._rerun_requested = True + event_service._acp_internal_rerun_requested = True + + await event_service.interrupt() + + conversation.interrupt.assert_called_once() + assert event_service._rerun_requested is False + assert event_service._acp_internal_rerun_requested is False + + @pytest.mark.asyncio + async def test_internal_acp_rerun_does_not_override_explicit_interrupt( + self, event_service + ): + """Explicit Stop/Pause should win while an internal ACP interrupt drains.""" + conversation = MagicMock() + conversation.send_message = MagicMock() + event_service._conversation = conversation + event_service._mark_running_acp_prompt_superseded = AsyncMock( + return_value=(True, False) + ) + event_service.run = AsyncMock() + + async def interrupt_and_simulate_user_stop(internal_acp_rerun=False): + assert internal_acp_rerun is True + event_service._explicit_interrupt_generation += 1 + event_service._rerun_requested = False + event_service._acp_internal_rerun_requested = False + + event_service.interrupt = interrupt_and_simulate_user_stop + + await event_service.send_message(Message(role="user", content=[]), run=True) + + event_service.run.assert_not_awaited() + assert event_service._rerun_requested is False + assert event_service._acp_internal_rerun_requested is False + + @pytest.mark.asyncio + async def test_internal_acp_send_message_restart_rechecks_generation_in_run( + self, event_service, tmp_path + ): + """A late explicit Stop/Pause should prevent direct ACP restart.""" + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=3, + stuck_detection=False, + ) + mock_arun = AsyncMock() + event_service._conversation = conversation + event_service._publish_state_update = AsyncMock() + event_service._mark_running_acp_prompt_superseded = AsyncMock( + return_value=(True, False) + ) + event_service.interrupt = AsyncMock() + + async def status_with_late_explicit_interrupt(): + event_service._explicit_interrupt_generation += 1 + event_service._rerun_requested = False + event_service._acp_internal_rerun_requested = False + return ConversationExecutionStatus.PAUSED + + event_service._get_execution_status = status_with_late_explicit_interrupt + + with patch.object(conversation, "arun", mock_arun): + await event_service.send_message(Message(role="user", content=[]), run=True) + + event_service.interrupt.assert_awaited_once_with(internal_acp_rerun=True) + mock_arun.assert_not_awaited() + assert event_service._run_task is None + assert event_service._rerun_requested is False + assert event_service._acp_internal_rerun_requested is False + + @pytest.mark.asyncio + async def test_internal_acp_rerun_rechecks_explicit_interrupt_before_restart( + self, event_service, tmp_path + ): + """Explicit Stop/Pause should win during final restart status checks.""" + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=3, + stuck_detection=False, + ) + mock_arun = AsyncMock() + event_service._conversation = conversation + event_service._publish_state_update = AsyncMock() + event_service._rerun_requested = True + event_service._acp_internal_rerun_requested = True + + status_calls = 0 + + async def status_with_late_explicit_interrupt(): + nonlocal status_calls + status_calls += 1 + if status_calls == 1: + return ConversationExecutionStatus.IDLE + event_service._explicit_interrupt_generation += 1 + event_service._rerun_requested = False + event_service._acp_internal_rerun_requested = False + return ConversationExecutionStatus.PAUSED + + event_service._get_execution_status = status_with_late_explicit_interrupt + + with patch.object(conversation, "arun", mock_arun): + await event_service.run() + assert event_service._run_task is not None + await asyncio.wait_for(event_service._run_task, timeout=1.0) + + mock_arun.assert_awaited_once() + assert status_calls == 2 + assert event_service._rerun_requested is False + assert event_service._acp_internal_rerun_requested is False + @pytest.mark.asyncio async def test_send_message_with_run_true_logs_exception(self, event_service): """Test that exceptions from conversation.run() are caught and logged.""" diff --git a/tests/agent_server/test_event_streaming.py b/tests/agent_server/test_event_streaming.py index 5279d73a8c..bfae85edc3 100644 --- a/tests/agent_server/test_event_streaming.py +++ b/tests/agent_server/test_event_streaming.py @@ -12,7 +12,7 @@ from openhands.agent_server.models import StoredConversation from openhands.agent_server.pub_sub import Subscriber from openhands.sdk import Event -from openhands.sdk.agent import Agent +from openhands.sdk.agent import ACPAgent, Agent from openhands.sdk.event import StreamingDeltaEvent from openhands.sdk.llm import LLM from openhands.sdk.workspace import LocalWorkspace @@ -218,6 +218,64 @@ async def test_token_callbacks_not_wired_when_stream_disabled(tmp_path): assert MockConv.call_args.kwargs["token_callbacks"] == [] +@pytest.mark.asyncio +async def test_acp_agents_wire_token_callback_without_llm_streaming(tmp_path): + """ACP AgentMessageChunk text should stream even though ACPAgent has no LLM.""" + service = EventService( + stored=StoredConversation( + id=uuid4(), + agent=ACPAgent(acp_command=["echo", "test"]), + workspace=LocalWorkspace(working_dir=str(tmp_path / "workspace")), + ), + conversations_dir=tmp_path / "conversations", + ) + (tmp_path / "workspace").mkdir(exist_ok=True) + + with _mock_local_conversation() as MockConv: + mock_conv = MagicMock() + mock_conv.state = MagicMock(execution_status="idle") + mock_conv._state = MagicMock() + mock_conv._on_event = MagicMock() + MockConv.return_value = mock_conv + + await service.start() + assert len(MockConv.call_args.kwargs["token_callbacks"]) == 1 + + +@pytest.mark.asyncio +async def test_acp_string_token_callback_publishes_delta(tmp_path): + """ACPAgent invokes token callbacks with plain text chunks.""" + service = EventService( + stored=StoredConversation( + id=uuid4(), + agent=ACPAgent(acp_command=["echo", "test"]), + workspace=LocalWorkspace(working_dir=str(tmp_path / "workspace")), + ), + conversations_dir=tmp_path / "conversations", + ) + collector = _CollectorSubscriber() + service._pub_sub.subscribe(collector) + (tmp_path / "workspace").mkdir(exist_ok=True) + + with _mock_local_conversation() as MockConv: + mock_conv = MagicMock() + mock_conv.state = MagicMock(execution_status="idle") + mock_conv._state = MagicMock() + mock_conv._on_event = MagicMock() + MockConv.return_value = mock_conv + + await service.start() + callback = MockConv.call_args.kwargs["token_callbacks"][0] + + callback("ACP live text") + await asyncio.sleep(0.05) + + delta_events = [e for e in collector.events if isinstance(e, StreamingDeltaEvent)] + assert len(delta_events) == 1 + assert delta_events[0].content == "ACP live text" + assert delta_events[0].reasoning_content is None + + @pytest.mark.asyncio async def test_multiple_chunks_produce_multiple_events(event_service, tmp_path): collector = _CollectorSubscriber() diff --git a/tests/sdk/agent/test_acp_agent.py b/tests/sdk/agent/test_acp_agent.py index 7f0bd2f689..6fb29a0e9e 100644 --- a/tests/sdk/agent/test_acp_agent.py +++ b/tests/sdk/agent/test_acp_agent.py @@ -6,11 +6,13 @@ import json import threading import uuid +from concurrent.futures import Future from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from acp.exceptions import RequestError as ACPRequestError +from acp.schema import PromptResponse from openhands.sdk.agent.acp_agent import ( ACPAgent, @@ -34,6 +36,7 @@ MessageEvent, SystemPromptEvent, ) +from openhands.sdk.event.conversation_error import ConversationErrorEvent from openhands.sdk.llm import ImageContent, Message, TextContent from openhands.sdk.skills import KeywordTrigger, Skill from openhands.sdk.tool.builtins.finish import FinishAction @@ -1369,9 +1372,9 @@ def _fake_run_async(_coro, **_kwargs): class TestACPAgentAstep: """Native ``ACPAgent.astep`` must not fall back to ``AgentBase.astep`` (which wraps ``step`` in ``loop.run_in_executor``). Doing so would - move post-prompt callbacks onto an executor worker thread, - deadlocking against ``LocalConversation.arun`` which holds the - state's reentrant ``FIFOLock`` on the loop thread. See #3348. + move post-prompt callbacks and state updates onto an executor worker + thread, outside ``LocalConversation.arun``'s controlled event + serialization. See #3348. """ def _make_conversation_with_message(self, tmp_path, text="Hello"): @@ -1401,10 +1404,9 @@ def test_astep_overrides_default_agentbase_implementation(self): def test_astep_runs_post_prompt_callbacks_on_caller_thread(self, tmp_path): """Post-prompt ``on_event`` callbacks must fire on the caller - thread (same thread that holds ``state.lock`` in ``arun``). - ``FIFOLock`` is reentrant per-thread; if astep schedules ``step`` - on a worker thread (the buggy default), callbacks run cross-thread - and block on the lock owner forever — see #3348. + thread. If astep schedules ``step`` on a worker thread (the buggy + default), callbacks and final state updates run outside the async + run task's serialization model — see #3348. """ from openhands.sdk.utils.async_executor import AsyncExecutor @@ -1515,6 +1517,82 @@ def _message_text(ev: MessageEvent) -> str: ) assert conversation.state.execution_status == ConversationExecutionStatus.ERROR + def test_astep_times_out_while_tool_call_is_inflight(self, tmp_path): + """A hard ACP prompt timeout still fires during an active tool call. + + Mirroring OpenHands command handling, active output/heartbeats keep the + runtime alive but do not let a never-ending command suppress the hard + turn deadline. The timeout path must cancel the ACP session and close + any streamed tool cards as failed. + """ + from concurrent.futures import Future + + agent = _make_agent(acp_prompt_timeout=0.02) + conversation = self._make_conversation_with_message(tmp_path) + emitted: list = [] + cancel_called = threading.Event() + + mock_client = _OpenHandsACPBridge() + mock_client.get_turn_usage_update = MagicMock(return_value=object()) + agent._client = mock_client + agent._conn = MagicMock() + agent._session_id = "test-session" + + class _FakePortal: + def __init__(self) -> None: + self.prompt_future: Future = Future() + + def start_task_soon(self, fn, *args): # noqa: ANN001, ANN202 + if args: + entry = { + "tool_call_id": "git-1", + "title": "git status", + "tool_kind": "execute", + "status": "in_progress", + "raw_input": None, + "raw_output": None, + "content": None, + } + mock_client.accumulated_tool_calls.append(entry) + mock_client._emit_tool_call_event(entry) + return self.prompt_future + + cancel_called.set() + cancel_future: Future = Future() + cancel_future.set_result(None) + return cancel_future + + mock_executor = MagicMock() + mock_executor.portal = _FakePortal() + agent._executor = mock_executor + + with patch("openhands.sdk.agent.acp_agent._ACP_CANCEL_DRAIN_TIMEOUT", 0.01): + asyncio.run(agent.astep(conversation, on_event=emitted.append)) + + assert cancel_called.is_set() + assert conversation.state.execution_status == ConversationExecutionStatus.ERROR + assert any( + isinstance(e, ACPToolCallEvent) + and e.tool_call_id == "git-1" + and e.status == "failed" + and e.is_error + for e in emitted + ) + + def _message_text(ev: MessageEvent) -> str: + first = ev.llm_message.content[0] + return first.text if isinstance(first, TextContent) else "" + + assert any( + isinstance(e, MessageEvent) + and "ACP prompt timed out after" in _message_text(e) + for e in emitted + ) + assert not any( + isinstance(e, ActionEvent) and isinstance(e.action, FinishAction) + for e in emitted + ) + def test_astep_emits_failed_tool_calls_on_cancellation(self, tmp_path): """``asyncio.CancelledError`` during astep must close in-flight ``ACPToolCallEvent``s as ``failed`` and re-raise. @@ -1543,6 +1621,8 @@ def test_astep_emits_failed_tool_calls_on_cancellation(self, tmp_path): async def _run_with_cancel() -> None: prompt_entered = asyncio.Event() + cancel_called = asyncio.Event() + prompt_released = threading.Event() caller_loop = asyncio.get_running_loop() async def _fake_prompt(prompt_blocks, session_id): @@ -1565,11 +1645,19 @@ async def _fake_prompt(prompt_blocks, session_id): # Signal caller loop that we're holding inside the prompt # so the cancel races deterministically. caller_loop.call_soon_threadsafe(prompt_entered.set) - # Block long enough for the cancel to land. - await asyncio.sleep(60) + # Block beyond the cancel-drain timeout so this test exercises + # the non-quiesced cancellation path that must synthesize + # failed ACP tool-call events. + released = await asyncio.to_thread(prompt_released.wait, 10.0) + assert released return None + async def _fake_cancel(session_id): + assert session_id == "test-session" + caller_loop.call_soon_threadsafe(cancel_called.set) + agent._conn.prompt = _fake_prompt + agent._conn.cancel = _fake_cancel agent._session_id = "test-session" task = asyncio.create_task( @@ -1577,8 +1665,16 @@ async def _fake_prompt(prompt_blocks, session_id): ) await asyncio.wait_for(prompt_entered.wait(), timeout=5.0) task.cancel() - with pytest.raises(asyncio.CancelledError): - await task + try: + with pytest.raises(asyncio.CancelledError): + with patch( + "openhands.sdk.agent.acp_agent._ACP_CANCEL_DRAIN_TIMEOUT", + 0.01, + ): + await task + await asyncio.wait_for(cancel_called.wait(), timeout=5.0) + finally: + prompt_released.set() try: agent._executor = executor @@ -1599,6 +1695,290 @@ async def _fake_prompt(prompt_blocks, session_id): ) assert failed_tool_events[0].is_error is True + def test_astep_finalizes_and_reraises_completed_cancelled_prompt(self, tmp_path): + """If a cancelled ACP prompt drains successfully, keep the completed turn. + + The ACP server may finish the prompt while ``session/cancel`` is being + delivered. In that case the remote session has accepted the assistant + turn, so OpenHands must finalize the same turn locally instead of + discarding the response and later resuming from diverged session history. + The original cancellation still propagates so explicit user stop intent + wins at the conversation layer. + """ + from acp.schema import AgentMessageChunk, TextContentBlock + + from openhands.sdk.utils.async_executor import AsyncExecutor + + agent = _make_agent() + conversation = self._make_conversation_with_message(tmp_path) + emitted: list = [] + + mock_client = _OpenHandsACPBridge() + mock_client.get_turn_usage_update = MagicMock(return_value=object()) + agent._client = mock_client + agent._conn = MagicMock() + + executor = AsyncExecutor() + + async def _run_with_cancel() -> None: + prompt_entered = asyncio.Event() + cancel_called = asyncio.Event() + prompt_released = threading.Event() + caller_loop = asyncio.get_running_loop() + + async def _fake_prompt(prompt_blocks, session_id): # noqa: ARG001 + caller_loop.call_soon_threadsafe(prompt_entered.set) + released = await asyncio.to_thread(prompt_released.wait, 10.0) + assert released + await mock_client.session_update( + session_id, + AgentMessageChunk( + session_update="agent_message_chunk", + content=TextContentBlock(type="text", text="done"), + ), + ) + return None + + async def _fake_cancel(session_id): + assert session_id == "test-session" + caller_loop.call_soon_threadsafe(cancel_called.set) + prompt_released.set() + + agent._conn.prompt = _fake_prompt + agent._conn.cancel = _fake_cancel + agent._session_id = "test-session" + + task = asyncio.create_task( + agent.astep(conversation, on_event=emitted.append) + ) + await asyncio.wait_for(prompt_entered.wait(), timeout=5.0) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + await asyncio.wait_for(cancel_called.wait(), timeout=5.0) + + try: + agent._executor = executor + asyncio.run(_run_with_cancel()) + finally: + executor.close() + + assert ( + conversation.state.execution_status == ConversationExecutionStatus.FINISHED + ) + assert any( + isinstance(e, ActionEvent) + and isinstance(e.action, FinishAction) + and e.action.message == "done" + for e in emitted + ) + + def test_astep_cancelled_prompt_error_pauses_without_turn_error(self, tmp_path): + """Explicit cancellation should not emit stale prompt errors.""" + from openhands.sdk.utils.async_executor import AsyncExecutor + + agent = _make_agent() + conversation = self._make_conversation_with_message(tmp_path) + emitted: list = [] + + mock_client = _OpenHandsACPBridge() + mock_client.get_turn_usage_update = MagicMock(return_value=object()) + agent._client = mock_client + agent._conn = MagicMock() + + executor = AsyncExecutor() + + async def _run_with_cancel() -> None: + prompt_entered = asyncio.Event() + cancel_called = asyncio.Event() + prompt_released = threading.Event() + caller_loop = asyncio.get_running_loop() + + async def _fake_prompt(prompt_blocks, session_id): # noqa: ARG001 + caller_loop.call_soon_threadsafe(prompt_entered.set) + released = await asyncio.to_thread(prompt_released.wait, 10.0) + assert released + raise RuntimeError("late prompt failure") + + async def _fake_cancel(session_id): + assert session_id == "test-session" + caller_loop.call_soon_threadsafe(cancel_called.set) + prompt_released.set() + + agent._conn.prompt = _fake_prompt + agent._conn.cancel = _fake_cancel + agent._session_id = "test-session" + + task = asyncio.create_task( + agent.astep(conversation, on_event=emitted.append) + ) + await asyncio.wait_for(prompt_entered.wait(), timeout=5.0) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + await asyncio.wait_for(cancel_called.wait(), timeout=5.0) + + try: + agent._executor = executor + asyncio.run(_run_with_cancel()) + finally: + executor.close() + + assert not any( + isinstance(e, MessageEvent) + and e.source == "agent" + and any( + isinstance(c, TextContent) and c.text.startswith("ACP error:") + for c in e.llm_message.content + ) + for e in emitted + ) + assert not any(isinstance(e, ConversationErrorEvent) for e in emitted) + assert agent._restart_session_on_next_turn is True + + def test_astep_double_cancel_during_drain_restarts_next_turn(self, tmp_path): + """A second cancellation during drain should quarantine the live prompt.""" + from openhands.sdk.utils.async_executor import AsyncExecutor + + agent = _make_agent() + conversation = self._make_conversation_with_message(tmp_path) + + mock_client = _OpenHandsACPBridge() + mock_client.get_turn_usage_update = MagicMock(return_value=object()) + agent._client = mock_client + agent._conn = MagicMock() + + executor = AsyncExecutor() + + async def _run_with_double_cancel() -> None: + prompt_entered = asyncio.Event() + prompt_released = threading.Event() + caller_loop = asyncio.get_running_loop() + + async def _fake_prompt(prompt_blocks, session_id): # noqa: ARG001 + caller_loop.call_soon_threadsafe(prompt_entered.set) + released = await asyncio.to_thread(prompt_released.wait, 10.0) + assert released + return None + + async def _fake_cancel(session_id): + assert session_id == "test-session" + + async def _raise_during_drain(self, future): # noqa: ARG001 + assert future is not None + assert not future.done() + raise asyncio.CancelledError + + agent._conn.prompt = _fake_prompt + agent._conn.cancel = _fake_cancel + agent._session_id = "test-session" + + with patch.object( + ACPAgent, + "_drain_cancelled_prompt", + new=_raise_during_drain, + ): + task = asyncio.create_task( + agent.astep(conversation, on_event=lambda _: None) + ) + await asyncio.wait_for(prompt_entered.wait(), timeout=5.0) + task.cancel() + try: + with pytest.raises(asyncio.CancelledError): + await task + finally: + prompt_released.set() + + try: + agent._executor = executor + asyncio.run(_run_with_double_cancel()) + finally: + executor.close() + + assert agent._restart_session_on_next_turn is True + + def test_astep_double_cancel_during_cancel_send_restarts_next_turn(self, tmp_path): + """A second cancellation during session/cancel should quarantine prompt.""" + from openhands.sdk.utils.async_executor import AsyncExecutor + + agent = _make_agent() + conversation = self._make_conversation_with_message(tmp_path) + + mock_client = _OpenHandsACPBridge() + mock_client.get_turn_usage_update = MagicMock(return_value=object()) + agent._client = mock_client + agent._conn = MagicMock() + + executor = AsyncExecutor() + + async def _run_with_cancelled_cancel_send() -> None: + prompt_entered = asyncio.Event() + prompt_released = threading.Event() + caller_loop = asyncio.get_running_loop() + + async def _fake_prompt(prompt_blocks, session_id): # noqa: ARG001 + caller_loop.call_soon_threadsafe(prompt_entered.set) + released = await asyncio.to_thread(prompt_released.wait, 10.0) + assert released + return None + + async def _raise_during_cancel_send(self): # noqa: ARG001 + raise asyncio.CancelledError + + agent._conn.prompt = _fake_prompt + agent._session_id = "test-session" + + with patch.object( + ACPAgent, + "_arequest_session_cancel", + new=_raise_during_cancel_send, + ): + task = asyncio.create_task( + agent.astep(conversation, on_event=lambda _: None) + ) + await asyncio.wait_for(prompt_entered.wait(), timeout=5.0) + task.cancel() + try: + with pytest.raises(asyncio.CancelledError): + await task + finally: + prompt_released.set() + + try: + agent._executor = executor + asyncio.run(_run_with_cancelled_cancel_send()) + finally: + executor.close() + + assert agent._restart_session_on_next_turn is True + + def test_cleanup_interruption_finalizes_completed_prompt(self, tmp_path): + """A completed prompt should be finalized if cleanup is cancelled.""" + agent = _make_agent() + conversation = self._make_conversation_with_message(tmp_path) + mock_client = _OpenHandsACPBridge() + mock_client.get_turn_usage_update = MagicMock(return_value=object()) + agent._client = mock_client + agent._session_id = "test-session" + + prompt_future: Future[PromptResponse | None] = Future() + prompt_future.set_result(None) + emitted = [] + + with conversation.state as state: + agent._handle_cancelled_cleanup_interruption( + prompt_future, + 0.1, + state, + emitted.append, + ) + + assert ( + conversation.state.execution_status == ConversationExecutionStatus.FINISHED + ) + assert agent._restart_session_on_next_turn is False + assert any(isinstance(event, ActionEvent) for event in emitted) + def test_astep_cancellation_does_not_mark_suffix_installed(self, tmp_path): """Cancellation before a turn completes must leave ``_suffix_install_state`` as ``pending_first_prompt``. @@ -1630,14 +2010,20 @@ def test_astep_cancellation_does_not_mark_suffix_installed(self, tmp_path): async def _run_with_cancel() -> None: prompt_entered = asyncio.Event() + prompt_released = threading.Event() caller_loop = asyncio.get_running_loop() async def _fake_prompt(prompt_blocks, session_id): caller_loop.call_soon_threadsafe(prompt_entered.set) - await asyncio.sleep(60) + released = await asyncio.to_thread(prompt_released.wait, 10.0) + assert released return None + async def _fake_cancel(session_id): + assert session_id == "test-session" + agent._conn.prompt = _fake_prompt + agent._conn.cancel = _fake_cancel agent._session_id = "test-session" task = asyncio.create_task( @@ -1645,8 +2031,15 @@ async def _fake_prompt(prompt_blocks, session_id): ) await asyncio.wait_for(prompt_entered.wait(), timeout=5.0) task.cancel() - with pytest.raises(asyncio.CancelledError): - await task + try: + with pytest.raises(asyncio.CancelledError): + with patch( + "openhands.sdk.agent.acp_agent._ACP_CANCEL_DRAIN_TIMEOUT", + 0.01, + ): + await task + finally: + prompt_released.set() try: agent._executor = executor @@ -1673,12 +2066,11 @@ async def _fake_prompt(prompt_blocks, session_id): def test_astep_does_not_deadlock_under_reentrant_state_lock(self, tmp_path): """End-to-end shape of the #3348 bug. - Mirrors ``LocalConversation.arun``: holds ``state.lock`` on the - loop thread across ``await astep(...)``, while a post-prompt - callback re-acquires it (same shape as ``stats_callback``'s - ``with state:``). With astep overridden, the callback runs on - the same thread as the lock owner — FIFOLock's reentrancy lets - it through. Without the override, this hangs. + Covers direct callers that hold ``state.lock`` on the loop thread + across ``await astep(...)`` while a post-prompt callback + re-acquires it. With astep overridden, the callback runs on the + same thread as the lock owner — FIFOLock's reentrancy lets it + through. Without the override, this hangs. """ from openhands.sdk.utils.async_executor import AsyncExecutor @@ -3721,6 +4113,20 @@ def test_first_launch_calls_new_session(self, tmp_path): conn.load_session.assert_not_awaited() assert agent._session_id == "fresh-sess" + def test_cancel_drain_restart_keeps_retry_flag_when_init_fails(self, tmp_path): + """A failed replacement session should leave the deferred restart armed.""" + agent = _make_agent() + state = _make_state(tmp_path) + agent._restart_session_on_next_turn = True + + with patch.object(ACPAgent, "init_state", side_effect=RuntimeError("boom")): + with pytest.raises(RuntimeError, match="boom"): + agent._restart_session_after_drain_timeout( + state, on_event=lambda _: None + ) + + assert agent._restart_session_on_next_turn is True + def test_init_state_writes_session_id_into_agent_state(self, tmp_path): """init_state lands the session id in state.agent_state so ConversationState's base_state.json persistence carries it forward. @@ -3758,6 +4164,34 @@ def test_resume_reads_session_id_from_agent_state(self, tmp_path): conn.new_session.assert_not_awaited() assert agent._session_id == "stored-sess" + def test_cancel_drain_restart_preserves_session_id_for_resume(self, tmp_path): + """A cancelled-prompt drain timeout restarts the subprocess, but should + still load the persisted ACP session so the server keeps conversation + memory. + """ + agent = _make_agent( + agent_context=AgentContext(system_message_suffix="Team rules.") + ) + state = _make_state(tmp_path) + state.agent_state = { + **state.agent_state, + "acp_session_id": "stored-sess", + "acp_session_cwd": str(tmp_path), + "acp_suffix_installed": True, + } + conn = self._make_conn() + + with self._transport_patches(conn): + agent._restart_session_after_drain_timeout(state, on_event=lambda _: None) + + conn.load_session.assert_awaited_once() + conn.new_session.assert_not_awaited() + assert agent._session_id == "stored-sess" + assert state.agent_state["acp_session_id"] == "stored-sess" + assert state.agent_state["acp_session_cwd"] == str(tmp_path) + assert state.agent_state["acp_suffix_installed"] is True + assert agent._suffix_install_state == "installed" + def test_load_session_failure_falls_back_to_new_session(self, tmp_path): """ACPRequestError on load_session → new_session is called.""" agent = _make_agent() @@ -3870,6 +4304,35 @@ def test_fallback_replacement_id_lands_in_agent_state(self, tmp_path): assert state.agent_state["acp_session_id"] == "replacement-sess" assert state.agent_state["acp_session_cwd"] == str(tmp_path) + def test_fallback_replacement_clears_suffix_marker(self, tmp_path): + """If load_session fails, the replacement session has not seen any + suffix yet, even if the stale session had persisted the marker. + """ + agent = _make_agent( + agent_context=AgentContext(system_message_suffix="Team rules.") + ) + state = _make_state(tmp_path) + state.agent_state = { + **state.agent_state, + "acp_session_id": "stale-sess", + "acp_session_cwd": str(tmp_path), + "acp_suffix_installed": True, + } + conn = self._make_conn( + new_session_id="replacement-sess", + load_exc=ACPRequestError(-32602, "unknown session"), + ) + + with self._transport_patches(conn): + agent.init_state(state, on_event=lambda _: None) + + conn.load_session.assert_awaited_once() + conn.new_session.assert_awaited_once() + assert state.agent_state["acp_session_id"] == "replacement-sess" + assert state.agent_state["acp_session_cwd"] == str(tmp_path) + assert state.agent_state.get("acp_suffix_installed") is not True + assert agent._suffix_install_state == "pending_first_prompt" + def test_resume_path_still_applies_session_mode_and_model(self, tmp_path): """load_session must be followed by the same set_session_model and set_session_mode calls as new_session, so a resumed session honours diff --git a/tests/sdk/conversation/local/test_conversation_send_message.py b/tests/sdk/conversation/local/test_conversation_send_message.py index b43c4c7e7b..ec124f2da8 100644 --- a/tests/sdk/conversation/local/test_conversation_send_message.py +++ b/tests/sdk/conversation/local/test_conversation_send_message.py @@ -1,10 +1,16 @@ -from unittest.mock import patch +import asyncio +from unittest.mock import MagicMock, patch +import pytest from pydantic import SecretStr from openhands.sdk.agent.acp_agent import ACPAgent from openhands.sdk.agent.base import AgentBase from openhands.sdk.conversation import Conversation, LocalConversation +from openhands.sdk.conversation.impl.local_conversation import ( + ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID, + ACP_LAST_PROMPT_USER_MESSAGE_ID, +) from openhands.sdk.conversation.state import ( ConversationExecutionStatus, ConversationState, @@ -201,3 +207,645 @@ def _finish_immediately(self, conv, on_event, on_token=None): conversation.state.execution_status == ConversationExecutionStatus.FINISHED ) assert conversation.state.events[-1] == user_event + + +@pytest.mark.asyncio +async def test_acp_arun_accepts_user_message_while_step_is_in_flight(tmp_path): + """ACP user messages should be persisted while a long async turn is running.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=4, + stuck_detection=False, + ) + conversation.send_message("initial request") + + first_step_started = asyncio.Event() + release_first_step = asyncio.Event() + second_step_seen = asyncio.Event() + prompts_seen: list[str] = [] + + def user_text(event: MessageEvent | None) -> str: + assert event is not None + content = event.llm_message.content[0] + assert isinstance(content, TextContent) + return content.text + + async def blocking_astep( + self, # noqa: ARG001 + conv: LocalConversation, # noqa: ARG001 + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + prompts_seen.append(user_text(prompt_message)) + if len(prompts_seen) == 1: + first_step_started.set() + await release_first_step.wait() + else: + second_step_seen.set() + conv.state.execution_status = ConversationExecutionStatus.FINISHED + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=blocking_astep), + ): + run_task = asyncio.create_task(conversation.arun()) + send_done = asyncio.Event() + + async def send_intervening_message() -> None: + await asyncio.to_thread(conversation.send_message, "intervening request") + send_done.set() + + await asyncio.wait_for(first_step_started.wait(), timeout=1.0) + send_task = asyncio.create_task(send_intervening_message()) + + try: + await asyncio.wait_for(send_done.wait(), timeout=5.0) + finally: + release_first_step.set() + await asyncio.wait_for(send_task, timeout=1.0) + await asyncio.wait_for(second_step_seen.wait(), timeout=1.0) + await asyncio.wait_for(run_task, timeout=1.0) + + assert prompts_seen == ["initial request", "intervening request"] + + +@pytest.mark.asyncio +async def test_acp_arun_marks_queued_message_running_after_finish_gap(tmp_path): + """Queued ACP messages should resume RUNNING even if send sees FINISHED.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=4, + stuck_detection=False, + ) + conversation.send_message("initial request") + + first_step_finished = asyncio.Event() + release_first_step = asyncio.Event() + second_step_seen = asyncio.Event() + second_step_statuses: list[ConversationExecutionStatus] = [] + prompts_seen: list[str] = [] + + def user_text(event: MessageEvent | None) -> str: + assert event is not None + content = event.llm_message.content[0] + assert isinstance(content, TextContent) + return content.text + + async def blocking_astep( + self, # noqa: ARG001 + conv: LocalConversation, + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + prompts_seen.append(user_text(prompt_message)) + if len(prompts_seen) == 1: + conv.state.execution_status = ConversationExecutionStatus.FINISHED + first_step_finished.set() + await release_first_step.wait() + else: + second_step_statuses.append(conv.state.execution_status) + conv.state.execution_status = ConversationExecutionStatus.FINISHED + second_step_seen.set() + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=blocking_astep), + ): + run_task = asyncio.create_task(conversation.arun()) + await asyncio.wait_for(first_step_finished.wait(), timeout=1.0) + await asyncio.to_thread(conversation.send_message, "intervening request") + assert conversation.state.execution_status == ConversationExecutionStatus.IDLE + release_first_step.set() + await asyncio.wait_for(second_step_seen.wait(), timeout=1.0) + await asyncio.wait_for(run_task, timeout=1.0) + + assert prompts_seen == ["initial request", "intervening request"] + assert second_step_statuses == [ConversationExecutionStatus.RUNNING] + + +@pytest.mark.asyncio +async def test_acp_arun_processes_multiple_queued_messages_fifo(tmp_path): + """ACP arun should not skip earlier messages queued during a prompt.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=4, + stuck_detection=False, + ) + conversation.send_message("initial request") + + first_step_finished = asyncio.Event() + release_first_step = asyncio.Event() + all_queued_steps_seen = asyncio.Event() + prompts_seen: list[str] = [] + + def user_text(event: MessageEvent | None) -> str: + assert event is not None + content = event.llm_message.content[0] + assert isinstance(content, TextContent) + return content.text + + async def blocking_astep( + self, # noqa: ARG001 + conv: LocalConversation, + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + prompts_seen.append(user_text(prompt_message)) + conv.state.execution_status = ConversationExecutionStatus.FINISHED + if len(prompts_seen) == 1: + first_step_finished.set() + await release_first_step.wait() + elif prompts_seen[-2:] == ["queued one", "queued two"]: + all_queued_steps_seen.set() + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=blocking_astep), + ): + run_task = asyncio.create_task(conversation.arun()) + await asyncio.wait_for(first_step_finished.wait(), timeout=1.0) + await asyncio.to_thread(conversation.send_message, "queued one") + await asyncio.to_thread(conversation.send_message, "queued two") + release_first_step.set() + await asyncio.wait_for(all_queued_steps_seen.wait(), timeout=1.0) + await asyncio.wait_for(run_task, timeout=1.0) + + assert prompts_seen == ["initial request", "queued one", "queued two"] + + +@pytest.mark.asyncio +async def test_acp_arun_processes_initial_queued_messages_fifo(tmp_path): + """ACP arun should process pre-run queued messages from oldest to newest.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=3, + stuck_detection=False, + ) + conversation.send_message("queued one") + conversation.send_message("queued two") + + prompts_seen: list[str] = [] + + def user_text(event: MessageEvent | None) -> str: + assert event is not None + content = event.llm_message.content[0] + assert isinstance(content, TextContent) + return content.text + + async def blocking_astep( + self, # noqa: ARG001 + conv: LocalConversation, + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + prompts_seen.append(user_text(prompt_message)) + conv.state.execution_status = ConversationExecutionStatus.FINISHED + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=blocking_astep), + ): + await asyncio.wait_for(conversation.arun(), timeout=1.0) + + assert prompts_seen == ["queued one", "queued two"] + + +@pytest.mark.asyncio +async def test_acp_arun_does_not_reprompt_when_cursor_is_current(tmp_path): + """ACP arun should finish when there is no queued user message.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=3, + stuck_detection=False, + ) + conversation.send_message("already processed") + conversation.state.agent_state = { + ACP_LAST_PROMPT_USER_MESSAGE_ID: conversation.state.last_user_message_id + } + + prompts_seen: list[MessageEvent | None] = [] + + async def blocking_astep( + self, # noqa: ARG001 + conv: LocalConversation, # noqa: ARG001 + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + prompts_seen.append(prompt_message) + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=blocking_astep), + ): + await asyncio.wait_for(conversation.arun(), timeout=1.0) + + assert prompts_seen == [] + assert conversation.state.execution_status == ConversationExecutionStatus.FINISHED + + +@pytest.mark.asyncio +async def test_acp_arun_recovers_when_persisted_cursor_is_missing(tmp_path): + """A stale persisted ACP cursor should not tight-loop the run.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=3, + stuck_detection=False, + ) + conversation.send_message("surviving message") + surviving_message_id = conversation.state.last_user_message_id + conversation.state.agent_state = {ACP_LAST_PROMPT_USER_MESSAGE_ID: "missing-id"} + + prompts_seen: list[str] = [] + + def user_text(event: MessageEvent | None) -> str: + assert event is not None + content = event.llm_message.content[0] + assert isinstance(content, TextContent) + return content.text + + async def record_astep( + self, # noqa: ARG001 + conv: LocalConversation, + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + prompts_seen.append(user_text(prompt_message)) + conv.state.execution_status = ConversationExecutionStatus.FINISHED + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=record_astep), + ): + await asyncio.wait_for(conversation.arun(), timeout=1.0) + + assert prompts_seen == ["surviving message"] + assert conversation.state.execution_status == ConversationExecutionStatus.FINISHED + assert ( + conversation.state.agent_state.get(ACP_LAST_PROMPT_USER_MESSAGE_ID) + == surviving_message_id + ) + + +@pytest.mark.asyncio +async def test_acp_arun_sends_stop_hook_feedback_to_acp(tmp_path): + """ACP stop-hook feedback should be queued as the next prompt.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=3, + stuck_detection=False, + ) + conversation.send_message("initial request") + + hook = MagicMock() + hook.run_stop.side_effect = [(False, "please continue"), (True, None)] + conversation._hook_processor = hook + prompts_seen: list[str] = [] + + def user_text(event: MessageEvent | None) -> str: + assert event is not None + content = event.llm_message.content[0] + assert isinstance(content, TextContent) + return content.text + + async def finish_astep( + self, # noqa: ARG001 + conv: LocalConversation, + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + prompts_seen.append(user_text(prompt_message)) + conv.state.execution_status = ConversationExecutionStatus.FINISHED + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=finish_astep), + ): + await asyncio.wait_for(conversation.arun(), timeout=1.0) + + assert hook.run_stop.call_count == 2 + assert prompts_seen == [ + "initial request", + "[Stop hook feedback] please continue", + ] + assert conversation.state.execution_status == ConversationExecutionStatus.FINISHED + + +@pytest.mark.asyncio +async def test_acp_arun_rechecks_messages_before_finishing(tmp_path): + """A user message appended in the finish gap should be sent in the same run.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=3, + stuck_detection=False, + ) + conversation.send_message("already processed") + conversation.state.agent_state = { + ACP_LAST_PROMPT_USER_MESSAGE_ID: conversation.state.last_user_message_id + } + conversation._agent_ready = True + + prompts_seen: list[str] = [] + + def user_text(event: MessageEvent | None) -> str: + assert event is not None + content = event.llm_message.content[0] + assert isinstance(content, TextContent) + return content.text + + async def record_astep( + self, # noqa: ARG001 + conv: LocalConversation, + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + prompts_seen.append(user_text(prompt_message)) + conv.state.execution_status = ConversationExecutionStatus.FINISHED + + original_exit = ConversationState.__exit__ + exit_count = 0 + injected = False + + def inject_after_empty_selection( + state: ConversationState, exc_type, exc_val, exc_tb + ) -> None: + nonlocal exit_count, injected + original_exit(state, exc_type, exc_val, exc_tb) + if state is conversation.state: + exit_count += 1 + if exit_count == 2 and not injected: + injected = True + conversation.send_message("arrived in finish gap") + + with ( + patch.object(ConversationState, "__exit__", new=inject_after_empty_selection), + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=record_astep), + ): + await asyncio.wait_for(conversation.arun(), timeout=1.0) + + assert injected is True + assert prompts_seen == ["arrived in finish gap"] + assert conversation.state.execution_status == ConversationExecutionStatus.FINISHED + assert ( + conversation.state.agent_state.get(ACP_LAST_PROMPT_USER_MESSAGE_ID) + == conversation.state.last_user_message_id + ) + + +@pytest.mark.asyncio +async def test_acp_arun_does_not_commit_cursor_on_explicit_interrupt(tmp_path): + """Explicit interruption should leave the in-flight ACP prompt retryable.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=3, + stuck_detection=False, + ) + conversation.send_message("cancel me") + first_message_id = conversation.state.last_user_message_id + + prompt_started = asyncio.Event() + + async def blocking_astep( + self, # noqa: ARG001 + conv: LocalConversation, # noqa: ARG001 + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, # noqa: ARG001 + ) -> None: + prompt_started.set() + await asyncio.Event().wait() + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=blocking_astep), + ): + task = asyncio.create_task(conversation.arun()) + await asyncio.wait_for(prompt_started.wait(), timeout=1.0) + conversation.interrupt() + await asyncio.wait_for(task, timeout=1.0) + + assert conversation.state.execution_status == ConversationExecutionStatus.PAUSED + assert ( + conversation.state.agent_state.get(ACP_LAST_PROMPT_USER_MESSAGE_ID) + != first_message_id + ) + assert ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID not in conversation.state.agent_state + + +@pytest.mark.asyncio +async def test_acp_arun_commits_cursor_when_cancelled_prompt_completed(tmp_path): + """Completed ACP prompts should not be replayed after cancellation.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=3, + stuck_detection=False, + ) + conversation.send_message("complete during cancel") + first_message_id = conversation.state.last_user_message_id + prompts_seen: list[str] = [] + + async def finishing_cancelled_astep( + self, # noqa: ARG001 + conv: LocalConversation, + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + assert prompt_message is not None + content = prompt_message.llm_message.content[0] + assert isinstance(content, TextContent) + prompts_seen.append(content.text) + conv.state.execution_status = ConversationExecutionStatus.FINISHED + raise asyncio.CancelledError + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=finishing_cancelled_astep), + ): + await asyncio.wait_for(conversation.arun(), timeout=1.0) + assert conversation.state.execution_status == ConversationExecutionStatus.PAUSED + await asyncio.wait_for(conversation.arun(), timeout=1.0) + + assert conversation.state.execution_status == ConversationExecutionStatus.FINISHED + assert ( + conversation.state.agent_state.get(ACP_LAST_PROMPT_USER_MESSAGE_ID) + == first_message_id + ) + assert ACP_INFLIGHT_PROMPT_USER_MESSAGE_ID not in conversation.state.agent_state + assert prompts_seen == ["complete during cancel"] + + +@pytest.mark.asyncio +async def test_acp_arun_resumes_queued_messages_fifo_after_iteration_cap(tmp_path): + """Queued ACP messages should remain FIFO across follow-up runs.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=1, + stuck_detection=False, + ) + conversation.send_message("initial request") + + first_step_finished = asyncio.Event() + release_first_step = asyncio.Event() + prompts_seen: list[str] = [] + + def user_text(event: MessageEvent | None) -> str: + assert event is not None + content = event.llm_message.content[0] + assert isinstance(content, TextContent) + return content.text + + async def blocking_astep( + self, # noqa: ARG001 + conv: LocalConversation, + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + prompts_seen.append(user_text(prompt_message)) + conv.state.execution_status = ConversationExecutionStatus.FINISHED + if len(prompts_seen) == 1: + first_step_finished.set() + await release_first_step.wait() + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=blocking_astep), + ): + first_run = asyncio.create_task(conversation.arun()) + await asyncio.wait_for(first_step_finished.wait(), timeout=1.0) + await asyncio.to_thread(conversation.send_message, "queued one") + await asyncio.to_thread(conversation.send_message, "queued two") + release_first_step.set() + await asyncio.wait_for(first_run, timeout=1.0) + + assert conversation.state.execution_status == ConversationExecutionStatus.IDLE + await asyncio.wait_for(conversation.arun(), timeout=1.0) + await asyncio.wait_for(conversation.arun(), timeout=1.0) + + assert prompts_seen == ["initial request", "queued one", "queued two"] + + +@pytest.mark.asyncio +async def test_acp_arun_stops_after_agent_sets_error(tmp_path): + """ACP timeout/error statuses should not be replaced by max-iteration errors.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=3, + stuck_detection=False, + ) + conversation.send_message("initial request") + prompts_seen: list[str] = [] + + async def failing_astep( + self, # noqa: ARG001 + conv: LocalConversation, + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + assert prompt_message is not None + prompts_seen.append("prompt") + conv.state.execution_status = ConversationExecutionStatus.ERROR + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=failing_astep), + ): + await asyncio.wait_for(conversation.arun(), timeout=1.0) + + assert prompts_seen == ["prompt"] + assert conversation.state.execution_status == ConversationExecutionStatus.ERROR + + +@pytest.mark.asyncio +async def test_acp_arun_leaves_queued_message_idle_at_iteration_cap(tmp_path): + """A queued ACP message at the run cap should wait for another run.""" + + agent = ACPAgent(acp_command=["echo", "test"]) + conversation = LocalConversation( + agent=agent, + workspace=str(tmp_path), + max_iteration_per_run=1, + stuck_detection=False, + ) + conversation.send_message("initial request") + + step_finished = asyncio.Event() + release_step = asyncio.Event() + prompts_seen: list[str] = [] + + def user_text(event: MessageEvent | None) -> str: + assert event is not None + content = event.llm_message.content[0] + assert isinstance(content, TextContent) + return content.text + + async def blocking_astep( + self, # noqa: ARG001 + conv: LocalConversation, + on_event: ConversationCallbackType, # noqa: ARG001 + on_token: ConversationTokenCallbackType | None = None, # noqa: ARG001 + prompt_message: MessageEvent | None = None, + ) -> None: + prompts_seen.append(user_text(prompt_message)) + conv.state.execution_status = ConversationExecutionStatus.FINISHED + step_finished.set() + await release_step.wait() + + with ( + patch.object(ACPAgent, "init_state", autospec=True), + patch.object(ACPAgent, "astep", new=blocking_astep), + ): + run_task = asyncio.create_task(conversation.arun()) + await asyncio.wait_for(step_finished.wait(), timeout=1.0) + await asyncio.to_thread(conversation.send_message, "intervening request") + release_step.set() + await asyncio.wait_for(run_task, timeout=1.0) + + assert prompts_seen == ["initial request"] + assert conversation.state.execution_status == ConversationExecutionStatus.IDLE