-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
122 lines (97 loc) · 4.01 KB
/
inference.py
File metadata and controls
122 lines (97 loc) · 4.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import matplotlib.pyplot as plt
from models.DDPM.model import UNET
from models.DiT.model import DiT
import os
import hydra
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
@hydra.main(version_base=None, config_path="./configs", config_name="load")
def main(cfg: DictConfig) -> None:
if "path" in cfg and cfg.path is not None:
path = f"./runs/{cfg.path}"
assert os.path.exists(path), f"path, {path}, does not exist"
cfg = OmegaConf.load(os.path.join(path, "config.yaml"))
else:
ValueError("load not in config")
if not os.path.exists(f"{path}/samples"):
os.makedirs(f"{path}/samples")
assert cfg.model in ["UNET", "DiT"], "model must be UNET or DiT"
assert not (cfg.model == "DiT" and cfg.dit_model_config is None), (
"DiT model config must be provided"
)
assert not (cfg.model == "UNET" and cfg.unet_model_config is None), (
"UNET model config must be provided"
)
print(OmegaConf.to_yaml(cfg))
print(f"Saving images at {path}")
lr = cfg.training.lr
size = cfg.training.size
if cfg.training.dataset == "cifar":
size = (32, 32)
B_1 = cfg.training.B_1
B_T = cfg.training.B_T
T = cfg.training.T
if torch.cuda.is_available():
device = "cuda"
else:
device = "mps" if torch.backends.mps.is_available() else "cpu"
# linear based noise scheduler
beta_array = torch.linspace(B_1, B_T, T, dtype=torch.float32).to(device)
alpha_array = 1.0 - beta_array
alpha_bar_array = torch.cumprod(alpha_array, dim=0, dtype=torch.float32)
def save_image(x, path):
img = (x + 1.0) * 255.0 / 2.0
img = img.type(torch.uint8)
plt.imsave(path, img.permute(1, 2, 0).cpu().detach().numpy(), format="png")
if cfg.model == "UNET":
model = UNET(T=T, **cfg.unet_model_config)
elif cfg.model == "DiT":
model = DiT(T=T, length=size[0], **cfg.dit_model_config)
else:
raise ValueError("model must be UNET or DiT")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr)
checkpoint = torch.load(f"{path}/model.pt", weights_only=True, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
print(
"loaded model from checkpoint at step",
checkpoint["step"],
"with loss",
checkpoint["loss"],
)
print("Files loaded, setting up model ...\n\n")
print("device", device)
print("model params", sum(p.numel() for p in model.parameters()))
print("starting inference ... \n\n")
with torch.no_grad():
model.eval()
batch_size = 1
x_t = torch.randn(batch_size, 3, size[0], size[1]).to(
device
) # intitally set to normal distrubtion
for t in tqdm(reversed(range(T)), desc="Generating images"):
timesteps = t * torch.ones(batch_size).to(device).long()
eps_pred = model(x_t, timesteps)
alpha_bar_t = alpha_bar_array[timesteps].view(-1, 1, 1, 1)
beta_t = beta_array[timesteps].view(-1, 1, 1, 1)
alpha_t = alpha_array[timesteps].view(-1, 1, 1, 1)
mean = (1 / alpha_t.sqrt()) * (
x_t - (beta_t / ((1 - alpha_bar_t).sqrt())) * eps_pred
)
# mean = mean.clamp(-1, 1) # numerical stability
epsilon = torch.zeros_like(x_t)
if t > 0:
alpha_bar_t_sub1 = alpha_bar_array[timesteps - 1].view(-1, 1, 1, 1)
beta_tilde = beta_t * (1 - alpha_bar_t_sub1) / (1 - alpha_bar_t)
z = torch.randn_like(x_t)
epsilon = (beta_tilde).sqrt() * z
x_t = mean + epsilon
if t % 100 == 0:
save_image(x_t[0], f"{path}/samples/{t}.png")
img = torch.clamp(x_t, -1, 1)
save_image(img[0], f"{path}/generated_image.png")
print(f"saved image at {path}/generated_image.png")
if __name__ == "__main__":
main()