diff --git a/src/modelgauge/annotators/composer/dag.py b/src/modelgauge/annotators/composer/dag.py index 5aba35c3..5a653053 100644 --- a/src/modelgauge/annotators/composer/dag.py +++ b/src/modelgauge/annotators/composer/dag.py @@ -268,14 +268,14 @@ def _run_traced(self, ctx: EvalContext) -> tuple[SuccessfulDAGOutput | FailedDAG for node_name in self._ordered: if node_name not in reachable: continue - ctx = ctx.with_parent_outputs( - {pred: node_outputs[pred] for pred in self._predecessors[node_name] if pred in node_outputs} - ) - node = self._nodes[node_name] try: + ctx = ctx.with_parent_outputs( + {pred: node_outputs[pred] for pred in self._predecessors[node_name] if pred in node_outputs} + ) + node = self._nodes[node_name] output = self._run_node(node, ctx) except Exception as e: - wrapped_error = NodeExecutionError(node.name, e) + wrapped_error = NodeExecutionError(node_name, e) return ( FailedDAGOutput( node_outputs=node_outputs, diff --git a/tests/modelgauge_tests/annotator_tests/composer_tests/test_composer.py b/tests/modelgauge_tests/annotator_tests/composer_tests/test_composer.py index 28852fda..3a7e2f8b 100644 --- a/tests/modelgauge_tests/annotator_tests/composer_tests/test_composer.py +++ b/tests/modelgauge_tests/annotator_tests/composer_tests/test_composer.py @@ -19,7 +19,7 @@ from modelgauge.annotators.composed_annotator import Safety from modelgauge.annotators.composer.context import EvalContext -from modelgauge.annotators.composer.dag import Composer, ComposerColumnNames +from modelgauge.annotators.composer.dag import Composer, ComposerColumnNames, FailedDAGOutput def test_dag_outputs(simple_dag): @@ -274,11 +274,8 @@ def test_dag_parallel_nodes_different_updated_contexts_raises_error(): .add_node(LowerCaseScorer(name="lower_scorer", routes=["threshold_arbiter"])) .add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5)) ) - with pytest.raises( - ValueError, - match="all parent outputs must have the same updated prompt/response", - ): - dag.run(ctx) + result = dag.run(ctx) + assert isinstance(result, FailedDAGOutput) def test_dag_run_with_dataframe_json_md(simple_dag):