Corrections

This commit is contained in:
HuguesTHOMAS 2020-04-23 09:51:16 -04:00
parent 46e2f7d3e6
commit 31cce85c95
15 changed files with 1321 additions and 191 deletions

502
datasets/NCLT.py Normal file
View file

@ -0,0 +1,502 @@
#
#
# 0=================================0
# | Kernel Point Convolutions |
# 0=================================0
#
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Class handling SemanticKitti dataset.
# Implements a Dataset, a Sampler, and a collate_fn
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Hugues THOMAS - 11/06/2018
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Imports and global variables
# \**********************************/
#
# Common libs
import sys
import struct
import scipy
import time
import numpy as np
import pickle
import torch
import yaml
#from mayavi import mlab
from multiprocessing import Lock
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# OS functions
from os import listdir
from os.path import exists, join, isdir, getsize
# Dataset parent class
from datasets.common import *
from torch.utils.data import Sampler, get_worker_info
from utils.mayavi_visu import *
from utils.metrics import fast_confusion
from datasets.common import grid_subsampling
from utils.config import bcolors
def ssc_to_homo(ssc, ssc_in_radians=True):
# Convert 6-DOF ssc coordinate transformation to 4x4 homogeneous matrix
# transformation
if ssc.ndim == 1:
reduce = True
ssc = np.expand_dims(ssc, 0)
else:
reduce = False
if not ssc_in_radians:
ssc[:, 3:] = np.pi / 180.0 * ssc[:, 3:]
sr = np.sin(ssc[:, 3])
cr = np.cos(ssc[:, 3])
sp = np.sin(ssc[:, 4])
cp = np.cos(ssc[:, 4])
sh = np.sin(ssc[:, 5])
ch = np.cos(ssc[:, 5])
H = np.zeros((ssc.shape[0], 4, 4))
H[:, 0, 0] = ch*cp
H[:, 0, 1] = -sh*cr + ch*sp*sr
H[:, 0, 2] = sh*sr + ch*sp*cr
H[:, 1, 0] = sh*cp
H[:, 1, 1] = ch*cr + sh*sp*sr
H[:, 1, 2] = -ch*sr + sh*sp*cr
H[:, 2, 0] = -sp
H[:, 2, 1] = cp*sr
H[:, 2, 2] = cp*cr
H[:, 0, 3] = ssc[:, 0]
H[:, 1, 3] = ssc[:, 1]
H[:, 2, 3] = ssc[:, 2]
H[:, 3, 3] = 1
if reduce:
H = np.squeeze(H)
return H
def verify_magic(s):
magic = 44444
m = struct.unpack('<HHHH', s)
return len(m)>=4 and m[0] == magic and m[1] == magic and m[2] == magic and m[3] == magic
def test_read_hits():
data_path = '../../Data/NCLT'
velo_folder = 'velodyne_data'
day = '2012-01-08'
hits_path = join(data_path, velo_folder, day, 'velodyne_hits.bin')
all_utimes = []
all_hits = []
all_ints = []
num_bytes = getsize(hits_path)
current_bytes = 0
with open(hits_path, 'rb') as f_bin:
total_hits = 0
first_utime = -1
last_utime = -1
while True:
magic = f_bin.read(8)
if magic == b'':
break
if not verify_magic(magic):
print('Could not verify magic')
num_hits = struct.unpack('<I', f_bin.read(4))[0]
utime = struct.unpack('<Q', f_bin.read(8))[0]
# Do not convert padding (it is an int always equal to zero)
padding = f_bin.read(4)
total_hits += num_hits
if first_utime == -1:
first_utime = utime
last_utime = utime
hits = []
ints = []
for i in range(num_hits):
x = struct.unpack('<H', f_bin.read(2))[0]
y = struct.unpack('<H', f_bin.read(2))[0]
z = struct.unpack('<H', f_bin.read(2))[0]
i = struct.unpack('B', f_bin.read(1))[0]
l = struct.unpack('B', f_bin.read(1))[0]
hits += [[x, y, z]]
ints += [i]
utimes = np.full((num_hits,), utime - first_utime, dtype=np.int32)
ints = np.array(ints, dtype=np.uint8)
hits = np.array(hits, dtype=np.float32)
hits *= 0.005
hits += -100.0
all_utimes.append(utimes)
all_hits.append(hits)
all_ints.append(ints)
if 100 * current_bytes / num_bytes > 0.1:
break
current_bytes += 24 + 8 * num_hits
print('{:d}/{:d} => {:.1f}%'.format(current_bytes, num_bytes, 100 * current_bytes / num_bytes))
all_utimes = np.hstack(all_utimes)
all_hits = np.vstack(all_hits)
all_ints = np.hstack(all_ints)
write_ply('test_hits',
[all_hits, all_ints, all_utimes],
['x', 'y', 'z', 'intensity', 'utime'])
print("Read %d total hits from %ld to %ld" % (total_hits, first_utime, last_utime))
return 0
def frames_to_ply(show_frames=False):
# In files
data_path = '../../Data/NCLT'
velo_folder = 'velodyne_data'
days = np.sort([d for d in listdir(join(data_path, velo_folder))])
for day in days:
# Out files
ply_folder = join(data_path, 'frames_ply', day)
if not exists(ply_folder):
makedirs(ply_folder)
day_path = join(data_path, velo_folder, day, 'velodyne_sync')
f_names = np.sort([f for f in listdir(day_path) if f[-4:] == '.bin'])
N = len(f_names)
print('Reading', N, 'files')
for f_i, f_name in enumerate(f_names):
ply_name = join(ply_folder, f_name[:-4] + '.ply')
if exists(ply_name):
continue
t1 = time.time()
hits = []
ints = []
with open(join(day_path, f_name), 'rb') as f_bin:
while True:
x_str = f_bin.read(2)
# End of file
if x_str == b'':
break
x = struct.unpack('<H', x_str)[0]
y = struct.unpack('<H', f_bin.read(2))[0]
z = struct.unpack('<H', f_bin.read(2))[0]
intensity = struct.unpack('B', f_bin.read(1))[0]
l = struct.unpack('B', f_bin.read(1))[0]
hits += [[x, y, z]]
ints += [intensity]
ints = np.array(ints, dtype=np.uint8)
hits = np.array(hits, dtype=np.float32)
hits *= 0.005
hits += -100.0
write_ply(ply_name,
[hits, ints],
['x', 'y', 'z', 'intensity'])
t2 = time.time()
print('File {:s} {:d}/{:d} Done in {:.1f}s'.format(f_name, f_i, N, t2 - t1))
if show_frames:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(hits[:, 0], hits[:, 1], -hits[:, 2], c=-hits[:, 2], s=5, linewidths=0)
plt.show()
return 0
def merge_day_pointclouds(show_day_trajectory=False, only_SLAM_nodes=False):
"""
Recreate the whole day point cloud thks to gt pose
Generate gt_annotation of mobile objects
"""
# In files
data_path = '../../Data/NCLT'
gt_folder = 'ground_truth'
cov_folder = 'ground_truth_cov'
# Transformation from body to velodyne frame (from NCLT paper)
x_body_velo = np.array([0.002, -0.004, -0.957, 0.807, 0.166, -90.703])
H_body_velo = ssc_to_homo(x_body_velo, ssc_in_radians=False)
H_velo_body = np.linalg.inv(H_body_velo)
x_body_lb3 = np.array([0.035, 0.002, -1.23, -179.93, -0.23, 0.50])
H_body_lb3 = ssc_to_homo(x_body_lb3, ssc_in_radians=False)
H_lb3_body = np.linalg.inv(H_body_lb3)
# Get gt files and days
gt_files = np.sort([gt_f for gt_f in listdir(join(data_path, gt_folder)) if gt_f[-4:] == '.csv'])
cov_files = np.sort([cov_f for cov_f in listdir(join(data_path, cov_folder)) if cov_f[-4:] == '.csv'])
days = [d[:-4].split('_')[1] for d in gt_files]
# Load all gt poses
print('\nLoading days groundtruth poses...')
t0 = time.time()
gt_H = []
gt_t = []
for d, gt_f in enumerate(gt_files):
t1 = time.time()
gt_pkl_file = join(data_path, gt_folder, gt_f[:-4] + '.pkl')
if exists(gt_pkl_file):
# Read pkl
with open(gt_pkl_file, 'rb') as f:
day_gt_t, day_gt_H = pickle.load(f)
else:
# File paths
gt_csv = join(data_path, gt_folder, gt_f)
# Load gt
gt = np.loadtxt(gt_csv, delimiter=',')
# Convert gt to homogenous rotation/translation matrix
day_gt_t = gt[:, 0]
day_gt_H = ssc_to_homo(gt[:, 1:])
# Save pickle
with open(gt_pkl_file, 'wb') as f:
pickle.dump([day_gt_t, day_gt_H], f)
t2 = time.time()
print('{:s} {:d}/{:d} Done in {:.1f}s'.format(gt_f, d, gt_files.shape[0], t2 - t1))
gt_t += [day_gt_t]
gt_H += [day_gt_H]
if show_day_trajectory:
cov_csv = join(data_path, cov_folder, cov_files[d])
cov = np.loadtxt(cov_csv, delimiter=',')
t_cov = cov[:, 0]
t_cov_bool = np.logical_and(t_cov > np.min(day_gt_t), t_cov < np.max(day_gt_t))
t_cov = t_cov[t_cov_bool]
# Note: Interpolation is not needed, this is done as a convinience
interp = scipy.interpolate.interp1d(day_gt_t, day_gt_H[:, :3, 3], kind='nearest', axis=0)
node_poses = interp(t_cov)
plt.figure()
plt.scatter(day_gt_H[:, 1, 3], day_gt_H[:, 0, 3], 1, c=-day_gt_H[:, 2, 3], linewidth=0)
plt.scatter(node_poses[:, 1], node_poses[:, 0], 1, c=-node_poses[:, 2], linewidth=5)
plt.axis('equal')
plt.title('Ground Truth Position of Nodes in SLAM Graph')
plt.xlabel('East (m)')
plt.ylabel('North (m)')
plt.colorbar()
plt.show()
t2 = time.time()
print('Done in {:.1f}s\n'.format(t2 - t0))
# Out files
out_folder = join(data_path, 'day_ply')
if not exists(out_folder):
makedirs(out_folder)
# Focus on a particular point
p0 = np.array([-220, -527, 12])
center_radius = 10.0
point_radius = 50.0
# Loop on days
for d, day in enumerate(days):
#if day != '2012-02-05':
# continue
day_min_t = gt_t[d][0]
day_max_t = gt_t[d][-1]
frames_folder = join(data_path, 'frames_ply', day)
f_times = np.sort([float(f[:-4]) for f in listdir(frames_folder) if f[-4:] == '.ply'])
# If we want, load only SLAM nodes
if only_SLAM_nodes:
# Load node timestamps
cov_csv = join(data_path, cov_folder, cov_files[d])
cov = np.loadtxt(cov_csv, delimiter=',')
t_cov = cov[:, 0]
t_cov_bool = np.logical_and(t_cov > day_min_t, t_cov < day_max_t)
t_cov = t_cov[t_cov_bool]
# Find closest lidar frames
t_cov = np.expand_dims(t_cov, 1)
diffs = np.abs(t_cov - f_times)
inds = np.argmin(diffs, axis=1)
f_times = f_times[inds]
# Is this frame in gt
f_t_bool = np.logical_and(f_times > day_min_t, f_times < day_max_t)
f_times = f_times[f_t_bool]
# Interpolation gt poses to frame timestamps
interp = scipy.interpolate.interp1d(gt_t[d], gt_H[d], kind='nearest', axis=0)
frame_poses = interp(f_times)
N = len(f_times)
world_points = []
world_frames = []
world_frames_c = []
print('Reading', day, ' => ', N, 'files')
for f_i, f_t in enumerate(f_times):
t1 = time.time()
#########
# GT pose
#########
H = frame_poses[f_i].astype(np.float32)
# s = '\n'
# for cc in H:
# for c in cc:
# s += '{:5.2f} '.format(c)
# s += '\n'
# print(s)
#############
# Focus check
#############
if np.linalg.norm(H[:3, 3] - p0) > center_radius:
continue
###################################
# Local frame coordinates for debug
###################################
# Create artificial frames
x = np.linspace(0, 1, 50, dtype=np.float32)
points = np.hstack((np.vstack((x, x*0, x*0)), np.vstack((x*0, x, x*0)), np.vstack((x*0, x*0, x)))).T
colors = ((points > 0.1).astype(np.float32) * 255).astype(np.uint8)
hpoints = np.hstack((points, np.ones_like(points[:, :1])))
hpoints = np.matmul(hpoints, H.T)
hpoints[:, 3] *= 0
world_frames += [hpoints[:, :3]]
world_frames_c += [colors]
#######################
# Load velo point cloud
#######################
# Load frame ply file
f_name = '{:.0f}.ply'.format(f_t)
data = read_ply(join(frames_folder, f_name))
points = np.vstack((data['x'], data['y'], data['z'])).T
#intensity = data['intensity']
hpoints = np.hstack((points, np.ones_like(points[:, :1])))
hpoints = np.matmul(hpoints, H.T)
hpoints[:, 3] *= 0
hpoints[:, 3] += np.sqrt(f_t - f_times[0])
# focus check
focus_bool = np.linalg.norm(hpoints[:, :3] - p0, axis=1) < point_radius
hpoints = hpoints[focus_bool, :]
world_points += [hpoints]
t2 = time.time()
print('File {:s} {:d}/{:d} Done in {:.1f}s'.format(f_name, f_i, N, t2 - t1))
if len(world_points) < 2:
continue
world_points = np.vstack(world_points)
###### DEBUG
world_frames = np.vstack(world_frames)
world_frames_c = np.vstack(world_frames_c)
write_ply('testf.ply',
[world_frames, world_frames_c],
['x', 'y', 'z', 'red', 'green', 'blue'])
###### DEBUG
print(world_points.shape, world_points.dtype)
# Subsample merged frames
# world_points, features = grid_subsampling(world_points[:, :3],
# features=world_points[:, 3:],
# sampleDl=0.1)
features = world_points[:, 3:]
world_points = world_points[:, :3]
print(world_points.shape, world_points.dtype)
write_ply('test' + day + '.ply',
[world_points, features],
['x', 'y', 'z', 't'])
# Generate gt annotations
# Subsample day ply (for visualization)
# Save day ply
# a = 1/0

