LION/utils/eval_helper.py
xzeng 4d6af1d8b9 fix tensor device issue
https://github.com/nv-tlabs/LION/issues/31#issue-1627736930
the original code index cpu tensor with cuda tensor (works in torch 1.10.2),
may fail in other torch version?
2023-03-16 12:44:57 -04:00

342 lines
14 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import os
import json
from comet_ml import Experiment, OfflineExperiment
## import open3d as o3d
import time
import numpy as np
import torch
from loguru import logger
import torchvision
from PIL import Image
from utils.vis_helper import visualize_point_clouds_3d
from utils.data_helper import normalize_point_clouds
from utils.checker import *
import torchvision
import sys
import math
from utils.evaluation_metrics_fast import compute_all_metrics, \
jsd_between_point_cloud_sets, print_results, write_results
from utils.evaluation_metrics_fast import EMD_CD
CD_ONLY = int(os.environ.get('CD_ONLY', 0))
VIS = 1
def pair_vis(gen_x, tr_x, titles, subtitles, writer, step=-1):
img_list = []
num_recon = len(gen_x)
for i in range(num_recon):
points = gen_x[i]
points = normalize_point_clouds([tr_x[i], points])
img = visualize_point_clouds_3d(points, subtitles[i])
img_list.append(torch.as_tensor(img) / 255.0)
grid = torchvision.utils.make_grid(img_list, nrow=num_recon//2)
if writer is not None:
writer.add_image(titles, grid, step)
def compute_NLL_metric(gen_pcs, ref_pcs, device, writer=None, output_name='', batch_size=200, step=-1, tag=''):
# evaluate the reconstrution results
metrics = EMD_CD(gen_pcs.to(device), ref_pcs.to(device),
batch_size=batch_size, accelerated_cd=True, reduced=False)
titles = 'nll/first-10-%s' % tag
k1, k2 = list(metrics.keys())
subtitles = [['ori', 'gen-%s=%.1fx1e-2;%s=%.1fx1e-2' %
(k1, metrics[k1][j]*1e2, k2, metrics[k2][j]*1e2)] for j in range(10)]
pair_vis(gen_pcs[:10], ref_pcs[:10], titles, subtitles, writer, step=step)
results = {}
for k in metrics.keys():
sorted, indices = torch.sort(metrics[k])
worse_ten, worse_score = indices[-10:].cpu(), sorted[-10:].cpu()
titles = 'nll/worst-%s-%s' % (k, tag)
subtitles = [['ori', 'gen-%s=%.2fx1e-2' %
(k, worse_score[j]*1e2)] for j in range(len(worse_score))]
pair_vis(gen_pcs[worse_ten], ref_pcs[worse_ten],
titles, subtitles, writer, step=step)
if 'score_detail' not in results:
results['score_detail'] = metrics[k]
metrics[k] = metrics[k].mean()
logger.info('best 10: {}', indices[:10])
results.update({k: v.item() for k, v in metrics.items()})
output = ''
for k, v in results.items():
if 'detail' in k:
continue
output += '%s=%.3fx1e-2 ' % (k, v*1e2)
logger.info('{}: {}', k, v)
if 'CD' in k:
score = v
url = writer.url if writer is not None else ''
logger.info('\n' + '-'*60 +
f'\n{output_name} | \n{output} step={step} \n {url} \n ' + '-'*60)
return results
def get_ref_num(cats, luo_split=False):
#ref = './scripts/test_data/ref_%s.pt'%cats
#assert(os.path.exists(ref)), f'file not found: {ref}'
num_test = {
'animal': 100,
'airplane': 405,
'airplane_ps': 405,
'chair': 662,
'chair_ps': 662,
'car': 352,
'car_ps': 352,
'all': 1000,
'mug': 22,
'bottle': 43
}
if luo_split:
num_test = {
'airplane': 607,
'chair': 989,
'car': 528
}
assert(cats in num_test), f'not found: {cats} in {num_test}'
return num_test[cats]
def get_cats(cats):
# return the category name for this dataset
all_cats = ['airplane', 'chair', 'car', 'all', 'animal', 'mug', 'bottle']
for c in all_cats:
if c in cats or c == cats:
cats = c
break
assert(cats in all_cats), f'not foud cats for {cats} in {all_cats}'
return cats
def get_ref_pt(cats, data_type="datasets.pointflow_datasets", luo_split=False):
cats = get_cats(cats)
root = './datasets/test_data/'
if 'pointflow' in data_type:
ref = 'ref_val_%s.pt' % cats
elif 'neuralspline_datasets' in data_type:
ref = 'ref_ns_val_%s.pt' % cats
else:
logger.info('get_ref_pt not support data_type: %s' % data_type)
return None
ref = os.path.join(root, ref)
assert(os.path.exists(ref)), f'file not found: {ref}'
return ref
#@torch.no_grad()
#def compute_score_fast(gen_pcs, ref_pcs, m_pcs, s_pcs,
# batch_size_test=256, device_str='cuda', cd_only=1,
# exp=None, verbose=False,
# device=None, accelerated_cd=True, writer=None, norm_box=False, **print_kwargs):
# """ used to eval the pcs during training; all the files will not be dumpped into disk (to save time)
# the ref_pcs will be part of the full dataset only
# Args:
# output_name (str) path to sample obj: tensor: (Nsample.Npoint.3or6)
# ref_name (str) path to torch obj:
# torch.save({'ref': ref_pcs, 'mean': m_pcs, 'std': s_pcs}, ref_name)
# print_kwargs (dict): entries: dataset, hash, step, epoch;
# """
# if gen_pcs.shape[1] > ref_pcs.shape[1]:
# xperm = np.random.permutation(np.arange(gen_pcs.shape[1]))[
# :ref_pcs.shape[1]]
# gen_pcs = gen_pcs[:, xperm]
# if ref_pcs.shape[0] > gen_pcs.shape[0]:
# ref_pcs = ref_pcs[:gen_pcs.shape[0]]
# m_pcs = m_pcs[:gen_pcs.shape[0]]
# s_pcs = s_pcs[:gen_pcs.shape[0]]
# elif ref_pcs.shape[0] < gen_pcs.shape[0]:
# gen_pcs = gen_pcs[:ref_pcs.shape[0]]
#
# device = torch.device(device_str) if device is None else device
# CHECKEQ(ref_pcs.shape[0], gen_pcs.shape[0])
# N_ref = ref_pcs.shape[0] # subset it
# batch_size_test = N_ref # * 0.5
# if gen_pcs.shape[2] == 6: # B,N,3 or 6
# gen_pcs = gen_pcs[:, :, :3]
# ref_pcs = ref_pcs[:, :, :3]
# if norm_box:
# ref_pcs = 0.5 * torch.stack(normalize_point_clouds(ref_pcs), dim=0)
# gen_pcs = 0.5 * torch.stack(normalize_point_clouds(gen_pcs), dim=0)
# print_kwargs['dataset'] = print_kwargs.get('dataset',
# '')+'-normbox'
#
# #ref_pcs = normalize_point_clouds(ref_pcs)
# #gen_pcs = normalize_point_clouds(gen_pcs)
# # print_kwargs['dataset'] = print_kwargs.get('dataset',
# # '')+'-normbox'
# # logger.info('[data shape] ref_pcs: {}, gen_pcs: {}, mean={}, std={}; norm_box={}',
# # ref_pcs.shape, gen_pcs.shape, m_pcs.shape, s_pcs.shape, norm_box)
# elif m_pcs is not None and s_pcs is not None:
# ref_pcs = ref_pcs * s_pcs + m_pcs
# gen_pcs = gen_pcs * s_pcs + m_pcs
# # visualize first few samples:
# if VIS and writer is not None and writer.exp is not None or exp is not None:
# logger.info('vis the result')
# if exp is None:
# exp = writer.exp
# img_list = []
# for i in range(min(20, ref_pcs.shape[0])):
# NORM_VIS = 0
# if NORM_VIS:
# norm_ref, norm_gen = normalize_point_clouds([
# ref_pcs[i], gen_pcs[i]])
# else:
# norm_ref = ref_pcs[i]
# norm_gen = gen_pcs[i]
# img = visualize_point_clouds_3d([norm_ref, norm_gen],
# [f'ref-{i}', f'gen-{i}'], bound=0.5)
# img_list.append(torch.as_tensor(img) / 255.0)
# grid = torchvision.utils.make_grid(img_list)
# # to 3,H,W to H,W,3
# ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(
# 1, 2, 0).to('cpu', torch.uint8).numpy()
# exp.log_image(ndarr, 'samples/verse_%s' %
# print_kwargs.get('hash', '_'), step=print_kwargs.get('step', 0))
# # epoch=print_kwargs.get('epoch', 0))
#
# metric2 = 'EMD' if not cd_only else None
# results = compute_all_metrics(gen_pcs.to(device).float(),
# ref_pcs.to(device).float(), batch_size_test,
# accelerated_cd=accelerated_cd, metric2=metric2,
# verbose=verbose,
# **print_kwargs)
# print_results(results, **print_kwargs)
#
# return results
@torch.no_grad()
def compute_score(output_name, ref_name, batch_size_test=256, device_str='cuda',
device=None, accelerated_cd=True, writer=None,
exp=None,
norm_box=False, skip_write=False, **print_kwargs):
"""
Args:
output_name (str) path to sample obj: tensor: (Nsample.Npoint.3or6)
ref_name (str) path to torch obj:
torch.save({'ref': ref_pcs, 'mean': m_pcs, 'std': s_pcs}, ref_name)
print_kwargs (dict): entries: dataset, hash, step, epoch;
"""
logger.info('[compute sample metric] sample: {} and ref: {}',
output_name, ref_name)
ref = torch.load(ref_name)
ref_pcs = ref['ref'][:, :, :3]
m_pcs, s_pcs = ref['mean'], ref['std']
gen_pcs = torch.load(output_name)
if gen_pcs.shape[1] > ref_pcs.shape[1]:
xperm = np.random.permutation(np.arange(gen_pcs.shape[1]))[
:ref_pcs.shape[1]]
gen_pcs = gen_pcs[:, xperm]
if type(gen_pcs) is dict:
logger.info('WARNING: the gen_pcs is a dict, with key '
'as {}| usuaglly its a tensor '
'you perhaps takes the train data,',
gen_pcs.keys())
gen_pcs = gen_pcs['ref']
device = torch.device(device_str) if device is None else device
# batch_size_test = ref_pcs.shape[0]
logger.info('[data shape] ref_pcs: {}, gen_pcs: {}, mean={}, std={}; norm_box={}',
ref_pcs.shape, gen_pcs.shape, m_pcs.shape, s_pcs.shape, norm_box)
N_ref = ref_pcs.shape[0] # subset it
m_pcs = m_pcs[:N_ref]
s_pcs = s_pcs[:N_ref]
ref_pcs = ref_pcs[:N_ref]
gen_pcs = gen_pcs[:N_ref]
if gen_pcs.shape[2] == 6: # B,N,3 or 6
gen_pcs = gen_pcs[:, :, :3]
ref_pcs = ref_pcs[:, :, :3]
if norm_box:
#ref_pcs = ref_pcs * s_pcs + m_pcs
#gen_pcs = gen_pcs * s_pcs + m_pcs
ref_pcs = 0.5 * torch.stack(normalize_point_clouds(ref_pcs), dim=0)
gen_pcs = 0.5 * torch.stack(normalize_point_clouds(gen_pcs), dim=0)
print_kwargs['dataset'] = print_kwargs.get('dataset',
'')+'-normbox'
else:
ref_pcs = ref_pcs * s_pcs + m_pcs
gen_pcs = gen_pcs * s_pcs + m_pcs
# visualize first few samples:
if VIS:
if exp is not None:
exp = exp
elif writer is not None:
exp = writer.exp
elif os.path.exists('.comet_api'):
comet_args = json.load(open('.comet_api', 'r'))
exp = Experiment(display_summary_level=0,
**comet_args)
else:
exp = OfflineExperiment(offline_directory="/tmp")
img_list = []
gen_list = []
ref_list = []
for i in range(20):
NORM_VIS = 0
if NORM_VIS:
norm_ref, norm_gen = normalize_point_clouds([
ref_pcs[i], gen_pcs[i]])
else:
norm_ref = ref_pcs[i]
norm_gen = gen_pcs[i]
ref_img = visualize_point_clouds_3d([norm_ref],
[f'ref-{i}'], bound=1.0) # 0.8)
gen_img = visualize_point_clouds_3d([norm_gen],
[f'gen-{i}'], bound=1.0) # 0.8)
ref_list.append(torch.as_tensor(ref_img) / 255.0)
gen_list.append(torch.as_tensor(gen_img) / 255.0)
img_list.append(ref_list[-1])
img_list.append(gen_list[-1])
path = output_name.replace('.pt', '_eval.png')
grid = torchvision.utils.make_grid(gen_list)
# to 3,H,W to H,W,3
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(
1, 2, 0).to('cpu', torch.uint8).numpy()
exp.log_image(ndarr, 'samples')
ref_grid = torchvision.utils.make_grid(ref_list)
# to 3,H,W to H,W,3
ref_ndarr = ref_grid.mul(255).add_(0.5).clamp_(0, 255).permute(
1, 2, 0).to('cpu', torch.uint8).numpy()
ndarr = np.concatenate([ndarr, ref_ndarr], axis=0)
exp.log_image(ndarr, 'samples_vs_ref')
torchvision.utils.save_image(img_list, path)
logger.info(exp.url)
logger.info('save vis at {}', path)
metric2 = 'EMD' if not CD_ONLY else None
logger.info('print_kwargs: {}', print_kwargs)
results = compute_all_metrics(gen_pcs.to(device).float(),
ref_pcs.to(device).float(), batch_size_test, accelerated_cd=accelerated_cd, metric2=metric2,
**print_kwargs)
jsd = jsd_between_point_cloud_sets(
gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy())
results['jsd'] = jsd
msg = print_results(results, **print_kwargs)
# with open('../exp/eval_out.txt', 'a') as f:
# run_time = time.strftime('%m%d-%H%M-%S')
# f.write('<< date: %s >>\n' % run_time)
# f.write('%s\n%s\n' % (exp.url, msg))
results['url'] = exp.url
if not skip_write:
os.makedirs('results', exist_ok=True)
msg = write_results(
os.path.join('./results/', 'eval_out.csv'),
results, **print_kwargs)
if metric2 is None:
logger.info('early exit')
exit()
return results