PointFlow/models/flow.py
2019-07-13 21:32:26 -07:00

90 lines
2.7 KiB
Python

from .odefunc import ODEfunc, ODEnet
from .normalization import MovingBatchNorm1d
from .cnf import CNF, SequentialFlow
def count_nfe(model):
class AccNumEvals(object):
def __init__(self):
self.num_evals = 0
def __call__(self, module):
if isinstance(module, CNF):
self.num_evals += module.num_evals()
accumulator = AccNumEvals()
model.apply(accumulator)
return accumulator.num_evals
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def count_total_time(model):
class Accumulator(object):
def __init__(self):
self.total_time = 0
def __call__(self, module):
if isinstance(module, CNF):
self.total_time = self.total_time + module.sqrt_end_time * module.sqrt_end_time
accumulator = Accumulator()
model.apply(accumulator)
return accumulator.total_time
def build_model(args, input_dim, hidden_dims, context_dim, num_blocks, conditional):
def build_cnf():
diffeq = ODEnet(
hidden_dims=hidden_dims,
input_shape=(input_dim,),
context_dim=context_dim,
layer_type=args.layer_type,
nonlinearity=args.nonlinearity,
)
odefunc = ODEfunc(
diffeq=diffeq,
)
cnf = CNF(
odefunc=odefunc,
T=args.time_length,
train_T=args.train_T,
conditional=conditional,
solver=args.solver,
use_adjoint=args.use_adjoint,
atol=args.atol,
rtol=args.rtol,
)
return cnf
chain = [build_cnf() for _ in range(num_blocks)]
if args.batch_norm:
bn_layers = [MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag, sync=args.sync_bn)
for _ in range(num_blocks)]
bn_chain = [MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag, sync=args.sync_bn)]
for a, b in zip(chain, bn_layers):
bn_chain.append(a)
bn_chain.append(b)
chain = bn_chain
model = SequentialFlow(chain)
return model
def get_point_cnf(args):
dims = tuple(map(int, args.dims.split("-")))
model = build_model(args, args.input_dim, dims, args.zdim, args.num_blocks, True).cuda()
print("Number of trainable parameters of Point CNF: {}".format(count_parameters(model)))
return model
def get_latent_cnf(args):
dims = tuple(map(int, args.latent_dims.split("-")))
model = build_model(args, args.zdim, dims, 0, args.latent_num_blocks, False).cuda()
print("Number of trainable parameters of Latent CNF: {}".format(count_parameters(model)))
return model