Skip to content

Commit aac0214

Browse files
authored
support groupwise scoring (#101)
* support groupwise scoring * format * fix ut * add tests * remove useless test * format * rename listwise to all * fix ut
1 parent 6e557b8 commit aac0214

File tree

8 files changed

+507
-196
lines changed

8 files changed

+507
-196
lines changed

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
135135
async def process_row(row: EvaluationRow) -> EvaluationRow:
136136
"""Process a single row with agent rollout."""
137137
agent = Agent(
138-
model=config.completion_params["model"],
138+
model=row.input_metadata.completion_params["model"],
139139
row=row,
140140
config_path=config.mcp_config_path,
141141
logger=config.logger,

eval_protocol/pytest/evaluation_test.py

Lines changed: 294 additions & 188 deletions
Large diffs are not rendered by default.

eval_protocol/pytest/types.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,11 @@
1919

2020
Dataset = List[EvaluationRow]
2121

22-
EvaluationTestMode = Literal["batch", "pointwise"]
22+
EvaluationTestMode = Literal["pointwise", "groupwise", "all"]
2323
"""
24-
"batch": (default) expects test function to handle full dataset.
25-
"pointwise": applies test function to each row.
26-
27-
How to choose between "batch" and "pointwise":
28-
If your evaluation requires the rollout of all rows to be passed into your eval compute the score, use "batch".
29-
If your evaluation can be computed pointwise, use "pointwise" as EP can pipeline the rollouts and evals to be faster.
24+
"pointwise": (default) applies test function to each row (rollout result).
25+
"groupwise": applies test function to a group of rollout results from the same original row (for use cases such as dpo/grpo).
26+
"all": applies test function to the whole dataset.
3027
"""
3128

3229
"""

tests/pytest/test_pytest_async.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
],
1919
],
2020
completion_params=[{"model": "accounts/fireworks/models/kimi-k2-instruct"}],
21+
mode="all",
2122
)
2223
async def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
2324
"""Run math evaluation on sample dataset using pytest interface."""

tests/pytest/test_pytest_default_agent_rollout_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
],
1919
rollout_processor=AgentRolloutProcessor(),
2020
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}],
21+
mode="all",
2122
)
2223
def test_pytest_default_agent_rollout_processor(rows: List[EvaluationRow]) -> List[EvaluationRow]:
2324
"""Run math evaluation on sample dataset using pytest interface."""
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import List
2+
3+
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
4+
from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test
5+
6+
7+
@evaluation_test(
8+
input_messages=[
9+
[
10+
Message(role="user", content="What is the capital of France?"),
11+
]
12+
],
13+
completion_params=[
14+
{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"},
15+
{"model": "fireworks_ai/accounts/fireworks/models/gpt-4.1"},
16+
],
17+
rollout_processor=SingleTurnRolloutProcessor(),
18+
mode="groupwise",
19+
)
20+
def test_pytest_groupwise(rows: List[EvaluationRow]) -> List[EvaluationRow]:
21+
"""Run math evaluation on sample dataset using pytest interface."""
22+
assert rows[0].input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"
23+
assert rows[1].input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-4.1"
24+
rows[0].evaluation_result = EvaluateResult(score=1.0, reason="test")
25+
rows[1].evaluation_result = EvaluateResult(score=0.0, reason="test")
26+
print(rows[0].model_dump_json())
27+
print(rows[1].model_dump_json())
28+
return rows

tests/pytest/test_pytest_input_messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
],
1313
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
1414
rollout_processor=SingleTurnRolloutProcessor(),
15+
mode="all",
1516
)
1617
def test_input_messages_in_decorator(rows: List[EvaluationRow]) -> List[EvaluationRow]:
1718
"""Run math evaluation on sample dataset using pytest interface."""

tests/pytest/test_svgbench.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,74 @@ def evaluate_with_llm_judge(image_path: str, requirements: List[str]) -> Dict[st
264264
raise ValueError("Missing required field in response")
265265

266266

267+
def evaluate_with_llm_judge_groupwise(image_paths: List[str], requirements: List[str]) -> Dict[str, Any]:
268+
"""
269+
Use LLM judge to evaluate how many requirements are fulfilled.
270+
Uses GPT-4.1 for vision capabilities to match project's model preferences. (note original repo uses Gemini 2.5 flashs)
271+
272+
Args:
273+
image_path: Path to rendered PNG image
274+
requirements: List of requirements to evaluate
275+
276+
Returns:
277+
Dictionary with evaluation results
278+
"""
279+
# Format requirements for evaluation (exactly as in original)
280+
requirements_text = "\n".join([f"{i + 1}. {req}" for i, req in enumerate(requirements)])
281+
282+
# Create evaluation prompt with JSON response format
283+
evaluate_prompt = f"""Examine the generated images you are given. Based on the following {len(requirements)} requirements, which one is better?
284+
285+
Respond ONLY with a JSON object in this exact format:
286+
{{"best_image_index": <index>, "reasoning": <reasoning_text>}}
287+
288+
Requirements:
289+
{requirements_text}"""
290+
291+
messages = [
292+
{
293+
"role": "user",
294+
"content": [
295+
{"type": "text", "text": evaluate_prompt},
296+
],
297+
}
298+
]
299+
300+
# Read and encode image
301+
for image_path in image_paths:
302+
with open(image_path, "rb") as f:
303+
image_data = base64.b64encode(f.read()).decode("utf-8")
304+
messages[0]["content"].append(
305+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}}
306+
)
307+
308+
# Use GPT-4.1 for vision capabilities to match project's OpenAI model preference
309+
response = litellm.completion(
310+
model="gpt-4.1",
311+
messages=messages,
312+
temperature=0.0,
313+
response_format={
314+
"type": "json_schema",
315+
"json_schema": {"name": "SVGBenchResponse", "schema": SVGBenchResponse.model_json_schema()},
316+
},
317+
)
318+
319+
# Parse response
320+
response_content = response.choices[0].message.content
321+
322+
# Handle empty response
323+
if not response_content or response_content.strip() == "":
324+
raise ValueError("Empty response from LLM judge")
325+
326+
result = json.loads(response_content)
327+
328+
# Validate the result
329+
if "best_image_index" in result:
330+
return result
331+
else:
332+
raise ValueError("Missing required field in response")
333+
334+
267335
@evaluation_test(
268336
input_dataset=["tests/pytest/data/svgbench_dataset.jsonl"],
269337
dataset_adapter=svgbench_to_evaluation_row,
@@ -279,6 +347,7 @@ def evaluate_with_llm_judge(image_path: str, requirements: List[str]) -> Dict[st
279347
passed_threshold=0.5, # 50% average score to pass
280348
num_runs=1,
281349
mode="pointwise",
350+
max_dataset_rows=1,
282351
max_concurrent_rollouts=50,
283352
)
284353
def test_svg_generation_evaluation(row: EvaluationRow) -> EvaluationRow:
@@ -378,3 +447,111 @@ def test_svg_generation_evaluation(row: EvaluationRow) -> EvaluationRow:
378447
os.unlink(png_path)
379448
except Exception:
380449
pass
450+
451+
452+
@evaluation_test(
453+
input_dataset=["tests/pytest/data/svgbench_dataset.jsonl"],
454+
dataset_adapter=svgbench_to_evaluation_row,
455+
completion_params=[
456+
{"temperature": 0.0, "model": "gpt-4.1"},
457+
{
458+
"temperature": 0.8,
459+
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
460+
"extra_body": {"reasoning_effort": "high"},
461+
},
462+
],
463+
rollout_processor=SingleTurnRolloutProcessor(),
464+
passed_threshold=None,
465+
num_runs=1,
466+
max_dataset_rows=3,
467+
mode="groupwise",
468+
max_concurrent_rollouts=50,
469+
)
470+
def test_svg_generation_evaluation_groupwise(rows: List[EvaluationRow]) -> List[EvaluationRow]:
471+
"""
472+
Test SVG generation and evaluation using SVGBench methodology.
473+
474+
This test:
475+
1. Extracts SVG code from the model's response
476+
2. Renders SVG to PNG using Selenium
477+
3. Uses LLM judge to evaluate requirement fulfillment
478+
4. Calculates score based on fulfilled requirements ratio
479+
480+
Args:
481+
row: EvaluationRow with model's SVG generation response
482+
483+
Returns:
484+
EvaluationRow with evaluation results
485+
"""
486+
# Extract dataset info
487+
image_paths = []
488+
requirements = rows[0].input_metadata.dataset_info["requirements"]
489+
for row in rows:
490+
row_id = row.input_metadata.row_id
491+
492+
# Check if we should save debug files
493+
save_debug_files = os.environ.get("SVGBENCH_SAVE_DEBUG_FILES", "false").lower() == "true"
494+
495+
# Get model response
496+
if not row.messages or len(row.messages) < 2:
497+
row.evaluation_result = EvaluateResult(score=0.0, reason="No model response found")
498+
continue
499+
500+
model_response = row.messages[-1].content
501+
502+
# Extract SVG code with better error reporting (matching original)
503+
try:
504+
svg_code = extract_svg_code(model_response)
505+
if not svg_code:
506+
raise ValueError("No valid SVG code found in response")
507+
except Exception as e:
508+
logger.error(f"Error extracting SVG code for question {row_id}: {e}")
509+
if save_debug_files:
510+
logger.error(f"Full response: {model_response}")
511+
512+
row.evaluation_result = EvaluateResult(score=0.0, reason=f"SVG extraction failed: {str(e)}")
513+
continue
514+
515+
# Setup file paths
516+
if save_debug_files:
517+
# Create debug directory
518+
model = row.input_metadata.completion_params["model"]
519+
# Sanitize model name for filesystem (replace slashes with underscores)
520+
safe_model_name = model.replace("/", "_").replace(":", "_")
521+
debug_dir = "svgbench_debug"
522+
os.makedirs(debug_dir, exist_ok=True)
523+
png_path = os.path.join(debug_dir, f"question_{row_id}_{safe_model_name}.png")
524+
svg_path = os.path.join(debug_dir, f"question_{row_id}_{safe_model_name}.svg")
525+
# Save SVG file for debugging
526+
with open(svg_path, "w") as f:
527+
f.write(svg_code)
528+
else:
529+
# Use temporary file
530+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
531+
png_path = f.name
532+
image_paths.append(png_path)
533+
try:
534+
# Render SVG to PNG
535+
if not render_svg_to_png(svg_code, png_path):
536+
row.evaluation_result = EvaluateResult(score=0.0, reason="Failed to render SVG to PNG")
537+
538+
except Exception as e:
539+
logger.error(f"Evaluation failed for question {row_id}: {e}")
540+
row.evaluation_result = EvaluateResult(score=0.0, reason=f"Evaluation error: {str(e)}")
541+
542+
judge_result = evaluate_with_llm_judge_groupwise(image_paths, requirements)
543+
print(f"********** judge_result: {judge_result} **********")
544+
if judge_result.get("best_image_index") == 0:
545+
rows[0].evaluation_result = EvaluateResult(score=1.0, reason=judge_result.get("reasoning", ""))
546+
rows[1].evaluation_result = EvaluateResult(score=0.0, reason=judge_result.get("reasoning", ""))
547+
else:
548+
rows[0].evaluation_result = EvaluateResult(score=0.0, reason=judge_result.get("reasoning", ""))
549+
rows[1].evaluation_result = EvaluateResult(score=1.0, reason=judge_result.get("reasoning", ""))
550+
551+
# Clean up temporary PNG file (only if not saving debug files)
552+
if not save_debug_files:
553+
for png_path in image_paths:
554+
if os.path.exists(png_path):
555+
os.unlink(png_path)
556+
557+
return rows

0 commit comments

Comments
 (0)