diff --git a/datasets/S3DIS.py b/datasets/S3DIS.py index b12bb36..46818ad 100644 --- a/datasets/S3DIS.py +++ b/datasets/S3DIS.py @@ -52,7 +52,7 @@ from utils.config import bcolors class S3DISDataset(PointCloudDataset): - """Class to handle Modelnet 40 dataset.""" + """Class to handle S3DIS dataset.""" def __init__(self, config, set='training', use_potentials=True, load_data=True): """ @@ -138,7 +138,23 @@ class S3DISDataset(PointCloudDataset): ################ # List of training files - self.train_files = [join(ply_path, f + '.ply') for f in self.cloud_names] + self.files = [] + for i, f in enumerate(self.cloud_names): + if self.set == 'training': + if self.all_splits[i] != self.validation_split: + self.files += [join(ply_path, f + '.ply')] + elif self.set in ['validation', 'test', 'ERF']: + if self.all_splits[i] == self.validation_split: + self.files += [join(ply_path, f + '.ply')] + else: + raise ValueError('Unknown set for S3DIS data: ', self.set) + + if self.set == 'training': + self.cloud_names = [f for i, f in enumerate(self.cloud_names) + if self.all_splits[i] != self.validation_split] + elif self.set in ['validation', 'test', 'ERF']: + self.cloud_names = [f for i, f in enumerate(self.cloud_names) + if self.all_splits[i] == self.validation_split] if 0 < self.config.first_subsampling_dl <= 0.01: raise ValueError('subsampling_parameter too low (should be over 1 cm') @@ -149,7 +165,7 @@ class S3DISDataset(PointCloudDataset): self.input_labels = [] self.pot_trees = [] self.num_clouds = 0 - self.validation_proj = [] + self.test_proj = [] self.validation_labels = [] # Start loading @@ -624,21 +640,11 @@ class S3DISDataset(PointCloudDataset): # Load KDTrees ############## - for i, file_path in enumerate(self.train_files): + for i, file_path in enumerate(self.files): # Restart timer t0 = time.time() - # Skip split that is not in current set - if self.set == 'training': - if self.all_splits[i] == self.validation_split: - continue - elif self.set in ['validation', 'test', 'ERF']: - if self.all_splits[i] != self.validation_split: - continue - else: - raise ValueError('Unknown set for S3DIS data: ', self.set) - # Get cloud name cloud_name = self.cloud_names[i] @@ -714,17 +720,7 @@ class S3DISDataset(PointCloudDataset): pot_dl = self.config.in_radius / 10 cloud_ind = 0 - for i, file_path in enumerate(self.train_files): - - # Skip split that is not in current set - if self.set == 'training': - if self.all_splits[i] == self.validation_split: - continue - elif self.set in ['validation', 'test', 'ERF']: - if self.all_splits[i] != self.validation_split: - continue - else: - raise ValueError('Unknown set for S3DIS data: ', self.set) + for i, file_path in enumerate(self.files): # Get cloud name cloud_name = self.cloud_names[i] @@ -769,12 +765,7 @@ class S3DISDataset(PointCloudDataset): print('\nPreparing reprojection indices for testing') # Get validation/test reprojection indices - i_cloud = 0 - for i, file_path in enumerate(self.train_files): - - # Skip split that is not in current set - if self.all_splits[i] != self.validation_split: - continue + for i, file_path in enumerate(self.files): # Restart timer t0 = time.time() @@ -795,7 +786,7 @@ class S3DISDataset(PointCloudDataset): labels = data['class'] # Compute projection inds - idxs = self.input_trees[i_cloud].query(points, return_distance=False) + idxs = self.input_trees[i].query(points, return_distance=False) #dists, idxs = self.input_trees[i_cloud].kneighbors(points) proj_inds = np.squeeze(idxs).astype(np.int32) @@ -803,9 +794,8 @@ class S3DISDataset(PointCloudDataset): with open(proj_file, 'wb') as f: pickle.dump([proj_inds, labels], f) - self.validation_proj += [proj_inds] + self.test_proj += [proj_inds] self.validation_labels += [labels] - i_cloud += 1 print('{:s} done in {:.1f}s'.format(cloud_name, time.time() - t0)) print() @@ -819,6 +809,9 @@ class S3DISDataset(PointCloudDataset): # Get original points data = read_ply(file_path) return np.vstack((data['x'], data['y'], data['z'])).T + + + # ---------------------------------------------------------------------------------------------------------------------- # # Utility classes definition diff --git a/datasets/SemanticKitti.py b/datasets/SemanticKitti.py new file mode 100644 index 0000000..6c66df7 --- /dev/null +++ b/datasets/SemanticKitti.py @@ -0,0 +1,1407 @@ +# +# +# 0=================================0 +# | Kernel Point Convolutions | +# 0=================================0 +# +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Class handling SemanticKitti dataset. +# Implements a Dataset, a Sampler, and a collate_fn +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Hugues THOMAS - 11/06/2018 +# + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Imports and global variables +# \**********************************/ +# + +# Common libs +import time +import numpy as np +import pickle +import torch +import yaml +#from mayavi import mlab +from multiprocessing import Lock + + +# OS functions +from os import listdir +from os.path import exists, join, isdir + +# Dataset parent class +from datasets.common import * +from torch.utils.data import Sampler, get_worker_info +from utils.mayavi_visu import * +from utils.metrics import fast_confusion + +from datasets.common import grid_subsampling +from utils.config import bcolors + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Dataset class definition +# \******************************/ + + +class SemanticKittiDataset(PointCloudDataset): + """Class to handle SemanticKitti dataset.""" + + def __init__(self, config, set='training', balance_classes=True): + PointCloudDataset.__init__(self, 'SemanticKitti') + + ########################## + # Parameters for the files + ########################## + + # Dataset folder + self.path = '../../Data/SemanticKitti' + + # Type of task conducted on this dataset + self.dataset_task = 'slam_segmentation' + + # Training or test set + self.set = set + + # Get a list of sequences + if self.set == 'training': + self.sequences = ['{:02d}'.format(i) for i in range(11) if i != 8] + elif self.set == 'validation': + self.sequences = ['{:02d}'.format(i) for i in range(11) if i == 8] + elif self.set == 'test': + self.sequences = ['{:02d}'.format(i) for i in range(11, 22)] + else: + raise ValueError('Unknown set for SemanticKitti data: ', self.set) + + # List all files in each sequence + self.frames = [] + for seq in self.sequences: + velo_path = join(self.path, 'sequences', seq, 'velodyne') + frames = np.sort([vf[:-4] for vf in listdir(velo_path) if vf.endswith('.bin')]) + self.frames.append(frames) + + ########################### + # Object classes parameters + ########################### + + # Read labels + if config.n_frames == 1: + config_file = join(self.path, 'semantic-kitti.yaml') + elif config.n_frames > 1: + config_file = join(self.path, 'semantic-kitti-all.yaml') + else: + raise ValueError('number of frames has to be >= 1') + + with open(config_file, 'r') as stream: + doc = yaml.safe_load(stream) + all_labels = doc['labels'] + learning_map_inv = doc['learning_map_inv'] + learning_map = doc['learning_map'] + self.learning_map = np.zeros((np.max([k for k in learning_map.keys()]) + 1), dtype=np.int32) + for k, v in learning_map.items(): + self.learning_map[k] = v + + self.learning_map_inv = np.zeros((np.max([k for k in learning_map_inv.keys()]) + 1), dtype=np.int32) + for k, v in learning_map_inv.items(): + self.learning_map_inv[k] = v + + # Dict from labels to names + self.label_to_names = {k: all_labels[v] for k, v in learning_map_inv.items()} + + # Initiate a bunch of variables concerning class labels + self.init_labels() + + # List of classes ignored during training (can be empty) + self.ignored_labels = np.sort([0]) + + ################## + # Other parameters + ################## + + # Update number of class and data task in configuration + config.num_classes = self.num_classes + config.dataset_task = self.dataset_task + + # Parameters from config + self.config = config + + ################## + # Load calibration + ################## + + # Init variables + self.calibrations = [] + self.times = [] + self.poses = [] + self.all_inds = None + self.class_proportions = None + self.class_frames = [] + self.val_confs = [] + + # Load everything + self.load_calib_poses() + + ############################ + # Batch selection parameters + ############################ + + # Initialize value for batch limit (max number of points per batch). + self.batch_limit = torch.tensor([1], dtype=torch.float32) + self.batch_limit.share_memory_() + + # Initialize frame potentials + self.potentials = torch.from_numpy(np.random.rand(self.all_inds.shape[0]) * 0.1 + 0.1) + self.potentials.share_memory_() + + # If true, the same amount of frames is picked per class + self.balance_classes = balance_classes + + # Choose batch_num in_R and max_in_p depending on validation or training + if self.set == 'training': + self.batch_num = config.batch_num + self.max_in_p = config.max_in_points + self.in_R = config.in_radius + else: + self.batch_num = config.val_batch_num + self.max_in_p = config.max_val_points + self.in_R = config.val_radius + + # shared epoch indices and classes (in case we want class balanced sampler) + if set == 'training': + N = int(np.ceil(config.epoch_steps * self.batch_num * 1.1)) + else: + N = int(np.ceil(config.validation_size * self.batch_num * 1.1)) + self.epoch_i = torch.from_numpy(np.zeros((1,), dtype=np.int64)) + self.epoch_inds = torch.from_numpy(np.zeros((N,), dtype=np.int64)) + self.epoch_labels = torch.from_numpy(np.zeros((N,), dtype=np.int32)) + self.epoch_i.share_memory_() + self.epoch_inds.share_memory_() + self.epoch_labels.share_memory_() + + self.worker_waiting = torch.tensor([0 for _ in range(config.input_threads)], dtype=torch.int32) + self.worker_waiting.share_memory_() + self.worker_lock = Lock() + + return + + def __len__(self): + """ + Return the length of data here + """ + return len(self.frames) + + def __getitem__(self, batch_i): + """ + The main thread gives a list of indices to load a batch. Each worker is going to work in parallel to load a + different list of indices. + """ + + t0 = time.time() + + # Initiate concatanation lists + p_list = [] + f_list = [] + l_list = [] + fi_list = [] + p0_list = [] + s_list = [] + R_list = [] + r_inds_list = [] + r_mask_list = [] + val_labels_list = [] + batch_n = 0 + + while True: + + with self.worker_lock: + + # Get potential minimum + ind = int(self.epoch_inds[self.epoch_i]) + wanted_label = int(self.epoch_labels[self.epoch_i]) + + # Update epoch indice + self.epoch_i += 1 + + s_ind, f_ind = self.all_inds[ind] + + t1 = time.time() + + ######################### + # Merge n_frames together + ######################### + + # Initiate merged points + merged_points = np.zeros((0, 3), dtype=np.float32) + merged_labels = np.zeros((0,), dtype=np.int32) + merged_coords = np.zeros((0, 4), dtype=np.float32) + + # In case of validation also keep original point and reproj indices + + + # Get center of the first frame in world coordinates + p_origin = np.zeros((1, 4)) + p_origin[0, 3] = 1 + pose0 = self.poses[s_ind][f_ind] + p0 = p_origin.dot(pose0.T)[:, :3] + p0 = np.squeeze(p0) + o_pts = None + o_labels = None + + t2 = time.time() + + num_merged = 0 + f_inc = 0 + while num_merged < self.config.n_frames and f_ind - f_inc >= 0: + + # Select frame only if center has moved far away (more than X meter). Negative value to ignore + X = -1.0 + pose = self.poses[s_ind][f_ind - f_inc] + diff = p_origin.dot(pose.T)[:, :3] - p_origin.dot(pose0.T)[:, :3] + if num_merged > 0 and np.linalg.norm(diff) < num_merged * X: + f_inc += 1 + continue + + # Path of points and labels + seq_path = join(self.path, 'sequences', self.sequences[s_ind]) + velo_file = join(seq_path, 'velodyne', self.frames[s_ind][f_ind - f_inc] + '.bin') + if self.set == 'test': + label_file = None + else: + label_file = join(seq_path, 'labels', self.frames[s_ind][f_ind - f_inc] + '.label') + + # Read points + frame_points = np.fromfile(velo_file, dtype=np.float32) + points = frame_points.reshape((-1, 4)) + + if self.set == 'test': + # Fake labels + sem_labels = np.zeros((frame_points.shape[0],), dtype=np.int32) + else: + # Read labels + frame_labels = np.fromfile(label_file, dtype=np.int32) + sem_labels = frame_labels & 0xFFFF # semantic label in lower half + sem_labels = self.learning_map[sem_labels] + + # Apply pose + hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1]))) + new_points = hpoints.dot(pose.T) + new_points[:, 3:] = points[:, 3:] + + # In case of validation, keep the original points in memory + if self.set in ['validation', 'test'] and f_inc == 0: + o_pts = new_points[:, :3].astype(np.float32) + o_labels = sem_labels.astype(np.int32) + + # In case radius smaller than 50m, chose new center on a point of the wanted class or not + if self.in_R < 50.0 and f_inc == 0: + if self.balance_classes: + wanted_ind = np.random.choice(np.where(sem_labels == wanted_label)[0]) + else: + wanted_ind = np.random.choice(new_points.shape[0]) + p0 = new_points[wanted_ind, :3] + + # Eliminate points further than config.in_radius + mask = np.sum(np.square(new_points[:, :3] - p0), axis=1) < self.in_R ** 2 + mask_inds = np.where(mask)[0].astype(np.int32) + + # Shuffle points + rand_order = np.random.permutation(mask_inds) + new_points = new_points[rand_order, :] + sem_labels = sem_labels[rand_order] + + # Place points in original frame reference to get coordinates + hpoints = np.hstack((new_points[:, :3], np.ones_like(new_points[:, :1]))) + new_coords = hpoints.dot(pose0) + new_coords[:, 3:] = new_points[:, 3:] + + # Increment merge count + merged_points = np.vstack((merged_points, new_points[:, :3])) + merged_labels = np.hstack((merged_labels, sem_labels)) + merged_coords = np.vstack((merged_coords, new_coords)) + num_merged += 1 + f_inc += 1 + + + t3 = time.time() + + ######################### + # Merge n_frames together + ######################### + + # Too see yielding speed with debug timings method, collapse points (reduce mapping time to nearly 0) + #merged_points = merged_points[:100, :] + #merged_labels = merged_labels[:100] + #merged_points *= 0.1 + + # Subsample merged frames + in_pts, in_fts, in_lbls = grid_subsampling(merged_points, + features=merged_coords, + labels=merged_labels, + sampleDl=self.config.first_subsampling_dl) + + t4 = time.time() + + # Number collected + n = in_pts.shape[0] + + # Safe check + if n < 2: + continue + + # Randomly drop some points (augmentation process and safety for GPU memory consumption) + if n > self.max_in_p: + input_inds = np.random.choice(n, size=self.max_in_p, replace=False) + in_pts = in_pts[input_inds, :] + in_fts = in_fts[input_inds, :] + in_lbls = in_lbls[input_inds] + n = input_inds.shape[0] + + t5 = time.time() + + # Before augmenting, compute reprojection inds (only for validation and test) + if self.set in ['validation', 'test']: + + # get val_points that are in range + radiuses = np.sum(np.square(o_pts - p0), axis=1) + reproj_mask = radiuses < (0.99 * self.in_R) ** 2 + + # Project predictions on the frame points + search_tree = KDTree(in_pts, leaf_size=50) + proj_inds = search_tree.query(o_pts[reproj_mask, :], return_distance=False) + proj_inds = np.squeeze(proj_inds).astype(np.int32) + else: + proj_inds = np.zeros((0,)) + reproj_mask = np.zeros((0,)) + + t6 = time.time() + + # Data augmentation + in_pts, scale, R = self.augmentation_transform(in_pts) + + t7 = time.time() + + # Color augmentation + if np.random.rand() > self.config.augment_color: + in_fts[:, 3:] *= 0 + + # Stack batch + p_list += [in_pts] + f_list += [in_fts] + l_list += [np.squeeze(in_lbls)] + fi_list += [[s_ind, f_ind]] + p0_list += [p0] + s_list += [scale] + R_list += [R] + r_inds_list += [proj_inds] + r_mask_list += [reproj_mask] + val_labels_list += [o_labels] + + + t8 = time.time() + + # Update batch size + batch_n += n + + # In case batch is full, stop + if batch_n > int(self.batch_limit): + break + + ################### + # Concatenate batch + ################### + + stacked_points = np.concatenate(p_list, axis=0) + features = np.concatenate(f_list, axis=0) + labels = np.concatenate(l_list, axis=0) + frame_inds = np.array(fi_list, dtype=np.int32) + frame_centers = np.stack(p0_list, axis=0) + stack_lengths = np.array([pp.shape[0] for pp in p_list], dtype=np.int32) + scales = np.array(s_list, dtype=np.float32) + rots = np.stack(R_list, axis=0) + + + # Input features (Use reflectance, input height or all coordinates) + stacked_features = np.ones_like(stacked_points[:, :1], dtype=np.float32) + if self.config.in_features_dim == 1: + pass + elif self.config.in_features_dim == 2: + # Use original height coordinate + stacked_features = np.hstack((stacked_features, features[:, 2:3])) + elif self.config.in_features_dim == 3: + # Use height + reflectance + stacked_features = np.hstack((stacked_features, features[:, 2:])) + elif self.config.in_features_dim == 4: + # Use all coordinates + stacked_features = np.hstack((stacked_features, features[:3])) + elif self.config.in_features_dim == 5: + # Use all coordinates + reflectance + stacked_features = np.hstack((stacked_features, features)) + else: + raise ValueError('Only accepted input dimensions are 1, 4 and 7 (without and with XYZ)') + + t9 = time.time() + ####################### + # Create network inputs + ####################### + # + # Points, neighbors, pooling indices for each layers + # + + # Get the whole input list + input_list = self.segmentation_inputs(stacked_points, + stacked_features, + labels.astype(np.int64), + stack_lengths) + + t10 = time.time() + + # Add scale and rotation for testing + input_list += [scales, rots, frame_inds, frame_centers, r_inds_list, r_mask_list, val_labels_list] + + t11 = time.time() + + # Display timings + debug = False + if debug: + print('\n************************\n') + print('Timings:') + print('Lock ...... {:.1f}ms'.format(1000 * (t1 - t0))) + print('Init ...... {:.1f}ms'.format(1000 * (t2 - t1))) + print('Load ...... {:.1f}ms'.format(1000 * (t3 - t2))) + print('subs ...... {:.1f}ms'.format(1000 * (t4 - t3))) + print('drop ...... {:.1f}ms'.format(1000 * (t5 - t4))) + print('reproj .... {:.1f}ms'.format(1000 * (t6 - t5))) + print('augment ... {:.1f}ms'.format(1000 * (t7 - t6))) + print('stack ..... {:.1f}ms'.format(1000 * (t8 - t7))) + print('concat .... {:.1f}ms'.format(1000 * (t9 - t8))) + print('input ..... {:.1f}ms'.format(1000 * (t10 - t9))) + print('stack ..... {:.1f}ms'.format(1000 * (t11 - t10))) + print('\n************************\n') + + # Timings: (in test configuration) + # Lock ...... 0.1ms + # Init ...... 0.0ms + # Load ...... 40.0ms + # subs ...... 143.6ms + # drop ...... 4.6ms + # reproj .... 297.4ms + # augment ... 7.5ms + # stack ..... 0.0ms + # concat .... 1.4ms + # input ..... 816.0ms + # stack ..... 0.0ms + + # TODO: Where can we gain time for the robot real time test? + # > Load: no disk read necessary + pose useless if we only use one frame for testing + # > Drop: We can drop even more points. Random choice could be faster without replace=False + # > reproj: No reprojection needed + # > Augment: See which data agment we want at test time + # > input: MAIN BOTTLENECK. We need to see if we can do faster, maybe with some parallelisation + + return [self.config.num_layers] + input_list + + def load_calib_poses(self): + """ + load calib poses and times. + """ + + ########### + # Load data + ########### + + self.calibrations = [] + self.times = [] + self.poses = [] + + for seq in self.sequences: + + seq_folder = join(self.path, 'sequences', seq) + + # Read Calib + self.calibrations.append(self.parse_calibration(join(seq_folder, "calib.txt"))) + + # Read times + self.times.append(np.loadtxt(join(seq_folder, 'times.txt'), dtype=np.float32)) + + # Read poses + poses_f64 = self.parse_poses(join(seq_folder, 'poses.txt'), self.calibrations[-1]) + self.poses.append([pose.astype(np.float32) for pose in poses_f64]) + + ################################### + # Prepare the indices of all frames + ################################### + + seq_inds = np.hstack([np.ones(len(_), dtype=np.int32) * i for i, _ in enumerate(self.frames)]) + frame_inds = np.hstack([np.arange(len(_), dtype=np.int32) for _ in self.frames]) + self.all_inds = np.vstack((seq_inds, frame_inds)).T + + ################################################ + # For each class list the frames containing them + ################################################ + + if self.set in ['training', 'validation']: + + class_frames_bool = np.zeros((0, self.num_classes), dtype=np.bool) + self.class_proportions = np.zeros((self.num_classes,), dtype=np.int32) + + for s_ind, (seq, seq_frames) in enumerate(zip(self.sequences, self.frames)): + + frame_mode = 'single' + if self.config.n_frames > 1: + frame_mode = 'multi' + seq_stat_file = join(self.path, 'sequences', seq, 'stats_{:s}.pkl'.format(frame_mode)) + + # Check if inputs have already been computed + if exists(seq_stat_file): + # Read pkl + with open(seq_stat_file, 'rb') as f: + seq_class_frames, seq_proportions = pickle.load(f) + + else: + + # Initiate dict + print('Preparing seq {:s} class frames. (Long but one time only)'.format(seq)) + + # Class frames as a boolean mask + seq_class_frames = np.zeros((len(seq_frames), self.num_classes), dtype=np.bool) + + # Proportion of each class + seq_proportions = np.zeros((self.num_classes,), dtype=np.int32) + + # Sequence path + seq_path = join(self.path, 'sequences', seq) + + # Read all frames + for f_ind, frame_name in enumerate(seq_frames): + + # Path of points and labels + label_file = join(seq_path, 'labels', frame_name + '.label') + + # Read labels + frame_labels = np.fromfile(label_file, dtype=np.int32) + sem_labels = frame_labels & 0xFFFF # semantic label in lower half + sem_labels = self.learning_map[sem_labels] + + # Get present labels and there frequency + unique, counts = np.unique(sem_labels, return_counts=True) + + # Add this frame to the frame lists of all class present + frame_labels = np.array([self.label_to_idx[l] for l in unique], dtype=np.int32) + seq_class_frames[f_ind, frame_labels] = True + + # Add proportions + seq_proportions[frame_labels] += counts + + # Save pickle + with open(seq_stat_file, 'wb') as f: + pickle.dump([seq_class_frames, seq_proportions], f) + + class_frames_bool = np.vstack((class_frames_bool, seq_class_frames)) + self.class_proportions += seq_proportions + + # Transform boolean indexing to int indices. + self.class_frames = [] + for i, c in enumerate(self.label_values): + if c in self.ignored_labels: + self.class_frames.append(torch.zeros((0,), dtype=torch.int64)) + else: + integer_inds = np.where(class_frames_bool[:, i])[0] + self.class_frames.append(torch.from_numpy(integer_inds.astype(np.int64))) + + # Add variables for validation + if self.set == 'validation': + self.val_points = [] + self.val_labels = [] + self.val_confs = [] + + for s_ind, seq_frames in enumerate(self.frames): + self.val_confs.append(np.zeros((len(seq_frames), self.num_classes, self.num_classes))) + + return + + def parse_calibration(self, filename): + """ read calibration file with given filename + + Returns + ------- + dict + Calibration matrices as 4x4 numpy arrays. + """ + calib = {} + + calib_file = open(filename) + for line in calib_file: + key, content = line.strip().split(":") + values = [float(v) for v in content.strip().split()] + + pose = np.zeros((4, 4)) + pose[0, 0:4] = values[0:4] + pose[1, 0:4] = values[4:8] + pose[2, 0:4] = values[8:12] + pose[3, 3] = 1.0 + + calib[key] = pose + + calib_file.close() + + return calib + + def parse_poses(self, filename, calibration): + """ read poses file with per-scan poses from given filename + + Returns + ------- + list + list of poses as 4x4 numpy arrays. + """ + file = open(filename) + + poses = [] + + Tr = calibration["Tr"] + Tr_inv = np.linalg.inv(Tr) + + for line in file: + values = [float(v) for v in line.strip().split()] + + pose = np.zeros((4, 4)) + pose[0, 0:4] = values[0:4] + pose[1, 0:4] = values[4:8] + pose[2, 0:4] = values[8:12] + pose[3, 3] = 1.0 + + poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr))) + + return poses + + +class SemanticKittiSampler(Sampler): + """Sampler for SemanticKitti""" + + def __init__(self, dataset: SemanticKittiDataset): + Sampler.__init__(self, dataset) + + # Dataset used by the sampler (no copy is made in memory) + self.dataset = dataset + + # Number of step per epoch + if dataset.set == 'training': + self.N = dataset.config.epoch_steps + else: + self.N = dataset.config.validation_size + + return + + def __iter__(self): + """ + Yield next batch indices here. In this dataset, this is a dummy sampler that yield the index of batch element + (input sphere) in epoch instead of the list of point indices + """ + + if self.dataset.balance_classes: + + # Initiate current epoch ind + self.dataset.epoch_i *= 0 + self.dataset.epoch_inds *= 0 + self.dataset.epoch_labels *= 0 + + # Number of sphere centers taken per class in each cloud + num_centers = self.dataset.epoch_inds.shape[0] + + # Generate a list of indices balancing classes and respecting potentials + gen_indices = [] + gen_classes = [] + for i, c in enumerate(self.dataset.label_values): + if c not in self.dataset.ignored_labels: + + # Get the potentials of the frames containing this class + class_potentials = self.dataset.potentials[self.dataset.class_frames[i]] + + # Get the indices to generate thanks to potentials + used_classes = self.dataset.num_classes - len(self.dataset.ignored_labels) + class_n = num_centers // used_classes + 1 + if class_n < class_potentials.shape[0]: + _, class_indices = torch.topk(class_potentials, class_n, largest=False) + else: + class_indices = torch.randperm(class_potentials.shape[0]) + class_indices = self.dataset.class_frames[i][class_indices] + + # Add the indices to the generated ones + gen_indices.append(class_indices) + gen_classes.append(class_indices * 0 + c) + + # Update potentials + self.dataset.potentials[class_indices] = np.ceil(self.dataset.potentials[class_indices]) + self.dataset.potentials[class_indices] += np.random.rand(class_indices.shape[0]) * 0.1 + 0.1 + + # Stack the chosen indices of all classes + gen_indices = torch.cat(gen_indices, dim=0) + gen_classes = torch.cat(gen_classes, dim=0) + + # Shuffle generated indices + rand_order = torch.randperm(gen_indices.shape[0])[:num_centers] + gen_indices = gen_indices[rand_order] + gen_classes = gen_classes[rand_order] + + # Update potentials (Change the order for the next epoch) + self.dataset.potentials[gen_indices] = torch.ceil(self.dataset.potentials[gen_indices]) + self.dataset.potentials[gen_indices] += torch.from_numpy(np.random.rand(gen_indices.shape[0]) * 0.1 + 0.1) + + # Update epoch inds + self.dataset.epoch_inds += gen_indices + self.dataset.epoch_labels += gen_classes.type(torch.int32) + + else: + + # Initiate current epoch ind + self.dataset.epoch_i *= 0 + self.dataset.epoch_inds *= 0 + self.dataset.epoch_labels *= 0 + + # Number of sphere centers taken per class in each cloud + num_centers = self.dataset.epoch_inds.shape[0] + + # Get the list of indices to generate thanks to potentials + if num_centers < self.dataset.potentials.shape[0]: + _, gen_indices = torch.topk(self.dataset.potentials, num_centers, largest=False, sorted=True) + else: + gen_indices = torch.randperm(self.dataset.potentials.shape[0]) + + # Update potentials (Change the order for the next epoch) + self.dataset.potentials[gen_indices] = torch.ceil(self.dataset.potentials[gen_indices]) + self.dataset.potentials[gen_indices] += torch.from_numpy(np.random.rand(gen_indices.shape[0]) * 0.1 + 0.1) + + # Update epoch inds + self.dataset.epoch_inds += gen_indices + + # Generator loop + for i in range(self.N): + yield i + + def __len__(self): + """ + The number of yielded samples is variable + """ + return self.N + + def calib_max_in(self, config, dataloader, untouched_ratio=0.8, verbose=True): + """ + Method performing batch and neighbors calibration. + Batch calibration: Set "batch_limit" (the maximum number of points allowed in every batch) so that the + average batch size (number of stacked pointclouds) is the one asked. + Neighbors calibration: Set the "neighborhood_limits" (the maximum number of neighbors allowed in convolutions) + so that 90% of the neighborhoods remain untouched. There is a limit for each layer. + """ + + ############################## + # Previously saved calibration + ############################## + + print('\nStarting Calibration of max_in_points value (use verbose=True for more details)') + t0 = time.time() + + redo = False + + # Batch limit + # *********** + + # Load max_in_limit dictionary + max_in_lim_file = join(self.dataset.path, 'max_in_limits.pkl') + if exists(max_in_lim_file): + with open(max_in_lim_file, 'rb') as file: + max_in_lim_dict = pickle.load(file) + else: + max_in_lim_dict = {} + + # Check if the max_in limit associated with current parameters exists + if self.dataset.balance_classes: + sampler_method = 'balanced' + else: + sampler_method = 'random' + key = '{:s}_{:.3f}_{:.3f}'.format(sampler_method, + self.dataset.in_R, + self.dataset.config.first_subsampling_dl) + if key in max_in_lim_dict: + self.dataset.max_in_p = max_in_lim_dict[key] + else: + redo = True + + if verbose: + print('\nPrevious calibration found:') + print('Check max_in limit dictionary') + if key in max_in_lim_dict: + color = bcolors.OKGREEN + v = str(int(max_in_lim_dict[key])) + else: + color = bcolors.FAIL + v = '?' + print('{:}\"{:s}\": {:s}{:}'.format(color, key, v, bcolors.ENDC)) + + if redo: + + ######################## + # Batch calib parameters + ######################## + + # Loop parameters + last_display = time.time() + i = 0 + breaking = False + + all_lengths = [] + N = 1000 + + ##################### + # Perform calibration + ##################### + + for epoch in range(10): + for batch_i, batch in enumerate(dataloader): + + # Control max_in_points value + all_lengths += batch.lengths[0] + + # Convergence + if len(all_lengths) > N: + breaking = True + break + + i += 1 + t = time.time() + + # Console display (only one per second) + if t - last_display > 1.0: + last_display = t + message = 'Collecting {:d} in_points: {:5.1f}%' + print(message.format(N, + 100 * len(all_lengths) / N)) + + if breaking: + break + + # TODO: Compute the percentile np.percentile? + # TODO: optionnally show a plot of the in_points histogram? + + self.dataset.max_in_p = int(np.percentile(all_lengths, 100*untouched_ratio)) + + if verbose: + + # Create histogram + a = 1 + + # Save max_in_limit dictionary + max_in_lim_dict[key] = self.dataset.max_in_p + with open(max_in_lim_file, 'wb') as file: + pickle.dump(max_in_lim_dict, file) + + # Update value in config + if self.dataset.set == 'training': + config.max_in_points = self.dataset.max_in_p + else: + config.max_val_points = self.dataset.max_in_p + + print('Calibration done in {:.1f}s\n'.format(time.time() - t0)) + return + + def calibration(self, dataloader, untouched_ratio=0.9, verbose=False): + """ + Method performing batch and neighbors calibration. + Batch calibration: Set "batch_limit" (the maximum number of points allowed in every batch) so that the + average batch size (number of stacked pointclouds) is the one asked. + Neighbors calibration: Set the "neighborhood_limits" (the maximum number of neighbors allowed in convolutions) + so that 90% of the neighborhoods remain untouched. There is a limit for each layer. + """ + + ############################## + # Previously saved calibration + ############################## + + print('\nStarting Calibration (use verbose=True for more details)') + t0 = time.time() + + redo = False + + # Batch limit + # *********** + + # Load batch_limit dictionary + batch_lim_file = join(self.dataset.path, 'batch_limits.pkl') + if exists(batch_lim_file): + with open(batch_lim_file, 'rb') as file: + batch_lim_dict = pickle.load(file) + else: + batch_lim_dict = {} + + # Check if the batch limit associated with current parameters exists + if self.dataset.balance_classes: + sampler_method = 'balanced' + else: + sampler_method = 'random' + key = '{:s}_{:.3f}_{:.3f}_{:d}_{:d}'.format(sampler_method, + self.dataset.in_R, + self.dataset.config.first_subsampling_dl, + self.dataset.batch_num, + self.dataset.max_in_p) + if key in batch_lim_dict: + self.dataset.batch_limit[0] = batch_lim_dict[key] + else: + redo = True + + if verbose: + print('\nPrevious calibration found:') + print('Check batch limit dictionary') + if key in batch_lim_dict: + color = bcolors.OKGREEN + v = str(int(batch_lim_dict[key])) + else: + color = bcolors.FAIL + v = '?' + print('{:}\"{:s}\": {:s}{:}'.format(color, key, v, bcolors.ENDC)) + + # Neighbors limit + # *************** + + # Load neighb_limits dictionary + neighb_lim_file = join(self.dataset.path, 'neighbors_limits.pkl') + if exists(neighb_lim_file): + with open(neighb_lim_file, 'rb') as file: + neighb_lim_dict = pickle.load(file) + else: + neighb_lim_dict = {} + + # Check if the limit associated with current parameters exists (for each layer) + neighb_limits = [] + for layer_ind in range(self.dataset.config.num_layers): + + dl = self.dataset.config.first_subsampling_dl * (2**layer_ind) + if self.dataset.config.deform_layers[layer_ind]: + r = dl * self.dataset.config.deform_radius + else: + r = dl * self.dataset.config.conv_radius + + key = '{:s}_{:d}_{:.3f}_{:.3f}'.format(sampler_method, self.dataset.max_in_p, dl, r) + if key in neighb_lim_dict: + neighb_limits += [neighb_lim_dict[key]] + + if len(neighb_limits) == self.dataset.config.num_layers: + self.dataset.neighborhood_limits = neighb_limits + else: + redo = True + + if verbose: + print('Check neighbors limit dictionary') + for layer_ind in range(self.dataset.config.num_layers): + dl = self.dataset.config.first_subsampling_dl * (2**layer_ind) + if self.dataset.config.deform_layers[layer_ind]: + r = dl * self.dataset.config.deform_radius + else: + r = dl * self.dataset.config.conv_radius + key = '{:s}_{:d}_{:.3f}_{:.3f}'.format(sampler_method, self.dataset.max_in_p, dl, r) + + if key in neighb_lim_dict: + color = bcolors.OKGREEN + v = str(neighb_lim_dict[key]) + else: + color = bcolors.FAIL + v = '?' + print('{:}\"{:s}\": {:s}{:}'.format(color, key, v, bcolors.ENDC)) + + if redo: + + ############################ + # Neighbors calib parameters + ############################ + + # From config parameter, compute higher bound of neighbors number in a neighborhood + hist_n = int(np.ceil(4 / 3 * np.pi * (self.dataset.config.deform_radius + 1) ** 3)) + + # Histogram of neighborhood sizes + neighb_hists = np.zeros((self.dataset.config.num_layers, hist_n), dtype=np.int32) + + ######################## + # Batch calib parameters + ######################## + + # Estimated average batch size and target value + estim_b = 0 + target_b = self.dataset.batch_num + + # Calibration parameters + low_pass_T = 10 + Kp = 100.0 + finer = False + + # Convergence parameters + smooth_errors = [] + converge_threshold = 0.1 + + # Save input pointcloud sizes to control max_in_points + cropped_n = 0 + all_n = 0 + + # Loop parameters + last_display = time.time() + i = 0 + breaking = False + + ##################### + # Perform calibration + ##################### + + for epoch in range(10): + for batch_i, batch in enumerate(dataloader): + + # Control max_in_points value + are_cropped = batch.lengths[0] > self.dataset.max_in_p - 1 + cropped_n += torch.sum(are_cropped.type(torch.int32)).item() + all_n += int(batch.lengths[0].shape[0]) + + # Update neighborhood histogram + counts = [np.sum(neighb_mat.numpy() < neighb_mat.shape[0], axis=1) for neighb_mat in batch.neighbors] + hists = [np.bincount(c, minlength=hist_n)[:hist_n] for c in counts] + neighb_hists += np.vstack(hists) + + # batch length + b = len(batch.frame_inds) + + # Update estim_b (low pass filter) + estim_b += (b - estim_b) / low_pass_T + + # Estimate error (noisy) + error = target_b - b + + # Save smooth errors for convergene check + smooth_errors.append(target_b - estim_b) + if len(smooth_errors) > 10: + smooth_errors = smooth_errors[1:] + + # Update batch limit with P controller + self.dataset.batch_limit += Kp * error + + # finer low pass filter when closing in + if not finer and np.abs(estim_b - target_b) < 1: + low_pass_T = 100 + finer = True + + # Convergence + if finer and np.max(np.abs(smooth_errors)) < converge_threshold: + breaking = True + break + + i += 1 + t = time.time() + + # Console display (only one per second) + if verbose and (t - last_display) > 1.0: + last_display = t + message = 'Step {:5d} estim_b ={:5.2f} batch_limit ={:7d}' + print(message.format(i, + estim_b, + int(self.dataset.batch_limit))) + + if breaking: + break + + # Use collected neighbor histogram to get neighbors limit + cumsum = np.cumsum(neighb_hists.T, axis=0) + percentiles = np.sum(cumsum < (untouched_ratio * cumsum[hist_n - 1, :]), axis=0) + self.dataset.neighborhood_limits = percentiles + + if verbose: + + # Crop histogram + while np.sum(neighb_hists[:, -1]) == 0: + neighb_hists = neighb_hists[:, :-1] + hist_n = neighb_hists.shape[1] + + print('\n**************************************************\n') + line0 = 'neighbors_num ' + for layer in range(neighb_hists.shape[0]): + line0 += '| layer {:2d} '.format(layer) + print(line0) + for neighb_size in range(hist_n): + line0 = ' {:4d} '.format(neighb_size) + for layer in range(neighb_hists.shape[0]): + if neighb_size > percentiles[layer]: + color = bcolors.FAIL + else: + color = bcolors.OKGREEN + line0 += '|{:}{:10d}{:} '.format(color, + neighb_hists[layer, neighb_size], + bcolors.ENDC) + + print(line0) + + print('\n**************************************************\n') + print('\nchosen neighbors limits: ', percentiles) + print() + + # Control max_in_points value + print('\n**************************************************\n') + if cropped_n > 0.3 * all_n: + color = bcolors.FAIL + else: + color = bcolors.OKGREEN + print('Current value of max_in_points {:d}'.format(self.dataset.max_in_p)) + print(' > {:}{:.1f}% inputs are cropped{:}'.format(color, 100 * cropped_n / all_n, bcolors.ENDC)) + if cropped_n > 0.3 * all_n: + print('\nTry a higher max_in_points value\n'.format(100 * cropped_n / all_n)) + #raise ValueError('Value of max_in_points too low') + print('\n**************************************************\n') + + # Save batch_limit dictionary + key = '{:s}_{:.3f}_{:.3f}_{:d}_{:d}'.format(sampler_method, + self.dataset.in_R, + self.dataset.config.first_subsampling_dl, + self.dataset.batch_num, + self.dataset.max_in_p) + batch_lim_dict[key] = float(self.dataset.batch_limit) + with open(batch_lim_file, 'wb') as file: + pickle.dump(batch_lim_dict, file) + + # Save neighb_limit dictionary + for layer_ind in range(self.dataset.config.num_layers): + dl = self.dataset.config.first_subsampling_dl * (2 ** layer_ind) + if self.dataset.config.deform_layers[layer_ind]: + r = dl * self.dataset.config.deform_radius + else: + r = dl * self.dataset.config.conv_radius + key = '{:s}_{:d}_{:.3f}_{:.3f}'.format(sampler_method, self.dataset.max_in_p, dl, r) + neighb_lim_dict[key] = self.dataset.neighborhood_limits[layer_ind] + with open(neighb_lim_file, 'wb') as file: + pickle.dump(neighb_lim_dict, file) + + + print('Calibration done in {:.1f}s\n'.format(time.time() - t0)) + return + + +class SemanticKittiCustomBatch: + """Custom batch definition with memory pinning for SemanticKitti""" + + def __init__(self, input_list): + + # Get rid of batch dimension + input_list = input_list[0] + + # Number of layers + L = int(input_list[0]) + + # Extract input tensors from the list of numpy array + ind = 1 + self.points = [torch.from_numpy(nparray) for nparray in input_list[ind:ind+L]] + ind += L + self.neighbors = [torch.from_numpy(nparray) for nparray in input_list[ind:ind+L]] + ind += L + self.pools = [torch.from_numpy(nparray) for nparray in input_list[ind:ind+L]] + ind += L + self.upsamples = [torch.from_numpy(nparray) for nparray in input_list[ind:ind+L]] + ind += L + self.lengths = [torch.from_numpy(nparray) for nparray in input_list[ind:ind+L]] + ind += L + self.features = torch.from_numpy(input_list[ind]) + ind += 1 + self.labels = torch.from_numpy(input_list[ind]) + ind += 1 + self.scales = torch.from_numpy(input_list[ind]) + ind += 1 + self.rots = torch.from_numpy(input_list[ind]) + ind += 1 + self.frame_inds = torch.from_numpy(input_list[ind]) + ind += 1 + self.frame_centers = torch.from_numpy(input_list[ind]) + ind += 1 + self.reproj_inds = input_list[ind] + ind += 1 + self.reproj_masks = input_list[ind] + ind += 1 + self.val_labels = input_list[ind] + + return + + def pin_memory(self): + """ + Manual pinning of the memory + """ + + self.points = [in_tensor.pin_memory() for in_tensor in self.points] + self.neighbors = [in_tensor.pin_memory() for in_tensor in self.neighbors] + self.pools = [in_tensor.pin_memory() for in_tensor in self.pools] + self.upsamples = [in_tensor.pin_memory() for in_tensor in self.upsamples] + self.lengths = [in_tensor.pin_memory() for in_tensor in self.lengths] + self.features = self.features.pin_memory() + self.labels = self.labels.pin_memory() + self.scales = self.scales.pin_memory() + self.rots = self.rots.pin_memory() + self.frame_inds = self.frame_inds.pin_memory() + self.frame_centers = self.frame_centers.pin_memory() + + return self + + def to(self, device): + + self.points = [in_tensor.to(device) for in_tensor in self.points] + self.neighbors = [in_tensor.to(device) for in_tensor in self.neighbors] + self.pools = [in_tensor.to(device) for in_tensor in self.pools] + self.upsamples = [in_tensor.to(device) for in_tensor in self.upsamples] + self.lengths = [in_tensor.to(device) for in_tensor in self.lengths] + self.features = self.features.to(device) + self.labels = self.labels.to(device) + self.scales = self.scales.to(device) + self.rots = self.rots.to(device) + self.frame_inds = self.frame_inds.to(device) + self.frame_centers = self.frame_centers.to(device) + + return self + + def unstack_points(self, layer=None): + """Unstack the points""" + return self.unstack_elements('points', layer) + + def unstack_neighbors(self, layer=None): + """Unstack the neighbors indices""" + return self.unstack_elements('neighbors', layer) + + def unstack_pools(self, layer=None): + """Unstack the pooling indices""" + return self.unstack_elements('pools', layer) + + def unstack_elements(self, element_name, layer=None, to_numpy=True): + """ + Return a list of the stacked elements in the batch at a certain layer. If no layer is given, then return all + layers + """ + + if element_name == 'points': + elements = self.points + elif element_name == 'neighbors': + elements = self.neighbors + elif element_name == 'pools': + elements = self.pools[:-1] + else: + raise ValueError('Unknown element name: {:s}'.format(element_name)) + + all_p_list = [] + for layer_i, layer_elems in enumerate(elements): + + if layer is None or layer == layer_i: + + i0 = 0 + p_list = [] + if element_name == 'pools': + lengths = self.lengths[layer_i+1] + else: + lengths = self.lengths[layer_i] + + for b_i, length in enumerate(lengths): + + elem = layer_elems[i0:i0 + length] + if element_name == 'neighbors': + elem[elem >= self.points[layer_i].shape[0]] = -1 + elem[elem >= 0] -= i0 + elif element_name == 'pools': + elem[elem >= self.points[layer_i].shape[0]] = -1 + elem[elem >= 0] -= torch.sum(self.lengths[layer_i][:b_i]) + i0 += length + + if to_numpy: + p_list.append(elem.numpy()) + else: + p_list.append(elem) + + if layer == layer_i: + return p_list + + all_p_list.append(p_list) + + return all_p_list + + +def SemanticKittiCollate(batch_data): + return SemanticKittiCustomBatch(batch_data) + + +def debug_timing(dataset, loader): + """Timing of generator function""" + + t = [time.time()] + last_display = time.time() + mean_dt = np.zeros(2) + estim_b = dataset.batch_num + estim_N = 0 + + for epoch in range(10): + + for batch_i, batch in enumerate(loader): + # print(batch_i, tuple(points.shape), tuple(normals.shape), labels, indices, in_sizes) + + # New time + t = t[-1:] + t += [time.time()] + + # Update estim_b (low pass filter) + estim_b += (len(batch.frame_inds) - estim_b) / 100 + estim_N += (batch.features.shape[0] - estim_N) / 10 + + # Pause simulating computations + time.sleep(0.05) + t += [time.time()] + + # Average timing + mean_dt = 0.9 * mean_dt + 0.1 * (np.array(t[1:]) - np.array(t[:-1])) + + # Console display (only one per second) + if (t[-1] - last_display) > -1.0: + last_display = t[-1] + message = 'Step {:08d} -> (ms/batch) {:8.2f} {:8.2f} / batch = {:.2f} - {:.0f}' + print(message.format(batch_i, + 1000 * mean_dt[0], + 1000 * mean_dt[1], + estim_b, + estim_N)) + + print('************* Epoch ended *************') + + _, counts = np.unique(dataset.input_labels, return_counts=True) + print(counts) + + +def debug_class_w(dataset, loader): + """Timing of generator function""" + + i = 0 + + counts = np.zeros((0, dataset.num_classes,), dtype=np.int64) + + s = '{:^6}|'.format('step') + for c in dataset.label_names: + s += '{:^6}'.format(c[:4]) + print(s) + print(6*'-' + '|' + 6*dataset.num_classes*'-') + + for epoch in range(10): + for batch_i, batch in enumerate(loader): + # print(batch_i, tuple(points.shape), tuple(normals.shape), labels, indices, in_sizes) + + # count labels + new_counts = np.bincount(batch.labels) + counts[:new_counts.shape[0]] += new_counts.astype(np.int64) + + # Update proportions + proportions = 1000 * counts / np.sum(counts) + + print(proportions) + s = '{:^6d}|'.format(i) + for pp in proportions: + s += '{:^6.1f}'.format(pp) + print(s) + i += 1 + diff --git a/datasets/common.py b/datasets/common.py index c20041b..006be36 100644 --- a/datasets/common.py +++ b/datasets/common.py @@ -345,7 +345,7 @@ class PointCloudDataset(Dataset): else: # No pooling in the end of this layer, no pooling indices required pool_i = np.zeros((0, 1), dtype=np.int32) - pool_p = np.zeros((0, 3), dtype=np.float32) + pool_p = np.zeros((0, 1), dtype=np.float32) pool_b = np.zeros((0,), dtype=np.int32) # Reduce size of neighbors matrices by eliminating furthest point diff --git a/models/architectures.py b/models/architectures.py index d3e0261..75b5949 100644 --- a/models/architectures.py +++ b/models/architectures.py @@ -14,7 +14,6 @@ # Hugues THOMAS - 06/03/2020 # - from models.blocks import * import numpy as np @@ -201,7 +200,7 @@ class KPFCNN(nn.Module): Class defining KPFCNN """ - def __init__(self, config): + def __init__(self, config, lbl_values, ign_lbls): super(KPFCNN, self).__init__() ############ @@ -214,6 +213,7 @@ class KPFCNN(nn.Module): in_dim = config.in_features_dim out_dim = config.first_features_dim self.K = config.num_kernel_points + self.C = len(lbl_values) - len(ign_lbls) ##################### # List Encoder blocks @@ -303,21 +303,21 @@ class KPFCNN(nn.Module): out_dim = out_dim // 2 self.head_mlp = UnaryBlock(out_dim, config.first_features_dim, False, 0) - self.head_softmax = UnaryBlock(config.first_features_dim, config.num_classes, False, 0) + self.head_softmax = UnaryBlock(config.first_features_dim, self.C, False, 0) ################ # Network Losses ################ + # List of valid labels (those not ignored in loss) + self.valid_labels = np.sort([c for c in lbl_values if c not in ign_lbls]) + # Choose segmentation loss - if config.segloss_balance == 'none': - self.criterion = torch.nn.CrossEntropyLoss() - elif config.segloss_balance == 'class': - self.criterion = torch.nn.CrossEntropyLoss() - elif config.segloss_balance == 'batch': - self.criterion = torch.nn.CrossEntropyLoss() + if len(config.class_w) > 0: + class_w = torch.from_numpy(np.array(config.class_w, dtype=np.float32)) + self.criterion = torch.nn.CrossEntropyLoss(weight=class_w, ignore_index=-1) else: - raise ValueError('Unknown segloss_balance:', config.segloss_balance) + self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) self.offset_loss = config.offsets_loss self.offset_decay = config.offsets_decay self.output_loss = 0 @@ -357,12 +357,18 @@ class KPFCNN(nn.Module): :return: loss """ + # Set all ignored labels to -1 and correct the other label to be in [0, C-1] range + target = - torch.ones_like(labels) + for i, c in enumerate(self.valid_labels): + target[labels == c] = i + + # Reshape to have a minibatch size of 1 outputs = torch.transpose(outputs, 0, 1) outputs = outputs.unsqueeze(0) - labels = labels.unsqueeze(0) + target = target.unsqueeze(0) # Cross entropy loss - self.output_loss = self.criterion(outputs, labels) + self.output_loss = self.criterion(outputs, target) # Regularization of deformable offsets self.reg_loss = self.offset_regularizer() @@ -370,8 +376,7 @@ class KPFCNN(nn.Module): # Combined loss return self.output_loss + self.reg_loss - @staticmethod - def accuracy(outputs, labels): + def accuracy(self, outputs, labels): """ Computes accuracy of the current batch :param outputs: logits predicted by the network @@ -379,9 +384,14 @@ class KPFCNN(nn.Module): :return: accuracy value """ + # Set all ignored labels to -1 and correct the other label to be in [0, C-1] range + target = - torch.ones_like(labels) + for i, c in enumerate(self.valid_labels): + target[labels == c] = i + predicted = torch.argmax(outputs.data, dim=1) - total = labels.size(0) - correct = (predicted == labels).sum().item() + total = target.size(0) + correct = (predicted == target).sum().item() return correct / total diff --git a/plot_convergence.py b/plot_convergence.py index 8c07821..8b0abe1 100644 --- a/plot_convergence.py +++ b/plot_convergence.py @@ -39,6 +39,7 @@ from utils.ply import read_ply # Datasets from datasets.ModelNet40 import ModelNet40Dataset from datasets.S3DIS import S3DISDataset +from datasets.SemanticKitti import SemanticKittiDataset # ---------------------------------------------------------------------------------------------------------------------- # @@ -239,7 +240,7 @@ def load_multi_snap_clouds(path, dataset, file_i, only_last=False): else: for f in listdir(cloud_folder): if f.endswith('.ply') and not f.endswith('sub.ply'): - if np.any([cloud_path.endswith(f) for cloud_path in dataset.train_files]): + if np.any([cloud_path.endswith(f) for cloud_path in dataset.files]): data = read_ply(join(cloud_folder, f)) labels = data['class'] preds = data['preds'] @@ -971,20 +972,21 @@ def compare_convergences_SLAM(dataset, list_of_paths, list_of_names=None): class_list = [dataset.label_to_names[label] for label in dataset.label_values if label not in dataset.ignored_labels] - s = '{:^10}|'.format('mean') + s = '{:^6}|'.format('mean') for c in class_list: - s += '{:^10}'.format(c) + s += '{:^6}'.format(c[:4]) print(s) - print(10*'-' + '|' + 10*config.num_classes*'-') + print(6*'-' + '|' + 6*config.num_classes*'-') for path in list_of_paths: # Get validation IoUs + nc_model = dataset.num_classes - len(dataset.ignored_labels) file = join(path, 'val_IoUs.txt') - val_IoUs = load_single_IoU(file, config.num_classes) + val_IoUs = load_single_IoU(file, nc_model) # Get Subpart IoUs file = join(path, 'subpart_IoUs.txt') - subpart_IoUs = load_single_IoU(file, config.num_classes) + subpart_IoUs = load_single_IoU(file, nc_model) # Get mean IoU val_class_IoUs, val_mIoUs = IoU_class_metrics(val_IoUs, smooth_n) @@ -997,22 +999,21 @@ def compare_convergences_SLAM(dataset, list_of_paths, list_of_names=None): all_subpart_mIoUs += [subpart_mIoUs] all_subpart_class_IoUs += [subpart_class_IoUs] - s = '{:^10.1f}|'.format(100*subpart_mIoUs[-1]) + s = '{:^6.1f}|'.format(100*subpart_mIoUs[-1]) for IoU in subpart_class_IoUs[-1]: - s += '{:^10.1f}'.format(100*IoU) + s += '{:^6.1f}'.format(100*IoU) print(s) - - print(10*'-' + '|' + 10*config.num_classes*'-') + print(6*'-' + '|' + 6*config.num_classes*'-') for snap_IoUs in all_val_class_IoUs: if len(snap_IoUs) > 0: - s = '{:^10.1f}|'.format(100*np.mean(snap_IoUs[-1])) + s = '{:^6.1f}|'.format(100*np.mean(snap_IoUs[-1])) for IoU in snap_IoUs[-1]: - s += '{:^10.1f}'.format(100*IoU) + s += '{:^6.1f}'.format(100*IoU) else: - s = '{:^10s}'.format('-') + s = '{:^6s}'.format('-') for _ in range(config.num_classes): - s += '{:^10s}'.format('-') + s += '{:^6s}'.format('-') print(s) # Plots @@ -1038,7 +1039,7 @@ def compare_convergences_SLAM(dataset, list_of_paths, list_of_names=None): #ax.set_yticks(np.arange(0.8, 1.02, 0.02)) displayed_classes = [0, 1, 2, 3, 4, 5, 6, 7] - displayed_classes = [] + #displayed_classes = [] for c_i, c_name in enumerate(class_list): if c_i in displayed_classes: @@ -1410,14 +1411,14 @@ def S3DIS_first(old_result_limit): return logs, logs_names -def S3DIS_(old_result_limit): +def S3DIS_go(old_result_limit): """ Test S3DIS. """ # Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset. start = 'Log_2020-04-03_11-12-07' - end = 'Log_2020-04-25_19-30-17' + end = 'Log_2020-04-07_15-30-17' if end < old_result_limit: res_path = 'old_results' @@ -1430,6 +1431,11 @@ def S3DIS_(old_result_limit): # Give names to the logs (for legends) logs_names = ['R=2.0_r=0.04_Din=128_potential', 'R=2.0_r=0.04_Din=64_potential', + 'R=1.8_r=0.03', + 'R=1.8_r=0.03_deeper', + 'R=1.8_r=0.03_deform', + 'R=2.0_r=0.03_megadeep', + 'R=2.5_r=0.03_megadeep', 'test'] logs_names = np.array(logs_names[:len(logs)]) @@ -1437,17 +1443,52 @@ def S3DIS_(old_result_limit): return logs, logs_names +def SemanticKittiFirst(old_result_limit): + """ + Test SematicKitti. First exps + """ + + # Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset. + start = 'Log_2020-04-07_15-30-17' + end = 'Log_2020-05-07_15-30-17' + + if end < old_result_limit: + res_path = 'old_results' + else: + res_path = 'results' + + logs = np.sort([join(res_path, l) for l in listdir(res_path) if start <= l <= end]) + logs = logs.astype(' 'last_XXX': Automatically retrieve the last trained model on dataset XXX + # > '(old_)results/Log_YYYY-MM-DD_HH-MM-SS': Directly provide the path of a trained model + + chosen_log = 'results/Log_2020-04-07_18-22-18' # => ModelNet40 + + # You can also choose the index of the snapshot to load (last by default) + chkp_idx = None + + # Choose to test on validation or test split + on_val = True + + # Deal with 'last_XXXXXX' choices + chosen_log = model_choice(chosen_log) + + ############################ + # Initialize the environment + ############################ + + # Set which gpu is going to be used + GPU_ID = '3' + + # Set GPU visible device + os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID + + ############### + # Previous chkp + ############### + + # Find all checkpoints in the chosen training folder + chkp_path = os.path.join(chosen_log, 'checkpoints') + chkps = [f for f in os.listdir(chkp_path) if f[:4] == 'chkp'] + + # Find which snapshot to restore + if chkp_idx is None: + chosen_chkp = 'current_chkp.tar' + else: + chosen_chkp = np.sort(chkps)[chkp_idx] + chosen_chkp = os.path.join(chosen_log, 'checkpoints', chosen_chkp) + + # Initialize configuration class + config = Config() + config.load(chosen_log) + + ################################## + # Change model parameters for test + ################################## + + # Change parameters for the test here. For example, you can stop augmenting the input data. + + #config.augment_noise = 0.0001 + #config.augment_symmetries = False + #config.batch_num = 3 + #config.in_radius = 4 + config.validation_size = 200 + config.input_threads = 0 + + ############## + # Prepare Data + ############## + + print() + print('Data Preparation') + print('****************') + + if on_val: + set = 'validation' + else: + set = 'test' + + # Initiate dataset + if config.dataset.startswith('ModelNet40'): + test_dataset = ModelNet40Dataset(config, train=False) + test_sampler = ModelNet40Sampler(test_dataset) + collate_fn = ModelNet40Collate + elif config.dataset == 'S3DIS': + test_dataset = S3DISDataset(config, set='validation', use_potentials=True) + test_sampler = S3DISSampler(test_dataset) + collate_fn = S3DISCollate + elif config.dataset == 'SemanticKitti': + test_dataset = SemanticKittiDataset(config, set=set, balance_classes=False) + test_sampler = SemanticKittiSampler(test_dataset) + collate_fn = SemanticKittiCollate + else: + raise ValueError('Unsupported dataset : ' + config.dataset) + + # Data loader + test_loader = DataLoader(test_dataset, + batch_size=1, + sampler=test_sampler, + collate_fn=collate_fn, + num_workers=config.input_threads, + pin_memory=True) + + # Calibrate samplers + test_sampler.calibration(test_loader, verbose=True) + + print('\nModel Preparation') + print('*****************') + + # Define network model + t1 = time.time() + if config.dataset_task == 'classification': + net = KPCNN(config) + elif config.dataset_task in ['cloud_segmentation', 'slam_segmentation']: + net = KPFCNN(config, test_dataset.label_values, test_dataset.ignored_labels) + else: + raise ValueError('Unsupported dataset_task for testing: ' + config.dataset_task) + + # Define a visualizer class + tester = ModelTester(net, chkp_path=chosen_chkp) + print('Done in {:.1f}s\n'.format(time.time() - t1)) + + print('\nStart test') + print('**********\n') + + # Training + if config.dataset_task == 'classification': + a = 1/0 + elif config.dataset_task == 'cloud_segmentation': + tester.cloud_segmentation_test(net, test_loader, config) + elif config.dataset_task == 'slam_segmentation': + tester.slam_segmentation_test(net, test_loader, config) + else: + raise ValueError('Unsupported dataset_task for testing: ' + config.dataset_task) + + + # TODO: For test and also for training. When changing epoch do not restart the worker initiation. Keep workers + # active with a while loop instead of using for loops. + # For training and validation, keep two sets of worker active in parallel? is it possible? + + # TODO: We have to verify if training on smaller spheres and testing on whole frame changes the score because + # batchnorm may not have the same result as distribution of points will be different. + diff --git a/train_S3DIS.py b/train_S3DIS.py index bf5cef4..250289a 100644 --- a/train_S3DIS.py +++ b/train_S3DIS.py @@ -73,12 +73,23 @@ class S3DISConfig(Config): 'resnetb', 'resnetb_strided', 'resnetb', - 'resnetb_strided', + 'resnetb', 'resnetb', 'resnetb_strided', 'resnetb', + 'resnetb', + 'resnetb', + 'resnetb', + 'resnetb', 'resnetb_strided', 'resnetb', + 'resnetb', + 'resnetb', + 'resnetb', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', 'nearest_upsample', 'unary', 'nearest_upsample', @@ -93,13 +104,13 @@ class S3DISConfig(Config): ################### # Radius of the input sphere - in_radius = 2.0 + in_radius = 2.5 # Number of kernel points num_kernel_points = 15 # Size of the first subsampling grid in meter - first_subsampling_dl = 0.04 + first_subsampling_dl = 0.03 # Radius of convolution in "number grid cell". (2.5 is the standard value) conv_radius = 2.5 @@ -108,7 +119,7 @@ class S3DISConfig(Config): deform_radius = 6.0 # Radius of the area of influence of each kernel point in "number grid cell". (1.0 is the standard value) - KP_extent = 1.2 + KP_extent = 1.5 # Behavior of convolutions in ('constant', 'linear', 'gaussian') KP_influence = 'linear' @@ -117,7 +128,7 @@ class S3DISConfig(Config): aggregation_mode = 'sum' # Choice of input features - first_features_dim = 64 + first_features_dim = 128 in_features_dim = 5 # Can the network learn modulations @@ -143,17 +154,17 @@ class S3DISConfig(Config): # Learning rate management learning_rate = 1e-2 momentum = 0.98 - lr_decays = {i: 0.1**(1/100) for i in range(1, max_epoch)} + lr_decays = {i: 0.1 ** (1 / 150) for i in range(1, max_epoch)} grad_clip_norm = 100.0 # Number of batch - batch_num = 10 + batch_num = 4 # Number of steps per epochs epoch_steps = 500 # Number of validation examples per epoch - validation_size = 30 + validation_size = 50 # Number of epoch between each checkpoint checkpoint_gap = 50 @@ -191,7 +202,7 @@ if __name__ == '__main__': ############################ # Set which gpu is going to be used - GPU_ID = '2' + GPU_ID = '3' # Set GPU visible device os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID @@ -201,7 +212,7 @@ if __name__ == '__main__': ############### # Choose here if you want to start training from a previous snapshot (None for new training) - #previous_training_path = 'Log_2020-03-19_19-53-27' + # previous_training_path = 'Log_2020-03-19_19-53-27' previous_training_path = '' # Choose index of checkpoint to start from. If None, uses the latest chkp @@ -266,16 +277,16 @@ if __name__ == '__main__': training_sampler.calibration(training_loader, verbose=True) test_sampler.calibration(test_loader, verbose=True) - #debug_timing(training_dataset, training_loader) - #debug_timing(test_dataset, test_loader) - #debug_upsampling(training_dataset, training_loader) + # debug_timing(training_dataset, training_loader) + # debug_timing(test_dataset, test_loader) + # debug_upsampling(training_dataset, training_loader) print('\nModel Preparation') print('*****************') # Define network model t1 = time.time() - net = KPFCNN(config) + net = KPFCNN(config, training_dataset.label_values, training_dataset.ignored_labels) debug = False if debug: @@ -297,14 +308,7 @@ if __name__ == '__main__': print('**************') # Training - try: - trainer.train(net, training_loader, test_loader, config) - except: - print('Caught an error') - os.kill(os.getpid(), signal.SIGINT) + trainer.train(net, training_loader, test_loader, config) print('Forcing exit now') os.kill(os.getpid(), signal.SIGINT) - - - diff --git a/train_SemanticKitti.py b/train_SemanticKitti.py new file mode 100644 index 0000000..93292fe --- /dev/null +++ b/train_SemanticKitti.py @@ -0,0 +1,321 @@ +# +# +# 0=================================0 +# | Kernel Point Convolutions | +# 0=================================0 +# +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Callable script to start a training on SemanticKitti dataset +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Hugues THOMAS - 06/03/2020 +# + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Imports and global variables +# \**********************************/ +# + +# Common libs +import signal +import os +import numpy as np +import sys +import torch + +# Dataset +from datasets.SemanticKitti import * +from torch.utils.data import DataLoader + +from utils.config import Config +from utils.trainer import ModelTrainer +from models.architectures import KPFCNN + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Config Class +# \******************/ +# + +class SemanticKittiConfig(Config): + """ + Override the parameters you want to modify for this dataset + """ + + #################### + # Dataset parameters + #################### + + # Dataset name + dataset = 'SemanticKitti' + + # Number of classes in the dataset (This value is overwritten by dataset class when Initializating dataset). + num_classes = None + + # Type of task performed on this dataset (also overwritten) + dataset_task = '' + + # Number of CPU threads for the input pipeline + input_threads = 20 + + ######################### + # Architecture definition + ######################### + + # Define layers + architecture = ['simple', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', + 'resnetb', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', + 'nearest_upsample', + 'unary', + 'nearest_upsample', + 'unary', + 'nearest_upsample', + 'unary', + 'nearest_upsample', + 'unary'] + + ################### + # KPConv parameters + ################### + + # Radius of the input sphere + in_radius = 10.0 + val_radius = 51.0 + n_frames = 1 + max_in_points = 10000 + max_val_points = 50000 + + # Number of batch + batch_num = 6 + val_batch_num = 1 + + # Number of kernel points + num_kernel_points = 15 + + # Size of the first subsampling grid in meter + first_subsampling_dl = 0.08 + + # Radius of convolution in "number grid cell". (2.5 is the standard value) + conv_radius = 2.5 + + # Radius of deformable convolution in "number grid cell". Larger so that deformed kernel can spread out + deform_radius = 6.0 + + # Radius of the area of influence of each kernel point in "number grid cell". (1.0 is the standard value) + KP_extent = 1.5 + + # Behavior of convolutions in ('constant', 'linear', 'gaussian') + KP_influence = 'linear' + + # Aggregation function of KPConv in ('closest', 'sum') + aggregation_mode = 'sum' + + # Choice of input features + first_features_dim = 128 + in_features_dim = 5 + + # Can the network learn modulations + modulated = False + + # Batch normalization parameters + use_batch_norm = True + batch_norm_momentum = 0.02 + + # Offset loss + # 'permissive' only constrains offsets inside the deform radius (NOT implemented yet) + # 'fitting' helps deformed kernels to adapt to the geometry by penalizing distance to input points + offsets_loss = 'fitting' + offsets_decay = 0.01 + + ##################### + # Training parameters + ##################### + + # Maximal number of epochs + max_epoch = 500 + + # Learning rate management + learning_rate = 1e-2 + momentum = 0.98 + lr_decays = {i: 0.1 ** (1 / 100) for i in range(1, max_epoch)} + grad_clip_norm = 100.0 + + # Number of steps per epochs + epoch_steps = 500 + + # Number of validation examples per epoch + validation_size = 50 + + # Number of epoch between each checkpoint + checkpoint_gap = 50 + + # Augmentations + augment_scale_anisotropic = True + augment_symmetries = [True, False, False] + augment_rotation = 'vertical' + augment_scale_min = 0.8 + augment_scale_max = 1.2 + augment_noise = 0.001 + augment_color = 0.8 + + # Choose weights for class (used in segmentation loss). Empty list for no weights + class_w = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + + # Do we nee to save convergence + saving = True + saving_path = None + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Main Call +# \***************/ +# + +if __name__ == '__main__': + + ############################ + # Initialize the environment + ############################ + + # Set which gpu is going to be used + GPU_ID = '2' + + # Set GPU visible device + os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID + + ############### + # Previous chkp + ############### + + # Choose here if you want to start training from a previous snapshot (None for new training) + # previous_training_path = 'Log_2020-03-19_19-53-27' + previous_training_path = '' + + # Choose index of checkpoint to start from. If None, uses the latest chkp + chkp_idx = None + if previous_training_path: + + # Find all snapshot in the chosen training folder + chkp_path = os.path.join('results', previous_training_path, 'checkpoints') + chkps = [f for f in os.listdir(chkp_path) if f[:4] == 'chkp'] + + # Find which snapshot to restore + if chkp_idx is None: + chosen_chkp = 'current_chkp.tar' + else: + chosen_chkp = np.sort(chkps)[chkp_idx] + chosen_chkp = os.path.join('results', previous_training_path, 'checkpoints', chosen_chkp) + + else: + chosen_chkp = None + + ############## + # Prepare Data + ############## + + print() + print('Data Preparation') + print('****************') + + # Initialize configuration class + config = SemanticKittiConfig() + if previous_training_path: + config.load(os.path.join('results', previous_training_path)) + config.saving_path = None + + # Get path from argument if given + if len(sys.argv) > 1: + config.saving_path = sys.argv[1] + + # Initialize datasets + training_dataset = SemanticKittiDataset(config, set='training', + balance_classes=True) + test_dataset = SemanticKittiDataset(config, set='validation', + balance_classes=False) + + # Initialize samplers + training_sampler = SemanticKittiSampler(training_dataset) + test_sampler = SemanticKittiSampler(test_dataset) + + # Initialize the dataloader + training_loader = DataLoader(training_dataset, + batch_size=1, + sampler=training_sampler, + collate_fn=SemanticKittiCollate, + num_workers=config.input_threads, + pin_memory=True) + test_loader = DataLoader(test_dataset, + batch_size=1, + sampler=test_sampler, + collate_fn=SemanticKittiCollate, + num_workers=config.input_threads, + pin_memory=True) + + # Calibrate max_in_point value + training_sampler.calib_max_in(config, training_loader, verbose=False) + test_sampler.calib_max_in(config, test_loader, verbose=False) + + # Calibrate samplers + training_sampler.calibration(training_loader, verbose=True) + test_sampler.calibration(test_loader, verbose=True) + + # debug_timing(training_dataset, training_loader) + # debug_timing(test_dataset, test_loader) + debug_class_w(training_dataset, training_loader) + + print('\nModel Preparation') + print('*****************') + + # Define network model + t1 = time.time() + net = KPFCNN(config, training_dataset.label_values, training_dataset.ignored_labels) + + debug = False + if debug: + print('\n*************************************\n') + print(net) + print('\n*************************************\n') + for param in net.parameters(): + if param.requires_grad: + print(param.shape) + print('\n*************************************\n') + print("Model size %i" % sum(param.numel() for param in net.parameters() if param.requires_grad)) + print('\n*************************************\n') + + # Define a trainer class + trainer = ModelTrainer(net, config, chkp_path=chosen_chkp) + print('Done in {:.1f}s\n'.format(time.time() - t1)) + + print('\nStart training') + print('**************') + + # Training + trainer.train(net, training_loader, test_loader, config) + + print('Forcing exit now') + os.kill(os.getpid(), signal.SIGINT) + + # TODO: Create a function debug_class_weights that shows class distribution in input sphere. Use that as + # indication for the class weights during training diff --git a/utils/config.py b/utils/config.py index ff092d6..12f0bbd 100644 --- a/utils/config.py +++ b/utils/config.py @@ -117,8 +117,10 @@ class Config: # For SLAM datasets like SemanticKitti number of frames used (minimum one) n_frames = 1 - # For SLAM datasets like SemanticKitti max number of point in input cloud + # For SLAM datasets like SemanticKitti max number of point in input cloud + validation max_in_points = 0 + val_radius = 51.0 + max_val_points = 50000 ##################### # Training parameters @@ -151,18 +153,19 @@ class Config: # Regularization loss importance weight_decay = 1e-3 - # The way we balance segmentation loss - # > 'none': Each point in the whole batch has the same contribution. - # > 'class': Each class has the same contribution (points are weighted according to class balance) - # > 'batch': Each cloud in the batch has the same contribution (points are weighted according cloud sizes) + # The way we balance segmentation loss DEPRECATED segloss_balance = 'none' + # Choose weights for class (used in segmentation loss). Empty list for no weights + class_w = [] + # Offset regularization loss offsets_loss = 'permissive' offsets_decay = 1e-2 # Number of batch batch_num = 10 + val_batch_num = 10 # Maximal number of epochs max_epoch = 1000 @@ -253,6 +256,9 @@ class Config: else: self.num_classes = int(line_info[2]) + elif line_info[0] == 'class_w': + self.class_w = [float(w) for w in line_info[2:]] + else: attr_type = type(getattr(self, line_info[0])) if attr_type == bool: @@ -320,6 +326,8 @@ class Config: text_file.write('modulated = {:d}\n'.format(int(self.modulated))) text_file.write('n_frames = {:d}\n'.format(self.n_frames)) text_file.write('max_in_points = {:d}\n\n'.format(self.max_in_points)) + text_file.write('max_val_points = {:d}\n\n'.format(self.max_val_points)) + text_file.write('val_radius = {:.3f}\n\n'.format(self.val_radius)) # Training parameters text_file.write('# Training parameters\n') @@ -350,9 +358,14 @@ class Config: text_file.write('weight_decay = {:f}\n'.format(self.weight_decay)) text_file.write('segloss_balance = {:s}\n'.format(self.segloss_balance)) + text_file.write('class_w =') + for a in self.class_w: + text_file.write(' {:.3f}'.format(a)) + text_file.write('\n') text_file.write('offsets_loss = {:s}\n'.format(self.offsets_loss)) text_file.write('offsets_decay = {:f}\n'.format(self.offsets_decay)) text_file.write('batch_num = {:d}\n'.format(self.batch_num)) + text_file.write('val_batch_num = {:d}\n'.format(self.val_batch_num)) text_file.write('max_epoch = {:d}\n'.format(self.max_epoch)) if self.epoch_steps is None: text_file.write('epoch_steps = None\n') diff --git a/utils/tester.py b/utils/tester.py new file mode 100644 index 0000000..a5eb2b9 --- /dev/null +++ b/utils/tester.py @@ -0,0 +1,688 @@ +# +# +# 0=================================0 +# | Kernel Point Convolutions | +# 0=================================0 +# +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Class handling the test of any model +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Hugues THOMAS - 11/06/2018 +# + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Imports and global variables +# \**********************************/ +# + + +# Basic libs +import torch +import torch.nn as nn +import numpy as np +from os import makedirs, listdir +from os.path import exists, join +import time +import json +from sklearn.neighbors import KDTree + +# PLY reader +from utils.ply import read_ply, write_ply + +# Metrics +from utils.metrics import IoU_from_confusions, fast_confusion +from sklearn.metrics import confusion_matrix + +#from utils.visualizer import show_ModelNet_models + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Tester Class +# \******************/ +# + + +class ModelTester: + + # Initialization methods + # ------------------------------------------------------------------------------------------------------------------ + + def __init__(self, net, chkp_path=None, on_gpu=True): + + ############ + # Parameters + ############ + + # Choose to train on CPU or GPU + if on_gpu and torch.cuda.is_available(): + self.device = torch.device("cuda:0") + else: + self.device = torch.device("cpu") + net.to(self.device) + + ########################## + # Load previous checkpoint + ########################## + + checkpoint = torch.load(chkp_path) + net.load_state_dict(checkpoint['model_state_dict']) + self.epoch = checkpoint['epoch'] + net.eval() + print("Model and training state restored.") + + return + + # Test main methods + # ------------------------------------------------------------------------------------------------------------------ + + def cloud_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False): + """ + Test method for cloud segmentation models + """ + + ############ + # Initialize + ############ + + # Choose test smoothing parameter (0 for no smothing, 0.99 for big smoothing) + test_smooth = 0.98 + softmax = torch.nn.Softmax(1) + + # Number of classes including ignored labels + nc_tot = test_loader.dataset.num_classes + + # Number of classes predicted by the model + nc_model = config.num_classes + + # Initiate global prediction over test clouds + self.test_probs = [np.zeros((l.shape[0], nc_model)) for l in test_loader.dataset.input_labels] + + # Test saving path + if config.saving: + test_path = join('test', config.saving_path.split('/')[-1]) + if not exists(test_path): + makedirs(test_path) + if not exists(join(test_path, 'predictions')): + makedirs(join(test_path, 'predictions')) + if not exists(join(test_path, 'probs')): + makedirs(join(test_path, 'probs')) + if not exists(join(test_path, 'potentials')): + makedirs(join(test_path, 'potentials')) + else: + test_path = None + + # If on validation directly compute score + if test_loader.dataset.set == 'validation': + val_proportions = np.zeros(nc_model, dtype=np.float32) + i = 0 + for label_value in test_loader.dataset.label_values: + if label_value not in test_loader.dataset.ignored_labels: + val_proportions[i] = np.sum([np.sum(labels == label_value) + for labels in test_loader.dataset.validation_labels]) + i += 1 + else: + val_proportions = None + + ##################### + # Network predictions + ##################### + + test_epoch = 0 + last_min = -0.5 + + t = [time.time()] + last_display = time.time() + mean_dt = np.zeros(1) + + # Start test loop + while True: + print('Initialize workers') + for i, batch in enumerate(test_loader): + + # New time + t = t[-1:] + t += [time.time()] + + if i == 0: + print('Done in {:.1f}s'.format(t[1] - t[0])) + + if 'cuda' in self.device.type: + batch.to(self.device) + + # Forward pass + outputs = net(batch, config) + + t += [time.time()] + + # Get probs and labels + stacked_probs = softmax(outputs).cpu().detach().numpy() + lengths = batch.lengths[0].cpu().numpy() + in_inds = batch.input_inds.cpu().numpy() + cloud_inds = batch.cloud_inds.cpu().numpy() + torch.cuda.synchronize(self.device) + + # Get predictions and labels per instance + # *************************************** + + i0 = 0 + for b_i, length in enumerate(lengths): + + # Get prediction + probs = stacked_probs[i0:i0 + length] + inds = in_inds[i0:i0 + length] + c_i = cloud_inds[b_i] + + # Update current probs in whole cloud + self.test_probs[c_i][inds] = test_smooth * self.test_probs[c_i][inds] + (1 - test_smooth) * probs + i0 += length + + # Average timing + t += [time.time()] + if i < 2: + mean_dt = np.array(t[1:]) - np.array(t[:-1]) + else: + mean_dt = 0.9 * mean_dt + 0.1 * (np.array(t[1:]) - np.array(t[:-1])) + + # Display + if (t[-1] - last_display) > 1.0: + last_display = t[-1] + message = 'e{:03d}-i{:04d} => {:.1f}% (timings : {:4.2f} {:4.2f} {:4.2f})' + print(message.format(test_epoch, i, + 100 * i / config.validation_size, + 1000 * (mean_dt[0]), + 1000 * (mean_dt[1]), + 1000 * (mean_dt[2]))) + + # Update minimum od potentials + new_min = torch.min(test_loader.dataset.min_potentials) + print('Test epoch {:d}, end. Min potential = {:.1f}'.format(test_epoch, new_min)) + #print([np.mean(pots) for pots in test_loader.dataset.potentials]) + + # Save predicted cloud + if last_min + 1 < new_min: + + # Update last_min + last_min += 1 + + # Show vote results (On subcloud so it is not the good values here) + if test_loader.dataset.set == 'validation': + print('\nConfusion on sub clouds') + Confs = [] + for i, file_path in enumerate(test_loader.dataset.files): + + # Insert false columns for ignored labels + probs = np.array(self.test_probs[i], copy=True) + for l_ind, label_value in enumerate(test_loader.dataset.label_values): + if label_value in test_loader.dataset.ignored_labels: + probs = np.insert(probs, l_ind, 0, axis=1) + + # Predicted labels + preds = test_loader.dataset.label_values[np.argmax(probs, axis=1)].astype(np.int32) + + # Targets + targets = test_loader.dataset.input_labels[i] + + # Confs + Confs += [fast_confusion(targets, preds, test_loader.dataset.label_values)] + + # Regroup confusions + C = np.sum(np.stack(Confs), axis=0).astype(np.float32) + + # Remove ignored labels from confusions + for l_ind, label_value in reversed(list(enumerate(test_loader.dataset.label_values))): + if label_value in test_loader.dataset.ignored_labels: + C = np.delete(C, l_ind, axis=0) + C = np.delete(C, l_ind, axis=1) + + # Rescale with the right number of point per class + C *= np.expand_dims(val_proportions / (np.sum(C, axis=1) + 1e-6), 1) + + # Compute IoUs + IoUs = IoU_from_confusions(C) + mIoU = np.mean(IoUs) + s = '{:5.2f} | '.format(100 * mIoU) + for IoU in IoUs: + s += '{:5.2f} '.format(100 * IoU) + print(s + '\n') + + # Save real IoU once in a while + if int(np.ceil(new_min)) % 10 == 0: + + # Project predictions + print('\nReproject Vote #{:d}'.format(int(np.floor(new_min)))) + t1 = time.time() + proj_probs = [] + for i, file_path in enumerate(test_loader.dataset.files): + + print(i, file_path, test_loader.dataset.test_proj[i].shape, self.test_probs[i].shape) + + print(test_loader.dataset.test_proj[i].dtype, np.max(test_loader.dataset.test_proj[i])) + print(test_loader.dataset.test_proj[i][:5]) + + # Reproject probs on the evaluations points + probs = self.test_probs[i][test_loader.dataset.test_proj[i], :] + proj_probs += [probs] + + t2 = time.time() + print('Done in {:.1f} s\n'.format(t2 - t1)) + + # Show vote results + if test_loader.dataset.set == 'validation': + print('Confusion on full clouds') + t1 = time.time() + Confs = [] + for i, file_path in enumerate(test_loader.dataset.files): + + # Insert false columns for ignored labels + for l_ind, label_value in enumerate(test_loader.dataset.label_values): + if label_value in test_loader.dataset.ignored_labels: + proj_probs[i] = np.insert(proj_probs[i], l_ind, 0, axis=1) + + # Get the predicted labels + preds = test_loader.dataset.label_values[np.argmax(proj_probs[i], axis=1)].astype(np.int32) + + # Confusion + targets = test_loader.dataset.validation_labels[i] + Confs += [fast_confusion(targets, preds, test_loader.dataset.label_values)] + + t2 = time.time() + print('Done in {:.1f} s\n'.format(t2 - t1)) + + # Regroup confusions + C = np.sum(np.stack(Confs), axis=0) + + # Remove ignored labels from confusions + for l_ind, label_value in reversed(list(enumerate(test_loader.dataset.label_values))): + if label_value in test_loader.dataset.ignored_labels: + C = np.delete(C, l_ind, axis=0) + C = np.delete(C, l_ind, axis=1) + + IoUs = IoU_from_confusions(C) + mIoU = np.mean(IoUs) + s = '{:5.2f} | '.format(100 * mIoU) + for IoU in IoUs: + s += '{:5.2f} '.format(100 * IoU) + print('-' * len(s)) + print(s) + print('-' * len(s) + '\n') + + # Save predictions + print('Saving clouds') + t1 = time.time() + for i, file_path in enumerate(test_loader.dataset.files): + + # Get file + points = test_loader.dataset.load_evaluation_points(file_path) + + # Get the predicted labels + preds = test_loader.dataset.label_values[np.argmax(proj_probs[i], axis=1)].astype(np.int32) + + # Save plys + cloud_name = file_path.split('/')[-1] + test_name = join(test_path, 'predictions', cloud_name) + write_ply(test_name, + [points, preds], + ['x', 'y', 'z', 'preds']) + test_name2 = join(test_path, 'probs', cloud_name) + prob_names = ['_'.join(test_loader.dataset.label_to_names[label].split()) + for label in test_loader.dataset.label_values] + write_ply(test_name2, + [points, proj_probs[i]], + ['x', 'y', 'z'] + prob_names) + + # Save potentials + pot_points = np.array(test_loader.dataset.pot_trees[i].data, copy=False) + pot_name = join(test_path, 'potentials', cloud_name) + pots = test_loader.dataset.potentials[i].numpy().astype(np.float32) + write_ply(pot_name, + [pot_points.astype(np.float32), pots], + ['x', 'y', 'z', 'pots']) + + # Save ascii preds + if test_loader.dataset.set == 'test': + if test_loader.dataset.name.startswith('Semantic3D'): + ascii_name = join(test_path, 'predictions', test_loader.dataset.ascii_files[cloud_name]) + else: + ascii_name = join(test_path, 'predictions', cloud_name[:-4] + '.txt') + np.savetxt(ascii_name, preds, fmt='%d') + + t2 = time.time() + print('Done in {:.1f} s\n'.format(t2 - t1)) + + test_epoch += 1 + + # Break when reaching number of desired votes + if last_min > num_votes: + break + + return + + def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False): + """ + Test method for slam segmentation models + """ + + ############ + # Initialize + ############ + + # Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing) + test_smooth = 0 + last_min = -0.5 + softmax = torch.nn.Softmax(1) + + # Number of classes including ignored labels + nc_tot = test_loader.dataset.num_classes + nc_model = net.C + + # Test saving path + test_path = None + report_path = None + if config.saving: + test_path = join('test', config.saving_path.split('/')[-1]) + if not exists(test_path): + makedirs(test_path) + report_path = join(test_path, 'reports') + if not exists(report_path): + makedirs(report_path) + + if test_loader.dataset.set == 'validation': + for folder in ['val_predictions', 'val_probs']: + if not exists(join(test_path, folder)): + makedirs(join(test_path, folder)) + else: + for folder in ['predictions', 'probs']: + if not exists(join(test_path, folder)): + makedirs(join(test_path, folder)) + + # Init validation container + all_f_preds = [] + all_f_labels = [] + if test_loader.dataset.set == 'validation': + for i, seq_frames in enumerate(test_loader.dataset.frames): + all_f_preds.append([np.zeros((0,), dtype=np.int32) for _ in seq_frames]) + all_f_labels.append([np.zeros((0,), dtype=np.int32) for _ in seq_frames]) + + ##################### + # Network predictions + ##################### + + predictions = [] + targets = [] + test_epoch = 0 + + t = [time.time()] + last_display = time.time() + mean_dt = np.zeros(1) + + # Start test loop + while True: + print('Initialize workers') + for i, batch in enumerate(test_loader): + + # New time + t = t[-1:] + t += [time.time()] + + if i == 0: + print('Done in {:.1f}s'.format(t[1] - t[0])) + + if 'cuda' in self.device.type: + batch.to(self.device) + + # Forward pass + outputs = net(batch, config) + + # Get probs and labels + stk_probs = softmax(outputs).cpu().detach().numpy() + lengths = batch.lengths[0].cpu().numpy() + f_inds = batch.frame_inds.cpu().numpy() + r_inds_list = batch.reproj_inds + r_mask_list = batch.reproj_masks + labels_list = batch.val_labels + torch.cuda.synchronize(self.device) + + t += [time.time()] + + # Get predictions and labels per instance + # *************************************** + + i0 = 0 + for b_i, length in enumerate(lengths): + + # Get prediction + probs = stk_probs[i0:i0 + length] + proj_inds = r_inds_list[b_i] + proj_mask = r_mask_list[b_i] + frame_labels = labels_list[b_i] + s_ind = f_inds[b_i, 0] + f_ind = f_inds[b_i, 1] + + # Project predictions on the frame points + proj_probs = probs[proj_inds] + + # Safe check if only one point: + if proj_probs.ndim < 2: + proj_probs = np.expand_dims(proj_probs, 0) + + # Save probs in a binary file (uint8 format for lighter weight) + seq_name = test_loader.dataset.sequences[s_ind] + if test_loader.dataset.set == 'validation': + folder = 'val_probs' + pred_folder = 'val_predictions' + else: + folder = 'probs' + pred_folder = 'predictions' + filename = '{:s}_{:07d}.npy'.format(seq_name, f_ind) + filepath = join(test_path, folder, filename) + if exists(filepath): + frame_probs_uint8 = np.load(filepath) + else: + frame_probs_uint8 = np.zeros((proj_mask.shape[0], nc_model), dtype=np.uint8) + frame_probs = frame_probs_uint8[proj_mask, :].astype(np.float32) / 255 + frame_probs = test_smooth * frame_probs + (1 - test_smooth) * proj_probs + frame_probs_uint8[proj_mask, :] = (frame_probs * 255).astype(np.uint8) + np.save(filepath, frame_probs_uint8) + + # Save some prediction in ply format for visual + if test_loader.dataset.set == 'validation': + + # Insert false columns for ignored labels + for l_ind, label_value in enumerate(test_loader.dataset.label_values): + if label_value in test_loader.dataset.ignored_labels: + frame_probs_uint8 = np.insert(frame_probs_uint8, l_ind, 0, axis=1) + + # Predicted labels + frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8, + axis=1)].astype(np.int32) + + # Save some of the frame pots + if f_ind % 20 == 0: + seq_path = join(test_loader.dataset.path, 'sequences', test_loader.dataset.sequences[s_ind]) + velo_file = join(seq_path, 'velodyne', test_loader.dataset.frames[s_ind][f_ind] + '.bin') + frame_points = np.fromfile(velo_file, dtype=np.float32) + frame_points = frame_points.reshape((-1, 4)) + predpath = join(test_path, pred_folder, filename[:-4] + '.ply') + #pots = test_loader.dataset.f_potentials[s_ind][f_ind] + pots = np.zeros((0,)) + if pots.shape[0] > 0: + write_ply(predpath, + [frame_points[:, :3], frame_labels, frame_preds, pots], + ['x', 'y', 'z', 'gt', 'pre', 'pots']) + else: + write_ply(predpath, + [frame_points[:, :3], frame_labels, frame_preds], + ['x', 'y', 'z', 'gt', 'pre']) + + # keep frame preds in memory + all_f_preds[s_ind][f_ind] = frame_preds + all_f_labels[s_ind][f_ind] = frame_labels + + else: + + # Save some of the frame preds + if f_inds[b_i, 1] % 100 == 0: + + # Insert false columns for ignored labels + for l_ind, label_value in enumerate(test_loader.dataset.label_values): + if label_value in test_loader.dataset.ignored_labels: + frame_probs_uint8 = np.insert(frame_probs_uint8, l_ind, 0, axis=1) + + # Predicted labels + frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8, + axis=1)].astype(np.int32) + + # Load points + seq_path = join(test_loader.dataset.path, 'sequences', test_loader.dataset.sequences[s_ind]) + velo_file = join(seq_path, 'velodyne', test_loader.dataset.frames[s_ind][f_ind] + '.bin') + frame_points = np.fromfile(velo_file, dtype=np.float32) + frame_points = frame_points.reshape((-1, 4)) + predpath = join(test_path, pred_folder, filename[:-4] + '.ply') + #pots = test_loader.dataset.f_potentials[s_ind][f_ind] + pots = np.zeros((0,)) + if pots.shape[0] > 0: + write_ply(predpath, + [frame_points[:, :3], frame_preds, pots], + ['x', 'y', 'z', 'pre', 'pots']) + else: + write_ply(predpath, + [frame_points[:, :3], frame_preds], + ['x', 'y', 'z', 'pre']) + + # Stack all prediction for this epoch + i0 += length + + # Average timing + t += [time.time()] + mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) + + # Display + if (t[-1] - last_display) > 1.0: + last_display = t[-1] + message = 'e{:03d}-i{:04d} => {:.1f}% (timings : {:4.2f} {:4.2f} {:4.2f}) / pots {:d} => {:.1f}%' + min_pot = int(torch.floor(torch.min(test_loader.dataset.potentials))) + pot_num = torch.sum(test_loader.dataset.potentials > min_pot).type(torch.int32).item() + current_num = pot_num + (i0 + 1 - config.validation_size) * config.val_batch_num + print(message.format(test_epoch, i, + 100 * i / config.validation_size, + 1000 * (mean_dt[0]), + 1000 * (mean_dt[1]), + 1000 * (mean_dt[2]), + min_pot, + 100.0 * current_num / len(test_loader.dataset.potentials))) + + + # Update minimum od potentials + new_min = torch.min(test_loader.dataset.potentials) + print('Test epoch {:d}, end. Min potential = {:.1f}'.format(test_epoch, new_min)) + + if last_min + 1 < new_min: + + # Update last_min + last_min += 1 + + if test_loader.dataset.set == 'validation' and last_min % 1 == 0: + + ##################################### + # Results on the whole validation set + ##################################### + + # Confusions for our subparts of validation set + Confs = np.zeros((len(predictions), nc_tot, nc_tot), dtype=np.int32) + for i, (preds, truth) in enumerate(zip(predictions, targets)): + + # Confusions + Confs[i, :, :] = fast_confusion(truth, preds, test_loader.dataset.label_values).astype(np.int32) + + + # Show vote results + print('\nCompute confusion') + + val_preds = [] + val_labels = [] + t1 = time.time() + for i, seq_frames in enumerate(test_loader.dataset.frames): + val_preds += [np.hstack(all_f_preds[i])] + val_labels += [np.hstack(all_f_labels[i])] + val_preds = np.hstack(val_preds) + val_labels = np.hstack(val_labels) + t2 = time.time() + C_tot = fast_confusion(val_labels, val_preds, test_loader.dataset.label_values) + t3 = time.time() + print(' Stacking time : {:.1f}s'.format(t2 - t1)) + print('Confusion time : {:.1f}s'.format(t3 - t2)) + + s1 = '\n' + for cc in C_tot: + for c in cc: + s1 += '{:7.0f} '.format(c) + s1 += '\n' + if debug: + print(s1) + + # Remove ignored labels from confusions + for l_ind, label_value in reversed(list(enumerate(test_loader.dataset.label_values))): + if label_value in test_loader.dataset.ignored_labels: + C_tot = np.delete(C_tot, l_ind, axis=0) + C_tot = np.delete(C_tot, l_ind, axis=1) + + # Objects IoU + val_IoUs = IoU_from_confusions(C_tot) + + # Compute IoUs + mIoU = np.mean(val_IoUs) + s2 = '{:5.2f} | '.format(100 * mIoU) + for IoU in val_IoUs: + s2 += '{:5.2f} '.format(100 * IoU) + print(s2 + '\n') + + # Save a report + report_file = join(report_path, 'report_{:04d}.txt'.format(int(np.floor(last_min)))) + str = 'Report of the confusion and metrics\n' + str += '***********************************\n\n\n' + str += 'Confusion matrix:\n\n' + str += s1 + str += '\nIoU values:\n\n' + str += s2 + str += '\n\n' + with open(report_file, 'w') as f: + f.write(str) + + test_epoch += 1 + + # Break when reaching number of desired votes + if last_min > num_votes: + break + + return + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/utils/trainer.py b/utils/trainer.py index f05fa01..fff958e 100644 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -37,13 +37,13 @@ import sys from utils.ply import read_ply, write_ply # Metrics -from utils.metrics import IoU_from_confusions +from utils.metrics import IoU_from_confusions, fast_confusion from utils.config import Config -from sklearn.metrics import confusion_matrix from sklearn.neighbors import KDTree from models.blocks import KPConv + # ---------------------------------------------------------------------------------------------------------------------- # # Trainer Class @@ -370,14 +370,14 @@ class ModelTrainer: validation_labels = np.array(val_loader.dataset.label_values) # Compute classification results - C1 = confusion_matrix(targets, - np.argmax(probs, axis=1), - validation_labels) + C1 = fast_confusion(targets, + np.argmax(probs, axis=1), + validation_labels) # Compute votes confusion - C2 = confusion_matrix(val_loader.dataset.input_labels, - np.argmax(self.val_probs, axis=1), - validation_labels) + C2 = fast_confusion(val_loader.dataset.input_labels, + np.argmax(self.val_probs, axis=1), + validation_labels) # Saving (optionnal) @@ -406,7 +406,7 @@ class ModelTrainer: return C1 - def cloud_segmentation_validation(self, net, val_loader, config): + def cloud_segmentation_validation(self, net, val_loader, config, debug=False): """ Validation method for cloud segmentation models """ @@ -415,6 +415,8 @@ class ModelTrainer: # Initialize ############ + t0 = time.time() + # Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing) val_smooth = 0.95 softmax = torch.nn.Softmax(1) @@ -455,6 +457,9 @@ class ModelTrainer: last_display = time.time() mean_dt = np.zeros(1) + + t1 = time.time() + # Start validation loop for i, batch in enumerate(val_loader): @@ -509,6 +514,8 @@ class ModelTrainer: 1000 * (mean_dt[0]), 1000 * (mean_dt[1]))) + t2 = time.time() + # Confusions for our subparts of validation set Confs = np.zeros((len(predictions), nc_tot, nc_tot), dtype=np.int32) for i, (probs, truth) in enumerate(zip(predictions, targets)): @@ -522,7 +529,10 @@ class ModelTrainer: preds = val_loader.dataset.label_values[np.argmax(probs, axis=1)] # Confusions - Confs[i, :, :] = confusion_matrix(truth, preds, val_loader.dataset.label_values) + Confs[i, :, :] = fast_confusion(truth, preds, val_loader.dataset.label_values).astype(np.int32) + + + t3 = time.time() # Sum all confusions C = np.sum(Confs, axis=0).astype(np.float32) @@ -536,9 +546,14 @@ class ModelTrainer: # Balance with real validation proportions C *= np.expand_dims(self.val_proportions / (np.sum(C, axis=1) + 1e-6), 1) + + t4 = time.time() + # Objects IoU IoUs = IoU_from_confusions(C) + t5 = time.time() + # Saving (optionnal) if config.saving: @@ -563,17 +578,17 @@ class ModelTrainer: pot_path = join(config.saving_path, 'potentials') if not exists(pot_path): makedirs(pot_path) - files = val_loader.dataset.train_files - i_val = 0 + files = val_loader.dataset.files for i, file_path in enumerate(files): - if val_loader.dataset.all_splits[i] == val_loader.dataset.validation_split: - pot_points = np.array(val_loader.dataset.pot_trees[i_val].data, copy=False) - cloud_name = file_path.split('/')[-1] - pot_name = join(pot_path, cloud_name) - pots = val_loader.dataset.potentials[i_val].numpy().astype(np.float32) - write_ply(pot_name, - [pot_points.astype(np.float32), pots], - ['x', 'y', 'z', 'pots']) + pot_points = np.array(val_loader.dataset.pot_trees[i].data, copy=False) + cloud_name = file_path.split('/')[-1] + pot_name = join(pot_path, cloud_name) + pots = val_loader.dataset.potentials[i].numpy().astype(np.float32) + write_ply(pot_name, + [pot_points.astype(np.float32), pots], + ['x', 'y', 'z', 'pots']) + + t6 = time.time() # Print instance mean mIoU = 100 * np.mean(IoUs) @@ -581,904 +596,84 @@ class ModelTrainer: # Save predicted cloud occasionally if config.saving and (self.epoch + 1) % config.checkpoint_gap == 0: - val_path = join(config.saving_path, 'val_preds_{:d}'.format(self.epoch)) + val_path = join(config.saving_path, 'val_preds_{:d}'.format(self.epoch + 1)) if not exists(val_path): makedirs(val_path) - files = val_loader.dataset.train_files - i_val = 0 + files = val_loader.dataset.files for i, file_path in enumerate(files): - if val_loader.dataset.all_splits[i] == val_loader.dataset.validation_split: - # Get points - points = val_loader.dataset.load_evaluation_points(file_path) + # Get points + points = val_loader.dataset.load_evaluation_points(file_path) - # Get probs on our own ply points - sub_probs = self.validation_probs[i_val] - - # Insert false columns for ignored labels - for l_ind, label_value in enumerate(val_loader.dataset.label_values): - if label_value in val_loader.dataset.ignored_labels: - sub_probs = np.insert(sub_probs, l_ind, 0, axis=1) - - # Get the predicted labels - sub_preds = val_loader.dataset.label_values[np.argmax(sub_probs, axis=1).astype(np.int32)] - - # Reproject preds on the evaluations points - preds = (sub_preds[val_loader.dataset.validation_proj[i_val]]).astype(np.int32) - - # Path of saved validation file - cloud_name = file_path.split('/')[-1] - val_name = join(val_path, cloud_name) - - # Save file - labels = val_loader.dataset.validation_labels[i_val].astype(np.int32) - write_ply(val_name, - [points, preds, labels], - ['x', 'y', 'z', 'preds', 'class']) - - i_val += 1 - - return - - - - - - def validation_error(self, model, dataset): - """ - Validation method for classification models - """ - - ############ - # Initialize - ############ - - # Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing) - val_smooth = 0.95 - - # Initialise iterator with train data - self.sess.run(dataset.val_init_op) - - # Number of classes predicted by the model - nc_model = model.config.num_classes - - # Initialize global prediction over all models - if not hasattr(self, 'val_probs'): - self.val_probs = np.zeros((len(dataset.input_labels['validation']), nc_model)) - - ##################### - # Network predictions - ##################### - - probs = [] - targets = [] - obj_inds = [] - - mean_dt = np.zeros(2) - last_display = time.time() - while True: - try: - # Run one step of the model. - t = [time.time()] - ops = (self.prob_logits, model.labels, model.inputs['object_inds']) - prob, labels, inds = self.sess.run(ops, {model.dropout_prob: 1.0}) - t += [time.time()] - - # Get probs and labels - probs += [prob] - targets += [labels] - obj_inds += [inds] - - # Average timing - t += [time.time()] - mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) - - # Display - if (t[-1] - last_display) > 1.0: - last_display = t[-1] - message = 'Validation : {:.1f}% (timings : {:4.2f} {:4.2f})' - print(message.format(100 * len(obj_inds) / model.config.validation_size, - 1000 * (mean_dt[0]), - 1000 * (mean_dt[1]))) - - except tf.errors.OutOfRangeError: - break - - # Stack all validation predictions - probs = np.vstack(probs) - targets = np.hstack(targets) - obj_inds = np.hstack(obj_inds) - - ################### - # Voting validation - ################### - - self.val_probs[obj_inds] = val_smooth * self.val_probs[obj_inds] + (1-val_smooth) * probs - - ############ - # Confusions - ############ - - validation_labels = np.array(dataset.label_values) - - # Compute classification results - C1 = confusion_matrix(targets, - np.argmax(probs, axis=1), - validation_labels) - - # Compute training confusion - C2 = confusion_matrix(self.training_labels, - self.training_preds, - validation_labels) - - # Compute votes confusion - C3 = confusion_matrix(dataset.input_labels['validation'], - np.argmax(self.val_probs, axis=1), - validation_labels) - - - # Saving (optionnal) - if model.config.saving: - print("Save confusions") - conf_list = [C1, C2, C3] - file_list = ['val_confs.txt', 'training_confs.txt', 'vote_confs.txt'] - for conf, conf_file in zip(conf_list, file_list): - test_file = join(model.saving_path, conf_file) - if exists(test_file): - with open(test_file, "a") as text_file: - for line in conf: - for value in line: - text_file.write('%d ' % value) - text_file.write('\n') - else: - with open(test_file, "w") as text_file: - for line in conf: - for value in line: - text_file.write('%d ' % value) - text_file.write('\n') - - train_ACC = 100 * np.sum(np.diag(C2)) / (np.sum(C2) + 1e-6) - val_ACC = 100 * np.sum(np.diag(C1)) / (np.sum(C1) + 1e-6) - vote_ACC = 100 * np.sum(np.diag(C3)) / (np.sum(C3) + 1e-6) - print('Accuracies : train = {:.1f}% / val = {:.1f}% / vote = {:.1f}%'.format(train_ACC, val_ACC, vote_ACC)) - - return C1 - - def segment_validation_error(self, model, dataset): - """ - Validation method for single object segmentation models - """ - - ########## - # Initialize - ########## - - # Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing) - val_smooth = 0.95 - - # Initialise iterator with train data - self.sess.run(dataset.val_init_op) - - # Number of classes predicted by the model - nc_model = model.config.num_classes - - # Initialize global prediction over all models - if not hasattr(self, 'val_probs'): - self.val_probs = [np.zeros((len(p_l), nc_model)) for p_l in dataset.input_point_labels['validation']] - - ##################### - # Network predictions - ##################### - - probs = [] - targets = [] - obj_inds = [] - mean_dt = np.zeros(2) - last_display = time.time() - for i0 in range(model.config.validation_size): - try: - # Run one step of the model. - t = [time.time()] - ops = (self.prob_logits, model.labels, model.inputs['in_batches'], model.inputs['object_inds']) - prob, labels, batches, o_inds = self.sess.run(ops, {model.dropout_prob: 1.0}) - t += [time.time()] - - # Get predictions and labels per instance - # *************************************** - - # Stack all validation predictions for each class separately - max_ind = np.max(batches) - for b_i, b in enumerate(batches): - - # Eliminate shadow indices - b = b[b < max_ind-0.5] - - # Stack all results - probs += [prob[b]] - targets += [labels[b]] - obj_inds += [o_inds[b_i]] - - # Average timing - t += [time.time()] - mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) - - # Display - if (t[-1] - last_display) > 1.0: - last_display = t[-1] - message = 'Validation : {:.1f}% (timings : {:4.2f} {:4.2f})' - print(message.format(100 * i0 / model.config.validation_size, - 1000 * (mean_dt[0]), - 1000 * (mean_dt[1]))) - - except tf.errors.OutOfRangeError: - break - - ################### - # Voting validation - ################### - - for o_i, o_probs in zip(obj_inds, probs): - self.val_probs[o_i] = val_smooth * self.val_probs[o_i] + (1 - val_smooth) * o_probs - - ############ - # Confusions - ############ - - # Confusion matrix for each instance - n_parts = model.config.num_classes - Confs = np.zeros((len(probs), n_parts, n_parts), dtype=np.int32) - for i, (pred, truth) in enumerate(zip(probs, targets)): - parts = [j for j in range(pred.shape[1])] - Confs[i, :, :] = confusion_matrix(truth, np.argmax(pred, axis=1), parts) - - # Objects IoU - IoUs = IoU_from_confusions(Confs) - - - # Compute votes confusion - Confs = np.zeros((len(self.val_probs), n_parts, n_parts), dtype=np.int32) - for i, (pred, truth) in enumerate(zip(self.val_probs, dataset.input_point_labels['validation'])): - parts = [j for j in range(pred.shape[1])] - Confs[i, :, :] = confusion_matrix(truth, np.argmax(pred, axis=1), parts) - - # Objects IoU - vote_IoUs = IoU_from_confusions(Confs) - - # Saving (optionnal) - if model.config.saving: - - IoU_list = [IoUs, vote_IoUs] - file_list = ['val_IoUs.txt', 'vote_IoUs.txt'] - for IoUs_to_save, IoU_file in zip(IoU_list, file_list): - - # Name of saving file - test_file = join(model.saving_path, IoU_file) - - # Line to write: - line = '' - for instance_IoUs in IoUs_to_save: - for IoU in instance_IoUs: - line += '{:.3f} '.format(IoU) - line = line + '\n' - - # Write in file - if exists(test_file): - with open(test_file, "a") as text_file: - text_file.write(line) - else: - with open(test_file, "w") as text_file: - text_file.write(line) - - # Print instance mean - mIoU = 100 * np.mean(IoUs) - mIoU2 = 100 * np.mean(vote_IoUs) - print('{:s} : mIoU = {:.1f}% / vote mIoU = {:.1f}%'.format(model.config.dataset, mIoU, mIoU2)) - - return - - def cloud_validation_error(self, model, dataset): - """ - Validation method for cloud segmentation models - """ - - ########## - # Initialize - ########## - - # Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing) - val_smooth = 0.95 - - # Do not validate if dataset has no validation cloud - if dataset.validation_split not in dataset.all_splits: - return - - # Initialise iterator with train data - self.sess.run(dataset.val_init_op) - - # Number of classes including ignored labels - nc_tot = dataset.num_classes - - # Number of classes predicted by the model - nc_model = model.config.num_classes - - # Initialize global prediction over validation clouds - if not hasattr(self, 'validation_probs'): - self.validation_probs = [np.zeros((l.shape[0], nc_model)) for l in dataset.input_labels['validation']] - self.val_proportions = np.zeros(nc_model, dtype=np.float32) - i = 0 - for label_value in dataset.label_values: - if label_value not in dataset.ignored_labels: - self.val_proportions[i] = np.sum([np.sum(labels == label_value) - for labels in dataset.validation_labels]) - i += 1 - - ##################### - # Network predictions - ##################### - - predictions = [] - targets = [] - mean_dt = np.zeros(2) - last_display = time.time() - for i0 in range(model.config.validation_size): - try: - # Run one step of the model. - t = [time.time()] - ops = (self.prob_logits, - model.labels, - model.inputs['in_batches'], - model.inputs['point_inds'], - model.inputs['cloud_inds']) - stacked_probs, labels, batches, point_inds, cloud_inds = self.sess.run(ops, {model.dropout_prob: 1.0}) - t += [time.time()] - - # Get predictions and labels per instance - # *************************************** - - # Stack all validation predictions for each class separately - max_ind = np.max(batches) - for b_i, b in enumerate(batches): - - # Eliminate shadow indices - b = b[b < max_ind-0.5] - - # Get prediction (only for the concerned parts) - probs = stacked_probs[b] - inds = point_inds[b] - c_i = cloud_inds[b_i] - - # Update current probs in whole cloud - self.validation_probs[c_i][inds] = val_smooth * self.validation_probs[c_i][inds] \ - + (1-val_smooth) * probs - - # Stack all prediction for this epoch - predictions += [probs] - targets += [dataset.input_labels['validation'][c_i][inds]] - - # Average timing - t += [time.time()] - mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) - - # Display - if (t[-1] - last_display) > 1.0: - last_display = t[-1] - message = 'Validation : {:.1f}% (timings : {:4.2f} {:4.2f})' - print(message.format(100 * i0 / model.config.validation_size, - 1000 * (mean_dt[0]), - 1000 * (mean_dt[1]))) - - except tf.errors.OutOfRangeError: - break - - # Confusions for our subparts of validation set - Confs = np.zeros((len(predictions), nc_tot, nc_tot), dtype=np.int32) - for i, (probs, truth) in enumerate(zip(predictions, targets)): - - # Insert false columns for ignored labels - for l_ind, label_value in enumerate(dataset.label_values): - if label_value in dataset.ignored_labels: - probs = np.insert(probs, l_ind, 0, axis=1) - - # Predicted labels - preds = dataset.label_values[np.argmax(probs, axis=1)] - - # Confusions - Confs[i, :, :] = confusion_matrix(truth, preds, dataset.label_values) - - # Sum all confusions - C = np.sum(Confs, axis=0).astype(np.float32) - - # Remove ignored labels from confusions - for l_ind, label_value in reversed(list(enumerate(dataset.label_values))): - if label_value in dataset.ignored_labels: - C = np.delete(C, l_ind, axis=0) - C = np.delete(C, l_ind, axis=1) - - # Balance with real validation proportions - C *= np.expand_dims(self.val_proportions / (np.sum(C, axis=1) + 1e-6), 1) - - # Objects IoU - IoUs = IoU_from_confusions(C) - - # Saving (optionnal) - if model.config.saving: - - # Name of saving file - test_file = join(model.saving_path, 'val_IoUs.txt') - - # Line to write: - line = '' - for IoU in IoUs: - line += '{:.3f} '.format(IoU) - line = line + '\n' - - # Write in file - if exists(test_file): - with open(test_file, "a") as text_file: - text_file.write(line) - else: - with open(test_file, "w") as text_file: - text_file.write(line) - - # Print instance mean - mIoU = 100 * np.mean(IoUs) - print('{:s} mean IoU = {:.1f}%'.format(model.config.dataset, mIoU)) - - # Save predicted cloud occasionally - if model.config.saving and (self.training_epoch + 1) % model.config.checkpoint_gap == 0: - val_path = join(model.saving_path, 'val_preds_{:d}'.format(self.training_epoch)) - if not exists(val_path): - makedirs(val_path) - files = dataset.train_files - i_val = 0 - for i, file_path in enumerate(files): - if dataset.all_splits[i] == dataset.validation_split: - - # Get points - points = dataset.load_evaluation_points(file_path) - - # Get probs on our own ply points - sub_probs = self.validation_probs[i_val] - - # Insert false columns for ignored labels - for l_ind, label_value in enumerate(dataset.label_values): - if label_value in dataset.ignored_labels: - sub_probs = np.insert(sub_probs, l_ind, 0, axis=1) - - # Get the predicted labels - sub_preds = dataset.label_values[np.argmax(sub_probs, axis=1).astype(np.int32)] - - # Reproject preds on the evaluations points - preds = (sub_preds[dataset.validation_proj[i_val]]).astype(np.int32) - - # Path of saved validation file - cloud_name = file_path.split('/')[-1] - val_name = join(val_path, cloud_name) - - # Save file - labels = dataset.validation_labels[i_val].astype(np.int32) - write_ply(val_name, - [points, preds, labels], - ['x', 'y', 'z', 'preds', 'class']) - - i_val += 1 - - return - - def multi_cloud_validation_error(self, model, multi_dataset): - """ - Validation method for cloud segmentation models - """ - - ########## - # Initialize - ########## - - # Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing) - val_smooth = 0.95 - - # Initialise iterator with train data - self.sess.run(multi_dataset.val_init_op) - - if not hasattr(self, 'validation_probs'): - - self.validation_probs = [] - self.val_proportions = [] - - for d_i, dataset in enumerate(multi_dataset.datasets): - - # Do not validate if dataset has no validation cloud - if dataset.validation_split not in dataset.all_splits: - continue - - # Number of classes including ignored labels - nc_tot = dataset.num_classes - - # Number of classes predicted by the model - nc_model = model.config.num_classes[d_i] - - # Initialize global prediction over validation clouds - self.validation_probs.append([np.zeros((l.shape[0], nc_model)) for l in dataset.input_labels['validation']]) - self.val_proportions.append(np.zeros(nc_model, dtype=np.float32)) - i = 0 - for label_value in dataset.label_values: - if label_value not in dataset.ignored_labels: - self.val_proportions[-1][i] = np.sum([np.sum(labels == label_value) - for labels in dataset.validation_labels]) - i += 1 - - ##################### - # Network predictions - ##################### - - pred_d_inds = [] - predictions = [] - targets = [] - mean_dt = np.zeros(2) - last_display = time.time() - for i0 in range(model.config.validation_size): - try: - # Run one step of the model. - t = [time.time()] - ops = (self.val_logits, - model.labels, - model.inputs['in_batches'], - model.inputs['point_inds'], - model.inputs['cloud_inds'], - model.inputs['dataset_inds']) - stacked_probs, labels, batches, p_inds, c_inds, d_inds = self.sess.run(ops, {model.dropout_prob: 1.0}) - t += [time.time()] - - # Get predictions and labels per instance - # *************************************** - - # Stack all validation predictions for each class separately - max_ind = np.max(batches) - for b_i, b in enumerate(batches): - - # Eliminate shadow indices - b = b[b < max_ind-0.5] - - # Get prediction (only for the concerned parts) - d_i = d_inds[b_i] - probs = stacked_probs[b, :model.config.num_classes[d_i]] - inds = p_inds[b] - c_i = c_inds[b_i] - - # Update current probs in whole cloud - self.validation_probs[d_i][c_i][inds] = val_smooth * self.validation_probs[d_i][c_i][inds] \ - + (1-val_smooth) * probs - - # Stack all prediction for this epoch - pred_d_inds += [d_i] - predictions += [probs] - targets += [multi_dataset.datasets[d_i].input_labels['validation'][c_i][inds]] - - # Average timing - t += [time.time()] - mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) - - # Display - if (t[-1] - last_display) > 1.0: - last_display = t[-1] - message = 'Validation : {:.1f}% (timings : {:4.2f} {:4.2f})' - print(message.format(100 * i0 / model.config.validation_size, - 1000 * (mean_dt[0]), - 1000 * (mean_dt[1]))) - - except tf.errors.OutOfRangeError: - break - - # Convert list to np array for indexing - predictions = np.array(predictions) - targets = np.array(targets) - pred_d_inds = np.array(pred_d_inds, np.int32) - - IoUs = [] - for d_i, dataset in enumerate(multi_dataset.datasets): - - # Do not validate if dataset has no validation cloud - if dataset.validation_split not in dataset.all_splits: - continue - - # Number of classes including ignored labels - nc_tot = dataset.num_classes - - # Number of classes predicted by the model - nc_model = model.config.num_classes[d_i] - - # Extract the spheres from this dataset - tmp_inds = np.where(pred_d_inds == d_i)[0] - - # Confusions for our subparts of validation set - Confs = np.zeros((len(tmp_inds), nc_tot, nc_tot), dtype=np.int32) - for i, (probs, truth) in enumerate(zip(predictions[tmp_inds], targets[tmp_inds])): + # Get probs on our own ply points + sub_probs = self.validation_probs[i] # Insert false columns for ignored labels - for l_ind, label_value in enumerate(dataset.label_values): - if label_value in dataset.ignored_labels: - probs = np.insert(probs, l_ind, 0, axis=1) + for l_ind, label_value in enumerate(val_loader.dataset.label_values): + if label_value in val_loader.dataset.ignored_labels: + sub_probs = np.insert(sub_probs, l_ind, 0, axis=1) - # Predicted labels - preds = dataset.label_values[np.argmax(probs, axis=1)] + # Get the predicted labels + sub_preds = val_loader.dataset.label_values[np.argmax(sub_probs, axis=1).astype(np.int32)] - # Confusions - Confs[i, :, :] = confusion_matrix(truth, preds, dataset.label_values) + # Reproject preds on the evaluations points + preds = (sub_preds[val_loader.dataset.test_proj[i]]).astype(np.int32) - # Sum all confusions - C = np.sum(Confs, axis=0).astype(np.float32) + # Path of saved validation file + cloud_name = file_path.split('/')[-1] + val_name = join(val_path, cloud_name) - # Remove ignored labels from confusions - for l_ind, label_value in reversed(list(enumerate(dataset.label_values))): - if label_value in dataset.ignored_labels: - C = np.delete(C, l_ind, axis=0) - C = np.delete(C, l_ind, axis=1) + # Save file + labels = val_loader.dataset.validation_labels[i].astype(np.int32) + write_ply(val_name, + [points, preds, labels], + ['x', 'y', 'z', 'preds', 'class']) - # Balance with real validation proportions - C *= np.expand_dims(self.val_proportions[d_i] / (np.sum(C, axis=1) + 1e-6), 1) - - # Objects IoU - IoUs += [IoU_from_confusions(C)] - - # Saving (optionnal) - if model.config.saving: - - # Name of saving file - test_file = join(model.saving_path, 'val_IoUs_{:d}_{:s}.txt'.format(d_i, dataset.name)) - - # Line to write: - line = '' - for IoU in IoUs[-1]: - line += '{:.3f} '.format(IoU) - line = line + '\n' - - # Write in file - if exists(test_file): - with open(test_file, "a") as text_file: - text_file.write(line) - else: - with open(test_file, "w") as text_file: - text_file.write(line) - - # Print instance mean - mIoU = 100 * np.mean(IoUs[-1]) - print('{:s} mean IoU = {:.1f}%'.format(dataset.name, mIoU)) - - # Save predicted cloud occasionally - if model.config.saving and (self.training_epoch + 1) % model.config.checkpoint_gap == 0: - val_path = join(model.saving_path, 'val_preds_{:d}'.format(self.training_epoch)) - if not exists(val_path): - makedirs(val_path) - - for d_i, dataset in enumerate(multi_dataset.datasets): - - dataset_val_path = join(val_path, '{:d}_{:s}'.format(d_i, dataset.name)) - if not exists(dataset_val_path): - makedirs(dataset_val_path) - - files = dataset.train_files - i_val = 0 - for i, file_path in enumerate(files): - if dataset.all_splits[i] == dataset.validation_split: - - # Get points - points = dataset.load_evaluation_points(file_path) - - # Get probs on our own ply points - sub_probs = self.validation_probs[d_i][i_val] - - # Insert false columns for ignored labels - for l_ind, label_value in enumerate(dataset.label_values): - if label_value in dataset.ignored_labels: - sub_probs = np.insert(sub_probs, l_ind, 0, axis=1) - - # Get the predicted labels - sub_preds = dataset.label_values[np.argmax(sub_probs, axis=1).astype(np.int32)] - - # Reproject preds on the evaluations points - preds = (sub_preds[dataset.validation_proj[i_val]]).astype(np.int32) - - # Path of saved validation file - cloud_name = file_path.split('/')[-1] - val_name = join(dataset_val_path, cloud_name) - - # Save file - labels = dataset.validation_labels[i_val].astype(np.int32) - write_ply(val_name, - [points, preds, labels], - ['x', 'y', 'z', 'preds', 'class']) - - i_val += 1 + # Display timings + t7 = time.time() + if debug: + print('\n************************\n') + print('Validation timings:') + print('Init ...... {:.1f}s'.format(t1 - t0)) + print('Loop ...... {:.1f}s'.format(t2 - t1)) + print('Confs ..... {:.1f}s'.format(t3 - t2)) + print('Confs bis . {:.1f}s'.format(t4 - t3)) + print('IoU ....... {:.1f}s'.format(t5 - t4)) + print('Save1 ..... {:.1f}s'.format(t6 - t5)) + print('Save2 ..... {:.1f}s'.format(t7 - t6)) + print('\n************************\n') return - def multi_validation_error(self, model, dataset): - """ - Validation method for multi object segmentation models - """ - - ########## - # Initialize - ########## - - # Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing) - val_smooth = 0.95 - - # Initialise iterator with train data - self.sess.run(dataset.val_init_op) - - # Initialize global prediction over all models - if not hasattr(self, 'val_probs'): - self.val_probs = [] - for p_l, o_l in zip(dataset.input_point_labels['validation'], dataset.input_labels['validation']): - self.val_probs += [np.zeros((len(p_l), dataset.num_parts[o_l]))] - - ##################### - # Network predictions - ##################### - - probs = [] - targets = [] - objects = [] - obj_inds = [] - mean_dt = np.zeros(2) - last_display = time.time() - for i0 in range(model.config.validation_size): - try: - # Run one step of the model. - t = [time.time()] - ops = (model.logits, - model.labels, - model.inputs['super_labels'], - model.inputs['object_inds'], - model.inputs['in_batches']) - prob, labels, object_labels, o_inds, batches = self.sess.run(ops, {model.dropout_prob: 1.0}) - t += [time.time()] - - # Get predictions and labels per instance - # *************************************** - - # Stack all validation predictions for each class separately - max_ind = np.max(batches) - for b_i, b in enumerate(batches): - - # Eliminate shadow indices - b = b[b < max_ind-0.5] - - # Get prediction (only for the concerned parts) - obj = object_labels[b[0]] - pred = prob[b][:, :model.config.num_classes[obj]] - - # Stack all results - objects += [obj] - obj_inds += [o_inds[b_i]] - probs += [prob[b, :model.config.num_classes[obj]]] - targets += [labels[b]] - - # Average timing - t += [time.time()] - mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) - - # Display - if (t[-1] - last_display) > 1.0: - last_display = t[-1] - message = 'Validation : {:.1f}% (timings : {:4.2f} {:4.2f})' - print(message.format(100 * i0 / model.config.validation_size, - 1000 * (mean_dt[0]), - 1000 * (mean_dt[1]))) - - except tf.errors.OutOfRangeError: - break - - ################### - # Voting validation - ################### - - for o_i, o_probs in zip(obj_inds, probs): - self.val_probs[o_i] = val_smooth * self.val_probs[o_i] + (1 - val_smooth) * o_probs - - ############ - # Confusions - ############ - - # Confusion matrix for each object - n_objs = [np.sum(np.array(objects) == l) for l in dataset.label_values] - Confs = [np.zeros((n_obj, n_parts, n_parts), dtype=np.int32) for n_parts, n_obj in - zip(dataset.num_parts, n_objs)] - obj_count = [0 for _ in n_objs] - for obj, pred, truth in zip(objects, probs, targets): - parts = [i for i in range(pred.shape[1])] - Confs[obj][obj_count[obj], :, :] = confusion_matrix(truth, np.argmax(pred, axis=1), parts) - obj_count[obj] += 1 - - # Objects mIoU - IoUs = [IoU_from_confusions(C) for C in Confs] - - - # Compute votes confusion - n_objs = [np.sum(np.array(dataset.input_labels['validation']) == l) for l in dataset.label_values] - Confs = [np.zeros((n_obj, n_parts, n_parts), dtype=np.int32) for n_parts, n_obj in - zip(dataset.num_parts, n_objs)] - obj_count = [0 for _ in n_objs] - for obj, pred, truth in zip(dataset.input_labels['validation'], - self.val_probs, - dataset.input_point_labels['validation']): - parts = [i for i in range(pred.shape[1])] - Confs[obj][obj_count[obj], :, :] = confusion_matrix(truth, np.argmax(pred, axis=1), parts) - obj_count[obj] += 1 - - # Objects mIoU - vote_IoUs = [IoU_from_confusions(C) for C in Confs] - - # Saving (optionnal) - if model.config.saving: - - IoU_list = [IoUs, vote_IoUs] - file_list = ['val_IoUs.txt', 'vote_IoUs.txt'] - - for IoUs_to_save, IoU_file in zip(IoU_list, file_list): - - # Name of saving file - test_file = join(model.saving_path, IoU_file) - - # Line to write: - line = '' - for obj_IoUs in IoUs_to_save: - for part_IoUs in obj_IoUs: - for IoU in part_IoUs: - line += '{:.3f} '.format(IoU) - line += '/ ' - line = line[:-2] + '\n' - - # Write in file - if exists(test_file): - with open(test_file, "a") as text_file: - text_file.write(line) - else: - with open(test_file, "w") as text_file: - text_file.write(line) - - # Print instance mean - mIoU = 100 * np.mean(np.hstack([np.mean(obj_IoUs, axis=1) for obj_IoUs in IoUs])) - class_mIoUs = [np.mean(obj_IoUs) for obj_IoUs in IoUs] - mcIoU = 100 * np.mean(class_mIoUs) - print('Val : mIoU = {:.1f}% / mcIoU = {:.1f}% '.format(mIoU, mcIoU)) - mIoU = 100 * np.mean(np.hstack([np.mean(obj_IoUs, axis=1) for obj_IoUs in vote_IoUs])) - class_mIoUs = [np.mean(obj_IoUs) for obj_IoUs in vote_IoUs] - mcIoU = 100 * np.mean(class_mIoUs) - print('Vote : mIoU = {:.1f}% / mcIoU = {:.1f}% '.format(mIoU, mcIoU)) - - return - - def slam_validation_error(self, model, dataset): + def slam_segmentation_validation(self, net, val_loader, config, debug=True): """ Validation method for slam segmentation models """ - ########## + ############ # Initialize - ########## + ############ + + t0 = time.time() # Do not validate if dataset has no validation cloud - if dataset.validation_split not in dataset.seq_splits: + if val_loader is None: return + # Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing) + val_smooth = 0.95 + softmax = torch.nn.Softmax(1) + # Create folder for validation predictions - if not exists (join(model.saving_path, 'val_preds')): - makedirs(join(model.saving_path, 'val_preds')) + if not exists (join(config.saving_path, 'val_preds')): + makedirs(join(config.saving_path, 'val_preds')) - # Initialize the dataset validation containers - dataset.val_points = [] - dataset.val_labels = [] - - # Initialise iterator with train data - self.sess.run(dataset.val_init_op) + # initiate the dataset validation containers + val_loader.dataset.val_points = [] + val_loader.dataset.val_labels = [] # Number of classes including ignored labels - nc_tot = dataset.num_classes - - # Number of classes predicted by the model - nc_model = model.config.num_classes + nc_tot = val_loader.dataset.num_classes ##################### # Network predictions @@ -1487,116 +682,121 @@ class ModelTrainer: predictions = [] targets = [] inds = [] - mean_dt = np.zeros(2) - last_display = time.time() val_i = 0 - for i0 in range(model.config.validation_size): - try: - # Run one step of the model. - t = [time.time()] - ops = (self.prob_logits, - model.labels, - model.inputs['points'][0], - model.inputs['in_batches'], - model.inputs['frame_inds'], - model.inputs['frame_centers'], - model.inputs['augment_scales'], - model.inputs['augment_rotations']) - s_probs, s_labels, s_points, batches, f_inds, p0s, S, R = self.sess.run(ops, {model.dropout_prob: 1.0}) - t += [time.time()] - # Get predictions and labels per instance - # *************************************** + t = [time.time()] + last_display = time.time() + mean_dt = np.zeros(1) - # Stack all validation predictions for each class separately - max_ind = np.max(batches) - for b_i, b in enumerate(batches): - # Eliminate shadow indices - b = b[b < max_ind-0.5] + t1 = time.time() - # Get prediction (only for the concerned parts) - probs = s_probs[b] - labels = s_labels[b] - points = s_points[b, :] - S_i = S[b_i] - R_i = R[b_i] - p0 = p0s[b_i] + # Start validation loop + for i, batch in enumerate(val_loader): - # Get input points in their original positions - points2 = (points * (1/S_i)).dot(R_i.T) + # New time + t = t[-1:] + t += [time.time()] - # get val_points that are in range - radiuses = np.sum(np.square(dataset.val_points[val_i] - p0), axis=1) - mask = radiuses < (0.9 * model.config.in_radius) ** 2 + if 'cuda' in self.device.type: + batch.to(self.device) - # Project predictions on the frame points - search_tree = KDTree(points2, leaf_size=50) - proj_inds = search_tree.query(dataset.val_points[val_i][mask, :], return_distance=False) - proj_inds = np.squeeze(proj_inds).astype(np.int32) - proj_probs = probs[proj_inds] - #proj_labels = labels[proj_inds] + # Forward pass + outputs = net(batch, config) - # Safe check if only one point: - if proj_probs.ndim < 2: - proj_probs = np.expand_dims(proj_probs, 0) + # Get probs and labels + stk_probs = softmax(outputs).cpu().detach().numpy() + lengths = batch.lengths[0].cpu().numpy() + f_inds = batch.frame_inds.cpu().numpy() + r_inds_list = batch.reproj_inds + r_mask_list = batch.reproj_masks + labels_list = batch.val_labels + torch.cuda.synchronize(self.device) - # Insert false columns for ignored labels - for l_ind, label_value in enumerate(dataset.label_values): - if label_value in dataset.ignored_labels: - proj_probs = np.insert(proj_probs, l_ind, 0, axis=1) + # Get predictions and labels per instance + # *************************************** - # Predicted labels - preds = dataset.label_values[np.argmax(proj_probs, axis=1)] + i0 = 0 + for b_i, length in enumerate(lengths): - # Save predictions in a binary file - filename ='{:02d}_{:07d}.npy'.format(f_inds[b_i, 0], f_inds[b_i, 1]) - filepath = join(model.saving_path, 'val_preds', filename) - if exists(filepath): - frame_preds = np.load(filepath) - else: - frame_preds = np.zeros(dataset.val_labels[val_i].shape, dtype=np.uint8) - frame_preds[mask] = preds.astype(np.uint8) - np.save(filepath, frame_preds) + # Get prediction + probs = stk_probs[i0:i0 + length] + proj_inds = r_inds_list[b_i] + proj_mask = r_mask_list[b_i] + frame_labels = labels_list[b_i] + s_ind = f_inds[b_i, 0] + f_ind = f_inds[b_i, 1] - # Save some of the frame pots - if f_inds[b_i, 1] % 10 == 0: - pots = dataset.f_potentials['validation'][f_inds[b_i, 0]][f_inds[b_i, 1]] - write_ply(filepath[:-4]+'_pots.ply', - [dataset.val_points[val_i], dataset.val_labels[val_i], frame_preds, pots], - ['x', 'y', 'z', 'gt', 'pre', 'pots']) + # Project predictions on the frame points + proj_probs = probs[proj_inds] - # Update validation confusions - frame_C = confusion_matrix(dataset.val_labels[val_i], frame_preds, dataset.label_values) - dataset.val_confs[f_inds[b_i, 0]][f_inds[b_i, 1], :, :] = frame_C + # Safe check if only one point: + if proj_probs.ndim < 2: + proj_probs = np.expand_dims(proj_probs, 0) - # Stack all prediction for this epoch - predictions += [preds] - targets += [dataset.val_labels[val_i][mask]] - inds += [f_inds[b_i, :]] - val_i += 1 + # Insert false columns for ignored labels + for l_ind, label_value in enumerate(val_loader.dataset.label_values): + if label_value in val_loader.dataset.ignored_labels: + proj_probs = np.insert(proj_probs, l_ind, 0, axis=1) - # Average timing - t += [time.time()] - mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) + # Predicted labels + preds = val_loader.dataset.label_values[np.argmax(proj_probs, axis=1)] - # Display - if (t[-1] - last_display) > 1.0: - last_display = t[-1] - message = 'Validation : {:.1f}% (timings : {:4.2f} {:4.2f})' - print(message.format(100 * i0 / model.config.validation_size, - 1000 * (mean_dt[0]), - 1000 * (mean_dt[1]))) + # Save predictions in a binary file + filename = '{:s}_{:07d}.npy'.format(val_loader.dataset.sequences[s_ind], f_ind) + filepath = join(config.saving_path, 'val_preds', filename) + if exists(filepath): + frame_preds = np.load(filepath) + else: + frame_preds = np.zeros(frame_labels.shape, dtype=np.uint8) + frame_preds[proj_mask] = preds.astype(np.uint8) + np.save(filepath, frame_preds) - except tf.errors.OutOfRangeError: - break + # Save some of the frame pots + if f_ind % 20 == 0: + seq_path = join(val_loader.dataset.path, 'sequences', val_loader.dataset.sequences[s_ind]) + velo_file = join(seq_path, 'velodyne', val_loader.dataset.frames[s_ind][f_ind] + '.bin') + frame_points = np.fromfile(velo_file, dtype=np.float32) + frame_points = frame_points.reshape((-1, 4)) + write_ply(filepath[:-4] + '_pots.ply', + [frame_points[:, :3], frame_labels, frame_preds], + ['x', 'y', 'z', 'gt', 'pre']) + + # Update validation confusions + frame_C = fast_confusion(frame_labels, + frame_preds.astype(np.int32), + val_loader.dataset.label_values) + val_loader.dataset.val_confs[s_ind][f_ind, :, :] = frame_C + + # Stack all prediction for this epoch + predictions += [preds] + targets += [frame_labels[proj_mask]] + inds += [f_inds[b_i, :]] + val_i += 1 + i0 += length + + # Average timing + t += [time.time()] + mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) + + # Display + if (t[-1] - last_display) > 1.0: + last_display = t[-1] + message = 'Validation : {:.1f}% (timings : {:4.2f} {:4.2f})' + print(message.format(100 * i / config.validation_size, + 1000 * (mean_dt[0]), + 1000 * (mean_dt[1]))) + + t2 = time.time() # Confusions for our subparts of validation set Confs = np.zeros((len(predictions), nc_tot, nc_tot), dtype=np.int32) for i, (preds, truth) in enumerate(zip(predictions, targets)): # Confusions - Confs[i, :, :] = confusion_matrix(truth, preds, dataset.label_values) + Confs[i, :, :] = fast_confusion(truth, preds, val_loader.dataset.label_values).astype(np.int32) + + t3 = time.time() ####################################### # Results on this subpart of validation @@ -1606,11 +806,11 @@ class ModelTrainer: C = np.sum(Confs, axis=0).astype(np.float32) # Balance with real validation proportions - C *= np.expand_dims(dataset.class_proportions['validation'] / (np.sum(C, axis=1) + 1e-6), 1) + C *= np.expand_dims(val_loader.dataset.class_proportions / (np.sum(C, axis=1) + 1e-6), 1) # Remove ignored labels from confusions - for l_ind, label_value in reversed(list(enumerate(dataset.label_values))): - if label_value in dataset.ignored_labels: + for l_ind, label_value in reversed(list(enumerate(val_loader.dataset.label_values))): + if label_value in val_loader.dataset.ignored_labels: C = np.delete(C, l_ind, axis=0) C = np.delete(C, l_ind, axis=1) @@ -1621,35 +821,40 @@ class ModelTrainer: # Results on the whole validation set ##################################### + t4 = time.time() + # Sum all validation confusions - C_tot = [np.sum(seq_C, axis=0) for seq_C in dataset.val_confs if len(seq_C) > 0] + C_tot = [np.sum(seq_C, axis=0) for seq_C in val_loader.dataset.val_confs if len(seq_C) > 0] C_tot = np.sum(np.stack(C_tot, axis=0), axis=0) - s = '' - for cc in C_tot: - for c in cc: - s += '{:8.1f} '.format(c) - s += '\n' - print(s) + if debug: + s = '\n' + for cc in C_tot: + for c in cc: + s += '{:8.1f} '.format(c) + s += '\n' + print(s) # Remove ignored labels from confusions - for l_ind, label_value in reversed(list(enumerate(dataset.label_values))): - if label_value in dataset.ignored_labels: + for l_ind, label_value in reversed(list(enumerate(val_loader.dataset.label_values))): + if label_value in val_loader.dataset.ignored_labels: C_tot = np.delete(C_tot, l_ind, axis=0) C_tot = np.delete(C_tot, l_ind, axis=1) # Objects IoU val_IoUs = IoU_from_confusions(C_tot) + t5 = time.time() + # Saving (optionnal) - if model.config.saving: + if config.saving: IoU_list = [IoUs, val_IoUs] file_list = ['subpart_IoUs.txt', 'val_IoUs.txt'] for IoUs_to_save, IoU_file in zip(IoU_list, file_list): # Name of saving file - test_file = join(model.saving_path, IoU_file) + test_file = join(config.saving_path, IoU_file) # Line to write: line = '' @@ -1667,12 +872,28 @@ class ModelTrainer: # Print instance mean mIoU = 100 * np.mean(IoUs) - print('{:s} : subpart mIoU = {:.1f} %'.format(model.config.dataset, mIoU)) + print('{:s} : subpart mIoU = {:.1f} %'.format(config.dataset, mIoU)) mIoU = 100 * np.mean(val_IoUs) - print('{:s} : val mIoU = {:.1f} %'.format(model.config.dataset, mIoU)) + print('{:s} : val mIoU = {:.1f} %'.format(config.dataset, mIoU)) + + t6 = time.time() + + # Display timings + if debug: + print('\n************************\n') + print('Validation timings:') + print('Init ...... {:.1f}s'.format(t1 - t0)) + print('Loop ...... {:.1f}s'.format(t2 - t1)) + print('Confs ..... {:.1f}s'.format(t3 - t2)) + print('IoU1 ...... {:.1f}s'.format(t4 - t3)) + print('IoU2 ...... {:.1f}s'.format(t5 - t4)) + print('Save ...... {:.1f}s'.format(t6 - t5)) + print('\n************************\n') return + + # Saving methods # ------------------------------------------------------------------------------------------------------------------