Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions src/iac_code/a2a/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from iac_code.a2a.parts import (
allowed_cwd_roots,
is_relative_to,
parts_to_pipeline_input,
parts_to_prompt,
parts_to_user_input,
resolve_workspace_path,
)
from iac_code.a2a.pipeline_events import PipelineA2AContext, PipelineEventTranslator
Expand All @@ -50,6 +50,7 @@
TASK_STATE_INPUT_REQUIRED,
TASK_STATE_WORKING,
)
from iac_code.agent.message import ContentBlock
from iac_code.agent.message import Message as AgentMessage
from iac_code.config import get_active_provider_key, get_provider_config, load_credentials
from iac_code.i18n import _
Expand Down Expand Up @@ -683,14 +684,15 @@ def _public_cleanup_error(value: Any) -> str | None:
async def _stream_a2a_normal_events(
*,
runtime: Any,
prompt: str,
prompt: str | list[ContentBlock],
prompt_text: str,
cleanup_ledger: CleanupLedger | None,
cleanup_publisher: PipelineA2AEventPublisher | None,
cwd: str,
session_id: str,
) -> AsyncIterator[Any]:
if _a2a_cleanup_ledger_unavailable(cleanup_ledger, runtime=runtime, cwd=cwd, session_id=session_id):
if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt):
if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt_text):
yield TextDeltaEvent(
text=_("Rollback cleanup deferred prompt state is unavailable. Please repair it before continuing.")
)
Expand All @@ -702,7 +704,7 @@ async def _stream_a2a_normal_events(

if cleanup_ledger is not None and cleanup_ledger.load_failed():
if _runtime_has_cleanup_prompt(runtime) or _session_has_cleanup_prompt(cwd=cwd, session_id=session_id):
if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt):
if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt_text):
yield TextDeltaEvent(
text=_("Rollback cleanup deferred prompt state is unavailable. Please repair it before continuing.")
)
Expand All @@ -728,7 +730,7 @@ async def _stream_a2a_normal_events(
async for event in cleanup_stream:
yield event
if cleanup_ledger.pending_resources():
if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt):
if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt_text):
yield TextDeltaEvent(
text=_("Rollback cleanup deferred prompt state is unavailable. Please repair it before continuing.")
)
Expand All @@ -740,13 +742,18 @@ async def _stream_a2a_normal_events(
_mark_completed_cleanup_prompts(runtime=runtime, cwd=cwd, session_id=session_id, ledger=cleanup_ledger)
_prune_completed_cleanup_prompt_from_runtime(runtime, cleanup_ledger)

prompts_after_cleanup = _a2a_prompts_after_cleanup(cwd=cwd, session_id=session_id, prompt=prompt)
prompts_after_cleanup = _a2a_prompts_after_cleanup(cwd=cwd, session_id=session_id, prompt=prompt_text)
if prompts_after_cleanup is None:
yield TextDeltaEvent(
text=_("Rollback cleanup deferred prompt state is unavailable. Please repair it before continuing.")
)
return
prompts_to_run, has_deferred_prompts = prompts_after_cleanup
deferred_prompts, has_deferred_prompts = prompts_after_cleanup
prompts_to_run: list[str | list[ContentBlock]] = []
if has_deferred_prompts:
prompts_to_run.extend(deferred_prompts)
else:
prompts_to_run.append(prompt)
for prompt_to_run in prompts_to_run:
prompt_stream = runtime.agent_loop.run_streaming(prompt_to_run)
if cleanup_ledger is not None:
Expand Down Expand Up @@ -818,15 +825,20 @@ async def publish_initial_task_if_missing() -> None:
cwd=cwd,
)
pipeline_input: PipelineUserInput | None = None
normal_input: PipelineUserInput | None = None
if pipeline_mode and not route_pipeline_handoff_to_normal:
try:
pipeline_input = self._pipeline_input_from_context(context, cwd=cwd)
except ValueError as exc:
raise InvalidParamsError(sanitize_public_text(str(exc))) from exc
prompt = pipeline_input.display_text
self._validate_pipeline_request_input(pipeline_input, model=model)
else:
prompt = self._prompt_from_context(context, cwd=cwd)
try:
normal_input = self._normal_input_from_context(context, cwd=cwd)
except ValueError as exc:
raise InvalidParamsError(sanitize_public_text(str(exc))) from exc
if normal_input.has_images:
self._validate_pipeline_request_input(normal_input, model=model)
if pipeline_mode and requested_task_id is None:
recovered_task_id = await self._recoverable_pipeline_task_id_for_context(context_id=context_id, cwd=cwd)
if recovered_task_id is not None:
Expand Down Expand Up @@ -873,7 +885,11 @@ async def publish_initial_task_if_missing() -> None:
self._metrics.record_task_failed()
return

