Correction of the S3DIS random indice generation
This commit is contained in:
parent
73e444d486
commit
e600c1667d
|
@ -28,6 +28,7 @@ import numpy as np
|
|||
import pickle
|
||||
import torch
|
||||
import math
|
||||
import warnings
|
||||
from multiprocessing import Lock
|
||||
|
||||
|
||||
|
@ -914,41 +915,59 @@ class S3DISSampler(Sampler):
|
|||
self.dataset.epoch_inds *= 0
|
||||
|
||||
# Initiate container for indices
|
||||
all_epoch_inds = np.zeros((2, 0), dtype=np.int32)
|
||||
all_epoch_inds = np.zeros((2, 0), dtype=np.int64)
|
||||
|
||||
# Number of sphere centers taken per class in each cloud
|
||||
num_centers = self.N * self.dataset.config.batch_num
|
||||
random_pick_n = int(np.ceil(num_centers / (self.dataset.num_clouds * self.dataset.config.num_classes)))
|
||||
random_pick_n = int(np.ceil(num_centers / self.dataset.config.num_classes))
|
||||
|
||||
# Choose random points of each class for each cloud
|
||||
for cloud_ind, cloud_labels in enumerate(self.dataset.input_labels):
|
||||
epoch_indices = np.empty((0,), dtype=np.int32)
|
||||
epoch_indices = np.zeros((2, 0), dtype=np.int64)
|
||||
for label_ind, label in enumerate(self.dataset.label_values):
|
||||
if label not in self.dataset.ignored_labels:
|
||||
|
||||
# Gather indices of the points with this label in all the input clouds
|
||||
all_label_indices = []
|
||||
for cloud_ind, cloud_labels in enumerate(self.dataset.input_labels):
|
||||
label_indices = np.where(np.equal(cloud_labels, label))[0]
|
||||
if len(label_indices) <= random_pick_n:
|
||||
epoch_indices = np.hstack((epoch_indices, label_indices))
|
||||
elif len(label_indices) < 50 * random_pick_n:
|
||||
new_randoms = np.random.choice(label_indices, size=random_pick_n, replace=False)
|
||||
epoch_indices = np.hstack((epoch_indices, new_randoms.astype(np.int32)))
|
||||
all_label_indices.append(np.vstack((np.full(label_indices.shape, cloud_ind, dtype=np.int64), label_indices)))
|
||||
|
||||
# Stack them: [2, N1+N2+...]
|
||||
all_label_indices = np.hstack(all_label_indices)
|
||||
|
||||
# Select a a random number amongst them
|
||||
N_inds = all_label_indices.shape[1]
|
||||
if N_inds < random_pick_n:
|
||||
chosen_label_inds = np.zeros((2, 0), dtype=np.int64)
|
||||
while chosen_label_inds.shape[1] < random_pick_n:
|
||||
chosen_label_inds = np.hstack((chosen_label_inds, all_label_indices[:, np.random.permutation(N_inds)]))
|
||||
warnings.warn('When choosing random epoch indices (use_potentials=False), \
|
||||
class {:d}: {:s} only had {:d} available points, while we \
|
||||
needed {:d}. Repeating indices in the same epoch'.format(label,
|
||||
self.dataset.label_names[label_ind],
|
||||
N_inds,
|
||||
random_pick_n))
|
||||
|
||||
elif N_inds < 50 * random_pick_n:
|
||||
rand_inds = np.random.choice(N_inds, size=random_pick_n, replace=False)
|
||||
chosen_label_inds = all_label_indices[:, rand_inds]
|
||||
|
||||
else:
|
||||
rand_inds = []
|
||||
while len(rand_inds) < random_pick_n:
|
||||
rand_inds = np.unique(np.random.choice(label_indices, size=5 * random_pick_n, replace=True))
|
||||
epoch_indices = np.hstack((epoch_indices, rand_inds[:random_pick_n].astype(np.int32)))
|
||||
chosen_label_inds = np.zeros((2, 0), dtype=np.int64)
|
||||
while chosen_label_inds.shape[1] < random_pick_n:
|
||||
rand_inds = np.unique(np.random.choice(N_inds, size=2*random_pick_n, replace=True))
|
||||
chosen_label_inds = np.hstack((chosen_label_inds, all_label_indices[:, rand_inds]))
|
||||
chosen_label_inds = chosen_label_inds[:, :random_pick_n]
|
||||
|
||||
# Stack those indices with the cloud index
|
||||
epoch_indices = np.vstack((np.full(epoch_indices.shape, cloud_ind, dtype=np.int32), epoch_indices))
|
||||
|
||||
# Update the global indice container
|
||||
all_epoch_inds = np.hstack((all_epoch_inds, epoch_indices))
|
||||
# Stack for each label
|
||||
all_epoch_inds = np.hstack((all_epoch_inds, chosen_label_inds))
|
||||
|
||||
# Random permutation of the indices
|
||||
random_order = np.random.permutation(all_epoch_inds.shape[1])
|
||||
random_order = np.random.permutation(all_epoch_inds.shape[1])[:num_centers]
|
||||
all_epoch_inds = all_epoch_inds[:, random_order].astype(np.int64)
|
||||
|
||||
# Update epoch inds
|
||||
self.dataset.epoch_inds += torch.from_numpy(all_epoch_inds[:, :num_centers])
|
||||
self.dataset.epoch_inds += torch.from_numpy(all_epoch_inds)
|
||||
|
||||
# Generator loop
|
||||
for i in range(self.N):
|
||||
|
|
Loading…
Reference in a new issue