diff --git a/examples/01_standalone_sdk/52_dynamic_workflow.py b/examples/01_standalone_sdk/52_dynamic_workflow.py new file mode 100644 index 0000000000..d3ccbd4e56 --- /dev/null +++ b/examples/01_standalone_sdk/52_dynamic_workflow.py @@ -0,0 +1,122 @@ +"""Dynamic workflow tool example. + +This example demonstrates the intended workflow shape: + +1. The parent agent writes a Python workflow script. +2. The parent agent calls the workflow tool with that generated script. +3. The workflow fans out sub-agents to audit test coverage by project area. +4. A reducer sub-agent summarizes the repo-wide coverage risks. +""" + +import os +from pathlib import Path + +from openhands.sdk import LLM, Agent, AgentContext, Conversation, Tool +from openhands.sdk.context import Skill +from openhands.sdk.subagent import register_agent_if_absent +from openhands.tools.delegate import DelegationVisualizer +from openhands.tools.file_editor import FileEditorTool +from openhands.tools.terminal import TerminalTool +from openhands.tools.workflow import WorkflowToolSet + + +llm = LLM( + model=os.getenv("LLM_MODEL", "gpt-5.5"), + api_key=os.getenv("LLM_API_KEY"), + base_url=os.getenv("LLM_BASE_URL"), + usage_id="dynamic-workflow-demo", +) + + +# Sub-agent used by the generated workflow. +def create_coverage_auditor(llm: LLM) -> Agent: + return Agent( + llm=llm, + tools=[ + Tool(name=TerminalTool.name), + Tool(name=FileEditorTool.name), + ], + agent_context=AgentContext( + skills=[ + Skill( + name="coverage_audit", + content=( + "You audit whether source code has meaningful test " + "coverage. Use read-only inspection commands and file " + "views. Compare source modules against the matching " + "tests under tests/sdk, tests/tools, tests/workspace, " + "or tests/agent_server. Identify risky untested " + "behavior, and recommend the " + "next tests to add. Use at most three tool calls, " + "avoid broad dumps, and do not edit files." + ), + trigger=None, + ) + ], + system_message_suffix=( + "Return a concise coverage assessment with evidence, gaps, " + "and recommended tests. Keep command output under 200 lines " + "and do not modify the repository." + ), + ), + ) + + +register_agent_if_absent( + name="coverage_auditor", + factory_func=create_coverage_auditor, + description="Audits test coverage quality for one project area.", +) + +# The parent agent has the workflow tool. It is responsible for writing the +# workflow script and then calling the tool with that generated Python code. +parent_agent = Agent( + llm=llm, + tools=[Tool(name=WorkflowToolSet.name)], + agent_context=AgentContext( + skills=[ + Skill( + name="workflow_author", + content=( + "When a task benefits from parallel sub-agents, write a " + "Python workflow script with `async def main(wf):` and call " + "the workflow tool. Keep intermediate findings inside the " + "workflow and return only the reducer's final report. " + "Prefer bounded prompts and `max_concurrency=2` for " + "examples that inspect repositories." + ), + trigger=None, + ) + ] + ), +) + +conversation = Conversation( + agent=parent_agent, + workspace=Path.cwd(), + visualizer=DelegationVisualizer(name="CoverageWorkflow"), + max_iteration_per_run=6, # increase if more turns needed to write the script +) + +conversation.send_message( + "Write and run a dynamic workflow that audits whether test coverage is " + "good across this repository. In the workflow code you generate, create " + "one item for each project area: `openhands-sdk/openhands/sdk`, " + "`openhands-tools/openhands/tools`, " + "`openhands-workspace/openhands/workspace`, and " + "`openhands-agent-server/openhands/agent_server`. Use `wf.map_agents` " + "with `max_concurrency=2` to fan out one `coverage_auditor` sub-agent " + "per area. Each sub-agent should inspect source files and matching tests " + "under `tests/sdk`, `tests/tools`, `tests/workspace`, or " + "`tests/agent_server` with at most three read-only commands or file views, " + "avoid running the full test suite, and report coverage strengths, risky " + "gaps, and the " + "next tests to add. Finally use `wf.reduce_agent` with " + "`coverage_auditor` to synthesize a " + "repo-wide coverage report with the highest-priority gaps. Return the " + "final report to me." +) +conversation.run() + +cost = conversation.conversation_stats.get_combined_metrics().accumulated_cost +print(f"EXAMPLE_COST: {cost}") diff --git a/openhands-tools/openhands/tools/__init__.py b/openhands-tools/openhands/tools/__init__.py index 17d0eaeaa3..62f77a5ee7 100644 --- a/openhands-tools/openhands/tools/__init__.py +++ b/openhands-tools/openhands/tools/__init__.py @@ -29,6 +29,7 @@ from openhands.tools.task import TaskToolSet from openhands.tools.task_tracker import TaskTrackerTool from openhands.tools.terminal import TerminalTool +from openhands.tools.workflow import WorkflowToolSet try: @@ -44,6 +45,7 @@ "TaskToolSet", "TaskTrackerTool", "TerminalTool", + "WorkflowToolSet", "get_default_agent", "get_default_tools", "register_default_tools", diff --git a/openhands-tools/openhands/tools/task/manager.py b/openhands-tools/openhands/tools/task/manager.py index e483aa8e47..fc6685ab98 100644 --- a/openhands-tools/openhands/tools/task/manager.py +++ b/openhands-tools/openhands/tools/task/manager.py @@ -106,6 +106,23 @@ def __init__( # when the parent persists, otherwise a temporary directory. self._persistence_dir: Path | None = None + def attach_parent(self, conversation: LocalConversation) -> None: + """Attach the parent conversation used to create sub-agent tasks. + + Idempotent: if a parent conversation is already attached, subsequent + calls with the same conversation have no effect. Calls with a different + conversation are also ignored, but log a warning to surface potential + programming errors where two subsystems try to register different parents. + """ + if ( + self._parent_conversation is not None + and self._parent_conversation is not conversation + ): + logger.warning( + "attach_parent called with a different conversation; ignoring." + ) + self._ensure_parent(conversation) + def _ensure_parent(self, conversation: LocalConversation) -> None: if self._parent_conversation is None: self._parent_conversation = conversation diff --git a/openhands-tools/openhands/tools/workflow/__init__.py b/openhands-tools/openhands/tools/workflow/__init__.py new file mode 100644 index 0000000000..f55cc80b03 --- /dev/null +++ b/openhands-tools/openhands/tools/workflow/__init__.py @@ -0,0 +1,24 @@ +"""Dynamic workflow tool for sub-agent orchestration.""" + +from openhands.tools.workflow.definition import ( + WorkflowAction, + WorkflowObservation, + WorkflowTool, + WorkflowToolSet, +) +from openhands.tools.workflow.impl import ( + WorkflowContext, + WorkflowExecutor, + WorkflowScriptError, +) + + +__all__ = [ + "WorkflowAction", + "WorkflowContext", + "WorkflowExecutor", + "WorkflowObservation", + "WorkflowScriptError", + "WorkflowTool", + "WorkflowToolSet", +] diff --git a/openhands-tools/openhands/tools/workflow/definition.py b/openhands-tools/openhands/tools/workflow/definition.py new file mode 100644 index 0000000000..e51c4add88 --- /dev/null +++ b/openhands-tools/openhands/tools/workflow/definition.py @@ -0,0 +1,175 @@ +"""Dynamic workflow tool definitions.""" + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Literal + +from pydantic import Field + +from openhands.sdk.tool import ( + Action, + Observation, + ToolAnnotations, + ToolDefinition, + register_tool, +) + + +if TYPE_CHECKING: + from openhands.sdk.conversation.state import ConversationState + from openhands.tools.workflow.impl import WorkflowExecutor + + +class WorkflowAction(Action): + """Schema for running a Python dynamic workflow script.""" + + name: str = Field(description="A short name for this workflow run.") + script: str = Field( + description=( + "Python workflow script to run. It must define `async def main(wf):` " + "and coordinate work only through the provided `wf` object." + ) + ) + max_concurrency: int = Field( + default=8, + ge=1, + le=64, + description=( + "Maximum number of sub-agent tasks to run concurrently. " + "Consider 2–4 for LLM-heavy workflows to avoid hitting API rate limits." + ), + ) + + +class WorkflowObservation(Observation): + """Observation from a dynamic workflow run.""" + + name: str = Field(description="The workflow name that was executed.") + status: Literal["completed", "error"] = Field( + description="The workflow execution status." + ) + + +_WORKFLOW_DESCRIPTION = """Run a dynamic workflow written as Python orchestration code. + +Use this tool for large tasks that benefit from parallel sub-agents, such as +codebase-wide audits, independent plan reviews, security sweeps, or discovery +work where intermediate results should stay outside the main conversation. + +Provide a Python script that defines exactly this entry point: + +```python +async def main(wf): + ... +``` + +The script coordinates sub-agents through the `wf` object. It should not read or +write files, run shell commands, or perform the engineering work directly. +Sub-agents should do that work through their normal OpenHands tools and security +policy. Scripts should use only the documented `wf` methods; private `wf` +attributes are rejected. Large reducer inputs may be truncated before being sent +to the reducer sub-agent. + +Available `wf` methods: +- `await wf.run_agent(prompt, subagent_type="general-purpose", description=None)` +- `await wf.map_agents(items, prompt, subagent_type="general-purpose", + max_concurrency=None, description=None)` +- `await wf.reduce_agent(items, prompt, subagent_type="general-purpose", + description=None)` +- `wf.flatten(values)` — flatten one level of nesting (not recursive) + +`subagent_type` must be a sub-agent type registered in the parent application. +Use the same type names you registered when building your agent. + +Scripts must use only the documented `wf` methods listed above; calling +`wf.close()` or any other undocumented attribute is not supported. + +`print()` is available for debugging but writes to the server logs, not to +the workflow observation seen by the LLM; use the return value of `main()` to +surface results. + +If one or more `map_agents` items fail, the whole call raises an +`ExceptionGroup`. The name `ExceptionGroup` is not available by name in the +workflow sandbox, so scripts cannot use `except*` for selective group handling. +A plain `except Exception` will still catch the entire group. To handle partial +failures and collect all results, design sub-agent prompts to return an error +sentinel value instead of raising. + +`map_agents` accepts either a callable prompt, such as +`lambda item: f"Review this finding: {item}"`, or a string template containing +`{item}`. + +Example: +```python +async def main(wf): + strategies = ["minimal fix", "test-first", "security-focused"] + plans = await wf.map_agents( + items=strategies, + subagent_type="general-purpose", + max_concurrency=3, + prompt=lambda strategy: f"Create a plan using this strategy: {strategy}", + ) + critiques = await wf.map_agents( + items=plans, + subagent_type="code-reviewer", + prompt=lambda plan: f"Adversarially critique this plan: {plan}", + ) + return await wf.reduce_agent( + items={"plans": plans, "critiques": critiques}, + prompt="Synthesize the safest and simplest final plan.", + ) +``` + +This MVP executes generated Python in-process after best-effort validation. Treat +running a workflow as approving generated code execution. +""" + + +class WorkflowTool(ToolDefinition[WorkflowAction, WorkflowObservation]): + """Low-level tool for explicit executor injection. + + Prefer ``WorkflowToolSet`` for standard SDK auto-create usage. + Use ``WorkflowTool`` when you need to inject a custom executor + (e.g., in tests or extensions). + """ + + @classmethod + def create( + cls, + conv_state: "ConversationState | None" = None, # noqa: ARG003 + executor: "WorkflowExecutor | None" = None, + description: str = _WORKFLOW_DESCRIPTION, + ) -> Sequence["WorkflowTool"]: + from openhands.tools.workflow.impl import WorkflowExecutor + + return [ + cls( + action_type=WorkflowAction, + observation_type=WorkflowObservation, + description=description, + annotations=ToolAnnotations( + title="workflow", + readOnlyHint=False, + destructiveHint=True, + idempotentHint=False, + openWorldHint=True, + ), + executor=executor if executor is not None else WorkflowExecutor(), + ) + ] + + +class WorkflowToolSet(ToolDefinition[WorkflowAction, WorkflowObservation]): + """Tool set that creates the dynamic workflow tool.""" + + @classmethod + def create( + cls, + conv_state: "ConversationState", # noqa: ARG003 + ) -> Sequence[WorkflowTool]: + from openhands.tools.workflow.impl import WorkflowExecutor + + return WorkflowTool.create(executor=WorkflowExecutor()) + + +register_tool(WorkflowToolSet.name, WorkflowToolSet) +register_tool(WorkflowTool.name, WorkflowTool) diff --git a/openhands-tools/openhands/tools/workflow/impl.py b/openhands-tools/openhands/tools/workflow/impl.py new file mode 100644 index 0000000000..5b272e2de6 --- /dev/null +++ b/openhands-tools/openhands/tools/workflow/impl.py @@ -0,0 +1,463 @@ +"""Implementation of the dynamic workflow tool.""" + +from __future__ import annotations + +import ast +import asyncio +import inspect +import json as jsonlib +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, Protocol + +from openhands.sdk.logger import get_logger +from openhands.sdk.tool import ToolExecutor +from openhands.tools.task.manager import TaskManager +from openhands.tools.workflow.definition import WorkflowObservation + + +if TYPE_CHECKING: + from openhands.sdk.conversation.impl.local_conversation import LocalConversation + from openhands.tools.workflow.definition import WorkflowAction + +logger = get_logger(__name__) + +_MAX_SCRIPT_CHARS = 20_000 +_MAX_REDUCE_INPUT_CHARS = 12_000 +_WORKFLOW_TIMEOUT_SECONDS = 3600.0 # 1 hour; prevents indefinitely hung workflows +_UNSAFE_CALLS = frozenset( + { + "breakpoint", + "compile", + "delattr", + "dir", + "eval", + "exec", + "getattr", + "globals", + "input", + "locals", + "open", + "setattr", + "vars", + "__import__", + } +) +# Attribute-root deny-list is intentionally narrow: scripts cannot import +# modules, so only names that are pre-injected via _safe_globals() need to +# be listed here. os and subprocess are the two that would be most harmful +# if they were ever inadvertently exposed. +_UNSAFE_ATTRIBUTE_ROOTS = frozenset({"os", "subprocess"}) + + +class WorkflowScriptError(ValueError): + """Raised when a workflow script is invalid or unsafe.""" + + +class _TaskLike(Protocol): + result: str | None + error: str | None + + +class _TaskStarter(Protocol): + def start_task( + self, + prompt: str, + subagent_type: str = "default", + resume: str | None = None, + description: str | None = None, + conversation: LocalConversation | None = None, + ) -> _TaskLike: ... + + def close(self) -> None: ... + + +class WorkflowContext: + """Small capability object exposed to generated workflow scripts.""" + + def __init__( + self, + parent_conversation: LocalConversation, + max_concurrency: int, + manager: _TaskStarter | None = None, + ) -> None: + if max_concurrency < 1: + raise ValueError("max_concurrency must be at least 1") + self._parent_conversation = parent_conversation + self._max_concurrency = max_concurrency + if manager is None: + task_manager = TaskManager() + task_manager.attach_parent(parent_conversation) + self._manager = task_manager + else: + self._manager = manager + self._semaphore: asyncio.Semaphore | None = None + self._closed = False + + @property + def _default_semaphore(self) -> asyncio.Semaphore: + if self._semaphore is None: + self._semaphore = asyncio.Semaphore(self._max_concurrency) + return self._semaphore + + async def run_agent( + self, + prompt: str, + subagent_type: str = "general-purpose", + description: str | None = None, + ) -> str: + """Run a single sub-agent task and return its final result.""" + async with self._default_semaphore: + return await self._run_agent_task( + prompt=prompt, + subagent_type=subagent_type, + description=description, + ) + + async def _run_agent_task( + self, + prompt: str, + subagent_type: str, + description: str | None, + ) -> str: + # Note: `_TaskStarter.start_task` accepts a `resume` parameter, but + # workflow sub-agents are always fresh tasks; resumption is intentionally + # not exposed through WorkflowContext in the MVP. + if self._closed: + raise WorkflowScriptError("WorkflowContext is already closed") + task = await asyncio.to_thread( + self._manager.start_task, + prompt=prompt, + subagent_type=subagent_type, + description=description, + conversation=self._parent_conversation, + ) + if task.error: + raise RuntimeError(task.error) + return task.result or "" + + async def map_agents( + self, + items: Sequence[Any], + prompt: Callable[[Any], str] | str, + subagent_type: str = "general-purpose", + max_concurrency: int | None = None, + description: Callable[[Any], str] | str | None = None, + ) -> list[str]: + """Run one sub-agent task per item and return results in item order. + + A per-call ``max_concurrency`` caps concurrency for this map operation + only; it is silently capped at the context's ``max_concurrency`` limit. + """ + if max_concurrency is not None and max_concurrency < 1: + raise ValueError("max_concurrency must be at least 1") + semaphore = ( + asyncio.Semaphore(min(max_concurrency, self._max_concurrency)) + if max_concurrency is not None + else self._default_semaphore + ) + + async def run_one(index: int, item: Any) -> str: + rendered_prompt = _render_required_template(prompt, item) + rendered_description = _render_template(description, item) + async with semaphore: + try: + return await self._run_agent_task( + prompt=rendered_prompt, + subagent_type=subagent_type, + description=rendered_description, + ) + except Exception as exc: + raise RuntimeError(f"[item {index + 1}] {exc}") from exc + + results = await asyncio.gather( + *(run_one(i, item) for i, item in enumerate(items)), + return_exceptions=True, + ) + failures = [result for result in results if isinstance(result, BaseException)] + if failures: + exceptions = [ + failure + if isinstance(failure, Exception) + else RuntimeError(str(failure)) + for failure in failures + ] + raise ExceptionGroup( + "map_agents: one or more sub-agents failed", + exceptions, + ) + return [str(result) for result in results] + + async def reduce_agent( + self, + items: Any, + prompt: str, + subagent_type: str = "general-purpose", + description: str | None = None, + ) -> str: + """Run a single reducer sub-agent with serialized intermediate results. + + Delegates to ``run_agent``, which acquires ``_default_semaphore``. + Workflow scripts always await operations sequentially, so the semaphore + is always fully available when ``reduce_agent`` is called. + """ + return await self.run_agent( + prompt=f"{prompt}\n\nInput:\n{_format_value(items)}", + subagent_type=subagent_type, + description=description, + ) + + def flatten(self, values: list[Any]) -> list[Any]: + """Flatten one list level.""" + flattened: list[Any] = [] + for value in values: + if isinstance(value, list): + flattened.extend(value) + else: + flattened.append(value) + return flattened + + def close(self) -> None: + if self._closed: + return + self._closed = True + self._manager.close() + + +def _render_required_template(template: Callable[[Any], str] | str, item: Any) -> str: + if callable(template): + return str(template(item)) + # Plain replace avoids Python's format mini-language attribute traversal + # (e.g. "{item._manager}"), which would bypass the AST private-attribute guard. + if "{item}" not in template: + logger.debug( + "map_agents string template does not contain '{item}'; " + "all sub-agents will receive the same prompt." + ) + # Use json.dumps for non-str items so dicts/lists and scalars are consistently + # serialised as JSON (booleans → true/false, None → null), matching reduce_agent. + serialised = item if isinstance(item, str) else jsonlib.dumps(item, default=str) + return template.replace("{item}", serialised) + + +def _render_template( + template: Callable[[Any], str] | str | None, item: Any +) -> str | None: + if template is None: + return None + return _render_required_template(template, item) + + +def _format_value(value: Any) -> str: + if isinstance(value, str): + text = value + else: + text = jsonlib.dumps(value, indent=2, default=str) + if len(text) <= _MAX_REDUCE_INPUT_CHARS: + return text + # Character-boundary truncation can split mid-token in JSON; element-boundary + # truncation for list/dict inputs would be cleaner but is deferred post-MVP. + return ( + text[:_MAX_REDUCE_INPUT_CHARS] + + "\n... [truncated workflow intermediate results]" + ) + + +def validate_workflow_script(script: str) -> None: + """Perform best-effort validation for generated workflow scripts. + + Note: The private-attribute guard checks the literal name ``wf``, so aliasing + (e.g. ``x = wf; x._attr``) can bypass the check. The attributes accessible + through ``WorkflowContext`` do not expose dangerous capabilities, so this is + a documentation gap rather than a security gap. + """ + if len(script) > _MAX_SCRIPT_CHARS: + raise WorkflowScriptError( + f"Workflow script is too large: {len(script)} > {_MAX_SCRIPT_CHARS}" + ) + + try: + tree = ast.parse(script) + except SyntaxError as e: + raise WorkflowScriptError(f"Workflow script has invalid syntax: {e}") from e + + main_defs = [ + node + for node in tree.body + if isinstance(node, ast.AsyncFunctionDef) and node.name == "main" + ] + if len(main_defs) != 1: + raise WorkflowScriptError( + "Workflow script must define exactly one async main(wf)" + ) + + main_def = main_defs[0] + if ( + len(main_def.args.args) != 1 + or main_def.args.args[0].arg != "wf" + or main_def.args.kwonlyargs + or main_def.args.vararg + or main_def.args.kwarg + or main_def.args.defaults + or main_def.args.posonlyargs + ): + raise WorkflowScriptError("Workflow entry point must be `async def main(wf):`") + + for node in ast.walk(tree): + if isinstance(node, (ast.Import, ast.ImportFrom)): + raise WorkflowScriptError("Workflow scripts may not import modules") + if isinstance(node, ast.Name) and node.id.startswith("__"): + raise WorkflowScriptError("Workflow scripts may not access dunder names") + if isinstance(node, ast.Attribute) and node.attr.startswith("__"): + raise WorkflowScriptError( + "Workflow scripts may not access dunder attributes" + ) + if ( + isinstance(node, ast.Attribute) + and _attribute_root_name(node) == "wf" + and (node.attr.startswith("_") or node.attr == "close") + ): + raise WorkflowScriptError( + "Workflow scripts may not access private wf attributes" + " or call wf.close()" + ) + if ( + isinstance(node, ast.Attribute) + and _attribute_root_name(node) in _UNSAFE_ATTRIBUTE_ROOTS + ): + raise WorkflowScriptError("Workflow scripts may not access unsafe modules") + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + if node.func.id in _UNSAFE_CALLS: + raise WorkflowScriptError( + f"Workflow scripts may not call `{node.func.id}`" + ) + + +def _attribute_root_name(node: ast.Attribute) -> str | None: + value = node.value + while isinstance(value, ast.Attribute): + value = value.value + if isinstance(value, ast.Name): + return value.id + return None + + +def execute_workflow_script(script: str, context: WorkflowContext) -> Any: + """Validate and execute a workflow script from a synchronous context.""" + try: + asyncio.get_running_loop() + except RuntimeError: + pass + else: + raise WorkflowScriptError( + "Workflow scripts must be executed from a synchronous context" + ) + + validate_workflow_script(script) + namespace: dict[str, Any] = {} + exec(compile(script, "", "exec"), _safe_globals(), namespace) + main = namespace.get("main") + if not inspect.iscoroutinefunction(main): + raise WorkflowScriptError("Workflow entry point must be async") + + async def _run_with_timeout() -> Any: + async with asyncio.timeout(_WORKFLOW_TIMEOUT_SECONDS): + return await main(context) + + try: + return asyncio.run(_run_with_timeout()) + except TimeoutError: + raise WorkflowScriptError( + f"Workflow timed out after {_WORKFLOW_TIMEOUT_SECONDS:.0f} seconds" + ) from None + + +def _format_exception(error: Exception) -> str: + if isinstance(error, ExceptionGroup): + details = "\n".join( + f" [{index}] {exception}" + for index, exception in enumerate(error.exceptions, start=1) + ) + return f"{error.args[0]}:\n{details}" + return str(error) + + +def _safe_globals() -> dict[str, Any]: + safe_builtins = { + "abs": abs, + "all": all, + "any": any, + "bool": bool, + "dict": dict, + "enumerate": enumerate, + "Exception": Exception, + "float": float, + "IndexError": IndexError, + "int": int, + "isinstance": isinstance, + "KeyError": KeyError, + "len": len, + "list": list, + "max": max, + "min": min, + "print": print, + "range": range, + "repr": repr, + "round": round, + "RuntimeError": RuntimeError, + "set": set, + "sorted": sorted, + "str": str, + "sum": sum, + "tuple": tuple, + # type() is included for 1-arg introspection (e.g. type(x).__name__). + # 3-arg class creation is permitted; methods DEFINED IN THE SCRIPT execute in + # restricted globals, and the AST validator blocks __dunder__ attribute access + # (closing __subclasses__()-based escapes). Calls to pre-existing injected + # objects such as wf are not re-sandboxed, but those expose only public wf API. + "type": type, + "TypeError": TypeError, + "ValueError": ValueError, + "zip": zip, + "format": format, + } + return {"__builtins__": safe_builtins} + + +class WorkflowExecutor(ToolExecutor["WorkflowAction", WorkflowObservation]): + """Executor for the dynamic workflow tool.""" + + def __call__( + self, + action: WorkflowAction, + conversation: LocalConversation | None = None, + ) -> WorkflowObservation: + if conversation is None: + return WorkflowObservation.from_text( + text="Workflow tool requires a local conversation context.", + name=action.name, + status="error", + is_error=True, + ) + + context = WorkflowContext( + parent_conversation=conversation, + max_concurrency=action.max_concurrency, + ) + try: + result = execute_workflow_script(action.script, context) + return WorkflowObservation.from_text( + text=str(result), + name=action.name, + status="completed", + ) + except Exception as e: + error_text = _format_exception(e) + logger.warning("Workflow '%s' failed: %s", action.name, e, exc_info=True) + return WorkflowObservation.from_text( + text=error_text, + name=action.name, + status="error", + is_error=True, + ) + finally: + context.close() diff --git a/tests/tools/workflow/test_workflow_tool.py b/tests/tools/workflow/test_workflow_tool.py new file mode 100644 index 0000000000..5eec11b83e --- /dev/null +++ b/tests/tools/workflow/test_workflow_tool.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +import asyncio +import threading +from dataclasses import dataclass +from typing import cast + +import pytest + +from openhands.sdk.conversation.impl.local_conversation import LocalConversation +from openhands.tools.workflow import ( + WorkflowAction, + WorkflowContext, + WorkflowExecutor, + WorkflowScriptError, +) +from openhands.tools.workflow.impl import ( + _format_exception, + _format_value, + execute_workflow_script, + validate_workflow_script, +) + + +@dataclass +class _FakeTask: + result: str | None = None + error: str | None = None + + +class _FakeTaskManager: + def __init__(self) -> None: + self.prompts: list[str] = [] + self.descriptions: list[str | None] = [] + self.closed = False + + def start_task( + self, + prompt: str, + subagent_type: str = "default", + resume: str | None = None, + description: str | None = None, + conversation: LocalConversation | None = None, + ) -> _FakeTask: + self.prompts.append(f"{subagent_type}: {prompt}") + self.descriptions.append(description) + return _FakeTask(result=f"result:{prompt}") + + def close(self) -> None: + self.closed = True + + +def _context(manager: _FakeTaskManager, max_concurrency: int = 4) -> WorkflowContext: + return WorkflowContext( + parent_conversation=cast(LocalConversation, object()), + max_concurrency=max_concurrency, + manager=manager, + ) + + +def test_execute_workflow_script_runs_map_and_reduce() -> None: + manager = _FakeTaskManager() + script = """ +async def main(wf): + results = await wf.map_agents( + items=["alpha", "beta"], + subagent_type="researcher", + max_concurrency=2, + prompt=lambda item: f"inspect {item}", + description=lambda item: f"job {item}", + ) + return await wf.reduce_agent( + items=results, + subagent_type="writer", + prompt="summarize the results", + description="final summary", + ) +""" + + result = execute_workflow_script(script, _context(manager)) + + expected_reduce_prompt = ( + 'writer: summarize the results\n\nInput:\n[\n "result:inspect alpha",\n' + ' "result:inspect beta"\n]' + ) + assert result.startswith("result:summarize the results") + # map_agents uses asyncio.to_thread; thread scheduling is non-deterministic so the + # first two prompts may arrive in any order. gather() preserves result ordering but + # not dispatch order — use set comparison for the map phase. + assert set(manager.prompts[:2]) == { + "researcher: inspect alpha", + "researcher: inspect beta", + } + assert manager.prompts[2] == expected_reduce_prompt + assert set(manager.descriptions[:2]) == {"job alpha", "job beta"} + assert manager.descriptions[2] == "final summary" + + +def test_run_agent_returns_task_result() -> None: + manager = _FakeTaskManager() + script = """ +async def main(wf): + return await wf.run_agent("do the thing", subagent_type="analyst") +""" + result = execute_workflow_script(script, _context(manager)) + assert result == "result:do the thing" + assert manager.prompts == ["analyst: do the thing"] + + +def test_map_agents_uses_context_default_concurrency_when_none_given() -> None: + manager = _FakeTaskManager() + script = """ +async def main(wf): + return await wf.map_agents( + items=["one", "two"], + prompt="inspect {item}", + subagent_type="researcher", + ) +""" + + assert execute_workflow_script(script, _context(manager)) == [ + "result:inspect one", + "result:inspect two", + ] + + +def test_map_agents_reports_all_sub_agent_failures() -> None: + class FailingTaskManager(_FakeTaskManager): + def start_task( + self, + prompt: str, + subagent_type: str = "default", + resume: str | None = None, + description: str | None = None, + conversation: LocalConversation | None = None, + ) -> _FakeTask: + self.prompts.append(f"{subagent_type}: {prompt}") + if prompt in {"inspect bad", "inspect worse"}: + return _FakeTask(error=f"failed {prompt}") + return _FakeTask(result=f"result:{prompt}") + + script = """ +async def main(wf): + return await wf.map_agents( + items=["good", "bad", "worse"], + prompt="inspect {item}", + subagent_type="researcher", + ) +""" + manager = FailingTaskManager() + + with pytest.raises(ExceptionGroup) as exc_info: + execute_workflow_script(script, _context(manager)) + + assert "map_agents" in str(exc_info.value) + assert [str(exc) for exc in exc_info.value.exceptions] == [ + "[item 2] failed inspect bad", + "[item 3] failed inspect worse", + ] + assert set(manager.prompts) == { + "researcher: inspect good", + "researcher: inspect bad", + "researcher: inspect worse", + } + + +def test_workflow_script_can_catch_common_exceptions() -> None: + script = """ +async def main(wf): + try: + raise ValueError("recoverable") + except ValueError as exc: + return str(exc) +""" + + assert ( + execute_workflow_script(script, _context(_FakeTaskManager())) == "recoverable" + ) + + +def test_format_value_truncates_large_intermediate_results() -> None: + value = _format_value("x" * 12_050) + + assert len(value) < 12_100 + assert value.endswith("[truncated workflow intermediate results]") + + +def test_format_exception_includes_exception_group_details() -> None: + error = ExceptionGroup( + "map_agents: one or more sub-agents failed", + [RuntimeError("first failure"), RuntimeError("second failure")], + ) + + assert _format_exception(error) == ( + "map_agents: one or more sub-agents failed:\n" + " [1] first failure\n" + " [2] second failure" + ) + + +def test_validate_workflow_script_rejects_missing_async_main() -> None: + with pytest.raises(WorkflowScriptError, match="async main"): + validate_workflow_script("def main(wf):\n return 'nope'\n") + + +def test_validate_workflow_script_rejects_unsafe_calls() -> None: + script = """ +async def main(wf): + return open('secrets.txt').read() +""" + + with pytest.raises(WorkflowScriptError, match="open"): + validate_workflow_script(script) + + +def test_validate_workflow_script_rejects_private_wf_access() -> None: + script = """ +async def main(wf): + return wf._parent_conversation +""" + + with pytest.raises(WorkflowScriptError, match="private wf attributes"): + validate_workflow_script(script) + + +def test_validate_workflow_script_rejects_wf_close() -> None: + script = """ +async def main(wf): + wf.close() +""" + + with pytest.raises(WorkflowScriptError, match="wf.close"): + validate_workflow_script(script) + + +def test_validate_workflow_script_rejects_unsafe_module_access() -> None: + script = """ +async def main(wf): + os.system('echo nope') +""" + + with pytest.raises(WorkflowScriptError, match="unsafe modules"): + validate_workflow_script(script) + + +def test_validate_workflow_script_rejects_imports() -> None: + script = """ +import os + +async def main(wf): + return 'nope' +""" + + with pytest.raises(WorkflowScriptError, match="import"): + validate_workflow_script(script) + + +def test_workflow_executor_returns_error_observation_without_conversation() -> None: + observation = WorkflowExecutor()(WorkflowAction(name="demo", script="")) + + assert observation.is_error + assert observation.status == "error" + assert "requires a local conversation" in observation.text + + +def test_workflow_context_helper_flattens_one_level() -> None: + context = _context(_FakeTaskManager()) + + assert context.flatten([[1, 2], 3, [4]]) == [1, 2, 3, 4] + + +def test_workflow_executor_success_path() -> None: + @dataclass + class _FakeState: + persistence_dir: str | None = None + + @dataclass + class _FakeConv: + state: _FakeState + + conv = cast(LocalConversation, _FakeConv(state=_FakeState())) + action = WorkflowAction( + name="trivial", + script="async def main(wf):\n return 'done'", + ) + + obs = WorkflowExecutor()(action, conversation=conv) + + assert not obs.is_error + assert obs.status == "completed" + assert obs.text == "done" + + +def test_workflow_context_close_propagates_to_manager() -> None: + manager = _FakeTaskManager() + context = _context(manager) + + assert not manager.closed + context.close() + assert manager.closed + + +def test_workflow_context_close_is_idempotent() -> None: + manager = _FakeTaskManager() + context = _context(manager) + + context.close() + context.close() # second call must not raise + assert manager.closed + + +def test_run_agent_raises_after_close() -> None: + manager = _FakeTaskManager() + context = _context(manager) + context.close() + + with pytest.raises(WorkflowScriptError, match="already closed"): + asyncio.run(context.run_agent("any prompt")) + + +def test_map_agents_respects_context_concurrency_cap() -> None: + """Per-call max_concurrency must be silently capped at context max.""" + + class _PeakTrackingManager(_FakeTaskManager): + def __init__(self) -> None: + super().__init__() + self._active = 0 + self.peak_active = 0 + self._lock = threading.Lock() + + def start_task( + self, + prompt: str, + subagent_type: str = "default", + resume: str | None = None, + description: str | None = None, + conversation: LocalConversation | None = None, + ) -> _FakeTask: + with self._lock: + self._active += 1 + self.peak_active = max(self.peak_active, self._active) + try: + return super().start_task( + prompt, + subagent_type=subagent_type, + resume=resume, + description=description, + conversation=conversation, + ) + finally: + with self._lock: + self._active -= 1 + + # Context capped at 3; per-call max_concurrency=1000 should be min'd to 3 + context_cap = 3 + manager = _PeakTrackingManager() + context = _context(manager, max_concurrency=context_cap) + script = """ +async def main(wf): + return await wf.map_agents( + items=list(range(10)), + prompt="task {item}", + max_concurrency=1000, + ) +""" + execute_workflow_script(script, context) + assert manager.peak_active <= context_cap