From 2568c5e70e7c711e438733865040cf75f456016d Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Mon, 22 Jun 2026 10:32:30 -0400 Subject: [PATCH] Maintain audit log of DAG in output metadata. --- .../annotators/composed_annotator.py | 13 +++++++--- src/modelgauge/annotators/composer/context.py | 8 +++--- src/modelgauge/annotators/composer/dag.py | 13 ++++++++++ .../composer_tests/test_composer.py | 25 +++++++++++++++++++ .../composer_tests/test_safety.py | 2 ++ 5 files changed, 54 insertions(+), 7 deletions(-) diff --git a/src/modelgauge/annotators/composed_annotator.py b/src/modelgauge/annotators/composed_annotator.py index 97638bcb..7868a3fb 100644 --- a/src/modelgauge/annotators/composed_annotator.py +++ b/src/modelgauge/annotators/composed_annotator.py @@ -27,10 +27,10 @@ def translate_prompt( response=response.text, ) - def annotate(self, annotation_request: EvalContext) -> Verdict: + def annotate(self, annotation_request: EvalContext) -> SuccessfulDAGOutput: dag_output = self.dag.run(annotation_request) if isinstance(dag_output, SuccessfulDAGOutput): - return dag_output.verdict + return dag_output else: raise dag_output.error @@ -62,10 +62,15 @@ def __init__(self, uid: str, dag: Composer) -> None: def translate_response( self, request: EvalContext, - response: Safety, + response: SuccessfulDAGOutput, ) -> SafetyAnnotation: """Map DAGResult verdict to a SafetyAnnotation (is_safe bool).""" - return SafetyAnnotation(is_safe=response.is_safe) + assert isinstance(response.verdict, Safety), "Safety DAG output verdict must be of type Safety." + return SafetyAnnotation( + is_safe=response.verdict.is_safe, + is_valid=True, + metadata=response.to_dict(skip_cost=True), + ) class AnnotatorArbiter(SafetyArbiter, CacheableNodeMixin): diff --git a/src/modelgauge/annotators/composer/context.py b/src/modelgauge/annotators/composer/context.py index 76820fc5..64bcf56f 100644 --- a/src/modelgauge/annotators/composer/context.py +++ b/src/modelgauge/annotators/composer/context.py @@ -14,13 +14,15 @@ class NodeOutput: realized_cost: RealizedCost = field(default_factory=RealizedCost) updated_ctx: Optional[EvalContext] = None - def to_dict(self) -> dict: - return { + def to_dict(self, skip_cost=False) -> dict: + d = { "value": str(self.value), - "realized_cost": self.realized_cost.to_dict(), "updated_ctx": self.updated_ctx.to_dict() if self.updated_ctx else None, "original_ctx": self.original_ctx.to_dict(), } + if not skip_cost: + d["realized_cost"] = self.realized_cost.to_dict() + return d class EvalContext: diff --git a/src/modelgauge/annotators/composer/dag.py b/src/modelgauge/annotators/composer/dag.py index 03b4de79..5aba35c3 100644 --- a/src/modelgauge/annotators/composer/dag.py +++ b/src/modelgauge/annotators/composer/dag.py @@ -43,11 +43,24 @@ class _DAGOutput: node_outputs: dict[str, NodeOutput] total_cost: RealizedCost + def to_dict(self, skip_cost=False) -> dict: + d = { + "node_outputs": {k: v.to_dict(skip_cost=skip_cost) for k, v in self.node_outputs.items()}, + } + if not skip_cost: + d["total_cost"] = self.total_cost.to_dict() + return d + @dataclass class SuccessfulDAGOutput(_DAGOutput): verdict: Verdict + def to_dict(self, skip_cost=False) -> dict: + d = super().to_dict(skip_cost=skip_cost) + d["verdict"] = self.verdict.name + return d + @dataclass class FailedDAGOutput(_DAGOutput): diff --git a/tests/modelgauge_tests/annotator_tests/composer_tests/test_composer.py b/tests/modelgauge_tests/annotator_tests/composer_tests/test_composer.py index eab6fe2a..28852fda 100644 --- a/tests/modelgauge_tests/annotator_tests/composer_tests/test_composer.py +++ b/tests/modelgauge_tests/annotator_tests/composer_tests/test_composer.py @@ -231,6 +231,31 @@ def test_dag_updated_context_not_passed_to_parallel_nodes(): assert dag_output.node_outputs["lower_scorer"].value == pytest.approx(1.0) +def test_dag_output_to_dict(simple_dag, sample_ctx): + dag_output = simple_dag.run(sample_ctx) + dag_output_dict = dag_output.to_dict() + + assert "verdict" in dag_output_dict + assert "total_cost" in dag_output_dict + assert "node_outputs" in dag_output_dict + for node_output in dag_output_dict["node_outputs"].values(): + assert "value" in node_output + assert "original_ctx" in node_output + assert "updated_ctx" in node_output + assert "realized_cost" in node_output + + dag_output_dict_no_cost = dag_output.to_dict(skip_cost=True) + + assert "verdict" in dag_output_dict_no_cost + assert "total_cost" not in dag_output_dict_no_cost + assert "node_outputs" in dag_output_dict_no_cost + for node_output in dag_output_dict_no_cost["node_outputs"].values(): + assert "value" in node_output + assert "original_ctx" in node_output + assert "updated_ctx" in node_output + assert "realized_cost" not in node_output + + def test_dag_parallel_nodes_different_updated_contexts_raises_error(): # upper caser and lower caser are parallel nodes, they update the dontext differently which should raise an error. ctx = EvalContext(prompt="x", response="HELLO") diff --git a/tests/modelgauge_tests/annotator_tests/composer_tests/test_safety.py b/tests/modelgauge_tests/annotator_tests/composer_tests/test_safety.py index 7888dfec..3f4c0dbb 100644 --- a/tests/modelgauge_tests/annotator_tests/composer_tests/test_safety.py +++ b/tests/modelgauge_tests/annotator_tests/composer_tests/test_safety.py @@ -35,6 +35,8 @@ def test_safety_dag_run(simple_dag, sample_ctx): ) assert not output.is_safe assert isinstance(output, SafetyAnnotation) + assert len(output.metadata["node_outputs"]) == 3 + assert output.metadata["verdict"] == "UNSAFE" def test_safety_dag_with_bad_verdict_type():