Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 50 additions & 10 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,14 @@ class InputMetadata(BaseModel):

model_config = ConfigDict(extra="allow")

row_id: Optional[str] = Field(None, description="Unique string to ID the row")
row_id: Optional[str] = Field(
default=None,
description=(
"Unique string to ID the row. If not provided, a stable hash will be generated "
"based on the row's content. The hash removes fields that are not typically stable "
"across processes such as created_at, execution_metadata, and pid."
),
)
completion_params: CompletionParams = Field(
default_factory=dict, description="Completion endpoint parameters used"
)
Expand Down Expand Up @@ -430,20 +437,53 @@ def get_termination_reason(self) -> str:
return "unknown"

def __hash__(self) -> int:
# Use a stable hash by sorting keys and ensuring compact output
json_str = self.stable_json(self)
return hash(json_str)
# Use a stable hash that works across Python processes
return self._stable_hash()

def _stable_hash(self) -> int:
"""Generate a stable hash that works across Python processes."""
import hashlib

# Get the stable JSON representation
json_str = self._stable_json()

# Use SHA-256 for deterministic hashing across processes
hash_obj = hashlib.sha256(json_str.encode("utf-8"))

# Convert to a positive integer (first 8 bytes)
hash_bytes = hash_obj.digest()[:8]
return int.from_bytes(hash_bytes, byteorder="big")

def _stable_json(self) -> str:
"""Generate a stable JSON string representation for hashing."""
# Produce a canonical, key-sorted JSON across nested structures and
# exclude volatile fields that can differ across processes
import json
from enum import Enum

def canonicalize(value):
# Recursively convert to a structure with deterministic key ordering
if isinstance(value, dict):
return {k: canonicalize(value[k]) for k in sorted(value.keys())}
if isinstance(value, list):
return [canonicalize(v) for v in value]
if isinstance(value, Enum):
return value.value
return value

@classmethod
def stable_json(cls, row: "EvaluationRow") -> int:
json_str = row.model_dump_json(
# Dump to a plain Python structure first
data = self.model_dump(
exclude_none=True,
exclude_defaults=True,
by_alias=True,
indent=None,
exclude=["created_at", "execution_metadata"],
exclude={"created_at", "execution_metadata", "pid"},
)
return json_str

# Ensure deterministic ordering for all nested dicts
canonical_data = canonicalize(data)

# Compact, sorted JSON string
return json.dumps(canonical_data, separators=(",", ":"), sort_keys=True, ensure_ascii=False)


# Original dataclass-based models for backwards compatibility
Expand Down
284 changes: 281 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,46 @@
)


def dummy_row() -> EvaluationRow:
from eval_protocol.models import (
EvaluateResult as _EvaluateResult,
EvaluationRow as _EvaluationRow,
InputMetadata as _InputMetadata,
Message as _Message,
MetricResult as _MetricResult,
)

msgs = [
_Message(role="system", content="You are a helpful assistant"),
_Message(role="user", content="Compute 2+2"),
_Message(role="assistant", content="4"),
]
eval_res = _EvaluateResult(
score=1.0,
reason="Correct",
metrics={
"accuracy": _MetricResult(score=1.0, reason="matches ground truth"),
},
)
child_row = _EvaluationRow(
messages=msgs,
ground_truth="4",
evaluation_result=eval_res,
input_metadata=_InputMetadata(
row_id="arith_0001",
completion_params={"model": "dummy/local-model", "temperature": 0.0},
dataset_info={"source": "unit_test", "variant": "subprocess"},
session_data={"attempt": 1},
),
)
return child_row


def _child_compute_hash_value(_unused=None) -> int:
row = dummy_row()
return hash(row)


def test_metric_result_creation():
"""Test creating a MetricResult."""
metric = MetricResult(score=0.5, reason="Test reason", is_score_valid=True)
Expand Down Expand Up @@ -289,7 +329,7 @@ def test_evaluation_row_creation():
assert not row.is_trajectory_evaluation()


