# # # 0=================================0 # | Kernel Point Convolutions | # 0=================================0 # # # ---------------------------------------------------------------------------------------------------------------------- # # Class handling the training of any model # # ---------------------------------------------------------------------------------------------------------------------- # # Hugues THOMAS - 11/06/2018 # # ---------------------------------------------------------------------------------------------------------------------- # # Imports and global variables # \**********************************/ # # Basic libs import torch import numpy as np from os import makedirs, remove from os.path import exists, join import time # PLY reader from utils.ply import write_ply # Metrics from utils.metrics import IoU_from_confusions, fast_confusion from utils.config import Config # ---------------------------------------------------------------------------------------------------------------------- # # Trainer Class # \*******************/ # class ModelTrainer: # Initialization methods # ------------------------------------------------------------------------------------------------------------------ def __init__(self, net, config, chkp_path=None, finetune=False, on_gpu=True): """ Initialize training parameters and reload previous model for restore/finetune :param net: network object :param config: configuration object :param chkp_path: path to the checkpoint that needs to be loaded (None for new training) :param finetune: finetune from checkpoint (True) or restore training from checkpoint (False) :param on_gpu: Train on GPU or CPU """ ############ # Parameters ############ # Epoch index self.epoch = 0 self.step = 0 # Optimizer with specific learning rate for deformable KPConv deform_params = [v for k, v in net.named_parameters() if "offset" in k] other_params = [v for k, v in net.named_parameters() if "offset" not in k] deform_lr = config.learning_rate * config.deform_lr_factor self.optimizer = torch.optim.SGD( [{"params": other_params}, {"params": deform_params, "lr": deform_lr}], lr=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay, ) # Choose to train on CPU or GPU if on_gpu and torch.cuda.is_available(): self.device = torch.device("cuda:0") else: self.device = torch.device("cpu") net.to(self.device) ########################## # Load previous checkpoint ########################## if chkp_path is not None: if finetune: checkpoint = torch.load(chkp_path) net.load_state_dict(checkpoint["model_state_dict"]) net.train() print("Model restored and ready for finetuning.") else: checkpoint = torch.load(chkp_path) net.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.epoch = checkpoint["epoch"] net.train() print("Model and training state restored.") # Path of the result folder if config.saving: if config.saving_path is None: config.saving_path = time.strftime( "results/Log_%Y-%m-%d_%H-%M-%S", time.gmtime() ) if not exists(config.saving_path): makedirs(config.saving_path) config.save() return # Training main method # ------------------------------------------------------------------------------------------------------------------ def train(self, net, training_loader, val_loader, config): """ Train the model on a particular dataset. """ ################ # Initialization ################ if config.saving: # Training log file with open(join(config.saving_path, "training.txt"), "w") as file: file.write("epochs steps out_loss offset_loss train_accuracy time\n") # Killing file (simply delete this file when you want to stop the training) PID_file = join(config.saving_path, "running_PID.txt") if not exists(PID_file): with open(PID_file, "w") as file: file.write("Launched with PyCharm") # Checkpoints directory checkpoint_directory = join(config.saving_path, "checkpoints") if not exists(checkpoint_directory): makedirs(checkpoint_directory) else: checkpoint_directory = None PID_file = None # Loop variables t0 = time.time() t = [time.time()] last_display = time.time() mean_dt = np.zeros(1) # Start training loop for epoch in range(config.max_epoch): # Remove File for kill signal if epoch == config.max_epoch - 1 and exists(PID_file): remove(PID_file) self.step = 0 for batch in training_loader: # Check kill signal (running_PID.txt deleted) if config.saving and not exists(PID_file): continue ################## # Processing batch ################## # New time t = t[-1:] t += [time.time()] if "cuda" in self.device.type: batch.to(self.device) # zero the parameter gradients self.optimizer.zero_grad() # Forward pass outputs = net(batch, config) loss = net.loss(outputs, batch.labels) acc = net.accuracy(outputs, batch.labels) t += [time.time()] # Backward + optimize loss.backward() if config.grad_clip_norm > 0: # torch.nn.utils.clip_grad_norm_(net.parameters(), config.grad_clip_norm) torch.nn.utils.clip_grad_value_( net.parameters(), config.grad_clip_norm ) self.optimizer.step() torch.cuda.empty_cache() torch.cuda.synchronize(self.device) t += [time.time()] # Average timing if self.step < 2: mean_dt = np.array(t[1:]) - np.array(t[:-1]) else: mean_dt = 0.9 * mean_dt + 0.1 * (np.array(t[1:]) - np.array(t[:-1])) # Console display (only one per second) if (t[-1] - last_display) > 1.0: last_display = t[-1] message = "e{:03d}-i{:04d} => L={:.3f} acc={:3.0f}% / t(ms): {:5.1f} {:5.1f} {:5.1f})" print( message.format( self.epoch, self.step, loss.item(), 100 * acc, 1000 * mean_dt[0], 1000 * mean_dt[1], 1000 * mean_dt[2], ) ) # Log file if config.saving: with open(join(config.saving_path, "training.txt"), "a") as file: message = "{:d} {:d} {:.3f} {:.3f} {:.3f} {:.3f}\n" file.write( message.format( self.epoch, self.step, net.output_loss, net.reg_loss, acc, t[-1] - t0, ) ) self.step += 1 ############## # End of epoch ############## # Check kill signal (running_PID.txt deleted) if config.saving and not exists(PID_file): break # Update learning rate if self.epoch in config.lr_decays: for param_group in self.optimizer.param_groups: param_group["lr"] *= config.lr_decays[self.epoch] # Update epoch self.epoch += 1 # Saving if config.saving: # Get current state dict save_dict = { "epoch": self.epoch, "model_state_dict": net.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "saving_path": config.saving_path, } # Save current state of the network (for restoring purposes) checkpoint_path = join(checkpoint_directory, "current_chkp.tar") torch.save(save_dict, checkpoint_path) # Save checkpoints occasionally if (self.epoch + 1) % config.checkpoint_gap == 0: checkpoint_path = join( checkpoint_directory, "chkp_{:04d}.tar".format(self.epoch + 1) ) torch.save(save_dict, checkpoint_path) # Validation net.eval() self.validation(net, val_loader, config) net.train() print("Finished Training") return # Validation methods # ------------------------------------------------------------------------------------------------------------------ def validation(self, net, val_loader, config: Config): if config.dataset_task == "classification": self.object_classification_validation(net, val_loader, config) elif config.dataset_task == "segmentation": self.object_segmentation_validation(net, val_loader, config) elif config.dataset_task == "cloud_segmentation": self.cloud_segmentation_validation(net, val_loader, config) elif config.dataset_task == "slam_segmentation": self.slam_segmentation_validation(net, val_loader, config) else: raise ValueError("No validation method implemented for this network type") def object_classification_validation(self, net, val_loader, config): """ Perform a round of validation and show/save results :param net: network object :param val_loader: data loader for validation set :param config: configuration object """ ############ # Initialize ############ # Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing) val_smooth = 0.95 # Number of classes predicted by the model nc_model = config.num_classes softmax = torch.nn.Softmax(1) # Initialize global prediction over all models if not hasattr(self, "val_probs"): self.val_probs = np.zeros((val_loader.dataset.num_models, nc_model)) ##################### # Network predictions ##################### probs = [] targets = [] obj_inds = [] t = [time.time()] last_display = time.time() mean_dt = np.zeros(1) # Start validation loop for batch in val_loader: # New time t = t[-1:] t += [time.time()] if "cuda" in self.device.type: batch.to(self.device) # Forward pass outputs = net(batch, config) # Get probs and labels probs += [softmax(outputs).cpu().detach().numpy()] targets += [batch.labels.cpu().numpy()] obj_inds += [batch.model_inds.cpu().numpy()] torch.cuda.synchronize(self.device) # Average timing t += [time.time()] mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) # Display if (t[-1] - last_display) > 1.0: last_display = t[-1] message = "Validation : {:.1f}% (timings : {:4.2f} {:4.2f})" print( message.format( 100 * len(obj_inds) / config.validation_size, 1000 * (mean_dt[0]), 1000 * (mean_dt[1]), ) ) # Stack all validation predictions probs = np.vstack(probs) targets = np.hstack(targets) obj_inds = np.hstack(obj_inds) ################### # Voting validation ################### self.val_probs[obj_inds] = ( val_smooth * self.val_probs[obj_inds] + (1 - val_smooth) * probs ) ############ # Confusions ############ validation_labels = np.array(val_loader.dataset.label_values) # Compute classification results C1 = fast_confusion(targets, np.argmax(probs, axis=1), validation_labels) # Compute votes confusion C2 = fast_confusion( val_loader.dataset.input_labels, np.argmax(self.val_probs, axis=1), validation_labels, ) # Saving (optionnal) if config.saving: print("Save confusions") conf_list = [C1, C2] file_list = ["val_confs.txt", "vote_confs.txt"] for conf, conf_file in zip(conf_list, file_list): test_file = join(config.saving_path, conf_file) if exists(test_file): with open(test_file, "a") as text_file: for line in conf: for value in line: text_file.write("%d " % value) text_file.write("\n") else: with open(test_file, "w") as text_file: for line in conf: for value in line: text_file.write("%d " % value) text_file.write("\n") val_ACC = 100 * np.sum(np.diag(C1)) / (np.sum(C1) + 1e-6) vote_ACC = 100 * np.sum(np.diag(C2)) / (np.sum(C2) + 1e-6) print("Accuracies : val = {:.1f}% / vote = {:.1f}%".format(val_ACC, vote_ACC)) return C1 def cloud_segmentation_validation(self, net, val_loader, config, debug=False): """ Validation method for cloud segmentation models """ ############ # Initialize ############ t0 = time.time() # Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing) val_smooth = 0.95 softmax = torch.nn.Softmax(1) # Do not validate if dataset has no validation cloud if val_loader.dataset.validation_split not in val_loader.dataset.all_splits: return # Number of classes including ignored labels nc_tot = val_loader.dataset.num_classes # Number of classes predicted by the model nc_model = config.num_classes # print(nc_tot) # print(nc_model) # Initiate global prediction over validation clouds if not hasattr(self, "validation_probs"): self.validation_probs = [ np.zeros((l.shape[0], nc_model)) for l in val_loader.dataset.input_labels ] self.val_proportions = np.zeros(nc_model, dtype=np.float32) i = 0 for label_value in val_loader.dataset.label_values: if label_value not in val_loader.dataset.ignored_labels: self.val_proportions[i] = np.sum( [ np.sum(labels == label_value) for labels in val_loader.dataset.validation_labels ] ) i += 1 ##################### # Network predictions ##################### predictions = [] targets = [] t = [time.time()] last_display = time.time() mean_dt = np.zeros(1) t1 = time.time() # Start validation loop for i, batch in enumerate(val_loader): # New time t = t[-1:] t += [time.time()] if "cuda" in self.device.type: batch.to(self.device) # Forward pass outputs = net(batch, config) # Get probs and labels stacked_probs = softmax(outputs).cpu().detach().numpy() labels = batch.labels.cpu().numpy() lengths = batch.lengths[0].cpu().numpy() in_inds = batch.input_inds.cpu().numpy() cloud_inds = batch.cloud_inds.cpu().numpy() torch.cuda.synchronize(self.device) # Get predictions and labels per instance # *************************************** i0 = 0 for b_i, length in enumerate(lengths): # Get prediction target = labels[i0 : i0 + length] probs = stacked_probs[i0 : i0 + length] inds = in_inds[i0 : i0 + length] c_i = cloud_inds[b_i] # Update current probs in whole cloud self.validation_probs[c_i][inds] = ( val_smooth * self.validation_probs[c_i][inds] + (1 - val_smooth) * probs ) # Stack all prediction for this epoch predictions.append(probs) targets.append(target) i0 += length # Average timing t += [time.time()] mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) # Display if (t[-1] - last_display) > 1.0: last_display = t[-1] message = "Validation : {:.1f}% (timings : {:4.2f} {:4.2f})" print( message.format( 100 * i / config.validation_size, 1000 * (mean_dt[0]), 1000 * (mean_dt[1]), ) ) t2 = time.time() # Confusions for our subparts of validation set Confs = np.zeros((len(predictions), nc_tot, nc_tot), dtype=np.int32) for i, (probs, truth) in enumerate(zip(predictions, targets)): # Insert false columns for ignored labels for l_ind, label_value in enumerate(val_loader.dataset.label_values): if label_value in val_loader.dataset.ignored_labels: probs = np.insert(probs, l_ind, 0, axis=1) # Predicted labels preds = val_loader.dataset.label_values[np.argmax(probs, axis=1)] # Confusions Confs[i, :, :] = fast_confusion( truth, preds, val_loader.dataset.label_values ).astype(np.int32) t3 = time.time() # Sum all confusions C = np.sum(Confs, axis=0).astype(np.float32) # Remove ignored labels from confusions for l_ind, label_value in reversed( list(enumerate(val_loader.dataset.label_values)) ): if label_value in val_loader.dataset.ignored_labels: C = np.delete(C, l_ind, axis=0) C = np.delete(C, l_ind, axis=1) # Balance with real validation proportions C *= np.expand_dims(self.val_proportions / (np.sum(C, axis=1) + 1e-6), 1) t4 = time.time() # Objects IoU IoUs = IoU_from_confusions(C) t5 = time.time() # Saving (optionnal) if config.saving: # Name of saving file test_file = join(config.saving_path, "val_IoUs.txt") # Line to write: line = "" for IoU in IoUs: line += "{:.3f} ".format(IoU) line = line + "\n" # Write in file if exists(test_file): with open(test_file, "a") as text_file: text_file.write(line) else: with open(test_file, "w") as text_file: text_file.write(line) # Save potentials if val_loader.dataset.use_potentials: pot_path = join(config.saving_path, "potentials") if not exists(pot_path): makedirs(pot_path) files = val_loader.dataset.files for i, file_path in enumerate(files): pot_points = np.array( val_loader.dataset.pot_trees[i].data, copy=False ) cloud_name = file_path.split("/")[-1] pot_name = join(pot_path, cloud_name) pots = val_loader.dataset.potentials[i].numpy().astype(np.float32) write_ply( pot_name, [pot_points.astype(np.float32), pots], ["x", "y", "z", "pots"], ) t6 = time.time() # Print instance mean mIoU = 100 * np.mean(IoUs) print("{:s} mean IoU = {:.1f}%".format(config.dataset, mIoU)) # Save predicted cloud occasionally if config.saving and (self.epoch + 1) % config.checkpoint_gap == 0: val_path = join(config.saving_path, "val_preds_{:d}".format(self.epoch + 1)) if not exists(val_path): makedirs(val_path) files = val_loader.dataset.files for i, file_path in enumerate(files): # Get points points = val_loader.dataset.load_evaluation_points(file_path) # Get probs on our own ply points sub_probs = self.validation_probs[i] # Insert false columns for ignored labels for l_ind, label_value in enumerate(val_loader.dataset.label_values): if label_value in val_loader.dataset.ignored_labels: sub_probs = np.insert(sub_probs, l_ind, 0, axis=1) # Get the predicted labels sub_preds = val_loader.dataset.label_values[ np.argmax(sub_probs, axis=1).astype(np.int32) ] # Reproject preds on the evaluations points preds = (sub_preds[val_loader.dataset.test_proj[i]]).astype(np.int32) # Path of saved validation file cloud_name = file_path.split("/")[-1] val_name = join(val_path, cloud_name) # Save file labels = val_loader.dataset.validation_labels[i].astype(np.int32) write_ply( val_name, [points, preds, labels], ["x", "y", "z", "preds", "class"] ) # Display timings t7 = time.time() if debug: print("\n************************\n") print("Validation timings:") print("Init ...... {:.1f}s".format(t1 - t0)) print("Loop ...... {:.1f}s".format(t2 - t1)) print("Confs ..... {:.1f}s".format(t3 - t2)) print("Confs bis . {:.1f}s".format(t4 - t3)) print("IoU ....... {:.1f}s".format(t5 - t4)) print("Save1 ..... {:.1f}s".format(t6 - t5)) print("Save2 ..... {:.1f}s".format(t7 - t6)) print("\n************************\n") return def slam_segmentation_validation(self, net, val_loader, config, debug=True): """ Validation method for slam segmentation models """ ############ # Initialize ############ t0 = time.time() # Do not validate if dataset has no validation cloud if val_loader is None: return # Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing) softmax = torch.nn.Softmax(1) # Create folder for validation predictions if not exists(join(config.saving_path, "val_preds")): makedirs(join(config.saving_path, "val_preds")) # initiate the dataset validation containers val_loader.dataset.val_points = [] val_loader.dataset.val_labels = [] # Number of classes including ignored labels nc_tot = val_loader.dataset.num_classes ##################### # Network predictions ##################### predictions = [] targets = [] inds = [] val_i = 0 t = [time.time()] last_display = time.time() mean_dt = np.zeros(1) t1 = time.time() # Start validation loop for i, batch in enumerate(val_loader): # New time t = t[-1:] t += [time.time()] if "cuda" in self.device.type: batch.to(self.device) # Forward pass outputs = net(batch, config) # Get probs and labels stk_probs = softmax(outputs).cpu().detach().numpy() lengths = batch.lengths[0].cpu().numpy() f_inds = batch.frame_inds.cpu().numpy() r_inds_list = batch.reproj_inds r_mask_list = batch.reproj_masks labels_list = batch.val_labels torch.cuda.synchronize(self.device) # Get predictions and labels per instance # *************************************** i0 = 0 for b_i, length in enumerate(lengths): # Get prediction probs = stk_probs[i0 : i0 + length] proj_inds = r_inds_list[b_i] proj_mask = r_mask_list[b_i] frame_labels = labels_list[b_i] s_ind = f_inds[b_i, 0] f_ind = f_inds[b_i, 1] # Project predictions on the frame points proj_probs = probs[proj_inds] # Safe check if only one point: if proj_probs.ndim < 2: proj_probs = np.expand_dims(proj_probs, 0) # Insert false columns for ignored labels for l_ind, label_value in enumerate(val_loader.dataset.label_values): if label_value in val_loader.dataset.ignored_labels: proj_probs = np.insert(proj_probs, l_ind, 0, axis=1) # Predicted labels preds = val_loader.dataset.label_values[np.argmax(proj_probs, axis=1)] # Save predictions in a binary file filename = "{:s}_{:07d}.npy".format( val_loader.dataset.sequences[s_ind], f_ind ) filepath = join(config.saving_path, "val_preds", filename) if exists(filepath): frame_preds = np.load(filepath) else: frame_preds = np.zeros(frame_labels.shape, dtype=np.uint8) frame_preds[proj_mask] = preds.astype(np.uint8) np.save(filepath, frame_preds) # Save some of the frame pots if f_ind % 20 == 0: seq_path = join( val_loader.dataset.path, "sequences", val_loader.dataset.sequences[s_ind], ) velo_file = join( seq_path, "velodyne", val_loader.dataset.frames[s_ind][f_ind] + ".bin", ) frame_points = np.fromfile(velo_file, dtype=np.float32) frame_points = frame_points.reshape((-1, 4)) write_ply( filepath[:-4] + "_pots.ply", [frame_points[:, :3], frame_labels, frame_preds], ["x", "y", "z", "gt", "pre"], ) # Update validation confusions frame_C = fast_confusion( frame_labels, frame_preds.astype(np.int32), val_loader.dataset.label_values, ) val_loader.dataset.val_confs[s_ind][f_ind, :, :] = frame_C # Stack all prediction for this epoch predictions += [preds] targets += [frame_labels[proj_mask]] inds += [f_inds[b_i, :]] val_i += 1 i0 += length # Average timing t += [time.time()] mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1])) # Display if (t[-1] - last_display) > 1.0: last_display = t[-1] message = "Validation : {:.1f}% (timings : {:4.2f} {:4.2f})" print( message.format( 100 * i / config.validation_size, 1000 * (mean_dt[0]), 1000 * (mean_dt[1]), ) ) t2 = time.time() # Confusions for our subparts of validation set Confs = np.zeros((len(predictions), nc_tot, nc_tot), dtype=np.int32) for i, (preds, truth) in enumerate(zip(predictions, targets)): # Confusions Confs[i, :, :] = fast_confusion( truth, preds, val_loader.dataset.label_values ).astype(np.int32) t3 = time.time() ####################################### # Results on this subpart of validation ####################################### # Sum all confusions C = np.sum(Confs, axis=0).astype(np.float32) # Balance with real validation proportions C *= np.expand_dims( val_loader.dataset.class_proportions / (np.sum(C, axis=1) + 1e-6), 1 ) # Remove ignored labels from confusions for l_ind, label_value in reversed( list(enumerate(val_loader.dataset.label_values)) ): if label_value in val_loader.dataset.ignored_labels: C = np.delete(C, l_ind, axis=0) C = np.delete(C, l_ind, axis=1) # Objects IoU IoUs = IoU_from_confusions(C) ##################################### # Results on the whole validation set ##################################### t4 = time.time() # Sum all validation confusions C_tot = [ np.sum(seq_C, axis=0) for seq_C in val_loader.dataset.val_confs if len(seq_C) > 0 ] C_tot = np.sum(np.stack(C_tot, axis=0), axis=0) if debug: s = "\n" for cc in C_tot: for c in cc: s += "{:8.1f} ".format(c) s += "\n" print(s) # Remove ignored labels from confusions for l_ind, label_value in reversed( list(enumerate(val_loader.dataset.label_values)) ): if label_value in val_loader.dataset.ignored_labels: C_tot = np.delete(C_tot, l_ind, axis=0) C_tot = np.delete(C_tot, l_ind, axis=1) # Objects IoU val_IoUs = IoU_from_confusions(C_tot) t5 = time.time() # Saving (optionnal) if config.saving: IoU_list = [IoUs, val_IoUs] file_list = ["subpart_IoUs.txt", "val_IoUs.txt"] for IoUs_to_save, IoU_file in zip(IoU_list, file_list): # Name of saving file test_file = join(config.saving_path, IoU_file) # Line to write: line = "" for IoU in IoUs_to_save: line += "{:.3f} ".format(IoU) line = line + "\n" # Write in file if exists(test_file): with open(test_file, "a") as text_file: text_file.write(line) else: with open(test_file, "w") as text_file: text_file.write(line) # Print instance mean mIoU = 100 * np.mean(IoUs) print("{:s} : subpart mIoU = {:.1f} %".format(config.dataset, mIoU)) mIoU = 100 * np.mean(val_IoUs) print("{:s} : val mIoU = {:.1f} %".format(config.dataset, mIoU)) t6 = time.time() # Display timings if debug: print("\n************************\n") print("Validation timings:") print("Init ...... {:.1f}s".format(t1 - t0)) print("Loop ...... {:.1f}s".format(t2 - t1)) print("Confs ..... {:.1f}s".format(t3 - t2)) print("IoU1 ...... {:.1f}s".format(t4 - t3)) print("IoU2 ...... {:.1f}s".format(t5 - t4)) print("Save ...... {:.1f}s".format(t6 - t5)) print("\n************************\n") return