Skip to content

Commit 33c52a8

Browse files
author
Dylan Huang
authored
Aggregated metrics part 2 (#48)
* parameterize logger * fix logging too much * fix sqlite stuff to use rollout_id * fix * fix sqlite logger tests * fix defualt_agent_rollout_processor.py * for generating some test data * comment out saving logs code * Refactor pivot test to handle JSONL logs and improve data parsing - Updated the test to read from a new JSONL log file instead of JSON, enhancing data handling. - Adjusted the parsing logic to accommodate line-by-line JSON parsing for better flexibility. - Modified row and column field definitions to include additional identifiers for improved data aggregation. * everything is rollout id basd
1 parent a50c3f6 commit 33c52a8

16 files changed

+348
-83
lines changed

eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,17 @@ def __init__(self, db_path: Optional[str] = None, store: Optional[SqliteEvaluati
2222
self._store = SqliteEvaluationRowStore(self.db_path)
2323

2424
def log(self, row: "EvaluationRow") -> None:
25-
row_id = row.input_metadata.row_id
2625
data = row.model_dump(exclude_none=True, mode="json")
27-
self._store.upsert_row(row_id=row_id, data=data)
26+
self._store.upsert_row(data=data)
2827
try:
2928
event_bus.emit(LOG_EVENT_TYPE, EvaluationRow(**data))
3029
except Exception as e:
3130
# Avoid breaking storage due to event emission issues
3231
logger.error(f"Failed to emit row_upserted event: {e}")
3332
pass
3433

35-
def read(self, row_id: Optional[str] = None) -> List["EvaluationRow"]:
34+
def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]:
3635
from eval_protocol.models import EvaluationRow
3736

38-
results = self._store.read_rows(row_id=row_id)
37+
results = self._store.read_rows(rollout_id=rollout_id)
3938
return [EvaluationRow(**data) for data in results]

eval_protocol/dataset_logger/sqlite_evaluation_row_store.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class SqliteEvaluationRowStore:
1111
"""
1212
Lightweight reusable SQLite store for evaluation rows.
1313
14-
Stores arbitrary row data as JSON keyed by a unique string `row_id`.
14+
Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
1515
"""
1616

1717
def __init__(self, db_path: str):
@@ -24,7 +24,7 @@ class Meta:
2424
database = self._db
2525

2626
class EvaluationRow(BaseModel): # type: ignore
27-
row_id = CharField(unique=True)
27+
rollout_id = CharField(unique=True)
2828
data = JSONField()
2929

3030
self._EvaluationRow = EvaluationRow
@@ -36,22 +36,25 @@ class EvaluationRow(BaseModel): # type: ignore
3636
def db_path(self) -> str:
3737
return self._db_path
3838

39-
def upsert_row(self, row_id: str, data: dict) -> None:
40-
if self._EvaluationRow.select().where(self._EvaluationRow.row_id == row_id).exists():
41-
self._EvaluationRow.update(data=data).where(self._EvaluationRow.row_id == row_id).execute()
39+
def upsert_row(self, data: dict) -> None:
40+
rollout_id = data["rollout_id"]
41+
if "rollout_id" not in data:
42+
raise ValueError("rollout_id is required to upsert a row")
43+
if self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id).exists():
44+
self._EvaluationRow.update(data=data).where(self._EvaluationRow.rollout_id == rollout_id).execute()
4245
else:
43-
self._EvaluationRow.create(row_id=row_id, data=data)
46+
self._EvaluationRow.create(rollout_id=rollout_id, data=data)
4447

45-
def read_rows(self, row_id: Optional[str] = None) -> List[dict]:
46-
if row_id is None:
48+
def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
49+
if rollout_id is None:
4750
query = self._EvaluationRow.select().dicts()
4851
else:
49-
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.row_id == row_id)
52+
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.rollout_id == rollout_id)
5053
results = list(query)
5154
return [result["data"] for result in results]
5255

53-
def delete_row(self, row_id: str) -> int:
54-
return self._EvaluationRow.delete().where(self._EvaluationRow.row_id == row_id).execute()
56+
def delete_row(self, rollout_id: str) -> int:
57+
return self._EvaluationRow.delete().where(self._EvaluationRow.rollout_id == rollout_id).execute()
5558

5659
def delete_all_rows(self) -> int:
5760
return self._EvaluationRow.delete().execute()

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from openai.types.chat import ChatCompletionContentPartTextParam, ChatCompletionMessage, ChatCompletionToolParam
99
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
1010

