diff --git a/datasets/SemanticKitti.py b/datasets/SemanticKitti.py index 8eae642..ffd355f 100644 --- a/datasets/SemanticKitti.py +++ b/datasets/SemanticKitti.py @@ -772,9 +772,9 @@ class SemanticKittiSampler(Sampler): if class_n < class_potentials.shape[0]: _, class_indices = torch.topk(class_potentials, class_n, largest=False) else: - class_indices = torch.zeros((0,), dtype=torch.int32) + class_indices = torch.zeros((0,), dtype=torch.int64) while class_indices.shape[0] < class_n: - new_class_inds = torch.randperm(class_potentials.shape[0]).type(torch.int32) + new_class_inds = torch.randperm(class_potentials.shape[0]).type(torch.int64) class_indices = torch.cat((class_indices, new_class_inds), dim=0) class_indices = class_indices[:class_n] class_indices = self.dataset.class_frames[i][class_indices]