support interpolation
This commit is contained in:
parent
110f73f4c0
commit
90fe641054
7
script/interpolate.sh
Normal file
7
script/interpolate.sh
Normal file
|
@ -0,0 +1,7 @@
|
|||
NP=2048
|
||||
model="./lion_ckpt/unconditional/chair/checkpoints/model.pt"
|
||||
python train_dist.py --eval_generation --pretrained $model --skip_nll 1 \
|
||||
data.batch_size_test 32 ddpm.ema 1 trainer.type trainers.interpolate_latent num_val_samples 20 trainer.seed 2 sde.ode_sample 1 \
|
||||
sde.beta_end 20.0 sde.embedding_scale 1000.0 \
|
||||
data.tr_max_sample_points ${NP} data.te_max_sample_points ${NP} shapelatent.decoder_num_points ${NP}
|
||||
|
7
script/interpolate_posterior.sh
Normal file
7
script/interpolate_posterior.sh
Normal file
|
@ -0,0 +1,7 @@
|
|||
NP=2048
|
||||
model="./lion_ckpt/unconditional/chair/checkpoints/model.pt"
|
||||
python train_dist.py --eval_generation --pretrained $model --skip_nll 0 \
|
||||
data.batch_size_test 32 ddpm.ema 1 trainer.type trainers.encode_interp_interp num_val_samples 20 trainer.seed 2 sde.ode_sample 1 \
|
||||
sde.beta_end 20.0 sde.embedding_scale 1000.0 \
|
||||
data.tr_max_sample_points ${NP} data.te_max_sample_points ${NP} shapelatent.decoder_num_points ${NP}
|
||||
|
322
trainers/encode_interp_interp.py
Normal file
322
trainers/encode_interp_interp.py
Normal file
|
@ -0,0 +1,322 @@
|
|||
""" to train VAE-encoder with two prior """
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import torch
|
||||
import torchvision
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from torch import optim
|
||||
from utils.ema import EMA
|
||||
from utils import eval_helper, exp_helper
|
||||
from utils.sr_utils import SpectralNormCalculator
|
||||
from utils.checker import *
|
||||
from utils.utils import AvgrageMeter
|
||||
from utils import utils
|
||||
from torch.optim import Adam as FusedAdam
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from trainers.train_2prior import Trainer as BaseTrainer
|
||||
from utils.data_helper import normalize_point_clouds
|
||||
from utils.vis_helper import visualize_point_clouds_3d
|
||||
from utils.eval_helper import compute_NLL_metric
|
||||
CHECKPOINT = int(os.environ.get('CHECKPOINT', 0))
|
||||
EVAL_LNS = int(os.environ.get('EVAL_LNS', 0))
|
||||
def interpolate_noise(noise):
|
||||
noise_a = noise[0].contiguous() # 1,D,1,1
|
||||
noise_b = noise[-1].contiguous() # 1,D,1,1
|
||||
num_inter = noise.shape[0] - 2
|
||||
for k in range(1, noise.shape[0]-1):
|
||||
p = float(k) / len(noise) # 1/8 to 7/8
|
||||
## logger.info('p={}; eps: {}', p, noise.shape)
|
||||
# noise[k] = p * noise_b + (1-p) * noise_a
|
||||
noise[k] = np.sqrt(p) * noise_b + np.sqrt(1-p) * noise_a
|
||||
return noise
|
||||
#def get_data(num_points=3072, num_selected=50):
|
||||
# import h5py
|
||||
# import torch
|
||||
# import sys
|
||||
# import os
|
||||
# from datasets.data_path import get_path
|
||||
#
|
||||
# synsetid_to_cate = {
|
||||
# "02691156": "airplane",
|
||||
# "02828884": "bench",
|
||||
# "02958343": "car",
|
||||
# "03001627": "chair",
|
||||
# "03211117": "display",
|
||||
# "03636649": "lamp",
|
||||
# "04256520": "sofa",
|
||||
# "04379243": "table",
|
||||
# "04530566": "watercraft",
|
||||
# }
|
||||
#
|
||||
#
|
||||
# root_dir = get_path('pointflow')
|
||||
# out_dict = []
|
||||
# out_name = []
|
||||
# for synset_id in synsetid_to_cate.keys():
|
||||
# filename = os.path.join(root_dir, f'{synset_id}', 'train')
|
||||
# tr_out_full = []
|
||||
# for file in sorted(os.listdir(filename))[:10]:
|
||||
# pts = np.load(filename + '/' + file)
|
||||
# tr_out_full.append(pts)
|
||||
# tr_out_full = np.stack(tr_out_full)
|
||||
# #with h5py.File(filename, "r") as h5f:
|
||||
# # tr_out_full = h5f['surface_points/points'][5:5+num_selected]
|
||||
# if tr_out_full.shape[1] > num_points:
|
||||
# pt_cur = []
|
||||
# for b in range(tr_out_full.shape[0]):
|
||||
# tr_idxs = np.random.choice(tr_out_full.shape[1], num_points)
|
||||
# pt_cur.append(tr_out_full[b, tr_idxs]) ## pt_cur[tr_idxs]
|
||||
# tr_out_full = torch.from_numpy(np.stack(pt_cur))
|
||||
# # out_dict[synsetid_to_cate[synset_id]] = tr_out_full
|
||||
# out_dict.append(tr_out_full)
|
||||
# out_name.extend([synsetid_to_cate[synset_id]] * num_selected)
|
||||
# out_dict = torch.cat(out_dict)
|
||||
# logger.info('created data: {}, ', out_dict.shape)
|
||||
# return out_dict, out_name
|
||||
#
|
||||
#def fun_hash(a, b):
|
||||
# if a < b:
|
||||
# return f'{a}-{b}'
|
||||
# else:
|
||||
# return f'{b}-{a}'
|
||||
#
|
||||
#def fun_select_count_pair(names, a, b):
|
||||
# select_all_pair = []
|
||||
# indexes = list( np.arange(len(names)) )
|
||||
# indexes_a = [i for ii, i in enumerate(indexes) if names[ii] == a]
|
||||
# indexes_b = [i for ii, i in enumerate(indexes) if names[ii] == b]
|
||||
# hash_d = []
|
||||
# logger.info('select paits for: {}, {}', a, b)
|
||||
# for ai in indexes_a:
|
||||
# for bi in indexes_b:
|
||||
# if ai != bi and fun_hash(ai, bi) not in hash_d:
|
||||
# hash_d.append(fun_hash(ai, bi))
|
||||
# select_all_pair.append([ai, bi])
|
||||
# logger.info('get pairs: {}', len(select_all_pair))
|
||||
# return select_all_pair
|
||||
#
|
||||
class Trainer(BaseTrainer):
|
||||
is_diffusion = 0
|
||||
generate_mode_global = 'interpolate'
|
||||
generate_mode_local = 'interpolate'
|
||||
def __init__(self, cfg, args):
|
||||
"""
|
||||
Args:
|
||||
cfg: training config
|
||||
args: used for distributed training
|
||||
"""
|
||||
super().__init__(cfg, args)
|
||||
self.draw_sample_when_vis = 0
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_sample(self, step=0):
|
||||
pass # do nothing
|
||||
|
||||
# -- shared method for all model with vae component -- #
|
||||
@torch.no_grad()
|
||||
def eval_nll(self, step, ntest=None, save_file=False):
|
||||
cfg = self.cfg
|
||||
if cfg.ddpm.ema: self.swap_vae_param_if_need()
|
||||
self.dae_optimizer.swap_parameters_with_ema(store_params_in_ema=True)
|
||||
args = self.args
|
||||
device = torch.device('cuda:%d'%args.local_rank)
|
||||
tag = exp_helper.get_evalname(self.cfg) + 'D%d'%self.cfg.voxel2pts.diffusion_steps[0]
|
||||
ode_sample = self.cfg.sde.ode_sample
|
||||
diffusion = self.diffusion_cont if ode_sample else self.diffusion_disc
|
||||
eval_trainnll = self.cfg.eval_trainnll
|
||||
data_loader = self.train_loader
|
||||
#data_loader = self.test_loader
|
||||
gen_pcs, ref_pcs = [], []
|
||||
gen_x_lns_list = []
|
||||
tag_cur = self.cfg.voxel2pts.diffusion_steps[0]
|
||||
num_selected = 60
|
||||
#input_pts, input_names = get_data(self.num_points, num_selected)
|
||||
|
||||
#is_all_class_model = 'nsall' in self.cfg.save_dir
|
||||
#logger.info('is_all_class_model: {}', is_all_class_model)
|
||||
#if is_all_class_model:
|
||||
# select_count_pair = []
|
||||
# class_pairs = [
|
||||
# ['airplane', 'car'],
|
||||
# #['car', 'watercraft'],
|
||||
# #['watercraft', 'airplane']
|
||||
# #['airplane', 'bench'],
|
||||
# #['bench', 'table'],
|
||||
# #['table', 'chair']
|
||||
# #['chair', 'display'],
|
||||
# ##['chair', 'lamp'],
|
||||
# ##['lamp', 'watercraft']
|
||||
# ]
|
||||
|
||||
# for class_a, class_b in class_pairs:
|
||||
# select_count_pair_cur = fun_select_count_pair(input_names, class_a, class_b)
|
||||
# select_count_pair.extend(select_count_pair_cur)
|
||||
|
||||
#else:
|
||||
# class_a = class_b = self.cfg.data.cates
|
||||
# select_count_pair = fun_select_count_pair(input_names, class_a, class_b)
|
||||
|
||||
#input_pts_list = []
|
||||
#input_pts_cates = []
|
||||
#print('number of pair', len(select_count_pair))
|
||||
#for select_count in select_count_pair:
|
||||
# input_pts_list_pair = []
|
||||
# input_pts_cates_pair = []
|
||||
# for vi in select_count:
|
||||
# input_pts_list_pair.append(input_pts[vi][None])
|
||||
# input_pts_cates_pair.append(input_names[vi])
|
||||
# input_pts_list.append(torch.cat(input_pts_list_pair, dim=0))
|
||||
# input_pts_cates.append('-'.join(input_pts_cates_pair))
|
||||
|
||||
##shuffle_idx = list(range(len(input_pts_list)))
|
||||
##random.Random(38383).shuffle(shuffle_idx)
|
||||
##input_pts_list = [input_pts_list[i] for i in shuffle_idx][:50] # select first 50 pairs
|
||||
##input_pts_cates = [input_pts_cates[i] for i in shuffle_idx][:50] # select first 50 pairs
|
||||
#logger.info('num of input pts: {}, cates: {}',
|
||||
# len(input_pts_list), input_pts_cates[:10])
|
||||
output_dir_template = self.cfg.save_dir + '/enc%d_%s_%s/'%(num_selected, self.generate_mode_global, self.generate_mode_local)
|
||||
vis_output_dir = self.cfg.save_dir + '/vis_enc%d_%s_%s/'%(num_selected, self.generate_mode_global, self.generate_mode_local)
|
||||
if not os.path.exists(vis_output_dir):
|
||||
os.makedirs(vis_output_dir)
|
||||
|
||||
ode_eps = cfg.sde.ode_eps
|
||||
ode_solver_tol = 1e-5
|
||||
enable_autocast = False
|
||||
temp = 1.0
|
||||
noise = None
|
||||
condition_input = None
|
||||
clip_feat = None
|
||||
|
||||
##for vid, val_batch in enumerate(data_loader):
|
||||
## m, s = val_batch['mean'][0:1], val_batch['std'][0:1]
|
||||
## B = val_batch['mean'].shape[0]
|
||||
## break
|
||||
#m = torch.zeros(1,3)
|
||||
#m[:,0] = -0.0308 #-1.0504e-02
|
||||
#m[:,1] = -0.0353 #-4.1844e-03
|
||||
#m[:,2] = -0.0001 #-5.1331e-05
|
||||
#s = torch.zeros(1,1)
|
||||
#s[0] = 0.1512 #0.1694
|
||||
|
||||
#logger.info('mean: {}, s={}', m, s)
|
||||
B = 4
|
||||
for vid, val_batch in enumerate(data_loader):
|
||||
## for vid, (pt_cur, cate_cur) in enumerate(zip(input_pts_list, input_pts_cates)):
|
||||
if vid % 30 == 1:
|
||||
logger.info('eval: {}/{}, BS={}', vid, len(data_loader), B)
|
||||
output_dir = output_dir_template + '/sph_B%d_%04d'%(B, vid)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
pt_cur = val_batch['tr_points']
|
||||
pt_cur = pt_cur[:2] # select two input here
|
||||
B0,N,C = pt_cur.shape
|
||||
input_pts = pt_cur[None].expand(B//B0, -1, -1, -1).contiguous().view(B,N,C).contiguous()
|
||||
|
||||
# input_pts = ( input_pts - m ) / s
|
||||
input_pts = input_pts.cuda().float()
|
||||
file_name = ['%04d'%i for i in range(B)]
|
||||
val_x = input_pts
|
||||
logger.info('val_x: {}', val_x.shape)
|
||||
|
||||
B,N,C = val_x.shape
|
||||
inputs = val_x
|
||||
log_info = {'x_0': val_x, 'vis/input_noise': inputs.view(B,N,-1), 'x_t': val_x}
|
||||
|
||||
## dist = self.model.encode(inputs)
|
||||
# -- global latent -- #
|
||||
dae_index = 0
|
||||
dist = self.model.encode_global(inputs)
|
||||
|
||||
B = inputs.shape[0]
|
||||
shape = self.model.latent_shape()[dae_index]
|
||||
noise_shape_global = [B] + shape
|
||||
eps = dist.sample()[0]
|
||||
eps_shape_global = eps.shape
|
||||
eps = eps.view(noise_shape_global)
|
||||
eps_ori = eps
|
||||
# get eps_0:
|
||||
i = 0
|
||||
dae = self.dae
|
||||
eps_0 = eps
|
||||
eps_global = eps_0.contiguous()
|
||||
condition_input = eps_global
|
||||
|
||||
# style
|
||||
style = self.model.global2style(eps_global.view(eps_shape_global))
|
||||
dist_local = self.model.encode_local(inputs, style)
|
||||
|
||||
# -- local latent -- #
|
||||
dae_index = 1
|
||||
shape = self.model.latent_shape()[dae_index]
|
||||
noise_shape_local = [B] + shape
|
||||
eps = dist_local.sample()[0]
|
||||
eps_shape_local = eps.shape
|
||||
eps = eps.view(noise_shape_local)
|
||||
eps_ori = eps
|
||||
|
||||
# -------------------
|
||||
# start the interplot
|
||||
# -------------------
|
||||
## eps_local = eps_ori
|
||||
eps_T_global_interp = diffusion.compute_ode_nll(
|
||||
dae[0], eps_0, ode_eps, ode_solver_tol,
|
||||
condition_input=None
|
||||
)
|
||||
eps_T_global_interp = interpolate_noise(eps_T_global_interp.contiguous())
|
||||
eps_0_global_interp, _, _= diffusion.sample_model_ode(
|
||||
dae[0], B, shape, ode_eps, ode_solver_tol, enable_autocast, temp,
|
||||
noise=eps_T_global_interp,
|
||||
condition_input=None, clip_feat=clip_feat
|
||||
)
|
||||
|
||||
eps_T_local_interp = diffusion.compute_ode_nll(
|
||||
dae[1], eps, ode_eps, ode_solver_tol,
|
||||
condition_input=eps_global
|
||||
)
|
||||
|
||||
eps_T_local_interp = interpolate_noise(eps_T_local_interp.contiguous())
|
||||
# double check if the eps_0 can denoise to get eps:
|
||||
eps_0_local_interp, _, _= diffusion.sample_model_ode(
|
||||
dae[1], B, shape, ode_eps,
|
||||
ode_solver_tol, enable_autocast, temp,
|
||||
noise=eps_T_local_interp,
|
||||
condition_input=eps_0_global_interp,
|
||||
clip_feat=clip_feat
|
||||
)
|
||||
|
||||
style = self.model.global2style(eps_0_global_interp.view(eps_shape_global))
|
||||
eps_local = eps_0_local_interp.view(eps_shape_local)
|
||||
gen_x = self.model.decoder(None, beta=None, context=eps_local, style=style) # (B,ncenter,3)
|
||||
## gen_x = val_x
|
||||
|
||||
# start the interpretation:
|
||||
|
||||
|
||||
log_info['x_0_pred'] = gen_x.detach().cpu()
|
||||
if True:
|
||||
img_list = []
|
||||
for i in range(B):
|
||||
points = gen_x[i] # N,3
|
||||
points = normalize_point_clouds([points])[0]
|
||||
img = visualize_point_clouds_3d([points])
|
||||
img_list.append(img)
|
||||
img = np.concatenate(img_list, axis=2)
|
||||
self.writer.add_image('%d'%vid, torch.as_tensor(img), 0)
|
||||
img_list = [torch.as_tensor(img) / 255.0 for img in img_list]
|
||||
vis_file = os.path.join(vis_output_dir, '%04d.png'%(vid))
|
||||
torchvision.utils.save_image(img_list, vis_file)
|
||||
logger.info('save img as: {}', vis_file)
|
||||
|
||||
# forward to get output
|
||||
val_x = val_x.contiguous()
|
||||
inputs = inputs.contiguous()
|
||||
output = self.model.get_loss(val_x, it=step, is_eval_nll=1, noisy_input=inputs)
|
||||
#gen_x = gen_x.cpu() * s + m
|
||||
for i, file_name_i in enumerate(file_name):
|
||||
torch.save(gen_x[i], os.path.join(output_dir, file_name_i))
|
||||
logger.info('save output at : {}', output_dir)
|
||||
return 0
|
264
trainers/interpolate_latent.py
Normal file
264
trainers/interpolate_latent.py
Normal file
|
@ -0,0 +1,264 @@
|
|||
""" to train hierarchical VAE model with 2 prior
|
||||
one for style latent, one for latent pts,
|
||||
based on trainers/train_prior.py
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from utils.data_helper import normalize_point_clouds
|
||||
from utils.vis_helper import visualize_point_clouds_3d
|
||||
from utils import model_helper, exp_helper, data_helper
|
||||
from utils.diffusion_pvd import DiffusionDiscretized
|
||||
from utils.diffusion_continuous import make_diffusion, DiffusionBase
|
||||
from utils.checker import *
|
||||
from utils import utils
|
||||
from matplotlib import pyplot as plt
|
||||
from timeit import default_timer as timer
|
||||
from trainers.train_2prior import Trainer as PriorTrainer
|
||||
def linear_interpolate_noise(noise):
|
||||
noise_a = noise[0].contiguous() # 1,D,1,1
|
||||
noise_b = noise[-1].contiguous() # 1,D,1,1
|
||||
num_inter = noise.shape[0] - 2
|
||||
for k in range(1, noise.shape[0]-1):
|
||||
p = float(k) / len(noise) # 1/8 to 7/8
|
||||
## logger.info('p={}; eps: {}', p, noise.shape)
|
||||
noise[k] = p * noise_b + (1-p) * noise_a
|
||||
return noise
|
||||
|
||||
def interpolate_noise(noise):
|
||||
noise_a = noise[0].contiguous() # 1,D,1,1
|
||||
noise_b = noise[-1].contiguous() # 1,D,1,1
|
||||
num_inter = noise.shape[0] - 2
|
||||
for k in range(1, noise.shape[0]-1):
|
||||
p = float(k) / len(noise) # 1/8 to 7/8
|
||||
## logger.info('p={}; eps: {}', p, noise.shape)
|
||||
noise[k] = np.sqrt(p) * noise_b + np.sqrt(1-p) * noise_a
|
||||
return noise
|
||||
|
||||
def subtract_noise(noise):
|
||||
noise_a = noise[12].contiguous() # 1,D,1,1
|
||||
noise_b = noise[15].contiguous() # 1,D,1,1
|
||||
diff = noise_a - noise_b
|
||||
num_inter = noise.shape[0] - 2
|
||||
add_target_1 = noise[9]
|
||||
add_target_2 = noise[10]
|
||||
noise_list = []
|
||||
noise_list.append(noise_a)
|
||||
noise_list.append(noise_b)
|
||||
noise_list.append(add_target_1)
|
||||
noise_list.append(add_target_2)
|
||||
noise_list.append(add_target_1 + diff)
|
||||
noise_list.append(add_target_2 + diff)
|
||||
noise[:6] = torch.stack(noise_list)
|
||||
return noise
|
||||
|
||||
|
||||
VIS_LATENT_PTS = 0
|
||||
@torch.no_grad()
|
||||
def validate_inspect(vis_file, latent_shape,
|
||||
model, dae, diffusion, ode_sample,
|
||||
it, writer,
|
||||
sample_num_points, num_samples,
|
||||
autocast_train=False,
|
||||
need_sample=1, need_val=1, need_train=1,
|
||||
w_prior=None, val_x=None, tr_x=None,
|
||||
val_input=None,
|
||||
m_pcs=None, s_pcs=None,
|
||||
test_loader=None, # can be None
|
||||
has_shapelatent=False, vis_latent_point=False,
|
||||
ddim_step=0, epoch=0, fun_generate_samples_vada=None):
|
||||
""" visualize the samples, and recont if needed
|
||||
Args:
|
||||
has_shapelatent (bool): True when the model has shape latent
|
||||
it (int): step index
|
||||
num_samples:
|
||||
need_* : draw samples for * or not
|
||||
"""
|
||||
assert(has_shapelatent)
|
||||
z_list = []
|
||||
num_samples = w_prior.shape[0] if need_sample else 0
|
||||
num_recon = val_x.shape[0]
|
||||
num_recon_val = num_recon if need_val else 0
|
||||
num_recon_train = num_recon if need_train else 0
|
||||
|
||||
if need_sample:
|
||||
# gen_x: B,N,3
|
||||
gen_x, nstep, ode_time, sample_time, output_dict = \
|
||||
fun_generate_samples_vada(latent_shape, dae, diffusion, model, w_prior.shape[0],
|
||||
enable_autocast=autocast_train,
|
||||
ode_sample=ode_sample, ddim_step=ddim_step)
|
||||
logger.info('cast={}, sample step={}, ode_time={}, sample_time={}',
|
||||
autocast_train,
|
||||
nstep if ddim_step == 0 else ddim_step,
|
||||
ode_time, sample_time)
|
||||
gen_pcs = gen_x
|
||||
else:
|
||||
output_dict = {}
|
||||
# vis the samples
|
||||
if num_samples > 0:
|
||||
img_list = []
|
||||
for i in range(num_samples):
|
||||
points = gen_x[i] # N,3
|
||||
points = normalize_point_clouds([points])[0]
|
||||
img = visualize_point_clouds_3d([points])
|
||||
img_list.append(img)
|
||||
img = np.concatenate(img_list, axis=2)
|
||||
writer.add_image('sample', torch.as_tensor(img), it)
|
||||
img_list = [torch.as_tensor(img) / 255.0 for img in img_list]
|
||||
torchvision.utils.save_image(img_list, vis_file)
|
||||
logger.info('save img as: {}', vis_file)
|
||||
|
||||
return output_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_samples(shape, dae, diffusion, vae, num_samples, enable_autocast, ode_eps=0.00001, ode_solver_tol=1e-5, ## None,
|
||||
ode_sample=False, prior_var=1.0, temp=1.0, vae_temp=1.0, noise=None, need_denoise=False,
|
||||
ddim_step=0, writer=None, generate_mode_global='interpolate', generate_mode_local='freeze'):
|
||||
output = {}
|
||||
if ode_sample:
|
||||
assert isinstance(diffusion, DiffusionBase), 'ODE-based sampling requires cont. diffusion!'
|
||||
assert ode_eps is not None, 'ODE-based sampling requires integration cutoff ode_eps!'
|
||||
assert ode_solver_tol is not None, 'ODE-based sampling requires ode solver tolerance!'
|
||||
start = timer()
|
||||
condition_input = None
|
||||
eps_list = []
|
||||
for i in range(len(dae)):
|
||||
noise = torch.randn(size=[num_samples] + shape[i], device='cuda')
|
||||
if i == 0: # interpolation
|
||||
generate_mode = generate_mode_global
|
||||
else:
|
||||
generate_mode = generate_mode_local
|
||||
logger.info('level: {}, generate_mode: {}', i, generate_mode)
|
||||
if generate_mode == 'subtract':
|
||||
logger.info('interpolate latent between left most and right most')
|
||||
noise = subtract_noise(noise)
|
||||
elif generate_mode == 'interpolate':
|
||||
logger.info('interpolate latent between left most and right most')
|
||||
noise = interpolate_noise(noise)
|
||||
elif generate_mode == 'linear_interpolate':
|
||||
logger.info('linear interpolate latent between left most and right most')
|
||||
noise = linear_interpolate_noise(noise)
|
||||
elif generate_mode == 'freeze':
|
||||
for k in range(1, noise.shape[0]):
|
||||
noise[k] = noise[0] # for local latent, use the same one for all samples
|
||||
|
||||
eps, nfe, time_ode_solve = diffusion.sample_model_ode(
|
||||
dae[i], num_samples, shape[i], ode_eps, ode_solver_tol, enable_autocast, temp, noise,
|
||||
condition_input=condition_input
|
||||
)
|
||||
condition_input = eps
|
||||
eps_list.append(eps)
|
||||
output['sampled_eps'] = eps
|
||||
eps = vae.compose_eps(eps_list)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
output['print/sample_mean_global'] = eps.view(num_samples, -1).mean(-1).mean()
|
||||
output['print/sample_var_global'] = eps.view(num_samples, -1).var(-1).mean()
|
||||
decomposed_eps = vae.decompose_eps(eps)
|
||||
image = vae.sample(num_samples=num_samples, decomposed_eps=decomposed_eps)
|
||||
output['gen_x'] = image
|
||||
|
||||
end = timer()
|
||||
sampling_time = end - start
|
||||
nfe_torch = torch.tensor(nfe * 1.0, device='cuda')
|
||||
sampling_time_torch = torch.tensor(sampling_time * 1.0, device='cuda')
|
||||
time_ode_solve_torch = torch.tensor(time_ode_solve * 1.0, device='cuda')
|
||||
return image, nfe_torch, time_ode_solve_torch, sampling_time_torch, output
|
||||
|
||||
class Trainer(PriorTrainer):
|
||||
is_diffusion = 0
|
||||
generate_mode_global = 'interpolate'
|
||||
generate_mode_local = 'interpolate'
|
||||
def __init__(self, cfg, args):
|
||||
"""
|
||||
Args:
|
||||
cfg: training config
|
||||
args: used for distributed training
|
||||
"""
|
||||
cfg.num_val_samples = 20
|
||||
super().__init__(cfg, args)
|
||||
@torch.no_grad()
|
||||
def vis_sample(self, writer, num_vis=None, step=0, include_pred_x0=True,
|
||||
save_file=None):
|
||||
if self.cfg.ddpm.ema:
|
||||
self.swap_vae_param_if_need()
|
||||
self.dae_optimizer.swap_parameters_with_ema(store_params_in_ema=True)
|
||||
shape = self.model.latent_shape()
|
||||
logger.info('[url]: {}', writer.url)
|
||||
logger.info('Latent shape for prior: {}; num_val_samples: {}', shape, self.num_val_samples)
|
||||
## [self.vae.latent_dim, .num_input_channels, dae.input_size, dae.input_size]
|
||||
ode_sample = self.cfg.sde.ode_sample
|
||||
diffusion = self.diffusion_cont if ode_sample else self.diffusion_disc
|
||||
rank = 0
|
||||
seed = 0
|
||||
torch.manual_seed(rank + seed)
|
||||
np.random.seed(rank + seed)
|
||||
torch.cuda.manual_seed(rank + seed)
|
||||
torch.cuda.manual_seed_all(rank + seed)
|
||||
for idx in range(40):
|
||||
output_dir = os.path.join(self.cfg.save_dir, 'interp',
|
||||
'mode_%s_%s_%d'%(self.generate_mode_global,
|
||||
self.generate_mode_local, self.sample_num_points),
|
||||
'%04d'%idx)
|
||||
vis_dir = os.path.join(self.cfg.save_dir, 'interp',
|
||||
'mode_%s_%s_%d_img'%(self.generate_mode_global,
|
||||
self.generate_mode_local,
|
||||
self.sample_num_points))
|
||||
|
||||
logger.info('will save to {}', output_dir)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
if not os.path.exists(vis_dir):
|
||||
os.makedirs(vis_dir)
|
||||
vis_file = os.path.join(vis_dir, '%04d.png'%idx)
|
||||
output = validate_inspect(vis_file, shape, self.model, self.dae,
|
||||
diffusion, ode_sample,
|
||||
step + idx, self.writer, self.sample_num_points,
|
||||
epoch=self.cur_epoch,
|
||||
autocast_train=self.cfg.sde.autocast_train,
|
||||
need_sample=self.draw_sample_when_vis,
|
||||
need_val=0, need_train=0,
|
||||
num_samples=self.num_val_samples,
|
||||
test_loader=self.test_loader,
|
||||
w_prior=self.w_prior,
|
||||
val_x=self.val_x, tr_x=self.tr_x,
|
||||
val_input=self.val_input,
|
||||
m_pcs=self.m_pcs, s_pcs=self.s_pcs,
|
||||
has_shapelatent=True,
|
||||
vis_latent_point=self.cfg.vis_latent_point,
|
||||
ddim_step=self.cfg.viz.vis_sample_ddim_step,
|
||||
fun_generate_samples_vada=self.fun_generate_samples_vada
|
||||
)
|
||||
gen_x = output['gen_x']
|
||||
logger.info('gen_x shape: {}', gen_x.shape)
|
||||
for idxx in range(len(gen_x)):
|
||||
torch.save(gen_x[idxx], output_dir + '/%04d.pt'%idxx)
|
||||
logger.info('save to {}', output_dir)
|
||||
|
||||
if writer is not None:
|
||||
for n, v in output.items():
|
||||
if 'print/' not in n: continue
|
||||
self.writer.add_scalar('%s'%(n.split('print/')[-1]), v, step)
|
||||
|
||||
if self.cfg.ddpm.ema:
|
||||
self.swap_vae_param_if_need()
|
||||
self.dae_optimizer.swap_parameters_with_ema(store_params_in_ema=True)
|
||||
|
||||
|
||||
def set_writer(self, writer):
|
||||
self.writer = writer
|
||||
self.fun_generate_samples_vada = functools.partial(
|
||||
generate_samples, ode_eps=self.cfg.sde.ode_eps,
|
||||
writer=self.writer,
|
||||
generate_mode_global=self.generate_mode_global,
|
||||
generate_mode_local=self.generate_mode_local
|
||||
)
|
||||
def eval_sample(self, step=0):
|
||||
logger.info('skip eval-sample')
|
||||
return 0
|
Loading…
Reference in a new issue