LION/trainers/train_2prior.py
2023-01-23 00:14:49 -05:00

450 lines
21 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
""" to train hierarchical VAE model with 2 prior
one for style latent, one for latent pts,
based on trainers/train_prior.py
"""
import os
import time
from PIL import Image
import gc
import functools
import psutil
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision
import numpy as np
from loguru import logger
import torch.distributed as dist
from torch import optim
from utils.ema import EMA
from utils.model_helper import import_model, loss_fn
from utils.vis_helper import visualize_point_clouds_3d
from utils.eval_helper import compute_NLL_metric
from utils import model_helper, exp_helper, data_helper
from utils.data_helper import normalize_point_clouds
## from utils.diffusion_discretized import DiffusionDiscretized
from utils.diffusion_pvd import DiffusionDiscretized
from utils.diffusion_continuous import make_diffusion, DiffusionBase
from utils.checker import *
from utils import utils
from matplotlib import pyplot as plt
import third_party.pvcnn.functional as pvcnn_fn
from timeit import default_timer as timer
from torch.optim import Adam as FusedAdam
from torch.cuda.amp import autocast, GradScaler
from trainers.train_prior import Trainer as PriorTrainer
from trainers.train_prior import validate_inspect # import Trainer as PriorTrainer
quiet = int(os.environ.get('quiet', 0))
VIS_LATENT_PTS = 0
@torch.no_grad()
def generate_samples_vada_2prior(shape, dae, diffusion, vae, num_samples, enable_autocast,
ode_eps=0.00001, ode_solver_tol=1e-5, # None,
ode_sample=False, prior_var=1.0, temp=1.0, vae_temp=1.0, noise=None, need_denoise=False,
ddim_step=0, clip_feat=None, cls_emb=None, ddim_skip_type='uniform', ddim_kappa=1.0):
output = {}
#kwargs = {}
# if cls_emb is not None:
# kwargs['cls_emb'] = cls_emb
if ode_sample == 1:
assert isinstance(
diffusion, DiffusionBase), 'ODE-based sampling requires cont. diffusion!'
assert ode_eps is not None, 'ODE-based sampling requires integration cutoff ode_eps!'
assert ode_solver_tol is not None, 'ODE-based sampling requires ode solver tolerance!'
start = timer()
condition_input = None
eps_list = []
for i in range(len(dae)):
assert(cls_emb is None), f' not support yet'
eps, nfe, time_ode_solve = diffusion.sample_model_ode(
dae[i], num_samples, shape[i], ode_eps, ode_solver_tol, enable_autocast, temp, noise,
condition_input=condition_input, clip_feat=clip_feat,
)
condition_input = eps
eps_list.append(eps)
output['sampled_eps'] = eps
eps = vae.compose_eps(eps_list) # torch.cat(eps, dim=1)
elif ode_sample == 0:
assert isinstance(
diffusion, DiffusionDiscretized), 'Regular sampling requires disc. diffusion!'
assert noise is None, 'Noise is not used in ancestral sampling.'
nfe = diffusion._diffusion_steps
time_ode_solve = 999.999 # Yeah I know...
start = timer()
condition_input = None if cls_emb is None else cls_emb
all_eps = []
for i in range(len(dae)):
if ddim_step > 0:
assert(cls_emb is None), f'not support yet'
eps, eps_list = diffusion.run_ddim(dae[i],
num_samples, shape[i], temp, enable_autocast,
is_image=False, prior_var=prior_var, ddim_step=ddim_step,
condition_input=condition_input, clip_feat=clip_feat,
skip_type=ddim_skip_type, kappa=ddim_kappa)
else:
eps, eps_list = diffusion.run_denoising_diffusion(dae[i],
num_samples, shape[i], temp, enable_autocast,
is_image=False, prior_var=prior_var,
condition_input=condition_input, clip_feat=clip_feat,
)
condition_input = eps
if cls_emb is not None:
condition_input = torch.cat([condition_input,
cls_emb.unsqueeze(-1).unsqueeze(-1)], dim=1)
if i == 0:
condition_input = vae.global2style(condition_input)
# exit()
all_eps.append(eps)
output['sampled_eps'] = eps
eps = vae.compose_eps(all_eps)
output['eps_list'] = eps_list
output['print/sample_mean_global'] = eps.view(
num_samples, -1).mean(-1).mean()
output['print/sample_var_global'] = eps.view(
num_samples, -1).var(-1).mean()
decomposed_eps = vae.decompose_eps(eps)
image = vae.sample(num_samples=num_samples,
decomposed_eps=decomposed_eps, cls_emb=cls_emb)
end = timer()
sampling_time = end - start
# average over GPUs
nfe_torch = torch.tensor(nfe * 1.0, device='cuda')
sampling_time_torch = torch.tensor(sampling_time * 1.0, device='cuda')
time_ode_solve_torch = torch.tensor(time_ode_solve * 1.0, device='cuda')
return image, nfe_torch, time_ode_solve_torch, sampling_time_torch, output
class Trainer(PriorTrainer):
is_diffusion = 0
def __init__(self, cfg, args):
"""
Args:
cfg: training config
args: used for distributed training
"""
super().__init__(cfg, args)
self.fun_generate_samples_vada = functools.partial(
generate_samples_vada_2prior, ode_eps=cfg.sde.ode_eps,
ddim_skip_type=cfg.sde.ddim_skip_type,
ddim_kappa=cfg.sde.ddim_kappa)
def compute_loss_vae(self, tr_pts, global_step, **kwargs):
""" compute forward for VAE model, used in global-only prior training
Input:
tr_pts: points
global_step: int
Returns:
output dict including entry:
'eps': z ~ posterior
'q_loss': 0 if not train vae else the KL+rec
'x_0_pred': global points if not train vae
'x_0_target': target points
"""
vae = self.model
dae = self.dae
args = self.cfg.sde
distributed = args.distributed
vae_sn_calculator = self.vae_sn_calculator
num_total_iter = self.num_total_iter
## diffusion = self.diffusion_cont if self.cfg.sde.ode_sample else self.diffusion_disc
if self.cfg.sde.ode_sample == 1:
diffusion = self.diffusion_cont
elif self.cfg.sde.ode_sample == 0:
diffusion = self.diffusion_disc
elif self.cfg.sde.ode_sample == 2:
raise NotImplementedError
# diffusion = [self.diffusion_cont, self.diffusion_disc]
B = tr_pts.size(0)
with torch.set_grad_enabled(args.train_vae):
with autocast(enabled=args.autocast_train):
# posterior and likelihood
if not args.train_vae:
output = {}
all_eps, all_log_q, latent_list = vae.encode(tr_pts)
x_0_pred = x_0_target = tr_pts
vae_recon_loss = 0
def make_4d(x): return x.unsqueeze(-1).unsqueeze(-1) if \
len(x.shape) == 2 else x.unsqueeze(-1)
eps = make_4d(all_eps)
output.update({'eps': eps, 'q_loss': torch.zeros(1),
'x_0_pred': tr_pts, 'x_0_target': tr_pts,
'x_0': tr_pts, 'final_pred': tr_pts})
else:
raise NotImplementedError
return output
# ------------------------------------------- #
# training fun #
# ------------------------------------------- #
def train_iter(self, data, *args, **kwargs):
""" forward one iteration; and step optimizer
Args:
data: (dict) tr_points shape: (B,N,3)
see get_loss in models/shapelatent_diffusion.py
"""
# some variables
input_dim = self.cfg.ddpm.input_dim
loss_type = self.cfg.ddpm.loss_type
vae = self.model
dae = self.dae
dae.train()
diffusion = self.diffusion_cont if self.cfg.sde.ode_sample else self.diffusion_disc
if self.cfg.sde.ode_sample == 1:
diffusion = self.diffusion_cont
elif self.cfg.sde.ode_sample == 0:
diffusion = self.diffusion_disc
elif self.cfg.sde.ode_sample == 2:
raise NotImplementedError # not support training with different solver
## diffusion = [self.diffusion_cont, self.diffusion_disc]
dae_optimizer = self.dae_optimizer
vae_optimizer = self.vae_optimizer
args = self.cfg.sde
device = torch.device(self.device_str)
num_total_iter = self.num_total_iter
distributed = self.args.distributed
dae_sn_calculator = self.dae_sn_calculator
vae_sn_calculator = self.vae_sn_calculator
grad_scalar = self.grad_scalar
global_step = step = kwargs.get('step', None)
no_update = kwargs.get('no_update', False)
# update_lr
warmup_iters = len(self.train_loader) * args.warmup_epochs
utils.update_lr(args, global_step, warmup_iters,
dae_optimizer, vae_optimizer)
# input
tr_pts = data['tr_points'].to(device) # (B, Npoints, 3)
inputs = data['input_pts'].to(
device) if 'input_pts' in data else None # the noisy points
tr_img = data['tr_img'].to(device) if 'tr_img' in data else None
model_kwargs = {}
if self.cfg.data.cond_on_cat:
class_label_int = data['cate_idx'].view(-1) # .to(device)
nclass = self.cfg.data.nclass
class_label = torch.nn.functional.one_hot(class_label_int, nclass)
model_kwargs['class_label'] = class_label.float().to(device)
B = batch_size = tr_pts.size(0)
if tr_img is not None:
# tr_img: B,nimg,3,H,W
# logger.info('image: {}', tr_img.shape)
nimg = tr_img.shape[1]
tr_img = tr_img.view(B*nimg, *tr_img.shape[2:])
clip_feat = self.clip_model.encode_image(
tr_img).view(B, nimg, -1).mean(1).float()
else:
clip_feat = None
# optimize vae params
vae_optimizer.zero_grad()
output = self.compute_loss_vae(
tr_pts, global_step, inputs=inputs, **model_kwargs)
# the interface between VAE and DAE is eps.
eps = output['eps'].detach() # 4d: B,D,-1,1
CHECK4D(eps)
dae_kwarg = {}
if self.cfg.data.cond_on_cat:
dae_kwarg['condition_input'] = output['cls_emb']
# train prior
if args.train_dae:
dae_optimizer.zero_grad()
with autocast(enabled=args.autocast_train):
# get diffusion quantities for p sampling scheme and reweighting for q
t_p, var_t_p, m_t_p, obj_weight_t_p, _, g2_t_p = \
diffusion.iw_quantities(B, args.time_eps,
args.iw_sample_p, args.iw_subvp_like_vp_sde)
# logger.info('t_p: {}, var: {}, m_t: {}', t_p[0], var_t_p[0], m_t_p[0])
decomposed_eps = self.vae.decompose_eps(eps)
output['vis/eps'] = decomposed_eps[1].view(
-1, self.dae.num_points, self.dae.num_classes)[:, :, :3]
p_loss_list = []
for latent_id, eps in enumerate(decomposed_eps):
noise_p = torch.randn(size=eps.size(), device=device)
eps_t_p = diffusion.sample_q(eps, noise_p, var_t_p, m_t_p)
# run the score model
eps_t_p.requires_grad_(True)
mixing_component = diffusion.mixing_component(
eps_t_p, var_t_p, t_p, enabled=args.mixed_prediction)
if latent_id == 0:
pred_params_p = dae[latent_id](
eps_t_p, t_p, x0=eps, clip_feat=clip_feat, **dae_kwarg)
else:
condition_input = decomposed_eps[0] if not self.cfg.data.cond_on_cat else \
torch.cat(
[decomposed_eps[0], output['cls_emb'].unsqueeze(-1).unsqueeze(-1)], dim=1)
condition_input = self.model.global2style(
condition_input)
pred_params_p = dae[latent_id](eps_t_p, t_p, x0=eps,
condition_input=condition_input, clip_feat=clip_feat)
pred_eps_t0 = (eps_t_p - torch.sqrt(var_t_p)
* pred_params_p) / m_t_p
params = utils.get_mixed_prediction(args.mixed_prediction,
pred_params_p, dae[latent_id].mixing_logit, mixing_component)
if self.cfg.latent_pts.pvd_mse_loss:
p_loss = F.mse_loss(
params.contiguous().view(B, -1), noise_p.view(B, -1),
reduction='mean')
else:
l2_term_p = torch.square(params - noise_p)
p_objective = torch.sum(
obj_weight_t_p * l2_term_p, dim=[1, 2, 3])
regularization_p, dae_norm_loss, dae_bn_loss, dae_wdn_coeff, \
jac_reg_loss, kin_reg_loss = utils.dae_regularization(
args, dae_sn_calculator, diffusion, dae, step, t_p,
pred_params_p, eps_t_p, var_t_p, m_t_p, g2_t_p)
reg_mlogit = ((torch.sum(torch.sigmoid(dae.mixing_logit)) -
args.regularize_mlogit_margin)**2) * args.regularize_mlogit \
if args.regularize_mlogit else 0
p_loss = torch.mean(p_objective) + \
regularization_p + reg_mlogit
if self.writer is not None:
self.writer.avg_meter(
'train/p_loss_%d' % latent_id, p_loss.detach().item())
p_loss_list.append(p_loss)
p_loss = sum(p_loss_list) # torch.cat(p_loss_list, dim=0).sum()
loss = p_loss
# update dae parameters
grad_scalar.scale(p_loss).backward()
utils.average_gradients(dae.parameters(), distributed)
if args.grad_clip_max_norm > 0.: # apply gradient clipping
grad_scalar.unscale_(dae_optimizer)
torch.nn.utils.clip_grad_norm_(dae.parameters(),
max_norm=args.grad_clip_max_norm)
grad_scalar.step(dae_optimizer)
# update grade scalar
grad_scalar.update()
if args.bound_mlogit:
dae.mixing_logit.data.clamp_(max=args.bound_mlogit_value)
# Bookkeeping!
writer = self.writer
if writer is not None:
writer.avg_meter('train/lr_dae', dae_optimizer.state_dict()[
'param_groups'][0]['lr'], global_step)
writer.avg_meter('train/lr_vae', vae_optimizer.state_dict()[
'param_groups'][0]['lr'], global_step)
if self.cfg.latent_pts.pvd_mse_loss:
writer.avg_meter(
'train/p_loss', p_loss.item(), global_step)
if args.mixed_prediction and global_step % 500 == 0:
for i in range(len(dae)):
m = torch.sigmoid(dae[i].mixing_logit)
if not torch.isnan(m).any():
writer.add_histogram(
'mixing_prob_%d' % i, m.detach().cpu().numpy(), global_step)
# no other loss
else:
writer.avg_meter(
'train/p_loss', (p_loss - regularization_p).item(), global_step)
if torch.is_tensor(regularization_p):
writer.avg_meter(
'train/reg_p', regularization_p.item(), global_step)
if args.regularize_mlogit:
writer.avg_meter(
'train/m_logit', reg_mlogit / args.regularize_mlogit, global_step)
if args.mixed_prediction:
writer.avg_meter(
'train/m_logit_sum', torch.sum(torch.sigmoid(dae.mixing_logit)).detach().cpu(), global_step)
if (global_step) % 500 == 0:
writer.add_scalar(
'train/norm_loss_dae', dae_norm_loss, global_step)
writer.add_scalar('train/bn_loss_dae',
dae_bn_loss, global_step)
writer.add_scalar(
'train/norm_coeff_dae', dae_wdn_coeff, global_step)
if args.mixed_prediction:
m = torch.sigmoid(dae.mixing_logit)
if not torch.isnan(m).any():
writer.add_histogram(
'mixing_prob', m.detach().cpu().numpy(), global_step)
# write stats
if self.writer is not None:
for k, v in output.items():
if 'print/' in k and step is not None:
self.writer.avg_meter(k.split('print/')[-1],
v.mean().item() if torch.is_tensor(v) else v,
step=step)
res = output
output_dict = {
'loss': loss.detach().cpu().item(),
'x_0_pred': res['x_0_pred'].detach().cpu(), # perturbed data
'x_0': res['x_0'].detach().cpu(),
# B.B,3
'x_t': res['final_pred'].detach().view(batch_size, -1, res['x_0'].shape[-1]),
't': res.get('t', None)
}
for k, v in output.items():
if 'vis/' in k:
output_dict[k] = v
return output_dict
# --------------------------------------------- #
# visulization function and sampling function #
# --------------------------------------------- #
def build_prior(self):
args = self.cfg.sde
device = torch.device(self.device_str)
arch_instance_dae = utils.get_arch_cells_denoising(
'res_ho_attn', True, False)
num_input_channels = self.cfg.shapelatent.latent_dim
DAE = nn.ModuleList(
[
import_model(self.cfg.latent_pts.style_prior)(args,
self.cfg.latent_pts.style_dim, self.cfg), # style prior
import_model(self.cfg.sde.prior_model)(args,
num_input_channels, self.cfg), # global prior, conditional model
])
self.dae = DAE.to(device)
# Bad solution! it is used in validate_inspect function
self.dae.num_points = self.dae[1].num_points
self.dae.num_classes = self.dae[1].num_classes
if len(self.cfg.sde.dae_checkpoint):
logger.info('Load dae checkpoint: {}',
self.cfg.sde.dae_checkpoint)
checkpoint = torch.load(
self.cfg.sde.dae_checkpoint, map_location='cpu')
self.dae.load_state_dict(checkpoint['dae_state_dict'])
self.diffusion_cont = make_diffusion(args)
self.diffusion_disc = DiffusionDiscretized(
args, self.diffusion_cont.var, self.cfg)
if not quiet:
logger.info('DAE: {}', self.dae)
logger.info('DAE: param size = %fM ' %
utils.count_parameters_in_M(self.dae))
# sync all parameters between all gpus by sending param from rank 0 to all gpus.
utils.broadcast_params(self.dae.parameters(), self.args.distributed)