-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
371 lines (315 loc) · 17.6 KB
/
train.py
File metadata and controls
371 lines (315 loc) · 17.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
import json
from pathlib import Path
import click
import dnnlib
import torch
import numpy as np
import matplotlib.pyplot as plt
# Local imports
from training.networks import MLP, Identity
from training.loss import ColtLoss
from training.training_loop import train_loop, evaluate_all
from baselines.C2ST import training_loop as c2st_training_loop
from data import get_mean_fn, get_std_fn
from data.priors import GaussianPrior, SimpleGaussianPrior, UniformAnglePrior
from data.posteriors import GaussianPosterior, GaussianMixturePosterior, TreePosterior, NoisyAnglePosterior
from data.approximations import GaussianApproximatePosterior, TreeApproximatePosterior, DiffusionApproximatePosterior
from data.sampler import DataSampler
from data.diffusion import train_edm_xpred as diffusion_trainer
import os
# --- Constants ---
HIDDEN_DIM = 256
NUM_HIDDEN_LAYERS = 3
# --- Helper Functions ---
def setup_environment(opts: dnnlib.EasyDict) -> Path:
"""
Sets up the output directory, seeds random number generators,
and converts integer click options to booleans.
"""
# Setup output directory
output_path = Path(opts.output_dir) / (
f"|x:{opts.x_dim}|theta:{opts.theta_dim}|sampling:{opts.approx_sampling_mathod}|"
f"alpha:{opts.alpha}|n_sim:{opts.n_sim}|sigma:{opts.sigma}|seed:{opts.seed}"
)
output_path.mkdir(parents=True, exist_ok=True)
print(f"Saving results to: {output_path}")
# Seed initialization
torch.manual_seed(opts.seed)
np.random.seed(opts.seed)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Convert integer booleans to actual booleans
opts.use_skip = bool(opts.use_skip)
opts.use_batchnorm = bool(opts.use_batchnorm)
opts.normalize_input = bool(opts.normalize_input)
opts.early_stop = bool(opts.early_stop)
return output_path
def build_data_sampler(opts: dnnlib.EasyDict) -> DataSampler:
"""Constructs the data sampler based on configuration options."""
device = opts.device
# Prior p(x)
if opts.approx_sampling_mathod == "tree":
prior = SimpleGaussianPrior(device=device)
elif opts.approx_sampling_mathod == "diffusion":
prior = UniformAnglePrior(device=device)
else:
prior = GaussianPrior(
mu=torch.ones(opts.x_dim, device=device),
sigma=torch.ones(opts.x_dim, device=device),
device=device,
)
# Mean function
if opts.mean_fn == "linear":
mean_fn = get_mean_fn("linear", scale=opts.scale, shift=opts.shift)
elif opts.mean_fn == "gaussian_projection":
w = torch.randn(opts.x_dim, opts.theta_dim, device=device)
mean_fn = get_mean_fn("gaussian_projection", w=w)
else:
raise ValueError(f"Unknown mean_fn: {opts.mean_fn}")
# Std function
if opts.std_fn == "one_diag":
std_fn = get_std_fn("one_diag", sigma=1.0, dim=opts.theta_dim)
elif opts.std_fn == "scale_diag":
std_fn = get_std_fn("scale_diag", scale_list=[opts.sigma] * opts.theta_dim) # Using sigma option
elif opts.std_fn == "ar1":
std_fn = get_std_fn("ar1", corr=0.9, dim=opts.theta_dim, sigma=opts.sigma)
elif opts.std_fn == "x_ar1":
# Need to ensure dimensions match for x_ar1
wx = torch.randn(opts.x_dim, 1, device=device) # Assuming x_dim x 1
std_fn = get_std_fn("x_ar1", corr=0.9, dim=opts.theta_dim, sigma=opts.sigma, wx=wx)
else:
raise ValueError(f"Unknown std_fn: {opts.std_fn}")
# Posterior definitions
if opts.approx_sampling_mathod == "collapse":
posterior = GaussianMixturePosterior(mu_fn=mean_fn, sigma_fn=std_fn, alpha=opts.alpha, device=device)
elif opts.approx_sampling_mathod == "tree":
posterior = TreePosterior(device=device, sigma_max=1e-2)
elif opts.approx_sampling_mathod == "diffusion":
posterior = NoisyAnglePosterior(device=device,noise_std=1e-4)
else:
posterior = GaussianPosterior(mu_fn=mean_fn, sigma_fn=std_fn, device=device)
# Approximate posterior
approx_posterior = GaussianApproximatePosterior(mean_fn=mean_fn, std_fn=std_fn, prior=prior, device=device)
# Set sampling method if not default or collapse
if opts.approx_sampling_mathod not in ["default", "collapse", "blind", "tree", "diffusion"]:
approx_posterior.set_sampling_method(method=opts.approx_sampling_mathod, alpha=opts.alpha)
elif opts.approx_sampling_mathod == "blind":
if opts.alpha > 0.0: # alpha larger than 0 means there is from blind sample
approx_posterior.set_sampling_method(method="blind")
else: # otherwise the approx posterior is the same as the truth
approx_posterior.set_sampling_method(method="default")
elif opts.approx_sampling_mathod == "tree":
approx_posterior = TreeApproximatePosterior(
device=device,
sigma_max=1e-2, alpha=opts.alpha
)
elif opts.approx_sampling_mathod == "diffusion":
x_samples = prior.sample(32768)
theta_p_samples = posterior.sample(x_samples, n_posterior_per_x=1).squeeze(0)
if os.path.exists("/scratch/CoLT/diffusion_model.pt"):
diffusion_model = diffusion_trainer(x_samples, theta_p_samples, epochs=1, device=device, lr=1e-4)
diffusion_model.load_state_dict(torch.load("/scratch/CoLT/diffusion_model.pt", map_location=device))
diffusion_model.to(device)
diffusion_model.eval()
else:
diffusion_model = diffusion_trainer(x_samples, theta_p_samples, epochs=50000, device=device, lr=1e-4)
diffusion_model.eval()
torch.save(diffusion_model.state_dict(), "/scratch/CoLT/diffusion_model.pt")
posterior = DiffusionApproximatePosterior(diffusion_model, step=20, device=device, num_steps=20) # make full sampling as the ground truth
approx_posterior = DiffusionApproximatePosterior(diffusion_model, step=int(20-opts.alpha), device=device, num_steps=20)
else:
pass # Default method is already set
# Create sampler
sampler = DataSampler(prior, posterior, approx_posterior, opts.forward_operator, opts.theta_dim, opts.device, opts.seed)
return sampler
def sample_data(sampler: DataSampler, opts: dnnlib.EasyDict):
"""Samples data (x, theta_p, theta_q) using the provided sampler."""
print("Sampling data...")
x_samples = sampler.sample_x(n_sim=opts.n_sim)
# Sample one ground truth posterior sample per x
theta_p_samples = sampler.sample_posterior(x_samples, n_posterior_per_x=1).squeeze(0)
# Sample multiple approximate posterior samples per x
theta_q_samples = sampler.sample_approx_posterior(x_samples, n_posterior_per_x=opts.n_posterior_per_x)
print("Data sampling complete.")
return x_samples, theta_p_samples, theta_q_samples
def build_c2st_components(opts: dnnlib.EasyDict):
"""Builds the C2ST classifier, optimizer, and scheduler."""
print("Building C2ST components...")
device = opts.device
classifier = MLP(
input_dim=opts.theta_dim + opts.x_dim,
output_dim=1,
hidden_dim=HIDDEN_DIM,
hidden_layers=NUM_HIDDEN_LAYERS,
device=device,
use_skip=opts.use_skip,
use_batchnorm=opts.use_batchnorm,
normalize_input=opts.normalize_input,
)
optimizer = torch.optim.Adam(classifier.parameters(), lr=opts.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=opts.epochs)
print("C2ST components built.")
return {'classifier': classifier, 'optimizer': optimizer, 'scheduler': scheduler}
def build_colt_components(opts: dnnlib.EasyDict):
"""Builds the COLT ref_net, distance_net, and their identity counterparts, plus optimizers, schedulers, and losses."""
print("Building COLT components...")
device = opts.device
# Main COLT components
ref_net = MLP(opts.x_dim, opts.theta_dim, HIDDEN_DIM, NUM_HIDDEN_LAYERS, device, opts.use_skip, opts.use_batchnorm, opts.normalize_input)
distance_net = MLP(opts.theta_dim, opts.distance_dim, HIDDEN_DIM, NUM_HIDDEN_LAYERS, device, opts.use_skip, opts.use_batchnorm, opts.normalize_input)
optimizer = torch.optim.Adam(list(ref_net.parameters()) + list(distance_net.parameters()), lr=opts.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=opts.epochs)
loss_fn = ColtLoss(dist_distance="s")
# Identity baseline components
ref_net_id = MLP(opts.x_dim, opts.theta_dim, HIDDEN_DIM, NUM_HIDDEN_LAYERS, device, opts.use_skip, opts.use_batchnorm, opts.normalize_input)
identity_distance_net = Identity(device=device) # Identity net takes theta_dim to theta_dim or distance_dim? Assuming theta_dim -> theta_dim, so distance_dim is ignored here.
optimizer_identity = torch.optim.Adam(list(ref_net_id.parameters()) + list(identity_distance_net.parameters()), lr=opts.lr)
scheduler_identity = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_identity, T_max=opts.epochs)
loss_fn_identity = ColtLoss(dist_distance="s") # Loss function seems the same?
print("COLT components built.")
return {
'ref_net': ref_net,
'distance_net': distance_net,
'optimizer': optimizer,
'scheduler': scheduler,
'loss_fn': loss_fn,
'ref_net_id': ref_net_id,
'identity_distance_net': identity_distance_net,
'optimizer_identity': optimizer_identity,
'scheduler_identity': scheduler_identity,
'loss_fn_identity': loss_fn_identity,
}
def train_c2st_model(x_samples, theta_p_samples, theta_q_samples, c2st_components, opts: dnnlib.EasyDict):
"""Trains the C2ST classifier."""
print("Starting C2ST training...")
c2st_training_loop(
theta_q_samples, theta_p_samples, x_samples,
c2st_components['classifier'], c2st_components['optimizer'], c2st_components['scheduler'],
opts
)
print("C2ST training finished.")
def train_colt_models(x_samples, theta_p_samples, theta_q_samples, colt_components, opts: dnnlib.EasyDict, output_path: Path):
"""Trains the main COLT models and the identity baseline models."""
print("Starting COLT (Identity baseline) training...")
identity_logs = train_loop(
theta_q_samples, theta_p_samples, x_samples,
colt_components['ref_net_id'], colt_components['identity_distance_net'],
colt_components['loss_fn_identity'], colt_components['optimizer_identity'], colt_components['scheduler_identity'],
opts,
)
print("COLT (Identity baseline) training finished.")
print("Starting main COLT training...")
logs = train_loop(
theta_q_samples, theta_p_samples, x_samples,
colt_components['ref_net'], colt_components['distance_net'],
colt_components['loss_fn'], colt_components['optimizer'], colt_components['scheduler'],
opts,
)
print("Main COLT training finished.")
# Save main training logs
with open(output_path / "train_logs.json", "w") as f:
json.dump(logs, f, indent=2)
# Plot main training loss
plt.figure()
plt.plot([entry["epoch"] for entry in logs], [entry["loss"] for entry in logs], marker='o')
plt.title("COLT Training Loss over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.savefig(output_path / "loss.png")
plt.close()
print(f"Saved training logs and plot to {output_path}")
def evaluate_pipeline(sampler: DataSampler, c2st_components, colt_components, opts: dnnlib.EasyDict, output_path: Path):
"""Evaluates all trained models and saves the results."""
print("Starting evaluation...")
results, summary = evaluate_all(
sampler,
colt_components['ref_net'], colt_components['distance_net'], colt_components['loss_fn'],
c2st_components['classifier'],
colt_components['ref_net_id'], colt_components['identity_distance_net'], colt_components['loss_fn_identity'],
opts,
)
print("Evaluation finished.")
# Save evaluation results and options
for name, content in {
"eval_snapshot.json": results,
"eval_result.json": summary,
"opts.json": dict(opts),
}.items():
with open(output_path / name, "w") as f:
json.dump(content, f, indent=2)
print(f"Saved {name} to {output_path}")
return summary
# --- Main Click Command ---
@click.command()
# Output settings
@click.option("--output_dir", type=click.Path(), default="results", help="Directory to save logs and evaluation.")
# Training settings
@click.option("--epochs", type=int, default=1000, help="Number of training epochs for colt.")
@click.option("--c2st_epochs", type=int, default=1000, help="Number of training epochs for c2st.")
@click.option("--seed", type=int, default=42, help="Random seed for reproducibility.")
@click.option("--lr", type=float, default=1e-4, help="Learning rate for optimizers.")
@click.option("--device", default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use for training (e.g., 'cuda:0', 'cpu').")
@click.option("--log_interval", type=int, default=50, help="Interval (in epochs) to log training progress.")
# Model architecture settings
@click.option("--distance_dim", type=int, default=8, help="Output dimension of the distance network.")
@click.option("--use_skip", type=click.IntRange(0, 1), default=1, help="Use skip connections in MLPs (1 for True, 0 for False).")
@click.option("--use_batchnorm", type=click.IntRange(0, 1), default=1, help="Use batch normalization in MLPs (1 for True, 0 for False).")
@click.option("--normalize_input", type=click.IntRange(0, 1), default=1, help="Normalize MLP inputs (1 for True, 0 for False).")
# Early stopping
@click.option("--early_stop", type=click.IntRange(0, 1), default=0, help="Enable early stopping (1 for True, 0 for False).")
@click.option("--tol", type=float, default=0.1, help="Tolerance for early stopping criterion.")
@click.option("--patience", type=int, default=200, help="Patience for early stopping (epochs to wait for improvement).")
# Data sampling
@click.option("--n_sim", type=int, default=1000, help="Number of simulations (x samples) to generate.")
@click.option("--x_dim", type=int, default=50, help="Dimension of the observation space (x).")
@click.option("--theta_dim", type=int, default=10, help="Dimension of the parameter space (theta).")
# Mean and Std functions for Posterior
@click.option("--mean_fn", type=str, default="linear", help="Mean function for the ground truth posterior. Choose from ['linear', 'gaussian_projection'].")
@click.option("--scale", type=float, default=1.0, help="Scale parameter for linear mean function.")
@click.option("--shift", type=float, default=0.0, help="Shift parameter for linear mean function.")
@click.option("--std_fn", type=str, default="x_ar1", help="Std function for the ground truth posterior. Choose from ['one_diag', 'scale_diag', 'ar1', 'x_ar1'].")
@click.option("--sigma", type=float, default=1.0, help="Base scale parameter for std functions.")
@click.option("--forward_operator", type=int, default=0, help="whether to use forward operator (1 for True, 0 for False).")
# Approximate posterior sampling
@click.option("--approx_sampling_mathod", type=str, default="default",
help="Method for sampling the approximate posterior. Choose from ['default', 'perturbed_var', 'perturbed_mean', 'distorted_var', 'tail', 'blind', 'mixture', 'collapse', 'diffusion'].")
@click.option("--alpha", type=float, default=0.1, help="Parameter alpha used in some approximate sampling methods.")
@click.option("--eps", type=float, default=1e-6, help="Small epsilon value for numerical stability.")
# Evaluation
@click.option("--n_posterior_per_x", type=int, default=100, help="Number of approximate posterior samples per x for training and evaluation.")
@click.option("--n_eval", type=int, default=100, help="Number of x samples to use for final evaluation.")
@click.option("--normalize_theta", type=bool, default=False, help="Normalize theta samples during evaluation.")
@click.option("--p_threshold", type=float, default=0.05, help="P-value threshold for C2ST evaluation.")
@click.option("--c2st_balance", type=bool, default=True, help="Balance positive/negative samples for C2ST evaluation.")
def main(**kwargs):
"""
Main function to run the COLT and C2ST experiments.
Sets up environment, builds data sampler and models, trains, and evaluates.
"""
# Use dnnlib.EasyDict for convenient access to options
opts = dnnlib.EasyDict(kwargs)
opts.forward_operator = bool(opts.forward_operator)
print(f"Running with options:\n{json.dumps(dict(opts), indent=2)}")
# 1. Setup environment (output directory, seeds, boolean conversion)
output_path = setup_environment(opts)
# 2. Build data sampler (prior, posterior, approx posterior, sampler)
sampler = build_data_sampler(opts)
# 3. Sample data for training and evaluation
x_samples, theta_p_samples, theta_q_samples = sample_data(sampler, opts)
# 4. Build C2ST components (classifier, optimizer, scheduler)
c2st_components = build_c2st_components(opts)
# 5. Build COLT components (ref_net, distance_net, identity baseline, optimizers, schedulers, losses)
colt_components = build_colt_components(opts)
# 6. Train C2ST classifier
train_c2st_model(x_samples, theta_p_samples, theta_q_samples, c2st_components, opts)
# 7. Train COLT models (main and identity baseline)
train_colt_models(x_samples, theta_p_samples, theta_q_samples, colt_components, opts, output_path)
# 8. Evaluate all models
summary = evaluate_pipeline(sampler, c2st_components, colt_components, opts, output_path)
# Final print of key results
print(f"\nRun Summary (sigma: {opts.sigma}, alpha: {opts.alpha}):")
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()