diff --git a/src/modelgauge/annotators/cheval/annotator.py b/src/modelgauge/annotators/cheval/annotator.py index 0319cde13..9af2a7a2a 100644 --- a/src/modelgauge/annotators/cheval/annotator.py +++ b/src/modelgauge/annotators/cheval/annotator.py @@ -4,12 +4,10 @@ import requests from modelgauge.annotation import EnsembleSafetyAnnotation, SafetyAnnotation -from modelgauge.annotator import Annotator -from modelgauge.annotators.cheval.request import AnnotationRequest -from modelgauge.prompt import ChatPrompt, TextPrompt +from modelgauge.annotators.request import AnnotationRequest +from modelgauge.annotators.sideinfo import SideInformationAwareAnnotator from modelgauge.retry_decorator import retry from modelgauge.secret_values import RequiredSecret, SecretDescription -from modelgauge.sut import SUTResponse _CHEVAL_SCOPE = "cheval" @@ -71,7 +69,7 @@ class ChevalAnnotatorError(Exception): pass -class ChevalAnnotator(Annotator): +class ChevalAnnotator(SideInformationAwareAnnotator): def __init__( self, uid: str, @@ -85,15 +83,6 @@ def __init__( if not self.cheval.knows(self.uid): raise ChevalAnnotatorError(f"Annotator {self.uid} not found or not ready at {self.endpoint_url}") - def translate_prompt(self, prompt: TextPrompt | ChatPrompt, response: SUTResponse): - if not isinstance(prompt, TextPrompt): - raise ChevalAnnotatorError("ChevalAnnotator only supports TextPrompt") - return AnnotationRequest( - annotator=self.uid, - prompt=prompt.text, - response=response.text, - ) - def annotate(self, request: AnnotationRequest) -> SafetyAnnotation: return self.cheval.annotate(request) diff --git a/src/modelgauge/annotators/cheval/request.py b/src/modelgauge/annotators/cheval/request.py deleted file mode 100644 index c24a26890..000000000 --- a/src/modelgauge/annotators/cheval/request.py +++ /dev/null @@ -1,7 +0,0 @@ -from pydantic import BaseModel - - -class AnnotationRequest(BaseModel): - annotator: str - prompt: str = "" - response: str diff --git a/src/modelgauge/annotators/composed_annotator.py b/src/modelgauge/annotators/composed_annotator.py index 7868a3fb3..b7447343f 100644 --- a/src/modelgauge/annotators/composed_annotator.py +++ b/src/modelgauge/annotators/composed_annotator.py @@ -4,31 +4,21 @@ from modelgauge.annotators.composer.dag import Composer, SuccessfulDAGOutput from modelgauge.annotators.composer.nodes import Arbiter, CacheableNodeMixin, NodeOutput from modelgauge.annotators.composer.verdict import Verdict -from modelgauge.prompt import ChatPrompt, TextPrompt -from modelgauge.prompt_formatting import format_chat +from modelgauge.annotators.request import AnnotationRequest, AnnotatorSideInformation +from modelgauge.annotators.sideinfo import SideInformationAwareAnnotator +from modelgauge.prompt import TextPrompt from modelgauge.sut import SUTResponse -class DAGAnnotator(Annotator): +class DAGAnnotator(SideInformationAwareAnnotator): """Annotator that executes a DAG.""" def __init__(self, uid: str, dag: Composer) -> None: super().__init__(uid) self.dag = dag - def translate_prompt( - self, - prompt: TextPrompt | ChatPrompt, - response: SUTResponse, - ) -> EvalContext: - prompt_str = prompt.text if isinstance(prompt, TextPrompt) else format_chat(prompt) - return EvalContext( - prompt=prompt_str, - response=response.text, - ) - - def annotate(self, annotation_request: EvalContext) -> SuccessfulDAGOutput: - dag_output = self.dag.run(annotation_request) + def annotate(self, annotation_request: AnnotationRequest) -> SuccessfulDAGOutput: + dag_output = self.dag.run(annotation_request.get_eval_context()) if isinstance(dag_output, SuccessfulDAGOutput): return dag_output else: @@ -81,8 +71,16 @@ def __init__(self, name: str, annotator: Annotator) -> None: self.annotator = annotator def run(self, ctx: EvalContext) -> NodeOutput: - prompt = TextPrompt(text=ctx.prompt) - response = SUTResponse(text=ctx.response) - annotation = self.annotator.process(prompt, response) + if isinstance(self.annotator, SideInformationAwareAnnotator): + annotation = self.annotator.process( + prompt=TextPrompt(text=ctx.prompt), + response=SUTResponse(text=ctx.response), + side_information=AnnotatorSideInformation(info=ctx.metadata), + ) + else: + annotation = self.annotator.process( + prompt=TextPrompt(text=ctx.prompt), + response=SUTResponse(text=ctx.response), + ) val = Safety(is_safe=annotation.is_safe) return NodeOutput(value=val, original_ctx=ctx) diff --git a/src/modelgauge/annotators/demo_annotator.py b/src/modelgauge/annotators/demo_annotator.py index 234d544f8..857e9e3b0 100644 --- a/src/modelgauge/annotators/demo_annotator.py +++ b/src/modelgauge/annotators/demo_annotator.py @@ -1,9 +1,12 @@ +from pydantic import BaseModel + from modelgauge.annotation import SafetyAnnotation from modelgauge.annotator import Annotator from modelgauge.annotator_registry import ANNOTATORS +from modelgauge.annotators.request import AnnotationRequest +from modelgauge.annotators.sideinfo import SideInformationAwareAnnotator from modelgauge.prompt import ChatPrompt, TextPrompt from modelgauge.sut import SUTResponse -from pydantic import BaseModel class DemoYBadRequest(BaseModel): @@ -14,6 +17,14 @@ class DemoYBadResponse(BaseModel): score: float +def bad_y_annotate(annotation_request: DemoYBadRequest) -> DemoYBadResponse: + score = 0 + for character in annotation_request.text: + if character in {"Y", "y"}: + score += 1 + return DemoYBadResponse(score=score) + + class DemoYBadAnnotator(Annotator): """A demonstration annotator that dislikes the letter Y. @@ -26,11 +37,21 @@ def translate_prompt(self, prompt: TextPrompt | ChatPrompt, response: SUTRespons return DemoYBadRequest(text=response.text) def annotate(self, annotation_request: DemoYBadRequest) -> DemoYBadResponse: - score = 0 - for character in annotation_request.text: - if character in {"Y", "y"}: - score += 1 - return DemoYBadResponse(score=score) + return bad_y_annotate(annotation_request) + + def translate_response(self, request, response: DemoYBadResponse) -> SafetyAnnotation: + return SafetyAnnotation(is_safe=response.score == 0.0) + + +class SideInfoDemoYBadAnnotator(SideInformationAwareAnnotator): + """A demonstration annotator that dislikes the letter Y unless passed + side-information telling it to change its mind.""" + + def annotate(self, annotation_request: AnnotationRequest) -> DemoYBadResponse: + if annotation_request.side_information and annotation_request.side_information.get("cheat", False): + return DemoYBadResponse(score=0.0) + else: + return bad_y_annotate(DemoYBadRequest(text=annotation_request.prompt)) def translate_response(self, request, response: DemoYBadResponse) -> SafetyAnnotation: return SafetyAnnotation(is_safe=response.score == 0.0) diff --git a/src/modelgauge/annotators/request.py b/src/modelgauge/annotators/request.py new file mode 100644 index 000000000..7658275d7 --- /dev/null +++ b/src/modelgauge/annotators/request.py @@ -0,0 +1,34 @@ +from pydantic import BaseModel + +from modelgauge.annotators.composer.context import EvalContext +from modelgauge.prompt import TextPrompt +from modelgauge.sut import SUTResponse + + +class AnnotatorSideInformation(BaseModel, frozen=True): + """Side information for the annotator.""" + + info: dict + + +class AnnotationRequest(BaseModel): + annotator: str + prompt: str = "" + response: str + side_information: dict = {} # optional side information to pass to evaluator + + def get_text_prompt(self) -> TextPrompt: + return TextPrompt(text=self.prompt) + + def get_sut_response(self) -> SUTResponse: + return SUTResponse(text=self.response) + + def get_annotator_side_information(self) -> AnnotatorSideInformation: + return AnnotatorSideInformation(info=self.side_information) + + def get_eval_context(self) -> EvalContext: + return EvalContext( + prompt=self.prompt, + response=self.response, + metadata=self.side_information, + ) diff --git a/src/modelgauge/annotators/sideinfo.py b/src/modelgauge/annotators/sideinfo.py new file mode 100644 index 000000000..c845f7fd0 --- /dev/null +++ b/src/modelgauge/annotators/sideinfo.py @@ -0,0 +1,38 @@ +from typing import Optional + +from modelgauge.annotation import SafetyAnnotation +from modelgauge.annotator import Annotator +from modelgauge.annotators.request import AnnotationRequest, AnnotatorSideInformation +from modelgauge.prompt import ChatPrompt, TextPrompt +from modelgauge.sut import SUTResponse + + +class SideInformationAwareAnnotator(Annotator): + """Abstract Annotator that can accept side information.""" + + def translate_prompt( + self, + prompt: TextPrompt | ChatPrompt, + response: SUTResponse, + side_information: Optional[AnnotatorSideInformation] = None, + ) -> AnnotationRequest: + if not isinstance(prompt, TextPrompt): + raise ValueError(f"{self.__class__.__name__} only supports TextPrompt") + return AnnotationRequest( + annotator=self.uid, + prompt=prompt.text, + response=response.text, + side_information=side_information.info if side_information else {}, + ) + + def process( + self, + prompt: TextPrompt | ChatPrompt, + response: SUTResponse, + side_information: Optional[AnnotatorSideInformation] = None, + ) -> SafetyAnnotation: + # Proper fix in the future should port this signature to Annotator.process + # `translate_prompt` should be updated to allow side information. + annotator_request = self.translate_prompt(prompt, response, side_information) + annotator_response = self.annotate(annotator_request) + return self.translate_response(annotator_request, annotator_response) diff --git a/tests/modelgauge_tests/annotator_tests/composer_tests/test_safety.py b/tests/modelgauge_tests/annotator_tests/composer_tests/test_safety.py index 3f4c0dbb1..ff376ab07 100644 --- a/tests/modelgauge_tests/annotator_tests/composer_tests/test_safety.py +++ b/tests/modelgauge_tests/annotator_tests/composer_tests/test_safety.py @@ -13,19 +13,28 @@ NodeExecutionError, ) from modelgauge.annotators.composer.verdict import Verdict -from modelgauge.annotators.demo_annotator import DemoYBadAnnotator +from modelgauge.annotators.demo_annotator import SideInfoDemoYBadAnnotator from modelgauge.prompt import TextPrompt from modelgauge.sut import SUTResponse def test_safety_annotator_arbiter(sample_ctx): - annotator = DemoYBadAnnotator("demo_annotator") + annotator = SideInfoDemoYBadAnnotator("demo_annotator") arbiter = AnnotatorArbiter(name="demo_arbiter", annotator=annotator) output = arbiter.run(sample_ctx) assert output.value.is_safe assert isinstance(output.value, Safety) assert arbiter.verdict_type == Safety + bad_y_ctx = sample_ctx.with_prompt("y") + output = arbiter.run(bad_y_ctx) + assert not output.value.is_safe + + cheat_ctx = bad_y_ctx.with_metadata_updates({"cheat": True}) + output = arbiter.run(cheat_ctx) + assert output.value.is_safe + assert "y" in cheat_ctx.prompt + def test_safety_dag_run(simple_dag, sample_ctx): safety_annotator = SafetyDAGAnnotator("safety", simple_dag) @@ -35,6 +44,7 @@ def test_safety_dag_run(simple_dag, sample_ctx): ) assert not output.is_safe assert isinstance(output, SafetyAnnotation) + assert output.metadata is not None assert len(output.metadata["node_outputs"]) == 3 assert output.metadata["verdict"] == "UNSAFE"