From 6ece97707825d10bfaf51c2ec8f62d3149e26935 Mon Sep 17 00:00:00 2001 From: Muizz Lateef Date: Thu, 2 Apr 2026 13:14:09 +0100 Subject: [PATCH] fix: reject batch run_mode for synchronous tool/HITL actions at validation time Previously, configuring run_mode=batch with kind=tool or kind=hitl silently fell through to online mode at runtime with no warning. This caused config dishonesty and made debugging batch failures difficult. - Extend VendorCompatibilityValidator to emit a hard error when kind is tool or hitl with run_mode=batch - Add logger.warning at the runtime override site as defense-in-depth - Add 8 tests covering validation rejection and runtime warning Closes #85 --- .../vendor_compatibility_validator.py | 8 +++ agent_actions/workflow/pipeline.py | 10 ++- ...t_vendor_compatibility_runmode_coercion.py | 59 +++++++++++++++ .../test_pipeline_batch_sync_warning.py | 72 +++++++++++++++++++ 4 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 tests/unit/workflow/test_pipeline_batch_sync_warning.py diff --git a/agent_actions/validation/action_validators/vendor_compatibility_validator.py b/agent_actions/validation/action_validators/vendor_compatibility_validator.py index 92b1839..cf65a62 100644 --- a/agent_actions/validation/action_validators/vendor_compatibility_validator.py +++ b/agent_actions/validation/action_validators/vendor_compatibility_validator.py @@ -28,6 +28,14 @@ def validate(self, context) -> ActionEntryValidationResult: run_mode = RunMode(raw_run_mode) if isinstance(raw_run_mode, str) else raw_run_mode if run_mode == RunMode.BATCH: + kind = normalized_entry.get("kind", "").lower() + if kind in ("tool", "hitl"): + errors.append( + f"{desc} kind '{kind}' does not support batch processing. " + f"Tool and HITL actions execute synchronously. " + f"Set run_mode='online' or change kind to 'llm'." + ) + model_vendor = str(normalized_entry.get("model_vendor", "")).lower() if model_vendor: diff --git a/agent_actions/workflow/pipeline.py b/agent_actions/workflow/pipeline.py index 87c41c6..5c21e52 100644 --- a/agent_actions/workflow/pipeline.py +++ b/agent_actions/workflow/pipeline.py @@ -256,7 +256,8 @@ def process_file(params: ProcessParams): HITL_VENDOR, ] or params.action_config.get("kind") in ["tool", "hitl"] - if params.action_config.get("run_mode") == RunMode.BATCH and not is_synchronous: + run_mode = params.action_config.get("run_mode") + if run_mode == RunMode.BATCH and not is_synchronous: return ProcessingPipeline._handle_batch_generation( BatchPipelineParams( pipeline_action_config=params.action_config, @@ -269,6 +270,13 @@ def process_file(params: ProcessParams): storage_backend=params.storage_backend, ) ) + if run_mode == RunMode.BATCH and is_synchronous: + logger.warning( + "Action '%s' has run_mode=batch but kind '%s' requires synchronous " + "execution. Running in online mode.", + params.action_name, + params.action_config.get("kind", params.action_config.get("model_vendor")), + ) pipeline = create_processing_pipeline_from_params( action_config=params.action_config, action_name=params.action_name, diff --git a/tests/unit/validation/test_vendor_compatibility_runmode_coercion.py b/tests/unit/validation/test_vendor_compatibility_runmode_coercion.py index c2af3a5..762d1a1 100644 --- a/tests/unit/validation/test_vendor_compatibility_runmode_coercion.py +++ b/tests/unit/validation/test_vendor_compatibility_runmode_coercion.py @@ -74,3 +74,62 @@ def test_batch_unknown_vendor_produces_warning(self): assert not result.errors assert result.warnings assert "some_unknown_vendor" in result.warnings[0] + + +class TestBatchKindValidation: + """Verify batch mode is rejected for synchronous action kinds (tool, hitl).""" + + def test_batch_kind_tool_rejected(self): + """Batch mode with kind=tool produces an error.""" + context = ActionEntryValidationContext( + entry={"run_mode": "batch", "kind": "tool"}, + agent_name_context="test_agent", + ) + validator = VendorCompatibilityValidator() + result = validator.validate(context) + assert result.errors + assert "tool" in result.errors[0].lower() + assert "batch" in result.errors[0].lower() + + def test_batch_kind_hitl_rejected(self): + """Batch mode with kind=hitl produces an error.""" + context = ActionEntryValidationContext( + entry={"run_mode": "batch", "kind": "hitl"}, + agent_name_context="test_agent", + ) + validator = VendorCompatibilityValidator() + result = validator.validate(context) + assert result.errors + assert "hitl" in result.errors[0].lower() + assert "batch" in result.errors[0].lower() + + def test_batch_kind_tool_with_valid_vendor_rejected(self): + """Batch mode with kind=tool is rejected even when model_vendor is batch-compatible.""" + context = ActionEntryValidationContext( + entry={"run_mode": "batch", "kind": "tool", "model_vendor": "openai"}, + agent_name_context="test_agent", + ) + validator = VendorCompatibilityValidator() + result = validator.validate(context) + assert result.errors + assert "tool" in result.errors[0].lower() + + def test_online_kind_tool_passes(self): + """Online mode with kind=tool produces no errors (valid config).""" + context = ActionEntryValidationContext( + entry={"run_mode": "online", "kind": "tool"}, + agent_name_context="test_agent", + ) + validator = VendorCompatibilityValidator() + result = validator.validate(context) + assert not result.errors + + def test_batch_kind_llm_passes(self): + """Batch mode with kind=llm and valid vendor produces no errors.""" + context = ActionEntryValidationContext( + entry={"run_mode": "batch", "kind": "llm", "model_vendor": "openai"}, + agent_name_context="test_agent", + ) + validator = VendorCompatibilityValidator() + result = validator.validate(context) + assert not result.errors diff --git a/tests/unit/workflow/test_pipeline_batch_sync_warning.py b/tests/unit/workflow/test_pipeline_batch_sync_warning.py new file mode 100644 index 0000000..3fd96df --- /dev/null +++ b/tests/unit/workflow/test_pipeline_batch_sync_warning.py @@ -0,0 +1,72 @@ +"""Tests that batch + synchronous action emits a runtime warning.""" + +from unittest.mock import MagicMock, patch + +from agent_actions.config.types import RunMode +from agent_actions.workflow.pipeline import FilePathsConfig, ProcessingPipeline, ProcessParams + + +def _make_process_params(kind: str, run_mode: RunMode = RunMode.BATCH) -> ProcessParams: + """Create ProcessParams for pipeline tests.""" + return ProcessParams( + action_config={"run_mode": run_mode, "kind": kind}, + action_name=f"test_{kind}_action", + paths=FilePathsConfig( + file_path="/tmp/input.json", + base_directory="/tmp", + output_directory="/tmp/output", + ), + idx=0, + processor_factory=MagicMock(), + ) + + +class TestBatchSyncWarning: + """Verify runtime warning when batch mode meets synchronous action kind.""" + + @patch("agent_actions.workflow.pipeline.create_processing_pipeline_from_params") + @patch("agent_actions.workflow.pipeline.logger") + def test_batch_tool_logs_warning(self, mock_logger, mock_create): + """Batch + kind=tool should log a warning and proceed in online mode.""" + mock_pipeline = MagicMock() + mock_pipeline.process.return_value = "/tmp/output/input.json" + mock_create.return_value = mock_pipeline + + params = _make_process_params("tool") + ProcessingPipeline.process_file(params) + + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] % mock_logger.warning.call_args[0][1:] + assert "run_mode=batch" in warning_msg + assert "tool" in warning_msg + mock_pipeline.process.assert_called_once() + + @patch("agent_actions.workflow.pipeline.create_processing_pipeline_from_params") + @patch("agent_actions.workflow.pipeline.logger") + def test_batch_hitl_logs_warning(self, mock_logger, mock_create): + """Batch + kind=hitl should log a warning and proceed in online mode.""" + mock_pipeline = MagicMock() + mock_pipeline.process.return_value = "/tmp/output/input.json" + mock_create.return_value = mock_pipeline + + params = _make_process_params("hitl") + ProcessingPipeline.process_file(params) + + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] % mock_logger.warning.call_args[0][1:] + assert "run_mode=batch" in warning_msg + assert "hitl" in warning_msg + mock_pipeline.process.assert_called_once() + + @patch("agent_actions.workflow.pipeline.create_processing_pipeline_from_params") + @patch("agent_actions.workflow.pipeline.logger") + def test_online_tool_no_warning(self, mock_logger, mock_create): + """Online + kind=tool should NOT log a warning.""" + mock_pipeline = MagicMock() + mock_pipeline.process.return_value = "/tmp/output/input.json" + mock_create.return_value = mock_pipeline + + params = _make_process_params("tool", RunMode.ONLINE) + ProcessingPipeline.process_file(params) + + mock_logger.warning.assert_not_called()