Skip to content

Commit d587101

Browse files
committed
format
1 parent 3406889 commit d587101

File tree

4 files changed

+12
-10
lines changed

4 files changed

+12
-10
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
7373

7474
_litellm = importlib.import_module("litellm")
7575
acompletion = getattr(_litellm, "acompletion")
76-
logger.debug(f'********** request_params: {request_params} **********')
76+
logger.debug(f"********** request_params: {request_params} **********")
7777
response = await acompletion(**request_params)
7878

7979
assistant_content = response.choices[0].message.content or ""
80-
logger.debug(f'********** assistant_content: {assistant_content} **********')
80+
logger.debug(f"********** assistant_content: {assistant_content} **********")
8181
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
8282

8383
converted_tool_calls = None

eval_protocol/pytest/evaluation_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,9 @@ async def _collect_result(config, lst, max_retry):
634634
for ori_row in fresh_dataset:
635635
copied_row = ori_row.model_copy(deep=True)
636636
# overwrite the rollout_id to the index of the completion_params
637-
copied_row.execution_metadata.rollout_id = str(ori_row.execution_metadata.rollout_id) + "_" + str(idx)
637+
copied_row.execution_metadata.rollout_id = (
638+
str(ori_row.execution_metadata.rollout_id) + "_" + str(idx)
639+
)
638640
copied_row.input_metadata.completion_params = cp
639641
lst.append(copied_row)
640642
tasks.append(asyncio.create_task(_collect_result(config, lst, max_retry)))

tests/pytest/test_pytest_groupwise.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
44
from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test
55

6+
67
@evaluation_test(
78
input_messages=[
89
[
@@ -24,4 +25,4 @@ def test_pytest_groupwise(rows: List[EvaluationRow]) -> List[EvaluationRow]:
2425
rows[1].evaluation_result = EvaluateResult(score=0.0, reason="test")
2526
print(rows[0].model_dump_json())
2627
print(rows[1].model_dump_json())
27-
return rows
28+
return rows

tests/pytest/test_svgbench.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,6 @@ def evaluate_with_llm_judge_groupwise(image_paths: List[str], requirements: List
288288
Requirements:
289289
{requirements_text}"""
290290

291-
292291
messages = [
293292
{
294293
"role": "user",
@@ -302,7 +301,9 @@ def evaluate_with_llm_judge_groupwise(image_paths: List[str], requirements: List
302301
for image_path in image_paths:
303302
with open(image_path, "rb") as f:
304303
image_data = base64.b64encode(f.read()).decode("utf-8")
305-
messages[0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}})
304+
messages[0]["content"].append(
305+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}}
306+
)
306307

307308
# Use GPT-4.1 for vision capabilities to match project's OpenAI model preference
308309
response = litellm.completion(
@@ -331,7 +332,6 @@ def evaluate_with_llm_judge_groupwise(image_paths: List[str], requirements: List
331332
raise ValueError("Missing required field in response")
332333

333334

334-
335335
@evaluation_test(
336336
input_dataset=["tests/pytest/data/svgbench_dataset.jsonl"],
337337
dataset_adapter=svgbench_to_evaluation_row,
@@ -540,15 +540,14 @@ def test_svg_generation_evaluation_groupwise(rows: List[EvaluationRow]) -> List[
540540
row.evaluation_result = EvaluateResult(score=0.0, reason=f"Evaluation error: {str(e)}")
541541

542542
judge_result = evaluate_with_llm_judge_groupwise(image_paths, requirements)
543-
print(f'********** judge_result: {judge_result} **********')
543+
print(f"********** judge_result: {judge_result} **********")
544544
if judge_result.get("best_image_index") == 0:
545545
rows[0].evaluation_result = EvaluateResult(score=1.0, reason=judge_result.get("reasoning", ""))
546546
rows[1].evaluation_result = EvaluateResult(score=0.0, reason=judge_result.get("reasoning", ""))
547547
else:
548548
rows[0].evaluation_result = EvaluateResult(score=0.0, reason=judge_result.get("reasoning", ""))
549549
rows[1].evaluation_result = EvaluateResult(score=1.0, reason=judge_result.get("reasoning", ""))
550-
551-
550+
552551
# Clean up temporary PNG file (only if not saving debug files)
553552
if not save_debug_files:
554553
for png_path in image_paths:

0 commit comments

Comments
 (0)