532 lines
20 KiB
Python
532 lines
20 KiB
Python
#
|
|
#
|
|
# 0=================================0
|
|
# | Kernel Point Convolutions |
|
|
# 0=================================0
|
|
#
|
|
#
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Class handling the visualization
|
|
#
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Hugues THOMAS - 11/06/2018
|
|
#
|
|
|
|
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Imports and global variables
|
|
# \**********************************/
|
|
#
|
|
|
|
|
|
# Basic libs
|
|
import torch
|
|
import numpy as np
|
|
from sklearn.neighbors import KDTree
|
|
from os import makedirs, remove, rename, listdir
|
|
from os.path import exists, join
|
|
import time
|
|
from mayavi import mlab
|
|
import sys
|
|
|
|
from models.blocks import KPConv
|
|
|
|
# PLY reader
|
|
from utils.ply import write_ply, read_ply
|
|
|
|
# Configuration class
|
|
from utils.config import Config, bcolors
|
|
|
|
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Trainer Class
|
|
# \*******************/
|
|
#
|
|
|
|
|
|
class ModelVisualizer:
|
|
|
|
# Initialization methods
|
|
# ------------------------------------------------------------------------------------------------------------------
|
|
|
|
def __init__(self, net, config, chkp_path, on_gpu=True):
|
|
"""
|
|
Initialize training parameters and reload previous model for restore/finetune
|
|
:param net: network object
|
|
:param config: configuration object
|
|
:param chkp_path: path to the checkpoint that needs to be loaded (None for new training)
|
|
:param finetune: finetune from checkpoint (True) or restore training from checkpoint (False)
|
|
:param on_gpu: Train on GPU or CPU
|
|
"""
|
|
|
|
############
|
|
# Parameters
|
|
############
|
|
|
|
# Choose to train on CPU or GPU
|
|
if on_gpu and torch.cuda.is_available():
|
|
self.device = torch.device("cuda:0")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
net.to(self.device)
|
|
|
|
##########################
|
|
# Load previous checkpoint
|
|
##########################
|
|
|
|
checkpoint = torch.load(chkp_path)
|
|
|
|
new_dict = {}
|
|
for k, v in checkpoint['model_state_dict'].items():
|
|
if 'blocs' in k:
|
|
k = k.replace('blocs', 'blocks')
|
|
new_dict[k] = v
|
|
|
|
net.load_state_dict(new_dict)
|
|
self.epoch = checkpoint['epoch']
|
|
net.eval()
|
|
print("\nModel state restored from {:s}.".format(chkp_path))
|
|
|
|
return
|
|
|
|
# Main visualization methods
|
|
# ------------------------------------------------------------------------------------------------------------------
|
|
|
|
def show_deformable_kernels(self, net, loader, config, deform_idx=0):
|
|
"""
|
|
Show some inference with deformable kernels
|
|
"""
|
|
|
|
##########################################
|
|
# First choose the visualized deformations
|
|
##########################################
|
|
|
|
print('\nList of the deformable convolution available (chosen one highlighted in green)')
|
|
fmt_str = ' {:}{:2d} > KPConv(r={:.3f}, Din={:d}, Dout={:d}){:}'
|
|
deform_convs = []
|
|
for m in net.modules():
|
|
if isinstance(m, KPConv) and m.deformable:
|
|
if len(deform_convs) == deform_idx:
|
|
color = bcolors.OKGREEN
|
|
else:
|
|
color = bcolors.FAIL
|
|
print(fmt_str.format(color, len(deform_convs), m.radius, m.in_channels, m.out_channels, bcolors.ENDC))
|
|
deform_convs.append(m)
|
|
|
|
################
|
|
# Initialization
|
|
################
|
|
|
|
print('\n****************************************************\n')
|
|
|
|
# Loop variables
|
|
t0 = time.time()
|
|
t = [time.time()]
|
|
last_display = time.time()
|
|
mean_dt = np.zeros(1)
|
|
count = 0
|
|
|
|
# Start training loop
|
|
for epoch in range(config.max_epoch):
|
|
|
|
for batch in loader:
|
|
|
|
##################
|
|
# Processing batch
|
|
##################
|
|
|
|
# New time
|
|
t = t[-1:]
|
|
t += [time.time()]
|
|
|
|
if 'cuda' in self.device.type:
|
|
batch.to(self.device)
|
|
|
|
# Forward pass
|
|
outputs = net(batch, config)
|
|
original_KP = deform_convs[deform_idx].kernel_points.cpu().detach().numpy()
|
|
stacked_deformed_KP = deform_convs[deform_idx].deformed_KP.cpu().detach().numpy()
|
|
count += batch.lengths[0].shape[0]
|
|
|
|
if 'cuda' in self.device.type:
|
|
torch.cuda.synchronize(self.device)
|
|
|
|
# Find layer
|
|
l = None
|
|
for i, p in enumerate(batch.points):
|
|
if p.shape[0] == stacked_deformed_KP.shape[0]:
|
|
l = i
|
|
|
|
t += [time.time()]
|
|
|
|
# Get data
|
|
in_points = []
|
|
in_colors = []
|
|
deformed_KP = []
|
|
points = []
|
|
lookuptrees = []
|
|
i0 = 0
|
|
for b_i, length in enumerate(batch.lengths[0]):
|
|
in_points.append(batch.points[0][i0:i0 + length].cpu().detach().numpy())
|
|
if batch.features.shape[1] == 4:
|
|
in_colors.append(batch.features[i0:i0 + length, 1:].cpu().detach().numpy())
|
|
else:
|
|
in_colors.append(None)
|
|
i0 += length
|
|
|
|
i0 = 0
|
|
for b_i, length in enumerate(batch.lengths[l]):
|
|
points.append(batch.points[l][i0:i0 + length].cpu().detach().numpy())
|
|
deformed_KP.append(stacked_deformed_KP[i0:i0 + length])
|
|
lookuptrees.append(KDTree(points[-1]))
|
|
i0 += length
|
|
|
|
###########################
|
|
# Interactive visualization
|
|
###########################
|
|
|
|
# Create figure for features
|
|
fig1 = mlab.figure('Deformations', bgcolor=(1.0, 1.0, 1.0), size=(1280, 920))
|
|
fig1.scene.parallel_projection = False
|
|
|
|
# Indices
|
|
global obj_i, point_i, plots, offsets, p_scale, show_in_p, aim_point
|
|
p_scale = 0.03
|
|
obj_i = 0
|
|
point_i = 0
|
|
plots = {}
|
|
offsets = False
|
|
show_in_p = 2
|
|
aim_point = np.zeros((1, 3))
|
|
|
|
def picker_callback(picker):
|
|
""" Picker callback: this get called when on pick events.
|
|
"""
|
|
global plots, aim_point
|
|
|
|
if 'in_points' in plots:
|
|
if plots['in_points'].actor.actor._vtk_obj in [o._vtk_obj for o in picker.actors]:
|
|
point_rez = plots['in_points'].glyph.glyph_source.glyph_source.output.points.to_array().shape[0]
|
|
new_point_i = int(np.floor(picker.point_id / point_rez))
|
|
if new_point_i < len(plots['in_points'].mlab_source.points):
|
|
# Get closest point in the layer we are interested in
|
|
aim_point = plots['in_points'].mlab_source.points[new_point_i:new_point_i + 1]
|
|
update_scene()
|
|
|
|
if 'points' in plots:
|
|
if plots['points'].actor.actor._vtk_obj in [o._vtk_obj for o in picker.actors]:
|
|
point_rez = plots['points'].glyph.glyph_source.glyph_source.output.points.to_array().shape[0]
|
|
new_point_i = int(np.floor(picker.point_id / point_rez))
|
|
if new_point_i < len(plots['points'].mlab_source.points):
|
|
# Get closest point in the layer we are interested in
|
|
aim_point = plots['points'].mlab_source.points[new_point_i:new_point_i + 1]
|
|
update_scene()
|
|
|
|
def update_scene():
|
|
global plots, offsets, p_scale, show_in_p, aim_point, point_i
|
|
|
|
# Get the current view
|
|
v = mlab.view()
|
|
roll = mlab.roll()
|
|
|
|
# clear figure
|
|
for key in plots.keys():
|
|
plots[key].remove()
|
|
|
|
plots = {}
|
|
|
|
# Plot new data feature
|
|
p = points[obj_i]
|
|
|
|
# Rescale points for visu
|
|
p = (p * 1.5 / config.in_radius)
|
|
|
|
|
|
# Show point cloud
|
|
if show_in_p <= 1:
|
|
plots['points'] = mlab.points3d(p[:, 0],
|
|
p[:, 1],
|
|
p[:, 2],
|
|
resolution=8,
|
|
scale_factor=p_scale,
|
|
scale_mode='none',
|
|
color=(0, 1, 1),
|
|
figure=fig1)
|
|
|
|
if show_in_p >= 1:
|
|
|
|
# Get points and colors
|
|
in_p = in_points[obj_i]
|
|
in_p = (in_p * 1.5 / config.in_radius)
|
|
|
|
# Color point cloud if possible
|
|
in_c = in_colors[obj_i]
|
|
if in_c is not None:
|
|
|
|
# Primitives
|
|
scalars = np.arange(len(in_p)) # Key point: set an integer for each point
|
|
|
|
# Define color table (including alpha), which must be uint8 and [0,255]
|
|
colors = np.hstack((in_c, np.ones_like(in_c[:, :1])))
|
|
colors = (colors * 255).astype(np.uint8)
|
|
|
|
plots['in_points'] = mlab.points3d(in_p[:, 0],
|
|
in_p[:, 1],
|
|
in_p[:, 2],
|
|
scalars,
|
|
resolution=8,
|
|
scale_factor=p_scale*0.8,
|
|
scale_mode='none',
|
|
figure=fig1)
|
|
plots['in_points'].module_manager.scalar_lut_manager.lut.table = colors
|
|
|
|
else:
|
|
|
|
plots['in_points'] = mlab.points3d(in_p[:, 0],
|
|
in_p[:, 1],
|
|
in_p[:, 2],
|
|
resolution=8,
|
|
scale_factor=p_scale*0.8,
|
|
scale_mode='none',
|
|
figure=fig1)
|
|
|
|
|
|
# Get KP locations
|
|
rescaled_aim_point = aim_point * config.in_radius / 1.5
|
|
point_i = lookuptrees[obj_i].query(rescaled_aim_point, return_distance=False)[0][0]
|
|
if offsets:
|
|
KP = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
|
|
scals = np.ones_like(KP[:, 0])
|
|
else:
|
|
KP = points[obj_i][point_i] + original_KP
|
|
scals = np.zeros_like(KP[:, 0])
|
|
|
|
KP = (KP * 1.5 / config.in_radius)
|
|
|
|
plots['KP'] = mlab.points3d(KP[:, 0],
|
|
KP[:, 1],
|
|
KP[:, 2],
|
|
scals,
|
|
colormap='autumn',
|
|
resolution=8,
|
|
scale_factor=1.2*p_scale,
|
|
scale_mode='none',
|
|
vmin=0,
|
|
vmax=1,
|
|
figure=fig1)
|
|
|
|
|
|
if True:
|
|
plots['center'] = mlab.points3d(p[point_i, 0],
|
|
p[point_i, 1],
|
|
p[point_i, 2],
|
|
scale_factor=1.1*p_scale,
|
|
scale_mode='none',
|
|
color=(0, 1, 0),
|
|
figure=fig1)
|
|
|
|
# New title
|
|
plots['title'] = mlab.title(str(obj_i), color=(0, 0, 0), size=0.3, height=0.01)
|
|
text = '<--- (press g for previous)' + 50 * ' ' + '(press h for next) --->'
|
|
plots['text'] = mlab.text(0.01, 0.01, text, color=(0, 0, 0), width=0.98)
|
|
plots['orient'] = mlab.orientation_axes()
|
|
|
|
# Set the saved view
|
|
mlab.view(*v)
|
|
mlab.roll(roll)
|
|
|
|
return
|
|
|
|
def animate_kernel():
|
|
global plots, offsets, p_scale, show_in_p
|
|
|
|
# Get KP locations
|
|
|
|
KP_def = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
|
|
KP_def = (KP_def * 1.5 / config.in_radius)
|
|
KP_def_color = (1, 0, 0)
|
|
|
|
KP_rigid = points[obj_i][point_i] + original_KP
|
|
KP_rigid = (KP_rigid * 1.5 / config.in_radius)
|
|
KP_rigid_color = (1, 0.7, 0)
|
|
|
|
if offsets:
|
|
t_list = np.linspace(0, 1, 150, dtype=np.float32)
|
|
else:
|
|
t_list = np.linspace(1, 0, 150, dtype=np.float32)
|
|
|
|
@mlab.animate(delay=10)
|
|
def anim():
|
|
for t in t_list:
|
|
plots['KP'].mlab_source.set(x=t * KP_def[:, 0] + (1 - t) * KP_rigid[:, 0],
|
|
y=t * KP_def[:, 1] + (1 - t) * KP_rigid[:, 1],
|
|
z=t * KP_def[:, 2] + (1 - t) * KP_rigid[:, 2],
|
|
scalars=t * np.ones_like(KP_def[:, 0]))
|
|
|
|
yield
|
|
|
|
anim()
|
|
|
|
return
|
|
|
|
def keyboard_callback(vtk_obj, event):
|
|
global obj_i, point_i, offsets, p_scale, show_in_p
|
|
|
|
if vtk_obj.GetKeyCode() in ['b', 'B']:
|
|
p_scale /= 1.5
|
|
update_scene()
|
|
|
|
elif vtk_obj.GetKeyCode() in ['n', 'N']:
|
|
p_scale *= 1.5
|
|
update_scene()
|
|
|
|
if vtk_obj.GetKeyCode() in ['g', 'G']:
|
|
obj_i = (obj_i - 1) % len(deformed_KP)
|
|
point_i = 0
|
|
update_scene()
|
|
|
|
elif vtk_obj.GetKeyCode() in ['h', 'H']:
|
|
obj_i = (obj_i + 1) % len(deformed_KP)
|
|
point_i = 0
|
|
update_scene()
|
|
|
|
elif vtk_obj.GetKeyCode() in ['k', 'K']:
|
|
offsets = not offsets
|
|
animate_kernel()
|
|
|
|
elif vtk_obj.GetKeyCode() in ['z', 'Z']:
|
|
show_in_p = (show_in_p + 1) % 3
|
|
update_scene()
|
|
|
|
elif vtk_obj.GetKeyCode() in ['0']:
|
|
|
|
print('Saving')
|
|
|
|
# Find a new name
|
|
file_i = 0
|
|
file_name = 'KP_{:03d}.ply'.format(file_i)
|
|
files = [f for f in listdir('KP_clouds') if f.endswith('.ply')]
|
|
while file_name in files:
|
|
file_i += 1
|
|
file_name = 'KP_{:03d}.ply'.format(file_i)
|
|
|
|
KP_deform = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
|
|
KP_normal = points[obj_i][point_i] + original_KP
|
|
|
|
# Save
|
|
write_ply(join('KP_clouds', file_name),
|
|
[in_points[obj_i], in_colors[obj_i]],
|
|
['x', 'y', 'z', 'red', 'green', 'blue'])
|
|
write_ply(join('KP_clouds', 'KP_{:03d}_deform.ply'.format(file_i)),
|
|
[KP_deform],
|
|
['x', 'y', 'z'])
|
|
write_ply(join('KP_clouds', 'KP_{:03d}_normal.ply'.format(file_i)),
|
|
[KP_normal],
|
|
['x', 'y', 'z'])
|
|
print('OK')
|
|
|
|
return
|
|
|
|
# Draw a first plot
|
|
pick_func = fig1.on_mouse_pick(picker_callback)
|
|
pick_func.tolerance = 0.01
|
|
update_scene()
|
|
fig1.scene.interactor.add_observer('KeyPressEvent', keyboard_callback)
|
|
mlab.show()
|
|
|
|
return
|
|
|
|
# Utilities
|
|
# ------------------------------------------------------------------------------------------------------------------
|
|
|
|
|
|
def show_ModelNet_models(all_points):
|
|
|
|
###########################
|
|
# Interactive visualization
|
|
###########################
|
|
|
|
# Create figure for features
|
|
fig1 = mlab.figure('Models', bgcolor=(1, 1, 1), size=(1000, 800))
|
|
fig1.scene.parallel_projection = False
|
|
|
|
# Indices
|
|
global file_i
|
|
file_i = 0
|
|
|
|
def update_scene():
|
|
|
|
# clear figure
|
|
mlab.clf(fig1)
|
|
|
|
# Plot new data feature
|
|
points = all_points[file_i]
|
|
|
|
# Rescale points for visu
|
|
points = (points * 1.5 + np.array([1.0, 1.0, 1.0])) * 50.0
|
|
|
|
# Show point clouds colorized with activations
|
|
activations = mlab.points3d(points[:, 0],
|
|
points[:, 1],
|
|
points[:, 2],
|
|
points[:, 2],
|
|
scale_factor=3.0,
|
|
scale_mode='none',
|
|
figure=fig1)
|
|
|
|
# New title
|
|
mlab.title(str(file_i), color=(0, 0, 0), size=0.3, height=0.01)
|
|
text = '<--- (press g for previous)' + 50 * ' ' + '(press h for next) --->'
|
|
mlab.text(0.01, 0.01, text, color=(0, 0, 0), width=0.98)
|
|
mlab.orientation_axes()
|
|
|
|
return
|
|
|
|
def keyboard_callback(vtk_obj, event):
|
|
global file_i
|
|
|
|
if vtk_obj.GetKeyCode() in ['g', 'G']:
|
|
|
|
file_i = (file_i - 1) % len(all_points)
|
|
update_scene()
|
|
|
|
elif vtk_obj.GetKeyCode() in ['h', 'H']:
|
|
|
|
file_i = (file_i + 1) % len(all_points)
|
|
update_scene()
|
|
|
|
return
|
|
|
|
# Draw a first plot
|
|
update_scene()
|
|
fig1.scene.interactor.add_observer('KeyPressEvent', keyboard_callback)
|
|
mlab.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|