Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions src/modelgauge/annotators/cheval/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
import requests

from modelgauge.annotation import EnsembleSafetyAnnotation, SafetyAnnotation
from modelgauge.annotator import Annotator
from modelgauge.annotators.cheval.request import AnnotationRequest
from modelgauge.prompt import ChatPrompt, TextPrompt
from modelgauge.annotators.request import AnnotationRequest
from modelgauge.annotators.sideinfo import SideInformationAwareAnnotator
from modelgauge.retry_decorator import retry
from modelgauge.secret_values import RequiredSecret, SecretDescription
from modelgauge.sut import SUTResponse

_CHEVAL_SCOPE = "cheval"

Expand Down Expand Up @@ -71,7 +69,7 @@ class ChevalAnnotatorError(Exception):
pass


class ChevalAnnotator(Annotator):
class ChevalAnnotator(SideInformationAwareAnnotator):
def __init__(
self,
uid: str,
Expand All @@ -85,15 +83,6 @@ def __init__(
if not self.cheval.knows(self.uid):
raise ChevalAnnotatorError(f"Annotator {self.uid} not found or not ready at {self.endpoint_url}")

def translate_prompt(self, prompt: TextPrompt | ChatPrompt, response: SUTResponse):
if not isinstance(prompt, TextPrompt):
raise ChevalAnnotatorError("ChevalAnnotator only supports TextPrompt")
return AnnotationRequest(
annotator=self.uid,
prompt=prompt.text,
response=response.text,
)

def annotate(self, request: AnnotationRequest) -> SafetyAnnotation:
return self.cheval.annotate(request)

Expand Down
7 changes: 0 additions & 7 deletions src/modelgauge/annotators/cheval/request.py

This file was deleted.

36 changes: 17 additions & 19 deletions src/modelgauge/annotators/composed_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,21 @@
from modelgauge.annotators.composer.dag import Composer, SuccessfulDAGOutput
from modelgauge.annotators.composer.nodes import Arbiter, CacheableNodeMixin, NodeOutput
from modelgauge.annotators.composer.verdict import Verdict
from modelgauge.prompt import ChatPrompt, TextPrompt
from modelgauge.prompt_formatting import format_chat
from modelgauge.annotators.request import AnnotationRequest, AnnotatorSideInformation
from modelgauge.annotators.sideinfo import SideInformationAwareAnnotator
from modelgauge.prompt import TextPrompt
from modelgauge.sut import SUTResponse


class DAGAnnotator(Annotator):
class DAGAnnotator(SideInformationAwareAnnotator):
"""Annotator that executes a DAG."""

def __init__(self, uid: str, dag: Composer) -> None:
super().__init__(uid)
self.dag = dag

def translate_prompt(
self,
prompt: TextPrompt | ChatPrompt,
response: SUTResponse,
) -> EvalContext:
prompt_str = prompt.text if isinstance(prompt, TextPrompt) else format_chat(prompt)
return EvalContext(
prompt=prompt_str,
response=response.text,
)

def annotate(self, annotation_request: EvalContext) -> SuccessfulDAGOutput:
dag_output = self.dag.run(annotation_request)
def annotate(self, annotation_request: AnnotationRequest) -> SuccessfulDAGOutput:
dag_output = self.dag.run(annotation_request.get_eval_context())
if isinstance(dag_output, SuccessfulDAGOutput):
return dag_output
else:
Expand Down Expand Up @@ -81,8 +71,16 @@ def __init__(self, name: str, annotator: Annotator) -> None:
self.annotator = annotator

def run(self, ctx: EvalContext) -> NodeOutput:
prompt = TextPrompt(text=ctx.prompt)
response = SUTResponse(text=ctx.response)
annotation = self.annotator.process(prompt, response)
if isinstance(self.annotator, SideInformationAwareAnnotator):
annotation = self.annotator.process(
prompt=TextPrompt(text=ctx.prompt),
response=SUTResponse(text=ctx.response),
side_information=AnnotatorSideInformation(info=ctx.metadata),
)
else:
annotation = self.annotator.process(
prompt=TextPrompt(text=ctx.prompt),
response=SUTResponse(text=ctx.response),
)
val = Safety(is_safe=annotation.is_safe)
return NodeOutput(value=val, original_ctx=ctx)
33 changes: 27 additions & 6 deletions src/modelgauge/annotators/demo_annotator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from pydantic import BaseModel

from modelgauge.annotation import SafetyAnnotation
from modelgauge.annotator import Annotator
from modelgauge.annotator_registry import ANNOTATORS
from modelgauge.annotators.request import AnnotationRequest
from modelgauge.annotators.sideinfo import SideInformationAwareAnnotator
from modelgauge.prompt import ChatPrompt, TextPrompt
from modelgauge.sut import SUTResponse
from pydantic import BaseModel


class DemoYBadRequest(BaseModel):
Expand All @@ -14,6 +17,14 @@ class DemoYBadResponse(BaseModel):
score: float


def bad_y_annotate(annotation_request: DemoYBadRequest) -> DemoYBadResponse:
score = 0
for character in annotation_request.text:
if character in {"Y", "y"}:
score += 1
return DemoYBadResponse(score=score)


class DemoYBadAnnotator(Annotator):
"""A demonstration annotator that dislikes the letter Y.

Expand All @@ -26,11 +37,21 @@ def translate_prompt(self, prompt: TextPrompt | ChatPrompt, response: SUTRespons
return DemoYBadRequest(text=response.text)

def annotate(self, annotation_request: DemoYBadRequest) -> DemoYBadResponse:
score = 0
for character in annotation_request.text:
if character in {"Y", "y"}:
score += 1
return DemoYBadResponse(score=score)
return bad_y_annotate(annotation_request)

def translate_response(self, request, response: DemoYBadResponse) -> SafetyAnnotation:
return SafetyAnnotation(is_safe=response.score == 0.0)


class SideInfoDemoYBadAnnotator(SideInformationAwareAnnotator):
"""A demonstration annotator that dislikes the letter Y unless passed
side-information telling it to change its mind."""

def annotate(self, annotation_request: AnnotationRequest) -> DemoYBadResponse:
if annotation_request.side_information and annotation_request.side_information.get("cheat", False):
return DemoYBadResponse(score=0.0)
else:
return bad_y_annotate(DemoYBadRequest(text=annotation_request.prompt))

def translate_response(self, request, response: DemoYBadResponse) -> SafetyAnnotation:
return SafetyAnnotation(is_safe=response.score == 0.0)
Expand Down
34 changes: 34 additions & 0 deletions src/modelgauge/annotators/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from pydantic import BaseModel

from modelgauge.annotators.composer.context import EvalContext
from modelgauge.prompt import TextPrompt
from modelgauge.sut import SUTResponse


class AnnotatorSideInformation(BaseModel, frozen=True):
"""Side information for the annotator."""

info: dict


class AnnotationRequest(BaseModel):
annotator: str
prompt: str = ""
response: str
side_information: dict = {} # optional side information to pass to evaluator

def get_text_prompt(self) -> TextPrompt:
return TextPrompt(text=self.prompt)

def get_sut_response(self) -> SUTResponse:
return SUTResponse(text=self.response)

def get_annotator_side_information(self) -> AnnotatorSideInformation:
return AnnotatorSideInformation(info=self.side_information)

def get_eval_context(self) -> EvalContext:
return EvalContext(
prompt=self.prompt,
response=self.response,
metadata=self.side_information,
)
38 changes: 38 additions & 0 deletions src/modelgauge/annotators/sideinfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Optional

from modelgauge.annotation import SafetyAnnotation
from modelgauge.annotator import Annotator
from modelgauge.annotators.request import AnnotationRequest, AnnotatorSideInformation
from modelgauge.prompt import ChatPrompt, TextPrompt
from modelgauge.sut import SUTResponse


class SideInformationAwareAnnotator(Annotator):
"""Abstract Annotator that can accept side information."""

def translate_prompt(
self,
prompt: TextPrompt | ChatPrompt,
response: SUTResponse,
side_information: Optional[AnnotatorSideInformation] = None,
) -> AnnotationRequest:
if not isinstance(prompt, TextPrompt):
raise ValueError(f"{self.__class__.__name__} only supports TextPrompt")
return AnnotationRequest(
annotator=self.uid,
prompt=prompt.text,
response=response.text,
side_information=side_information.info if side_information else {},
)

def process(
self,
prompt: TextPrompt | ChatPrompt,
response: SUTResponse,
side_information: Optional[AnnotatorSideInformation] = None,
) -> SafetyAnnotation:
# Proper fix in the future should port this signature to Annotator.process
# `translate_prompt` should be updated to allow side information.
annotator_request = self.translate_prompt(prompt, response, side_information)
annotator_response = self.annotate(annotator_request)
return self.translate_response(annotator_request, annotator_response)
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,28 @@
NodeExecutionError,
)
from modelgauge.annotators.composer.verdict import Verdict
from modelgauge.annotators.demo_annotator import DemoYBadAnnotator
from modelgauge.annotators.demo_annotator import SideInfoDemoYBadAnnotator
from modelgauge.prompt import TextPrompt
from modelgauge.sut import SUTResponse


def test_safety_annotator_arbiter(sample_ctx):
annotator = DemoYBadAnnotator("demo_annotator")
annotator = SideInfoDemoYBadAnnotator("demo_annotator")
arbiter = AnnotatorArbiter(name="demo_arbiter", annotator=annotator)
output = arbiter.run(sample_ctx)
assert output.value.is_safe
assert isinstance(output.value, Safety)
assert arbiter.verdict_type == Safety

bad_y_ctx = sample_ctx.with_prompt("y")
output = arbiter.run(bad_y_ctx)
assert not output.value.is_safe

cheat_ctx = bad_y_ctx.with_metadata_updates({"cheat": True})
output = arbiter.run(cheat_ctx)
assert output.value.is_safe
assert "y" in cheat_ctx.prompt


def test_safety_dag_run(simple_dag, sample_ctx):
safety_annotator = SafetyDAGAnnotator("safety", simple_dag)
Expand All @@ -35,6 +44,7 @@ def test_safety_dag_run(simple_dag, sample_ctx):
)
assert not output.is_safe
assert isinstance(output, SafetyAnnotation)
assert output.metadata is not None
assert len(output.metadata["node_outputs"]) == 3
assert output.metadata["verdict"] == "UNSAFE"

Expand Down
Loading