-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodules.py
More file actions
executable file
·148 lines (120 loc) · 4.75 KB
/
modules.py
File metadata and controls
executable file
·148 lines (120 loc) · 4.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import math
import torch as th
import torch.nn as nn
class SiLU(nn.Module):
def forward(self, x):
return x * th.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
class EmbedFC(nn.Module):
def __init__(self, input_dim, emb_dim):
super(EmbedFC, self).__init__()
self.input_dim = input_dim
layers = [
nn.Linear(input_dim, emb_dim),
nn.GELU(),
nn.Linear(emb_dim, emb_dim),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
x = x.view(-1, self.input_dim)
return self.model(x)
def conv_nd(dims, *args, **kwargs):
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def update_ema(target_params, source_params, rate=0.99):
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
return GroupNorm32(32, channels)
def timestep_embedding(timesteps, dim, max_period=10000):
half = dim // 2
freqs = th.exp(
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
return embedding
# def cond_embedding(cond, out_channels, use_scale_shift_norm=False):
# #print(cond)
# emb_size = []
# emb_loc = []
# cond = cond.cpu()
# size_layer1= linear(1 ,2 * out_channels if use_scale_shift_norm else out_channels,)
# size_layer2= linear(2 ,2 * out_channels if use_scale_shift_norm else out_channels,)
# loc_layer1 = linear(2 ,2 * out_channels if use_scale_shift_norm else out_channels,)
# loc_layer2 = linear(4 ,2 * out_channels if use_scale_shift_norm else out_channels,)
# for i in range(cond.shape[0]):
# if int(cond[i][6])== 1:
# emb_size.append(size_layer1(cond[i][0].unsqueeze(0)))
# emb_loc.append(loc_layer1(cond[i][2:4].unsqueeze(0).squeeze()))
# elif int(cond[i][7]) == 1:
# emb_size.append(size_layer2(cond[i][0:2].unsqueeze(0).squeeze()))
# emb_loc.append(loc_layer2(cond[i][2:6].unsqueeze(0).squeeze()))
# else:
# emb_size.append(th.zeros(out_channels))
# emb_loc.append(th.zeros(out_channels))
# #print(th.stack(emb_size), th.cat(emb_loc, dim=0))
# return th.stack(emb_size), th.stack(emb_loc)
def checkpoint(func, inputs, params, flag):
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(th.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with th.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with th.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = th.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads