NIPS 2019
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: you defined a validation_step but have no val_dataloader. Skipping validation loop
warnings.warn(*args, **kwargs)
/home/PU-INN/utils/optimizers.py:88: UserWarning: This overload of add_ is deprecated:
add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
add_(Tensor other, *, Number alpha) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: training_step returned None if it was on purpose, ignore this warning...
warnings.warn(*args, **kwargs)python train_img_cifar10.py --data cifar10 --actnorm True --save experiments/cifar10# coeff: contraction coefficient for linear layers / desired lipschitz constant
# n_power_iter: number of iterations for spectral normalization
# numTraceSamples: number of samples used for trace estimation
# numSeriesTerms: number of terms used in power series for matrix log
# powerIterSpectralNorm: number of power iterations used for spectral norm
# weight_decay: coefficient for weight decay
# inj_pad: initial inj padding
# resume: path to latest checkpointlib\layers\elemwise.py\LogitTransform
class LogitTransform(nn.Module):
"""
The proprocessing step used in Real NVP:
y = sigmoid(x) - a / (1 - 2a)
x = logit(a + (1 - 2a)*y)
"""resflow: $$ y = \frac{1}{1+e^{-x}} - \frac{a}{1-2a}\ x = logit(a + (1-2a)y) $$ iresnet: (a=0) $$ y = \frac{1}{1 + e^{-x}}\ x = \text{logit}(y) $$
Glow 中提出了名为 Actnorm 的层来取代 BN。不过,所谓 Actnorm 层事实上只不过是 NICE 中的尺度变换层的一般化,也就是缩放平移变换
$$
\boldsymbol{\hat{z}}=\frac{\boldsymbol{z} - \boldsymbol{\mu}}{\boldsymbol{\sigma}}
$$
其中
激活归一化层,即 Actnorm,使用可学习尺度
Actnorm 操作实际就是把输入的各个通道归一化为 0 均值,单位方差的通道数据后,进行线性变换。其操作如下:
$$ y=s\times x+b $$ 线性变换的逆变换为:
$$
x=\frac{y-b}{s}
$$
其雅可比矩阵为
iResBlock(dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(1): Swish()
(2): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(3): Swish()
(4): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
logdetgrad
g, logdetgrad = self._logdetgrad(x) #[8 3 32 32]
def _logdetgrad(self, x):
"""Returns g(x) and logdet|d(x+g(x))/dx|."""
with torch.enable_grad():
if self.n_dist == 'poisson': #true
lamb = self.lamb.item() #2.0
sample_fn = lambda m: poisson_sample(lamb, m) #poisson sample function
rcdf_fn = lambda k, offset: poisson_1mcdf(lamb, k, offset) #possion 1mcdf function
if self.training:#true
if self.n_power_series is None: #true
# Unbiased estimation.
lamb = self.lamb.item()#2.0
n_samples = sample_fn(self.n_samples) #1 > 1
n_power_series = max(n_samples) + self.n_exact_terms #1 + 2 = 3 至少
coeff_fn = lambda k: 1 / rcdf_fn(k, self.n_exact_terms) * \
sum(n_samples >= k - self.n_exact_terms) / len(n_samples)
if not self.exact_trace: #true exactly trace estimator
####################################
# Power series with trace estimator.
####################################
vareps = torch.randn_like(x)
# Choose the type of estimator.
if self.training and self.neumann_grad: #true
estimator_fn = neumann_logdet_estimator
# Do backprop-in-forward to save memory.
if self.training and self.grad_in_forward: #true
g, logdetgrad = mem_eff_wrapper(
estimator_fn, self.nnet, x, n_power_series, vareps, coeff_fn, self.training)
if self.training and self.n_power_series is None: #true
self.last_n_samples.copy_(torch.tensor(n_samples).to(self.last_n_samples))
estimator = logdetgrad.detach()
self.last_firmom.copy_(torch.mean(estimator).to(self.last_firmom))
self.last_secmom.copy_(torch.mean(estimator**2).to(self.last_secmom))
return g, logdetgrad.view(-1, 1)
def poisson_sample(lamb, n_samples): #lamb = 2.0 均值方差为2
return np.random.poisson(lamb, n_samples)
def poisson_1mcdf(lamb, k, offset): #lamb = 2.0
if k <= offset:
return 1.
else:
k = k - offset
"""P(n >= k)"""
sum = 1.
for i in range(1, k):
sum += lamb**i / math.factorial(i)
return 1 - np.exp(-lamb) * sumestimator_fn = neumann_logdet_estimator
g, logdetgrad = mem_eff_wrapper(estimator_fn, self.nnet, x, n_power_series, vareps, coeff_fn, self.training)
# self.nnet = Sequential(InducedNormConv2d() Swish() InducedNormConv2d() Swish() InducedNormConv2d())
# n_power_series = 4
# vareps = randn
# coeff_fn
# self.training = true
parser.add_argument('--', default=5, type=int, help='')
parser.add_argument('--', default=5, type=int, help='')
parser.add_argument('-', default=0., type=float, help='')
def mem_eff_wrapper(estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training):
# We need this in order to access the variables inside this module,
# since we have no other way of getting variables along the execution path.
if not isinstance(gnet, nn.Module):
raise ValueError('g is required to be an instance of nn.Module.')
return MemoryEfficientLogDetEstimator.apply(
estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *list(gnet.parameters())
)
#####################
# Logdet Estimators
#####################
class MemoryEfficientLogDetEstimator(torch.autograd.Function):
@staticmethod
def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params):
ctx.training = training #true
with torch.enable_grad():
x = x.detach().requires_grad_(True)
g = gnet(x)
ctx.g = g
ctx.x = x
logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training)
if training:
grad_x, *grad_params = torch.autograd.grad(
logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True
)
if grad_x is None:
grad_x = torch.zeros_like(x)
ctx.save_for_backward(grad_x, *g_params, *grad_params)
return safe_detach(g), safe_detach(logdetgrad)# 初始化 只是为了把u和v的shape确定下来,其中的conv2d操作只用来推到shape
def _initialize_u_v(self):
with torch.no_grad():
domain, codomain = self.compute_domain_codomain()
if self.kernel_size == (1, 1): #false
self.u.resize_(self.out_channels).normal_(0, 1)
self.u.copy_(normalize_u(self.u, codomain))
self.v.resize_(self.in_channels).normal_(0, 1)
self.v.copy_(normalize_v(self.v, domain))
else: #true [3 3]
c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item()) # 3 32 32
with torch.no_grad():
num_input_dim = c * h * w #3072
self.v.resize_(num_input_dim).normal_(0, 1) #[3072] 0-1 normal distribution
self.v.copy_(normalize_v(self.v, domain))
# forward call to infer the shape
u = F.conv2d(
self.v.view(1, c, h, w), self.weight, stride=self.stride, padding=self.padding, bias=None
)
num_output_dim = u.shape[0] * u.shape[1] * u.shape[2] * u.shape[3] #1*512*32*32=524288
# overwrite u with random init
self.u.resize_(num_output_dim).normal_(0, 1) #[1 512 32 32]>524288
self.u.copy_(normalize_u(self.u, codomain))
self.initialized.fill_(1) #change flag
# Try different random seeds to find the best u and v.
self.compute_weight(True)
best_scale = self.scale.clone()
best_u, best_v = self.u.clone(), self.v.clone()
if not (domain == 2 and codomain == 2):
for _ in range(10):
if self.kernel_size == (1, 1):
self.u.copy_(normalize_u(self.weight.new_empty(self.out_channels).normal_(0, 1), codomain))
self.v.copy_(normalize_v(self.weight.new_empty(self.in_channels).normal_(0, 1), domain))
else:
self.u.copy_(normalize_u(torch.randn(num_output_dim).to(self.weight), codomain))
self.v.copy_(normalize_v(torch.randn(num_input_dim).to(self.weight), domain))
self.compute_weight(True, n_iterations=200)
if self.scale > best_scale:
best_u, best_v = self.u.clone(), self.v.clone()
self.u.copy_(best_u)
self.v.copy_(best_v)
def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None):
if not self.initialized:
self._initialize_u_v()
if self.kernel_size == (1, 1): #false
return self._compute_weight_1x1(update, n_iterations, atol, rtol)
else: #true [3 3]
return self._compute_weight_kxk(update, n_iterations, atol, rtol)
def forward(self, input):
if not self.initialized: self.spatial_dims.copy_(torch.tensor(input.shape[2:4]).to(self.spatial_dims)) #self.spatial_dims [32 32]
weight = self.compute_weight(update=False)
return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1)
ResidualFlow(
(init_layer): LogitTransform(0.05)
(transforms): ModuleList(
(0): StackediResBlocks(
(chain): ModuleList(
(0): LogitTransform(0.05)
(1): ActNorm2d(3)
(2): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(1): Swish()
(2): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(3): Swish()
(4): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(3): ActNorm2d(3)
(4): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(5): ActNorm2d(3)
(6): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(7): ActNorm2d(3)
(8): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(9): ActNorm2d(3)
(10): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(11): ActNorm2d(3)
(12): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(13): ActNorm2d(3)
(14): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(15): ActNorm2d(3)
(16): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(17): ActNorm2d(3)
(18): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(19): ActNorm2d(3)
(20): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(21): ActNorm2d(3)
(22): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(23): ActNorm2d(3)
(24): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(25): ActNorm2d(3)
(26): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(27): ActNorm2d(3)
(28): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(29): ActNorm2d(3)
(30): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(31): ActNorm2d(3)
(32): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(33): ActNorm2d(3)
(34): SqueezeLayer()
)
)
(1): StackediResBlocks(
(chain): ModuleList(
(0): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(1): ActNorm2d(12)
(2): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(3): ActNorm2d(12)
(4): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(5): ActNorm2d(12)
(6): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(7): ActNorm2d(12)
(8): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(9): ActNorm2d(12)
(10): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(11): ActNorm2d(12)
(12): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(13): ActNorm2d(12)
(14): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(15): ActNorm2d(12)
(16): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(17): ActNorm2d(12)
(18): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(19): ActNorm2d(12)
(20): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(21): ActNorm2d(12)
(22): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(23): ActNorm2d(12)
(24): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(25): ActNorm2d(12)
(26): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(27): ActNorm2d(12)
(28): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(29): ActNorm2d(12)
(30): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(12, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(31): ActNorm2d(12)
(32): SqueezeLayer()
)
)
(2): StackediResBlocks(
(chain): ModuleList(
(0): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(1): ActNorm2d(48)
(2): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(3): ActNorm2d(48)
(4): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(5): ActNorm2d(48)
(6): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(7): ActNorm2d(48)
(8): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(9): ActNorm2d(48)
(10): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(11): ActNorm2d(48)
(12): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(13): ActNorm2d(48)
(14): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(15): ActNorm2d(48)
(16): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(17): ActNorm2d(48)
(18): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(19): ActNorm2d(48)
(20): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(21): ActNorm2d(48)
(22): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(23): ActNorm2d(48)
(24): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(25): ActNorm2d(48)
(26): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(27): ActNorm2d(48)
(28): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(29): ActNorm2d(48)
(30): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): Sequential(
(0): Swish()
(1): InducedNormConv2d(48, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormConv2d(512, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
(31): ActNorm2d(48)
(32): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): FCNet(
(nnet): Sequential(
(0): Swish()
(1): InducedNormLinear(in_features=3072, out_features=128, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormLinear(in_features=128, out_features=128, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormLinear(in_features=128, out_features=3072, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
)
(33): FCWrapper(
(fc_module): ActNorm1d(3072)
)
(34): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): FCNet(
(nnet): Sequential(
(0): Swish()
(1): InducedNormLinear(in_features=3072, out_features=128, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormLinear(in_features=128, out_features=128, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormLinear(in_features=128, out_features=3072, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
)
(35): FCWrapper(
(fc_module): ActNorm1d(3072)
)
(36): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): FCNet(
(nnet): Sequential(
(0): Swish()
(1): InducedNormLinear(in_features=3072, out_features=128, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormLinear(in_features=128, out_features=128, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormLinear(in_features=128, out_features=3072, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
)
(37): FCWrapper(
(fc_module): ActNorm1d(3072)
)
(38): iResBlock(
dist=poisson, n_samples=1, n_power_series=None, neumann_grad=True, exact_trace=False, brute_force=False
(nnet): FCNet(
(nnet): Sequential(
(0): Swish()
(1): InducedNormLinear(in_features=3072, out_features=128, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(2): Swish()
(3): InducedNormLinear(in_features=128, out_features=128, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
(4): Swish()
(5): InducedNormLinear(in_features=128, out_features=3072, bias=True, coeff=0.98, domain=2.00, codomain=2.00, n_iters=None, atol=0.001, rtol=0.001, learnable_ord=False)
)
)
)
(39): FCWrapper(
(fc_module): ActNorm1d(3072)
)
)
)
)
)基于流的生成模型通过可逆变换对概率分布进行参数化,并且可以通过最大似然来进行训练。可逆残差网络提供了灵活的转换系列,其中仅需要 Lipschitz 条件而不是严格的体系结构约束即可实施可逆性。但是,先前的工作依靠有偏的对数密度估计值训练可逆残差网络以进行密度估计,其偏差随着网络表达的增加而增加。我们使用 “俄罗斯轮盘赌” 估算器给出对数密度的易于处理的无偏估计,并通过使用替代的无穷级数梯度来减少训练期间所需的内存。此外,我们通过提议使用避免导数饱和的激活函数以及将 Lipschitz 条件推广到诱导的混合范数来改善可逆残差块。产生的方法称为“残差流”,在基于流的模型中实现了最新的密度估计性能,并且优于在联合生成和判别模型中使用 coupling blocks 的网络
最大似然是核心机器学习范式,它将学习视为分布对齐问题。然而,我们通常不清楚应该使用什么分布族(family of distributions)来拟合高维连续数据。在这方面,变量变化定理(the change of variables theorem)提供了一种构建灵活分布的吸引人的方法,该分布允许易于处理的精确采样和对其密度的有效评估。 这类模型通常称为可逆或基于流的生成模型(Deco,1995;Rezende,Variational inference with normalizing flows,2015)。
以可逆性作为其核心设计原则,基于流的模型(也称为标准化流)已证明能够生成逼真的图像(Kingma,Glow,2018 ),并且可以实现与竞争状态相当的密度估计性能 - 最先进的方法(Ho,Flow++,2019 )。 在应用中,它们已被应用于研究对抗性鲁棒性(Jacobsen 等人,2019 年),并用于使用加权最大似然目标训练具有生成和分类能力的混合模型(Nalisnick 等人,2019 年)。
现有的基于流的模型(Rezende,Variational inference with normalizing flows,2015;Kingma ,Improved variational inference with inverse autoregressive flow,2016;Dinh,NICE,2014;Chen,NODE,2018)利用具有稀疏或结构化雅可比矩阵的受限变换(Fig. 1)。这些允许在模型下有效计算对数概率,但以建筑工程为代价。扩展到高维数据的转换依赖于专门的架构,例如耦合块(Dinh,2014,2017)或求解常微分方程(Grathwohl,Ffjord,2019)。这种方法具有很强的归纳偏差,可能会阻碍它们在其他任务中的应用,例如适用于生成和判别任务的学习表示。
Figure 1: Pathways to designing scalable normalizing flows and their enforced Jacobian structure. Residual Flows fall under unbiased estimation with free-form Jacobian.
设计可扩展标准化化流及其强制雅可比结构的途径。残差流属于具有自由形式雅可比行列式的无偏估计。
Behrmann,2019 表明,残差网络(He,2016)可以通过简单地强制执行 Lipschitz 约束而变得可逆,从而允许使用非常成功的判别深度网络架构进行基于流的无监督建模。不幸的是,密度评估需要计算无穷级数。Behrmann 使用的固定截断估计量的选择导致大量偏差与网络的表现力紧密耦合,并且不能说是执行最大似然,因为偏差是在目标和梯度中引入的。
在这项工作中,我们引入了残差流,这是一种基于流的生成模型,它产生对数密度的无偏估计,并通过对数密度计算具有内存高效的反向传播。这使我们能够使用富有表现力的架构并通过最大似然进行训练。此外,我们提出并尝试使用激活函数,以避免 Lipschitz 约束神经网络的导数饱和和诱导混合范数。
为了使用随机梯度下降执行最大似然,有一个梯度的无偏估计器就足够了
$$
\begin{equation}
\nabla_{\theta} D_{\mathrm{KL}}\left(p_{\text {data }} | p_{\theta}\right)=\nabla_{\theta} \mathbb{E}{x \sim p{\text {data }}(x)}\left[\log p_{\theta}(x)\right] =\mathbb{E}{x \sim p{\text {data }}(x)}\left[\nabla_{\theta}\log p_{\theta}(x)\right]
\end{equation}\tag{1}
$$
其中
通过可逆变换
残差网络由简单的变换 $ y=f(x)=x+g(x) $ 组成。Behrmann (2019) 指出,如果
将 i-ResNets 应用于变量变化公式 Eq. 2,等式如下
$$
\begin{equation}
\log p(x)=\log p(f(x))+\operatorname{tr}\left(\sum_{k=1}^{\infty} \frac{(-1)^{k+1}}{k}\left[J_{g}(x)\right]^{k}\right)
\end{equation}\tag{3}
$$
其中 $ J_{g}(x)=\frac{d g(x)}{d x} $ 。此外,Skilling-Hutchinson 估计量(Skilling,1989;Hutchinson,1990)用于估计幂级数中的迹。Behrmann (2019) 使用固定截断来近似 Eq. 3 中的无限级数。 然而,这种朴素的方法有一个偏差,它随着
因此,固定截断估计器需要在偏差和表现力之间谨慎平衡,并且不能扩展到更高维的数据。 在不解耦目标和估计偏差的情况下,i-ResNet 最终会针对偏差进行优化,而没有改进实际的最大似然目标(见 Fig. 2)。
Fig. 2 : i-ResNets suffer from substantial bias when using expressive networks, whereas Residual Flows principledly perform maximum likelihood with unbiased stochastic gradients.
i-ResNets 在使用表达网络时会受到很大的偏差,而残差流原则上使用无偏随机梯度执行最大似然。
最大似然估计的无偏对数密度估计
由于幂级数,Eq. 3 中精确对数密度函数 $ \log p_{\theta}(\cdot) $ 的评估需要无限时间。相反,我们依靠随机化来推导出一个无偏估计量,该估计量可以基于现有概念(Kahn,1955)在有限时间内(概率为 1)计算。
让
有趣的是,虽然朴素计算总是使用无限计算,但这个无偏估计量在有限时间内被评估的概率为
$$
\begin{equation}
\sum_{k=1}^{\infty} \Delta_{k}=\mathbb{E}{n \sim p(N)}\left[\sum{k=1}^{n} \frac{\Delta_{k}}{\mathbb{P}(N \geq k)}\right]
\end{equation}\tag{5}
$$
我们注意到上面的解释只是作为一个直观的指导,而不是一个正式的推导。处理无限量的特殊性要求我们必须对
令
我们使用了 Skilling-Hutchinson 迹估计器(Skilling,1989;Hutchinson,1990)来估计矩阵
请注意,由于
Theorem 1 构成了残差流的核心,因为我们现在可以通过反向传播 Eq. 6 来执行最大似然训练以获得无偏梯度。这使我们能够训练更有表现力的网络,其中有偏差的估计器会失败(Fig. 2)。 我们为无偏估计器付出的代价是可变计算和内存,因为对数密度的每个样本都使用幂级数中的随机数项。
内存高效的反向传播
内存可能是一种稀缺资源,并且由于来自无偏估计器的大量样本而耗尽内存可能会意外停止训练。 为此,我们提出了两种方法来减少训练期间的内存消耗。
为了了解朴素的反向传播有多大问题,梯度 w.r.t. 通过幂级数(6)直接微分,参数
$$
\log p(x)=\log p(f(x)) + \mathbb{E}{n,v}\left[
\sum{k=1}^{n} \frac{(-1)^{k+1}}{k} \frac{\partial v^{T}\left(J_{g}(x, \theta)^{k}\right) v}{\partial \theta} \right]\tag{7}
$$
不幸的是,这个估计器需要将每一项都存储在内存中,因为 $ \partial / \partial \theta $ 需要应用于每一项。 总内存成本为 $ \mathcal{O}(n \cdot m) $,其中
诺依曼梯度级数
相反,我们可以将梯度具体表示为从 Neumann 级数导出的幂级数(参见 Appendix C)。 应用俄罗斯轮盘赌和迹估计器,我们得到以下定理。
让
$$
\begin{equation}
\frac{\partial}{\partial \theta} \log \operatorname{det}\left(I+J_{g}(x, \theta)\right)= \mathbb{E}{n,v}\left[ \left(\sum{k=0}^{n} \frac{(-1)^{k}}{\mathbb{P}(N \geq k)} v^{T} J(x, \theta)^{k}\right) \frac{\partial\left(J_{g}(x, \theta)\right)}{\partial \theta} v \right]
\end{equation}\tag{8}
$$
其中
由于不需要对 Eq. 8 中的幂级数进行微分,因此使用它可以将内存需求减少
Backward-in-forward: early computation of gradients 梯度的早期计算
我们可以通过在前向评估期间部分执行反向传播来进一步减少内存。 利用
$$
\begin{equation}
\frac{\partial \mathcal{L}}{\partial \theta}=\underbrace{\frac{\partial \mathcal{L}}{\partial \log \operatorname{det}\left(I+J_{g}(x, \theta)\right)}}{\text {scalar }} \underbrace{\frac{\partial \log \operatorname{det}\left(I+J{g}(x, \theta)\right)}{\partial \theta}}{\text {vector }}
\end{equation}
$$
对于每个残差块,我们计算 $ \partial \log \operatorname{det}\left(I+J{g}(x, \theta)\right) / \partial \theta $ 连同前向传递,释放计算图的内存,然后简单地乘以 $ \partial \mathcal{L} / \partial \log \operatorname{det}\left(I+J_{g}(x, \theta)\right) $ 稍后在主要反向传播期间。这将内存减少了另一个因子
请注意,虽然这两个技巧从通过
Fig. 3 在相应的幂级数中计算 n = 10 项时,每个小批量 64 个样本的内存使用量 (GB)。CIFAR10-small 在任何残差块之前使用立即下采样。
使用 LipSwish 激活函数避免导数饱和
由于对数密度取决于通过雅可比
因此,我们希望我们的激活函数 $ \phi(z) $ 具有两个属性:
- 对于所有
$z$ ,一阶导数必须限定为 $ \left|\phi^{\prime}(z)\right| \leq 1 $ 。 - 当 $ \left|\phi^{\prime}(z)\right| $ 接近一时,二阶导数不应渐近消失。
虽然许多激活函数满足条件 1,但大多数不满足条件 2。我们认为 ELU 和 softplus 激活由于导数饱和而不是最优的。Fig. 4 显示,当 softplus 和 ELU 在单位 Lipschitz 的区域处饱和时,二阶导数变为零,这会导致训练期间梯度消失。
我们发现满足条件 2 的良好激活函数是平滑 smooth 且 non-monotonic 的函数,例如 Swish (Ramachandran, 2017)。 但是,默认情况下 Swish 不满足条件 1 作为 $ \operatorname{nax}_{z}\left|\frac{d}{d z} \operatorname{Swish}(z)\right| \lesssim 1.1
Fig. 4 常见的平滑 Lipschitz 激活函数 $ \phi $ 通常在 $ \phi^{\prime} $最大时消失 $ \phi^{\prime \prime} $ 。$\operatorname{LipSwish}$ 在 $ \phi^{\prime} $ 接近 1 的区域具有非消失的 $ \phi^{\prime \prime} $ 。
我们使用与 Behrmann 类似的架构。除了在图像像素级没有立即可逆下采样(Dinh,2017)。由于每一层都有更多的空间维度,因此删除它会显著增加所需的内存量(如 Fig. 3 所示),但会提高整体性能。我们还将每个权重矩阵的 Lipschitz 常数的界限增加到 0.98 ,而 Behrmann 使用 0.90 来减少有偏估计量的误差。 更详细的架构描述在 Appendix E 中。
与使用多个 GPU、 large batch sizes 和几百个 epoch 的先前工作不同,残差流模型使用 64 的标准批量大小进行训练,并在 MNIST 和 CIFAR-10 的大约 300-350 个 epoch 中收敛。 尽管我们在实验中使用了 4 个 GPU 来加速训练,但大多数网络设置都可以安装在单个 GPU 上(参见 Fig. 3)。 在 CelebA-HQ 上,Glow 必须使用每个 GPU 的 1 批量大小和 40 个 GPU 的预算,而我们使用每个 GPU 的 3 批量大小和 4 个 GPU 的预算来训练我们的模型,这是由于模型较小和内存高效的反向传播 .
表 1 报告了标准基准数据集 MNIST、CIFAR-10、下采样 ImageNet 和 CelebA-HQ 上的每维位数(log 2 p(x)/d,其中 x ∈ R d)。 我们在所有数据集上实现了与最先进的基于流的模型的竞争性能。 为了评估,我们计算了幂级数 (3) 的 20 项,并使用无偏估计量 (6) 来估计剩余的项。 这将每维测试位的无偏估计的标准偏差降低到可以忽略不计的水平。
此外,可以将残余流的 Lipschitz 条件推广到任意 p 范数甚至混合矩阵范数。 通过与模型共同学习范数阶数,与频谱归一化相比,我们在 CIFAR-10 上实现了 0.003 位/dim 的小增益。 此外,我们表明其他规范如 p = ∞ 产生了更适合低维数据的约束。
有关如何概括 Lipschitz 条件以及对 2D 问题和图像数据的不同范数约束的探索,请参见附录 D。
我们首先制定一个引理,它给出了随机截断序列在相当一般的环境中是无偏估计量的条件。 之后,我们研究我们的特定估计量并证明满足引理的假设。 请注意,在以前的作品中已经说明了类似的条件,例如 在 McLeish (2011) 和 Rhee 和 Glynn (2012)。 然而,我们使用 Bouchard-Côté (2018) 的条件,它只需要 p(N) 有足够的支持。 为了使推导自成一体,我们按以下方式重新制定 Bouchard-Côté (2018) 中的条件:
引理 3(无偏随机截断序列)。 令 Y k 是一个实数随机变量,其中 lim k→∞ E[Y k ] = a 对于某些 a ∈ R 。 此外,令 ∆ 0 = Y 0 且 ∆ k = Y k − Y k-1 且 k ≥ 1 。 假设 E " ∞ X k=0 |∆ k | # < ∞ 并让 N 是一个支持正整数和 n ∼ p(N) 的随机变量。然后对于 Z = n X k=0 ∆ k P(N ≥ k) ,它成立 a = lim k→∞ E[Y k ] = E n∼p(N) [Z] = a。 证明。 首先,表示 Z M = M X k=0 1[N ≥ k]∆ k P(N ≥ k) 和 B M = M X k=0 1[N ≥ k]|∆ k | P(N ≥ k) ,其中 |Z M | ≤ B M 由三角不等式。 由于 B M 是非递减的,单调收敛定理允许交换期望和极限为 E[B] = E[lim M→∞ B M ] =
lim M→∞ E[BM ]。 此外,E[B] = lim M→∞ E[B M ] = lim M→∞ M X k=0 E ? 1[N ≥ k]|∆ k | P(N≥k) ? = lim M→∞ M X k=0 P(N ≥ k)E|∆ k | P(N ≥ k) = E " lim M→∞ MX k=0 |∆ k ] # < ∞,其中最后一步使用了假设。使用上面的,可以使用支配收敛定理来交换极限 和对 ZM 的期望。使用与上述类似的推导,它是 E[Z] = lim M→∞ E[ZM ] = lim M→∞ E " MX k=0 ∆ k # = lim M→∞ E[Y k ] = a,其中我们使用了 YM 和 lim k→∞ E[Y k ] = a 的定义。
我们使用标准设置将数据通过 “unsquashing” 层(我们使用 realnvp中的 logit 变换),然后交替多个块和挤压层(realnvp)。 我们在每个残差块之前和之后使用激活归一化(Glow)。 每个残差连接由下组成
LipSwish → 3×3 Conv → LipSwish → 1×1 Conv → LipSwish → 3×3 Conv
隐藏维度为 512。以下是每个数据集的架构。
MNIST.
CIFAR-10
对于 MNIST 和 CIFAR-10 上的密度建模,我们在网络末端添加了 4 个全连接残差块,中间隐藏维度为 128。这些残差块未用于混合建模实验或其他数据集。 对于图像大小高于 32×32 的数据集,我们在除第一个挤压操作之外的每一次挤压操作后分解出一半的变量。
对于 CIFAR-10 上的混合建模,我们通过在训练数据中减去均值并除以标准差的标准预处理将 logit 变换替换为归一化。 用于混合建模的 MNIST 和 SVHN 架构与用于密度建模的架构相同。
为了在混合建模实验中使用分类器增强基于流的模型,我们在每个挤压层之后和网络末端添加了一个额外的分支。 每个分支由下组成
3×3 Conv → ActNorm → ReLU → AdaptiveAveragePooling((1,1))
其中自适应平均池化在所有空间维度上取平均值并产生一个维度为 256 的向量。每个尺度的输出连接在一起并馈送 进入线性 softmax 分类器。
自适应幂迭代次数。 我们对卷积使用了频谱归一化(Gouk,2018)。 为了考虑训练期间的可变权重更新,我们实施了频谱归一化的自适应版本,我们根据需要执行尽可能多的迭代,直到估计的频谱范数的相对变化足够小。 由于这作为一种摊销,在权重更新较小时减少了迭代次数,因此不会导致比固定次数的幂迭代更高的时间成本,同时,作为 Lipschitz 的更可靠保证 有界。
优化。 对于随机梯度下降,我们使用 Adam(Kingma,2015),在自适应学习率计算之外应用了 0.001 的学习率和 0.0005 的权重衰减(Loshchilov,2019;Zhang ,2019)。 我们使用 Polyak 平均(Polyak,1992)进行评估,衰减为 0.999。
预处理。 对于密度估计实验,我们对 CIFAR10、CelebA-HQ 64 和 CelebA-HQ 256 使用随机水平翻转。对于 CelebA-HQ 64 和 256,我们将样本预处理为 5 位。 对于混合建模和分类实验,我们在 SVHN 和 CIFAR-10 的 4 个像素的反射填充后使用随机裁剪;CIFAR-10 还包括随机水平翻转。



