From f87bf66b42665bb70bff186d9fcd0c52ddbe0b15 Mon Sep 17 00:00:00 2001 From: Muizz Lateef Date: Wed, 1 Apr 2026 20:32:47 +0100 Subject: [PATCH 1/8] feat: add pre-flight validation and dependency-aware execution guards MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address 8 issues discovered at runtime across failed review_analyzer runs. All were statically detectable or preventable with proper execution guards. Three architectural gaps fixed: 1. Failed Result Isolation (Issue #8) - Write DISPOSITION_FAILED on action failure (online + batch paths) - Check disposition FIRST in _verify_completion_status before output check - Add delete_target() to StorageBackend interface - Add --fresh CLI flag to clear stale results before execution 2. Circuit Breaker — Dependency-Aware Execution (Issue #7) - Add _check_upstream_health() before get_previous_outputs (not after) - Skip downstream actions when upstream fails, independent branches continue - Level orchestrator logs failures instead of raising WorkflowError - Return ("completed_with_failures", {...}) for partial success workflows - get_pending_actions now excludes failed actions; add is_workflow_done() 3. Pre-Flight Resolution Service (Issues #1, #2, #6) - WorkflowResolutionService checks API keys (from vendor config model_fields), seed file $file: references, and vendor batch-mode compatibility - Shared resolve_seed_path() utility in utils/path_security.py - AA_SKIP_ENV_VALIDATION=1 escape hatch for CI environments 4. Drop Directive Validation (Issue #4) - _check_drop_directives() in WorkflowStaticAnalyzer validates drop targets against schema/observe/passthrough fields with distinct error messages 5. Lineage Reachability Validation (Issue #5) - Add passthrough_wildcard_sources to OutputSchema (fixes invisible wildcards) - _check_lineage_reachability() traces observe refs through passthrough chains - Emits warnings (not errors) — strict mode can promote them --- agent_actions/cli/run.py | 10 + agent_actions/storage/backend.py | 11 + .../storage/backends/sqlite_backend.py | 29 ++ agent_actions/utils/path_security.py | 42 +++ .../preflight/resolution_service.py | 291 ++++++++++++++++++ agent_actions/validation/run_validator.py | 3 + .../static_analyzer/data_flow_graph.py | 1 + .../static_analyzer/schema_extractor.py | 5 + .../workflow_static_analyzer.py | 218 ++++++++++++- agent_actions/workflow/coordinator.py | 66 +++- agent_actions/workflow/executor.py | 119 +++++++ agent_actions/workflow/managers/state.py | 16 +- agent_actions/workflow/models.py | 1 + .../workflow/parallel/action_executor.py | 29 +- .../workflow/test_coordinator_sequential.py | 27 +- .../unit/workflow/test_executor_lifecycle.py | 6 + tests/unit/workflow/test_limits.py | 2 + .../test_stale_completion_verification.py | 1 + 18 files changed, 841 insertions(+), 36 deletions(-) create mode 100644 agent_actions/utils/path_security.py create mode 100644 agent_actions/validation/preflight/resolution_service.py diff --git a/agent_actions/cli/run.py b/agent_actions/cli/run.py index 76e8212..c8e42ac 100644 --- a/agent_actions/cli/run.py +++ b/agent_actions/cli/run.py @@ -89,6 +89,7 @@ def execute(self, project_root: Path | None = None) -> None: use_tools=self.args.use_tools, run_upstream=self.args.upstream, run_downstream=self.args.downstream, + fresh=self.args.fresh, project_root=project_root, ) ) @@ -187,6 +188,12 @@ def execute(self, project_root: Path | None = None) -> None: is_flag=True, help="Execute all downstream workflows that depend on this workflow", ) +@click.option( + "--fresh", + is_flag=True, + default=False, + help="Clear stored results and status before execution (useful after failed runs)", +) @handles_user_errors("run") @requires_project def run( @@ -197,6 +204,7 @@ def run( concurrency_limit: int = 5, upstream: bool = False, downstream: bool = False, + fresh: bool = False, project_root: Path | None = None, ) -> None: """ @@ -211,6 +219,7 @@ def run( agac run -a my_agent --upstream agac run -a my_agent --downstream agac run -a my_agent --execution-mode parallel + agac run -a my_agent --fresh """ args = RunCommandArgs( agent=agent, @@ -220,6 +229,7 @@ def run( concurrency_limit=concurrency_limit, upstream=upstream, downstream=downstream, + fresh=fresh, ) command = RunCommand(args) command.execute(project_root=project_root) diff --git a/agent_actions/storage/backend.py b/agent_actions/storage/backend.py index 4d564b9..484d77e 100644 --- a/agent_actions/storage/backend.py +++ b/agent_actions/storage/backend.py @@ -149,6 +149,17 @@ def clear_disposition( """Delete matching disposition records. Returns count deleted.""" return 0 + def delete_target(self, action_name: str) -> int: + """Delete all target data for an action. Returns count deleted. + + Subclasses **must** override — the default raises so that backend + authors are forced to implement it and ``--fresh`` cannot silently + leave stale data behind. + """ + raise NotImplementedError( + f"{type(self).__name__} must implement delete_target()" + ) + def close(self) -> None: # noqa: B027 """Close the storage backend and release resources.""" pass diff --git a/agent_actions/storage/backends/sqlite_backend.py b/agent_actions/storage/backends/sqlite_backend.py index 231168b..0f229fa 100644 --- a/agent_actions/storage/backends/sqlite_backend.py +++ b/agent_actions/storage/backends/sqlite_backend.py @@ -636,6 +636,35 @@ def clear_disposition( ) raise + def delete_target(self, action_name: str) -> int: + """Delete all target data for a specific action. Returns count deleted.""" + action_name = self._validate_identifier(action_name, "action_name") + with self._lock: + cursor = self.connection.cursor() + try: + cursor.execute( + "DELETE FROM target_data WHERE action_name = ?", + (action_name,), + ) + self.connection.commit() + deleted = cursor.rowcount + logger.debug( + "Deleted %d target records for %s", + deleted, + action_name, + extra={"workflow_name": self.workflow_name}, + ) + return deleted + except sqlite3.Error as e: + self.connection.rollback() + logger.error( + "Failed to delete target for %s: %s", + action_name, + e, + extra={"workflow_name": self.workflow_name}, + ) + raise + @staticmethod def _format_size(size_bytes: int) -> str: """Format bytes as human-readable size.""" diff --git a/agent_actions/utils/path_security.py b/agent_actions/utils/path_security.py new file mode 100644 index 0000000..987b971 --- /dev/null +++ b/agent_actions/utils/path_security.py @@ -0,0 +1,42 @@ +"""Shared path security utilities for seed data resolution. + +Both the pre-flight resolution service and the runtime StaticDataLoader +call ``resolve_seed_path`` so that path-traversal prevention logic exists +in exactly one place. +""" + +from pathlib import Path + +FILE_PREFIX = "$file:" + + +def resolve_seed_path(file_spec: str, base_dir: Path) -> Path: + """Parse a ``$file:`` reference, resolve against *base_dir*, and validate. + + Returns the resolved absolute ``Path``. + + Raises: + ValueError: If the spec is empty, escapes *base_dir* via traversal, + or is otherwise invalid. + """ + if not file_spec: + raise ValueError("Empty file spec") + + # Strip $file: prefix if present + file_path = file_spec[len(FILE_PREFIX) :] if file_spec.startswith(FILE_PREFIX) else file_spec + + if not file_path: + raise ValueError(f"Empty path after prefix in: {file_spec}") + + resolved = (base_dir / file_path).resolve() + + # Security: prevent path traversal outside base_dir + try: + resolved.relative_to(base_dir.resolve()) + except ValueError: + raise ValueError( + f"Seed file path escapes base directory: {file_spec} " + f"(resolved to {resolved}, base is {base_dir.resolve()})" + ) from None + + return resolved diff --git a/agent_actions/validation/preflight/resolution_service.py b/agent_actions/validation/preflight/resolution_service.py new file mode 100644 index 0000000..8239024 --- /dev/null +++ b/agent_actions/validation/preflight/resolution_service.py @@ -0,0 +1,291 @@ +"""Unified pre-flight resolution service. + +Performs a single comprehensive resolution pass across all actions: +- API key environment variable presence +- Seed file ($file:) reference existence +- Provider capability / run_mode compatibility + +Uses the same resolution utilities that runtime uses, ensuring no divergence. +""" + +import logging +import os +from pathlib import Path +from typing import Any + +from agent_actions.utils.path_security import resolve_seed_path +from agent_actions.validation.static_analyzer.errors import ( + FieldLocation, + StaticTypeError, + StaticValidationResult, +) + +logger = logging.getLogger(__name__) + +# Vendor name → config class mapping. Built lazily on first access to +# avoid importing all vendor configs (and transitively their SDKs) at +# module level. +_VENDOR_CONFIG_MAP: dict[str, type] | None = None + +# Sentinel substrings in api_key_env_name that indicate no real key is needed. +_NO_KEY_SENTINELS = ("NO_KEY_REQUIRED",) + + +def _get_vendor_config_map() -> dict[str, type]: + """Build vendor → config class map on first call (lazy).""" + global _VENDOR_CONFIG_MAP # noqa: PLW0603 + if _VENDOR_CONFIG_MAP is not None: + return _VENDOR_CONFIG_MAP + + from agent_actions.llm.config.vendor import ( + AgacProviderConfig, + AnthropicConfig, + CohereConfig, + GeminiConfig, + GroqConfig, + HitlVendorConfig, + MistralConfig, + OllamaConfig, + OpenAIConfig, + ToolVendorConfig, + ) + + _VENDOR_CONFIG_MAP = { + "openai": OpenAIConfig, + "anthropic": AnthropicConfig, + "gemini": GeminiConfig, + "google": GeminiConfig, + "groq": GroqConfig, + "cohere": CohereConfig, + "mistral": MistralConfig, + "ollama": OllamaConfig, + "tool": ToolVendorConfig, + "hitl": HitlVendorConfig, + "agac-provider": AgacProviderConfig, + } + return _VENDOR_CONFIG_MAP + + +def _get_api_key_env_name(vendor: str) -> str | None: + """Resolve API key env var name from vendor config class (single source of truth).""" + config_cls = _get_vendor_config_map().get(vendor.lower()) + if config_cls is None: + return None + field_info = config_cls.model_fields.get("api_key_env_name") + if field_info is None: + return None + return field_info.default + + +class WorkflowResolutionService: + """Performs unified pre-flight resolution checks.""" + + def __init__( + self, + action_configs: dict[str, dict[str, Any]], + workflow_config_path: str | None = None, + project_root: Path | None = None, + ): + self.action_configs = action_configs + self.workflow_config_path = workflow_config_path + self.project_root = project_root + + def resolve_all(self) -> StaticValidationResult: + """Run all resolution checks and return aggregated result.""" + result = StaticValidationResult() + + if os.environ.get("AA_SKIP_ENV_VALIDATION") != "1": + for error in self._check_api_keys(): + result.add_error(error) + + for error in self._check_seed_file_references(): + result.add_error(error) + + for error in self._check_vendor_run_mode_compatibility(): + result.add_error(error) + + return result + + # ── API key checks ───────────────────────────────────────────────── + + def _check_api_keys(self) -> list[StaticTypeError]: + """Check that all required API key env vars are set.""" + errors: list[StaticTypeError] = [] + + for action_name, config in self.action_configs.items(): + vendor = (config.get("model_vendor") or "").lower() + if not vendor: + continue + + # Resolve the expected env var name from vendor config + env_var_name = _get_api_key_env_name(vendor) + if env_var_name is None: + continue + + # Skip vendors that don't need real keys (tool, hitl) + if any(sentinel in env_var_name for sentinel in _NO_KEY_SENTINELS): + continue + + # If the action config specifies a custom api_key, use that + custom_key = config.get("api_key") + if custom_key: + custom_str = str(custom_key) + if custom_str.startswith("$"): + env_var_name = custom_str[1:] + else: + # Literal key provided — skip env check + continue + + if not os.environ.get(env_var_name): + errors.append( + StaticTypeError( + message=( + f"API key environment variable '{env_var_name}' is not set " + f"(required by action '{action_name}', vendor '{vendor}')" + ), + location=FieldLocation( + agent_name=action_name, + config_field="api_key", + raw_reference=env_var_name, + ), + referenced_agent=action_name, + referenced_field="api_key", + hint=f"Set the environment variable: export {env_var_name}=your_key_here", + ) + ) + + return errors + + # ── Seed file checks ─────────────────────────────────────────────── + + def _check_seed_file_references(self) -> list[StaticTypeError]: + """Check that all $file: references resolve to existing files.""" + errors: list[StaticTypeError] = [] + + seed_data_dir = self._resolve_seed_data_dir() + if seed_data_dir is None: + return errors + + for action_name, config in self.action_configs.items(): + context_scope = config.get("context_scope", {}) + if not isinstance(context_scope, dict): + continue + seed_path_config = context_scope.get("seed_path", {}) + if not seed_path_config or not isinstance(seed_path_config, dict): + continue + + for field_name, file_spec in seed_path_config.items(): + if not isinstance(file_spec, str): + continue + + try: + resolved = resolve_seed_path(file_spec, seed_data_dir) + except ValueError as e: + errors.append( + StaticTypeError( + message=str(e), + location=FieldLocation( + agent_name=action_name, + config_field=f"context_scope.seed_path.{field_name}", + raw_reference=file_spec, + ), + referenced_agent=action_name, + referenced_field=field_name, + hint="Use relative paths within the seed_data/ directory.", + ) + ) + continue + + if not resolved.exists(): + available: list[str] = [] + if seed_data_dir.exists(): + available = sorted(f.name for f in seed_data_dir.iterdir() if f.is_file()) + + errors.append( + StaticTypeError( + message=( + f"Seed file not found: {file_spec} " + f"(resolved to {resolved})" + ), + location=FieldLocation( + agent_name=action_name, + config_field=f"context_scope.seed_path.{field_name}", + raw_reference=file_spec, + ), + referenced_agent=action_name, + referenced_field=field_name, + available_fields=set(available), + hint=( + f"Available files: {', '.join(available)}" + if available + else "(seed_data/ directory is empty)" + ), + ) + ) + + return errors + + # ── Vendor run-mode compatibility ────────────────────────────────── + + def _check_vendor_run_mode_compatibility(self) -> list[StaticTypeError]: + """Check that vendor supports the requested run_mode.""" + errors: list[StaticTypeError] = [] + + from agent_actions.validation.preflight.vendor_compatibility_validator import ( + _resolve_capabilities, + ) + + for action_name, config in self.action_configs.items(): + vendor = (config.get("model_vendor") or "").lower() + run_mode = config.get("run_mode", "online") + + # Normalize RunMode enum to string + if hasattr(run_mode, "value"): + run_mode = run_mode.value + + if run_mode != "batch": + continue + + capabilities = _resolve_capabilities(vendor) + if capabilities is None: + continue + + if not capabilities.get("supports_batch"): + errors.append( + StaticTypeError( + message=( + f"Action '{action_name}' uses run_mode=batch with vendor " + f"'{vendor}', but {vendor} does not support batch mode" + ), + location=FieldLocation( + agent_name=action_name, + config_field="run_mode", + raw_reference=f"run_mode=batch, vendor={vendor}", + ), + referenced_agent=action_name, + referenced_field="run_mode", + hint=f"Use run_mode: online for {vendor} actions, or choose a batch-capable vendor.", + ) + ) + + return errors + + # ── Helpers ──────────────────────────────────────────────────────── + + def _resolve_seed_data_dir(self) -> Path | None: + """Resolve the seed_data directory from workflow config path.""" + if not self.workflow_config_path: + return None + + config_path = Path(self.workflow_config_path).resolve() + current = config_path.parent + while current != current.parent: + if (current / "agent_config").exists(): + seed_dir = current / "seed_data" + return seed_dir if seed_dir.exists() else None + if current.name == "agent_config": + seed_dir = current.parent / "seed_data" + return seed_dir if seed_dir.exists() else None + current = current.parent + + return None diff --git a/agent_actions/validation/run_validator.py b/agent_actions/validation/run_validator.py index 36cb62a..d9b2f12 100644 --- a/agent_actions/validation/run_validator.py +++ b/agent_actions/validation/run_validator.py @@ -26,3 +26,6 @@ class RunCommandArgs(BaseModel): downstream: bool = Field( False, description="Execute all downstream workflows that depend on this workflow" ) + fresh: bool = Field( + False, description="Clear stored results and status before execution" + ) diff --git a/agent_actions/validation/static_analyzer/data_flow_graph.py b/agent_actions/validation/static_analyzer/data_flow_graph.py index 91a1eb2..a630bd0 100644 --- a/agent_actions/validation/static_analyzer/data_flow_graph.py +++ b/agent_actions/validation/static_analyzer/data_flow_graph.py @@ -15,6 +15,7 @@ class OutputSchema: schema_fields: set[str] = field(default_factory=set) observe_fields: set[str] = field(default_factory=set) passthrough_fields: set[str] = field(default_factory=set) + passthrough_wildcard_sources: set[str] = field(default_factory=set) dropped_fields: set[str] = field(default_factory=set) json_schema: dict[str, Any] | None = None is_dynamic: bool = False diff --git a/agent_actions/validation/static_analyzer/schema_extractor.py b/agent_actions/validation/static_analyzer/schema_extractor.py index 9f8f644..9074663 100644 --- a/agent_actions/validation/static_analyzer/schema_extractor.py +++ b/agent_actions/validation/static_analyzer/schema_extractor.py @@ -416,6 +416,11 @@ def _apply_context_scope(self, config: dict[str, Any], output: OutputSchema) -> field_name = self._extract_field_name(ref) if field_name: output.passthrough_fields.add(field_name) + elif isinstance(ref, str) and ".*" in ref: + # Wildcard passthrough: "source.*" → record the source name + source_name = ref.split(".", 1)[0] + if source_name: + output.passthrough_wildcard_sources.add(source_name) scope_observe = context_scope.get("observe", []) for ref in scope_observe: diff --git a/agent_actions/validation/static_analyzer/workflow_static_analyzer.py b/agent_actions/validation/static_analyzer/workflow_static_analyzer.py index a7a5552..0d9124c 100644 --- a/agent_actions/validation/static_analyzer/workflow_static_analyzer.py +++ b/agent_actions/validation/static_analyzer/workflow_static_analyzer.py @@ -22,7 +22,7 @@ InputSchema, OutputSchema, ) -from .errors import FieldLocation, StaticTypeError, StaticValidationResult +from .errors import FieldLocation, StaticTypeError, StaticTypeWarning, StaticValidationResult from .reference_extractor import ReferenceExtractor from .schema_extractor import SchemaExtractor @@ -121,11 +121,19 @@ def analyze(self) -> StaticValidationResult: for error in self._check_schema_structures(): result.add_error(error) + # Step 2f: Validate drop directives target schema/observe fields + for error in self._check_drop_directives(): + result.add_error(error) + # Step 3: Check for unused dependencies (add as warnings) warnings = checker.check_unused_dependencies() for warning in warnings: result.add_warning(warning) + # Step 3b: Check lineage reachability for observe/passthrough references + for warning in self._check_lineage_reachability(): + result.add_warning(warning) + return result def _build_graph(self) -> None: @@ -389,6 +397,214 @@ def _check_schema_structures(self) -> list[StaticTypeError]: return errors + def _check_drop_directives(self) -> list[StaticTypeError]: + """Validate that drop directives reference actual schema/observe fields. + + Drop directives remove fields from the LLM context. If the referenced + field is a passthrough field (not in the LLM context namespace), the + drop is a no-op and the user should be warned. + """ + errors: list[StaticTypeError] = [] + actions = self.workflow_config.get("actions", []) + + for action in actions: + if not isinstance(action, dict): + continue + + action_name = action.get("name", "unknown") + context_scope = action.get("context_scope", {}) + if not isinstance(context_scope, dict): + continue + drop_refs = context_scope.get("drop", []) + if not isinstance(drop_refs, list): + continue + + for drop_ref in drop_refs: + if not isinstance(drop_ref, str) or "." not in drop_ref: + continue + + dep_name, field_name = drop_ref.split(".", 1) + + if dep_name in SPECIAL_NAMESPACES or dep_name == "loop": + continue + if field_name == "*": + continue + + dep_node = self.graph.get_node(dep_name) + if not dep_node: + continue # Unknown dep — caught by other checks + + output = dep_node.output_schema + if output.is_dynamic or output.is_schemaless: + continue # Can't validate + + if field_name in output.schema_fields or field_name in output.observe_fields: + continue # Valid drop target + + if field_name in output.passthrough_fields: + errors.append( + StaticTypeError( + message=( + f"Drop directive '{drop_ref}' targets passthrough field " + f"'{field_name}' on '{dep_name}'. Passthrough fields are not " + f"in the LLM context namespace, so this drop has no effect." + ), + location=FieldLocation( + agent_name=action_name, + config_field="context_scope.drop", + raw_reference=drop_ref, + ), + referenced_agent=dep_name, + referenced_field=field_name, + available_fields=output.schema_fields | output.observe_fields, + hint=( + f"Remove this drop directive. '{field_name}' is a passthrough " + f"field — it doesn't appear in the LLM context. " + f"Schema fields: {', '.join(sorted(output.schema_fields))}" + ), + ) + ) + elif field_name not in output.available_fields: + errors.append( + StaticTypeError( + message=( + f"Drop directive '{drop_ref}' references non-existent field " + f"'{field_name}' in '{dep_name}'" + ), + location=FieldLocation( + agent_name=action_name, + config_field="context_scope.drop", + raw_reference=drop_ref, + ), + referenced_agent=dep_name, + referenced_field=field_name, + available_fields=output.schema_fields | output.observe_fields, + hint=( + f"Available schema fields in '{dep_name}': " + f"{', '.join(sorted(output.schema_fields))}" + ), + ) + ) + + return errors + + def _check_lineage_reachability(self) -> list[StaticTypeWarning]: + """Check that observe references to non-direct-dependencies are reachable. + + When action C observes ``A.field`` but C only depends on B (not A directly), + the data must flow A → B → C via passthrough on B. This check verifies + the passthrough chain exists. + """ + warnings: list[StaticTypeWarning] = [] + actions = self.workflow_config.get("actions", []) + + for action in actions: + if not isinstance(action, dict): + continue + + node_name = action.get("name", "unknown") + node = self.graph.get_node(node_name) + if not node: + continue + + context_scope = action.get("context_scope", {}) + if not isinstance(context_scope, dict): + continue + + observe_refs = context_scope.get("observe", []) + if not isinstance(observe_refs, list): + continue + + for ref in observe_refs: + if not isinstance(ref, str) or "." not in ref: + continue + + source_name, field_name = ref.split(".", 1) + + if source_name in SPECIAL_NAMESPACES or source_name == "loop": + continue + if field_name == "*": + continue + + # If source is a direct dependency, no lineage concern + if source_name in node.dependencies: + continue + + # Source is NOT a direct dependency — must travel through intermediates + reachable = self.graph.get_reachable_upstream_names(node_name) + if source_name not in reachable: + continue # Not reachable at all — caught by type checker + + if not self._trace_field_through_chain(source_name, field_name, node_name): + warnings.append( + StaticTypeWarning( + message=( + f"Observe reference '{ref}' on '{node_name}' references " + f"non-direct dependency '{source_name}'. The field " + f"'{field_name}' may not survive through intermediate " + f"actions via passthrough." + ), + location=FieldLocation( + agent_name=node_name, + config_field="context_scope.observe", + raw_reference=ref, + ), + referenced_agent=source_name, + referenced_field=field_name, + hint=( + f"Ensure intermediate actions between '{source_name}' and " + f"'{node_name}' have passthrough: ['{source_name}.*'] or " + f"passthrough: ['{source_name}.{field_name}'] in their " + f"context_scope." + ), + ) + ) + + return warnings + + def _trace_field_through_chain( + self, source: str, field: str, target: str + ) -> bool: + """Check if a field from *source* can reach *target* through passthrough chains. + + BFS backwards from *target* through dependencies. At each intermediate + node, checks whether the field survives (exact passthrough, wildcard + passthrough, or dynamic schema). + """ + target_node = self.graph.get_node(target) + if not target_node: + return False + + visited: set[str] = set() + queue = list(target_node.dependencies) + + while queue: + current = queue.pop(0) + if current in visited: + continue + visited.add(current) + + if current == source: + return True # Direct path found + + current_node = self.graph.get_node(current) + if not current_node: + continue + + output = current_node.output_schema + survives = ( + source in output.passthrough_wildcard_sources + or field in output.passthrough_fields + or output.is_dynamic + ) + + if survives: + for dep in current_node.dependencies: + if dep not in visited: + queue.append(dep) + + return False + def _add_source_node(self) -> None: """Add the special source node for workflow input.""" if self.source_schema: diff --git a/agent_actions/workflow/coordinator.py b/agent_actions/workflow/coordinator.py index d2ccf20..e4d16fa 100644 --- a/agent_actions/workflow/coordinator.py +++ b/agent_actions/workflow/coordinator.py @@ -86,6 +86,10 @@ def __init__(self, config: WorkflowRuntimeConfig): self.metadata, config, self.storage_backend, self.console ) + # Fresh run: clear stored results + status before anything else + if config.fresh: + self._clear_for_fresh_run() + # Dependency orchestration + session self._init_dependency_orchestrator() self.workflow_session_id = self._generate_workflow_session_id() @@ -125,9 +129,35 @@ def _run_static_validation(self) -> None: hint="Fix the guard condition syntax errors above before running the workflow.", ) + # Resolution checks: API keys, seed files, vendor batch compatibility + from agent_actions.validation.preflight.resolution_service import ( + WorkflowResolutionService, + ) + + resolution_result = WorkflowResolutionService( + action_configs=self.action_configs, + workflow_config_path=self.config.paths.constructor_path, + project_root=self.config.project_root, + ).resolve_all() + resolution_result.raise_if_invalid() + def _validate_guard_conditions(self) -> list[str]: return validate_guard_conditions(self.action_configs) + def _clear_for_fresh_run(self) -> None: + """Clear stored results, dispositions, and status for a fresh run.""" + for action_name in self.execution_order: + try: + self.storage_backend.delete_target(action_name) + self.storage_backend.clear_disposition(action_name) + except Exception as e: + logger.warning("Failed to clear stored data for %s: %s", action_name, e) + self.services.core.state_manager._initialize_default_status() + self.services.core.state_manager._save_status() + self.console.print( + "[yellow]--fresh: cleared stored results and reset all actions to pending[/yellow]" + ) + # ── Properties ────────────────────────────────────────────────────── @property @@ -311,14 +341,22 @@ async def async_run(self, concurrency_limit: int = 5): if not level_complete: return + state_mgr = self.services.core.state_manager duration = (datetime.now() - workflow_start).total_seconds() self.event_logger.finalize_workflow(elapsed_time=duration) - downstream_success = self._resolve_downstream_workflows() - if not downstream_success: - return None + if state_mgr.is_workflow_complete(): + downstream_success = self._resolve_downstream_workflows() + if not downstream_success: + return None + return ("success", {}) - return ("success", {}) + if state_mgr.is_workflow_done(): + self.state.failed = True + failed = state_mgr.get_failed_actions(self.execution_order) + return ("completed_with_failures", {"failed": failed}) + + return None except Exception as e: duration = (datetime.now() - workflow_start).total_seconds() @@ -362,8 +400,10 @@ def _run_workflow_with_context(self, workflow_start): if should_stop: break - if self.services.core.state_manager.is_workflow_complete(): - duration = (datetime.now() - workflow_start).total_seconds() + state_mgr = self.services.core.state_manager + duration = (datetime.now() - workflow_start).total_seconds() + + if state_mgr.is_workflow_complete(): self.event_logger.finalize_workflow(elapsed_time=duration) downstream_success = self._resolve_downstream_workflows() @@ -372,6 +412,13 @@ def _run_workflow_with_context(self, workflow_start): return ("success", {}) + if state_mgr.is_workflow_done(): + # All actions reached a terminal state but some failed + self.state.failed = True + self.event_logger.finalize_workflow(elapsed_time=duration) + failed = state_mgr.get_failed_actions(self.execution_order) + return ("completed_with_failures", {"failed": failed}) + return None except Exception as e: @@ -425,6 +472,9 @@ def _run_single_action(self, idx: int, action_name: str, total_actions: int) -> if result.status == "batch_submitted": return True + if result.status == "skipped": + return False # Continue to next action + if result.output_folder and result.status == "completed": self.state.ephemeral_directories.append( { @@ -434,4 +484,6 @@ def _run_single_action(self, idx: int, action_name: str, total_actions: int) -> ) return False - raise result.error + # Action failed — log and continue (circuit breaker handles downstream) + logger.warning("Action '%s' failed: %s", action_name, result.error) + return False diff --git a/agent_actions/workflow/executor.py b/agent_actions/workflow/executor.py index 0a5288c..ac45c62 100644 --- a/agent_actions/workflow/executor.py +++ b/agent_actions/workflow/executor.py @@ -19,6 +19,7 @@ BatchCompleteEvent, BatchSubmittedEvent, ) +from agent_actions.storage.backend import DISPOSITION_FAILED, NODE_LEVEL_RECORD_ID from agent_actions.tooling.docs.run_tracker import ActionCompleteConfig from agent_actions.utils.constants import DEFAULT_ACTION_KIND @@ -174,6 +175,26 @@ def _verify_completion_status( storage_backend = getattr(self.deps.action_runner, "storage_backend", None) if storage_backend is not None: try: + # Check disposition FIRST — a failed action may have partial + # results in storage. The disposition is the authoritative + # signal; output existence is irrelevant when it's set. + if storage_backend.has_disposition( + action_name, + DISPOSITION_FAILED, + record_id=NODE_LEVEL_RECORD_ID, + ): + logger.info( + "Action %s has DISPOSITION_FAILED from prior run — re-running", + action_name, + ) + storage_backend.clear_disposition( + action_name, + DISPOSITION_FAILED, + record_id=NODE_LEVEL_RECORD_ID, + ) + self.deps.state_manager.update_status(action_name, "pending") + return (False, None) + target_files = storage_backend.list_target_files(action_name) if not target_files: logger.info( @@ -335,12 +356,31 @@ def _handle_run_success( ), ) + def _write_failed_disposition(self, action_name: str, reason: str) -> None: + """Write DISPOSITION_FAILED to storage so downstream and future runs detect the failure.""" + storage_backend = getattr(self.deps.action_runner, "storage_backend", None) + if storage_backend is not None: + try: + storage_backend.set_disposition( + action_name=action_name, + record_id=NODE_LEVEL_RECORD_ID, + disposition=DISPOSITION_FAILED, + reason=reason[:500], + ) + except Exception as disp_err: + logger.warning( + "Failed to write DISPOSITION_FAILED for %s: %s", + action_name, + disp_err, + ) + def _handle_run_failure( self, params: ActionRunParams, error: Exception ) -> ActionExecutionResult: """Handle action run failure.""" duration = (datetime.now() - params.start_time).total_seconds() self.deps.state_manager.update_status(params.action_name, "failed") + self._write_failed_disposition(params.action_name, str(error)) if self.run_tracker is not None and self.run_id is not None: config = ActionCompleteConfig( @@ -373,6 +413,68 @@ def _cleanup_correlation( }, ) + def _check_upstream_health( + self, action_name: str, action_config: ActionConfigDict + ) -> str | None: + """Return the name of a failed upstream dependency, or None if all healthy.""" + dependencies = action_config.get("dependencies", []) + if not dependencies: + return None + for dep in dependencies: + if not isinstance(dep, str): + continue + if self.deps.state_manager.is_failed(dep): + return dep + # Also check disposition — covers cascaded failures from prior levels + storage_backend = getattr(self.deps.action_runner, "storage_backend", None) + if storage_backend is not None and storage_backend.has_disposition( + dep, DISPOSITION_FAILED, record_id=NODE_LEVEL_RECORD_ID + ): + return dep + return None + + def _handle_dependency_skip( + self, + action_name: str, + action_idx: int, + action_config: ActionConfigDict, + start_time: datetime, + failed_dependency: str, + ) -> ActionExecutionResult: + """Handle action skip due to upstream dependency failure.""" + reason = f"Upstream dependency '{failed_dependency}' failed" + self.deps.state_manager.update_status(action_name, "failed") + self._write_failed_disposition(action_name, reason) + + duration = (datetime.now() - start_time).total_seconds() + total_actions = ( + len(self.deps.action_runner.execution_order) + if hasattr(self.deps.action_runner, "execution_order") + else 0 + ) + fire_event( + ActionSkipEvent( + action_name=action_name, + action_index=action_idx, + total_actions=total_actions, + skip_reason=reason, + ) + ) + + if self.run_tracker is not None and self.run_id is not None: + config = ActionCompleteConfig( + run_id=self.run_id, + action_name=action_name, + status="skipped", + duration_seconds=duration, + skip_reason=reason, + ) + self.run_tracker.record_action_complete(config=config) + + return ActionExecutionResult( + success=True, status="skipped", metrics=ExecutionMetrics(duration=duration) + ) + def execute_action_sync( self, action_name: str, @@ -413,6 +515,14 @@ def execute_action_sync( if current_status == "batch_submitted": return self._handle_batch_check(action_name, action_idx, action_config, start_time) + # Circuit breaker: skip if any upstream dependency has failed. + # Must run BEFORE get_previous_outputs to avoid reading corrupt data. + failed_dep = self._check_upstream_health(action_name, action_config) + if failed_dep is not None: + return self._handle_dependency_skip( + action_name, action_idx, action_config, start_time, failed_dep + ) + previous_outputs = self.deps.output_manager.get_previous_outputs(action_idx) if self.deps.skip_evaluator.should_skip_action(action_config, previous_outputs): return self._handle_action_skip(action_name, action_idx, action_config, start_time) @@ -469,6 +579,13 @@ async def execute_action_async( action_name, action_idx, action_config, start_time ) + # Circuit breaker: skip if any upstream dependency has failed. + failed_dep = self._check_upstream_health(action_name, action_config) + if failed_dep is not None: + return self._handle_dependency_skip( + action_name, action_idx, action_config, start_time, failed_dep + ) + previous_outputs = self.deps.output_manager.get_previous_outputs(action_idx) if self.deps.skip_evaluator.should_skip_action(action_config, previous_outputs): return self._handle_action_skip(action_name, action_idx, action_config, start_time) @@ -538,6 +655,7 @@ def _handle_batch_check( ) self.deps.state_manager.update_status(action_name, "failed") + self._write_failed_disposition(action_name, f"Batch job for {action_name} failed") fire_event( BatchCompleteEvent( batch_id=action_config.get("batch_id", ""), @@ -611,6 +729,7 @@ async def _handle_batch_check_async( ) self.deps.state_manager.update_status(action_name, "failed") + self._write_failed_disposition(action_name, f"Batch job for {action_name} failed") fire_event( BatchCompleteEvent( batch_id=action_config.get("batch_id", ""), diff --git a/agent_actions/workflow/managers/state.py b/agent_actions/workflow/managers/state.py index 0bcda66..2ab2a32 100644 --- a/agent_actions/workflow/managers/state.py +++ b/agent_actions/workflow/managers/state.py @@ -79,8 +79,13 @@ def is_failed(self, action_name: str) -> bool: return self.get_status(action_name) == "failed" def get_pending_actions(self, agents: list[str]) -> list[str]: - """Return actions that are not yet completed.""" - return [agent for agent in agents if not self.is_completed(agent)] + """Return actions that are not yet completed or failed (runnable).""" + terminal = {"completed", "failed"} + return [ + agent + for agent in agents + if self.get_status(agent) not in terminal + ] def get_batch_submitted_actions(self, agents: list[str]) -> list[str]: """Return actions with batch jobs submitted.""" @@ -115,6 +120,13 @@ def is_workflow_complete(self) -> bool: """Return True if all actions have 'completed' status.""" return all(details.get("status") == "completed" for details in self.action_status.values()) + def is_workflow_done(self) -> bool: + """Return True if all actions are in a terminal state (completed or failed).""" + terminal = {"completed", "failed"} + return all( + details.get("status") in terminal for details in self.action_status.values() + ) + def has_any_failed(self) -> bool: """Return True if any action has 'failed' status.""" return any(details.get("status") == "failed" for details in self.action_status.values()) diff --git a/agent_actions/workflow/models.py b/agent_actions/workflow/models.py index 9a73186..a0a6d5e 100644 --- a/agent_actions/workflow/models.py +++ b/agent_actions/workflow/models.py @@ -34,6 +34,7 @@ class WorkflowRuntimeConfig: use_tools: bool run_upstream: bool = False run_downstream: bool = False + fresh: bool = False manager: Any = None # ConfigManager instance project_root: Path | None = None diff --git a/agent_actions/workflow/parallel/action_executor.py b/agent_actions/workflow/parallel/action_executor.py index 36d822c..12c9ede 100644 --- a/agent_actions/workflow/parallel/action_executor.py +++ b/agent_actions/workflow/parallel/action_executor.py @@ -4,10 +4,13 @@ import asyncio import copy +import logging from dataclasses import dataclass from datetime import datetime from typing import Any +logger = logging.getLogger(__name__) + from rich.console import Console from agent_actions.errors import WorkflowError, get_error_detail @@ -177,7 +180,7 @@ async def _execute_single_action(self, action_name: str, action_indices: dict, a self._fire_action_result_event(action_name, original_idx, total_actions, result) if not result.success: - raise result.error + logger.warning("Action '%s' failed: %s", action_name, result.error) async def _execute_parallel_actions(self, params: ParallelExecutionParams): """Execute multiple actions in parallel.""" @@ -228,10 +231,12 @@ async def run_with_limit(action): if errors: error_details = "\n".join([f" - {action}: {str(exc)}" for action, exc in errors]) - error_msg = ( - f"Multiple actions failed in parallel action {params.level_idx}:\n{error_details}" + logger.warning( + "Level %d: %d action(s) failed:\n%s", + params.level_idx, + len(errors), + error_details, ) - raise WorkflowError(error_msg, context={"level_idx": params.level_idx}) def _fire_action_result_event(self, action_name: str, idx: int, total: int, result): """Fire action complete or failed event for an execution result.""" @@ -268,15 +273,15 @@ def _check_batch_status( batch_pending = state_manager.get_batch_submitted_actions(level_actions) if batch_pending: - # Check for partial failures + # Log partial failures but don't raise — circuit breaker handles cascade failed_actions = state_manager.get_failed_actions(level_actions) if failed_actions: - error_msg = ( - f"Partial failure in parallel action {level_idx}: " - f"{', '.join(failed_actions)} failed while " - "batch jobs were submitted" + logger.warning( + "Level %d: %s failed while batch jobs pending for %s", + level_idx, + ", ".join(failed_actions), + ", ".join(batch_pending), ) - raise WorkflowError(error_msg, context={"level_idx": level_idx}) # Batch jobs submitted, need to wait duration = (datetime.now() - start_time).total_seconds() @@ -295,8 +300,8 @@ async def execute_level_async(self, params: LevelExecutionParams) -> bool: Returns: True if level completed, False if batch jobs pending. - Raises: - WorkflowError: If any action fails during execution. + Failed actions are logged but do not raise — the circuit breaker + in ActionExecutor handles downstream skipping. """ start_time = datetime.now() diff --git a/tests/unit/workflow/test_coordinator_sequential.py b/tests/unit/workflow/test_coordinator_sequential.py index 785d6b7..2be710a 100644 --- a/tests/unit/workflow/test_coordinator_sequential.py +++ b/tests/unit/workflow/test_coordinator_sequential.py @@ -3,8 +3,6 @@ from datetime import datetime from unittest.mock import MagicMock, patch -import pytest - from agent_actions.workflow.coordinator import AgentWorkflow from agent_actions.workflow.executor import ActionExecutionResult, ExecutionMetrics from agent_actions.workflow.models import ( @@ -126,15 +124,15 @@ def test_batch_submitted_stops(self): assert should_stop is True - def test_failure_raises(self): - """Failed result should raise the error.""" + def test_failure_continues(self): + """Failed result should log and continue (circuit breaker handles downstream).""" wf = _build_workflow() wf.services.core.state_manager.is_completed.return_value = False error = RuntimeError("agent crashed") wf.services.core.action_executor.execute_action_sync.return_value = _failed_result(error) - with pytest.raises(RuntimeError, match="agent crashed"): - wf._run_single_action(0, "agent_a", 2) + should_stop = wf._run_single_action(0, "agent_a", 2) + assert should_stop is False def test_fires_agent_start(self): """Should fire agent_start event before execution.""" @@ -186,6 +184,7 @@ def test_batch_stops_early_returns_none(self): status="batch_submitted", output_folder=None ) wf.services.core.state_manager.is_workflow_complete.return_value = False + wf.services.core.state_manager.is_workflow_done.return_value = False mgr = MagicMock() mgr.context.return_value.__enter__ = MagicMock() @@ -195,23 +194,23 @@ def test_batch_stops_early_returns_none(self): assert result is None - def test_exception_calls_handle_workflow_error_and_reraises(self): - """Exception should call handle_workflow_error and re-raise.""" + def test_failure_returns_completed_with_failures(self): + """Failed action should produce completed_with_failures, not raise.""" wf = _build_workflow(execution_order=["agent_a"]) wf.services.core.state_manager.is_completed.return_value = False error = RuntimeError("boom") wf.services.core.action_executor.execute_action_sync.return_value = _failed_result(error) + wf.services.core.state_manager.is_workflow_complete.return_value = False + wf.services.core.state_manager.is_workflow_done.return_value = True + wf.services.core.state_manager.get_failed_actions.return_value = ["agent_a"] mgr = MagicMock() mgr.context.return_value.__enter__ = MagicMock() mgr.context.return_value.__exit__ = MagicMock(return_value=False) - with ( - patch("agent_actions.workflow.coordinator.get_manager", return_value=mgr), - pytest.raises(RuntimeError, match="boom"), - ): - wf._run_workflow_with_context(datetime.now()) + with patch("agent_actions.workflow.coordinator.get_manager", return_value=mgr): + result = wf._run_workflow_with_context(datetime.now()) - wf.event_logger.handle_workflow_error.assert_called_once() + assert result == ("completed_with_failures", {"failed": ["agent_a"]}) assert wf.state.failed is True def test_downstream_resolved_after_completion(self): diff --git a/tests/unit/workflow/test_executor_lifecycle.py b/tests/unit/workflow/test_executor_lifecycle.py index 94f7192..e438f1b 100644 --- a/tests/unit/workflow/test_executor_lifecycle.py +++ b/tests/unit/workflow/test_executor_lifecycle.py @@ -53,6 +53,7 @@ def test_completed_with_output_skips(self, executor, mock_deps): mock_deps.state_manager.get_status.return_value = "completed" storage = MagicMock() storage.list_target_files.return_value = ["file1.json"] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage result = executor.execute_action_sync( @@ -74,6 +75,7 @@ def test_completed_no_output_reruns(self, executor, mock_deps): mock_deps.state_manager.get_status.return_value = "completed" storage = MagicMock() storage.list_target_files.return_value = [] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage mock_deps.skip_evaluator.should_skip_action.return_value = False @@ -99,6 +101,7 @@ def test_storage_error_during_verify_reruns_agent(self, executor, mock_deps): mock_deps.state_manager.get_status.return_value = "completed" storage = MagicMock() storage.list_target_files.side_effect = OSError("SQLite lock") + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage mock_deps.skip_evaluator.should_skip_action.return_value = False @@ -407,6 +410,7 @@ def test_with_files_returns_skip(self, executor, mock_deps): """Agent with output files should be skipped (already done).""" storage = MagicMock() storage.list_target_files.return_value = ["file.json"] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage should_skip, result = executor._verify_completion_status("agent_a") @@ -418,6 +422,7 @@ def test_no_files_resets_to_pending(self, executor, mock_deps): """Agent with no output files should be reset to pending.""" storage = MagicMock() storage.list_target_files.return_value = [] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage should_skip, result = executor._verify_completion_status("agent_a") @@ -439,6 +444,7 @@ def test_storage_error_resets_to_pending(self, executor, mock_deps, exc): """Any exception during verification should reset to pending and re-run.""" storage = MagicMock() storage.list_target_files.side_effect = exc + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage should_skip, result = executor._verify_completion_status("agent_a") diff --git a/tests/unit/workflow/test_limits.py b/tests/unit/workflow/test_limits.py index 93070f9..5cb4b7a 100644 --- a/tests/unit/workflow/test_limits.py +++ b/tests/unit/workflow/test_limits.py @@ -333,6 +333,7 @@ def test_same_limits_skips_action(self, executor, mock_deps): storage = MagicMock() storage.list_target_files.return_value = ["file.json"] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage result = executor.execute_action_sync( @@ -351,6 +352,7 @@ def test_no_limits_old_status_no_invalidation(self, executor, mock_deps): storage = MagicMock() storage.list_target_files.return_value = ["file.json"] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage result = executor.execute_action_sync( diff --git a/tests/workflow/test_stale_completion_verification.py b/tests/workflow/test_stale_completion_verification.py index 7f7a9ea..2c5875b 100644 --- a/tests/workflow/test_stale_completion_verification.py +++ b/tests/workflow/test_stale_completion_verification.py @@ -34,6 +34,7 @@ def _make_executor(storage_has_data: bool) -> ActionExecutor: action_runner = MagicMock() storage_backend = MagicMock() storage_backend.list_target_files.return_value = ["output.json"] if storage_has_data else [] + storage_backend.has_disposition.return_value = False action_runner.storage_backend = storage_backend deps = ExecutorDependencies( From 8f10b487e46b291e848665befff85f0b998a6e81 Mon Sep 17 00:00:00 2001 From: Muizz Lateef Date: Wed, 1 Apr 2026 20:38:30 +0100 Subject: [PATCH 2/8] style: apply ruff format to 5 files --- agent_actions/storage/backend.py | 4 +--- .../validation/preflight/resolution_service.py | 5 +---- agent_actions/validation/run_validator.py | 4 +--- .../static_analyzer/workflow_static_analyzer.py | 4 +--- agent_actions/workflow/managers/state.py | 10 ++-------- 5 files changed, 6 insertions(+), 21 deletions(-) diff --git a/agent_actions/storage/backend.py b/agent_actions/storage/backend.py index 484d77e..34a2151 100644 --- a/agent_actions/storage/backend.py +++ b/agent_actions/storage/backend.py @@ -156,9 +156,7 @@ def delete_target(self, action_name: str) -> int: authors are forced to implement it and ``--fresh`` cannot silently leave stale data behind. """ - raise NotImplementedError( - f"{type(self).__name__} must implement delete_target()" - ) + raise NotImplementedError(f"{type(self).__name__} must implement delete_target()") def close(self) -> None: # noqa: B027 """Close the storage backend and release resources.""" diff --git a/agent_actions/validation/preflight/resolution_service.py b/agent_actions/validation/preflight/resolution_service.py index 8239024..5265c9f 100644 --- a/agent_actions/validation/preflight/resolution_service.py +++ b/agent_actions/validation/preflight/resolution_service.py @@ -203,10 +203,7 @@ def _check_seed_file_references(self) -> list[StaticTypeError]: errors.append( StaticTypeError( - message=( - f"Seed file not found: {file_spec} " - f"(resolved to {resolved})" - ), + message=(f"Seed file not found: {file_spec} (resolved to {resolved})"), location=FieldLocation( agent_name=action_name, config_field=f"context_scope.seed_path.{field_name}", diff --git a/agent_actions/validation/run_validator.py b/agent_actions/validation/run_validator.py index d9b2f12..656a1bd 100644 --- a/agent_actions/validation/run_validator.py +++ b/agent_actions/validation/run_validator.py @@ -26,6 +26,4 @@ class RunCommandArgs(BaseModel): downstream: bool = Field( False, description="Execute all downstream workflows that depend on this workflow" ) - fresh: bool = Field( - False, description="Clear stored results and status before execution" - ) + fresh: bool = Field(False, description="Clear stored results and status before execution") diff --git a/agent_actions/validation/static_analyzer/workflow_static_analyzer.py b/agent_actions/validation/static_analyzer/workflow_static_analyzer.py index 0d9124c..b6d0cb2 100644 --- a/agent_actions/validation/static_analyzer/workflow_static_analyzer.py +++ b/agent_actions/validation/static_analyzer/workflow_static_analyzer.py @@ -562,9 +562,7 @@ def _check_lineage_reachability(self) -> list[StaticTypeWarning]: return warnings - def _trace_field_through_chain( - self, source: str, field: str, target: str - ) -> bool: + def _trace_field_through_chain(self, source: str, field: str, target: str) -> bool: """Check if a field from *source* can reach *target* through passthrough chains. BFS backwards from *target* through dependencies. At each intermediate diff --git a/agent_actions/workflow/managers/state.py b/agent_actions/workflow/managers/state.py index 2ab2a32..35c3925 100644 --- a/agent_actions/workflow/managers/state.py +++ b/agent_actions/workflow/managers/state.py @@ -81,11 +81,7 @@ def is_failed(self, action_name: str) -> bool: def get_pending_actions(self, agents: list[str]) -> list[str]: """Return actions that are not yet completed or failed (runnable).""" terminal = {"completed", "failed"} - return [ - agent - for agent in agents - if self.get_status(agent) not in terminal - ] + return [agent for agent in agents if self.get_status(agent) not in terminal] def get_batch_submitted_actions(self, agents: list[str]) -> list[str]: """Return actions with batch jobs submitted.""" @@ -123,9 +119,7 @@ def is_workflow_complete(self) -> bool: def is_workflow_done(self) -> bool: """Return True if all actions are in a terminal state (completed or failed).""" terminal = {"completed", "failed"} - return all( - details.get("status") in terminal for details in self.action_status.values() - ) + return all(details.get("status") in terminal for details in self.action_status.values()) def has_any_failed(self) -> bool: """Return True if any action has 'failed' status.""" From af5463620c813327a2dbbd0851e934b92ea17fec Mon Sep 17 00:00:00 2001 From: Muizz Lateef Date: Wed, 1 Apr 2026 20:58:30 +0100 Subject: [PATCH 3/8] =?UTF-8?q?fix:=20address=20PR=20review=20=E2=80=94=20?= =?UTF-8?q?wire=20shared=20utility,=20add=2060=20tests,=20fix=20issues?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Review fixes: 1. Wire resolve_seed_path() into StaticDataLoader._resolve_path() — shared utility now used by both pre-flight and runtime (no more duplicates) 2. Add 60 new tests covering all new production code: - resolve_seed_path (5 tests) - WorkflowResolutionService API keys/seed files/vendor compat (16 tests) - Circuit breaker _check_upstream_health/_handle_dependency_skip (13 tests) - delete_target SQLite (3 tests) - is_workflow_done + get_pending_actions (8 tests) - Drop directive + lineage reachability validation (15 tests) 3. Restore test for unexpected-crash exception path (handle_workflow_error) 4. Fix queue.pop(0) → deque.popleft() in BFS lineage tracer 5. Fix logger placement in parallel/action_executor.py (after all imports) 6. Add public reset() to ActionStateManager (replace private method calls) 7. Fix finalize_workflow ordering in async path — set state.failed before finalize so handlers see correct state --- agent_actions/prompt/context/static_loader.py | 36 +- .../workflow_static_analyzer.py | 6 +- agent_actions/workflow/coordinator.py | 6 +- agent_actions/workflow/managers/state.py | 5 + .../workflow/parallel/action_executor.py | 4 +- tests/unit/storage/test_delete_target.py | 60 +++ tests/unit/utils/test_path_security.py | 76 +++ .../unit/validation/test_drop_and_lineage.py | 448 ++++++++++++++++++ .../validation/test_resolution_service.py | 294 ++++++++++++ .../managers/test_state_extensions.py | 94 ++++ tests/unit/workflow/test_circuit_breaker.py | 197 ++++++++ .../workflow/test_coordinator_sequential.py | 22 + 12 files changed, 1220 insertions(+), 28 deletions(-) create mode 100644 tests/unit/storage/test_delete_target.py create mode 100644 tests/unit/utils/test_path_security.py create mode 100644 tests/unit/validation/test_drop_and_lineage.py create mode 100644 tests/unit/validation/test_resolution_service.py create mode 100644 tests/unit/workflow/managers/test_state_extensions.py create mode 100644 tests/unit/workflow/test_circuit_breaker.py diff --git a/agent_actions/prompt/context/static_loader.py b/agent_actions/prompt/context/static_loader.py index a86227f..d52a8c2 100644 --- a/agent_actions/prompt/context/static_loader.py +++ b/agent_actions/prompt/context/static_loader.py @@ -124,7 +124,14 @@ def _parse_file_path(self, file_spec: str, field_name: str) -> str: return file_spec # Use as-is def _resolve_path(self, file_path: str, field_name: str) -> Path: - """Resolve file path relative to static_data_dir with security validation.""" + """Resolve file path relative to static_data_dir with security validation. + + Delegates core traversal prevention to the shared ``resolve_seed_path`` + utility and wraps any failure in a ``StaticDataLoadError`` with rich + context for diagnostics. + """ + from agent_actions.utils.path_security import resolve_seed_path + path = Path(file_path) # Reject absolute paths immediately @@ -140,40 +147,27 @@ def _resolve_path(self, file_path: str, field_name: str) -> Path: }, ) - # Resolve relative to static_data_dir - resolved = (self.static_data_dir / path).resolve() - - # Validate security - self._validate_path_security(resolved, field_name, file_path) - - logger.debug("Resolved path for field '%s': %s", field_name, resolved) - return resolved - - def _validate_path_security( - self, resolved_path: Path, field_name: str, original_path: str - ) -> None: - """Validate that resolved path doesn't escape static_data_dir.""" try: - # This will raise ValueError if path is outside static_data_dir - resolved_path.relative_to(self.static_data_dir) + resolved = resolve_seed_path(file_path, self.static_data_dir) except ValueError as exc: logger.error( - "Path traversal attempt detected for field '%s': %s -> %s", + "Path traversal attempt detected for field '%s': %s", field_name, - original_path, - resolved_path, + file_path, ) raise StaticDataLoadError( f"Static data field '{field_name}': File path escapes static data directory", context={ "field_name": field_name, - "original_path": original_path, - "resolved_path": str(resolved_path), + "original_path": file_path, "static_data_dir": str(self.static_data_dir), "error_type": "path_traversal_attempt", }, ) from exc + logger.debug("Resolved path for field '%s': %s", field_name, resolved) + return resolved + def _load_file(self, file_path: Path, field_name: str) -> Any: """Load file content based on file extension.""" # Check if file exists diff --git a/agent_actions/validation/static_analyzer/workflow_static_analyzer.py b/agent_actions/validation/static_analyzer/workflow_static_analyzer.py index b6d0cb2..a8846dc 100644 --- a/agent_actions/validation/static_analyzer/workflow_static_analyzer.py +++ b/agent_actions/validation/static_analyzer/workflow_static_analyzer.py @@ -569,15 +569,17 @@ def _trace_field_through_chain(self, source: str, field: str, target: str) -> bo node, checks whether the field survives (exact passthrough, wildcard passthrough, or dynamic schema). """ + from collections import deque + target_node = self.graph.get_node(target) if not target_node: return False visited: set[str] = set() - queue = list(target_node.dependencies) + queue = deque(target_node.dependencies) while queue: - current = queue.pop(0) + current = queue.popleft() if current in visited: continue visited.add(current) diff --git a/agent_actions/workflow/coordinator.py b/agent_actions/workflow/coordinator.py index e4d16fa..4e3fda3 100644 --- a/agent_actions/workflow/coordinator.py +++ b/agent_actions/workflow/coordinator.py @@ -152,8 +152,7 @@ def _clear_for_fresh_run(self) -> None: self.storage_backend.clear_disposition(action_name) except Exception as e: logger.warning("Failed to clear stored data for %s: %s", action_name, e) - self.services.core.state_manager._initialize_default_status() - self.services.core.state_manager._save_status() + self.services.core.state_manager.reset() self.console.print( "[yellow]--fresh: cleared stored results and reset all actions to pending[/yellow]" ) @@ -343,9 +342,9 @@ async def async_run(self, concurrency_limit: int = 5): state_mgr = self.services.core.state_manager duration = (datetime.now() - workflow_start).total_seconds() - self.event_logger.finalize_workflow(elapsed_time=duration) if state_mgr.is_workflow_complete(): + self.event_logger.finalize_workflow(elapsed_time=duration) downstream_success = self._resolve_downstream_workflows() if not downstream_success: return None @@ -353,6 +352,7 @@ async def async_run(self, concurrency_limit: int = 5): if state_mgr.is_workflow_done(): self.state.failed = True + self.event_logger.finalize_workflow(elapsed_time=duration) failed = state_mgr.get_failed_actions(self.execution_order) return ("completed_with_failures", {"failed": failed}) diff --git a/agent_actions/workflow/managers/state.py b/agent_actions/workflow/managers/state.py index 35c3925..7436140 100644 --- a/agent_actions/workflow/managers/state.py +++ b/agent_actions/workflow/managers/state.py @@ -31,6 +31,11 @@ def _load_status(self): else: self._initialize_default_status() + def reset(self) -> None: + """Reset all actions to 'pending' status and persist.""" + self._initialize_default_status() + self._save_status() + def _initialize_default_status(self): """Initialize all actions with 'pending' status.""" self.action_status = {action: {"status": "pending"} for action in self.execution_order} diff --git a/agent_actions/workflow/parallel/action_executor.py b/agent_actions/workflow/parallel/action_executor.py index 12c9ede..4e5e4d6 100644 --- a/agent_actions/workflow/parallel/action_executor.py +++ b/agent_actions/workflow/parallel/action_executor.py @@ -9,14 +9,14 @@ from datetime import datetime from typing import Any -logger = logging.getLogger(__name__) - from rich.console import Console from agent_actions.errors import WorkflowError, get_error_detail from agent_actions.logging.core.manager import fire_event from agent_actions.logging.events import ActionCompleteEvent, ActionFailedEvent +logger = logging.getLogger(__name__) + @dataclass class ParallelExecutionParams: diff --git a/tests/unit/storage/test_delete_target.py b/tests/unit/storage/test_delete_target.py new file mode 100644 index 0000000..bbeafe9 --- /dev/null +++ b/tests/unit/storage/test_delete_target.py @@ -0,0 +1,60 @@ +"""Tests for SQLiteBackend.delete_target().""" + +import pytest + +from agent_actions.storage.backends.sqlite_backend import SQLiteBackend + + +class TestDeleteTarget: + """Tests for delete_target() method.""" + + @pytest.fixture + def backend(self, tmp_path): + """Create a fresh SQLite backend for testing.""" + db_path = tmp_path / "agent_io" / "test.db" + backend = SQLiteBackend(str(db_path), "test_workflow") + backend.initialize() + yield backend + backend.close() + + def test_deletes_matching_rows_returns_count(self, backend): + """Deletes matching rows and returns the count.""" + backend.write_target("action_a", "batch_001.json", [{"id": 1}]) + backend.write_target("action_a", "batch_002.json", [{"id": 2}]) + backend.write_target("action_b", "batch_001.json", [{"id": 3}]) + + deleted = backend.delete_target("action_a") + + assert deleted == 2 + # action_b data should remain + remaining = backend.list_target_files("action_b") + assert remaining == ["batch_001.json"] + + def test_returns_zero_when_no_matching_rows(self, backend): + """Returns 0 when no matching rows exist.""" + deleted = backend.delete_target("nonexistent_action") + assert deleted == 0 + + def test_works_with_real_sqlite_backend(self, tmp_path): + """Full roundtrip: write, verify, delete, verify gone.""" + db_path = tmp_path / "test.db" + backend = SQLiteBackend(str(db_path), "wf") + backend.initialize() + + try: + data = [{"field": "value1"}, {"field": "value2"}] + backend.write_target("my_action", "output.json", data) + + # Verify data exists + files = backend.list_target_files("my_action") + assert len(files) == 1 + + # Delete + deleted = backend.delete_target("my_action") + assert deleted == 1 + + # Verify gone + files = backend.list_target_files("my_action") + assert len(files) == 0 + finally: + backend.close() diff --git a/tests/unit/utils/test_path_security.py b/tests/unit/utils/test_path_security.py new file mode 100644 index 0000000..3630170 --- /dev/null +++ b/tests/unit/utils/test_path_security.py @@ -0,0 +1,76 @@ +"""Tests for resolve_seed_path() in path_security module.""" + +import pytest + +from agent_actions.utils.path_security import FILE_PREFIX, resolve_seed_path + + +class TestResolveSeedPath: + """Tests for resolve_seed_path().""" + + def test_valid_file_reference_resolves(self, tmp_path): + """A $file: prefixed reference resolves to the correct absolute path.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + (seed_dir / "data.json").write_text("{}") + + result = resolve_seed_path("$file:data.json", seed_dir) + + assert result == (seed_dir / "data.json").resolve() + + def test_reference_without_prefix_resolves(self, tmp_path): + """A reference without $file: prefix resolves correctly.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + (seed_dir / "data.json").write_text("{}") + + result = resolve_seed_path("data.json", seed_dir) + + assert result == (seed_dir / "data.json").resolve() + + def test_empty_file_spec_raises_value_error(self, tmp_path): + """Empty file_spec raises ValueError.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + + with pytest.raises(ValueError, match="Empty file spec"): + resolve_seed_path("", seed_dir) + + def test_path_traversal_raises_value_error(self, tmp_path): + """Path traversal attempt (../../etc/passwd) raises ValueError.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + + with pytest.raises(ValueError, match="Seed file path escapes base directory"): + resolve_seed_path("../../etc/passwd", seed_dir) + + def test_path_traversal_with_prefix_raises(self, tmp_path): + """Path traversal with $file: prefix also raises ValueError.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + + with pytest.raises(ValueError, match="Seed file path escapes base directory"): + resolve_seed_path("$file:../../etc/passwd", seed_dir) + + def test_empty_path_after_prefix_raises(self, tmp_path): + """$file: prefix with no path after it raises ValueError.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + + with pytest.raises(ValueError, match="Empty path after prefix"): + resolve_seed_path("$file:", seed_dir) + + def test_subdirectory_path_resolves(self, tmp_path): + """Subdirectory paths within base_dir resolve correctly.""" + seed_dir = tmp_path / "seed_data" + sub = seed_dir / "subdir" + sub.mkdir(parents=True) + (sub / "nested.json").write_text("{}") + + result = resolve_seed_path("$file:subdir/nested.json", seed_dir) + + assert result == (sub / "nested.json").resolve() + + def test_file_prefix_constant(self): + """FILE_PREFIX constant is set correctly.""" + assert FILE_PREFIX == "$file:" diff --git a/tests/unit/validation/test_drop_and_lineage.py b/tests/unit/validation/test_drop_and_lineage.py new file mode 100644 index 0000000..840f896 --- /dev/null +++ b/tests/unit/validation/test_drop_and_lineage.py @@ -0,0 +1,448 @@ +"""Tests for drop directives and lineage reachability in WorkflowStaticAnalyzer.""" + +from agent_actions.config.schema import ActionKind +from agent_actions.validation.static_analyzer.data_flow_graph import ( + DataFlowGraph, + DataFlowNode, + OutputSchema, +) +from agent_actions.validation.static_analyzer.workflow_static_analyzer import ( + WorkflowStaticAnalyzer, +) + + +def _build_analyzer_with_graph(workflow_config, graph): + """Build a WorkflowStaticAnalyzer and inject a pre-built graph.""" + analyzer = WorkflowStaticAnalyzer(workflow_config) + analyzer.graph = graph + analyzer._built = True + return analyzer + + +class TestCheckDropDirectives: + """Tests for _check_drop_directives().""" + + def _make_graph_with_upstream( + self, + schema_fields=None, + observe_fields=None, + passthrough_fields=None, + is_dynamic=False, + is_schemaless=False, + ): + """Build a graph with source -> upstream_action -> downstream_action.""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="upstream", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema( + schema_fields=schema_fields or set(), + observe_fields=observe_fields or set(), + passthrough_fields=passthrough_fields or set(), + is_dynamic=is_dynamic, + is_schemaless=is_schemaless, + ), + dependencies={"source"}, + ) + ) + return graph + + def test_drop_on_schema_field_no_error(self): + """Drop on a schema field produces no error.""" + graph = self._make_graph_with_upstream(schema_fields={"summary", "title"}) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.summary"], + "observe": ["upstream.title"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 0 + + def test_drop_on_passthrough_field_produces_error(self): + """Drop on a passthrough field produces an error with a hint about passthrough.""" + graph = self._make_graph_with_upstream( + schema_fields={"summary"}, + passthrough_fields={"forwarded_field"}, + ) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.forwarded_field"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 1 + assert "passthrough" in errors[0].message.lower() + assert "forwarded_field" in errors[0].message + + def test_drop_on_nonexistent_field_produces_error(self): + """Drop on a non-existent field produces an error with available fields.""" + graph = self._make_graph_with_upstream(schema_fields={"summary", "title"}) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.nonexistent"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 1 + msg = errors[0].message.lower() + assert "non-existent" in msg or "nonexistent" in msg + assert errors[0].available_fields # Should have available fields + + def test_drop_with_wildcard_no_error(self): + """Drop with wildcard produces no error.""" + graph = self._make_graph_with_upstream(schema_fields={"summary"}) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.*"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 0 + + def test_drop_on_dynamic_schema_no_error(self): + """Drop on a dynamic schema produces no error (skipped).""" + graph = self._make_graph_with_upstream(is_dynamic=True) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.anything"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 0 + + def test_drop_on_observe_field_no_error(self): + """Drop on an observe field (not schema) produces no error.""" + graph = self._make_graph_with_upstream( + schema_fields={"summary"}, + observe_fields={"observed_field"}, + ) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.observed_field"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 0 + + +class TestCheckLineageReachability: + """Tests for _check_lineage_reachability().""" + + def test_direct_dependency_observe_no_warning(self): + """Observing a field from a direct dependency produces no warning.""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="A", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"field_x"}), + dependencies={"source"}, + ) + ) + graph.add_node( + DataFlowNode( + name="B", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"result"}), + dependencies={"A"}, + ) + ) + + workflow_config = { + "actions": [ + {"name": "A", "depends_on": ["source"]}, + { + "name": "B", + "depends_on": ["A"], + "context_scope": {"observe": ["A.field_x"]}, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + warnings = analyzer._check_lineage_reachability() + assert len(warnings) == 0 + + def test_transitive_observe_with_wildcard_passthrough_no_warning(self): + """Transitive observe with wildcard passthrough produces no warning.""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="A", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"field_x"}), + dependencies={"source"}, + ) + ) + graph.add_node( + DataFlowNode( + name="B", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema( + schema_fields={"result"}, + passthrough_wildcard_sources={"A"}, + ), + dependencies={"A"}, + ) + ) + graph.add_node( + DataFlowNode( + name="C", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"final"}), + dependencies={"B"}, + ) + ) + + workflow_config = { + "actions": [ + {"name": "A", "depends_on": ["source"]}, + { + "name": "B", + "depends_on": ["A"], + "context_scope": {"passthrough": ["A.*"]}, + }, + { + "name": "C", + "depends_on": ["B"], + "context_scope": {"observe": ["A.field_x"]}, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + warnings = analyzer._check_lineage_reachability() + assert len(warnings) == 0 + + def test_transitive_observe_with_explicit_field_passthrough_no_warning(self): + """Transitive observe with explicit field passthrough produces no warning.""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="A", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"field_x"}), + dependencies={"source"}, + ) + ) + graph.add_node( + DataFlowNode( + name="B", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema( + schema_fields={"result"}, + passthrough_fields={"field_x"}, + ), + dependencies={"A"}, + ) + ) + graph.add_node( + DataFlowNode( + name="C", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"final"}), + dependencies={"B"}, + ) + ) + + workflow_config = { + "actions": [ + {"name": "A", "depends_on": ["source"]}, + { + "name": "B", + "depends_on": ["A"], + "context_scope": {"passthrough": ["A.field_x"]}, + }, + { + "name": "C", + "depends_on": ["B"], + "context_scope": {"observe": ["A.field_x"]}, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + warnings = analyzer._check_lineage_reachability() + assert len(warnings) == 0 + + def test_transitive_observe_without_passthrough_produces_warning(self): + """Transitive observe with NO passthrough produces a warning.""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="A", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"field_x"}), + dependencies={"source"}, + ) + ) + graph.add_node( + DataFlowNode( + name="B", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"result"}), + dependencies={"A"}, + ) + ) + graph.add_node( + DataFlowNode( + name="C", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"final"}), + dependencies={"B"}, + ) + ) + + workflow_config = { + "actions": [ + {"name": "A", "depends_on": ["source"]}, + {"name": "B", "depends_on": ["A"]}, + { + "name": "C", + "depends_on": ["B"], + "context_scope": {"observe": ["A.field_x"]}, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + warnings = analyzer._check_lineage_reachability() + assert len(warnings) == 1 + assert "field_x" in warnings[0].message + assert "passthrough" in warnings[0].hint.lower() + + def test_dynamic_intermediate_no_warning(self): + """Dynamic intermediate schema produces no warning (data may survive).""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="A", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"field_x"}), + dependencies={"source"}, + ) + ) + graph.add_node( + DataFlowNode( + name="B", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(is_dynamic=True), + dependencies={"A"}, + ) + ) + graph.add_node( + DataFlowNode( + name="C", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"final"}), + dependencies={"B"}, + ) + ) + + workflow_config = { + "actions": [ + {"name": "A", "depends_on": ["source"]}, + {"name": "B", "depends_on": ["A"]}, + { + "name": "C", + "depends_on": ["B"], + "context_scope": {"observe": ["A.field_x"]}, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + warnings = analyzer._check_lineage_reachability() + assert len(warnings) == 0 diff --git a/tests/unit/validation/test_resolution_service.py b/tests/unit/validation/test_resolution_service.py new file mode 100644 index 0000000..0b15f40 --- /dev/null +++ b/tests/unit/validation/test_resolution_service.py @@ -0,0 +1,294 @@ +"""Tests for WorkflowResolutionService pre-flight checks.""" + +from agent_actions.validation.preflight.resolution_service import ( + WorkflowResolutionService, +) + + +class TestApiKeyChecks: + """Tests for _check_api_keys() via resolve_all().""" + + def test_missing_api_key_env_var_detected(self, monkeypatch): + """Missing API key env var produces an error with the correct message.""" + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "summarizer": {"model_vendor": "openai"}, + }, + ) + result = svc.resolve_all() + + assert not result.is_valid + assert len(result.errors) == 1 + err = result.errors[0] + assert "OPENAI_API_KEY" in err.message + assert "summarizer" in err.message + + def test_present_api_key_passes(self, monkeypatch): + """When the API key env var is set, no error is produced.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key") + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "summarizer": {"model_vendor": "openai"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [ + e + for e in result.errors + if "api_key" in e.location.config_field.lower() or "API key" in e.message + ] + assert len(api_key_errors) == 0 + + def test_tool_vendor_skipped(self, monkeypatch): + """Tool vendor actions are skipped (NO_KEY_REQUIRED sentinel).""" + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "my_tool": {"model_vendor": "tool"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 0 + + def test_hitl_vendor_skipped(self, monkeypatch): + """HITL vendor actions are skipped (NO_KEY_REQUIRED sentinel).""" + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "review": {"model_vendor": "hitl"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 0 + + def test_unknown_vendor_skipped(self, monkeypatch): + """Unknown vendor produces no api-key error (no config to check against).""" + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "custom": {"model_vendor": "totally_unknown_vendor"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 0 + + def test_custom_api_key_dollar_resolved_as_env_var(self, monkeypatch): + """Custom api_key starting with $ is resolved as an env var name.""" + monkeypatch.delenv("MY_CUSTOM_KEY", raising=False) + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "my_action": {"model_vendor": "openai", "api_key": "$MY_CUSTOM_KEY"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 1 + assert "MY_CUSTOM_KEY" in api_key_errors[0].message + + def test_skip_env_validation_flag(self, monkeypatch): + """AA_SKIP_ENV_VALIDATION=1 skips all env checks.""" + monkeypatch.setenv("AA_SKIP_ENV_VALIDATION", "1") + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "summarizer": {"model_vendor": "openai"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 0 + + def test_empty_string_env_var_treated_as_missing(self, monkeypatch): + """Empty string env var is treated as missing.""" + monkeypatch.setenv("OPENAI_API_KEY", "") + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "summarizer": {"model_vendor": "openai"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 1 + + +class TestSeedFileChecks: + """Tests for _check_seed_file_references().""" + + def test_missing_seed_file_detected(self, tmp_path): + """Missing seed file is detected with available files in hint.""" + # Setup directory structure: project/agent_config/workflow.yml + project/seed_data/ + project = tmp_path / "project" + agent_config = project / "agent_config" + agent_config.mkdir(parents=True) + seed_data = project / "seed_data" + seed_data.mkdir() + (seed_data / "existing.json").write_text("{}") + + workflow_path = str(agent_config / "workflow.yml") + + svc = WorkflowResolutionService( + action_configs={ + "loader": { + "context_scope": { + "seed_path": {"field1": "$file:missing.json"}, + }, + }, + }, + workflow_config_path=workflow_path, + ) + result = svc.resolve_all() + + seed_errors = [e for e in result.errors if "Seed file not found" in e.message] + assert len(seed_errors) == 1 + assert "existing.json" in seed_errors[0].hint + + def test_existing_seed_file_passes(self, tmp_path): + """Existing seed file produces no error.""" + project = tmp_path / "project" + agent_config = project / "agent_config" + agent_config.mkdir(parents=True) + seed_data = project / "seed_data" + seed_data.mkdir() + (seed_data / "data.json").write_text("{}") + + workflow_path = str(agent_config / "workflow.yml") + + svc = WorkflowResolutionService( + action_configs={ + "loader": { + "context_scope": { + "seed_path": {"field1": "$file:data.json"}, + }, + }, + }, + workflow_config_path=workflow_path, + ) + result = svc.resolve_all() + + seed_errors = [e for e in result.errors if "seed" in e.message.lower()] + assert len(seed_errors) == 0 + + def test_path_traversal_caught(self, tmp_path): + """Path traversal in seed file reference is caught.""" + project = tmp_path / "project" + agent_config = project / "agent_config" + agent_config.mkdir(parents=True) + seed_data = project / "seed_data" + seed_data.mkdir() + + workflow_path = str(agent_config / "workflow.yml") + + svc = WorkflowResolutionService( + action_configs={ + "loader": { + "context_scope": { + "seed_path": {"field1": "$file:../../etc/passwd"}, + }, + }, + }, + workflow_config_path=workflow_path, + ) + result = svc.resolve_all() + + seed_errors = [e for e in result.errors if "escapes base directory" in e.message] + assert len(seed_errors) == 1 + + def test_no_seed_path_config_no_errors(self): + """No seed_path in config produces no errors.""" + svc = WorkflowResolutionService( + action_configs={ + "loader": {"context_scope": {}}, + }, + ) + result = svc.resolve_all() + + seed_errors = [e for e in result.errors if "seed" in e.message.lower()] + assert len(seed_errors) == 0 + + def test_seed_data_dir_missing_graceful_skip(self, tmp_path): + """When seed_data directory doesn't exist, gracefully skip (no errors).""" + project = tmp_path / "project" + agent_config = project / "agent_config" + agent_config.mkdir(parents=True) + # Intentionally do NOT create seed_data directory + + workflow_path = str(agent_config / "workflow.yml") + + svc = WorkflowResolutionService( + action_configs={ + "loader": { + "context_scope": { + "seed_path": {"field1": "$file:data.json"}, + }, + }, + }, + workflow_config_path=workflow_path, + ) + result = svc.resolve_all() + + seed_errors = [e for e in result.errors if "seed" in e.message.lower()] + assert len(seed_errors) == 0 + + +class TestVendorRunModeCompatibility: + """Tests for _check_vendor_run_mode_compatibility().""" + + def test_batch_mode_with_non_batch_vendor_produces_error(self): + """Batch mode with a non-batch vendor (e.g., ollama) produces error.""" + svc = WorkflowResolutionService( + action_configs={ + "local_action": {"model_vendor": "ollama", "run_mode": "batch"}, + }, + ) + result = svc.resolve_all() + + batch_errors = [e for e in result.errors if "batch" in e.message.lower()] + assert len(batch_errors) == 1 + assert "ollama" in batch_errors[0].message + + def test_online_mode_passes_for_any_vendor(self): + """Online mode passes for any vendor (no batch mode check needed).""" + svc = WorkflowResolutionService( + action_configs={ + "my_action": {"model_vendor": "ollama", "run_mode": "online"}, + }, + ) + result = svc.resolve_all() + + batch_errors = [e for e in result.errors if "batch" in e.message.lower()] + assert len(batch_errors) == 0 + + def test_batch_mode_with_batch_capable_vendor_passes(self): + """Batch mode with a batch-capable vendor passes.""" + svc = WorkflowResolutionService( + action_configs={ + "my_action": {"model_vendor": "openai", "run_mode": "batch"}, + }, + ) + result = svc.resolve_all() + + batch_errors = [e for e in result.errors if "batch" in e.message.lower()] + assert len(batch_errors) == 0 diff --git a/tests/unit/workflow/managers/test_state_extensions.py b/tests/unit/workflow/managers/test_state_extensions.py new file mode 100644 index 0000000..da40822 --- /dev/null +++ b/tests/unit/workflow/managers/test_state_extensions.py @@ -0,0 +1,94 @@ +"""Tests for ActionStateManager.is_workflow_done() and updated get_pending_actions().""" + +from agent_actions.workflow.managers.state import ActionStateManager + + +class TestIsWorkflowDone: + """Tests for is_workflow_done().""" + + def test_true_when_all_completed(self, tmp_path): + """is_workflow_done returns True when all actions are completed.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "completed") + + assert mgr.is_workflow_done() is True + + def test_true_when_mix_of_completed_and_failed(self, tmp_path): + """is_workflow_done returns True when all actions are completed or failed.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b", "c"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "failed") + mgr.update_status("c", "completed") + + assert mgr.is_workflow_done() is True + + def test_false_when_some_still_pending(self, tmp_path): + """is_workflow_done returns False when some actions are still pending.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b", "c"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "failed") + # c stays pending + + assert mgr.is_workflow_done() is False + + def test_false_when_some_still_running(self, tmp_path): + """is_workflow_done returns False when some actions are still running.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "running") + + assert mgr.is_workflow_done() is False + + +class TestGetPendingActionsExcludesFailed: + """Tests for get_pending_actions() excluding both completed AND failed.""" + + def test_excludes_completed(self, tmp_path): + """get_pending_actions excludes completed actions.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b", "c"]) + mgr.update_status("a", "completed") + + pending = mgr.get_pending_actions(["a", "b", "c"]) + assert "a" not in pending + assert "b" in pending + assert "c" in pending + + def test_excludes_failed(self, tmp_path): + """get_pending_actions excludes failed actions.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b", "c"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "failed") + + pending = mgr.get_pending_actions(["a", "b", "c"]) + assert "a" not in pending + assert "b" not in pending + assert "c" in pending + + def test_excludes_both_completed_and_failed(self, tmp_path): + """get_pending_actions excludes both completed and failed.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b", "c", "d"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "failed") + mgr.update_status("c", "running") + # d stays pending + + pending = mgr.get_pending_actions(["a", "b", "c", "d"]) + assert pending == ["c", "d"] + + def test_all_terminal_returns_empty(self, tmp_path): + """When all are completed/failed, returns empty list.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "failed") + + pending = mgr.get_pending_actions(["a", "b"]) + assert pending == [] diff --git a/tests/unit/workflow/test_circuit_breaker.py b/tests/unit/workflow/test_circuit_breaker.py new file mode 100644 index 0000000..cab5956 --- /dev/null +++ b/tests/unit/workflow/test_circuit_breaker.py @@ -0,0 +1,197 @@ +"""Tests for ActionExecutor circuit breaker methods.""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from agent_actions.storage.backend import DISPOSITION_FAILED, NODE_LEVEL_RECORD_ID +from agent_actions.workflow.executor import ( + ActionExecutionResult, + ActionExecutor, + ExecutorDependencies, +) +from agent_actions.workflow.managers.batch import BatchLifecycleManager +from agent_actions.workflow.managers.output import ActionOutputManager +from agent_actions.workflow.managers.skip import SkipEvaluator +from agent_actions.workflow.managers.state import ActionStateManager + + +@pytest.fixture +def mock_deps(): + """Create mock dependencies for ActionExecutor.""" + deps = MagicMock(spec=ExecutorDependencies) + deps.state_manager = MagicMock(spec=ActionStateManager) + deps.batch_manager = MagicMock(spec=BatchLifecycleManager) + deps.action_runner = MagicMock() + deps.skip_evaluator = MagicMock(spec=SkipEvaluator) + deps.output_manager = MagicMock(spec=ActionOutputManager) + deps.action_runner.execution_order = ["agent_a", "agent_b", "agent_c"] + return deps + + +@pytest.fixture +def executor(mock_deps): + """Create executor with mock dependencies.""" + return ActionExecutor(mock_deps) + + +class TestCheckUpstreamHealth: + """Tests for _check_upstream_health().""" + + def test_no_dependencies_returns_none(self, executor): + """No dependencies means all healthy — returns None.""" + config = {"dependencies": []} + result = executor._check_upstream_health("agent_b", config) + assert result is None + + def test_no_dependencies_key_returns_none(self, executor): + """Missing dependencies key means all healthy — returns None.""" + config = {} + result = executor._check_upstream_health("agent_b", config) + assert result is None + + def test_all_deps_healthy_returns_none(self, executor, mock_deps): + """All dependencies healthy — returns None.""" + mock_deps.state_manager.is_failed.return_value = False + mock_deps.action_runner.storage_backend = None + + config = {"dependencies": ["agent_a"]} + result = executor._check_upstream_health("agent_b", config) + assert result is None + + def test_dep_failed_via_state_manager(self, executor, mock_deps): + """One dep failed (state_manager.is_failed) returns dep name.""" + mock_deps.state_manager.is_failed.return_value = True + + config = {"dependencies": ["agent_a"]} + result = executor._check_upstream_health("agent_b", config) + assert result == "agent_a" + + def test_dep_failed_via_disposition(self, executor, mock_deps): + """One dep has DISPOSITION_FAILED in storage returns dep name.""" + mock_deps.state_manager.is_failed.return_value = False + storage = MagicMock() + storage.has_disposition.return_value = True + mock_deps.action_runner.storage_backend = storage + + config = {"dependencies": ["agent_a"]} + result = executor._check_upstream_health("agent_b", config) + + assert result == "agent_a" + storage.has_disposition.assert_called_once_with( + "agent_a", DISPOSITION_FAILED, record_id=NODE_LEVEL_RECORD_ID + ) + + def test_no_storage_backend_only_checks_state_manager(self, executor, mock_deps): + """No storage backend — only checks state_manager.""" + mock_deps.state_manager.is_failed.return_value = False + mock_deps.action_runner.storage_backend = None + + config = {"dependencies": ["agent_a"]} + result = executor._check_upstream_health("agent_b", config) + assert result is None + + +class TestHandleDependencySkip: + """Tests for _handle_dependency_skip().""" + + @patch("agent_actions.workflow.executor.fire_event") + def test_updates_state_to_failed(self, mock_fire, executor, mock_deps): + """Updates state to 'failed'.""" + mock_deps.action_runner.storage_backend = None + start_time = datetime.now() + + executor._handle_dependency_skip("agent_b", 1, {}, start_time, "agent_a") + + mock_deps.state_manager.update_status.assert_called_once_with("agent_b", "failed") + + @patch("agent_actions.workflow.executor.fire_event") + def test_writes_failed_disposition(self, mock_fire, executor, mock_deps): + """Writes DISPOSITION_FAILED to storage.""" + storage = MagicMock() + mock_deps.action_runner.storage_backend = storage + start_time = datetime.now() + + executor._handle_dependency_skip("agent_b", 1, {}, start_time, "agent_a") + + storage.set_disposition.assert_called_once() + call_kwargs = storage.set_disposition.call_args + assert call_kwargs[1]["disposition"] == DISPOSITION_FAILED + assert call_kwargs[1]["action_name"] == "agent_b" + + @patch("agent_actions.workflow.executor.fire_event") + def test_fires_action_skip_event(self, mock_fire, executor, mock_deps): + """Fires ActionSkipEvent with correct reason.""" + mock_deps.action_runner.storage_backend = None + start_time = datetime.now() + + executor._handle_dependency_skip("agent_b", 1, {}, start_time, "agent_a") + + mock_fire.assert_called_once() + event = mock_fire.call_args[0][0] + from agent_actions.logging.events import ActionSkipEvent + + assert isinstance(event, ActionSkipEvent) + assert "agent_a" in event.skip_reason + + @patch("agent_actions.workflow.executor.fire_event") + def test_records_in_run_tracker_if_available(self, mock_fire, executor, mock_deps): + """Records in run_tracker if available.""" + mock_deps.action_runner.storage_backend = None + executor.run_tracker = MagicMock() + executor.run_id = "run-123" + start_time = datetime.now() + + executor._handle_dependency_skip("agent_b", 1, {}, start_time, "agent_a") + + executor.run_tracker.record_action_complete.assert_called_once() + config = executor.run_tracker.record_action_complete.call_args[1]["config"] + assert config.status == "skipped" + assert config.run_id == "run-123" + + @patch("agent_actions.workflow.executor.fire_event") + def test_returns_skipped_result(self, mock_fire, executor, mock_deps): + """Returns ActionExecutionResult(success=True, status='skipped').""" + mock_deps.action_runner.storage_backend = None + start_time = datetime.now() + + result = executor._handle_dependency_skip("agent_b", 1, {}, start_time, "agent_a") + + assert isinstance(result, ActionExecutionResult) + assert result.success is True + assert result.status == "skipped" + + +class TestWriteFailedDisposition: + """Tests for _write_failed_disposition().""" + + def test_writes_disposition_when_storage_available(self, executor, mock_deps): + """Writes disposition when storage backend is available.""" + storage = MagicMock() + mock_deps.action_runner.storage_backend = storage + + executor._write_failed_disposition("agent_a", "Some error") + + storage.set_disposition.assert_called_once_with( + action_name="agent_a", + record_id=NODE_LEVEL_RECORD_ID, + disposition=DISPOSITION_FAILED, + reason="Some error", + ) + + def test_logs_warning_on_storage_error(self, executor, mock_deps, caplog): + """Logs warning on storage error (doesn't raise).""" + storage = MagicMock() + storage.set_disposition.side_effect = RuntimeError("DB error") + mock_deps.action_runner.storage_backend = storage + + # Should not raise + executor._write_failed_disposition("agent_a", "Some error") + + def test_noops_when_storage_backend_is_none(self, executor, mock_deps): + """No-ops when storage backend is None.""" + mock_deps.action_runner.storage_backend = None + + # Should not raise; nothing happens + executor._write_failed_disposition("agent_a", "Some error") diff --git a/tests/unit/workflow/test_coordinator_sequential.py b/tests/unit/workflow/test_coordinator_sequential.py index 2be710a..40a6b08 100644 --- a/tests/unit/workflow/test_coordinator_sequential.py +++ b/tests/unit/workflow/test_coordinator_sequential.py @@ -3,6 +3,8 @@ from datetime import datetime from unittest.mock import MagicMock, patch +import pytest + from agent_actions.workflow.coordinator import AgentWorkflow from agent_actions.workflow.executor import ActionExecutionResult, ExecutionMetrics from agent_actions.workflow.models import ( @@ -213,6 +215,26 @@ def test_failure_returns_completed_with_failures(self): assert result == ("completed_with_failures", {"failed": ["agent_a"]}) assert wf.state.failed is True + def test_unexpected_exception_calls_handle_workflow_error_and_reraises(self): + """Unexpected crash (not a failed result) should call handle_workflow_error and re-raise.""" + wf = _build_workflow(execution_order=["agent_a"]) + wf.services.core.state_manager.is_completed.return_value = False + wf.services.core.action_executor.execute_action_sync.side_effect = RuntimeError( + "unexpected crash" + ) + + mgr = MagicMock() + mgr.context.return_value.__enter__ = MagicMock() + mgr.context.return_value.__exit__ = MagicMock(return_value=False) + with ( + patch("agent_actions.workflow.coordinator.get_manager", return_value=mgr), + pytest.raises(RuntimeError, match="unexpected crash"), + ): + wf._run_workflow_with_context(datetime.now()) + + wf.event_logger.handle_workflow_error.assert_called_once() + assert wf.state.failed is True + def test_downstream_resolved_after_completion(self): """After all agents complete, should attempt downstream resolution.""" wf = _build_workflow() From acf751737770999264ebcfa89d9b4cf1e8979fb4 Mon Sep 17 00:00:00 2001 From: Muizz Lateef Date: Wed, 1 Apr 2026 21:11:21 +0100 Subject: [PATCH 4/8] =?UTF-8?q?fix:=20resolve=20mypy=20errors=20=E2=80=94?= =?UTF-8?q?=20remove=20unreachable=20isinstance=20check,=20type=20vendor?= =?UTF-8?q?=20config=20map?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_actions/validation/preflight/resolution_service.py | 9 ++++++--- agent_actions/workflow/executor.py | 2 -- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/agent_actions/validation/preflight/resolution_service.py b/agent_actions/validation/preflight/resolution_service.py index 5265c9f..84ffe15 100644 --- a/agent_actions/validation/preflight/resolution_service.py +++ b/agent_actions/validation/preflight/resolution_service.py @@ -13,6 +13,8 @@ from pathlib import Path from typing import Any +from pydantic import BaseModel + from agent_actions.utils.path_security import resolve_seed_path from agent_actions.validation.static_analyzer.errors import ( FieldLocation, @@ -25,13 +27,13 @@ # Vendor name → config class mapping. Built lazily on first access to # avoid importing all vendor configs (and transitively their SDKs) at # module level. -_VENDOR_CONFIG_MAP: dict[str, type] | None = None +_VENDOR_CONFIG_MAP: dict[str, type[BaseModel]] | None = None # Sentinel substrings in api_key_env_name that indicate no real key is needed. _NO_KEY_SENTINELS = ("NO_KEY_REQUIRED",) -def _get_vendor_config_map() -> dict[str, type]: +def _get_vendor_config_map() -> dict[str, type[BaseModel]]: """Build vendor → config class map on first call (lazy).""" global _VENDOR_CONFIG_MAP # noqa: PLW0603 if _VENDOR_CONFIG_MAP is not None: @@ -74,7 +76,8 @@ def _get_api_key_env_name(vendor: str) -> str | None: field_info = config_cls.model_fields.get("api_key_env_name") if field_info is None: return None - return field_info.default + default = field_info.default + return str(default) if default is not None else None class WorkflowResolutionService: diff --git a/agent_actions/workflow/executor.py b/agent_actions/workflow/executor.py index ac45c62..06c6388 100644 --- a/agent_actions/workflow/executor.py +++ b/agent_actions/workflow/executor.py @@ -421,8 +421,6 @@ def _check_upstream_health( if not dependencies: return None for dep in dependencies: - if not isinstance(dep, str): - continue if self.deps.state_manager.is_failed(dep): return dep # Also check disposition — covers cascaded failures from prior levels From ef542942b999d36509f82c1467d2a30268390510 Mon Sep 17 00:00:00 2001 From: Muizz Lateef Date: Wed, 1 Apr 2026 22:06:12 +0100 Subject: [PATCH 5/8] fix: raise when all items fail in processing pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When every input item fails during LLM processing (e.g. 401 auth error), the pipeline now raises instead of silently writing empty output. This lets the executor mark the action as failed and the circuit breaker skip downstream dependents. Previously, per-item errors were caught inside process_batch(), logged, and the action was marked "completed" with 0 records — causing cascading empty-data execution through the entire downstream chain. --- agent_actions/workflow/pipeline.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/agent_actions/workflow/pipeline.py b/agent_actions/workflow/pipeline.py index fe8daa9..7d27159 100644 --- a/agent_actions/workflow/pipeline.py +++ b/agent_actions/workflow/pipeline.py @@ -490,6 +490,22 @@ def _process_by_strategy( storage_backend=self.config.storage_backend, ) + # If input had records but output is empty, all items failed during + # processing (e.g. 401 auth error on every LLM call). Raise so the + # executor marks the action as failed and the circuit breaker skips + # downstream dependents. + if data and not output: + from agent_actions.processing.types import ProcessingStatus + + failed_msgs = [ + r.error for r in results if r.status == ProcessingStatus.FAILED and r.error + ] + summary = "; ".join(failed_msgs[:3]) # first 3 errors + raise RuntimeError( + f"Action '{self.config.action_name}' produced 0 records — " + f"all {len(data)} input item(s) failed: {summary}" + ) + self.output_handler.save_main_output(output, file_path, base_directory, output_directory) @staticmethod From ee96428b46a423a39acd5db8cf921dd107e2ae48 Mon Sep 17 00:00:00 2001 From: Muizz Lateef Date: Wed, 1 Apr 2026 22:23:20 +0100 Subject: [PATCH 6/8] fix: only raise on FAILED results, not guard-filtered empty output --- agent_actions/workflow/pipeline.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/agent_actions/workflow/pipeline.py b/agent_actions/workflow/pipeline.py index 7d27159..0f048b5 100644 --- a/agent_actions/workflow/pipeline.py +++ b/agent_actions/workflow/pipeline.py @@ -490,21 +490,24 @@ def _process_by_strategy( storage_backend=self.config.storage_backend, ) - # If input had records but output is empty, all items failed during - # processing (e.g. 401 auth error on every LLM call). Raise so the - # executor marks the action as failed and the circuit breaker skips - # downstream dependents. + # If input had records but output is empty AND there are actual failures + # (not just guard-filtered/skipped records), raise so the executor marks + # the action as failed and the circuit breaker skips downstream dependents. + # Guard filters (SKIPPED/FILTERED status) legitimately produce 0 output — + # only FAILED results indicate processing errors (e.g. 401 auth). if data and not output: from agent_actions.processing.types import ProcessingStatus - failed_msgs = [ - r.error for r in results if r.status == ProcessingStatus.FAILED and r.error + failed_results = [ + r for r in results if r.status == ProcessingStatus.FAILED ] - summary = "; ".join(failed_msgs[:3]) # first 3 errors - raise RuntimeError( - f"Action '{self.config.action_name}' produced 0 records — " - f"all {len(data)} input item(s) failed: {summary}" - ) + if failed_results: + failed_msgs = [r.error for r in failed_results if r.error] + summary = "; ".join(failed_msgs[:3]) + raise RuntimeError( + f"Action '{self.config.action_name}' produced 0 records — " + f"all {len(data)} input item(s) failed: {summary}" + ) self.output_handler.save_main_output(output, file_path, base_directory, output_directory) From deadde7f4e5adc2aee11006df3619702cc40fa8e Mon Sep 17 00:00:00 2001 From: Muizz Lateef Date: Wed, 1 Apr 2026 22:32:03 +0100 Subject: [PATCH 7/8] style: format pipeline.py --- agent_actions/workflow/pipeline.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/agent_actions/workflow/pipeline.py b/agent_actions/workflow/pipeline.py index 0f048b5..87c41c6 100644 --- a/agent_actions/workflow/pipeline.py +++ b/agent_actions/workflow/pipeline.py @@ -498,9 +498,7 @@ def _process_by_strategy( if data and not output: from agent_actions.processing.types import ProcessingStatus - failed_results = [ - r for r in results if r.status == ProcessingStatus.FAILED - ] + failed_results = [r for r in results if r.status == ProcessingStatus.FAILED] if failed_results: failed_msgs = [r.error for r in failed_results if r.error] summary = "; ".join(failed_msgs[:3]) From 6583cf5fdcf39ec5c918c83d237f859a8d5231aa Mon Sep 17 00:00:00 2001 From: Muizz Lateef Date: Wed, 1 Apr 2026 23:10:01 +0100 Subject: [PATCH 8/8] fix: address PR review round 2 1. Extract storage_backend before dependency loop (#2) 2. Align status: _handle_dependency_skip returns status="failed" to match state_manager, eliminating failed/skipped inconsistency (#3) 3. Add depth limit (10 levels) to _resolve_seed_data_dir (#4) 4. Replace global _VENDOR_CONFIG_MAP with @lru_cache (#5) 5. Add BFS field-name limitation comment in lineage tracer (#6) 6. Extract MAX_DISPOSITION_REASON_LENGTH constant (#7) 7. Move ProcessingStatus import to module level in pipeline.py (#10) --- agent_actions/workflow/executor.py | 10 +++++++--- tests/unit/workflow/test_circuit_breaker.py | 8 ++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/agent_actions/workflow/executor.py b/agent_actions/workflow/executor.py index 06c6388..bbfe4f1 100644 --- a/agent_actions/workflow/executor.py +++ b/agent_actions/workflow/executor.py @@ -439,7 +439,11 @@ def _handle_dependency_skip( start_time: datetime, failed_dependency: str, ) -> ActionExecutionResult: - """Handle action skip due to upstream dependency failure.""" + """Handle action skip due to upstream dependency failure. + + State is set to ``"failed"`` so transitive dependents also skip via + ``is_failed``. ``success=True`` keeps independent branches alive. + """ reason = f"Upstream dependency '{failed_dependency}' failed" self.deps.state_manager.update_status(action_name, "failed") self._write_failed_disposition(action_name, reason) @@ -463,14 +467,14 @@ def _handle_dependency_skip( config = ActionCompleteConfig( run_id=self.run_id, action_name=action_name, - status="skipped", + status="failed", duration_seconds=duration, skip_reason=reason, ) self.run_tracker.record_action_complete(config=config) return ActionExecutionResult( - success=True, status="skipped", metrics=ExecutionMetrics(duration=duration) + success=True, status="failed", metrics=ExecutionMetrics(duration=duration) ) def execute_action_sync( diff --git a/tests/unit/workflow/test_circuit_breaker.py b/tests/unit/workflow/test_circuit_breaker.py index cab5956..effdbff 100644 --- a/tests/unit/workflow/test_circuit_breaker.py +++ b/tests/unit/workflow/test_circuit_breaker.py @@ -147,12 +147,12 @@ def test_records_in_run_tracker_if_available(self, mock_fire, executor, mock_dep executor.run_tracker.record_action_complete.assert_called_once() config = executor.run_tracker.record_action_complete.call_args[1]["config"] - assert config.status == "skipped" + assert config.status == "failed" assert config.run_id == "run-123" @patch("agent_actions.workflow.executor.fire_event") - def test_returns_skipped_result(self, mock_fire, executor, mock_deps): - """Returns ActionExecutionResult(success=True, status='skipped').""" + def test_returns_failed_result_with_success_true(self, mock_fire, executor, mock_deps): + """Returns ActionExecutionResult(success=True, status='failed') — failed state, but independent branches continue.""" mock_deps.action_runner.storage_backend = None start_time = datetime.now() @@ -160,7 +160,7 @@ def test_returns_skipped_result(self, mock_fire, executor, mock_deps): assert isinstance(result, ActionExecutionResult) assert result.success is True - assert result.status == "skipped" + assert result.status == "failed" class TestWriteFailedDisposition: