Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions src/iac_code/a2a/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,15 @@ async def send_message(
*,
cwd: str,
context_id: str | None = None,
model: str | None = None,
) -> A2AClientResponse:
payload = self._message_payload(method="SendMessage", prompt=prompt, cwd=cwd, context_id=context_id)
payload = self._message_payload(
method="SendMessage",
prompt=prompt,
cwd=cwd,
context_id=context_id,
model=model,
)
transport = self._make_transport_client(url)
response = await transport.send(payload)
return A2AClientResponse(payload=response)
Expand All @@ -144,8 +151,15 @@ async def stream_message(
*,
cwd: str,
context_id: str | None = None,
model: str | None = None,
) -> AsyncIterator[dict[str, Any]]:
payload = self._message_payload(method="SendStreamingMessage", prompt=prompt, cwd=cwd, context_id=context_id)
payload = self._message_payload(
method="SendStreamingMessage",
prompt=prompt,
cwd=cwd,
context_id=context_id,
model=model,
)
transport = self._make_transport_client(url)
async for event in transport.stream(payload):
yield event
Expand Down Expand Up @@ -332,12 +346,25 @@ def _transport_options(self, binding: A2ATransportBinding) -> TransportClientOpt
api_key_header=auth.api_key_header,
)

def _message_payload(self, *, method: str, prompt: str, cwd: str, context_id: str | None) -> dict[str, Any]:
def _message_payload(
self,
*,
method: str,
prompt: str,
cwd: str,
context_id: str | None,
model: str | None,
) -> dict[str, Any]:
iac_code_metadata = {"cwd": cwd}
if model:
stripped_model = model.strip()
if stripped_model:
iac_code_metadata["iac_code_model"] = stripped_model
message: dict[str, Any] = {
"messageId": str(uuid.uuid4()),
"role": "ROLE_USER",
"parts": [{"text": prompt}],
"metadata": {"iac_code": {"cwd": cwd}},
"metadata": {"iac_code": iac_code_metadata},
}
if context_id:
message["contextId"] = context_id
Expand Down
124 changes: 109 additions & 15 deletions src/iac_code/a2a/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from iac_code.a2a.events import make_text_part, publish_stream_event
from iac_code.a2a.exposure import normalize_a2a_exposure_types
from iac_code.a2a.metrics import A2AMetrics, NoOpA2AMetrics
from iac_code.a2a.parts import allowed_cwd_roots, is_relative_to, parts_to_prompt
from iac_code.a2a.parts import allowed_cwd_roots, is_relative_to, parts_to_prompt, resolve_workspace_path
from iac_code.a2a.pipeline_executor import IacCodeA2APipelineExecutor, recoverable_task_id_from_sidecar
from iac_code.a2a.pipeline_paths import existing_a2a_pipeline_dir_for_session
from iac_code.a2a.pipeline_snapshot import A2APipelineSnapshotStore
Expand All @@ -33,6 +33,7 @@
from iac_code.i18n import _
from iac_code.pipeline.config import RunMode, get_run_mode
from iac_code.services.agent_factory import AgentFactoryOptions, create_agent_runtime
from iac_code.services.providers.aliyun import DEFAULT_REGION, AliyunCredential, use_aliyun_credential
from iac_code.services.session_storage import SessionStorage
from iac_code.services.telemetry import use_session_id, use_user_id
from iac_code.utils.public_errors import public_exception_summary, sanitize_public_text
Expand Down Expand Up @@ -90,6 +91,9 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non
)
cwd = self._resolve_cwd(metadata)
user_id = self._resolve_user_id(metadata)
metadata_model = self._resolve_model(metadata)
model = metadata_model or self._model
aliyun_credential = self._resolve_aliyun_credential(metadata)
prompt = self._prompt_from_context(context, cwd=cwd)
pipeline_mode = get_run_mode() == RunMode.PIPELINE
if pipeline_mode and requested_task_id is None:
Expand Down Expand Up @@ -121,6 +125,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non
await self._notify_terminal_task(task_id=task.task_id, context_id=task.context_id, state=task.state)
self._metrics.record_executor_error()
return
self._log_executor_exception("setup", task_id=task_id, context_id=context_id)
await self._publish_status(
event_queue,
task_id=task_id,
Expand Down Expand Up @@ -156,7 +161,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non
if pipeline_mode and not route_pipeline_handoff_to_normal:
pipeline_executor = IacCodeA2APipelineExecutor(
task_store=self._task_store,
model=self._model,
model=model,
metrics=self._metrics,
artifact_store=self._artifact_store,
push_notifier=self._push_notifier,
Expand Down Expand Up @@ -185,23 +190,28 @@ def runtime_factory(session_id: str) -> Any:
resume_messages = SessionStorage.repair_interrupted(loaded) if loaded else None
return create_agent_runtime(
AgentFactoryOptions(
model=self._model,
model=model,
session_id=session_id,
cwd=cwd,
resume_messages=resume_messages,
)
)

try:
ctx = await self._task_store.get_or_create_context(
context_id=context_id,
cwd=cwd,
runtime_factory=runtime_factory,
aliyun_credential_ctx = (
use_aliyun_credential(aliyun_credential) if aliyun_credential else contextlib.nullcontext()
)
if not hasattr(ctx.runtime, "agent_loop"):
ctx.runtime = runtime_factory(ctx.session_id)
self._task_store.mirror_context(ctx)
with aliyun_credential_ctx:
ctx = await self._task_store.get_or_create_context(
context_id=context_id,
cwd=cwd,
runtime_factory=runtime_factory,
)
if not hasattr(ctx.runtime, "agent_loop"):
ctx.runtime = runtime_factory(ctx.session_id)
self._task_store.mirror_context(ctx)
except Exception as exc:
self._log_executor_exception("runtime setup", task_id=task_id, context_id=context_id)
await self._publish_status(
event_queue,
task_id=task_id,
Expand Down Expand Up @@ -272,7 +282,12 @@ def runtime_factory(session_id: str) -> Any:
state=TaskState.TASK_STATE_WORKING,
)
user_id_ctx = use_user_id(user_id) if user_id else contextlib.nullcontext()
with use_session_id(ctx.session_id), user_id_ctx:
aliyun_credential_ctx = (
use_aliyun_credential(aliyun_credential) if aliyun_credential else contextlib.nullcontext()
)
with use_session_id(ctx.session_id), user_id_ctx, aliyun_credential_ctx:
self._configure_runtime_model(runtime, model, from_metadata=metadata_model is not None)
self._refresh_runtime_cloud_tools(runtime)
async for event in runtime.agent_loop.run_streaming(prompt):
text_chunk = await publish_stream_event(
event_queue,
Expand Down Expand Up @@ -323,6 +338,7 @@ def runtime_factory(session_id: str) -> Any:
self._metrics.record_executor_error()
else:
task.state = TASK_STATE_FAILED
self._log_executor_exception("streaming", task_id=task_id, context_id=context_id)
await self._publish_status(
event_queue,
task_id=task_id,
Expand Down Expand Up @@ -377,26 +393,29 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None
)

def _resolve_cwd(self, metadata: Any | None) -> str:
cwd = os.getcwd()
if metadata is not None and hasattr(metadata, "DESCRIPTOR"):
metadata = MessageToDict(metadata, preserving_proto_field_name=False)
cwd: str | None = None
if metadata:
raw_iac_meta = metadata.get("iac_code") if isinstance(metadata, Mapping) else None
if isinstance(raw_iac_meta, Mapping):
raw_cwd = raw_iac_meta.get("cwd")
if isinstance(raw_cwd, str):
cwd = raw_cwd
if cwd is None:
cwd = os.getcwd()
if not isinstance(cwd, str) or not Path(cwd).is_absolute():
raise ValueError("Invalid A2A workspace metadata.")
resolved_cwd = Path(cwd).resolve()
logical_cwd = os.path.normpath(cwd)
resolved_cwd = resolve_workspace_path(Path(logical_cwd))
if not any(_is_relative_to(resolved_cwd, root) for root in _allowed_cwd_roots()):
raise ValueError("Invalid A2A workspace metadata.")
if resolved_cwd.exists():
if not resolved_cwd.is_dir():
raise ValueError("Invalid A2A workspace metadata.")
else:
resolved_cwd.mkdir(parents=True, exist_ok=True)
return str(resolved_cwd)
return logical_cwd

def _resolve_user_id(self, metadata: Any | None) -> str | None:
if metadata is not None and hasattr(metadata, "DESCRIPTOR"):
Expand All @@ -411,6 +430,47 @@ def _resolve_user_id(self, metadata: Any | None) -> str | None:
return raw_user_id.strip()
return None

def _resolve_model(self, metadata: Any | None) -> str | None:
if metadata is not None and hasattr(metadata, "DESCRIPTOR"):
metadata = MessageToDict(metadata, preserving_proto_field_name=False)
if not isinstance(metadata, Mapping):
return None
raw_iac_meta = metadata.get("iac_code")
if not isinstance(raw_iac_meta, Mapping):
return None
raw_model = raw_iac_meta.get("iac_code_model")
if isinstance(raw_model, str) and raw_model.strip():
return raw_model.strip()
return None

def _resolve_aliyun_credential(self, metadata: Any | None) -> AliyunCredential | None:
if metadata is not None and hasattr(metadata, "DESCRIPTOR"):
metadata = MessageToDict(metadata, preserving_proto_field_name=False)
if not isinstance(metadata, Mapping):
return None
raw_iac_meta = metadata.get("iac_code")
if not isinstance(raw_iac_meta, Mapping):
return None

def _read(name: str) -> str | None:
raw_value = raw_iac_meta.get(name)
if isinstance(raw_value, str) and raw_value.strip():
return raw_value.strip()
return None

access_key_id = _read("alibaba_cloud_access_key_id")
access_key_secret = _read("alibaba_cloud_access_key_secret")
if not access_key_id or not access_key_secret:
return None
sts_token = _read("alibaba_cloud_security_token") or ""
return AliyunCredential(
mode="StsToken" if sts_token else "AK",
access_key_id=access_key_id,
access_key_secret=access_key_secret,
region_id=_read("alibaba_cloud_region_id") or DEFAULT_REGION,
sts_token=sts_token,
)

def _prompt_from_context(self, context: RequestContext, *, cwd: str) -> str:
message = getattr(context, "message", None)
if not isinstance(message, Message):
Expand All @@ -427,7 +487,6 @@ def _sanitize_error(self, exc: Exception) -> str:
status = getattr(exc, "status_code", None) or getattr(exc, "status", None)
if status == 401:
return "Authentication required. Please configure your API credentials."
logger.exception("Unhandled A2A executor error")
return _format_exception(exc)

async def _should_route_pipeline_handoff_to_normal(self, *, context_id: str, cwd: str) -> bool:
Expand Down Expand Up @@ -488,6 +547,9 @@ async def _recoverable_pipeline_task_id_for_context(self, *, context_id: str, cw
logger.debug("Failed to recover A2A pipeline task id", exc_info=True)
return None

def _log_executor_exception(self, stage: str, *, task_id: str, context_id: str) -> None:
logger.exception("A2A executor %s failed (task_id=%s, context_id=%s)", stage, task_id, context_id)

async def _publish_status(
self,
event_queue: EventQueue,
Expand Down Expand Up @@ -528,6 +590,38 @@ async def _publish_initial_task(
task.history.append(message)
await event_queue.enqueue_event(task)

def _refresh_runtime_cloud_tools(self, runtime: Any) -> None:
refresh_cloud_tools = getattr(runtime, "refresh_cloud_tools", None)
if callable(refresh_cloud_tools):
refresh_cloud_tools()
return
tool_registry = getattr(runtime, "tool_registry", None)
if tool_registry is None:
return
from iac_code.services.cloud_credentials import CloudCredentials
from iac_code.tools.cloud.registry import register_cloud_tools

register_cloud_tools(tool_registry, CloudCredentials())

def _configure_runtime_model(self, runtime: Any, model: str, *, from_metadata: bool) -> None:
provider_manager = getattr(runtime, "provider_manager", None)
reconfigure = getattr(provider_manager, "reconfigure", None)
if not callable(reconfigure):
return
was_metadata_model = bool(getattr(runtime, "_iac_code_a2a_metadata_model_applied", False))
if not from_metadata and not was_metadata_model:
return

from iac_code.config import load_credentials

provider_key_override = getattr(provider_manager, "_provider_key_override", None)
base_url_override = getattr(provider_manager, "_base_url_override", None)
credentials = getattr(provider_manager, "_credentials", None)
if not isinstance(credentials, dict) or provider_key_override is None:
credentials = load_credentials(model=model)
reconfigure(model, credentials, provider_key_override, base_url_override)
setattr(runtime, "_iac_code_a2a_metadata_model_applied", from_metadata)

async def _notify_terminal_task(self, *, task_id: str, context_id: str, state: str) -> None:
if self._push_notifier is None:
return
Expand Down
24 changes: 23 additions & 1 deletion src/iac_code/a2a/parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def allowed_cwd_roots() -> list[Path]:
candidates = [Path(item) for item in raw.split(os.pathsep) if item]
else:
candidates = [Path.cwd(), Path(tempfile.gettempdir())]
return [path.resolve() for path in candidates if path.exists() and path.is_dir()]
return [resolve_workspace_path(path) for path in candidates if path.exists() and path.is_dir()]


def is_relative_to(path: Path, root: Path) -> bool:
Expand All @@ -81,6 +81,28 @@ def is_relative_to(path: Path, root: Path) -> bool:
return True


def resolve_workspace_path(path: Path) -> Path:
try:
return path.resolve()
except FileNotFoundError:
if not path.is_absolute() or _has_symlink_component(path):
raise
return path.absolute()


def _has_symlink_component(path: Path) -> bool:
current = Path(path.anchor) if path.anchor else Path()
parts = path.parts[1:] if path.anchor else path.parts
for part in parts:
current /= part
try:
if current.is_symlink():
return True
except OSError:
return False
return False


def parts_to_prompt(message_parts: Iterable[Any], *, cwd: str | Path) -> str:
values = [part_to_prompt(part, cwd=cwd) for part in message_parts]
return "\n".join(value for value in values if value)
Expand Down
Loading
Loading