LION/trainers/common_fun.py

106 lines
4.3 KiB
Python
Raw Normal View History

2023-01-23 05:14:49 +00:00
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from loguru import logger
from utils.vis_helper import visualize_point_clouds_3d
from utils.data_helper import normalize_point_clouds
from utils.checker import *
@torch.no_grad()
def validate_inspect_noprior(model,
it, writer,
sample_num_points, num_samples,
need_sample=1, need_val=1, need_train=1,
w_prior=None, val_x=None, tr_x=None,
test_loader=None, # can be None
has_shapelatent=False,
bound=1.5, val_class_label=None, tr_class_label=None,
cfg={}):
""" visualize the samples, and recont if needed
"""
assert(has_shapelatent)
assert(w_prior is not None and val_x is not None and tr_x is not None)
z_list = []
num_samples = w_prior.shape[0] if need_sample else 0
num_recon = val_x.shape[0]
num_recon_val = num_recon if need_val else 0
num_recon_train = num_recon if need_train else 0
assert(need_sample == 0 and need_val > 0 and need_train == 0)
if need_sample:
z_prior = model.pz(w_prior, sample_num_points)
z_list.append(z_prior)
if val_class_label is not None:
output = model.recont(val_x, class_label=val_class_label)
else:
output = model.recont(val_x) # torch.cat([val_x, tr_x]))
gen_x = output['final_pred']
vis_order = cfg.viz.viz_order
vis_args = {'vis_order': vis_order}
# vis the samples
if num_samples > 0:
img_list = []
for i in range(num_samples):
points = gen_x[i]
points = normalize_point_clouds([points])[0]
img = visualize_point_clouds_3d([points], bound=bound, **vis_args)
img_list.append(img)
img = np.concatenate(img_list, axis=2)
writer.add_image('sample', torch.as_tensor(img), it)
# vis the recont
if num_recon_val > 0:
img_list = []
for i in range(num_recon_val):
points = gen_x[num_samples + i]
points = normalize_point_clouds([points]) # val_x[i], points])
img = visualize_point_clouds_3d(
points, ['rec#%d' % i], bound=bound, **vis_args)
img_list.append(img)
gt_list = []
for i in range(num_recon_val):
points = normalize_point_clouds([val_x[i]])
img = visualize_point_clouds_3d(
points, ['gt#%d' % i], bound=bound, **vis_args)
gt_list.append(img)
img = np.concatenate(img_list, axis=2)
gt = np.concatenate(gt_list, axis=2)
img = np.concatenate([gt, img], axis=1)
if 'vis/latent_pts' in output:
latent_pts = output['vis/latent_pts']
img_list = []
for i in range(num_recon_val):
points = latent_pts[num_samples + i]
points = normalize_point_clouds([points])
latent = visualize_point_clouds_3d(
points, ['latent#%d' % i], bound=bound, **vis_args)
img_list.append(latent)
latent_list = np.concatenate(img_list, axis=2)
img = np.concatenate([img, latent_list], axis=1)
writer.add_image('valrecont', torch.as_tensor(img), it)
if num_recon_train > 0:
img_list = []
for i in range(num_recon_train):
points = gen_x[num_samples + num_recon_val + i]
points = normalize_point_clouds([tr_x[i], points])
img = visualize_point_clouds_3d(
points, ['ori', 'rec'], bound=bound, **vis_args)
img_list.append(img)
img = np.concatenate(img_list, axis=2)
writer.add_image('train/recont', torch.as_tensor(img), it)
logger.info('writer: {}', writer.url)