11-
from eval_protocol.dataset_logger import default_logger
11+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1212
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
1313
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
1414
from eval_protocol.models import EvaluationRow, Message
@@ -20,12 +20,13 @@ class Agent:
2020
A really simple agent that calls the model until no more tool calls are needed.
2121
"""
2222

23-
def __init__(self, model: str, row: EvaluationRow, config_path: str):
23+
def __init__(self, model: str, row: EvaluationRow, config_path: str, logger: DatasetLogger):
2424
self.model = model
2525
self.evaluation_row: EvaluationRow = row
2626
self._policy = LiteLLMPolicy(model_id=model)
2727
self.mcp_client = MCPMultiClient(config_path=config_path) if config_path else None
2828
self.tools: Union[List[ChatCompletionToolParam], NotGiven] = NOT_GIVEN
29+
self.logger: DatasetLogger = logger
2930

3031
async def setup(self):
3132
if self.mcp_client:
@@ -42,7 +43,7 @@ def messages(self) -> list[Message]:
4243

4344
def append_message_and_log(self, message: Message):
4445
self.messages.append(message)
45-
default_logger.log(self.evaluation_row)
46+
self.logger.log(self.evaluation_row)
4647

4748
async def call_agent(self) -> str:
4849
"""
@@ -116,7 +117,7 @@ async def default_agent_rollout_processor(
116117
) -> List[EvaluationRow]:
117118
dataset: Dataset = []
118119
for row in rows:
119-
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path)
120+
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger)
120121
await agent.setup()
121122
await agent.call_agent()
122123
dataset.append(agent.evaluation_row)

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import asyncio
2-
from typing import List
3-
42
import logging
53
import os
4+
from typing import List
65

7-
from eval_protocol.dataset_logger import default_logger
8-
from eval_protocol.models import EvaluationRow, Message, ChatCompletionMessageToolCall
6+
from eval_protocol.models import ChatCompletionMessageToolCall, EvaluationRow, Message
97
from eval_protocol.pytest.types import RolloutProcessorConfig
108

119

@@ -49,6 +47,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
4947

5048
# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
5149
import importlib
50+
5251
_litellm = importlib.import_module("litellm")
5352
acompletion = getattr(_litellm, "acompletion")
5453
response = await acompletion(**request_params)
@@ -79,7 +78,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
7978
]
8079

8180
row.messages = messages
82-
default_logger.log(row)
81+
config.logger.log(row)
8382
return row
8483

8584
# Process rows with bounded concurrency if configured

