from .odefunc import ODEfunc, ODEnet from .normalization import MovingBatchNorm1d from .cnf import CNF, SequentialFlow def count_nfe(model): class AccNumEvals(object): def __init__(self): self.num_evals = 0 def __call__(self, module): if isinstance(module, CNF): self.num_evals += module.num_evals() accumulator = AccNumEvals() model.apply(accumulator) return accumulator.num_evals def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def count_total_time(model): class Accumulator(object): def __init__(self): self.total_time = 0 def __call__(self, module): if isinstance(module, CNF): self.total_time = self.total_time + module.sqrt_end_time * module.sqrt_end_time accumulator = Accumulator() model.apply(accumulator) return accumulator.total_time def build_model(args, input_dim, hidden_dims, context_dim, num_blocks, conditional): def build_cnf(): diffeq = ODEnet( hidden_dims=hidden_dims, input_shape=(input_dim,), context_dim=context_dim, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = ODEfunc( diffeq=diffeq, ) cnf = CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, conditional=conditional, solver=args.solver, use_adjoint=args.use_adjoint, atol=args.atol, rtol=args.rtol, ) return cnf chain = [build_cnf() for _ in range(num_blocks)] if args.batch_norm: bn_layers = [MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag, sync=args.sync_bn) for _ in range(num_blocks)] bn_chain = [MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag, sync=args.sync_bn)] for a, b in zip(chain, bn_layers): bn_chain.append(a) bn_chain.append(b) chain = bn_chain model = SequentialFlow(chain) return model def get_point_cnf(args): dims = tuple(map(int, args.dims.split("-"))) model = build_model(args, args.input_dim, dims, args.zdim, args.num_blocks, True).cuda() print("Number of trainable parameters of Point CNF: {}".format(count_parameters(model))) return model def get_latent_cnf(args): dims = tuple(map(int, args.latent_dims.split("-"))) model = build_model(args, args.zdim, dims, 0, args.latent_num_blocks, False).cuda() print("Number of trainable parameters of Latent CNF: {}".format(count_parameters(model))) return model