Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/modelgauge/annotators/composed_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def translate_prompt(
response=response.text,
)

def annotate(self, annotation_request: EvalContext) -> Verdict:
def annotate(self, annotation_request: EvalContext) -> SuccessfulDAGOutput:
dag_output = self.dag.run(annotation_request)
if isinstance(dag_output, SuccessfulDAGOutput):
return dag_output.verdict
return dag_output
else:
raise dag_output.error

Expand Down Expand Up @@ -62,10 +62,15 @@ def __init__(self, uid: str, dag: Composer) -> None:
def translate_response(
self,
request: EvalContext,
response: Safety,
response: SuccessfulDAGOutput,
) -> SafetyAnnotation:
"""Map DAGResult verdict to a SafetyAnnotation (is_safe bool)."""
return SafetyAnnotation(is_safe=response.is_safe)
assert isinstance(response.verdict, Safety), "Safety DAG output verdict must be of type Safety."
return SafetyAnnotation(
is_safe=response.verdict.is_safe,
is_valid=True,
metadata=response.to_dict(skip_cost=True),
)


class AnnotatorArbiter(SafetyArbiter, CacheableNodeMixin):
Expand Down
8 changes: 5 additions & 3 deletions src/modelgauge/annotators/composer/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ class NodeOutput:
realized_cost: RealizedCost = field(default_factory=RealizedCost)
updated_ctx: Optional[EvalContext] = None

def to_dict(self) -> dict:
return {
def to_dict(self, skip_cost=False) -> dict:
d = {
"value": str(self.value),
"realized_cost": self.realized_cost.to_dict(),
"updated_ctx": self.updated_ctx.to_dict() if self.updated_ctx else None,
"original_ctx": self.original_ctx.to_dict(),
}
if not skip_cost:
d["realized_cost"] = self.realized_cost.to_dict()
return d


class EvalContext:
Expand Down
13 changes: 13 additions & 0 deletions src/modelgauge/annotators/composer/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,24 @@ class _DAGOutput:
node_outputs: dict[str, NodeOutput]
total_cost: RealizedCost

def to_dict(self, skip_cost=False) -> dict:
d = {
"node_outputs": {k: v.to_dict(skip_cost=skip_cost) for k, v in self.node_outputs.items()},
}
if not skip_cost:
d["total_cost"] = self.total_cost.to_dict()
return d


@dataclass
class SuccessfulDAGOutput(_DAGOutput):
verdict: Verdict

def to_dict(self, skip_cost=False) -> dict:
d = super().to_dict(skip_cost=skip_cost)
d["verdict"] = self.verdict.name
return d


@dataclass
class FailedDAGOutput(_DAGOutput):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,31 @@ def test_dag_updated_context_not_passed_to_parallel_nodes():
assert dag_output.node_outputs["lower_scorer"].value == pytest.approx(1.0)


def test_dag_output_to_dict(simple_dag, sample_ctx):
dag_output = simple_dag.run(sample_ctx)
dag_output_dict = dag_output.to_dict()

assert "verdict" in dag_output_dict
assert "total_cost" in dag_output_dict
assert "node_outputs" in dag_output_dict
for node_output in dag_output_dict["node_outputs"].values():
assert "value" in node_output
assert "original_ctx" in node_output
assert "updated_ctx" in node_output
assert "realized_cost" in node_output

dag_output_dict_no_cost = dag_output.to_dict(skip_cost=True)

assert "verdict" in dag_output_dict_no_cost
assert "total_cost" not in dag_output_dict_no_cost
assert "node_outputs" in dag_output_dict_no_cost
for node_output in dag_output_dict_no_cost["node_outputs"].values():
assert "value" in node_output
assert "original_ctx" in node_output
assert "updated_ctx" in node_output
assert "realized_cost" not in node_output


def test_dag_parallel_nodes_different_updated_contexts_raises_error():
# upper caser and lower caser are parallel nodes, they update the dontext differently which should raise an error.
ctx = EvalContext(prompt="x", response="HELLO")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def test_safety_dag_run(simple_dag, sample_ctx):
)
assert not output.is_safe
assert isinstance(output, SafetyAnnotation)
assert len(output.metadata["node_outputs"]) == 3
assert output.metadata["verdict"] == "UNSAFE"


def test_safety_dag_with_bad_verdict_type():
Expand Down
Loading