KPConv-PyTorch/utils/trainer.py
2020-04-23 09:51:16 -04:00

1131 lines
40 KiB
Python

#
#
# 0=================================0
# | Kernel Point Convolutions |
# 0=================================0
#
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Class handling the training of any model
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Hugues THOMAS - 11/06/2018
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Imports and global variables
# \**********************************/
#
# Basic libs
import torch
import torch.nn as nn
import numpy as np
import pickle
import os
from os import makedirs, remove
from os.path import exists, join
import time
import sys
# PLY reader
from utils.ply import read_ply, write_ply
# Metrics
from utils.metrics import IoU_from_confusions, fast_confusion
from utils.config import Config
from sklearn.neighbors import KDTree
from models.blocks import KPConv
# ----------------------------------------------------------------------------------------------------------------------
#
# Trainer Class
# \*******************/
#
class ModelTrainer:
# Initialization methods
# ------------------------------------------------------------------------------------------------------------------
def __init__(self, net, config, chkp_path=None, finetune=False, on_gpu=True):
"""
Initialize training parameters and reload previous model for restore/finetune
:param net: network object
:param config: configuration object
:param chkp_path: path to the checkpoint that needs to be loaded (None for new training)
:param finetune: finetune from checkpoint (True) or restore training from checkpoint (False)
:param on_gpu: Train on GPU or CPU
"""
############
# Parameters
############
# Epoch index
self.epoch = 0
self.step = 0
# Optimizer
self.optimizer = torch.optim.SGD(net.parameters(),
lr=config.learning_rate,
momentum=config.momentum,
weight_decay=config.weight_decay)
# Choose to train on CPU or GPU
if on_gpu and torch.cuda.is_available():
self.device = torch.device("cuda:0")
else:
self.device = torch.device("cpu")
net.to(self.device)
##########################
# Load previous checkpoint
##########################
if (chkp_path is not None):
if finetune:
checkpoint = torch.load(chkp_path)
net.load_state_dict(checkpoint['model_state_dict'])
net.train()
print("Model restored and ready for finetuning.")
else:
checkpoint = torch.load(chkp_path)
net.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.epoch = checkpoint['epoch']
net.train()
print("Model and training state restored.")
# Path of the result folder
if config.saving:
if config.saving_path is None:
config.saving_path = time.strftime('results/Log_%Y-%m-%d_%H-%M-%S', time.gmtime())
if not exists(config.saving_path):
makedirs(config.saving_path)
config.save()
return
# Training main method
# ------------------------------------------------------------------------------------------------------------------
def train(self, net, training_loader, val_loader, config):
"""
Train the model on a particular dataset.
"""
################
# Initialization
################
if config.saving:
# Training log file
with open(join(config.saving_path, 'training.txt'), "w") as file:
file.write('epochs steps out_loss offset_loss train_accuracy time\n')
# Killing file (simply delete this file when you want to stop the training)
PID_file = join(config.saving_path, 'running_PID.txt')
if not exists(PID_file):
with open(PID_file, "w") as file:
file.write('Launched with PyCharm')
# Checkpoints directory
checkpoint_directory = join(config.saving_path, 'checkpoints')
if not exists(checkpoint_directory):
makedirs(checkpoint_directory)
else:
checkpoint_directory = None
PID_file = None
# Loop variables
t0 = time.time()
t = [time.time()]
last_display = time.time()
mean_dt = np.zeros(1)
# Start training loop
for epoch in range(config.max_epoch):
# Remove File for kill signal
if epoch == config.max_epoch - 1 and exists(PID_file):
remove(PID_file)
self.step = 0
for batch in training_loader:
# Check kill signal (running_PID.txt deleted)
if config.saving and not exists(PID_file):
continue
##################
# Processing batch
##################
# New time
t = t[-1:]
t += [time.time()]
if 'cuda' in self.device.type:
batch.to(self.device)
# zero the parameter gradients
self.optimizer.zero_grad()
# Forward pass
outputs = net(batch, config)
loss = net.loss(outputs, batch.labels)
acc = net.accuracy(outputs, batch.labels)
t += [time.time()]
# Backward + optimize
loss.backward()
if config.grad_clip_norm > 0:
#torch.nn.utils.clip_grad_norm_(net.parameters(), config.grad_clip_norm)
torch.nn.utils.clip_grad_value_(net.parameters(), config.grad_clip_norm)
self.optimizer.step()
torch.cuda.synchronize(self.device)
t += [time.time()]
# Average timing
if self.step < 2:
mean_dt = np.array(t[1:]) - np.array(t[:-1])
else:
mean_dt = 0.9 * mean_dt + 0.1 * (np.array(t[1:]) - np.array(t[:-1]))
# Console display (only one per second)
if (t[-1] - last_display) > 1.0:
last_display = t[-1]
message = 'e{:03d}-i{:04d} => L={:.3f} acc={:3.0f}% / t(ms): {:5.1f} {:5.1f} {:5.1f})'
print(message.format(self.epoch, self.step,
loss.item(),
100*acc,
1000 * mean_dt[0],
1000 * mean_dt[1],
1000 * mean_dt[2]))
# Log file
if config.saving:
with open(join(config.saving_path, 'training.txt'), "a") as file:
message = '{:d} {:d} {:.3f} {:.3f} {:.3f} {:.3f}\n'
file.write(message.format(self.epoch,
self.step,
net.output_loss,
net.reg_loss,
acc,
t[-1] - t0))
self.step += 1
##############
# End of epoch
##############
# Check kill signal (running_PID.txt deleted)
if config.saving and not exists(PID_file):
break
# Update learning rate
if self.epoch in config.lr_decays:
for param_group in self.optimizer.param_groups:
param_group['lr'] *= config.lr_decays[self.epoch]
# Update epoch
self.epoch += 1
# Saving
if config.saving:
# Get current state dict
save_dict = {'epoch': self.epoch,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'saving_path': config.saving_path}
# Save current state of the network (for restoring purposes)
checkpoint_path = join(checkpoint_directory, 'current_chkp.tar')
torch.save(save_dict, checkpoint_path)
# Save checkpoints occasionally
if (self.epoch + 1) % config.checkpoint_gap == 0:
checkpoint_path = join(checkpoint_directory, 'chkp_{:04d}.tar'.format(self.epoch + 1))
torch.save(save_dict, checkpoint_path)
# Validation
net.eval()
self.validation(net, val_loader, config)
net.train()
print('Finished Training')
return
# Validation methods
# ------------------------------------------------------------------------------------------------------------------
def validation(self, net, val_loader, config: Config):
if config.dataset_task == 'classification':
self.object_classification_validation(net, val_loader, config)
elif config.dataset_task == 'segmentation':
self.object_segmentation_validation(net, val_loader, config)
elif config.dataset_task == 'cloud_segmentation':
self.cloud_segmentation_validation(net, val_loader, config)
elif config.dataset_task == 'slam_segmentation':
self.slam_segmentation_validation(net, val_loader, config)
else:
raise ValueError('No validation method implemented for this network type')
def object_classification_validation(self, net, val_loader, config):
"""
Perform a round of validation and show/save results
:param net: network object
:param val_loader: data loader for validation set
:param config: configuration object
"""
############
# Initialize
############
# Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing)
val_smooth = 0.95
# Number of classes predicted by the model
nc_model = config.num_classes
softmax = torch.nn.Softmax(1)
# Initialize global prediction over all models
if not hasattr(self, 'val_probs'):
self.val_probs = np.zeros((val_loader.dataset.num_models, nc_model))
#####################
# Network predictions
#####################
probs = []
targets = []
obj_inds = []
t = [time.time()]
last_display = time.time()
mean_dt = np.zeros(1)
# Start validation loop
for batch in val_loader:
# New time
t = t[-1:]
t += [time.time()]
if 'cuda' in self.device.type:
batch.to(self.device)
# Forward pass
outputs = net(batch, config)
# Get probs and labels
probs += [softmax(outputs).cpu().detach().numpy()]
targets += [batch.labels.cpu().numpy()]
obj_inds += [batch.model_inds.cpu().numpy()]
torch.cuda.synchronize(self.device)
# Average timing
t += [time.time()]
mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1]))
# Display
if (t[-1] - last_display) > 1.0:
last_display = t[-1]
message = 'Validation : {:.1f}% (timings : {:4.2f} {:4.2f})'
print(message.format(100 * len(obj_inds) / config.validation_size,
1000 * (mean_dt[0]),
1000 * (mean_dt[1])))
# Stack all validation predictions
probs = np.vstack(probs)
targets = np.hstack(targets)
obj_inds = np.hstack(obj_inds)
###################
# Voting validation
###################
self.val_probs[obj_inds] = val_smooth * self.val_probs[obj_inds] + (1-val_smooth) * probs
############
# Confusions
############
validation_labels = np.array(val_loader.dataset.label_values)
# Compute classification results
C1 = fast_confusion(targets,
np.argmax(probs, axis=1),
validation_labels)
# Compute votes confusion
C2 = fast_confusion(val_loader.dataset.input_labels,
np.argmax(self.val_probs, axis=1),
validation_labels)
# Saving (optionnal)
if config.saving:
print("Save confusions")
conf_list = [C1, C2]
file_list = ['val_confs.txt', 'vote_confs.txt']
for conf, conf_file in zip(conf_list, file_list):
test_file = join(config.saving_path, conf_file)
if exists(test_file):
with open(test_file, "a") as text_file:
for line in conf:
for value in line:
text_file.write('%d ' % value)
text_file.write('\n')
else:
with open(test_file, "w") as text_file:
for line in conf:
for value in line:
text_file.write('%d ' % value)
text_file.write('\n')
val_ACC = 100 * np.sum(np.diag(C1)) / (np.sum(C1) + 1e-6)
vote_ACC = 100 * np.sum(np.diag(C2)) / (np.sum(C2) + 1e-6)
print('Accuracies : val = {:.1f}% / vote = {:.1f}%'.format(val_ACC, vote_ACC))
return C1
def cloud_segmentation_validation(self, net, val_loader, config, debug=False):
"""
Validation method for cloud segmentation models
"""
############
# Initialize
############
t0 = time.time()
# Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing)
val_smooth = 0.95
softmax = torch.nn.Softmax(1)
# Do not validate if dataset has no validation cloud
if val_loader.dataset.validation_split not in val_loader.dataset.all_splits:
return
# Number of classes including ignored labels
nc_tot = val_loader.dataset.num_classes
# Number of classes predicted by the model
nc_model = config.num_classes
#print(nc_tot)
#print(nc_model)
# Initiate global prediction over validation clouds
if not hasattr(self, 'validation_probs'):
self.validation_probs = [np.zeros((l.shape[0], nc_model))
for l in val_loader.dataset.input_labels]
self.val_proportions = np.zeros(nc_model, dtype=np.float32)
i = 0
for label_value in val_loader.dataset.label_values:
if label_value not in val_loader.dataset.ignored_labels:
self.val_proportions[i] = np.sum([np.sum(labels == label_value)
for labels in val_loader.dataset.validation_labels])
i += 1
#####################
# Network predictions
#####################
predictions = []
targets = []
t = [time.time()]
last_display = time.time()
mean_dt = np.zeros(1)
t1 = time.time()
# Start validation loop
for i, batch in enumerate(val_loader):
# New time
t = t[-1:]
t += [time.time()]
if 'cuda' in self.device.type:
batch.to(self.device)
# Forward pass
outputs = net(batch, config)
# Get probs and labels
stacked_probs = softmax(outputs).cpu().detach().numpy()
labels = batch.labels.cpu().numpy()
lengths = batch.lengths[0].cpu().numpy()
in_inds = batch.input_inds.cpu().numpy()
cloud_inds = batch.cloud_inds.cpu().numpy()
torch.cuda.synchronize(self.device)
# Get predictions and labels per instance
# ***************************************
i0 = 0
for b_i, length in enumerate(lengths):
# Get prediction
target = labels[i0:i0 + length]
probs = stacked_probs[i0:i0 + length]
inds = in_inds[i0:i0 + length]
c_i = cloud_inds[b_i]
# Update current probs in whole cloud
self.validation_probs[c_i][inds] = val_smooth * self.validation_probs[c_i][inds] \
+ (1 - val_smooth) * probs
# Stack all prediction for this epoch
predictions.append(probs)
targets.append(target)
i0 += length
# Average timing
t += [time.time()]
mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1]))
# Display
if (t[-1] - last_display) > 1.0:
last_display = t[-1]
message = 'Validation : {:.1f}% (timings : {:4.2f} {:4.2f})'
print(message.format(100 * i / config.validation_size,
1000 * (mean_dt[0]),
1000 * (mean_dt[1])))
t2 = time.time()
# Confusions for our subparts of validation set
Confs = np.zeros((len(predictions), nc_tot, nc_tot), dtype=np.int32)
for i, (probs, truth) in enumerate(zip(predictions, targets)):
# Insert false columns for ignored labels
for l_ind, label_value in enumerate(val_loader.dataset.label_values):
if label_value in val_loader.dataset.ignored_labels:
probs = np.insert(probs, l_ind, 0, axis=1)
# Predicted labels
preds = val_loader.dataset.label_values[np.argmax(probs, axis=1)]
# Confusions
Confs[i, :, :] = fast_confusion(truth, preds, val_loader.dataset.label_values).astype(np.int32)
t3 = time.time()
# Sum all confusions
C = np.sum(Confs, axis=0).astype(np.float32)
# Remove ignored labels from confusions
for l_ind, label_value in reversed(list(enumerate(val_loader.dataset.label_values))):
if label_value in val_loader.dataset.ignored_labels:
C = np.delete(C, l_ind, axis=0)
C = np.delete(C, l_ind, axis=1)
# Balance with real validation proportions
C *= np.expand_dims(self.val_proportions / (np.sum(C, axis=1) + 1e-6), 1)
t4 = time.time()
# Objects IoU
IoUs = IoU_from_confusions(C)
t5 = time.time()
# Saving (optionnal)
if config.saving:
# Name of saving file
test_file = join(config.saving_path, 'val_IoUs.txt')
# Line to write:
line = ''
for IoU in IoUs:
line += '{:.3f} '.format(IoU)
line = line + '\n'
# Write in file
if exists(test_file):
with open(test_file, "a") as text_file:
text_file.write(line)
else:
with open(test_file, "w") as text_file:
text_file.write(line)
# Save potentials
pot_path = join(config.saving_path, 'potentials')
if not exists(pot_path):
makedirs(pot_path)
files = val_loader.dataset.files
for i, file_path in enumerate(files):
pot_points = np.array(val_loader.dataset.pot_trees[i].data, copy=False)
cloud_name = file_path.split('/')[-1]
pot_name = join(pot_path, cloud_name)
pots = val_loader.dataset.potentials[i].numpy().astype(np.float32)
write_ply(pot_name,
[pot_points.astype(np.float32), pots],
['x', 'y', 'z', 'pots'])
t6 = time.time()
# Print instance mean
mIoU = 100 * np.mean(IoUs)
print('{:s} mean IoU = {:.1f}%'.format(config.dataset, mIoU))
# Save predicted cloud occasionally
if config.saving and (self.epoch + 1) % config.checkpoint_gap == 0:
val_path = join(config.saving_path, 'val_preds_{:d}'.format(self.epoch + 1))
if not exists(val_path):
makedirs(val_path)
files = val_loader.dataset.files
for i, file_path in enumerate(files):
# Get points
points = val_loader.dataset.load_evaluation_points(file_path)
# Get probs on our own ply points
sub_probs = self.validation_probs[i]
# Insert false columns for ignored labels
for l_ind, label_value in enumerate(val_loader.dataset.label_values):
if label_value in val_loader.dataset.ignored_labels:
sub_probs = np.insert(sub_probs, l_ind, 0, axis=1)
# Get the predicted labels
sub_preds = val_loader.dataset.label_values[np.argmax(sub_probs, axis=1).astype(np.int32)]
# Reproject preds on the evaluations points
preds = (sub_preds[val_loader.dataset.test_proj[i]]).astype(np.int32)
# Path of saved validation file
cloud_name = file_path.split('/')[-1]
val_name = join(val_path, cloud_name)
# Save file
labels = val_loader.dataset.validation_labels[i].astype(np.int32)
write_ply(val_name,
[points, preds, labels],
['x', 'y', 'z', 'preds', 'class'])
# Display timings
t7 = time.time()
if debug:
print('\n************************\n')
print('Validation timings:')
print('Init ...... {:.1f}s'.format(t1 - t0))
print('Loop ...... {:.1f}s'.format(t2 - t1))
print('Confs ..... {:.1f}s'.format(t3 - t2))
print('Confs bis . {:.1f}s'.format(t4 - t3))
print('IoU ....... {:.1f}s'.format(t5 - t4))
print('Save1 ..... {:.1f}s'.format(t6 - t5))
print('Save2 ..... {:.1f}s'.format(t7 - t6))
print('\n************************\n')
return
def slam_segmentation_validation(self, net, val_loader, config, debug=True):
"""
Validation method for slam segmentation models
"""
############
# Initialize
############
t0 = time.time()
# Do not validate if dataset has no validation cloud
if val_loader is None:
return
# Choose validation smoothing parameter (0 for no smothing, 0.99 for big smoothing)
val_smooth = 0.95
softmax = torch.nn.Softmax(1)
# Create folder for validation predictions
if not exists (join(config.saving_path, 'val_preds')):
makedirs(join(config.saving_path, 'val_preds'))
# initiate the dataset validation containers
val_loader.dataset.val_points = []
val_loader.dataset.val_labels = []
# Number of classes including ignored labels
nc_tot = val_loader.dataset.num_classes
#####################
# Network predictions
#####################
predictions = []
targets = []
inds = []
val_i = 0
t = [time.time()]
last_display = time.time()
mean_dt = np.zeros(1)
t1 = time.time()
# Start validation loop
for i, batch in enumerate(val_loader):
# New time
t = t[-1:]
t += [time.time()]
if 'cuda' in self.device.type:
batch.to(self.device)
# Forward pass
outputs = net(batch, config)
# Get probs and labels
stk_probs = softmax(outputs).cpu().detach().numpy()
lengths = batch.lengths[0].cpu().numpy()
f_inds = batch.frame_inds.cpu().numpy()
r_inds_list = batch.reproj_inds
r_mask_list = batch.reproj_masks
labels_list = batch.val_labels
torch.cuda.synchronize(self.device)
# Get predictions and labels per instance
# ***************************************
i0 = 0
for b_i, length in enumerate(lengths):
# Get prediction
probs = stk_probs[i0:i0 + length]
proj_inds = r_inds_list[b_i]
proj_mask = r_mask_list[b_i]
frame_labels = labels_list[b_i]
s_ind = f_inds[b_i, 0]
f_ind = f_inds[b_i, 1]
# Project predictions on the frame points
proj_probs = probs[proj_inds]
# Safe check if only one point:
if proj_probs.ndim < 2:
proj_probs = np.expand_dims(proj_probs, 0)
# Insert false columns for ignored labels
for l_ind, label_value in enumerate(val_loader.dataset.label_values):
if label_value in val_loader.dataset.ignored_labels:
proj_probs = np.insert(proj_probs, l_ind, 0, axis=1)
# Predicted labels
preds = val_loader.dataset.label_values[np.argmax(proj_probs, axis=1)]
# Save predictions in a binary file
filename = '{:s}_{:07d}.npy'.format(val_loader.dataset.sequences[s_ind], f_ind)
filepath = join(config.saving_path, 'val_preds', filename)
if exists(filepath):
frame_preds = np.load(filepath)
else:
frame_preds = np.zeros(frame_labels.shape, dtype=np.uint8)
frame_preds[proj_mask] = preds.astype(np.uint8)
np.save(filepath, frame_preds)
# Save some of the frame pots
if f_ind % 20 == 0:
seq_path = join(val_loader.dataset.path, 'sequences', val_loader.dataset.sequences[s_ind])
velo_file = join(seq_path, 'velodyne', val_loader.dataset.frames[s_ind][f_ind] + '.bin')
frame_points = np.fromfile(velo_file, dtype=np.float32)
frame_points = frame_points.reshape((-1, 4))
write_ply(filepath[:-4] + '_pots.ply',
[frame_points[:, :3], frame_labels, frame_preds],
['x', 'y', 'z', 'gt', 'pre'])
# Update validation confusions
frame_C = fast_confusion(frame_labels,
frame_preds.astype(np.int32),
val_loader.dataset.label_values)
val_loader.dataset.val_confs[s_ind][f_ind, :, :] = frame_C
# Stack all prediction for this epoch
predictions += [preds]
targets += [frame_labels[proj_mask]]
inds += [f_inds[b_i, :]]
val_i += 1
i0 += length
# Average timing
t += [time.time()]
mean_dt = 0.95 * mean_dt + 0.05 * (np.array(t[1:]) - np.array(t[:-1]))
# Display
if (t[-1] - last_display) > 1.0:
last_display = t[-1]
message = 'Validation : {:.1f}% (timings : {:4.2f} {:4.2f})'
print(message.format(100 * i / config.validation_size,
1000 * (mean_dt[0]),
1000 * (mean_dt[1])))
t2 = time.time()
# Confusions for our subparts of validation set
Confs = np.zeros((len(predictions), nc_tot, nc_tot), dtype=np.int32)
for i, (preds, truth) in enumerate(zip(predictions, targets)):
# Confusions
Confs[i, :, :] = fast_confusion(truth, preds, val_loader.dataset.label_values).astype(np.int32)
t3 = time.time()
#######################################
# Results on this subpart of validation
#######################################
# Sum all confusions
C = np.sum(Confs, axis=0).astype(np.float32)
# Balance with real validation proportions
C *= np.expand_dims(val_loader.dataset.class_proportions / (np.sum(C, axis=1) + 1e-6), 1)
# Remove ignored labels from confusions
for l_ind, label_value in reversed(list(enumerate(val_loader.dataset.label_values))):
if label_value in val_loader.dataset.ignored_labels:
C = np.delete(C, l_ind, axis=0)
C = np.delete(C, l_ind, axis=1)
# Objects IoU
IoUs = IoU_from_confusions(C)
#####################################
# Results on the whole validation set
#####################################
t4 = time.time()
# Sum all validation confusions
C_tot = [np.sum(seq_C, axis=0) for seq_C in val_loader.dataset.val_confs if len(seq_C) > 0]
C_tot = np.sum(np.stack(C_tot, axis=0), axis=0)
if debug:
s = '\n'
for cc in C_tot:
for c in cc:
s += '{:8.1f} '.format(c)
s += '\n'
print(s)
# Remove ignored labels from confusions
for l_ind, label_value in reversed(list(enumerate(val_loader.dataset.label_values))):
if label_value in val_loader.dataset.ignored_labels:
C_tot = np.delete(C_tot, l_ind, axis=0)
C_tot = np.delete(C_tot, l_ind, axis=1)
# Objects IoU
val_IoUs = IoU_from_confusions(C_tot)
t5 = time.time()
# Saving (optionnal)
if config.saving:
IoU_list = [IoUs, val_IoUs]
file_list = ['subpart_IoUs.txt', 'val_IoUs.txt']
for IoUs_to_save, IoU_file in zip(IoU_list, file_list):
# Name of saving file
test_file = join(config.saving_path, IoU_file)
# Line to write:
line = ''
for IoU in IoUs_to_save:
line += '{:.3f} '.format(IoU)
line = line + '\n'
# Write in file
if exists(test_file):
with open(test_file, "a") as text_file:
text_file.write(line)
else:
with open(test_file, "w") as text_file:
text_file.write(line)
# Print instance mean
mIoU = 100 * np.mean(IoUs)
print('{:s} : subpart mIoU = {:.1f} %'.format(config.dataset, mIoU))
mIoU = 100 * np.mean(val_IoUs)
print('{:s} : val mIoU = {:.1f} %'.format(config.dataset, mIoU))
t6 = time.time()
# Display timings
if debug:
print('\n************************\n')
print('Validation timings:')
print('Init ...... {:.1f}s'.format(t1 - t0))
print('Loop ...... {:.1f}s'.format(t2 - t1))
print('Confs ..... {:.1f}s'.format(t3 - t2))
print('IoU1 ...... {:.1f}s'.format(t4 - t3))
print('IoU2 ...... {:.1f}s'.format(t5 - t4))
print('Save ...... {:.1f}s'.format(t6 - t5))
print('\n************************\n')
return
# Saving methods
# ------------------------------------------------------------------------------------------------------------------
def save_kernel_points(self, model, epoch):
"""
Method saving kernel point disposition and current model weights for later visualization
"""
if model.config.saving:
# Create a directory to save kernels of this epoch
kernels_dir = join(model.saving_path, 'kernel_points', 'epoch{:d}'.format(epoch))
if not exists(kernels_dir):
makedirs(kernels_dir)
# Get points
all_kernel_points_tf = [v for v in tf.global_variables() if 'kernel_points' in v.name
and v.name.startswith('KernelPoint')]
all_kernel_points = self.sess.run(all_kernel_points_tf)
# Get Extents
if False and 'gaussian' in model.config.convolution_mode:
all_kernel_params_tf = [v for v in tf.global_variables() if 'kernel_extents' in v.name
and v.name.startswith('KernelPoint')]
all_kernel_params = self.sess.run(all_kernel_params_tf)
else:
all_kernel_params = [None for p in all_kernel_points]
# Save in ply file
for kernel_points, kernel_extents, v in zip(all_kernel_points, all_kernel_params, all_kernel_points_tf):
# Name of saving file
ply_name = '_'.join(v.name[:-2].split('/')[1:-1]) + '.ply'
ply_file = join(kernels_dir, ply_name)
# Data to save
if kernel_points.ndim > 2:
kernel_points = kernel_points[:, 0, :]
if False and 'gaussian' in model.config.convolution_mode:
data = [kernel_points, kernel_extents]
keys = ['x', 'y', 'z', 'sigma']
else:
data = kernel_points
keys = ['x', 'y', 'z']
# Save
write_ply(ply_file, data, keys)
# Get Weights
all_kernel_weights_tf = [v for v in tf.global_variables() if 'weights' in v.name
and v.name.startswith('KernelPointNetwork')]
all_kernel_weights = self.sess.run(all_kernel_weights_tf)
# Save in numpy file
for kernel_weights, v in zip(all_kernel_weights, all_kernel_weights_tf):
np_name = '_'.join(v.name[:-2].split('/')[1:-1]) + '.npy'
np_file = join(kernels_dir, np_name)
np.save(np_file, kernel_weights)
# Debug methods
# ------------------------------------------------------------------------------------------------------------------
def show_memory_usage(self, batch_to_feed):
for l in range(self.config.num_layers):
neighb_size = list(batch_to_feed[self.in_neighbors_f32[l]].shape)
dist_size = neighb_size + [self.config.num_kernel_points, 3]
dist_memory = np.prod(dist_size) * 4 * 1e-9
in_feature_size = neighb_size + [self.config.first_features_dim * 2**l]
in_feature_memory = np.prod(in_feature_size) * 4 * 1e-9
out_feature_size = [neighb_size[0], self.config.num_kernel_points, self.config.first_features_dim * 2**(l+1)]
out_feature_memory = np.prod(out_feature_size) * 4 * 1e-9
print('Layer {:d} => {:.1f}GB {:.1f}GB {:.1f}GB'.format(l,
dist_memory,
in_feature_memory,
out_feature_memory))
print('************************************')
def debug_nan(self, model, inputs, logits):
"""
NaN happened, find where
"""
print('\n\n------------------------ NaN DEBUG ------------------------\n')
# First save everything to reproduce error
file1 = join(model.saving_path, 'all_debug_inputs.pkl')
with open(file1, 'wb') as f1:
pickle.dump(inputs, f1)
# First save all inputs
file1 = join(model.saving_path, 'all_debug_logits.pkl')
with open(file1, 'wb') as f1:
pickle.dump(logits, f1)
# Then print a list of the trainable variables and if they have nan
print('List of variables :')
print('*******************\n')
all_vars = self.sess.run(tf.global_variables())
for v, value in zip(tf.global_variables(), all_vars):
nan_percentage = 100 * np.sum(np.isnan(value)) / np.prod(value.shape)
print(v.name, ' => {:.1f}% of values are NaN'.format(nan_percentage))
print('Inputs :')
print('********')
#Print inputs
nl = model.config.num_layers
for layer in range(nl):
print('Layer : {:d}'.format(layer))
points = inputs[layer]
neighbors = inputs[nl + layer]
pools = inputs[2*nl + layer]
upsamples = inputs[3*nl + layer]
nan_percentage = 100 * np.sum(np.isnan(points)) / np.prod(points.shape)
print('Points =>', points.shape, '{:.1f}% NaN'.format(nan_percentage))
nan_percentage = 100 * np.sum(np.isnan(neighbors)) / np.prod(neighbors.shape)
print('neighbors =>', neighbors.shape, '{:.1f}% NaN'.format(nan_percentage))
nan_percentage = 100 * np.sum(np.isnan(pools)) / np.prod(pools.shape)
print('pools =>', pools.shape, '{:.1f}% NaN'.format(nan_percentage))
nan_percentage = 100 * np.sum(np.isnan(upsamples)) / np.prod(upsamples.shape)
print('upsamples =>', upsamples.shape, '{:.1f}% NaN'.format(nan_percentage))
ind = 4 * nl
features = inputs[ind]
nan_percentage = 100 * np.sum(np.isnan(features)) / np.prod(features.shape)
print('features =>', features.shape, '{:.1f}% NaN'.format(nan_percentage))
ind += 1
batch_weights = inputs[ind]
ind += 1
in_batches = inputs[ind]
max_b = np.max(in_batches)
print(in_batches.shape)
in_b_sizes = np.sum(in_batches < max_b - 0.5, axis=-1)
print('in_batch_sizes =>', in_b_sizes)
ind += 1
out_batches = inputs[ind]
max_b = np.max(out_batches)
print(out_batches.shape)
out_b_sizes = np.sum(out_batches < max_b - 0.5, axis=-1)
print('out_batch_sizes =>', out_b_sizes)
ind += 1
point_labels = inputs[ind]
print('point labels, ', point_labels.shape, ', values : ', np.unique(point_labels))
print(np.array([int(100 * np.sum(point_labels == l) / len(point_labels)) for l in np.unique(point_labels)]))
ind += 1
if model.config.dataset.startswith('ShapeNetPart_multi'):
object_labels = inputs[ind]
nan_percentage = 100 * np.sum(np.isnan(object_labels)) / np.prod(object_labels.shape)
print('object_labels =>', object_labels.shape, '{:.1f}% NaN'.format(nan_percentage))
ind += 1
augment_scales = inputs[ind]
ind += 1
augment_rotations = inputs[ind]
ind += 1
print('\npoolings and upsamples nums :\n')
#Print inputs
for layer in range(nl):
print('\nLayer : {:d}'.format(layer))
neighbors = inputs[nl + layer]
pools = inputs[2*nl + layer]
upsamples = inputs[3*nl + layer]
max_n = np.max(neighbors)
nums = np.sum(neighbors < max_n - 0.5, axis=-1)
print('min neighbors =>', np.min(nums))
if np.prod(pools.shape) > 0:
max_n = np.max(pools)
nums = np.sum(pools < max_n - 0.5, axis=-1)
print('min pools =>', np.min(nums))
else:
print('pools empty')
if np.prod(upsamples.shape) > 0:
max_n = np.max(upsamples)
nums = np.sum(upsamples < max_n - 0.5, axis=-1)
print('min upsamples =>', np.min(nums))
else:
print('upsamples empty')
print('\nFinished\n\n')
time.sleep(0.5)