eval_protocol/pytest/evaluation_test.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
from eval_protocol.dataset_logger import default_logger
11+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1112
from eval_protocol.human_id import generate_id
1213
from eval_protocol.models import CompletionParams, EvalMetadata, EvaluationRow, InputMetadata, Message
1314
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
@@ -55,6 +56,7 @@ def evaluation_test( # noqa: C901
5556
steps: int = 30,
5657
mode: EvaluationTestMode = "batch",
5758
combine_datasets: bool = True,
59+
logger: Optional[DatasetLogger] = None,
5860
) -> Callable[
5961
[TestFunction],
6062
TestFunction,
@@ -117,8 +119,11 @@ def evaluation_test( # noqa: C901
117119
mode: Evaluation mode. "batch" (default) expects test function to handle
118120
full dataset. "pointwise" applies test function to each row. If your evaluation requires
119121
the full rollout of all rows to compute the score, use
122+
logger: DatasetLogger to use for logging. If not provided, a default logger will be used.
120123
"""
121124

125+
active_logger: DatasetLogger = logger if logger else default_logger
126+
122127
def decorator(
123128
test_func: TestFunction,
124129
):
@@ -287,7 +292,7 @@ def wrapper_body(**kwargs):
287292
def _log_eval_error(
288293
status: Literal["finished", "error"], rows: Optional[List[EvaluationRow]] | None, passed: bool
289294
) -> None:
290-
log_eval_status_and_rows(eval_metadata, rows, status, passed, default_logger)
295+
log_eval_status_and_rows(eval_metadata, rows, status, passed, active_logger)
291296

292297
try:
293298
# Handle dataset loading
@@ -369,7 +374,6 @@ def _log_eval_error(
369374
# has to be done in the pytest main process since it's
370375
# used to determine whether this eval has stopped
371376
row.pid = os.getpid()
372-
default_logger.log(row)
373377

374378
# Prepare rollout processor config once; we will generate fresh outputs per run
375379
config = RolloutProcessorConfig(
@@ -379,6 +383,7 @@ def _log_eval_error(
379383
max_concurrent_rollouts=max_concurrent_rollouts,
380384
server_script_path=server_script_path,
381385
steps=steps,
386+
logger=active_logger,
382387
)
383388

384389
for _ in range(num_runs):
@@ -395,6 +400,10 @@ def _log_eval_error(
395400
for row in fresh_dataset:
396401
row.rollout_id = generate_id()
397402

403+
# log the fresh_dataset
404+
for row in fresh_dataset:
405+
active_logger.log(row)
406+
398407
processed_dataset = execute_function(rollout_processor, rows=fresh_dataset, config=config)
399408

400409
if mode == "pointwise":
@@ -463,7 +472,7 @@ def _log_eval_error(
463472
if r.eval_metadata is not None:
464473
r.eval_metadata.status = "finished"
465474
r.eval_metadata.passed = passed
466-
default_logger.log(r)
475+
active_logger.log(r)
467476

468477
# Optional: print and/or persist a summary artifact for CI
469478
try:
@@ -587,6 +596,23 @@ def _extract_effort_tag(params: dict) -> str | None:
587596
# Do not fail evaluation if summary writing fails
588597
pass
589598

599+
# # Write all rows from active_logger.read() to a JSONL file in the same directory as the summary
600+
# try:
601+
# if active_logger is not None:
602+
# rows = active_logger.read()
603+
# # Write to a .jsonl file alongside the summary file
604+
# jsonl_path = "logs.jsonl"
605+
# import json
606+
607+
# with open(jsonl_path, "w", encoding="utf-8") as f_jsonl:
608+
# for row in rows:
609+
# json.dump(row.model_dump(exclude_none=True, mode="json"), f_jsonl)
610+
# f_jsonl.write("\n")
611+
# except Exception as e:
612+
# # Do not fail evaluation if log writing fails
613+
# print(e)
614+
# pass
615+
590616
# Check threshold after logging
591617
if threshold_of_success is not None and not passed:
592618
assert (

eval_protocol/pytest/types.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from dataclasses import dataclass
66
from typing import Any, Callable, Dict, List, Literal, Optional
77

8+
from eval_protocol.dataset_logger import default_logger
9+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
10+
811
from ..models import EvaluationRow, Message
912

1013
ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
@@ -39,10 +42,13 @@
3942
class RolloutProcessorConfig:
4043
model: ModelParam
4144
input_params: RolloutInputParam # optional input parameters for inference
42-
mcp_config_path: str
43-
server_script_path: Optional[str] = None # TODO: change from server_script_path to mcp_config_path for agent rollout processor
45+
mcp_config_path: str
46+
server_script_path: Optional[str] = (
47+
None # TODO: change from server_script_path to mcp_config_path for agent rollout processor
48+
)
4449
max_concurrent_rollouts: int = 8 # maximum number of concurrent rollouts
4550
steps: int = 30 # max number of rollout steps
51+
logger: DatasetLogger = default_logger # logger to use during rollout for mid-rollout logs
4652

4753

4854
RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[EvaluationRow]]

tests/dataset_logger/test_sqlite_dataset_logger_adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ def test_update_log_and_read():
1818
messages = [Message(role="user", content="Hello")]
1919
input_metadata = InputMetadata(row_id="1")
2020
row = EvaluationRow(input_metadata=input_metadata, messages=messages)
21-
store.upsert_row(row_id="1", data=row.model_dump(exclude_none=True, mode="json"))
21+
store.upsert_row(data=row.model_dump(exclude_none=True, mode="json"))
2222

2323
row.messages.append(Message(role="assistant", content="Hello"))
2424

25-
logger = SqliteDatasetLoggerAdapter()
25+
logger = SqliteDatasetLoggerAdapter(store=store)
2626
logger.log(row)
27-
saved = logger.read(row_id="1")[0]
27+
saved = logger.read(row.rollout_id)[0]
2828
assert row.messages == saved.messages
2929
assert row.input_metadata == saved.input_metadata
3030

@@ -42,7 +42,7 @@ def test_create_log_and_read():
4242
row = EvaluationRow(input_metadata=input_metadata, messages=messages)
4343

4444
logger.log(row)
45-
saved = logger.read(row_id="1")[0]
45+
saved = logger.read(rollout_id=row.rollout_id)[0]
4646
assert row.messages == saved.messages
4747
assert row.input_metadata == saved.input_metadata
4848

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
import random
3+
from typing import List
4+
5+
import pytest
6+
7+
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
8+
from eval_protocol.pytest import default_no_op_rollout_processor, evaluation_test
9+
10+
11+
# skip in CI since it will intentionally fail. This is useful for local generation of logs
12+
@pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping flaky test in CI")
13+
@evaluation_test(
14+
input_messages=[[Message(role="user", content="Return HEADS or TAILS at random.")]],
15+
model=["dummy/local-model"],
16+
rollout_processor=default_no_op_rollout_processor,
17+
mode="pointwise",
18+
num_runs=5,
19+
)
20+
def test_flaky_passes_sometimes(row: EvaluationRow) -> EvaluationRow:
21+
"""
22+
A deliberately flaky evaluation that only passes occasionally.
23+
24+
With num_runs=5 and a success probability of ~0.3 per run, the aggregated mean
25+
will clear the threshold (0.8) only rarely. Uses the no-op rollout to avoid any
26+
actual model calls.
27+
"""
28+
# Stochastic score: 1.0 with 30% probability, else 0.0
29+
score = 1.0 if random.random() < 0.3 else 0.0
30+
row.evaluation_result = EvaluateResult(score=score, reason=f"stochastic={score}")
31+
return row

0 commit comments

Comments
 (0)