Skip to content

Commit b483e00

Browse files
authored
fix remote dataset info (#347)
* fix remote dataset info * add test
1 parent 361369e commit b483e00

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

eval_protocol/pytest/tracing_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,14 @@ def update_row_with_remote_trace(
171171
row.messages = remote_row.messages
172172
row.tools = remote_row.tools
173173
row.input_metadata.session_data = remote_row.input_metadata.session_data
174-
row.input_metadata.dataset_info = remote_row.input_metadata.dataset_info
174+
remote_info = remote_row.input_metadata.dataset_info or {}
175+
if row.input_metadata.dataset_info is None:
176+
row.input_metadata.dataset_info = dict(remote_info)
177+
else:
178+
for k, v in remote_info.items():
179+
if k not in row.input_metadata.dataset_info:
180+
row.input_metadata.dataset_info[k] = v
181+
175182
row.execution_metadata = remote_row.execution_metadata
176183
return None
177184
else:

tests/remote_server/test_remote_fireworks.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,16 @@ def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
9494

9595

9696
def rows() -> List[EvaluationRow]:
97+
"""Generate local rows with rich input_metadata to verify it survives remote traces."""
98+
base_dataset_info = {
99+
"requirements": ["Answer with the capital city of France."],
100+
"total_requirements": 1,
101+
"original_prompt": "What is the capital of France?",
102+
}
103+
97104
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
105+
row.input_metadata.dataset_info = dict(base_dataset_info)
106+
98107
return [row, row, row]
99108

100109

@@ -127,6 +136,17 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
127136
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
128137
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
129138
)
139+
assert row.input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"
130140
assert row.input_metadata.completion_params["temperature"] == 0.5, "Row should have temperature at top level"
131141

142+
assert row.input_metadata.row_id is not None
143+
144+
assert row.input_metadata.dataset_info is not None
145+
assert row.input_metadata.dataset_info["requirements"] == ["Answer with the capital city of France."]
146+
assert row.input_metadata.dataset_info["total_requirements"] == 1
147+
assert row.input_metadata.dataset_info["original_prompt"] == "What is the capital of France?"
148+
149+
assert "data_loader_type" in row.input_metadata.dataset_info
150+
assert "data_loader_num_rows" in row.input_metadata.dataset_info
151+
132152
return row

0 commit comments

Comments
 (0)