This commit is contained in:
HuguesTHOMAS 2022-07-11 10:26:41 -04:00
parent 3a774ff8d5
commit 843629422b

View file

@ -818,6 +818,10 @@ class SemanticKittiSampler(Sampler):
_, gen_indices = torch.topk(self.dataset.potentials, num_centers, largest=False, sorted=True)
else:
gen_indices = torch.randperm(self.dataset.potentials.shape[0])
while gen_indices.shape[0] < num_centers:
new_gen_indices = torch.randperm(self.dataset.potentials.shape[0]).type(torch.int32)
gen_indices = torch.cat((gen_indices, new_gen_indices), dim=0)
gen_indices = gen_indices[:num_centers]
# Update potentials (Change the order for the next epoch)
self.dataset.potentials[gen_indices] = torch.ceil(self.dataset.potentials[gen_indices])