View file

@ -131,7 +131,7 @@ class S3DISDataset(PointCloudDataset):
# Prepare ply files
###################
self.prepare_S3DIS_ply()
#self.prepare_S3DIS_ply()
################
# Load ply files
@ -1037,7 +1037,7 @@ class S3DISSampler(Sampler):
if breaking:
break
def calibration(self, dataloader, untouched_ratio=0.9, verbose=False):
def calibration(self, dataloader, untouched_ratio=0.9, verbose=False, force_redo=False):
"""
Method performing batch and neighbors calibration.
Batch calibration: Set "batch_limit" (the maximum number of points allowed in every batch) so that the
@ -1053,7 +1053,7 @@ class S3DISSampler(Sampler):
print('\nStarting Calibration (use verbose=True for more details)')
t0 = time.time()
redo = False
redo = force_redo
# Batch limit
# ***********
@ -1075,7 +1075,7 @@ class S3DISSampler(Sampler):
self.dataset.config.in_radius,
self.dataset.config.first_subsampling_dl,
self.dataset.config.batch_num)
if key in batch_lim_dict:
if not redo and key in batch_lim_dict:
self.dataset.batch_limit[0] = batch_lim_dict[key]
else:
redo = True
@ -1116,7 +1116,7 @@ class S3DISSampler(Sampler):
if key in neighb_lim_dict:
neighb_limits += [neighb_lim_dict[key]]
if len(neighb_limits) == self.dataset.config.num_layers:
if not redo and len(neighb_limits) == self.dataset.config.num_layers:
self.dataset.neighborhood_limits = neighb_limits
else:
redo = True

View file

@ -1115,7 +1115,7 @@ class SemanticKittiSampler(Sampler):
# Perform calibration
#####################
self.dataset.batch_limit = self.dataset.max_in_p * (self.dataset.batch_num - 1)
#self.dataset.batch_limit[0] = self.dataset.max_in_p * (self.dataset.batch_num - 1)
for epoch in range(10):
for batch_i, batch in enumerate(dataloader):
@ -1145,7 +1145,7 @@ class SemanticKittiSampler(Sampler):
smooth_errors = smooth_errors[1:]
# Update batch limit with P controller
self.dataset.batch_limit += Kp * error
self.dataset.batch_limit[0] += Kp * error
# finer low pass filter when closing in
if not finer and np.abs(estim_b - target_b) < 1:
@ -1166,7 +1166,7 @@ class SemanticKittiSampler(Sampler):
message = 'Step {:5d} estim_b ={:5.2f} batch_limit ={:7d}'
print(message.format(i,
estim_b,
int(self.dataset.batch_limit)))
int(self.dataset.batch_limit[0])))
if breaking:
break
@ -1224,7 +1224,7 @@ class SemanticKittiSampler(Sampler):
self.dataset.config.first_subsampling_dl,
self.dataset.batch_num,
self.dataset.max_in_p)
batch_lim_dict[key] = float(self.dataset.batch_limit)
batch_lim_dict[key] = float(self.dataset.batch_limit[0])
with open(batch_lim_file, 'wb') as file:
pickle.dump(batch_lim_dict, file)

View file

@ -228,7 +228,7 @@ class PointCloudDataset(Dataset):
# Add random symmetries to the scale factor
symmetries = np.array(self.config.augment_symmetries).astype(np.int32)
symmetries *= np.random.randint(2, size=points.shape[1])
scale = (scale * symmetries * 2 - 1).astype(np.float32)
scale = (scale * (1 - symmetries * 2)).astype(np.float32)
#######
# Noise

View file

@ -183,7 +183,7 @@ class KPCNN(nn.Module):
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
rep_loss = torch.sum(torch.clamp_max(distances - 1.5, max=0.0) ** 2, dim=1)
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss))
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K
@ -218,7 +218,7 @@ class KPFCNN(nn.Module):
#####################
# Save all block operations in a list of modules
self.encoder_blocs = nn.ModuleList()
self.encoder_blocks = nn.ModuleList()
self.encoder_skip_dims = []
self.encoder_skips = []
@ -239,7 +239,7 @@ class KPFCNN(nn.Module):
break
# Apply the good block function defining tf ops
self.encoder_blocs.append(block_decider(block,
self.encoder_blocks.append(block_decider(block,
r,
in_dim,
out_dim,
@ -264,7 +264,7 @@ class KPFCNN(nn.Module):
#####################
# Save all block operations in a list of modules
self.decoder_blocs = nn.ModuleList()
self.decoder_blocks = nn.ModuleList()
self.decoder_concats = []
# Find first upsampling block
@ -283,7 +283,7 @@ class KPFCNN(nn.Module):
self.decoder_concats.append(block_i)
# Apply the good block function defining tf ops
self.decoder_blocs.append(block_decider(block,
self.decoder_blocks.append(block_decider(block,
r,
in_dim,
out_dim,
@ -331,12 +331,12 @@ class KPFCNN(nn.Module):
# Loop over consecutive blocks
skip_x = []
for block_i, block_op in enumerate(self.encoder_blocs):
for block_i, block_op in enumerate(self.encoder_blocks):
if block_i in self.encoder_skips:
skip_x.append(x)
x = block_op(x, batch)
for block_i, block_op in enumerate(self.decoder_blocs):
for block_i, block_op in enumerate(self.decoder_blocks):
if block_i in self.decoder_concats:
x = torch.cat([x, skip_x.pop()], dim=1)
x = block_op(x, batch)
@ -434,9 +434,8 @@ class KPFCNN(nn.Module):
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
rep_loss = torch.sum(torch.clamp_max(distances - 1.5, max=0.0) ** 2, dim=1)
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss))
rep_loss = torch.sum(torch.clamp_max(distances - 0.5, max=0.0) ** 2, dim=1)
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K
return self.offset_decay * (fitting_loss + repulsive_loss)

View file

@ -421,6 +421,7 @@ class BatchNormBlock(nn.Module):
super(BatchNormBlock, self).__init__()
self.bn_momentum = bn_momentum
self.use_bn = use_bn
self.in_dim = in_dim
if self.use_bn:
self.batch_norm = nn.BatchNorm1d(in_dim, momentum=bn_momentum)
#self.batch_norm = nn.InstanceNorm1d(in_dim, momentum=bn_momentum)
@ -442,6 +443,11 @@ class BatchNormBlock(nn.Module):
else:
return x + self.bias
def __repr__(self):
return 'BatchNormBlock(in_feat: {:d}, momentum: {:.3f}, only_bias: {:s})'.format(self.in_dim,
self.bn_momentum,
str(not self.use_bn))
class UnaryBlock(nn.Module):
@ -458,6 +464,8 @@ class UnaryBlock(nn.Module):
self.bn_momentum = bn_momentum
self.use_bn = use_bn
self.no_relu = no_relu
self.in_dim = in_dim
self.out_dim = out_dim
self.mlp = nn.Linear(in_dim, out_dim, bias=False)
self.batch_norm = BatchNormBlock(out_dim, self.use_bn, self.bn_momentum)
if not no_relu:
@ -471,6 +479,12 @@ class UnaryBlock(nn.Module):
x = self.leaky_relu(x)
return x
def __repr__(self):
return 'UnaryBlock(in_feat: {:d}, out_feat: {:d}, BN: {:s}, ReLU: {:s})'.format(self.in_dim,
self.out_dim,
str(self.use_bn),
str(not self.no_relu))
class SimpleBlock(nn.Module):
@ -492,6 +506,8 @@ class SimpleBlock(nn.Module):
self.use_bn = config.use_batch_norm
self.layer_ind = layer_ind
self.block_name = block_name
self.in_dim = in_dim
self.out_dim = out_dim
# Define the KPConv class
self.KPConv = KPConv(config.num_kernel_points,
@ -547,6 +563,8 @@ class ResnetBottleneckBlock(nn.Module):
self.use_bn = config.use_batch_norm
self.block_name = block_name
self.layer_ind = layer_ind
self.in_dim = in_dim
self.out_dim = out_dim
# First downscaling mlp
if in_dim != out_dim // 4:
@ -639,6 +657,10 @@ class NearestUpsampleBlock(nn.Module):
def forward(self, x, batch):
return closest_pool(x, batch.upsamples[self.layer_ind - 1])
def __repr__(self):
return 'NearestUpsampleBlock(layer: {:d} -> {:d})'.format(self.layer_ind,
self.layer_ind - 1)
class MaxPoolBlock(nn.Module):

View file

@ -1445,12 +1445,14 @@ def S3DIS_go(old_result_limit):
def SemanticKittiFirst(old_result_limit):
"""
Test SematicKitti. First exps
Test SematicKitti. First exps.
Try some class weight strategies. It seems that the final score is not impacted so much. With weights, some classes
are better while other are worse, for a final score that remains the same.
"""
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
start = 'Log_2020-04-07_15-30-17'
end = 'Log_2020-05-07_15-30-17'
end = 'Log_2020-04-11_21-34-16'
if end < old_result_limit:
res_path = 'old_results'
@ -1464,8 +1466,43 @@ def SemanticKittiFirst(old_result_limit):
logs_names = ['R=5.0_dl=0.04',
'R=5.0_dl=0.08',
'R=10.0_dl=0.08',
'R=10.0_dl=0.08_weigths',
'R=10.0_dl=0.08_sqrt_weigths',
'R=10.0_dl=0.08_20*weigths',
'R=10.0_dl=0.08_20*sqrt_weigths',
'R=10.0_dl=0.08_100*sqrt_w',
'R=10.0_dl=0.08_100*sqrt_w_capped',
'R=10.0_dl=0.08_no_w']
logs_names = np.array(logs_names[:len(logs)])
return logs, logs_names
def SemanticKitti_scale(old_result_limit):
"""
Test SematicKitti. Try different scales of input raduis / subsampling.
"""
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
start = 'Log_2020-04-11_21-34-15'
end = 'Log_2020-04-20_11-52-58'
if end < old_result_limit:
res_path = 'old_results'
else:
res_path = 'results'
logs = np.sort([join(res_path, l) for l in listdir(res_path) if start <= l <= end])
logs = logs.astype('<U50')
# Give names to the logs (for legends)
logs_names = ['R=10.0_dl=0.08',
'R=4.0_dl=0.04',
'R=6.0_dl=0.06',
'R=6.0_dl=0.06_inF=2',
'test',
'test',
'test',
'test',
'test']
logs_names = np.array(logs_names[:len(logs)])
@ -1473,6 +1510,41 @@ def SemanticKittiFirst(old_result_limit):
return logs, logs_names
def S3DIS_deform(old_result_limit):
"""
Debug S3DIS deformable.
At checkpoint 50, the points seem to start fitting the shape, but then, they just get further away from each other
and do not care about input points. The fitting loss seems broken?
"""
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
start = 'Log_2020-04-22_11-52-58'
end = 'Log_2020-05-22_11-52-58'
if end < old_result_limit:
res_path = 'old_results'
else:
res_path = 'results'
logs = np.sort([join(res_path, l) for l in listdir(res_path) if start <= l <= end])
logs = logs.astype('<U50')
logs = np.insert(logs, 0, 'results/Log_2020-04-04_10-04-42')
# Give names to the logs (for legends)
logs_names = ['off_d=0.01_baseline',
'off_d=0.01',
'off_d=0.05',
'off_d=0.05_corrected',
'off_d=0.05_norepulsive',
'off_d=0.05_repulsive0.5',
'test']
logs_names = np.array(logs_names[:len(logs)])
return logs, logs_names
@ -1489,7 +1561,7 @@ if __name__ == '__main__':
old_res_lim = 'Log_2020-03-25_19-30-17'
# My logs: choose the logs to show
logs, logs_names = SemanticKittiFirst(old_res_lim)
logs, logs_names = S3DIS_deform(old_res_lim)
#os.environ['QT_DEBUG_PLUGINS'] = '1'
######################################################

344
train_NCLT.py Normal file
View file

@ -0,0 +1,344 @@
#
#
# 0=================================0
# | Kernel Point Convolutions |
# 0=================================0
#
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Callable script to start a training on NCLT dataset
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Hugues THOMAS - 06/03/2020
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Imports and global variables
# \**********************************/
#
# Common libs
import signal
import os
import numpy as np
import sys
import torch
# Dataset
from datasets.NCLT import *
from torch.utils.data import DataLoader
from utils.config import Config
from utils.trainer import ModelTrainer
from models.architectures import KPFCNN
# ----------------------------------------------------------------------------------------------------------------------
#
# Config Class
# \******************/
#
class NCLTConfig(Config):
"""
Override the parameters you want to modify for this dataset
"""
####################
# Dataset parameters
####################
# Dataset name
dataset = 'NCLT'
# Number of classes in the dataset (This value is overwritten by dataset class when Initializating dataset).
num_classes = None
# Type of task performed on this dataset (also overwritten)
dataset_task = ''
# Number of CPU threads for the input pipeline
input_threads = 10
#########################
# Architecture definition
#########################
# Define layers
architecture = ['simple',
'resnetb',
'resnetb_strided',
'resnetb',
'resnetb',
'resnetb_strided',
'resnetb',
'resnetb',
'resnetb',
'resnetb',
'resnetb_strided',
'resnetb',
'resnetb',
'resnetb',
'resnetb_strided',
'resnetb',
'resnetb',
'nearest_upsample',
'unary',
'nearest_upsample',
'unary',
'nearest_upsample',
'unary',
'nearest_upsample',
'unary']
###################
# KPConv parameters
###################
# Radius of the input sphere
in_radius = 6.0
val_radius = 51.0
n_frames = 1
max_in_points = 100000
max_val_points = 200000
# Number of batch
batch_num = 8
val_batch_num = 1
# Number of kernel points
num_kernel_points = 15
# Size of the first subsampling grid in meter
first_subsampling_dl = 0.06
# Radius of convolution in "number grid cell". (2.5 is the standard value)
conv_radius = 2.5
# Radius of deformable convolution in "number grid cell". Larger so that deformed kernel can spread out
deform_radius = 6.0
# Radius of the area of influence of each kernel point in "number grid cell". (1.0 is the standard value)
KP_extent = 1.5
# Behavior of convolutions in ('constant', 'linear', 'gaussian')
KP_influence = 'linear'
# Aggregation function of KPConv in ('closest', 'sum')
aggregation_mode = 'sum'
# Choice of input features
first_features_dim = 128
in_features_dim = 2
# Can the network learn modulations
modulated = False
# Batch normalization parameters
use_batch_norm = True
batch_norm_momentum = 0.02
# Offset loss
# 'permissive' only constrains offsets inside the deform radius (NOT implemented yet)
# 'fitting' helps deformed kernels to adapt to the geometry by penalizing distance to input points
offsets_loss = 'fitting'
offsets_decay = 0.01
#####################
# Training parameters
#####################
# Maximal number of epochs
max_epoch = 800
# Learning rate management
learning_rate = 1e-2
momentum = 0.98
lr_decays = {i: 0.1 ** (1 / 150) for i in range(1, max_epoch)}
grad_clip_norm = 100.0
# Number of steps per epochs
epoch_steps = 500
# Number of validation examples per epoch
validation_size = 200
# Number of epoch between each checkpoint
checkpoint_gap = 50
# Augmentations
augment_scale_anisotropic = True
augment_symmetries = [True, False, False]
augment_rotation = 'vertical'
augment_scale_min = 0.8
augment_scale_max = 1.2
augment_noise = 0.001
augment_color = 0.8
# Choose weights for class (used in segmentation loss). Empty list for no weights
# class proportion for R=10.0 and dl=0.08 (first is unlabeled)
# 19.1 48.9 0.5 1.1 5.6 3.6 0.7 0.6 0.9 193.2 17.7 127.4 6.7 132.3 68.4 283.8 7.0 78.5 3.3 0.8
#
#
# sqrt(Inverse of proportion * 100)
# class_w = [1.430, 14.142, 9.535, 4.226, 5.270, 11.952, 12.910, 10.541, 0.719,
# 2.377, 0.886, 3.863, 0.869, 1.209, 0.594, 3.780, 1.129, 5.505, 11.180]
# sqrt(Inverse of proportion * 100) capped (0.5 < X < 5)
# class_w = [1.430, 5.000, 5.000, 4.226, 5.000, 5.000, 5.000, 5.000, 0.719, 2.377,
# 0.886, 3.863, 0.869, 1.209, 0.594, 3.780, 1.129, 5.000, 5.000]
# Do we nee to save convergence
saving = True
saving_path = None
# ----------------------------------------------------------------------------------------------------------------------
#
# Main Call
# \***************/
#
if __name__ == '__main__':
#test_read_hits()
#frames_to_ply()
merge_day_pointclouds()
a = 1/0
############################
# Initialize the environment
############################
# Set which gpu is going to be used
GPU_ID = '2'
# Set GPU visible device
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
###############
# Previous chkp
###############
# Choose here if you want to start training from a previous snapshot (None for new training)
# previous_training_path = 'Log_2020-03-19_19-53-27'
previous_training_path = ''
# Choose index of checkpoint to start from. If None, uses the latest chkp
chkp_idx = None
if previous_training_path:
# Find all snapshot in the chosen training folder
chkp_path = os.path.join('results', previous_training_path, 'checkpoints')
chkps = [f for f in os.listdir(chkp_path) if f[:4] == 'chkp']
# Find which snapshot to restore
if chkp_idx is None:
chosen_chkp = 'current_chkp.tar'
else:
chosen_chkp = np.sort(chkps)[chkp_idx]
chosen_chkp = os.path.join('results', previous_training_path, 'checkpoints', chosen_chkp)
else:
chosen_chkp = None
##############
# Prepare Data
##############
print()
print('Data Preparation')
print('****************')
# Initialize configuration class
config = NCLTConfig()
if previous_training_path:
config.load(os.path.join('results', previous_training_path))
config.saving_path = None
# Get path from argument if given
if len(sys.argv) > 1:
config.saving_path = sys.argv[1]
# Initialize datasets
training_dataset = NCLTDataset(config, set='training',
balance_classes=True)
test_dataset = NCLTDataset(config, set='validation',
balance_classes=False)
# Initialize samplers
training_sampler = NCLTSampler(training_dataset)
test_sampler = NCLTSampler(test_dataset)
# Initialize the dataloader
training_loader = DataLoader(training_dataset,
batch_size=1,
sampler=training_sampler,
collate_fn=NCLTCollate,
num_workers=config.input_threads,
pin_memory=True)
test_loader = DataLoader(test_dataset,
batch_size=1,
sampler=test_sampler,
collate_fn=NCLTCollate,
num_workers=config.input_threads,
pin_memory=True)
# Calibrate max_in_point value
training_sampler.calib_max_in(config, training_loader, verbose=True)
test_sampler.calib_max_in(config, test_loader, verbose=True)
# Calibrate samplers
training_sampler.calibration(training_loader, verbose=True)
test_sampler.calibration(test_loader, verbose=True)
# debug_timing(training_dataset, training_loader)
# debug_timing(test_dataset, test_loader)
# debug_class_w(training_dataset, training_loader)
print('\nModel Preparation')
print('*****************')
# Define network model
t1 = time.time()
net = KPFCNN(config, training_dataset.label_values, training_dataset.ignored_labels)
debug = False
if debug:
print('\n*************************************\n')
print(net)
print('\n*************************************\n')
for param in net.parameters():
if param.requires_grad:
print(param.shape)
print('\n*************************************\n')
print("Model size %i" % sum(param.numel() for param in net.parameters() if param.requires_grad))
print('\n*************************************\n')
# Define a trainer class
trainer = ModelTrainer(net, config, chkp_path=chosen_chkp)
print('Done in {:.1f}s\n'.format(time.time() - t1))
print('\nStart training')
print('**************')
# Training
trainer.train(net, training_loader, test_loader, config)
print('Forcing exit now')
os.kill(os.getpid(), signal.SIGINT)

View file

@ -74,22 +74,15 @@ class S3DISConfig(Config):
'resnetb_strided',
'resnetb',
'resnetb',
'resnetb',
'resnetb_strided',
'resnetb',
'resnetb',
'resnetb',
'resnetb',
'resnetb',
'resnetb_strided',
'resnetb',
'resnetb',
'resnetb',
'resnetb',
'resnetb',
'resnetb_strided',
'resnetb',
'resnetb',
'resnetb_deformable',
'resnetb_deformable',
'resnetb_deformable_strided',
'resnetb_deformable',
'resnetb_deformable',
'nearest_upsample',
'unary',
'nearest_upsample',
@ -104,7 +97,7 @@ class S3DISConfig(Config):
###################
# Radius of the input sphere
in_radius = 1.0
in_radius = 1.5
# Number of kernel points
num_kernel_points = 15
@ -142,7 +135,7 @@ class S3DISConfig(Config):
# 'permissive' only constrains offsets inside the deform radius (NOT implemented yet)
# 'fitting' helps deformed kernels to adapt to the geometry by penalizing distance to input points
offsets_loss = 'fitting'
offsets_decay = 0.01
offsets_decay = 0.05
#####################
# Training parameters
@ -158,7 +151,7 @@ class S3DISConfig(Config):
grad_clip_norm = 100.0
# Number of batch
batch_num = 8
batch_num = 6
# Number of steps per epochs
epoch_steps = 500

View file

@ -100,21 +100,21 @@ class SemanticKittiConfig(Config):
###################
# Radius of the input sphere
in_radius = 10.0
in_radius = 6.0
val_radius = 51.0
n_frames = 1
max_in_points = 100000
max_val_points = 100000
max_val_points = 200000
# Number of batch
batch_num = 10
batch_num = 8
val_batch_num = 1
# Number of kernel points
num_kernel_points = 15
# Size of the first subsampling grid in meter
first_subsampling_dl = 0.08
first_subsampling_dl = 0.06
# Radius of convolution in "number grid cell". (2.5 is the standard value)
conv_radius = 2.5
@ -133,7 +133,7 @@ class SemanticKittiConfig(Config):
# Choice of input features
first_features_dim = 128
in_features_dim = 5
in_features_dim = 2
# Can the network learn modulations
modulated = False
@ -158,7 +158,7 @@ class SemanticKittiConfig(Config):
# Learning rate management
learning_rate = 1e-2
momentum = 0.98
lr_decays = {i: 0.1 ** (1 / 100) for i in range(1, max_epoch)}
lr_decays = {i: 0.1 ** (1 / 150) for i in range(1, max_epoch)}
grad_clip_norm = 100.0
# Number of steps per epochs
@ -190,8 +190,8 @@ class SemanticKittiConfig(Config):
# 2.377, 0.886, 3.863, 0.869, 1.209, 0.594, 3.780, 1.129, 5.505, 11.180]
# sqrt(Inverse of proportion * 100) capped (0.5 < X < 5)
class_w = [1.430, 5.000, 5.000, 4.226, 5.000, 5.000, 5.000, 5.000, 0.719, 2.377,
0.886, 3.863, 0.869, 1.209, 0.594, 3.780, 1.129, 5.000, 5.000]
# class_w = [1.430, 5.000, 5.000, 4.226, 5.000, 5.000, 5.000, 5.000, 0.719, 2.377,
# 0.886, 3.863, 0.869, 1.209, 0.594, 3.780, 1.129, 5.000, 5.000]
# Do we nee to save convergence
@ -212,7 +212,7 @@ if __name__ == '__main__':
############################
# Set which gpu is going to be used
GPU_ID = '3'
GPU_ID = '2'
# Set GPU visible device
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID

View file

@ -181,7 +181,7 @@ class ModelTester:
inds = in_inds[i0:i0 + length]
c_i = cloud_inds[b_i]
if test_radius_ratio < 0.99:
if 0 < test_radius_ratio < 1:
mask = np.sum(points ** 2, axis=1) < (test_radius_ratio * config.in_radius) ** 2
inds = inds[mask]
probs = probs[mask]

View file

@ -259,7 +259,7 @@ class ModelTrainer:
# Save checkpoints occasionally
if (self.epoch + 1) % config.checkpoint_gap == 0:
checkpoint_path = join(checkpoint_directory, 'chkp_{:04d}.tar'.format(self.epoch))
checkpoint_path = join(checkpoint_directory, 'chkp_{:04d}.tar'.format(self.epoch + 1))
torch.save(save_dict, checkpoint_path)
# Validation

View file

@ -82,7 +82,7 @@ class ModelVisualizer:
net.load_state_dict(checkpoint['model_state_dict'])
self.epoch = checkpoint['epoch']
net.eval()
print("Model and training state restored.")
print("\nModel state restored from {:s}.".format(chkp_path))
return
@ -679,139 +679,63 @@ class ModelVisualizer:
except tf.errors.OutOfRangeError:
break
def show_effective_recep_field(self, model, dataset, relu_idx=0):
def show_effective_recep_field(self, net, loader, config, f_idx=0):
###################################################
# First add a modulation variable on input features
###################################################
##########################################
# First choose the visualized deformations
##########################################
# Tensorflow random seed
random_seed = 42
blocks = {}
# Create a modulated input feature op
with tf.variable_scope('input_modulations'):
initial = tf.constant(0., shape=[200000, 1])
input_modulations_var = tf.Variable(initial, name='alphas')
input_modulations = 2 * tf.sigmoid(input_modulations_var)
assert_op = tf.assert_less(tf.shape(model.inputs['features'])[0], tf.shape(input_modulations)[0])
with tf.control_dependencies([assert_op]):
modulated_input = model.inputs['features'] * input_modulations[:tf.shape(model.inputs['features'])[0]]
modulated_input = tf.identity(modulated_input, name='modulated_features')
named_blocks = [(m_name, m) for m_name, m in net.named_modules()
if len(m_name.split('.')) == 2 and m_name.split('.')[0].endswith('_blocks')]
chosen_block = named_blocks[-1][0]
print('*******************************************')
# Swap the op with the normal input features
for op in tf.get_default_graph().get_operations():
if 'input_modulations' in op.name:
continue
if model.inputs['features'].name in [in_t.name for in_t in op.inputs]:
input_list = []
for in_t in op.inputs:
if in_t.name == model.inputs['features'].name:
input_list += [modulated_input]
else:
input_list += [in_t]
print('swapping op ', op.name)
print('old inputs ', [in_t.name for in_t in op.inputs])
print('new inputs ', [in_t.name for in_t in input_list])
ge.swap_inputs(op, input_list)
print('*******************************************')
##########################
# Create the ERF optimizer
##########################
# This optimizer only computes gradients for the feature modulation variables. We set the ERF loss, which
# consists of modifying the features in one location a the wanted layer
with tf.variable_scope('ERF_loss'):
# List all relu ops
all_ops = [op for op in tf.get_default_graph().get_operations() if op.name.startswith('KernelPointNetwork')
and op.name.endswith('LeakyRelu')]
# Print the chosen one
features_tensor = all_ops[relu_idx].outputs[0]
# Get parameters
layer_idx = int(features_tensor.name.split('/')[1][6:])
if 'strided' in all_ops[relu_idx].name and not ('strided' in all_ops[relu_idx + 1].name):
layer_idx += 1
features_dim = int(features_tensor.shape[1])
radius = model.config.first_subsampling_dl * model.config.density_parameter * (2 ** layer_idx)
print('You chose to visualize the output of operation named: ' + all_ops[relu_idx].name)
print('It contains {:d} features.'.format(int(features_tensor.shape[1])))
print('\nPossible Relu indices:')
for i, t in enumerate(all_ops):
print(i, ': ', t.name)
print('\n****************************************************************************')
# Get the receptive field of a random point
N = tf.shape(features_tensor)[0]
#random_ind = tf.random_uniform([1], minval=0, maxval=N, dtype=np.int32, seed=random_seed)[0]
#chosen_i_holder = tf.placeholder(tf.int32, name='chosen_ind')
aimed_coordinates = tf.placeholder(tf.float32, shape=(1, 3), name='aimed_coordinates')
d2 = tf.reduce_sum(tf.square(model.inputs['points'][layer_idx] - aimed_coordinates), axis=1)
chosen_i_tf = tf.argmin(d2, output_type=tf.int32)
#test1 = tf.multiply(features_tensor, 2.0, name='test1')
#test2 = tf.multiply(features_tensor, 2.0, name='test2')
# Gradient scaling operation
@tf.custom_gradient
def scale_grad_layer(x):
def scaled_grad(dy):
p_op = tf.print(x.name,
tf.reduce_mean(tf.abs(x)),
tf.reduce_mean(tf.abs(dy)),
output_stream=sys.stdout)
with tf.control_dependencies([p_op]):
new_dy = 1.0 * dy
return new_dy
return tf.identity(x), scaled_grad
#test2 = scale_grad_layer(test2)
# Get the tensor of error for these features (one for the chosen point, zero for the rest)
chosen_f_tf = tf.placeholder(tf.int32, name='feature_ind')
ERF_error = tf.expand_dims(tf.cast(tf.equal(tf.range(N), chosen_i_tf), tf.float32), 1)
ERF_error *= tf.expand_dims(tf.cast(tf.equal(tf.range(features_dim), chosen_f_tf), tf.float32), 0)
# Get objective for the features (with a stop gradient so that we can get a gradient on the loss)
objective_features = features_tensor + ERF_error
objective_features = tf.stop_gradient(objective_features)
# Loss is the error but with the features that can be learned to correct it
ERF_loss = tf.reduce_sum(tf.square(objective_features - features_tensor))
for mi, (m_name, m) in enumerate(named_blocks):
with tf.variable_scope('ERF_optimizer'):
c1 = bcolors.OKBLUE
c2 = bcolors.BOLD
ce = bcolors.ENDC
print('{:}{:}{:s}{:}{:} {:s}'.format(c1, c2, m_name, ce, ce, m.__repr__()))
blocks[m_name] = m
# Create the gradient descent optimizer with a dummy learning rate
optimizer = tf.train.GradientDescentOptimizer(1.0)
if mi == f_idx:
chosen_block = m_name
# Get the gradients with respect to the modulation variable
ERF_var_grads = optimizer.compute_gradients(ERF_loss, var_list=[input_modulations_var])
print('\nChoose which block output you want to visualize by entering the block name in blue')
override_block = input('Block name: ')
if len(override_block) > 0:
chosen_block = override_block
print('{:}{:}{:s}{:}{:} {:s}'.format(c1, c2, chosen_block, ce, ce, blocks[chosen_block].__repr__()))
features_dim = blocks[chosen_block].out_dim
# Fix all the trainable variables in the network (is it needed in eval mode?)
print('\n*************************************\n')
for p_name, param in net.named_parameters():
if param.requires_grad:
param.requires_grad = False
print('\n*************************************\n')
# Create modulation variable that requires grad
input_modulations = torch.nn.Parameter(torch.zeros((200000, 1),
dtype=torch.float32),
requires_grad=True)
print('\n*************************************\n')
for p_name, param in net.named_parameters():
if param.requires_grad:
print(p_name, param.shape)
print('\n*************************************\n')
# Create ERF loss
# Create ERF optimizer
# Gradient of the modulations
ERF_train_op = optimizer.apply_gradients(ERF_var_grads)
################################
# Run model on all test examples
################################
# Init our modulation variable
self.sess.run(tf.variables_initializer([input_modulations_var]))
# Initialise iterator with test data
self.sess.run(dataset.test_init_op)
count = 0
global plots, p_scale, show_in_p, remove_h, aim_point
aim_point = np.zeros((1, 3), dtype=np.float32)
@ -841,10 +765,11 @@ class ModelVisualizer:
global points, in_points, grad_values, chosen_point, aim_point, in_colors
# Generate clouds until we effectively changed
batch = None
if only_points:
for i in range(50):
all_points = self.sess.run(model.inputs['points'])
if all_points[0].shape[0] != in_points.shape[0]:
# get a new batch (index does not matter given our input pipeline)
for batch in loader:
if batch.points[0].shape[0] != in_points.shape[0]:
break
sum_grads = 0
@ -853,11 +778,65 @@ class ModelVisualizer:
else:
num_tries = 10
#################################################
# Apply ERF optim to the same batch several times
#################################################
if 'cuda' in self.device.type:
batch.to(self.device)
for test_i in range(num_tries):
print('Updating ERF {:.0f}%'.format((test_i + 1) * 100 / num_tries))
rand_f_i = np.random.randint(features_dim)
# Reset input modulation variable
torch.nn.init.zeros_(input_modulations)
reset_op = input_modulations_var.assign(tf.zeros_like(input_modulations_var))
self.sess.run(reset_op)
# zero the parameter gradients
ERF_optimizer.zero_grad()
# Forward pass
outputs = net(batch, config)
loss = net.ERF_loss(outputs)
# Backward
loss.backward()
# Get result from hook here?
ERF_optimizer.step()
torch.cuda.synchronize(self.device)
# Forward pass
outputs = net(batch, config)
original_KP = deform_convs[deform_idx].kernel_points.cpu().detach().numpy()
stacked_deformed_KP = deform_convs[deform_idx].deformed_KP.cpu().detach().numpy()
count += batch.lengths[0].shape[0]
if 'cuda' in self.device.type:
torch.cuda.synchronize(self.device)
# Reset input modulation variable
reset_op = input_modulations_var.assign(tf.zeros_like(input_modulations_var))
self.sess.run(reset_op)
@ -1069,6 +1048,8 @@ class ModelVisualizer:
fig1.scene.interactor.add_observer('KeyPressEvent', keyboard_callback)
mlab.show()
return
def show_deformable_kernels(self, net, loader, config, deform_idx=0):
"""
Show some inference with deformable kernels

