Skip to content

Commit a439e76

Browse files
committed
fix ut
1 parent 2431dbe commit a439e76

File tree

6 files changed

+189
-9
lines changed

6 files changed

+189
-9
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,11 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
7373

7474
_litellm = importlib.import_module("litellm")
7575
acompletion = getattr(_litellm, "acompletion")
76+
logger.debug(f'********** request_params: {request_params} **********')
7677
response = await acompletion(**request_params)
7778

7879
assistant_content = response.choices[0].message.content or ""
80+
logger.debug(f'********** assistant_content: {assistant_content} **********')
7981
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
8082

8183
converted_tool_calls = None

eval_protocol/pytest/evaluation_test.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,6 @@ def _log_eval_error(
501501
index = abs(index) % (max_index + 1)
502502
row.input_metadata.row_id = generate_id(seed=0, index=index)
503503

504-
if "completion_params" not in kwargs or not kwargs["completion_params"]:
505-
raise ValueError(
506-
"No completion parameters provided. Please provide a completion parameters object."
507-
)
508504
completion_params = kwargs["completion_params"]
509505
if completion_params and ("model" not in completion_params or not completion_params["model"]):
510506
raise ValueError(
@@ -638,7 +634,7 @@ async def _collect_result(config, lst, max_retry):
638634
for ori_row in fresh_dataset:
639635
copied_row = ori_row.model_copy(deep=True)
640636
# overwrite the rollout_id to the index of the completion_params
641-
copied_row.execution_metadata.rollout_id = str(idx)
637+
copied_row.execution_metadata.rollout_id = str(ori_row.execution_metadata.rollout_id) + "_" + str(idx)
642638
copied_row.input_metadata.completion_params = cp
643639
lst.append(copied_row)
644640
tasks.append(asyncio.create_task(_collect_result(config, lst, max_retry)))
@@ -698,17 +694,18 @@ async def _collect_result(config, lst, max_retry):
698694
results_by_group = [
699695
[[] for _ in range(num_runs)] for _ in range(len(original_completion_params_list))
700696
]
701-
for i, result in enumerate(all_results):
697+
for i_run, result in enumerate(all_results):
702698
for r in result:
703-
results_by_group[int(r.execution_metadata.rollout_id)][i].append(r)
704-
for i, result in enumerate(results_by_group):
699+
completion_param_idx = int(r.execution_metadata.rollout_id.split("_")[1])
700+
results_by_group[completion_param_idx][i_run].append(r)
701+
for rollout_id, result in enumerate(results_by_group):
705702
postprocess(
706703
result,
707704
aggregation_method,
708705
threshold,
709706
active_logger,
710707
mode,
711-
original_completion_params_list[i],
708+
original_completion_params_list[rollout_id],
712709
test_func.__name__,
713710
num_runs,
714711
)

tests/pytest/test_pytest_async.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
],
1919
],
2020
completion_params=[{"model": "accounts/fireworks/models/kimi-k2-instruct"}],
21+
mode="listwise",
2122
)
2223
async def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
2324
"""Run math evaluation on sample dataset using pytest interface."""

tests/pytest/test_pytest_default_agent_rollout_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
],
1919
rollout_processor=AgentRolloutProcessor(),
2020
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}],
21+
mode="listwise",
2122
)
2223
def test_pytest_default_agent_rollout_processor(rows: List[EvaluationRow]) -> List[EvaluationRow]:
2324
"""Run math evaluation on sample dataset using pytest interface."""

tests/pytest/test_pytest_input_messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
],
1313
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
1414
rollout_processor=SingleTurnRolloutProcessor(),
15+
mode="listwise",
1516
)
1617
def test_input_messages_in_decorator(rows: List[EvaluationRow]) -> List[EvaluationRow]:
1718
"""Run math evaluation on sample dataset using pytest interface."""

