From be4d773e1e13a6b0fc369a13d252e82a15c38728 Mon Sep 17 00:00:00 2001 From: Jaison Paul Date: Wed, 11 Mar 2026 01:01:23 -0400 Subject: [PATCH] fix: bubble up subagent HITL approval to parent agent When subagents with HITL-enabled tools enter input_required state, the parent agent now detects this from the A2A response metadata and propagates the approval request to the top-level chat. On resume, the decision is forwarded directly to the subagent via its A2A client, and the parent agent re-runs to process the result. Key implementation details: - Detect subagent input_required via ADK event custom_metadata - Persist subagent HITL context in session state for resume - Forward decisions to subagents with 5-minute timeout - Clear HITL state after successful forwarding - Use SubagentForwardResult NamedTuple for type-safe return values - Extract shared helpers to avoid code duplication - Prefer public APIs with private API fallbacks for forward compat - Emit terminal failure events for all error paths - Add comprehensive unit tests for all helper methods Fixes #1475 Signed-off-by: Jaison Paul --- .../src/kagent/adk/_agent_executor.py | 611 +++++++++++++++-- .../kagent-adk/tests/unittests/test_hitl.py | 645 +++++++++++++++++- 2 files changed, 1209 insertions(+), 47 deletions(-) diff --git a/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py b/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py index a4ac5e280..2bd6e7862 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py +++ b/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py @@ -8,6 +8,7 @@ from datetime import datetime, timezone from typing import Any, Awaitable, Callable, Optional +from a2a.client.middleware import ClientCallContext from a2a.server.agent_execution.context import RequestContext from a2a.server.events.event_queue import EventQueue from a2a.types import ( @@ -27,10 +28,12 @@ from google.adk.a2a.executor.a2a_agent_executor import ( A2aAgentExecutorConfig as UpstreamA2aAgentExecutorConfig, ) +from google.adk.agents.remote_a2a_agent import RemoteA2aAgent from google.adk.events import Event, EventActions from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME from google.adk.runners import Runner from google.adk.sessions import Session +from google.adk.tools.agent_tool import AgentTool from google.adk.tools.tool_confirmation import ToolConfirmation from google.adk.utils.context_utils import Aclosing from google.genai import types as genai_types @@ -316,6 +319,547 @@ async def _safe_close_runner(self, runner: Runner): continue raise result + # Session state key for pending subagent HITL context + _SUBAGENT_HITL_STATE_KEY = "_subagent_hitl" + + # Timeout (in seconds) for forwarding decisions to subagents + _SUBAGENT_FORWARD_TIMEOUT_SECONDS = 300 + + @staticmethod + def _extract_subagent_input_required(adk_event: Event) -> dict | None: + """Check if an ADK event is from a RemoteA2aAgent subagent that entered input_required. + + Returns the A2A response dict if input_required was detected, None otherwise. + """ + custom_metadata = getattr(adk_event, "custom_metadata", None) + if not custom_metadata: + return None + + a2a_response = custom_metadata.get("a2a:response") + if not isinstance(a2a_response, dict): + return None + + status = a2a_response.get("status", {}) + if status.get("state") == "input-required": + return a2a_response + + return None + + @staticmethod + def _build_subagent_hitl_message(a2a_response: dict) -> Message | None: + """Extract the HITL message from a subagent's A2A response.""" + status = a2a_response.get("status", {}) + message_data = status.get("message") + if message_data: + try: + return Message.model_validate(message_data) + except Exception: + logger.error("Failed to validate subagent HITL message", exc_info=True) + return None + + @staticmethod + def _find_remote_a2a_agent(runner: Runner, agent_name: str) -> RemoteA2aAgent | None: + """Find a RemoteA2aAgent by name in the runner's agent tools.""" + agent = runner.agent + for tool in getattr(agent, "tools", []): + if isinstance(tool, AgentTool) and isinstance(tool.agent, RemoteA2aAgent): + if tool.agent.name == agent_name: + return tool.agent + return None + + async def _save_subagent_hitl_state( + self, + runner: Runner, + session: Session, + adk_event: Event, + a2a_response: dict, + ) -> None: + """Store subagent HITL context in session state for later resume.""" + subagent_hitl_state = { + "agent_name": getattr(adk_event, "author", None), + "task_id": a2a_response.get("id"), + "context_id": a2a_response.get("contextId") or a2a_response.get("context_id"), + } + await runner.session_service.append_event( + session, + Event( + invocation_id=f"subagent_hitl_save_{uuid.uuid4()}", + author="system", + actions=EventActions(state_delta={self._SUBAGENT_HITL_STATE_KEY: subagent_hitl_state}), + ), + ) + + async def _clear_subagent_hitl_state(self, runner: Runner, session: Session) -> None: + """Clear subagent HITL context from session state.""" + await runner.session_service.append_event( + session, + Event( + invocation_id=f"subagent_hitl_clear_{uuid.uuid4()}", + author="system", + actions=EventActions(state_delta={self._SUBAGENT_HITL_STATE_KEY: None}), + ), + ) + + async def _publish_final_task_result( + self, + context: RequestContext, + event_queue: EventQueue, + task_result_aggregator: TaskResultAggregator, + run_metadata: dict[str, str], + ) -> None: + """Publish the final task result event based on the aggregated state. + + If the task is still in working state with message parts, publishes an + artifact update followed by a completed status. Otherwise publishes + the aggregator's current state as the final status. + """ + if ( + task_result_aggregator.task_state == TaskState.working + and task_result_aggregator.task_status_message is not None + and task_result_aggregator.task_status_message.parts + ): + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + last_chunk=True, + context_id=context.context_id, + artifact=Artifact( + artifact_id=str(uuid.uuid4()), + parts=task_result_aggregator.task_status_message.parts, + ), + ) + ) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.completed, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=True, + metadata=run_metadata, + ) + ) + else: + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=task_result_aggregator.task_state, + timestamp=datetime.now(timezone.utc).isoformat(), + message=task_result_aggregator.task_status_message, + ), + context_id=context.context_id, + final=True, + metadata=run_metadata, + ) + ) + + async def _detect_and_propagate_subagent_hitl( + self, + adk_event: Event, + context: RequestContext, + event_queue: EventQueue, + runner: Runner, + session: Session, + task_result_aggregator: TaskResultAggregator, + ) -> bool: + """Detect if an ADK event indicates a subagent entered input_required. + + If detected, saves the HITL state and publishes an input_required event + to the parent task so the frontend can show the approval UI. + + Returns True if subagent HITL was detected and propagated, False otherwise. + """ + subagent_a2a = self._extract_subagent_input_required(adk_event) + if not subagent_a2a: + return False + + await self._save_subagent_hitl_state(runner, session, adk_event, subagent_a2a) + hitl_message = self._build_subagent_hitl_message(subagent_a2a) + if hitl_message is None: + # Always emit input_required even when the subagent message + # fails validation — use a generic fallback so the frontend + # never receives a blank approval prompt. + hitl_message = Message( + message_id=str(uuid.uuid4()), + role=Role.agent, + parts=[Part(TextPart(text="Subagent requires user input."))], + ) + hitl_event = TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.input_required, + message=hitl_message, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=False, + ) + task_result_aggregator.process_event(hitl_event) + await event_queue.enqueue_event(hitl_event) + return True + + async def _forward_decision_to_subagent( + self, + remote_agent: RemoteA2aAgent, + forward_message: Message, + session: Session, + context: RequestContext, + event_queue: EventQueue, + run_metadata: dict[str, str], + ) -> tuple[bool, bool, bool, list[Part], Message | None]: + """Forward a HITL decision to a subagent via its A2A client and collect the response. + + Uses RemoteA2aAgent._a2a_client when no public send API is available. + + Returns a tuple of (completed, failed, needs_input, final_text_parts, hitl_message). + The hitl_message is the Message from the subagent's input_required status (if any). + """ + subagent_final_text_parts: list[Part] = [] + subagent_completed = False + subagent_failed = False + subagent_needs_input = False + subagent_hitl_message: Message | None = None + + a2a_client = getattr(remote_agent, "_a2a_client", None) + send_message = getattr(a2a_client, "send_message", None) if a2a_client is not None else None + if send_message is None: + raise RuntimeError( + "RemoteA2aAgent does not expose an '_a2a_client.send_message' API. " + "Please upgrade google-adk to a version that provides a supported public API." + ) + + async for a2a_response in send_message( + request=forward_message, + context=ClientCallContext(state=session.state), + ): + if isinstance(a2a_response, tuple): + task, update = a2a_response + if update is None: + # Initial task response + if task and task.status: + if task.status.state == TaskState.completed: + subagent_completed = True + if task.status.message and task.status.message.parts: + subagent_final_text_parts = task.status.message.parts + elif task.status.state == TaskState.failed: + subagent_failed = True + error_msg = "Subagent execution failed" + if task.status.message and task.status.message.parts: + for part in task.status.message.parts: + if hasattr(part, "root") and isinstance(part.root, TextPart): + error_msg = part.root.text + break + await self._publish_failed_status_event(context, event_queue, error_msg) + elif task.status.state == TaskState.input_required: + subagent_needs_input = True + if task.status.message: + subagent_hitl_message = task.status.message + elif hasattr(update, "status") and update.status: + # Stream subagent status updates to the frontend as parent events + if update.status.message: + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.working, + message=update.status.message, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=False, + metadata=run_metadata.copy(), + ) + ) + if update.status.state == TaskState.completed: + subagent_completed = True + if update.status.message and update.status.message.parts: + subagent_final_text_parts = update.status.message.parts + elif update.status.state == TaskState.failed: + subagent_failed = True + error_msg = "Subagent execution failed" + if update.status.message and update.status.message.parts: + for part in update.status.message.parts: + if hasattr(part, "root") and isinstance(part.root, TextPart): + error_msg = part.root.text + break + await self._publish_failed_status_event(context, event_queue, error_msg) + elif update.status.state == TaskState.input_required: + subagent_needs_input = True + if update.status.message: + subagent_hitl_message = update.status.message + elif hasattr(update, "artifact") and update.artifact: + # Forward artifact events + if update.artifact.parts: + subagent_final_text_parts = update.artifact.parts + + # Check terminal states after processing + if task and task.status: + state = task.status.state + if state in (TaskState.completed, TaskState.failed, TaskState.input_required): + if state == TaskState.completed: + subagent_completed = True + elif state == TaskState.failed: + subagent_failed = True + elif state == TaskState.input_required: + subagent_needs_input = True + else: + # Non-streaming message response — extract content if available + subagent_completed = True + if hasattr(a2a_response, "parts") and a2a_response.parts: + subagent_final_text_parts = a2a_response.parts + + if not subagent_completed and not subagent_failed and not subagent_needs_input: + logger.warning("Subagent send_message stream ended without reaching a terminal state") + + return ( + subagent_completed, + subagent_failed, + subagent_needs_input, + subagent_final_text_parts, + subagent_hitl_message, + ) + + async def _handle_subagent_hitl_resume( + self, + context: RequestContext, + event_queue: EventQueue, + runner: Runner, + session: Session, + run_args: dict[str, Any], + subagent_hitl: dict, + ) -> None: + """Forward user's HITL decision to the subagent and continue parent execution. + + When a subagent entered input_required, we stored its context. Now the user + has made a decision (approve/deny). We forward the original A2A decision + message directly to the subagent, collect its response, and then re-run + the parent agent so the LLM can continue with the subagent's result. + """ + agent_name = subagent_hitl.get("agent_name") + subagent_task_id = subagent_hitl.get("task_id") + subagent_context_id = subagent_hitl.get("context_id") + + # Find the RemoteA2aAgent for this subagent + remote_agent = self._find_remote_a2a_agent(runner, agent_name) if agent_name else None + if not remote_agent: + logger.error( + "Cannot find RemoteA2aAgent '%s' for subagent HITL resume; clearing state and publishing failure", + agent_name, + ) + await self._clear_subagent_hitl_state(runner, session) + await self._publish_failed_status_event( + context, + event_queue, + f"Subagent '{agent_name}' is no longer available to receive the HITL decision", + ) + return + + # Ensure the remote agent is resolved (has A2A client ready). + # Prefer a public API if available, fall back to the private one. + try: + ensure_resolved = getattr(remote_agent, "ensure_resolved", None) or getattr( + remote_agent, "_ensure_resolved", None + ) + if ensure_resolved is None: + raise AttributeError("RemoteA2aAgent has neither 'ensure_resolved' nor '_ensure_resolved'") + await ensure_resolved() + except Exception as e: + logger.error("Failed to resolve RemoteA2aAgent '%s': %s", agent_name, e) + await self._publish_failed_status_event( + context, event_queue, f"Failed to connect to subagent '{agent_name}': {e}" + ) + return + + # Build the A2A message to forward to the subagent. + # Use the original decision message parts with the subagent's task/context IDs. + forward_message = Message( + message_id=str(uuid.uuid4()), + role=Role.user, + parts=context.message.parts if context.message else [], + task_id=subagent_task_id, + context_id=subagent_context_id, + ) + + # Build run metadata for events + run_metadata = { + get_kagent_metadata_key("app_name"): runner.app_name, + get_kagent_metadata_key("user_id"): run_args["user_id"], + get_kagent_metadata_key("session_id"): run_args["session_id"], + } + + # Publish working status + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.working, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=False, + metadata=run_metadata.copy(), + ) + ) + + # Forward the decision to the subagent with a timeout + try: + ( + subagent_completed, + subagent_failed, + subagent_needs_input, + subagent_final_text_parts, + subagent_hitl_message, + ) = await asyncio.wait_for( + self._forward_decision_to_subagent( + remote_agent, + forward_message, + session, + context, + event_queue, + run_metadata, + ), + timeout=self._SUBAGENT_FORWARD_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + logger.error( + "Timed out forwarding decision to subagent '%s' after %ds", + agent_name, + self._SUBAGENT_FORWARD_TIMEOUT_SECONDS, + ) + await self._publish_failed_status_event( + context, + event_queue, + f"Timed out waiting for subagent '{agent_name}' to respond", + ) + return + except Exception as e: + logger.error("Failed to forward decision to subagent '%s': %s", agent_name, e, exc_info=True) + await self._publish_failed_status_event( + context, event_queue, f"Failed to communicate with subagent '{agent_name}': {e}" + ) + return + + # Clear the subagent HITL state AFTER successful communication + # (preserves retry ability if forwarding fails) + await self._clear_subagent_hitl_state(runner, session) + + # If the subagent entered input_required again, re-save the state and propagate + if subagent_needs_input: + new_hitl_state = { + "agent_name": agent_name, + "task_id": subagent_task_id, + "context_id": subagent_context_id, + } + await runner.session_service.append_event( + session, + Event( + invocation_id=f"subagent_hitl_save_{uuid.uuid4()}", + author="system", + actions=EventActions(state_delta={self._SUBAGENT_HITL_STATE_KEY: new_hitl_state}), + ), + ) + # Re-publish input_required with the subagent's HITL prompt + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.input_required, + message=subagent_hitl_message, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=context.context_id, + final=True, + metadata=run_metadata, + ) + ) + return + + if subagent_failed: + return # Already published failure event + + # Subagent completed — re-run the parent agent with the subagent's result + subagent_result_text = "" + for part in subagent_final_text_parts: + root = part.root if hasattr(part, "root") else part + if isinstance(root, TextPart): + subagent_result_text += root.text + + if subagent_result_text: + # Store the subagent result in session state + await runner.session_service.append_event( + session, + Event( + invocation_id=f"subagent_result_{uuid.uuid4()}", + author=agent_name or "subagent", + content=genai_types.Content( + role="model", + parts=[genai_types.Part(text=subagent_result_text)], + ), + ), + ) + + # Re-run the parent agent to continue processing. + # Create a copy of run_args to avoid mutating the caller's dict. + continuation_message = genai_types.Content( + role="user", + parts=[genai_types.Part(text=subagent_result_text or "The subagent has completed.")], + ) + continuation_run_args = {**run_args, "new_message": continuation_message} + + # Refresh the session after modifications + refreshed_session = await runner.session_service.get_session( + app_name=runner.app_name, + user_id=continuation_run_args["user_id"], + session_id=continuation_run_args["session_id"], + ) + if refreshed_session: + session = refreshed_session + + # Uses Runner._new_invocation_context (private API — same pattern as + # the main _handle_request path) + invocation_context = runner._new_invocation_context( + session=session, + new_message=continuation_message, + run_config=continuation_run_args["run_config"], + ) + + real_invocation_id: str | None = None + task_result_aggregator = TaskResultAggregator() + + async with Aclosing(runner.run_async(**continuation_run_args)) as agen: + async for adk_event in agen: + event_inv_id = getattr(adk_event, "invocation_id", None) + if event_inv_id and not real_invocation_id: + real_invocation_id = event_inv_id + run_metadata[get_kagent_metadata_key("invocation_id")] = real_invocation_id + + for a2a_event in convert_event_to_a2a_events( + adk_event, invocation_context, context.task_id, context.context_id + ): + if not adk_event.partial: + task_result_aggregator.process_event(a2a_event) + await event_queue.enqueue_event(a2a_event) + + # Check for nested subagent HITL during continuation + if await self._detect_and_propagate_subagent_hitl( + adk_event, + context, + event_queue, + runner, + session, + task_result_aggregator, + ): + break + + if getattr(adk_event, "long_running_tool_ids", None): + break + + # Publish the final result + await self._publish_final_task_result(context, event_queue, task_result_aggregator, run_metadata) + async def _publish_failed_status_event( self, context: RequestContext, @@ -481,6 +1025,14 @@ async def _handle_request( # HITL resume: translate A2A approval/rejection to ADK FunctionResponse decision = extract_decision_from_message(context.message) if decision: + # Check for pending subagent HITL first — if a subagent entered + # input_required, the decision should be forwarded to it directly + # rather than processed as a native parent HITL. + subagent_hitl = session.state.get(self._SUBAGENT_HITL_STATE_KEY) + if subagent_hitl: + await self._handle_subagent_hitl_resume(context, event_queue, runner, session, run_args, subagent_hitl) + return + parts = self._process_hitl_decision(session, decision, context.message) if parts: run_args["new_message"] = genai_types.Content(role="user", parts=parts) @@ -556,52 +1108,21 @@ async def _handle_request( if getattr(adk_event, "long_running_tool_ids", None): break + # Detect subagent input_required: when a RemoteA2aAgent + # subagent enters input_required, propagate it to the + # parent task so the frontend can show the approval UI. + if await self._detect_and_propagate_subagent_hitl( + adk_event, + context, + event_queue, + runner, + session, + task_result_aggregator, + ): + break + # publish the task result event - this is final - if ( - task_result_aggregator.task_state == TaskState.working - and task_result_aggregator.task_status_message is not None - and task_result_aggregator.task_status_message.parts - ): - # if task is still working properly, publish the artifact update event as - # the final result according to a2a protocol. - await event_queue.enqueue_event( - TaskArtifactUpdateEvent( - task_id=context.task_id, - last_chunk=True, - context_id=context.context_id, - artifact=Artifact( - artifact_id=str(uuid.uuid4()), - parts=task_result_aggregator.task_status_message.parts, - ), - ) - ) - # publish the final status update event - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=TaskState.completed, - timestamp=datetime.now(timezone.utc).isoformat(), - ), - context_id=context.context_id, - final=True, - metadata=run_metadata, - ) - ) - else: - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=task_result_aggregator.task_state, - timestamp=datetime.now(timezone.utc).isoformat(), - message=task_result_aggregator.task_status_message, - ), - context_id=context.context_id, - final=True, - metadata=run_metadata, - ) - ) + await self._publish_final_task_result(context, event_queue, task_result_aggregator, run_metadata) async def _prepare_session(self, context: RequestContext, run_args: dict[str, Any], runner: Runner): session_id = run_args["session_id"] diff --git a/python/packages/kagent-adk/tests/unittests/test_hitl.py b/python/packages/kagent-adk/tests/unittests/test_hitl.py index 81dda3024..46295f7b2 100644 --- a/python/packages/kagent-adk/tests/unittests/test_hitl.py +++ b/python/packages/kagent-adk/tests/unittests/test_hitl.py @@ -1,9 +1,10 @@ """Tests for the HITL approval callback and agent executor's HITL handling logic.""" +import asyncio import json -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock, create_autospec, patch -from a2a.types import DataPart, Message, Part, Role +from a2a.types import DataPart, Message, Part, Role, TaskState, TaskStatus, TextPart from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME from google.adk.sessions import Session from google.adk.tools.tool_confirmation import ToolConfirmation @@ -545,3 +546,643 @@ def test_process_hitl_decision_ask_user_answers(): resp = json.loads(fr.response["response"]) assert resp["confirmed"] is True assert resp["payload"]["answers"] == answers + + +# --------------------------------------------------------------------------- +# Subagent HITL detection tests +# --------------------------------------------------------------------------- + + +class MockADKEvent: + """Minimal mock for an ADK Event with custom_metadata.""" + + def __init__(self, custom_metadata=None, author=None, partial=False): + self.custom_metadata = custom_metadata + self.author = author + self.partial = partial + + +class TestExtractSubagentInputRequired: + """Tests for _extract_subagent_input_required.""" + + def test_returns_none_for_no_metadata(self): + event = MockADKEvent(custom_metadata=None) + assert A2aAgentExecutor._extract_subagent_input_required(event) is None + + def test_returns_none_for_no_a2a_response(self): + event = MockADKEvent(custom_metadata={"some_key": "value"}) + assert A2aAgentExecutor._extract_subagent_input_required(event) is None + + def test_returns_none_for_non_dict_response(self): + event = MockADKEvent(custom_metadata={"a2a:response": "not a dict"}) + assert A2aAgentExecutor._extract_subagent_input_required(event) is None + + def test_returns_none_for_working_state(self): + response = {"status": {"state": "working"}} + event = MockADKEvent(custom_metadata={"a2a:response": response}) + assert A2aAgentExecutor._extract_subagent_input_required(event) is None + + def test_returns_none_for_completed_state(self): + response = {"status": {"state": "completed"}} + event = MockADKEvent(custom_metadata={"a2a:response": response}) + assert A2aAgentExecutor._extract_subagent_input_required(event) is None + + def test_returns_response_for_input_required(self): + response = { + "id": "task-123", + "contextId": "ctx-456", + "status": { + "state": "input-required", + "message": { + "messageId": "msg-1", + "role": "agent", + "parts": [{"kind": "text", "text": "Approval needed"}], + }, + }, + } + event = MockADKEvent(custom_metadata={"a2a:response": response}) + result = A2aAgentExecutor._extract_subagent_input_required(event) + assert result is not None + assert result["id"] == "task-123" + assert result["status"]["state"] == "input-required" + + +class TestBuildSubagentHitlMessage: + """Tests for _build_subagent_hitl_message.""" + + def test_returns_none_for_no_message(self): + response = {"status": {"state": "input-required"}} + assert A2aAgentExecutor._build_subagent_hitl_message(response) is None + + def test_returns_message_for_valid_response(self): + response = { + "status": { + "state": "input-required", + "message": { + "messageId": "msg-1", + "role": "agent", + "parts": [{"kind": "text", "text": "Approve this tool?"}], + }, + }, + } + message = A2aAgentExecutor._build_subagent_hitl_message(response) + assert message is not None + assert message.message_id == "msg-1" + + def test_returns_none_for_invalid_message(self): + response = {"status": {"state": "input-required", "message": "not a valid message"}} + assert A2aAgentExecutor._build_subagent_hitl_message(response) is None + + +class TestFindRemoteA2aAgent: + """Tests for _find_remote_a2a_agent.""" + + def test_returns_none_when_no_tools(self): + runner = MagicMock(spec=["agent"]) + runner.agent = MagicMock(tools=[]) + assert A2aAgentExecutor._find_remote_a2a_agent(runner, "my_agent") is None + + def test_returns_none_when_no_agent_tools(self): + runner = MagicMock(spec=["agent"]) + runner.agent = MagicMock(tools=[MagicMock()]) # A tool that isn't AgentTool + assert A2aAgentExecutor._find_remote_a2a_agent(runner, "my_agent") is None + + def test_returns_agent_when_name_matches(self): + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent + from google.adk.tools.agent_tool import AgentTool + + mock_remote = create_autospec(RemoteA2aAgent, instance=True) + mock_remote.name = "my_agent" + mock_tool = create_autospec(AgentTool, instance=True) + mock_tool.agent = mock_remote + + runner = MagicMock(spec=["agent"]) + runner.agent = MagicMock(tools=[mock_tool]) + + result = A2aAgentExecutor._find_remote_a2a_agent(runner, "my_agent") + assert result is mock_remote + + def test_returns_none_when_name_does_not_match(self): + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent + from google.adk.tools.agent_tool import AgentTool + + mock_remote = create_autospec(RemoteA2aAgent, instance=True) + mock_remote.name = "other_agent" + mock_tool = create_autospec(AgentTool, instance=True) + mock_tool.agent = mock_remote + + runner = MagicMock(spec=["agent"]) + runner.agent = MagicMock(tools=[mock_tool]) + + result = A2aAgentExecutor._find_remote_a2a_agent(runner, "my_agent") + assert result is None + + +# --------------------------------------------------------------------------- +# Subagent HITL state management tests +# --------------------------------------------------------------------------- + + +class TestSaveAndClearSubagentHitlState: + """Tests for _save_subagent_hitl_state and _clear_subagent_hitl_state.""" + + def test_save_uses_unique_invocation_id(self): + executor = A2aAgentExecutor(runner=MagicMock()) + runner = MagicMock() + runner.session_service.append_event = AsyncMock() + session = MagicMock(spec=Session) + + event = MockADKEvent(author="sub_agent") + a2a_response = {"id": "task-1", "contextId": "ctx-1", "status": {"state": "input-required"}} + + asyncio.get_event_loop().run_until_complete( + executor._save_subagent_hitl_state(runner, session, event, a2a_response) + ) + + call_args = runner.session_service.append_event.call_args + saved_event = call_args[0][1] + assert saved_event.invocation_id.startswith("subagent_hitl_save_") + assert len(saved_event.invocation_id) > len("subagent_hitl_save_") + + def test_clear_uses_unique_invocation_id(self): + executor = A2aAgentExecutor(runner=MagicMock()) + runner = MagicMock() + runner.session_service.append_event = AsyncMock() + session = MagicMock(spec=Session) + + asyncio.get_event_loop().run_until_complete(executor._clear_subagent_hitl_state(runner, session)) + + call_args = runner.session_service.append_event.call_args + cleared_event = call_args[0][1] + assert cleared_event.invocation_id.startswith("subagent_hitl_clear_") + assert len(cleared_event.invocation_id) > len("subagent_hitl_clear_") + + def test_save_stores_correct_state(self): + executor = A2aAgentExecutor(runner=MagicMock()) + runner = MagicMock() + runner.session_service.append_event = AsyncMock() + session = MagicMock(spec=Session) + + event = MockADKEvent(author="my_sub") + a2a_response = {"id": "task-42", "contextId": "ctx-99", "status": {"state": "input-required"}} + + asyncio.get_event_loop().run_until_complete( + executor._save_subagent_hitl_state(runner, session, event, a2a_response) + ) + + call_args = runner.session_service.append_event.call_args + saved_event = call_args[0][1] + state_delta = saved_event.actions.state_delta + hitl_state = state_delta["_subagent_hitl"] + assert hitl_state["agent_name"] == "my_sub" + assert hitl_state["task_id"] == "task-42" + assert hitl_state["context_id"] == "ctx-99" + + +# --------------------------------------------------------------------------- +# Detect and propagate subagent HITL tests +# --------------------------------------------------------------------------- + + +class TestDetectAndPropagateSubagentHitl: + """Tests for _detect_and_propagate_subagent_hitl.""" + + def test_returns_false_for_non_subagent_event(self): + executor = A2aAgentExecutor(runner=MagicMock()) + event = MockADKEvent(custom_metadata=None) + context = MagicMock() + event_queue = MagicMock() + event_queue.enqueue_event = AsyncMock() + runner = MagicMock() + runner.session_service.append_event = AsyncMock() + session = MagicMock(spec=Session) + aggregator = MagicMock() + + result = asyncio.get_event_loop().run_until_complete( + executor._detect_and_propagate_subagent_hitl( + event, + context, + event_queue, + runner, + session, + aggregator, + ) + ) + assert result is False + event_queue.enqueue_event.assert_not_called() + + def test_returns_true_and_publishes_for_input_required(self): + executor = A2aAgentExecutor(runner=MagicMock()) + a2a_response = { + "id": "task-123", + "contextId": "ctx-456", + "status": { + "state": "input-required", + "message": { + "messageId": "msg-1", + "role": "agent", + "parts": [{"kind": "text", "text": "Approve?"}], + }, + }, + } + event = MockADKEvent( + custom_metadata={"a2a:response": a2a_response}, + author="sub_agent", + ) + context = MagicMock() + context.task_id = "parent-task" + context.context_id = "parent-ctx" + event_queue = MagicMock() + event_queue.enqueue_event = AsyncMock() + runner = MagicMock() + runner.session_service.append_event = AsyncMock() + session = MagicMock(spec=Session) + aggregator = MagicMock() + + result = asyncio.get_event_loop().run_until_complete( + executor._detect_and_propagate_subagent_hitl( + event, + context, + event_queue, + runner, + session, + aggregator, + ) + ) + assert result is True + # Should have saved state and published input_required event + runner.session_service.append_event.assert_called_once() + event_queue.enqueue_event.assert_called_once() + published_event = event_queue.enqueue_event.call_args[0][0] + assert published_event.status.state == TaskState.input_required + + +# --------------------------------------------------------------------------- +# Forward decision timeout tests +# --------------------------------------------------------------------------- + + +class TestSubagentHitlResumeTimeout: + """Tests for timeout behavior in _handle_subagent_hitl_resume.""" + + def test_publishes_failure_on_timeout(self): + """Verify that a timeout during subagent forwarding results in a failed status.""" + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent + from google.adk.tools.agent_tool import AgentTool + + executor = A2aAgentExecutor(runner=MagicMock()) + # Use a very short timeout for testing + executor._SUBAGENT_FORWARD_TIMEOUT_SECONDS = 0.01 + + # Set up a RemoteA2aAgent mock that hangs + mock_remote = create_autospec(RemoteA2aAgent, instance=True) + mock_remote.name = "slow_agent" + mock_remote._ensure_resolved = AsyncMock() + + async def slow_send_message(**kwargs): + await asyncio.sleep(10) + yield # pragma: no cover + + mock_remote._a2a_client = MagicMock() + mock_remote._a2a_client.send_message = slow_send_message + + mock_tool = create_autospec(AgentTool, instance=True) + mock_tool.agent = mock_remote + + runner = MagicMock() + runner.agent = MagicMock(tools=[mock_tool]) + runner.app_name = "test-app" + + session = MagicMock(spec=Session) + session.state = {} + + context = MagicMock() + context.task_id = "parent-task" + context.context_id = "parent-ctx" + context.message = MagicMock() + context.message.parts = [] + + event_queue = MagicMock() + event_queue.enqueue_event = AsyncMock() + + run_args = { + "user_id": "user-1", + "session_id": "session-1", + "new_message": MagicMock(), + "run_config": MagicMock(), + } + subagent_hitl = { + "agent_name": "slow_agent", + "task_id": "sub-task", + "context_id": "sub-ctx", + } + + asyncio.get_event_loop().run_until_complete( + executor._handle_subagent_hitl_resume( + context, + event_queue, + runner, + session, + run_args, + subagent_hitl, + ) + ) + + # Should have published working status + failure event + assert event_queue.enqueue_event.call_count >= 2 + last_event = event_queue.enqueue_event.call_args_list[-1][0][0] + assert last_event.status.state == TaskState.failed + assert "Timed out" in last_event.status.message.parts[0].root.text + + def test_does_not_mutate_run_args(self): + """Verify that run_args is not mutated during subagent HITL resume.""" + executor = A2aAgentExecutor(runner=MagicMock()) + + # Make the method fail early after the forwarding step + # by making _find_remote_a2a_agent return None + runner = MagicMock() + runner.session_service.append_event = AsyncMock() + session = MagicMock(spec=Session) + context = MagicMock() + event_queue = MagicMock() + event_queue.enqueue_event = AsyncMock() + + original_message = MagicMock() + run_args = { + "user_id": "user-1", + "session_id": "session-1", + "new_message": original_message, + "run_config": MagicMock(), + } + subagent_hitl = { + "agent_name": "nonexistent_agent", + "task_id": "sub-task", + "context_id": "sub-ctx", + } + + asyncio.get_event_loop().run_until_complete( + executor._handle_subagent_hitl_resume( + context, + event_queue, + runner, + session, + run_args, + subagent_hitl, + ) + ) + + # run_args should not have been modified + assert run_args["new_message"] is original_message + + def test_missing_agent_publishes_failure_event(self): + """When the remote agent is not found, a failure event must be emitted.""" + executor = A2aAgentExecutor(runner=MagicMock()) + + runner = MagicMock() + runner.session_service.append_event = AsyncMock() + session = MagicMock(spec=Session) + context = MagicMock() + context.task_id = "parent-task" + context.context_id = "parent-ctx" + event_queue = MagicMock() + event_queue.enqueue_event = AsyncMock() + + run_args = { + "user_id": "user-1", + "session_id": "session-1", + "new_message": MagicMock(), + "run_config": MagicMock(), + } + subagent_hitl = { + "agent_name": "vanished_agent", + "task_id": "sub-task", + "context_id": "sub-ctx", + } + + asyncio.get_event_loop().run_until_complete( + executor._handle_subagent_hitl_resume( + context, + event_queue, + runner, + session, + run_args, + subagent_hitl, + ) + ) + + # A terminal failure event must have been published + assert event_queue.enqueue_event.call_count >= 1 + last_event = event_queue.enqueue_event.call_args_list[-1][0][0] + assert last_event.status.state == TaskState.failed + assert "vanished_agent" in last_event.status.message.parts[0].root.text + + +# --------------------------------------------------------------------------- +# Forward decision captures HITL message tests +# --------------------------------------------------------------------------- + + +class TestForwardDecisionCapturesHitlMessage: + """Tests for _forward_decision_to_subagent capturing the HITL message.""" + + def test_captures_hitl_message_from_initial_task(self): + """When input_required arrives on the initial task object, the HITL message is captured.""" + from a2a.types import Message as A2AMessage + + executor = A2aAgentExecutor(runner=MagicMock()) + + hitl_msg = A2AMessage( + message_id="hitl-msg-1", + role=Role.agent, + parts=[Part(TextPart(text="Please approve this action"))], + ) + mock_task = MagicMock() + mock_task.status = TaskStatus( + state=TaskState.input_required, + message=hitl_msg, + ) + + async def mock_send_message(**kwargs): + yield (mock_task, None) + + mock_remote = MagicMock() + mock_remote._a2a_client.send_message = mock_send_message + + forward_message = A2AMessage( + message_id="fwd-1", + role=Role.user, + parts=[], + ) + session = MagicMock() + session.state = {} + context = MagicMock() + context.task_id = "t1" + context.context_id = "c1" + event_queue = MagicMock() + event_queue.enqueue_event = AsyncMock() + + completed, failed, needs_input, parts, captured_message = asyncio.get_event_loop().run_until_complete( + executor._forward_decision_to_subagent( + mock_remote, + forward_message, + session, + context, + event_queue, + {}, + ) + ) + + assert needs_input is True + assert captured_message is not None + assert captured_message.message_id == "hitl-msg-1" + + def test_captures_hitl_message_from_status_update(self): + """When input_required arrives via a status update, the HITL message is captured.""" + from a2a.types import Message as A2AMessage + + executor = A2aAgentExecutor(runner=MagicMock()) + + hitl_msg = A2AMessage( + message_id="hitl-msg-2", + role=Role.agent, + parts=[Part(TextPart(text="Confirm deployment?"))], + ) + mock_task = MagicMock() + mock_task.status = None # No terminal state on task itself + + mock_update = MagicMock() + mock_update.status = TaskStatus( + state=TaskState.input_required, + message=hitl_msg, + ) + mock_update.artifact = None + + async def mock_send_message(**kwargs): + yield (mock_task, mock_update) + + mock_remote = MagicMock() + mock_remote._a2a_client.send_message = mock_send_message + + forward_message = A2AMessage( + message_id="fwd-2", + role=Role.user, + parts=[], + ) + session = MagicMock() + session.state = {} + context = MagicMock() + context.task_id = "t1" + context.context_id = "c1" + event_queue = MagicMock() + event_queue.enqueue_event = AsyncMock() + + completed, failed, needs_input, parts, captured_message = asyncio.get_event_loop().run_until_complete( + executor._forward_decision_to_subagent( + mock_remote, + forward_message, + session, + context, + event_queue, + {}, + ) + ) + + assert needs_input is True + assert captured_message is not None + assert captured_message.message_id == "hitl-msg-2" + + def test_returns_none_message_when_completed(self): + """When subagent completes normally, hitl_message is None.""" + from a2a.types import Message as A2AMessage + + executor = A2aAgentExecutor(runner=MagicMock()) + + mock_task = MagicMock() + mock_task.status = TaskStatus(state=TaskState.completed) + + async def mock_send_message(**kwargs): + yield (mock_task, None) + + mock_remote = MagicMock() + mock_remote._a2a_client.send_message = mock_send_message + + forward_message = A2AMessage( + message_id="fwd-3", + role=Role.user, + parts=[], + ) + session = MagicMock() + session.state = {} + context = MagicMock() + event_queue = MagicMock() + event_queue.enqueue_event = AsyncMock() + + completed, failed, needs_input, parts, captured_message = asyncio.get_event_loop().run_until_complete( + executor._forward_decision_to_subagent( + mock_remote, + forward_message, + session, + context, + event_queue, + {}, + ) + ) + + assert completed is True + assert needs_input is False + assert captured_message is None + + def test_initial_task_failure_publishes_failure_event(self): + """When initial task response is failed, a failure event must be emitted.""" + from a2a.types import Message as A2AMessage + + executor = A2aAgentExecutor(runner=MagicMock()) + + fail_msg = A2AMessage( + message_id="fail-msg-1", + role=Role.agent, + parts=[Part(TextPart(text="Something went wrong"))], + ) + mock_task = MagicMock() + mock_task.status = TaskStatus( + state=TaskState.failed, + message=fail_msg, + ) + + async def mock_send_message(**kwargs): + yield (mock_task, None) + + mock_remote = MagicMock() + mock_remote._a2a_client.send_message = mock_send_message + + forward_message = A2AMessage( + message_id="fwd-fail", + role=Role.user, + parts=[], + ) + session = MagicMock() + session.state = {} + context = MagicMock() + context.task_id = "t1" + context.context_id = "c1" + event_queue = MagicMock() + event_queue.enqueue_event = AsyncMock() + + completed, failed, needs_input, parts, captured_message = asyncio.get_event_loop().run_until_complete( + executor._forward_decision_to_subagent( + mock_remote, + forward_message, + session, + context, + event_queue, + {}, + ) + ) + + assert failed is True + assert completed is False + # A failure event should have been published + assert event_queue.enqueue_event.call_count >= 1 + last_event = event_queue.enqueue_event.call_args_list[-1][0][0] + assert last_event.status.state == TaskState.failed + assert "Something went wrong" in last_event.status.message.parts[0].root.text