106 lines
4.3 KiB
Python
106 lines
4.3 KiB
Python
|
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
|
#
|
||
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||
|
# and proprietary rights in and to this software, related documentation
|
||
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
||
|
# distribution of this software and related documentation without an express
|
||
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
import torch.nn as nn
|
||
|
import numpy as np
|
||
|
from loguru import logger
|
||
|
from utils.vis_helper import visualize_point_clouds_3d
|
||
|
from utils.data_helper import normalize_point_clouds
|
||
|
from utils.checker import *
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def validate_inspect_noprior(model,
|
||
|
it, writer,
|
||
|
sample_num_points, num_samples,
|
||
|
need_sample=1, need_val=1, need_train=1,
|
||
|
w_prior=None, val_x=None, tr_x=None,
|
||
|
test_loader=None, # can be None
|
||
|
has_shapelatent=False,
|
||
|
bound=1.5, val_class_label=None, tr_class_label=None,
|
||
|
cfg={}):
|
||
|
""" visualize the samples, and recont if needed
|
||
|
"""
|
||
|
assert(has_shapelatent)
|
||
|
assert(w_prior is not None and val_x is not None and tr_x is not None)
|
||
|
z_list = []
|
||
|
num_samples = w_prior.shape[0] if need_sample else 0
|
||
|
num_recon = val_x.shape[0]
|
||
|
num_recon_val = num_recon if need_val else 0
|
||
|
num_recon_train = num_recon if need_train else 0
|
||
|
|
||
|
assert(need_sample == 0 and need_val > 0 and need_train == 0)
|
||
|
if need_sample:
|
||
|
z_prior = model.pz(w_prior, sample_num_points)
|
||
|
z_list.append(z_prior)
|
||
|
if val_class_label is not None:
|
||
|
output = model.recont(val_x, class_label=val_class_label)
|
||
|
else:
|
||
|
output = model.recont(val_x) # torch.cat([val_x, tr_x]))
|
||
|
gen_x = output['final_pred']
|
||
|
vis_order = cfg.viz.viz_order
|
||
|
vis_args = {'vis_order': vis_order}
|
||
|
|
||
|
# vis the samples
|
||
|
if num_samples > 0:
|
||
|
img_list = []
|
||
|
for i in range(num_samples):
|
||
|
points = gen_x[i]
|
||
|
points = normalize_point_clouds([points])[0]
|
||
|
img = visualize_point_clouds_3d([points], bound=bound, **vis_args)
|
||
|
img_list.append(img)
|
||
|
img = np.concatenate(img_list, axis=2)
|
||
|
writer.add_image('sample', torch.as_tensor(img), it)
|
||
|
|
||
|
# vis the recont
|
||
|
if num_recon_val > 0:
|
||
|
img_list = []
|
||
|
for i in range(num_recon_val):
|
||
|
points = gen_x[num_samples + i]
|
||
|
points = normalize_point_clouds([points]) # val_x[i], points])
|
||
|
img = visualize_point_clouds_3d(
|
||
|
points, ['rec#%d' % i], bound=bound, **vis_args)
|
||
|
img_list.append(img)
|
||
|
gt_list = []
|
||
|
for i in range(num_recon_val):
|
||
|
points = normalize_point_clouds([val_x[i]])
|
||
|
img = visualize_point_clouds_3d(
|
||
|
points, ['gt#%d' % i], bound=bound, **vis_args)
|
||
|
gt_list.append(img)
|
||
|
img = np.concatenate(img_list, axis=2)
|
||
|
gt = np.concatenate(gt_list, axis=2)
|
||
|
img = np.concatenate([gt, img], axis=1)
|
||
|
|
||
|
if 'vis/latent_pts' in output:
|
||
|
latent_pts = output['vis/latent_pts']
|
||
|
img_list = []
|
||
|
for i in range(num_recon_val):
|
||
|
points = latent_pts[num_samples + i]
|
||
|
points = normalize_point_clouds([points])
|
||
|
latent = visualize_point_clouds_3d(
|
||
|
points, ['latent#%d' % i], bound=bound, **vis_args)
|
||
|
img_list.append(latent)
|
||
|
latent_list = np.concatenate(img_list, axis=2)
|
||
|
img = np.concatenate([img, latent_list], axis=1)
|
||
|
|
||
|
writer.add_image('valrecont', torch.as_tensor(img), it)
|
||
|
|
||
|
if num_recon_train > 0:
|
||
|
img_list = []
|
||
|
for i in range(num_recon_train):
|
||
|
points = gen_x[num_samples + num_recon_val + i]
|
||
|
points = normalize_point_clouds([tr_x[i], points])
|
||
|
img = visualize_point_clouds_3d(
|
||
|
points, ['ori', 'rec'], bound=bound, **vis_args)
|
||
|
img_list.append(img)
|
||
|
img = np.concatenate(img_list, axis=2)
|
||
|
writer.add_image('train/recont', torch.as_tensor(img), it)
|
||
|
|
||
|
logger.info('writer: {}', writer.url)
|
||
|
|