This commit is contained in:
HuguesTHOMAS 2022-07-18 10:00:15 -04:00
parent 7d4c03d199
commit d33c4d254d

View file

@ -772,9 +772,9 @@ class SemanticKittiSampler(Sampler):
if class_n < class_potentials.shape[0]: if class_n < class_potentials.shape[0]:
_, class_indices = torch.topk(class_potentials, class_n, largest=False) _, class_indices = torch.topk(class_potentials, class_n, largest=False)
else: else:
class_indices = torch.zeros((0,), dtype=torch.int32) class_indices = torch.zeros((0,), dtype=torch.int64)
while class_indices.shape[0] < class_n: while class_indices.shape[0] < class_n:
new_class_inds = torch.randperm(class_potentials.shape[0]).type(torch.int32) new_class_inds = torch.randperm(class_potentials.shape[0]).type(torch.int64)
class_indices = torch.cat((class_indices, new_class_inds), dim=0) class_indices = torch.cat((class_indices, new_class_inds), dim=0)
class_indices = class_indices[:class_n] class_indices = class_indices[:class_n]
class_indices = self.dataset.class_frames[i][class_indices] class_indices = self.dataset.class_frames[i][class_indices]