def test_stable_hash():
def test_stable_json():
"""Test the stable hash method."""
row = EvaluationRow(
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
Expand All @@ -299,8 +339,8 @@ def test_stable_hash():
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
ground_truth="4",
)
stable_json = EvaluationRow.stable_json(row)
stable_json2 = EvaluationRow.stable_json(row2)
stable_json = row._stable_json()
stable_json2 = row2._stable_json()
assert stable_json == stable_json2
assert "created_at" not in stable_json
assert "execution_metadata" not in stable_json
Expand Down Expand Up @@ -382,3 +422,241 @@ def test_message_creation_requires_role():
msg_none_content = Message(role="user") # content defaults to ""
assert msg_none_content.role == "user"
assert msg_none_content.content == ""


def test_stable_hash_consistency():
"""Test that the same EvaluationRow produces the same hash value consistently."""
row1 = EvaluationRow(
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
ground_truth="4",
)
row2 = EvaluationRow(
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
ground_truth="4",
)

# Same content should produce same hash
assert hash(row1) == hash(row2)

# Hash should be consistent across multiple calls
hash1_first = hash(row1)
hash1_second = hash(row1)
hash1_third = hash(row1)

assert hash1_first == hash1_second == hash1_third

# Hash should be a positive integer
assert isinstance(hash1_first, int)
assert hash1_first > 0


def test_stable_hash_different_content():
"""Test that different content produces different hash values."""
row1 = EvaluationRow(
messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")],
ground_truth="4",
)
row2 = EvaluationRow(
messages=[Message(role="user", content="What is 3+3?"), Message(role="assistant", content="3+3 equals 6.")],
ground_truth="6",
)

# Different content should produce different hashes
assert hash(row1) != hash(row2)


def test_stable_hash_ignores_volatile_fields():
"""Test that volatile fields like timestamps don't affect the hash."""
messages = [Message(role="user", content="Test"), Message(role="assistant", content="Response")]

# Create rows with different timestamps
row1 = EvaluationRow(messages=messages, ground_truth="test")
row2 = EvaluationRow(messages=messages, ground_truth="test")

# Wait a moment to ensure different timestamps
import time

time.sleep(0.001)

# Create another row
row3 = EvaluationRow(messages=messages, ground_truth="test")

# All should have the same hash despite different timestamps
assert hash(row1) == hash(row2) == hash(row3)


def test_stable_hash_with_complex_data():
"""Test stable hashing with complex nested data structures."""
complex_messages = [
Message(role="system", content="You are a helpful assistant"),
Message(role="user", content="Solve this math problem: 15 * 23"),
Message(
role="assistant",
content="Let me solve this step by step:\n1. 15 * 20 = 300\n2. 15 * 3 = 45\n3. 300 + 45 = 345",
),
Message(role="user", content="Thank you!"),
Message(role="assistant", content="You're welcome! Let me know if you need help with anything else."),
]

complex_evaluation = EvaluateResult(
score=0.95,
reason="Excellent step-by-step solution with clear explanation",
metrics={
"accuracy": MetricResult(score=1.0, reason="Correct mathematical calculation"),
"explanation_quality": MetricResult(score=0.9, reason="Clear step-by-step breakdown"),
"completeness": MetricResult(score=0.95, reason="Covers all aspects of the problem"),
},
)

row1 = EvaluationRow(
messages=complex_messages,
ground_truth="345",
evaluation_result=complex_evaluation,
input_metadata=InputMetadata(
row_id="complex_math_001",
completion_params={"model": "gpt-4", "temperature": 0.1},
dataset_info={"source": "math_eval", "difficulty": "medium"},
session_data={"user_id": "test_user", "session_id": "session_123"},
),
)

row2 = EvaluationRow(
messages=complex_messages,
ground_truth="345",
evaluation_result=complex_evaluation,
input_metadata=InputMetadata(
row_id="complex_math_001",
completion_params={"model": "gpt-4", "temperature": 0.1},
dataset_info={"source": "math_eval", "difficulty": "medium"},
session_data={"user_id": "test_user", "session_id": "session_123"},
),
)

