Corrections
This commit is contained in:
parent
31cce85c95
commit
5eb4482209
|
@ -17,6 +17,61 @@
|
||||||
from models.blocks import *
|
from models.blocks import *
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def p2p_fitting_regularizer(net):
|
||||||
|
|
||||||
|
fitting_loss = 0
|
||||||
|
repulsive_loss = 0
|
||||||
|
|
||||||
|
for m in net.modules():
|
||||||
|
|
||||||
|
if isinstance(m, KPConv) and m.deformable:
|
||||||
|
|
||||||
|
########################
|
||||||
|
# divide offset gradient
|
||||||
|
########################
|
||||||
|
#
|
||||||
|
# The offset gradient comes from two different losses. The regularizer loss (fitting deformation to point
|
||||||
|
# cloud) and the output loss (which should force deformations to help get a better score). The strength of
|
||||||
|
# the regularizer loss is set with the parameter deform_fitting_power. Therefore, this hook control the
|
||||||
|
# strength of the output loss. This strength can be set with the parameter deform_loss_power
|
||||||
|
#
|
||||||
|
|
||||||
|
m.offset_features.register_hook(lambda grad: grad * net.deform_loss_power)
|
||||||
|
# m.offset_features.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
|
||||||
|
|
||||||
|
##############
|
||||||
|
# Fitting loss
|
||||||
|
##############
|
||||||
|
|
||||||
|
# Get the distance to closest input point
|
||||||
|
|
||||||
|
KP_min_d2, _ = torch.min(m.deformed_d2, dim=1)
|
||||||
|
|
||||||
|
# Normalize KP locations to be independant from layers
|
||||||
|
KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2)
|
||||||
|
|
||||||
|
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
|
||||||
|
fitting_loss += net.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
|
||||||
|
|
||||||
|
################
|
||||||
|
# Repulsive loss
|
||||||
|
################
|
||||||
|
|
||||||
|
# Normalized KP locations
|
||||||
|
KP_locs = m.deformed_KP / m.KP_extent
|
||||||
|
|
||||||
|
# Point should not be close to each other
|
||||||
|
for i in range(net.K):
|
||||||
|
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
|
||||||
|
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
|
||||||
|
rep_loss = torch.sum(torch.clamp_max(distances - 1.0, max=0.0) ** 2, dim=1)
|
||||||
|
repulsive_loss += net.l1(rep_loss, torch.zeros_like(rep_loss)) / net.K
|
||||||
|
|
||||||
|
# The hook effectively affect both regularizer and output loss. So here we have to divide by deform_loss_power
|
||||||
|
return (net.deform_fitting_power / net.deform_loss_power) * (fitting_loss + repulsive_loss)
|
||||||
|
|
||||||
|
|
||||||
class KPCNN(nn.Module):
|
class KPCNN(nn.Module):
|
||||||
"""
|
"""
|
||||||
Class defining KPCNN
|
Class defining KPCNN
|
||||||
|
@ -86,8 +141,9 @@ class KPCNN(nn.Module):
|
||||||
################
|
################
|
||||||
|
|
||||||
self.criterion = torch.nn.CrossEntropyLoss()
|
self.criterion = torch.nn.CrossEntropyLoss()
|
||||||
self.offset_loss = config.offsets_loss
|
self.deform_fitting_mode = config.deform_fitting_mode
|
||||||
self.offset_decay = config.offsets_decay
|
self.deform_fitting_power = config.deform_fitting_power
|
||||||
|
self.deform_loss_power = config.deform_loss_power
|
||||||
self.output_loss = 0
|
self.output_loss = 0
|
||||||
self.reg_loss = 0
|
self.reg_loss = 0
|
||||||
self.l1 = nn.L1Loss()
|
self.l1 = nn.L1Loss()
|
||||||
|
@ -121,7 +177,12 @@ class KPCNN(nn.Module):
|
||||||
self.output_loss = self.criterion(outputs, labels)
|
self.output_loss = self.criterion(outputs, labels)
|
||||||
|
|
||||||
# Regularization of deformable offsets
|
# Regularization of deformable offsets
|
||||||
self.reg_loss = self.offset_regularizer()
|
if self.deform_fitting_mode == 'point2point':
|
||||||
|
self.reg_loss = p2p_fitting_regularizer(self)
|
||||||
|
elif self.deform_fitting_mode == 'point2plane':
|
||||||
|
raise ValueError('point2plane fitting mode not implemented yet.')
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode)
|
||||||
|
|
||||||
# Combined loss
|
# Combined loss
|
||||||
return self.output_loss + self.reg_loss
|
return self.output_loss + self.reg_loss
|
||||||
|
@ -141,57 +202,6 @@ class KPCNN(nn.Module):
|
||||||
|
|
||||||
return correct / total
|
return correct / total
|
||||||
|
|
||||||
def offset_regularizer(self):
|
|
||||||
|
|
||||||
fitting_loss = 0
|
|
||||||
repulsive_loss = 0
|
|
||||||
|
|
||||||
for m in self.modules():
|
|
||||||
|
|
||||||
if isinstance(m, KPConv) and m.deformable:
|
|
||||||
|
|
||||||
##############################
|
|
||||||
# divide offset gradient by 10
|
|
||||||
##############################
|
|
||||||
|
|
||||||
m.unscaled_offsets.register_hook(lambda grad: grad * 0.1)
|
|
||||||
#m.unscaled_offsets.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
|
|
||||||
|
|
||||||
##############
|
|
||||||
# Fitting loss
|
|
||||||
##############
|
|
||||||
|
|
||||||
# Get the distance to closest input point
|
|
||||||
KP_min_d2, _ = torch.min(m.deformed_d2, dim=1)
|
|
||||||
|
|
||||||
# Normalize KP locations to be independant from layers
|
|
||||||
KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2)
|
|
||||||
|
|
||||||
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
|
|
||||||
fitting_loss += self.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
|
|
||||||
|
|
||||||
################
|
|
||||||
# Repulsive loss
|
|
||||||
################
|
|
||||||
|
|
||||||
# Normalized KP locations
|
|
||||||
KP_locs = m.deformed_KP / m.KP_extent
|
|
||||||
|
|
||||||
# Point should not be close to each other
|
|
||||||
for i in range(self.K):
|
|
||||||
|
|
||||||
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
|
|
||||||
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
|
|
||||||
rep_loss = torch.sum(torch.clamp_max(distances - 1.5, max=0.0) ** 2, dim=1)
|
|
||||||
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return self.offset_decay * (fitting_loss + repulsive_loss)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class KPFCNN(nn.Module):
|
class KPFCNN(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
@ -316,8 +326,9 @@ class KPFCNN(nn.Module):
|
||||||
self.criterion = torch.nn.CrossEntropyLoss(weight=class_w, ignore_index=-1)
|
self.criterion = torch.nn.CrossEntropyLoss(weight=class_w, ignore_index=-1)
|
||||||
else:
|
else:
|
||||||
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
|
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
|
||||||
self.offset_loss = config.offsets_loss
|
self.deform_fitting_mode = config.deform_fitting_mode
|
||||||
self.offset_decay = config.offsets_decay
|
self.deform_fitting_power = config.deform_fitting_power
|
||||||
|
self.deform_loss_power = config.deform_loss_power
|
||||||
self.output_loss = 0
|
self.output_loss = 0
|
||||||
self.reg_loss = 0
|
self.reg_loss = 0
|
||||||
self.l1 = nn.L1Loss()
|
self.l1 = nn.L1Loss()
|
||||||
|
@ -369,7 +380,12 @@ class KPFCNN(nn.Module):
|
||||||
self.output_loss = self.criterion(outputs, target)
|
self.output_loss = self.criterion(outputs, target)
|
||||||
|
|
||||||
# Regularization of deformable offsets
|
# Regularization of deformable offsets
|
||||||
self.reg_loss = self.offset_regularizer()
|
if self.deform_fitting_mode == 'point2point':
|
||||||
|
self.reg_loss = p2p_fitting_regularizer(self)
|
||||||
|
elif self.deform_fitting_mode == 'point2plane':
|
||||||
|
raise ValueError('point2plane fitting mode not implemented yet.')
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode)
|
||||||
|
|
||||||
# Combined loss
|
# Combined loss
|
||||||
return self.output_loss + self.reg_loss
|
return self.output_loss + self.reg_loss
|
||||||
|
@ -393,57 +409,6 @@ class KPFCNN(nn.Module):
|
||||||
|
|
||||||
return correct / total
|
return correct / total
|
||||||
|
|
||||||
def offset_regularizer(self):
|
|
||||||
|
|
||||||
fitting_loss = 0
|
|
||||||
repulsive_loss = 0
|
|
||||||
|
|
||||||
for m in self.modules():
|
|
||||||
|
|
||||||
if isinstance(m, KPConv) and m.deformable:
|
|
||||||
|
|
||||||
##############################
|
|
||||||
# divide offset gradient by 10
|
|
||||||
##############################
|
|
||||||
|
|
||||||
m.unscaled_offsets.register_hook(lambda grad: grad * 0.1)
|
|
||||||
#m.unscaled_offsets.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
|
|
||||||
|
|
||||||
##############
|
|
||||||
# Fitting loss
|
|
||||||
##############
|
|
||||||
|
|
||||||
# Get the distance to closest input point
|
|
||||||
KP_min_d2, _ = torch.min(m.deformed_d2, dim=1)
|
|
||||||
|
|
||||||
# Normalize KP locations to be independant from layers
|
|
||||||
KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2)
|
|
||||||
|
|
||||||
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
|
|
||||||
fitting_loss += self.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
|
|
||||||
|
|
||||||
################
|
|
||||||
# Repulsive loss
|
|
||||||
################
|
|
||||||
|
|
||||||
# Normalized KP locations
|
|
||||||
KP_locs = m.deformed_KP / m.KP_extent
|
|
||||||
|
|
||||||
# Point should not be close to each other
|
|
||||||
for i in range(self.K):
|
|
||||||
|
|
||||||
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
|
|
||||||
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
|
|
||||||
rep_loss = torch.sum(torch.clamp_max(distances - 0.5, max=0.0) ** 2, dim=1)
|
|
||||||
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K
|
|
||||||
|
|
||||||
|
|
||||||
return self.offset_decay * (fitting_loss + repulsive_loss)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -177,7 +177,7 @@ class KPConv(nn.Module):
|
||||||
# Running variable containing deformed KP distance to input points. (used in regularization loss)
|
# Running variable containing deformed KP distance to input points. (used in regularization loss)
|
||||||
self.deformed_d2 = None
|
self.deformed_d2 = None
|
||||||
self.deformed_KP = None
|
self.deformed_KP = None
|
||||||
self.unscaled_offsets = None
|
self.offset_features = None
|
||||||
|
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
self.weights = Parameter(torch.zeros((self.K, in_channels, out_channels), dtype=torch.float32),
|
self.weights = Parameter(torch.zeros((self.K, in_channels, out_channels), dtype=torch.float32),
|
||||||
|
@ -241,27 +241,27 @@ class KPConv(nn.Module):
|
||||||
###################
|
###################
|
||||||
|
|
||||||
if self.deformable:
|
if self.deformable:
|
||||||
offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias
|
self.offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias
|
||||||
|
|
||||||
if self.modulated:
|
if self.modulated:
|
||||||
|
|
||||||
# Get offset (in normalized scale) from features
|
# Get offset (in normalized scale) from features
|
||||||
offsets = offset_features[:, :self.p_dim * self.K]
|
unscaled_offsets = self.offset_features[:, :self.p_dim * self.K]
|
||||||
self.unscaled_offsets = offsets.view(-1, self.K, self.p_dim)
|
unscaled_offsets = unscaled_offsets.view(-1, self.K, self.p_dim)
|
||||||
|
|
||||||
# Get modulations
|
# Get modulations
|
||||||
modulations = 2 * torch.sigmoid(offset_features[:, self.p_dim * self.K:])
|
modulations = 2 * torch.sigmoid(self.offset_features[:, self.p_dim * self.K:])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# Get offset (in normalized scale) from features
|
# Get offset (in normalized scale) from features
|
||||||
self.unscaled_offsets = offset_features.view(-1, self.K, self.p_dim)
|
unscaled_offsets = self.offset_features.view(-1, self.K, self.p_dim)
|
||||||
|
|
||||||
# No modulations
|
# No modulations
|
||||||
modulations = None
|
modulations = None
|
||||||
|
|
||||||
# Rescale offset for this layer
|
# Rescale offset for this layer
|
||||||
offsets = self.unscaled_offsets * self.KP_extent
|
offsets = unscaled_offsets * self.KP_extent
|
||||||
|
|
||||||
else:
|
else:
|
||||||
offsets = None
|
offsets = None
|
||||||
|
|
|
@ -1515,6 +1515,9 @@ def S3DIS_deform(old_result_limit):
|
||||||
Debug S3DIS deformable.
|
Debug S3DIS deformable.
|
||||||
At checkpoint 50, the points seem to start fitting the shape, but then, they just get further away from each other
|
At checkpoint 50, the points seem to start fitting the shape, but then, they just get further away from each other
|
||||||
and do not care about input points. The fitting loss seems broken?
|
and do not care about input points. The fitting loss seems broken?
|
||||||
|
|
||||||
|
10* fitting loss seems pretty good fitting the point cloud. It seems that the offset decay was a bit to low,
|
||||||
|
because the same happens without the 0.1 hook. So we can try to keep a 0.5 hook and multiply offset decay by 2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
|
# Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset.
|
||||||
|
@ -1537,6 +1540,10 @@ def S3DIS_deform(old_result_limit):
|
||||||
'off_d=0.05_corrected',
|
'off_d=0.05_corrected',
|
||||||
'off_d=0.05_norepulsive',
|
'off_d=0.05_norepulsive',
|
||||||
'off_d=0.05_repulsive0.5',
|
'off_d=0.05_repulsive0.5',
|
||||||
|
'off_d=0.05_10*fitting',
|
||||||
|
'off_d=0.05_no_hook0.1',
|
||||||
|
'NEWPARAMS_fit=0.05_loss=0.5_(=off_d=0.1_hook0.5)',
|
||||||
|
'same_normal'
|
||||||
'test']
|
'test']
|
||||||
|
|
||||||
logs_names = np.array(logs_names[:len(logs)])
|
logs_names = np.array(logs_names[:len(logs)])
|
||||||
|
|
|
@ -78,11 +78,11 @@ class S3DISConfig(Config):
|
||||||
'resnetb',
|
'resnetb',
|
||||||
'resnetb',
|
'resnetb',
|
||||||
'resnetb_strided',
|
'resnetb_strided',
|
||||||
'resnetb_deformable',
|
'resnetb',
|
||||||
'resnetb_deformable',
|
'resnetb',
|
||||||
'resnetb_deformable_strided',
|
'resnetb_strided',
|
||||||
'resnetb_deformable',
|
'resnetb',
|
||||||
'resnetb_deformable',
|
'resnetb',
|
||||||
'nearest_upsample',
|
'nearest_upsample',
|
||||||
'unary',
|
'unary',
|
||||||
'nearest_upsample',
|
'nearest_upsample',
|
||||||
|
@ -132,10 +132,11 @@ class S3DISConfig(Config):
|
||||||
batch_norm_momentum = 0.02
|
batch_norm_momentum = 0.02
|
||||||
|
|
||||||
# Offset loss
|
# Offset loss
|
||||||
# 'permissive' only constrains offsets inside the deform radius (NOT implemented yet)
|
# 'point2point' fitting geometry by penalizing distance from deform point to input points
|
||||||
# 'fitting' helps deformed kernels to adapt to the geometry by penalizing distance to input points
|
# 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet
|
||||||
offsets_loss = 'fitting'
|
deform_fitting_mode = 'point2point'
|
||||||
offsets_decay = 0.05
|
deform_fitting_power = 0.05
|
||||||
|
deform_loss_power = 0.5
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
# Training parameters
|
# Training parameters
|
||||||
|
@ -195,7 +196,7 @@ if __name__ == '__main__':
|
||||||
############################
|
############################
|
||||||
|
|
||||||
# Set which gpu is going to be used
|
# Set which gpu is going to be used
|
||||||
GPU_ID = '3'
|
GPU_ID = '2'
|
||||||
|
|
||||||
# Set GPU visible device
|
# Set GPU visible device
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
|
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
|
||||||
|
|
|
@ -159,9 +159,10 @@ class Config:
|
||||||
# Choose weights for class (used in segmentation loss). Empty list for no weights
|
# Choose weights for class (used in segmentation loss). Empty list for no weights
|
||||||
class_w = []
|
class_w = []
|
||||||
|
|
||||||
# Offset regularization loss
|
# New offset regularization parameters
|
||||||
offsets_loss = 'permissive'
|
deform_fitting_mode = 'point2point'
|
||||||
offsets_decay = 1e-2
|
deform_fitting_power = 0.05
|
||||||
|
deform_loss_power = 0.5
|
||||||
|
|
||||||
# Number of batch
|
# Number of batch
|
||||||
batch_num = 10
|
batch_num = 10
|
||||||
|
@ -259,7 +260,7 @@ class Config:
|
||||||
elif line_info[0] == 'class_w':
|
elif line_info[0] == 'class_w':
|
||||||
self.class_w = [float(w) for w in line_info[2:]]
|
self.class_w = [float(w) for w in line_info[2:]]
|
||||||
|
|
||||||
else:
|
elif hasattr(self, line_info[0]):
|
||||||
attr_type = type(getattr(self, line_info[0]))
|
attr_type = type(getattr(self, line_info[0]))
|
||||||
if attr_type == bool:
|
if attr_type == bool:
|
||||||
setattr(self, line_info[0], attr_type(int(line_info[2])))
|
setattr(self, line_info[0], attr_type(int(line_info[2])))
|
||||||
|
@ -362,8 +363,9 @@ class Config:
|
||||||
for a in self.class_w:
|
for a in self.class_w:
|
||||||
text_file.write(' {:.3f}'.format(a))
|
text_file.write(' {:.3f}'.format(a))
|
||||||
text_file.write('\n')
|
text_file.write('\n')
|
||||||
text_file.write('offsets_loss = {:s}\n'.format(self.offsets_loss))
|
text_file.write('deform_fitting_mode = {:s}\n'.format(self.deform_fitting_mode))
|
||||||
text_file.write('offsets_decay = {:f}\n'.format(self.offsets_decay))
|
text_file.write('deform_fitting_power = {:f}\n'.format(self.deform_fitting_power))
|
||||||
|
text_file.write('deform_loss_power = {:f}\n'.format(self.deform_loss_power))
|
||||||
text_file.write('batch_num = {:d}\n'.format(self.batch_num))
|
text_file.write('batch_num = {:d}\n'.format(self.batch_num))
|
||||||
text_file.write('val_batch_num = {:d}\n'.format(self.val_batch_num))
|
text_file.write('val_batch_num = {:d}\n'.format(self.val_batch_num))
|
||||||
text_file.write('max_epoch = {:d}\n'.format(self.max_epoch))
|
text_file.write('max_epoch = {:d}\n'.format(self.max_epoch))
|
||||||
|
|
|
@ -79,7 +79,14 @@ class ModelVisualizer:
|
||||||
##########################
|
##########################
|
||||||
|
|
||||||
checkpoint = torch.load(chkp_path)
|
checkpoint = torch.load(chkp_path)
|
||||||
net.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
|
new_dict = {}
|
||||||
|
for k, v in checkpoint['model_state_dict'].items():
|
||||||
|
if 'blocs' in k:
|
||||||
|
k = k.replace('blocs', 'blocks')
|
||||||
|
new_dict[k] = v
|
||||||
|
|
||||||
|
net.load_state_dict(new_dict)
|
||||||
self.epoch = checkpoint['epoch']
|
self.epoch = checkpoint['epoch']
|
||||||
net.eval()
|
net.eval()
|
||||||
print("\nModel state restored from {:s}.".format(chkp_path))
|
print("\nModel state restored from {:s}.".format(chkp_path))
|
||||||
|
|
|
@ -96,7 +96,12 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# chosen_log = 'results/Log_2020-04-04_10-04-42' # => ModelNet40
|
# chosen_log = 'results/Log_2020-04-04_10-04-42' # => ModelNet40
|
||||||
# chosen_log = 'results/Log_2020-04-22_11-53-45' # => S3DIS
|
# chosen_log = 'results/Log_2020-04-22_11-53-45' # => S3DIS
|
||||||
chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected
|
# chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected
|
||||||
|
# chosen_log = 'results/Log_2020-04-23_09-48-15' # => S3DIS no repulsive
|
||||||
|
# chosen_log = 'results/Log_2020-04-23_09-49-49' # => S3DIS repulsive 0.5
|
||||||
|
# chosen_log = 'results/Log_2020-04-23_19-41-12' # => S3DIS 10*fitting
|
||||||
|
chosen_log = 'results/Log_2020-04-23_19-42-18' # => S3DIS no hook
|
||||||
|
|
||||||
|
|
||||||
# You can also choose the index of the snapshot to load (last by default)
|
# You can also choose the index of the snapshot to load (last by default)
|
||||||
chkp_idx = -1
|
chkp_idx = -1
|
||||||
|
|
Loading…
Reference in a new issue