-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexperiments.py
More file actions
107 lines (91 loc) · 4.11 KB
/
experiments.py
File metadata and controls
107 lines (91 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from configs import configs
import os
from datetime import datetime
from typing import Dict, List
from dataclasses import dataclass, asdict
import json
import time
from hyperband import HyperbandHPO
from fabolas2 import FabolasHPO
from tqdm import tqdm
@dataclass
class ExperimentResult:
config: Dict[str, any]
metrics: Dict[str, float]
resource_used: int
status: str # 'completed', 'running', 'failed'
timestamp: str
duration: float = 0.0
def to_dict(self) -> Dict:
return asdict(self)
class Experiments:
def __init__(self, results_path = None):
self.configs = configs
self.results_path = results_path
if self.results_path:
self.checkpoint_path = self.results_path
is_resume = os.path.exists(self.checkpoint_path)
else:
self.run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
results_dir = "./results"
os.makedirs(results_dir, exist_ok = True)
self.checkpoint_path = os.path.join(results_dir, f"cp{self.run_id}.json")
is_resume = False
self.results: List[ExperimentResult] = []
if is_resume:
try:
with open(self.checkpoint_path, 'rb') as f:
self.results = json.load(f, object_hook=lambda d: ExperimentResult(*d))
except Exception as e:
print(f"[ERROR] - Failed to load previous results: {e}")
self.results = []
else:
for i, resources in enumerate(self.configs["RESOURCES"]):
for j, optimizer in enumerate(self.configs["OPTIMIZER"]):
self.results.append(
ExperimentResult(
config = {
"optimizer": optimizer,
"resources": resources,
"opt_hyperparams": self.configs["OPT_HYPERPARAMS"][j],
"dataset": self.configs["DATASET"]
},
metrics = {},
resource_used = 0,
status = 'running',
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
)
)
def run(self):
try:
for result in tqdm(self.results):
if result.status == 'running' or result.status == 'failed':
print(f"[INFO] - Running experiment with config: {result.config}")
start = time.perf_counter()
pipeline = result.config["optimizer"](
domains = self.configs["DOMAINS"],
metric_to_monitor = "test_error",
monitor_mode = min,
dataset_name = result.config["dataset"],
**result.config["opt_hyperparams"],
resource_type = result.config["resources"],
time_unit = self.configs["TIME_UNIT"],
n_runs = self.configs["N_RUNS"]
)
history = pipeline.optimize()
pipeline.plot()
result.resource_used = pipeline.total_resources_used
print("best config:", history["best_config"])
print("best_metric:", history["best_metric"])
result.duration = start - time.perf_counter()
print(f"[INFO] - Experiment completed in {result.duration:.2f} seconds.")
result.status = "completed"
except KeyboardInterrupt:
print("[INFO] - Experiment interrupted by user.")
self.save_results()
def save_results(self):
try:
with open(self.checkpoint_path, 'w') as f:
json.dump([result.to_dict() for result in self.results], f, indent = 4)
except Exception as e:
print(f"[ERROR] - Failed to save results: {e}")