Skip to content

Commit 6025f78

Browse files
author
Dylan Huang
committed
fix sqlite stuff to use rollout_id
1 parent c436eaf commit 6025f78

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ def __init__(self, db_path: Optional[str] = None, store: Optional[SqliteEvaluati
2222
self._store = SqliteEvaluationRowStore(self.db_path)
2323

2424
def log(self, row: "EvaluationRow") -> None:
25-
row_id = row.input_metadata.row_id
25+
rollout_id = row.rollout_id
2626
data = row.model_dump(exclude_none=True, mode="json")
27-
self._store.upsert_row(row_id=row_id, data=data)
27+
self._store.upsert_row(rollout_id=rollout_id, data=data)
2828
try:
2929
event_bus.emit(LOG_EVENT_TYPE, EvaluationRow(**data))
3030
except Exception as e:
3131
# Avoid breaking storage due to event emission issues
3232
logger.error(f"Failed to emit row_upserted event: {e}")
3333
pass
3434

35-
def read(self, row_id: Optional[str] = None) -> List["EvaluationRow"]:
35+
def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]:
3636
from eval_protocol.models import EvaluationRow
3737

38-
results = self._store.read_rows(row_id=row_id)
38+
results = self._store.read_rows(rollout_id=rollout_id)
3939
return [EvaluationRow(**data) for data in results]

eval_protocol/dataset_logger/sqlite_evaluation_row_store.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class SqliteEvaluationRowStore:
1111
"""
1212
Lightweight reusable SQLite store for evaluation rows.
1313
14-
Stores arbitrary row data as JSON keyed by a unique string `row_id`.
14+
Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
1515
"""
1616

1717
def __init__(self, db_path: str):
@@ -24,7 +24,7 @@ class Meta:
2424
database = self._db
2525

2626
class EvaluationRow(BaseModel): # type: ignore
27-
row_id = CharField(unique=True)
27+
rollout_id = CharField(unique=True)
2828
data = JSONField()
2929

3030
self._EvaluationRow = EvaluationRow
@@ -36,22 +36,22 @@ class EvaluationRow(BaseModel): # type: ignore
3636
def db_path(self) -> str:
3737
return self._db_path
3838

39-
def upsert_row(self, row_id: str, data: dict) -> None:
40-
if self._EvaluationRow.select().where(self._EvaluationRow.row_id == row_id).exists():
41-
self._EvaluationRow.update(data=data).where(self._EvaluationRow.row_id == row_id).execute()
39+
def upsert_row(self, rollout_id: str, data: dict) -> None:
40+
if self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id).exists():
41+
self._EvaluationRow.update(data=data).where(self._EvaluationRow.rollout_id == rollout_id).execute()
4242
else:
43-
self._EvaluationRow.create(row_id=row_id, data=data)
43+
self._EvaluationRow.create(rollout_id=rollout_id, data=data)
4444

45-
def read_rows(self, row_id: Optional[str] = None) -> List[dict]:
46-
if row_id is None:
45+
def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
46+
if rollout_id is None:
4747
query = self._EvaluationRow.select().dicts()
4848
else:
49-
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.row_id == row_id)
49+
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.rollout_id == rollout_id)
5050
results = list(query)
5151
return [result["data"] for result in results]
5252

53-
def delete_row(self, row_id: str) -> int:
54-
return self._EvaluationRow.delete().where(self._EvaluationRow.row_id == row_id).execute()
53+
def delete_row(self, rollout_id: str) -> int:
54+
return self._EvaluationRow.delete().where(self._EvaluationRow.rollout_id == rollout_id).execute()
5555

5656
def delete_all_rows(self) -> int:
5757
return self._EvaluationRow.delete().execute()

0 commit comments

Comments
 (0)