diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 32859d29f..b25d5d49f 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -87,6 +87,35 @@ _MESSAGE_TYPE_UNSET = object() +def _format_prompt_fn( + prompt_str: str, + system_prompt: str | None = None, + few_shot: Messages | None = None, +) -> Messages: + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if few_shot: + messages.extend(few_shot) + messages.append({"role": "user", "content": prompt_str}) + return messages + +def _prepend_system_prompt(prompt: list[Any], system_prompt: str) -> list[Any]: + assert isinstance(prompt, list), ( + f"prompt must be a list of messages when system_prompt is provided, got {type(prompt)}" + ) + # Check if a system message already exists (first message) + first = prompt[0] if prompt else None + first_role = ( + first.get("role") + if isinstance(first, dict) + else getattr(first, "role", None) + ) + if first_role == "system": + return prompt + # Prepend as a plain dict so Arrow/HuggingFace can serialize. + # Normalization to Pydantic happens later in init_state. + return [{"role": "system", "content": system_prompt}, *prompt] class Environment(ABC): """ @@ -282,14 +311,14 @@ def _sync_teardown(): ) signal.signal(signal.SIGTERM, lambda _, __: (_sync_teardown(), exit(143))) - def _ensure_example_id(self, dataset: Dataset) -> Dataset: + def _ensure_example_id(self, dataset: Dataset, map_kwargs: dict = {}) -> Dataset: """Ensure example_id column exists and is integer type.""" if "example_id" in dataset.column_names and not isinstance( dataset["example_id"][0], int ): dataset = dataset.rename_column("example_id", "src_id") if "example_id" not in dataset.column_names: - dataset = dataset.add_column("example_id", range(len(dataset))) + dataset = dataset.map(lambda _, i: {"example_id": i}, with_indices=True, **map_kwargs) return dataset def _ensure_prompt( @@ -303,27 +332,17 @@ def _ensure_prompt( ) -> Dataset: """Ensure prompt column exists.""" if "prompt" not in dataset.column_names: - - def format_prompt_fn(prompt_str: str) -> Messages: - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - if few_shot: - messages.extend(few_shot) - messages.append({"role": "user", "content": prompt_str}) - return messages - if answer_key == "answer": dataset = dataset.map( lambda x: { - "prompt": format_prompt_fn(x[question_key]), + "prompt": _format_prompt_fn(x[question_key], system_prompt, few_shot), }, **map_kwargs, ) else: dataset = dataset.map( lambda x: { - "prompt": format_prompt_fn(x[question_key]), + "prompt": _format_prompt_fn(x[question_key], system_prompt, few_shot), "answer": x[answer_key], }, **map_kwargs, @@ -332,25 +351,8 @@ def format_prompt_fn(prompt_str: str) -> Messages: else: if system_prompt is not None: - def prepend_system_prompt(prompt: list[Any]) -> list[Any]: - assert isinstance(prompt, list), ( - f"prompt must be a list of messages when system_prompt is provided, got {type(prompt)}" - ) - # Check if a system message already exists (first message) - first = prompt[0] if prompt else None - first_role = ( - first.get("role") - if isinstance(first, dict) - else getattr(first, "role", None) - ) - if first_role == "system": - return prompt - # Prepend as a plain dict so Arrow/HuggingFace can serialize. - # Normalization to Pydantic happens later in init_state. - return [{"role": "system", "content": system_prompt}, *prompt] - dataset = dataset.map( - lambda x: {"prompt": prepend_system_prompt(x["prompt"])}, + lambda x: {"prompt": _prepend_system_prompt(x["prompt"], system_prompt)}, **map_kwargs, ) if few_shot is not None: @@ -383,7 +385,7 @@ def _format_dataset( """ Format dataset by creating example_id and prompt columns, and setting task column. """ - dataset = self._ensure_example_id(dataset) + dataset = self._ensure_example_id(dataset, map_kwargs) dataset = self._ensure_prompt( dataset, system_prompt, few_shot, question_key, answer_key, map_kwargs ) @@ -396,7 +398,7 @@ def _format_completion_dataset( """ Format dataset by creating example_id and prompt columns, and setting task column. """ - dataset = self._ensure_example_id(dataset) + dataset = self._ensure_example_id(dataset, map_kwargs) dataset = self._ensure_task(dataset, map_kwargs) return dataset