Initial commit
This commit is contained in:
parent
5efecbbe20
commit
3d05a41368
|
@ -52,7 +52,7 @@ from utils.config import bcolors
|
|||
|
||||
|
||||
class S3DISDataset(PointCloudDataset):
|
||||
"""Class to handle Modelnet 40 dataset."""
|
||||
"""Class to handle S3DIS dataset."""
|
||||
|
||||
def __init__(self, config, set='training', use_potentials=True, load_data=True):
|
||||
"""
|
||||
|
@ -138,7 +138,23 @@ class S3DISDataset(PointCloudDataset):
|
|||
################
|
||||
|
||||
# List of training files
|
||||
self.train_files = [join(ply_path, f + '.ply') for f in self.cloud_names]
|
||||
self.files = []
|
||||
for i, f in enumerate(self.cloud_names):
|
||||
if self.set == 'training':
|
||||
if self.all_splits[i] != self.validation_split:
|
||||
self.files += [join(ply_path, f + '.ply')]
|
||||
elif self.set in ['validation', 'test', 'ERF']:
|
||||
if self.all_splits[i] == self.validation_split:
|
||||
self.files += [join(ply_path, f + '.ply')]
|
||||
else:
|
||||
raise ValueError('Unknown set for S3DIS data: ', self.set)
|
||||
|
||||
if self.set == 'training':
|
||||
self.cloud_names = [f for i, f in enumerate(self.cloud_names)
|
||||
if self.all_splits[i] != self.validation_split]
|
||||
elif self.set in ['validation', 'test', 'ERF']:
|
||||
self.cloud_names = [f for i, f in enumerate(self.cloud_names)
|
||||
if self.all_splits[i] == self.validation_split]
|
||||
|
||||
if 0 < self.config.first_subsampling_dl <= 0.01:
|
||||
raise ValueError('subsampling_parameter too low (should be over 1 cm')
|
||||
|
@ -149,7 +165,7 @@ class S3DISDataset(PointCloudDataset):
|
|||
self.input_labels = []
|
||||
self.pot_trees = []
|
||||
self.num_clouds = 0
|
||||
self.validation_proj = []
|
||||
self.test_proj = []
|
||||
self.validation_labels = []
|
||||
|
||||
# Start loading
|
||||
|
@ -624,21 +640,11 @@ class S3DISDataset(PointCloudDataset):
|
|||
# Load KDTrees
|
||||
##############
|
||||
|
||||
for i, file_path in enumerate(self.train_files):
|
||||
for i, file_path in enumerate(self.files):
|
||||
|
||||
# Restart timer
|
||||
t0 = time.time()
|
||||
|
||||
# Skip split that is not in current set
|
||||
if self.set == 'training':
|
||||
if self.all_splits[i] == self.validation_split:
|
||||
continue
|
||||
elif self.set in ['validation', 'test', 'ERF']:
|
||||
if self.all_splits[i] != self.validation_split:
|
||||
continue
|
||||
else:
|
||||
raise ValueError('Unknown set for S3DIS data: ', self.set)
|
||||
|
||||
# Get cloud name
|
||||
cloud_name = self.cloud_names[i]
|
||||
|
||||
|
@ -714,17 +720,7 @@ class S3DISDataset(PointCloudDataset):
|
|||
pot_dl = self.config.in_radius / 10
|
||||
cloud_ind = 0
|
||||
|
||||
for i, file_path in enumerate(self.train_files):
|
||||
|
||||
# Skip split that is not in current set
|
||||
if self.set == 'training':
|
||||
if self.all_splits[i] == self.validation_split:
|
||||
continue
|
||||
elif self.set in ['validation', 'test', 'ERF']:
|
||||
if self.all_splits[i] != self.validation_split:
|
||||
continue
|
||||
else:
|
||||
raise ValueError('Unknown set for S3DIS data: ', self.set)
|
||||
for i, file_path in enumerate(self.files):
|
||||
|
||||
# Get cloud name
|
||||
cloud_name = self.cloud_names[i]
|
||||
|
@ -769,12 +765,7 @@ class S3DISDataset(PointCloudDataset):
|
|||
print('\nPreparing reprojection indices for testing')
|
||||
|
||||
# Get validation/test reprojection indices
|
||||
i_cloud = 0
|
||||
for i, file_path in enumerate(self.train_files):
|
||||
|
||||
# Skip split that is not in current set
|
||||
if self.all_splits[i] != self.validation_split:
|
||||
continue
|
||||
for i, file_path in enumerate(self.files):
|
||||
|
||||
# Restart timer
|
||||
t0 = time.time()
|
||||
|
@ -795,7 +786,7 @@ class S3DISDataset(PointCloudDataset):
|
|||
labels = data['class']
|
||||
|
||||
# Compute projection inds
|
||||
idxs = self.input_trees[i_cloud].query(points, return_distance=False)
|
||||
idxs = self.input_trees[i].query(points, return_distance=False)
|
||||
#dists, idxs = self.input_trees[i_cloud].kneighbors(points)
|
||||
proj_inds = np.squeeze(idxs).astype(np.int32)
|
||||
|
||||
|
@ -803,9 +794,8 @@ class S3DISDataset(PointCloudDataset):
|
|||
with open(proj_file, 'wb') as f:
|
||||
pickle.dump([proj_inds, labels], f)
|
||||
|
||||
self.validation_proj += [proj_inds]
|
||||
self.test_proj += [proj_inds]
|
||||
self.validation_labels += [labels]
|
||||
i_cloud += 1
|
||||
print('{:s} done in {:.1f}s'.format(cloud_name, time.time() - t0))
|
||||
|
||||
print()
|
||||
|
@ -819,6 +809,9 @@ class S3DISDataset(PointCloudDataset):
|
|||
# Get original points
|
||||
data = read_ply(file_path)
|
||||
return np.vstack((data['x'], data['y'], data['z'])).T
|
||||
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Utility classes definition
|
||||
|
|
1407
datasets/SemanticKitti.py
Normal file
1407
datasets/SemanticKitti.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -345,7 +345,7 @@ class PointCloudDataset(Dataset):
|
|||
else:
|
||||
# No pooling in the end of this layer, no pooling indices required
|
||||
pool_i = np.zeros((0, 1), dtype=np.int32)
|
||||
pool_p = np.zeros((0, 3), dtype=np.float32)
|
||||
pool_p = np.zeros((0, 1), dtype=np.float32)
|
||||
pool_b = np.zeros((0,), dtype=np.int32)
|
||||
|
||||
# Reduce size of neighbors matrices by eliminating furthest point
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# Hugues THOMAS - 06/03/2020
|
||||
#
|
||||
|
||||
|
||||
from models.blocks import *
|
||||
import numpy as np
|
||||
|
||||
|
@ -201,7 +200,7 @@ class KPFCNN(nn.Module):
|
|||
Class defining KPFCNN
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, lbl_values, ign_lbls):
|
||||
super(KPFCNN, self).__init__()
|
||||
|
||||
############
|
||||
|
@ -214,6 +213,7 @@ class KPFCNN(nn.Module):
|
|||
in_dim = config.in_features_dim
|
||||
out_dim = config.first_features_dim
|
||||
self.K = config.num_kernel_points
|
||||
self.C = len(lbl_values) - len(ign_lbls)
|
||||
|
||||
#####################
|
||||
# List Encoder blocks
|
||||
|
@ -303,21 +303,21 @@ class KPFCNN(nn.Module):
|
|||
out_dim = out_dim // 2
|
||||
|
||||
self.head_mlp = UnaryBlock(out_dim, config.first_features_dim, False, 0)
|
||||
self.head_softmax = UnaryBlock(config.first_features_dim, config.num_classes, False, 0)
|
||||
self.head_softmax = UnaryBlock(config.first_features_dim, self.C, False, 0)
|
||||
|
||||
################
|
||||
# Network Losses
|
||||
################
|
||||
|
||||
# List of valid labels (those not ignored in loss)
|
||||
self.valid_labels = np.sort([c for c in lbl_values if c not in ign_lbls])
|
||||
|
||||
# Choose segmentation loss
|
||||
if config.segloss_balance == 'none':
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
elif config.segloss_balance == 'class':
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
elif config.segloss_balance == 'batch':
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
if len(config.class_w) > 0:
|
||||
class_w = torch.from_numpy(np.array(config.class_w, dtype=np.float32))
|
||||
self.criterion = torch.nn.CrossEntropyLoss(weight=class_w, ignore_index=-1)
|
||||
else:
|
||||
raise ValueError('Unknown segloss_balance:', config.segloss_balance)
|
||||
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
|
||||
self.offset_loss = config.offsets_loss
|
||||
self.offset_decay = config.offsets_decay
|
||||
self.output_loss = 0
|
||||
|
@ -357,12 +357,18 @@ class KPFCNN(nn.Module):
|
|||
:return: loss
|
||||
"""
|
||||
|
||||
# Set all ignored labels to -1 and correct the other label to be in [0, C-1] range
|
||||
target = - torch.ones_like(labels)
|
||||
for i, c in enumerate(self.valid_labels):
|
||||
target[labels == c] = i
|
||||
|
||||
# Reshape to have a minibatch size of 1
|
||||
outputs = torch.transpose(outputs, 0, 1)
|
||||
outputs = outputs.unsqueeze(0)
|
||||
labels = labels.unsqueeze(0)
|
||||
target = target.unsqueeze(0)
|
||||
|
||||
# Cross entropy loss
|
||||
self.output_loss = self.criterion(outputs, labels)
|
||||
self.output_loss = self.criterion(outputs, target)
|
||||
|
||||
# Regularization of deformable offsets
|
||||
self.reg_loss = self.offset_regularizer()
|
||||
|
@ -370,8 +376,7 @@ class KPFCNN(nn.Module):
|
|||
# Combined loss
|
||||
return self.output_loss + self.reg_loss
|
||||
|
||||
@staticmethod
|
||||
def accuracy(outputs, labels):
|
||||
def accuracy(self, outputs, labels):
|
||||
"""
|
||||
Computes accuracy of the current batch
|
||||
:param outputs: logits predicted by the network
|
||||
|
@ -379,9 +384,14 @@ class KPFCNN(nn.Module):
|
|||
:return: accuracy value
|
||||
"""
|
||||
|
||||
# Set all ignored labels to -1 and correct the other label to be in [0, C-1] range
|
||||
target = - torch.ones_like(labels)
|
||||
for i, c in enumerate(self.valid_labels):
|
||||
target[labels == c] = i
|
||||
|
||||
predicted = torch.argmax(outputs.data, dim=1)
|
||||
total = labels.size(0)
|
||||
correct = (predicted == labels).sum().item()
|
||||
total = target.size(0)
|
||||
correct = (predicted == target).sum().item()
|
||||
|
||||
return correct / total
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ from utils.ply import read_ply
|
|||
# Datasets
|
||||
from datasets.ModelNet40 import ModelNet40Dataset
|
||||
from datasets.S3DIS import S3DISDataset
|
||||
from datasets.SemanticKitti import SemanticKittiDataset
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
|
@ -239,7 +240,7 @@ def load_multi_snap_clouds(path, dataset, file_i, only_last=False):
|
|||
else:
|
||||
for f in listdir(cloud_folder):
|
||||
if f.endswith('.ply') and not f.endswith('sub.ply'):
|
||||
if np.any([cloud_path.endswith(f) for cloud_path in dataset.train_files]):
|
||||
if np.any([cloud_path.endswith(f) for cloud_path in dataset.files]):
|
||||
data = read_ply(join(cloud_folder, f))
|
||||
labels = data['class']
|
||||
preds = data['preds']
|
||||
|
@ -971,20 +972,21 @@ def compare_convergences_SLAM(dataset, list_of_paths, list_of_names=None):
|
|||
class_list = [dataset.label_to_names[label] for label in dataset.label_values
|
||||
if label not in dataset.ignored_labels]
|
||||
|
||||
s = '{:^10}|'.format('mean')
|
||||
s = '{:^6}|'.format('mean')
|
||||
for c in class_list:
|
||||
s += '{:^10}'.format(c)
|
||||
s += '{:^6}'.format(c[:4])
|
||||
print(s)
|
||||
print(10*'-' + '|' + 10*config.num_classes*'-')
|
||||
print(6*'-' + '|' + 6*config.num_classes*'-')
|
||||
for path in list_of_paths:
|
||||
|
||||
# Get validation IoUs
|
||||
nc_model = dataset.num_classes - len(dataset.ignored_labels)
|
||||
file = join(path, 'val_IoUs.txt')
|
||||
val_IoUs = load_single_IoU(file, config.num_classes)
|
||||
val_IoUs = load_single_IoU(file, nc_model)
|
||||
|
||||
# Get Subpart IoUs
|
||||
file = join(path, 'subpart_IoUs.txt')
|
||||
subpart_IoUs = load_single_IoU(file, config.num_classes)
|
||||
subpart_IoUs = load_single_IoU(file, nc_model)
|
||||
|
||||
# Get mean IoU
|
||||
val_class_IoUs, val_mIoUs = IoU_class_metrics(val_IoUs, smooth_n)
|
||||
|
@ -997,22 +999,21 @@ def compare_convergences_SLAM(dataset, list_of_paths, list_of_names=None):
|
|||
all_subpart_mIoUs += [subpart_mIoUs]
|
||||
all_subpart_class_IoUs += [subpart_class_IoUs]
|
||||
|
||||
s = '{:^10.1f}|'.format(100*subpart_mIoUs[-1])
|
||||
s = '{:^6.1f}|'.format(100*subpart_mIoUs[-1])
|
||||
for IoU in subpart_class_IoUs[-1]:
|
||||
s += '{:^10.1f}'.format(100*IoU)
|
||||
s += '{:^6.1f}'.format(100*IoU)
|
||||
print(s)
|
||||
|
||||
|
||||
print(10*'-' + '|' + 10*config.num_classes*'-')
|
||||
print(6*'-' + '|' + 6*config.num_classes*'-')
|
||||
for snap_IoUs in all_val_class_IoUs:
|
||||
if len(snap_IoUs) > 0:
|
||||
s = '{:^10.1f}|'.format(100*np.mean(snap_IoUs[-1]))
|
||||
s = '{:^6.1f}|'.format(100*np.mean(snap_IoUs[-1]))
|
||||
for IoU in snap_IoUs[-1]:
|
||||
s += '{:^10.1f}'.format(100*IoU)
|
||||
s += '{:^6.1f}'.format(100*IoU)
|
||||
else:
|
||||
s = '{:^10s}'.format('-')
|
||||
s = '{:^6s}'.format('-')
|
||||
for _ in range(config.num_classes):
|
||||
s += '{:^10s}'.format('-')
|
||||
s += '{:^6s}'.format('-')
|
||||
print(s)
|
||||
|
||||
# Plots
|
||||
|
@ -1038,7 +1039,7 @@ def compare_convergences_SLAM(dataset, list_of_paths, list_of_names=None):
|
|||
#ax.set_yticks(np.arange(0.8, 1.02, 0.02))
|
||||
|
||||
displayed_classes = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
displayed_classes = []
|
||||
#displayed_classes = []
|
||||
for c_i, c_name in enumerate(class_list):
|
||||
if c_i in displayed_classes:
|
||||
|
||||
|
@ -1410,14 +1411,14 @@ def S3DIS_first(old_result_limit):
|
|||
return logs, logs_names
|
||||
|
||||
|
||||
def S3DIS_(old_result_limit):
|
||||
def S3DIS_go(old_result_limit):
|
||||
"""
|
||||
Test S3DIS.
|
||||
"""
|
||||
|
||||
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
|
||||
start = 'Log_2020-04-03_11-12-07'
|
||||
end = 'Log_2020-04-25_19-30-17'
|
||||
end = 'Log_2020-04-07_15-30-17'
|
||||
|
||||
if end < old_result_limit:
|
||||
res_path = 'old_results'
|
||||
|
@ -1430,6 +1431,11 @@ def S3DIS_(old_result_limit):
|
|||
# Give names to the logs (for legends)
|
||||
logs_names = ['R=2.0_r=0.04_Din=128_potential',
|
||||
'R=2.0_r=0.04_Din=64_potential',
|
||||
'R=1.8_r=0.03',
|
||||
'R=1.8_r=0.03_deeper',
|
||||
'R=1.8_r=0.03_deform',
|
||||
'R=2.0_r=0.03_megadeep',
|
||||
'R=2.5_r=0.03_megadeep',
|
||||
'test']
|
||||
|
||||
logs_names = np.array(logs_names[:len(logs)])
|
||||
|
@ -1437,17 +1443,52 @@ def S3DIS_(old_result_limit):
|
|||
return logs, logs_names
|
||||
|
||||
|
||||
def SemanticKittiFirst(old_result_limit):
|
||||
"""
|
||||
Test SematicKitti. First exps
|
||||
"""
|
||||
|
||||
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
|
||||
start = 'Log_2020-04-07_15-30-17'
|
||||
end = 'Log_2020-05-07_15-30-17'
|
||||
|
||||
if end < old_result_limit:
|
||||
res_path = 'old_results'
|
||||
else:
|
||||
res_path = 'results'
|
||||
|
||||
logs = np.sort([join(res_path, l) for l in listdir(res_path) if start <= l <= end])
|
||||
logs = logs.astype('<U50')
|
||||
|
||||
# Give names to the logs (for legends)
|
||||
logs_names = ['R=5.0_dl=0.04',
|
||||
'R=5.0_dl=0.08',
|
||||
'R=10.0_dl=0.08',
|
||||
'test']
|
||||
|
||||
logs_names = np.array(logs_names[:len(logs)])
|
||||
|
||||
return logs, logs_names
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
######################################################
|
||||
# Choose a list of log to plot together for comparison
|
||||
######################################################
|
||||
|
||||
# TODO: test deformable on S3DIS to see of fitting loss works
|
||||
# TODO: GOOOO SemanticKitti for wednesday at least have a timing to give to them
|
||||
# TODO: try class weights on S3DIS (very low weight for beam)
|
||||
|
||||
# Old result limit
|
||||
old_res_lim = 'Log_2020-03-25_19-30-17'
|
||||
|
||||
# My logs: choose the logs to show
|
||||
logs, logs_names = S3DIS_(old_res_lim)
|
||||
logs, logs_names = SemanticKittiFirst(old_res_lim)
|
||||
#os.environ['QT_DEBUG_PLUGINS'] = '1'
|
||||
|
||||
######################################################
|
||||
|
@ -1482,6 +1523,10 @@ if __name__ == '__main__':
|
|||
if config.dataset.startswith('S3DIS'):
|
||||
dataset = S3DISDataset(config, load_data=False)
|
||||
compare_convergences_segment(dataset, logs, logs_names)
|
||||
elif config.dataset_task == 'slam_segmentation':
|
||||
if config.dataset.startswith('SemanticKitti'):
|
||||
dataset = SemanticKittiDataset(config)
|
||||
compare_convergences_SLAM(dataset, logs, logs_names)
|
||||
else:
|
||||
raise ValueError('Unsupported dataset : ' + plot_dataset)
|
||||
|
||||
|
|
227
test_models.py
Normal file
227
test_models.py
Normal file
|
@ -0,0 +1,227 @@
|
|||
#
|
||||
#
|
||||
# 0=================================0
|
||||
# | Kernel Point Convolutions |
|
||||
# 0=================================0
|
||||
#
|
||||
#
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Callable script to start a training on ModelNet40 dataset
|
||||
#
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Hugues THOMAS - 06/03/2020
|
||||
#
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Imports and global variables
|
||||
# \**********************************/
|
||||
#
|
||||
|
||||
# Common libs
|
||||
import signal
|
||||
import os
|
||||
import numpy as np
|
||||
import sys
|
||||
import torch
|
||||
|
||||
# Dataset
|
||||
from datasets.ModelNet40 import *
|
||||
from datasets.S3DIS import *
|
||||
from datasets.SemanticKitti import *
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from utils.config import Config
|
||||
from utils.tester import ModelTester
|
||||
from models.architectures import KPCNN, KPFCNN
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Main Call
|
||||
# \***************/
|
||||
#
|
||||
|
||||
def model_choice(chosen_log):
|
||||
|
||||
###########################
|
||||
# Call the test initializer
|
||||
###########################
|
||||
|
||||
# Automatically retrieve the last trained model
|
||||
if chosen_log in ['last_ModelNet40', 'last_ShapeNetPart', 'last_S3DIS']:
|
||||
|
||||
# Dataset name
|
||||
test_dataset = '_'.join(chosen_log.split('_')[1:])
|
||||
|
||||
# List all training logs
|
||||
logs = np.sort([os.path.join('results', f) for f in os.listdir('results') if f.startswith('Log')])
|
||||
|
||||
# Find the last log of asked dataset
|
||||
for log in logs[::-1]:
|
||||
log_config = Config()
|
||||
log_config.load(log)
|
||||
if log_config.dataset.startswith(test_dataset):
|
||||
chosen_log = log
|
||||
break
|
||||
|
||||
if chosen_log in ['last_ModelNet40', 'last_ShapeNetPart', 'last_S3DIS']:
|
||||
raise ValueError('No log of the dataset "' + test_dataset + '" found')
|
||||
|
||||
# Check if log exists
|
||||
if not os.path.exists(chosen_log):
|
||||
raise ValueError('The given log does not exists: ' + chosen_log)
|
||||
|
||||
return chosen_log
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Main Call
|
||||
# \***************/
|
||||
#
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
###############################
|
||||
# Choose the model to visualize
|
||||
###############################
|
||||
|
||||
# Here you can choose which model you want to test with the variable test_model. Here are the possible values :
|
||||
#
|
||||
# > 'last_XXX': Automatically retrieve the last trained model on dataset XXX
|
||||
# > '(old_)results/Log_YYYY-MM-DD_HH-MM-SS': Directly provide the path of a trained model
|
||||
|
||||
chosen_log = 'results/Log_2020-04-07_18-22-18' # => ModelNet40
|
||||
|
||||
# You can also choose the index of the snapshot to load (last by default)
|
||||
chkp_idx = None
|
||||
|
||||
# Choose to test on validation or test split
|
||||
on_val = True
|
||||
|
||||
# Deal with 'last_XXXXXX' choices
|
||||
chosen_log = model_choice(chosen_log)
|
||||
|
||||
############################
|
||||
# Initialize the environment
|
||||
############################
|
||||
|
||||
# Set which gpu is going to be used
|
||||
GPU_ID = '3'
|
||||
|
||||
# Set GPU visible device
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
|
||||
|
||||
###############
|
||||
# Previous chkp
|
||||
###############
|
||||
|
||||
# Find all checkpoints in the chosen training folder
|
||||
chkp_path = os.path.join(chosen_log, 'checkpoints')
|
||||
chkps = [f for f in os.listdir(chkp_path) if f[:4] == 'chkp']
|
||||
|
||||
# Find which snapshot to restore
|
||||
if chkp_idx is None:
|
||||
chosen_chkp = 'current_chkp.tar'
|
||||
else:
|
||||
chosen_chkp = np.sort(chkps)[chkp_idx]
|
||||
chosen_chkp = os.path.join(chosen_log, 'checkpoints', chosen_chkp)
|
||||
|
||||
# Initialize configuration class
|
||||
config = Config()
|
||||
config.load(chosen_log)
|
||||
|
||||
##################################
|
||||
# Change model parameters for test
|
||||
##################################
|
||||
|
||||
# Change parameters for the test here. For example, you can stop augmenting the input data.
|
||||
|
||||
#config.augment_noise = 0.0001
|
||||
#config.augment_symmetries = False
|
||||
#config.batch_num = 3
|
||||
#config.in_radius = 4
|
||||
config.validation_size = 200
|
||||
config.input_threads = 0
|
||||
|
||||
##############
|
||||
# Prepare Data
|
||||
##############
|
||||
|
||||
print()
|
||||
print('Data Preparation')
|
||||
print('****************')
|
||||
|
||||
if on_val:
|
||||
set = 'validation'
|
||||
else:
|
||||
set = 'test'
|
||||
|
||||
# Initiate dataset
|
||||
if config.dataset.startswith('ModelNet40'):
|
||||
test_dataset = ModelNet40Dataset(config, train=False)
|
||||
test_sampler = ModelNet40Sampler(test_dataset)
|
||||
collate_fn = ModelNet40Collate
|
||||
elif config.dataset == 'S3DIS':
|
||||
test_dataset = S3DISDataset(config, set='validation', use_potentials=True)
|
||||
test_sampler = S3DISSampler(test_dataset)
|
||||
collate_fn = S3DISCollate
|
||||
elif config.dataset == 'SemanticKitti':
|
||||
test_dataset = SemanticKittiDataset(config, set=set, balance_classes=False)
|
||||
test_sampler = SemanticKittiSampler(test_dataset)
|
||||
collate_fn = SemanticKittiCollate
|
||||
else:
|
||||
raise ValueError('Unsupported dataset : ' + config.dataset)
|
||||
|
||||
# Data loader
|
||||
test_loader = DataLoader(test_dataset,
|
||||
batch_size=1,
|
||||
sampler=test_sampler,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=config.input_threads,
|
||||
pin_memory=True)
|
||||
|
||||
# Calibrate samplers
|
||||
test_sampler.calibration(test_loader, verbose=True)
|
||||
|
||||
print('\nModel Preparation')
|
||||
print('*****************')
|
||||
|
||||
# Define network model
|
||||
t1 = time.time()
|
||||
if config.dataset_task == 'classification':
|
||||
net = KPCNN(config)
|
||||
elif config.dataset_task in ['cloud_segmentation', 'slam_segmentation']:
|
||||
net = KPFCNN(config, test_dataset.label_values, test_dataset.ignored_labels)
|
||||
else:
|
||||
raise ValueError('Unsupported dataset_task for testing: ' + config.dataset_task)
|
||||
|
||||
# Define a visualizer class
|
||||
tester = ModelTester(net, chkp_path=chosen_chkp)
|
||||
print('Done in {:.1f}s\n'.format(time.time() - t1))
|
||||
|
||||
print('\nStart test')
|
||||
print('**********\n')
|
||||
|
||||
# Training
|
||||
if config.dataset_task == 'classification':
|
||||
a = 1/0
|
||||
elif config.dataset_task == 'cloud_segmentation':
|
||||
tester.cloud_segmentation_test(net, test_loader, config)
|
||||
elif config.dataset_task == 'slam_segmentation':
|
||||
tester.slam_segmentation_test(net, test_loader, config)
|
||||
else:
|
||||
raise ValueError('Unsupported dataset_task for testing: ' + config.dataset_task)
|
||||
|
||||
|
||||
# TODO: For test and also for training. When changing epoch do not restart the worker initiation. Keep workers
|
||||
# active with a while loop instead of using for loops.
|
||||
# For training and validation, keep two sets of worker active in parallel? is it possible?
|
||||
|
||||
# TODO: We have to verify if training on smaller spheres and testing on whole frame changes the score because
|
||||
# batchnorm may not have the same result as distribution of points will be different.
|
||||
|
|
@ -73,12 +73,23 @@ class S3DISConfig(Config):
|
|||
'resnetb',
|
||||
'resnetb_strided',
|
||||
'resnetb',
|
||||
'resnetb_strided',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb_strided',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb_strided',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb_strided',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'nearest_upsample',
|
||||
'unary',
|
||||
'nearest_upsample',
|
||||
|
@ -93,13 +104,13 @@ class S3DISConfig(Config):
|
|||
###################
|
||||
|
||||
# Radius of the input sphere
|
||||
in_radius = 2.0
|
||||
in_radius = 2.5
|
||||
|
||||
# Number of kernel points
|
||||
num_kernel_points = 15
|
||||
|
||||
# Size of the first subsampling grid in meter
|
||||
first_subsampling_dl = 0.04
|
||||
first_subsampling_dl = 0.03
|
||||
|
||||
# Radius of convolution in "number grid cell". (2.5 is the standard value)
|
||||
conv_radius = 2.5
|
||||
|
@ -108,7 +119,7 @@ class S3DISConfig(Config):
|
|||
deform_radius = 6.0
|
||||
|
||||
# Radius of the area of influence of each kernel point in "number grid cell". (1.0 is the standard value)
|
||||
KP_extent = 1.2
|
||||
KP_extent = 1.5
|
||||
|
||||
# Behavior of convolutions in ('constant', 'linear', 'gaussian')
|
||||
KP_influence = 'linear'
|
||||
|
@ -117,7 +128,7 @@ class S3DISConfig(Config):
|
|||
aggregation_mode = 'sum'
|
||||
|
||||
# Choice of input features
|
||||
first_features_dim = 64
|
||||
first_features_dim = 128
|
||||
in_features_dim = 5
|
||||
|
||||
# Can the network learn modulations
|
||||
|
@ -143,17 +154,17 @@ class S3DISConfig(Config):
|
|||
# Learning rate management
|
||||
learning_rate = 1e-2
|
||||
momentum = 0.98
|
||||
lr_decays = {i: 0.1**(1/100) for i in range(1, max_epoch)}
|
||||
lr_decays = {i: 0.1 ** (1 / 150) for i in range(1, max_epoch)}
|
||||
grad_clip_norm = 100.0
|
||||
|
||||
# Number of batch
|
||||
batch_num = 10
|
||||
batch_num = 4
|
||||
|
||||
# Number of steps per epochs
|
||||
epoch_steps = 500
|
||||
|
||||
# Number of validation examples per epoch
|
||||
validation_size = 30
|
||||
validation_size = 50
|
||||
|
||||
# Number of epoch between each checkpoint
|
||||
checkpoint_gap = 50
|
||||
|
@ -191,7 +202,7 @@ if __name__ == '__main__':
|
|||
############################
|
||||
|
||||
# Set which gpu is going to be used
|
||||
GPU_ID = '2'
|
||||
GPU_ID = '3'
|
||||
|
||||
# Set GPU visible device
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
|
||||
|
@ -275,7 +286,7 @@ if __name__ == '__main__':
|
|||
|
||||
# Define network model
|
||||
t1 = time.time()
|
||||
net = KPFCNN(config)
|
||||
net = KPFCNN(config, training_dataset.label_values, training_dataset.ignored_labels)
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
|
@ -297,14 +308,7 @@ if __name__ == '__main__':
|
|||
print('**************')
|
||||
|
||||
# Training
|
||||
try:
|
||||
trainer.train(net, training_loader, test_loader, config)
|
||||
except:
|
||||
print('Caught an error')
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
print('Forcing exit now')
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
|
||||
|
||||
|
|
321
train_SemanticKitti.py
Normal file
321
train_SemanticKitti.py
Normal file
|
@ -0,0 +1,321 @@
|
|||
#
|
||||
#
|
||||
# 0=================================0
|
||||
# | Kernel Point Convolutions |
|
||||
# 0=================================0
|
||||
#
|
||||
#
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Callable script to start a training on SemanticKitti dataset
|
||||
#
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Hugues THOMAS - 06/03/2020
|
||||
#
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Imports and global variables
|
||||
# \**********************************/
|
||||
#
|
||||
|
||||
# Common libs
|
||||
import signal
|
||||
import os
|
||||
import numpy as np
|
||||
import sys
|
||||
import torch
|
||||
|
||||
# Dataset
|
||||
from datasets.SemanticKitti import *
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from utils.config import Config
|
||||
from utils.trainer import ModelTrainer
|
||||
from models.architectures import KPFCNN
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Config Class
|
||||
# \******************/
|
||||
#
|
||||
|
||||
class SemanticKittiConfig(Config):
|
||||
"""
|
||||
Override the parameters you want to modify for this dataset
|
||||
"""
|
||||
|
||||
####################
|
||||
# Dataset parameters
|
||||
####################
|
||||
|
||||
# Dataset name
|
||||
dataset = 'SemanticKitti'
|
||||
|
||||
# Number of classes in the dataset (This value is overwritten by dataset class when Initializating dataset).
|
||||
num_classes = None
|
||||
|
||||
# Type of task performed on this dataset (also overwritten)
|
||||
dataset_task = ''
|
||||
|
||||
# Number of CPU threads for the input pipeline
|
||||
input_threads = 20
|
||||
|
||||
#########################
|
||||
# Architecture definition
|
||||
#########################
|
||||
|
||||
# Define layers
|
||||
architecture = ['simple',
|
||||
'resnetb',
|
||||
'resnetb_strided',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb_strided',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb_strided',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb_strided',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'nearest_upsample',
|
||||
'unary',
|
||||
'nearest_upsample',
|
||||
'unary',
|
||||
'nearest_upsample',
|
||||
'unary',
|
||||
'nearest_upsample',
|
||||
'unary']
|
||||
|
||||
###################
|
||||
# KPConv parameters
|
||||
###################
|
||||
|
||||
# Radius of the input sphere
|
||||
in_radius = 10.0
|
||||
val_radius = 51.0
|
||||
n_frames = 1
|
||||
max_in_points = 10000
|
||||
max_val_points = 50000
|
||||
|
||||
# Number of batch
|
||||
batch_num = 6
|
||||
val_batch_num = 1
|
||||
|
||||
# Number of kernel points
|
||||
num_kernel_points = 15
|
||||
|
||||
# Size of the first subsampling grid in meter
|
||||
first_subsampling_dl = 0.08
|
||||
|
||||
# Radius of convolution in "number grid cell". (2.5 is the standard value)
|
||||
conv_radius = 2.5
|
||||
|
||||
# Radius of deformable convolution in "number grid cell". Larger so that deformed kernel can spread out
|
||||
deform_radius = 6.0
|
||||
|
||||
# Radius of the area of influence of each kernel point in "number grid cell". (1.0 is the standard value)
|
||||
KP_extent = 1.5
|
||||
|
||||
# Behavior of convolutions in ('constant', 'linear', 'gaussian')
|
||||
KP_influence = 'linear'
|
||||
|
||||
# Aggregation function of KPConv in ('closest', 'sum')
|
||||
aggregation_mode = 'sum'
|
||||
|
||||
# Choice of input features
|
||||
first_features_dim = 128
|
||||
in_features_dim = 5
|
||||
|
||||
# Can the network learn modulations
|
||||
modulated = False
|
||||
|
||||
# Batch normalization parameters
|
||||
use_batch_norm = True
|
||||
batch_norm_momentum = 0.02
|
||||
|
||||
# Offset loss
|
||||
# 'permissive' only constrains offsets inside the deform radius (NOT implemented yet)
|
||||
# 'fitting' helps deformed kernels to adapt to the geometry by penalizing distance to input points
|
||||
offsets_loss = 'fitting'
|
||||
offsets_decay = 0.01
|
||||
|
||||
#####################
|
||||
# Training parameters
|
||||
#####################
|
||||
|
||||
# Maximal number of epochs
|
||||
max_epoch = 500
|
||||
|
||||
# Learning rate management
|
||||
learning_rate = 1e-2
|
||||
momentum = 0.98
|
||||
lr_decays = {i: 0.1 ** (1 / 100) for i in range(1, max_epoch)}
|
||||
grad_clip_norm = 100.0
|
||||
|
||||
# Number of steps per epochs
|
||||
epoch_steps = 500
|
||||
|
||||
# Number of validation examples per epoch
|
||||
validation_size = 50
|
||||
|
||||
# Number of epoch between each checkpoint
|
||||
checkpoint_gap = 50
|
||||
|
||||
# Augmentations
|
||||
augment_scale_anisotropic = True
|
||||
augment_symmetries = [True, False, False]
|
||||
augment_rotation = 'vertical'
|
||||
augment_scale_min = 0.8
|
||||
augment_scale_max = 1.2
|
||||
augment_noise = 0.001
|
||||
augment_color = 0.8
|
||||
|
||||
# Choose weights for class (used in segmentation loss). Empty list for no weights
|
||||
class_w = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
||||
|
||||
# Do we nee to save convergence
|
||||
saving = True
|
||||
saving_path = None
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Main Call
|
||||
# \***************/
|
||||
#
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
############################
|
||||
# Initialize the environment
|
||||
############################
|
||||
|
||||
# Set which gpu is going to be used
|
||||
GPU_ID = '2'
|
||||
|
||||
# Set GPU visible device
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
|
||||
|
||||
###############
|
||||
# Previous chkp
|
||||
###############
|
||||
|
||||
# Choose here if you want to start training from a previous snapshot (None for new training)
|
||||
# previous_training_path = 'Log_2020-03-19_19-53-27'
|
||||
previous_training_path = ''
|
||||
|
||||
# Choose index of checkpoint to start from. If None, uses the latest chkp
|
||||
chkp_idx = None
|
||||
if previous_training_path:
|
||||
|
||||
# Find all snapshot in the chosen training folder
|
||||
chkp_path = os.path.join('results', previous_training_path, 'checkpoints')
|
||||
chkps = [f for f in os.listdir(chkp_path) if f[:4] == 'chkp']
|
||||
|
||||
# Find which snapshot to restore
|
||||
if chkp_idx is None:
|
||||
chosen_chkp = 'current_chkp.tar'
|
||||
else:
|
||||
chosen_chkp = np.sort(chkps)[chkp_idx]
|
||||
chosen_chkp = os.path.join('results', previous_training_path, 'checkpoints', chosen_chkp)
|
||||
|
||||
else:
|
||||
chosen_chkp = None
|
||||
|
||||
##############
|
||||
# Prepare Data
|
||||
##############
|
||||
|
||||
print()
|
||||
print('Data Preparation')
|
||||
print('****************')
|
||||
|
||||
# Initialize configuration class
|
||||
config = SemanticKittiConfig()
|
||||
if previous_training_path:
|
||||
config.load(os.path.join('results', previous_training_path))
|
||||
config.saving_path = None
|
||||
|
||||
# Get path from argument if given
|
||||
if len(sys.argv) > 1:
|
||||
config.saving_path = sys.argv[1]
|
||||
|
||||
# Initialize datasets
|
||||
training_dataset = SemanticKittiDataset(config, set='training',
|
||||
balance_classes=True)
|
||||
test_dataset = SemanticKittiDataset(config, set='validation',
|
||||
balance_classes=False)
|
||||
|
||||
# Initialize samplers
|
||||
training_sampler = SemanticKittiSampler(training_dataset)
|
||||
test_sampler = SemanticKittiSampler(test_dataset)
|
||||
|
||||
# Initialize the dataloader
|
||||
training_loader = DataLoader(training_dataset,
|
||||
batch_size=1,
|
||||
sampler=training_sampler,
|
||||
collate_fn=SemanticKittiCollate,
|
||||
num_workers=config.input_threads,
|
||||
pin_memory=True)
|
||||
test_loader = DataLoader(test_dataset,
|
||||
batch_size=1,
|
||||
sampler=test_sampler,
|
||||
collate_fn=SemanticKittiCollate,
|
||||
num_workers=config.input_threads,
|
||||
pin_memory=True)
|
||||
|
||||
# Calibrate max_in_point value
|
||||
training_sampler.calib_max_in(config, training_loader, verbose=False)
|
||||
test_sampler.calib_max_in(config, test_loader, verbose=False)
|
||||
|
||||
# Calibrate samplers
|
||||
training_sampler.calibration(training_loader, verbose=True)
|
||||
test_sampler.calibration(test_loader, verbose=True)
|
||||
|
||||
# debug_timing(training_dataset, training_loader)
|
||||
# debug_timing(test_dataset, test_loader)
|
||||
debug_class_w(training_dataset, training_loader)
|
||||
|
||||
print('\nModel Preparation')
|
||||
print('*****************')
|
||||
|
||||
# Define network model
|
||||
t1 = time.time()
|
||||
net = KPFCNN(config, training_dataset.label_values, training_dataset.ignored_labels)
|
||||
|
||||
debug = False
|
||||
if debug:
|
||||
print('\n*************************************\n')
|
||||
print(net)
|
||||
print('\n*************************************\n')
|
||||
for param in net.parameters():
|
||||
if param.requires_grad:
|
||||
print(param.shape)
|
||||
print('\n*************************************\n')
|
||||
print("Model size %i" % sum(param.numel() for param in net.parameters() if param.requires_grad))
|
||||
print('\n*************************************\n')
|
||||
|
||||
# Define a trainer class
|
||||
trainer = ModelTrainer(net, config, chkp_path=chosen_chkp)
|
||||
print('Done in {:.1f}s\n'.format(time.time() - t1))
|
||||
|
||||
print('\nStart training')
|
||||
print('**************')
|
||||
|
||||
# Training
|
||||
trainer.train(net, training_loader, test_loader, config)
|
||||
|
||||
print('Forcing exit now')
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
# TODO: Create a function debug_class_weights that shows class distribution in input sphere. Use that as
|
||||
# indication for the class weights during training
|
|
@ -117,8 +117,10 @@ class Config:
|
|||
# For SLAM datasets like SemanticKitti number of frames used (minimum one)
|
||||
n_frames = 1
|
||||
|
||||
# For SLAM datasets like SemanticKitti max number of point in input cloud
|
||||
# For SLAM datasets like SemanticKitti max number of point in input cloud + validation
|
||||
max_in_points = 0
|
||||
val_radius = 51.0
|
||||
max_val_points = 50000
|
||||
|
||||
#####################
|
||||
# Training parameters
|
||||
|
@ -151,18 +153,19 @@ class Config:
|
|||
# Regularization loss importance
|
||||
weight_decay = 1e-3
|
||||
|
||||
# The way we balance segmentation loss
|
||||
# > 'none': Each point in the whole batch has the same contribution.
|
||||
# > 'class': Each class has the same contribution (points are weighted according to class balance)
|
||||
# > 'batch': Each cloud in the batch has the same contribution (points are weighted according cloud sizes)
|
||||
# The way we balance segmentation loss DEPRECATED
|
||||
segloss_balance = 'none'
|
||||
|
||||
# Choose weights for class (used in segmentation loss). Empty list for no weights
|
||||
class_w = []
|
||||
|
||||
# Offset regularization loss
|
||||
offsets_loss = 'permissive'
|
||||
offsets_decay = 1e-2
|
||||
|
||||
# Number of batch
|
||||
batch_num = 10
|
||||
val_batch_num = 10
|
||||
|
||||
# Maximal number of epochs
|
||||
max_epoch = 1000
|
||||
|
@ -253,6 +256,9 @@ class Config:
|
|||
else:
|
||||
self.num_classes = int(line_info[2])
|
||||
|
||||
elif line_info[0] == 'class_w':
|
||||
self.class_w = [float(w) for w in line_info[2:]]
|
||||
|
||||
else:
|
||||
attr_type = type(getattr(self, line_info[0]))
|
||||
if attr_type == bool:
|
||||
|
@ -320,6 +326,8 @@ class Config:
|
|||
text_file.write('modulated = {:d}\n'.format(int(self.modulated)))
|
||||
text_file.write('n_frames = {:d}\n'.format(self.n_frames))
|
||||
text_file.write('max_in_points = {:d}\n\n'.format(self.max_in_points))
|
||||
text_file.write('max_val_points = {:d}\n\n'.format(self.max_val_points))
|
||||
text_file.write('val_radius = {:.3f}\n\n'.format(self.val_radius))
|
||||
|
||||
# Training parameters
|
||||
text_file.write('# Training parameters\n')
|
||||
|
@ -350,9 +358,14 @@ class Config:
|
|||
|
||||
text_file.write('weight_decay = {:f}\n'.format(self.weight_decay))
|
||||
text_file.write('segloss_balance = {:s}\n'.format(self.segloss_balance))
|
||||
text_file.write('class_w =')
|
||||
for a in self.class_w:
|
||||
text_file.write(' {:.3f}'.format(a))
|
||||
text_file.write('\n')
|
||||
text_file.write('offsets_loss = {:s}\n'.format(self.offsets_loss))
|
||||
text_file.write('offsets_decay = {:f}\n'.format(self.offsets_decay))
|
||||
text_file.write('batch_num = {:d}\n'.format(self.batch_num))
|
||||
text_file.write('val_batch_num = {:d}\n'.format(self.val_batch_num))
|
||||
text_file.write('max_epoch = {:d}\n'.format(self.max_epoch))
|
||||
if self.epoch_steps is None:
|
||||
text_file.write('epoch_steps = None\n')
|
||||
|
|
688
utils/tester.py
Normal file
688
utils/tester.py
Normal file
|
@ -0,0 +1,688 @@
|
|||
#
|
||||
#
|
||||
# 0=================================0
|
||||
# | Kernel Point Convolutions |
|
||||
# 0=================================0
|
||||
#
|
||||
#
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Class handling the test of any model
|
||||
#
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Hugues THOMAS - 11/06/2018
|
||||
#
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Imports and global variables
|
||||
# \**********************************/
|
||||
#
|
||||
|
||||
|
||||
# Basic libs
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from os import makedirs, listdir
|
||||
from os.path import exists, join
|
||||
import time
|
||||
import json
|
||||
from sklearn.neighbors import KDTree
|
||||
|
||||
# PLY reader
|
||||
from utils.ply import read_ply, write_ply
|
||||
|
||||
# Metrics
|
||||
from utils.metrics import IoU_from_confusions, fast_confusion
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
#from utils.visualizer import show_ModelNet_models
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Tester Class
|
||||
# \******************/
|
||||
#
|
||||
|
||||
|
||||
class ModelTester:
|
||||
|
||||
# Initialization methods
|
||||
# ------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
def __init__(self, net, chkp_path=None, on_gpu=True):
|
||||
|
||||
############
|
||||
# Parameters
|
||||
############
|
||||
|
||||
# Choose to train on CPU or GPU
|
||||
if on_gpu and torch.cuda.is_available():
|
||||
self.device = torch.device("cuda:0")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
net.to(self.device)
|
||||
|
||||
##########################
|
||||
# Load previous checkpoint
|
||||
##########################
|
||||
|
||||
checkpoint = torch.load(chkp_path)
|
||||
net.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.epoch = checkpoint['epoch']
|
||||
net.eval()
|
||||
print("Model and training state restored.")
|
||||
|
||||
return
|
||||
|
||||
# Test main methods
|
||||
# ------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
def cloud_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False):
|
||||
"""
|
||||
Test method for cloud segmentation models
|
||||
"""
|
||||
|
||||
############
|
||||
# Initialize
|
||||
############
|
||||
|
||||
# Choose test smoothing parameter (0 for no smothing, 0.99 for big smoothing)
|
||||
test_smooth = 0.98
|
||||
softmax = torch.nn.Softmax(1)
|
||||
|
||||
# Number of classes including ignored labels
|
||||
nc_tot = test_loader.dataset.num_classes
|
||||
|
||||
# Number of classes predicted by the model
|
||||
nc_model = config.num_classes
|
||||
|
||||
# Initiate global prediction over test clouds
|
||||
self.test_probs = [np.zeros((l.shape[0], nc_model)) for l in test_loader.dataset.input_labels]
|
||||
|
||||
# Test saving path
|
||||
if config.saving:
|
||||
test_path = join('test', config.saving_path.split('/')[-1])
|
||||
if not exists(test_path):
|
||||
makedirs(test_path)
|
||||
if not exists(join(test_path, 'predictions')):
|
||||
makedirs(join(test_path, 'predictions'))
|
||||
if not exists(join(test_path, 'probs')):
|
||||
makedirs(join(test_path, 'probs'))
|
||||
if not exists(join(test_path, 'potentials')):
|
||||
makedirs(join(test_path, 'potentials'))
|
||||
else:
|
||||
test_path = None
|
||||
|
||||
# If on validation directly compute score
|
||||
if test_loader.dataset.set == 'validation':
|
||||
val_proportions = np.zeros(nc_model, dtype=np.float32)
|
||||
i = 0
|
||||
for label_value in test_loader.dataset.label_values:
|
||||
if label_value not in test_loader.dataset.ignored_labels:
|
||||
val_proportions[i] = np.sum([np.sum(labels == label_value)
|
||||
for labels in test_loader.dataset.validation_labels])
|
||||
i += 1
|
||||
else:
|
||||
val_proportions = None
|
||||
|
||||
#####################
|
||||
# Network predictions
|
||||
#####################
|
||||
|
||||
test_epoch = 0
|
||||
last_min = -0.5
|
||||
|
||||
t = [time.time()]
|
||||
last_display = time.time()
|
||||
mean_dt = np.zeros(1)
|
||||
|
||||
# Start test loop
|
||||
while True:
|
||||
print('Initialize workers')
|
||||
for i, batch in enumerate(test_loader):
|
||||
|
||||
# New time
|
||||
t = t[-1:]
|
||||
t += [time.time()]
|
||||
|
||||
if i == 0:
|
||||
print('Done in {:.1f}s'.format(t[1] - t[0]))
|
||||
|
||||
if 'cuda' in self.device.type:
|
||||
batch.to(self.device)
|
||||
|
||||
# Forward pass
|
||||
outputs = net(batch, config)
|
||||
|
||||
t += [time.time()]
|
||||
|
||||
# Get probs and labels
|
||||
stacked_probs = softmax(outputs).cpu().detach().numpy()
|
||||
lengths = batch.lengths[0].cpu().numpy()
|
||||
in_inds = batch.input_inds.cpu().numpy()
|
||||
cloud_inds = batch.cloud_inds.cpu().numpy()
|
||||
torch.cuda.synchronize(self.device)
|
||||
|
||||
# Get predictions and labels per instance
|
||||
# ***************************************
|
||||
|
||||
i0 = 0
|
||||
for b_i, length in enumerate(lengths):
|
||||
|
||||
# Get prediction
|
||||
probs = stacked_probs[i0:i0 + length]
|
||||
inds = in_inds[i0:i0 + length]
|
||||
c_i = cloud_inds[b_i]
|
||||
|
||||
# Update current probs in whole cloud
|
||||
self.test_probs[c_i][inds] = test_smooth * self.test_probs[c_i][inds] + (1 - test_smooth) * probs
|
||||
i0 += length
|
||||
|
||||
# Average timing
|
||||
t += [time.time()]
|
||||
if i < 2:
|
||||
mean_dt = np.array(t[1:]) - np.array(t[:-1])
|
||||
else:
|
||||
mean_dt = 0.9 * mean_dt + 0.1 * (np.array(t[1:]) - np.array(t[:-1]))
|
||||
|
||||
# Display
|
||||
if (t[-1] - last_display) > 1.0:
|
||||
last_display = t[-1]
|
||||
message = 'e{:03d}-i{:04d} => {:.1f}% (timings : {:4.2f} {:4.2f} {:4.2f})'
|
||||
print(message.format(test_epoch, i,
|
||||
100 * i / config.validation_size,
|
||||
1000 * (mean_dt[0]),
|
||||
1000 * (mean_dt[1]),
|
||||
1000 * (mean_dt[2])))
|
||||
|
||||
# Update minimum od potentials
|
||||
new_min = torch.min(test_loader.dataset.min_potentials)
|
||||
print('Test epoch {:d}, end. Min potential = {:.1f}'.format(test_epoch, new_min))
|
||||
#print([np.mean(pots) for pots in test_loader.dataset.potentials])
|
||||
|
||||
# Save predicted cloud
|
||||
if last_min + 1 < new_min:
|
||||
|
||||
# Update last_min
|
||||
last_min += 1
|
||||
|
||||
# Show vote results (On subcloud so it is not the good values here)
|
||||
if test_loader.dataset.set == 'validation':
|
||||
print('\nConfusion on sub clouds')
|
||||
Confs = []
|
||||
for i, file_path in enumerate(test_loader.dataset.files):
|
||||
|
||||
# Insert false columns for ignored labels
|
||||
probs = np.array(self.test_probs[i], copy=True)
|
||||
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
|
||||
if label_value in test_loader.dataset.ignored_labels:
|
||||
probs = np.insert(probs, l_ind, 0, axis=1)
|
||||
|
||||
# Predicted labels
|
||||
preds = test_loader.dataset.label_values[np.argmax(probs, axis=1)].astype(np.int32)
|
||||
|
||||
# Targets
|
||||
targets = test_loader.dataset.input_labels[i]
|
||||
|
||||
# Confs
|
||||
Confs += [fast_confusion(targets, preds, test_loader.dataset.label_values)]
|
||||
|
||||
# Regroup confusions
|
||||
C = np.sum(np.stack(Confs), axis=0).astype(np.float32)
|
||||
|
||||
# Remove ignored labels from confusions
|
||||
for l_ind, label_value in reversed(list(enumerate(test_loader.dataset.label_values))):
|
||||
if label_value in test_loader.dataset.ignored_labels:
|
||||
C = np.delete(C, l_ind, axis=0)
|
||||
C = np.delete(C, l_ind, axis=1)
|
||||
|
||||
# Rescale with the right number of point per class
|
||||
C *= np.expand_dims(val_proportions / (np.sum(C, axis=1) + 1e-6), 1)
|
||||
|
||||
# Compute IoUs
|
||||
IoUs = IoU_from_confusions(C)
|
||||
mIoU = np.mean(IoUs)
|
||||
s = '{:5.2f} | '.format(100 * mIoU)
|
||||
for IoU in IoUs:
|
||||
s += '{:5.2f} '.format(100 * IoU)
|
||||
print(s + '\n')
|
||||
|
||||
# Save real IoU once in a while
|
||||
if int(np.ceil(new_min)) % 10 == 0:
|
||||
|
||||
# Project predictions
|
||||
print('\nReproject Vote #{:d}'.format(int(np.floor(new_min))))
|
||||
t1 = time.time()
|
||||
proj_probs = []
|
||||
for i, file_path in enumerate(test_loader.dataset.files):
|
||||
|
||||
print(i, file_path, test_loader.dataset.test_proj[i].shape, self.test_probs[i].shape)
|
||||
|
||||
print(test_loader.dataset.test_proj[i].dtype, np.max(test_loader.dataset.test_proj[i]))
|
||||
print(test_loader.dataset.test_proj[i][:5])
|
||||
|
||||
# Reproject probs on the evaluations points
|
||||
probs = self.test_probs[i][test_loader.dataset.test_proj[i], :]
|
||||
proj_probs += [probs]
|
||||
|
||||
t2 = time.time()
|
||||
print('Done in {:.1f} s\n'.format(t2 - t1))
|
||||
|
||||
# Show vote results
|
||||
if test_loader.dataset.set == 'validation':
|
||||
print('Confusion on full clouds')
|
||||
t1 = time.time()
|
||||
Confs = []
|
||||
for i, file_path in enumerate(test_loader.dataset.files):
|
||||
|
||||
# Insert false columns for ignored labels
|
||||
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
|
||||
if label_value in test_loader.dataset.ignored_labels:
|
||||
proj_probs[i] = np.insert(proj_probs[i], l_ind, 0, axis=1)
|
||||
|
||||
# Get the predicted labels
|
||||
preds = test_loader.dataset.label_values[np.argmax(proj_probs[i], axis=1)].astype(np.int32)
|
||||
|
||||
# Confusion
|
||||
targets = test_loader.dataset.validation_labels[i]
|
||||
Confs += [fast_confusion(targets, preds, test_loader.dataset.label_values)]
|
||||
|
||||
t2 = time.time()
|
||||
print('Done in {:.1f} s\n'.format(t2 - t1))
|
||||
|
||||
# Regroup confusions
|
||||
C = np.sum(np.stack(Confs), axis=0)
|
||||
|
||||
# Remove ignored labels from confusions
|
||||
for l_ind, label_value in reversed(list(enumerate(test_loader.dataset.label_values))):
|
||||
if label_value in test_loader.dataset.ignored_labels:
|
||||
C = np.delete(C, l_ind, axis=0)
|
||||
C = np.delete(C, l_ind, axis=1)
|
||||
|
||||
IoUs = IoU_from_confusions(C)
|
||||
mIoU = np.mean(IoUs)
|
||||
s = '{:5.2f} | '.format(100 * mIoU)
|
||||
for IoU in IoUs:
|
||||
s += '{:5.2f} '.format(100 * IoU)
|
||||
print('-' * len(s))
|
||||
print(s)
|
||||
print('-' * len(s) + '\n')
|
||||
|
||||
# Save predictions
|
||||
print('Saving clouds')
|
||||
t1 = time.time()
|
||||
for i, file_path in enumerate(test_loader.dataset.files):
|
||||
|
||||
# Get file
|
||||
points = test_loader.dataset.load_evaluation_points(file_path)
|
||||
|
||||
# Get the predicted labels
|
||||
preds = test_loader.dataset.label_values[np.argmax(proj_probs[i], axis=1)].astype(np.int32)
|
||||
|
||||
# Save plys
|
||||
cloud_name = file_path.split('/')[-1]
|
||||
test_name = join(test_path, 'predictions', cloud_name)
|
||||
write_ply(test_name,
|
||||
[points, preds],
|
||||
['x', 'y', 'z', 'preds'])
|
||||
test_name2 = join(test_path, 'probs', cloud_name)
|
||||
prob_names = ['_'.join(test_loader.dataset.label_to_names[label].split())
|
||||
for label in test_loader.dataset.label_values]
|
||||
write_ply(test_name2,
|
||||
[points, proj_probs[i]],
|
||||
['x', 'y', 'z'] + prob_names)
|
||||
|
||||
# Save potentials
|
||||
pot_points = np.array(test_loader.dataset.pot_trees[i].data, copy=False)
|
||||
pot_name = join(test_path, 'potentials', cloud_name)
|
||||
pots = test_loader.dataset.potentials[i].numpy().astype(np.float32)
|
||||
write_ply(pot_name,
|
||||
[pot_points.astype(np.float32), pots],
|
||||
['x', 'y', 'z', 'pots'])
|
||||
|
||||
# Save ascii preds
|
||||
if test_loader.dataset.set == 'test':
|
||||
if test_loader.dataset.name.startswith('Semantic3D'):
|
||||
ascii_name = join(test_path, 'predictions', test_loader.dataset.ascii_files[cloud_name])
|
||||
else:
|
||||
ascii_name = join(test_path, 'predictions', cloud_name[:-4] + '.txt')
|
||||
np.savetxt(ascii_name, preds, fmt='%d')
|
||||
|
||||
t2 = time.time()
|
||||
print('Done in {:.1f} s\n'.format(t2 - t1))
|
||||
|
||||
test_epoch += 1
|
||||
|
||||
# Break when reaching number of desired votes
|
||||
if last_min > num_votes:
|
||||
break
|
||||
|
||||
return
|
||||
|
||||
def slam_segmentation_test(self, net, test_loader, config, num_votes=100, debug=False):
|
||||
"""
|
||||
Test method for slam segmentation models
|
||||
"""
|
||||
|
||||
############
|
||||
# Initialize
|
||||
############
|
||||
|
||||
# Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing)
|
||||
test_smooth = 0
|
||||
last_min = -0.5
|
||||
softmax = torch.nn.Softmax(1)
|
||||
|
||||
# Number of classes including ignored labels
|
||||
nc_tot = test_loader.dataset.num_classes
|
||||
nc_model = net.C
|
||||
|
||||
# Test saving path
|
||||
test_path = None
|
||||
report_path = None
|
||||
if config.saving:
|
||||
test_path = join('test', config.saving_path.split('/')[-1])
|
||||
if not exists(test_path):
|
||||
makedirs(test_path)
|
||||
report_path = join(test_path, 'reports')
|
||||
if not exists(report_path):
|
||||
makedirs(report_path)
|
||||
|
||||
if test_loader.dataset.set == 'validation':
|
||||
for folder in ['val_predictions', 'val_probs']:
|
||||
if not exists(join(test_path, folder)):
|
||||
makedirs(join(test_path, folder))
|
||||
else:
|
||||
for folder in ['predictions', 'probs']:
|
||||
if not exists(join(test_path, folder)):
|
||||
makedirs(join(test_path, folder))
|
||||
|
||||
# Init validation container
|
||||
all_f_preds = []
|
||||
all_f_labels = []
|
||||
if test_loader.dataset.set == 'validation':
|
||||
for i, seq_frames in enumerate(test_loader.dataset.frames):
|
||||
all_f_preds.append([np.zeros((0,), dtype=np.int32) for _ in seq_frames])
|
||||
all_f_labels.append([np.zeros((0,), dtype=np.int32) for _ in seq_frames])
|
||||
|
||||
#####################
|
||||
# Network predictions
|
||||
#####################
|
||||
|
||||
predictions = []
|
||||
targets = []
|
||||
test_epoch = 0
|
||||
|
||||
t = [time.time()]
|
||||
last_display = time.time()
|
||||
mean_dt = np.zeros(1)
|
||||
|
||||
# Start test loop
|
||||
while True:
|
||||
print('Initialize workers')
|
||||
for i, batch in enumerate(test_loader):
|
||||
|
||||
# New time
|
||||
t = t[-1:]
|
||||
t += [time.time()]
|
||||
|
||||
if i == 0:
|
||||
print('Done in {:.1f}s'.format(t[1] - t[0]))
|
||||
|
||||
if 'cuda' in self.device.type:
|
||||
batch.to(self.device)
|
||||
|
||||
# Forward pass
|
||||
outputs = net(batch, config)
|
||||
|
||||
# Get probs and labels
|
||||
stk_probs = softmax(outputs).cpu().detach().numpy()
|
||||
lengths = batch.lengths[0].cpu().numpy()
|
||||
f_inds = batch.frame_inds.cpu().numpy()
|
||||
r_inds_list = batch.reproj_inds
|
||||
r_mask_list = batch.reproj_masks
|
||||
labels_list = batch.val_labels
|
||||
torch.cuda.synchronize(self.device)
|
||||
|
||||
t += [time.time()]
|
||||
|
||||
# Get predictions and labels per instance
|
||||
# ***************************************
|
||||
|
||||
i0 = 0
|
||||
for b_i, length in enumerate(lengths):
|
||||
|
||||
# Get prediction
|
||||
probs = stk_probs[i0:i0 + length]
|
||||
proj_inds = r_inds_list[b_i]
|
||||
proj_mask = r_mask_list[b_i]
|
||||
frame_labels = labels_list[b_i]
|
||||
s_ind = f_inds[b_i, 0]
|
||||
f_ind = f_inds[b_i, 1]
|
||||
|
||||
# Project predictions on the frame points
|
||||
proj_probs = probs[proj_inds]
|
||||
|
||||
# Safe check if only one point:
|
||||
if proj_probs.ndim < 2:
|
||||
proj_probs = np.expand_dims(proj_probs, 0)
|
||||
|
||||
# Save probs in a binary file (uint8 format for lighter weight)
|
||||
seq_name = test_loader.dataset.sequences[s_ind]
|
||||
if test_loader.dataset.set == 'validation':
|
||||
folder = 'val_probs'
|
||||
pred_folder = 'val_predictions'
|
||||
else:
|
||||
folder = 'probs'
|
||||
pred_folder = 'predictions'
|
||||
filename = '{:s}_{:07d}.npy'.format(seq_name, f_ind)
|
||||
filepath = join(test_path, folder, filename)
|
||||
if exists(filepath):
|
||||
frame_probs_uint8 = np.load(filepath)
|
||||
else:
|
||||
frame_probs_uint8 = np.zeros((proj_mask.shape[0], nc_model), dtype=np.uint8)
|
||||
frame_probs = frame_probs_uint8[proj_mask, :].astype(np.float32) / 255
|
||||
frame_probs = test_smooth * frame_probs + (1 - test_smooth) * proj_probs
|
||||
frame_probs_uint8[proj_mask, :] = (frame_probs * 255).astype(np.uint8)
|
||||
np.save(filepath, frame_probs_uint8)
|
||||
|
||||
# Save some prediction in ply format for visual
|
||||
if test_loader.dataset.set == 'validation':
|
||||
|
||||
# Insert false columns for ignored labels
|
||||
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
|
||||
if label_value in test_loader.dataset.ignored_labels:
|
||||
frame_probs_uint8 = np.insert(frame_probs_uint8, l_ind, 0, axis=1)
|
||||
|
||||
# Predicted labels
|
||||
frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8,
|
||||
axis=1)].astype(np.int32)
|
||||
|
||||
# Save some of the frame pots
|
||||
if f_ind % 20 == 0:
|
||||
seq_path = join(test_loader.dataset.path, 'sequences', test_loader.dataset.sequences[s_ind])
|
||||
velo_file = join(seq_path, 'velodyne', test_loader.dataset.frames[s_ind][f_ind] + '.bin')
|
||||
frame_points = np.fromfile(velo_file, dtype=np.float32)
|
||||
frame_points = frame_points.reshape((-1, 4))
|
||||
predpath = join(test_path, pred_folder, filename[:-4] + '.ply')
|
||||
#pots = test_loader.dataset.f_potentials[s_ind][f_ind]
|
||||
pots = np.zeros((0,))
|
||||
if pots.shape[0] > 0:
|
||||
write_ply(predpath,
|
||||
[frame_points[:, :3], frame_labels, frame_preds, pots],
|
||||
['x', 'y', 'z', 'gt', 'pre', 'pots'])
|
||||
else:
|
||||
write_ply(predpath,
|
||||
[frame_points[:, :3], frame_labels, frame_preds],
|
||||
['x', 'y', 'z', 'gt', 'pre'])
|
||||
|
||||
# keep frame preds in memory
|
||||
all_f_preds[s_ind][f_ind] = frame_preds
|
||||
all_f_labels[s_ind][f_ind] = frame_labels
|
||||
|
||||
else:
|
||||
|
||||
# Save some of the frame preds
|
||||
if f_inds[b_i, 1] % 100 == 0:
|
||||
|
||||
# Insert false columns for ignored labels
|
||||
for l_ind, label_value in enumerate(test_loader.dataset.label_values):
|
||||
if label_value in test_loader.dataset.ignored_labels:
|
||||
frame_probs_uint8 = np.insert(frame_probs_uint8, l_ind, 0, axis=1)
|
||||
|
||||
# Predicted labels
|
||||
frame_preds = test_loader.dataset.label_values[np.argmax(frame_probs_uint8,
|
||||
axis=1)].astype(np.int32)
|
||||
|
||||
# Load points
|
||||
seq_path = join(test_loader.dataset.path, 'sequences', test_loader.dataset.sequences[s_ind])
|
||||
velo_file = join(seq_path, 'velodyne', test_loader.dataset.frames[s_ind][f_ind] + '.bin')
|
||||
frame_points = np.fromfile(velo_file, dtype=np.float32)
|
||||
frame_points = frame_points.reshape((-1, 4))
|
||||
predpath = join(test_path, pred_folder, filename[:-4] + '.ply')
|
||||
#pots = test_loader.dataset.f_potentials[s_ind][f_ind]
|
||||
pots = np.zeros((0,))
|
||||
if pots.shape[0] > 0:
|
||||
write_ply(predpath,
|
||||
[frame_points[:, :3], frame_preds, pots],
|
||||
['x', 'y', 'z', 'pre', 'pots'])
|
||||
else:
|
||||
write_ply(predpath,
|
||||
[frame_points[:, :3], frame_preds],
|
||||
['x', 'y', 'z', 'pre'])
|
||||
|
||||
# Stack all prediction for this epoch
|
||||
i0 += length
|
||||
|
||||
# Average timing
|
||||
t += [time.time()]
|
||||
mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1]))
|
||||
|
||||
# Display
|
||||
if (t[-1] - last_display) > 1.0:
|
||||
last_display = t[-1]
|
||||
message = 'e{:03d}-i{:04d} => {:.1f}% (timings : {:4.2f} {:4.2f} {:4.2f}) / pots {:d} => {:.1f}%'
|
||||
min_pot = int(torch.floor(torch.min(test_loader.dataset.potentials)))
|
||||
pot_num = torch.sum(test_loader.dataset.potentials > min_pot).type(torch.int32).item()
|
||||
current_num = pot_num + (i0 + 1 - config.validation_size) * config.val_batch_num
|
||||
print(message.format(test_epoch, i,
|
||||
100 * i / config.validation_size,
|
||||
1000 * (mean_dt[0]),
|
||||
1000 * (mean_dt[1]),
|
||||
1000 * (mean_dt[2]),
|
||||
min_pot,
|
||||
100.0 * current_num / len(test_loader.dataset.potentials)))
|
||||
|
||||
|
||||
# Update minimum od potentials
|
||||
new_min = torch.min(test_loader.dataset.potentials)
|
||||
print('Test epoch {:d}, end. Min potential = {:.1f}'.format(test_epoch, new_min))
|
||||
|
||||
if last_min + 1 < new_min:
|
||||
|
||||
# Update last_min
|
||||
last_min += 1
|
||||
|
||||
if test_loader.dataset.set == 'validation' and last_min % 1 == 0:
|
||||
|
||||
#####################################
|
||||
# Results on the whole validation set
|
||||
#####################################
|
||||
|
||||
# Confusions for our subparts of validation set
|
||||
Confs = np.zeros((len(predictions), nc_tot, nc_tot), dtype=np.int32)
|
||||
for i, (preds, truth) in enumerate(zip(predictions, targets)):
|
||||
|
||||
# Confusions
|
||||
Confs[i, :, :] = fast_confusion(truth, preds, test_loader.dataset.label_values).astype(np.int32)
|
||||
|
||||
|
||||
# Show vote results
|
||||
print('\nCompute confusion')
|
||||
|
||||
val_preds = []
|
||||
val_labels = []
|
||||
t1 = time.time()
|
||||
for i, seq_frames in enumerate(test_loader.dataset.frames):
|
||||
val_preds += [np.hstack(all_f_preds[i])]
|
||||
val_labels += [np.hstack(all_f_labels[i])]
|
||||
val_preds = np.hstack(val_preds)
|
||||
val_labels = np.hstack(val_labels)
|
||||
t2 = time.time()
|
||||
C_tot = fast_confusion(val_labels, val_preds, test_loader.dataset.label_values)
|
||||
t3 = time.time()
|
||||
print(' Stacking time : {:.1f}s'.format(t2 - t1))
|
||||
print('Confusion time : {:.1f}s'.format(t3 - t2))
|
||||
|
||||
s1 = '\n'
|
||||
for cc in C_tot:
|
||||
for c in cc:
|
||||
s1 += '{:7.0f} '.format(c)
|
||||
s1 += '\n'
|
||||
if debug:
|
||||
print(s1)
|
||||
|
||||
# Remove ignored labels from confusions
|
||||
for l_ind, label_value in reversed(list(enumerate(test_loader.dataset.label_values))):
|
||||
if label_value in test_loader.dataset.ignored_labels:
|
||||
C_tot = np.delete(C_tot, l_ind, axis=0)
|
||||
C_tot = np.delete(C_tot, l_ind, axis=1)
|
||||
|
||||
# Objects IoU
|
||||
val_IoUs = IoU_from_confusions(C_tot)
|
||||
|
||||
# Compute IoUs
|
||||
mIoU = np.mean(val_IoUs)
|
||||
s2 = '{:5.2f} | '.format(100 * mIoU)
|
||||
for IoU in val_IoUs:
|
||||
s2 += '{:5.2f} '.format(100 * IoU)
|
||||
print(s2 + '\n')
|
||||
|
||||
# Save a report
|
||||
report_file = join(report_path, 'report_{:04d}.txt'.format(int(np.floor(last_min))))
|
||||
str = 'Report of the confusion and metrics\n'
|
||||
str += '***********************************\n\n\n'
|
||||
str += 'Confusion matrix:\n\n'
|
||||
str += s1
|
||||
str += '\nIoU values:\n\n'
|
||||
str += s2
|
||||
str += '\n\n'
|
||||
with open(report_file, 'w') as f:
|
||||
f.write(str)
|
||||
|
||||
test_epoch += 1
|
||||
|
||||
# Break when reaching number of desired votes
|
||||
if last_min > num_votes:
|
||||
break
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
1085
utils/trainer.py
1085
utils/trainer.py
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue