Skip to content

Commit b57b87e

Browse files
Benny ChenBenny Chen
authored andcommitted
fix more tests
1 parent 847ff69 commit b57b87e

File tree

8 files changed

+80
-24
lines changed

8 files changed

+80
-24
lines changed

eval_protocol/benchmarks/test_livebench_data_analysis.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import re
44
from typing import Any, Dict, List, Optional
55

6-
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
6+
from eval_protocol.models import (
7+
EvaluateResult,
8+
EvaluationRow,
9+
Message,
10+
MetricResult,
11+
ChatCompletionContentPartTextParam,
12+
)
713
from eval_protocol.pytest.default_single_turn_rollout_process import (
814
SingleTurnRolloutProcessor,
915
)
@@ -31,6 +37,12 @@ def _extract_last_boxed_segment(text: str) -> Optional[str]:
3137
return matches[-1]
3238

3339

40+
def _coerce_content_to_str(content: str | list[ChatCompletionContentPartTextParam] | None) -> str:
41+
if isinstance(content, list):
42+
return "".join([getattr(p, "text", str(p)) for p in content])
43+
return str(content or "")
44+
45+
3446
def _cta_process_results(ground_truth: str, llm_answer: str) -> int:
3547
parsed_answer = llm_answer
3648
if "\\boxed{" in parsed_answer or "\\framebox{" in parsed_answer:
@@ -420,7 +432,8 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:
420432
)
421433
def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
422434
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
423-
content = assistant_msgs[-1].content if assistant_msgs else ""
435+
raw_content = assistant_msgs[-1].content if assistant_msgs else ""
436+
content = _coerce_content_to_str(raw_content)
424437
payload = _extract_gt(row)
425438
gt = payload.get("ground_truth")
426439
gt_str = str(gt) if gt is not None else ""
@@ -462,9 +475,9 @@ def test_livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
462475
)
463476
def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
464477
user_msgs = [m for m in row.messages if m.role == "user"]
465-
question = user_msgs[-1].content if user_msgs else ""
478+
question = _coerce_content_to_str(user_msgs[-1].content if user_msgs else "")
466479
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
467-
content = assistant_msgs[-1].content if assistant_msgs else ""
480+
content = _coerce_content_to_str(assistant_msgs[-1].content if assistant_msgs else "")
468481
payload = _extract_gt(row)
469482
gt = payload.get("ground_truth")
470483

@@ -505,9 +518,9 @@ def test_livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
505518
)
506519
def test_livebench_tablereformat_pointwise(row: EvaluationRow) -> EvaluationRow:
507520
user_msgs = [m for m in row.messages if m.role == "user"]
508-
question = user_msgs[-1].content if user_msgs else ""
521+
question = _coerce_content_to_str(user_msgs[-1].content if user_msgs else "")
509522
assistant_msgs = [m for m in row.messages if m.role == "assistant"]
510-
content = assistant_msgs[-1].content if assistant_msgs else ""
523+
content = _coerce_content_to_str(assistant_msgs[-1].content if assistant_msgs else "")
511524
payload = _extract_gt(row)
512525
gt = payload.get("ground_truth")
513526
release = payload.get("release") or ""

