Correction of the S3DIS random indice generation

This commit is contained in:
HuguesTHOMAS 2022-02-28 12:40:51 -05:00
parent 73e444d486
commit e600c1667d

View file

@ -28,6 +28,7 @@ import numpy as np
import pickle
import torch
import math
import warnings
from multiprocessing import Lock
@ -914,41 +915,59 @@ class S3DISSampler(Sampler):
self.dataset.epoch_inds *= 0
# Initiate container for indices
all_epoch_inds = np.zeros((2, 0), dtype=np.int32)
all_epoch_inds = np.zeros((2, 0), dtype=np.int64)
# Number of sphere centers taken per class in each cloud
num_centers = self.N * self.dataset.config.batch_num
random_pick_n = int(np.ceil(num_centers / (self.dataset.num_clouds * self.dataset.config.num_classes)))
random_pick_n = int(np.ceil(num_centers / self.dataset.config.num_classes))
# Choose random points of each class for each cloud
for cloud_ind, cloud_labels in enumerate(self.dataset.input_labels):
epoch_indices = np.empty((0,), dtype=np.int32)
for label_ind, label in enumerate(self.dataset.label_values):
if label not in self.dataset.ignored_labels:
epoch_indices = np.zeros((2, 0), dtype=np.int64)
for label_ind, label in enumerate(self.dataset.label_values):
if label not in self.dataset.ignored_labels:
# Gather indices of the points with this label in all the input clouds
all_label_indices = []
for cloud_ind, cloud_labels in enumerate(self.dataset.input_labels):
label_indices = np.where(np.equal(cloud_labels, label))[0]
if len(label_indices) <= random_pick_n:
epoch_indices = np.hstack((epoch_indices, label_indices))
elif len(label_indices) < 50 * random_pick_n:
new_randoms = np.random.choice(label_indices, size=random_pick_n, replace=False)
epoch_indices = np.hstack((epoch_indices, new_randoms.astype(np.int32)))
else:
rand_inds = []
while len(rand_inds) < random_pick_n:
rand_inds = np.unique(np.random.choice(label_indices, size=5 * random_pick_n, replace=True))
epoch_indices = np.hstack((epoch_indices, rand_inds[:random_pick_n].astype(np.int32)))
all_label_indices.append(np.vstack((np.full(label_indices.shape, cloud_ind, dtype=np.int64), label_indices)))
# Stack those indices with the cloud index
epoch_indices = np.vstack((np.full(epoch_indices.shape, cloud_ind, dtype=np.int32), epoch_indices))
# Stack them: [2, N1+N2+...]
all_label_indices = np.hstack(all_label_indices)
# Update the global indice container
all_epoch_inds = np.hstack((all_epoch_inds, epoch_indices))
# Select a a random number amongst them
N_inds = all_label_indices.shape[1]
if N_inds < random_pick_n:
chosen_label_inds = np.zeros((2, 0), dtype=np.int64)
while chosen_label_inds.shape[1] < random_pick_n:
chosen_label_inds = np.hstack((chosen_label_inds, all_label_indices[:, np.random.permutation(N_inds)]))
warnings.warn('When choosing random epoch indices (use_potentials=False), \
class {:d}: {:s} only had {:d} available points, while we \
needed {:d}. Repeating indices in the same epoch'.format(label,
self.dataset.label_names[label_ind],
N_inds,
random_pick_n))
elif N_inds < 50 * random_pick_n:
rand_inds = np.random.choice(N_inds, size=random_pick_n, replace=False)
chosen_label_inds = all_label_indices[:, rand_inds]
else:
chosen_label_inds = np.zeros((2, 0), dtype=np.int64)
while chosen_label_inds.shape[1] < random_pick_n:
rand_inds = np.unique(np.random.choice(N_inds, size=2*random_pick_n, replace=True))
chosen_label_inds = np.hstack((chosen_label_inds, all_label_indices[:, rand_inds]))
chosen_label_inds = chosen_label_inds[:, :random_pick_n]
# Stack for each label
all_epoch_inds = np.hstack((all_epoch_inds, chosen_label_inds))
# Random permutation of the indices
random_order = np.random.permutation(all_epoch_inds.shape[1])
random_order = np.random.permutation(all_epoch_inds.shape[1])[:num_centers]
all_epoch_inds = all_epoch_inds[:, random_order].astype(np.int64)
# Update epoch inds
self.dataset.epoch_inds += torch.from_numpy(all_epoch_inds[:, :num_centers])
self.dataset.epoch_inds += torch.from_numpy(all_epoch_inds)
# Generator loop
for i in range(self.N):