146 lines
5 KiB
Python
146 lines
5 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch.nn import Parameter
|
||
|
from utils import reduce_tensor
|
||
|
|
||
|
__all__ = ['MovingBatchNorm1d']
|
||
|
|
||
|
|
||
|
class MovingBatchNormNd(nn.Module):
|
||
|
def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True, sync=False):
|
||
|
super(MovingBatchNormNd, self).__init__()
|
||
|
self.num_features = num_features
|
||
|
self.sync = sync
|
||
|
self.affine = affine
|
||
|
self.eps = eps
|
||
|
self.decay = decay
|
||
|
self.bn_lag = bn_lag
|
||
|
self.register_buffer('step', torch.zeros(1))
|
||
|
if self.affine:
|
||
|
self.weight = Parameter(torch.Tensor(num_features))
|
||
|
self.bias = Parameter(torch.Tensor(num_features))
|
||
|
else:
|
||
|
self.register_parameter('weight', None)
|
||
|
self.register_parameter('bias', None)
|
||
|
self.register_buffer('running_mean', torch.zeros(num_features))
|
||
|
self.register_buffer('running_var', torch.ones(num_features))
|
||
|
self.reset_parameters()
|
||
|
|
||
|
@property
|
||
|
def shape(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def reset_parameters(self):
|
||
|
self.running_mean.zero_()
|
||
|
self.running_var.fill_(1)
|
||
|
if self.affine:
|
||
|
self.weight.data.zero_()
|
||
|
self.bias.data.zero_()
|
||
|
|
||
|
def forward(self, x, c=None, logpx=None, reverse=False):
|
||
|
if reverse:
|
||
|
return self._reverse(x, logpx)
|
||
|
else:
|
||
|
return self._forward(x, logpx)
|
||
|
|
||
|
def _forward(self, x, logpx=None):
|
||
|
num_channels = x.size(-1)
|
||
|
used_mean = self.running_mean.clone().detach()
|
||
|
used_var = self.running_var.clone().detach()
|
||
|
|
||
|
if self.training:
|
||
|
# compute batch statistics
|
||
|
x_t = x.transpose(0, 1).reshape(num_channels, -1)
|
||
|
batch_mean = torch.mean(x_t, dim=1)
|
||
|
|
||
|
if self.sync:
|
||
|
batch_ex2 = torch.mean(x_t**2, dim=1)
|
||
|
batch_mean = reduce_tensor(batch_mean)
|
||
|
batch_ex2 = reduce_tensor(batch_ex2)
|
||
|
batch_var = batch_ex2 - batch_mean**2
|
||
|
else:
|
||
|
batch_var = torch.var(x_t, dim=1)
|
||
|
|
||
|
# moving average
|
||
|
if self.bn_lag > 0:
|
||
|
used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach())
|
||
|
used_mean /= (1. - self.bn_lag**(self.step[0] + 1))
|
||
|
used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach())
|
||
|
used_var /= (1. - self.bn_lag**(self.step[0] + 1))
|
||
|
|
||
|
# update running estimates
|
||
|
self.running_mean -= self.decay * (self.running_mean - batch_mean.data)
|
||
|
self.running_var -= self.decay * (self.running_var - batch_var.data)
|
||
|
self.step += 1
|
||
|
|
||
|
# perform normalization
|
||
|
used_mean = used_mean.view(*self.shape).expand_as(x)
|
||
|
used_var = used_var.view(*self.shape).expand_as(x)
|
||
|
|
||
|
y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps))
|
||
|
|
||
|
if self.affine:
|
||
|
weight = self.weight.view(*self.shape).expand_as(x)
|
||
|
bias = self.bias.view(*self.shape).expand_as(x)
|
||
|
y = y * torch.exp(weight) + bias
|
||
|
|
||
|
if logpx is None:
|
||
|
return y
|
||
|
else:
|
||
|
return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True)
|
||
|
|
||
|
def _reverse(self, y, logpy=None):
|
||
|
used_mean = self.running_mean
|
||
|
used_var = self.running_var
|
||
|
|
||
|
if self.affine:
|
||
|
weight = self.weight.view(*self.shape).expand_as(y)
|
||
|
bias = self.bias.view(*self.shape).expand_as(y)
|
||
|
y = (y - bias) * torch.exp(-weight)
|
||
|
|
||
|
used_mean = used_mean.view(*self.shape).expand_as(y)
|
||
|
used_var = used_var.view(*self.shape).expand_as(y)
|
||
|
x = y * torch.exp(0.5 * torch.log(used_var + self.eps)) + used_mean
|
||
|
|
||
|
if logpy is None:
|
||
|
return x
|
||
|
else:
|
||
|
return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True)
|
||
|
|
||
|
def _logdetgrad(self, x, used_var):
|
||
|
logdetgrad = -0.5 * torch.log(used_var + self.eps)
|
||
|
if self.affine:
|
||
|
weight = self.weight.view(*self.shape).expand(*x.size())
|
||
|
logdetgrad += weight
|
||
|
return logdetgrad
|
||
|
|
||
|
def __repr__(self):
|
||
|
return (
|
||
|
'{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},'
|
||
|
' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__)
|
||
|
)
|
||
|
|
||
|
|
||
|
def stable_var(x, mean=None, dim=1):
|
||
|
if mean is None:
|
||
|
mean = x.mean(dim, keepdim=True)
|
||
|
mean = mean.view(-1, 1)
|
||
|
res = torch.pow(x - mean, 2)
|
||
|
max_sqr = torch.max(res, dim, keepdim=True)[0]
|
||
|
var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr
|
||
|
var = var.view(-1)
|
||
|
# change nan to zero
|
||
|
var[var != var] = 0
|
||
|
return var
|
||
|
|
||
|
|
||
|
class MovingBatchNorm1d(MovingBatchNormNd):
|
||
|
@property
|
||
|
def shape(self):
|
||
|
return [1, -1]
|
||
|
|
||
|
def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False):
|
||
|
ret = super(MovingBatchNorm1d, self).forward(
|
||
|
x, context, logpx=logpx, reverse=reverse)
|
||
|
return ret
|