tests/pytest/test_svgbench.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,74 @@ def evaluate_with_llm_judge(image_path: str, requirements: List[str]) -> Dict[st
264264
raise ValueError("Missing required field in response")
265265

266266

267+
def evaluate_with_llm_judge_groupwise(image_paths: List[str], requirements: List[str]) -> Dict[str, Any]:
268+
"""
269+
Use LLM judge to evaluate how many requirements are fulfilled.
270+
Uses GPT-4.1 for vision capabilities to match project's model preferences. (note original repo uses Gemini 2.5 flashs)
271+
272+
Args:
273+
image_path: Path to rendered PNG image
274+
requirements: List of requirements to evaluate
275+
276+
Returns:
277+
Dictionary with evaluation results
278+
"""
279+
# Format requirements for evaluation (exactly as in original)
280+
requirements_text = "\n".join([f"{i + 1}. {req}" for i, req in enumerate(requirements)])
281+
282+
# Create evaluation prompt with JSON response format
283+
evaluate_prompt = f"""Examine the generated images you are given. Based on the following {len(requirements)} requirements, which one is better?
284+
285+
Respond ONLY with a JSON object in this exact format:
286+
{{"best_image_index": <index>, "reasoning": <reasoning_text>}}
287+
288+
Requirements:
289+
{requirements_text}"""
290+
291+
292+
messages = [
293+
{
294+
"role": "user",
295+
"content": [
296+
{"type": "text", "text": evaluate_prompt},
297+
],
298+
}
299+
]
300+
301+
# Read and encode image
302+
for image_path in image_paths:
303+
with open(image_path, "rb") as f:
304+
image_data = base64.b64encode(f.read()).decode("utf-8")
305+
messages[0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}})
306+
307+
# Use GPT-4.1 for vision capabilities to match project's OpenAI model preference
308+
response = litellm.completion(
309+
model="gpt-4.1",
310+
messages=messages,
311+
temperature=0.0,
312+
response_format={
313+
"type": "json_schema",
314+
"json_schema": {"name": "SVGBenchResponse", "schema": SVGBenchResponse.model_json_schema()},
315+
},
316+
)
317+
318+
# Parse response
319+
response_content = response.choices[0].message.content
320+
321+
# Handle empty response
322+
if not response_content or response_content.strip() == "":
323+
raise ValueError("Empty response from LLM judge")
324+
325+
result = json.loads(response_content)
326+
327+
# Validate the result
328+
if "best_image_index" in result:
329+
return result
330+
else:
331+
raise ValueError("Missing required field in response")
332+
333+
334+
267335
@evaluation_test(
268336
input_dataset=["tests/pytest/data/svgbench_dataset.jsonl"],
269337
dataset_adapter=svgbench_to_evaluation_row,
@@ -279,6 +347,7 @@ def evaluate_with_llm_judge(image_path: str, requirements: List[str]) -> Dict[st
279347
passed_threshold=0.5, # 50% average score to pass
280348
num_runs=1,
281349
mode="pointwise",
350+
max_dataset_rows=1,
282351
max_concurrent_rollouts=50,
283352
)
284353
def test_svg_generation_evaluation(row: EvaluationRow) -> EvaluationRow:
@@ -378,3 +447,112 @@ def test_svg_generation_evaluation(row: EvaluationRow) -> EvaluationRow:
378447
os.unlink(png_path)
379448
except Exception:
380449
pass
450+
451+
452+
@evaluation_test(
453+
input_dataset=["tests/pytest/data/svgbench_dataset.jsonl"],
454+
dataset_adapter=svgbench_to_evaluation_row,
455+
completion_params=[
456+
{"temperature": 0.0, "model": "gpt-4.1"},
457+
{
458+
"temperature": 0.8,
459+
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
460+
"extra_body": {"reasoning_effort": "high"},
461+
},
462+
],
463+
rollout_processor=SingleTurnRolloutProcessor(),
464+
passed_threshold=None,
465+
num_runs=1,
466+
max_dataset_rows=3,
467+
mode="groupwise",
468+
max_concurrent_rollouts=50,
469+
)
470+
def test_svg_generation_evaluation_groupwise(rows: List[EvaluationRow]) -> List[EvaluationRow]:
471+
"""
472+
Test SVG generation and evaluation using SVGBench methodology.
473+
474+
This test:
475+
1. Extracts SVG code from the model's response
476+
2. Renders SVG to PNG using Selenium
477+
3. Uses LLM judge to evaluate requirement fulfillment
478+
4. Calculates score based on fulfilled requirements ratio
479+
480+
Args:
481+
row: EvaluationRow with model's SVG generation response
482+
483+
Returns:
484+
EvaluationRow with evaluation results
485+
"""
486+
# Extract dataset info
487+
image_paths = []
488+
requirements = rows[0].input_metadata.dataset_info["requirements"]
489+
for row in rows:
490+
row_id = row.input_metadata.row_id
491+
492+
# Check if we should save debug files
493+
save_debug_files = os.environ.get("SVGBENCH_SAVE_DEBUG_FILES", "false").lower() == "true"
494+
495+
# Get model response
496+
if not row.messages or len(row.messages) < 2:
497+
row.evaluation_result = EvaluateResult(score=0.0, reason="No model response found")
498+
continue
499+
500+
model_response = row.messages[-1].content
501+
502+
# Extract SVG code with better error reporting (matching original)
503+
try:
504+
svg_code = extract_svg_code(model_response)
505+
if not svg_code:
506+
raise ValueError("No valid SVG code found in response")
507+
except Exception as e:
508+
logger.error(f"Error extracting SVG code for question {row_id}: {e}")
509+
if save_debug_files:
510+
logger.error(f"Full response: {model_response}")
511+
512+
row.evaluation_result = EvaluateResult(score=0.0, reason=f"SVG extraction failed: {str(e)}")
513+
continue
514+
515+
# Setup file paths
516+
if save_debug_files:
517+
# Create debug directory
518+
model = row.input_metadata.completion_params["model"]
519+
# Sanitize model name for filesystem (replace slashes with underscores)
520+
safe_model_name = model.replace("/", "_").replace(":", "_")
521+
debug_dir = "svgbench_debug"
522+
os.makedirs(debug_dir, exist_ok=True)
523+
png_path = os.path.join(debug_dir, f"question_{row_id}_{safe_model_name}.png")
524+
svg_path = os.path.join(debug_dir, f"question_{row_id}_{safe_model_name}.svg")
525+
# Save SVG file for debugging
526+
with open(svg_path, "w") as f:
527+
f.write(svg_code)
528+
else:
529+
# Use temporary file
530+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
531+
png_path = f.name
532+
image_paths.append(png_path)
533+
try:
534+
# Render SVG to PNG
535+
if not render_svg_to_png(svg_code, png_path):
536+
row.evaluation_result = EvaluateResult(score=0.0, reason="Failed to render SVG to PNG")
537+
538+
except Exception as e:
539+
logger.error(f"Evaluation failed for question {row_id}: {e}")
540+
row.evaluation_result = EvaluateResult(score=0.0, reason=f"Evaluation error: {str(e)}")
541+
542+
judge_result = evaluate_with_llm_judge_groupwise(image_paths, requirements)
543+
print(f'********** judge_result: {judge_result} **********')
544+
if judge_result.get("best_image_index") == 0:
545+
rows[0].evaluation_result = EvaluateResult(score=1.0, reason=judge_result.get("reasoning", ""))
546+
rows[1].evaluation_result = EvaluateResult(score=0.0, reason=judge_result.get("reasoning", ""))
547+
else:
548+
rows[0].evaluation_result = EvaluateResult(score=0.0, reason=judge_result.get("reasoning", ""))
549+
rows[1].evaluation_result = EvaluateResult(score=1.0, reason=judge_result.get("reasoning", ""))
550+
551+
552+
# Clean up temporary PNG file (only if not saving debug files)
553+
if not save_debug_files:
554+
for png_path in image_paths:
555+
if os.path.exists(png_path):
556+
os.unlink(png_path)
557+
558+
return rows

0 commit comments

Comments
 (0)