This commit is contained in:
HuguesTHOMAS 2022-07-18 10:40:10 -04:00
parent e5be3ce280
commit d0ac818366

View file

@ -764,6 +764,7 @@ class SemanticKittiSampler(Sampler):
# Get the potentials of the frames containing this class # Get the potentials of the frames containing this class
class_potentials = self.dataset.potentials[self.dataset.class_frames[i]] class_potentials = self.dataset.potentials[self.dataset.class_frames[i]]
if class_potentials.shape[0] > 0: if class_potentials.shape[0] > 0:
# Get the indices to generate thanks to potentials # Get the indices to generate thanks to potentials
@ -788,6 +789,15 @@ class SemanticKittiSampler(Sampler):
self.dataset.potentials[update_inds] = torch.ceil(self.dataset.potentials[update_inds]) self.dataset.potentials[update_inds] = torch.ceil(self.dataset.potentials[update_inds])
self.dataset.potentials[update_inds] += torch.from_numpy(np.random.rand(update_inds.shape[0]) * 0.1 + 0.1) self.dataset.potentials[update_inds] += torch.from_numpy(np.random.rand(update_inds.shape[0]) * 0.1 + 0.1)
else:
error_message = '\nIt seems there is a problem with the class statistics of your dataset, saved in the variable dataset.class_frames.\n'
error_message += 'Here are the current statistics:\n'
error_message += '{:>15s} {:>15s}\n'.format('Class', '# of frames')
for iii, ccc in enumerate(self.dataset.label_values):
error_message += '{:>15s} {:>15d}\n'.format(self.dataset.label_names[iii], len(self.dataset.class_frames[iii]))
error_message = '\nThis error is raised if one of the classes is not ignored and does not appear in any of the frames of the dataset.\n'
raise ValueError(error_message)
# Stack the chosen indices of all classes # Stack the chosen indices of all classes
gen_indices = torch.cat(gen_indices, dim=0) gen_indices = torch.cat(gen_indices, dim=0)
gen_classes = torch.cat(gen_classes, dim=0) gen_classes = torch.cat(gen_classes, dim=0)