.
This commit is contained in:
parent
e5be3ce280
commit
d0ac818366
|
@ -764,6 +764,7 @@ class SemanticKittiSampler(Sampler):
|
||||||
# Get the potentials of the frames containing this class
|
# Get the potentials of the frames containing this class
|
||||||
class_potentials = self.dataset.potentials[self.dataset.class_frames[i]]
|
class_potentials = self.dataset.potentials[self.dataset.class_frames[i]]
|
||||||
|
|
||||||
|
|
||||||
if class_potentials.shape[0] > 0:
|
if class_potentials.shape[0] > 0:
|
||||||
|
|
||||||
# Get the indices to generate thanks to potentials
|
# Get the indices to generate thanks to potentials
|
||||||
|
@ -788,6 +789,15 @@ class SemanticKittiSampler(Sampler):
|
||||||
self.dataset.potentials[update_inds] = torch.ceil(self.dataset.potentials[update_inds])
|
self.dataset.potentials[update_inds] = torch.ceil(self.dataset.potentials[update_inds])
|
||||||
self.dataset.potentials[update_inds] += torch.from_numpy(np.random.rand(update_inds.shape[0]) * 0.1 + 0.1)
|
self.dataset.potentials[update_inds] += torch.from_numpy(np.random.rand(update_inds.shape[0]) * 0.1 + 0.1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
error_message = '\nIt seems there is a problem with the class statistics of your dataset, saved in the variable dataset.class_frames.\n'
|
||||||
|
error_message += 'Here are the current statistics:\n'
|
||||||
|
error_message += '{:>15s} {:>15s}\n'.format('Class', '# of frames')
|
||||||
|
for iii, ccc in enumerate(self.dataset.label_values):
|
||||||
|
error_message += '{:>15s} {:>15d}\n'.format(self.dataset.label_names[iii], len(self.dataset.class_frames[iii]))
|
||||||
|
error_message = '\nThis error is raised if one of the classes is not ignored and does not appear in any of the frames of the dataset.\n'
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
# Stack the chosen indices of all classes
|
# Stack the chosen indices of all classes
|
||||||
gen_indices = torch.cat(gen_indices, dim=0)
|
gen_indices = torch.cat(gen_indices, dim=0)
|
||||||
gen_classes = torch.cat(gen_classes, dim=0)
|
gen_classes = torch.cat(gen_classes, dim=0)
|
||||||
|
|
Loading…
Reference in a new issue