From 90fe6410545f3bcfe714c853d550591b00606286 Mon Sep 17 00:00:00 2001 From: xzeng Date: Wed, 1 Feb 2023 21:09:13 -0500 Subject: [PATCH] support interpolation --- script/interpolate.sh | 7 + script/interpolate_posterior.sh | 7 + trainers/encode_interp_interp.py | 322 +++++++++++++++++++++++++++++++ trainers/interpolate_latent.py | 264 +++++++++++++++++++++++++ 4 files changed, 600 insertions(+) create mode 100644 script/interpolate.sh create mode 100644 script/interpolate_posterior.sh create mode 100644 trainers/encode_interp_interp.py create mode 100644 trainers/interpolate_latent.py diff --git a/script/interpolate.sh b/script/interpolate.sh new file mode 100644 index 0000000..10b1e74 --- /dev/null +++ b/script/interpolate.sh @@ -0,0 +1,7 @@ +NP=2048 +model="./lion_ckpt/unconditional/chair/checkpoints/model.pt" +python train_dist.py --eval_generation --pretrained $model --skip_nll 1 \ + data.batch_size_test 32 ddpm.ema 1 trainer.type trainers.interpolate_latent num_val_samples 20 trainer.seed 2 sde.ode_sample 1 \ + sde.beta_end 20.0 sde.embedding_scale 1000.0 \ + data.tr_max_sample_points ${NP} data.te_max_sample_points ${NP} shapelatent.decoder_num_points ${NP} + diff --git a/script/interpolate_posterior.sh b/script/interpolate_posterior.sh new file mode 100644 index 0000000..257f82d --- /dev/null +++ b/script/interpolate_posterior.sh @@ -0,0 +1,7 @@ +NP=2048 +model="./lion_ckpt/unconditional/chair/checkpoints/model.pt" +python train_dist.py --eval_generation --pretrained $model --skip_nll 0 \ + data.batch_size_test 32 ddpm.ema 1 trainer.type trainers.encode_interp_interp num_val_samples 20 trainer.seed 2 sde.ode_sample 1 \ + sde.beta_end 20.0 sde.embedding_scale 1000.0 \ + data.tr_max_sample_points ${NP} data.te_max_sample_points ${NP} shapelatent.decoder_num_points ${NP} + diff --git a/trainers/encode_interp_interp.py b/trainers/encode_interp_interp.py new file mode 100644 index 0000000..4611fc7 --- /dev/null +++ b/trainers/encode_interp_interp.py @@ -0,0 +1,322 @@ +""" to train VAE-encoder with two prior """ +import os +import time +import random +import torch +import torchvision +import torch.nn.functional as F +import torch.nn as nn +import numpy as np +from loguru import logger +from torch import optim +from utils.ema import EMA +from utils import eval_helper, exp_helper +from utils.sr_utils import SpectralNormCalculator +from utils.checker import * +from utils.utils import AvgrageMeter +from utils import utils +from torch.optim import Adam as FusedAdam +from torch.cuda.amp import autocast, GradScaler +from trainers.train_2prior import Trainer as BaseTrainer +from utils.data_helper import normalize_point_clouds +from utils.vis_helper import visualize_point_clouds_3d +from utils.eval_helper import compute_NLL_metric +CHECKPOINT = int(os.environ.get('CHECKPOINT', 0)) +EVAL_LNS = int(os.environ.get('EVAL_LNS', 0)) +def interpolate_noise(noise): + noise_a = noise[0].contiguous() # 1,D,1,1 + noise_b = noise[-1].contiguous() # 1,D,1,1 + num_inter = noise.shape[0] - 2 + for k in range(1, noise.shape[0]-1): + p = float(k) / len(noise) # 1/8 to 7/8 + ## logger.info('p={}; eps: {}', p, noise.shape) + # noise[k] = p * noise_b + (1-p) * noise_a + noise[k] = np.sqrt(p) * noise_b + np.sqrt(1-p) * noise_a + return noise +#def get_data(num_points=3072, num_selected=50): +# import h5py +# import torch +# import sys +# import os +# from datasets.data_path import get_path +# +# synsetid_to_cate = { +# "02691156": "airplane", +# "02828884": "bench", +# "02958343": "car", +# "03001627": "chair", +# "03211117": "display", +# "03636649": "lamp", +# "04256520": "sofa", +# "04379243": "table", +# "04530566": "watercraft", +# } +# +# +# root_dir = get_path('pointflow') +# out_dict = [] +# out_name = [] +# for synset_id in synsetid_to_cate.keys(): +# filename = os.path.join(root_dir, f'{synset_id}', 'train') +# tr_out_full = [] +# for file in sorted(os.listdir(filename))[:10]: +# pts = np.load(filename + '/' + file) +# tr_out_full.append(pts) +# tr_out_full = np.stack(tr_out_full) +# #with h5py.File(filename, "r") as h5f: +# # tr_out_full = h5f['surface_points/points'][5:5+num_selected] +# if tr_out_full.shape[1] > num_points: +# pt_cur = [] +# for b in range(tr_out_full.shape[0]): +# tr_idxs = np.random.choice(tr_out_full.shape[1], num_points) +# pt_cur.append(tr_out_full[b, tr_idxs]) ## pt_cur[tr_idxs] +# tr_out_full = torch.from_numpy(np.stack(pt_cur)) +# # out_dict[synsetid_to_cate[synset_id]] = tr_out_full +# out_dict.append(tr_out_full) +# out_name.extend([synsetid_to_cate[synset_id]] * num_selected) +# out_dict = torch.cat(out_dict) +# logger.info('created data: {}, ', out_dict.shape) +# return out_dict, out_name +# +#def fun_hash(a, b): +# if a < b: +# return f'{a}-{b}' +# else: +# return f'{b}-{a}' +# +#def fun_select_count_pair(names, a, b): +# select_all_pair = [] +# indexes = list( np.arange(len(names)) ) +# indexes_a = [i for ii, i in enumerate(indexes) if names[ii] == a] +# indexes_b = [i for ii, i in enumerate(indexes) if names[ii] == b] +# hash_d = [] +# logger.info('select paits for: {}, {}', a, b) +# for ai in indexes_a: +# for bi in indexes_b: +# if ai != bi and fun_hash(ai, bi) not in hash_d: +# hash_d.append(fun_hash(ai, bi)) +# select_all_pair.append([ai, bi]) +# logger.info('get pairs: {}', len(select_all_pair)) +# return select_all_pair +# +class Trainer(BaseTrainer): + is_diffusion = 0 + generate_mode_global = 'interpolate' + generate_mode_local = 'interpolate' + def __init__(self, cfg, args): + """ + Args: + cfg: training config + args: used for distributed training + """ + super().__init__(cfg, args) + self.draw_sample_when_vis = 0 + + @torch.no_grad() + def eval_sample(self, step=0): + pass # do nothing + + # -- shared method for all model with vae component -- # + @torch.no_grad() + def eval_nll(self, step, ntest=None, save_file=False): + cfg = self.cfg + if cfg.ddpm.ema: self.swap_vae_param_if_need() + self.dae_optimizer.swap_parameters_with_ema(store_params_in_ema=True) + args = self.args + device = torch.device('cuda:%d'%args.local_rank) + tag = exp_helper.get_evalname(self.cfg) + 'D%d'%self.cfg.voxel2pts.diffusion_steps[0] + ode_sample = self.cfg.sde.ode_sample + diffusion = self.diffusion_cont if ode_sample else self.diffusion_disc + eval_trainnll = self.cfg.eval_trainnll + data_loader = self.train_loader + #data_loader = self.test_loader + gen_pcs, ref_pcs = [], [] + gen_x_lns_list = [] + tag_cur = self.cfg.voxel2pts.diffusion_steps[0] + num_selected = 60 + #input_pts, input_names = get_data(self.num_points, num_selected) + + #is_all_class_model = 'nsall' in self.cfg.save_dir + #logger.info('is_all_class_model: {}', is_all_class_model) + #if is_all_class_model: + # select_count_pair = [] + # class_pairs = [ + # ['airplane', 'car'], + # #['car', 'watercraft'], + # #['watercraft', 'airplane'] + # #['airplane', 'bench'], + # #['bench', 'table'], + # #['table', 'chair'] + # #['chair', 'display'], + # ##['chair', 'lamp'], + # ##['lamp', 'watercraft'] + # ] + + # for class_a, class_b in class_pairs: + # select_count_pair_cur = fun_select_count_pair(input_names, class_a, class_b) + # select_count_pair.extend(select_count_pair_cur) + + #else: + # class_a = class_b = self.cfg.data.cates + # select_count_pair = fun_select_count_pair(input_names, class_a, class_b) + + #input_pts_list = [] + #input_pts_cates = [] + #print('number of pair', len(select_count_pair)) + #for select_count in select_count_pair: + # input_pts_list_pair = [] + # input_pts_cates_pair = [] + # for vi in select_count: + # input_pts_list_pair.append(input_pts[vi][None]) + # input_pts_cates_pair.append(input_names[vi]) + # input_pts_list.append(torch.cat(input_pts_list_pair, dim=0)) + # input_pts_cates.append('-'.join(input_pts_cates_pair)) + + ##shuffle_idx = list(range(len(input_pts_list))) + ##random.Random(38383).shuffle(shuffle_idx) + ##input_pts_list = [input_pts_list[i] for i in shuffle_idx][:50] # select first 50 pairs + ##input_pts_cates = [input_pts_cates[i] for i in shuffle_idx][:50] # select first 50 pairs + #logger.info('num of input pts: {}, cates: {}', + # len(input_pts_list), input_pts_cates[:10]) + output_dir_template = self.cfg.save_dir + '/enc%d_%s_%s/'%(num_selected, self.generate_mode_global, self.generate_mode_local) + vis_output_dir = self.cfg.save_dir + '/vis_enc%d_%s_%s/'%(num_selected, self.generate_mode_global, self.generate_mode_local) + if not os.path.exists(vis_output_dir): + os.makedirs(vis_output_dir) + + ode_eps = cfg.sde.ode_eps + ode_solver_tol = 1e-5 + enable_autocast = False + temp = 1.0 + noise = None + condition_input = None + clip_feat = None + + ##for vid, val_batch in enumerate(data_loader): + ## m, s = val_batch['mean'][0:1], val_batch['std'][0:1] + ## B = val_batch['mean'].shape[0] + ## break + #m = torch.zeros(1,3) + #m[:,0] = -0.0308 #-1.0504e-02 + #m[:,1] = -0.0353 #-4.1844e-03 + #m[:,2] = -0.0001 #-5.1331e-05 + #s = torch.zeros(1,1) + #s[0] = 0.1512 #0.1694 + + #logger.info('mean: {}, s={}', m, s) + B = 4 + for vid, val_batch in enumerate(data_loader): + ## for vid, (pt_cur, cate_cur) in enumerate(zip(input_pts_list, input_pts_cates)): + if vid % 30 == 1: + logger.info('eval: {}/{}, BS={}', vid, len(data_loader), B) + output_dir = output_dir_template + '/sph_B%d_%04d'%(B, vid) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + pt_cur = val_batch['tr_points'] + pt_cur = pt_cur[:2] # select two input here + B0,N,C = pt_cur.shape + input_pts = pt_cur[None].expand(B//B0, -1, -1, -1).contiguous().view(B,N,C).contiguous() + + # input_pts = ( input_pts - m ) / s + input_pts = input_pts.cuda().float() + file_name = ['%04d'%i for i in range(B)] + val_x = input_pts + logger.info('val_x: {}', val_x.shape) + + B,N,C = val_x.shape + inputs = val_x + log_info = {'x_0': val_x, 'vis/input_noise': inputs.view(B,N,-1), 'x_t': val_x} + + ## dist = self.model.encode(inputs) + # -- global latent -- # + dae_index = 0 + dist = self.model.encode_global(inputs) + + B = inputs.shape[0] + shape = self.model.latent_shape()[dae_index] + noise_shape_global = [B] + shape + eps = dist.sample()[0] + eps_shape_global = eps.shape + eps = eps.view(noise_shape_global) + eps_ori = eps + # get eps_0: + i = 0 + dae = self.dae + eps_0 = eps + eps_global = eps_0.contiguous() + condition_input = eps_global + + # style + style = self.model.global2style(eps_global.view(eps_shape_global)) + dist_local = self.model.encode_local(inputs, style) + + # -- local latent -- # + dae_index = 1 + shape = self.model.latent_shape()[dae_index] + noise_shape_local = [B] + shape + eps = dist_local.sample()[0] + eps_shape_local = eps.shape + eps = eps.view(noise_shape_local) + eps_ori = eps + + # ------------------- + # start the interplot + # ------------------- + ## eps_local = eps_ori + eps_T_global_interp = diffusion.compute_ode_nll( + dae[0], eps_0, ode_eps, ode_solver_tol, + condition_input=None + ) + eps_T_global_interp = interpolate_noise(eps_T_global_interp.contiguous()) + eps_0_global_interp, _, _= diffusion.sample_model_ode( + dae[0], B, shape, ode_eps, ode_solver_tol, enable_autocast, temp, + noise=eps_T_global_interp, + condition_input=None, clip_feat=clip_feat + ) + + eps_T_local_interp = diffusion.compute_ode_nll( + dae[1], eps, ode_eps, ode_solver_tol, + condition_input=eps_global + ) + + eps_T_local_interp = interpolate_noise(eps_T_local_interp.contiguous()) + # double check if the eps_0 can denoise to get eps: + eps_0_local_interp, _, _= diffusion.sample_model_ode( + dae[1], B, shape, ode_eps, + ode_solver_tol, enable_autocast, temp, + noise=eps_T_local_interp, + condition_input=eps_0_global_interp, + clip_feat=clip_feat + ) + + style = self.model.global2style(eps_0_global_interp.view(eps_shape_global)) + eps_local = eps_0_local_interp.view(eps_shape_local) + gen_x = self.model.decoder(None, beta=None, context=eps_local, style=style) # (B,ncenter,3) + ## gen_x = val_x + + # start the interpretation: + + + log_info['x_0_pred'] = gen_x.detach().cpu() + if True: + img_list = [] + for i in range(B): + points = gen_x[i] # N,3 + points = normalize_point_clouds([points])[0] + img = visualize_point_clouds_3d([points]) + img_list.append(img) + img = np.concatenate(img_list, axis=2) + self.writer.add_image('%d'%vid, torch.as_tensor(img), 0) + img_list = [torch.as_tensor(img) / 255.0 for img in img_list] + vis_file = os.path.join(vis_output_dir, '%04d.png'%(vid)) + torchvision.utils.save_image(img_list, vis_file) + logger.info('save img as: {}', vis_file) + + # forward to get output + val_x = val_x.contiguous() + inputs = inputs.contiguous() + output = self.model.get_loss(val_x, it=step, is_eval_nll=1, noisy_input=inputs) + #gen_x = gen_x.cpu() * s + m + for i, file_name_i in enumerate(file_name): + torch.save(gen_x[i], os.path.join(output_dir, file_name_i)) + logger.info('save output at : {}', output_dir) + return 0 diff --git a/trainers/interpolate_latent.py b/trainers/interpolate_latent.py new file mode 100644 index 0000000..4f3c70b --- /dev/null +++ b/trainers/interpolate_latent.py @@ -0,0 +1,264 @@ +""" to train hierarchical VAE model with 2 prior +one for style latent, one for latent pts, +based on trainers/train_prior.py +""" +import os +import time +import torchvision +from PIL import Image +import functools +import torch +import torch.nn as nn +import numpy as np +from loguru import logger +from utils.data_helper import normalize_point_clouds +from utils.vis_helper import visualize_point_clouds_3d +from utils import model_helper, exp_helper, data_helper +from utils.diffusion_pvd import DiffusionDiscretized +from utils.diffusion_continuous import make_diffusion, DiffusionBase +from utils.checker import * +from utils import utils +from matplotlib import pyplot as plt +from timeit import default_timer as timer +from trainers.train_2prior import Trainer as PriorTrainer +def linear_interpolate_noise(noise): + noise_a = noise[0].contiguous() # 1,D,1,1 + noise_b = noise[-1].contiguous() # 1,D,1,1 + num_inter = noise.shape[0] - 2 + for k in range(1, noise.shape[0]-1): + p = float(k) / len(noise) # 1/8 to 7/8 + ## logger.info('p={}; eps: {}', p, noise.shape) + noise[k] = p * noise_b + (1-p) * noise_a + return noise + +def interpolate_noise(noise): + noise_a = noise[0].contiguous() # 1,D,1,1 + noise_b = noise[-1].contiguous() # 1,D,1,1 + num_inter = noise.shape[0] - 2 + for k in range(1, noise.shape[0]-1): + p = float(k) / len(noise) # 1/8 to 7/8 + ## logger.info('p={}; eps: {}', p, noise.shape) + noise[k] = np.sqrt(p) * noise_b + np.sqrt(1-p) * noise_a + return noise + +def subtract_noise(noise): + noise_a = noise[12].contiguous() # 1,D,1,1 + noise_b = noise[15].contiguous() # 1,D,1,1 + diff = noise_a - noise_b + num_inter = noise.shape[0] - 2 + add_target_1 = noise[9] + add_target_2 = noise[10] + noise_list = [] + noise_list.append(noise_a) + noise_list.append(noise_b) + noise_list.append(add_target_1) + noise_list.append(add_target_2) + noise_list.append(add_target_1 + diff) + noise_list.append(add_target_2 + diff) + noise[:6] = torch.stack(noise_list) + return noise + + +VIS_LATENT_PTS = 0 +@torch.no_grad() +def validate_inspect(vis_file, latent_shape, + model, dae, diffusion, ode_sample, + it, writer, + sample_num_points, num_samples, + autocast_train=False, + need_sample=1, need_val=1, need_train=1, + w_prior=None, val_x=None, tr_x=None, + val_input=None, + m_pcs=None, s_pcs=None, + test_loader=None, # can be None + has_shapelatent=False, vis_latent_point=False, + ddim_step=0, epoch=0, fun_generate_samples_vada=None): + """ visualize the samples, and recont if needed + Args: + has_shapelatent (bool): True when the model has shape latent + it (int): step index + num_samples: + need_* : draw samples for * or not + """ + assert(has_shapelatent) + z_list = [] + num_samples = w_prior.shape[0] if need_sample else 0 + num_recon = val_x.shape[0] + num_recon_val = num_recon if need_val else 0 + num_recon_train = num_recon if need_train else 0 + + if need_sample: + # gen_x: B,N,3 + gen_x, nstep, ode_time, sample_time, output_dict = \ + fun_generate_samples_vada(latent_shape, dae, diffusion, model, w_prior.shape[0], + enable_autocast=autocast_train, + ode_sample=ode_sample, ddim_step=ddim_step) + logger.info('cast={}, sample step={}, ode_time={}, sample_time={}', + autocast_train, + nstep if ddim_step == 0 else ddim_step, + ode_time, sample_time) + gen_pcs = gen_x + else: + output_dict = {} + # vis the samples + if num_samples > 0: + img_list = [] + for i in range(num_samples): + points = gen_x[i] # N,3 + points = normalize_point_clouds([points])[0] + img = visualize_point_clouds_3d([points]) + img_list.append(img) + img = np.concatenate(img_list, axis=2) + writer.add_image('sample', torch.as_tensor(img), it) + img_list = [torch.as_tensor(img) / 255.0 for img in img_list] + torchvision.utils.save_image(img_list, vis_file) + logger.info('save img as: {}', vis_file) + + return output_dict + +@torch.no_grad() +def generate_samples(shape, dae, diffusion, vae, num_samples, enable_autocast, ode_eps=0.00001, ode_solver_tol=1e-5, ## None, + ode_sample=False, prior_var=1.0, temp=1.0, vae_temp=1.0, noise=None, need_denoise=False, + ddim_step=0, writer=None, generate_mode_global='interpolate', generate_mode_local='freeze'): + output = {} + if ode_sample: + assert isinstance(diffusion, DiffusionBase), 'ODE-based sampling requires cont. diffusion!' + assert ode_eps is not None, 'ODE-based sampling requires integration cutoff ode_eps!' + assert ode_solver_tol is not None, 'ODE-based sampling requires ode solver tolerance!' + start = timer() + condition_input = None + eps_list = [] + for i in range(len(dae)): + noise = torch.randn(size=[num_samples] + shape[i], device='cuda') + if i == 0: # interpolation + generate_mode = generate_mode_global + else: + generate_mode = generate_mode_local + logger.info('level: {}, generate_mode: {}', i, generate_mode) + if generate_mode == 'subtract': + logger.info('interpolate latent between left most and right most') + noise = subtract_noise(noise) + elif generate_mode == 'interpolate': + logger.info('interpolate latent between left most and right most') + noise = interpolate_noise(noise) + elif generate_mode == 'linear_interpolate': + logger.info('linear interpolate latent between left most and right most') + noise = linear_interpolate_noise(noise) + elif generate_mode == 'freeze': + for k in range(1, noise.shape[0]): + noise[k] = noise[0] # for local latent, use the same one for all samples + + eps, nfe, time_ode_solve = diffusion.sample_model_ode( + dae[i], num_samples, shape[i], ode_eps, ode_solver_tol, enable_autocast, temp, noise, + condition_input=condition_input + ) + condition_input = eps + eps_list.append(eps) + output['sampled_eps'] = eps + eps = vae.compose_eps(eps_list) + else: + raise NotImplementedError + output['print/sample_mean_global'] = eps.view(num_samples, -1).mean(-1).mean() + output['print/sample_var_global'] = eps.view(num_samples, -1).var(-1).mean() + decomposed_eps = vae.decompose_eps(eps) + image = vae.sample(num_samples=num_samples, decomposed_eps=decomposed_eps) + output['gen_x'] = image + + end = timer() + sampling_time = end - start + nfe_torch = torch.tensor(nfe * 1.0, device='cuda') + sampling_time_torch = torch.tensor(sampling_time * 1.0, device='cuda') + time_ode_solve_torch = torch.tensor(time_ode_solve * 1.0, device='cuda') + return image, nfe_torch, time_ode_solve_torch, sampling_time_torch, output + +class Trainer(PriorTrainer): + is_diffusion = 0 + generate_mode_global = 'interpolate' + generate_mode_local = 'interpolate' + def __init__(self, cfg, args): + """ + Args: + cfg: training config + args: used for distributed training + """ + cfg.num_val_samples = 20 + super().__init__(cfg, args) + @torch.no_grad() + def vis_sample(self, writer, num_vis=None, step=0, include_pred_x0=True, + save_file=None): + if self.cfg.ddpm.ema: + self.swap_vae_param_if_need() + self.dae_optimizer.swap_parameters_with_ema(store_params_in_ema=True) + shape = self.model.latent_shape() + logger.info('[url]: {}', writer.url) + logger.info('Latent shape for prior: {}; num_val_samples: {}', shape, self.num_val_samples) + ## [self.vae.latent_dim, .num_input_channels, dae.input_size, dae.input_size] + ode_sample = self.cfg.sde.ode_sample + diffusion = self.diffusion_cont if ode_sample else self.diffusion_disc + rank = 0 + seed = 0 + torch.manual_seed(rank + seed) + np.random.seed(rank + seed) + torch.cuda.manual_seed(rank + seed) + torch.cuda.manual_seed_all(rank + seed) + for idx in range(40): + output_dir = os.path.join(self.cfg.save_dir, 'interp', + 'mode_%s_%s_%d'%(self.generate_mode_global, + self.generate_mode_local, self.sample_num_points), + '%04d'%idx) + vis_dir = os.path.join(self.cfg.save_dir, 'interp', + 'mode_%s_%s_%d_img'%(self.generate_mode_global, + self.generate_mode_local, + self.sample_num_points)) + + logger.info('will save to {}', output_dir) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + if not os.path.exists(vis_dir): + os.makedirs(vis_dir) + vis_file = os.path.join(vis_dir, '%04d.png'%idx) + output = validate_inspect(vis_file, shape, self.model, self.dae, + diffusion, ode_sample, + step + idx, self.writer, self.sample_num_points, + epoch=self.cur_epoch, + autocast_train=self.cfg.sde.autocast_train, + need_sample=self.draw_sample_when_vis, + need_val=0, need_train=0, + num_samples=self.num_val_samples, + test_loader=self.test_loader, + w_prior=self.w_prior, + val_x=self.val_x, tr_x=self.tr_x, + val_input=self.val_input, + m_pcs=self.m_pcs, s_pcs=self.s_pcs, + has_shapelatent=True, + vis_latent_point=self.cfg.vis_latent_point, + ddim_step=self.cfg.viz.vis_sample_ddim_step, + fun_generate_samples_vada=self.fun_generate_samples_vada + ) + gen_x = output['gen_x'] + logger.info('gen_x shape: {}', gen_x.shape) + for idxx in range(len(gen_x)): + torch.save(gen_x[idxx], output_dir + '/%04d.pt'%idxx) + logger.info('save to {}', output_dir) + + if writer is not None: + for n, v in output.items(): + if 'print/' not in n: continue + self.writer.add_scalar('%s'%(n.split('print/')[-1]), v, step) + + if self.cfg.ddpm.ema: + self.swap_vae_param_if_need() + self.dae_optimizer.swap_parameters_with_ema(store_params_in_ema=True) + + + def set_writer(self, writer): + self.writer = writer + self.fun_generate_samples_vada = functools.partial( + generate_samples, ode_eps=self.cfg.sde.ode_eps, + writer=self.writer, + generate_mode_global=self.generate_mode_global, + generate_mode_local=self.generate_mode_local + ) + def eval_sample(self, step=0): + logger.info('skip eval-sample') + return 0