PointFlow/models/normalization.py
2019-07-13 21:32:26 -07:00

146 lines
5 KiB
Python

import torch
import torch.nn as nn
from torch.nn import Parameter
from utils import reduce_tensor
__all__ = ['MovingBatchNorm1d']
class MovingBatchNormNd(nn.Module):
def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True, sync=False):
super(MovingBatchNormNd, self).__init__()
self.num_features = num_features
self.sync = sync
self.affine = affine
self.eps = eps
self.decay = decay
self.bn_lag = bn_lag
self.register_buffer('step', torch.zeros(1))
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
@property
def shape(self):
raise NotImplementedError
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.zero_()
self.bias.data.zero_()
def forward(self, x, c=None, logpx=None, reverse=False):
if reverse:
return self._reverse(x, logpx)
else:
return self._forward(x, logpx)
def _forward(self, x, logpx=None):
num_channels = x.size(-1)
used_mean = self.running_mean.clone().detach()
used_var = self.running_var.clone().detach()
if self.training:
# compute batch statistics
x_t = x.transpose(0, 1).reshape(num_channels, -1)
batch_mean = torch.mean(x_t, dim=1)
if self.sync:
batch_ex2 = torch.mean(x_t**2, dim=1)
batch_mean = reduce_tensor(batch_mean)
batch_ex2 = reduce_tensor(batch_ex2)
batch_var = batch_ex2 - batch_mean**2
else:
batch_var = torch.var(x_t, dim=1)
# moving average
if self.bn_lag > 0:
used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach())
used_mean /= (1. - self.bn_lag**(self.step[0] + 1))
used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach())
used_var /= (1. - self.bn_lag**(self.step[0] + 1))
# update running estimates
self.running_mean -= self.decay * (self.running_mean - batch_mean.data)
self.running_var -= self.decay * (self.running_var - batch_var.data)
self.step += 1
# perform normalization
used_mean = used_mean.view(*self.shape).expand_as(x)
used_var = used_var.view(*self.shape).expand_as(x)
y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps))
if self.affine:
weight = self.weight.view(*self.shape).expand_as(x)
bias = self.bias.view(*self.shape).expand_as(x)
y = y * torch.exp(weight) + bias
if logpx is None:
return y
else:
return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True)
def _reverse(self, y, logpy=None):
used_mean = self.running_mean
used_var = self.running_var
if self.affine:
weight = self.weight.view(*self.shape).expand_as(y)
bias = self.bias.view(*self.shape).expand_as(y)
y = (y - bias) * torch.exp(-weight)
used_mean = used_mean.view(*self.shape).expand_as(y)
used_var = used_var.view(*self.shape).expand_as(y)
x = y * torch.exp(0.5 * torch.log(used_var + self.eps)) + used_mean
if logpy is None:
return x
else:
return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True)
def _logdetgrad(self, x, used_var):
logdetgrad = -0.5 * torch.log(used_var + self.eps)
if self.affine:
weight = self.weight.view(*self.shape).expand(*x.size())
logdetgrad += weight
return logdetgrad
def __repr__(self):
return (
'{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},'
' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__)
)
def stable_var(x, mean=None, dim=1):
if mean is None:
mean = x.mean(dim, keepdim=True)
mean = mean.view(-1, 1)
res = torch.pow(x - mean, 2)
max_sqr = torch.max(res, dim, keepdim=True)[0]
var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr
var = var.view(-1)
# change nan to zero
var[var != var] = 0
return var
class MovingBatchNorm1d(MovingBatchNormNd):
@property
def shape(self):
return [1, -1]
def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False):
ret = super(MovingBatchNorm1d, self).forward(
x, context, logpx=logpx, reverse=reverse)
return ret