LION/trainers/encode_interp_interp.py
2023-02-01 21:09:13 -05:00

323 lines
13 KiB
Python

""" to train VAE-encoder with two prior """
import os
import time
import random
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from loguru import logger
from torch import optim
from utils.ema import EMA
from utils import eval_helper, exp_helper
from utils.sr_utils import SpectralNormCalculator
from utils.checker import *
from utils.utils import AvgrageMeter
from utils import utils
from torch.optim import Adam as FusedAdam
from torch.cuda.amp import autocast, GradScaler
from trainers.train_2prior import Trainer as BaseTrainer
from utils.data_helper import normalize_point_clouds
from utils.vis_helper import visualize_point_clouds_3d
from utils.eval_helper import compute_NLL_metric
CHECKPOINT = int(os.environ.get('CHECKPOINT', 0))
EVAL_LNS = int(os.environ.get('EVAL_LNS', 0))
def interpolate_noise(noise):
noise_a = noise[0].contiguous() # 1,D,1,1
noise_b = noise[-1].contiguous() # 1,D,1,1
num_inter = noise.shape[0] - 2
for k in range(1, noise.shape[0]-1):
p = float(k) / len(noise) # 1/8 to 7/8
## logger.info('p={}; eps: {}', p, noise.shape)
# noise[k] = p * noise_b + (1-p) * noise_a
noise[k] = np.sqrt(p) * noise_b + np.sqrt(1-p) * noise_a
return noise
#def get_data(num_points=3072, num_selected=50):
# import h5py
# import torch
# import sys
# import os
# from datasets.data_path import get_path
#
# synsetid_to_cate = {
# "02691156": "airplane",
# "02828884": "bench",
# "02958343": "car",
# "03001627": "chair",
# "03211117": "display",
# "03636649": "lamp",
# "04256520": "sofa",
# "04379243": "table",
# "04530566": "watercraft",
# }
#
#
# root_dir = get_path('pointflow')
# out_dict = []
# out_name = []
# for synset_id in synsetid_to_cate.keys():
# filename = os.path.join(root_dir, f'{synset_id}', 'train')
# tr_out_full = []
# for file in sorted(os.listdir(filename))[:10]:
# pts = np.load(filename + '/' + file)
# tr_out_full.append(pts)
# tr_out_full = np.stack(tr_out_full)
# #with h5py.File(filename, "r") as h5f:
# # tr_out_full = h5f['surface_points/points'][5:5+num_selected]
# if tr_out_full.shape[1] > num_points:
# pt_cur = []
# for b in range(tr_out_full.shape[0]):
# tr_idxs = np.random.choice(tr_out_full.shape[1], num_points)
# pt_cur.append(tr_out_full[b, tr_idxs]) ## pt_cur[tr_idxs]
# tr_out_full = torch.from_numpy(np.stack(pt_cur))
# # out_dict[synsetid_to_cate[synset_id]] = tr_out_full
# out_dict.append(tr_out_full)
# out_name.extend([synsetid_to_cate[synset_id]] * num_selected)
# out_dict = torch.cat(out_dict)
# logger.info('created data: {}, ', out_dict.shape)
# return out_dict, out_name
#
#def fun_hash(a, b):
# if a < b:
# return f'{a}-{b}'
# else:
# return f'{b}-{a}'
#
#def fun_select_count_pair(names, a, b):
# select_all_pair = []
# indexes = list( np.arange(len(names)) )
# indexes_a = [i for ii, i in enumerate(indexes) if names[ii] == a]
# indexes_b = [i for ii, i in enumerate(indexes) if names[ii] == b]
# hash_d = []
# logger.info('select paits for: {}, {}', a, b)
# for ai in indexes_a:
# for bi in indexes_b:
# if ai != bi and fun_hash(ai, bi) not in hash_d:
# hash_d.append(fun_hash(ai, bi))
# select_all_pair.append([ai, bi])
# logger.info('get pairs: {}', len(select_all_pair))
# return select_all_pair
#
class Trainer(BaseTrainer):
is_diffusion = 0
generate_mode_global = 'interpolate'
generate_mode_local = 'interpolate'
def __init__(self, cfg, args):
"""
Args:
cfg: training config
args: used for distributed training
"""
super().__init__(cfg, args)
self.draw_sample_when_vis = 0
@torch.no_grad()
def eval_sample(self, step=0):
pass # do nothing
# -- shared method for all model with vae component -- #
@torch.no_grad()
def eval_nll(self, step, ntest=None, save_file=False):
cfg = self.cfg
if cfg.ddpm.ema: self.swap_vae_param_if_need()
self.dae_optimizer.swap_parameters_with_ema(store_params_in_ema=True)
args = self.args
device = torch.device('cuda:%d'%args.local_rank)
tag = exp_helper.get_evalname(self.cfg) + 'D%d'%self.cfg.voxel2pts.diffusion_steps[0]
ode_sample = self.cfg.sde.ode_sample
diffusion = self.diffusion_cont if ode_sample else self.diffusion_disc
eval_trainnll = self.cfg.eval_trainnll
data_loader = self.train_loader
#data_loader = self.test_loader
gen_pcs, ref_pcs = [], []
gen_x_lns_list = []
tag_cur = self.cfg.voxel2pts.diffusion_steps[0]
num_selected = 60
#input_pts, input_names = get_data(self.num_points, num_selected)
#is_all_class_model = 'nsall' in self.cfg.save_dir
#logger.info('is_all_class_model: {}', is_all_class_model)
#if is_all_class_model:
# select_count_pair = []
# class_pairs = [
# ['airplane', 'car'],
# #['car', 'watercraft'],
# #['watercraft', 'airplane']
# #['airplane', 'bench'],
# #['bench', 'table'],
# #['table', 'chair']
# #['chair', 'display'],
# ##['chair', 'lamp'],
# ##['lamp', 'watercraft']
# ]
# for class_a, class_b in class_pairs:
# select_count_pair_cur = fun_select_count_pair(input_names, class_a, class_b)
# select_count_pair.extend(select_count_pair_cur)
#else:
# class_a = class_b = self.cfg.data.cates
# select_count_pair = fun_select_count_pair(input_names, class_a, class_b)
#input_pts_list = []
#input_pts_cates = []
#print('number of pair', len(select_count_pair))
#for select_count in select_count_pair:
# input_pts_list_pair = []
# input_pts_cates_pair = []
# for vi in select_count:
# input_pts_list_pair.append(input_pts[vi][None])
# input_pts_cates_pair.append(input_names[vi])
# input_pts_list.append(torch.cat(input_pts_list_pair, dim=0))
# input_pts_cates.append('-'.join(input_pts_cates_pair))
##shuffle_idx = list(range(len(input_pts_list)))
##random.Random(38383).shuffle(shuffle_idx)
##input_pts_list = [input_pts_list[i] for i in shuffle_idx][:50] # select first 50 pairs
##input_pts_cates = [input_pts_cates[i] for i in shuffle_idx][:50] # select first 50 pairs
#logger.info('num of input pts: {}, cates: {}',
# len(input_pts_list), input_pts_cates[:10])
output_dir_template = self.cfg.save_dir + '/enc%d_%s_%s/'%(num_selected, self.generate_mode_global, self.generate_mode_local)
vis_output_dir = self.cfg.save_dir + '/vis_enc%d_%s_%s/'%(num_selected, self.generate_mode_global, self.generate_mode_local)
if not os.path.exists(vis_output_dir):
os.makedirs(vis_output_dir)
ode_eps = cfg.sde.ode_eps
ode_solver_tol = 1e-5
enable_autocast = False
temp = 1.0
noise = None
condition_input = None
clip_feat = None
##for vid, val_batch in enumerate(data_loader):
## m, s = val_batch['mean'][0:1], val_batch['std'][0:1]
## B = val_batch['mean'].shape[0]
## break
#m = torch.zeros(1,3)
#m[:,0] = -0.0308 #-1.0504e-02
#m[:,1] = -0.0353 #-4.1844e-03
#m[:,2] = -0.0001 #-5.1331e-05
#s = torch.zeros(1,1)
#s[0] = 0.1512 #0.1694
#logger.info('mean: {}, s={}', m, s)
B = 4
for vid, val_batch in enumerate(data_loader):
## for vid, (pt_cur, cate_cur) in enumerate(zip(input_pts_list, input_pts_cates)):
if vid % 30 == 1:
logger.info('eval: {}/{}, BS={}', vid, len(data_loader), B)
output_dir = output_dir_template + '/sph_B%d_%04d'%(B, vid)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
pt_cur = val_batch['tr_points']
pt_cur = pt_cur[:2] # select two input here
B0,N,C = pt_cur.shape
input_pts = pt_cur[None].expand(B//B0, -1, -1, -1).contiguous().view(B,N,C).contiguous()
# input_pts = ( input_pts - m ) / s
input_pts = input_pts.cuda().float()
file_name = ['%04d'%i for i in range(B)]
val_x = input_pts
logger.info('val_x: {}', val_x.shape)
B,N,C = val_x.shape
inputs = val_x
log_info = {'x_0': val_x, 'vis/input_noise': inputs.view(B,N,-1), 'x_t': val_x}
## dist = self.model.encode(inputs)
# -- global latent -- #
dae_index = 0
dist = self.model.encode_global(inputs)
B = inputs.shape[0]
shape = self.model.latent_shape()[dae_index]
noise_shape_global = [B] + shape
eps = dist.sample()[0]
eps_shape_global = eps.shape
eps = eps.view(noise_shape_global)
eps_ori = eps
# get eps_0:
i = 0
dae = self.dae
eps_0 = eps
eps_global = eps_0.contiguous()
condition_input = eps_global
# style
style = self.model.global2style(eps_global.view(eps_shape_global))
dist_local = self.model.encode_local(inputs, style)
# -- local latent -- #
dae_index = 1
shape = self.model.latent_shape()[dae_index]
noise_shape_local = [B] + shape
eps = dist_local.sample()[0]
eps_shape_local = eps.shape
eps = eps.view(noise_shape_local)
eps_ori = eps
# -------------------
# start the interplot
# -------------------
## eps_local = eps_ori
eps_T_global_interp = diffusion.compute_ode_nll(
dae[0], eps_0, ode_eps, ode_solver_tol,
condition_input=None
)
eps_T_global_interp = interpolate_noise(eps_T_global_interp.contiguous())
eps_0_global_interp, _, _= diffusion.sample_model_ode(
dae[0], B, shape, ode_eps, ode_solver_tol, enable_autocast, temp,
noise=eps_T_global_interp,
condition_input=None, clip_feat=clip_feat
)
eps_T_local_interp = diffusion.compute_ode_nll(
dae[1], eps, ode_eps, ode_solver_tol,
condition_input=eps_global
)
eps_T_local_interp = interpolate_noise(eps_T_local_interp.contiguous())
# double check if the eps_0 can denoise to get eps:
eps_0_local_interp, _, _= diffusion.sample_model_ode(
dae[1], B, shape, ode_eps,
ode_solver_tol, enable_autocast, temp,
noise=eps_T_local_interp,
condition_input=eps_0_global_interp,
clip_feat=clip_feat
)
style = self.model.global2style(eps_0_global_interp.view(eps_shape_global))
eps_local = eps_0_local_interp.view(eps_shape_local)
gen_x = self.model.decoder(None, beta=None, context=eps_local, style=style) # (B,ncenter,3)
## gen_x = val_x
# start the interpretation:
log_info['x_0_pred'] = gen_x.detach().cpu()
if True:
img_list = []
for i in range(B):
points = gen_x[i] # N,3
points = normalize_point_clouds([points])[0]
img = visualize_point_clouds_3d([points])
img_list.append(img)
img = np.concatenate(img_list, axis=2)
self.writer.add_image('%d'%vid, torch.as_tensor(img), 0)
img_list = [torch.as_tensor(img) / 255.0 for img in img_list]
vis_file = os.path.join(vis_output_dir, '%04d.png'%(vid))
torchvision.utils.save_image(img_list, vis_file)
logger.info('save img as: {}', vis_file)
# forward to get output
val_x = val_x.contiguous()
inputs = inputs.contiguous()
output = self.model.get_loss(val_x, it=step, is_eval_nll=1, noisy_input=inputs)
#gen_x = gen_x.cpu() * s + m
for i, file_name_i in enumerate(file_name):
torch.save(gen_x[i], os.path.join(output_dir, file_name_i))
logger.info('save output at : {}', output_dir)
return 0