LION/trainers/train_prior.py
2023-01-23 00:14:49 -05:00

742 lines
33 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
""" to train hierarchical VAE model with single prior """
import os
import time
from PIL import Image
import gc
import psutil
import functools
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision
import numpy as np
from loguru import logger
import torch.distributed as dist
from torch import optim
from trainers.base_trainer import BaseTrainer
from utils.ema import EMA
from utils.model_helper import import_model, loss_fn
from utils.vis_helper import visualize_point_clouds_3d
from utils.eval_helper import compute_NLL_metric
from utils import model_helper, exp_helper, data_helper
from utils.data_helper import normalize_point_clouds
from utils.diffusion_pvd import DiffusionDiscretized
from utils.diffusion_continuous import make_diffusion, DiffusionBase
from utils.checker import *
from utils import utils
from matplotlib import pyplot as plt
import third_party.pvcnn.functional as pvcnn_fn
from timeit import default_timer as timer
from torch.optim import Adam as FusedAdam
from torch.cuda.amp import autocast, GradScaler
from trainers import common_fun_prior_train
@torch.no_grad()
def generate_samples_vada(shape, dae, diffusion, vae, num_samples,
enable_autocast, ode_eps=0.00001, ode_solver_tol=1e-5, # None,
ode_sample=False, prior_var=1.0, temp=1.0, vae_temp=1.0,
noise=None, need_denoise=False, ddim_step=0, clip_feat=None):
output = {}
if ode_sample == 1:
assert isinstance(
diffusion, DiffusionBase), 'ODE-based sampling requires cont. diffusion!'
assert ode_eps is not None, 'ODE-based sampling requires integration cutoff ode_eps!'
assert ode_solver_tol is not None, 'ODE-based sampling requires ode solver tolerance!'
start = timer()
eps, eps_list, nfe, time_ode_solve = diffusion.sample_model_ode(
dae, num_samples, shape, ode_eps,
ode_solver_tol, enable_autocast, temp, noise, return_all_sample=True)
output['sampled_eps'] = eps
output['eps_list'] = eps_list
logger.info('ode_eps={}', ode_eps)
elif ode_sample == 0:
assert isinstance(
diffusion, DiffusionDiscretized), 'Regular sampling requires disc. diffusion!'
assert noise is None, 'Noise is not used in ancestral sampling.'
nfe = diffusion._diffusion_steps
time_ode_solve = 999.999 # Yeah I know...
start = timer()
if ddim_step > 0:
eps, eps_list = diffusion.run_ddim(dae,
num_samples, shape, temp, enable_autocast,
is_image=False, prior_var=prior_var, ddim_step=ddim_step)
else:
eps, eps_list = diffusion.run_denoising_diffusion(dae,
num_samples, shape, temp, enable_autocast,
is_image=False, prior_var=prior_var)
output['sampled_eps'] = eps # latent pts
output['eps_list'] = eps_list
else:
raise NotImplementedError
output['print/sample_mean_global'] = eps.view(
num_samples, -1).mean(-1).mean()
output['print/sample_var_global'] = eps.view(
num_samples, -1).var(-1).mean()
decomposed_eps = vae.decompose_eps(eps)
image = vae.sample(num_samples=num_samples, decomposed_eps=decomposed_eps)
end = timer()
sampling_time = end - start
# average over GPUs
nfe_torch = torch.tensor(nfe * 1.0, device='cuda')
sampling_time_torch = torch.tensor(sampling_time * 1.0, device='cuda')
time_ode_solve_torch = torch.tensor(time_ode_solve * 1.0, device='cuda')
return image, nfe_torch, time_ode_solve_torch, sampling_time_torch, output
@torch.no_grad()
def validate_inspect(latent_shape,
model, dae, diffusion, ode_sample,
it, writer,
sample_num_points, num_samples,
autocast_train=False,
need_sample=1, need_val=1, need_train=1,
w_prior=None, val_x=None, tr_x=None,
val_input=None,
m_pcs=None, s_pcs=None,
test_loader=None, # can be None
has_shapelatent=False, vis_latent_point=False,
ddim_step=0, epoch=0, fun_generate_samples_vada=None, clip_feat=None,
cls_emb=None, cfg={}):
""" visualize the samples, and recont if needed
Args:
has_shapelatent (bool): True when the model has shape latent
it (int): step index
num_samples:
need_* : draw samples for * or not
"""
assert(has_shapelatent)
assert(w_prior is not None and val_x is not None and tr_x is not None)
z_list = []
num_samples = w_prior.shape[0] if need_sample else 0
num_recon = val_x.shape[0]
num_recon_val = num_recon if need_val else 0
num_recon_train = num_recon if need_train else 0
kwargs = {}
assert(need_sample >= 0 and need_val > 0 and need_train == 0)
if need_sample:
# gen_x: B,N,3
gen_x, nstep, ode_time, sample_time, output_dict = \
fun_generate_samples_vada(latent_shape, dae, diffusion,
model, w_prior.shape[0], enable_autocast=autocast_train,
ode_sample=ode_sample, ddim_step=ddim_step, clip_feat=clip_feat,
**kwargs)
logger.info('cast={}, sample step={}, ode_time={}, sample_time={}',
autocast_train,
nstep if ddim_step == 0 else ddim_step,
ode_time, sample_time)
gen_pcs = gen_x
else:
output_dict = {}
vis_order = cfg.viz.viz_order
vis_args = {'vis_order': vis_order,
}
# vis the samples
if not vis_latent_point and num_samples > 0:
img_list = []
for i in range(num_samples):
points = gen_x[i] # N,3
points = normalize_point_clouds([points])[0]
img = visualize_point_clouds_3d([points], **vis_args)
img_list.append(img)
img = np.concatenate(img_list, axis=2)
writer.add_image('sample', torch.as_tensor(img), it)
# vis the latent points
if vis_latent_point and num_samples > 0:
img_list = []
eps_list = []
eps = output_dict['sampled_eps'].view(
num_samples, dae.num_points, dae.num_classes)[:, :, :3]
for i in range(num_samples):
points = gen_x[i]
points = normalize_point_clouds([points])[0]
img = visualize_point_clouds_3d([points], **vis_args)
img_list.append(img)
points = eps[i]
points = normalize_point_clouds([points])[0]
img = visualize_point_clouds_3d([points], **vis_args)
eps_list.append(img)
img = np.concatenate(img_list, axis=2)
img_eps = np.concatenate(eps_list, axis=2)
img = np.concatenate([img, img_eps], axis=1)
writer.add_image('sample', torch.as_tensor(img), it)
logger.info('call recont')
inputs = val_input if val_input is not None else val_x
output = model.recont(inputs) if cls_emb is None else model.recont(
inputs, cls_emb=cls_emb)
gen_x = output['final_pred']
# vis the recont
if num_recon_val > 0:
img_list = []
for i in range(num_recon_val):
points = gen_x[i]
points = normalize_point_clouds([points])
img = visualize_point_clouds_3d(points, ['rec#%d' % i], **vis_args)
img_list.append(img)
gt_list = []
for i in range(num_recon_val):
points = normalize_point_clouds([val_x[i]])
img = visualize_point_clouds_3d(points, ['gt#%d' % i], **vis_args)
gt_list.append(img)
img = np.concatenate(img_list, axis=2)
gt = np.concatenate(gt_list, axis=2)
img = np.concatenate([gt, img], axis=1)
if 'vis/latent_pts' in output:
# also vis the input, used when we take voxel points as input
input_list = []
for i in range(num_recon_val):
points = output['vis/latent_pts'][i, :, :3]
points = normalize_point_clouds([points])
input_img = visualize_point_clouds_3d(
points, ['input#%d' % i], **vis_args)
input_list.append(input_img)
input_list = np.concatenate(input_list, axis=2)
img = np.concatenate([img, input_list], axis=1)
writer.add_image('valrecont', torch.as_tensor(img), it)
if num_recon_train > 0:
img_list = []
for i in range(num_recon_train):
points = gen_x[num_recon_val + i]
points = normalize_point_clouds([tr_x[i], points])
img = visualize_point_clouds_3d(points, ['ori', 'rec'], **vis_args)
img_list.append(img)
img = np.concatenate(img_list, axis=2)
writer.add_image('train/recont', torch.as_tensor(img), it)
logger.info('writer: {}', writer.url)
return output_dict
class Trainer(BaseTrainer):
is_diffusion = 0
def __init__(self, cfg, args):
"""
Args:
cfg: training config
args: used for distributed training
"""
super().__init__(cfg, args)
self.draw_sample_when_vis = 1
self.fun_generate_samples_vada = functools.partial(
generate_samples_vada, ode_eps=cfg.sde.ode_eps)
self.train_iter_kwargs = {}
self.cfg.sde.distributed = args.distributed
self.sample_num_points = cfg.data.tr_max_sample_points
self.model_var_type = cfg.ddpm.model_var_type
self.clip_denoised = cfg.ddpm.clip_denoised
self.num_steps = cfg.ddpm.num_steps
self.model_mean_type = cfg.ddpm.model_mean_type
self.loss_type = cfg.ddpm.loss_type
device = torch.device(self.device_str)
self.model = self.build_model().to(device)
if len(self.cfg.sde.vae_checkpoint) and not args.pretrained and self.cfg.sde.vae_checkpoint != 'none':
# if has pretrained ckpt, we dont need to load the vae ckpt anymore
logger.info('Load vae_checkpoint: {}', self.cfg.sde.vae_checkpoint)
vae_ckpt = torch.load(self.cfg.sde.vae_checkpoint)
vae_weight = vae_ckpt['model']
self.model.load_state_dict(vae_weight)
if self.cfg.shapelatent.model == 'models.hvae_ddpm':
self.model.build_other_module(device)
logger.info('broadcast_params: device={}', device)
utils.broadcast_params(self.model.parameters(),
args.distributed)
self.build_other_module()
self.build_prior()
if args.distributed:
logger.info('waitting for barrier, device={}', device)
dist.barrier()
logger.info('pass barrier, device={}', device)
self.train_loader, self.test_loader = self.build_data()
# The optimizer
self.init_optimizer()
# Prepare variable for summy
self.num_points = self.cfg.data.tr_max_sample_points
logger.info('done init trainer @{}', device)
# Prepare for evaluation
# init the latent for validate
self.prepare_vis_data()
self.alpha_i = utils.kl_balancer_coeff(
num_scales=2,
groups_per_scale=[1, 1], fun='square')
@property
def vae(self):
return self.model
def init_optimizer(self):
out_dict = common_fun_prior_train.init_optimizer_train_2prior(
self.cfg, self.vae, self.dae)
self.dae_sn_calculator, self.vae_sn_calculator = out_dict[
'dae_sn_calculator'], out_dict['vae_sn_calculator']
self.vae_scheduler, self.vae_optimizer = out_dict['vae_scheduler'], out_dict['vae_optimizer']
self.dae_scheduler, self.dae_optimizer = out_dict['dae_scheduler'], out_dict['dae_optimizer']
self.grad_scalar = out_dict['grad_scalar']
def resume(self, path, strict=True, **kwargs):
dae, vae = self.dae, self.vae
vae_optimizer, vae_scheduler, dae_optimizer, dae_scheduler = \
self.vae_optimizer, self.vae_scheduler, self.dae_optimizer, self.dae_scheduler
grad_scalar = self.grad_scalar
checkpoint = torch.load(path, map_location='cpu')
init_epoch = checkpoint['epoch']
epoch = init_epoch
dae.load_state_dict(checkpoint['dae_state_dict'])
# load dae
dae = dae.cuda()
dae_optimizer.load_state_dict(checkpoint['dae_optimizer'])
dae_scheduler.load_state_dict(checkpoint['dae_scheduler'])
# load vae
if self.cfg.eval.load_other_vae_ckpt:
raise NotImplementedError
else:
vae.load_state_dict(checkpoint['vae_state_dict'])
vae_optimizer.load_state_dict(checkpoint['vae_optimizer'])
vae = vae.cuda()
# need to commend if load regular vae from voxel2input_ada trainer
vae_scheduler.load_state_dict(checkpoint['vae_scheduler'])
grad_scalar.load_state_dict(checkpoint['grad_scalar'])
global_step = checkpoint['global_step']
## logger.info('loaded the model at epoch %d.'%init_epoch)
start_epoch = epoch
self.epoch = start_epoch
self.step = global_step
logger.info('resumedd from : {}, epo={}', path, start_epoch)
return start_epoch
def save(self, save_name=None, epoch=None, step=None, appendix=None, save_dir=None, **kwargs):
dae, vae = self.dae, self.vae
vae_optimizer, vae_scheduler, dae_optimizer, dae_scheduler = \
self.vae_optimizer, self.vae_scheduler, self.dae_optimizer, self.dae_scheduler
grad_scalar = self.grad_scalar
content = {'epoch': epoch + 1, 'global_step': step,
# 'args': self.cfg.sde, 'cfg': self.cfg,
'grad_scalar': grad_scalar.state_dict(),
'dae_state_dict': dae.state_dict(), 'dae_optimizer': dae_optimizer.state_dict(),
'dae_scheduler': dae_scheduler.state_dict(), 'vae_state_dict': vae.state_dict(),
'vae_optimizer': vae_optimizer.state_dict(), 'vae_scheduler': vae_scheduler.state_dict()}
if appendix is not None:
content.update(appendix)
save_name = "epoch_%s_iters_%s.pt" % (
epoch, step) if save_name is None else save_name
if save_dir is None:
save_dir = self.cfg.save_dir
path = os.path.join(save_dir, "checkpoints", save_name)
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
logger.info('save model as : {}', path)
torch.save(content, path)
return path
def epoch_start(self, epoch):
if epoch > self.cfg.sde.warmup_epochs:
self.dae_scheduler.step()
self.vae_scheduler.step()
def compute_loss_vae(self, tr_pts, global_step, **kwargs):
""" compute forward for VAE model, used in global-only prior training
Input:
tr_pts: points
global_step: int
Returns:
output dict including entry:
'eps': z ~ posterior
'q_loss': 0 if not train vae else the KL+rec
'x_0_pred': global points if not train vae
'x_0_target': target points
"""
vae = self.model
dae = self.dae
args = self.cfg.sde
distributed = args.distributed
vae_sn_calculator = self.vae_sn_calculator
num_total_iter = self.num_total_iter
if self.cfg.sde.ode_sample == 1:
diffusion = self.diffusion_cont
elif self.cfg.sde.ode_sample == 0:
diffusion = self.diffusion_disc
elif self.cfg.sde.ode_sample == 2:
raise NotImplementedError
## diffusion = [self.diffusion_cont, self.diffusion_disc]
## diffusion = self.diffusion_cont if self.cfg.sde.ode_sample else self.diffusion_disc
B = tr_pts.size(0)
with torch.set_grad_enabled(args.train_vae):
with autocast(enabled=args.autocast_train):
# posterior and likelihood
if not args.train_vae:
dist = vae.encode(tr_pts)
eps = dist.sample()[0] # B,D or B,N,D or BN,D
all_log_q = [dist.log_p(eps)]
x_0_pred = x_0_target = tr_pts
vae_recon_loss = 0
def make_4d(
x): return x.unsqueeze(-1).unsqueeze(-1) if len(x.shape) == 2 else x.unsqueeze(-1)
eps = make_4d(eps)
output = {'eps': eps, 'q_loss': torch.zeros(1),
'x_0_pred': tr_pts, 'x_0_target': tr_pts,
'x_0': tr_pts, 'final_pred': tr_pts}
else:
raise NotImplementedError
return output
# ------------------------------------------- #
# training fun #
# ------------------------------------------- #
def train_iter(self, data, *args, **kwargs):
""" forward one iteration; and step optimizer
Args:
data: (dict) tr_points shape: (B,N,3)
see get_loss in models/shapelatent_diffusion.py
"""
# some variables
input_dim = self.cfg.ddpm.input_dim
loss_type = self.cfg.ddpm.loss_type
vae = self.model
dae = self.dae
dae.train()
diffusion = self.diffusion_cont if self.cfg.sde.ode_sample else self.diffusion_disc
dae_optimizer = self.dae_optimizer
vae_optimizer = self.vae_optimizer
args = self.cfg.sde
device = torch.device(self.device_str)
num_total_iter = self.num_total_iter
distributed = self.args.distributed
dae_sn_calculator = self.dae_sn_calculator
vae_sn_calculator = self.vae_sn_calculator
grad_scalar = self.grad_scalar
global_step = step = kwargs.get('step', None)
no_update = kwargs.get('no_update', False)
# update_lr
warmup_iters = len(self.train_loader) * args.warmup_epochs
utils.update_lr(args, global_step, warmup_iters,
dae_optimizer, vae_optimizer)
# input
tr_pts = data['tr_points'].to(device) # (B, Npoints, 3)
# the noisy points, used in trainers/voxel2pts.py and trainers/voxel2pts_ada.py
inputs = data['input_pts'].to(device) if 'input_pts' in data else None
B = batch_size = tr_pts.size(0)
# optimize vae params
vae_optimizer.zero_grad()
output = self.compute_loss_vae(tr_pts, global_step, inputs=inputs)
# backpropagate q_loss for vae and update vae params, if trained
if args.train_vae:
q_loss = output['q_loss']
loss = q_loss
grad_scalar.scale(q_loss).backward()
utils.average_gradients(vae.parameters(), distributed)
if args.grad_clip_max_norm > 0.: # apply gradient clipping
grad_scalar.unscale_(vae_optimizer)
torch.nn.utils.clip_grad_norm_(vae.parameters(),
max_norm=args.grad_clip_max_norm)
grad_scalar.step(vae_optimizer)
# train prior
if args.train_dae:
# the interface between VAE and DAE is eps.
eps = output['eps'].detach() # 4d: B,D,-1,1
CHECK4D(eps)
dae_optimizer.zero_grad()
with autocast(enabled=args.autocast_train):
noise_p = torch.randn(size=eps.size(), device=device)
# get diffusion quantities for p sampling scheme and reweighting for q
t_p, var_t_p, m_t_p, obj_weight_t_p, _, g2_t_p = \
diffusion.iw_quantities(B, args.time_eps,
args.iw_sample_p, args.iw_subvp_like_vp_sde)
# logger.info('t_p: {}, var: {}, m_t: {}', t_p[0], var_t_p[0], m_t_p[0])
eps_t_p = diffusion.sample_q(eps, noise_p, var_t_p, m_t_p)
# run the score model
eps_t_p.requires_grad_(True)
mixing_component = diffusion.mixing_component(
eps_t_p, var_t_p, t_p, enabled=args.mixed_prediction)
pred_params_p = dae(eps_t_p, t_p, x0=eps)
# pred_eps_t0 = (eps_t_p - torch.sqrt(var_t_p) * noise_p) / m_t_p # this will recover the true eps
pred_eps_t0 = (eps_t_p - torch.sqrt(var_t_p)
* pred_params_p) / m_t_p
params = utils.get_mixed_prediction(args.mixed_prediction,
pred_params_p, dae.mixing_logit, mixing_component)
if self.cfg.latent_pts.pvd_mse_loss:
p_loss = F.mse_loss(
params.contiguous().view(B, -1), noise_p.view(B, -1),
reduction='mean')
else:
l2_term_p = torch.square(params - noise_p)
p_objective = torch.sum(
obj_weight_t_p * l2_term_p, dim=[1, 2, 3])
regularization_p, dae_norm_loss, dae_bn_loss, dae_wdn_coeff, \
jac_reg_loss, kin_reg_loss = utils.dae_regularization(
args, dae_sn_calculator, diffusion, dae, step, t_p,
pred_params_p, eps_t_p, var_t_p, m_t_p, g2_t_p)
if args.regularize_mlogit:
reg_mlogit = ((torch.sum(torch.sigmoid(dae.mixing_logit)) -
args.regularize_mlogit_margin)**2) * args.regularize_mlogit
else:
reg_mlogit = 0
p_loss = torch.mean(p_objective) + \
regularization_p + reg_mlogit
loss = p_loss
# update dae parameters
grad_scalar.scale(p_loss).backward()
utils.average_gradients(dae.parameters(), distributed)
if args.grad_clip_max_norm > 0.: # apply gradient clipping
grad_scalar.unscale_(dae_optimizer)
torch.nn.utils.clip_grad_norm_(dae.parameters(),
max_norm=args.grad_clip_max_norm)
grad_scalar.step(dae_optimizer)
# update grade scalar
grad_scalar.update()
if args.bound_mlogit:
dae.mixing_logit.data.clamp_(max=args.bound_mlogit_value)
# Bookkeeping!
writer = self.writer
if writer is not None:
writer.avg_meter('train/lr_dae', dae_optimizer.state_dict()[
'param_groups'][0]['lr'], global_step)
writer.avg_meter('train/lr_vae', vae_optimizer.state_dict()[
'param_groups'][0]['lr'], global_step)
if self.cfg.latent_pts.pvd_mse_loss:
writer.avg_meter(
'train/p_loss', p_loss.item(), global_step)
if args.mixed_prediction and global_step % 500 == 0:
m = torch.sigmoid(dae.mixing_logit)
if not torch.isnan(m).any():
writer.add_histogram(
'mixing_prob', m.detach().cpu().numpy(), global_step)
# no other loss
else:
writer.avg_meter(
'train/p_loss', (p_loss - regularization_p).item(), global_step)
if torch.is_tensor(regularization_p):
writer.avg_meter(
'train/reg_p', regularization_p.item(), global_step)
if args.regularize_mlogit:
writer.avg_meter(
'train/m_logit', reg_mlogit / args.regularize_mlogit, global_step)
if args.mixed_prediction:
writer.avg_meter(
'train/m_logit_sum', torch.sum(torch.sigmoid(dae.mixing_logit)).detach().cpu(), global_step)
if (global_step) % 500 == 0:
writer.add_scalar(
'train/norm_loss_dae', dae_norm_loss, global_step)
writer.add_scalar('train/bn_loss_dae',
dae_bn_loss, global_step)
writer.add_scalar(
'train/norm_coeff_dae', dae_wdn_coeff, global_step)
if args.mixed_prediction:
m = torch.sigmoid(dae.mixing_logit)
if not torch.isnan(m).any():
writer.add_histogram(
'mixing_prob', m.detach().cpu().numpy(), global_step)
# write stats
if self.writer is not None:
for k, v in output.items():
if 'print/' in k and step is not None:
self.writer.avg_meter(k.split('print/')[-1],
v.mean().item() if torch.is_tensor(v) else v,
step=step)
if 'hist/' in k:
output[k] = v
res = output
output_dict = {
'loss': loss.detach().cpu().item(),
'x_0_pred': res['x_0_pred'].detach().cpu(), # perturbed data
'x_0': res['x_0'].detach().cpu(),
# B.B,3
'x_t': res['final_pred'].detach().view(batch_size, -1, res['x_0'].shape[-1]),
't': res.get('t', None)
}
for k, v in output.items():
if 'vis/' in k or 'msg/' in k:
output_dict[k] = v
return output_dict
# --------------------------------------------- #
# visulization function and sampling function #
# --------------------------------------------- #
@torch.no_grad()
def vis_sample(self, writer, num_vis=None, step=0, include_pred_x0=True,
save_file=None):
if self.cfg.ddpm.ema:
self.swap_vae_param_if_need()
self.dae_optimizer.swap_parameters_with_ema(
store_params_in_ema=True)
shape = self.model.latent_shape()
logger.info('Latent shape for prior: {}; num_val_samples: {}',
shape, self.num_val_samples)
# [self.vae.latent_dim, .num_input_channels, dae.input_size, dae.input_size]
ode_sample = self.cfg.sde.ode_sample
## diffusion = self.diffusion_cont if ode_sample else self.diffusion_disc
if self.cfg.sde.ode_sample == 1:
diffusion = self.diffusion_cont
elif self.cfg.sde.ode_sample == 0:
diffusion = self.diffusion_disc
if self.cfg.clipforge.enable:
assert(self.clip_feat_test is not None)
kwargs = {}
output = validate_inspect(shape, self.model, self.dae,
diffusion, ode_sample,
step, self.writer, self.sample_num_points,
epoch=self.cur_epoch,
autocast_train=self.cfg.sde.autocast_train,
need_sample=self.draw_sample_when_vis,
need_val=1, need_train=0,
num_samples=self.num_val_samples,
test_loader=self.test_loader,
w_prior=self.w_prior,
val_x=self.val_x, tr_x=self.tr_x,
val_input=self.val_input,
m_pcs=self.m_pcs, s_pcs=self.s_pcs,
has_shapelatent=True,
vis_latent_point=self.cfg.vis_latent_point,
ddim_step=self.cfg.viz.vis_sample_ddim_step,
clip_feat=self.clip_feat_test,
cfg=self.cfg,
fun_generate_samples_vada=self.fun_generate_samples_vada,
**kwargs
)
if writer is not None:
for n, v in output.items():
if 'print/' not in n:
continue
self.writer.add_scalar('%s' % (n.split('print/')[-1]), v, step)
if self.cfg.ddpm.ema:
self.swap_vae_param_if_need()
self.dae_optimizer.swap_parameters_with_ema(
store_params_in_ema=True)
@torch.no_grad()
def sample(self, num_shapes=2, num_points=2048, device_str='cuda',
for_vis=True, use_ddim=False, save_file=None, ddim_step=0, clip_feat=None):
""" return the final samples in shape [B,3,N] """
# switch to EMA parameters
assert(
not self.cfg.clipforge.enable), f'not suuport yet, not sure what the clip feat will be'
cfg = self.cfg
if cfg.ddpm.ema:
self.swap_vae_param_if_need()
self.dae_optimizer.swap_parameters_with_ema(
store_params_in_ema=True)
self.model.eval() # Draw sample under train mode
S = self.num_steps
logger.info('num_shapes={}, num_points={}, use_ddim={}, Nstep={}',
num_shapes, num_points, use_ddim, S)
latent_shape = self.model.latent_shape()
ode_sample = self.cfg.sde.ode_sample
## diffusion = self.diffusion_cont if ode_sample else self.diffusion_disc
if self.cfg.sde.ode_sample == 1:
diffusion = self.diffusion_cont
elif self.cfg.sde.ode_sample == 0:
diffusion = self.diffusion_disc
elif self.cfg.sde.ode_sample == 2:
diffusion = [self.diffusion_cont, self.diffusion_disc]
# ---- forward sampling ---- #
gen_x, nstep, ode_time, sample_time, output_fsample = \
self.fun_generate_samples_vada(latent_shape, self.dae,
diffusion, self.model, num_shapes,
enable_autocast=self.cfg.sde.autocast_train,
ode_sample=ode_sample,
need_denoise=self.cfg.eval.need_denoise,
ddim_step=ddim_step,
clip_feat=clip_feat)
# gen_x: BNC
CHECKEQ(gen_x.shape[2], self.cfg.ddpm.input_dim)
if gen_x.shape[1] > self.sample_num_points:
gen_x = pvcnn_fn.furthest_point_sample(gen_x.permute(0, 2, 1).contiguous(),
self.sample_num_points).permute(0, 2, 1).contiguous() # [B,C,npoint]
traj = gen_x.permute(0, 2, 1).contiguous() # BN3->B3N
# ---- debug perpuse ---- #
if save_file:
if not os.path.exists(os.path.dirname(save_file)):
os.makedirs(os.path.dirname(save_file))
torch.save(traj.permute(0, 2, 1), save_file)
exit()
# switch back to original parameters
if cfg.ddpm.ema:
self.dae_optimizer.swap_parameters_with_ema(
store_params_in_ema=True)
self.swap_vae_param_if_need()
return traj
def build_prior(self):
args = self.cfg.sde
device = torch.device(self.device_str)
arch_instance_dae = utils.get_arch_cells_denoising(
'res_ho_attn', True, False)
num_input_channels = self.cfg.shapelatent.latent_dim
if self.cfg.sde.hier_prior:
if self.cfg.sde.prior_model == 'sim':
DAE = NCSNppPointHie
else:
DAE = import_model(self.cfg.sde.prior_model)
elif self.cfg.sde.prior_model == 'sim':
DAE = NCSNppPoint
else:
DAE = import_model(self.cfg.sde.prior_model)
self.dae = DAE(args, num_input_channels, self.cfg).to(device)
if len(self.cfg.sde.dae_checkpoint):
logger.info('Load dae checkpoint: {}',
self.cfg.sde.dae_checkpoint)
checkpoint = torch.load(
self.cfg.sde.dae_checkpoint, map_location='cpu')
self.dae.load_state_dict(checkpoint['dae_state_dict'])
self.diffusion_cont = make_diffusion(args)
self.diffusion_disc = DiffusionDiscretized(
args, self.diffusion_cont.var, self.cfg)
logger.info('DAE: {}', self.dae)
logger.info('DAE: param size = %fM ' %
utils.count_parameters_in_M(self.dae))
## self.check_consistence(self.diffusion_cont, self.diffusion_disc)
# sync all parameters between all gpus by sending param from rank 0 to all gpus.
utils.broadcast_params(self.dae.parameters(), self.args.distributed)
def swap_vae_param_if_need(self):
if self.cfg.eval.load_other_vae_ckpt:
self.optimizer.swap_parameters_with_ema(store_params_in_ema=True)