Corrections

This commit is contained in:
HuguesTHOMAS 2020-04-24 12:00:11 -04:00
parent 31cce85c95
commit 5eb4482209
7 changed files with 120 additions and 133 deletions

View file

@ -17,6 +17,61 @@
from models.blocks import *
import numpy as np
def p2p_fitting_regularizer(net):
fitting_loss = 0
repulsive_loss = 0
for m in net.modules():
if isinstance(m, KPConv) and m.deformable:
########################
# divide offset gradient
########################
#
# The offset gradient comes from two different losses. The regularizer loss (fitting deformation to point
# cloud) and the output loss (which should force deformations to help get a better score). The strength of
# the regularizer loss is set with the parameter deform_fitting_power. Therefore, this hook control the
# strength of the output loss. This strength can be set with the parameter deform_loss_power
#
m.offset_features.register_hook(lambda grad: grad * net.deform_loss_power)
# m.offset_features.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
##############
# Fitting loss
##############
# Get the distance to closest input point
KP_min_d2, _ = torch.min(m.deformed_d2, dim=1)
# Normalize KP locations to be independant from layers
KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2)
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
fitting_loss += net.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
################
# Repulsive loss
################
# Normalized KP locations
KP_locs = m.deformed_KP / m.KP_extent
# Point should not be close to each other
for i in range(net.K):
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
rep_loss = torch.sum(torch.clamp_max(distances - 1.0, max=0.0) ** 2, dim=1)
repulsive_loss += net.l1(rep_loss, torch.zeros_like(rep_loss)) / net.K
# The hook effectively affect both regularizer and output loss. So here we have to divide by deform_loss_power
return (net.deform_fitting_power / net.deform_loss_power) * (fitting_loss + repulsive_loss)
class KPCNN(nn.Module):
"""
Class defining KPCNN
@ -86,8 +141,9 @@ class KPCNN(nn.Module):
################
self.criterion = torch.nn.CrossEntropyLoss()
self.offset_loss = config.offsets_loss
self.offset_decay = config.offsets_decay
self.deform_fitting_mode = config.deform_fitting_mode
self.deform_fitting_power = config.deform_fitting_power
self.deform_loss_power = config.deform_loss_power
self.output_loss = 0
self.reg_loss = 0
self.l1 = nn.L1Loss()
@ -121,7 +177,12 @@ class KPCNN(nn.Module):
self.output_loss = self.criterion(outputs, labels)
# Regularization of deformable offsets
self.reg_loss = self.offset_regularizer()
if self.deform_fitting_mode == 'point2point':
self.reg_loss = p2p_fitting_regularizer(self)
elif self.deform_fitting_mode == 'point2plane':
raise ValueError('point2plane fitting mode not implemented yet.')
else:
raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode)
# Combined loss
return self.output_loss + self.reg_loss
@ -141,57 +202,6 @@ class KPCNN(nn.Module):
return correct / total
def offset_regularizer(self):
fitting_loss = 0
repulsive_loss = 0
for m in self.modules():
if isinstance(m, KPConv) and m.deformable:
##############################
# divide offset gradient by 10
##############################
m.unscaled_offsets.register_hook(lambda grad: grad * 0.1)
#m.unscaled_offsets.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
##############
# Fitting loss
##############
# Get the distance to closest input point
KP_min_d2, _ = torch.min(m.deformed_d2, dim=1)
# Normalize KP locations to be independant from layers
KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2)
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
fitting_loss += self.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
################
# Repulsive loss
################
# Normalized KP locations
KP_locs = m.deformed_KP / m.KP_extent
# Point should not be close to each other
for i in range(self.K):
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
rep_loss = torch.sum(torch.clamp_max(distances - 1.5, max=0.0) ** 2, dim=1)
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K
return self.offset_decay * (fitting_loss + repulsive_loss)
class KPFCNN(nn.Module):
"""
@ -316,8 +326,9 @@ class KPFCNN(nn.Module):
self.criterion = torch.nn.CrossEntropyLoss(weight=class_w, ignore_index=-1)
else:
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
self.offset_loss = config.offsets_loss
self.offset_decay = config.offsets_decay
self.deform_fitting_mode = config.deform_fitting_mode
self.deform_fitting_power = config.deform_fitting_power
self.deform_loss_power = config.deform_loss_power
self.output_loss = 0
self.reg_loss = 0
self.l1 = nn.L1Loss()
@ -369,7 +380,12 @@ class KPFCNN(nn.Module):
self.output_loss = self.criterion(outputs, target)
# Regularization of deformable offsets
self.reg_loss = self.offset_regularizer()
if self.deform_fitting_mode == 'point2point':
self.reg_loss = p2p_fitting_regularizer(self)
elif self.deform_fitting_mode == 'point2plane':
raise ValueError('point2plane fitting mode not implemented yet.')
else:
raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode)
# Combined loss
return self.output_loss + self.reg_loss
@ -393,57 +409,6 @@ class KPFCNN(nn.Module):
return correct / total
def offset_regularizer(self):
fitting_loss = 0
repulsive_loss = 0
for m in self.modules():
if isinstance(m, KPConv) and m.deformable:
##############################
# divide offset gradient by 10
##############################
m.unscaled_offsets.register_hook(lambda grad: grad * 0.1)
#m.unscaled_offsets.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
##############
# Fitting loss
##############
# Get the distance to closest input point
KP_min_d2, _ = torch.min(m.deformed_d2, dim=1)
# Normalize KP locations to be independant from layers
KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2)
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
fitting_loss += self.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
################
# Repulsive loss
################
# Normalized KP locations
KP_locs = m.deformed_KP / m.KP_extent
# Point should not be close to each other
for i in range(self.K):
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
rep_loss = torch.sum(torch.clamp_max(distances - 0.5, max=0.0) ** 2, dim=1)
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K
return self.offset_decay * (fitting_loss + repulsive_loss)

