correction of use_potentials=False on validation

This commit is contained in:
HuguesTHOMAS 2021-07-29 15:49:30 +00:00
parent 6c31df8156
commit 1527e7d09e
4 changed files with 116 additions and 31 deletions

2
.gitignore vendored
View file

@ -4,6 +4,8 @@
/results /results
/test /test
/docker_scripts /docker_scripts
/kernels/dispositions
core
# VSCode related # VSCode related
*.code-workspace *.code-workspace

View file

@ -206,8 +206,7 @@ class S3DISDataset(PointCloudDataset):
self.potentials = None self.potentials = None
self.min_potentials = None self.min_potentials = None
self.argmin_potentials = None self.argmin_potentials = None
N = config.epoch_steps * config.batch_num self.epoch_inds = torch.from_numpy(np.zeros((2, self.epoch_n), dtype=np.int64))
self.epoch_inds = torch.from_numpy(np.zeros((2, N), dtype=np.int64))
self.epoch_i = torch.from_numpy(np.zeros((1,), dtype=np.int64)) self.epoch_i = torch.from_numpy(np.zeros((1,), dtype=np.int64))
self.epoch_i.share_memory_() self.epoch_i.share_memory_()
self.epoch_inds.share_memory_() self.epoch_inds.share_memory_()
@ -515,6 +514,9 @@ class S3DISDataset(PointCloudDataset):
# Update epoch indice # Update epoch indice
self.epoch_i += 1 self.epoch_i += 1
if self.epoch_i >= int(self.epoch_inds.shape[1]):
self.epoch_i -= int(self.epoch_inds.shape[1])
# Get points from tree structure # Get points from tree structure
points = np.array(self.input_trees[cloud_ind].data, copy=False) points = np.array(self.input_trees[cloud_ind].data, copy=False)
@ -1157,10 +1159,17 @@ class S3DISSampler(Sampler):
estim_b = 0 estim_b = 0
target_b = self.dataset.config.batch_num target_b = self.dataset.config.batch_num
# Calibration parameters # Expected batch size order of magnitude
low_pass_T = 10 expected_N = 100000
Kp = 100.0
# Calibration parameters. Higher means faster but can also become unstable
# Reduce Kp and Kd if your GP Uis small as the total number of points per batch will be smaller
low_pass_T = 100
Kp = expected_N / 200
Ki = 0.001 * Kp
Kd = 5 * Kp
finer = False finer = False
stabilized = False
# Convergence parameters # Convergence parameters
smooth_errors = [] smooth_errors = []
@ -1170,12 +1179,22 @@ class S3DISSampler(Sampler):
last_display = time.time() last_display = time.time()
i = 0 i = 0
breaking = False breaking = False
error_I = 0
error_D = 0
last_error = 0
debug_in = []
debug_out = []
debug_b = []
debug_estim_b = []
##################### #####################
# Perform calibration # Perform calibration
##################### #####################
for epoch in range(10): # number of batch per epoch
sample_batches = 999
for epoch in range((sample_batches // self.N) + 1):
for batch_i, batch in enumerate(dataloader): for batch_i, batch in enumerate(dataloader):
# Update neighborhood histogram # Update neighborhood histogram
@ -1191,14 +1210,25 @@ class S3DISSampler(Sampler):
# Estimate error (noisy) # Estimate error (noisy)
error = target_b - b error = target_b - b
error_I += error
error_D = error - last_error
last_error = error
# Save smooth errors for convergene check # Save smooth errors for convergene check
smooth_errors.append(target_b - estim_b) smooth_errors.append(target_b - estim_b)
if len(smooth_errors) > 10: if len(smooth_errors) > 30:
smooth_errors = smooth_errors[1:] smooth_errors = smooth_errors[1:]
# Update batch limit with P controller # Update batch limit with P controller
self.dataset.batch_limit += Kp * error self.dataset.batch_limit += Kp * error + Ki * error_I + Kd * error_D
# Unstability detection
if not stabilized and self.dataset.batch_limit < 0:
Kp *= 0.1
Ki *= 0.1
Kd *= 0.1
stabilized = True
# finer low pass filter when closing in # finer low pass filter when closing in
if not finer and np.abs(estim_b - target_b) < 1: if not finer and np.abs(estim_b - target_b) < 1:
@ -1221,14 +1251,42 @@ class S3DISSampler(Sampler):
estim_b, estim_b,
int(self.dataset.batch_limit))) int(self.dataset.batch_limit)))
# Debug plots
debug_in.append(int(batch.points[0].shape[0]))
debug_out.append(int(self.dataset.batch_limit))
debug_b.append(b)
debug_estim_b.append(estim_b)
if breaking: if breaking:
break break
# Plot in case we did not reach convergence
if not breaking:
import matplotlib.pyplot as plt
print("ERROR: It seems that the calibration have not reached convergence. Here are some plot to understand why:")
print("If you notice unstability, reduce the expected_N value")
print("If convergece is too slow, increase the expected_N value")
plt.figure()
plt.plot(debug_in)
plt.plot(debug_out)
plt.figure()
plt.plot(debug_b)
plt.plot(debug_estim_b)
plt.show()
a = 1/0
# Use collected neighbor histogram to get neighbors limit # Use collected neighbor histogram to get neighbors limit
cumsum = np.cumsum(neighb_hists.T, axis=0) cumsum = np.cumsum(neighb_hists.T, axis=0)
percentiles = np.sum(cumsum < (untouched_ratio * cumsum[hist_n - 1, :]), axis=0) percentiles = np.sum(cumsum < (untouched_ratio * cumsum[hist_n - 1, :]), axis=0)
self.dataset.neighborhood_limits = percentiles self.dataset.neighborhood_limits = percentiles
if verbose: if verbose:
# Crop histogram # Crop histogram

