Skip to content

Commit 41ac050

Browse files
author
root
committed
construct the framework of agent-v1
1 parent 287e1f8 commit 41ac050

9 files changed

Lines changed: 774 additions & 2 deletions

File tree

agent_runner.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import argparse
2+
import copy
3+
import os
4+
import os.path as osp
5+
from concurrent.futures import ThreadPoolExecutor, as_completed
6+
from typing import Any, Dict, List, Tuple
7+
8+
from tqdm import tqdm
9+
10+
from scieval.agents.records import EvalRecord, TrajectoryStore
11+
from scieval.agents.smolagents import SmolAgentsAgent
12+
from scieval.dataset import build_dataset
13+
from scieval.smp import dump, get_logger, load, timestr, githash, ls
14+
15+
16+
def _build_dataset_from_config(cfg: Dict[str, Any], dataset_name: str):
17+
import inspect
18+
import scieval.dataset as dataset_mod
19+
20+
config = copy.deepcopy(cfg[dataset_name])
21+
if config == {}:
22+
return build_dataset(dataset_name)
23+
if "class" not in config:
24+
return build_dataset(dataset_name, **config)
25+
cls_name = config.pop("class")
26+
if hasattr(dataset_mod, cls_name):
27+
cls = getattr(dataset_mod, cls_name)
28+
sig = inspect.signature(cls.__init__)
29+
valid_params = {k: v for k, v in config.items() if k in sig.parameters}
30+
return cls(**valid_params)
31+
raise ValueError(f"Dataset class {cls_name} is not supported in scieval.dataset")
32+
33+
34+
def _build_agent_from_config(cfg: Dict[str, Any], agent_name: str):
35+
config = copy.deepcopy(cfg[agent_name])
36+
cls_name = config.pop("class", "SmolAgentsAgent")
37+
if cls_name not in ["SmolAgentsAgent", "smolagents"]:
38+
raise ValueError(f"Unsupported agent class: {cls_name}")
39+
return SmolAgentsAgent(**config)
40+
41+
42+
def _run_one_sample(
43+
idx: int,
44+
agent,
45+
dataset,
46+
store: TrajectoryStore,
47+
judge_kwargs: Dict[str, Any],
48+
reuse: bool,
49+
do_infer: bool,
50+
do_eval: bool,
51+
) -> Tuple[int, Dict[str, Any], str]:
52+
final_answer = ""
53+
traj = store.load_traj(idx) if reuse else None
54+
if do_infer:
55+
if traj and traj.get("success"):
56+
final_answer = traj.get("final_answer", "")
57+
else:
58+
sample = dataset.build_agent_sample(idx)
59+
result = agent.run(sample)
60+
store.save_traj(idx, result)
61+
final_answer = result.final_answer
62+
elif traj:
63+
final_answer = traj.get("final_answer", "")
64+
65+
if not do_eval:
66+
return idx, {}, final_answer
67+
68+
eval_cached = store.load_eval(idx) if reuse else None
69+
if eval_cached is not None:
70+
cached_score = eval_cached.get("score", eval_cached)
71+
cached_final = eval_cached.get("final_answer", final_answer)
72+
return idx, cached_score, cached_final
73+
74+
score = dataset.score_agent_sample(idx, final_answer, **judge_kwargs)
75+
metadata = {}
76+
if "question" in score:
77+
metadata["question"] = score["question"]
78+
if "answer" in score:
79+
metadata["answer"] = score["answer"]
80+
record = EvalRecord(index=idx, final_answer=final_answer, score=score, metadata=metadata)
81+
store.save_eval(idx, record)
82+
return idx, score, final_answer
83+
84+
85+
def _is_number(value: Any) -> bool:
86+
return isinstance(value, (int, float)) and not isinstance(value, bool)
87+
88+
89+
def run_agent_eval(
90+
agent,
91+
dataset,
92+
work_dir: str,
93+
nproc: int = 1,
94+
reuse: bool = False,
95+
mode: str = "all",
96+
judge_kwargs: Dict[str, Any] = None,
97+
):
98+
logger = get_logger("AGENT_EVAL")
99+
judge_kwargs = judge_kwargs or {}
100+
dataset_name = getattr(dataset, "dataset_name", dataset.__class__.__name__)
101+
root_dir = osp.join(work_dir, "agent_eval", dataset_name, agent.name, agent.model_version)
102+
eval_id = f"T{timestr('day')}_G{githash(digits=8)}"
103+
log_dir = osp.join(root_dir, eval_id)
104+
if reuse and osp.exists(root_dir):
105+
prev_runs = ls(root_dir, mode="dir")
106+
if prev_runs:
107+
prev_runs.sort()
108+
log_dir = prev_runs[-1]
109+
store = TrajectoryStore(log_dir)
110+
logger.info(f"Logging directory: {log_dir}")
111+
112+
do_infer = mode in ["all", "infer"]
113+
do_eval = mode in ["all", "eval"]
114+
115+
results: List[Tuple[int, Dict[str, Any], str]] = []
116+
tasks = list(range(len(dataset)))
117+
if nproc > 1:
118+
with ThreadPoolExecutor(max_workers=nproc) as executor:
119+
futures = [
120+
executor.submit(
121+
_run_one_sample,
122+
idx,
123+
agent,
124+
dataset,
125+
store,
126+
judge_kwargs,
127+
reuse,
128+
do_infer,
129+
do_eval,
130+
)
131+
for idx in tasks
132+
]
133+
with tqdm(total=len(tasks), desc="Agent Eval", unit="sample") as pbar:
134+
for fut in as_completed(futures):
135+
results.append(fut.result())
136+
pbar.update(1)
137+
else:
138+
with tqdm(total=len(tasks), desc="Agent Eval", unit="sample") as pbar:
139+
for idx in tasks:
140+
results.append(
141+
_run_one_sample(
142+
idx, agent, dataset, store, judge_kwargs, reuse, do_infer, do_eval
143+
)
144+
)
145+
pbar.update(1)
146+
147+
results.sort(key=lambda x: x[0])
148+
predictions = [{"index": idx, "prediction": final_answer} for idx, _, final_answer in results]
149+
pred_file = osp.join(log_dir, f"{agent.name}_{dataset_name}.json")
150+
dump(predictions, pred_file)
151+
152+
agg: Dict[str, List[float]] = {}
153+
for _, score, _ in results:
154+
for k, v in score.items():
155+
if _is_number(v):
156+
agg.setdefault(k, []).append(float(v))
157+
158+
summary = {k: (sum(v) / len(v) if v else 0.0) for k, v in agg.items()}
159+
summary_file = osp.join(log_dir, "summary.json")
160+
dump(summary, summary_file)
161+
return summary
162+
163+
164+
def run_agent_eval_from_config(cfg: Dict[str, Any], args) -> Dict[str, Any]:
165+
logger = get_logger("AGENT_RUN")
166+
agent_cfg = cfg.get("agent") or cfg.get("agents")
167+
data_cfg = cfg.get("data")
168+
if not agent_cfg or not data_cfg:
169+
raise ValueError("Config must include 'agent' and 'data' sections for agent evaluation.")
170+
171+
if isinstance(agent_cfg, dict) and "class" in agent_cfg:
172+
agents_cfg = {"agent": agent_cfg}
173+
else:
174+
agents_cfg = agent_cfg
175+
176+
results = {}
177+
for agent_name in agents_cfg:
178+
agent = _build_agent_from_config(agents_cfg, agent_name)
179+
for dataset_name in data_cfg:
180+
dataset = _build_dataset_from_config(data_cfg, dataset_name)
181+
if dataset is None:
182+
logger.error(f"Dataset {dataset_name} is not valid, skipping.")
183+
continue
184+
summary = run_agent_eval(
185+
agent,
186+
dataset,
187+
work_dir=args.work_dir,
188+
nproc=args.api_nproc,
189+
reuse=args.reuse,
190+
mode=args.mode,
191+
judge_kwargs={
192+
"model": getattr(args, "judge", None),
193+
"api_key": os.environ.get("OPENAI_API_KEY", ""),
194+
"api_base": os.environ.get("OPENAI_API_BASE", ""),
195+
},
196+
)
197+
results[f"{agent_name}:{dataset_name}"] = summary
198+
return results
199+
200+
201+
def parse_args():
202+
parser = argparse.ArgumentParser(description="Agent evaluation runner")
203+
parser.add_argument("--config", type=str, required=True, help="Path to agent eval config JSON")
204+
parser.add_argument("--work-dir", type=str, default="./outputs", help="Output directory")
205+
parser.add_argument("--mode", type=str, default="all", choices=["all", "infer", "eval"])
206+
parser.add_argument("--api-nproc", type=int, default=1, help="Parallel agent calls")
207+
parser.add_argument("--reuse", action="store_true")
208+
parser.add_argument("--judge", type=str, default=None)
209+
return parser.parse_args()
210+
211+
212+
def main():
213+
args = parse_args()
214+
cfg = load(args.config)
215+
run_agent_eval_from_config(cfg, args)
216+
217+
218+
if __name__ == "__main__":
219+
main()

