diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index f87e4c31..fc956528 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -135,7 +135,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> async def process_row(row: EvaluationRow) -> EvaluationRow: """Process a single row with agent rollout.""" agent = Agent( - model=config.completion_params["model"], + model=row.input_metadata.completion_params["model"], row=row, config_path=config.mcp_config_path, logger=config.logger, diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index ef1a70a7..e46f58c9 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -10,7 +10,9 @@ import time from dataclasses import replace from typing import Any, Callable, Dict, List, Literal, Optional, Union +from collections import defaultdict +from mcp.types import Completion import pytest from eval_protocol.dataset_logger import default_logger @@ -56,6 +58,176 @@ from ..common_utils import load_jsonl +def postprocess( + all_results: List[List[EvaluationRow]], + aggregation_method: AggregationMethod, + threshold: Optional[EvaluationThreshold], + active_logger: DatasetLogger, + mode: EvaluationTestMode, + completion_params: CompletionParams, + test_func_name: str, + num_runs: int, +): + scores = [ + sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result) for result in all_results + ] + agg_score = aggregate(scores, aggregation_method) + + # Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats) + ci_low: float | None = None + ci_high: float | None = None + if aggregation_method == "mean": + try: + result_ci = compute_fixed_set_mu_ci([item for sublist in all_results for item in sublist]) + _, mu_ci_low, mu_ci_high, standard_error = result_ci + if mu_ci_low is not None and mu_ci_high is not None: + ci_low = float(mu_ci_low) + ci_high = float(mu_ci_high) + # Keep agg_score as-is (mean over scores). For equal repeats per question these match. + except Exception: + ci_low = None + ci_high = None + + # Determine if the evaluation passed based on threshold + passed = None + + if threshold is not None: + success_passed, standard_error_passed = True, True + + success_passed = agg_score >= threshold.success + + if threshold.standard_error is not None and standard_error is not None: + standard_error_passed = standard_error <= threshold.standard_error + + passed = success_passed and standard_error_passed + + # Update eval metadata passed field for all results + for result in all_results: + for r in result: + if r.eval_metadata is not None: + r.eval_metadata.passed = passed + if r.evaluation_result is not None: + r.evaluation_result.agg_score = agg_score + r.evaluation_result.standard_error = standard_error + active_logger.log(r) + + # Optional: print and/or persist a summary artifact for CI + try: + should_print = os.getenv("EP_PRINT_SUMMARY") == "1" + summary_path = os.getenv("EP_SUMMARY_JSON") + suite_name = test_func_name + model_used = completion_params["model"] + total_rows = len([item for sublist in all_results for item in sublist]) + summary_obj = { + "suite": suite_name, + "model": model_used, + "agg_score": float(agg_score) if agg_score is not None else None, + "num_runs": num_runs, + "rows": total_rows, + } + if ci_low is not None and ci_high is not None: + summary_obj["agg_ci_low"] = ci_low + summary_obj["agg_ci_high"] = ci_high + + # Aggregate per-metric mean and 95% CI when available + metrics_summary: Dict[str, Dict[str, float]] = {} + + metric_scores: Dict[str, list] = defaultdict(list) + for r in [item for sublist in all_results for item in sublist]: + if r.evaluation_result and r.evaluation_result.metrics: + for m_name, m_res in r.evaluation_result.metrics.items(): + if m_res is not None and getattr(m_res, "score", None) is not None: + metric_scores[m_name].append(m_res.score) + for m_name, vals in metric_scores.items(): + if len(vals) == 0: + continue + m_mean = sum(vals) / len(vals) + m_low = None + m_high = None + if len(vals) >= 2: + try: + m_std = statistics.stdev(vals) + m_se = m_std / math.sqrt(len(vals)) + m_margin = 1.96 * m_se + m_low = max(0.0, m_mean - m_margin) + m_high = min(1.0, m_mean + m_margin) + except Exception: + m_low = None + m_high = None + entry: Dict[str, float] = {"mean": float(m_mean)} + if m_low is not None and m_high is not None: + entry["ci_low"] = float(m_low) + entry["ci_high"] = float(m_high) + metrics_summary[m_name] = entry + if metrics_summary: + summary_obj["metrics_agg"] = metrics_summary + if should_print: + if ci_low is not None and ci_high is not None: + print( + f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] runs={num_runs} rows={total_rows}" + ) + else: + print( + f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} runs={num_runs} rows={total_rows}" + ) + # As per project convention, avoid printing per-metric CI lines to reduce noise + if summary_path: + model_slug = sanitize_filename(model_used) + effort_tag = extract_effort_tag(completion_params) or "" + effort_suffix = f"__effort-{sanitize_filename(effort_tag)}" if effort_tag else "" + base_name = f"{suite_name}__{model_slug}{effort_suffix}__{mode}__runs{num_runs}.json" + + p = pathlib.Path(summary_path) + summary_obj["timestamp"] = int(time.time()) + + # When a directory is provided (or a path without .json), write per-combination files inside it + if p.suffix.lower() != ".json" or summary_path.endswith("/") or p.is_dir(): + out_dir = p + out_dir.mkdir(parents=True, exist_ok=True) + out_file = out_dir / base_name + else: + # A file path was provided + # If multiple parameterizations exist, write side-by-side files with suffixes based on base name + parent = p.parent + parent.mkdir(parents=True, exist_ok=True) + # If we detected an effort tag, fan out to separate files; otherwise write to the exact file + if effort_tag: + out_file = parent / f"{p.stem}__{sanitize_filename(effort_tag)}{p.suffix}" + else: + out_file = p + + with open(out_file, "w", encoding="utf-8") as f: + json.dump(summary_obj, f) + except Exception: + # Do not fail evaluation if summary writing fails + pass + + # # Write all rows from active_logger.read() to a JSONL file in the same directory as the summary + # try: + # if active_logger is not None: + # rows = active_logger.read() + # # Write to a .jsonl file alongside the summary file + # jsonl_path = "logs.jsonl" + # import json + + # with open(jsonl_path, "w", encoding="utf-8") as f_jsonl: + # for row in rows: + # json.dump(row.model_dump(exclude_none=True, mode="json"), f_jsonl) + # f_jsonl.write("\n") + # except Exception as e: + # # Do not fail evaluation if log writing fails + # print(e) + # pass + + # Check threshold after logging + if threshold is not None and not passed: + assert agg_score >= threshold.success, f"Aggregated score {agg_score:.3f} below threshold {threshold.success}" + if threshold.standard_error is not None and standard_error is not None: + assert standard_error <= threshold.standard_error, ( + f"Standard error {standard_error:.3f} above threshold {threshold.standard_error}" + ) + + def evaluation_test( # noqa: C901 *, completion_params: List[CompletionParams], @@ -73,7 +245,7 @@ def evaluation_test( # noqa: C901 max_concurrent_rollouts: int = 8, server_script_path: Optional[str] = None, steps: int = 30, - mode: EvaluationTestMode = "batch", + mode: EvaluationTestMode = "pointwise", combine_datasets: bool = True, logger: Optional[DatasetLogger] = None, ) -> Callable[ @@ -136,9 +308,9 @@ def evaluation_test( # noqa: C901 max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel. server_script_path: Path to the MCP server script to run (default: "examples/tau2_mcp/server.py"). steps: Number of rollout steps to execute (default: 30). - mode: Evaluation mode. "batch" (default) expects test function to handle - full dataset. "pointwise" applies test function to each row. If your evaluation requires - the full rollout of all rows to compute the score, use + mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result). + "groupwise" applies test function to a group of rollout results from the same original row (for use cases such as dpo/grpo). + "all" applies test function to the whole dataset. logger: DatasetLogger to use for logging. If not provided, a default logger will be used. """ @@ -156,8 +328,11 @@ def decorator( threshold = None sig = inspect.signature(test_func) + if not completion_params: + raise ValueError("completion_params is required") - # For pointwise/rowwise mode, we expect a different signature + # For pointwise/groupwise mode, we expect a different signature + # we expect single row to be passed in as the original row if mode == "pointwise": # Pointwise mode: function should accept messages and other row-level params if "row" not in sig.parameters: @@ -170,18 +345,33 @@ def decorator( # validate that the function has a return type of EvaluationRow if sig.return_annotation is not EvaluationRow: raise ValueError("In pointwise mode, your eval function must return an EvaluationRow instance") + + # additional check for groupwise evaluation + elif mode == "groupwise": + if "rows" not in sig.parameters: + raise ValueError("In groupwise mode, your eval function must have a parameter named 'rows'") + + # validate that "Rows" is of type List[EvaluationRow] + if sig.parameters["rows"].annotation is not List[EvaluationRow]: + raise ValueError("In groupwise mode, the 'rows' parameter must be of type List[EvaluationRow") + + # validate that the function has a return type of List[EvaluationRow] + if sig.return_annotation is not List[EvaluationRow]: + raise ValueError("In groupwise mode, your eval function must return a list of EvaluationRow instances") + if len(completion_params) < 2: + raise ValueError("In groupwise mode, you must provide at least 2 completion parameters") else: - # Batch mode: function should accept input_dataset and model + # all mode: function should accept input_dataset and model if "rows" not in sig.parameters: - raise ValueError("In batch mode, your eval function must have a parameter named 'rows'") + raise ValueError("In all mode, your eval function must have a parameter named 'rows'") # validate that "Rows" is of type List[EvaluationRow] if sig.parameters["rows"].annotation is not List[EvaluationRow]: - raise ValueError("In batch mode, the 'rows' parameter must be of type List[EvaluationRow") + raise ValueError("In all mode, the 'rows' parameter must be of type List[EvaluationRow") # validate that the function has a return type of List[EvaluationRow] if sig.return_annotation is not List[EvaluationRow]: - raise ValueError("In batch mode, your eval function must return a list of EvaluationRow instances") + raise ValueError("In all mode, your eval function must return a list of EvaluationRow instances") async def execute_with_params( test_func: TestFunction, @@ -207,16 +397,23 @@ async def execute_with_params( else: return test_func(**kwargs) - # Calculate all possible combinations of parameters + # preserve the original completion_params list for groupwise mode + original_completion_params_list = completion_params - combinations = generate_parameter_combinations( - input_dataset, - completion_params, - input_messages, - evaluation_test_kwargs, - max_dataset_rows, - combine_datasets, - ) + # Calculate all possible combinations of parameters + if mode == "groupwise": + combinations = generate_parameter_combinations( + input_dataset, None, input_messages, evaluation_test_kwargs, max_dataset_rows, combine_datasets + ) + else: + combinations = generate_parameter_combinations( + input_dataset, + completion_params, + input_messages, + evaluation_test_kwargs, + max_dataset_rows, + combine_datasets, + ) if len(combinations) == 0: raise ValueError( "No combinations of parameters were found. Please provide at least a model and one of input_dataset or input_messages." @@ -237,7 +434,7 @@ async def execute_with_params( param_tuple.append(etk) param_tuples.append(tuple(param_tuple)) - # For batch mode, use the original parameter names + # For all mode, preserve the original parameter names test_param_names = [] if input_dataset is not None: test_param_names.append("dataset_path") @@ -304,12 +501,8 @@ def _log_eval_error( index = abs(index) % (max_index + 1) row.input_metadata.row_id = generate_id(seed=0, index=index) - if "completion_params" not in kwargs or not kwargs["completion_params"]: - raise ValueError( - "No completion parameters provided. Please provide a completion parameters object." - ) completion_params = kwargs["completion_params"] - if "model" not in completion_params or not completion_params["model"]: + if completion_params and ("model" not in completion_params or not completion_params["model"]): raise ValueError( "No model provided. Please provide a model in the completion parameters object." ) @@ -338,7 +531,6 @@ def _log_eval_error( passed_threshold=threshold, passed=None, ) - for row in data: if row.input_metadata is None: row.input_metadata = InputMetadata() @@ -366,9 +558,7 @@ def _log_eval_error( logger=active_logger, kwargs=rollout_processor_kwargs or {}, ) - max_retry = int(os.getenv("EP_MAX_RETRY", "0")) - for i in range(num_runs): # Regenerate outputs each run by deep-copying the pristine dataset # so model responses are not reused across runs. @@ -416,7 +606,53 @@ async def _execute_with_semaphore(row): results = await asyncio.gather(*tasks) all_results[i] = results + elif mode == "groupwise": + # rollout all the completion_params for the same row at once, and then send the output to the test_func + row_groups = defaultdict(list) # key: row_id, value: list of rollout_result + tasks: List[asyncio.Task[List[EvaluationRow]]] = [] + # completion_groups = [] + for idx, cp in enumerate(original_completion_params_list): + config = RolloutProcessorConfig( + completion_params=cp, + mcp_config_path=mcp_config_path or "", + max_concurrent_rollouts=max_concurrent_rollouts, + server_script_path=server_script_path, + steps=steps, + logger=active_logger, + kwargs=rollout_processor_kwargs or {}, + ) + lst = [] + + async def _collect_result(config, lst, max_retry): + result = [] + async for row in rollout_processor_with_retry( + rollout_processor, lst, config, max_retry + ): + result.append(row) + return result + for ori_row in fresh_dataset: + copied_row = ori_row.model_copy(deep=True) + # overwrite the rollout_id to the index of the completion_params + copied_row.execution_metadata.rollout_id = ( + str(ori_row.execution_metadata.rollout_id) + "_" + str(idx) + ) + copied_row.input_metadata.completion_params = cp + lst.append(copied_row) + tasks.append(asyncio.create_task(_collect_result(config, lst, max_retry))) + rollout_results = await asyncio.gather(*tasks) + for result in rollout_results: + for row in result: + row_groups[row.input_metadata.row_id].append(row) + results = [] + for row_id, rows in row_groups.items(): + result = await execute_with_params( + test_func, + processed_dataset=rows, + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, + ) + results.extend(result) + all_results[i] = results else: # Batch mode: collect all results first, then evaluate (no pipelining) input_dataset = [] @@ -454,168 +690,38 @@ async def _execute_with_semaphore(row): r.eval_metadata.status = "finished" active_logger.log(r) - scores = [ - sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result) - for result in all_results - ] - agg_score = aggregate(scores, aggregation_method) - - # Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats) - ci_low: float | None = None - ci_high: float | None = None - if aggregation_method == "mean": - try: - result_ci = compute_fixed_set_mu_ci([item for sublist in all_results for item in sublist]) - _, mu_ci_low, mu_ci_high, standard_error = result_ci - if mu_ci_low is not None and mu_ci_high is not None: - ci_low = float(mu_ci_low) - ci_high = float(mu_ci_high) - # Keep agg_score as-is (mean over scores). For equal repeats per question these match. - except Exception: - ci_low = None - ci_high = None - - # Determine if the evaluation passed based on threshold - passed = None - - if threshold is not None: - success_passed, standard_error_passed = True, True - - success_passed = agg_score >= threshold.success - - if threshold.standard_error is not None and standard_error is not None: - standard_error_passed = standard_error <= threshold.standard_error - - passed = success_passed and standard_error_passed - - # Update eval metadata passed field for all results - for result in all_results: - for r in result: - if r.eval_metadata is not None: - r.eval_metadata.passed = passed - if r.evaluation_result is not None: - r.evaluation_result.agg_score = agg_score - r.evaluation_result.standard_error = standard_error - active_logger.log(r) - - # Optional: print and/or persist a summary artifact for CI - try: - should_print = os.getenv("EP_PRINT_SUMMARY") == "1" - summary_path = os.getenv("EP_SUMMARY_JSON") - suite_name = test_func.__name__ - model_used = config.completion_params["model"] - total_rows = len([item for sublist in all_results for item in sublist]) - summary_obj = { - "suite": suite_name, - "model": model_used, - "agg_score": float(agg_score) if agg_score is not None else None, - "num_runs": num_runs, - "rows": total_rows, - } - if ci_low is not None and ci_high is not None: - summary_obj["agg_ci_low"] = ci_low - summary_obj["agg_ci_high"] = ci_high - - # Aggregate per-metric mean and 95% CI when available - metrics_summary: Dict[str, Dict[str, float]] = {} - from collections import defaultdict - - metric_scores: Dict[str, list] = defaultdict(list) - for r in [item for sublist in all_results for item in sublist]: - if r.evaluation_result and r.evaluation_result.metrics: - for m_name, m_res in r.evaluation_result.metrics.items(): - if m_res is not None and getattr(m_res, "score", None) is not None: - metric_scores[m_name].append(m_res.score) - for m_name, vals in metric_scores.items(): - if len(vals) == 0: - continue - m_mean = sum(vals) / len(vals) - m_low = None - m_high = None - if len(vals) >= 2: - try: - m_std = statistics.stdev(vals) - m_se = m_std / math.sqrt(len(vals)) - m_margin = 1.96 * m_se - m_low = max(0.0, m_mean - m_margin) - m_high = min(1.0, m_mean + m_margin) - except Exception: - m_low = None - m_high = None - entry: Dict[str, float] = {"mean": float(m_mean)} - if m_low is not None and m_high is not None: - entry["ci_low"] = float(m_low) - entry["ci_high"] = float(m_high) - metrics_summary[m_name] = entry - if metrics_summary: - summary_obj["metrics_agg"] = metrics_summary - if should_print: - if ci_low is not None and ci_high is not None: - print( - f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] runs={num_runs} rows={total_rows}" - ) - else: - print( - f"EP Summary | suite={suite_name} model={model_used} agg={summary_obj['agg_score']:.3f} runs={num_runs} rows={total_rows}" - ) - # As per project convention, avoid printing per-metric CI lines to reduce noise - if summary_path: - model_slug = sanitize_filename(model_used) - effort_tag = extract_effort_tag(completion_params) or "" - effort_suffix = f"__effort-{sanitize_filename(effort_tag)}" if effort_tag else "" - base_name = f"{suite_name}__{model_slug}{effort_suffix}__{mode}__runs{num_runs}.json" - - p = pathlib.Path(summary_path) - summary_obj["timestamp"] = int(time.time()) - - # When a directory is provided (or a path without .json), write per-combination files inside it - if p.suffix.lower() != ".json" or summary_path.endswith("/") or p.is_dir(): - out_dir = p - out_dir.mkdir(parents=True, exist_ok=True) - out_file = out_dir / base_name - else: - # A file path was provided - # If multiple parameterizations exist, write side-by-side files with suffixes based on base name - parent = p.parent - parent.mkdir(parents=True, exist_ok=True) - # If we detected an effort tag, fan out to separate files; otherwise write to the exact file - if effort_tag: - out_file = parent / f"{p.stem}__{sanitize_filename(effort_tag)}{p.suffix}" - else: - out_file = p - - with open(out_file, "w", encoding="utf-8") as f: - json.dump(summary_obj, f) - except Exception: - # Do not fail evaluation if summary writing fails - pass - - # # Write all rows from active_logger.read() to a JSONL file in the same directory as the summary - # try: - # if active_logger is not None: - # rows = active_logger.read() - # # Write to a .jsonl file alongside the summary file - # jsonl_path = "logs.jsonl" - # import json - - # with open(jsonl_path, "w", encoding="utf-8") as f_jsonl: - # for row in rows: - # json.dump(row.model_dump(exclude_none=True, mode="json"), f_jsonl) - # f_jsonl.write("\n") - # except Exception as e: - # # Do not fail evaluation if log writing fails - # print(e) - # pass - - # Check threshold after logging - if threshold is not None and not passed: - assert agg_score >= threshold.success, ( - f"Aggregated score {agg_score:.3f} below threshold {threshold.success}" - ) - if threshold.standard_error is not None and standard_error is not None: - assert standard_error <= threshold.standard_error, ( - f"Standard error {standard_error:.3f} above threshold {threshold.standard_error}" + # for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them + # rollout_id is used to differentiate the result from different completion_params + if mode == "groupwise": + results_by_group = [ + [[] for _ in range(num_runs)] for _ in range(len(original_completion_params_list)) + ] + for i_run, result in enumerate(all_results): + for r in result: + completion_param_idx = int(r.execution_metadata.rollout_id.split("_")[1]) + results_by_group[completion_param_idx][i_run].append(r) + for rollout_id, result in enumerate(results_by_group): + postprocess( + result, + aggregation_method, + threshold, + active_logger, + mode, + original_completion_params_list[rollout_id], + test_func.__name__, + num_runs, ) + else: + postprocess( + all_results, + aggregation_method, + threshold, + active_logger, + mode, + completion_params, + test_func.__name__, + num_runs, + ) except AssertionError: _log_eval_error("finished", data if "data" in locals() else None, passed=False) diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index 8a3be489..597248d9 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -19,14 +19,11 @@ Dataset = List[EvaluationRow] -EvaluationTestMode = Literal["batch", "pointwise"] +EvaluationTestMode = Literal["pointwise", "groupwise", "all"] """ -"batch": (default) expects test function to handle full dataset. -"pointwise": applies test function to each row. - -How to choose between "batch" and "pointwise": -If your evaluation requires the rollout of all rows to be passed into your eval compute the score, use "batch". -If your evaluation can be computed pointwise, use "pointwise" as EP can pipeline the rollouts and evals to be faster. +"pointwise": (default) applies test function to each row (rollout result). +"groupwise": applies test function to a group of rollout results from the same original row (for use cases such as dpo/grpo). +"all": applies test function to the whole dataset. """ """ diff --git a/tests/pytest/test_pytest_async.py b/tests/pytest/test_pytest_async.py index 1cfc2db6..1acfe48d 100644 --- a/tests/pytest/test_pytest_async.py +++ b/tests/pytest/test_pytest_async.py @@ -18,6 +18,7 @@ ], ], completion_params=[{"model": "accounts/fireworks/models/kimi-k2-instruct"}], + mode="all", ) async def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]: """Run math evaluation on sample dataset using pytest interface.""" diff --git a/tests/pytest/test_pytest_default_agent_rollout_processor.py b/tests/pytest/test_pytest_default_agent_rollout_processor.py index bfabe35c..6eec4c5d 100644 --- a/tests/pytest/test_pytest_default_agent_rollout_processor.py +++ b/tests/pytest/test_pytest_default_agent_rollout_processor.py @@ -18,6 +18,7 @@ ], rollout_processor=AgentRolloutProcessor(), completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}], + mode="all", ) def test_pytest_default_agent_rollout_processor(rows: List[EvaluationRow]) -> List[EvaluationRow]: """Run math evaluation on sample dataset using pytest interface.""" diff --git a/tests/pytest/test_pytest_groupwise.py b/tests/pytest/test_pytest_groupwise.py new file mode 100644 index 00000000..295d6a7b --- /dev/null +++ b/tests/pytest/test_pytest_groupwise.py @@ -0,0 +1,28 @@ +from typing import List + +from eval_protocol.models import EvaluationRow, Message, EvaluateResult +from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test + + +@evaluation_test( + input_messages=[ + [ + Message(role="user", content="What is the capital of France?"), + ] + ], + completion_params=[ + {"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}, + {"model": "fireworks_ai/accounts/fireworks/models/gpt-4.1"}, + ], + rollout_processor=SingleTurnRolloutProcessor(), + mode="groupwise", +) +def test_pytest_groupwise(rows: List[EvaluationRow]) -> List[EvaluationRow]: + """Run math evaluation on sample dataset using pytest interface.""" + assert rows[0].input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b" + assert rows[1].input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-4.1" + rows[0].evaluation_result = EvaluateResult(score=1.0, reason="test") + rows[1].evaluation_result = EvaluateResult(score=0.0, reason="test") + print(rows[0].model_dump_json()) + print(rows[1].model_dump_json()) + return rows diff --git a/tests/pytest/test_pytest_input_messages.py b/tests/pytest/test_pytest_input_messages.py index 7b4f8d9e..f4401f22 100644 --- a/tests/pytest/test_pytest_input_messages.py +++ b/tests/pytest/test_pytest_input_messages.py @@ -12,6 +12,7 @@ ], completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], rollout_processor=SingleTurnRolloutProcessor(), + mode="all", ) def test_input_messages_in_decorator(rows: List[EvaluationRow]) -> List[EvaluationRow]: """Run math evaluation on sample dataset using pytest interface.""" diff --git a/tests/pytest/test_svgbench.py b/tests/pytest/test_svgbench.py index 90d2f8f0..e8fdb03c 100644 --- a/tests/pytest/test_svgbench.py +++ b/tests/pytest/test_svgbench.py @@ -264,6 +264,74 @@ def evaluate_with_llm_judge(image_path: str, requirements: List[str]) -> Dict[st raise ValueError("Missing required field in response") +def evaluate_with_llm_judge_groupwise(image_paths: List[str], requirements: List[str]) -> Dict[str, Any]: + """ + Use LLM judge to evaluate how many requirements are fulfilled. + Uses GPT-4.1 for vision capabilities to match project's model preferences. (note original repo uses Gemini 2.5 flashs) + + Args: + image_path: Path to rendered PNG image + requirements: List of requirements to evaluate + + Returns: + Dictionary with evaluation results + """ + # Format requirements for evaluation (exactly as in original) + requirements_text = "\n".join([f"{i + 1}. {req}" for i, req in enumerate(requirements)]) + + # Create evaluation prompt with JSON response format + evaluate_prompt = f"""Examine the generated images you are given. Based on the following {len(requirements)} requirements, which one is better? + +Respond ONLY with a JSON object in this exact format: +{{"best_image_index": , "reasoning": }} + +Requirements: +{requirements_text}""" + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": evaluate_prompt}, + ], + } + ] + + # Read and encode image + for image_path in image_paths: + with open(image_path, "rb") as f: + image_data = base64.b64encode(f.read()).decode("utf-8") + messages[0]["content"].append( + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}} + ) + + # Use GPT-4.1 for vision capabilities to match project's OpenAI model preference + response = litellm.completion( + model="gpt-4.1", + messages=messages, + temperature=0.0, + response_format={ + "type": "json_schema", + "json_schema": {"name": "SVGBenchResponse", "schema": SVGBenchResponse.model_json_schema()}, + }, + ) + + # Parse response + response_content = response.choices[0].message.content + + # Handle empty response + if not response_content or response_content.strip() == "": + raise ValueError("Empty response from LLM judge") + + result = json.loads(response_content) + + # Validate the result + if "best_image_index" in result: + return result + else: + raise ValueError("Missing required field in response") + + @evaluation_test( input_dataset=["tests/pytest/data/svgbench_dataset.jsonl"], dataset_adapter=svgbench_to_evaluation_row, @@ -279,6 +347,7 @@ def evaluate_with_llm_judge(image_path: str, requirements: List[str]) -> Dict[st passed_threshold=0.5, # 50% average score to pass num_runs=1, mode="pointwise", + max_dataset_rows=1, max_concurrent_rollouts=50, ) def test_svg_generation_evaluation(row: EvaluationRow) -> EvaluationRow: @@ -378,3 +447,111 @@ def test_svg_generation_evaluation(row: EvaluationRow) -> EvaluationRow: os.unlink(png_path) except Exception: pass + + +@evaluation_test( + input_dataset=["tests/pytest/data/svgbench_dataset.jsonl"], + dataset_adapter=svgbench_to_evaluation_row, + completion_params=[ + {"temperature": 0.0, "model": "gpt-4.1"}, + { + "temperature": 0.8, + "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", + "extra_body": {"reasoning_effort": "high"}, + }, + ], + rollout_processor=SingleTurnRolloutProcessor(), + passed_threshold=None, + num_runs=1, + max_dataset_rows=3, + mode="groupwise", + max_concurrent_rollouts=50, +) +def test_svg_generation_evaluation_groupwise(rows: List[EvaluationRow]) -> List[EvaluationRow]: + """ + Test SVG generation and evaluation using SVGBench methodology. + + This test: + 1. Extracts SVG code from the model's response + 2. Renders SVG to PNG using Selenium + 3. Uses LLM judge to evaluate requirement fulfillment + 4. Calculates score based on fulfilled requirements ratio + + Args: + row: EvaluationRow with model's SVG generation response + + Returns: + EvaluationRow with evaluation results + """ + # Extract dataset info + image_paths = [] + requirements = rows[0].input_metadata.dataset_info["requirements"] + for row in rows: + row_id = row.input_metadata.row_id + + # Check if we should save debug files + save_debug_files = os.environ.get("SVGBENCH_SAVE_DEBUG_FILES", "false").lower() == "true" + + # Get model response + if not row.messages or len(row.messages) < 2: + row.evaluation_result = EvaluateResult(score=0.0, reason="No model response found") + continue + + model_response = row.messages[-1].content + + # Extract SVG code with better error reporting (matching original) + try: + svg_code = extract_svg_code(model_response) + if not svg_code: + raise ValueError("No valid SVG code found in response") + except Exception as e: + logger.error(f"Error extracting SVG code for question {row_id}: {e}") + if save_debug_files: + logger.error(f"Full response: {model_response}") + + row.evaluation_result = EvaluateResult(score=0.0, reason=f"SVG extraction failed: {str(e)}") + continue + + # Setup file paths + if save_debug_files: + # Create debug directory + model = row.input_metadata.completion_params["model"] + # Sanitize model name for filesystem (replace slashes with underscores) + safe_model_name = model.replace("/", "_").replace(":", "_") + debug_dir = "svgbench_debug" + os.makedirs(debug_dir, exist_ok=True) + png_path = os.path.join(debug_dir, f"question_{row_id}_{safe_model_name}.png") + svg_path = os.path.join(debug_dir, f"question_{row_id}_{safe_model_name}.svg") + # Save SVG file for debugging + with open(svg_path, "w") as f: + f.write(svg_code) + else: + # Use temporary file + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + png_path = f.name + image_paths.append(png_path) + try: + # Render SVG to PNG + if not render_svg_to_png(svg_code, png_path): + row.evaluation_result = EvaluateResult(score=0.0, reason="Failed to render SVG to PNG") + + except Exception as e: + logger.error(f"Evaluation failed for question {row_id}: {e}") + row.evaluation_result = EvaluateResult(score=0.0, reason=f"Evaluation error: {str(e)}") + + judge_result = evaluate_with_llm_judge_groupwise(image_paths, requirements) + print(f"********** judge_result: {judge_result} **********") + if judge_result.get("best_image_index") == 0: + rows[0].evaluation_result = EvaluateResult(score=1.0, reason=judge_result.get("reasoning", "")) + rows[1].evaluation_result = EvaluateResult(score=0.0, reason=judge_result.get("reasoning", "")) + else: + rows[0].evaluation_result = EvaluateResult(score=0.0, reason=judge_result.get("reasoning", "")) + rows[1].evaluation_result = EvaluateResult(score=1.0, reason=judge_result.get("reasoning", "")) + + # Clean up temporary PNG file (only if not saving debug files) + if not save_debug_files: + for png_path in image_paths: + if os.path.exists(png_path): + os.unlink(png_path) + + return rows