View file

@ -65,6 +65,30 @@ class S3DISConfig(Config):
# Architecture definition # Architecture definition
######################### #########################
# # Define layers
# architecture = ['simple',
# 'resnetb',
# 'resnetb_strided',
# 'resnetb',
# 'resnetb',
# 'resnetb_strided',
# 'resnetb_deformable',
# 'resnetb_deformable',
# 'resnetb_deformable_strided',
# 'resnetb_deformable',
# 'resnetb_deformable',
# 'resnetb_deformable_strided',
# 'resnetb_deformable',
# 'resnetb_deformable',
# 'nearest_upsample',
# 'unary',
# 'nearest_upsample',
# 'unary',
# 'nearest_upsample',
# 'unary',
# 'nearest_upsample',
# 'unary']
# Define layers # Define layers
architecture = ['simple', architecture = ['simple',
'resnetb', 'resnetb',
@ -72,14 +96,14 @@ class S3DISConfig(Config):
'resnetb', 'resnetb',
'resnetb', 'resnetb',
'resnetb_strided', 'resnetb_strided',
'resnetb_deformable', 'resnetb',
'resnetb_deformable', 'resnetb',
'resnetb_deformable_strided', 'resnetb_strided',
'resnetb_deformable', 'resnetb',
'resnetb_deformable', 'resnetb',
'resnetb_deformable_strided', 'resnetb_strided',
'resnetb_deformable', 'resnetb',
'resnetb_deformable', 'resnetb',
'nearest_upsample', 'nearest_upsample',
'unary', 'unary',
'nearest_upsample', 'nearest_upsample',
@ -97,7 +121,7 @@ class S3DISConfig(Config):
num_kernel_points = 15 num_kernel_points = 15
# Radius of the input sphere (decrease value to reduce memory cost) # Radius of the input sphere (decrease value to reduce memory cost)
in_radius = 1.2 in_radius = 1.0
# Size of the first subsampling grid in meter (increase value to reduce memory cost) # Size of the first subsampling grid in meter (increase value to reduce memory cost)
first_subsampling_dl = 0.02 first_subsampling_dl = 0.02
@ -244,8 +268,8 @@ if __name__ == '__main__':
config.saving_path = sys.argv[1] config.saving_path = sys.argv[1]
# Initialize datasets # Initialize datasets
training_dataset = S3DISDataset(config, set='training', use_potentials=True) training_dataset = S3DISDataset(config, set='training', use_potentials=False)
test_dataset = S3DISDataset(config, set='validation', use_potentials=True) test_dataset = S3DISDataset(config, set='validation', use_potentials=False)
# Initialize samplers # Initialize samplers
training_sampler = S3DISSampler(training_dataset) training_sampler = S3DISSampler(training_dataset)

View file

@ -579,6 +579,7 @@ class ModelTrainer:
text_file.write(line) text_file.write(line)
# Save potentials # Save potentials
if val_loader.dataset.use_potentials:
pot_path = join(config.saving_path, 'potentials') pot_path = join(config.saving_path, 'potentials')
if not exists(pot_path): if not exists(pot_path):
makedirs(pot_path) makedirs(pot_path)