205
visualize_ERFs.py Normal file
View file

@ -0,0 +1,205 @@
#
#
# 0=================================0
# | Kernel Point Convolutions |
# 0=================================0
#
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Callable script to start a training on ModelNet40 dataset
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Hugues THOMAS - 06/03/2020
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Imports and global variables
# \**********************************/
#
# Common libs
import signal
import os
import numpy as np
import sys
import torch
# Dataset
from datasets.ModelNet40 import *
from datasets.S3DIS import *
from torch.utils.data import DataLoader
from utils.config import Config
from utils.visualizer import ModelVisualizer
from models.architectures import KPCNN, KPFCNN
# ----------------------------------------------------------------------------------------------------------------------
#
# Main Call
# \***************/
#
def model_choice(chosen_log):
###########################
# Call the test initializer
###########################
# Automatically retrieve the last trained model
if chosen_log in ['last_ModelNet40', 'last_ShapeNetPart', 'last_S3DIS']:
# Dataset name
test_dataset = '_'.join(chosen_log.split('_')[1:])
# List all training logs
logs = np.sort([os.path.join('results', f) for f in os.listdir('results') if f.startswith('Log')])
# Find the last log of asked dataset
for log in logs[::-1]:
log_config = Config()
log_config.load(log)
if log_config.dataset.startswith(test_dataset):
chosen_log = log
break
if chosen_log in ['last_ModelNet40', 'last_ShapeNetPart', 'last_S3DIS']:
raise ValueError('No log of the dataset "' + test_dataset + '" found')
# Check if log exists
if not os.path.exists(chosen_log):
raise ValueError('The given log does not exists: ' + chosen_log)
return chosen_log
# ----------------------------------------------------------------------------------------------------------------------
#
# Main Call
# \***************/
#
if __name__ == '__main__':
###############################
# Choose the model to visualize
###############################
# Here you can choose which model you want to test with the variable test_model. Here are the possible values :
#
# > 'last_XXX': Automatically retrieve the last trained model on dataset XXX
# > '(old_)results/Log_YYYY-MM-DD_HH-MM-SS': Directly provide the path of a trained model
# chosen_log = 'results/Log_2020-04-04_10-04-42' # => ModelNet40
# chosen_log = 'results/Log_2020-04-04_10-04-42' # => S3DIS
chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected
# You can also choose the index of the snapshot to load (last by default)
chkp_idx = -1
# Eventually you can choose which feature is visualized (index of the deform operation in the network)
f_idx = -1
# Deal with 'last_XXX' choices
chosen_log = model_choice(chosen_log)
############################
# Initialize the environment
############################
# Set which gpu is going to be used
GPU_ID = '0'
# Set GPU visible device
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
###############
# Previous chkp
###############
# Find all checkpoints in the chosen training folder
chkp_path = os.path.join(chosen_log, 'checkpoints')
chkps = [f for f in os.listdir(chkp_path) if f[:4] == 'chkp']
# Find which snapshot to restore
if chkp_idx is None:
chosen_chkp = 'current_chkp.tar'
else:
chosen_chkp = np.sort(chkps)[chkp_idx]
chosen_chkp = os.path.join(chosen_log, 'checkpoints', chosen_chkp)
# Initialize configuration class
config = Config()
config.load(chosen_log)
##################################
# Change model parameters for test
##################################
# Change parameters for the test here. For example, you can stop augmenting the input data.
config.augment_noise = 0.0001
#config.augment_symmetries = False
config.batch_num = 1
config.in_radius = 2.0
config.input_threads = 0
##############
# Prepare Data
##############
print()
print('Data Preparation')
print('****************')
# Initiate dataset
if config.dataset.startswith('ModelNet40'):
test_dataset = ModelNet40Dataset(config, train=False)
test_sampler = ModelNet40Sampler(test_dataset)
collate_fn = ModelNet40Collate
elif config.dataset == 'S3DIS':
test_dataset = S3DISDataset(config, set='validation', use_potentials=True)
test_sampler = S3DISSampler(test_dataset)
collate_fn = S3DISCollate
else:
raise ValueError('Unsupported dataset : ' + config.dataset)
# Data loader
test_loader = DataLoader(test_dataset,
batch_size=1,
sampler=test_sampler,
collate_fn=collate_fn,
num_workers=config.input_threads,
pin_memory=True)
# Calibrate samplers
test_sampler.calibration(test_loader, verbose=True)
print('\nModel Preparation')
print('*****************')
# Define network model
t1 = time.time()
if config.dataset_task == 'classification':
net = KPCNN(config)
elif config.dataset_task in ['cloud_segmentation', 'slam_segmentation']:
net = KPFCNN(config, test_dataset.label_values, test_dataset.ignored_labels)
else:
raise ValueError('Unsupported dataset_task for deformation visu: ' + config.dataset_task)
# Define a visualizer class
visualizer = ModelVisualizer(net, config, chkp_path=chosen_chkp, on_gpu=False)
print('Done in {:.1f}s\n'.format(time.time() - t1))
print('\nStart visualization')
print('*******************')
# Training
visualizer.show_effective_recep_field(net, test_loader, config, f_idx)

