Corrections
This commit is contained in:
parent
5b7ecb418c
commit
a101ed2e92
|
@ -27,19 +27,6 @@ def p2p_fitting_regularizer(net):
|
||||||
|
|
||||||
if isinstance(m, KPConv) and m.deformable:
|
if isinstance(m, KPConv) and m.deformable:
|
||||||
|
|
||||||
########################
|
|
||||||
# divide offset gradient
|
|
||||||
########################
|
|
||||||
#
|
|
||||||
# The offset gradient comes from two different losses. The regularizer loss (fitting deformation to point
|
|
||||||
# cloud) and the output loss (which should force deformations to help get a better score). The strength of
|
|
||||||
# the regularizer loss is set with the parameter deform_fitting_power. Therefore, this hook control the
|
|
||||||
# strength of the output loss. This strength can be set with the parameter deform_loss_power
|
|
||||||
#
|
|
||||||
|
|
||||||
m.offset_features.register_hook(lambda grad: grad * net.deform_loss_power)
|
|
||||||
# m.offset_features.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
|
|
||||||
|
|
||||||
##############
|
##############
|
||||||
# Fitting loss
|
# Fitting loss
|
||||||
##############
|
##############
|
||||||
|
@ -64,8 +51,7 @@ def p2p_fitting_regularizer(net):
|
||||||
rep_loss = torch.sum(torch.clamp_max(distances - net.repulse_extent, max=0.0) ** 2, dim=1)
|
rep_loss = torch.sum(torch.clamp_max(distances - net.repulse_extent, max=0.0) ** 2, dim=1)
|
||||||
repulsive_loss += net.l1(rep_loss, torch.zeros_like(rep_loss)) / net.K
|
repulsive_loss += net.l1(rep_loss, torch.zeros_like(rep_loss)) / net.K
|
||||||
|
|
||||||
# The hook effectively affect both regularizer and output loss. So here we have to divide by deform_loss_power
|
return net.deform_fitting_power * (2 * fitting_loss + repulsive_loss)
|
||||||
return (net.deform_fitting_power / net.deform_loss_power) * (2 * fitting_loss + repulsive_loss)
|
|
||||||
|
|
||||||
|
|
||||||
class KPCNN(nn.Module):
|
class KPCNN(nn.Module):
|
||||||
|
@ -139,7 +125,7 @@ class KPCNN(nn.Module):
|
||||||
self.criterion = torch.nn.CrossEntropyLoss()
|
self.criterion = torch.nn.CrossEntropyLoss()
|
||||||
self.deform_fitting_mode = config.deform_fitting_mode
|
self.deform_fitting_mode = config.deform_fitting_mode
|
||||||
self.deform_fitting_power = config.deform_fitting_power
|
self.deform_fitting_power = config.deform_fitting_power
|
||||||
self.deform_loss_power = config.deform_loss_power
|
self.deform_lr_factor = config.deform_lr_factor
|
||||||
self.repulse_extent = config.repulse_extent
|
self.repulse_extent = config.repulse_extent
|
||||||
self.output_loss = 0
|
self.output_loss = 0
|
||||||
self.reg_loss = 0
|
self.reg_loss = 0
|
||||||
|
@ -325,7 +311,7 @@ class KPFCNN(nn.Module):
|
||||||
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
|
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
|
||||||
self.deform_fitting_mode = config.deform_fitting_mode
|
self.deform_fitting_mode = config.deform_fitting_mode
|
||||||
self.deform_fitting_power = config.deform_fitting_power
|
self.deform_fitting_power = config.deform_fitting_power
|
||||||
self.deform_loss_power = config.deform_loss_power
|
self.deform_lr_factor = config.deform_lr_factor
|
||||||
self.repulse_extent = config.repulse_extent
|
self.repulse_extent = config.repulse_extent
|
||||||
self.output_loss = 0
|
self.output_loss = 0
|
||||||
self.reg_loss = 0
|
self.reg_loss = 0
|
||||||
|
|
|
@ -174,8 +174,6 @@ class KPConv(nn.Module):
|
||||||
self.deformable = deformable
|
self.deformable = deformable
|
||||||
self.modulated = modulated
|
self.modulated = modulated
|
||||||
|
|
||||||
self.in_offset_channels = in_channels
|
|
||||||
|
|
||||||
# Running variable containing deformed KP distance to input points. (used in regularization loss)
|
# Running variable containing deformed KP distance to input points. (used in regularization loss)
|
||||||
self.min_d2 = None
|
self.min_d2 = None
|
||||||
self.deformed_KP = None
|
self.deformed_KP = None
|
||||||
|
@ -193,7 +191,7 @@ class KPConv(nn.Module):
|
||||||
self.offset_dim = self.p_dim * self.K
|
self.offset_dim = self.p_dim * self.K
|
||||||
self.offset_conv = KPConv(self.K,
|
self.offset_conv = KPConv(self.K,
|
||||||
self.p_dim,
|
self.p_dim,
|
||||||
self.in_offset_channels,
|
self.in_channels,
|
||||||
self.offset_dim,
|
self.offset_dim,
|
||||||
KP_extent,
|
KP_extent,
|
||||||
radius,
|
radius,
|
||||||
|
@ -245,8 +243,7 @@ class KPConv(nn.Module):
|
||||||
if self.deformable:
|
if self.deformable:
|
||||||
|
|
||||||
# Get offsets with a KPConv that only takes part of the features
|
# Get offsets with a KPConv that only takes part of the features
|
||||||
x_offsets = x[:, :self.in_offset_channels]
|
self.offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias
|
||||||
self.offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x_offsets) + self.offset_bias
|
|
||||||
|
|
||||||
if self.modulated:
|
if self.modulated:
|
||||||
|
|
||||||
|
|
|
@ -124,8 +124,8 @@ class Modelnet40Config(Config):
|
||||||
# 'point2point' fitting geometry by penalizing distance from deform point to input points
|
# 'point2point' fitting geometry by penalizing distance from deform point to input points
|
||||||
# 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
|
# 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
|
||||||
deform_fitting_mode = 'point2point'
|
deform_fitting_mode = 'point2point'
|
||||||
deform_fitting_power = 0.1 # Multiplier for the fitting/repulsive loss
|
deform_fitting_power = 1.0 # Multiplier for the fitting/repulsive loss
|
||||||
deform_loss_power = 0.1 # Multiplier for output loss applied to the deformations
|
deform_lr_factor = 0.1 # Multiplier for learning rate applied to the deformations
|
||||||
repulse_extent = 0.8 # Distance of repulsion for deformed kernel points
|
repulse_extent = 0.8 # Distance of repulsion for deformed kernel points
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
|
|
|
@ -132,9 +132,9 @@ class S3DISConfig(Config):
|
||||||
# 'point2point' fitting geometry by penalizing distance from deform point to input points
|
# 'point2point' fitting geometry by penalizing distance from deform point to input points
|
||||||
# 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
|
# 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
|
||||||
deform_fitting_mode = 'point2point'
|
deform_fitting_mode = 'point2point'
|
||||||
deform_fitting_power = 0.1 # Multiplier for the fitting/repulsive loss
|
deform_fitting_power = 1.0 # Multiplier for the fitting/repulsive loss
|
||||||
deform_loss_power = 0.1 # Multiplier for output loss applied to the deformations
|
deform_lr_factor = 0.1 # Multiplier for learning rate applied to the deformations
|
||||||
repulse_extent = 0.8 # Distance of repulsion for deformed kernel points
|
repulse_extent = 1.2 # Distance of repulsion for deformed kernel points
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
# Training parameters
|
# Training parameters
|
||||||
|
|
|
@ -146,8 +146,8 @@ class SemanticKittiConfig(Config):
|
||||||
# 'point2point' fitting geometry by penalizing distance from deform point to input points
|
# 'point2point' fitting geometry by penalizing distance from deform point to input points
|
||||||
# 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
|
# 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
|
||||||
deform_fitting_mode = 'point2point'
|
deform_fitting_mode = 'point2point'
|
||||||
deform_fitting_power = 0.1 # Multiplier for the fitting/repulsive loss
|
deform_fitting_power = 1.0 # Multiplier for the fitting/repulsive loss
|
||||||
deform_loss_power = 0.1 # Multiplier for output loss applied to the deformations
|
deform_lr_factor = 0.1 # Multiplier for learning rate applied to the deformations
|
||||||
repulse_extent = 0.8 # Distance of repulsion for deformed kernel points
|
repulse_extent = 0.8 # Distance of repulsion for deformed kernel points
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
|
|
|
@ -163,8 +163,8 @@ class Config:
|
||||||
# 'point2point' fitting geometry by penalizing distance from deform point to input points
|
# 'point2point' fitting geometry by penalizing distance from deform point to input points
|
||||||
# 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
|
# 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented)
|
||||||
deform_fitting_mode = 'point2point'
|
deform_fitting_mode = 'point2point'
|
||||||
deform_fitting_power = 0.1 # Multiplier for the fitting/repulsive loss
|
deform_fitting_power = 1.0 # Multiplier for the fitting/repulsive loss
|
||||||
deform_loss_power = 0.1 # Multiplier for output loss applied to the deformations
|
deform_lr_factor = 0.1 # Multiplier for learning rate applied to the deformations
|
||||||
repulse_extent = 1.0 # Distance of repulsion for deformed kernel points
|
repulse_extent = 1.0 # Distance of repulsion for deformed kernel points
|
||||||
|
|
||||||
# Number of batch
|
# Number of batch
|
||||||
|
@ -368,7 +368,7 @@ class Config:
|
||||||
text_file.write('\n')
|
text_file.write('\n')
|
||||||
text_file.write('deform_fitting_mode = {:s}\n'.format(self.deform_fitting_mode))
|
text_file.write('deform_fitting_mode = {:s}\n'.format(self.deform_fitting_mode))
|
||||||
text_file.write('deform_fitting_power = {:.6f}\n'.format(self.deform_fitting_power))
|
text_file.write('deform_fitting_power = {:.6f}\n'.format(self.deform_fitting_power))
|
||||||
text_file.write('deform_loss_power = {:.6f}\n'.format(self.deform_loss_power))
|
text_file.write('deform_lr_factor = {:.6f}\n'.format(self.deform_lr_factor))
|
||||||
text_file.write('repulse_extent = {:.6f}\n'.format(self.repulse_extent))
|
text_file.write('repulse_extent = {:.6f}\n'.format(self.repulse_extent))
|
||||||
text_file.write('batch_num = {:d}\n'.format(self.batch_num))
|
text_file.write('batch_num = {:d}\n'.format(self.batch_num))
|
||||||
text_file.write('val_batch_num = {:d}\n'.format(self.val_batch_num))
|
text_file.write('val_batch_num = {:d}\n'.format(self.val_batch_num))
|
||||||
|
|
|
@ -74,8 +74,12 @@ class ModelTrainer:
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
# Optimizer
|
# Optimizer with specific learning rate for deformable KPConv
|
||||||
self.optimizer = torch.optim.SGD(net.parameters(),
|
deform_params = [v for k, v in net.named_parameters() if 'offset' in k]
|
||||||
|
other_params = [v for k, v in net.named_parameters() if 'offset' not in k]
|
||||||
|
deform_lr = config.learning_rate * config.deform_lr_factor
|
||||||
|
self.optimizer = torch.optim.SGD([{'params': other_params},
|
||||||
|
{'params': deform_params, 'lr': deform_lr}],
|
||||||
lr=config.learning_rate,
|
lr=config.learning_rate,
|
||||||
momentum=config.momentum,
|
momentum=config.momentum,
|
||||||
weight_decay=config.weight_decay)
|
weight_decay=config.weight_decay)
|
||||||
|
|
Loading…
Reference in a new issue