eval_protocol/integrations/braintrust.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def scorer_to_reward_fn(
1818
"""Wrap a Braintrust scorer as an Eval Protocol reward function."""
1919

2020
@reward_function
21-
def reward_fn(messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs) -> EvaluateResult:
21+
def reward_fn(
22+
messages: List[Message], ground_truth: Optional[List[Message]] = None, **kwargs: Any
23+
) -> EvaluateResult:
2224
input_val = messages_to_input(messages) if messages_to_input else messages[0].content
2325
output_val = messages[-1].content
2426
expected_val = None

eval_protocol/integrations/deepeval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _build_case_kwargs() -> Dict[str, Any]:
7979
case_kwargs["actual_output"] = output
8080
return case_kwargs
8181

82-
if isinstance(metric, BaseConversationalMetric):
82+
if BaseConversationalMetric is not None and isinstance(metric, BaseConversationalMetric):
8383
turns = []
8484
for i, msg in enumerate(messages):
8585
turn_input = messages[i - 1].get("content", "") if i > 0 else ""

eval_protocol/models.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,33 @@ class ChatCompletionContentPartTextParam(BaseModel):
224224
text: str = Field(..., description="The text content.")
225225
type: Literal["text"] = Field("text", description="The type of the content part.")
226226

227+
# Provide dict-like access for tests and ergonomic usage
228+
def __getitem__(self, key: str) -> Any:
229+
if key == "text":
230+
return self.text
231+
if key == "type":
232+
return self.type
233+
raise KeyError(key)
234+
235+
def get(self, key: str, default: Any = None) -> Any:
236+
try:
237+
return self[key]
238+
except KeyError:
239+
return default
240+
241+
def keys(self):
242+
return (k for k in ("text", "type"))
243+
244+
def values(self):
245+
return (self.text, self.type)
246+
247+
def items(self):
248+
return [("text", self.text), ("type", self.type)]
249+
250+
def __iter__(self):
251+
# Iterate over keys only
252+
return iter(["text", "type"])
253+
227254

228255
class Message(BaseModel):
229256
"""Chat message model with trajectory evaluation support."""
@@ -293,10 +320,12 @@ def values(self):
293320
return [getattr(self, key) for key in self.__fields__.keys()] # Changed to __fields__
294321

295322
def items(self):
296-
return [(key, getattr(self, key)) for key in self.__fields__.keys()] # Changed to __fields__
323+
# Exclude 'data' from items to keep items hashable and match tests
324+
return [(key, getattr(self, key)) for key in self.__fields__.keys() if key != "data"] # Changed to __fields__
297325

298326
def __iter__(self):
299-
return iter(self.__fields__.keys()) # Changed to __fields__
327+
# Exclude 'data' to match expectations in tests
328+
return iter([k for k in self.__fields__.keys() if k != "data"]) # Changed to __fields__
300329

301330

302331
class StepOutput(BaseModel):

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ async def _get_tools(self) -> Optional[List[dict[str, Any]]]:
5858
if f is not None and not isinstance(f, dict):
5959
f_name = getattr(f, "name", None)
6060
f_params = getattr(f, "parameters", None)
61-
if hasattr(f_params, "model_dump"):
61+
if f_params is not None and hasattr(f_params, "model_dump"):
6262
f_params = f_params.model_dump()
6363
func_obj = FunctionLike(name=f_name, parameters=f_params)
6464
t = {"type": t.get("type", "function"), "function": func_obj}
@@ -70,7 +70,7 @@ async def _get_tools(self) -> Optional[List[dict[str, Any]]]:
7070
# Construct a dict from object-like tool
7171
name = getattr(func, "name", None)
7272
params = getattr(func, "parameters", None)
73-
if hasattr(params, "model_dump"):
73+
if params is not None and hasattr(params, "model_dump"):
7474
params_payload = params.model_dump()
7575
elif isinstance(params, dict):
7676
params_payload = params
@@ -135,15 +135,15 @@ async def _call_model(self, messages: list[Message], tools: Optional[List[dict[s
135135
for tool in tools or []:
136136
if isinstance(tool, dict):
137137
fn = tool.get("function")
138-
if hasattr(fn, "model_dump"):
138+
if fn is not None and hasattr(fn, "model_dump"):
139139
fn_payload = fn.model_dump()
140140
elif isinstance(fn, dict):
141141
fn_payload = fn
142142
else:
143143
# Best effort fallback
144144
name = getattr(fn, "name", None)
145145
params = getattr(fn, "parameters", None)
146-
if hasattr(params, "model_dump"):
146+
if params is not None and hasattr(params, "model_dump"):
147147
params_payload = params.model_dump()
148148
elif isinstance(params, dict):
149149
params_payload = params
@@ -157,7 +157,7 @@ async def _call_model(self, messages: list[Message], tools: Optional[List[dict[s
157157
func = getattr(tool, "function", None)
158158
name = getattr(func, "name", None)
159159
params = getattr(func, "parameters", None)
160-
if hasattr(params, "model_dump"):
160+
if params is not None and hasattr(params, "model_dump"):
161161
params_payload = params.model_dump()
162162
elif isinstance(params, dict):
163163
params_payload = params
@@ -192,11 +192,11 @@ async def _execute_tool_call(
192192
return tool_call_id, content
193193

194194
def _get_content_from_tool_result(self, tool_result: CallToolResult | str) -> List[TextContent]:
195+
if isinstance(tool_result, str):
196+
return [TextContent(text=tool_result, type="text")]
195197
if getattr(tool_result, "structuredContent", None):
196198
return [TextContent(text=json.dumps(tool_result.structuredContent), type="text")]
197199
normalized: List[TextContent] = []
198-
if isinstance(tool_result, str):
199-
return [TextContent(text=tool_result, type="text")]
200200
for content in getattr(tool_result, "content", []) or []:
201201
if isinstance(content, TextContent):
202202
normalized.append(content)

eval_protocol/pytest/default_langchain_rollout_processor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
3535
from langchain_core.messages import HumanMessage
3636
except Exception:
3737
# Fallback minimal message if langchain_core is unavailable
38-
class HumanMessage: # type: ignore
38+
class HumanMessage(BaseMessage): # type: ignore
3939
def __init__(self, content: str):
4040
self.content = content
41+
self.type = "human"
4142

4243
lm_messages: List[BaseMessage] = []
4344
if row.messages:
@@ -67,8 +68,12 @@ async def _invoke_wrapper(payload):
6768
else:
6869
raise TypeError("Unsupported invoke target for LangGraphRolloutProcessor")
6970

70-
result = await invoke_fn({"messages": lm_messages})
71-
result_messages: List[BaseMessage] = result.get("messages", [])
71+
result_obj = await invoke_fn({"messages": lm_messages})
72+
# Accept both dicts and objects with .get/.messages
73+
if isinstance(result_obj, dict):
74+
result_messages: List[BaseMessage] = result_obj.get("messages", [])
75+
else:
76+
result_messages = getattr(result_obj, "messages", [])
7277

7378
def _serialize_message(msg: BaseMessage) -> Message:
7479
# Prefer SDK-level serializer

eval_protocol/pytest/plugin.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,19 @@ def pytest_configure(config) -> None:
282282
def pytest_sessionfinish(session, exitstatus):
283283
"""Print all collected Fireworks experiment links from pytest stash."""
284284
try:
285-
from .evaluation_test import EXPERIMENT_LINKS_STASH_KEY
285+
# Late import to avoid circulars; if missing key, skip printing
286+
EXPERIMENT_LINKS_STASH_KEY: StashKey[list[dict]] | None = None
287+
try:
288+
from .evaluation_test import EXPERIMENT_LINKS_STASH_KEY as _KEY # type: ignore
289+
290+
EXPERIMENT_LINKS_STASH_KEY = _KEY
291+
except Exception:
292+
EXPERIMENT_LINKS_STASH_KEY = None
286293

287294
# Get links from pytest stash using shared key
288295
links = []
289296

290-
if EXPERIMENT_LINKS_STASH_KEY in session.stash:
297+
if EXPERIMENT_LINKS_STASH_KEY is not None and EXPERIMENT_LINKS_STASH_KEY in session.stash:
291298
links = session.stash[EXPERIMENT_LINKS_STASH_KEY]
292299

293300
if links:
@@ -303,5 +310,5 @@ def pytest_sessionfinish(session, exitstatus):
303310

304311
print("=" * 80, file=sys.__stderr__)
305312
sys.__stderr__.flush()
306-
except Exception as e:
313+
except Exception:
307314
pass

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_metric_result_dict_access():
175175
assert metric.get("invalid_key", "default_val") == "default_val"
176176

177177
# keys()
178-
assert set(metric.keys()) == {"score", "reason", "is_score_valid"}
178+
assert set(metric.keys()) == {"score", "reason", "is_score_valid", "data"}
179179

180180
# values() - order might not be guaranteed by model_fields, so check content
181181
# Pydantic model_fields preserves declaration order.

0 commit comments

Comments
 (0)