Skip to content

Commit b969a9c

Browse files
committed
fixes
1 parent a9b0674 commit b969a9c

18 files changed

+350
-512
lines changed

eval_protocol/adapters/huggingface.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,12 @@ def get_evaluation_rows(
188188
return
189189

190190
# Create completion parameters
191-
completion_params = CompletionParams(
192-
model=model_name,
193-
temperature=temperature,
194-
max_tokens=max_tokens,
191+
completion_params: CompletionParams = {
192+
"model": model_name,
193+
"temperature": temperature,
194+
"max_tokens": max_tokens,
195195
**completion_params_kwargs,
196-
)
196+
}
197197

198198
# Convert each row
199199
for i in range(offset, end_idx):

eval_protocol/adapters/langfuse.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from datetime import datetime
99
from typing import Any, Dict, Iterator, List, Optional
1010

11-
from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message
11+
from eval_protocol.models import EvaluationRow, InputMetadata, Message
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -277,20 +277,20 @@ def _create_input_metadata(self, trace: Any, observations: List[Any]) -> InputMe
277277
InputMetadata object
278278
"""
279279
# Extract completion parameters from observations
280-
completion_params = CompletionParams()
280+
completion_params = {}
281281

282282
# Look for model parameters in observations
283283
for obs in observations:
284284
if hasattr(obs, "model") and obs.model:
285-
completion_params.model = obs.model
285+
completion_params["model"] = obs.model
286286
if hasattr(obs, "model_parameters") and obs.model_parameters:
287287
params = obs.model_parameters
288288
if "temperature" in params:
289-
completion_params.temperature = params["temperature"]
289+
completion_params["temperature"] = params["temperature"]
290290
if "max_tokens" in params:
291-
completion_params.max_tokens = params["max_tokens"]
291+
completion_params["max_tokens"] = params["max_tokens"]
292292
if "top_p" in params:
293-
completion_params.top_p = params["top_p"]
293+
completion_params["top_p"] = params["top_p"]
294294
break
295295

296296
# Create dataset info from trace metadata

eval_protocol/benchmarks/suites/gpqa.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import csv
23
import io
34
import re
@@ -8,9 +9,11 @@
89
from eval_protocol.benchmarks.registry import export_benchmark
910
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
1011
from eval_protocol.pytest.default_single_turn_rollout_process import (
11-
default_single_turn_rollout_processor,
12+
SingleTurnRolloutProcessor,
1213
)
1314
from eval_protocol.pytest.evaluation_test import evaluation_test
15+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
16+
from eval_protocol.pytest.types import RolloutProcessorConfig
1417

1518
SYSTEM_PROMPT = (
1619
"You are a helpful assistant. Read the question and options carefully. "
@@ -60,27 +63,40 @@ def _strip_gt_messages(msgs: List[Message]) -> List[Message]:
6063
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
6164

6265

63-
async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[EvaluationRow]:
64-
"""Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to default processor."""
65-
processed: List[EvaluationRow] = []
66-
for r in rows:
67-
gt_tokens = [m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")]
68-
if gt_tokens:
69-
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
70-
r.ground_truth = gt_val
71-
r.messages = [
72-
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
66+
class GPQAStripGTRolloutProcessor(RolloutProcessor):
67+
"""Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to SingleTurnRolloutProcessor."""
68+
69+
def __init__(self):
70+
super().__init__()
71+
self.single_turn_processor = SingleTurnRolloutProcessor()
72+
73+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
74+
"""Preprocess rows and delegate to SingleTurnRolloutProcessor."""
75+
processed: List[EvaluationRow] = []
76+
77+
for r in rows:
78+
gt_tokens = [
79+
m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")
7380
]
74-
processed.append(r)
75-
return await default_single_turn_rollout_processor(processed, config)
81+
if gt_tokens:
82+
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
83+
r.ground_truth = gt_val
84+
r.messages = [
85+
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
86+
]
87+
processed.append(r)
88+
89+
# Delegate to SingleTurnRolloutProcessor
90+
return self.single_turn_processor(processed, config)
7691

7792

7893
@export_benchmark("gpqa")
7994
@evaluation_test(
80-
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
8195
input_messages=_GPQA_INPUT_MESSAGES,
82-
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
83-
rollout_processor=gpqa_strip_gt_rollout_processor,
96+
completion_params=[
97+
{"extra_body": {"reasoning_effort": "low"}, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}
98+
],
99+
rollout_processor=GPQAStripGTRolloutProcessor(),
84100
aggregation_method="mean",
85101
passed_threshold=None,
86102
num_runs=8,
Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,32 @@
11
import os
22

3+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
34
from eval_protocol.dataset_logger.sqlite_dataset_logger_adapter import SqliteDatasetLoggerAdapter
45

6+
57
# Allow disabling sqlite logger to avoid environment-specific constraints in simple CLI runs.
6-
if os.getenv("EP_SQLITE_LOG", "0").strip() == "1":
7-
default_logger = SqliteDatasetLoggerAdapter()
8-
else:
8+
def _get_default_logger():
9+
if os.getenv("DISABLE_EP_SQLITE_LOG", "0").strip() != "1":
10+
return SqliteDatasetLoggerAdapter()
11+
else:
12+
13+
class _NoOpLogger(DatasetLogger):
14+
def log(self, row):
15+
return None
16+
17+
def read(self, rollout_id=None):
18+
return []
19+
20+
return _NoOpLogger()
21+
22+
23+
# Lazy property that creates the logger only when accessed
24+
class _LazyLogger(DatasetLogger):
25+
def log(self, row):
26+
return _get_default_logger().log(row)
927

10-
class _NoOpLogger:
11-
def log(self, row):
12-
return None
28+
def read(self, rollout_id=None):
29+
return _get_default_logger().read(rollout_id)
1330

14-
def read(self, rollout_id=None):
15-
return []
1631

17-
default_logger = _NoOpLogger()
32+
default_logger: DatasetLogger = _LazyLogger()

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 97 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -2,129 +2,117 @@
22
import logging
33
import os
44
import time
5-
from typing import AsyncIterator, List
5+
from typing import List
66

7-
import litellm
87
from litellm import acompletion
98
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
109

1110
from eval_protocol.dataset_logger import default_logger
1211
from eval_protocol.models import EvaluationRow, Message
12+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1313
from eval_protocol.pytest.types import RolloutProcessorConfig
1414

1515
logger = logging.getLogger(__name__)
1616

1717

18-
async def default_single_turn_rollout_processor(
19-
rows: List[EvaluationRow], config: RolloutProcessorConfig
20-
) -> AsyncIterator[EvaluationRow]:
21-
"""Generate a single response from any supported model provider using LiteLLM."""
22-
23-
# Quiet LiteLLM logs in test runs unless user overrode
24-
try:
25-
if os.environ.get("LITELLM_LOG") is None:
26-
os.environ["LITELLM_LOG"] = "ERROR"
27-
_llog = logging.getLogger("LiteLLM")
28-
_llog.setLevel(logging.CRITICAL)
29-
_llog.propagate = False
30-
for _h in list(_llog.handlers):
31-
_llog.removeHandler(_h)
32-
except Exception:
33-
pass
34-
35-
# Do not modify global LiteLLM cache. Disable caching per-request instead.
36-
37-
async def process_row(row: EvaluationRow) -> EvaluationRow:
38-
"""Process a single row asynchronously."""
39-
if len(row.messages) == 0:
40-
raise ValueError("Messages is empty. Please provide a non-empty dataset")
41-
42-
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
43-
44-
request_params = {"model": config.model, "messages": messages_payload, **config.input_params}
45-
# Ensure caching is disabled only for this request (review feedback)
46-
request_params["cache"] = {"no-cache": True}
47-
# Single-level reasoning effort: expect `reasoning_effort` only
48-
effort_val = None
49-
if isinstance(config.input_params, dict):
50-
if "reasoning_effort" in config.input_params:
51-
effort_val = str(config.input_params["reasoning_effort"]) # flat shape
18+
class SingleTurnRolloutProcessor(RolloutProcessor):
19+
"""Single turn rollout processor for direct LLM calls."""
20+
21+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
22+
"""Generate single turn rollout tasks and return them for external handling."""
23+
24+
# Quiet LiteLLM logs in test runs unless user overrode
25+
try:
26+
if os.environ.get("LITELLM_LOG") is None:
27+
os.environ["LITELLM_LOG"] = "ERROR"
28+
_llog = logging.getLogger("LiteLLM")
29+
_llog.setLevel(logging.CRITICAL)
30+
_llog.propagate = False
31+
for _h in list(_llog.handlers):
32+
_llog.removeHandler(_h)
33+
except Exception:
34+
pass
35+
36+
# Do not modify global LiteLLM cache. Disable caching per-request instead.
37+
38+
async def process_row(row: EvaluationRow) -> EvaluationRow:
39+
"""Process a single row asynchronously."""
40+
if len(row.messages) == 0:
41+
raise ValueError("Messages is empty. Please provide a non-empty dataset")
42+
43+
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
44+
45+
request_params = {"messages": messages_payload, **config.completion_params}
46+
# Ensure caching is disabled only for this request (review feedback)
47+
request_params["cache"] = {"no-cache": True}
48+
# Single-level reasoning effort: expect `reasoning_effort` only
49+
effort_val = None
50+
51+
if "reasoning_effort" in config.completion_params:
52+
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
5253
elif (
53-
isinstance(config.input_params.get("extra_body"), dict)
54-
and "reasoning_effort" in config.input_params["extra_body"]
54+
isinstance(config.completion_params.get("extra_body"), dict)
55+
and "reasoning_effort" in config.completion_params["extra_body"]
5556
):
5657
# Accept if user passed it directly inside extra_body
57-
effort_val = str(config.input_params["extra_body"]["reasoning_effort"]) # already in extra_body
58-
59-
if effort_val:
60-
# Always under extra_body so LiteLLM forwards to provider-specific param set
61-
request_params.setdefault("extra_body", {})
62-
request_params["extra_body"]["reasoning_effort"] = effort_val
63-
# Ensure unsupported top-level keys are not present
64-
if "reasoning_effort" in request_params:
65-
request_params.pop("reasoning_effort", None)
66-
67-
if row.tools is not None:
68-
request_params["tools"] = row.tools
69-
70-
# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
71-
import importlib
72-
73-
_litellm = importlib.import_module("litellm")
74-
acompletion = getattr(_litellm, "acompletion")
75-
response = await acompletion(**request_params)
76-
77-
assistant_content = response.choices[0].message.content or ""
78-
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
79-
80-
converted_tool_calls = None
81-
if tool_calls:
82-
converted_tool_calls = [
83-
ChatCompletionMessageToolCall(
84-
id=tool_call.id,
85-
type=tool_call.type,
86-
function={
87-
"name": tool_call.function.name,
88-
"arguments": tool_call.function.arguments,
89-
},
58+
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
59+
60+
if effort_val:
61+
# Always under extra_body so LiteLLM forwards to provider-specific param set
62+
request_params.setdefault("extra_body", {})
63+
request_params["extra_body"]["reasoning_effort"] = effort_val
64+
# Ensure unsupported top-level keys are not present
65+
if "reasoning_effort" in request_params:
66+
request_params.pop("reasoning_effort", None)
67+
68+
if row.tools is not None:
69+
request_params["tools"] = row.tools
70+
71+
# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
72+
import importlib
73+
74+
_litellm = importlib.import_module("litellm")
75+
acompletion = getattr(_litellm, "acompletion")
76+
response = await acompletion(**request_params)
77+
78+
assistant_content = response.choices[0].message.content or ""
79+
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
80+
81+
converted_tool_calls = None
82+
if tool_calls:
83+
converted_tool_calls = [
84+
ChatCompletionMessageToolCall(
85+
id=tool_call.id,
86+
type=tool_call.type,
87+
function={
88+
"name": tool_call.function.name,
89+
"arguments": tool_call.function.arguments,
90+
},
91+
)
92+
for tool_call in tool_calls
93+
]
94+
95+
messages = list(row.messages) + [
96+
Message(
97+
role="assistant",
98+
content=assistant_content,
99+
tool_calls=converted_tool_calls,
90100
)
91-
for tool_call in tool_calls
92101
]
93102

94-
messages = list(row.messages) + [
95-
Message(
96-
role="assistant",
97-
content=assistant_content,
98-
tool_calls=converted_tool_calls,
99-
)
100-
]
101-
102-
row.messages = messages
103-
default_logger.log(row)
104-
return row
105-
106-
# Process rows with bounded concurrency and yield as they complete
107-
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
108-
semaphore = asyncio.Semaphore(max_concurrent)
109-
110-
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
111-
async with semaphore:
112-
try:
113-
return await process_row(r)
114-
except Exception:
115-
return r
116-
117-
# Create all tasks
118-
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
119-
120-
# Yield results as they complete (note that they're not necessarily in original order)
121-
try:
122-
for task in asyncio.as_completed(tasks):
123-
try:
124-
yield await task
125-
except Exception:
126-
logger.exception("Error processing row")
127-
finally:
128-
for t in tasks:
129-
t.cancel()
130-
await asyncio.gather(*tasks, return_exceptions=True)
103+
row.messages = messages
104+
default_logger.log(row)
105+
return row
106+
107+
# Process rows with bounded concurrency
108+
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
109+
semaphore = asyncio.Semaphore(max_concurrent)
110+
111+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
112+
async with semaphore:
113+
result = await process_row(r)
114+
return result
115+
116+
# Create and return tasks for external handling
117+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
118+
return tasks

0 commit comments

Comments
 (0)