Correction of the S3DIS random indice generation
This commit is contained in:
parent
73e444d486
commit
e600c1667d
|
@ -28,6 +28,7 @@ import numpy as np
|
||||||
import pickle
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
import warnings
|
||||||
from multiprocessing import Lock
|
from multiprocessing import Lock
|
||||||
|
|
||||||
|
|
||||||
|
@ -914,41 +915,59 @@ class S3DISSampler(Sampler):
|
||||||
self.dataset.epoch_inds *= 0
|
self.dataset.epoch_inds *= 0
|
||||||
|
|
||||||
# Initiate container for indices
|
# Initiate container for indices
|
||||||
all_epoch_inds = np.zeros((2, 0), dtype=np.int32)
|
all_epoch_inds = np.zeros((2, 0), dtype=np.int64)
|
||||||
|
|
||||||
# Number of sphere centers taken per class in each cloud
|
# Number of sphere centers taken per class in each cloud
|
||||||
num_centers = self.N * self.dataset.config.batch_num
|
num_centers = self.N * self.dataset.config.batch_num
|
||||||
random_pick_n = int(np.ceil(num_centers / (self.dataset.num_clouds * self.dataset.config.num_classes)))
|
random_pick_n = int(np.ceil(num_centers / self.dataset.config.num_classes))
|
||||||
|
|
||||||
# Choose random points of each class for each cloud
|
# Choose random points of each class for each cloud
|
||||||
for cloud_ind, cloud_labels in enumerate(self.dataset.input_labels):
|
epoch_indices = np.zeros((2, 0), dtype=np.int64)
|
||||||
epoch_indices = np.empty((0,), dtype=np.int32)
|
for label_ind, label in enumerate(self.dataset.label_values):
|
||||||
for label_ind, label in enumerate(self.dataset.label_values):
|
if label not in self.dataset.ignored_labels:
|
||||||
if label not in self.dataset.ignored_labels:
|
|
||||||
|
# Gather indices of the points with this label in all the input clouds
|
||||||
|
all_label_indices = []
|
||||||
|
for cloud_ind, cloud_labels in enumerate(self.dataset.input_labels):
|
||||||
label_indices = np.where(np.equal(cloud_labels, label))[0]
|
label_indices = np.where(np.equal(cloud_labels, label))[0]
|
||||||
if len(label_indices) <= random_pick_n:
|
all_label_indices.append(np.vstack((np.full(label_indices.shape, cloud_ind, dtype=np.int64), label_indices)))
|
||||||
epoch_indices = np.hstack((epoch_indices, label_indices))
|
|
||||||
elif len(label_indices) < 50 * random_pick_n:
|
|
||||||
new_randoms = np.random.choice(label_indices, size=random_pick_n, replace=False)
|
|
||||||
epoch_indices = np.hstack((epoch_indices, new_randoms.astype(np.int32)))
|
|
||||||
else:
|
|
||||||
rand_inds = []
|
|
||||||
while len(rand_inds) < random_pick_n:
|
|
||||||
rand_inds = np.unique(np.random.choice(label_indices, size=5 * random_pick_n, replace=True))
|
|
||||||
epoch_indices = np.hstack((epoch_indices, rand_inds[:random_pick_n].astype(np.int32)))
|
|
||||||
|
|
||||||
# Stack those indices with the cloud index
|
# Stack them: [2, N1+N2+...]
|
||||||
epoch_indices = np.vstack((np.full(epoch_indices.shape, cloud_ind, dtype=np.int32), epoch_indices))
|
all_label_indices = np.hstack(all_label_indices)
|
||||||
|
|
||||||
# Update the global indice container
|
# Select a a random number amongst them
|
||||||
all_epoch_inds = np.hstack((all_epoch_inds, epoch_indices))
|
N_inds = all_label_indices.shape[1]
|
||||||
|
if N_inds < random_pick_n:
|
||||||
|
chosen_label_inds = np.zeros((2, 0), dtype=np.int64)
|
||||||
|
while chosen_label_inds.shape[1] < random_pick_n:
|
||||||
|
chosen_label_inds = np.hstack((chosen_label_inds, all_label_indices[:, np.random.permutation(N_inds)]))
|
||||||
|
warnings.warn('When choosing random epoch indices (use_potentials=False), \
|
||||||
|
class {:d}: {:s} only had {:d} available points, while we \
|
||||||
|
needed {:d}. Repeating indices in the same epoch'.format(label,
|
||||||
|
self.dataset.label_names[label_ind],
|
||||||
|
N_inds,
|
||||||
|
random_pick_n))
|
||||||
|
|
||||||
|
elif N_inds < 50 * random_pick_n:
|
||||||
|
rand_inds = np.random.choice(N_inds, size=random_pick_n, replace=False)
|
||||||
|
chosen_label_inds = all_label_indices[:, rand_inds]
|
||||||
|
|
||||||
|
else:
|
||||||
|
chosen_label_inds = np.zeros((2, 0), dtype=np.int64)
|
||||||
|
while chosen_label_inds.shape[1] < random_pick_n:
|
||||||
|
rand_inds = np.unique(np.random.choice(N_inds, size=2*random_pick_n, replace=True))
|
||||||
|
chosen_label_inds = np.hstack((chosen_label_inds, all_label_indices[:, rand_inds]))
|
||||||
|
chosen_label_inds = chosen_label_inds[:, :random_pick_n]
|
||||||
|
|
||||||
|
# Stack for each label
|
||||||
|
all_epoch_inds = np.hstack((all_epoch_inds, chosen_label_inds))
|
||||||
|
|
||||||
# Random permutation of the indices
|
# Random permutation of the indices
|
||||||
random_order = np.random.permutation(all_epoch_inds.shape[1])
|
random_order = np.random.permutation(all_epoch_inds.shape[1])[:num_centers]
|
||||||
all_epoch_inds = all_epoch_inds[:, random_order].astype(np.int64)
|
all_epoch_inds = all_epoch_inds[:, random_order].astype(np.int64)
|
||||||
|
|
||||||
# Update epoch inds
|
# Update epoch inds
|
||||||
self.dataset.epoch_inds += torch.from_numpy(all_epoch_inds[:, :num_centers])
|
self.dataset.epoch_inds += torch.from_numpy(all_epoch_inds)
|
||||||
|
|
||||||
# Generator loop
|
# Generator loop
|
||||||
for i in range(self.N):
|
for i in range(self.N):
|
||||||
|
|
Loading…
Reference in a new issue