Skip to content

Commit de246c6

Browse files
author
Dylan Huang
authored
Stable ids (#99)
* implement human id mapping + hashing for evaluationrow * example w/ row ids from dataset * test_pytest_stable_row_ids
1 parent 159a12f commit de246c6

File tree

6 files changed

+276
-6
lines changed

6 files changed

+276
-6
lines changed

eval_protocol/human_id/__init__.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,68 @@
44

55
from . import dictionary
66

7-
__all__ = ["generate_id"]
7+
__all__ = ["generate_id", "num_combinations"]
88

99
system_random = random.SystemRandom()
1010

1111

12-
def generate_id(separator="-", seed: int | float | str | bytes | bytearray | None = None, word_count=5) -> str:
12+
def generate_id(
13+
separator: str = "-",
14+
seed: int | float | str | bytes | bytearray | None = None,
15+
word_count: int = 5,
16+
index: int | None = None,
17+
) -> str:
1318
"""
1419
Generate a human readable ID
1520
1621
:param separator: The string to use to separate words
17-
:param seed: The seed to use. The same seed will produce the same ID
22+
:param seed: The seed to use. The same seed will produce the same ID or index-based mapping
23+
:param index: Optional non-negative integer providing a 1:1 mapping to an ID.
24+
When provided, the mapping is deterministic and bijective for
25+
all integers in range [0, total_combinations).
1826
:param word_count: The number of words to use. Minimum of 3.
1927
:return: A human readable ID
2028
"""
2129
if word_count < 3:
2230
raise ValueError("word_count cannot be lower than 3")
2331

32+
# If a specific index is provided, use mixed-radix encoding into a fixed
33+
# sequence of parts to guarantee a bijection between integers and IDs.
34+
# The sequence cycles as: verb, adjective, noun, verb, adjective, noun, ...
35+
if index is not None:
36+
if not isinstance(index, int) or index < 0:
37+
raise ValueError("index must be a non-negative integer if provided")
38+
39+
# Prepare category lists; if seed is provided, shuffle deterministically
40+
base_categories = [dictionary.verbs, dictionary.adjectives, dictionary.nouns]
41+
if seed is not None:
42+
rnd = random.Random(seed)
43+
categories = [tuple(rnd.sample(cat, len(cat))) for cat in base_categories]
44+
else:
45+
categories = base_categories
46+
# Build the category order for the desired word_count
47+
ordered_categories = [categories[i % 3] for i in range(word_count)]
48+
49+
# Compute total number of combinations for this word_count
50+
radices = [len(cat) for cat in ordered_categories]
51+
total = num_combinations(word_count)
52+
53+
if index >= total:
54+
raise ValueError(f"index out of range for given word_count. Received {index}, max allowed is {total - 1}")
55+
56+
# Mixed-radix decomposition (least significant position is the last word)
57+
digits: list[int] = []
58+
remaining = index
59+
for base in reversed(radices):
60+
digits.append(remaining % base)
61+
remaining //= base
62+
digits.reverse()
63+
64+
words = [ordered_categories[pos][digits[pos]] for pos in range(word_count)]
65+
return separator.join(words)
66+
2467
random_obj = system_random
25-
if seed:
68+
if seed is not None:
2669
random_obj = random.Random(seed)
2770

2871
parts = {dictionary.verbs: 1, dictionary.adjectives: 1, dictionary.nouns: 1}
@@ -33,3 +76,21 @@ def generate_id(separator="-", seed: int | float | str | bytes | bytearray | Non
3376
parts = itertools.chain.from_iterable(random_obj.sample(part, count) for part, count in parts.items())
3477

3578
return separator.join(parts)
79+
80+
81+
def num_combinations(word_count: int = 5) -> int:
82+
"""
83+
Return the total number of unique IDs possible for the given word_count.
84+
85+
The sequence of categories cycles as: verb, adjective, noun, then repeats.
86+
This value can be used to mod an index when calling generate_id(index=...).
87+
"""
88+
if word_count < 3:
89+
raise ValueError("word_count cannot be lower than 3")
90+
91+
categories = [dictionary.verbs, dictionary.adjectives, dictionary.nouns]
92+
radices = [len(categories[i % 3]) for i in range(word_count)]
93+
total = 1
94+
for r in radices:
95+
total *= r
96+
return total

eval_protocol/models.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ class InputMetadata(BaseModel):
211211

212212
model_config = ConfigDict(extra="allow")
213213

214-
row_id: Optional[str] = Field(default_factory=generate_id, description="Unique string to ID the row")
214+
row_id: Optional[str] = Field(None, description="Unique string to ID the row")
215215
completion_params: CompletionParams = Field(
216216
default_factory=dict, description="Completion endpoint parameters used"
217217
)
@@ -429,6 +429,22 @@ def get_termination_reason(self) -> str:
429429
return msg.control_plane_step["termination_reason"]
430430
return "unknown"
431431

432+
def __hash__(self) -> int:
433+
# Use a stable hash by sorting keys and ensuring compact output
434+
json_str = self.stable_json(self)
435+
return hash(json_str)
436+
437+
@classmethod
438+
def stable_json(cls, row: "EvaluationRow") -> int:
439+
json_str = row.model_dump_json(
440+
exclude_none=True,
441+
exclude_defaults=True,
442+
by_alias=True,
443+
indent=None,
444+
exclude=["created_at", "execution_metadata"],
445+
)
446+
return json_str
447+
432448

433449
# Original dataclass-based models for backwards compatibility
434450
# These are deprecated and will be removed in a future version

eval_protocol/pytest/evaluation_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from eval_protocol.dataset_logger import default_logger
1717
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
18-
from eval_protocol.human_id import generate_id
18+
from eval_protocol.human_id import generate_id, num_combinations
1919
from eval_protocol.models import (
2020
CompletionParams,
2121
EvalMetadata,
@@ -294,6 +294,16 @@ def _log_eval_error(
294294
else:
295295
raise ValueError("No input dataset or input messages provided")
296296

297+
for row in data:
298+
# generate a stable row_id for each row
299+
if row.input_metadata.row_id is None:
300+
# Generate a stable, deterministic row_id using the row's hash and num_combinations
301+
index = hash(row)
302+
max_index = num_combinations() - 1
303+
# Ensure index is a non-negative integer within [0, max_index]
304+
index = abs(index) % (max_index + 1)
305+
row.input_metadata.row_id = generate_id(seed=0, index=index)
306+
297307
if "completion_params" not in kwargs or not kwargs["completion_params"]:
298308
raise ValueError(
299309
"No completion parameters provided. Please provide a completion parameters object."
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from typing import List
2+
3+
from eval_protocol.models import EvaluationRow, Message
4+
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
5+
from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row
6+
7+
8+
async def test_evaluation_test_decorator_ids_single():
9+
from eval_protocol.pytest.evaluation_test import evaluation_test
10+
11+
row_ids = set()
12+
13+
input_dataset = [
14+
"tests/pytest/data/markdown_dataset.jsonl",
15+
"tests/pytest/data/markdown_dataset.jsonl",
16+
]
17+
completion_params_list = [
18+
{"temperature": 0.0, "model": "dummy/local-model"},
19+
{"temperature": 1.0, "model": "dummy/local-model"},
20+
]
21+
22+
@evaluation_test(
23+
input_dataset=input_dataset,
24+
completion_params=completion_params_list,
25+
dataset_adapter=markdown_dataset_to_evaluation_row,
26+
rollout_processor=NoOpRolloutProcessor(),
27+
mode="pointwise",
28+
combine_datasets=False,
29+
num_runs=5,
30+
)
31+
def eval_fn(row: EvaluationRow) -> EvaluationRow:
32+
row_ids.add(row.input_metadata.row_id)
33+
return row
34+
35+
# Manually invoke all parameter combinations within a single test
36+
for ds_path in input_dataset:
37+
for params in completion_params_list:
38+
await eval_fn(dataset_path=[ds_path], completion_params=params)
39+
40+
# Second invocation to ensure that IDs are stable across multiple invocations
41+
for ds_path in input_dataset:
42+
for params in completion_params_list:
43+
await eval_fn(dataset_path=[ds_path], completion_params=params)
44+
45+
# Assertions on IDs generated by the decorator logic
46+
assert len(row_ids) == 19 # from the markdown dataset
47+
48+
49+
async def test_evaluation_test_generated_row_ids_without_dataset_keys():
50+
from eval_protocol.pytest.evaluation_test import evaluation_test
51+
52+
# Adapter that does NOT set row_id; lets evaluation_test generate IDs
53+
def markdown_dataset_no_row_id_adapter(data: List[dict]) -> List[EvaluationRow]:
54+
return [
55+
EvaluationRow(
56+
messages=[Message(role="user", content=row["prompt"])],
57+
ground_truth=str(row["num_highlights"]),
58+
)
59+
for row in data
60+
]
61+
62+
row_ids = set()
63+
64+
input_dataset = ["tests/pytest/data/markdown_dataset.jsonl", "tests/pytest/data/markdown_dataset.jsonl"]
65+
completion_params = [
66+
{"temperature": 0.0, "model": "dummy/local-model"},
67+
{"temperature": 1.0, "model": "dummy/local-model"},
68+
]
69+
70+
@evaluation_test(
71+
input_dataset=input_dataset,
72+
completion_params=completion_params,
73+
dataset_adapter=markdown_dataset_no_row_id_adapter,
74+
rollout_processor=NoOpRolloutProcessor(),
75+
mode="pointwise",
76+
combine_datasets=False,
77+
num_runs=5,
78+
)
79+
def eval_fn(row: EvaluationRow) -> EvaluationRow:
80+
# row_id should be auto-generated by evaluation_test/InputMetadata
81+
assert row.input_metadata is not None
82+
assert row.input_metadata.row_id is not None and isinstance(row.input_metadata.row_id, str)
83+
row_ids.add(row.input_metadata.row_id)
84+
return row
85+
86+
# Single invocation (one dataset, one param set) with multiple runs
87+
for ds_path in input_dataset:
88+
for params in completion_params:
89+
await eval_fn(dataset_path=[ds_path], completion_params=params)
90+
91+
# Second invocation to ensure that IDs are stable across multiple invocations
92+
for ds_path in input_dataset:
93+
for params in completion_params:
94+
await eval_fn(dataset_path=[ds_path], completion_params=params)
95+
96+
# Even with multiple runs, generated row_ids should be stable within the invocation
97+
assert len(row_ids) == 19 # equals dataset size when IDs are generated once and preserved across runs

tests/test_human_id.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import re
2+
import pytest
3+
4+
from eval_protocol.human_id import generate_id, num_combinations
5+
6+
7+
def test_generate_id_index_basic_3_words():
8+
# index 0 maps to the first element of each category (verb, adjective, noun)
9+
assert generate_id(index=0, word_count=3) == "be-other-time"
10+
11+
# incrementing index advances the least-significant position (noun)
12+
assert generate_id(index=1, word_count=3) == "be-other-year"
13+
14+
# carry into the adjective when nouns wrap
15+
# index == len(nouns) => adjective advances by 1, noun resets
16+
# nouns length inferred by probing with large indices is brittle; instead, compute via reach
17+
# We know index=0 gives be-other-time, and index that produces adjective=new, noun=time should be reachable.
18+
# Derive by scanning forward until adjective changes to 'new'. This keeps test robust to dictionary size edits.
19+
base = generate_id(index=0, word_count=3)
20+
# Find the first index where adjective becomes 'new' and noun resets to 'time'
21+
target = None
22+
for i in range(1, 2000):
23+
cand = generate_id(index=i, word_count=3)
24+
if cand.startswith("be-new-time"):
25+
target = i
26+
break
27+
assert target is not None, "Expected to find carry into adjective within search bound"
28+
assert generate_id(index=target, word_count=3) == "be-new-time"
29+
30+
31+
def test_generate_id_index_word_count_cycle():
32+
# word_count cycles categories: verb, adj, noun, verb, adj, ...
33+
assert generate_id(index=0, word_count=5) == "be-other-time-be-other"
34+
# increment least-significant position (adj at position 5)
35+
assert generate_id(index=1, word_count=5) == "be-other-time-be-new"
36+
37+
38+
def test_generate_id_index_out_of_range_and_negative():
39+
# Use exported total combinations for clean boundary checks
40+
total = num_combinations(word_count=3)
41+
assert total > 0
42+
# Last valid index
43+
generate_id(index=total - 1, word_count=3)
44+
# First invalid index
45+
with pytest.raises(ValueError):
46+
generate_id(index=total, word_count=3)
47+
48+
with pytest.raises(ValueError):
49+
generate_id(index=-1, word_count=3)
50+
51+
52+
def test_generate_id_seed_stability_and_compat():
53+
# Without index, same seed yields same id
54+
a = generate_id(seed=1234)
55+
b = generate_id(seed=1234)
56+
assert a == b
57+
58+
# Without index, default produces separator '-' and at least 3 components
59+
c = generate_id()
60+
assert re.match(r"^[a-z]+(-[a-z]+){2,}$", c)
61+
62+
63+
def test_generate_id_index_ignores_seed():
64+
# With index provided, seed should affect the mapping deterministically
65+
x = generate_id(index=42, seed=1)
66+
y = generate_id(index=42, seed=999)
67+
z = generate_id(index=42, seed=1)
68+
assert x != y
69+
assert x == z

tests/test_models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,23 @@ def test_evaluation_row_creation():
289289
assert not row.is_trajectory_evaluation()
290290

291291

292+
def test_stable_hash():
293+
"""Test the stable hash method."""
294+
row = EvaluationRow(
295+
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
296+
ground_truth="4",
297+
)
298+
row2 = EvaluationRow(
299+
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
300+
ground_truth="4",
301+
)
302+
stable_json = EvaluationRow.stable_json(row)
303+
stable_json2 = EvaluationRow.stable_json(row2)
304+
assert stable_json == stable_json2
305+
assert "created_at" not in stable_json
306+
assert "execution_metadata" not in stable_json
307+
308+
292309
def test_evaluation_row_trajectory_evaluation():
293310
"""Test EvaluationRow with trajectory evaluation."""
294311
messages = [

0 commit comments

Comments
 (0)