KPConv-PyTorch/utils/visualizer.py
Laurent FAINSIN d0cdb8e4ee 🎨 black + ruff
2023-05-15 17:18:10 +02:00

563 lines
20 KiB
Python

#
#
# 0=================================0
# | Kernel Point Convolutions |
# 0=================================0
#
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Class handling the visualization
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Hugues THOMAS - 11/06/2018
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Imports and global variables
# \**********************************/
#
# Basic libs
import torch
import numpy as np
from sklearn.neighbors import KDTree
from os import listdir
from os.path import join
import time
from mayavi import mlab
from models.blocks import KPConv
# PLY reader
from utils.ply import write_ply
# Configuration class
from utils.config import bcolors
# ----------------------------------------------------------------------------------------------------------------------
#
# Trainer Class
# \*******************/
#
class ModelVisualizer:
# Initialization methods
# ------------------------------------------------------------------------------------------------------------------
def __init__(self, net, config, chkp_path, on_gpu=True):
"""
Initialize training parameters and reload previous model for restore/finetune
:param net: network object
:param config: configuration object
:param chkp_path: path to the checkpoint that needs to be loaded (None for new training)
:param finetune: finetune from checkpoint (True) or restore training from checkpoint (False)
:param on_gpu: Train on GPU or CPU
"""
############
# Parameters
############
# Choose to train on CPU or GPU
if on_gpu and torch.cuda.is_available():
self.device = torch.device("cuda:0")
else:
self.device = torch.device("cpu")
net.to(self.device)
##########################
# Load previous checkpoint
##########################
checkpoint = torch.load(chkp_path)
new_dict = {}
for k, v in checkpoint["model_state_dict"].items():
if "blocs" in k:
k = k.replace("blocs", "blocks")
new_dict[k] = v
net.load_state_dict(new_dict)
self.epoch = checkpoint["epoch"]
net.eval()
print("\nModel state restored from {:s}.".format(chkp_path))
return
# Main visualization methods
# ------------------------------------------------------------------------------------------------------------------
def show_deformable_kernels(self, net, loader, config, deform_idx=0):
"""
Show some inference with deformable kernels
"""
##########################################
# First choose the visualized deformations
##########################################
print(
"\nList of the deformable convolution available (chosen one highlighted in green)"
)
fmt_str = " {:}{:2d} > KPConv(r={:.3f}, Din={:d}, Dout={:d}){:}"
deform_convs = []
for m in net.modules():
if isinstance(m, KPConv) and m.deformable:
if len(deform_convs) == deform_idx:
color = bcolors.OKGREEN
else:
color = bcolors.FAIL
print(
fmt_str.format(
color,
len(deform_convs),
m.radius,
m.in_channels,
m.out_channels,
bcolors.ENDC,
)
)
deform_convs.append(m)
################
# Initialization
################
print("\n****************************************************\n")
# Loop variables
time.time()
t = [time.time()]
time.time()
np.zeros(1)
count = 0
# Start training loop
for epoch in range(config.max_epoch):
for batch in loader:
##################
# Processing batch
##################
# New time
t = t[-1:]
t += [time.time()]
if "cuda" in self.device.type:
batch.to(self.device)
# Forward pass
net(batch, config)
original_KP = (
deform_convs[deform_idx].kernel_points.cpu().detach().numpy()
)
stacked_deformed_KP = (
deform_convs[deform_idx].deformed_KP.cpu().detach().numpy()
)
count += batch.lengths[0].shape[0]
if "cuda" in self.device.type:
torch.cuda.synchronize(self.device)
# Find layer
l = None
for i, p in enumerate(batch.points):
if p.shape[0] == stacked_deformed_KP.shape[0]:
l = i
t += [time.time()]
# Get data
in_points = []
in_colors = []
deformed_KP = []
points = []
lookuptrees = []
i0 = 0
for b_i, length in enumerate(batch.lengths[0]):
in_points.append(
batch.points[0][i0 : i0 + length].cpu().detach().numpy()
)
if batch.features.shape[1] == 4:
in_colors.append(
batch.features[i0 : i0 + length, 1:].cpu().detach().numpy()
)
else:
in_colors.append(None)
i0 += length
i0 = 0
for b_i, length in enumerate(batch.lengths[l]):
points.append(
batch.points[l][i0 : i0 + length].cpu().detach().numpy()
)
deformed_KP.append(stacked_deformed_KP[i0 : i0 + length])
lookuptrees.append(KDTree(points[-1]))
i0 += length
###########################
# Interactive visualization
###########################
# Create figure for features
fig1 = mlab.figure(
"Deformations", bgcolor=(1.0, 1.0, 1.0), size=(1280, 920)
)
fig1.scene.parallel_projection = False
# Indices
global obj_i, point_i, plots, offsets, p_scale, show_in_p, aim_point
p_scale = 0.03
obj_i = 0
point_i = 0
plots = {}
offsets = False
show_in_p = 2
aim_point = np.zeros((1, 3))
def picker_callback(picker):
"""Picker callback: this get called when on pick events."""
global plots, aim_point
if "in_points" in plots:
if plots["in_points"].actor.actor._vtk_obj in [
o._vtk_obj for o in picker.actors
]:
point_rez = (
plots["in_points"]
.glyph.glyph_source.glyph_source.output.points.to_array()
.shape[0]
)
new_point_i = int(np.floor(picker.point_id / point_rez))
if new_point_i < len(plots["in_points"].mlab_source.points):
# Get closest point in the layer we are interested in
aim_point = plots["in_points"].mlab_source.points[
new_point_i : new_point_i + 1
]
update_scene()
if "points" in plots:
if plots["points"].actor.actor._vtk_obj in [
o._vtk_obj for o in picker.actors
]:
point_rez = (
plots["points"]
.glyph.glyph_source.glyph_source.output.points.to_array()
.shape[0]
)
new_point_i = int(np.floor(picker.point_id / point_rez))
if new_point_i < len(plots["points"].mlab_source.points):
# Get closest point in the layer we are interested in
aim_point = plots["points"].mlab_source.points[
new_point_i : new_point_i + 1
]
update_scene()
def update_scene():
global plots, offsets, p_scale, show_in_p, aim_point, point_i
# Get the current view
v = mlab.view()
roll = mlab.roll()
# clear figure
for key in plots.keys():
plots[key].remove()
plots = {}
# Plot new data feature
p = points[obj_i]
# Rescale points for visu
p = p * 1.5 / config.in_radius
# Show point cloud
if show_in_p <= 1:
plots["points"] = mlab.points3d(
p[:, 0],
p[:, 1],
p[:, 2],
resolution=8,
scale_factor=p_scale,
scale_mode="none",
color=(0, 1, 1),
figure=fig1,
)
if show_in_p >= 1:
# Get points and colors
in_p = in_points[obj_i]
in_p = in_p * 1.5 / config.in_radius
# Color point cloud if possible
in_c = in_colors[obj_i]
if in_c is not None:
# Primitives
scalars = np.arange(
len(in_p)
) # Key point: set an integer for each point
# Define color table (including alpha), which must be uint8 and [0,255]
colors = np.hstack((in_c, np.ones_like(in_c[:, :1])))
colors = (colors * 255).astype(np.uint8)
plots["in_points"] = mlab.points3d(
in_p[:, 0],
in_p[:, 1],
in_p[:, 2],
scalars,
resolution=8,
scale_factor=p_scale * 0.8,
scale_mode="none",
figure=fig1,
)
plots[
"in_points"
].module_manager.scalar_lut_manager.lut.table = colors
else:
plots["in_points"] = mlab.points3d(
in_p[:, 0],
in_p[:, 1],
in_p[:, 2],
resolution=8,
scale_factor=p_scale * 0.8,
scale_mode="none",
figure=fig1,
)
# Get KP locations
rescaled_aim_point = aim_point * config.in_radius / 1.5
point_i = lookuptrees[obj_i].query(
rescaled_aim_point, return_distance=False
)[0][0]
if offsets:
KP = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
scals = np.ones_like(KP[:, 0])
else:
KP = points[obj_i][point_i] + original_KP
scals = np.zeros_like(KP[:, 0])
KP = KP * 1.5 / config.in_radius
plots["KP"] = mlab.points3d(
KP[:, 0],
KP[:, 1],
KP[:, 2],
scals,
colormap="autumn",
resolution=8,
scale_factor=1.2 * p_scale,
scale_mode="none",
vmin=0,
vmax=1,
figure=fig1,
)
if True:
plots["center"] = mlab.points3d(
p[point_i, 0],
p[point_i, 1],
p[point_i, 2],
scale_factor=1.1 * p_scale,
scale_mode="none",
color=(0, 1, 0),
figure=fig1,
)
# New title
plots["title"] = mlab.title(
str(obj_i), color=(0, 0, 0), size=0.3, height=0.01
)
text = (
"<--- (press g for previous)"
+ 50 * " "
+ "(press h for next) --->"
)
plots["text"] = mlab.text(
0.01, 0.01, text, color=(0, 0, 0), width=0.98
)
plots["orient"] = mlab.orientation_axes()
# Set the saved view
mlab.view(*v)
mlab.roll(roll)
return
def animate_kernel():
global plots, offsets, p_scale, show_in_p
# Get KP locations
KP_def = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
KP_def = KP_def * 1.5 / config.in_radius
KP_rigid = points[obj_i][point_i] + original_KP
KP_rigid = KP_rigid * 1.5 / config.in_radius
if offsets:
t_list = np.linspace(0, 1, 150, dtype=np.float32)
else:
t_list = np.linspace(1, 0, 150, dtype=np.float32)
@mlab.animate(delay=10)
def anim():
for t in t_list:
plots["KP"].mlab_source.set(
x=t * KP_def[:, 0] + (1 - t) * KP_rigid[:, 0],
y=t * KP_def[:, 1] + (1 - t) * KP_rigid[:, 1],
z=t * KP_def[:, 2] + (1 - t) * KP_rigid[:, 2],
scalars=t * np.ones_like(KP_def[:, 0]),
)
yield
anim()
return
def keyboard_callback(vtk_obj, event):
global obj_i, point_i, offsets, p_scale, show_in_p
if vtk_obj.GetKeyCode() in ["b", "B"]:
p_scale /= 1.5
update_scene()
elif vtk_obj.GetKeyCode() in ["n", "N"]:
p_scale *= 1.5
update_scene()
if vtk_obj.GetKeyCode() in ["g", "G"]:
obj_i = (obj_i - 1) % len(deformed_KP)
point_i = 0
update_scene()
elif vtk_obj.GetKeyCode() in ["h", "H"]:
obj_i = (obj_i + 1) % len(deformed_KP)
point_i = 0
update_scene()
elif vtk_obj.GetKeyCode() in ["k", "K"]:
offsets = not offsets
animate_kernel()
elif vtk_obj.GetKeyCode() in ["z", "Z"]:
show_in_p = (show_in_p + 1) % 3
update_scene()
elif vtk_obj.GetKeyCode() in ["0"]:
print("Saving")
# Find a new name
file_i = 0
file_name = "KP_{:03d}.ply".format(file_i)
files = [f for f in listdir("KP_clouds") if f.endswith(".ply")]
while file_name in files:
file_i += 1
file_name = "KP_{:03d}.ply".format(file_i)
KP_deform = points[obj_i][point_i] + deformed_KP[obj_i][point_i]
KP_normal = points[obj_i][point_i] + original_KP
# Save
write_ply(
join("KP_clouds", file_name),
[in_points[obj_i], in_colors[obj_i]],
["x", "y", "z", "red", "green", "blue"],
)
write_ply(
join("KP_clouds", "KP_{:03d}_deform.ply".format(file_i)),
[KP_deform],
["x", "y", "z"],
)
write_ply(
join("KP_clouds", "KP_{:03d}_normal.ply".format(file_i)),
[KP_normal],
["x", "y", "z"],
)
print("OK")
return
# Draw a first plot
pick_func = fig1.on_mouse_pick(picker_callback)
pick_func.tolerance = 0.01
update_scene()
fig1.scene.interactor.add_observer("KeyPressEvent", keyboard_callback)
mlab.show()
return
# Utilities
# ------------------------------------------------------------------------------------------------------------------
def show_ModelNet_models(all_points):
###########################
# Interactive visualization
###########################
# Create figure for features
fig1 = mlab.figure("Models", bgcolor=(1, 1, 1), size=(1000, 800))
fig1.scene.parallel_projection = False
# Indices
global file_i
file_i = 0
def update_scene():
# clear figure
mlab.clf(fig1)
# Plot new data feature
points = all_points[file_i]
# Rescale points for visu
points = (points * 1.5 + np.array([1.0, 1.0, 1.0])) * 50.0
# Show point clouds colorized with activations
mlab.points3d(
points[:, 0],
points[:, 1],
points[:, 2],
points[:, 2],
scale_factor=3.0,
scale_mode="none",
figure=fig1,
)
# New title
mlab.title(str(file_i), color=(0, 0, 0), size=0.3, height=0.01)
text = "<--- (press g for previous)" + 50 * " " + "(press h for next) --->"
mlab.text(0.01, 0.01, text, color=(0, 0, 0), width=0.98)
mlab.orientation_axes()
return
def keyboard_callback(vtk_obj, event):
global file_i
if vtk_obj.GetKeyCode() in ["g", "G"]:
file_i = (file_i - 1) % len(all_points)
update_scene()
elif vtk_obj.GetKeyCode() in ["h", "H"]:
file_i = (file_i + 1) % len(all_points)
update_scene()
return
# Draw a first plot
update_scene()
fig1.scene.interactor.add_observer("KeyPressEvent", keyboard_callback)
mlab.show()