diff --git a/datasets/SemanticKitti.py b/datasets/SemanticKitti.py index a1664e2..d2860b5 100644 --- a/datasets/SemanticKitti.py +++ b/datasets/SemanticKitti.py @@ -818,6 +818,10 @@ class SemanticKittiSampler(Sampler): _, gen_indices = torch.topk(self.dataset.potentials, num_centers, largest=False, sorted=True) else: gen_indices = torch.randperm(self.dataset.potentials.shape[0]) + while gen_indices.shape[0] < num_centers: + new_gen_indices = torch.randperm(self.dataset.potentials.shape[0]).type(torch.int32) + gen_indices = torch.cat((gen_indices, new_gen_indices), dim=0) + gen_indices = gen_indices[:num_centers] # Update potentials (Change the order for the next epoch) self.dataset.potentials[gen_indices] = torch.ceil(self.dataset.potentials[gen_indices])