-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreward_function.py
More file actions
79 lines (62 loc) · 2.63 KB
/
reward_function.py
File metadata and controls
79 lines (62 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import importlib
import os
import threading
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Dict, Optional
CONFIG_ENV_VAR = "TRAJECTORY_REWARD_CONFIG"
DEFAULT_CONFIG_FILENAME = "config.yaml"
_pipeline_module: Optional[ModuleType] = None
_loaded_config_path: Optional[str] = None
_module_lock = threading.Lock()
def _resolve_config_path(config_path: Optional[str]) -> str:
raw_path = (config_path or DEFAULT_CONFIG_FILENAME).strip()
if not raw_path:
raw_path = DEFAULT_CONFIG_FILENAME
return str(Path(raw_path).expanduser().resolve())
def get_config_model_name(config_path: Optional[str] = None) -> str:
resolved_config_path = _resolve_config_path(config_path)
try:
import yaml
except ModuleNotFoundError as exc:
raise RuntimeError(
"Missing dependency 'pyyaml'. Install it with: pip install pyyaml"
) from exc
with open(resolved_config_path, "r", encoding="utf-8") as f:
loaded = yaml.safe_load(f) or {}
if not isinstance(loaded, dict):
raise ValueError(f"Config root must be a mapping, got: {type(loaded)}")
model = loaded.get("model", {})
if not isinstance(model, dict):
raise ValueError("Config section 'model' must be a mapping.")
model_name = str(model.get("model_name", "")).strip()
if not model_name:
raise ValueError("Config 'model.model_name' cannot be empty.")
return model_name
def _load_run_score_trajectory(config_path: str) -> Callable[..., Dict[str, Any]]:
global _pipeline_module, _loaded_config_path
with _module_lock:
os.environ[CONFIG_ENV_VAR] = config_path
if _pipeline_module is None:
_pipeline_module = importlib.import_module("trajectory_reward.pipeline")
_loaded_config_path = config_path
elif _loaded_config_path != config_path:
config_module = importlib.import_module("trajectory_reward.config")
importlib.reload(config_module)
_pipeline_module = importlib.reload(_pipeline_module)
_loaded_config_path = config_path
run_score_trajectory = getattr(_pipeline_module, "run_score_trajectory")
return run_score_trajectory
def score_trajectory(
traj_dir: str,
agent_name: Optional[str] = None,
config_path: Optional[str] = None,
output_dir: Optional[str] = None,
) -> Dict[str, Any]:
resolved_config_path = _resolve_config_path(config_path)
run_score_trajectory = _load_run_score_trajectory(resolved_config_path)
return run_score_trajectory(
traj_dir=traj_dir,
agent_name=agent_name,
task_output_dir=output_dir,
)