-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexp_bandwidth.py
More file actions
128 lines (104 loc) · 4.49 KB
/
exp_bandwidth.py
File metadata and controls
128 lines (104 loc) · 4.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
import torch
import random
import wandb
from tqdm import tqdm
from utils import *
from scorers import *
from models import *
from hash_visualizer import *
# Set random seed
random.seed(21)
torch.manual_seed(21)
np.random.seed(21)
# Path parameters
experiment_name = "bandwidth"
# Model parameters
signal_type = "fourier"
MODEL = 'ngp_feature2d'
# special model config options:
# - "*ordered*" for ordered hash table
# - "*frozen*" for frozen hash table
# - "*flipped*" for flipped hash table
MODEL_NAME = f"{MODEL}_trainable"
# Training parameters
n_trials = 10
n_seeds = 3
n_samples = 50000
n = 1000
epoch = 10000
max_bandwidth = 100 if signal_type == "fourier" else 500
bandwidth_decrement = 10 if signal_type == "fourier" else 50
# Animation parameters
nframes = 30
def train(base_path, trial, n_seeds, signal_type="fourier", device="cuda", use_wandb=False):
torch.manual_seed(trial)
print("generating samples...")
sample = torch.tensor(np.linspace(0, 1, n_samples)).to(torch.float32).to(device)
# Generate full bandwidth signal
if signal_type == "fourier":
full_band_signal, coeffs, freqs, phases = generate_fourier_signal(sample, max_bandwidth, device=device)
elif signal_type == "piecewise":
full_band_signal, knot_idx, _, _ = generate_piecewise_signal(sample, max_bandwidth, seed=trial, device=device)
else:
raise ValueError("Signal type not recognized")
for bandwidth in range(max_bandwidth, 1, -bandwidth_decrement):
if signal_type == "fourier":
signal = decrement_fourier_signal(sample, coeffs, freqs, phases, bandwidth, device=device)
elif signal_type == "piecewise":
signal = decrement_piecewise_signal(sample, full_band_signal, knot_idx, bandwidth)
# signal, _, _, _ = generate_piecewise_signal(sample, bandwidth, seed=trial)
# Save data & configs
save_data(sample.cpu().numpy(), signal.cpu().numpy(), f"{base_path}/data_{bandwidth}.npy")
# Generate specific hash vals
for seed in range(n_seeds):
torch.manual_seed(seed)
# Load default model configs
configs = get_default_model_configs(MODEL)
# Get model
model = get_model(MODEL, 1, 1, [1], configs, device=device)
# Initialize model weights
if "ordered" in MODEL_NAME:
model.init_weights(ordered=True)
elif "flipped" in MODEL_NAME:
model.init_weights(flipped=True)
else:
model.init_weights()
if "frozen" in MODEL_NAME:
model.freeze_hash_table()
print("hash table weights frozen")
# Load default model optimizers and schedulers
optim, scheduler = get_default_model_opts(MODEL, model, epoch)
if use_wandb:
wandb.init(
project="1d-input-2d-feature",
entity="utmist-parsimony",
config=configs.NET._asdict(),
group=f"{MODEL_NAME}",
name=f"{trial}_{bandwidth}_{seed}",
)
# Model training
model_loss, model_preds = trainer(sample.unsqueeze(1), signal.unsqueeze(1), model, optim, scheduler, epoch, nframes, use_wandb=use_wandb)
# Animate model predictions
animate_model_preds(sample, signal, model_preds, nframes, f"{base_path}/preds_{bandwidth}_{seed}.mp4")
# Save model configs
save_configs(configs, f"{base_path}/configs.json")
# Save model loss
save_vals([model_loss], f"{base_path}/loss_{bandwidth}_{seed}.txt")
# Save model weights
torch.save(model.state_dict(), f"{base_path}/weights_{bandwidth}_{seed}.pth")
print(f"model weights saved at {base_path}/weights_{bandwidth}_{seed}.pth")
if __name__ == "__main__":
DEVICE = "cuda:1"
BASE_PATH = f"vis/{experiment_name}/{MODEL_NAME}"
EMPIRICAL_PATH = f"{BASE_PATH}/empirical"
FIGURE_PATH = f"{BASE_PATH}/figures"
create_subdirectories(EMPIRICAL_PATH)
create_subdirectories(FIGURE_PATH)
# train
for trial in range(n_trials):
empirical_save_path = f"{EMPIRICAL_PATH}/{trial}"
create_subdirectories(empirical_save_path)
train(empirical_save_path, trial, n_seeds=n_seeds, signal_type=signal_type, device=DEVICE, use_wandb=True)
# Plot
plot_segment_summary(EMPIRICAL_PATH, FIGURE_PATH, hashing=MODEL=="ngp", device=DEVICE)