scieval/agents/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .base import AgentBase, EvalSample
2+
from .records import EvalResult, StepResult, ToolCalling, TrajectoryStore
3+
from .smolagents import SmolAgentsAgent
4+
5+
__all__ = [
6+
"AgentBase",
7+
"EvalSample",
8+
"EvalResult",
9+
"StepResult",
10+
"ToolCalling",
11+
"TrajectoryStore",
12+
"SmolAgentsAgent",
13+
]

scieval/agents/base.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Dict, List, Optional
3+
4+
5+
class EvalSample:
6+
def __init__(
7+
self,
8+
prompt: str,
9+
images: Optional[List[str]] = None,
10+
files: Optional[Dict[str, str]] = None,
11+
metadata: Optional[Dict[str, str]] = None,
12+
):
13+
self.prompt = prompt
14+
self.images = images or []
15+
self.files = files or {}
16+
self.metadata = metadata or {}
17+
18+
19+
class AgentBase(ABC):
20+
name = "agent"
21+
22+
def __init__(self, name: Optional[str] = None, model_version: Optional[str] = None, **kwargs):
23+
self.name = name or getattr(self, "name", self.__class__.__name__.lower())
24+
self.model_version = model_version or "default"
25+
26+
@abstractmethod
27+
def run(self, sample: EvalSample):
28+
pass

