diff --git a/datasets/SemanticKitti.py b/datasets/SemanticKitti.py index 157adf0..a1664e2 100644 --- a/datasets/SemanticKitti.py +++ b/datasets/SemanticKitti.py @@ -762,27 +762,29 @@ class SemanticKittiSampler(Sampler): # Get the potentials of the frames containing this class class_potentials = self.dataset.potentials[self.dataset.class_frames[i]] - # Get the indices to generate thanks to potentials - used_classes = self.dataset.num_classes - len(self.dataset.ignored_labels) - class_n = num_centers // used_classes + 1 - if class_n < class_potentials.shape[0]: - _, class_indices = torch.topk(class_potentials, class_n, largest=False) - else: - class_indices = torch.zeros((0,), dtype=torch.int32) - while class_indices.shape[0] < class_n: - new_class_inds = torch.randperm(class_potentials.shape[0]).type(torch.int32) - class_indices = torch.cat((class_indices, new_class_inds), dim=0) - class_indices = class_indices[:class_n] - class_indices = self.dataset.class_frames[i][class_indices] + if class_potentials.shape[0] > 0: - # Add the indices to the generated ones - gen_indices.append(class_indices) - gen_classes.append(class_indices * 0 + c) + # Get the indices to generate thanks to potentials + used_classes = self.dataset.num_classes - len(self.dataset.ignored_labels) + class_n = num_centers // used_classes + 1 + if class_n < class_potentials.shape[0]: + _, class_indices = torch.topk(class_potentials, class_n, largest=False) + else: + class_indices = torch.zeros((0,), dtype=torch.int32) + while class_indices.shape[0] < class_n: + new_class_inds = torch.randperm(class_potentials.shape[0]).type(torch.int32) + class_indices = torch.cat((class_indices, new_class_inds), dim=0) + class_indices = class_indices[:class_n] + class_indices = self.dataset.class_frames[i][class_indices] - # Update potentials - update_inds = torch.unique(class_indices) - self.dataset.potentials[update_inds] = torch.ceil(self.dataset.potentials[update_inds]) - self.dataset.potentials[update_inds] += torch.from_numpy(np.random.rand(update_inds.shape[0]) * 0.1 + 0.1) + # Add the indices to the generated ones + gen_indices.append(class_indices) + gen_classes.append(class_indices * 0 + c) + + # Update potentials + update_inds = torch.unique(class_indices) + self.dataset.potentials[update_inds] = torch.ceil(self.dataset.potentials[update_inds]) + self.dataset.potentials[update_inds] += torch.from_numpy(np.random.rand(update_inds.shape[0]) * 0.1 + 0.1) # Stack the chosen indices of all classes gen_indices = torch.cat(gen_indices, dim=0)