563 lines
20 KiB
Python
563 lines
20 KiB
Python
#
|
|
#
|
|
# 0=================================0
|
|
# | Kernel Point Convolutions |
|
|
# 0=================================0
|
|
#
|
|
#
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Class handling the visualization
|
|
#
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Hugues THOMAS - 11/06/2018
|
|
#
|
|
|
|
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Imports and global variables
|
|
# \**********************************/
|
|
#
|
|
|
|
|
|
# Basic libs
|
|
import torch
|
|
import numpy as np
|
|
from sklearn.neighbors import KDTree
|
|
from os import listdir
|
|
from os.path import join
|
|
import time
|
|
from mayavi import mlab
|
|
|
|
from models.blocks import KPConv
|
|
|
|
# PLY reader
|
|
from utils.ply import write_ply
|
|
|
|
# Configuration class
|
|
from utils.config import bcolors
|
|
|
|
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Trainer Class
|
|
# \*******************/
|
|
#
|
|
|
|
|
|
class ModelVisualizer:
|
|
# Initialization methods
|
|
# ------------------------------------------------------------------------------------------------------------------
|
|
|
|
def __init__(self, net, config, chkp_path, on_gpu=True):
|
|
"""
|
|
Initialize training parameters and reload previous model for restore/finetune
|
|
:param net: network object
|
|
:param config: configuration object
|
|
:param chkp_path: path to the checkpoint that needs to be loaded (None for new training)
|
|
:param finetune: finetune from checkpoint (True) or restore training from checkpoint (False)
|
|
:param on_gpu: Train on GPU or CPU
|
|
"""
|
|
|
|
############
|
|
# Parameters
|
|
############
|
|
|
|
# Choose to train on CPU or GPU
|
|
if on_gpu and torch.cuda.is_available():
|
|
self.device = torch.device("cuda:0")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
net.to(self.device)
|
|
|
|
##########################
|
|
# Load previous checkpoint
|
|
##########################
|
|
|
|
checkpoint = torch.load(chkp_path)
|
|
|
|
new_dict = {}
|
|
for k, v in checkpoint["model_state_dict"].items():
|
|
if "blocs" in k:
|
|
k = k.replace("blocs", "blocks")
|
|
new_dict[k] = v
|
|
|
|
net.load_state_dict(new_dict)
|
|
self.epoch = checkpoint["epoch"]
|
|
net.eval()
|
|
print("\nModel state restored from {:s}.".format(chkp_path))
|
|
|
|
return
|
|
|
|
# Main visualization methods
|
|
# ------------------------------------------------------------------------------------------------------------------
|
|
|
|
def show_deformable_kernels(self, net, loader, config, deform_idx=0):
|
|
"""
|
|
Show some inference with deformable kernels
|
|
"""
|
|
|
|
##########################################
|
|
# First choose the visualized deformations
|
|
##########################################
|
|
|
|
print(
|
|
"\nList of the deformable convolution available (chosen one highlighted in green)"
|
|
)
|
|
fmt_str = " {:}{:2d} > KPConv(r={:.3f}, Din={:d}, Dout={:d}){:}"
|
|
deform_convs = []
|
|
for m in net.modules():
|
|
if isinstance(m, KPConv) and m.deformable:
|
|
if len(deform_convs) == deform_idx:
|
|
color = bcolors.OKGREEN
|
|
else:
|
|
color = bcolors.FAIL
|
|
print(
|
|
fmt_str.format(
|
|
color,
|
|
len(deform_convs),
|
|
m.radius,
|
|
m.in_channels,
|
|
m.out_channels,
|
|
bcolors.ENDC,
|
|
)
|
|
)
|
|
deform_convs.append(m)
|
|
|
|
################
|
|
# Initialization
|
|
################
|
|
|
|
print("\n****************************************************\n")
|
|
|
|
# Loop variables
|
|
time.time()
|
|
t = [time.time()]
|
|
time.time()
|
|
np.zeros(1)
|
|
count = 0
|
|
|
|
# Start training loop
|
|
for epoch in range(config.max_epoch):
|
|
for batch in loader:
|
|
##################
|
|
# Processing batch
|
|
##################
|
|
|
|
# New time
|
|
t = t[-1:]
|
|
t += [time.time()]
|
|
|
|
if "cuda" in self.device.type:
|
|
batch.to(self.device)
|
|
|
|
# Forward pass
|
|
net(batch, config)
|
|
original_KP = (
|
|
deform_convs[deform_idx].kernel_points.cpu().detach().numpy()
|
|
)
|
|
stacked_deformed_KP = (
|
|
deform_convs[deform_idx].deformed_KP.cpu().detach().numpy()
|
|
)
|
|
count += batch.lengths[0].shape[0]
|
|
|
|
if "cuda" in self.device.type:
|
|
torch.cuda.synchronize(self.device)
|
|
|
|
# Find layer
|
|
l = None
|
|
for i, p in enumerate(batch.points):
|
|
if p.shape[0] == stacked_deformed_KP.shape[0]:
|
|
l = i
|
|
|
|
t += [time.time()]
|
|
|
|
# Get data
|
|
in_points = []
|
|
in_colors = []
|
|
deformed_KP = []
|
|
points = []
|
|
lookuptrees = []
|
|
i0 = 0
|
|
for b_i, length in enumerate(batch.lengths[0]):
|
|
in_points.append(
|
|
batch.points[0][i0 : i0 + length].cpu().detach().numpy()
|
|
)
|
|
if batch.features.shape[1] == 4:
|
|
in_colors.append(
|
|
batch.features[i0 : i0 + length, 1:].cpu().detach().numpy()
|
|
)
|
|
else:
|
|
in_colors.append(None)
|
|
i0 += length
|
|
|
|
i0 = 0
|
|
for b_i, length in enumerate(batch.lengths[l]):
|
|
points.append(
|
|
batch.points[l][i0 : i0 + length].cpu().detach().numpy()
|
|
)
|
|
deformed_KP.append(stacked_deformed_KP[i0 : i0 + length])
|
|
lookuptrees.append(KDTree(points[-1]))
|
|
i0 += length
|
|
|
|
###########################
|
|
# Interactive visualization
|
|
###########################
|
|
|
|
# Create figure for features
|
|
fig1 = mlab.figure(
|
|
"Deformations", bgcolor=(1.0, 1.0, 1.0), size=(1280, 920)
|
|
)
|
|
fig1.scene.parallel_projection = False
|
|
|
|
# Indices
|
|
global obj_i, point_i, plots, offsets, p_scale, show_in_p, aim_point
|
|
p_scale = 0.03
|
|
obj_i = 0
|
|
point_i = 0
|
|
plots = {}
|
|
offsets = False
|
|
show_in_p = 2
|
|
aim_point = np.zeros((1, 3))
|
|
|
|
def picker_callback(picker):
|
|
"""Picker callback: this get called when on pick events."""
|
|
global plots, aim_point
|
|
|
|
if "in_points" in plots:
|
|
if plots["in_points"].actor.actor._vtk_obj in [
|
|
o._vtk_obj for o in picker.actors
|
|
]:
|
|
point_rez = (
|
|
plots["in_points"]
|
|
.glyph.glyph_source.glyph_source.output.points.to_array()
|
|
.shape[0]
|
|
)
|
|
new_point_i = int(np.floor(picker.point_id / point_rez))
|
|
if new_point_i < len(plots["in_points"].mlab_source.points):
|
|
# Get closest point in the layer we are interested in
|
|
aim_point = plots["in_points"].mlab_source.points[
|
|
new_point_i : new_point_i + 1
|
|
]
|
|
update_scene()
|
|
|
|
if "points" in plots:
|
|
if plots["points"].actor.actor._vtk_obj in [
|
|
o._vtk_obj for o in picker.actors
|
|
]:
|
|
point_rez = (
|
|
plots["points"]
|
|
.glyph.glyph_source.glyph_source.output.points.to_array()
|
|
.shape[0]
|
|
)
|
|
new_point_i = int(np.floor(picker.point_id / point_rez))
|
|
if new_point_i < len(plots["points"].mlab_source.points):
|
|
# Get closest point in the layer we are interested in
|
|
aim_point = plots["points"].mlab_source.points[
|
|
new_point_i : new_point_i + 1
|
|
]
|
|
update_scene()
|
|
|
|
def update_scene():
|
|
global plots, offsets, p_scale, show_in_p, aim_point, point_i
|
|
|
|
# Get the current view
|
|
v = mlab.view()
|
|
roll = mlab.roll()
|
|
|
|
# clear figure
|
|
for key in plots.keys():
|
|
plots[key].remove()
|
|
|
|
plots = {}
|
|
|
|
# Plot new data feature
|
|
p = points[obj_i]
|
|
|
|
# Rescale points for visu
|
|
p = p * 1.5 / config.in_radius
|
|
|
|
# Show point cloud
|
|
if show_in_p <= 1:
|
|
plots["points"] = mlab.points3d(
|
|
p[:, 0],
|
|
p[:, 1],
|
|
p[:, 2],
|
|
resolution=8,
|
|
scale_factor=p_scale,
|
|
scale_mode="none",
|
|
color=(0, 1, 1),
|
|
figure=fig1,
|
|
)
|
|
|
|
if show_in_p >= 1:
|
|
# Get points and colors
|
|
in_p = in_points[obj_i]
|
|
in_p = in_p * 1.5 / config.in_radius
|
|
|
|
# Color point cloud if possible
|
|
in_c = in_colors[obj_i]
|
|
if in_c is not None:
|
|
# Primitives
|
|
scalars = np.arange(
|
|
len(in_p)
|
|
) # Key point: set an integer for each point
|
|
|
|
# Define color table (including alpha), which must be uint8 and [0,255]
|
|
colors = np.hstack((in_c, np.ones_like(in_c[:, :1])))
|
|
colors = (colors * 255).astype(np.uint8)
|
|
|
|
plots["in_points"] = mlab.points3d(
|
|
in_p[:, 0],
|
|
in_p[:, 1],
|
|
in_p[:, 2],
|
|
scalars,
|
|
resolution=8,
|
|
scale_factor=p_scale * 0.8,
|
|
scale_mode="none",
|
|
figure=fig1,
|
|
)
|
|
plots[
|
|
"in_points"
|
|
].module_manager.scalar_lut_manager.lut.table = colors
|
|
|
|
else:
|
|
plots["in_points"] = mlab.points3d(
|
|
in_p[:, 0],
|
|
in_p[:, 1],
|
|
in_p[:, 2],
|
|
resolution=8,
|
|
scale_factor=p_scale * 0.8,
|
|
scale_mode="none",
|
|
figure=fig1,
|
|
)
|
|
|
|
# Get KP locations
|
|
rescaled_aim_point = aim_point * config.in_radius / 1.5
|
|
point_i = lookuptrees[obj_i].query(
|
|
rescaled_aim_point, return_distance=False
|
|
)[0][0]
|
|
if offsets:
|
|
KP = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
|
|
scals = np.ones_like(KP[:, 0])
|
|
else:
|
|
KP = points[obj_i][point_i] + original_KP
|
|
scals = np.zeros_like(KP[:, 0])
|
|
|
|
KP = KP * 1.5 / config.in_radius
|
|
|
|
plots["KP"] = mlab.points3d(
|
|
KP[:, 0],
|
|
KP[:, 1],
|
|
KP[:, 2],
|
|
scals,
|
|
colormap="autumn",
|
|
resolution=8,
|
|
scale_factor=1.2 * p_scale,
|
|
scale_mode="none",
|
|
vmin=0,
|
|
vmax=1,
|
|
figure=fig1,
|
|
)
|
|
|
|
if True:
|
|
plots["center"] = mlab.points3d(
|
|
p[point_i, 0],
|
|
p[point_i, 1],
|
|
p[point_i, 2],
|
|
scale_factor=1.1 * p_scale,
|
|
scale_mode="none",
|
|
color=(0, 1, 0),
|
|
figure=fig1,
|
|
)
|
|
|
|
# New title
|
|
plots["title"] = mlab.title(
|
|
str(obj_i), color=(0, 0, 0), size=0.3, height=0.01
|
|
)
|
|
text = (
|
|
"<--- (press g for previous)"
|
|
+ 50 * " "
|
|
+ "(press h for next) --->"
|
|
)
|
|
plots["text"] = mlab.text(
|
|
0.01, 0.01, text, color=(0, 0, 0), width=0.98
|
|
)
|
|
plots["orient"] = mlab.orientation_axes()
|
|
|
|
# Set the saved view
|
|
mlab.view(*v)
|
|
mlab.roll(roll)
|
|
|
|
return
|
|
|
|
def animate_kernel():
|
|
global plots, offsets, p_scale, show_in_p
|
|
|
|
# Get KP locations
|
|
|
|
KP_def = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
|
|
KP_def = KP_def * 1.5 / config.in_radius
|
|
|
|
KP_rigid = points[obj_i][point_i] + original_KP
|
|
KP_rigid = KP_rigid * 1.5 / config.in_radius
|
|
|
|
if offsets:
|
|
t_list = np.linspace(0, 1, 150, dtype=np.float32)
|
|
else:
|
|
t_list = np.linspace(1, 0, 150, dtype=np.float32)
|
|
|
|
@mlab.animate(delay=10)
|
|
def anim():
|
|
for t in t_list:
|
|
plots["KP"].mlab_source.set(
|
|
x=t * KP_def[:, 0] + (1 - t) * KP_rigid[:, 0],
|
|
y=t * KP_def[:, 1] + (1 - t) * KP_rigid[:, 1],
|
|
z=t * KP_def[:, 2] + (1 - t) * KP_rigid[:, 2],
|
|
scalars=t * np.ones_like(KP_def[:, 0]),
|
|
)
|
|
|
|
yield
|
|
|
|
anim()
|
|
|
|
return
|
|
|
|
def keyboard_callback(vtk_obj, event):
|
|
global obj_i, point_i, offsets, p_scale, show_in_p
|
|
|
|
if vtk_obj.GetKeyCode() in ["b", "B"]:
|
|
p_scale /= 1.5
|
|
update_scene()
|
|
|
|
elif vtk_obj.GetKeyCode() in ["n", "N"]:
|
|
p_scale *= 1.5
|
|
update_scene()
|
|
|
|
if vtk_obj.GetKeyCode() in ["g", "G"]:
|
|
obj_i = (obj_i - 1) % len(deformed_KP)
|
|
point_i = 0
|
|
update_scene()
|
|
|
|
elif vtk_obj.GetKeyCode() in ["h", "H"]:
|
|
obj_i = (obj_i + 1) % len(deformed_KP)
|
|
point_i = 0
|
|
update_scene()
|
|
|
|
elif vtk_obj.GetKeyCode() in ["k", "K"]:
|
|
offsets = not offsets
|
|
animate_kernel()
|
|
|
|
elif vtk_obj.GetKeyCode() in ["z", "Z"]:
|
|
show_in_p = (show_in_p + 1) % 3
|
|
update_scene()
|
|
|
|
elif vtk_obj.GetKeyCode() in ["0"]:
|
|
print("Saving")
|
|
|
|
# Find a new name
|
|
file_i = 0
|
|
file_name = "KP_{:03d}.ply".format(file_i)
|
|
files = [f for f in listdir("KP_clouds") if f.endswith(".ply")]
|
|
while file_name in files:
|
|
file_i += 1
|
|
file_name = "KP_{:03d}.ply".format(file_i)
|
|
|
|
KP_deform = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
|
|
KP_normal = points[obj_i][point_i] + original_KP
|
|
|
|
# Save
|
|
write_ply(
|
|
join("KP_clouds", file_name),
|
|
[in_points[obj_i], in_colors[obj_i]],
|
|
["x", "y", "z", "red", "green", "blue"],
|
|
)
|
|
write_ply(
|
|
join("KP_clouds", "KP_{:03d}_deform.ply".format(file_i)),
|
|
[KP_deform],
|
|
["x", "y", "z"],
|
|
)
|
|
write_ply(
|
|
join("KP_clouds", "KP_{:03d}_normal.ply".format(file_i)),
|
|
[KP_normal],
|
|
["x", "y", "z"],
|
|
)
|
|
print("OK")
|
|
|
|
return
|
|
|
|
# Draw a first plot
|
|
pick_func = fig1.on_mouse_pick(picker_callback)
|
|
pick_func.tolerance = 0.01
|
|
update_scene()
|
|
fig1.scene.interactor.add_observer("KeyPressEvent", keyboard_callback)
|
|
mlab.show()
|
|
|
|
return
|
|
|
|
# Utilities
|
|
# ------------------------------------------------------------------------------------------------------------------
|
|
|
|
|
|
def show_ModelNet_models(all_points):
|
|
###########################
|
|
# Interactive visualization
|
|
###########################
|
|
|
|
# Create figure for features
|
|
fig1 = mlab.figure("Models", bgcolor=(1, 1, 1), size=(1000, 800))
|
|
fig1.scene.parallel_projection = False
|
|
|
|
# Indices
|
|
global file_i
|
|
file_i = 0
|
|
|
|
def update_scene():
|
|
# clear figure
|
|
mlab.clf(fig1)
|
|
|
|
# Plot new data feature
|
|
points = all_points[file_i]
|
|
|
|
# Rescale points for visu
|
|
points = (points * 1.5 + np.array([1.0, 1.0, 1.0])) * 50.0
|
|
|
|
# Show point clouds colorized with activations
|
|
mlab.points3d(
|
|
points[:, 0],
|
|
points[:, 1],
|
|
points[:, 2],
|
|
points[:, 2],
|
|
scale_factor=3.0,
|
|
scale_mode="none",
|
|
figure=fig1,
|
|
)
|
|
|
|
# New title
|
|
mlab.title(str(file_i), color=(0, 0, 0), size=0.3, height=0.01)
|
|
text = "<--- (press g for previous)" + 50 * " " + "(press h for next) --->"
|
|
mlab.text(0.01, 0.01, text, color=(0, 0, 0), width=0.98)
|
|
mlab.orientation_axes()
|
|
|
|
return
|
|
|
|
def keyboard_callback(vtk_obj, event):
|
|
global file_i
|
|
|
|
if vtk_obj.GetKeyCode() in ["g", "G"]:
|
|
file_i = (file_i - 1) % len(all_points)
|
|
update_scene()
|
|
|
|
elif vtk_obj.GetKeyCode() in ["h", "H"]:
|
|
file_i = (file_i + 1) % len(all_points)
|
|
update_scene()
|
|
|
|
return
|
|
|
|
# Draw a first plot
|
|
update_scene()
|
|
fig1.scene.interactor.add_observer("KeyPressEvent", keyboard_callback)
|
|
mlab.show()
|