|
| 1 | +import argparse |
| 2 | +import copy |
| 3 | +import os |
| 4 | +import os.path as osp |
| 5 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 6 | +from typing import Any, Dict, List, Tuple |
| 7 | + |
| 8 | +from tqdm import tqdm |
| 9 | + |
| 10 | +from scieval.agents.records import EvalRecord, TrajectoryStore |
| 11 | +from scieval.agents.smolagents import SmolAgentsAgent |
| 12 | +from scieval.dataset import build_dataset |
| 13 | +from scieval.smp import dump, get_logger, load, timestr, githash, ls |
| 14 | + |
| 15 | + |
| 16 | +def _build_dataset_from_config(cfg: Dict[str, Any], dataset_name: str): |
| 17 | + import inspect |
| 18 | + import scieval.dataset as dataset_mod |
| 19 | + |
| 20 | + config = copy.deepcopy(cfg[dataset_name]) |
| 21 | + if config == {}: |
| 22 | + return build_dataset(dataset_name) |
| 23 | + if "class" not in config: |
| 24 | + return build_dataset(dataset_name, **config) |
| 25 | + cls_name = config.pop("class") |
| 26 | + if hasattr(dataset_mod, cls_name): |
| 27 | + cls = getattr(dataset_mod, cls_name) |
| 28 | + sig = inspect.signature(cls.__init__) |
| 29 | + valid_params = {k: v for k, v in config.items() if k in sig.parameters} |
| 30 | + return cls(**valid_params) |
| 31 | + raise ValueError(f"Dataset class {cls_name} is not supported in scieval.dataset") |
| 32 | + |
| 33 | + |
| 34 | +def _build_agent_from_config(cfg: Dict[str, Any], agent_name: str): |
| 35 | + config = copy.deepcopy(cfg[agent_name]) |
| 36 | + cls_name = config.pop("class", "SmolAgentsAgent") |
| 37 | + if cls_name not in ["SmolAgentsAgent", "smolagents"]: |
| 38 | + raise ValueError(f"Unsupported agent class: {cls_name}") |
| 39 | + return SmolAgentsAgent(**config) |
| 40 | + |
| 41 | + |
| 42 | +def _run_one_sample( |
| 43 | + idx: int, |
| 44 | + agent, |
| 45 | + dataset, |
| 46 | + store: TrajectoryStore, |
| 47 | + judge_kwargs: Dict[str, Any], |
| 48 | + reuse: bool, |
| 49 | + do_infer: bool, |
| 50 | + do_eval: bool, |
| 51 | +) -> Tuple[int, Dict[str, Any], str]: |
| 52 | + final_answer = "" |
| 53 | + traj = store.load_traj(idx) if reuse else None |
| 54 | + if do_infer: |
| 55 | + if traj and traj.get("success"): |
| 56 | + final_answer = traj.get("final_answer", "") |
| 57 | + else: |
| 58 | + sample = dataset.build_agent_sample(idx) |
| 59 | + result = agent.run(sample) |
| 60 | + store.save_traj(idx, result) |
| 61 | + final_answer = result.final_answer |
| 62 | + elif traj: |
| 63 | + final_answer = traj.get("final_answer", "") |
| 64 | + |
| 65 | + if not do_eval: |
| 66 | + return idx, {}, final_answer |
| 67 | + |
| 68 | + eval_cached = store.load_eval(idx) if reuse else None |
| 69 | + if eval_cached is not None: |
| 70 | + cached_score = eval_cached.get("score", eval_cached) |
| 71 | + cached_final = eval_cached.get("final_answer", final_answer) |
| 72 | + return idx, cached_score, cached_final |
| 73 | + |
| 74 | + score = dataset.score_agent_sample(idx, final_answer, **judge_kwargs) |
| 75 | + metadata = {} |
| 76 | + if "question" in score: |
| 77 | + metadata["question"] = score["question"] |
| 78 | + if "answer" in score: |
| 79 | + metadata["answer"] = score["answer"] |
| 80 | + record = EvalRecord(index=idx, final_answer=final_answer, score=score, metadata=metadata) |
| 81 | + store.save_eval(idx, record) |
| 82 | + return idx, score, final_answer |
| 83 | + |
| 84 | + |
| 85 | +def _is_number(value: Any) -> bool: |
| 86 | + return isinstance(value, (int, float)) and not isinstance(value, bool) |
| 87 | + |
| 88 | + |
| 89 | +def run_agent_eval( |
| 90 | + agent, |
| 91 | + dataset, |
| 92 | + work_dir: str, |
| 93 | + nproc: int = 1, |
| 94 | + reuse: bool = False, |
| 95 | + mode: str = "all", |
| 96 | + judge_kwargs: Dict[str, Any] = None, |
| 97 | +): |
| 98 | + logger = get_logger("AGENT_EVAL") |
| 99 | + judge_kwargs = judge_kwargs or {} |
| 100 | + dataset_name = getattr(dataset, "dataset_name", dataset.__class__.__name__) |
| 101 | + root_dir = osp.join(work_dir, "agent_eval", dataset_name, agent.name, agent.model_version) |
| 102 | + eval_id = f"T{timestr('day')}_G{githash(digits=8)}" |
| 103 | + log_dir = osp.join(root_dir, eval_id) |
| 104 | + if reuse and osp.exists(root_dir): |
| 105 | + prev_runs = ls(root_dir, mode="dir") |
| 106 | + if prev_runs: |
| 107 | + prev_runs.sort() |
| 108 | + log_dir = prev_runs[-1] |
| 109 | + store = TrajectoryStore(log_dir) |
| 110 | + logger.info(f"Logging directory: {log_dir}") |
| 111 | + |
| 112 | + do_infer = mode in ["all", "infer"] |
| 113 | + do_eval = mode in ["all", "eval"] |
| 114 | + |
| 115 | + results: List[Tuple[int, Dict[str, Any], str]] = [] |
| 116 | + tasks = list(range(len(dataset))) |
| 117 | + if nproc > 1: |
| 118 | + with ThreadPoolExecutor(max_workers=nproc) as executor: |
| 119 | + futures = [ |
| 120 | + executor.submit( |
| 121 | + _run_one_sample, |
| 122 | + idx, |
| 123 | + agent, |
| 124 | + dataset, |
| 125 | + store, |
| 126 | + judge_kwargs, |
| 127 | + reuse, |
| 128 | + do_infer, |
| 129 | + do_eval, |
| 130 | + ) |
| 131 | + for idx in tasks |
| 132 | + ] |
| 133 | + with tqdm(total=len(tasks), desc="Agent Eval", unit="sample") as pbar: |
| 134 | + for fut in as_completed(futures): |
| 135 | + results.append(fut.result()) |
| 136 | + pbar.update(1) |
| 137 | + else: |
| 138 | + with tqdm(total=len(tasks), desc="Agent Eval", unit="sample") as pbar: |
| 139 | + for idx in tasks: |
| 140 | + results.append( |
| 141 | + _run_one_sample( |
| 142 | + idx, agent, dataset, store, judge_kwargs, reuse, do_infer, do_eval |
| 143 | + ) |
| 144 | + ) |
| 145 | + pbar.update(1) |
| 146 | + |
| 147 | + results.sort(key=lambda x: x[0]) |
| 148 | + predictions = [{"index": idx, "prediction": final_answer} for idx, _, final_answer in results] |
| 149 | + pred_file = osp.join(log_dir, f"{agent.name}_{dataset_name}.json") |
| 150 | + dump(predictions, pred_file) |
| 151 | + |
| 152 | + agg: Dict[str, List[float]] = {} |
| 153 | + for _, score, _ in results: |
| 154 | + for k, v in score.items(): |
| 155 | + if _is_number(v): |
| 156 | + agg.setdefault(k, []).append(float(v)) |
| 157 | + |
| 158 | + summary = {k: (sum(v) / len(v) if v else 0.0) for k, v in agg.items()} |
| 159 | + summary_file = osp.join(log_dir, "summary.json") |
| 160 | + dump(summary, summary_file) |
| 161 | + return summary |
| 162 | + |
| 163 | + |
| 164 | +def run_agent_eval_from_config(cfg: Dict[str, Any], args) -> Dict[str, Any]: |
| 165 | + logger = get_logger("AGENT_RUN") |
| 166 | + agent_cfg = cfg.get("agent") or cfg.get("agents") |
| 167 | + data_cfg = cfg.get("data") |
| 168 | + if not agent_cfg or not data_cfg: |
| 169 | + raise ValueError("Config must include 'agent' and 'data' sections for agent evaluation.") |
| 170 | + |
| 171 | + if isinstance(agent_cfg, dict) and "class" in agent_cfg: |
| 172 | + agents_cfg = {"agent": agent_cfg} |
| 173 | + else: |
| 174 | + agents_cfg = agent_cfg |
| 175 | + |
| 176 | + results = {} |
| 177 | + for agent_name in agents_cfg: |
| 178 | + agent = _build_agent_from_config(agents_cfg, agent_name) |
| 179 | + for dataset_name in data_cfg: |
| 180 | + dataset = _build_dataset_from_config(data_cfg, dataset_name) |
| 181 | + if dataset is None: |
| 182 | + logger.error(f"Dataset {dataset_name} is not valid, skipping.") |
| 183 | + continue |
| 184 | + summary = run_agent_eval( |
| 185 | + agent, |
| 186 | + dataset, |
| 187 | + work_dir=args.work_dir, |
| 188 | + nproc=args.api_nproc, |
| 189 | + reuse=args.reuse, |
| 190 | + mode=args.mode, |
| 191 | + judge_kwargs={ |
| 192 | + "model": getattr(args, "judge", None), |
| 193 | + "api_key": os.environ.get("OPENAI_API_KEY", ""), |
| 194 | + "api_base": os.environ.get("OPENAI_API_BASE", ""), |
| 195 | + }, |
| 196 | + ) |
| 197 | + results[f"{agent_name}:{dataset_name}"] = summary |
| 198 | + return results |
| 199 | + |
| 200 | + |
| 201 | +def parse_args(): |
| 202 | + parser = argparse.ArgumentParser(description="Agent evaluation runner") |
| 203 | + parser.add_argument("--config", type=str, required=True, help="Path to agent eval config JSON") |
| 204 | + parser.add_argument("--work-dir", type=str, default="./outputs", help="Output directory") |
| 205 | + parser.add_argument("--mode", type=str, default="all", choices=["all", "infer", "eval"]) |
| 206 | + parser.add_argument("--api-nproc", type=int, default=1, help="Parallel agent calls") |
| 207 | + parser.add_argument("--reuse", action="store_true") |
| 208 | + parser.add_argument("--judge", type=str, default=None) |
| 209 | + return parser.parse_args() |
| 210 | + |
| 211 | + |
| 212 | +def main(): |
| 213 | + args = parse_args() |
| 214 | + cfg = load(args.config) |
| 215 | + run_agent_eval_from_config(cfg, args) |
| 216 | + |
| 217 | + |
| 218 | +if __name__ == "__main__": |
| 219 | + main() |
0 commit comments