From af12e4df3ab8b7a8ba7247bd2dc8811d81c9cf5f Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 19 Aug 2025 10:59:20 -0700 Subject: [PATCH 1/3] add tests regarding hashes --- eval_protocol/models.py | 51 ++++++-- tests/test_models.py | 284 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 323 insertions(+), 12 deletions(-) diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 988ca8ea..b591bea8 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -430,20 +430,53 @@ def get_termination_reason(self) -> str: return "unknown" def __hash__(self) -> int: - # Use a stable hash by sorting keys and ensuring compact output - json_str = self.stable_json(self) - return hash(json_str) + # Use a stable hash that works across Python processes + return self._stable_hash() + + def _stable_hash(self) -> int: + """Generate a stable hash that works across Python processes.""" + import hashlib + + # Get the stable JSON representation + json_str = self._stable_json() + + # Use SHA-256 for deterministic hashing across processes + hash_obj = hashlib.sha256(json_str.encode("utf-8")) + + # Convert to a positive integer (first 8 bytes) + hash_bytes = hash_obj.digest()[:8] + return int.from_bytes(hash_bytes, byteorder="big") + + def _stable_json(self) -> str: + """Generate a stable JSON string representation for hashing.""" + # Produce a canonical, key-sorted JSON across nested structures and + # exclude volatile fields that can differ across processes + import json + from enum import Enum + + def canonicalize(value): + # Recursively convert to a structure with deterministic key ordering + if isinstance(value, dict): + return {k: canonicalize(value[k]) for k in sorted(value.keys())} + if isinstance(value, list): + return [canonicalize(v) for v in value] + if isinstance(value, Enum): + return value.value + return value - @classmethod - def stable_json(cls, row: "EvaluationRow") -> int: - json_str = row.model_dump_json( + # Dump to a plain Python structure first + data = self.model_dump( exclude_none=True, exclude_defaults=True, by_alias=True, - indent=None, - exclude=["created_at", "execution_metadata"], + exclude={"created_at", "execution_metadata", "pid"}, ) - return json_str + + # Ensure deterministic ordering for all nested dicts + canonical_data = canonicalize(data) + + # Compact, sorted JSON string + return json.dumps(canonical_data, separators=(",", ":"), sort_keys=True, ensure_ascii=False) # Original dataclass-based models for backwards compatibility diff --git a/tests/test_models.py b/tests/test_models.py index 817c4e7c..d147dad0 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,6 +13,46 @@ ) +def dummy_row() -> EvaluationRow: + from eval_protocol.models import ( + EvaluateResult as _EvaluateResult, + EvaluationRow as _EvaluationRow, + InputMetadata as _InputMetadata, + Message as _Message, + MetricResult as _MetricResult, + ) + + msgs = [ + _Message(role="system", content="You are a helpful assistant"), + _Message(role="user", content="Compute 2+2"), + _Message(role="assistant", content="4"), + ] + eval_res = _EvaluateResult( + score=1.0, + reason="Correct", + metrics={ + "accuracy": _MetricResult(score=1.0, reason="matches ground truth"), + }, + ) + child_row = _EvaluationRow( + messages=msgs, + ground_truth="4", + evaluation_result=eval_res, + input_metadata=_InputMetadata( + row_id="arith_0001", + completion_params={"model": "dummy/local-model", "temperature": 0.0}, + dataset_info={"source": "unit_test", "variant": "subprocess"}, + session_data={"attempt": 1}, + ), + ) + return child_row + + +def _child_compute_hash_value(_unused=None) -> int: + row = dummy_row() + return hash(row) + + def test_metric_result_creation(): """Test creating a MetricResult.""" metric = MetricResult(score=0.5, reason="Test reason", is_score_valid=True) @@ -289,7 +329,7 @@ def test_evaluation_row_creation(): assert not row.is_trajectory_evaluation() -def test_stable_hash(): +def test_stable_json(): """Test the stable hash method.""" row = EvaluationRow( messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")], @@ -299,8 +339,8 @@ def test_stable_hash(): messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")], ground_truth="4", ) - stable_json = EvaluationRow.stable_json(row) - stable_json2 = EvaluationRow.stable_json(row2) + stable_json = row.model_dump_json() + stable_json2 = row2.model_dump_json() assert stable_json == stable_json2 assert "created_at" not in stable_json assert "execution_metadata" not in stable_json @@ -382,3 +422,241 @@ def test_message_creation_requires_role(): msg_none_content = Message(role="user") # content defaults to "" assert msg_none_content.role == "user" assert msg_none_content.content == "" + + +def test_stable_hash_consistency(): + """Test that the same EvaluationRow produces the same hash value consistently.""" + row1 = EvaluationRow( + messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")], + ground_truth="4", + ) + row2 = EvaluationRow( + messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")], + ground_truth="4", + ) + + # Same content should produce same hash + assert hash(row1) == hash(row2) + + # Hash should be consistent across multiple calls + hash1_first = hash(row1) + hash1_second = hash(row1) + hash1_third = hash(row1) + + assert hash1_first == hash1_second == hash1_third + + # Hash should be a positive integer + assert isinstance(hash1_first, int) + assert hash1_first > 0 + + +def test_stable_hash_different_content(): + """Test that different content produces different hash values.""" + row1 = EvaluationRow( + messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")], + ground_truth="4", + ) + row2 = EvaluationRow( + messages=[Message(role="user", content="What is 3+3?"), Message(role="assistant", content="3+3 equals 6.")], + ground_truth="6", + ) + + # Different content should produce different hashes + assert hash(row1) != hash(row2) + + +def test_stable_hash_ignores_volatile_fields(): + """Test that volatile fields like timestamps don't affect the hash.""" + messages = [Message(role="user", content="Test"), Message(role="assistant", content="Response")] + + # Create rows with different timestamps + row1 = EvaluationRow(messages=messages, ground_truth="test") + row2 = EvaluationRow(messages=messages, ground_truth="test") + + # Wait a moment to ensure different timestamps + import time + + time.sleep(0.001) + + # Create another row + row3 = EvaluationRow(messages=messages, ground_truth="test") + + # All should have the same hash despite different timestamps + assert hash(row1) == hash(row2) == hash(row3) + + +def test_stable_hash_with_complex_data(): + """Test stable hashing with complex nested data structures.""" + complex_messages = [ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Solve this math problem: 15 * 23"), + Message( + role="assistant", + content="Let me solve this step by step:\n1. 15 * 20 = 300\n2. 15 * 3 = 45\n3. 300 + 45 = 345", + ), + Message(role="user", content="Thank you!"), + Message(role="assistant", content="You're welcome! Let me know if you need help with anything else."), + ] + + complex_evaluation = EvaluateResult( + score=0.95, + reason="Excellent step-by-step solution with clear explanation", + metrics={ + "accuracy": MetricResult(score=1.0, reason="Correct mathematical calculation"), + "explanation_quality": MetricResult(score=0.9, reason="Clear step-by-step breakdown"), + "completeness": MetricResult(score=0.95, reason="Covers all aspects of the problem"), + }, + ) + + row1 = EvaluationRow( + messages=complex_messages, + ground_truth="345", + evaluation_result=complex_evaluation, + input_metadata=InputMetadata( + row_id="complex_math_001", + completion_params={"model": "gpt-4", "temperature": 0.1}, + dataset_info={"source": "math_eval", "difficulty": "medium"}, + session_data={"user_id": "test_user", "session_id": "session_123"}, + ), + ) + + row2 = EvaluationRow( + messages=complex_messages, + ground_truth="345", + evaluation_result=complex_evaluation, + input_metadata=InputMetadata( + row_id="complex_math_001", + completion_params={"model": "gpt-4", "temperature": 0.1}, + dataset_info={"source": "math_eval", "difficulty": "medium"}, + session_data={"user_id": "test_user", "session_id": "session_123"}, + ), + ) + + # Complex data should still produce consistent hashes + assert hash(row1) == hash(row2) + + # Hash should be different from simple rows + simple_row = EvaluationRow( + messages=[Message(role="user", content="Simple"), Message(role="assistant", content="Response")], + ground_truth="test", + ) + assert hash(row1) != hash(simple_row) + + +def test_stable_hash_json_representation(): + """Test that the stable JSON representation is consistent and excludes volatile fields.""" + row = EvaluationRow( + messages=[Message(role="user", content="Test"), Message(role="assistant", content="Response")], + ground_truth="test", + ) + + # Get the stable JSON representation + stable_json = row._stable_json() + + # Should be a valid JSON string + parsed = json.loads(stable_json) + + # Should contain the core data + assert "messages" in parsed + assert "ground_truth" in parsed + assert parsed["ground_truth"] == "test" + + # Should NOT contain volatile fields + assert "created_at" not in parsed + assert "execution_metadata" not in parsed + + # Should be deterministic (same content produces same JSON) + stable_json2 = row._stable_json() + assert stable_json == stable_json2 + + +def test_stable_hash_consistency_for_identical_rows(): + """Test that identical EvaluationRow objects produce the same stable hash. + + This simulates the behavior expected across Python process restarts by + creating multiple identical objects and ensuring their hashes match. + """ + # Create a complex evaluation row + messages = [ + Message(role="user", content="What is the capital of France?"), + Message(role="assistant", content="The capital of France is Paris."), + Message(role="user", content="What about Germany?"), + Message(role="assistant", content="The capital of Germany is Berlin."), + ] + + evaluation_result = EvaluateResult( + score=0.9, + reason="Correct answers for both questions", + metrics={ + "geography_knowledge": MetricResult(score=1.0, reason="Both capitals correctly identified"), + "response_quality": MetricResult(score=0.8, reason="Clear and concise responses"), + }, + ) + + # Create multiple identical rows + rows = [] + for i in range(5): + row = EvaluationRow( + messages=messages, + ground_truth="Paris, Berlin", + evaluation_result=evaluation_result, + input_metadata=InputMetadata( + completion_params={"model": "gpt-4"}, + dataset_info={"source": "geography_eval"}, + ), + ) + rows.append(row) + + # All rows should have identical hashes + first_hash = hash(rows[0]) + for row in rows[1:]: + assert hash(row) == first_hash + + # The hash should be a large positive integer (SHA-256 first 8 bytes) + assert first_hash > 0 + assert first_hash < 2**64 # 8 bytes = 64 bits + + +def test_stable_hash_edge_cases(): + """Test stable hashing with edge cases like empty data and None values.""" + # Empty messages + empty_row = EvaluationRow(messages=[], ground_truth="") + empty_hash = hash(empty_row) + assert isinstance(empty_hash, int) + assert empty_hash > 0 + + # None values in optional fields + none_row = EvaluationRow( + messages=[Message(role="user", content="Test")], ground_truth=None, evaluation_result=None + ) + none_hash = hash(none_row) + assert isinstance(none_hash, int) + assert none_hash > 0 + + # Different from empty row + assert empty_hash != none_hash + + # Row with only required fields + minimal_row = EvaluationRow(messages=[Message(role="user", content="Minimal")]) + minimal_hash = hash(minimal_row) + assert isinstance(minimal_hash, int) + assert minimal_hash > 0 + + # Should be different from other edge cases + assert minimal_hash != empty_hash + assert minimal_hash != none_hash + + +def test_stable_hash_across_subprocess(): + """Verify the same EvaluationRow produces the same hash in a separate Python process.""" + import multiprocessing as mp + + row = dummy_row() + parent_hash = hash(row) + # Compute the same hash in a fresh interpreter via Pool.map (spawned process) + ctx = mp.get_context("spawn") + with ctx.Pool(processes=1) as pool: + [child_hash] = pool.map(_child_compute_hash_value, [None]) + + assert isinstance(child_hash, int) + assert parent_hash == child_hash From 61c7328f14e157b91ab0edd4e7883eef1cc3cb81 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 19 Aug 2025 12:10:45 -0700 Subject: [PATCH 2/3] update --- eval_protocol/models.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/eval_protocol/models.py b/eval_protocol/models.py index b591bea8..83a0f178 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -211,7 +211,14 @@ class InputMetadata(BaseModel): model_config = ConfigDict(extra="allow") - row_id: Optional[str] = Field(None, description="Unique string to ID the row") + row_id: Optional[str] = Field( + default=None, + description=( + "Unique string to ID the row. If not provided, a stable hash will be generated " + "based on the row's content. The hash removes fields that are not typically stable " + "across processes such as created_at, execution_metadata, and pid." + ), + ) completion_params: CompletionParams = Field( default_factory=dict, description="Completion endpoint parameters used" ) From 20e65420ef3f58f88d27cc2dc98995e63323c23a Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 19 Aug 2025 12:17:35 -0700 Subject: [PATCH 3/3] fix stable json test --- tests/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index d147dad0..3e1f7706 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -339,8 +339,8 @@ def test_stable_json(): messages=[Message(role="user", content="What is 2+2?"), Message(role="assistant", content="2+2 equals 4.")], ground_truth="4", ) - stable_json = row.model_dump_json() - stable_json2 = row2.model_dump_json() + stable_json = row._stable_json() + stable_json2 = row2._stable_json() assert stable_json == stable_json2 assert "created_at" not in stable_json assert "execution_metadata" not in stable_json