-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_seeds.py
More file actions
50 lines (41 loc) · 1.61 KB
/
run_seeds.py
File metadata and controls
50 lines (41 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import sys
import os
import json
import numpy as np
import torch
from datetime import datetime
# Add paths
sys.path.append("Physics")
from train import train
RESULTS_DIR = "paper_results"
os.makedirs(RESULTS_DIR, exist_ok=True)
def run_main_results():
seeds = [42, 123, 456, 789, 1011] + [i for i in range(1000, 1015)]
model_types = ['std', 'versor', 'gns', 'hnn', 'egnn']
all_results = []
filename = f"{RESULTS_DIR}/main_results_20_seeds.json"
if os.path.exists(filename):
with open(filename, 'r') as f:
all_results = json.load(f)
done_seeds = [r['seed'] for r in all_results]
seeds = [s for s in seeds if s not in done_seeds]
for seed in seeds:
print(f"\n>>> Running Core Seeds: Seed {seed} (Main Experiments)")
try:
res = train(seed=seed, model_types=model_types)
all_results.append({"seed": seed, "metrics": res})
# Save progress
with open(filename, 'w') as f:
json.dump(all_results, f, indent=2)
except Exception as e:
print(f"Error in seed {seed}: {e}")
return all_results
def run_ablation_results():
seeds = [42, 123, 456, 789, 1011]
# Ablation configs are a bit harder to run with current train() since it doesn't expose architecture flags
# We would need to modify train() more or have multiple model classes.
# Looking at train.py, it uses VersorRotorRNN which is one of the ablation points.
# I'll modify train.py or models.py if needed, but for now let's see.
pass
if __name__ == "__main__":
run_main_results()