Corrections
This commit is contained in:
parent
31cce85c95
commit
5eb4482209
|
@ -17,6 +17,61 @@
|
|||
from models.blocks import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
def p2p_fitting_regularizer(net):
|
||||
|
||||
fitting_loss = 0
|
||||
repulsive_loss = 0
|
||||
|
||||
for m in net.modules():
|
||||
|
||||
if isinstance(m, KPConv) and m.deformable:
|
||||
|
||||
########################
|
||||
# divide offset gradient
|
||||
########################
|
||||
#
|
||||
# The offset gradient comes from two different losses. The regularizer loss (fitting deformation to point
|
||||
# cloud) and the output loss (which should force deformations to help get a better score). The strength of
|
||||
# the regularizer loss is set with the parameter deform_fitting_power. Therefore, this hook control the
|
||||
# strength of the output loss. This strength can be set with the parameter deform_loss_power
|
||||
#
|
||||
|
||||
m.offset_features.register_hook(lambda grad: grad * net.deform_loss_power)
|
||||
# m.offset_features.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
|
||||
|
||||
##############
|
||||
# Fitting loss
|
||||
##############
|
||||
|
||||
# Get the distance to closest input point
|
||||
|
||||
KP_min_d2, _ = torch.min(m.deformed_d2, dim=1)
|
||||
|
||||
# Normalize KP locations to be independant from layers
|
||||
KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2)
|
||||
|
||||
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
|
||||
fitting_loss += net.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
|
||||
|
||||
################
|
||||
# Repulsive loss
|
||||
################
|
||||
|
||||
# Normalized KP locations
|
||||
KP_locs = m.deformed_KP / m.KP_extent
|
||||
|
||||
# Point should not be close to each other
|
||||
for i in range(net.K):
|
||||
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
|
||||
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
|
||||
rep_loss = torch.sum(torch.clamp_max(distances - 1.0, max=0.0) ** 2, dim=1)
|
||||
repulsive_loss += net.l1(rep_loss, torch.zeros_like(rep_loss)) / net.K
|
||||
|
||||
# The hook effectively affect both regularizer and output loss. So here we have to divide by deform_loss_power
|
||||
return (net.deform_fitting_power / net.deform_loss_power) * (fitting_loss + repulsive_loss)
|
||||
|
||||
|
||||
class KPCNN(nn.Module):
|
||||
"""
|
||||
Class defining KPCNN
|
||||
|
@ -86,8 +141,9 @@ class KPCNN(nn.Module):
|
|||
################
|
||||
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
self.offset_loss = config.offsets_loss
|
||||
self.offset_decay = config.offsets_decay
|
||||
self.deform_fitting_mode = config.deform_fitting_mode
|
||||
self.deform_fitting_power = config.deform_fitting_power
|
||||
self.deform_loss_power = config.deform_loss_power
|
||||
self.output_loss = 0
|
||||
self.reg_loss = 0
|
||||
self.l1 = nn.L1Loss()
|
||||
|
@ -121,7 +177,12 @@ class KPCNN(nn.Module):
|
|||
self.output_loss = self.criterion(outputs, labels)
|
||||
|
||||
# Regularization of deformable offsets
|
||||
self.reg_loss = self.offset_regularizer()
|
||||
if self.deform_fitting_mode == 'point2point':
|
||||
self.reg_loss = p2p_fitting_regularizer(self)
|
||||
elif self.deform_fitting_mode == 'point2plane':
|
||||
raise ValueError('point2plane fitting mode not implemented yet.')
|
||||
else:
|
||||
raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode)
|
||||
|
||||
# Combined loss
|
||||
return self.output_loss + self.reg_loss
|
||||
|
@ -141,57 +202,6 @@ class KPCNN(nn.Module):
|
|||
|
||||
return correct / total
|
||||
|
||||
def offset_regularizer(self):
|
||||
|
||||
fitting_loss = 0
|
||||
repulsive_loss = 0
|
||||
|
||||
for m in self.modules():
|
||||
|
||||
if isinstance(m, KPConv) and m.deformable:
|
||||
|
||||
##############################
|
||||
# divide offset gradient by 10
|
||||
##############################
|
||||
|
||||
m.unscaled_offsets.register_hook(lambda grad: grad * 0.1)
|
||||
#m.unscaled_offsets.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
|
||||
|
||||
##############
|
||||
# Fitting loss
|
||||
##############
|
||||
|
||||
# Get the distance to closest input point
|
||||
KP_min_d2, _ = torch.min(m.deformed_d2, dim=1)
|
||||
|
||||
# Normalize KP locations to be independant from layers
|
||||
KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2)
|
||||
|
||||
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
|
||||
fitting_loss += self.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
|
||||
|
||||
################
|
||||
# Repulsive loss
|
||||
################
|
||||
|
||||
# Normalized KP locations
|
||||
KP_locs = m.deformed_KP / m.KP_extent
|
||||
|
||||
# Point should not be close to each other
|
||||
for i in range(self.K):
|
||||
|
||||
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
|
||||
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
|
||||
rep_loss = torch.sum(torch.clamp_max(distances - 1.5, max=0.0) ** 2, dim=1)
|
||||
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K
|
||||
|
||||
|
||||
|
||||
return self.offset_decay * (fitting_loss + repulsive_loss)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class KPFCNN(nn.Module):
|
||||
"""
|
||||
|
@ -316,8 +326,9 @@ class KPFCNN(nn.Module):
|
|||
self.criterion = torch.nn.CrossEntropyLoss(weight=class_w, ignore_index=-1)
|
||||
else:
|
||||
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
|
||||
self.offset_loss = config.offsets_loss
|
||||
self.offset_decay = config.offsets_decay
|
||||
self.deform_fitting_mode = config.deform_fitting_mode
|
||||
self.deform_fitting_power = config.deform_fitting_power
|
||||
self.deform_loss_power = config.deform_loss_power
|
||||
self.output_loss = 0
|
||||
self.reg_loss = 0
|
||||
self.l1 = nn.L1Loss()
|
||||
|
@ -369,7 +380,12 @@ class KPFCNN(nn.Module):
|
|||
self.output_loss = self.criterion(outputs, target)
|
||||
|
||||
# Regularization of deformable offsets
|
||||
self.reg_loss = self.offset_regularizer()
|
||||
if self.deform_fitting_mode == 'point2point':
|
||||
self.reg_loss = p2p_fitting_regularizer(self)
|
||||
elif self.deform_fitting_mode == 'point2plane':
|
||||
raise ValueError('point2plane fitting mode not implemented yet.')
|
||||
else:
|
||||
raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode)
|
||||
|
||||
# Combined loss
|
||||
return self.output_loss + self.reg_loss
|
||||
|
@ -393,57 +409,6 @@ class KPFCNN(nn.Module):
|
|||
|
||||
return correct / total
|
||||
|
||||
def offset_regularizer(self):
|
||||
|
||||
fitting_loss = 0
|
||||
repulsive_loss = 0
|
||||
|
||||
for m in self.modules():
|
||||
|
||||
if isinstance(m, KPConv) and m.deformable:
|
||||
|
||||
##############################
|
||||
# divide offset gradient by 10
|
||||
##############################
|
||||
|
||||
m.unscaled_offsets.register_hook(lambda grad: grad * 0.1)
|
||||
#m.unscaled_offsets.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
|
||||
|
||||
##############
|
||||
# Fitting loss
|
||||
##############
|
||||
|
||||
# Get the distance to closest input point
|
||||
KP_min_d2, _ = torch.min(m.deformed_d2, dim=1)
|
||||
|
||||
# Normalize KP locations to be independant from layers
|
||||
KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2)
|
||||
|
||||
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
|
||||
fitting_loss += self.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
|
||||
|
||||
################
|
||||
# Repulsive loss
|
||||
################
|
||||
|
||||
# Normalized KP locations
|
||||
KP_locs = m.deformed_KP / m.KP_extent
|
||||
|
||||
# Point should not be close to each other
|
||||
for i in range(self.K):
|
||||
|
||||
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
|
||||
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
|
||||
rep_loss = torch.sum(torch.clamp_max(distances - 0.5, max=0.0) ** 2, dim=1)
|
||||
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K
|
||||
|
||||
|
||||
return self.offset_decay * (fitting_loss + repulsive_loss)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -177,7 +177,7 @@ class KPConv(nn.Module):
|
|||
# Running variable containing deformed KP distance to input points. (used in regularization loss)
|
||||
self.deformed_d2 = None
|
||||
self.deformed_KP = None
|
||||
self.unscaled_offsets = None
|
||||
self.offset_features = None
|
||||
|
||||
# Initialize weights
|
||||
self.weights = Parameter(torch.zeros((self.K, in_channels, out_channels), dtype=torch.float32),
|
||||
|
@ -241,27 +241,27 @@ class KPConv(nn.Module):
|
|||
###################
|
||||
|
||||
if self.deformable:
|
||||
offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias
|
||||
self.offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias
|
||||
|
||||
if self.modulated:
|
||||
|
||||
# Get offset (in normalized scale) from features
|
||||
offsets = offset_features[:, :self.p_dim * self.K]
|
||||
self.unscaled_offsets = offsets.view(-1, self.K, self.p_dim)
|
||||
unscaled_offsets = self.offset_features[:, :self.p_dim * self.K]
|
||||
unscaled_offsets = unscaled_offsets.view(-1, self.K, self.p_dim)
|
||||
|
||||
# Get modulations
|
||||
modulations = 2 * torch.sigmoid(offset_features[:, self.p_dim * self.K:])
|
||||
modulations = 2 * torch.sigmoid(self.offset_features[:, self.p_dim * self.K:])
|
||||
|
||||
else:
|
||||
|
||||
# Get offset (in normalized scale) from features
|
||||
self.unscaled_offsets = offset_features.view(-1, self.K, self.p_dim)
|
||||
unscaled_offsets = self.offset_features.view(-1, self.K, self.p_dim)
|
||||
|
||||
# No modulations
|
||||
modulations = None
|
||||
|
||||
# Rescale offset for this layer
|
||||
offsets = self.unscaled_offsets * self.KP_extent
|
||||
offsets = unscaled_offsets * self.KP_extent
|
||||
|
||||
else:
|
||||
offsets = None
|
||||
|
|
|
@ -1515,6 +1515,9 @@ def S3DIS_deform(old_result_limit):
|
|||
Debug S3DIS deformable.
|
||||
At checkpoint 50, the points seem to start fitting the shape, but then, they just get further away from each other
|
||||
and do not care about input points. The fitting loss seems broken?
|
||||
|
||||
10* fitting loss seems pretty good fitting the point cloud. It seems that the offset decay was a bit to low,
|
||||
because the same happens without the 0.1 hook. So we can try to keep a 0.5 hook and multiply offset decay by 2.
|
||||
"""
|
||||
|
||||
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
|
||||
|
@ -1537,6 +1540,10 @@ def S3DIS_deform(old_result_limit):
|
|||
'off_d=0.05_corrected',
|
||||
'off_d=0.05_norepulsive',
|
||||
'off_d=0.05_repulsive0.5',
|
||||
'off_d=0.05_10*fitting',
|
||||
'off_d=0.05_no_hook0.1',
|
||||
'NEWPARAMS_fit=0.05_loss=0.5_(=off_d=0.1_hook0.5)',
|
||||
'same_normal'
|
||||
'test']
|
||||
|
||||
logs_names = np.array(logs_names[:len(logs)])
|
||||
|
|
|
@ -78,11 +78,11 @@ class S3DISConfig(Config):
|
|||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb_strided',
|
||||
'resnetb_deformable',
|
||||
'resnetb_deformable',
|
||||
'resnetb_deformable_strided',
|
||||
'resnetb_deformable',
|
||||
'resnetb_deformable',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'resnetb_strided',
|
||||
'resnetb',
|
||||
'resnetb',
|
||||
'nearest_upsample',
|
||||
'unary',
|
||||
'nearest_upsample',
|
||||
|
@ -132,10 +132,11 @@ class S3DISConfig(Config):
|
|||
batch_norm_momentum = 0.02
|
||||
|
||||
# Offset loss
|
||||
# 'permissive' only constrains offsets inside the deform radius (NOT implemented yet)
|
||||
# 'fitting' helps deformed kernels to adapt to the geometry by penalizing distance to input points
|
||||
offsets_loss = 'fitting'
|
||||
offsets_decay = 0.05
|
||||
# 'point2point' fitting geometry by penalizing distance from deform point to input points
|
||||
# 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet
|
||||
deform_fitting_mode = 'point2point'
|
||||
deform_fitting_power = 0.05
|
||||
deform_loss_power = 0.5
|
||||
|
||||
#####################
|
||||
# Training parameters
|
||||
|
@ -195,7 +196,7 @@ if __name__ == '__main__':
|
|||
############################
|
||||
|
||||
# Set which gpu is going to be used
|
||||
GPU_ID = '3'
|
||||
GPU_ID = '2'
|
||||
|
||||
# Set GPU visible device
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
|
||||
|
|
|
@ -159,9 +159,10 @@ class Config:
|
|||
# Choose weights for class (used in segmentation loss). Empty list for no weights
|
||||
class_w = []
|
||||
|
||||
# Offset regularization loss
|
||||
offsets_loss = 'permissive'
|
||||
offsets_decay = 1e-2
|
||||
# New offset regularization parameters
|
||||
deform_fitting_mode = 'point2point'
|
||||
deform_fitting_power = 0.05
|
||||
deform_loss_power = 0.5
|
||||
|
||||
# Number of batch
|
||||
batch_num = 10
|
||||
|
@ -259,7 +260,7 @@ class Config:
|
|||
elif line_info[0] == 'class_w':
|
||||
self.class_w = [float(w) for w in line_info[2:]]
|
||||
|
||||
else:
|
||||
elif hasattr(self, line_info[0]):
|
||||
attr_type = type(getattr(self, line_info[0]))
|
||||
if attr_type == bool:
|
||||
setattr(self, line_info[0], attr_type(int(line_info[2])))
|
||||
|
@ -362,8 +363,9 @@ class Config:
|
|||
for a in self.class_w:
|
||||
text_file.write(' {:.3f}'.format(a))
|
||||
text_file.write('\n')
|
||||
text_file.write('offsets_loss = {:s}\n'.format(self.offsets_loss))
|
||||
text_file.write('offsets_decay = {:f}\n'.format(self.offsets_decay))
|
||||
text_file.write('deform_fitting_mode = {:s}\n'.format(self.deform_fitting_mode))
|
||||
text_file.write('deform_fitting_power = {:f}\n'.format(self.deform_fitting_power))
|
||||
text_file.write('deform_loss_power = {:f}\n'.format(self.deform_loss_power))
|
||||
text_file.write('batch_num = {:d}\n'.format(self.batch_num))
|
||||
text_file.write('val_batch_num = {:d}\n'.format(self.val_batch_num))
|
||||
text_file.write('max_epoch = {:d}\n'.format(self.max_epoch))
|
||||
|
|
|
@ -79,7 +79,14 @@ class ModelVisualizer:
|
|||
##########################
|
||||
|
||||
checkpoint = torch.load(chkp_path)
|
||||
net.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
new_dict = {}
|
||||
for k, v in checkpoint['model_state_dict'].items():
|
||||
if 'blocs' in k:
|
||||
k = k.replace('blocs', 'blocks')
|
||||
new_dict[k] = v
|
||||
|
||||
net.load_state_dict(new_dict)
|
||||
self.epoch = checkpoint['epoch']
|
||||
net.eval()
|
||||
print("\nModel state restored from {:s}.".format(chkp_path))
|
||||
|
|
|
@ -96,7 +96,12 @@ if __name__ == '__main__':
|
|||
|
||||
# chosen_log = 'results/Log_2020-04-04_10-04-42' # => ModelNet40
|
||||
# chosen_log = 'results/Log_2020-04-22_11-53-45' # => S3DIS
|
||||
chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected
|
||||
# chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected
|
||||
# chosen_log = 'results/Log_2020-04-23_09-48-15' # => S3DIS no repulsive
|
||||
# chosen_log = 'results/Log_2020-04-23_09-49-49' # => S3DIS repulsive 0.5
|
||||
# chosen_log = 'results/Log_2020-04-23_19-41-12' # => S3DIS 10*fitting
|
||||
chosen_log = 'results/Log_2020-04-23_19-42-18' # => S3DIS no hook
|
||||
|
||||
|
||||
# You can also choose the index of the snapshot to load (last by default)
|
||||
chkp_idx = -1
|
||||
|
|
Loading…
Reference in a new issue