SemanticKitti inf while loop correction
This commit is contained in:
parent
7fdbc57f9b
commit
7f5f52b067
|
@ -762,27 +762,29 @@ class SemanticKittiSampler(Sampler):
|
||||||
# Get the potentials of the frames containing this class
|
# Get the potentials of the frames containing this class
|
||||||
class_potentials = self.dataset.potentials[self.dataset.class_frames[i]]
|
class_potentials = self.dataset.potentials[self.dataset.class_frames[i]]
|
||||||
|
|
||||||
# Get the indices to generate thanks to potentials
|
if class_potentials.shape[0] > 0:
|
||||||
used_classes = self.dataset.num_classes - len(self.dataset.ignored_labels)
|
|
||||||
class_n = num_centers // used_classes + 1
|
|
||||||
if class_n < class_potentials.shape[0]:
|
|
||||||
_, class_indices = torch.topk(class_potentials, class_n, largest=False)
|
|
||||||
else:
|
|
||||||
class_indices = torch.zeros((0,), dtype=torch.int32)
|
|
||||||
while class_indices.shape[0] < class_n:
|
|
||||||
new_class_inds = torch.randperm(class_potentials.shape[0]).type(torch.int32)
|
|
||||||
class_indices = torch.cat((class_indices, new_class_inds), dim=0)
|
|
||||||
class_indices = class_indices[:class_n]
|
|
||||||
class_indices = self.dataset.class_frames[i][class_indices]
|
|
||||||
|
|
||||||
# Add the indices to the generated ones
|
# Get the indices to generate thanks to potentials
|
||||||
gen_indices.append(class_indices)
|
used_classes = self.dataset.num_classes - len(self.dataset.ignored_labels)
|
||||||
gen_classes.append(class_indices * 0 + c)
|
class_n = num_centers // used_classes + 1
|
||||||
|
if class_n < class_potentials.shape[0]:
|
||||||
|
_, class_indices = torch.topk(class_potentials, class_n, largest=False)
|
||||||
|
else:
|
||||||
|
class_indices = torch.zeros((0,), dtype=torch.int32)
|
||||||
|
while class_indices.shape[0] < class_n:
|
||||||
|
new_class_inds = torch.randperm(class_potentials.shape[0]).type(torch.int32)
|
||||||
|
class_indices = torch.cat((class_indices, new_class_inds), dim=0)
|
||||||
|
class_indices = class_indices[:class_n]
|
||||||
|
class_indices = self.dataset.class_frames[i][class_indices]
|
||||||
|
|
||||||
# Update potentials
|
# Add the indices to the generated ones
|
||||||
update_inds = torch.unique(class_indices)
|
gen_indices.append(class_indices)
|
||||||
self.dataset.potentials[update_inds] = torch.ceil(self.dataset.potentials[update_inds])
|
gen_classes.append(class_indices * 0 + c)
|
||||||
self.dataset.potentials[update_inds] += torch.from_numpy(np.random.rand(update_inds.shape[0]) * 0.1 + 0.1)
|
|
||||||
|
# Update potentials
|
||||||
|
update_inds = torch.unique(class_indices)
|
||||||
|
self.dataset.potentials[update_inds] = torch.ceil(self.dataset.potentials[update_inds])
|
||||||
|
self.dataset.potentials[update_inds] += torch.from_numpy(np.random.rand(update_inds.shape[0]) * 0.1 + 0.1)
|
||||||
|
|
||||||
# Stack the chosen indices of all classes
|
# Stack the chosen indices of all classes
|
||||||
gen_indices = torch.cat(gen_indices, dim=0)
|
gen_indices = torch.cat(gen_indices, dim=0)
|
||||||
|
|
Loading…
Reference in a new issue