Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion agent_actions/input/preprocessing/staging/initial_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,14 +693,23 @@ def _process_online_mode_with_record_processor(

results = processor.process_batch(data_chunk, processing_context)

processed_items = ResultCollector.collect_results(
processed_items, stats = ResultCollector.collect_results(
results,
ctx.agent_config,
ctx.agent_name,
is_first_stage=True,
storage_backend=ctx.storage_backend,
)

# If input had records but output is empty AND there are actual failures,
# raise so the executor marks the action as failed and the circuit breaker
# skips downstream dependents.
if data_chunk and not processed_items and stats.failed > 0:
raise RuntimeError(
f"Action '{ctx.agent_name}' produced 0 records — "
f"all {len(data_chunk)} input item(s) failed ({stats.failed} failures)"
)

if ctx.storage_backend is None:
raise AgentActionsError(
"Storage backend is required for online initial-stage writes.",
Expand Down
27 changes: 25 additions & 2 deletions agent_actions/processing/result_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import json
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional

from agent_actions.errors import AgentActionsError
Expand Down Expand Up @@ -38,6 +39,18 @@ def _get_retry_attempts(result: ProcessingResult) -> str | int:
return "unknown"


@dataclass
class CollectionStats:
"""Counts from result collection — returned alongside output records."""

success: int = 0
failed: int = 0
skipped: int = 0
filtered: int = 0
exhausted: int = 0
unprocessed: int = 0


def _safe_set_disposition(
backend: "StorageBackend",
action_name: str,
Expand Down Expand Up @@ -72,9 +85,12 @@ def collect_results(
*,
is_first_stage: bool,
storage_backend: Optional["StorageBackend"] = None,
) -> list[dict[str, Any]]:
) -> tuple[list[dict[str, Any]], CollectionStats]:
"""Flatten ProcessingResult entries into output records.

Returns:
Tuple of (output_records, stats). Stats contain counts by status.

Raises:
AgentActionsError: If on_exhausted=raise and records exhausted retries.
"""
Expand Down Expand Up @@ -294,7 +310,14 @@ def collect_results(
stats["unprocessed"],
)

return output
return output, CollectionStats(
success=stats["success"],
failed=stats["failed"],
skipped=stats["skipped"],
filtered=stats["filtered"],
exhausted=stats["exhausted"],
unprocessed=stats["unprocessed"],
)

@staticmethod
def _check_exhausted_raise(
Expand Down
20 changes: 6 additions & 14 deletions agent_actions/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _process_by_strategy(
results = self.record_processor.process_batch(data, context)

# Collect success results
output = ResultCollector.collect_results(
output, stats = ResultCollector.collect_results(
results,
cast(dict[str, Any], self.config.action_config),
self.config.action_name,
Expand All @@ -501,19 +501,11 @@ def _process_by_strategy(
# If input had records but output is empty AND there are actual failures
# (not just guard-filtered/skipped records), raise so the executor marks
# the action as failed and the circuit breaker skips downstream dependents.
# Guard filters (SKIPPED/FILTERED status) legitimately produce 0 output —
# only FAILED results indicate processing errors (e.g. 401 auth).
if data and not output:
from agent_actions.processing.types import ProcessingStatus

failed_results = [r for r in results if r.status == ProcessingStatus.FAILED]
if failed_results:
failed_msgs = [r.error for r in failed_results if r.error]
summary = "; ".join(failed_msgs[:3])
raise RuntimeError(
f"Action '{self.config.action_name}' produced 0 records — "
f"all {len(data)} input item(s) failed: {summary}"
)
if data and not output and stats.failed > 0:
raise RuntimeError(
f"Action '{self.config.action_name}' produced 0 records — "
f"all {len(data)} input item(s) failed ({stats.failed} failures)"
)

self.output_handler.save_main_output(output, file_path, base_directory, output_directory)

Expand Down
10 changes: 5 additions & 5 deletions tests/unit/core/test_result_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_result_collector_aggregates_statuses_first_stage():
failed = ProcessingResult.failed(error="Boom", source_guid="src-4")
filtered = ProcessingResult.filtered(source_guid="src-5")

output = ResultCollector.collect_results(
output, _ = ResultCollector.collect_results(
[success, skipped, exhausted, failed, filtered],
agent_config,
"fallback_name",
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_result_collector_uses_input_record_downstream():
)
exhausted.data = [exhausted_data]

output = ResultCollector.collect_results(
output, _ = ResultCollector.collect_results(
[exhausted],
agent_config,
"downstream",
Expand All @@ -140,7 +140,7 @@ def test_result_collector_uses_input_record_downstream():
def test_result_collector_handles_none_data():
result = ProcessingResult(status=ProcessingStatus.SUCCESS, data=None) # type: ignore[arg-type]

output = ResultCollector.collect_results(
output, _ = ResultCollector.collect_results(
[result],
agent_config={},
agent_name="test",
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_result_collector_on_exhausted_return_last_does_not_raise():
exhausted.data = [exhausted_data]

# Should not raise, should return exhausted record
output = ResultCollector.collect_results(
output, _ = ResultCollector.collect_results(
[exhausted],
agent_config,
"test_agent",
Expand Down Expand Up @@ -479,7 +479,7 @@ def test_no_storage_backend_no_crash(self):
failed = ProcessingResult.failed(error="boom", source_guid="src-fail")

# Should not raise
output = ResultCollector.collect_results(
output, _ = ResultCollector.collect_results(
[filtered, failed],
{},
"agent",
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/core/test_upstream_unprocessed_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_counts_unprocessed_separately(self):
ProcessingResult.success(data=[{"content": "ok2"}]),
]

output = ResultCollector.collect_results(
output, _ = ResultCollector.collect_results(
results,
agent_config={"agent_type": "test"},
agent_name="test",
Expand All @@ -181,7 +181,7 @@ def test_unprocessed_preserved_in_output(self):
),
]

output = ResultCollector.collect_results(
output, _ = ResultCollector.collect_results(
results,
agent_config={"agent_type": "test"},
agent_name="test",
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/input/test_initial_pipeline_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def test_returns_string_path(self, tmp_dirs):
patch("agent_actions.input.preprocessing.staging.initial_pipeline.FileWriter"),
):
MockProc.return_value.process_batch.return_value = [{"result": "ok"}]
MockCollector.collect_results.return_value = [{"result": "ok"}]
from agent_actions.processing.result_collector import CollectionStats

MockCollector.collect_results.return_value = (
[{"result": "ok"}],
CollectionStats(success=1),
)

result = _process_online_mode_with_record_processor(
data_chunk, ctx, str(input_file), str(base), str(output)
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/workflow/test_pipeline_file_mode_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def supports_recovery(self):
assert item_out["source_guid"] == "sg-1"

# Verify ResultCollector flattens correctly
output = ResultCollector.collect_results(
output, _ = ResultCollector.collect_results(
[result], agent_config, agent_name, is_first_stage=False
)
assert len(output) == 3
Expand Down Expand Up @@ -392,7 +392,7 @@ def test_result_collector_does_not_raise_on_partial_failure():
ProcessingResult.failed(error="connection timeout"),
]

output = ResultCollector.collect_results(
output, _ = ResultCollector.collect_results(
results,
agent_config={"kind": "tool"},
agent_name="partial_tool",
Expand All @@ -413,7 +413,7 @@ def test_result_collector_does_not_raise_when_all_filtered():
ProcessingResult.filtered(),
]

output = ResultCollector.collect_results(
output, _ = ResultCollector.collect_results(
results,
agent_config={"kind": "tool"},
agent_name="filter_tool",
Expand Down
Loading