From 7e71d1d40106db163e6ea428ee4ef4737bb39ba7 Mon Sep 17 00:00:00 2001 From: rzx Date: Fri, 26 Jun 2026 16:13:41 +0800 Subject: [PATCH] fix a2a normal multimodal input --- src/iac_code/a2a/executor.py | 53 +++++++++++++++++++++++++--------- src/iac_code/a2a/parts.py | 14 ++++----- tests/a2a/fakes.py | 4 +-- tests/a2a/test_executor.py | 55 ++++++++++++++++++++++++++++++++---- tests/a2a/test_parts.py | 42 +++++++++++++-------------- 5 files changed, 119 insertions(+), 49 deletions(-) diff --git a/src/iac_code/a2a/executor.py b/src/iac_code/a2a/executor.py index 9b394ac2..4854ea80 100644 --- a/src/iac_code/a2a/executor.py +++ b/src/iac_code/a2a/executor.py @@ -27,8 +27,8 @@ from iac_code.a2a.parts import ( allowed_cwd_roots, is_relative_to, - parts_to_pipeline_input, parts_to_prompt, + parts_to_user_input, resolve_workspace_path, ) from iac_code.a2a.pipeline_events import PipelineA2AContext, PipelineEventTranslator @@ -50,6 +50,7 @@ TASK_STATE_INPUT_REQUIRED, TASK_STATE_WORKING, ) +from iac_code.agent.message import ContentBlock from iac_code.agent.message import Message as AgentMessage from iac_code.config import get_active_provider_key, get_provider_config, load_credentials from iac_code.i18n import _ @@ -683,14 +684,15 @@ def _public_cleanup_error(value: Any) -> str | None: async def _stream_a2a_normal_events( *, runtime: Any, - prompt: str, + prompt: str | list[ContentBlock], + prompt_text: str, cleanup_ledger: CleanupLedger | None, cleanup_publisher: PipelineA2AEventPublisher | None, cwd: str, session_id: str, ) -> AsyncIterator[Any]: if _a2a_cleanup_ledger_unavailable(cleanup_ledger, runtime=runtime, cwd=cwd, session_id=session_id): - if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt): + if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt_text): yield TextDeltaEvent( text=_("Rollback cleanup deferred prompt state is unavailable. Please repair it before continuing.") ) @@ -702,7 +704,7 @@ async def _stream_a2a_normal_events( if cleanup_ledger is not None and cleanup_ledger.load_failed(): if _runtime_has_cleanup_prompt(runtime) or _session_has_cleanup_prompt(cwd=cwd, session_id=session_id): - if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt): + if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt_text): yield TextDeltaEvent( text=_("Rollback cleanup deferred prompt state is unavailable. Please repair it before continuing.") ) @@ -728,7 +730,7 @@ async def _stream_a2a_normal_events( async for event in cleanup_stream: yield event if cleanup_ledger.pending_resources(): - if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt): + if not _append_a2a_deferred_cleanup_prompt(cwd=cwd, session_id=session_id, prompt=prompt_text): yield TextDeltaEvent( text=_("Rollback cleanup deferred prompt state is unavailable. Please repair it before continuing.") ) @@ -740,13 +742,18 @@ async def _stream_a2a_normal_events( _mark_completed_cleanup_prompts(runtime=runtime, cwd=cwd, session_id=session_id, ledger=cleanup_ledger) _prune_completed_cleanup_prompt_from_runtime(runtime, cleanup_ledger) - prompts_after_cleanup = _a2a_prompts_after_cleanup(cwd=cwd, session_id=session_id, prompt=prompt) + prompts_after_cleanup = _a2a_prompts_after_cleanup(cwd=cwd, session_id=session_id, prompt=prompt_text) if prompts_after_cleanup is None: yield TextDeltaEvent( text=_("Rollback cleanup deferred prompt state is unavailable. Please repair it before continuing.") ) return - prompts_to_run, has_deferred_prompts = prompts_after_cleanup + deferred_prompts, has_deferred_prompts = prompts_after_cleanup + prompts_to_run: list[str | list[ContentBlock]] = [] + if has_deferred_prompts: + prompts_to_run.extend(deferred_prompts) + else: + prompts_to_run.append(prompt) for prompt_to_run in prompts_to_run: prompt_stream = runtime.agent_loop.run_streaming(prompt_to_run) if cleanup_ledger is not None: @@ -818,15 +825,20 @@ async def publish_initial_task_if_missing() -> None: cwd=cwd, ) pipeline_input: PipelineUserInput | None = None + normal_input: PipelineUserInput | None = None if pipeline_mode and not route_pipeline_handoff_to_normal: try: pipeline_input = self._pipeline_input_from_context(context, cwd=cwd) except ValueError as exc: raise InvalidParamsError(sanitize_public_text(str(exc))) from exc - prompt = pipeline_input.display_text self._validate_pipeline_request_input(pipeline_input, model=model) else: - prompt = self._prompt_from_context(context, cwd=cwd) + try: + normal_input = self._normal_input_from_context(context, cwd=cwd) + except ValueError as exc: + raise InvalidParamsError(sanitize_public_text(str(exc))) from exc + if normal_input.has_images: + self._validate_pipeline_request_input(normal_input, model=model) if pipeline_mode and requested_task_id is None: recovered_task_id = await self._recoverable_pipeline_task_id_for_context(context_id=context_id, cwd=cwd) if recovered_task_id is not None: @@ -873,7 +885,11 @@ async def publish_initial_task_if_missing() -> None: self._metrics.record_task_failed() return - if not (pipeline_mode and not route_pipeline_handoff_to_normal) and not prompt.strip(): + if ( + not (pipeline_mode and not route_pipeline_handoff_to_normal) + and normal_input is not None + and normal_input.is_empty + ): task.state = TASK_STATE_FAILED await self._publish_status( event_queue, @@ -1041,9 +1057,11 @@ def runtime_factory(session_id: str) -> Any: artifact_store=self._artifact_store, exposure_types=self._thinking_exposure_types, ) + assert normal_input is not None stream = _stream_a2a_normal_events( runtime=runtime, - prompt=prompt, + prompt=normal_input.content, + prompt_text=normal_input.display_text, cleanup_ledger=cleanup_ledger, cleanup_publisher=cleanup_publisher, cwd=cwd, @@ -1256,17 +1274,26 @@ def _prompt_from_context(self, context: RequestContext, *, cwd: str) -> str: return context.get_user_input() return parts_to_prompt(message.parts, cwd=cwd) + def _normal_input_from_context(self, context: RequestContext, *, cwd: str) -> PipelineUserInput: + message = getattr(context, "message", None) + if not isinstance(message, Message): + return normalize_pipeline_user_input(context.get_user_input()) + user_input = parts_to_user_input(message.parts, cwd=cwd) + if user_input.has_images: + return user_input + return normalize_pipeline_user_input(user_input.display_text) + def _pipeline_input_from_context(self, context: RequestContext, *, cwd: str) -> PipelineUserInput: message = getattr(context, "message", None) if not isinstance(message, Message): return normalize_pipeline_user_input(context.get_user_input()) - return parts_to_pipeline_input(message.parts, cwd=cwd) + return parts_to_user_input(message.parts, cwd=cwd) def validate_pipeline_message_request(self, message: Message) -> None: metadata = getattr(message, "metadata", None) try: cwd = self._resolve_cwd(metadata) - pipeline_input = parts_to_pipeline_input(message.parts, cwd=cwd) + pipeline_input = parts_to_user_input(message.parts, cwd=cwd) except ValueError as exc: raise InvalidParamsError(sanitize_public_text(str(exc))) from exc model = self._resolve_model(metadata) or self._model diff --git a/src/iac_code/a2a/parts.py b/src/iac_code/a2a/parts.py index d28fded1..67db6404 100644 --- a/src/iac_code/a2a/parts.py +++ b/src/iac_code/a2a/parts.py @@ -113,10 +113,10 @@ def parts_to_prompt(message_parts: Iterable[Any], *, cwd: str | Path) -> str: return "\n".join(value for value in values if value) -def parts_to_pipeline_input(message_parts: Iterable[Any], *, cwd: str | Path) -> PipelineUserInput: +def parts_to_user_input(message_parts: Iterable[Any], *, cwd: str | Path) -> PipelineUserInput: blocks: list[ContentBlock] = [] for part in message_parts: - converted = part_to_pipeline_block(part, cwd=cwd) + converted = part_to_content_block(part, cwd=cwd) if isinstance(converted, list): blocks.extend(converted) elif converted: @@ -131,7 +131,7 @@ def parts_to_pipeline_input(message_parts: Iterable[Any], *, cwd: str | Path) -> return PipelineUserInput(content=text, display_text=text, has_images=False) -def part_to_pipeline_block(part: Any, *, cwd: str | Path) -> str | list[ContentBlock]: +def part_to_content_block(part: Any, *, cwd: str | Path) -> str | list[ContentBlock]: media_type = _media_type(part) if _has_field(part, "text"): _ensure_text_like(media_type) @@ -140,7 +140,7 @@ def part_to_pipeline_block(part: Any, *, cwd: str | Path) -> str | list[ContentB if media_type in SUPPORTED_IMAGE_MIME_TYPES: return [_image_block_from_binary(_binary_data_part_bytes(part), requested_media_type=media_type)] if _is_multimodal(media_type): - raise ValueError("A2A pipeline input has unsupported image media type.") + raise ValueError("A2A input has unsupported image media type.") if media_type != "application/json": raise ValueError("A2A data parts must use application/json media type.") data = MessageToDict(part.data, preserving_proto_field_name=False) @@ -153,7 +153,7 @@ def part_to_pipeline_block(part: Any, *, cwd: str | Path) -> str | list[ContentB _ensure_size(raw, limit=MAX_BINARY_INLINE_BYTES, label="A2A binary raw part") return [_image_block_from_binary(raw, requested_media_type=media_type)] if _is_multimodal(media_type): - raise ValueError("A2A pipeline input has unsupported image media type.") + raise ValueError("A2A input has unsupported image media type.") _ensure_text_like(media_type) _ensure_size(raw, limit=MAX_INLINE_BYTES, label="A2A raw part") try: @@ -167,7 +167,7 @@ def part_to_pipeline_block(part: Any, *, cwd: str | Path) -> str | list[ContentB raise ValueError("A2A binary file URL part content is too large.") return [_image_block_from_binary(path.read_bytes(), requested_media_type=media_type)] if _is_multimodal(media_type): - raise ValueError("A2A pipeline input has unsupported image media type.") + raise ValueError("A2A input has unsupported image media type.") _ensure_text_like(media_type) return _read_file_url_part(str(part.url), cwd=Path(cwd)) raise ValueError("A2A server supports text, JSON data, raw text, or workspace file URL parts only.") @@ -283,7 +283,7 @@ def _binary_data_part_bytes(part: Any) -> bytes: def _image_block_from_binary(raw: bytes, *, requested_media_type: str) -> ImageBlock: if requested_media_type not in SUPPORTED_IMAGE_MIME_TYPES: - raise ValueError("A2A pipeline input has unsupported image media type.") + raise ValueError("A2A input has unsupported image media type.") resized = maybe_resize_and_downsample(raw) return ImageBlock( media_type=resized.media_type, diff --git a/tests/a2a/fakes.py b/tests/a2a/fakes.py index fe0295be..7e4b61d3 100644 --- a/tests/a2a/fakes.py +++ b/tests/a2a/fakes.py @@ -40,9 +40,9 @@ def pending_future() -> asyncio.Future[bool]: class FakeAgentLoop: def __init__(self, events: list[Any]) -> None: self.events = events - self.prompts: list[str] = [] + self.prompts: list[Any] = [] - async def run_streaming(self, prompt: str): + async def run_streaming(self, prompt: Any): self.prompts.append(prompt) for event in self.events: await asyncio.sleep(0) diff --git a/tests/a2a/test_executor.py b/tests/a2a/test_executor.py index 4df909f7..ad9d5c13 100644 --- a/tests/a2a/test_executor.py +++ b/tests/a2a/test_executor.py @@ -1,6 +1,8 @@ import asyncio +import base64 import logging from pathlib import Path +from types import SimpleNamespace import pytest from a2a.types import TaskStatusUpdateEvent @@ -14,7 +16,7 @@ from iac_code.a2a.pipeline_journal import A2APipelineJournal from iac_code.a2a.pipeline_paths import a2a_pipeline_dir_for_session from iac_code.a2a.task_store import A2ATaskStore -from iac_code.agent.message import ImageBlock +from iac_code.agent.message import ImageBlock, TextBlock from iac_code.pipeline.engine.user_input import PipelineUserInput from iac_code.types.stream_events import PermissionRequestEvent, TextDeltaEvent, ToolResultEvent @@ -706,6 +708,44 @@ async def test_executor_runs_normal_mode_when_iac_code_mode_is_normal( assert dumped["status"]["state"] == "TASK_STATE_INPUT_REQUIRED" +@pytest.mark.asyncio +async def test_normal_mode_image_request_passes_image_blocks_to_agent_loop( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + from a2a.types import Message, Part, Role + + monkeypatch.setenv("IAC_CODE_MODE", "normal") + monkeypatch.setattr( + "iac_code.a2a.parts.maybe_resize_and_downsample", + lambda raw: SimpleNamespace(data=b"resized-image", media_type="image/webp"), + ) + loop = FakeAgentLoop([TextDeltaEvent(text="normal")]) + runtime = FakeRuntime(agent_loop=loop, session_id="session-1") + monkeypatch.setattr("iac_code.a2a.executor.create_agent_runtime", lambda options: runtime) + + context = FakeRequestContext(metadata={"iac_code": {"cwd": str(tmp_path)}}) + context.message = Message( + role=Role.ROLE_USER, + parts=[ + Part(text="请识别附件架构图", media_type="text/plain"), + Part(raw=b"fake-image", media_type="image/png", filename="diagram.png"), + ], + message_id="msg-normal-image", + ) + + store = A2ATaskStore(metrics=NoOpA2AMetrics()) + executor = IacCodeA2AExecutor(task_store=store, model="qwen3.6-plus") + await executor.execute(context, FakeEventQueue()) + + assert loop.prompts == [ + [ + TextBlock(text="请识别附件架构图"), + ImageBlock(media_type="image/webp", data=base64.b64encode(b"resized-image").decode("ascii")), + ] + ] + + @pytest.mark.asyncio async def test_cancel_bypasses_context_lock(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: started = asyncio.Event() @@ -1026,7 +1066,7 @@ async def execute(self, **kwargs) -> None: @pytest.mark.asyncio -async def test_pipeline_handoff_image_request_uses_normal_manifest_prompt( +async def test_pipeline_handoff_image_request_passes_image_blocks_to_normal_agent( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: @@ -1038,6 +1078,10 @@ async def test_pipeline_handoff_image_request_uses_normal_manifest_prompt( from iac_code.services.session_storage import SessionStorage monkeypatch.setenv("IAC_CODE_MODE", "pipeline") + monkeypatch.setattr( + "iac_code.a2a.parts.maybe_resize_and_downsample", + lambda raw: SimpleNamespace(data=b"resized-handoff-image", media_type="image/png"), + ) config_dir = tmp_path / "config" config_dir.mkdir() monkeypatch.setenv("IAC_CODE_CONFIG_DIR", str(config_dir)) @@ -1093,10 +1137,9 @@ async def execute(self, **kwargs) -> None: FakeEventQueue(), ) - assert loop.prompts - assert "A2A multimodal attachment:" in loop.prompts[0] - assert "mediaType=image/png" in loop.prompts[0] - assert "[Image input]" not in loop.prompts[0] + assert loop.prompts == [ + [ImageBlock(media_type="image/png", data=base64.b64encode(b"resized-handoff-image").decode("ascii"))] + ] @pytest.mark.asyncio diff --git a/tests/a2a/test_parts.py b/tests/a2a/test_parts.py index 1d8deeab..ecd6f521 100644 --- a/tests/a2a/test_parts.py +++ b/tests/a2a/test_parts.py @@ -10,7 +10,7 @@ from PIL import Image from iac_code.a2a import parts -from iac_code.a2a.parts import parts_to_pipeline_input +from iac_code.a2a.parts import parts_to_user_input from iac_code.agent.message import ImageBlock, TextBlock @@ -222,12 +222,12 @@ def _tiny_png_bytes() -> bytes: return buf.getvalue() -def test_parts_to_pipeline_input_converts_raw_image(monkeypatch, tmp_path) -> None: +def test_parts_to_user_input_converts_raw_image(monkeypatch, tmp_path) -> None: raw = b"fake png bytes" resized = b"resized raw image" calls = _resize_spy(monkeypatch, output=resized, media_type="image/webp") - value = parts_to_pipeline_input([Part(raw=raw, media_type="image/png")], cwd=tmp_path) + value = parts_to_user_input([Part(raw=raw, media_type="image/png")], cwd=tmp_path) assert calls == [raw] assert value.has_images is True @@ -235,12 +235,12 @@ def test_parts_to_pipeline_input_converts_raw_image(monkeypatch, tmp_path) -> No assert value.content == [ImageBlock(media_type="image/webp", data=base64.b64encode(resized).decode("ascii"))] -def test_parts_to_pipeline_input_preserves_text_plus_image_order(monkeypatch, tmp_path) -> None: +def test_parts_to_user_input_preserves_text_plus_image_order(monkeypatch, tmp_path) -> None: raw = b"fake jpeg bytes" resized = b"resized jpeg bytes" calls = _resize_spy(monkeypatch, output=resized, media_type="image/jpeg") - value = parts_to_pipeline_input( + value = parts_to_user_input( [ Part(text="inspect this", media_type="text/plain"), Part(raw=raw, media_type="image/jpeg"), @@ -256,35 +256,35 @@ def test_parts_to_pipeline_input_preserves_text_plus_image_order(monkeypatch, tm ] -def test_parts_to_pipeline_input_converts_base64_data_image(monkeypatch, tmp_path) -> None: +def test_parts_to_user_input_converts_base64_data_image(monkeypatch, tmp_path) -> None: raw = b"fake data image" resized = b"resized data image" encoded = base64.b64encode(raw).decode("ascii") calls = _resize_spy(monkeypatch, output=resized, media_type="image/png") - value = parts_to_pipeline_input([_binary_data_part({"bytes": encoded}, media_type="image/png")], cwd=tmp_path) + value = parts_to_user_input([_binary_data_part({"bytes": encoded}, media_type="image/png")], cwd=tmp_path) assert calls == [raw] assert value.content == [ImageBlock(media_type="image/png", data=base64.b64encode(resized).decode("ascii"))] -def test_parts_to_pipeline_input_converts_safe_file_url_image(monkeypatch, tmp_path) -> None: +def test_parts_to_user_input_converts_safe_file_url_image(monkeypatch, tmp_path) -> None: raw = b"file image bytes" resized = b"resized file image" source = tmp_path / "diagram.png" source.write_bytes(raw) calls = _resize_spy(monkeypatch, output=resized, media_type="image/png") - value = parts_to_pipeline_input([Part(url=source.as_uri(), media_type="image/png")], cwd=tmp_path) + value = parts_to_user_input([Part(url=source.as_uri(), media_type="image/png")], cwd=tmp_path) assert calls == [raw] assert value.content == [ImageBlock(media_type="image/png", data=base64.b64encode(resized).decode("ascii"))] -def test_parts_to_pipeline_input_uses_real_resizer_for_valid_image_bytes(tmp_path) -> None: +def test_parts_to_user_input_uses_real_resizer_for_valid_image_bytes(tmp_path) -> None: raw = _tiny_bmp_bytes() - value = parts_to_pipeline_input([Part(raw=raw, media_type="image/png")], cwd=tmp_path) + value = parts_to_user_input([Part(raw=raw, media_type="image/png")], cwd=tmp_path) assert isinstance(value.content, list) block = value.content[0] @@ -293,10 +293,10 @@ def test_parts_to_pipeline_input_uses_real_resizer_for_valid_image_bytes(tmp_pat assert base64.b64decode(block.data).startswith(b"\x89PNG\r\n\x1a\n") -def test_parts_to_pipeline_input_accepts_tiny_png_without_monkeypatch(tmp_path) -> None: +def test_parts_to_user_input_accepts_tiny_png_without_monkeypatch(tmp_path) -> None: raw = _tiny_png_bytes() - value = parts_to_pipeline_input([Part(raw=raw, media_type="image/png")], cwd=tmp_path) + value = parts_to_user_input([Part(raw=raw, media_type="image/png")], cwd=tmp_path) assert isinstance(value.content, list) block = value.content[0] @@ -305,26 +305,26 @@ def test_parts_to_pipeline_input_accepts_tiny_png_without_monkeypatch(tmp_path) assert base64.b64decode(block.data).startswith(b"\x89PNG\r\n\x1a\n") -def test_parts_to_pipeline_input_rejects_unsafe_file_url_image(tmp_path) -> None: +def test_parts_to_user_input_rejects_unsafe_file_url_image(tmp_path) -> None: outside = tmp_path.parent / "outside-diagram.png" outside.write_bytes(b"outside") with pytest.raises(ValueError, match="outside the allowed workspace"): - parts_to_pipeline_input([Part(url=outside.as_uri(), media_type="image/png")], cwd=tmp_path) + parts_to_user_input([Part(url=outside.as_uri(), media_type="image/png")], cwd=tmp_path) -def test_parts_to_pipeline_input_rejects_invalid_base64_data_image(tmp_path) -> None: +def test_parts_to_user_input_rejects_invalid_base64_data_image(tmp_path) -> None: with pytest.raises(ValueError, match="valid base64"): - parts_to_pipeline_input([_binary_data_part({"bytes": "not-base64!"}, media_type="image/png")], cwd=tmp_path) + parts_to_user_input([_binary_data_part({"bytes": "not-base64!"}, media_type="image/png")], cwd=tmp_path) -def test_parts_to_pipeline_input_rejects_oversized_raw_image(monkeypatch, tmp_path) -> None: +def test_parts_to_user_input_rejects_oversized_raw_image(monkeypatch, tmp_path) -> None: monkeypatch.setattr("iac_code.a2a.parts.MAX_BINARY_INLINE_BYTES", 3) with pytest.raises(ValueError, match="too large"): - parts_to_pipeline_input([Part(raw=b"abcd", media_type="image/png")], cwd=tmp_path) + parts_to_user_input([Part(raw=b"abcd", media_type="image/png")], cwd=tmp_path) -def test_parts_to_pipeline_input_rejects_audio_as_true_image(tmp_path) -> None: +def test_parts_to_user_input_rejects_audio_as_true_image(tmp_path) -> None: with pytest.raises(ValueError, match="unsupported image media type"): - parts_to_pipeline_input([Part(raw=b"audio", media_type="audio/wav")], cwd=tmp_path) + parts_to_user_input([Part(raw=b"audio", media_type="audio/wav")], cwd=tmp_path)