KPConv-PyTorch/utils/visualizer.py

532 lines
20 KiB
Python
Raw Normal View History

2020-03-31 19:42:35 +00:00
#
#
# 0=================================0
# | Kernel Point Convolutions |
# 0=================================0
#
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Class handling the visualization
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Hugues THOMAS - 11/06/2018
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Imports and global variables
# \**********************************/
#
# Basic libs
import torch
import numpy as np
from sklearn.neighbors import KDTree
from os import makedirs, remove, rename, listdir
from os.path import exists, join
import time
from mayavi import mlab
import sys
from models.blocks import KPConv
# PLY reader
from utils.ply import write_ply, read_ply
# Configuration class
from utils.config import Config, bcolors
# ----------------------------------------------------------------------------------------------------------------------
#
# Trainer Class
# \*******************/
#
class ModelVisualizer:
# Initialization methods
# ------------------------------------------------------------------------------------------------------------------
def __init__(self, net, config, chkp_path, on_gpu=True):
"""
Initialize training parameters and reload previous model for restore/finetune
:param net: network object
:param config: configuration object
:param chkp_path: path to the checkpoint that needs to be loaded (None for new training)
:param finetune: finetune from checkpoint (True) or restore training from checkpoint (False)
:param on_gpu: Train on GPU or CPU
"""
############
# Parameters
############
# Choose to train on CPU or GPU
if on_gpu and torch.cuda.is_available():
self.device = torch.device("cuda:0")
else:
self.device = torch.device("cpu")
net.to(self.device)
##########################
# Load previous checkpoint
##########################
checkpoint = torch.load(chkp_path)
2020-04-24 16:00:11 +00:00
new_dict = {}
for k, v in checkpoint['model_state_dict'].items():
if 'blocs' in k:
k = k.replace('blocs', 'blocks')
new_dict[k] = v
net.load_state_dict(new_dict)
2020-03-31 19:42:35 +00:00
self.epoch = checkpoint['epoch']
net.eval()
2020-04-23 13:51:16 +00:00
print("\nModel state restored from {:s}.".format(chkp_path))
2020-03-31 19:42:35 +00:00
return
# Main visualization methods
# ------------------------------------------------------------------------------------------------------------------
def show_deformable_kernels(self, net, loader, config, deform_idx=0):
"""
Show some inference with deformable kernels
"""
##########################################
# First choose the visualized deformations
##########################################
print('\nList of the deformable convolution available (chosen one highlighted in green)')
fmt_str = ' {:}{:2d} > KPConv(r={:.3f}, Din={:d}, Dout={:d}){:}'
deform_convs = []
for m in net.modules():
if isinstance(m, KPConv) and m.deformable:
if len(deform_convs) == deform_idx:
color = bcolors.OKGREEN
else:
color = bcolors.FAIL
print(fmt_str.format(color, len(deform_convs), m.radius, m.in_channels, m.out_channels, bcolors.ENDC))
deform_convs.append(m)
################
# Initialization
################
print('\n****************************************************\n')
# Loop variables
t0 = time.time()
t = [time.time()]
last_display = time.time()
mean_dt = np.zeros(1)
count = 0
# Start training loop
for epoch in range(config.max_epoch):
for batch in loader:
##################
# Processing batch
##################
# New time
t = t[-1:]
t += [time.time()]
if 'cuda' in self.device.type:
batch.to(self.device)
# Forward pass
outputs = net(batch, config)
original_KP = deform_convs[deform_idx].kernel_points.cpu().detach().numpy()
stacked_deformed_KP = deform_convs[deform_idx].deformed_KP.cpu().detach().numpy()
count += batch.lengths[0].shape[0]
if 'cuda' in self.device.type:
torch.cuda.synchronize(self.device)
# Find layer
l = None
for i, p in enumerate(batch.points):
if p.shape[0] == stacked_deformed_KP.shape[0]:
l = i
t += [time.time()]
# Get data
in_points = []
in_colors = []
deformed_KP = []
points = []
lookuptrees = []
i0 = 0
for b_i, length in enumerate(batch.lengths[0]):
in_points.append(batch.points[0][i0:i0 + length].cpu().detach().numpy())
if batch.features.shape[1] == 4:
in_colors.append(batch.features[i0:i0 + length, 1:].cpu().detach().numpy())
else:
in_colors.append(None)
i0 += length
i0 = 0
for b_i, length in enumerate(batch.lengths[l]):
points.append(batch.points[l][i0:i0 + length].cpu().detach().numpy())
deformed_KP.append(stacked_deformed_KP[i0:i0 + length])
lookuptrees.append(KDTree(points[-1]))
i0 += length
###########################
# Interactive visualization
###########################
# Create figure for features
fig1 = mlab.figure('Deformations', bgcolor=(1.0, 1.0, 1.0), size=(1280, 920))
fig1.scene.parallel_projection = False
# Indices
global obj_i, point_i, plots, offsets, p_scale, show_in_p, aim_point
p_scale = 0.03
obj_i = 0
point_i = 0
plots = {}
offsets = False
show_in_p = 2
aim_point = np.zeros((1, 3))
def picker_callback(picker):
""" Picker callback: this get called when on pick events.
"""
global plots, aim_point
if 'in_points' in plots:
if plots['in_points'].actor.actor._vtk_obj in [o._vtk_obj for o in picker.actors]:
point_rez = plots['in_points'].glyph.glyph_source.glyph_source.output.points.to_array().shape[0]
new_point_i = int(np.floor(picker.point_id / point_rez))
if new_point_i < len(plots['in_points'].mlab_source.points):
# Get closest point in the layer we are interested in
aim_point = plots['in_points'].mlab_source.points[new_point_i:new_point_i + 1]
update_scene()
if 'points' in plots:
if plots['points'].actor.actor._vtk_obj in [o._vtk_obj for o in picker.actors]:
point_rez = plots['points'].glyph.glyph_source.glyph_source.output.points.to_array().shape[0]
new_point_i = int(np.floor(picker.point_id / point_rez))
if new_point_i < len(plots['points'].mlab_source.points):
# Get closest point in the layer we are interested in
aim_point = plots['points'].mlab_source.points[new_point_i:new_point_i + 1]
update_scene()
def update_scene():
global plots, offsets, p_scale, show_in_p, aim_point, point_i
# Get the current view
v = mlab.view()
roll = mlab.roll()
# clear figure
for key in plots.keys():
plots[key].remove()
plots = {}
# Plot new data feature
p = points[obj_i]
# Rescale points for visu
p = (p * 1.5 / config.in_radius)
# Show point cloud
if show_in_p <= 1:
plots['points'] = mlab.points3d(p[:, 0],
p[:, 1],
p[:, 2],
resolution=8,
scale_factor=p_scale,
scale_mode='none',
color=(0, 1, 1),
figure=fig1)
if show_in_p >= 1:
# Get points and colors
in_p = in_points[obj_i]
in_p = (in_p * 1.5 / config.in_radius)
# Color point cloud if possible
in_c = in_colors[obj_i]
if in_c is not None:
# Primitives
scalars = np.arange(len(in_p)) # Key point: set an integer for each point
# Define color table (including alpha), which must be uint8 and [0,255]
colors = np.hstack((in_c, np.ones_like(in_c[:, :1])))
colors = (colors * 255).astype(np.uint8)
plots['in_points'] = mlab.points3d(in_p[:, 0],
in_p[:, 1],
in_p[:, 2],
scalars,
resolution=8,
scale_factor=p_scale*0.8,
scale_mode='none',
figure=fig1)
plots['in_points'].module_manager.scalar_lut_manager.lut.table = colors
else:
plots['in_points'] = mlab.points3d(in_p[:, 0],
in_p[:, 1],
in_p[:, 2],
resolution=8,
scale_factor=p_scale*0.8,
scale_mode='none',
figure=fig1)
# Get KP locations
rescaled_aim_point = aim_point * config.in_radius / 1.5
point_i = lookuptrees[obj_i].query(rescaled_aim_point, return_distance=False)[0][0]
if offsets:
KP = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
scals = np.ones_like(KP[:, 0])
else:
KP = points[obj_i][point_i] + original_KP
scals = np.zeros_like(KP[:, 0])
KP = (KP * 1.5 / config.in_radius)
plots['KP'] = mlab.points3d(KP[:, 0],
KP[:, 1],
KP[:, 2],
scals,
colormap='autumn',
resolution=8,
scale_factor=1.2*p_scale,
scale_mode='none',
vmin=0,
vmax=1,
figure=fig1)
if True:
plots['center'] = mlab.points3d(p[point_i, 0],
p[point_i, 1],
p[point_i, 2],
scale_factor=1.1*p_scale,
scale_mode='none',
color=(0, 1, 0),
figure=fig1)
# New title
plots['title'] = mlab.title(str(obj_i), color=(0, 0, 0), size=0.3, height=0.01)
text = '<--- (press g for previous)' + 50 * ' ' + '(press h for next) --->'
plots['text'] = mlab.text(0.01, 0.01, text, color=(0, 0, 0), width=0.98)
plots['orient'] = mlab.orientation_axes()
# Set the saved view
mlab.view(*v)
mlab.roll(roll)
return
def animate_kernel():
global plots, offsets, p_scale, show_in_p
# Get KP locations
KP_def = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
KP_def = (KP_def * 1.5 / config.in_radius)
KP_def_color = (1, 0, 0)
KP_rigid = points[obj_i][point_i] + original_KP
KP_rigid = (KP_rigid * 1.5 / config.in_radius)
KP_rigid_color = (1, 0.7, 0)
if offsets:
t_list = np.linspace(0, 1, 150, dtype=np.float32)
else:
t_list = np.linspace(1, 0, 150, dtype=np.float32)
@mlab.animate(delay=10)
def anim():
for t in t_list:
plots['KP'].mlab_source.set(x=t * KP_def[:, 0] + (1 - t) * KP_rigid[:, 0],
y=t * KP_def[:, 1] + (1 - t) * KP_rigid[:, 1],
z=t * KP_def[:, 2] + (1 - t) * KP_rigid[:, 2],
scalars=t * np.ones_like(KP_def[:, 0]))
yield
anim()
return
def keyboard_callback(vtk_obj, event):
global obj_i, point_i, offsets, p_scale, show_in_p
if vtk_obj.GetKeyCode() in ['b', 'B']:
p_scale /= 1.5
update_scene()
elif vtk_obj.GetKeyCode() in ['n', 'N']:
p_scale *= 1.5
update_scene()
if vtk_obj.GetKeyCode() in ['g', 'G']:
obj_i = (obj_i - 1) % len(deformed_KP)
point_i = 0
update_scene()
elif vtk_obj.GetKeyCode() in ['h', 'H']:
obj_i = (obj_i + 1) % len(deformed_KP)
point_i = 0
update_scene()
elif vtk_obj.GetKeyCode() in ['k', 'K']:
offsets = not offsets
animate_kernel()
elif vtk_obj.GetKeyCode() in ['z', 'Z']:
show_in_p = (show_in_p + 1) % 3
update_scene()
elif vtk_obj.GetKeyCode() in ['0']:
print('Saving')
# Find a new name
file_i = 0
file_name = 'KP_{:03d}.ply'.format(file_i)
files = [f for f in listdir('KP_clouds') if f.endswith('.ply')]
while file_name in files:
file_i += 1
file_name = 'KP_{:03d}.ply'.format(file_i)
KP_deform = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
KP_normal = points[obj_i][point_i] + original_KP
# Save
write_ply(join('KP_clouds', file_name),
[in_points[obj_i], in_colors[obj_i]],
['x', 'y', 'z', 'red', 'green', 'blue'])
write_ply(join('KP_clouds', 'KP_{:03d}_deform.ply'.format(file_i)),
[KP_deform],
['x', 'y', 'z'])
write_ply(join('KP_clouds', 'KP_{:03d}_normal.ply'.format(file_i)),
[KP_normal],
['x', 'y', 'z'])
print('OK')
return
# Draw a first plot
pick_func = fig1.on_mouse_pick(picker_callback)
pick_func.tolerance = 0.01
update_scene()
fig1.scene.interactor.add_observer('KeyPressEvent', keyboard_callback)
mlab.show()
return
# Utilities
# ------------------------------------------------------------------------------------------------------------------
def show_ModelNet_models(all_points):
###########################
# Interactive visualization
###########################
# Create figure for features
fig1 = mlab.figure('Models', bgcolor=(1, 1, 1), size=(1000, 800))
fig1.scene.parallel_projection = False
# Indices
global file_i
file_i = 0
def update_scene():
# clear figure
mlab.clf(fig1)
# Plot new data feature
points = all_points[file_i]
# Rescale points for visu
points = (points * 1.5 + np.array([1.0, 1.0, 1.0])) * 50.0
# Show point clouds colorized with activations
activations = mlab.points3d(points[:, 0],
points[:, 1],
points[:, 2],
points[:, 2],
scale_factor=3.0,
scale_mode='none',
figure=fig1)
# New title
mlab.title(str(file_i), color=(0, 0, 0), size=0.3, height=0.01)
text = '<--- (press g for previous)' + 50 * ' ' + '(press h for next) --->'
mlab.text(0.01, 0.01, text, color=(0, 0, 0), width=0.98)
mlab.orientation_axes()
return
def keyboard_callback(vtk_obj, event):
global file_i
if vtk_obj.GetKeyCode() in ['g', 'G']:
file_i = (file_i - 1) % len(all_points)
update_scene()
elif vtk_obj.GetKeyCode() in ['h', 'H']:
file_i = (file_i + 1) % len(all_points)
update_scene()
return
# Draw a first plot
update_scene()
fig1.scene.interactor.add_observer('KeyPressEvent', keyboard_callback)
mlab.show()