LION/trainers/hvae_trainer.py

205 lines
8.1 KiB
Python
Raw Normal View History

2023-01-23 05:14:49 +00:00
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
""" to train hierarchical VAE model
this trainer only train the vae without prior
"""
import os
import sys
import torch
import torch.nn.functional as F
import numpy as np
from loguru import logger
import torch.distributed as dist
from trainers.base_trainer import BaseTrainer
from utils.eval_helper import compute_NLL_metric
from utils import model_helper, exp_helper, data_helper
from utils.checker import *
from utils import utils
from trainers.common_fun import validate_inspect_noprior
from torch.cuda.amp import autocast, GradScaler
import third_party.pvcnn.functional as pvcnn_fn
from calmsize import size as calmsize
class Trainer(BaseTrainer):
def __init__(self, cfg, args):
"""
Args:
cfg: training config
args: used for distributed training
"""
super().__init__(cfg, args)
self.train_iter_kwargs = {}
self.sample_num_points = cfg.data.tr_max_sample_points
device = torch.device('cuda:%d' % args.local_rank)
self.device_str = 'cuda:%d' % args.local_rank
if not cfg.trainer.use_grad_scalar:
self.grad_scalar = utils.DummyGradScalar()
else:
logger.info('Init GradScaler!')
self.grad_scalar = GradScaler(2**10, enabled=True)
self.model = self.build_model().to(device)
if len(self.cfg.sde.vae_checkpoint):
logger.info('Load vae_checkpoint: {}', self.cfg.sde.vae_checkpoint)
self.model.load_state_dict(
torch.load(self.cfg.sde.vae_checkpoint)['model'])
logger.info('broadcast_params: device={}', device)
utils.broadcast_params(self.model.parameters(),
args.distributed)
self.build_other_module()
if args.distributed:
logger.info('waitting for barrier, device={}', device)
dist.barrier()
logger.info('pass barrier, device={}', device)
self.train_loader, self.test_loader = self.build_data()
# The optimizer
self.optimizer, self.scheduler = utils.get_opt(
self.model.parameters(),
self.cfg.trainer.opt,
cfg.ddpm.ema, self.cfg)
# Build Spectral Norm Regularization if needed
if self.cfg.trainer.sn_reg_vae:
raise NotImplementedError
# Prepare variable for summy
self.num_points = self.cfg.data.tr_max_sample_points
logger.info('done init trainer @{}', device)
# Prepare for evaluation
# init the latent for validate
self.prepare_vis_data()
# ------------------------------------------- #
# training fun #
# ------------------------------------------- #
def epoch_start(self, epoch):
pass
def epoch_end(self, epoch, writer=None, **kwargs):
return super().epoch_end(epoch, writer=writer)
def train_iter(self, data, *args, **kwargs):
""" forward one iteration; and step optimizer
Args:
data: (dict) tr_points shape: (B,N,3)
see get_loss in models/shapelatent_diffusion.py
"""
self.model.train()
step = kwargs.get('step', None)
assert(step is not None), 'require step as input'
warmup_iters = len(self.train_loader) * \
self.cfg.trainer.opt.vae_lr_warmup_epochs
utils.update_vae_lr(self.cfg, step, warmup_iters, self.optimizer)
if 'no_update' in kwargs:
no_update = kwargs['no_update']
else:
no_update = False
if not no_update:
self.model.train()
self.optimizer.zero_grad()
device = torch.device(self.device_str)
tr_pts = data['tr_points'].to(device) # (B, Npoints, 3)
batch_size = tr_pts.size(0)
model_kwargs = {}
with autocast(enabled=self.cfg.sde.autocast_train):
res = self.model.get_loss(tr_pts, writer=self.writer,
it=step, **model_kwargs)
loss = res['loss'].mean()
lossv = loss.detach().cpu().item()
if not no_update:
self.grad_scalar.scale(loss).backward()
utils.average_gradients(self.model.parameters(),
self.args.distributed)
if self.cfg.trainer.opt.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(),
max_norm=self.cfg.trainer.opt.grad_clip)
self.grad_scalar.step(self.optimizer)
self.grad_scalar.update()
output = {}
if self.writer is not None:
for k, v in res.items():
if 'print/' in k and step is not None:
v0 = v.mean().item() if torch.is_tensor(v) else v
self.writer.avg_meter(k.split('print/')[-1], v0,
step=step)
if 'hist/' in k:
output[k] = v
output.update({
'loss': lossv,
'x_0_pred': res['x_0_pred'].detach().cpu(), # perturbed data
'x_0': res['x_0'].detach().cpu(),
'x_t': res['final_pred'].detach().view(batch_size, -1, res['x_0'].shape[-1]),
't': res.get('t', None)
})
for k, v in res.items():
if 'vis/' in k or 'msg/' in k:
output[k] = v
# if 'x_ref_pred' in res:
# output['x_ref_pred'] = res['x_ref_pred'].detach().cpu()
# if 'x_ref_pred_input' in res:
# output['x_ref_pred_input'] = res['x_ref_pred_input'].detach().cpu()
return output
# --------------------------------------------- #
# visulization function and sampling function #
# --------------------------------------------- #
@torch.no_grad()
def vis_diffusion(self, data, writer):
pass
def diffusion_sample(self, *args, **kwargs):
pass
@torch.no_grad()
def vis_sample(self, writer, num_vis=None, step=0, include_pred_x0=True,
save_file=None):
bound = 1.5 if 'chair' in self.cfg.data.cates else 1.0
assert(not self.cfg.data.cond_on_cat)
val_class_label = tr_class_label = None
validate_inspect_noprior(self.model,
step, self.writer, self.sample_num_points,
need_sample=0, need_val=1, need_train=0,
num_samples=self.num_val_samples,
test_loader=self.test_loader,
w_prior=self.w_prior,
val_x=self.val_x, tr_x=self.tr_x,
val_class_label=val_class_label,
tr_class_label=tr_class_label,
has_shapelatent=True,
bound=bound, cfg=self.cfg
)
@torch.no_grad()
def sample(self, num_shapes=2, num_points=2048, device_str='cuda',
for_vis=True, use_ddim=False, save_file=None, ddim_step=500):
""" return the final samples in shape [B,3,N] """
# switch to EMA parameters
if self.cfg.ddpm.ema:
self.optimizer.swap_parameters_with_ema(store_params_in_ema=True)
self.model.eval()
# ---- forward sampling ---- #
gen_x = self.model.sample(
num_samples=num_shapes, device_str=self.device_str)
# gen_x: BNC
CHECKEQ(gen_x.shape[2], self.cfg.ddpm.input_dim)
traj = gen_x.permute(0, 2, 1).contiguous() # BN3->B3N
# switch back to original parameters
if self.cfg.ddpm.ema:
self.optimizer.swap_parameters_with_ema(store_params_in_ema=True)
return traj