if not (pipeline_mode and not route_pipeline_handoff_to_normal) and not prompt.strip():
if (
not (pipeline_mode and not route_pipeline_handoff_to_normal)
and normal_input is not None
and normal_input.is_empty
):
task.state = TASK_STATE_FAILED
await self._publish_status(
event_queue,
Expand Down Expand Up @@ -1041,9 +1057,11 @@ def runtime_factory(session_id: str) -> Any:
artifact_store=self._artifact_store,
exposure_types=self._thinking_exposure_types,
)
assert normal_input is not None
stream = _stream_a2a_normal_events(
runtime=runtime,
prompt=prompt,
prompt=normal_input.content,
prompt_text=normal_input.display_text,
cleanup_ledger=cleanup_ledger,
cleanup_publisher=cleanup_publisher,
cwd=cwd,
Expand Down Expand Up @@ -1256,17 +1274,26 @@ def _prompt_from_context(self, context: RequestContext, *, cwd: str) -> str:
return context.get_user_input()
return parts_to_prompt(message.parts, cwd=cwd)

def _normal_input_from_context(self, context: RequestContext, *, cwd: str) -> PipelineUserInput:
message = getattr(context, "message", None)
if not isinstance(message, Message):
return normalize_pipeline_user_input(context.get_user_input())
user_input = parts_to_user_input(message.parts, cwd=cwd)
if user_input.has_images:
return user_input
return normalize_pipeline_user_input(user_input.display_text)

def _pipeline_input_from_context(self, context: RequestContext, *, cwd: str) -> PipelineUserInput:
message = getattr(context, "message", None)
if not isinstance(message, Message):
return normalize_pipeline_user_input(context.get_user_input())
return parts_to_pipeline_input(message.parts, cwd=cwd)
return parts_to_user_input(message.parts, cwd=cwd)

def validate_pipeline_message_request(self, message: Message) -> None:
metadata = getattr(message, "metadata", None)
try:
cwd = self._resolve_cwd(metadata)
pipeline_input = parts_to_pipeline_input(message.parts, cwd=cwd)
pipeline_input = parts_to_user_input(message.parts, cwd=cwd)
except ValueError as exc:
raise InvalidParamsError(sanitize_public_text(str(exc))) from exc
model = self._resolve_model(metadata) or self._model
Expand Down
14 changes: 7 additions & 7 deletions src/iac_code/a2a/parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def parts_to_prompt(message_parts: Iterable[Any], *, cwd: str | Path) -> str:
return "\n".join(value for value in values if value)


def parts_to_pipeline_input(message_parts: Iterable[Any], *, cwd: str | Path) -> PipelineUserInput:
def parts_to_user_input(message_parts: Iterable[Any], *, cwd: str | Path) -> PipelineUserInput:
blocks: list[ContentBlock] = []
for part in message_parts:
converted = part_to_pipeline_block(part, cwd=cwd)
converted = part_to_content_block(part, cwd=cwd)
if isinstance(converted, list):
blocks.extend(converted)
elif converted:
Expand All @@ -131,7 +131,7 @@ def parts_to_pipeline_input(message_parts: Iterable[Any], *, cwd: str | Path) ->
return PipelineUserInput(content=text, display_text=text, has_images=False)


def part_to_pipeline_block(part: Any, *, cwd: str | Path) -> str | list[ContentBlock]:
def part_to_content_block(part: Any, *, cwd: str | Path) -> str | list[ContentBlock]:
media_type = _media_type(part)
if _has_field(part, "text"):
_ensure_text_like(media_type)
Expand All @@ -140,7 +140,7 @@ def part_to_pipeline_block(part: Any, *, cwd: str | Path) -> str | list[ContentB
if media_type in SUPPORTED_IMAGE_MIME_TYPES:
return [_image_block_from_binary(_binary_data_part_bytes(part), requested_media_type=media_type)]
if _is_multimodal(media_type):
raise ValueError("A2A pipeline input has unsupported image media type.")
raise ValueError("A2A input has unsupported image media type.")
if media_type != "application/json":
raise ValueError("A2A data parts must use application/json media type.")
data = MessageToDict(part.data, preserving_proto_field_name=False)
Expand All @@ -153,7 +153,7 @@ def part_to_pipeline_block(part: Any, *, cwd: str | Path) -> str | list[ContentB
_ensure_size(raw, limit=MAX_BINARY_INLINE_BYTES, label="A2A binary raw part")
return [_image_block_from_binary(raw, requested_media_type=media_type)]
if _is_multimodal(media_type):
raise ValueError("A2A pipeline input has unsupported image media type.")
raise ValueError("A2A input has unsupported image media type.")
_ensure_text_like(media_type)
_ensure_size(raw, limit=MAX_INLINE_BYTES, label="A2A raw part")
try:
Expand All @@ -167,7 +167,7 @@ def part_to_pipeline_block(part: Any, *, cwd: str | Path) -> str | list[ContentB
raise ValueError("A2A binary file URL part content is too large.")
return [_image_block_from_binary(path.read_bytes(), requested_media_type=media_type)]
if _is_multimodal(media_type):
raise ValueError("A2A pipeline input has unsupported image media type.")
raise ValueError("A2A input has unsupported image media type.")
_ensure_text_like(media_type)
return _read_file_url_part(str(part.url), cwd=Path(cwd))
raise ValueError("A2A server supports text, JSON data, raw text, or workspace file URL parts only.")
Expand Down Expand Up @@ -283,7 +283,7 @@ def _binary_data_part_bytes(part: Any) -> bytes:

def _image_block_from_binary(raw: bytes, *, requested_media_type: str) -> ImageBlock:
if requested_media_type not in SUPPORTED_IMAGE_MIME_TYPES:
raise ValueError("A2A pipeline input has unsupported image media type.")
raise ValueError("A2A input has unsupported image media type.")
resized = maybe_resize_and_downsample(raw)
return ImageBlock(
media_type=resized.media_type,
Expand Down
4 changes: 2 additions & 2 deletions tests/a2a/fakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def pending_future() -> asyncio.Future[bool]:
class FakeAgentLoop:
def __init__(self, events: list[Any]) -> None:
self.events = events
self.prompts: list[str] = []
self.prompts: list[Any] = []

async def run_streaming(self, prompt: str):
async def run_streaming(self, prompt: Any):
self.prompts.append(prompt)
for event in self.events:
await asyncio.sleep(0)
Expand Down
55 changes: 49 additions & 6 deletions tests/a2a/test_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import base64
import logging
from pathlib import Path
from types import SimpleNamespace

import pytest
from a2a.types import TaskStatusUpdateEvent
Expand All @@ -14,7 +16,7 @@
from iac_code.a2a.pipeline_journal import A2APipelineJournal
from iac_code.a2a.pipeline_paths import a2a_pipeline_dir_for_session
from iac_code.a2a.task_store import A2ATaskStore
from iac_code.agent.message import ImageBlock
from iac_code.agent.message import ImageBlock, TextBlock
from iac_code.pipeline.engine.user_input import PipelineUserInput
from iac_code.types.stream_events import PermissionRequestEvent, TextDeltaEvent, ToolResultEvent

Expand Down Expand Up @@ -706,6 +708,44 @@ async def test_executor_runs_normal_mode_when_iac_code_mode_is_normal(
assert dumped["status"]["state"] == "TASK_STATE_INPUT_REQUIRED"


@pytest.mark.asyncio
async def test_normal_mode_image_request_passes_image_blocks_to_agent_loop(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
from a2a.types import Message, Part, Role

monkeypatch.setenv("IAC_CODE_MODE", "normal")
monkeypatch.setattr(
"iac_code.a2a.parts.maybe_resize_and_downsample",
lambda raw: SimpleNamespace(data=b"resized-image", media_type="image/webp"),
)
loop = FakeAgentLoop([TextDeltaEvent(text="normal")])
runtime = FakeRuntime(agent_loop=loop, session_id="session-1")
monkeypatch.setattr("iac_code.a2a.executor.create_agent_runtime", lambda options: runtime)

context = FakeRequestContext(metadata={"iac_code": {"cwd": str(tmp_path)}})
context.message = Message(
role=Role.ROLE_USER,
parts=[
Part(text="请识别附件架构图", media_type="text/plain"),
Part(raw=b"fake-image", media_type="image/png", filename="diagram.png"),
],
message_id="msg-normal-image",
)

store = A2ATaskStore(metrics=NoOpA2AMetrics())
executor = IacCodeA2AExecutor(task_store=store, model="qwen3.6-plus")
await executor.execute(context, FakeEventQueue())

assert loop.prompts == [
[
TextBlock(text="请识别附件架构图"),
ImageBlock(media_type="image/webp", data=base64.b64encode(b"resized-image").decode("ascii")),
]
]


@pytest.mark.asyncio
async def test_cancel_bypasses_context_lock(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
started = asyncio.Event()
Expand Down Expand Up @@ -1026,7 +1066,7 @@ async def execute(self, **kwargs) -> None:


@pytest.mark.asyncio
async def test_pipeline_handoff_image_request_uses_normal_manifest_prompt(
async def test_pipeline_handoff_image_request_passes_image_blocks_to_normal_agent(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
Expand All @@ -1038,6 +1078,10 @@ async def test_pipeline_handoff_image_request_uses_normal_manifest_prompt(
from iac_code.services.session_storage import SessionStorage

monkeypatch.setenv("IAC_CODE_MODE", "pipeline")
monkeypatch.setattr(
"iac_code.a2a.parts.maybe_resize_and_downsample",
lambda raw: SimpleNamespace(data=b"resized-handoff-image", media_type="image/png"),
)
config_dir = tmp_path / "config"
config_dir.mkdir()
monkeypatch.setenv("IAC_CODE_CONFIG_DIR", str(config_dir))
Expand Down Expand Up @@ -1093,10 +1137,9 @@ async def execute(self, **kwargs) -> None:
FakeEventQueue(),
)

assert loop.prompts
assert "A2A multimodal attachment:" in loop.prompts[0]
assert "mediaType=image/png" in loop.prompts[0]
assert "[Image input]" not in loop.prompts[0]
assert loop.prompts == [
[ImageBlock(media_type="image/png", data=base64.b64encode(b"resized-handoff-image").decode("ascii"))]
]


@pytest.mark.asyncio
Expand Down
Loading
Loading