# # # 0=================================0 # | Kernel Point Convolutions | # 0=================================0 # # # ---------------------------------------------------------------------------------------------------------------------- # # Class handling the visualization # # ---------------------------------------------------------------------------------------------------------------------- # # Hugues THOMAS - 11/06/2018 # # ---------------------------------------------------------------------------------------------------------------------- # # Imports and global variables # \**********************************/ # # Basic libs import torch import numpy as np from sklearn.neighbors import KDTree from os import listdir from os.path import join import time from mayavi import mlab from models.blocks import KPConv # PLY reader from utils.ply import write_ply # Configuration class from utils.config import bcolors # ---------------------------------------------------------------------------------------------------------------------- # # Trainer Class # \*******************/ # class ModelVisualizer: # Initialization methods # ------------------------------------------------------------------------------------------------------------------ def __init__(self, net, config, chkp_path, on_gpu=True): """ Initialize training parameters and reload previous model for restore/finetune :param net: network object :param config: configuration object :param chkp_path: path to the checkpoint that needs to be loaded (None for new training) :param finetune: finetune from checkpoint (True) or restore training from checkpoint (False) :param on_gpu: Train on GPU or CPU """ ############ # Parameters ############ # Choose to train on CPU or GPU if on_gpu and torch.cuda.is_available(): self.device = torch.device("cuda:0") else: self.device = torch.device("cpu") net.to(self.device) ########################## # Load previous checkpoint ########################## checkpoint = torch.load(chkp_path) new_dict = {} for k, v in checkpoint["model_state_dict"].items(): if "blocs" in k: k = k.replace("blocs", "blocks") new_dict[k] = v net.load_state_dict(new_dict) self.epoch = checkpoint["epoch"] net.eval() print("\nModel state restored from {:s}.".format(chkp_path)) return # Main visualization methods # ------------------------------------------------------------------------------------------------------------------ def show_deformable_kernels(self, net, loader, config, deform_idx=0): """ Show some inference with deformable kernels """ ########################################## # First choose the visualized deformations ########################################## print( "\nList of the deformable convolution available (chosen one highlighted in green)" ) fmt_str = " {:}{:2d} > KPConv(r={:.3f}, Din={:d}, Dout={:d}){:}" deform_convs = [] for m in net.modules(): if isinstance(m, KPConv) and m.deformable: if len(deform_convs) == deform_idx: color = bcolors.OKGREEN else: color = bcolors.FAIL print( fmt_str.format( color, len(deform_convs), m.radius, m.in_channels, m.out_channels, bcolors.ENDC, ) ) deform_convs.append(m) ################ # Initialization ################ print("\n****************************************************\n") # Loop variables time.time() t = [time.time()] time.time() np.zeros(1) count = 0 # Start training loop for epoch in range(config.max_epoch): for batch in loader: ################## # Processing batch ################## # New time t = t[-1:] t += [time.time()] if "cuda" in self.device.type: batch.to(self.device) # Forward pass net(batch, config) original_KP = ( deform_convs[deform_idx].kernel_points.cpu().detach().numpy() ) stacked_deformed_KP = ( deform_convs[deform_idx].deformed_KP.cpu().detach().numpy() ) count += batch.lengths[0].shape[0] if "cuda" in self.device.type: torch.cuda.synchronize(self.device) # Find layer l = None for i, p in enumerate(batch.points): if p.shape[0] == stacked_deformed_KP.shape[0]: l = i t += [time.time()] # Get data in_points = [] in_colors = [] deformed_KP = [] points = [] lookuptrees = [] i0 = 0 for b_i, length in enumerate(batch.lengths[0]): in_points.append( batch.points[0][i0 : i0 + length].cpu().detach().numpy() ) if batch.features.shape[1] == 4: in_colors.append( batch.features[i0 : i0 + length, 1:].cpu().detach().numpy() ) else: in_colors.append(None) i0 += length i0 = 0 for b_i, length in enumerate(batch.lengths[l]): points.append( batch.points[l][i0 : i0 + length].cpu().detach().numpy() ) deformed_KP.append(stacked_deformed_KP[i0 : i0 + length]) lookuptrees.append(KDTree(points[-1])) i0 += length ########################### # Interactive visualization ########################### # Create figure for features fig1 = mlab.figure( "Deformations", bgcolor=(1.0, 1.0, 1.0), size=(1280, 920) ) fig1.scene.parallel_projection = False # Indices global obj_i, point_i, plots, offsets, p_scale, show_in_p, aim_point p_scale = 0.03 obj_i = 0 point_i = 0 plots = {} offsets = False show_in_p = 2 aim_point = np.zeros((1, 3)) def picker_callback(picker): """Picker callback: this get called when on pick events.""" global plots, aim_point if "in_points" in plots: if plots["in_points"].actor.actor._vtk_obj in [ o._vtk_obj for o in picker.actors ]: point_rez = ( plots["in_points"] .glyph.glyph_source.glyph_source.output.points.to_array() .shape[0] ) new_point_i = int(np.floor(picker.point_id / point_rez)) if new_point_i < len(plots["in_points"].mlab_source.points): # Get closest point in the layer we are interested in aim_point = plots["in_points"].mlab_source.points[ new_point_i : new_point_i + 1 ] update_scene() if "points" in plots: if plots["points"].actor.actor._vtk_obj in [ o._vtk_obj for o in picker.actors ]: point_rez = ( plots["points"] .glyph.glyph_source.glyph_source.output.points.to_array() .shape[0] ) new_point_i = int(np.floor(picker.point_id / point_rez)) if new_point_i < len(plots["points"].mlab_source.points): # Get closest point in the layer we are interested in aim_point = plots["points"].mlab_source.points[ new_point_i : new_point_i + 1 ] update_scene() def update_scene(): global plots, offsets, p_scale, show_in_p, aim_point, point_i # Get the current view v = mlab.view() roll = mlab.roll() # clear figure for key in plots.keys(): plots[key].remove() plots = {} # Plot new data feature p = points[obj_i] # Rescale points for visu p = p * 1.5 / config.in_radius # Show point cloud if show_in_p <= 1: plots["points"] = mlab.points3d( p[:, 0], p[:, 1], p[:, 2], resolution=8, scale_factor=p_scale, scale_mode="none", color=(0, 1, 1), figure=fig1, ) if show_in_p >= 1: # Get points and colors in_p = in_points[obj_i] in_p = in_p * 1.5 / config.in_radius # Color point cloud if possible in_c = in_colors[obj_i] if in_c is not None: # Primitives scalars = np.arange( len(in_p) ) # Key point: set an integer for each point # Define color table (including alpha), which must be uint8 and [0,255] colors = np.hstack((in_c, np.ones_like(in_c[:, :1]))) colors = (colors * 255).astype(np.uint8) plots["in_points"] = mlab.points3d( in_p[:, 0], in_p[:, 1], in_p[:, 2], scalars, resolution=8, scale_factor=p_scale * 0.8, scale_mode="none", figure=fig1, ) plots[ "in_points" ].module_manager.scalar_lut_manager.lut.table = colors else: plots["in_points"] = mlab.points3d( in_p[:, 0], in_p[:, 1], in_p[:, 2], resolution=8, scale_factor=p_scale * 0.8, scale_mode="none", figure=fig1, ) # Get KP locations rescaled_aim_point = aim_point * config.in_radius / 1.5 point_i = lookuptrees[obj_i].query( rescaled_aim_point, return_distance=False )[0][0] if offsets: KP = points[obj_i][point_i] + deformed_KP[obj_i][point_i] scals = np.ones_like(KP[:, 0]) else: KP = points[obj_i][point_i] + original_KP scals = np.zeros_like(KP[:, 0]) KP = KP * 1.5 / config.in_radius plots["KP"] = mlab.points3d( KP[:, 0], KP[:, 1], KP[:, 2], scals, colormap="autumn", resolution=8, scale_factor=1.2 * p_scale, scale_mode="none", vmin=0, vmax=1, figure=fig1, ) if True: plots["center"] = mlab.points3d( p[point_i, 0], p[point_i, 1], p[point_i, 2], scale_factor=1.1 * p_scale, scale_mode="none", color=(0, 1, 0), figure=fig1, ) # New title plots["title"] = mlab.title( str(obj_i), color=(0, 0, 0), size=0.3, height=0.01 ) text = ( "<--- (press g for previous)" + 50 * " " + "(press h for next) --->" ) plots["text"] = mlab.text( 0.01, 0.01, text, color=(0, 0, 0), width=0.98 ) plots["orient"] = mlab.orientation_axes() # Set the saved view mlab.view(*v) mlab.roll(roll) return def animate_kernel(): global plots, offsets, p_scale, show_in_p # Get KP locations KP_def = points[obj_i][point_i] + deformed_KP[obj_i][point_i] KP_def = KP_def * 1.5 / config.in_radius KP_rigid = points[obj_i][point_i] + original_KP KP_rigid = KP_rigid * 1.5 / config.in_radius if offsets: t_list = np.linspace(0, 1, 150, dtype=np.float32) else: t_list = np.linspace(1, 0, 150, dtype=np.float32) @mlab.animate(delay=10) def anim(): for t in t_list: plots["KP"].mlab_source.set( x=t * KP_def[:, 0] + (1 - t) * KP_rigid[:, 0], y=t * KP_def[:, 1] + (1 - t) * KP_rigid[:, 1], z=t * KP_def[:, 2] + (1 - t) * KP_rigid[:, 2], scalars=t * np.ones_like(KP_def[:, 0]), ) yield anim() return def keyboard_callback(vtk_obj, event): global obj_i, point_i, offsets, p_scale, show_in_p if vtk_obj.GetKeyCode() in ["b", "B"]: p_scale /= 1.5 update_scene() elif vtk_obj.GetKeyCode() in ["n", "N"]: p_scale *= 1.5 update_scene() if vtk_obj.GetKeyCode() in ["g", "G"]: obj_i = (obj_i - 1) % len(deformed_KP) point_i = 0 update_scene() elif vtk_obj.GetKeyCode() in ["h", "H"]: obj_i = (obj_i + 1) % len(deformed_KP) point_i = 0 update_scene() elif vtk_obj.GetKeyCode() in ["k", "K"]: offsets = not offsets animate_kernel() elif vtk_obj.GetKeyCode() in ["z", "Z"]: show_in_p = (show_in_p + 1) % 3 update_scene() elif vtk_obj.GetKeyCode() in ["0"]: print("Saving") # Find a new name file_i = 0 file_name = "KP_{:03d}.ply".format(file_i) files = [f for f in listdir("KP_clouds") if f.endswith(".ply")] while file_name in files: file_i += 1 file_name = "KP_{:03d}.ply".format(file_i) KP_deform = points[obj_i][point_i] + deformed_KP[obj_i][point_i] KP_normal = points[obj_i][point_i] + original_KP # Save write_ply( join("KP_clouds", file_name), [in_points[obj_i], in_colors[obj_i]], ["x", "y", "z", "red", "green", "blue"], ) write_ply( join("KP_clouds", "KP_{:03d}_deform.ply".format(file_i)), [KP_deform], ["x", "y", "z"], ) write_ply( join("KP_clouds", "KP_{:03d}_normal.ply".format(file_i)), [KP_normal], ["x", "y", "z"], ) print("OK") return # Draw a first plot pick_func = fig1.on_mouse_pick(picker_callback) pick_func.tolerance = 0.01 update_scene() fig1.scene.interactor.add_observer("KeyPressEvent", keyboard_callback) mlab.show() return # Utilities # ------------------------------------------------------------------------------------------------------------------ def show_ModelNet_models(all_points): ########################### # Interactive visualization ########################### # Create figure for features fig1 = mlab.figure("Models", bgcolor=(1, 1, 1), size=(1000, 800)) fig1.scene.parallel_projection = False # Indices global file_i file_i = 0 def update_scene(): # clear figure mlab.clf(fig1) # Plot new data feature points = all_points[file_i] # Rescale points for visu points = (points * 1.5 + np.array([1.0, 1.0, 1.0])) * 50.0 # Show point clouds colorized with activations mlab.points3d( points[:, 0], points[:, 1], points[:, 2], points[:, 2], scale_factor=3.0, scale_mode="none", figure=fig1, ) # New title mlab.title(str(file_i), color=(0, 0, 0), size=0.3, height=0.01) text = "<--- (press g for previous)" + 50 * " " + "(press h for next) --->" mlab.text(0.01, 0.01, text, color=(0, 0, 0), width=0.98) mlab.orientation_axes() return def keyboard_callback(vtk_obj, event): global file_i if vtk_obj.GetKeyCode() in ["g", "G"]: file_i = (file_i - 1) % len(all_points) update_scene() elif vtk_obj.GetKeyCode() in ["h", "H"]: file_i = (file_i + 1) % len(all_points) update_scene() return # Draw a first plot update_scene() fig1.scene.interactor.add_observer("KeyPressEvent", keyboard_callback) mlab.show()