diff --git a/datasets/SemanticKitti.py b/datasets/SemanticKitti.py index a7d2ab9..7a781e9 100644 --- a/datasets/SemanticKitti.py +++ b/datasets/SemanticKitti.py @@ -768,7 +768,11 @@ class SemanticKittiSampler(Sampler): if class_n < class_potentials.shape[0]: _, class_indices = torch.topk(class_potentials, class_n, largest=False) else: - class_indices = torch.randperm(class_potentials.shape[0]) + class_indices = torch.zeros((0,), dtype=torch.int32) + while class_indices.shape < class_n: + new_class_inds = torch.randperm(class_potentials.shape[0]) + class_indices = torch.cat((class_indices, new_class_inds), dim=0) + class_indices = class_indices[:class_n] class_indices = self.dataset.class_frames[i][class_indices] # Add the indices to the generated ones @@ -776,8 +780,9 @@ class SemanticKittiSampler(Sampler): gen_classes.append(class_indices * 0 + c) # Update potentials - self.dataset.potentials[class_indices] = torch.ceil(self.dataset.potentials[class_indices]) - self.dataset.potentials[class_indices] += torch.from_numpy(np.random.rand(class_indices.shape[0]) * 0.1 + 0.1) + update_inds = torch.unique(class_indices) + self.dataset.potentials[update_inds] = torch.ceil(self.dataset.potentials[update_inds]) + self.dataset.potentials[update_inds] += torch.from_numpy(np.random.rand(update_inds.shape[0]) * 0.1 + 0.1) # Stack the chosen indices of all classes gen_indices = torch.cat(gen_indices, dim=0)