123 lines
3.9 KiB
Python
123 lines
3.9 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torchdiffeq import odeint_adjoint
|
||
|
from torchdiffeq import odeint as odeint_normal
|
||
|
|
||
|
__all__ = ["CNF", "SequentialFlow"]
|
||
|
|
||
|
|
||
|
class SequentialFlow(nn.Module):
|
||
|
"""A generalized nn.Sequential container for normalizing flows."""
|
||
|
|
||
|
def __init__(self, layer_list):
|
||
|
super(SequentialFlow, self).__init__()
|
||
|
self.chain = nn.ModuleList(layer_list)
|
||
|
|
||
|
def forward(self, x, context, logpx=None, reverse=False, inds=None, integration_times=None):
|
||
|
if inds is None:
|
||
|
if reverse:
|
||
|
inds = range(len(self.chain) - 1, -1, -1)
|
||
|
else:
|
||
|
inds = range(len(self.chain))
|
||
|
|
||
|
if logpx is None:
|
||
|
for i in inds:
|
||
|
x = self.chain[i](x, context, logpx, integration_times, reverse)
|
||
|
return x
|
||
|
else:
|
||
|
for i in inds:
|
||
|
x, logpx = self.chain[i](x, context, logpx, integration_times, reverse)
|
||
|
return x, logpx
|
||
|
|
||
|
|
||
|
class CNF(nn.Module):
|
||
|
def __init__(self, odefunc, conditional=True, T=1.0, train_T=False, regularization_fns=None,
|
||
|
solver='dopri5', atol=1e-5, rtol=1e-5, use_adjoint=True):
|
||
|
super(CNF, self).__init__()
|
||
|
self.train_T = train_T
|
||
|
self.T = T
|
||
|
if train_T:
|
||
|
self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T))))
|
||
|
|
||
|
if regularization_fns is not None and len(regularization_fns) > 0:
|
||
|
raise NotImplementedError("Regularization not supported")
|
||
|
self.use_adjoint = use_adjoint
|
||
|
self.odefunc = odefunc
|
||
|
self.solver = solver
|
||
|
self.atol = atol
|
||
|
self.rtol = rtol
|
||
|
self.test_solver = solver
|
||
|
self.test_atol = atol
|
||
|
self.test_rtol = rtol
|
||
|
self.solver_options = {}
|
||
|
self.conditional = conditional
|
||
|
|
||
|
def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False):
|
||
|
if logpx is None:
|
||
|
_logpx = torch.zeros(*x.shape[:-1], 1).to(x)
|
||
|
else:
|
||
|
_logpx = logpx
|
||
|
|
||
|
if self.conditional:
|
||
|
assert context is not None
|
||
|
states = (x, _logpx, context)
|
||
|
atol = [self.atol] * 3
|
||
|
rtol = [self.rtol] * 3
|
||
|
else:
|
||
|
states = (x, _logpx)
|
||
|
atol = [self.atol] * 2
|
||
|
rtol = [self.rtol] * 2
|
||
|
|
||
|
if integration_times is None:
|
||
|
if self.train_T:
|
||
|
integration_times = torch.stack(
|
||
|
[torch.tensor(0.0).to(x), self.sqrt_end_time * self.sqrt_end_time]
|
||
|
).to(x)
|
||
|
else:
|
||
|
integration_times = torch.tensor([0., self.T], requires_grad=False).to(x)
|
||
|
|
||
|
if reverse:
|
||
|
integration_times = _flip(integration_times, 0)
|
||
|
|
||
|
# Refresh the odefunc statistics.
|
||
|
self.odefunc.before_odeint()
|
||
|
odeint = odeint_adjoint if self.use_adjoint else odeint_normal
|
||
|
if self.training:
|
||
|
state_t = odeint(
|
||
|
self.odefunc,
|
||
|
states,
|
||
|
integration_times.to(x),
|
||
|
atol=atol,
|
||
|
rtol=rtol,
|
||
|
method=self.solver,
|
||
|
options=self.solver_options,
|
||
|
)
|
||
|
else:
|
||
|
state_t = odeint(
|
||
|
self.odefunc,
|
||
|
states,
|
||
|
integration_times.to(x),
|
||
|
atol=self.test_atol,
|
||
|
rtol=self.test_rtol,
|
||
|
method=self.test_solver,
|
||
|
)
|
||
|
|
||
|
if len(integration_times) == 2:
|
||
|
state_t = tuple(s[1] for s in state_t)
|
||
|
|
||
|
z_t, logpz_t = state_t[:2]
|
||
|
|
||
|
if logpx is not None:
|
||
|
return z_t, logpz_t
|
||
|
else:
|
||
|
return z_t
|
||
|
|
||
|
def num_evals(self):
|
||
|
return self.odefunc._num_evals.item()
|
||
|
|
||
|
|
||
|
def _flip(x, dim):
|
||
|
indices = [slice(None)] * x.dim()
|
||
|
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
|
||
|
return x[tuple(indices)]
|