diff --git a/.gitignore b/.gitignore index 8f21526..40c4558 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ /results /test /docker_scripts +/kernels/dispositions +core # VSCode related *.code-workspace diff --git a/datasets/S3DIS.py b/datasets/S3DIS.py index 227d4ac..f5965cb 100644 --- a/datasets/S3DIS.py +++ b/datasets/S3DIS.py @@ -206,8 +206,7 @@ class S3DISDataset(PointCloudDataset): self.potentials = None self.min_potentials = None self.argmin_potentials = None - N = config.epoch_steps * config.batch_num - self.epoch_inds = torch.from_numpy(np.zeros((2, N), dtype=np.int64)) + self.epoch_inds = torch.from_numpy(np.zeros((2, self.epoch_n), dtype=np.int64)) self.epoch_i = torch.from_numpy(np.zeros((1,), dtype=np.int64)) self.epoch_i.share_memory_() self.epoch_inds.share_memory_() @@ -515,6 +514,9 @@ class S3DISDataset(PointCloudDataset): # Update epoch indice self.epoch_i += 1 + if self.epoch_i >= int(self.epoch_inds.shape[1]): + self.epoch_i -= int(self.epoch_inds.shape[1]) + # Get points from tree structure points = np.array(self.input_trees[cloud_ind].data, copy=False) @@ -1156,11 +1158,18 @@ class S3DISSampler(Sampler): # Estimated average batch size and target value estim_b = 0 target_b = self.dataset.config.batch_num + + # Expected batch size order of magnitude + expected_N = 100000 - # Calibration parameters - low_pass_T = 10 - Kp = 100.0 + # Calibration parameters. Higher means faster but can also become unstable + # Reduce Kp and Kd if your GP Uis small as the total number of points per batch will be smaller + low_pass_T = 100 + Kp = expected_N / 200 + Ki = 0.001 * Kp + Kd = 5 * Kp finer = False + stabilized = False # Convergence parameters smooth_errors = [] @@ -1170,12 +1179,22 @@ class S3DISSampler(Sampler): last_display = time.time() i = 0 breaking = False + error_I = 0 + error_D = 0 + last_error = 0 + + debug_in = [] + debug_out = [] + debug_b = [] + debug_estim_b = [] ##################### # Perform calibration ##################### - for epoch in range(10): + # number of batch per epoch + sample_batches = 999 + for epoch in range((sample_batches // self.N) + 1): for batch_i, batch in enumerate(dataloader): # Update neighborhood histogram @@ -1191,14 +1210,25 @@ class S3DISSampler(Sampler): # Estimate error (noisy) error = target_b - b + error_I += error + error_D = error - last_error + last_error = error + # Save smooth errors for convergene check smooth_errors.append(target_b - estim_b) - if len(smooth_errors) > 10: + if len(smooth_errors) > 30: smooth_errors = smooth_errors[1:] # Update batch limit with P controller - self.dataset.batch_limit += Kp * error + self.dataset.batch_limit += Kp * error + Ki * error_I + Kd * error_D + + # Unstability detection + if not stabilized and self.dataset.batch_limit < 0: + Kp *= 0.1 + Ki *= 0.1 + Kd *= 0.1 + stabilized = True # finer low pass filter when closing in if not finer and np.abs(estim_b - target_b) < 1: @@ -1221,14 +1251,42 @@ class S3DISSampler(Sampler): estim_b, int(self.dataset.batch_limit))) + # Debug plots + debug_in.append(int(batch.points[0].shape[0])) + debug_out.append(int(self.dataset.batch_limit)) + debug_b.append(b) + debug_estim_b.append(estim_b) + if breaking: break + # Plot in case we did not reach convergence + if not breaking: + import matplotlib.pyplot as plt + + print("ERROR: It seems that the calibration have not reached convergence. Here are some plot to understand why:") + print("If you notice unstability, reduce the expected_N value") + print("If convergece is too slow, increase the expected_N value") + + plt.figure() + plt.plot(debug_in) + plt.plot(debug_out) + + plt.figure() + plt.plot(debug_b) + plt.plot(debug_estim_b) + + plt.show() + + a = 1/0 + + # Use collected neighbor histogram to get neighbors limit cumsum = np.cumsum(neighb_hists.T, axis=0) percentiles = np.sum(cumsum < (untouched_ratio * cumsum[hist_n - 1, :]), axis=0) self.dataset.neighborhood_limits = percentiles + if verbose: # Crop histogram diff --git a/train_S3DIS.py b/train_S3DIS.py index fb61566..984fdc6 100644 --- a/train_S3DIS.py +++ b/train_S3DIS.py @@ -65,6 +65,30 @@ class S3DISConfig(Config): # Architecture definition ######################### + # # Define layers + # architecture = ['simple', + # 'resnetb', + # 'resnetb_strided', + # 'resnetb', + # 'resnetb', + # 'resnetb_strided', + # 'resnetb_deformable', + # 'resnetb_deformable', + # 'resnetb_deformable_strided', + # 'resnetb_deformable', + # 'resnetb_deformable', + # 'resnetb_deformable_strided', + # 'resnetb_deformable', + # 'resnetb_deformable', + # 'nearest_upsample', + # 'unary', + # 'nearest_upsample', + # 'unary', + # 'nearest_upsample', + # 'unary', + # 'nearest_upsample', + # 'unary'] + # Define layers architecture = ['simple', 'resnetb', @@ -72,14 +96,14 @@ class S3DISConfig(Config): 'resnetb', 'resnetb', 'resnetb_strided', - 'resnetb_deformable', - 'resnetb_deformable', - 'resnetb_deformable_strided', - 'resnetb_deformable', - 'resnetb_deformable', - 'resnetb_deformable_strided', - 'resnetb_deformable', - 'resnetb_deformable', + 'resnetb', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', 'nearest_upsample', 'unary', 'nearest_upsample', @@ -97,7 +121,7 @@ class S3DISConfig(Config): num_kernel_points = 15 # Radius of the input sphere (decrease value to reduce memory cost) - in_radius = 1.2 + in_radius = 1.0 # Size of the first subsampling grid in meter (increase value to reduce memory cost) first_subsampling_dl = 0.02 @@ -244,8 +268,8 @@ if __name__ == '__main__': config.saving_path = sys.argv[1] # Initialize datasets - training_dataset = S3DISDataset(config, set='training', use_potentials=True) - test_dataset = S3DISDataset(config, set='validation', use_potentials=True) + training_dataset = S3DISDataset(config, set='training', use_potentials=False) + test_dataset = S3DISDataset(config, set='validation', use_potentials=False) # Initialize samplers training_sampler = S3DISSampler(training_dataset) diff --git a/utils/trainer.py b/utils/trainer.py index 5b0aadb..b5d2c41 100644 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -579,18 +579,19 @@ class ModelTrainer: text_file.write(line) # Save potentials - pot_path = join(config.saving_path, 'potentials') - if not exists(pot_path): - makedirs(pot_path) - files = val_loader.dataset.files - for i, file_path in enumerate(files): - pot_points = np.array(val_loader.dataset.pot_trees[i].data, copy=False) - cloud_name = file_path.split('/')[-1] - pot_name = join(pot_path, cloud_name) - pots = val_loader.dataset.potentials[i].numpy().astype(np.float32) - write_ply(pot_name, - [pot_points.astype(np.float32), pots], - ['x', 'y', 'z', 'pots']) + if val_loader.dataset.use_potentials: + pot_path = join(config.saving_path, 'potentials') + if not exists(pot_path): + makedirs(pot_path) + files = val_loader.dataset.files + for i, file_path in enumerate(files): + pot_points = np.array(val_loader.dataset.pot_trees[i].data, copy=False) + cloud_name = file_path.split('/')[-1] + pot_name = join(pot_path, cloud_name) + pots = val_loader.dataset.potentials[i].numpy().astype(np.float32) + write_ply(pot_name, + [pot_points.astype(np.float32), pots], + ['x', 'y', 'z', 'pots']) t6 = time.time()