View file

@ -177,7 +177,7 @@ class KPConv(nn.Module):
# Running variable containing deformed KP distance to input points. (used in regularization loss)
self.deformed_d2 = None
self.deformed_KP = None
self.unscaled_offsets = None
self.offset_features = None
# Initialize weights
self.weights = Parameter(torch.zeros((self.K, in_channels, out_channels), dtype=torch.float32),
@ -241,27 +241,27 @@ class KPConv(nn.Module):
###################
if self.deformable:
offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias
self.offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias
if self.modulated:
# Get offset (in normalized scale) from features
offsets = offset_features[:, :self.p_dim * self.K]
self.unscaled_offsets = offsets.view(-1, self.K, self.p_dim)
unscaled_offsets = self.offset_features[:, :self.p_dim * self.K]
unscaled_offsets = unscaled_offsets.view(-1, self.K, self.p_dim)
# Get modulations
modulations = 2 * torch.sigmoid(offset_features[:, self.p_dim * self.K:])
modulations = 2 * torch.sigmoid(self.offset_features[:, self.p_dim * self.K:])
else:
# Get offset (in normalized scale) from features
self.unscaled_offsets = offset_features.view(-1, self.K, self.p_dim)
unscaled_offsets = self.offset_features.view(-1, self.K, self.p_dim)
# No modulations
modulations = None
# Rescale offset for this layer
offsets = self.unscaled_offsets * self.KP_extent
offsets = unscaled_offsets * self.KP_extent
else:
offsets = None

View file

@ -1515,6 +1515,9 @@ def S3DIS_deform(old_result_limit):
Debug S3DIS deformable.
At checkpoint 50, the points seem to start fitting the shape, but then, they just get further away from each other
and do not care about input points. The fitting loss seems broken?
10* fitting loss seems pretty good fitting the point cloud. It seems that the offset decay was a bit to low,
because the same happens without the 0.1 hook. So we can try to keep a 0.5 hook and multiply offset decay by 2.
"""
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
@ -1537,6 +1540,10 @@ def S3DIS_deform(old_result_limit):
'off_d=0.05_corrected',
'off_d=0.05_norepulsive',
'off_d=0.05_repulsive0.5',
'off_d=0.05_10*fitting',
'off_d=0.05_no_hook0.1',
'NEWPARAMS_fit=0.05_loss=0.5_(=off_d=0.1_hook0.5)',
'same_normal'
'test']
logs_names = np.array(logs_names[:len(logs)])

View file

@ -78,11 +78,11 @@ class S3DISConfig(Config):
'resnetb',
'resnetb',
'resnetb_strided',
'resnetb_deformable',
'resnetb_deformable',
'resnetb_deformable_strided',
'resnetb_deformable',
'resnetb_deformable',
'resnetb',
'resnetb',
'resnetb_strided',
'resnetb',
'resnetb',
'nearest_upsample',
'unary',
'nearest_upsample',
@ -132,10 +132,11 @@ class S3DISConfig(Config):
batch_norm_momentum = 0.02
# Offset loss
# 'permissive' only constrains offsets inside the deform radius (NOT implemented yet)
# 'fitting' helps deformed kernels to adapt to the geometry by penalizing distance to input points
offsets_loss = 'fitting'
offsets_decay = 0.05
# 'point2point' fitting geometry by penalizing distance from deform point to input points
# 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet
deform_fitting_mode = 'point2point'
deform_fitting_power = 0.05
deform_loss_power = 0.5
#####################
# Training parameters
@ -195,7 +196,7 @@ if __name__ == '__main__':
############################
# Set which gpu is going to be used
GPU_ID = '3'
GPU_ID = '2'
# Set GPU visible device
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID

View file

@ -159,9 +159,10 @@ class Config:
# Choose weights for class (used in segmentation loss). Empty list for no weights
class_w = []
# Offset regularization loss
offsets_loss = 'permissive'
offsets_decay = 1e-2
# New offset regularization parameters
deform_fitting_mode = 'point2point'
deform_fitting_power = 0.05
deform_loss_power = 0.5
# Number of batch
batch_num = 10
@ -259,7 +260,7 @@ class Config:
elif line_info[0] == 'class_w':
self.class_w = [float(w) for w in line_info[2:]]
else:
elif hasattr(self, line_info[0]):
attr_type = type(getattr(self, line_info[0]))
if attr_type == bool:
setattr(self, line_info[0], attr_type(int(line_info[2])))
@ -362,8 +363,9 @@ class Config:
for a in self.class_w:
text_file.write(' {:.3f}'.format(a))
text_file.write('\n')
text_file.write('offsets_loss = {:s}\n'.format(self.offsets_loss))
text_file.write('offsets_decay = {:f}\n'.format(self.offsets_decay))
text_file.write('deform_fitting_mode = {:s}\n'.format(self.deform_fitting_mode))
text_file.write('deform_fitting_power = {:f}\n'.format(self.deform_fitting_power))
text_file.write('deform_loss_power = {:f}\n'.format(self.deform_loss_power))
text_file.write('batch_num = {:d}\n'.format(self.batch_num))
text_file.write('val_batch_num = {:d}\n'.format(self.val_batch_num))
text_file.write('max_epoch = {:d}\n'.format(self.max_epoch))

View file

@ -79,7 +79,14 @@ class ModelVisualizer:
##########################
checkpoint = torch.load(chkp_path)
net.load_state_dict(checkpoint['model_state_dict'])
new_dict = {}
for k, v in checkpoint['model_state_dict'].items():
if 'blocs' in k:
k = k.replace('blocs', 'blocks')
new_dict[k] = v
net.load_state_dict(new_dict)
self.epoch = checkpoint['epoch']
net.eval()
print("\nModel state restored from {:s}.".format(chkp_path))

View file

@ -96,7 +96,12 @@ if __name__ == '__main__':
# chosen_log = 'results/Log_2020-04-04_10-04-42' # => ModelNet40
# chosen_log = 'results/Log_2020-04-22_11-53-45' # => S3DIS
chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected
# chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected
# chosen_log = 'results/Log_2020-04-23_09-48-15' # => S3DIS no repulsive
# chosen_log = 'results/Log_2020-04-23_09-49-49' # => S3DIS repulsive 0.5
# chosen_log = 'results/Log_2020-04-23_19-41-12' # => S3DIS 10*fitting
chosen_log = 'results/Log_2020-04-23_19-42-18' # => S3DIS no hook
# You can also choose the index of the snapshot to load (last by default)
chkp_idx = -1