Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 36 additions & 34 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,35 @@

_MESSAGE_TYPE_UNSET = object()

def _format_prompt_fn(
prompt_str: str,
system_prompt: str | None = None,
few_shot: Messages | None = None,
) -> Messages:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if few_shot:
messages.extend(few_shot)
messages.append({"role": "user", "content": prompt_str})
return messages

def _prepend_system_prompt(prompt: list[Any], system_prompt: str) -> list[Any]:
assert isinstance(prompt, list), (
f"prompt must be a list of messages when system_prompt is provided, got {type(prompt)}"
)
# Check if a system message already exists (first message)
first = prompt[0] if prompt else None
first_role = (
first.get("role")
if isinstance(first, dict)
else getattr(first, "role", None)
)
if first_role == "system":
return prompt
# Prepend as a plain dict so Arrow/HuggingFace can serialize.
# Normalization to Pydantic happens later in init_state.
return [{"role": "system", "content": system_prompt}, *prompt]

class Environment(ABC):
"""
Expand Down Expand Up @@ -282,14 +311,14 @@ def _sync_teardown():
)
signal.signal(signal.SIGTERM, lambda _, __: (_sync_teardown(), exit(143)))

def _ensure_example_id(self, dataset: Dataset) -> Dataset:
def _ensure_example_id(self, dataset: Dataset, map_kwargs: dict = {}) -> Dataset:
"""Ensure example_id column exists and is integer type."""
if "example_id" in dataset.column_names and not isinstance(
dataset["example_id"][0], int
):
dataset = dataset.rename_column("example_id", "src_id")
if "example_id" not in dataset.column_names:
dataset = dataset.add_column("example_id", range(len(dataset)))
dataset = dataset.map(lambda _, i: {"example_id": i}, with_indices=True, **map_kwargs)
return dataset

def _ensure_prompt(
Expand All @@ -303,27 +332,17 @@ def _ensure_prompt(
) -> Dataset:
"""Ensure prompt column exists."""
if "prompt" not in dataset.column_names:

def format_prompt_fn(prompt_str: str) -> Messages:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if few_shot:
messages.extend(few_shot)
messages.append({"role": "user", "content": prompt_str})
return messages

if answer_key == "answer":
dataset = dataset.map(
lambda x: {
"prompt": format_prompt_fn(x[question_key]),
"prompt": _format_prompt_fn(x[question_key], system_prompt, few_shot),
},
**map_kwargs,
)
else:
dataset = dataset.map(
lambda x: {
"prompt": format_prompt_fn(x[question_key]),
"prompt": _format_prompt_fn(x[question_key], system_prompt, few_shot),
"answer": x[answer_key],
},
**map_kwargs,
Expand All @@ -332,25 +351,8 @@ def format_prompt_fn(prompt_str: str) -> Messages:
else:
if system_prompt is not None:

def prepend_system_prompt(prompt: list[Any]) -> list[Any]:
assert isinstance(prompt, list), (
f"prompt must be a list of messages when system_prompt is provided, got {type(prompt)}"
)
# Check if a system message already exists (first message)
first = prompt[0] if prompt else None
first_role = (
first.get("role")
if isinstance(first, dict)
else getattr(first, "role", None)
)
if first_role == "system":
return prompt
# Prepend as a plain dict so Arrow/HuggingFace can serialize.
# Normalization to Pydantic happens later in init_state.
return [{"role": "system", "content": system_prompt}, *prompt]

dataset = dataset.map(
lambda x: {"prompt": prepend_system_prompt(x["prompt"])},
lambda x: {"prompt": _prepend_system_prompt(x["prompt"], system_prompt)},
**map_kwargs,
)
if few_shot is not None:
Expand Down Expand Up @@ -383,7 +385,7 @@ def _format_dataset(
"""
Format dataset by creating example_id and prompt columns, and setting task column.
"""
dataset = self._ensure_example_id(dataset)
dataset = self._ensure_example_id(dataset, map_kwargs)
dataset = self._ensure_prompt(
dataset, system_prompt, few_shot, question_key, answer_key, map_kwargs
)
Expand All @@ -396,7 +398,7 @@ def _format_completion_dataset(
"""
Format dataset by creating example_id and prompt columns, and setting task column.
"""
dataset = self._ensure_example_id(dataset)
dataset = self._ensure_example_id(dataset, map_kwargs)
dataset = self._ensure_task(dataset, map_kwargs)
return dataset

Expand Down