138 lines
4.5 KiB
138 lines
4.5 KiB
import copy
import torch
import torch.nn as nn
from . import diffeq_layers
__all__ = ["ODEnet", "ODEfunc"]
def divergence_approx(f, y, e=None):
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
e_dzdx_e = e_dzdx.mul(e)
cnt = 0
while not e_dzdx_e.requires_grad and cnt < 10:
# print("RequiresGrad:f=%s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt=%d"
# % (f.requires_grad, y.requires_grad, e_dzdx.requires_grad,
# e.requires_grad, e_dzdx_e.requires_grad, cnt))
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
e_dzdx_e = e_dzdx * e
cnt += 1
approx_tr_dzdx = e_dzdx_e.sum(dim=-1)
assert approx_tr_dzdx.requires_grad, \
"(failed to add node to graph) f=%s %s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt:%s" \
% (
f.size(), f.requires_grad, y.requires_grad, e_dzdx.requires_grad, e.requires_grad, e_dzdx_e.requires_grad, cnt)
return approx_tr_dzdx
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
self.beta = nn.Parameter(torch.tensor(1.0))
def forward(self, x):
return x * torch.sigmoid(self.beta * x)
class Lambda(nn.Module):
def __init__(self, f):
super(Lambda, self).__init__()
self.f = f
def forward(self, x):
return self.f(x)
"tanh": nn.Tanh(),
"relu": nn.ReLU(),
"softplus": nn.Softplus(),
"elu": nn.ELU(),
"swish": Swish(),
"square": Lambda(lambda x: x ** 2),
"identity": Lambda(lambda x: x),
class ODEnet(nn.Module):
Helper class to make neural nets for use in continuous normalizing flows
def __init__(self, hidden_dims, input_shape, context_dim, layer_type="concat", nonlinearity="softplus"):
super(ODEnet, self).__init__()
base_layer = {
"ignore": diffeq_layers.IgnoreLinear,
"squash": diffeq_layers.SquashLinear,
"scale": diffeq_layers.ScaleLinear,
"concat": diffeq_layers.ConcatLinear,
"concat_v2": diffeq_layers.ConcatLinear_v2,
"concatsquash": diffeq_layers.ConcatSquashLinear,
"concatscale": diffeq_layers.ConcatScaleLinear,
# build models and add them
layers = []
activation_fns = []
hidden_shape = input_shape
for dim_out in (hidden_dims + (input_shape[0],)):
layer_kwargs = {}
layer = base_layer(hidden_shape[0], dim_out, context_dim, **layer_kwargs)
hidden_shape = list(copy.copy(hidden_shape))
hidden_shape[0] = dim_out
self.layers = nn.ModuleList(layers)
self.activation_fns = nn.ModuleList(activation_fns[:-1])
def forward(self, context, y):
dx = y
for l, layer in enumerate(self.layers):
dx = layer(context, dx)
# if not last layer, use nonlinearity
if l < len(self.layers) - 1:
dx = self.activation_fns[l](dx)
return dx
class ODEfunc(nn.Module):
def __init__(self, diffeq):
super(ODEfunc, self).__init__()
self.diffeq = diffeq
self.divergence_fn = divergence_approx
self.register_buffer("_num_evals", torch.tensor(0.))
def before_odeint(self, e=None):
self._e = e
def forward(self, t, states):
y = states[0]
t = torch.ones(y.size(0), 1).to(y) * t.clone().detach().requires_grad_(True).type_as(y)
self._num_evals += 1
for state in states:
# Sample and fix the noise.
if self._e is None:
self._e = torch.randn_like(y, requires_grad=True).to(y)
with torch.set_grad_enabled(True):
if len(states) == 3: # conditional CNF
c = states[2]
tc = torch.cat([t, c.view(y.size(0), -1)], dim=1)
dy = self.diffeq(tc, y)
divergence = self.divergence_fn(dy, y, e=self._e).unsqueeze(-1)
return dy, -divergence, torch.zeros_like(c).requires_grad_(True)
elif len(states) == 2: # unconditional CNF
dy = self.diffeq(t, y)
divergence = self.divergence_fn(dy, y, e=self._e).view(-1, 1)
return dy, -divergence
assert 0, "`len(states)` should be 2 or 3"