diff --git a/agent_actions/cli/run.py b/agent_actions/cli/run.py index 76e8212..c8e42ac 100644 --- a/agent_actions/cli/run.py +++ b/agent_actions/cli/run.py @@ -89,6 +89,7 @@ def execute(self, project_root: Path | None = None) -> None: use_tools=self.args.use_tools, run_upstream=self.args.upstream, run_downstream=self.args.downstream, + fresh=self.args.fresh, project_root=project_root, ) ) @@ -187,6 +188,12 @@ def execute(self, project_root: Path | None = None) -> None: is_flag=True, help="Execute all downstream workflows that depend on this workflow", ) +@click.option( + "--fresh", + is_flag=True, + default=False, + help="Clear stored results and status before execution (useful after failed runs)", +) @handles_user_errors("run") @requires_project def run( @@ -197,6 +204,7 @@ def run( concurrency_limit: int = 5, upstream: bool = False, downstream: bool = False, + fresh: bool = False, project_root: Path | None = None, ) -> None: """ @@ -211,6 +219,7 @@ def run( agac run -a my_agent --upstream agac run -a my_agent --downstream agac run -a my_agent --execution-mode parallel + agac run -a my_agent --fresh """ args = RunCommandArgs( agent=agent, @@ -220,6 +229,7 @@ def run( concurrency_limit=concurrency_limit, upstream=upstream, downstream=downstream, + fresh=fresh, ) command = RunCommand(args) command.execute(project_root=project_root) diff --git a/agent_actions/prompt/context/static_loader.py b/agent_actions/prompt/context/static_loader.py index a86227f..d52a8c2 100644 --- a/agent_actions/prompt/context/static_loader.py +++ b/agent_actions/prompt/context/static_loader.py @@ -124,7 +124,14 @@ def _parse_file_path(self, file_spec: str, field_name: str) -> str: return file_spec # Use as-is def _resolve_path(self, file_path: str, field_name: str) -> Path: - """Resolve file path relative to static_data_dir with security validation.""" + """Resolve file path relative to static_data_dir with security validation. + + Delegates core traversal prevention to the shared ``resolve_seed_path`` + utility and wraps any failure in a ``StaticDataLoadError`` with rich + context for diagnostics. + """ + from agent_actions.utils.path_security import resolve_seed_path + path = Path(file_path) # Reject absolute paths immediately @@ -140,40 +147,27 @@ def _resolve_path(self, file_path: str, field_name: str) -> Path: }, ) - # Resolve relative to static_data_dir - resolved = (self.static_data_dir / path).resolve() - - # Validate security - self._validate_path_security(resolved, field_name, file_path) - - logger.debug("Resolved path for field '%s': %s", field_name, resolved) - return resolved - - def _validate_path_security( - self, resolved_path: Path, field_name: str, original_path: str - ) -> None: - """Validate that resolved path doesn't escape static_data_dir.""" try: - # This will raise ValueError if path is outside static_data_dir - resolved_path.relative_to(self.static_data_dir) + resolved = resolve_seed_path(file_path, self.static_data_dir) except ValueError as exc: logger.error( - "Path traversal attempt detected for field '%s': %s -> %s", + "Path traversal attempt detected for field '%s': %s", field_name, - original_path, - resolved_path, + file_path, ) raise StaticDataLoadError( f"Static data field '{field_name}': File path escapes static data directory", context={ "field_name": field_name, - "original_path": original_path, - "resolved_path": str(resolved_path), + "original_path": file_path, "static_data_dir": str(self.static_data_dir), "error_type": "path_traversal_attempt", }, ) from exc + logger.debug("Resolved path for field '%s': %s", field_name, resolved) + return resolved + def _load_file(self, file_path: Path, field_name: str) -> Any: """Load file content based on file extension.""" # Check if file exists diff --git a/agent_actions/storage/backend.py b/agent_actions/storage/backend.py index 4d564b9..34a2151 100644 --- a/agent_actions/storage/backend.py +++ b/agent_actions/storage/backend.py @@ -149,6 +149,15 @@ def clear_disposition( """Delete matching disposition records. Returns count deleted.""" return 0 + def delete_target(self, action_name: str) -> int: + """Delete all target data for an action. Returns count deleted. + + Subclasses **must** override — the default raises so that backend + authors are forced to implement it and ``--fresh`` cannot silently + leave stale data behind. + """ + raise NotImplementedError(f"{type(self).__name__} must implement delete_target()") + def close(self) -> None: # noqa: B027 """Close the storage backend and release resources.""" pass diff --git a/agent_actions/storage/backends/sqlite_backend.py b/agent_actions/storage/backends/sqlite_backend.py index 231168b..0f229fa 100644 --- a/agent_actions/storage/backends/sqlite_backend.py +++ b/agent_actions/storage/backends/sqlite_backend.py @@ -636,6 +636,35 @@ def clear_disposition( ) raise + def delete_target(self, action_name: str) -> int: + """Delete all target data for a specific action. Returns count deleted.""" + action_name = self._validate_identifier(action_name, "action_name") + with self._lock: + cursor = self.connection.cursor() + try: + cursor.execute( + "DELETE FROM target_data WHERE action_name = ?", + (action_name,), + ) + self.connection.commit() + deleted = cursor.rowcount + logger.debug( + "Deleted %d target records for %s", + deleted, + action_name, + extra={"workflow_name": self.workflow_name}, + ) + return deleted + except sqlite3.Error as e: + self.connection.rollback() + logger.error( + "Failed to delete target for %s: %s", + action_name, + e, + extra={"workflow_name": self.workflow_name}, + ) + raise + @staticmethod def _format_size(size_bytes: int) -> str: """Format bytes as human-readable size.""" diff --git a/agent_actions/utils/path_security.py b/agent_actions/utils/path_security.py new file mode 100644 index 0000000..987b971 --- /dev/null +++ b/agent_actions/utils/path_security.py @@ -0,0 +1,42 @@ +"""Shared path security utilities for seed data resolution. + +Both the pre-flight resolution service and the runtime StaticDataLoader +call ``resolve_seed_path`` so that path-traversal prevention logic exists +in exactly one place. +""" + +from pathlib import Path + +FILE_PREFIX = "$file:" + + +def resolve_seed_path(file_spec: str, base_dir: Path) -> Path: + """Parse a ``$file:`` reference, resolve against *base_dir*, and validate. + + Returns the resolved absolute ``Path``. + + Raises: + ValueError: If the spec is empty, escapes *base_dir* via traversal, + or is otherwise invalid. + """ + if not file_spec: + raise ValueError("Empty file spec") + + # Strip $file: prefix if present + file_path = file_spec[len(FILE_PREFIX) :] if file_spec.startswith(FILE_PREFIX) else file_spec + + if not file_path: + raise ValueError(f"Empty path after prefix in: {file_spec}") + + resolved = (base_dir / file_path).resolve() + + # Security: prevent path traversal outside base_dir + try: + resolved.relative_to(base_dir.resolve()) + except ValueError: + raise ValueError( + f"Seed file path escapes base directory: {file_spec} " + f"(resolved to {resolved}, base is {base_dir.resolve()})" + ) from None + + return resolved diff --git a/agent_actions/validation/preflight/resolution_service.py b/agent_actions/validation/preflight/resolution_service.py new file mode 100644 index 0000000..84ffe15 --- /dev/null +++ b/agent_actions/validation/preflight/resolution_service.py @@ -0,0 +1,291 @@ +"""Unified pre-flight resolution service. + +Performs a single comprehensive resolution pass across all actions: +- API key environment variable presence +- Seed file ($file:) reference existence +- Provider capability / run_mode compatibility + +Uses the same resolution utilities that runtime uses, ensuring no divergence. +""" + +import logging +import os +from pathlib import Path +from typing import Any + +from pydantic import BaseModel + +from agent_actions.utils.path_security import resolve_seed_path +from agent_actions.validation.static_analyzer.errors import ( + FieldLocation, + StaticTypeError, + StaticValidationResult, +) + +logger = logging.getLogger(__name__) + +# Vendor name → config class mapping. Built lazily on first access to +# avoid importing all vendor configs (and transitively their SDKs) at +# module level. +_VENDOR_CONFIG_MAP: dict[str, type[BaseModel]] | None = None + +# Sentinel substrings in api_key_env_name that indicate no real key is needed. +_NO_KEY_SENTINELS = ("NO_KEY_REQUIRED",) + + +def _get_vendor_config_map() -> dict[str, type[BaseModel]]: + """Build vendor → config class map on first call (lazy).""" + global _VENDOR_CONFIG_MAP # noqa: PLW0603 + if _VENDOR_CONFIG_MAP is not None: + return _VENDOR_CONFIG_MAP + + from agent_actions.llm.config.vendor import ( + AgacProviderConfig, + AnthropicConfig, + CohereConfig, + GeminiConfig, + GroqConfig, + HitlVendorConfig, + MistralConfig, + OllamaConfig, + OpenAIConfig, + ToolVendorConfig, + ) + + _VENDOR_CONFIG_MAP = { + "openai": OpenAIConfig, + "anthropic": AnthropicConfig, + "gemini": GeminiConfig, + "google": GeminiConfig, + "groq": GroqConfig, + "cohere": CohereConfig, + "mistral": MistralConfig, + "ollama": OllamaConfig, + "tool": ToolVendorConfig, + "hitl": HitlVendorConfig, + "agac-provider": AgacProviderConfig, + } + return _VENDOR_CONFIG_MAP + + +def _get_api_key_env_name(vendor: str) -> str | None: + """Resolve API key env var name from vendor config class (single source of truth).""" + config_cls = _get_vendor_config_map().get(vendor.lower()) + if config_cls is None: + return None + field_info = config_cls.model_fields.get("api_key_env_name") + if field_info is None: + return None + default = field_info.default + return str(default) if default is not None else None + + +class WorkflowResolutionService: + """Performs unified pre-flight resolution checks.""" + + def __init__( + self, + action_configs: dict[str, dict[str, Any]], + workflow_config_path: str | None = None, + project_root: Path | None = None, + ): + self.action_configs = action_configs + self.workflow_config_path = workflow_config_path + self.project_root = project_root + + def resolve_all(self) -> StaticValidationResult: + """Run all resolution checks and return aggregated result.""" + result = StaticValidationResult() + + if os.environ.get("AA_SKIP_ENV_VALIDATION") != "1": + for error in self._check_api_keys(): + result.add_error(error) + + for error in self._check_seed_file_references(): + result.add_error(error) + + for error in self._check_vendor_run_mode_compatibility(): + result.add_error(error) + + return result + + # ── API key checks ───────────────────────────────────────────────── + + def _check_api_keys(self) -> list[StaticTypeError]: + """Check that all required API key env vars are set.""" + errors: list[StaticTypeError] = [] + + for action_name, config in self.action_configs.items(): + vendor = (config.get("model_vendor") or "").lower() + if not vendor: + continue + + # Resolve the expected env var name from vendor config + env_var_name = _get_api_key_env_name(vendor) + if env_var_name is None: + continue + + # Skip vendors that don't need real keys (tool, hitl) + if any(sentinel in env_var_name for sentinel in _NO_KEY_SENTINELS): + continue + + # If the action config specifies a custom api_key, use that + custom_key = config.get("api_key") + if custom_key: + custom_str = str(custom_key) + if custom_str.startswith("$"): + env_var_name = custom_str[1:] + else: + # Literal key provided — skip env check + continue + + if not os.environ.get(env_var_name): + errors.append( + StaticTypeError( + message=( + f"API key environment variable '{env_var_name}' is not set " + f"(required by action '{action_name}', vendor '{vendor}')" + ), + location=FieldLocation( + agent_name=action_name, + config_field="api_key", + raw_reference=env_var_name, + ), + referenced_agent=action_name, + referenced_field="api_key", + hint=f"Set the environment variable: export {env_var_name}=your_key_here", + ) + ) + + return errors + + # ── Seed file checks ─────────────────────────────────────────────── + + def _check_seed_file_references(self) -> list[StaticTypeError]: + """Check that all $file: references resolve to existing files.""" + errors: list[StaticTypeError] = [] + + seed_data_dir = self._resolve_seed_data_dir() + if seed_data_dir is None: + return errors + + for action_name, config in self.action_configs.items(): + context_scope = config.get("context_scope", {}) + if not isinstance(context_scope, dict): + continue + seed_path_config = context_scope.get("seed_path", {}) + if not seed_path_config or not isinstance(seed_path_config, dict): + continue + + for field_name, file_spec in seed_path_config.items(): + if not isinstance(file_spec, str): + continue + + try: + resolved = resolve_seed_path(file_spec, seed_data_dir) + except ValueError as e: + errors.append( + StaticTypeError( + message=str(e), + location=FieldLocation( + agent_name=action_name, + config_field=f"context_scope.seed_path.{field_name}", + raw_reference=file_spec, + ), + referenced_agent=action_name, + referenced_field=field_name, + hint="Use relative paths within the seed_data/ directory.", + ) + ) + continue + + if not resolved.exists(): + available: list[str] = [] + if seed_data_dir.exists(): + available = sorted(f.name for f in seed_data_dir.iterdir() if f.is_file()) + + errors.append( + StaticTypeError( + message=(f"Seed file not found: {file_spec} (resolved to {resolved})"), + location=FieldLocation( + agent_name=action_name, + config_field=f"context_scope.seed_path.{field_name}", + raw_reference=file_spec, + ), + referenced_agent=action_name, + referenced_field=field_name, + available_fields=set(available), + hint=( + f"Available files: {', '.join(available)}" + if available + else "(seed_data/ directory is empty)" + ), + ) + ) + + return errors + + # ── Vendor run-mode compatibility ────────────────────────────────── + + def _check_vendor_run_mode_compatibility(self) -> list[StaticTypeError]: + """Check that vendor supports the requested run_mode.""" + errors: list[StaticTypeError] = [] + + from agent_actions.validation.preflight.vendor_compatibility_validator import ( + _resolve_capabilities, + ) + + for action_name, config in self.action_configs.items(): + vendor = (config.get("model_vendor") or "").lower() + run_mode = config.get("run_mode", "online") + + # Normalize RunMode enum to string + if hasattr(run_mode, "value"): + run_mode = run_mode.value + + if run_mode != "batch": + continue + + capabilities = _resolve_capabilities(vendor) + if capabilities is None: + continue + + if not capabilities.get("supports_batch"): + errors.append( + StaticTypeError( + message=( + f"Action '{action_name}' uses run_mode=batch with vendor " + f"'{vendor}', but {vendor} does not support batch mode" + ), + location=FieldLocation( + agent_name=action_name, + config_field="run_mode", + raw_reference=f"run_mode=batch, vendor={vendor}", + ), + referenced_agent=action_name, + referenced_field="run_mode", + hint=f"Use run_mode: online for {vendor} actions, or choose a batch-capable vendor.", + ) + ) + + return errors + + # ── Helpers ──────────────────────────────────────────────────────── + + def _resolve_seed_data_dir(self) -> Path | None: + """Resolve the seed_data directory from workflow config path.""" + if not self.workflow_config_path: + return None + + config_path = Path(self.workflow_config_path).resolve() + current = config_path.parent + while current != current.parent: + if (current / "agent_config").exists(): + seed_dir = current / "seed_data" + return seed_dir if seed_dir.exists() else None + if current.name == "agent_config": + seed_dir = current.parent / "seed_data" + return seed_dir if seed_dir.exists() else None + current = current.parent + + return None diff --git a/agent_actions/validation/run_validator.py b/agent_actions/validation/run_validator.py index 36cb62a..656a1bd 100644 --- a/agent_actions/validation/run_validator.py +++ b/agent_actions/validation/run_validator.py @@ -26,3 +26,4 @@ class RunCommandArgs(BaseModel): downstream: bool = Field( False, description="Execute all downstream workflows that depend on this workflow" ) + fresh: bool = Field(False, description="Clear stored results and status before execution") diff --git a/agent_actions/validation/static_analyzer/data_flow_graph.py b/agent_actions/validation/static_analyzer/data_flow_graph.py index 91a1eb2..a630bd0 100644 --- a/agent_actions/validation/static_analyzer/data_flow_graph.py +++ b/agent_actions/validation/static_analyzer/data_flow_graph.py @@ -15,6 +15,7 @@ class OutputSchema: schema_fields: set[str] = field(default_factory=set) observe_fields: set[str] = field(default_factory=set) passthrough_fields: set[str] = field(default_factory=set) + passthrough_wildcard_sources: set[str] = field(default_factory=set) dropped_fields: set[str] = field(default_factory=set) json_schema: dict[str, Any] | None = None is_dynamic: bool = False diff --git a/agent_actions/validation/static_analyzer/schema_extractor.py b/agent_actions/validation/static_analyzer/schema_extractor.py index 9f8f644..9074663 100644 --- a/agent_actions/validation/static_analyzer/schema_extractor.py +++ b/agent_actions/validation/static_analyzer/schema_extractor.py @@ -416,6 +416,11 @@ def _apply_context_scope(self, config: dict[str, Any], output: OutputSchema) -> field_name = self._extract_field_name(ref) if field_name: output.passthrough_fields.add(field_name) + elif isinstance(ref, str) and ".*" in ref: + # Wildcard passthrough: "source.*" → record the source name + source_name = ref.split(".", 1)[0] + if source_name: + output.passthrough_wildcard_sources.add(source_name) scope_observe = context_scope.get("observe", []) for ref in scope_observe: diff --git a/agent_actions/validation/static_analyzer/workflow_static_analyzer.py b/agent_actions/validation/static_analyzer/workflow_static_analyzer.py index a7a5552..a8846dc 100644 --- a/agent_actions/validation/static_analyzer/workflow_static_analyzer.py +++ b/agent_actions/validation/static_analyzer/workflow_static_analyzer.py @@ -22,7 +22,7 @@ InputSchema, OutputSchema, ) -from .errors import FieldLocation, StaticTypeError, StaticValidationResult +from .errors import FieldLocation, StaticTypeError, StaticTypeWarning, StaticValidationResult from .reference_extractor import ReferenceExtractor from .schema_extractor import SchemaExtractor @@ -121,11 +121,19 @@ def analyze(self) -> StaticValidationResult: for error in self._check_schema_structures(): result.add_error(error) + # Step 2f: Validate drop directives target schema/observe fields + for error in self._check_drop_directives(): + result.add_error(error) + # Step 3: Check for unused dependencies (add as warnings) warnings = checker.check_unused_dependencies() for warning in warnings: result.add_warning(warning) + # Step 3b: Check lineage reachability for observe/passthrough references + for warning in self._check_lineage_reachability(): + result.add_warning(warning) + return result def _build_graph(self) -> None: @@ -389,6 +397,214 @@ def _check_schema_structures(self) -> list[StaticTypeError]: return errors + def _check_drop_directives(self) -> list[StaticTypeError]: + """Validate that drop directives reference actual schema/observe fields. + + Drop directives remove fields from the LLM context. If the referenced + field is a passthrough field (not in the LLM context namespace), the + drop is a no-op and the user should be warned. + """ + errors: list[StaticTypeError] = [] + actions = self.workflow_config.get("actions", []) + + for action in actions: + if not isinstance(action, dict): + continue + + action_name = action.get("name", "unknown") + context_scope = action.get("context_scope", {}) + if not isinstance(context_scope, dict): + continue + drop_refs = context_scope.get("drop", []) + if not isinstance(drop_refs, list): + continue + + for drop_ref in drop_refs: + if not isinstance(drop_ref, str) or "." not in drop_ref: + continue + + dep_name, field_name = drop_ref.split(".", 1) + + if dep_name in SPECIAL_NAMESPACES or dep_name == "loop": + continue + if field_name == "*": + continue + + dep_node = self.graph.get_node(dep_name) + if not dep_node: + continue # Unknown dep — caught by other checks + + output = dep_node.output_schema + if output.is_dynamic or output.is_schemaless: + continue # Can't validate + + if field_name in output.schema_fields or field_name in output.observe_fields: + continue # Valid drop target + + if field_name in output.passthrough_fields: + errors.append( + StaticTypeError( + message=( + f"Drop directive '{drop_ref}' targets passthrough field " + f"'{field_name}' on '{dep_name}'. Passthrough fields are not " + f"in the LLM context namespace, so this drop has no effect." + ), + location=FieldLocation( + agent_name=action_name, + config_field="context_scope.drop", + raw_reference=drop_ref, + ), + referenced_agent=dep_name, + referenced_field=field_name, + available_fields=output.schema_fields | output.observe_fields, + hint=( + f"Remove this drop directive. '{field_name}' is a passthrough " + f"field — it doesn't appear in the LLM context. " + f"Schema fields: {', '.join(sorted(output.schema_fields))}" + ), + ) + ) + elif field_name not in output.available_fields: + errors.append( + StaticTypeError( + message=( + f"Drop directive '{drop_ref}' references non-existent field " + f"'{field_name}' in '{dep_name}'" + ), + location=FieldLocation( + agent_name=action_name, + config_field="context_scope.drop", + raw_reference=drop_ref, + ), + referenced_agent=dep_name, + referenced_field=field_name, + available_fields=output.schema_fields | output.observe_fields, + hint=( + f"Available schema fields in '{dep_name}': " + f"{', '.join(sorted(output.schema_fields))}" + ), + ) + ) + + return errors + + def _check_lineage_reachability(self) -> list[StaticTypeWarning]: + """Check that observe references to non-direct-dependencies are reachable. + + When action C observes ``A.field`` but C only depends on B (not A directly), + the data must flow A → B → C via passthrough on B. This check verifies + the passthrough chain exists. + """ + warnings: list[StaticTypeWarning] = [] + actions = self.workflow_config.get("actions", []) + + for action in actions: + if not isinstance(action, dict): + continue + + node_name = action.get("name", "unknown") + node = self.graph.get_node(node_name) + if not node: + continue + + context_scope = action.get("context_scope", {}) + if not isinstance(context_scope, dict): + continue + + observe_refs = context_scope.get("observe", []) + if not isinstance(observe_refs, list): + continue + + for ref in observe_refs: + if not isinstance(ref, str) or "." not in ref: + continue + + source_name, field_name = ref.split(".", 1) + + if source_name in SPECIAL_NAMESPACES or source_name == "loop": + continue + if field_name == "*": + continue + + # If source is a direct dependency, no lineage concern + if source_name in node.dependencies: + continue + + # Source is NOT a direct dependency — must travel through intermediates + reachable = self.graph.get_reachable_upstream_names(node_name) + if source_name not in reachable: + continue # Not reachable at all — caught by type checker + + if not self._trace_field_through_chain(source_name, field_name, node_name): + warnings.append( + StaticTypeWarning( + message=( + f"Observe reference '{ref}' on '{node_name}' references " + f"non-direct dependency '{source_name}'. The field " + f"'{field_name}' may not survive through intermediate " + f"actions via passthrough." + ), + location=FieldLocation( + agent_name=node_name, + config_field="context_scope.observe", + raw_reference=ref, + ), + referenced_agent=source_name, + referenced_field=field_name, + hint=( + f"Ensure intermediate actions between '{source_name}' and " + f"'{node_name}' have passthrough: ['{source_name}.*'] or " + f"passthrough: ['{source_name}.{field_name}'] in their " + f"context_scope." + ), + ) + ) + + return warnings + + def _trace_field_through_chain(self, source: str, field: str, target: str) -> bool: + """Check if a field from *source* can reach *target* through passthrough chains. + + BFS backwards from *target* through dependencies. At each intermediate + node, checks whether the field survives (exact passthrough, wildcard + passthrough, or dynamic schema). + """ + from collections import deque + + target_node = self.graph.get_node(target) + if not target_node: + return False + + visited: set[str] = set() + queue = deque(target_node.dependencies) + + while queue: + current = queue.popleft() + if current in visited: + continue + visited.add(current) + + if current == source: + return True # Direct path found + + current_node = self.graph.get_node(current) + if not current_node: + continue + + output = current_node.output_schema + survives = ( + source in output.passthrough_wildcard_sources + or field in output.passthrough_fields + or output.is_dynamic + ) + + if survives: + for dep in current_node.dependencies: + if dep not in visited: + queue.append(dep) + + return False + def _add_source_node(self) -> None: """Add the special source node for workflow input.""" if self.source_schema: diff --git a/agent_actions/workflow/coordinator.py b/agent_actions/workflow/coordinator.py index d2ccf20..4e3fda3 100644 --- a/agent_actions/workflow/coordinator.py +++ b/agent_actions/workflow/coordinator.py @@ -86,6 +86,10 @@ def __init__(self, config: WorkflowRuntimeConfig): self.metadata, config, self.storage_backend, self.console ) + # Fresh run: clear stored results + status before anything else + if config.fresh: + self._clear_for_fresh_run() + # Dependency orchestration + session self._init_dependency_orchestrator() self.workflow_session_id = self._generate_workflow_session_id() @@ -125,9 +129,34 @@ def _run_static_validation(self) -> None: hint="Fix the guard condition syntax errors above before running the workflow.", ) + # Resolution checks: API keys, seed files, vendor batch compatibility + from agent_actions.validation.preflight.resolution_service import ( + WorkflowResolutionService, + ) + + resolution_result = WorkflowResolutionService( + action_configs=self.action_configs, + workflow_config_path=self.config.paths.constructor_path, + project_root=self.config.project_root, + ).resolve_all() + resolution_result.raise_if_invalid() + def _validate_guard_conditions(self) -> list[str]: return validate_guard_conditions(self.action_configs) + def _clear_for_fresh_run(self) -> None: + """Clear stored results, dispositions, and status for a fresh run.""" + for action_name in self.execution_order: + try: + self.storage_backend.delete_target(action_name) + self.storage_backend.clear_disposition(action_name) + except Exception as e: + logger.warning("Failed to clear stored data for %s: %s", action_name, e) + self.services.core.state_manager.reset() + self.console.print( + "[yellow]--fresh: cleared stored results and reset all actions to pending[/yellow]" + ) + # ── Properties ────────────────────────────────────────────────────── @property @@ -311,14 +340,23 @@ async def async_run(self, concurrency_limit: int = 5): if not level_complete: return + state_mgr = self.services.core.state_manager duration = (datetime.now() - workflow_start).total_seconds() - self.event_logger.finalize_workflow(elapsed_time=duration) - downstream_success = self._resolve_downstream_workflows() - if not downstream_success: - return None + if state_mgr.is_workflow_complete(): + self.event_logger.finalize_workflow(elapsed_time=duration) + downstream_success = self._resolve_downstream_workflows() + if not downstream_success: + return None + return ("success", {}) - return ("success", {}) + if state_mgr.is_workflow_done(): + self.state.failed = True + self.event_logger.finalize_workflow(elapsed_time=duration) + failed = state_mgr.get_failed_actions(self.execution_order) + return ("completed_with_failures", {"failed": failed}) + + return None except Exception as e: duration = (datetime.now() - workflow_start).total_seconds() @@ -362,8 +400,10 @@ def _run_workflow_with_context(self, workflow_start): if should_stop: break - if self.services.core.state_manager.is_workflow_complete(): - duration = (datetime.now() - workflow_start).total_seconds() + state_mgr = self.services.core.state_manager + duration = (datetime.now() - workflow_start).total_seconds() + + if state_mgr.is_workflow_complete(): self.event_logger.finalize_workflow(elapsed_time=duration) downstream_success = self._resolve_downstream_workflows() @@ -372,6 +412,13 @@ def _run_workflow_with_context(self, workflow_start): return ("success", {}) + if state_mgr.is_workflow_done(): + # All actions reached a terminal state but some failed + self.state.failed = True + self.event_logger.finalize_workflow(elapsed_time=duration) + failed = state_mgr.get_failed_actions(self.execution_order) + return ("completed_with_failures", {"failed": failed}) + return None except Exception as e: @@ -425,6 +472,9 @@ def _run_single_action(self, idx: int, action_name: str, total_actions: int) -> if result.status == "batch_submitted": return True + if result.status == "skipped": + return False # Continue to next action + if result.output_folder and result.status == "completed": self.state.ephemeral_directories.append( { @@ -434,4 +484,6 @@ def _run_single_action(self, idx: int, action_name: str, total_actions: int) -> ) return False - raise result.error + # Action failed — log and continue (circuit breaker handles downstream) + logger.warning("Action '%s' failed: %s", action_name, result.error) + return False diff --git a/agent_actions/workflow/executor.py b/agent_actions/workflow/executor.py index 0a5288c..bbfe4f1 100644 --- a/agent_actions/workflow/executor.py +++ b/agent_actions/workflow/executor.py @@ -19,6 +19,7 @@ BatchCompleteEvent, BatchSubmittedEvent, ) +from agent_actions.storage.backend import DISPOSITION_FAILED, NODE_LEVEL_RECORD_ID from agent_actions.tooling.docs.run_tracker import ActionCompleteConfig from agent_actions.utils.constants import DEFAULT_ACTION_KIND @@ -174,6 +175,26 @@ def _verify_completion_status( storage_backend = getattr(self.deps.action_runner, "storage_backend", None) if storage_backend is not None: try: + # Check disposition FIRST — a failed action may have partial + # results in storage. The disposition is the authoritative + # signal; output existence is irrelevant when it's set. + if storage_backend.has_disposition( + action_name, + DISPOSITION_FAILED, + record_id=NODE_LEVEL_RECORD_ID, + ): + logger.info( + "Action %s has DISPOSITION_FAILED from prior run — re-running", + action_name, + ) + storage_backend.clear_disposition( + action_name, + DISPOSITION_FAILED, + record_id=NODE_LEVEL_RECORD_ID, + ) + self.deps.state_manager.update_status(action_name, "pending") + return (False, None) + target_files = storage_backend.list_target_files(action_name) if not target_files: logger.info( @@ -335,12 +356,31 @@ def _handle_run_success( ), ) + def _write_failed_disposition(self, action_name: str, reason: str) -> None: + """Write DISPOSITION_FAILED to storage so downstream and future runs detect the failure.""" + storage_backend = getattr(self.deps.action_runner, "storage_backend", None) + if storage_backend is not None: + try: + storage_backend.set_disposition( + action_name=action_name, + record_id=NODE_LEVEL_RECORD_ID, + disposition=DISPOSITION_FAILED, + reason=reason[:500], + ) + except Exception as disp_err: + logger.warning( + "Failed to write DISPOSITION_FAILED for %s: %s", + action_name, + disp_err, + ) + def _handle_run_failure( self, params: ActionRunParams, error: Exception ) -> ActionExecutionResult: """Handle action run failure.""" duration = (datetime.now() - params.start_time).total_seconds() self.deps.state_manager.update_status(params.action_name, "failed") + self._write_failed_disposition(params.action_name, str(error)) if self.run_tracker is not None and self.run_id is not None: config = ActionCompleteConfig( @@ -373,6 +413,70 @@ def _cleanup_correlation( }, ) + def _check_upstream_health( + self, action_name: str, action_config: ActionConfigDict + ) -> str | None: + """Return the name of a failed upstream dependency, or None if all healthy.""" + dependencies = action_config.get("dependencies", []) + if not dependencies: + return None + for dep in dependencies: + if self.deps.state_manager.is_failed(dep): + return dep + # Also check disposition — covers cascaded failures from prior levels + storage_backend = getattr(self.deps.action_runner, "storage_backend", None) + if storage_backend is not None and storage_backend.has_disposition( + dep, DISPOSITION_FAILED, record_id=NODE_LEVEL_RECORD_ID + ): + return dep + return None + + def _handle_dependency_skip( + self, + action_name: str, + action_idx: int, + action_config: ActionConfigDict, + start_time: datetime, + failed_dependency: str, + ) -> ActionExecutionResult: + """Handle action skip due to upstream dependency failure. + + State is set to ``"failed"`` so transitive dependents also skip via + ``is_failed``. ``success=True`` keeps independent branches alive. + """ + reason = f"Upstream dependency '{failed_dependency}' failed" + self.deps.state_manager.update_status(action_name, "failed") + self._write_failed_disposition(action_name, reason) + + duration = (datetime.now() - start_time).total_seconds() + total_actions = ( + len(self.deps.action_runner.execution_order) + if hasattr(self.deps.action_runner, "execution_order") + else 0 + ) + fire_event( + ActionSkipEvent( + action_name=action_name, + action_index=action_idx, + total_actions=total_actions, + skip_reason=reason, + ) + ) + + if self.run_tracker is not None and self.run_id is not None: + config = ActionCompleteConfig( + run_id=self.run_id, + action_name=action_name, + status="failed", + duration_seconds=duration, + skip_reason=reason, + ) + self.run_tracker.record_action_complete(config=config) + + return ActionExecutionResult( + success=True, status="failed", metrics=ExecutionMetrics(duration=duration) + ) + def execute_action_sync( self, action_name: str, @@ -413,6 +517,14 @@ def execute_action_sync( if current_status == "batch_submitted": return self._handle_batch_check(action_name, action_idx, action_config, start_time) + # Circuit breaker: skip if any upstream dependency has failed. + # Must run BEFORE get_previous_outputs to avoid reading corrupt data. + failed_dep = self._check_upstream_health(action_name, action_config) + if failed_dep is not None: + return self._handle_dependency_skip( + action_name, action_idx, action_config, start_time, failed_dep + ) + previous_outputs = self.deps.output_manager.get_previous_outputs(action_idx) if self.deps.skip_evaluator.should_skip_action(action_config, previous_outputs): return self._handle_action_skip(action_name, action_idx, action_config, start_time) @@ -469,6 +581,13 @@ async def execute_action_async( action_name, action_idx, action_config, start_time ) + # Circuit breaker: skip if any upstream dependency has failed. + failed_dep = self._check_upstream_health(action_name, action_config) + if failed_dep is not None: + return self._handle_dependency_skip( + action_name, action_idx, action_config, start_time, failed_dep + ) + previous_outputs = self.deps.output_manager.get_previous_outputs(action_idx) if self.deps.skip_evaluator.should_skip_action(action_config, previous_outputs): return self._handle_action_skip(action_name, action_idx, action_config, start_time) @@ -538,6 +657,7 @@ def _handle_batch_check( ) self.deps.state_manager.update_status(action_name, "failed") + self._write_failed_disposition(action_name, f"Batch job for {action_name} failed") fire_event( BatchCompleteEvent( batch_id=action_config.get("batch_id", ""), @@ -611,6 +731,7 @@ async def _handle_batch_check_async( ) self.deps.state_manager.update_status(action_name, "failed") + self._write_failed_disposition(action_name, f"Batch job for {action_name} failed") fire_event( BatchCompleteEvent( batch_id=action_config.get("batch_id", ""), diff --git a/agent_actions/workflow/managers/state.py b/agent_actions/workflow/managers/state.py index 0bcda66..7436140 100644 --- a/agent_actions/workflow/managers/state.py +++ b/agent_actions/workflow/managers/state.py @@ -31,6 +31,11 @@ def _load_status(self): else: self._initialize_default_status() + def reset(self) -> None: + """Reset all actions to 'pending' status and persist.""" + self._initialize_default_status() + self._save_status() + def _initialize_default_status(self): """Initialize all actions with 'pending' status.""" self.action_status = {action: {"status": "pending"} for action in self.execution_order} @@ -79,8 +84,9 @@ def is_failed(self, action_name: str) -> bool: return self.get_status(action_name) == "failed" def get_pending_actions(self, agents: list[str]) -> list[str]: - """Return actions that are not yet completed.""" - return [agent for agent in agents if not self.is_completed(agent)] + """Return actions that are not yet completed or failed (runnable).""" + terminal = {"completed", "failed"} + return [agent for agent in agents if self.get_status(agent) not in terminal] def get_batch_submitted_actions(self, agents: list[str]) -> list[str]: """Return actions with batch jobs submitted.""" @@ -115,6 +121,11 @@ def is_workflow_complete(self) -> bool: """Return True if all actions have 'completed' status.""" return all(details.get("status") == "completed" for details in self.action_status.values()) + def is_workflow_done(self) -> bool: + """Return True if all actions are in a terminal state (completed or failed).""" + terminal = {"completed", "failed"} + return all(details.get("status") in terminal for details in self.action_status.values()) + def has_any_failed(self) -> bool: """Return True if any action has 'failed' status.""" return any(details.get("status") == "failed" for details in self.action_status.values()) diff --git a/agent_actions/workflow/models.py b/agent_actions/workflow/models.py index 9a73186..a0a6d5e 100644 --- a/agent_actions/workflow/models.py +++ b/agent_actions/workflow/models.py @@ -34,6 +34,7 @@ class WorkflowRuntimeConfig: use_tools: bool run_upstream: bool = False run_downstream: bool = False + fresh: bool = False manager: Any = None # ConfigManager instance project_root: Path | None = None diff --git a/agent_actions/workflow/parallel/action_executor.py b/agent_actions/workflow/parallel/action_executor.py index 36d822c..4e5e4d6 100644 --- a/agent_actions/workflow/parallel/action_executor.py +++ b/agent_actions/workflow/parallel/action_executor.py @@ -4,6 +4,7 @@ import asyncio import copy +import logging from dataclasses import dataclass from datetime import datetime from typing import Any @@ -14,6 +15,8 @@ from agent_actions.logging.core.manager import fire_event from agent_actions.logging.events import ActionCompleteEvent, ActionFailedEvent +logger = logging.getLogger(__name__) + @dataclass class ParallelExecutionParams: @@ -177,7 +180,7 @@ async def _execute_single_action(self, action_name: str, action_indices: dict, a self._fire_action_result_event(action_name, original_idx, total_actions, result) if not result.success: - raise result.error + logger.warning("Action '%s' failed: %s", action_name, result.error) async def _execute_parallel_actions(self, params: ParallelExecutionParams): """Execute multiple actions in parallel.""" @@ -228,10 +231,12 @@ async def run_with_limit(action): if errors: error_details = "\n".join([f" - {action}: {str(exc)}" for action, exc in errors]) - error_msg = ( - f"Multiple actions failed in parallel action {params.level_idx}:\n{error_details}" + logger.warning( + "Level %d: %d action(s) failed:\n%s", + params.level_idx, + len(errors), + error_details, ) - raise WorkflowError(error_msg, context={"level_idx": params.level_idx}) def _fire_action_result_event(self, action_name: str, idx: int, total: int, result): """Fire action complete or failed event for an execution result.""" @@ -268,15 +273,15 @@ def _check_batch_status( batch_pending = state_manager.get_batch_submitted_actions(level_actions) if batch_pending: - # Check for partial failures + # Log partial failures but don't raise — circuit breaker handles cascade failed_actions = state_manager.get_failed_actions(level_actions) if failed_actions: - error_msg = ( - f"Partial failure in parallel action {level_idx}: " - f"{', '.join(failed_actions)} failed while " - "batch jobs were submitted" + logger.warning( + "Level %d: %s failed while batch jobs pending for %s", + level_idx, + ", ".join(failed_actions), + ", ".join(batch_pending), ) - raise WorkflowError(error_msg, context={"level_idx": level_idx}) # Batch jobs submitted, need to wait duration = (datetime.now() - start_time).total_seconds() @@ -295,8 +300,8 @@ async def execute_level_async(self, params: LevelExecutionParams) -> bool: Returns: True if level completed, False if batch jobs pending. - Raises: - WorkflowError: If any action fails during execution. + Failed actions are logged but do not raise — the circuit breaker + in ActionExecutor handles downstream skipping. """ start_time = datetime.now() diff --git a/agent_actions/workflow/pipeline.py b/agent_actions/workflow/pipeline.py index fe8daa9..87c41c6 100644 --- a/agent_actions/workflow/pipeline.py +++ b/agent_actions/workflow/pipeline.py @@ -490,6 +490,23 @@ def _process_by_strategy( storage_backend=self.config.storage_backend, ) + # If input had records but output is empty AND there are actual failures + # (not just guard-filtered/skipped records), raise so the executor marks + # the action as failed and the circuit breaker skips downstream dependents. + # Guard filters (SKIPPED/FILTERED status) legitimately produce 0 output — + # only FAILED results indicate processing errors (e.g. 401 auth). + if data and not output: + from agent_actions.processing.types import ProcessingStatus + + failed_results = [r for r in results if r.status == ProcessingStatus.FAILED] + if failed_results: + failed_msgs = [r.error for r in failed_results if r.error] + summary = "; ".join(failed_msgs[:3]) + raise RuntimeError( + f"Action '{self.config.action_name}' produced 0 records — " + f"all {len(data)} input item(s) failed: {summary}" + ) + self.output_handler.save_main_output(output, file_path, base_directory, output_directory) @staticmethod diff --git a/tests/unit/storage/test_delete_target.py b/tests/unit/storage/test_delete_target.py new file mode 100644 index 0000000..bbeafe9 --- /dev/null +++ b/tests/unit/storage/test_delete_target.py @@ -0,0 +1,60 @@ +"""Tests for SQLiteBackend.delete_target().""" + +import pytest + +from agent_actions.storage.backends.sqlite_backend import SQLiteBackend + + +class TestDeleteTarget: + """Tests for delete_target() method.""" + + @pytest.fixture + def backend(self, tmp_path): + """Create a fresh SQLite backend for testing.""" + db_path = tmp_path / "agent_io" / "test.db" + backend = SQLiteBackend(str(db_path), "test_workflow") + backend.initialize() + yield backend + backend.close() + + def test_deletes_matching_rows_returns_count(self, backend): + """Deletes matching rows and returns the count.""" + backend.write_target("action_a", "batch_001.json", [{"id": 1}]) + backend.write_target("action_a", "batch_002.json", [{"id": 2}]) + backend.write_target("action_b", "batch_001.json", [{"id": 3}]) + + deleted = backend.delete_target("action_a") + + assert deleted == 2 + # action_b data should remain + remaining = backend.list_target_files("action_b") + assert remaining == ["batch_001.json"] + + def test_returns_zero_when_no_matching_rows(self, backend): + """Returns 0 when no matching rows exist.""" + deleted = backend.delete_target("nonexistent_action") + assert deleted == 0 + + def test_works_with_real_sqlite_backend(self, tmp_path): + """Full roundtrip: write, verify, delete, verify gone.""" + db_path = tmp_path / "test.db" + backend = SQLiteBackend(str(db_path), "wf") + backend.initialize() + + try: + data = [{"field": "value1"}, {"field": "value2"}] + backend.write_target("my_action", "output.json", data) + + # Verify data exists + files = backend.list_target_files("my_action") + assert len(files) == 1 + + # Delete + deleted = backend.delete_target("my_action") + assert deleted == 1 + + # Verify gone + files = backend.list_target_files("my_action") + assert len(files) == 0 + finally: + backend.close() diff --git a/tests/unit/utils/test_path_security.py b/tests/unit/utils/test_path_security.py new file mode 100644 index 0000000..3630170 --- /dev/null +++ b/tests/unit/utils/test_path_security.py @@ -0,0 +1,76 @@ +"""Tests for resolve_seed_path() in path_security module.""" + +import pytest + +from agent_actions.utils.path_security import FILE_PREFIX, resolve_seed_path + + +class TestResolveSeedPath: + """Tests for resolve_seed_path().""" + + def test_valid_file_reference_resolves(self, tmp_path): + """A $file: prefixed reference resolves to the correct absolute path.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + (seed_dir / "data.json").write_text("{}") + + result = resolve_seed_path("$file:data.json", seed_dir) + + assert result == (seed_dir / "data.json").resolve() + + def test_reference_without_prefix_resolves(self, tmp_path): + """A reference without $file: prefix resolves correctly.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + (seed_dir / "data.json").write_text("{}") + + result = resolve_seed_path("data.json", seed_dir) + + assert result == (seed_dir / "data.json").resolve() + + def test_empty_file_spec_raises_value_error(self, tmp_path): + """Empty file_spec raises ValueError.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + + with pytest.raises(ValueError, match="Empty file spec"): + resolve_seed_path("", seed_dir) + + def test_path_traversal_raises_value_error(self, tmp_path): + """Path traversal attempt (../../etc/passwd) raises ValueError.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + + with pytest.raises(ValueError, match="Seed file path escapes base directory"): + resolve_seed_path("../../etc/passwd", seed_dir) + + def test_path_traversal_with_prefix_raises(self, tmp_path): + """Path traversal with $file: prefix also raises ValueError.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + + with pytest.raises(ValueError, match="Seed file path escapes base directory"): + resolve_seed_path("$file:../../etc/passwd", seed_dir) + + def test_empty_path_after_prefix_raises(self, tmp_path): + """$file: prefix with no path after it raises ValueError.""" + seed_dir = tmp_path / "seed_data" + seed_dir.mkdir() + + with pytest.raises(ValueError, match="Empty path after prefix"): + resolve_seed_path("$file:", seed_dir) + + def test_subdirectory_path_resolves(self, tmp_path): + """Subdirectory paths within base_dir resolve correctly.""" + seed_dir = tmp_path / "seed_data" + sub = seed_dir / "subdir" + sub.mkdir(parents=True) + (sub / "nested.json").write_text("{}") + + result = resolve_seed_path("$file:subdir/nested.json", seed_dir) + + assert result == (sub / "nested.json").resolve() + + def test_file_prefix_constant(self): + """FILE_PREFIX constant is set correctly.""" + assert FILE_PREFIX == "$file:" diff --git a/tests/unit/validation/test_drop_and_lineage.py b/tests/unit/validation/test_drop_and_lineage.py new file mode 100644 index 0000000..840f896 --- /dev/null +++ b/tests/unit/validation/test_drop_and_lineage.py @@ -0,0 +1,448 @@ +"""Tests for drop directives and lineage reachability in WorkflowStaticAnalyzer.""" + +from agent_actions.config.schema import ActionKind +from agent_actions.validation.static_analyzer.data_flow_graph import ( + DataFlowGraph, + DataFlowNode, + OutputSchema, +) +from agent_actions.validation.static_analyzer.workflow_static_analyzer import ( + WorkflowStaticAnalyzer, +) + + +def _build_analyzer_with_graph(workflow_config, graph): + """Build a WorkflowStaticAnalyzer and inject a pre-built graph.""" + analyzer = WorkflowStaticAnalyzer(workflow_config) + analyzer.graph = graph + analyzer._built = True + return analyzer + + +class TestCheckDropDirectives: + """Tests for _check_drop_directives().""" + + def _make_graph_with_upstream( + self, + schema_fields=None, + observe_fields=None, + passthrough_fields=None, + is_dynamic=False, + is_schemaless=False, + ): + """Build a graph with source -> upstream_action -> downstream_action.""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="upstream", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema( + schema_fields=schema_fields or set(), + observe_fields=observe_fields or set(), + passthrough_fields=passthrough_fields or set(), + is_dynamic=is_dynamic, + is_schemaless=is_schemaless, + ), + dependencies={"source"}, + ) + ) + return graph + + def test_drop_on_schema_field_no_error(self): + """Drop on a schema field produces no error.""" + graph = self._make_graph_with_upstream(schema_fields={"summary", "title"}) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.summary"], + "observe": ["upstream.title"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 0 + + def test_drop_on_passthrough_field_produces_error(self): + """Drop on a passthrough field produces an error with a hint about passthrough.""" + graph = self._make_graph_with_upstream( + schema_fields={"summary"}, + passthrough_fields={"forwarded_field"}, + ) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.forwarded_field"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 1 + assert "passthrough" in errors[0].message.lower() + assert "forwarded_field" in errors[0].message + + def test_drop_on_nonexistent_field_produces_error(self): + """Drop on a non-existent field produces an error with available fields.""" + graph = self._make_graph_with_upstream(schema_fields={"summary", "title"}) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.nonexistent"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 1 + msg = errors[0].message.lower() + assert "non-existent" in msg or "nonexistent" in msg + assert errors[0].available_fields # Should have available fields + + def test_drop_with_wildcard_no_error(self): + """Drop with wildcard produces no error.""" + graph = self._make_graph_with_upstream(schema_fields={"summary"}) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.*"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 0 + + def test_drop_on_dynamic_schema_no_error(self): + """Drop on a dynamic schema produces no error (skipped).""" + graph = self._make_graph_with_upstream(is_dynamic=True) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.anything"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 0 + + def test_drop_on_observe_field_no_error(self): + """Drop on an observe field (not schema) produces no error.""" + graph = self._make_graph_with_upstream( + schema_fields={"summary"}, + observe_fields={"observed_field"}, + ) + workflow_config = { + "actions": [ + { + "name": "downstream", + "depends_on": ["upstream"], + "context_scope": { + "drop": ["upstream.observed_field"], + }, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + errors = analyzer._check_drop_directives() + assert len(errors) == 0 + + +class TestCheckLineageReachability: + """Tests for _check_lineage_reachability().""" + + def test_direct_dependency_observe_no_warning(self): + """Observing a field from a direct dependency produces no warning.""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="A", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"field_x"}), + dependencies={"source"}, + ) + ) + graph.add_node( + DataFlowNode( + name="B", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"result"}), + dependencies={"A"}, + ) + ) + + workflow_config = { + "actions": [ + {"name": "A", "depends_on": ["source"]}, + { + "name": "B", + "depends_on": ["A"], + "context_scope": {"observe": ["A.field_x"]}, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + warnings = analyzer._check_lineage_reachability() + assert len(warnings) == 0 + + def test_transitive_observe_with_wildcard_passthrough_no_warning(self): + """Transitive observe with wildcard passthrough produces no warning.""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="A", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"field_x"}), + dependencies={"source"}, + ) + ) + graph.add_node( + DataFlowNode( + name="B", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema( + schema_fields={"result"}, + passthrough_wildcard_sources={"A"}, + ), + dependencies={"A"}, + ) + ) + graph.add_node( + DataFlowNode( + name="C", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"final"}), + dependencies={"B"}, + ) + ) + + workflow_config = { + "actions": [ + {"name": "A", "depends_on": ["source"]}, + { + "name": "B", + "depends_on": ["A"], + "context_scope": {"passthrough": ["A.*"]}, + }, + { + "name": "C", + "depends_on": ["B"], + "context_scope": {"observe": ["A.field_x"]}, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + warnings = analyzer._check_lineage_reachability() + assert len(warnings) == 0 + + def test_transitive_observe_with_explicit_field_passthrough_no_warning(self): + """Transitive observe with explicit field passthrough produces no warning.""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="A", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"field_x"}), + dependencies={"source"}, + ) + ) + graph.add_node( + DataFlowNode( + name="B", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema( + schema_fields={"result"}, + passthrough_fields={"field_x"}, + ), + dependencies={"A"}, + ) + ) + graph.add_node( + DataFlowNode( + name="C", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"final"}), + dependencies={"B"}, + ) + ) + + workflow_config = { + "actions": [ + {"name": "A", "depends_on": ["source"]}, + { + "name": "B", + "depends_on": ["A"], + "context_scope": {"passthrough": ["A.field_x"]}, + }, + { + "name": "C", + "depends_on": ["B"], + "context_scope": {"observe": ["A.field_x"]}, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + warnings = analyzer._check_lineage_reachability() + assert len(warnings) == 0 + + def test_transitive_observe_without_passthrough_produces_warning(self): + """Transitive observe with NO passthrough produces a warning.""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="A", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"field_x"}), + dependencies={"source"}, + ) + ) + graph.add_node( + DataFlowNode( + name="B", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"result"}), + dependencies={"A"}, + ) + ) + graph.add_node( + DataFlowNode( + name="C", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"final"}), + dependencies={"B"}, + ) + ) + + workflow_config = { + "actions": [ + {"name": "A", "depends_on": ["source"]}, + {"name": "B", "depends_on": ["A"]}, + { + "name": "C", + "depends_on": ["B"], + "context_scope": {"observe": ["A.field_x"]}, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + warnings = analyzer._check_lineage_reachability() + assert len(warnings) == 1 + assert "field_x" in warnings[0].message + assert "passthrough" in warnings[0].hint.lower() + + def test_dynamic_intermediate_no_warning(self): + """Dynamic intermediate schema produces no warning (data may survive).""" + graph = DataFlowGraph() + graph.add_node( + DataFlowNode( + name="source", + agent_kind=ActionKind.SOURCE, + output_schema=OutputSchema(is_dynamic=True), + ) + ) + graph.add_node( + DataFlowNode( + name="A", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"field_x"}), + dependencies={"source"}, + ) + ) + graph.add_node( + DataFlowNode( + name="B", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(is_dynamic=True), + dependencies={"A"}, + ) + ) + graph.add_node( + DataFlowNode( + name="C", + agent_kind=ActionKind.LLM, + output_schema=OutputSchema(schema_fields={"final"}), + dependencies={"B"}, + ) + ) + + workflow_config = { + "actions": [ + {"name": "A", "depends_on": ["source"]}, + {"name": "B", "depends_on": ["A"]}, + { + "name": "C", + "depends_on": ["B"], + "context_scope": {"observe": ["A.field_x"]}, + }, + ], + } + analyzer = _build_analyzer_with_graph(workflow_config, graph) + + warnings = analyzer._check_lineage_reachability() + assert len(warnings) == 0 diff --git a/tests/unit/validation/test_resolution_service.py b/tests/unit/validation/test_resolution_service.py new file mode 100644 index 0000000..0b15f40 --- /dev/null +++ b/tests/unit/validation/test_resolution_service.py @@ -0,0 +1,294 @@ +"""Tests for WorkflowResolutionService pre-flight checks.""" + +from agent_actions.validation.preflight.resolution_service import ( + WorkflowResolutionService, +) + + +class TestApiKeyChecks: + """Tests for _check_api_keys() via resolve_all().""" + + def test_missing_api_key_env_var_detected(self, monkeypatch): + """Missing API key env var produces an error with the correct message.""" + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "summarizer": {"model_vendor": "openai"}, + }, + ) + result = svc.resolve_all() + + assert not result.is_valid + assert len(result.errors) == 1 + err = result.errors[0] + assert "OPENAI_API_KEY" in err.message + assert "summarizer" in err.message + + def test_present_api_key_passes(self, monkeypatch): + """When the API key env var is set, no error is produced.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key") + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "summarizer": {"model_vendor": "openai"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [ + e + for e in result.errors + if "api_key" in e.location.config_field.lower() or "API key" in e.message + ] + assert len(api_key_errors) == 0 + + def test_tool_vendor_skipped(self, monkeypatch): + """Tool vendor actions are skipped (NO_KEY_REQUIRED sentinel).""" + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "my_tool": {"model_vendor": "tool"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 0 + + def test_hitl_vendor_skipped(self, monkeypatch): + """HITL vendor actions are skipped (NO_KEY_REQUIRED sentinel).""" + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "review": {"model_vendor": "hitl"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 0 + + def test_unknown_vendor_skipped(self, monkeypatch): + """Unknown vendor produces no api-key error (no config to check against).""" + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "custom": {"model_vendor": "totally_unknown_vendor"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 0 + + def test_custom_api_key_dollar_resolved_as_env_var(self, monkeypatch): + """Custom api_key starting with $ is resolved as an env var name.""" + monkeypatch.delenv("MY_CUSTOM_KEY", raising=False) + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "my_action": {"model_vendor": "openai", "api_key": "$MY_CUSTOM_KEY"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 1 + assert "MY_CUSTOM_KEY" in api_key_errors[0].message + + def test_skip_env_validation_flag(self, monkeypatch): + """AA_SKIP_ENV_VALIDATION=1 skips all env checks.""" + monkeypatch.setenv("AA_SKIP_ENV_VALIDATION", "1") + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "summarizer": {"model_vendor": "openai"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 0 + + def test_empty_string_env_var_treated_as_missing(self, monkeypatch): + """Empty string env var is treated as missing.""" + monkeypatch.setenv("OPENAI_API_KEY", "") + monkeypatch.delenv("AA_SKIP_ENV_VALIDATION", raising=False) + + svc = WorkflowResolutionService( + action_configs={ + "summarizer": {"model_vendor": "openai"}, + }, + ) + result = svc.resolve_all() + + api_key_errors = [e for e in result.errors if "API key" in e.message] + assert len(api_key_errors) == 1 + + +class TestSeedFileChecks: + """Tests for _check_seed_file_references().""" + + def test_missing_seed_file_detected(self, tmp_path): + """Missing seed file is detected with available files in hint.""" + # Setup directory structure: project/agent_config/workflow.yml + project/seed_data/ + project = tmp_path / "project" + agent_config = project / "agent_config" + agent_config.mkdir(parents=True) + seed_data = project / "seed_data" + seed_data.mkdir() + (seed_data / "existing.json").write_text("{}") + + workflow_path = str(agent_config / "workflow.yml") + + svc = WorkflowResolutionService( + action_configs={ + "loader": { + "context_scope": { + "seed_path": {"field1": "$file:missing.json"}, + }, + }, + }, + workflow_config_path=workflow_path, + ) + result = svc.resolve_all() + + seed_errors = [e for e in result.errors if "Seed file not found" in e.message] + assert len(seed_errors) == 1 + assert "existing.json" in seed_errors[0].hint + + def test_existing_seed_file_passes(self, tmp_path): + """Existing seed file produces no error.""" + project = tmp_path / "project" + agent_config = project / "agent_config" + agent_config.mkdir(parents=True) + seed_data = project / "seed_data" + seed_data.mkdir() + (seed_data / "data.json").write_text("{}") + + workflow_path = str(agent_config / "workflow.yml") + + svc = WorkflowResolutionService( + action_configs={ + "loader": { + "context_scope": { + "seed_path": {"field1": "$file:data.json"}, + }, + }, + }, + workflow_config_path=workflow_path, + ) + result = svc.resolve_all() + + seed_errors = [e for e in result.errors if "seed" in e.message.lower()] + assert len(seed_errors) == 0 + + def test_path_traversal_caught(self, tmp_path): + """Path traversal in seed file reference is caught.""" + project = tmp_path / "project" + agent_config = project / "agent_config" + agent_config.mkdir(parents=True) + seed_data = project / "seed_data" + seed_data.mkdir() + + workflow_path = str(agent_config / "workflow.yml") + + svc = WorkflowResolutionService( + action_configs={ + "loader": { + "context_scope": { + "seed_path": {"field1": "$file:../../etc/passwd"}, + }, + }, + }, + workflow_config_path=workflow_path, + ) + result = svc.resolve_all() + + seed_errors = [e for e in result.errors if "escapes base directory" in e.message] + assert len(seed_errors) == 1 + + def test_no_seed_path_config_no_errors(self): + """No seed_path in config produces no errors.""" + svc = WorkflowResolutionService( + action_configs={ + "loader": {"context_scope": {}}, + }, + ) + result = svc.resolve_all() + + seed_errors = [e for e in result.errors if "seed" in e.message.lower()] + assert len(seed_errors) == 0 + + def test_seed_data_dir_missing_graceful_skip(self, tmp_path): + """When seed_data directory doesn't exist, gracefully skip (no errors).""" + project = tmp_path / "project" + agent_config = project / "agent_config" + agent_config.mkdir(parents=True) + # Intentionally do NOT create seed_data directory + + workflow_path = str(agent_config / "workflow.yml") + + svc = WorkflowResolutionService( + action_configs={ + "loader": { + "context_scope": { + "seed_path": {"field1": "$file:data.json"}, + }, + }, + }, + workflow_config_path=workflow_path, + ) + result = svc.resolve_all() + + seed_errors = [e for e in result.errors if "seed" in e.message.lower()] + assert len(seed_errors) == 0 + + +class TestVendorRunModeCompatibility: + """Tests for _check_vendor_run_mode_compatibility().""" + + def test_batch_mode_with_non_batch_vendor_produces_error(self): + """Batch mode with a non-batch vendor (e.g., ollama) produces error.""" + svc = WorkflowResolutionService( + action_configs={ + "local_action": {"model_vendor": "ollama", "run_mode": "batch"}, + }, + ) + result = svc.resolve_all() + + batch_errors = [e for e in result.errors if "batch" in e.message.lower()] + assert len(batch_errors) == 1 + assert "ollama" in batch_errors[0].message + + def test_online_mode_passes_for_any_vendor(self): + """Online mode passes for any vendor (no batch mode check needed).""" + svc = WorkflowResolutionService( + action_configs={ + "my_action": {"model_vendor": "ollama", "run_mode": "online"}, + }, + ) + result = svc.resolve_all() + + batch_errors = [e for e in result.errors if "batch" in e.message.lower()] + assert len(batch_errors) == 0 + + def test_batch_mode_with_batch_capable_vendor_passes(self): + """Batch mode with a batch-capable vendor passes.""" + svc = WorkflowResolutionService( + action_configs={ + "my_action": {"model_vendor": "openai", "run_mode": "batch"}, + }, + ) + result = svc.resolve_all() + + batch_errors = [e for e in result.errors if "batch" in e.message.lower()] + assert len(batch_errors) == 0 diff --git a/tests/unit/workflow/managers/test_state_extensions.py b/tests/unit/workflow/managers/test_state_extensions.py new file mode 100644 index 0000000..da40822 --- /dev/null +++ b/tests/unit/workflow/managers/test_state_extensions.py @@ -0,0 +1,94 @@ +"""Tests for ActionStateManager.is_workflow_done() and updated get_pending_actions().""" + +from agent_actions.workflow.managers.state import ActionStateManager + + +class TestIsWorkflowDone: + """Tests for is_workflow_done().""" + + def test_true_when_all_completed(self, tmp_path): + """is_workflow_done returns True when all actions are completed.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "completed") + + assert mgr.is_workflow_done() is True + + def test_true_when_mix_of_completed_and_failed(self, tmp_path): + """is_workflow_done returns True when all actions are completed or failed.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b", "c"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "failed") + mgr.update_status("c", "completed") + + assert mgr.is_workflow_done() is True + + def test_false_when_some_still_pending(self, tmp_path): + """is_workflow_done returns False when some actions are still pending.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b", "c"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "failed") + # c stays pending + + assert mgr.is_workflow_done() is False + + def test_false_when_some_still_running(self, tmp_path): + """is_workflow_done returns False when some actions are still running.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "running") + + assert mgr.is_workflow_done() is False + + +class TestGetPendingActionsExcludesFailed: + """Tests for get_pending_actions() excluding both completed AND failed.""" + + def test_excludes_completed(self, tmp_path): + """get_pending_actions excludes completed actions.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b", "c"]) + mgr.update_status("a", "completed") + + pending = mgr.get_pending_actions(["a", "b", "c"]) + assert "a" not in pending + assert "b" in pending + assert "c" in pending + + def test_excludes_failed(self, tmp_path): + """get_pending_actions excludes failed actions.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b", "c"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "failed") + + pending = mgr.get_pending_actions(["a", "b", "c"]) + assert "a" not in pending + assert "b" not in pending + assert "c" in pending + + def test_excludes_both_completed_and_failed(self, tmp_path): + """get_pending_actions excludes both completed and failed.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b", "c", "d"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "failed") + mgr.update_status("c", "running") + # d stays pending + + pending = mgr.get_pending_actions(["a", "b", "c", "d"]) + assert pending == ["c", "d"] + + def test_all_terminal_returns_empty(self, tmp_path): + """When all are completed/failed, returns empty list.""" + status_file = tmp_path / "status.json" + mgr = ActionStateManager(status_file, ["a", "b"]) + mgr.update_status("a", "completed") + mgr.update_status("b", "failed") + + pending = mgr.get_pending_actions(["a", "b"]) + assert pending == [] diff --git a/tests/unit/workflow/test_circuit_breaker.py b/tests/unit/workflow/test_circuit_breaker.py new file mode 100644 index 0000000..effdbff --- /dev/null +++ b/tests/unit/workflow/test_circuit_breaker.py @@ -0,0 +1,197 @@ +"""Tests for ActionExecutor circuit breaker methods.""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from agent_actions.storage.backend import DISPOSITION_FAILED, NODE_LEVEL_RECORD_ID +from agent_actions.workflow.executor import ( + ActionExecutionResult, + ActionExecutor, + ExecutorDependencies, +) +from agent_actions.workflow.managers.batch import BatchLifecycleManager +from agent_actions.workflow.managers.output import ActionOutputManager +from agent_actions.workflow.managers.skip import SkipEvaluator +from agent_actions.workflow.managers.state import ActionStateManager + + +@pytest.fixture +def mock_deps(): + """Create mock dependencies for ActionExecutor.""" + deps = MagicMock(spec=ExecutorDependencies) + deps.state_manager = MagicMock(spec=ActionStateManager) + deps.batch_manager = MagicMock(spec=BatchLifecycleManager) + deps.action_runner = MagicMock() + deps.skip_evaluator = MagicMock(spec=SkipEvaluator) + deps.output_manager = MagicMock(spec=ActionOutputManager) + deps.action_runner.execution_order = ["agent_a", "agent_b", "agent_c"] + return deps + + +@pytest.fixture +def executor(mock_deps): + """Create executor with mock dependencies.""" + return ActionExecutor(mock_deps) + + +class TestCheckUpstreamHealth: + """Tests for _check_upstream_health().""" + + def test_no_dependencies_returns_none(self, executor): + """No dependencies means all healthy — returns None.""" + config = {"dependencies": []} + result = executor._check_upstream_health("agent_b", config) + assert result is None + + def test_no_dependencies_key_returns_none(self, executor): + """Missing dependencies key means all healthy — returns None.""" + config = {} + result = executor._check_upstream_health("agent_b", config) + assert result is None + + def test_all_deps_healthy_returns_none(self, executor, mock_deps): + """All dependencies healthy — returns None.""" + mock_deps.state_manager.is_failed.return_value = False + mock_deps.action_runner.storage_backend = None + + config = {"dependencies": ["agent_a"]} + result = executor._check_upstream_health("agent_b", config) + assert result is None + + def test_dep_failed_via_state_manager(self, executor, mock_deps): + """One dep failed (state_manager.is_failed) returns dep name.""" + mock_deps.state_manager.is_failed.return_value = True + + config = {"dependencies": ["agent_a"]} + result = executor._check_upstream_health("agent_b", config) + assert result == "agent_a" + + def test_dep_failed_via_disposition(self, executor, mock_deps): + """One dep has DISPOSITION_FAILED in storage returns dep name.""" + mock_deps.state_manager.is_failed.return_value = False + storage = MagicMock() + storage.has_disposition.return_value = True + mock_deps.action_runner.storage_backend = storage + + config = {"dependencies": ["agent_a"]} + result = executor._check_upstream_health("agent_b", config) + + assert result == "agent_a" + storage.has_disposition.assert_called_once_with( + "agent_a", DISPOSITION_FAILED, record_id=NODE_LEVEL_RECORD_ID + ) + + def test_no_storage_backend_only_checks_state_manager(self, executor, mock_deps): + """No storage backend — only checks state_manager.""" + mock_deps.state_manager.is_failed.return_value = False + mock_deps.action_runner.storage_backend = None + + config = {"dependencies": ["agent_a"]} + result = executor._check_upstream_health("agent_b", config) + assert result is None + + +class TestHandleDependencySkip: + """Tests for _handle_dependency_skip().""" + + @patch("agent_actions.workflow.executor.fire_event") + def test_updates_state_to_failed(self, mock_fire, executor, mock_deps): + """Updates state to 'failed'.""" + mock_deps.action_runner.storage_backend = None + start_time = datetime.now() + + executor._handle_dependency_skip("agent_b", 1, {}, start_time, "agent_a") + + mock_deps.state_manager.update_status.assert_called_once_with("agent_b", "failed") + + @patch("agent_actions.workflow.executor.fire_event") + def test_writes_failed_disposition(self, mock_fire, executor, mock_deps): + """Writes DISPOSITION_FAILED to storage.""" + storage = MagicMock() + mock_deps.action_runner.storage_backend = storage + start_time = datetime.now() + + executor._handle_dependency_skip("agent_b", 1, {}, start_time, "agent_a") + + storage.set_disposition.assert_called_once() + call_kwargs = storage.set_disposition.call_args + assert call_kwargs[1]["disposition"] == DISPOSITION_FAILED + assert call_kwargs[1]["action_name"] == "agent_b" + + @patch("agent_actions.workflow.executor.fire_event") + def test_fires_action_skip_event(self, mock_fire, executor, mock_deps): + """Fires ActionSkipEvent with correct reason.""" + mock_deps.action_runner.storage_backend = None + start_time = datetime.now() + + executor._handle_dependency_skip("agent_b", 1, {}, start_time, "agent_a") + + mock_fire.assert_called_once() + event = mock_fire.call_args[0][0] + from agent_actions.logging.events import ActionSkipEvent + + assert isinstance(event, ActionSkipEvent) + assert "agent_a" in event.skip_reason + + @patch("agent_actions.workflow.executor.fire_event") + def test_records_in_run_tracker_if_available(self, mock_fire, executor, mock_deps): + """Records in run_tracker if available.""" + mock_deps.action_runner.storage_backend = None + executor.run_tracker = MagicMock() + executor.run_id = "run-123" + start_time = datetime.now() + + executor._handle_dependency_skip("agent_b", 1, {}, start_time, "agent_a") + + executor.run_tracker.record_action_complete.assert_called_once() + config = executor.run_tracker.record_action_complete.call_args[1]["config"] + assert config.status == "failed" + assert config.run_id == "run-123" + + @patch("agent_actions.workflow.executor.fire_event") + def test_returns_failed_result_with_success_true(self, mock_fire, executor, mock_deps): + """Returns ActionExecutionResult(success=True, status='failed') — failed state, but independent branches continue.""" + mock_deps.action_runner.storage_backend = None + start_time = datetime.now() + + result = executor._handle_dependency_skip("agent_b", 1, {}, start_time, "agent_a") + + assert isinstance(result, ActionExecutionResult) + assert result.success is True + assert result.status == "failed" + + +class TestWriteFailedDisposition: + """Tests for _write_failed_disposition().""" + + def test_writes_disposition_when_storage_available(self, executor, mock_deps): + """Writes disposition when storage backend is available.""" + storage = MagicMock() + mock_deps.action_runner.storage_backend = storage + + executor._write_failed_disposition("agent_a", "Some error") + + storage.set_disposition.assert_called_once_with( + action_name="agent_a", + record_id=NODE_LEVEL_RECORD_ID, + disposition=DISPOSITION_FAILED, + reason="Some error", + ) + + def test_logs_warning_on_storage_error(self, executor, mock_deps, caplog): + """Logs warning on storage error (doesn't raise).""" + storage = MagicMock() + storage.set_disposition.side_effect = RuntimeError("DB error") + mock_deps.action_runner.storage_backend = storage + + # Should not raise + executor._write_failed_disposition("agent_a", "Some error") + + def test_noops_when_storage_backend_is_none(self, executor, mock_deps): + """No-ops when storage backend is None.""" + mock_deps.action_runner.storage_backend = None + + # Should not raise; nothing happens + executor._write_failed_disposition("agent_a", "Some error") diff --git a/tests/unit/workflow/test_coordinator_sequential.py b/tests/unit/workflow/test_coordinator_sequential.py index 785d6b7..40a6b08 100644 --- a/tests/unit/workflow/test_coordinator_sequential.py +++ b/tests/unit/workflow/test_coordinator_sequential.py @@ -126,15 +126,15 @@ def test_batch_submitted_stops(self): assert should_stop is True - def test_failure_raises(self): - """Failed result should raise the error.""" + def test_failure_continues(self): + """Failed result should log and continue (circuit breaker handles downstream).""" wf = _build_workflow() wf.services.core.state_manager.is_completed.return_value = False error = RuntimeError("agent crashed") wf.services.core.action_executor.execute_action_sync.return_value = _failed_result(error) - with pytest.raises(RuntimeError, match="agent crashed"): - wf._run_single_action(0, "agent_a", 2) + should_stop = wf._run_single_action(0, "agent_a", 2) + assert should_stop is False def test_fires_agent_start(self): """Should fire agent_start event before execution.""" @@ -186,6 +186,7 @@ def test_batch_stops_early_returns_none(self): status="batch_submitted", output_folder=None ) wf.services.core.state_manager.is_workflow_complete.return_value = False + wf.services.core.state_manager.is_workflow_done.return_value = False mgr = MagicMock() mgr.context.return_value.__enter__ = MagicMock() @@ -195,19 +196,39 @@ def test_batch_stops_early_returns_none(self): assert result is None - def test_exception_calls_handle_workflow_error_and_reraises(self): - """Exception should call handle_workflow_error and re-raise.""" + def test_failure_returns_completed_with_failures(self): + """Failed action should produce completed_with_failures, not raise.""" wf = _build_workflow(execution_order=["agent_a"]) wf.services.core.state_manager.is_completed.return_value = False error = RuntimeError("boom") wf.services.core.action_executor.execute_action_sync.return_value = _failed_result(error) + wf.services.core.state_manager.is_workflow_complete.return_value = False + wf.services.core.state_manager.is_workflow_done.return_value = True + wf.services.core.state_manager.get_failed_actions.return_value = ["agent_a"] + + mgr = MagicMock() + mgr.context.return_value.__enter__ = MagicMock() + mgr.context.return_value.__exit__ = MagicMock(return_value=False) + with patch("agent_actions.workflow.coordinator.get_manager", return_value=mgr): + result = wf._run_workflow_with_context(datetime.now()) + + assert result == ("completed_with_failures", {"failed": ["agent_a"]}) + assert wf.state.failed is True + + def test_unexpected_exception_calls_handle_workflow_error_and_reraises(self): + """Unexpected crash (not a failed result) should call handle_workflow_error and re-raise.""" + wf = _build_workflow(execution_order=["agent_a"]) + wf.services.core.state_manager.is_completed.return_value = False + wf.services.core.action_executor.execute_action_sync.side_effect = RuntimeError( + "unexpected crash" + ) mgr = MagicMock() mgr.context.return_value.__enter__ = MagicMock() mgr.context.return_value.__exit__ = MagicMock(return_value=False) with ( patch("agent_actions.workflow.coordinator.get_manager", return_value=mgr), - pytest.raises(RuntimeError, match="boom"), + pytest.raises(RuntimeError, match="unexpected crash"), ): wf._run_workflow_with_context(datetime.now()) diff --git a/tests/unit/workflow/test_executor_lifecycle.py b/tests/unit/workflow/test_executor_lifecycle.py index 94f7192..e438f1b 100644 --- a/tests/unit/workflow/test_executor_lifecycle.py +++ b/tests/unit/workflow/test_executor_lifecycle.py @@ -53,6 +53,7 @@ def test_completed_with_output_skips(self, executor, mock_deps): mock_deps.state_manager.get_status.return_value = "completed" storage = MagicMock() storage.list_target_files.return_value = ["file1.json"] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage result = executor.execute_action_sync( @@ -74,6 +75,7 @@ def test_completed_no_output_reruns(self, executor, mock_deps): mock_deps.state_manager.get_status.return_value = "completed" storage = MagicMock() storage.list_target_files.return_value = [] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage mock_deps.skip_evaluator.should_skip_action.return_value = False @@ -99,6 +101,7 @@ def test_storage_error_during_verify_reruns_agent(self, executor, mock_deps): mock_deps.state_manager.get_status.return_value = "completed" storage = MagicMock() storage.list_target_files.side_effect = OSError("SQLite lock") + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage mock_deps.skip_evaluator.should_skip_action.return_value = False @@ -407,6 +410,7 @@ def test_with_files_returns_skip(self, executor, mock_deps): """Agent with output files should be skipped (already done).""" storage = MagicMock() storage.list_target_files.return_value = ["file.json"] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage should_skip, result = executor._verify_completion_status("agent_a") @@ -418,6 +422,7 @@ def test_no_files_resets_to_pending(self, executor, mock_deps): """Agent with no output files should be reset to pending.""" storage = MagicMock() storage.list_target_files.return_value = [] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage should_skip, result = executor._verify_completion_status("agent_a") @@ -439,6 +444,7 @@ def test_storage_error_resets_to_pending(self, executor, mock_deps, exc): """Any exception during verification should reset to pending and re-run.""" storage = MagicMock() storage.list_target_files.side_effect = exc + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage should_skip, result = executor._verify_completion_status("agent_a") diff --git a/tests/unit/workflow/test_limits.py b/tests/unit/workflow/test_limits.py index 93070f9..5cb4b7a 100644 --- a/tests/unit/workflow/test_limits.py +++ b/tests/unit/workflow/test_limits.py @@ -333,6 +333,7 @@ def test_same_limits_skips_action(self, executor, mock_deps): storage = MagicMock() storage.list_target_files.return_value = ["file.json"] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage result = executor.execute_action_sync( @@ -351,6 +352,7 @@ def test_no_limits_old_status_no_invalidation(self, executor, mock_deps): storage = MagicMock() storage.list_target_files.return_value = ["file.json"] + storage.has_disposition.return_value = False mock_deps.action_runner.storage_backend = storage result = executor.execute_action_sync( diff --git a/tests/workflow/test_stale_completion_verification.py b/tests/workflow/test_stale_completion_verification.py index 7f7a9ea..2c5875b 100644 --- a/tests/workflow/test_stale_completion_verification.py +++ b/tests/workflow/test_stale_completion_verification.py @@ -34,6 +34,7 @@ def _make_executor(storage_has_data: bool) -> ActionExecutor: action_runner = MagicMock() storage_backend = MagicMock() storage_backend.list_target_files.return_value = ["output.json"] if storage_has_data else [] + storage_backend.has_disposition.return_value = False action_runner.storage_backend = storage_backend deps = ExecutorDependencies(