diff --git a/models/architectures.py b/models/architectures.py index 684ab02..e8cb254 100644 --- a/models/architectures.py +++ b/models/architectures.py @@ -27,19 +27,6 @@ def p2p_fitting_regularizer(net): if isinstance(m, KPConv) and m.deformable: - ######################## - # divide offset gradient - ######################## - # - # The offset gradient comes from two different losses. The regularizer loss (fitting deformation to point - # cloud) and the output loss (which should force deformations to help get a better score). The strength of - # the regularizer loss is set with the parameter deform_fitting_power. Therefore, this hook control the - # strength of the output loss. This strength can be set with the parameter deform_loss_power - # - - m.offset_features.register_hook(lambda grad: grad * net.deform_loss_power) - # m.offset_features.register_hook(lambda grad: print('GRAD2', grad[10, 5, :])) - ############## # Fitting loss ############## @@ -64,8 +51,7 @@ def p2p_fitting_regularizer(net): rep_loss = torch.sum(torch.clamp_max(distances - net.repulse_extent, max=0.0) ** 2, dim=1) repulsive_loss += net.l1(rep_loss, torch.zeros_like(rep_loss)) / net.K - # The hook effectively affect both regularizer and output loss. So here we have to divide by deform_loss_power - return (net.deform_fitting_power / net.deform_loss_power) * (2 * fitting_loss + repulsive_loss) + return net.deform_fitting_power * (2 * fitting_loss + repulsive_loss) class KPCNN(nn.Module): @@ -139,7 +125,7 @@ class KPCNN(nn.Module): self.criterion = torch.nn.CrossEntropyLoss() self.deform_fitting_mode = config.deform_fitting_mode self.deform_fitting_power = config.deform_fitting_power - self.deform_loss_power = config.deform_loss_power + self.deform_lr_factor = config.deform_lr_factor self.repulse_extent = config.repulse_extent self.output_loss = 0 self.reg_loss = 0 @@ -325,7 +311,7 @@ class KPFCNN(nn.Module): self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) self.deform_fitting_mode = config.deform_fitting_mode self.deform_fitting_power = config.deform_fitting_power - self.deform_loss_power = config.deform_loss_power + self.deform_lr_factor = config.deform_lr_factor self.repulse_extent = config.repulse_extent self.output_loss = 0 self.reg_loss = 0 diff --git a/models/blocks.py b/models/blocks.py index 918cd14..86b04a3 100644 --- a/models/blocks.py +++ b/models/blocks.py @@ -174,8 +174,6 @@ class KPConv(nn.Module): self.deformable = deformable self.modulated = modulated - self.in_offset_channels = in_channels - # Running variable containing deformed KP distance to input points. (used in regularization loss) self.min_d2 = None self.deformed_KP = None @@ -193,7 +191,7 @@ class KPConv(nn.Module): self.offset_dim = self.p_dim * self.K self.offset_conv = KPConv(self.K, self.p_dim, - self.in_offset_channels, + self.in_channels, self.offset_dim, KP_extent, radius, @@ -245,8 +243,7 @@ class KPConv(nn.Module): if self.deformable: # Get offsets with a KPConv that only takes part of the features - x_offsets = x[:, :self.in_offset_channels] - self.offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x_offsets) + self.offset_bias + self.offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias if self.modulated: diff --git a/train_ModelNet40.py b/train_ModelNet40.py index c8cb1b2..3f5ef08 100644 --- a/train_ModelNet40.py +++ b/train_ModelNet40.py @@ -124,8 +124,8 @@ class Modelnet40Config(Config): # 'point2point' fitting geometry by penalizing distance from deform point to input points # 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented) deform_fitting_mode = 'point2point' - deform_fitting_power = 0.1 # Multiplier for the fitting/repulsive loss - deform_loss_power = 0.1 # Multiplier for output loss applied to the deformations + deform_fitting_power = 1.0 # Multiplier for the fitting/repulsive loss + deform_lr_factor = 0.1 # Multiplier for learning rate applied to the deformations repulse_extent = 0.8 # Distance of repulsion for deformed kernel points ##################### diff --git a/train_S3DIS.py b/train_S3DIS.py index 3da3146..674a92a 100644 --- a/train_S3DIS.py +++ b/train_S3DIS.py @@ -132,9 +132,9 @@ class S3DISConfig(Config): # 'point2point' fitting geometry by penalizing distance from deform point to input points # 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented) deform_fitting_mode = 'point2point' - deform_fitting_power = 0.1 # Multiplier for the fitting/repulsive loss - deform_loss_power = 0.1 # Multiplier for output loss applied to the deformations - repulse_extent = 0.8 # Distance of repulsion for deformed kernel points + deform_fitting_power = 1.0 # Multiplier for the fitting/repulsive loss + deform_lr_factor = 0.1 # Multiplier for learning rate applied to the deformations + repulse_extent = 1.2 # Distance of repulsion for deformed kernel points ##################### # Training parameters diff --git a/train_SemanticKitti.py b/train_SemanticKitti.py index d4579c5..cd98834 100644 --- a/train_SemanticKitti.py +++ b/train_SemanticKitti.py @@ -146,8 +146,8 @@ class SemanticKittiConfig(Config): # 'point2point' fitting geometry by penalizing distance from deform point to input points # 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented) deform_fitting_mode = 'point2point' - deform_fitting_power = 0.1 # Multiplier for the fitting/repulsive loss - deform_loss_power = 0.1 # Multiplier for output loss applied to the deformations + deform_fitting_power = 1.0 # Multiplier for the fitting/repulsive loss + deform_lr_factor = 0.1 # Multiplier for learning rate applied to the deformations repulse_extent = 0.8 # Distance of repulsion for deformed kernel points ##################### diff --git a/utils/config.py b/utils/config.py index 688f735..094774b 100644 --- a/utils/config.py +++ b/utils/config.py @@ -163,8 +163,8 @@ class Config: # 'point2point' fitting geometry by penalizing distance from deform point to input points # 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented) deform_fitting_mode = 'point2point' - deform_fitting_power = 0.1 # Multiplier for the fitting/repulsive loss - deform_loss_power = 0.1 # Multiplier for output loss applied to the deformations + deform_fitting_power = 1.0 # Multiplier for the fitting/repulsive loss + deform_lr_factor = 0.1 # Multiplier for learning rate applied to the deformations repulse_extent = 1.0 # Distance of repulsion for deformed kernel points # Number of batch @@ -368,7 +368,7 @@ class Config: text_file.write('\n') text_file.write('deform_fitting_mode = {:s}\n'.format(self.deform_fitting_mode)) text_file.write('deform_fitting_power = {:.6f}\n'.format(self.deform_fitting_power)) - text_file.write('deform_loss_power = {:.6f}\n'.format(self.deform_loss_power)) + text_file.write('deform_lr_factor = {:.6f}\n'.format(self.deform_lr_factor)) text_file.write('repulse_extent = {:.6f}\n'.format(self.repulse_extent)) text_file.write('batch_num = {:d}\n'.format(self.batch_num)) text_file.write('val_batch_num = {:d}\n'.format(self.val_batch_num)) diff --git a/utils/trainer.py b/utils/trainer.py index c4a255e..5b0aadb 100644 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -74,8 +74,12 @@ class ModelTrainer: self.epoch = 0 self.step = 0 - # Optimizer - self.optimizer = torch.optim.SGD(net.parameters(), + # Optimizer with specific learning rate for deformable KPConv + deform_params = [v for k, v in net.named_parameters() if 'offset' in k] + other_params = [v for k, v in net.named_parameters() if 'offset' not in k] + deform_lr = config.learning_rate * config.deform_lr_factor + self.optimizer = torch.optim.SGD([{'params': other_params}, + {'params': deform_params, 'lr': deform_lr}], lr=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)