138 lines
4.5 KiB
Python
138 lines
4.5 KiB
Python
|
import copy
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from . import diffeq_layers
|
||
|
|
||
|
__all__ = ["ODEnet", "ODEfunc"]
|
||
|
|
||
|
|
||
|
def divergence_approx(f, y, e=None):
|
||
|
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
|
||
|
e_dzdx_e = e_dzdx.mul(e)
|
||
|
|
||
|
cnt = 0
|
||
|
while not e_dzdx_e.requires_grad and cnt < 10:
|
||
|
# print("RequiresGrad:f=%s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt=%d"
|
||
|
# % (f.requires_grad, y.requires_grad, e_dzdx.requires_grad,
|
||
|
# e.requires_grad, e_dzdx_e.requires_grad, cnt))
|
||
|
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
|
||
|
e_dzdx_e = e_dzdx * e
|
||
|
cnt += 1
|
||
|
|
||
|
approx_tr_dzdx = e_dzdx_e.sum(dim=-1)
|
||
|
assert approx_tr_dzdx.requires_grad, \
|
||
|
"(failed to add node to graph) f=%s %s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt:%s" \
|
||
|
% (
|
||
|
f.size(), f.requires_grad, y.requires_grad, e_dzdx.requires_grad, e.requires_grad, e_dzdx_e.requires_grad, cnt)
|
||
|
return approx_tr_dzdx
|
||
|
|
||
|
|
||
|
class Swish(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(Swish, self).__init__()
|
||
|
self.beta = nn.Parameter(torch.tensor(1.0))
|
||
|
|
||
|
def forward(self, x):
|
||
|
return x * torch.sigmoid(self.beta * x)
|
||
|
|
||
|
|
||
|
class Lambda(nn.Module):
|
||
|
def __init__(self, f):
|
||
|
super(Lambda, self).__init__()
|
||
|
self.f = f
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.f(x)
|
||
|
|
||
|
|
||
|
NONLINEARITIES = {
|
||
|
"tanh": nn.Tanh(),
|
||
|
"relu": nn.ReLU(),
|
||
|
"softplus": nn.Softplus(),
|
||
|
"elu": nn.ELU(),
|
||
|
"swish": Swish(),
|
||
|
"square": Lambda(lambda x: x ** 2),
|
||
|
"identity": Lambda(lambda x: x),
|
||
|
}
|
||
|
|
||
|
|
||
|
class ODEnet(nn.Module):
|
||
|
"""
|
||
|
Helper class to make neural nets for use in continuous normalizing flows
|
||
|
"""
|
||
|
|
||
|
def __init__(self, hidden_dims, input_shape, context_dim, layer_type="concat", nonlinearity="softplus"):
|
||
|
super(ODEnet, self).__init__()
|
||
|
base_layer = {
|
||
|
"ignore": diffeq_layers.IgnoreLinear,
|
||
|
"squash": diffeq_layers.SquashLinear,
|
||
|
"scale": diffeq_layers.ScaleLinear,
|
||
|
"concat": diffeq_layers.ConcatLinear,
|
||
|
"concat_v2": diffeq_layers.ConcatLinear_v2,
|
||
|
"concatsquash": diffeq_layers.ConcatSquashLinear,
|
||
|
"concatscale": diffeq_layers.ConcatScaleLinear,
|
||
|
}[layer_type]
|
||
|
|
||
|
# build models and add them
|
||
|
layers = []
|
||
|
activation_fns = []
|
||
|
hidden_shape = input_shape
|
||
|
|
||
|
for dim_out in (hidden_dims + (input_shape[0],)):
|
||
|
layer_kwargs = {}
|
||
|
layer = base_layer(hidden_shape[0], dim_out, context_dim, **layer_kwargs)
|
||
|
layers.append(layer)
|
||
|
activation_fns.append(NONLINEARITIES[nonlinearity])
|
||
|
|
||
|
hidden_shape = list(copy.copy(hidden_shape))
|
||
|
hidden_shape[0] = dim_out
|
||
|
|
||
|
self.layers = nn.ModuleList(layers)
|
||
|
self.activation_fns = nn.ModuleList(activation_fns[:-1])
|
||
|
|
||
|
def forward(self, context, y):
|
||
|
dx = y
|
||
|
for l, layer in enumerate(self.layers):
|
||
|
dx = layer(context, dx)
|
||
|
# if not last layer, use nonlinearity
|
||
|
if l < len(self.layers) - 1:
|
||
|
dx = self.activation_fns[l](dx)
|
||
|
return dx
|
||
|
|
||
|
|
||
|
class ODEfunc(nn.Module):
|
||
|
def __init__(self, diffeq):
|
||
|
super(ODEfunc, self).__init__()
|
||
|
self.diffeq = diffeq
|
||
|
self.divergence_fn = divergence_approx
|
||
|
self.register_buffer("_num_evals", torch.tensor(0.))
|
||
|
|
||
|
def before_odeint(self, e=None):
|
||
|
self._e = e
|
||
|
self._num_evals.fill_(0)
|
||
|
|
||
|
def forward(self, t, states):
|
||
|
y = states[0]
|
||
|
t = torch.ones(y.size(0), 1).to(y) * t.clone().detach().requires_grad_(True).type_as(y)
|
||
|
self._num_evals += 1
|
||
|
for state in states:
|
||
|
state.requires_grad_(True)
|
||
|
|
||
|
# Sample and fix the noise.
|
||
|
if self._e is None:
|
||
|
self._e = torch.randn_like(y, requires_grad=True).to(y)
|
||
|
|
||
|
with torch.set_grad_enabled(True):
|
||
|
if len(states) == 3: # conditional CNF
|
||
|
c = states[2]
|
||
|
tc = torch.cat([t, c.view(y.size(0), -1)], dim=1)
|
||
|
dy = self.diffeq(tc, y)
|
||
|
divergence = self.divergence_fn(dy, y, e=self._e).unsqueeze(-1)
|
||
|
return dy, -divergence, torch.zeros_like(c).requires_grad_(True)
|
||
|
elif len(states) == 2: # unconditional CNF
|
||
|
dy = self.diffeq(t, y)
|
||
|
divergence = self.divergence_fn(dy, y, e=self._e).view(-1, 1)
|
||
|
return dy, -divergence
|
||
|
else:
|
||
|
assert 0, "`len(states)` should be 2 or 3"
|