View file

@ -30,11 +30,12 @@ import torch
# Dataset
from datasets.ModelNet40 import *
from datasets.S3DIS import *
from torch.utils.data import DataLoader
from utils.config import Config
from utils.visualizer import ModelVisualizer
from models.architectures import KPCNN
from models.architectures import KPCNN, KPFCNN
# ----------------------------------------------------------------------------------------------------------------------
@ -93,10 +94,12 @@ if __name__ == '__main__':
# > 'last_XXX': Automatically retrieve the last trained model on dataset XXX
# > '(old_)results/Log_YYYY-MM-DD_HH-MM-SS': Directly provide the path of a trained model
chosen_log = 'results/Log_2020-03-23_22-18-26' # => ModelNet40
# chosen_log = 'results/Log_2020-04-04_10-04-42' # => ModelNet40
# chosen_log = 'results/Log_2020-04-22_11-53-45' # => S3DIS
chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected
# You can also choose the index of the snapshot to load (last by default)
chkp_idx = None
chkp_idx = -1
# Eventually you can choose which feature is visualized (index of the deform operation in the network)
deform_idx = 0
@ -139,10 +142,11 @@ if __name__ == '__main__':
# Change parameters for the test here. For example, you can stop augmenting the input data.
#config.augment_noise = 0.0001
config.augment_noise = 0.0001
#config.augment_symmetries = False
#config.batch_num = 3
#config.in_radius = 4
config.batch_num = 1
config.in_radius = 2.0
config.input_threads = 0
##############
# Prepare Data
@ -152,22 +156,28 @@ if __name__ == '__main__':
print('Data Preparation')
print('****************')
# Initialize datasets
test_dataset = ModelNet40Dataset(config, train=False)
# Initiate dataset
if config.dataset.startswith('ModelNet40'):
test_dataset = ModelNet40Dataset(config, train=False)
test_sampler = ModelNet40Sampler(test_dataset)
collate_fn = ModelNet40Collate
elif config.dataset == 'S3DIS':
test_dataset = S3DISDataset(config, set='validation', use_potentials=True)
test_sampler = S3DISSampler(test_dataset)
collate_fn = S3DISCollate
else:
raise ValueError('Unsupported dataset : ' + config.dataset)
# Initialize samplers
test_sampler = ModelNet40Sampler(test_dataset)
# Initialize the dataloader
# Data loader
test_loader = DataLoader(test_dataset,
batch_size=1,
sampler=test_sampler,
collate_fn=ModelNet40Collate,
num_workers=0,
collate_fn=collate_fn,
num_workers=config.input_threads,
pin_memory=True)
# Calibrate samplers
test_sampler.calibration(test_loader)
test_sampler.calibration(test_loader, verbose=True)
print('\nModel Preparation')
print('*****************')
@ -176,6 +186,8 @@ if __name__ == '__main__':
t1 = time.time()
if config.dataset_task == 'classification':
net = KPCNN(config)
elif config.dataset_task in ['cloud_segmentation', 'slam_segmentation']:
net = KPFCNN(config, test_dataset.label_values, test_dataset.ignored_labels)
else:
raise ValueError('Unsupported dataset_task for deformation visu: ' + config.dataset_task)