From 31cce85c95b5959863b38958ea30cbe7016bae76 Mon Sep 17 00:00:00 2001 From: HuguesTHOMAS Date: Thu, 23 Apr 2020 09:51:16 -0400 Subject: [PATCH] Corrections --- datasets/NCLT.py | 502 ++++++++++++++++++++++++++++++++++++++ datasets/S3DIS.py | 10 +- datasets/SemanticKitti.py | 8 +- datasets/common.py | 2 +- models/architectures.py | 19 +- models/blocks.py | 22 ++ plot_convergence.py | 82 ++++++- train_NCLT.py | 344 ++++++++++++++++++++++++++ train_S3DIS.py | 23 +- train_SemanticKitti.py | 18 +- utils/tester.py | 2 +- utils/trainer.py | 2 +- utils/visualizer.py | 231 ++++++++---------- visualize_ERFs.py | 205 ++++++++++++++++ visualize_deformations.py | 42 ++-- 15 files changed, 1321 insertions(+), 191 deletions(-) create mode 100644 datasets/NCLT.py create mode 100644 train_NCLT.py create mode 100644 visualize_ERFs.py diff --git a/datasets/NCLT.py b/datasets/NCLT.py new file mode 100644 index 0000000..f262f51 --- /dev/null +++ b/datasets/NCLT.py @@ -0,0 +1,502 @@ +# +# +# 0=================================0 +# | Kernel Point Convolutions | +# 0=================================0 +# +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Class handling SemanticKitti dataset. +# Implements a Dataset, a Sampler, and a collate_fn +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Hugues THOMAS - 11/06/2018 +# + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Imports and global variables +# \**********************************/ +# + +# Common libs +import sys +import struct +import scipy +import time +import numpy as np +import pickle +import torch +import yaml +#from mayavi import mlab +from multiprocessing import Lock + +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D + + +# OS functions +from os import listdir +from os.path import exists, join, isdir, getsize + +# Dataset parent class +from datasets.common import * +from torch.utils.data import Sampler, get_worker_info +from utils.mayavi_visu import * +from utils.metrics import fast_confusion + +from datasets.common import grid_subsampling +from utils.config import bcolors + + +def ssc_to_homo(ssc, ssc_in_radians=True): + + # Convert 6-DOF ssc coordinate transformation to 4x4 homogeneous matrix + # transformation + + if ssc.ndim == 1: + reduce = True + ssc = np.expand_dims(ssc, 0) + else: + reduce = False + + if not ssc_in_radians: + ssc[:, 3:] = np.pi / 180.0 * ssc[:, 3:] + + sr = np.sin(ssc[:, 3]) + cr = np.cos(ssc[:, 3]) + + sp = np.sin(ssc[:, 4]) + cp = np.cos(ssc[:, 4]) + + sh = np.sin(ssc[:, 5]) + ch = np.cos(ssc[:, 5]) + + H = np.zeros((ssc.shape[0], 4, 4)) + + H[:, 0, 0] = ch*cp + H[:, 0, 1] = -sh*cr + ch*sp*sr + H[:, 0, 2] = sh*sr + ch*sp*cr + H[:, 1, 0] = sh*cp + H[:, 1, 1] = ch*cr + sh*sp*sr + H[:, 1, 2] = -ch*sr + sh*sp*cr + H[:, 2, 0] = -sp + H[:, 2, 1] = cp*sr + H[:, 2, 2] = cp*cr + + H[:, 0, 3] = ssc[:, 0] + H[:, 1, 3] = ssc[:, 1] + H[:, 2, 3] = ssc[:, 2] + + H[:, 3, 3] = 1 + + if reduce: + H = np.squeeze(H) + + return H + + +def verify_magic(s): + + magic = 44444 + + m = struct.unpack('=4 and m[0] == magic and m[1] == magic and m[2] == magic and m[3] == magic + + +def test_read_hits(): + + data_path = '../../Data/NCLT' + velo_folder = 'velodyne_data' + day = '2012-01-08' + + hits_path = join(data_path, velo_folder, day, 'velodyne_hits.bin') + + all_utimes = [] + all_hits = [] + all_ints = [] + + num_bytes = getsize(hits_path) + current_bytes = 0 + + with open(hits_path, 'rb') as f_bin: + + total_hits = 0 + first_utime = -1 + last_utime = -1 + + while True: + + magic = f_bin.read(8) + if magic == b'': + break + + if not verify_magic(magic): + print('Could not verify magic') + + num_hits = struct.unpack(' 0.1: + break + + current_bytes += 24 + 8 * num_hits + + print('{:d}/{:d} => {:.1f}%'.format(current_bytes, num_bytes, 100 * current_bytes / num_bytes)) + + all_utimes = np.hstack(all_utimes) + all_hits = np.vstack(all_hits) + all_ints = np.hstack(all_ints) + + write_ply('test_hits', + [all_hits, all_ints, all_utimes], + ['x', 'y', 'z', 'intensity', 'utime']) + + print("Read %d total hits from %ld to %ld" % (total_hits, first_utime, last_utime)) + + return 0 + + +def frames_to_ply(show_frames=False): + + # In files + data_path = '../../Data/NCLT' + velo_folder = 'velodyne_data' + + days = np.sort([d for d in listdir(join(data_path, velo_folder))]) + + for day in days: + + # Out files + ply_folder = join(data_path, 'frames_ply', day) + if not exists(ply_folder): + makedirs(ply_folder) + + day_path = join(data_path, velo_folder, day, 'velodyne_sync') + f_names = np.sort([f for f in listdir(day_path) if f[-4:] == '.bin']) + + N = len(f_names) + print('Reading', N, 'files') + + for f_i, f_name in enumerate(f_names): + + ply_name = join(ply_folder, f_name[:-4] + '.ply') + if exists(ply_name): + continue + + + t1 = time.time() + + hits = [] + ints = [] + + with open(join(day_path, f_name), 'rb') as f_bin: + + while True: + x_str = f_bin.read(2) + + # End of file + if x_str == b'': + break + + x = struct.unpack(' np.min(day_gt_t), t_cov < np.max(day_gt_t)) + t_cov = t_cov[t_cov_bool] + + # Note: Interpolation is not needed, this is done as a convinience + interp = scipy.interpolate.interp1d(day_gt_t, day_gt_H[:, :3, 3], kind='nearest', axis=0) + node_poses = interp(t_cov) + + plt.figure() + plt.scatter(day_gt_H[:, 1, 3], day_gt_H[:, 0, 3], 1, c=-day_gt_H[:, 2, 3], linewidth=0) + plt.scatter(node_poses[:, 1], node_poses[:, 0], 1, c=-node_poses[:, 2], linewidth=5) + plt.axis('equal') + plt.title('Ground Truth Position of Nodes in SLAM Graph') + plt.xlabel('East (m)') + plt.ylabel('North (m)') + plt.colorbar() + + plt.show() + + t2 = time.time() + print('Done in {:.1f}s\n'.format(t2 - t0)) + + # Out files + out_folder = join(data_path, 'day_ply') + if not exists(out_folder): + makedirs(out_folder) + + # Focus on a particular point + p0 = np.array([-220, -527, 12]) + center_radius = 10.0 + point_radius = 50.0 + + # Loop on days + for d, day in enumerate(days): + + #if day != '2012-02-05': + # continue + day_min_t = gt_t[d][0] + day_max_t = gt_t[d][-1] + + frames_folder = join(data_path, 'frames_ply', day) + f_times = np.sort([float(f[:-4]) for f in listdir(frames_folder) if f[-4:] == '.ply']) + + # If we want, load only SLAM nodes + if only_SLAM_nodes: + + # Load node timestamps + cov_csv = join(data_path, cov_folder, cov_files[d]) + cov = np.loadtxt(cov_csv, delimiter=',') + t_cov = cov[:, 0] + t_cov_bool = np.logical_and(t_cov > day_min_t, t_cov < day_max_t) + t_cov = t_cov[t_cov_bool] + + # Find closest lidar frames + t_cov = np.expand_dims(t_cov, 1) + diffs = np.abs(t_cov - f_times) + inds = np.argmin(diffs, axis=1) + f_times = f_times[inds] + + # Is this frame in gt + f_t_bool = np.logical_and(f_times > day_min_t, f_times < day_max_t) + f_times = f_times[f_t_bool] + + # Interpolation gt poses to frame timestamps + interp = scipy.interpolate.interp1d(gt_t[d], gt_H[d], kind='nearest', axis=0) + frame_poses = interp(f_times) + + N = len(f_times) + world_points = [] + world_frames = [] + world_frames_c = [] + print('Reading', day, ' => ', N, 'files') + for f_i, f_t in enumerate(f_times): + + t1 = time.time() + + ######### + # GT pose + ######### + + H = frame_poses[f_i].astype(np.float32) + # s = '\n' + # for cc in H: + # for c in cc: + # s += '{:5.2f} '.format(c) + # s += '\n' + # print(s) + + ############# + # Focus check + ############# + + if np.linalg.norm(H[:3, 3] - p0) > center_radius: + continue + + ################################### + # Local frame coordinates for debug + ################################### + + # Create artificial frames + x = np.linspace(0, 1, 50, dtype=np.float32) + points = np.hstack((np.vstack((x, x*0, x*0)), np.vstack((x*0, x, x*0)), np.vstack((x*0, x*0, x)))).T + colors = ((points > 0.1).astype(np.float32) * 255).astype(np.uint8) + + hpoints = np.hstack((points, np.ones_like(points[:, :1]))) + hpoints = np.matmul(hpoints, H.T) + hpoints[:, 3] *= 0 + world_frames += [hpoints[:, :3]] + world_frames_c += [colors] + + ####################### + # Load velo point cloud + ####################### + + # Load frame ply file + f_name = '{:.0f}.ply'.format(f_t) + data = read_ply(join(frames_folder, f_name)) + points = np.vstack((data['x'], data['y'], data['z'])).T + #intensity = data['intensity'] + + hpoints = np.hstack((points, np.ones_like(points[:, :1]))) + hpoints = np.matmul(hpoints, H.T) + hpoints[:, 3] *= 0 + hpoints[:, 3] += np.sqrt(f_t - f_times[0]) + + # focus check + focus_bool = np.linalg.norm(hpoints[:, :3] - p0, axis=1) < point_radius + hpoints = hpoints[focus_bool, :] + + world_points += [hpoints] + + t2 = time.time() + print('File {:s} {:d}/{:d} Done in {:.1f}s'.format(f_name, f_i, N, t2 - t1)) + + if len(world_points) < 2: + continue + + world_points = np.vstack(world_points) + + + ###### DEBUG + world_frames = np.vstack(world_frames) + world_frames_c = np.vstack(world_frames_c) + write_ply('testf.ply', + [world_frames, world_frames_c], + ['x', 'y', 'z', 'red', 'green', 'blue']) + ###### DEBUG + + print(world_points.shape, world_points.dtype) + + # Subsample merged frames + # world_points, features = grid_subsampling(world_points[:, :3], + # features=world_points[:, 3:], + # sampleDl=0.1) + features = world_points[:, 3:] + world_points = world_points[:, :3] + + print(world_points.shape, world_points.dtype) + + write_ply('test' + day + '.ply', + [world_points, features], + ['x', 'y', 'z', 't']) + + + # Generate gt annotations + + # Subsample day ply (for visualization) + + # Save day ply + + # a = 1/0 diff --git a/datasets/S3DIS.py b/datasets/S3DIS.py index f5deb73..adbe8a7 100644 --- a/datasets/S3DIS.py +++ b/datasets/S3DIS.py @@ -131,7 +131,7 @@ class S3DISDataset(PointCloudDataset): # Prepare ply files ################### - self.prepare_S3DIS_ply() + #self.prepare_S3DIS_ply() ################ # Load ply files @@ -1037,7 +1037,7 @@ class S3DISSampler(Sampler): if breaking: break - def calibration(self, dataloader, untouched_ratio=0.9, verbose=False): + def calibration(self, dataloader, untouched_ratio=0.9, verbose=False, force_redo=False): """ Method performing batch and neighbors calibration. Batch calibration: Set "batch_limit" (the maximum number of points allowed in every batch) so that the @@ -1053,7 +1053,7 @@ class S3DISSampler(Sampler): print('\nStarting Calibration (use verbose=True for more details)') t0 = time.time() - redo = False + redo = force_redo # Batch limit # *********** @@ -1075,7 +1075,7 @@ class S3DISSampler(Sampler): self.dataset.config.in_radius, self.dataset.config.first_subsampling_dl, self.dataset.config.batch_num) - if key in batch_lim_dict: + if not redo and key in batch_lim_dict: self.dataset.batch_limit[0] = batch_lim_dict[key] else: redo = True @@ -1116,7 +1116,7 @@ class S3DISSampler(Sampler): if key in neighb_lim_dict: neighb_limits += [neighb_lim_dict[key]] - if len(neighb_limits) == self.dataset.config.num_layers: + if not redo and len(neighb_limits) == self.dataset.config.num_layers: self.dataset.neighborhood_limits = neighb_limits else: redo = True diff --git a/datasets/SemanticKitti.py b/datasets/SemanticKitti.py index 94c677f..75c493d 100644 --- a/datasets/SemanticKitti.py +++ b/datasets/SemanticKitti.py @@ -1115,7 +1115,7 @@ class SemanticKittiSampler(Sampler): # Perform calibration ##################### - self.dataset.batch_limit = self.dataset.max_in_p * (self.dataset.batch_num - 1) + #self.dataset.batch_limit[0] = self.dataset.max_in_p * (self.dataset.batch_num - 1) for epoch in range(10): for batch_i, batch in enumerate(dataloader): @@ -1145,7 +1145,7 @@ class SemanticKittiSampler(Sampler): smooth_errors = smooth_errors[1:] # Update batch limit with P controller - self.dataset.batch_limit += Kp * error + self.dataset.batch_limit[0] += Kp * error # finer low pass filter when closing in if not finer and np.abs(estim_b - target_b) < 1: @@ -1166,7 +1166,7 @@ class SemanticKittiSampler(Sampler): message = 'Step {:5d} estim_b ={:5.2f} batch_limit ={:7d}' print(message.format(i, estim_b, - int(self.dataset.batch_limit))) + int(self.dataset.batch_limit[0]))) if breaking: break @@ -1224,7 +1224,7 @@ class SemanticKittiSampler(Sampler): self.dataset.config.first_subsampling_dl, self.dataset.batch_num, self.dataset.max_in_p) - batch_lim_dict[key] = float(self.dataset.batch_limit) + batch_lim_dict[key] = float(self.dataset.batch_limit[0]) with open(batch_lim_file, 'wb') as file: pickle.dump(batch_lim_dict, file) diff --git a/datasets/common.py b/datasets/common.py index 6d771ab..f5f8a5e 100644 --- a/datasets/common.py +++ b/datasets/common.py @@ -228,7 +228,7 @@ class PointCloudDataset(Dataset): # Add random symmetries to the scale factor symmetries = np.array(self.config.augment_symmetries).astype(np.int32) symmetries *= np.random.randint(2, size=points.shape[1]) - scale = (scale * symmetries * 2 - 1).astype(np.float32) + scale = (scale * (1 - symmetries * 2)).astype(np.float32) ####### # Noise diff --git a/models/architectures.py b/models/architectures.py index 696cb94..42d98e9 100644 --- a/models/architectures.py +++ b/models/architectures.py @@ -183,7 +183,7 @@ class KPCNN(nn.Module): other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach() distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2)) rep_loss = torch.sum(torch.clamp_max(distances - 1.5, max=0.0) ** 2, dim=1) - repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) + repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K @@ -218,7 +218,7 @@ class KPFCNN(nn.Module): ##################### # Save all block operations in a list of modules - self.encoder_blocs = nn.ModuleList() + self.encoder_blocks = nn.ModuleList() self.encoder_skip_dims = [] self.encoder_skips = [] @@ -239,7 +239,7 @@ class KPFCNN(nn.Module): break # Apply the good block function defining tf ops - self.encoder_blocs.append(block_decider(block, + self.encoder_blocks.append(block_decider(block, r, in_dim, out_dim, @@ -264,7 +264,7 @@ class KPFCNN(nn.Module): ##################### # Save all block operations in a list of modules - self.decoder_blocs = nn.ModuleList() + self.decoder_blocks = nn.ModuleList() self.decoder_concats = [] # Find first upsampling block @@ -283,7 +283,7 @@ class KPFCNN(nn.Module): self.decoder_concats.append(block_i) # Apply the good block function defining tf ops - self.decoder_blocs.append(block_decider(block, + self.decoder_blocks.append(block_decider(block, r, in_dim, out_dim, @@ -331,12 +331,12 @@ class KPFCNN(nn.Module): # Loop over consecutive blocks skip_x = [] - for block_i, block_op in enumerate(self.encoder_blocs): + for block_i, block_op in enumerate(self.encoder_blocks): if block_i in self.encoder_skips: skip_x.append(x) x = block_op(x, batch) - for block_i, block_op in enumerate(self.decoder_blocs): + for block_i, block_op in enumerate(self.decoder_blocks): if block_i in self.decoder_concats: x = torch.cat([x, skip_x.pop()], dim=1) x = block_op(x, batch) @@ -434,9 +434,8 @@ class KPFCNN(nn.Module): other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach() distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2)) - rep_loss = torch.sum(torch.clamp_max(distances - 1.5, max=0.0) ** 2, dim=1) - repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) - + rep_loss = torch.sum(torch.clamp_max(distances - 0.5, max=0.0) ** 2, dim=1) + repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K return self.offset_decay * (fitting_loss + repulsive_loss) diff --git a/models/blocks.py b/models/blocks.py index e88d6db..f68f9d6 100644 --- a/models/blocks.py +++ b/models/blocks.py @@ -421,6 +421,7 @@ class BatchNormBlock(nn.Module): super(BatchNormBlock, self).__init__() self.bn_momentum = bn_momentum self.use_bn = use_bn + self.in_dim = in_dim if self.use_bn: self.batch_norm = nn.BatchNorm1d(in_dim, momentum=bn_momentum) #self.batch_norm = nn.InstanceNorm1d(in_dim, momentum=bn_momentum) @@ -442,6 +443,11 @@ class BatchNormBlock(nn.Module): else: return x + self.bias + def __repr__(self): + return 'BatchNormBlock(in_feat: {:d}, momentum: {:.3f}, only_bias: {:s})'.format(self.in_dim, + self.bn_momentum, + str(not self.use_bn)) + class UnaryBlock(nn.Module): @@ -458,6 +464,8 @@ class UnaryBlock(nn.Module): self.bn_momentum = bn_momentum self.use_bn = use_bn self.no_relu = no_relu + self.in_dim = in_dim + self.out_dim = out_dim self.mlp = nn.Linear(in_dim, out_dim, bias=False) self.batch_norm = BatchNormBlock(out_dim, self.use_bn, self.bn_momentum) if not no_relu: @@ -471,6 +479,12 @@ class UnaryBlock(nn.Module): x = self.leaky_relu(x) return x + def __repr__(self): + return 'UnaryBlock(in_feat: {:d}, out_feat: {:d}, BN: {:s}, ReLU: {:s})'.format(self.in_dim, + self.out_dim, + str(self.use_bn), + str(not self.no_relu)) + class SimpleBlock(nn.Module): @@ -492,6 +506,8 @@ class SimpleBlock(nn.Module): self.use_bn = config.use_batch_norm self.layer_ind = layer_ind self.block_name = block_name + self.in_dim = in_dim + self.out_dim = out_dim # Define the KPConv class self.KPConv = KPConv(config.num_kernel_points, @@ -547,6 +563,8 @@ class ResnetBottleneckBlock(nn.Module): self.use_bn = config.use_batch_norm self.block_name = block_name self.layer_ind = layer_ind + self.in_dim = in_dim + self.out_dim = out_dim # First downscaling mlp if in_dim != out_dim // 4: @@ -639,6 +657,10 @@ class NearestUpsampleBlock(nn.Module): def forward(self, x, batch): return closest_pool(x, batch.upsamples[self.layer_ind - 1]) + def __repr__(self): + return 'NearestUpsampleBlock(layer: {:d} -> {:d})'.format(self.layer_ind, + self.layer_ind - 1) + class MaxPoolBlock(nn.Module): diff --git a/plot_convergence.py b/plot_convergence.py index 4c0ac6e..4370ca9 100644 --- a/plot_convergence.py +++ b/plot_convergence.py @@ -1445,12 +1445,14 @@ def S3DIS_go(old_result_limit): def SemanticKittiFirst(old_result_limit): """ - Test SematicKitti. First exps + Test SematicKitti. First exps. + Try some class weight strategies. It seems that the final score is not impacted so much. With weights, some classes + are better while other are worse, for a final score that remains the same. """ # Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset. start = 'Log_2020-04-07_15-30-17' - end = 'Log_2020-05-07_15-30-17' + end = 'Log_2020-04-11_21-34-16' if end < old_result_limit: res_path = 'old_results' @@ -1464,8 +1466,43 @@ def SemanticKittiFirst(old_result_limit): logs_names = ['R=5.0_dl=0.04', 'R=5.0_dl=0.08', 'R=10.0_dl=0.08', - 'R=10.0_dl=0.08_weigths', - 'R=10.0_dl=0.08_sqrt_weigths', + 'R=10.0_dl=0.08_20*weigths', + 'R=10.0_dl=0.08_20*sqrt_weigths', + 'R=10.0_dl=0.08_100*sqrt_w', + 'R=10.0_dl=0.08_100*sqrt_w_capped', + 'R=10.0_dl=0.08_no_w'] + + logs_names = np.array(logs_names[:len(logs)]) + + return logs, logs_names + + +def SemanticKitti_scale(old_result_limit): + """ + Test SematicKitti. Try different scales of input raduis / subsampling. + """ + + # Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset. + start = 'Log_2020-04-11_21-34-15' + end = 'Log_2020-04-20_11-52-58' + + if end < old_result_limit: + res_path = 'old_results' + else: + res_path = 'results' + + logs = np.sort([join(res_path, l) for l in listdir(res_path) if start <= l <= end]) + logs = logs.astype(' 1: + config.saving_path = sys.argv[1] + + # Initialize datasets + training_dataset = NCLTDataset(config, set='training', + balance_classes=True) + test_dataset = NCLTDataset(config, set='validation', + balance_classes=False) + + # Initialize samplers + training_sampler = NCLTSampler(training_dataset) + test_sampler = NCLTSampler(test_dataset) + + # Initialize the dataloader + training_loader = DataLoader(training_dataset, + batch_size=1, + sampler=training_sampler, + collate_fn=NCLTCollate, + num_workers=config.input_threads, + pin_memory=True) + test_loader = DataLoader(test_dataset, + batch_size=1, + sampler=test_sampler, + collate_fn=NCLTCollate, + num_workers=config.input_threads, + pin_memory=True) + + # Calibrate max_in_point value + training_sampler.calib_max_in(config, training_loader, verbose=True) + test_sampler.calib_max_in(config, test_loader, verbose=True) + + # Calibrate samplers + training_sampler.calibration(training_loader, verbose=True) + test_sampler.calibration(test_loader, verbose=True) + + # debug_timing(training_dataset, training_loader) + # debug_timing(test_dataset, test_loader) + # debug_class_w(training_dataset, training_loader) + + print('\nModel Preparation') + print('*****************') + + # Define network model + t1 = time.time() + net = KPFCNN(config, training_dataset.label_values, training_dataset.ignored_labels) + + debug = False + if debug: + print('\n*************************************\n') + print(net) + print('\n*************************************\n') + for param in net.parameters(): + if param.requires_grad: + print(param.shape) + print('\n*************************************\n') + print("Model size %i" % sum(param.numel() for param in net.parameters() if param.requires_grad)) + print('\n*************************************\n') + + # Define a trainer class + trainer = ModelTrainer(net, config, chkp_path=chosen_chkp) + print('Done in {:.1f}s\n'.format(time.time() - t1)) + + print('\nStart training') + print('**************') + + # Training + trainer.train(net, training_loader, test_loader, config) + + print('Forcing exit now') + os.kill(os.getpid(), signal.SIGINT) diff --git a/train_S3DIS.py b/train_S3DIS.py index 2494348..471607f 100644 --- a/train_S3DIS.py +++ b/train_S3DIS.py @@ -74,22 +74,15 @@ class S3DISConfig(Config): 'resnetb_strided', 'resnetb', 'resnetb', - 'resnetb', 'resnetb_strided', 'resnetb', 'resnetb', - 'resnetb', - 'resnetb', - 'resnetb', 'resnetb_strided', - 'resnetb', - 'resnetb', - 'resnetb', - 'resnetb', - 'resnetb', - 'resnetb_strided', - 'resnetb', - 'resnetb', + 'resnetb_deformable', + 'resnetb_deformable', + 'resnetb_deformable_strided', + 'resnetb_deformable', + 'resnetb_deformable', 'nearest_upsample', 'unary', 'nearest_upsample', @@ -104,7 +97,7 @@ class S3DISConfig(Config): ################### # Radius of the input sphere - in_radius = 1.0 + in_radius = 1.5 # Number of kernel points num_kernel_points = 15 @@ -142,7 +135,7 @@ class S3DISConfig(Config): # 'permissive' only constrains offsets inside the deform radius (NOT implemented yet) # 'fitting' helps deformed kernels to adapt to the geometry by penalizing distance to input points offsets_loss = 'fitting' - offsets_decay = 0.01 + offsets_decay = 0.05 ##################### # Training parameters @@ -158,7 +151,7 @@ class S3DISConfig(Config): grad_clip_norm = 100.0 # Number of batch - batch_num = 8 + batch_num = 6 # Number of steps per epochs epoch_steps = 500 diff --git a/train_SemanticKitti.py b/train_SemanticKitti.py index 61f5414..7d4857f 100644 --- a/train_SemanticKitti.py +++ b/train_SemanticKitti.py @@ -100,21 +100,21 @@ class SemanticKittiConfig(Config): ################### # Radius of the input sphere - in_radius = 10.0 + in_radius = 6.0 val_radius = 51.0 n_frames = 1 max_in_points = 100000 - max_val_points = 100000 + max_val_points = 200000 # Number of batch - batch_num = 10 + batch_num = 8 val_batch_num = 1 # Number of kernel points num_kernel_points = 15 # Size of the first subsampling grid in meter - first_subsampling_dl = 0.08 + first_subsampling_dl = 0.06 # Radius of convolution in "number grid cell". (2.5 is the standard value) conv_radius = 2.5 @@ -133,7 +133,7 @@ class SemanticKittiConfig(Config): # Choice of input features first_features_dim = 128 - in_features_dim = 5 + in_features_dim = 2 # Can the network learn modulations modulated = False @@ -158,7 +158,7 @@ class SemanticKittiConfig(Config): # Learning rate management learning_rate = 1e-2 momentum = 0.98 - lr_decays = {i: 0.1 ** (1 / 100) for i in range(1, max_epoch)} + lr_decays = {i: 0.1 ** (1 / 150) for i in range(1, max_epoch)} grad_clip_norm = 100.0 # Number of steps per epochs @@ -190,8 +190,8 @@ class SemanticKittiConfig(Config): # 2.377, 0.886, 3.863, 0.869, 1.209, 0.594, 3.780, 1.129, 5.505, 11.180] # sqrt(Inverse of proportion * 100) capped (0.5 < X < 5) - class_w = [1.430, 5.000, 5.000, 4.226, 5.000, 5.000, 5.000, 5.000, 0.719, 2.377, - 0.886, 3.863, 0.869, 1.209, 0.594, 3.780, 1.129, 5.000, 5.000] + # class_w = [1.430, 5.000, 5.000, 4.226, 5.000, 5.000, 5.000, 5.000, 0.719, 2.377, + # 0.886, 3.863, 0.869, 1.209, 0.594, 3.780, 1.129, 5.000, 5.000] # Do we nee to save convergence @@ -212,7 +212,7 @@ if __name__ == '__main__': ############################ # Set which gpu is going to be used - GPU_ID = '3' + GPU_ID = '2' # Set GPU visible device os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID diff --git a/utils/tester.py b/utils/tester.py index ca13d2e..333625b 100644 --- a/utils/tester.py +++ b/utils/tester.py @@ -181,7 +181,7 @@ class ModelTester: inds = in_inds[i0:i0 + length] c_i = cloud_inds[b_i] - if test_radius_ratio < 0.99: + if 0 < test_radius_ratio < 1: mask = np.sum(points ** 2, axis=1) < (test_radius_ratio * config.in_radius) ** 2 inds = inds[mask] probs = probs[mask] diff --git a/utils/trainer.py b/utils/trainer.py index fff958e..1223d11 100644 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -259,7 +259,7 @@ class ModelTrainer: # Save checkpoints occasionally if (self.epoch + 1) % config.checkpoint_gap == 0: - checkpoint_path = join(checkpoint_directory, 'chkp_{:04d}.tar'.format(self.epoch)) + checkpoint_path = join(checkpoint_directory, 'chkp_{:04d}.tar'.format(self.epoch + 1)) torch.save(save_dict, checkpoint_path) # Validation diff --git a/utils/visualizer.py b/utils/visualizer.py index 566da39..ef8b16a 100644 --- a/utils/visualizer.py +++ b/utils/visualizer.py @@ -82,7 +82,7 @@ class ModelVisualizer: net.load_state_dict(checkpoint['model_state_dict']) self.epoch = checkpoint['epoch'] net.eval() - print("Model and training state restored.") + print("\nModel state restored from {:s}.".format(chkp_path)) return @@ -679,139 +679,63 @@ class ModelVisualizer: except tf.errors.OutOfRangeError: break - def show_effective_recep_field(self, model, dataset, relu_idx=0): + def show_effective_recep_field(self, net, loader, config, f_idx=0): - ################################################### - # First add a modulation variable on input features - ################################################### + ########################################## + # First choose the visualized deformations + ########################################## - # Tensorflow random seed - random_seed = 42 + blocks = {} - # Create a modulated input feature op - with tf.variable_scope('input_modulations'): - initial = tf.constant(0., shape=[200000, 1]) - input_modulations_var = tf.Variable(initial, name='alphas') - input_modulations = 2 * tf.sigmoid(input_modulations_var) - assert_op = tf.assert_less(tf.shape(model.inputs['features'])[0], tf.shape(input_modulations)[0]) - with tf.control_dependencies([assert_op]): - modulated_input = model.inputs['features'] * input_modulations[:tf.shape(model.inputs['features'])[0]] - modulated_input = tf.identity(modulated_input, name='modulated_features') + named_blocks = [(m_name, m) for m_name, m in net.named_modules() + if len(m_name.split('.')) == 2 and m_name.split('.')[0].endswith('_blocks')] + chosen_block = named_blocks[-1][0] - print('*******************************************') - - # Swap the op with the normal input features - for op in tf.get_default_graph().get_operations(): - - if 'input_modulations' in op.name: - continue - - if model.inputs['features'].name in [in_t.name for in_t in op.inputs]: - input_list = [] - for in_t in op.inputs: - if in_t.name == model.inputs['features'].name: - input_list += [modulated_input] - else: - input_list += [in_t] - print('swapping op ', op.name) - print('old inputs ', [in_t.name for in_t in op.inputs]) - print('new inputs ', [in_t.name for in_t in input_list]) - ge.swap_inputs(op, input_list) - - print('*******************************************') - - ########################## - # Create the ERF optimizer - ########################## - - # This optimizer only computes gradients for the feature modulation variables. We set the ERF loss, which - # consists of modifying the features in one location a the wanted layer - - with tf.variable_scope('ERF_loss'): - - # List all relu ops - all_ops = [op for op in tf.get_default_graph().get_operations() if op.name.startswith('KernelPointNetwork') - and op.name.endswith('LeakyRelu')] - - # Print the chosen one - features_tensor = all_ops[relu_idx].outputs[0] - - # Get parameters - layer_idx = int(features_tensor.name.split('/')[1][6:]) - if 'strided' in all_ops[relu_idx].name and not ('strided' in all_ops[relu_idx + 1].name): - layer_idx += 1 - features_dim = int(features_tensor.shape[1]) - radius = model.config.first_subsampling_dl * model.config.density_parameter * (2 ** layer_idx) - - print('You chose to visualize the output of operation named: ' + all_ops[relu_idx].name) - print('It contains {:d} features.'.format(int(features_tensor.shape[1]))) - - print('\nPossible Relu indices:') - for i, t in enumerate(all_ops): - print(i, ': ', t.name) - - print('\n****************************************************************************') - - # Get the receptive field of a random point - N = tf.shape(features_tensor)[0] - #random_ind = tf.random_uniform([1], minval=0, maxval=N, dtype=np.int32, seed=random_seed)[0] - #chosen_i_holder = tf.placeholder(tf.int32, name='chosen_ind') - aimed_coordinates = tf.placeholder(tf.float32, shape=(1, 3), name='aimed_coordinates') - d2 = tf.reduce_sum(tf.square(model.inputs['points'][layer_idx] - aimed_coordinates), axis=1) - chosen_i_tf = tf.argmin(d2, output_type=tf.int32) - - #test1 = tf.multiply(features_tensor, 2.0, name='test1') - #test2 = tf.multiply(features_tensor, 2.0, name='test2') - - # Gradient scaling operation - @tf.custom_gradient - def scale_grad_layer(x): - def scaled_grad(dy): - p_op = tf.print(x.name, - tf.reduce_mean(tf.abs(x)), - tf.reduce_mean(tf.abs(dy)), - output_stream=sys.stdout) - with tf.control_dependencies([p_op]): - new_dy = 1.0 * dy - return new_dy - return tf.identity(x), scaled_grad - - #test2 = scale_grad_layer(test2) - - # Get the tensor of error for these features (one for the chosen point, zero for the rest) - chosen_f_tf = tf.placeholder(tf.int32, name='feature_ind') - ERF_error = tf.expand_dims(tf.cast(tf.equal(tf.range(N), chosen_i_tf), tf.float32), 1) - ERF_error *= tf.expand_dims(tf.cast(tf.equal(tf.range(features_dim), chosen_f_tf), tf.float32), 0) - - # Get objective for the features (with a stop gradient so that we can get a gradient on the loss) - objective_features = features_tensor + ERF_error - objective_features = tf.stop_gradient(objective_features) - - # Loss is the error but with the features that can be learned to correct it - ERF_loss = tf.reduce_sum(tf.square(objective_features - features_tensor)) + for mi, (m_name, m) in enumerate(named_blocks): - with tf.variable_scope('ERF_optimizer'): + c1 = bcolors.OKBLUE + c2 = bcolors.BOLD + ce = bcolors.ENDC + print('{:}{:}{:s}{:}{:} {:s}'.format(c1, c2, m_name, ce, ce, m.__repr__())) + blocks[m_name] = m - # Create the gradient descent optimizer with a dummy learning rate - optimizer = tf.train.GradientDescentOptimizer(1.0) + if mi == f_idx: + chosen_block = m_name - # Get the gradients with respect to the modulation variable - ERF_var_grads = optimizer.compute_gradients(ERF_loss, var_list=[input_modulations_var]) + print('\nChoose which block output you want to visualize by entering the block name in blue') + override_block = input('Block name: ') + + if len(override_block) > 0: + chosen_block = override_block + print('{:}{:}{:s}{:}{:} {:s}'.format(c1, c2, chosen_block, ce, ce, blocks[chosen_block].__repr__())) + features_dim = blocks[chosen_block].out_dim + + # Fix all the trainable variables in the network (is it needed in eval mode?) + print('\n*************************************\n') + for p_name, param in net.named_parameters(): + if param.requires_grad: + param.requires_grad = False + print('\n*************************************\n') + + # Create modulation variable that requires grad + input_modulations = torch.nn.Parameter(torch.zeros((200000, 1), + dtype=torch.float32), + requires_grad=True) + + print('\n*************************************\n') + for p_name, param in net.named_parameters(): + if param.requires_grad: + print(p_name, param.shape) + print('\n*************************************\n') + + # Create ERF loss + + # Create ERF optimizer - # Gradient of the modulations - ERF_train_op = optimizer.apply_gradients(ERF_var_grads) - ################################ - # Run model on all test examples - ################################ - # Init our modulation variable - self.sess.run(tf.variables_initializer([input_modulations_var])) - # Initialise iterator with test data - self.sess.run(dataset.test_init_op) - count = 0 global plots, p_scale, show_in_p, remove_h, aim_point aim_point = np.zeros((1, 3), dtype=np.float32) @@ -841,10 +765,11 @@ class ModelVisualizer: global points, in_points, grad_values, chosen_point, aim_point, in_colors # Generate clouds until we effectively changed + batch = None if only_points: - for i in range(50): - all_points = self.sess.run(model.inputs['points']) - if all_points[0].shape[0] != in_points.shape[0]: + # get a new batch (index does not matter given our input pipeline) + for batch in loader: + if batch.points[0].shape[0] != in_points.shape[0]: break sum_grads = 0 @@ -853,11 +778,65 @@ class ModelVisualizer: else: num_tries = 10 + ################################################# + # Apply ERF optim to the same batch several times + ################################################# + + if 'cuda' in self.device.type: + batch.to(self.device) + + + for test_i in range(num_tries): print('Updating ERF {:.0f}%'.format((test_i + 1) * 100 / num_tries)) rand_f_i = np.random.randint(features_dim) + # Reset input modulation variable + torch.nn.init.zeros_(input_modulations) + + reset_op = input_modulations_var.assign(tf.zeros_like(input_modulations_var)) + self.sess.run(reset_op) + + # zero the parameter gradients + ERF_optimizer.zero_grad() + + # Forward pass + outputs = net(batch, config) + + loss = net.ERF_loss(outputs) + + # Backward + loss.backward() + + # Get result from hook here? + + ERF_optimizer.step() + torch.cuda.synchronize(self.device) + + + + + + + # Forward pass + outputs = net(batch, config) + original_KP = deform_convs[deform_idx].kernel_points.cpu().detach().numpy() + stacked_deformed_KP = deform_convs[deform_idx].deformed_KP.cpu().detach().numpy() + count += batch.lengths[0].shape[0] + + if 'cuda' in self.device.type: + torch.cuda.synchronize(self.device) + + + + + + + + + + # Reset input modulation variable reset_op = input_modulations_var.assign(tf.zeros_like(input_modulations_var)) self.sess.run(reset_op) @@ -1069,6 +1048,8 @@ class ModelVisualizer: fig1.scene.interactor.add_observer('KeyPressEvent', keyboard_callback) mlab.show() + return + def show_deformable_kernels(self, net, loader, config, deform_idx=0): """ Show some inference with deformable kernels diff --git a/visualize_ERFs.py b/visualize_ERFs.py new file mode 100644 index 0000000..bd908a3 --- /dev/null +++ b/visualize_ERFs.py @@ -0,0 +1,205 @@ +# +# +# 0=================================0 +# | Kernel Point Convolutions | +# 0=================================0 +# +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Callable script to start a training on ModelNet40 dataset +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Hugues THOMAS - 06/03/2020 +# + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Imports and global variables +# \**********************************/ +# + +# Common libs +import signal +import os +import numpy as np +import sys +import torch + +# Dataset +from datasets.ModelNet40 import * +from datasets.S3DIS import * +from torch.utils.data import DataLoader + +from utils.config import Config +from utils.visualizer import ModelVisualizer +from models.architectures import KPCNN, KPFCNN + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Main Call +# \***************/ +# + +def model_choice(chosen_log): + + ########################### + # Call the test initializer + ########################### + + # Automatically retrieve the last trained model + if chosen_log in ['last_ModelNet40', 'last_ShapeNetPart', 'last_S3DIS']: + + # Dataset name + test_dataset = '_'.join(chosen_log.split('_')[1:]) + + # List all training logs + logs = np.sort([os.path.join('results', f) for f in os.listdir('results') if f.startswith('Log')]) + + # Find the last log of asked dataset + for log in logs[::-1]: + log_config = Config() + log_config.load(log) + if log_config.dataset.startswith(test_dataset): + chosen_log = log + break + + if chosen_log in ['last_ModelNet40', 'last_ShapeNetPart', 'last_S3DIS']: + raise ValueError('No log of the dataset "' + test_dataset + '" found') + + # Check if log exists + if not os.path.exists(chosen_log): + raise ValueError('The given log does not exists: ' + chosen_log) + + return chosen_log + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Main Call +# \***************/ +# + +if __name__ == '__main__': + + ############################### + # Choose the model to visualize + ############################### + + # Here you can choose which model you want to test with the variable test_model. Here are the possible values : + # + # > 'last_XXX': Automatically retrieve the last trained model on dataset XXX + # > '(old_)results/Log_YYYY-MM-DD_HH-MM-SS': Directly provide the path of a trained model + + # chosen_log = 'results/Log_2020-04-04_10-04-42' # => ModelNet40 + # chosen_log = 'results/Log_2020-04-04_10-04-42' # => S3DIS + chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected + + # You can also choose the index of the snapshot to load (last by default) + chkp_idx = -1 + + # Eventually you can choose which feature is visualized (index of the deform operation in the network) + f_idx = -1 + + # Deal with 'last_XXX' choices + chosen_log = model_choice(chosen_log) + + ############################ + # Initialize the environment + ############################ + + # Set which gpu is going to be used + GPU_ID = '0' + + # Set GPU visible device + os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID + + ############### + # Previous chkp + ############### + + # Find all checkpoints in the chosen training folder + chkp_path = os.path.join(chosen_log, 'checkpoints') + chkps = [f for f in os.listdir(chkp_path) if f[:4] == 'chkp'] + + # Find which snapshot to restore + if chkp_idx is None: + chosen_chkp = 'current_chkp.tar' + else: + chosen_chkp = np.sort(chkps)[chkp_idx] + chosen_chkp = os.path.join(chosen_log, 'checkpoints', chosen_chkp) + + # Initialize configuration class + config = Config() + config.load(chosen_log) + + ################################## + # Change model parameters for test + ################################## + + # Change parameters for the test here. For example, you can stop augmenting the input data. + + config.augment_noise = 0.0001 + #config.augment_symmetries = False + config.batch_num = 1 + config.in_radius = 2.0 + config.input_threads = 0 + + ############## + # Prepare Data + ############## + + print() + print('Data Preparation') + print('****************') + + # Initiate dataset + if config.dataset.startswith('ModelNet40'): + test_dataset = ModelNet40Dataset(config, train=False) + test_sampler = ModelNet40Sampler(test_dataset) + collate_fn = ModelNet40Collate + elif config.dataset == 'S3DIS': + test_dataset = S3DISDataset(config, set='validation', use_potentials=True) + test_sampler = S3DISSampler(test_dataset) + collate_fn = S3DISCollate + else: + raise ValueError('Unsupported dataset : ' + config.dataset) + + # Data loader + test_loader = DataLoader(test_dataset, + batch_size=1, + sampler=test_sampler, + collate_fn=collate_fn, + num_workers=config.input_threads, + pin_memory=True) + + # Calibrate samplers + test_sampler.calibration(test_loader, verbose=True) + + print('\nModel Preparation') + print('*****************') + + # Define network model + t1 = time.time() + if config.dataset_task == 'classification': + net = KPCNN(config) + elif config.dataset_task in ['cloud_segmentation', 'slam_segmentation']: + net = KPFCNN(config, test_dataset.label_values, test_dataset.ignored_labels) + else: + raise ValueError('Unsupported dataset_task for deformation visu: ' + config.dataset_task) + + # Define a visualizer class + visualizer = ModelVisualizer(net, config, chkp_path=chosen_chkp, on_gpu=False) + print('Done in {:.1f}s\n'.format(time.time() - t1)) + + print('\nStart visualization') + print('*******************') + + # Training + visualizer.show_effective_recep_field(net, test_loader, config, f_idx) + + + diff --git a/visualize_deformations.py b/visualize_deformations.py index 1d9bf90..cb2ad0c 100644 --- a/visualize_deformations.py +++ b/visualize_deformations.py @@ -30,11 +30,12 @@ import torch # Dataset from datasets.ModelNet40 import * +from datasets.S3DIS import * from torch.utils.data import DataLoader from utils.config import Config from utils.visualizer import ModelVisualizer -from models.architectures import KPCNN +from models.architectures import KPCNN, KPFCNN # ---------------------------------------------------------------------------------------------------------------------- @@ -93,10 +94,12 @@ if __name__ == '__main__': # > 'last_XXX': Automatically retrieve the last trained model on dataset XXX # > '(old_)results/Log_YYYY-MM-DD_HH-MM-SS': Directly provide the path of a trained model - chosen_log = 'results/Log_2020-03-23_22-18-26' # => ModelNet40 + # chosen_log = 'results/Log_2020-04-04_10-04-42' # => ModelNet40 + # chosen_log = 'results/Log_2020-04-22_11-53-45' # => S3DIS + chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected # You can also choose the index of the snapshot to load (last by default) - chkp_idx = None + chkp_idx = -1 # Eventually you can choose which feature is visualized (index of the deform operation in the network) deform_idx = 0 @@ -139,10 +142,11 @@ if __name__ == '__main__': # Change parameters for the test here. For example, you can stop augmenting the input data. - #config.augment_noise = 0.0001 + config.augment_noise = 0.0001 #config.augment_symmetries = False - #config.batch_num = 3 - #config.in_radius = 4 + config.batch_num = 1 + config.in_radius = 2.0 + config.input_threads = 0 ############## # Prepare Data @@ -152,22 +156,28 @@ if __name__ == '__main__': print('Data Preparation') print('****************') - # Initialize datasets - test_dataset = ModelNet40Dataset(config, train=False) + # Initiate dataset + if config.dataset.startswith('ModelNet40'): + test_dataset = ModelNet40Dataset(config, train=False) + test_sampler = ModelNet40Sampler(test_dataset) + collate_fn = ModelNet40Collate + elif config.dataset == 'S3DIS': + test_dataset = S3DISDataset(config, set='validation', use_potentials=True) + test_sampler = S3DISSampler(test_dataset) + collate_fn = S3DISCollate + else: + raise ValueError('Unsupported dataset : ' + config.dataset) - # Initialize samplers - test_sampler = ModelNet40Sampler(test_dataset) - - # Initialize the dataloader + # Data loader test_loader = DataLoader(test_dataset, batch_size=1, sampler=test_sampler, - collate_fn=ModelNet40Collate, - num_workers=0, + collate_fn=collate_fn, + num_workers=config.input_threads, pin_memory=True) # Calibrate samplers - test_sampler.calibration(test_loader) + test_sampler.calibration(test_loader, verbose=True) print('\nModel Preparation') print('*****************') @@ -176,6 +186,8 @@ if __name__ == '__main__': t1 = time.time() if config.dataset_task == 'classification': net = KPCNN(config) + elif config.dataset_task in ['cloud_segmentation', 'slam_segmentation']: + net = KPFCNN(config, test_dataset.label_values, test_dataset.ignored_labels) else: raise ValueError('Unsupported dataset_task for deformation visu: ' + config.dataset_task)