# Complex data should still produce consistent hashes
assert hash(row1) == hash(row2)

# Hash should be different from simple rows
simple_row = EvaluationRow(
messages=[Message(role="user", content="Simple"), Message(role="assistant", content="Response")],
ground_truth="test",
)
assert hash(row1) != hash(simple_row)


def test_stable_hash_json_representation():
"""Test that the stable JSON representation is consistent and excludes volatile fields."""
row = EvaluationRow(
messages=[Message(role="user", content="Test"), Message(role="assistant", content="Response")],
ground_truth="test",
)

# Get the stable JSON representation
stable_json = row._stable_json()

# Should be a valid JSON string
parsed = json.loads(stable_json)

# Should contain the core data
assert "messages" in parsed
assert "ground_truth" in parsed
assert parsed["ground_truth"] == "test"

# Should NOT contain volatile fields
assert "created_at" not in parsed
assert "execution_metadata" not in parsed

# Should be deterministic (same content produces same JSON)
stable_json2 = row._stable_json()
assert stable_json == stable_json2


def test_stable_hash_consistency_for_identical_rows():
"""Test that identical EvaluationRow objects produce the same stable hash.

This simulates the behavior expected across Python process restarts by
creating multiple identical objects and ensuring their hashes match.
"""
# Create a complex evaluation row
messages = [
Message(role="user", content="What is the capital of France?"),
Message(role="assistant", content="The capital of France is Paris."),
Message(role="user", content="What about Germany?"),
Message(role="assistant", content="The capital of Germany is Berlin."),
]

evaluation_result = EvaluateResult(
score=0.9,
reason="Correct answers for both questions",
metrics={
"geography_knowledge": MetricResult(score=1.0, reason="Both capitals correctly identified"),
"response_quality": MetricResult(score=0.8, reason="Clear and concise responses"),
},
)

# Create multiple identical rows
rows = []
for i in range(5):
row = EvaluationRow(
messages=messages,
ground_truth="Paris, Berlin",
evaluation_result=evaluation_result,
input_metadata=InputMetadata(
completion_params={"model": "gpt-4"},
dataset_info={"source": "geography_eval"},
),
)
rows.append(row)

# All rows should have identical hashes
first_hash = hash(rows[0])
for row in rows[1:]:
assert hash(row) == first_hash

# The hash should be a large positive integer (SHA-256 first 8 bytes)
assert first_hash > 0
assert first_hash < 2**64 # 8 bytes = 64 bits


def test_stable_hash_edge_cases():
"""Test stable hashing with edge cases like empty data and None values."""
# Empty messages
empty_row = EvaluationRow(messages=[], ground_truth="")
empty_hash = hash(empty_row)
assert isinstance(empty_hash, int)
assert empty_hash > 0

# None values in optional fields
none_row = EvaluationRow(
messages=[Message(role="user", content="Test")], ground_truth=None, evaluation_result=None
)
none_hash = hash(none_row)
assert isinstance(none_hash, int)
assert none_hash > 0

# Different from empty row
assert empty_hash != none_hash

# Row with only required fields
minimal_row = EvaluationRow(messages=[Message(role="user", content="Minimal")])
minimal_hash = hash(minimal_row)
assert isinstance(minimal_hash, int)
assert minimal_hash > 0

# Should be different from other edge cases
assert minimal_hash != empty_hash
assert minimal_hash != none_hash


def test_stable_hash_across_subprocess():
"""Verify the same EvaluationRow produces the same hash in a separate Python process."""
import multiprocessing as mp

row = dummy_row()
parent_hash = hash(row)
# Compute the same hash in a fresh interpreter via Pool.map (spawned process)
ctx = mp.get_context("spawn")
with ctx.Pool(processes=1) as pool:
[child_hash] = pool.map(_child_compute_hash_value, [None])

assert isinstance(child_hash, int)
assert parent_hash == child_hash
Loading