scieval/agents/records.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import base64
2+
import io
3+
import os
4+
from dataclasses import dataclass
5+
from typing import Any, Dict, List, Optional
6+
7+
from PIL import Image
8+
9+
from ..smp import dump, load
10+
11+
12+
def _image_to_base64(image: Image.Image) -> str:
13+
buffer = io.BytesIO()
14+
image.save(buffer, format="PNG")
15+
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
16+
return "data:image/png;base64," + img_str
17+
18+
19+
@dataclass
20+
class ToolCalling:
21+
tool_name: Any
22+
tool_input: Any
23+
tool_output: Optional[str] = None
24+
25+
def add_response(self, response: str) -> None:
26+
self.tool_output = response
27+
28+
def to_dict(self) -> Dict[str, Any]:
29+
return {
30+
"tool_name": self.tool_name,
31+
"tool_input": self.tool_input,
32+
"tool_output": self.tool_output,
33+
}
34+
35+
36+
class StepResult:
37+
def __init__(self, role: str, content: Optional[List[Dict[str, Any]]]):
38+
self.role = role
39+
self.content = content if content is not None else []
40+
self.tool_calling: List[ToolCalling] = []
41+
42+
def add_tool_calling(self, tool_result: ToolCalling) -> None:
43+
self.tool_calling.append(tool_result)
44+
45+
def to_dict(self) -> Dict[str, Any]:
46+
serialized_content: List[Dict[str, Any]] = []
47+
for item in self.content:
48+
if item.get("type") == "image" and "image" in item:
49+
image = item["image"]
50+
if isinstance(image, Image.Image):
51+
item = dict(item)
52+
item["image"] = _image_to_base64(image)
53+
serialized_content.append(item)
54+
55+
return {
56+
"role": self.role,
57+
"content": serialized_content,
58+
"tool_calling": [tc.to_dict() for tc in self.tool_calling],
59+
}
60+
61+
62+
class EvalResult:
63+
def __init__(self, success: bool, final_answer: str):
64+
self.success = success
65+
self.final_answer = final_answer
66+
self.steps: List[StepResult] = []
67+
68+
def add_step(self, step: StepResult) -> None:
69+
self.steps.append(step)
70+
71+
def to_dict(self) -> Dict[str, Any]:
72+
return {
73+
"success": self.success,
74+
"final_answer": self.final_answer,
75+
"steps": [step.to_dict() for step in self.steps],
76+
}
77+
78+
79+
class TrajectoryStore:
80+
def __init__(self, root_dir: str):
81+
self.root_dir = root_dir
82+
os.makedirs(self.root_dir, exist_ok=True)
83+
84+
def traj_path(self, idx: int) -> str:
85+
return os.path.join(self.root_dir, f"sample_{idx}_traj.json")
86+
87+
def eval_path(self, idx: int) -> str:
88+
return os.path.join(self.root_dir, f"sample_{idx}_eval.json")
89+
90+
def load_traj(self, idx: int) -> Optional[Dict[str, Any]]:
91+
path = self.traj_path(idx)
92+
if not os.path.exists(path):
93+
return None
94+
return load(path)
95+
96+
def load_eval(self, idx: int) -> Optional[Dict[str, Any]]:
97+
path = self.eval_path(idx)
98+
if not os.path.exists(path):
99+
return None
100+
return load(path)
101+
102+
def save_traj(self, idx: int, result: EvalResult) -> None:
103+
dump(result.to_dict(), self.traj_path(idx))
104+
105+
def save_eval(self, idx: int, record: Any) -> None:
106+
if hasattr(record, "to_dict"):
107+
record = record.to_dict()
108+
dump(record, self.eval_path(idx))
109+
110+
111+
class EvalRecord:
112+
def __init__(
113+
self,
114+
index: int,
115+
final_answer: str,
116+
score: Dict[str, Any],
117+
metadata: Optional[Dict[str, Any]] = None,
118+
):
119+
self.index = index
120+
self.final_answer = final_answer
121+
self.score = score
122+
self.metadata = metadata or {}
123+
124+
def to_dict(self) -> Dict[str, Any]:
125+
return {
126+
"index": self.index,
127+
"final_answer": self.final_answer,
128+
"score": self.score,
129+
"metadata": self.metadata,
130+
}

0 commit comments

Comments
 (0)