# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. import torch import numpy as np import time from loguru import logger from utils.ema import EMA from torch import optim from torch.optim import Adam as FusedAdam from torch.cuda.amp import autocast, GradScaler from utils.sr_utils import SpectralNormCalculator from utils import utils from utils.vis_helper import visualize_point_clouds_3d from utils.diffusion_pvd import DiffusionDiscretized from utils.eval_helper import compute_NLL_metric from utils import model_helper, exp_helper, data_helper from timeit import default_timer as timer from utils.data_helper import normalize_point_clouds def init_optimizer_train_2prior(cfg, vae, dae, cond_enc=None): args = cfg.sde param_dict_dae = dae.parameters() # optimizer for prior if args.learning_rate_mlogit > 0: raise NotImplementedError if args.use_adamax: from utils.adamax import Adamax dae_optimizer = Adamax(param_dict_dae, args.learning_rate_dae, weight_decay=args.weight_decay, eps=1e-4) elif args.use_adam: cfgopt = cfg.trainer.opt dae_optimizer = optim.Adam(param_dict_dae, lr=args.learning_rate_dae, betas=(cfgopt.beta1, cfgopt.beta2), weight_decay=cfgopt.weight_decay) else: dae_optimizer = FusedAdam(param_dict_dae, args.learning_rate_dae, weight_decay=args.weight_decay, eps=1e-4) # add EMA functionality to the optimizer dae_optimizer = EMA(dae_optimizer, ema_decay=args.ema_decay) dae_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( dae_optimizer, float(args.epochs - args.warmup_epochs - 1), eta_min=args.learning_rate_min_dae) # optimizer for VAE if args.use_adamax: from utils.adamax import Adamax vae_optimizer = Adamax(vae.parameters(), args.learning_rate_vae, weight_decay=args.weight_decay, eps=1e-3) elif args.use_adam: cfgopt = cfg.trainer.opt vae_optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate_min_vae, betas=(cfgopt.beta1, cfgopt.beta2), weight_decay=cfgopt.weight_decay) else: vae_optimizer = FusedAdam(vae.parameters(), args.learning_rate_vae, weight_decay=args.weight_decay, eps=1e-3) vae_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( vae_optimizer, float(args.epochs - args.warmup_epochs - 1), eta_min=args.learning_rate_min_vae) logger.info('[grad_scalar] enabled={}', args.autocast_train) if not args.autocast_train: grad_scalar = utils.DummyGradScalar() else: grad_scalar = GradScaler(2**10, enabled=True) # create SN calculator vae_sn_calculator = SpectralNormCalculator() dae_sn_calculator = SpectralNormCalculator() if args.train_vae: # TODO: require using layer in layers/neural_operations vae_sn_calculator.add_bn_layers(vae) dae_sn_calculator.add_bn_layers(dae) return { 'vae_scheduler': vae_scheduler, 'vae_optimizer': vae_optimizer, 'vae_sn_calculator': vae_sn_calculator, 'dae_scheduler': dae_scheduler, 'dae_optimizer': dae_optimizer, 'dae_sn_calculator': dae_sn_calculator, 'grad_scalar': grad_scalar } @torch.no_grad() def validate_inspect(latent_shape, model, dae, diffusion, ode_sample, it, writer, sample_num_points, num_samples, autocast_train=False, need_sample=1, need_val=1, need_train=1, w_prior=None, val_x=None, tr_x=None, val_input=None, prior_cond=None, m_pcs=None, s_pcs=None, test_loader=None, # can be None has_shapelatent=False, vis_latent_point=False, ddim_step=0, epoch=0, fun_generate_samples_vada=None, clip_feat=None, cls_emb=None, cfg={}): """ visualize the samples, and recont if needed Args: has_shapelatent (bool): True when the model has shape latent it (int): step index num_samples: need_* : draw samples for * or not """ assert(has_shapelatent) assert(w_prior is not None and val_x is not None and tr_x is not None) z_list = [] num_samples = w_prior.shape[0] if need_sample else 0 num_recon = val_x.shape[0] num_recon_val = num_recon if need_val else 0 num_recon_train = num_recon if need_train else 0 kwargs = {} if cls_emb is not None: kwargs['cls_emb'] = cls_emb assert(need_sample >= 0 and need_val > 0 and need_train == 0) # draw samples if need_sample: # gen_x: B,N,3 gen_x, nstep, ode_time, sample_time, output_dict = \ fun_generate_samples_vada(latent_shape, dae, diffusion, model, w_prior.shape[0], enable_autocast=autocast_train, prior_cond=prior_cond, ode_sample=ode_sample, ddim_step=ddim_step, clip_feat=clip_feat, **kwargs) logger.info('cast={}, sample step={}, ode_time={}, sample_time={}', autocast_train, nstep if ddim_step == 0 else ddim_step, ode_time, sample_time) gen_pcs = gen_x else: output_dict = {} rgb_as_normal = not cfg.data.has_color # if has color, rgb not as normal vis_order = cfg.viz.viz_order vis_args = {'rgb_as_normal': rgb_as_normal, 'vis_order': vis_order, 'is_omap': 'omap' in cfg.data.type} # vis the samples if not vis_latent_point and num_samples > 0: img_list = [] for i in range(num_samples): points = gen_x[i] # N,3 points = normalize_point_clouds([points])[0] img = visualize_point_clouds_3d([points], **vis_args) img_list.append(img) img = np.concatenate(img_list, axis=2) writer.add_image('sample', torch.as_tensor(img), it) if vis_latent_point and num_samples > 0: img_list = [] eps_list = [] prior_cond_list = [] eps = output_dict['sampled_eps'].view( num_samples, dae.num_points, dae.num_classes)[:, :, :cfg.ddpm.input_dim] for i in range(num_samples): points = gen_x[i] points = normalize_point_clouds([points])[0] img = visualize_point_clouds_3d([points], ['samples'], **vis_args) img_list.append(img) points = eps[i] points = normalize_point_clouds([points])[0] img = visualize_point_clouds_3d([points], ['eps'], **vis_args) eps_list.append(img) if prior_cond is not None: points = prior_cond[i] if len(points.shape) > 2: # points shape is (1,X,Y,Z) output_voxel_XYZ = points[0].cpu().numpy() # XYZ coordsid = np.where(output_voxel_XYZ == 1) coordsid = np.stack(coordsid, axis=1) # N,3 points = torch.from_numpy(coordsid) voxel_size = 1.0 X, Y, Z = output_voxel_XYZ.shape c = torch.tensor([X, Y, Z]).view(1, 3) * 0.5 points = points - c # center at 1 vis_points = points bound = max(X, Y, Z)*0.5 # logger.info('voxel_size: {}, output_voxel_XYZ: {}, bound: {}', # voxel_size, output_voxel_XYZ.shape, bound) elif vis_args['is_omap']: vis_points = points * s_pcs[i] # range before norm bound = s_pcs[0].max().item() voxel_size = cfg.data.voxel_size else: vis_points = points voxel_size = cfg.data.voxel_size bound = 1.5 # 2.0 img = visualize_point_clouds_3d([vis_points], ['cond'], is_voxel=1, voxel_size=voxel_size, bound=bound, **vis_args) points = normalize_point_clouds([points])[0] ## print('points', points.shape, points.numpy().min(0), points.numpy().max(0), points[:3]) img2 = visualize_point_clouds_3d([points], ['cond_center'], **vis_args) img = np.concatenate([img, img2], axis=1) prior_cond_list.append(img) img = np.concatenate(img_list, axis=2) img_eps = np.concatenate(eps_list, axis=2) prior_cond_list = np.concatenate(prior_cond_list, axis=2) if len( prior_cond_list) else prior_cond_list img = np.concatenate([img, img_eps], axis=1) img = np.concatenate([img, prior_cond_list], axis=1) if len( prior_cond_list) else img writer.add_image('sample', torch.as_tensor(img), it) inputs = val_input if val_input is not None else val_x output = model.recont(inputs) if cls_emb is None else model.recont( inputs, cls_emb=cls_emb) gen_x = output['final_pred'] # vis the recont on val set if num_recon_val > 0: img_list = [] for i in range(num_recon_val): points = gen_x[i] points = normalize_point_clouds([points]) img = visualize_point_clouds_3d(points, ['rec#%d' % i], **vis_args) img_list.append(img) gt_list = [] for i in range(num_recon_val): points = normalize_point_clouds([val_x[i]]) img = visualize_point_clouds_3d(points, ['gt#%d' % i], **vis_args) gt_list.append(img) img = np.concatenate(img_list, axis=2) gt = np.concatenate(gt_list, axis=2) img = np.concatenate([gt, img], axis=1) if val_input is not None: # also vis the input, used when we take voxel points as input input_list = [] for i in range(num_recon_val): points = val_input[i] points = normalize_point_clouds([points]) input_img = visualize_point_clouds_3d( points, ['input#%d' % i], **vis_args) input_list.append(input_img) input_list = np.concatenate(input_list, axis=2) img = np.concatenate([img, input_list], axis=1) writer.add_image('valrecont', torch.as_tensor(img), it) # vis recont on train set if num_recon_train > 0: img_list = [] for i in range(num_recon_train): points = gen_x[num_recon_val + i] points = normalize_point_clouds([tr_x[i], points]) img = visualize_point_clouds_3d(points, ['ori', 'rec'], **vis_args) img_list.append(img) img = np.concatenate(img_list, axis=2) writer.add_image('train/recont', torch.as_tensor(img), it) logger.info('writer: {}', writer.url) return output_dict @torch.no_grad() def generate_samples_vada_2prior(shape, dae, diffusion, vae, num_samples, enable_autocast, ode_eps=0.00001, ode_solver_tol=1e-5, ode_sample=False, prior_var=1.0, temp=1.0, vae_temp=1.0, noise=None, need_denoise=False, prior_cond=None, device=None, cfg=None, ddim_step=0, clip_feat=None, cls_emb=None): """ this function is copied from trainers/train_2prior.py used by trainers/cond_prior.py should also support trainers/train_2prior.py but not test yet """ output = {} if ode_sample == 1: assert isinstance( diffusion, DiffusionBase), 'ODE-based sampling requires cont. diffusion!' assert ode_eps is not None, 'ODE-based sampling requires integration cutoff ode_eps!' assert ode_solver_tol is not None, 'ODE-based sampling requires ode solver tolerance!' start = timer() condition_input = None eps_list = [] for i in range(2): assert(cls_emb is None), f' not support yet' eps, nfe, time_ode_solve = diffusion.sample_model_ode( dae[i], num_samples, shape[i], ode_eps, ode_solver_tol, enable_autocast, temp, noise, condition_input=condition_input, clip_feat=clip_feat, ) condition_input = eps eps_list.append(eps) output['sampled_eps'] = eps eps = vae.compose_eps(eps_list) # torch.cat(eps, dim=1) elif ode_sample == 0: assert isinstance( diffusion, DiffusionDiscretized), 'Regular sampling requires disc. diffusion!' assert noise is None, 'Noise is not used in ancestral sampling.' nfe = diffusion._diffusion_steps time_ode_solve = 999.999 # Yeah I know... start = timer() dae_kwarg = {'is_image': False, 'prior_var': prior_var} dae_kwarg['clip_feat'] = clip_feat if cfg.data.cond_on_voxel: output['prior_cond'] = prior_cond voxel_grid_enc_out = dae[2](prior_cond.to( device)) # embed the condition_input condition_input = voxel_grid_enc_out['global_emb'] else: condition_input = None if cls_emb is None else cls_emb all_eps = [] for i in range(2): if i == 1 and cfg.data.cond_on_voxel: dae_kwarg['grid_emb'] = voxel_grid_enc_out['grid_emb'] if ddim_step > 0: assert(cls_emb is None), f'not support yet' eps, eps_list = diffusion.run_ddim(dae[i], num_samples, shape[i], temp, enable_autocast, ddim_step=ddim_step, condition_input=condition_input, skip_type=cfg.sde.ddim_skip_type, kappa=cfg.sde.ddim_kappa, dae_index=i, **dae_kwarg) else: eps, eps_list = diffusion.run_denoising_diffusion(dae[i], num_samples, shape[i], temp, enable_autocast, condition_input=condition_input, **dae_kwarg ) condition_input = eps if cls_emb is not None: condition_input = torch.cat([condition_input, cls_emb.unsqueeze(-1).unsqueeze(-1)], dim=1) if i == 0: condition_input = vae.global2style(condition_input) all_eps.append(eps) output['sampled_eps'] = eps eps = vae.compose_eps(all_eps) output['eps_list'] = eps_list output['print/sample_mean_global'] = eps.view( num_samples, -1).mean(-1).mean() output['print/sample_var_global'] = eps.view( num_samples, -1).var(-1).mean() decomposed_eps = vae.decompose_eps(eps) image = vae.sample(num_samples=num_samples, decomposed_eps=decomposed_eps, cls_emb=cls_emb) end = timer() sampling_time = end - start # average over GPUs nfe_torch = torch.tensor(nfe * 1.0, device='cuda') sampling_time_torch = torch.tensor(sampling_time * 1.0, device='cuda') time_ode_solve_torch = torch.tensor(time_ode_solve * 1.0, device='cuda') return image, nfe_torch, time_ode_solve_torch, sampling_time_torch, output