-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDiffuser.py
More file actions
82 lines (64 loc) · 2.63 KB
/
Diffuser.py
File metadata and controls
82 lines (64 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import yaml
from tqdm import tqdm
from torchvision import transforms
import matplotlib.pyplot as plt
from diffusers import AutoencoderKL
class Diffuser:
def __init__(self, device):
with open('config.yaml', 'rb') as f:
yml=yaml.safe_load(f)
self.num_timesteps = yml['Diffuser']['num_timesteps']
self.device = device
self.beta_start = yml['Diffuser']['beta_start']
self.beta_end = yml['Diffuser']['beta_end']
self.betas=torch.linspace(self.beta_start, self.beta_end, self.num_timesteps, device=self.device)
self.alphas=1-self.betas
self.alpha_bars=torch.cumprod(self.alphas, dim=0)
#self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device)
def add_noise(self, x_0, t):
T=self.num_timesteps
assert (t>=1).all() and (t<=T).all()
t_idx=t-1
alpha_bar=self.alpha_bars[t_idx]
alpha_bar=alpha_bar.view(alpha_bar.size(0),1,1,1)
noise=torch.randn_like(x_0, device=self.device)
x_t=torch.sqrt(alpha_bar)*x_0+torch.sqrt(1-alpha_bar)*noise
return x_t, noise
def denoise(self,model,x,t,labels):
T=self.num_timesteps
assert (t>=1).all() and (t<=T).all()
t_idx=t-1
alpha=self.alphas[t_idx]
alpha_bar=self.alpha_bars[t_idx]
alpha_bar_prev=self.alpha_bars[t_idx-1]
N=alpha.size(0)
alpha=alpha.view(N,1,1,1)
alpha_bar=alpha_bar.view(N,1,1,1)
alpha_bar_prev=alpha_bar_prev.view(N,1,1,1)
model.eval()
with torch.no_grad():
eps=model(x,t,labels)
model.train()
noise=torch.randn_like(x,device=self.device)
noise[t==1]=0
mu= (x - ((1-alpha) / torch.sqrt(1-alpha_bar)) * eps) / torch.sqrt(alpha)
std=torch.sqrt((1-alpha) * (1-alpha_bar_prev) / (1-alpha_bar))
return mu+noise*std
def reverse_to_img(self, x):
x = x * 255
x = x.clamp(0, 255)
x = x.to(torch.uint8)
x = x.cpu()
to_pil = transforms.ToPILImage()
return to_pil(x)
def sample(self, model, x_shape=(20, 1, 28, 28), labels=None, num_labels=10):
batch_size = x_shape[0]
x = torch.randn(x_shape).to(self.device)
#x = torch.randn(tuple(x_shape), device=self.device) # Corrected line
if labels is None:
labels = torch.randint(0, num_labels, (len(x),), device=self.device)
for i in tqdm(range(self.num_timesteps, 0, -1)):
t = torch.tensor([i] * batch_size, device=self.device, dtype=torch.long)
x = self.denoise(model, x, t, labels)
return x, labels