Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ def validate(self, context) -> ActionEntryValidationResult:
run_mode = RunMode(raw_run_mode) if isinstance(raw_run_mode, str) else raw_run_mode

if run_mode == RunMode.BATCH:
kind = normalized_entry.get("kind", "").lower()
if kind in ("tool", "hitl"):
errors.append(
f"{desc} kind '{kind}' does not support batch processing. "
f"Tool and HITL actions execute synchronously. "
f"Set run_mode='online' or change kind to 'llm'."
)

model_vendor = str(normalized_entry.get("model_vendor", "")).lower()

if model_vendor:
Expand Down
10 changes: 9 additions & 1 deletion agent_actions/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ def process_file(params: ProcessParams):
HITL_VENDOR,
] or params.action_config.get("kind") in ["tool", "hitl"]

if params.action_config.get("run_mode") == RunMode.BATCH and not is_synchronous:
run_mode = params.action_config.get("run_mode")
if run_mode == RunMode.BATCH and not is_synchronous:
return ProcessingPipeline._handle_batch_generation(
BatchPipelineParams(
pipeline_action_config=params.action_config,
Expand All @@ -269,6 +270,13 @@ def process_file(params: ProcessParams):
storage_backend=params.storage_backend,
)
)
if run_mode == RunMode.BATCH and is_synchronous:
logger.warning(
"Action '%s' has run_mode=batch but kind '%s' requires synchronous "
"execution. Running in online mode.",
params.action_name,
params.action_config.get("kind", params.action_config.get("model_vendor")),
)
pipeline = create_processing_pipeline_from_params(
action_config=params.action_config,
action_name=params.action_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,62 @@ def test_batch_unknown_vendor_produces_warning(self):
assert not result.errors
assert result.warnings
assert "some_unknown_vendor" in result.warnings[0]


class TestBatchKindValidation:
"""Verify batch mode is rejected for synchronous action kinds (tool, hitl)."""

def test_batch_kind_tool_rejected(self):
"""Batch mode with kind=tool produces an error."""
context = ActionEntryValidationContext(
entry={"run_mode": "batch", "kind": "tool"},
agent_name_context="test_agent",
)
validator = VendorCompatibilityValidator()
result = validator.validate(context)
assert result.errors
assert "tool" in result.errors[0].lower()
assert "batch" in result.errors[0].lower()

def test_batch_kind_hitl_rejected(self):
"""Batch mode with kind=hitl produces an error."""
context = ActionEntryValidationContext(
entry={"run_mode": "batch", "kind": "hitl"},
agent_name_context="test_agent",
)
validator = VendorCompatibilityValidator()
result = validator.validate(context)
assert result.errors
assert "hitl" in result.errors[0].lower()
assert "batch" in result.errors[0].lower()

def test_batch_kind_tool_with_valid_vendor_rejected(self):
"""Batch mode with kind=tool is rejected even when model_vendor is batch-compatible."""
context = ActionEntryValidationContext(
entry={"run_mode": "batch", "kind": "tool", "model_vendor": "openai"},
agent_name_context="test_agent",
)
validator = VendorCompatibilityValidator()
result = validator.validate(context)
assert result.errors
assert "tool" in result.errors[0].lower()

def test_online_kind_tool_passes(self):
"""Online mode with kind=tool produces no errors (valid config)."""
context = ActionEntryValidationContext(
entry={"run_mode": "online", "kind": "tool"},
agent_name_context="test_agent",
)
validator = VendorCompatibilityValidator()
result = validator.validate(context)
assert not result.errors

def test_batch_kind_llm_passes(self):
"""Batch mode with kind=llm and valid vendor produces no errors."""
context = ActionEntryValidationContext(
entry={"run_mode": "batch", "kind": "llm", "model_vendor": "openai"},
agent_name_context="test_agent",
)
validator = VendorCompatibilityValidator()
result = validator.validate(context)
assert not result.errors
72 changes: 72 additions & 0 deletions tests/unit/workflow/test_pipeline_batch_sync_warning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Tests that batch + synchronous action emits a runtime warning."""

from unittest.mock import MagicMock, patch

from agent_actions.config.types import RunMode
from agent_actions.workflow.pipeline import FilePathsConfig, ProcessingPipeline, ProcessParams


def _make_process_params(kind: str, run_mode: RunMode = RunMode.BATCH) -> ProcessParams:
"""Create ProcessParams for pipeline tests."""
return ProcessParams(
action_config={"run_mode": run_mode, "kind": kind},
action_name=f"test_{kind}_action",
paths=FilePathsConfig(
file_path="/tmp/input.json",
base_directory="/tmp",
output_directory="/tmp/output",
),
idx=0,
processor_factory=MagicMock(),
)


class TestBatchSyncWarning:
"""Verify runtime warning when batch mode meets synchronous action kind."""

@patch("agent_actions.workflow.pipeline.create_processing_pipeline_from_params")
@patch("agent_actions.workflow.pipeline.logger")
def test_batch_tool_logs_warning(self, mock_logger, mock_create):
"""Batch + kind=tool should log a warning and proceed in online mode."""
mock_pipeline = MagicMock()
mock_pipeline.process.return_value = "/tmp/output/input.json"
mock_create.return_value = mock_pipeline

params = _make_process_params("tool")
ProcessingPipeline.process_file(params)

mock_logger.warning.assert_called_once()
warning_msg = mock_logger.warning.call_args[0][0] % mock_logger.warning.call_args[0][1:]
assert "run_mode=batch" in warning_msg
assert "tool" in warning_msg
mock_pipeline.process.assert_called_once()

@patch("agent_actions.workflow.pipeline.create_processing_pipeline_from_params")
@patch("agent_actions.workflow.pipeline.logger")
def test_batch_hitl_logs_warning(self, mock_logger, mock_create):
"""Batch + kind=hitl should log a warning and proceed in online mode."""
mock_pipeline = MagicMock()
mock_pipeline.process.return_value = "/tmp/output/input.json"
mock_create.return_value = mock_pipeline

params = _make_process_params("hitl")
ProcessingPipeline.process_file(params)

mock_logger.warning.assert_called_once()
warning_msg = mock_logger.warning.call_args[0][0] % mock_logger.warning.call_args[0][1:]
assert "run_mode=batch" in warning_msg
assert "hitl" in warning_msg
mock_pipeline.process.assert_called_once()

@patch("agent_actions.workflow.pipeline.create_processing_pipeline_from_params")
@patch("agent_actions.workflow.pipeline.logger")
def test_online_tool_no_warning(self, mock_logger, mock_create):
"""Online + kind=tool should NOT log a warning."""
mock_pipeline = MagicMock()
mock_pipeline.process.return_value = "/tmp/output/input.json"
mock_create.return_value = mock_pipeline

params = _make_process_params("tool", RunMode.ONLINE)
ProcessingPipeline.process_file(params)

mock_logger.warning.assert_not_called()
Loading