-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquantum_pulse_train.py
More file actions
executable file
·73 lines (49 loc) · 2.15 KB
/
quantum_pulse_train.py
File metadata and controls
executable file
·73 lines (49 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
from utils.util import MemoryCleaner
from dataset.quantum_dataset import quantum_pulse_dataset
from models.unet import UNet
from scheduler.scheduler_ddim import DDIMScheduler
from pipeline.diffusion_pipeline import DiffusionPipeline
import functools
import torch
import torch.nn as nn
import argparse
def main(store_dir, file, batch_size, epochs):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
qds = quantum_pulse_dataset(file, device, batch_size)
print(qds.clip_dataset.x.shape)
num_chs = qds.get_num_chs()
print(f"num_chs: {num_chs}")
data_loaders = qds.get_data_loaders()
text_encoder = qds.get_text_encoder()
model = UNet(model_features=[32, 32, 64],num_chs=num_chs, t_emb_size=256, group_num=32)
# print(model.enc_chs)
model = model.to(device)
scheduler = DDIMScheduler(device=device)
scheduler.set_timesteps(100)
pipeline = DiffusionPipeline(
scheduler=scheduler, model=model, text_encoder=text_encoder, device=device
)
pipeline.guidance_sample_mode = "rescaled"
pipeline.compile(torch.optim.Adam, nn.MSELoss)
lr = 5e-5
sched = functools.partial(
torch.optim.lr_scheduler.OneCycleLR,
max_lr=lr,
total_steps=epochs * len(data_loaders.train),
)
pipeline.fit(epochs, data_loaders, lr=lr, lr_sched=sched)
print(f"store_dir: {store_dir}")
pipeline.store_pipeline(config_path=store_dir, save_path=store_dir)
del data_loaders
MemoryCleaner.purge_mem()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process a file with a given store directory.")
parser.add_argument("--store_dir", type=str, required=True, help="Directory to store diffusion model files")
parser.add_argument("--file", type=str, required=True, help="Training file to process")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training")
parser.add_argument("--epoch", type=int, default=300, help="Number of epochs for training")
args = parser.parse_args()
main(args.store_dir, args.file, args.batch_size, args.epoch)