Initial commit
This commit is contained in:
parent
c131638fa7
commit
5efecbbe20
|
@ -175,7 +175,6 @@ class S3DISDataset(PointCloudDataset):
|
||||||
self.min_potentials += [float(self.potentials[-1][min_ind])]
|
self.min_potentials += [float(self.potentials[-1][min_ind])]
|
||||||
|
|
||||||
# Share potential memory
|
# Share potential memory
|
||||||
self.pot_lock = Lock()
|
|
||||||
self.argmin_potentials = torch.from_numpy(np.array(self.argmin_potentials, dtype=np.int64))
|
self.argmin_potentials = torch.from_numpy(np.array(self.argmin_potentials, dtype=np.int64))
|
||||||
self.min_potentials = torch.from_numpy(np.array(self.min_potentials, dtype=np.float64))
|
self.min_potentials = torch.from_numpy(np.array(self.min_potentials, dtype=np.float64))
|
||||||
self.argmin_potentials.share_memory_()
|
self.argmin_potentials.share_memory_()
|
||||||
|
@ -185,12 +184,20 @@ class S3DISDataset(PointCloudDataset):
|
||||||
|
|
||||||
self.worker_waiting = torch.tensor([0 for _ in range(config.input_threads)], dtype=torch.int32)
|
self.worker_waiting = torch.tensor([0 for _ in range(config.input_threads)], dtype=torch.int32)
|
||||||
self.worker_waiting.share_memory_()
|
self.worker_waiting.share_memory_()
|
||||||
|
self.epoch_inds = None
|
||||||
|
self.epoch_i = 0
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.pot_lock = None
|
|
||||||
self.potentials = None
|
self.potentials = None
|
||||||
self.min_potentials = None
|
self.min_potentials = None
|
||||||
self.argmin_potentials = None
|
self.argmin_potentials = None
|
||||||
|
N = config.epoch_steps * config.batch_num
|
||||||
|
self.epoch_inds = torch.from_numpy(np.zeros((2, N), dtype=np.int64))
|
||||||
|
self.epoch_i = torch.from_numpy(np.zeros((1,), dtype=np.int64))
|
||||||
|
self.epoch_i.share_memory_()
|
||||||
|
self.epoch_inds.share_memory_()
|
||||||
|
|
||||||
|
self.worker_lock = Lock()
|
||||||
|
|
||||||
# For ERF visualization, we want only one cloud per batch and no randomness
|
# For ERF visualization, we want only one cloud per batch and no randomness
|
||||||
if self.set == 'ERF':
|
if self.set == 'ERF':
|
||||||
|
@ -206,12 +213,19 @@ class S3DISDataset(PointCloudDataset):
|
||||||
"""
|
"""
|
||||||
return len(self.cloud_names)
|
return len(self.cloud_names)
|
||||||
|
|
||||||
def __getitem__(self, batch_i, debug_workers=False):
|
def __getitem__(self, batch_i):
|
||||||
"""
|
"""
|
||||||
The main thread gives a list of indices to load a batch. Each worker is going to work in parallel to load a
|
The main thread gives a list of indices to load a batch. Each worker is going to work in parallel to load a
|
||||||
different list of indices.
|
different list of indices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.use_potentials:
|
||||||
|
return self.potential_item(batch_i)
|
||||||
|
else:
|
||||||
|
return self.random_item(batch_i)
|
||||||
|
|
||||||
|
def potential_item(self, batch_i, debug_workers=False):
|
||||||
|
|
||||||
# Initiate concatanation lists
|
# Initiate concatanation lists
|
||||||
p_list = []
|
p_list = []
|
||||||
f_list = []
|
f_list = []
|
||||||
|
@ -226,7 +240,6 @@ class S3DISDataset(PointCloudDataset):
|
||||||
info = get_worker_info()
|
info = get_worker_info()
|
||||||
wid = info.id
|
wid = info.id
|
||||||
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
if debug_workers:
|
if debug_workers:
|
||||||
|
@ -243,7 +256,7 @@ class S3DISDataset(PointCloudDataset):
|
||||||
print(message)
|
print(message)
|
||||||
self.worker_waiting[wid] = 0
|
self.worker_waiting[wid] = 0
|
||||||
|
|
||||||
with self.pot_lock:
|
with self.worker_lock:
|
||||||
|
|
||||||
if debug_workers:
|
if debug_workers:
|
||||||
message = ''
|
message = ''
|
||||||
|
@ -271,7 +284,7 @@ class S3DISDataset(PointCloudDataset):
|
||||||
|
|
||||||
# Add a small noise to center point
|
# Add a small noise to center point
|
||||||
if self.set != 'ERF':
|
if self.set != 'ERF':
|
||||||
center_point += np.random.normal(scale=self.config.in_radius/10, size=center_point.shape)
|
center_point += np.random.normal(scale=self.config.in_radius / 10, size=center_point.shape)
|
||||||
|
|
||||||
# Indices of points in input region
|
# Indices of points in input region
|
||||||
pot_inds, dists = self.pot_trees[cloud_ind].query_radius(center_point,
|
pot_inds, dists = self.pot_trees[cloud_ind].query_radius(center_point,
|
||||||
|
@ -337,7 +350,7 @@ class S3DISDataset(PointCloudDataset):
|
||||||
break
|
break
|
||||||
|
|
||||||
# Randomly drop some points (act as an augmentation process and a safety for GPU memory consumption)
|
# Randomly drop some points (act as an augmentation process and a safety for GPU memory consumption)
|
||||||
#if n > int(self.batch_limit):
|
# if n > int(self.batch_limit):
|
||||||
# input_inds = np.random.choice(input_inds, size=int(self.batch_limit) - 1, replace=False)
|
# input_inds = np.random.choice(input_inds, size=int(self.batch_limit) - 1, replace=False)
|
||||||
# n = input_inds.shape[0]
|
# n = input_inds.shape[0]
|
||||||
|
|
||||||
|
@ -398,6 +411,131 @@ class S3DISDataset(PointCloudDataset):
|
||||||
|
|
||||||
return input_list
|
return input_list
|
||||||
|
|
||||||
|
def random_item(self, batch_i):
|
||||||
|
|
||||||
|
# Initiate concatanation lists
|
||||||
|
p_list = []
|
||||||
|
f_list = []
|
||||||
|
l_list = []
|
||||||
|
i_list = []
|
||||||
|
pi_list = []
|
||||||
|
ci_list = []
|
||||||
|
s_list = []
|
||||||
|
R_list = []
|
||||||
|
batch_n = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
with self.worker_lock:
|
||||||
|
|
||||||
|
# Get potential minimum
|
||||||
|
cloud_ind = int(self.epoch_inds[0, self.epoch_i])
|
||||||
|
point_ind = int(self.epoch_inds[1, self.epoch_i])
|
||||||
|
|
||||||
|
# Update epoch indice
|
||||||
|
self.epoch_i += 1
|
||||||
|
|
||||||
|
# Get points from tree structure
|
||||||
|
points = np.array(self.input_trees[cloud_ind].data, copy=False)
|
||||||
|
|
||||||
|
# Center point of input region
|
||||||
|
center_point = points[point_ind, :].reshape(1, -1)
|
||||||
|
|
||||||
|
# Add a small noise to center point
|
||||||
|
if self.set != 'ERF':
|
||||||
|
center_point += np.random.normal(scale=self.config.in_radius / 10, size=center_point.shape)
|
||||||
|
|
||||||
|
# Indices of points in input region
|
||||||
|
input_inds = self.input_trees[cloud_ind].query_radius(center_point,
|
||||||
|
r=self.config.in_radius)[0]
|
||||||
|
|
||||||
|
# Number collected
|
||||||
|
n = input_inds.shape[0]
|
||||||
|
|
||||||
|
# Collect labels and colors
|
||||||
|
input_points = (points[input_inds] - center_point).astype(np.float32)
|
||||||
|
input_colors = self.input_colors[cloud_ind][input_inds]
|
||||||
|
if self.set in ['test', 'ERF']:
|
||||||
|
input_labels = np.zeros(input_points.shape[0])
|
||||||
|
else:
|
||||||
|
input_labels = self.input_labels[cloud_ind][input_inds]
|
||||||
|
input_labels = np.array([self.label_to_idx[l] for l in input_labels])
|
||||||
|
|
||||||
|
# Data augmentation
|
||||||
|
input_points, scale, R = self.augmentation_transform(input_points)
|
||||||
|
|
||||||
|
# Color augmentation
|
||||||
|
if np.random.rand() > self.config.augment_color:
|
||||||
|
input_colors *= 0
|
||||||
|
|
||||||
|
# Get original height as additional feature
|
||||||
|
input_features = np.hstack((input_colors, input_points[:, 2:] + center_point[:, 2:])).astype(np.float32)
|
||||||
|
|
||||||
|
# Stack batch
|
||||||
|
p_list += [input_points]
|
||||||
|
f_list += [input_features]
|
||||||
|
l_list += [input_labels]
|
||||||
|
pi_list += [input_inds]
|
||||||
|
i_list += [point_ind]
|
||||||
|
ci_list += [cloud_ind]
|
||||||
|
s_list += [scale]
|
||||||
|
R_list += [R]
|
||||||
|
|
||||||
|
# Update batch size
|
||||||
|
batch_n += n
|
||||||
|
|
||||||
|
# In case batch is full, stop
|
||||||
|
if batch_n > int(self.batch_limit):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Randomly drop some points (act as an augmentation process and a safety for GPU memory consumption)
|
||||||
|
# if n > int(self.batch_limit):
|
||||||
|
# input_inds = np.random.choice(input_inds, size=int(self.batch_limit) - 1, replace=False)
|
||||||
|
# n = input_inds.shape[0]
|
||||||
|
|
||||||
|
###################
|
||||||
|
# Concatenate batch
|
||||||
|
###################
|
||||||
|
|
||||||
|
stacked_points = np.concatenate(p_list, axis=0)
|
||||||
|
features = np.concatenate(f_list, axis=0)
|
||||||
|
labels = np.concatenate(l_list, axis=0)
|
||||||
|
point_inds = np.array(i_list, dtype=np.int32)
|
||||||
|
cloud_inds = np.array(ci_list, dtype=np.int32)
|
||||||
|
input_inds = np.concatenate(pi_list, axis=0)
|
||||||
|
stack_lengths = np.array([pp.shape[0] for pp in p_list], dtype=np.int32)
|
||||||
|
scales = np.array(s_list, dtype=np.float32)
|
||||||
|
rots = np.stack(R_list, axis=0)
|
||||||
|
|
||||||
|
# Input features
|
||||||
|
stacked_features = np.ones_like(stacked_points[:, :1], dtype=np.float32)
|
||||||
|
if self.config.in_features_dim == 1:
|
||||||
|
pass
|
||||||
|
elif self.config.in_features_dim == 4:
|
||||||
|
stacked_features = np.hstack((stacked_features, features[:, :3]))
|
||||||
|
elif self.config.in_features_dim == 5:
|
||||||
|
stacked_features = np.hstack((stacked_features, features))
|
||||||
|
else:
|
||||||
|
raise ValueError('Only accepted input dimensions are 1, 4 and 7 (without and with XYZ)')
|
||||||
|
|
||||||
|
#######################
|
||||||
|
# Create network inputs
|
||||||
|
#######################
|
||||||
|
#
|
||||||
|
# Points, neighbors, pooling indices for each layers
|
||||||
|
#
|
||||||
|
|
||||||
|
# Get the whole input list
|
||||||
|
input_list = self.segmentation_inputs(stacked_points,
|
||||||
|
stacked_features,
|
||||||
|
labels,
|
||||||
|
stack_lengths)
|
||||||
|
|
||||||
|
# Add scale and rotation for testing
|
||||||
|
input_list += [scales, rots, cloud_inds, point_inds, input_inds]
|
||||||
|
|
||||||
|
return input_list
|
||||||
|
|
||||||
def prepare_S3DIS_ply(self):
|
def prepare_S3DIS_ply(self):
|
||||||
|
|
||||||
print('\nPreparing ply files')
|
print('\nPreparing ply files')
|
||||||
|
@ -622,14 +760,14 @@ class S3DISDataset(PointCloudDataset):
|
||||||
# Reprojection indices
|
# Reprojection indices
|
||||||
######################
|
######################
|
||||||
|
|
||||||
|
# Get number of clouds
|
||||||
|
self.num_clouds = len(self.input_trees)
|
||||||
|
|
||||||
# Only necessary for validation and test sets
|
# Only necessary for validation and test sets
|
||||||
if self.set in ['validation', 'test']:
|
if self.set in ['validation', 'test']:
|
||||||
|
|
||||||
print('\nPreparing reprojection indices for testing')
|
print('\nPreparing reprojection indices for testing')
|
||||||
|
|
||||||
# Get number of clouds
|
|
||||||
self.num_clouds = len(self.input_trees)
|
|
||||||
|
|
||||||
# Get validation/test reprojection indices
|
# Get validation/test reprojection indices
|
||||||
i_cloud = 0
|
i_cloud = 0
|
||||||
for i, file_path in enumerate(self.train_files):
|
for i, file_path in enumerate(self.train_files):
|
||||||
|
@ -690,7 +828,7 @@ class S3DISDataset(PointCloudDataset):
|
||||||
class S3DISSampler(Sampler):
|
class S3DISSampler(Sampler):
|
||||||
"""Sampler for S3DIS"""
|
"""Sampler for S3DIS"""
|
||||||
|
|
||||||
def __init__(self, dataset: S3DISDataset,):
|
def __init__(self, dataset: S3DISDataset):
|
||||||
Sampler.__init__(self, dataset)
|
Sampler.__init__(self, dataset)
|
||||||
|
|
||||||
# Dataset used by the sampler (no copy is made in memory)
|
# Dataset used by the sampler (no copy is made in memory)
|
||||||
|
@ -710,6 +848,49 @@ class S3DISSampler(Sampler):
|
||||||
(input sphere) in epoch instead of the list of point indices
|
(input sphere) in epoch instead of the list of point indices
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not self.dataset.use_potentials:
|
||||||
|
|
||||||
|
# Initiate current epoch ind
|
||||||
|
self.dataset.epoch_i *= 0
|
||||||
|
self.dataset.epoch_inds *= 0
|
||||||
|
|
||||||
|
# Initiate container for indices
|
||||||
|
all_epoch_inds = np.zeros((2, 0), dtype=np.int32)
|
||||||
|
|
||||||
|
# Number of sphere centers taken per class in each cloud
|
||||||
|
num_centers = self.N * self.dataset.config.batch_num
|
||||||
|
random_pick_n = int(np.ceil(num_centers / (self.dataset.num_clouds * self.dataset.config.num_classes)))
|
||||||
|
|
||||||
|
# Choose random points of each class for each cloud
|
||||||
|
for cloud_ind, cloud_labels in enumerate(self.dataset.input_labels):
|
||||||
|
epoch_indices = np.empty((0,), dtype=np.int32)
|
||||||
|
for label_ind, label in enumerate(self.dataset.label_values):
|
||||||
|
if label not in self.dataset.ignored_labels:
|
||||||
|
label_indices = np.where(np.equal(cloud_labels, label))[0]
|
||||||
|
if len(label_indices) <= random_pick_n:
|
||||||
|
epoch_indices = np.hstack((epoch_indices, label_indices))
|
||||||
|
elif len(label_indices) < 50 * random_pick_n:
|
||||||
|
new_randoms = np.random.choice(label_indices, size=random_pick_n, replace=False)
|
||||||
|
epoch_indices = np.hstack((epoch_indices, new_randoms.astype(np.int32)))
|
||||||
|
else:
|
||||||
|
rand_inds = []
|
||||||
|
while len(rand_inds) < random_pick_n:
|
||||||
|
rand_inds = np.unique(np.random.choice(label_indices, size=5 * random_pick_n, replace=True))
|
||||||
|
epoch_indices = np.hstack((epoch_indices, rand_inds[:random_pick_n].astype(np.int32)))
|
||||||
|
|
||||||
|
# Stack those indices with the cloud index
|
||||||
|
epoch_indices = np.vstack((np.full(epoch_indices.shape, cloud_ind, dtype=np.int32), epoch_indices))
|
||||||
|
|
||||||
|
# Update the global indice container
|
||||||
|
all_epoch_inds = np.hstack((all_epoch_inds, epoch_indices))
|
||||||
|
|
||||||
|
# Random permutation of the indices
|
||||||
|
random_order = np.random.permutation(all_epoch_inds.shape[1])
|
||||||
|
all_epoch_inds = all_epoch_inds[:, random_order].astype(np.int64)
|
||||||
|
|
||||||
|
# Update epoch inds
|
||||||
|
self.dataset.epoch_inds += torch.from_numpy(all_epoch_inds[:, :num_centers])
|
||||||
|
|
||||||
# Generator loop
|
# Generator loop
|
||||||
for i in range(self.N):
|
for i in range(self.N):
|
||||||
yield i
|
yield i
|
||||||
|
@ -718,7 +899,7 @@ class S3DISSampler(Sampler):
|
||||||
"""
|
"""
|
||||||
The number of yielded samples is variable
|
The number of yielded samples is variable
|
||||||
"""
|
"""
|
||||||
return None
|
return self.N
|
||||||
|
|
||||||
def fast_calib(self):
|
def fast_calib(self):
|
||||||
"""
|
"""
|
||||||
|
@ -827,9 +1008,14 @@ class S3DISSampler(Sampler):
|
||||||
batch_lim_dict = {}
|
batch_lim_dict = {}
|
||||||
|
|
||||||
# Check if the batch limit associated with current parameters exists
|
# Check if the batch limit associated with current parameters exists
|
||||||
key = '{:.3f}_{:.3f}_{:d}'.format(self.dataset.config.in_radius,
|
if self.dataset.use_potentials:
|
||||||
self.dataset.config.first_subsampling_dl,
|
sampler_method = 'potentials'
|
||||||
self.dataset.config.batch_num)
|
else:
|
||||||
|
sampler_method = 'random'
|
||||||
|
key = '{:s}_{:.3f}_{:.3f}_{:d}'.format(sampler_method,
|
||||||
|
self.dataset.config.in_radius,
|
||||||
|
self.dataset.config.first_subsampling_dl,
|
||||||
|
self.dataset.config.batch_num)
|
||||||
if key in batch_lim_dict:
|
if key in batch_lim_dict:
|
||||||
self.dataset.batch_limit[0] = batch_lim_dict[key]
|
self.dataset.batch_limit[0] = batch_lim_dict[key]
|
||||||
else:
|
else:
|
||||||
|
@ -903,8 +1089,6 @@ class S3DISSampler(Sampler):
|
||||||
# From config parameter, compute higher bound of neighbors number in a neighborhood
|
# From config parameter, compute higher bound of neighbors number in a neighborhood
|
||||||
hist_n = int(np.ceil(4 / 3 * np.pi * (self.dataset.config.deform_radius + 1) ** 3))
|
hist_n = int(np.ceil(4 / 3 * np.pi * (self.dataset.config.deform_radius + 1) ** 3))
|
||||||
|
|
||||||
print(hist_n)
|
|
||||||
|
|
||||||
# Histogram of neighborhood sizes
|
# Histogram of neighborhood sizes
|
||||||
neighb_hists = np.zeros((self.dataset.config.num_layers, hist_n), dtype=np.int32)
|
neighb_hists = np.zeros((self.dataset.config.num_layers, hist_n), dtype=np.int32)
|
||||||
|
|
||||||
|
@ -1018,9 +1202,14 @@ class S3DISSampler(Sampler):
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Save batch_limit dictionary
|
# Save batch_limit dictionary
|
||||||
key = '{:.3f}_{:.3f}_{:d}'.format(self.dataset.config.in_radius,
|
if self.dataset.use_potentials:
|
||||||
self.dataset.config.first_subsampling_dl,
|
sampler_method = 'potentials'
|
||||||
self.dataset.config.batch_num)
|
else:
|
||||||
|
sampler_method = 'random'
|
||||||
|
key = '{:s}_{:.3f}_{:.3f}_{:d}'.format(sampler_method,
|
||||||
|
self.dataset.config.in_radius,
|
||||||
|
self.dataset.config.first_subsampling_dl,
|
||||||
|
self.dataset.config.batch_num)
|
||||||
batch_lim_dict[key] = float(self.dataset.batch_limit)
|
batch_lim_dict[key] = float(self.dataset.batch_limit)
|
||||||
with open(batch_lim_file, 'wb') as file:
|
with open(batch_lim_file, 'wb') as file:
|
||||||
pickle.dump(batch_lim_dict, file)
|
pickle.dump(batch_lim_dict, file)
|
||||||
|
@ -1232,6 +1421,7 @@ def debug_timing(dataset, loader):
|
||||||
last_display = time.time()
|
last_display = time.time()
|
||||||
mean_dt = np.zeros(2)
|
mean_dt = np.zeros(2)
|
||||||
estim_b = dataset.config.batch_num
|
estim_b = dataset.config.batch_num
|
||||||
|
estim_N = 0
|
||||||
|
|
||||||
for epoch in range(10):
|
for epoch in range(10):
|
||||||
|
|
||||||
|
@ -1244,6 +1434,7 @@ def debug_timing(dataset, loader):
|
||||||
|
|
||||||
# Update estim_b (low pass filter)
|
# Update estim_b (low pass filter)
|
||||||
estim_b += (len(batch.cloud_inds) - estim_b) / 100
|
estim_b += (len(batch.cloud_inds) - estim_b) / 100
|
||||||
|
estim_N += (batch.features.shape[0] - estim_N) / 10
|
||||||
|
|
||||||
# Pause simulating computations
|
# Pause simulating computations
|
||||||
time.sleep(0.05)
|
time.sleep(0.05)
|
||||||
|
@ -1255,11 +1446,12 @@ def debug_timing(dataset, loader):
|
||||||
# Console display (only one per second)
|
# Console display (only one per second)
|
||||||
if (t[-1] - last_display) > -1.0:
|
if (t[-1] - last_display) > -1.0:
|
||||||
last_display = t[-1]
|
last_display = t[-1]
|
||||||
message = 'Step {:08d} -> (ms/batch) {:8.2f} {:8.2f} / batch = {:.2f}'
|
message = 'Step {:08d} -> (ms/batch) {:8.2f} {:8.2f} / batch = {:.2f} - {:.0f}'
|
||||||
print(message.format(batch_i,
|
print(message.format(batch_i,
|
||||||
1000 * mean_dt[0],
|
1000 * mean_dt[0],
|
||||||
1000 * mean_dt[1],
|
1000 * mean_dt[1],
|
||||||
estim_b))
|
estim_b,
|
||||||
|
estim_N))
|
||||||
|
|
||||||
print('************* Epoch ended *************')
|
print('************* Epoch ended *************')
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,11 @@ class KPCNN(nn.Module):
|
||||||
block_in_layer += 1
|
block_in_layer += 1
|
||||||
|
|
||||||
# Update dimension of input from output
|
# Update dimension of input from output
|
||||||
in_dim = out_dim
|
if 'simple' in block:
|
||||||
|
in_dim = out_dim // 2
|
||||||
|
else:
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
|
||||||
# Detect change to a subsampled layer
|
# Detect change to a subsampled layer
|
||||||
if 'pool' in block or 'strided' in block:
|
if 'pool' in block or 'strided' in block:
|
||||||
|
@ -245,7 +249,10 @@ class KPFCNN(nn.Module):
|
||||||
config))
|
config))
|
||||||
|
|
||||||
# Update dimension of input from output
|
# Update dimension of input from output
|
||||||
in_dim = out_dim
|
if 'simple' in block:
|
||||||
|
in_dim = out_dim // 2
|
||||||
|
else:
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
# Detect change to a subsampled layer
|
# Detect change to a subsampled layer
|
||||||
if 'pool' in block or 'strided' in block:
|
if 'pool' in block or 'strided' in block:
|
||||||
|
|
|
@ -497,7 +497,7 @@ class SimpleBlock(nn.Module):
|
||||||
self.KPConv = KPConv(config.num_kernel_points,
|
self.KPConv = KPConv(config.num_kernel_points,
|
||||||
config.in_points_dim,
|
config.in_points_dim,
|
||||||
in_dim,
|
in_dim,
|
||||||
out_dim,
|
out_dim // 2,
|
||||||
current_extent,
|
current_extent,
|
||||||
radius,
|
radius,
|
||||||
fixed_kernel_points=config.fixed_kernel_points,
|
fixed_kernel_points=config.fixed_kernel_points,
|
||||||
|
@ -507,7 +507,7 @@ class SimpleBlock(nn.Module):
|
||||||
modulated=config.modulated)
|
modulated=config.modulated)
|
||||||
|
|
||||||
# Other opperations
|
# Other opperations
|
||||||
self.batch_norm = BatchNormBlock(out_dim, self.use_bn, self.bn_momentum)
|
self.batch_norm = BatchNormBlock(out_dim // 2, self.use_bn, self.bn_momentum)
|
||||||
self.leaky_relu = nn.LeakyReLU(0.1)
|
self.leaky_relu = nn.LeakyReLU(0.1)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -549,16 +549,16 @@ class ResnetBottleneckBlock(nn.Module):
|
||||||
self.layer_ind = layer_ind
|
self.layer_ind = layer_ind
|
||||||
|
|
||||||
# First downscaling mlp
|
# First downscaling mlp
|
||||||
if in_dim != out_dim // 2:
|
if in_dim != out_dim // 4:
|
||||||
self.unary1 = UnaryBlock(in_dim, out_dim // 2, self.use_bn, self.bn_momentum)
|
self.unary1 = UnaryBlock(in_dim, out_dim // 4, self.use_bn, self.bn_momentum)
|
||||||
else:
|
else:
|
||||||
self.unary1 = nn.Identity()
|
self.unary1 = nn.Identity()
|
||||||
|
|
||||||
# KPConv block
|
# KPConv block
|
||||||
self.KPConv = KPConv(config.num_kernel_points,
|
self.KPConv = KPConv(config.num_kernel_points,
|
||||||
config.in_points_dim,
|
config.in_points_dim,
|
||||||
out_dim // 2,
|
out_dim // 4,
|
||||||
out_dim // 2,
|
out_dim // 4,
|
||||||
current_extent,
|
current_extent,
|
||||||
radius,
|
radius,
|
||||||
fixed_kernel_points=config.fixed_kernel_points,
|
fixed_kernel_points=config.fixed_kernel_points,
|
||||||
|
@ -566,10 +566,10 @@ class ResnetBottleneckBlock(nn.Module):
|
||||||
aggregation_mode=config.aggregation_mode,
|
aggregation_mode=config.aggregation_mode,
|
||||||
deformable='deform' in block_name,
|
deformable='deform' in block_name,
|
||||||
modulated=config.modulated)
|
modulated=config.modulated)
|
||||||
self.batch_norm_conv = BatchNormBlock(out_dim // 2, self.use_bn, self.bn_momentum)
|
self.batch_norm_conv = BatchNormBlock(out_dim // 4, self.use_bn, self.bn_momentum)
|
||||||
|
|
||||||
# Second upscaling mlp
|
# Second upscaling mlp
|
||||||
self.unary2 = UnaryBlock(out_dim // 2, out_dim, self.use_bn, self.bn_momentum, no_relu=True)
|
self.unary2 = UnaryBlock(out_dim // 4, out_dim, self.use_bn, self.bn_momentum, no_relu=True)
|
||||||
|
|
||||||
# Shortcut optional mpl
|
# Shortcut optional mpl
|
||||||
if in_dim != out_dim:
|
if in_dim != out_dim:
|
||||||
|
|
|
@ -1376,12 +1376,14 @@ def S3DIS_first(old_result_limit):
|
||||||
"""
|
"""
|
||||||
Test first S3DIS. First two test have all symetries (even vertical), which is not good). We corecct for
|
Test first S3DIS. First two test have all symetries (even vertical), which is not good). We corecct for
|
||||||
the following.
|
the following.
|
||||||
Then we try some experiments with different input scalea and the results are not as high as expected. WHY?
|
Then we try some experiments with different input scalea and the results are not as high as expected.
|
||||||
|
WHY?
|
||||||
|
FOUND IT! Problem resnet bottleneck should divide out-dim by 4 and not by 2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
|
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
|
||||||
start = 'Log_2020-03-25_19-30-17'
|
start = 'Log_2020-03-25_19-30-17'
|
||||||
end = 'Log_2020-04-25_19-30-17'
|
end = 'Log_2020-04-03_11-12-05'
|
||||||
|
|
||||||
if end < old_result_limit:
|
if end < old_result_limit:
|
||||||
res_path = 'old_results'
|
res_path = 'old_results'
|
||||||
|
@ -1399,6 +1401,35 @@ def S3DIS_first(old_result_limit):
|
||||||
'Fin=5_R=2.5_r=0.04',
|
'Fin=5_R=2.5_r=0.04',
|
||||||
'original_normal',
|
'original_normal',
|
||||||
'original_deform',
|
'original_deform',
|
||||||
|
'original_random_sampler',
|
||||||
|
'original_potentials_batch16',
|
||||||
|
'test']
|
||||||
|
|
||||||
|
logs_names = np.array(logs_names[:len(logs)])
|
||||||
|
|
||||||
|
return logs, logs_names
|
||||||
|
|
||||||
|
|
||||||
|
def S3DIS_(old_result_limit):
|
||||||
|
"""
|
||||||
|
Test S3DIS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
|
||||||
|
start = 'Log_2020-04-03_11-12-07'
|
||||||
|
end = 'Log_2020-04-25_19-30-17'
|
||||||
|
|
||||||
|
if end < old_result_limit:
|
||||||
|
res_path = 'old_results'
|
||||||
|
else:
|
||||||
|
res_path = 'results'
|
||||||
|
|
||||||
|
logs = np.sort([join(res_path, l) for l in listdir(res_path) if start <= l <= end])
|
||||||
|
logs = logs.astype('<U50')
|
||||||
|
|
||||||
|
# Give names to the logs (for legends)
|
||||||
|
logs_names = ['R=2.0_r=0.04_Din=128_potential',
|
||||||
|
'R=2.0_r=0.04_Din=64_potential',
|
||||||
'test']
|
'test']
|
||||||
|
|
||||||
logs_names = np.array(logs_names[:len(logs)])
|
logs_names = np.array(logs_names[:len(logs)])
|
||||||
|
@ -1416,7 +1447,7 @@ if __name__ == '__main__':
|
||||||
old_res_lim = 'Log_2020-03-25_19-30-17'
|
old_res_lim = 'Log_2020-03-25_19-30-17'
|
||||||
|
|
||||||
# My logs: choose the logs to show
|
# My logs: choose the logs to show
|
||||||
logs, logs_names = S3DIS_first(old_res_lim)
|
logs, logs_names = S3DIS_(old_res_lim)
|
||||||
#os.environ['QT_DEBUG_PLUGINS'] = '1'
|
#os.environ['QT_DEBUG_PLUGINS'] = '1'
|
||||||
|
|
||||||
######################################################
|
######################################################
|
||||||
|
|
|
@ -76,9 +76,9 @@ class S3DISConfig(Config):
|
||||||
'resnetb_strided',
|
'resnetb_strided',
|
||||||
'resnetb',
|
'resnetb',
|
||||||
'resnetb_strided',
|
'resnetb_strided',
|
||||||
'resnetb_deformable',
|
'resnetb',
|
||||||
'resnetb_deformable_strided',
|
'resnetb_strided',
|
||||||
'resnetb_deformable',
|
'resnetb',
|
||||||
'nearest_upsample',
|
'nearest_upsample',
|
||||||
'unary',
|
'unary',
|
||||||
'nearest_upsample',
|
'nearest_upsample',
|
||||||
|
@ -117,6 +117,7 @@ class S3DISConfig(Config):
|
||||||
aggregation_mode = 'sum'
|
aggregation_mode = 'sum'
|
||||||
|
|
||||||
# Choice of input features
|
# Choice of input features
|
||||||
|
first_features_dim = 64
|
||||||
in_features_dim = 5
|
in_features_dim = 5
|
||||||
|
|
||||||
# Can the network learn modulations
|
# Can the network learn modulations
|
||||||
|
@ -161,8 +162,8 @@ class S3DISConfig(Config):
|
||||||
augment_scale_anisotropic = True
|
augment_scale_anisotropic = True
|
||||||
augment_symmetries = [True, False, False]
|
augment_symmetries = [True, False, False]
|
||||||
augment_rotation = 'vertical'
|
augment_rotation = 'vertical'
|
||||||
augment_scale_min = 0.9
|
augment_scale_min = 0.8
|
||||||
augment_scale_max = 1.1
|
augment_scale_max = 1.2
|
||||||
augment_noise = 0.001
|
augment_noise = 0.001
|
||||||
augment_color = 0.8
|
augment_color = 0.8
|
||||||
|
|
||||||
|
@ -189,13 +190,8 @@ if __name__ == '__main__':
|
||||||
# Initialize the environment
|
# Initialize the environment
|
||||||
############################
|
############################
|
||||||
|
|
||||||
# TODO: 9 millions de parametres au lieu de 14 millions... Pourquoi?
|
|
||||||
# TODO: radius des strided 2 fois trop grand
|
|
||||||
# TODO: implement un sampler plus simple
|
|
||||||
# TODO: test batch size a 16
|
|
||||||
|
|
||||||
# Set which gpu is going to be used
|
# Set which gpu is going to be used
|
||||||
GPU_ID = '1'
|
GPU_ID = '2'
|
||||||
|
|
||||||
# Set GPU visible device
|
# Set GPU visible device
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
|
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
|
||||||
|
@ -245,8 +241,8 @@ if __name__ == '__main__':
|
||||||
config.saving_path = sys.argv[1]
|
config.saving_path = sys.argv[1]
|
||||||
|
|
||||||
# Initialize datasets
|
# Initialize datasets
|
||||||
training_dataset = S3DISDataset(config, set='training')
|
training_dataset = S3DISDataset(config, set='training', use_potentials=True)
|
||||||
test_dataset = S3DISDataset(config, set='validation')
|
test_dataset = S3DISDataset(config, set='validation', use_potentials=True)
|
||||||
|
|
||||||
# Initialize samplers
|
# Initialize samplers
|
||||||
training_sampler = S3DISSampler(training_dataset)
|
training_sampler = S3DISSampler(training_dataset)
|
||||||
|
@ -280,8 +276,18 @@ if __name__ == '__main__':
|
||||||
# Define network model
|
# Define network model
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
net = KPFCNN(config)
|
net = KPFCNN(config)
|
||||||
print(net)
|
|
||||||
print("Model size %i" % sum(param.numel() for param in net.parameters() if param.requires_grad))
|
debug = False
|
||||||
|
if debug:
|
||||||
|
print('\n*************************************\n')
|
||||||
|
print(net)
|
||||||
|
print('\n*************************************\n')
|
||||||
|
for param in net.parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
print(param.shape)
|
||||||
|
print('\n*************************************\n')
|
||||||
|
print("Model size %i" % sum(param.numel() for param in net.parameters() if param.requires_grad))
|
||||||
|
print('\n*************************************\n')
|
||||||
|
|
||||||
# Define a trainer class
|
# Define a trainer class
|
||||||
trainer = ModelTrainer(net, config, chkp_path=chosen_chkp)
|
trainer = ModelTrainer(net, config, chkp_path=chosen_chkp)
|
||||||
|
|
Loading…
Reference in a new issue