853 lines
36 KiB
Python
853 lines
36 KiB
Python
|
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
|
#
|
||
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||
|
# and proprietary rights in and to this software, related documentation
|
||
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
||
|
# distribution of this software and related documentation without an express
|
||
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||
|
import os
|
||
|
import time
|
||
|
from abc import ABC, abstractmethod
|
||
|
from comet_ml import Experiment
|
||
|
import torch
|
||
|
import importlib
|
||
|
import numpy as np
|
||
|
from PIL import Image
|
||
|
from loguru import logger
|
||
|
import torchvision
|
||
|
import torch.distributed as dist
|
||
|
from utils.evaluation_metrics_fast import print_results
|
||
|
from utils.checker import *
|
||
|
from utils.vis_helper import visualize_point_clouds_3d
|
||
|
from utils.eval_helper import compute_score, get_ref_pt, get_ref_num
|
||
|
from utils import model_helper, exp_helper, data_helper
|
||
|
from utils.utils import infer_active_variables
|
||
|
from utils.data_helper import normalize_point_clouds
|
||
|
from utils.eval_helper import compute_NLL_metric
|
||
|
from utils.utils import AvgrageMeter
|
||
|
import clip
|
||
|
|
||
|
class BaseTrainer(ABC):
|
||
|
def __init__(self, cfg, args):
|
||
|
self.cfg, self.args = cfg, args
|
||
|
self.scheduler = None
|
||
|
self.local_rank = args.local_rank
|
||
|
self.cur_epoch = 0
|
||
|
self.start_epoch = 0
|
||
|
self.epoch = 0
|
||
|
self.step = 0
|
||
|
self.writer = None
|
||
|
self.encoder = None
|
||
|
self.num_val_samples = cfg.num_val_samples
|
||
|
self.train_iter_kwargs = {}
|
||
|
self.num_points = self.cfg.data.tr_max_sample_points
|
||
|
self.best_eval_epoch = 0
|
||
|
self.best_eval_score = -1
|
||
|
self.use_grad_scalar = cfg.trainer.use_grad_scalar
|
||
|
device = torch.device('cuda:%d' % args.local_rank)
|
||
|
self.device_str = 'cuda:%d' % args.local_rank
|
||
|
self.t2s_input = []
|
||
|
if cfg.clipforge.enable:
|
||
|
self.prepare_clip_model_data()
|
||
|
else:
|
||
|
self.clip_feat_list = None
|
||
|
|
||
|
def set_writer(self, writer):
|
||
|
self.writer = writer
|
||
|
logger.info(
|
||
|
'\n'+'-'*10 + f'\n[url]: {self.writer.url}\n{self.cfg.save_dir}\n' + '-'*10)
|
||
|
|
||
|
@abstractmethod
|
||
|
def train_iter(self, data, *args, **kwargs):
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def sample(self, *args, **kwargs):
|
||
|
pass
|
||
|
|
||
|
def log_val(self, val_info, writer=None, step=None, epoch=None, **kwargs):
|
||
|
if writer is not None:
|
||
|
for k, v in val_info.items():
|
||
|
if step is not None:
|
||
|
writer.add_scalar(k, v, step)
|
||
|
else:
|
||
|
writer.add_scalar(k, v, epoch)
|
||
|
|
||
|
def epoch_start(self, epoch):
|
||
|
pass
|
||
|
|
||
|
def epoch_end(self, epoch, writer=None, **kwargs):
|
||
|
# Signal now that the epoch ends....
|
||
|
if self.scheduler is not None:
|
||
|
self.scheduler.step(epoch=epoch)
|
||
|
if writer is not None:
|
||
|
writer.add_scalar(
|
||
|
'train/opt_lr', self.scheduler.get_lr()[0], epoch)
|
||
|
if writer is not None:
|
||
|
writer.upload_meter(epoch=epoch, step=kwargs.get('step', None))
|
||
|
|
||
|
# --- util function --
|
||
|
def save(self, save_name=None, epoch=None, step=None, appendix=None, save_dir=None, **kwargs):
|
||
|
d = {
|
||
|
'opt': self.optimizer.state_dict(),
|
||
|
'model': self.model.state_dict(),
|
||
|
'epoch': epoch,
|
||
|
'step': step
|
||
|
}
|
||
|
if appendix is not None:
|
||
|
d.update(appendix)
|
||
|
if self.use_grad_scalar:
|
||
|
d.update({'grad_scalar': self.grad_scalar.state_dict()})
|
||
|
save_name = "epoch_%s_iters_%s.pt" % (
|
||
|
epoch, step) if save_name is None else save_name
|
||
|
save_dir = self.cfg.save_dir if save_dir is None else save_dir
|
||
|
path = os.path.join(save_dir, "checkpoints", save_name)
|
||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||
|
logger.info('save model as : {}', path)
|
||
|
torch.save(d, path)
|
||
|
return path
|
||
|
|
||
|
def filter_name(self, ckpt):
|
||
|
ckpt_new = {}
|
||
|
for k, v in ckpt.items():
|
||
|
if k[:7] == 'module.':
|
||
|
kn = k[7:]
|
||
|
elif k[:13] == 'model.module.':
|
||
|
kn = k[13:]
|
||
|
else:
|
||
|
kn = k
|
||
|
ckpt_new[kn] = v
|
||
|
return ckpt_new
|
||
|
|
||
|
def resume(self, path, strict=True, **kwargs):
|
||
|
ckpt = torch.load(path)
|
||
|
strict = True
|
||
|
model_weight = ckpt['model'] if 'model' in ckpt else ckpt['model_state']
|
||
|
vae_weight = self.filter_name(model_weight)
|
||
|
self.model.load_state_dict(vae_weight, strict=strict)
|
||
|
if 'opt' in ckpt:
|
||
|
self.optimizer.load_state_dict(ckpt['opt'])
|
||
|
else:
|
||
|
logger.info('no optimizer found in ckpt')
|
||
|
|
||
|
start_epoch = ckpt['epoch']
|
||
|
self.epoch = start_epoch
|
||
|
self.cur_epoch = start_epoch
|
||
|
self.step = ckpt.get('step', 0)
|
||
|
logger.info('resume from : {}, epo={}', path, start_epoch)
|
||
|
if self.use_grad_scalar:
|
||
|
assert('grad_scalar' in ckpt), 'otherwise set it false'
|
||
|
self.grad_scalar.load_state_dict(ckpt['grad_scalar'])
|
||
|
return start_epoch
|
||
|
|
||
|
def build_model(self):
|
||
|
cfg, args = self.cfg, self.args
|
||
|
if args.distributed:
|
||
|
dist.barrier()
|
||
|
model_lib = importlib.import_module(cfg.shapelatent.model)
|
||
|
model = model_lib.Model(cfg)
|
||
|
return model
|
||
|
|
||
|
def build_data(self):
|
||
|
logger.info('start build_data')
|
||
|
cfg, args = self.cfg, self.args
|
||
|
self.args.eval_trainnll = cfg.eval_trainnll
|
||
|
data_lib = importlib.import_module(cfg.data.type)
|
||
|
loaders = data_lib.get_data_loaders(cfg.data, args)
|
||
|
train_loader = loaders['train_loader']
|
||
|
test_loader = loaders['test_loader']
|
||
|
return train_loader, test_loader
|
||
|
|
||
|
def train_epochs(self):
|
||
|
""" train for number of epochs; """
|
||
|
# main training loop
|
||
|
cfg, args = self.cfg, self.args
|
||
|
train_loader = self.train_loader
|
||
|
writer = self.writer
|
||
|
|
||
|
if cfg.viz.log_freq <= -1: # treat as per epoch
|
||
|
cfg.viz.log_freq = int(- cfg.viz.log_freq * len(train_loader))
|
||
|
if cfg.viz.viz_freq <= -1:
|
||
|
cfg.viz.viz_freq = - cfg.viz.viz_freq * len(train_loader)
|
||
|
|
||
|
logger.info("[rank=%d] Start epoch: %d End epoch: %d, batch-size=%d | "
|
||
|
"Niter/epo=%d | log freq=%d, viz freq %d, val freq %d " %
|
||
|
(args.local_rank,
|
||
|
self.start_epoch, cfg.trainer.epochs, cfg.data.batch_size,
|
||
|
len(train_loader),
|
||
|
cfg.viz.log_freq, cfg.viz.viz_freq, cfg.viz.val_freq))
|
||
|
tic0 = time.time()
|
||
|
step = 0
|
||
|
if args.global_rank == 0:
|
||
|
tic_log = time.time()
|
||
|
self.num_total_iter = cfg.trainer.epochs * len(train_loader)
|
||
|
self.model.num_total_iter = self.num_total_iter
|
||
|
|
||
|
for epoch in range(self.start_epoch, cfg.trainer.epochs):
|
||
|
self.cur_epoch = epoch
|
||
|
if args.global_rank == 0:
|
||
|
tic_epo = time.time()
|
||
|
if args.distributed:
|
||
|
train_loader.sampler.set_epoch(epoch)
|
||
|
if args.global_rank == 0 and cfg.trainer.type in ['trainers.voxel2pts', 'trainers.voxel2pts_ada'] and epoch == 0:
|
||
|
self.eval_nll(step=step)
|
||
|
epoch_loss = []
|
||
|
self.epoch_start(epoch)
|
||
|
|
||
|
# remove disabled latent variables by setting their mixing component to a small value
|
||
|
if epoch == 0 and cfg.sde.mixed_prediction and cfg.sde.drop_inactive_var:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
## -- train for one epoch -- ##
|
||
|
for bidx, data in enumerate(train_loader):
|
||
|
# let step start from 0 instead of 1
|
||
|
step = bidx + len(train_loader) * epoch
|
||
|
|
||
|
if args.global_rank == 0 and self.writer is not None:
|
||
|
tic_iter = time.time()
|
||
|
|
||
|
# -- train for one iter -- #
|
||
|
logs_info = self.train_iter(data, step=step,
|
||
|
**self.train_iter_kwargs)
|
||
|
|
||
|
# -- log information within epoch -- #
|
||
|
if self.args.global_rank == 0:
|
||
|
epoch_loss.append(logs_info['loss'])
|
||
|
if self.args.global_rank == 0 and (
|
||
|
time.time() - tic_log > 60
|
||
|
): # log per min
|
||
|
logger.info('[R%d] | E%d iter[%3d/%3d] | [Loss] %2.2f | '
|
||
|
'[exp] %s | [step] %5d | [url] %s ' % (
|
||
|
args.global_rank, epoch, bidx, len(train_loader),
|
||
|
np.array(epoch_loss).mean(),
|
||
|
cfg.save_dir, step, writer.url
|
||
|
))
|
||
|
tic_log = time.time()
|
||
|
|
||
|
# -- visualize rec and samples -- #
|
||
|
if step % int(cfg.viz.log_freq) == 0 and \
|
||
|
args.global_rank == 0 and not (
|
||
|
step == 0 and cfg.sde.ode_sample and
|
||
|
(cfg.trainer.type == 'trainers.train_prior' or cfg.trainer.type ==
|
||
|
'trainers.train_2prior') # this case, skip sampling at first step
|
||
|
):
|
||
|
avg_loss = np.array(epoch_loss).mean()
|
||
|
epo_loss = [] # clean up epoch loss
|
||
|
self.log_loss({'epo_loss': avg_loss},
|
||
|
writer=writer, step=step)
|
||
|
visualize = int(cfg.viz.viz_freq) > 0 and \
|
||
|
(step) % int(cfg.viz.viz_freq) == 0
|
||
|
vis_recont = visualize
|
||
|
if vis_recont:
|
||
|
self.vis_recont(logs_info, writer, step)
|
||
|
if visualize:
|
||
|
self.model.eval()
|
||
|
self.vis_sample(writer, step=step,
|
||
|
include_pred_x0=False)
|
||
|
self.model.train()
|
||
|
|
||
|
# -- timer -- #
|
||
|
if args.global_rank == 0 and self.writer is not None:
|
||
|
time_iter = time.time() - tic_iter
|
||
|
self.writer.avg_meter('time_iter', time_iter, step=step)
|
||
|
## -- log information after one epoch -- ##
|
||
|
if args.global_rank == 0:
|
||
|
epo_time = (time.time() - tic_epo) / 60.0 # min
|
||
|
logger.info('[R%d] | E%d iter[%3d/%3d] | [Loss] %2.2f '
|
||
|
'| [exp] %s | [step] %5d | [url] %s | [time] %.1fm (~%dh) |'
|
||
|
'[best] %d %.3fx1e-2 ' % (
|
||
|
args.global_rank, epoch, bidx, len(train_loader),
|
||
|
np.array(epoch_loss).mean(),
|
||
|
cfg.save_dir, step, writer.url,
|
||
|
epo_time, epo_time * (cfg.trainer.epochs - epoch) / 60,
|
||
|
self.best_eval_epoch, self.best_eval_score*1e2
|
||
|
))
|
||
|
tic_log = time.time() # reset tic_log
|
||
|
|
||
|
## -- save model -- ##
|
||
|
if (epoch + 1) % int(cfg.viz.save_freq) == 0 and \
|
||
|
int(cfg.viz.save_freq) > 0 and args.global_rank == 0:
|
||
|
self.save(epoch=epoch, step=step)
|
||
|
if ((time.time() - tic0) / 60 > cfg.snapshot_min) and \
|
||
|
args.global_rank == 0: # save every 30 min
|
||
|
file_name = self.save(
|
||
|
save_name='snapshot_bak', epoch=epoch, step=step)
|
||
|
if file_name is None:
|
||
|
file_name = os.path.join(
|
||
|
self.cfg.save_dir, "checkpoints", "snapshot_bak")
|
||
|
os.rename(file_name, file_name.replace(
|
||
|
'snapshot_bak', 'snapshot'))
|
||
|
tic0 = time.time()
|
||
|
|
||
|
## -- run eval -- ##
|
||
|
if int(cfg.viz.val_freq) > 0 and (epoch + 1) % int(cfg.viz.val_freq) == 0 and \
|
||
|
args.global_rank == 0:
|
||
|
eval_score = self.eval_nll(step=step, save_file=False)
|
||
|
if eval_score < self.best_eval_score or self.best_eval_score < 0:
|
||
|
self.save(save_name='best_eval.pth', # save_dir=snapshot_dir,
|
||
|
epoch=epoch, step=step)
|
||
|
self.best_eval_score = eval_score
|
||
|
self.best_eval_epoch = epoch
|
||
|
|
||
|
## -- Signal the trainer to cleanup now that an epoch has ended -- ##
|
||
|
self.epoch_end(epoch, writer=writer, step=step)
|
||
|
### -- end of the training -- ###
|
||
|
if args.global_rank == 0:
|
||
|
self.eval_nll(step=step)
|
||
|
if self.cfg.trainer.type == 'trainers.train_prior': # and args.global_rank == 0:
|
||
|
self.model.eval()
|
||
|
self.eval_sample(step)
|
||
|
logger.info('debugging eval-sample; exit now')
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def log_loss(self, train_info, writer=None, step=None, **kwargs):
|
||
|
""" write to tensorboard and visualize
|
||
|
"""
|
||
|
if writer is None:
|
||
|
return
|
||
|
|
||
|
# Log training information to tensorboard
|
||
|
train_info = {
|
||
|
k: (v.cpu() if not isinstance(v, float) else v)
|
||
|
for k, v in train_info.items()
|
||
|
}
|
||
|
for k, v in train_info.items():
|
||
|
if not ('loss' in k):
|
||
|
continue
|
||
|
if step is not None:
|
||
|
writer.add_scalar('train/' + k, v, step)
|
||
|
else:
|
||
|
assert epoch is not None
|
||
|
writer.add_scalar('train/' + k, v, epoch)
|
||
|
|
||
|
# --------------------------------------------- #
|
||
|
# visulization function and sampling function #
|
||
|
# --------------------------------------------- #
|
||
|
@torch.no_grad()
|
||
|
def vis_recont(self, output, writer, step, normalize_pts=False):
|
||
|
"""
|
||
|
Args:
|
||
|
x_0: Input point cloud, (B, N, d).
|
||
|
"""
|
||
|
if writer is None:
|
||
|
return 0
|
||
|
# x_0: target
|
||
|
# x_0_pred: recont
|
||
|
# x_t: intermidiate sample at t (if t is not None)
|
||
|
x_0_pred, x_0, x_t = output.get('x_0_pred', None), \
|
||
|
output.get('x_0', None), output.get('x_t', None)
|
||
|
if x_0_pred is None or x_0 is None or x_t is None:
|
||
|
logger.info('x_0_pred: None? {}; x_0: None? {}, x_t: None? {}',
|
||
|
x_0_pred is None, x_0 is None, x_t is None)
|
||
|
return 0
|
||
|
|
||
|
CHECK3D(x_0)
|
||
|
CHECK3D(x_t)
|
||
|
CHECK3D(x_0_pred)
|
||
|
|
||
|
t = output.get('t', None)
|
||
|
nvis = min(max(x_0.shape[0], 2), 5)
|
||
|
img_list = []
|
||
|
for b in range(nvis):
|
||
|
x_list, name_list = [], []
|
||
|
x_list.append(x_0_pred[b])
|
||
|
name_list.append('pred')
|
||
|
|
||
|
if t is not None and t[b] > 0:
|
||
|
x_t_name = 'x_t%d' % t[b].item()
|
||
|
name_list.append(x_t_name)
|
||
|
x_list.append(x_t[b])
|
||
|
|
||
|
x_list.append(x_0[b])
|
||
|
name_list.append('target')
|
||
|
|
||
|
for k, v in output.items():
|
||
|
if 'vis/' in k:
|
||
|
x_list.append(v[b])
|
||
|
name_list.append(k)
|
||
|
if normalize_pts:
|
||
|
x_list = normalize_point_clouds(x_list)
|
||
|
|
||
|
vis_order = self.cfg.viz.viz_order
|
||
|
vis_args = {'vis_order': vis_order}
|
||
|
|
||
|
img = visualize_point_clouds_3d(x_list, name_list, **vis_args)
|
||
|
img_list.append(img)
|
||
|
img_list = torchvision.utils.make_grid(
|
||
|
[torch.as_tensor(a) for a in img_list], pad_value=0)
|
||
|
writer.add_image('vis_out/recont-train', img_list, step)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def eval_sample(self, step=0):
|
||
|
""" compute sample metric: MMD,COV,1-NNA """
|
||
|
writer = self.writer
|
||
|
batch_size_test = self.cfg.data.batch_size_test
|
||
|
input_dim = self.cfg.ddpm.input_dim
|
||
|
ddim_step = self.cfg.eval_ddim_step
|
||
|
device = model_helper.get_device(self.model)
|
||
|
test_loader = self.test_loader
|
||
|
test_size = batch_size_test * len(test_loader)
|
||
|
sample_num_points = self.cfg.data.tr_max_sample_points
|
||
|
cates = self.cfg.data.cates
|
||
|
num_ref = get_ref_num(
|
||
|
cates) if self.cfg.num_ref == 0 else self.cfg.num_ref
|
||
|
|
||
|
# option for post-processing
|
||
|
if self.cfg.data.recenter_per_shape or self.cfg.data.normalize_shape_box or self.cfg.data.normalize_range:
|
||
|
norm_box = True
|
||
|
else:
|
||
|
norm_box = False
|
||
|
logger.info('norm_box: {}, recenter : {}, shapebox: {}',
|
||
|
norm_box, self.cfg.data.recenter_per_shape,
|
||
|
self.cfg.data.normalize_shape_box)
|
||
|
|
||
|
# get exp tag and output name
|
||
|
tag = exp_helper.get_evalname(self.cfg)
|
||
|
if not self.cfg.sde.ode_sample:
|
||
|
tag += 'diet'
|
||
|
else:
|
||
|
tag += 'ode%d' % self.cfg.sde.ode_sample
|
||
|
output_name = os.path.join(
|
||
|
self.cfg.save_dir, f'samples_{step}{tag}.pt')
|
||
|
logger.info('batch_size_test={}, test_size={}, saved output: {} ',
|
||
|
batch_size_test, test_size, output_name)
|
||
|
gen_pcs = []
|
||
|
|
||
|
### ---- ref_pcs ---- #
|
||
|
##ref_pcs = []
|
||
|
##m_pcs, s_pcs = [], []
|
||
|
# for i, data in enumerate(test_loader):
|
||
|
## tr_points = data['tr_points']
|
||
|
## m, s = data['mean'], data['std']
|
||
|
# ref_pcs.append(tr_points) # B,N,3
|
||
|
# m_pcs.append(m.float())
|
||
|
# s_pcs.append(s.float())
|
||
|
## sample_num_points = tr_points.shape[1]
|
||
|
# assert(tr_points.shape[2] in [3,6]
|
||
|
# ), f'expect B,N,3; get {tr_points.shape}'
|
||
|
##ref_pcs = torch.cat(ref_pcs, dim=0)
|
||
|
##m_pcs = torch.cat(m_pcs, dim=0)
|
||
|
##s_pcs = torch.cat(s_pcs, dim=0)
|
||
|
# if VIS:
|
||
|
## img_list = []
|
||
|
# for i in range(4):
|
||
|
## norm_ref, norm_gen = data_helper.normalize_point_clouds([ref_pcs[i], ref_pcs[-i]])
|
||
|
## img = visualize_point_clouds_3d([norm_ref, norm_gen], [f'ref-{i}', f'ref-{-i}'])
|
||
|
## img_list.append(torch.as_tensor(img) / 255.0)
|
||
|
## path = output_name.replace('.pt', '_ref.png')
|
||
|
# torchvision.utils.save_image(img_list, path)
|
||
|
## grid = torchvision.utils.make_grid(img_list)
|
||
|
# ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
||
|
## writer.add_image('ref', grid, 0)
|
||
|
# logger.info(writer.url)
|
||
|
## logger.info('save vis at {}', path)
|
||
|
|
||
|
# ---- gen_pcs ---- #
|
||
|
if True:
|
||
|
len_test_loader = num_ref // batch_size_test + 1
|
||
|
if self.args.distributed:
|
||
|
num_gen_iter = max(1, len_test_loader // self.args.global_size)
|
||
|
if num_gen_iter * batch_size_test * self.args.global_size < num_ref:
|
||
|
num_gen_iter = num_gen_iter + 1
|
||
|
else:
|
||
|
num_gen_iter = len_test_loader
|
||
|
|
||
|
index_start = 0
|
||
|
logger.info('Rank={}, num_gen_iter: {}; num_ref={}, batch_size_test={}',
|
||
|
self.args.global_rank, num_gen_iter, num_ref, batch_size_test)
|
||
|
seed = self.cfg.trainer.seed
|
||
|
for i in range(0, num_gen_iter):
|
||
|
torch.manual_seed(seed + i)
|
||
|
np.random.seed(seed + i)
|
||
|
torch.cuda.manual_seed(seed + i)
|
||
|
torch.cuda.manual_seed_all(seed + i)
|
||
|
|
||
|
logger.info('#%d/%d; BS=%d' %
|
||
|
(i, num_gen_iter, batch_size_test))
|
||
|
# ---- draw samples ---- #
|
||
|
self.index_start = index_start
|
||
|
x = self.sample(num_shapes=batch_size_test,
|
||
|
num_points=sample_num_points,
|
||
|
device_str=device.type,
|
||
|
for_vis=False,
|
||
|
ddim_step=ddim_step).permute(0, 2, 1).contiguous() # B,3,N->B,N,3
|
||
|
assert(
|
||
|
x.shape[-1] == input_dim), f'expect x: B,N,{input_dim}; get {x.shape}'
|
||
|
index_start = index_start + batch_size_test
|
||
|
gen_pcs.append(x.detach().cpu())
|
||
|
|
||
|
gen_pcs = torch.cat(gen_pcs, dim=0)
|
||
|
if self.args.distributed:
|
||
|
gen_pcs = gen_pcs.to(torch.device(self.device_str))
|
||
|
logger.info('before gather: {}, rank={}',
|
||
|
gen_pcs.shape, self.args.global_rank)
|
||
|
gen_pcs_list = [torch.zeros_like(gen_pcs)
|
||
|
for _ in range(self.args.global_size)]
|
||
|
dist.all_gather(gen_pcs_list, gen_pcs)
|
||
|
gen_pcs = torch.cat(gen_pcs_list, dim=0).cpu()
|
||
|
logger.info('after gather: {}, rank={}',
|
||
|
gen_pcs.shape, self.args.global_rank)
|
||
|
logger.info('save as %s' % output_name)
|
||
|
if self.args.global_rank == 0:
|
||
|
torch.save(gen_pcs, output_name)
|
||
|
else:
|
||
|
logger.info('return for rank {}', self.args.global_rank)
|
||
|
return # only do eval on one gpu
|
||
|
if writer is not None:
|
||
|
img_list = []
|
||
|
for i in range(1):
|
||
|
gen_list = [gen_pcs[k] for k in range(len(gen_pcs))][:8]
|
||
|
norm_ref = data_helper.normalize_point_clouds(gen_list)
|
||
|
img = visualize_point_clouds_3d(norm_ref, [f'gen-{k}' for k in range(len(norm_ref))]
|
||
|
)
|
||
|
img_list.append(torch.as_tensor(img) / 255.0)
|
||
|
grid = torchvision.utils.make_grid(img_list)
|
||
|
logger.info('ndarr: {}, range: {} img list: {} ', grid.shape,
|
||
|
grid.max(), img_list[0].shape, img_list[0].max())
|
||
|
writer.add_image('sample', grid, step)
|
||
|
logger.info('\n\t' + writer.url)
|
||
|
#logger.info('early exit')
|
||
|
# exit()
|
||
|
|
||
|
shape_str = '{}: gen_pcs: {}'.format(self.cfg.save_dir, gen_pcs.shape)
|
||
|
logger.info(shape_str)
|
||
|
|
||
|
ref = get_ref_pt(self.cfg.data.cates, self.cfg.data.type)
|
||
|
if ref is None:
|
||
|
logger.info('Not computing score')
|
||
|
return 1
|
||
|
step_str = '%dk' % (step / 1000.0)
|
||
|
epoch_str = '%.1fk' % (self.epoch / 1000.0)
|
||
|
print_kwargs = {'dataset': self.cfg.data.cates,
|
||
|
'hash': self.cfg.hash + tag,
|
||
|
'step': step_str,
|
||
|
'epoch': epoch_str+'-'+os.path.basename(ref).split('.')[0]}
|
||
|
self.model = self.model.cpu()
|
||
|
torch.cuda.empty_cache()
|
||
|
# -- compute the generation metric -- #
|
||
|
results = compute_score(output_name, ref_name=ref,
|
||
|
writer=writer,
|
||
|
batch_size_test=min(
|
||
|
5, self.cfg.data.batch_size_test),
|
||
|
norm_box=norm_box,
|
||
|
**print_kwargs)
|
||
|
|
||
|
self.model = self.model.to(device)
|
||
|
# ---- write to logger ---- #
|
||
|
writer.add_scalar('test/Coverage_CD', results['lgan_cov-CD'], step)
|
||
|
writer.add_scalar('test/Coverage_EMD', results['lgan_cov-EMD'], step)
|
||
|
writer.add_scalar('test/MMD_CD', results['lgan_mmd-CD'], step)
|
||
|
writer.add_scalar('test/MMD_EMD', results['lgan_mmd-EMD'], step)
|
||
|
writer.add_scalar('test/1NN_CD', results['1-NN-CD-acc'], step)
|
||
|
writer.add_scalar('test/1NN_EMD', results['1-NN-EMD-acc'], step)
|
||
|
writer.add_scalar('test/JSD', results['jsd'], step)
|
||
|
msg = f'step={step}'
|
||
|
msg += '\n[Test] MinMatDis | CD %.6f | EMD %.6f' % (
|
||
|
results['lgan_mmd-CD'], results['lgan_mmd-EMD'])
|
||
|
msg += '\n[Test] Coverage | CD %.6f | EMD %.6f' % (
|
||
|
results['lgan_cov-CD'], results['lgan_cov-EMD'])
|
||
|
msg += '\n[Test] 1NN-Accur | CD %.6f | EMD %.6f' % (
|
||
|
results['1-NN-CD-acc'], results['1-NN-EMD-acc'])
|
||
|
msg += '\n[Test] JsnShnDis | %.6f ' % (results['jsd'])
|
||
|
|
||
|
logger.info(msg)
|
||
|
with open(os.path.join(self.cfg.save_dir, 'eval_out.txt'), 'a') as f:
|
||
|
f.write(shape_str+'\n')
|
||
|
f.write(msg+'\n')
|
||
|
# self.cfg.data.cates, self.cfg.hash, step_str, epoch_str)
|
||
|
msg = print_results(results, **print_kwargs)
|
||
|
with open(os.path.join(self.cfg.save_dir, 'eval_out.txt'), 'a') as f:
|
||
|
f.write(msg+'\n')
|
||
|
logger.info('\n\t' + writer.url)
|
||
|
|
||
|
def vis_sample(self, writer, num_vis=None, step=0, include_pred_x0=True,
|
||
|
save_file=None):
|
||
|
if num_vis is None:
|
||
|
num_vis = self.num_val_samples
|
||
|
logger.info("Sampling.. train-step=%s | N=%d" % (step, num_vis))
|
||
|
tic = time.time()
|
||
|
# get three list with entry: [L,N,3]
|
||
|
# traj, traj_x0, time_list
|
||
|
traj, pred_x0 = self.sample(num_points=self.num_points,
|
||
|
num_shapes=num_vis, for_vis=True, use_ddim=True,
|
||
|
save_file=save_file)
|
||
|
|
||
|
toc = time.time()
|
||
|
logger.info('sampling take %.1f sec' % (toc-tic))
|
||
|
|
||
|
# display only a few steps
|
||
|
num_shapes = num_vis
|
||
|
vis_num_steps = len(traj)
|
||
|
vis_index = list(traj.keys())
|
||
|
vis_index = vis_index[::-1]
|
||
|
display_num_step = 5
|
||
|
step_size = max(1, vis_num_steps // 5)
|
||
|
display_num_step_list = []
|
||
|
for k in range(0, vis_num_steps, step_size):
|
||
|
display_num_step_list.append(vis_index[k])
|
||
|
if self.num_steps not in display_num_step_list and self.num_steps in traj:
|
||
|
display_num_step_list.append(self.num_steps)
|
||
|
logger.info('saving vis with N={}', len(display_num_step_list))
|
||
|
alltraj_list = []
|
||
|
allpred_x0_list = []
|
||
|
allstep_list = []
|
||
|
for b in range(num_shapes):
|
||
|
traj_list = []
|
||
|
pred_x0_list = []
|
||
|
step_list = []
|
||
|
for k in display_num_step_list:
|
||
|
v = traj[k]
|
||
|
traj_list.append(v[b].permute(1, 0).contiguous())
|
||
|
v = pred_x0[k]
|
||
|
pred_x0_list.append(v[b].permute(1, 0).contiguous())
|
||
|
step_list.append(k)
|
||
|
# B3N -> 3,N -> N,3 use first sample only
|
||
|
alltraj_list.append(traj_list)
|
||
|
allpred_x0_list.append(pred_x0_list)
|
||
|
allstep_list.append(step_list)
|
||
|
traj, traj_x0, time_list = alltraj_list, allpred_x0_list, allstep_list
|
||
|
|
||
|
# vis the final images,
|
||
|
all_imgs = []
|
||
|
all_imgs_torchvis = [] # no preconcat in the image, left to the torchvision
|
||
|
all_imgs_torchvis_norm = [] # no preconcat in the image, left to the torchvision
|
||
|
for idx in range(num_vis):
|
||
|
pcs = traj[idx][0:1] # 1,N,3
|
||
|
img = []
|
||
|
# vis the normalized point cloud
|
||
|
title_list = ['#%d normed x_%d' % (idx, 0)]
|
||
|
norm_pcs = data_helper.normalize_point_clouds(pcs)
|
||
|
img.append(visualize_point_clouds_3d(norm_pcs, title_list,
|
||
|
self.cfg.viz.viz_order))
|
||
|
all_imgs_torchvis_norm.append(img[-1] / 255.0)
|
||
|
if include_pred_x0:
|
||
|
title_list = ['#%d p(x_0|x_%d,t)' % (idx, 0)]
|
||
|
img.append(visualize_point_clouds_3d(traj_x0[idx][0:1], title_list,
|
||
|
self.cfg.viz.viz_order))
|
||
|
# concat along the height
|
||
|
all_imgs.append(np.concatenate(img, axis=1))
|
||
|
|
||
|
# concatenate along the width dimension
|
||
|
img = np.concatenate(all_imgs, axis=2)
|
||
|
writer.add_image('summary/sample', torch.as_tensor(img), step)
|
||
|
|
||
|
path = os.path.join(self.cfg.save_dir, 'vis', 'sample%06d.png' % step)
|
||
|
if not os.path.exists(os.path.dirname(path)):
|
||
|
os.makedirs(os.path.dirname(path))
|
||
|
|
||
|
img_list = [torch.as_tensor(a) for a in all_imgs_torchvis_norm]
|
||
|
grid = torchvision.utils.make_grid(img_list)
|
||
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(
|
||
|
1, 2, 0).to('cpu', torch.uint8).numpy()
|
||
|
im = Image.fromarray(ndarr)
|
||
|
im.save(path)
|
||
|
logger.info('save as {}; url: {} ', path, writer.url)
|
||
|
|
||
|
def prepare_vis_data(self):
|
||
|
device = torch.device(self.device_str)
|
||
|
num_val_samples = self.num_val_samples
|
||
|
c = 0
|
||
|
val_x = []
|
||
|
val_input = []
|
||
|
val_cls = []
|
||
|
prior_cond = []
|
||
|
for val_batch in self.test_loader:
|
||
|
val_x.append(val_batch['tr_points'])
|
||
|
val_cls.append(val_batch['cate_idx'])
|
||
|
if 'input_pts' in val_batch: # this is the input_pts to the vae model
|
||
|
val_input.append(val_batch['input_pts'])
|
||
|
if 'prior_cond' in val_batch:
|
||
|
prior_cond.append(val_batch['prior_cond'])
|
||
|
c += val_x[-1].shape[0]
|
||
|
if c >= num_val_samples:
|
||
|
break
|
||
|
self.val_x = torch.cat(val_x, dim=0)[:num_val_samples].to(device)
|
||
|
# this line may trigger error, change dataset output cate_idx from string to int can fix this issue
|
||
|
self.val_cls = torch.cat(val_cls, dim=0)[:num_val_samples].to(device)
|
||
|
self.prior_cond = torch.cat(prior_cond, dim=0)[:num_val_samples].to(
|
||
|
device) if len(prior_cond) else None
|
||
|
self.val_input = torch.cat(val_input, dim=0)[:num_val_samples].to(
|
||
|
device) if len(val_input) else None
|
||
|
c = 0
|
||
|
tr_x = []
|
||
|
m_x = []
|
||
|
s_x = []
|
||
|
tr_cls = []
|
||
|
logger.info('[prepare_vis_data] len of train_loader: {}',
|
||
|
len(self.train_loader))
|
||
|
assert(len(self.train_loader) > 0), f'get zero length train_loader, it could be the batch_size > the number of training sample, and the train drop_last is turn off'
|
||
|
for tr_batch in self.train_loader:
|
||
|
tr_x.append(tr_batch['tr_points'])
|
||
|
m_x.append(tr_batch['mean'])
|
||
|
s_x.append(tr_batch['std'])
|
||
|
tr_cls.append(tr_batch['cate_idx'].view(-1))
|
||
|
c += tr_x[-1].shape[0]
|
||
|
if c >= num_val_samples:
|
||
|
break
|
||
|
self.tr_cls = torch.cat(tr_cls, dim=0)[:num_val_samples].to(device)
|
||
|
self.tr_x = torch.cat(tr_x, dim=0)[:num_val_samples].to(device)
|
||
|
self.m_pcs = torch.cat(m_x, dim=0)[:num_val_samples].to(device)
|
||
|
self.s_pcs = torch.cat(s_x, dim=0)[:num_val_samples].to(device)
|
||
|
logger.info('tr_x: {}, m_pcs: {}, s_pcs: {}, val_x: {}',
|
||
|
self.tr_x.shape, self.m_pcs.shape, self.s_pcs.shape, self.val_x.shape)
|
||
|
|
||
|
self.w_prior = torch.randn(
|
||
|
[num_val_samples, self.cfg.shapelatent.latent_dim]).to(device)
|
||
|
if self.clip_feat_list is not None:
|
||
|
self.clip_feat_test = []
|
||
|
for k in range(len(self.clip_feat_list)):
|
||
|
for i in range(num_val_samples // len(self.clip_feat_list)):
|
||
|
self.clip_feat_test.append(self.clip_feat_list[k])
|
||
|
for i in range(num_val_samples - len(self.clip_feat_test)):
|
||
|
self.clip_feat_test.append(self.clip_feat_list[-1])
|
||
|
self.clip_feat_test = torch.stack(self.clip_feat_test, dim=0)
|
||
|
logger.info('[VIS data] clip_feat_test: {}',
|
||
|
self.clip_feat_test.shape)
|
||
|
if self.clip_feat_test.shape[0] > num_val_samples:
|
||
|
self.clip_feat_test = self.clip_feat_test[:num_val_samples]
|
||
|
else:
|
||
|
self.clip_feat_test = None
|
||
|
|
||
|
def build_other_module(self):
|
||
|
logger.info('no other module to build')
|
||
|
pass
|
||
|
|
||
|
def swap_vae_param_if_need(self):
|
||
|
if self.cfg.ddpm.ema:
|
||
|
self.optimizer.swap_parameters_with_ema(store_params_in_ema=True)
|
||
|
|
||
|
# -- shared method for all model with vae component -- #
|
||
|
@torch.no_grad()
|
||
|
def eval_nll(self, step, ntest=None, save_file=False):
|
||
|
loss_dict = {}
|
||
|
cfg = self.cfg
|
||
|
self.swap_vae_param_if_need()
|
||
|
args = self.args
|
||
|
device = torch.device('cuda:%d' % args.local_rank)
|
||
|
tag = exp_helper.get_evalname(self.cfg)
|
||
|
eval_trainnll = 0
|
||
|
if eval_trainnll:
|
||
|
data_loader = self.train_loader
|
||
|
tag += '-train'
|
||
|
else:
|
||
|
data_loader = self.test_loader
|
||
|
gen_pcs, ref_pcs = [], []
|
||
|
output_name = os.path.join(self.cfg.save_dir, f'recont_{step}{tag}.pt')
|
||
|
output_name_metric = os.path.join(
|
||
|
self.cfg.save_dir, f'recont_{step}{tag}_metric.pt')
|
||
|
shape_id_start = 0
|
||
|
batch_metric_all = {}
|
||
|
|
||
|
for vid, val_batch in enumerate(data_loader):
|
||
|
if vid % 30 == 1:
|
||
|
logger.info('eval: {}/{}', vid, len(data_loader))
|
||
|
val_x = val_batch['tr_points'].to(device)
|
||
|
m, s = val_batch['mean'], val_batch['std']
|
||
|
B, N, C = val_x.shape
|
||
|
m = m.view(B, 1, -1)
|
||
|
s = s.view(B, 1, -1)
|
||
|
inputs = val_batch['input_pts'].to(
|
||
|
device) if 'input_pts' in val_batch else None # the noisy points
|
||
|
model_kwargs = {}
|
||
|
|
||
|
output = self.model.get_loss(
|
||
|
val_x, it=step, is_eval_nll=1, noisy_input=inputs, **model_kwargs)
|
||
|
for k, v in output.items():
|
||
|
if 'print/' in k:
|
||
|
k = k.split('print/')[-1]
|
||
|
if k not in loss_dict:
|
||
|
loss_dict[k] = AvgrageMeter()
|
||
|
v = v.mean().item() if torch.is_tensor(v) else v
|
||
|
loss_dict[k].update(v)
|
||
|
|
||
|
gen_x = output['final_pred']
|
||
|
if gen_x.shape[1] > val_x.shape[1]:
|
||
|
tr_idxs = np.random.permutation(np.arange(gen_x.shape[1]))[
|
||
|
:val_x.shape[1]]
|
||
|
gen_x = gen_x[:, tr_idxs]
|
||
|
|
||
|
gen_x = gen_x.cpu()
|
||
|
val_x = val_x.cpu()
|
||
|
gen_x[:, :, :3] = gen_x[:, :, :3] * s + m
|
||
|
val_x[:, :, :3] = val_x[:, :, :3] * s + m
|
||
|
gen_pcs.append(gen_x.detach().cpu())
|
||
|
ref_pcs.append(val_x.detach().cpu())
|
||
|
if ntest is not None and shape_id_start >= int(ntest):
|
||
|
logger.info('!! reach number of test={}; has test: {}',
|
||
|
ntest, shape_id_start)
|
||
|
break
|
||
|
shape_id_start += B
|
||
|
# summarized batch-metric if any
|
||
|
for k, v in batch_metric_all.items():
|
||
|
if len(v) == 0:
|
||
|
continue
|
||
|
v = torch.cat(v, dim=0)
|
||
|
logger.info('{}={}', k, v.mean())
|
||
|
|
||
|
gen_pcs = torch.cat(gen_pcs, dim=0)
|
||
|
ref_pcs = torch.cat(ref_pcs, dim=0)
|
||
|
|
||
|
# Save
|
||
|
if self.writer is not None:
|
||
|
img_list = []
|
||
|
for i in range(10):
|
||
|
points = gen_pcs[i]
|
||
|
points = normalize_point_clouds([points])[0]
|
||
|
img = visualize_point_clouds_3d([points], bound=1.0)
|
||
|
img_list.append(img)
|
||
|
img = np.concatenate(img_list, axis=2)
|
||
|
self.writer.add_image('nll/rec', torch.as_tensor(img), step)
|
||
|
if save_file:
|
||
|
logger.info('reconstruct point clouds..., output shape: {}, save as {}',
|
||
|
gen_pcs.shape, output_name)
|
||
|
torch.save(gen_pcs, output_name)
|
||
|
|
||
|
results = compute_NLL_metric(
|
||
|
gen_pcs[:, :, :3], ref_pcs[:, :, :3], device, self.writer, output_name, batch_size=20, step=step)
|
||
|
score = 0
|
||
|
for n, v in results.items():
|
||
|
if 'detail' in n:
|
||
|
continue
|
||
|
if self.writer is not None:
|
||
|
logger.info('add: {}', n)
|
||
|
self.writer.add_scalar('eval/%s' % (n), v, step)
|
||
|
if 'CD' in n:
|
||
|
score = v
|
||
|
self.swap_vae_param_if_need()
|
||
|
return score
|
||
|
|
||
|
def prepare_clip_model_data(self):
|
||
|
cfg = self.cfg
|
||
|
self.clip_model, self.clip_preprocess = clip.load(cfg.clipforge.clip_model,
|
||
|
device=self.device_str)
|
||
|
self.test_img_path = []
|
||
|
if cfg.data.cates == 'chair':
|
||
|
input_t = [
|
||
|
"an armchair in the shape of an avocado. an armchair imitating a avocado"]
|
||
|
text = clip.tokenize(input_t).to(self.device_str)
|
||
|
elif cfg.data.cates == 'car':
|
||
|
input_t = ["a ford model T", "a pickup", "an off-road vehicle"]
|
||
|
text = clip.tokenize(input_t).to(self.device_str)
|
||
|
elif cfg.data.cates == 'all':
|
||
|
input_t = ['a boeing', 'an f-16', 'an suv', 'a chunk', 'a limo',
|
||
|
'a square chair', 'a swivel chair', 'a sniper rifle']
|
||
|
text = clip.tokenize(input_t).to(self.device_str)
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
if len(self.test_img_path):
|
||
|
self.test_img = [Image.open(t).convert("RGB")
|
||
|
for t in self.test_img_path]
|
||
|
self.test_img = [self.clip_preprocess(img).unsqueeze(
|
||
|
0).to(self.device_str) for img in self.test_img]
|
||
|
self.test_img = torch.cat(self.test_img, dim=0)
|
||
|
else:
|
||
|
self.test_img = []
|
||
|
self.t2s_input = self.test_img_path + input_t
|
||
|
clip_feat = []
|
||
|
if len(self.test_img):
|
||
|
clip_feat.append(
|
||
|
self.clip_model.encode_image(self.test_img).float())
|
||
|
clip_feat.append(self.clip_model.encode_text(text).float())
|
||
|
self.clip_feat_list = torch.cat(clip_feat, dim=0)
|
||
|
|