diff --git a/models/architectures.py b/models/architectures.py index 42d98e9..b01c8bc 100644 --- a/models/architectures.py +++ b/models/architectures.py @@ -17,6 +17,61 @@ from models.blocks import * import numpy as np + +def p2p_fitting_regularizer(net): + + fitting_loss = 0 + repulsive_loss = 0 + + for m in net.modules(): + + if isinstance(m, KPConv) and m.deformable: + + ######################## + # divide offset gradient + ######################## + # + # The offset gradient comes from two different losses. The regularizer loss (fitting deformation to point + # cloud) and the output loss (which should force deformations to help get a better score). The strength of + # the regularizer loss is set with the parameter deform_fitting_power. Therefore, this hook control the + # strength of the output loss. This strength can be set with the parameter deform_loss_power + # + + m.offset_features.register_hook(lambda grad: grad * net.deform_loss_power) + # m.offset_features.register_hook(lambda grad: print('GRAD2', grad[10, 5, :])) + + ############## + # Fitting loss + ############## + + # Get the distance to closest input point + + KP_min_d2, _ = torch.min(m.deformed_d2, dim=1) + + # Normalize KP locations to be independant from layers + KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2) + + # Loss will be the square distance to closest input point. We use L1 because dist is already squared + fitting_loss += net.l1(KP_min_d2, torch.zeros_like(KP_min_d2)) + + ################ + # Repulsive loss + ################ + + # Normalized KP locations + KP_locs = m.deformed_KP / m.KP_extent + + # Point should not be close to each other + for i in range(net.K): + other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach() + distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2)) + rep_loss = torch.sum(torch.clamp_max(distances - 1.0, max=0.0) ** 2, dim=1) + repulsive_loss += net.l1(rep_loss, torch.zeros_like(rep_loss)) / net.K + + # The hook effectively affect both regularizer and output loss. So here we have to divide by deform_loss_power + return (net.deform_fitting_power / net.deform_loss_power) * (fitting_loss + repulsive_loss) + + class KPCNN(nn.Module): """ Class defining KPCNN @@ -86,8 +141,9 @@ class KPCNN(nn.Module): ################ self.criterion = torch.nn.CrossEntropyLoss() - self.offset_loss = config.offsets_loss - self.offset_decay = config.offsets_decay + self.deform_fitting_mode = config.deform_fitting_mode + self.deform_fitting_power = config.deform_fitting_power + self.deform_loss_power = config.deform_loss_power self.output_loss = 0 self.reg_loss = 0 self.l1 = nn.L1Loss() @@ -121,7 +177,12 @@ class KPCNN(nn.Module): self.output_loss = self.criterion(outputs, labels) # Regularization of deformable offsets - self.reg_loss = self.offset_regularizer() + if self.deform_fitting_mode == 'point2point': + self.reg_loss = p2p_fitting_regularizer(self) + elif self.deform_fitting_mode == 'point2plane': + raise ValueError('point2plane fitting mode not implemented yet.') + else: + raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode) # Combined loss return self.output_loss + self.reg_loss @@ -141,57 +202,6 @@ class KPCNN(nn.Module): return correct / total - def offset_regularizer(self): - - fitting_loss = 0 - repulsive_loss = 0 - - for m in self.modules(): - - if isinstance(m, KPConv) and m.deformable: - - ############################## - # divide offset gradient by 10 - ############################## - - m.unscaled_offsets.register_hook(lambda grad: grad * 0.1) - #m.unscaled_offsets.register_hook(lambda grad: print('GRAD2', grad[10, 5, :])) - - ############## - # Fitting loss - ############## - - # Get the distance to closest input point - KP_min_d2, _ = torch.min(m.deformed_d2, dim=1) - - # Normalize KP locations to be independant from layers - KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2) - - # Loss will be the square distance to closest input point. We use L1 because dist is already squared - fitting_loss += self.l1(KP_min_d2, torch.zeros_like(KP_min_d2)) - - ################ - # Repulsive loss - ################ - - # Normalized KP locations - KP_locs = m.deformed_KP / m.KP_extent - - # Point should not be close to each other - for i in range(self.K): - - other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach() - distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2)) - rep_loss = torch.sum(torch.clamp_max(distances - 1.5, max=0.0) ** 2, dim=1) - repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K - - - - return self.offset_decay * (fitting_loss + repulsive_loss) - - - - class KPFCNN(nn.Module): """ @@ -316,8 +326,9 @@ class KPFCNN(nn.Module): self.criterion = torch.nn.CrossEntropyLoss(weight=class_w, ignore_index=-1) else: self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) - self.offset_loss = config.offsets_loss - self.offset_decay = config.offsets_decay + self.deform_fitting_mode = config.deform_fitting_mode + self.deform_fitting_power = config.deform_fitting_power + self.deform_loss_power = config.deform_loss_power self.output_loss = 0 self.reg_loss = 0 self.l1 = nn.L1Loss() @@ -369,7 +380,12 @@ class KPFCNN(nn.Module): self.output_loss = self.criterion(outputs, target) # Regularization of deformable offsets - self.reg_loss = self.offset_regularizer() + if self.deform_fitting_mode == 'point2point': + self.reg_loss = p2p_fitting_regularizer(self) + elif self.deform_fitting_mode == 'point2plane': + raise ValueError('point2plane fitting mode not implemented yet.') + else: + raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode) # Combined loss return self.output_loss + self.reg_loss @@ -393,57 +409,6 @@ class KPFCNN(nn.Module): return correct / total - def offset_regularizer(self): - - fitting_loss = 0 - repulsive_loss = 0 - - for m in self.modules(): - - if isinstance(m, KPConv) and m.deformable: - - ############################## - # divide offset gradient by 10 - ############################## - - m.unscaled_offsets.register_hook(lambda grad: grad * 0.1) - #m.unscaled_offsets.register_hook(lambda grad: print('GRAD2', grad[10, 5, :])) - - ############## - # Fitting loss - ############## - - # Get the distance to closest input point - KP_min_d2, _ = torch.min(m.deformed_d2, dim=1) - - # Normalize KP locations to be independant from layers - KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2) - - # Loss will be the square distance to closest input point. We use L1 because dist is already squared - fitting_loss += self.l1(KP_min_d2, torch.zeros_like(KP_min_d2)) - - ################ - # Repulsive loss - ################ - - # Normalized KP locations - KP_locs = m.deformed_KP / m.KP_extent - - # Point should not be close to each other - for i in range(self.K): - - other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach() - distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2)) - rep_loss = torch.sum(torch.clamp_max(distances - 0.5, max=0.0) ** 2, dim=1) - repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K - - - return self.offset_decay * (fitting_loss + repulsive_loss) - - - - - diff --git a/models/blocks.py b/models/blocks.py index f68f9d6..4a9a241 100644 --- a/models/blocks.py +++ b/models/blocks.py @@ -177,7 +177,7 @@ class KPConv(nn.Module): # Running variable containing deformed KP distance to input points. (used in regularization loss) self.deformed_d2 = None self.deformed_KP = None - self.unscaled_offsets = None + self.offset_features = None # Initialize weights self.weights = Parameter(torch.zeros((self.K, in_channels, out_channels), dtype=torch.float32), @@ -241,27 +241,27 @@ class KPConv(nn.Module): ################### if self.deformable: - offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias + self.offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias if self.modulated: # Get offset (in normalized scale) from features - offsets = offset_features[:, :self.p_dim * self.K] - self.unscaled_offsets = offsets.view(-1, self.K, self.p_dim) + unscaled_offsets = self.offset_features[:, :self.p_dim * self.K] + unscaled_offsets = unscaled_offsets.view(-1, self.K, self.p_dim) # Get modulations - modulations = 2 * torch.sigmoid(offset_features[:, self.p_dim * self.K:]) + modulations = 2 * torch.sigmoid(self.offset_features[:, self.p_dim * self.K:]) else: # Get offset (in normalized scale) from features - self.unscaled_offsets = offset_features.view(-1, self.K, self.p_dim) + unscaled_offsets = self.offset_features.view(-1, self.K, self.p_dim) # No modulations modulations = None # Rescale offset for this layer - offsets = self.unscaled_offsets * self.KP_extent + offsets = unscaled_offsets * self.KP_extent else: offsets = None diff --git a/plot_convergence.py b/plot_convergence.py index 4370ca9..bc81afe 100644 --- a/plot_convergence.py +++ b/plot_convergence.py @@ -1515,6 +1515,9 @@ def S3DIS_deform(old_result_limit): Debug S3DIS deformable. At checkpoint 50, the points seem to start fitting the shape, but then, they just get further away from each other and do not care about input points. The fitting loss seems broken? + + 10* fitting loss seems pretty good fitting the point cloud. It seems that the offset decay was a bit to low, + because the same happens without the 0.1 hook. So we can try to keep a 0.5 hook and multiply offset decay by 2. """ # Using the dates of the logs, you can easily gather consecutive ones. All logs should be of the same dataset. @@ -1537,6 +1540,10 @@ def S3DIS_deform(old_result_limit): 'off_d=0.05_corrected', 'off_d=0.05_norepulsive', 'off_d=0.05_repulsive0.5', + 'off_d=0.05_10*fitting', + 'off_d=0.05_no_hook0.1', + 'NEWPARAMS_fit=0.05_loss=0.5_(=off_d=0.1_hook0.5)', + 'same_normal' 'test'] logs_names = np.array(logs_names[:len(logs)]) diff --git a/train_S3DIS.py b/train_S3DIS.py index 471607f..276fffd 100644 --- a/train_S3DIS.py +++ b/train_S3DIS.py @@ -78,11 +78,11 @@ class S3DISConfig(Config): 'resnetb', 'resnetb', 'resnetb_strided', - 'resnetb_deformable', - 'resnetb_deformable', - 'resnetb_deformable_strided', - 'resnetb_deformable', - 'resnetb_deformable', + 'resnetb', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', 'nearest_upsample', 'unary', 'nearest_upsample', @@ -132,10 +132,11 @@ class S3DISConfig(Config): batch_norm_momentum = 0.02 # Offset loss - # 'permissive' only constrains offsets inside the deform radius (NOT implemented yet) - # 'fitting' helps deformed kernels to adapt to the geometry by penalizing distance to input points - offsets_loss = 'fitting' - offsets_decay = 0.05 + # 'point2point' fitting geometry by penalizing distance from deform point to input points + # 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet + deform_fitting_mode = 'point2point' + deform_fitting_power = 0.05 + deform_loss_power = 0.5 ##################### # Training parameters @@ -195,7 +196,7 @@ if __name__ == '__main__': ############################ # Set which gpu is going to be used - GPU_ID = '3' + GPU_ID = '2' # Set GPU visible device os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID diff --git a/utils/config.py b/utils/config.py index 12f0bbd..c77083e 100644 --- a/utils/config.py +++ b/utils/config.py @@ -159,9 +159,10 @@ class Config: # Choose weights for class (used in segmentation loss). Empty list for no weights class_w = [] - # Offset regularization loss - offsets_loss = 'permissive' - offsets_decay = 1e-2 + # New offset regularization parameters + deform_fitting_mode = 'point2point' + deform_fitting_power = 0.05 + deform_loss_power = 0.5 # Number of batch batch_num = 10 @@ -259,7 +260,7 @@ class Config: elif line_info[0] == 'class_w': self.class_w = [float(w) for w in line_info[2:]] - else: + elif hasattr(self, line_info[0]): attr_type = type(getattr(self, line_info[0])) if attr_type == bool: setattr(self, line_info[0], attr_type(int(line_info[2]))) @@ -362,8 +363,9 @@ class Config: for a in self.class_w: text_file.write(' {:.3f}'.format(a)) text_file.write('\n') - text_file.write('offsets_loss = {:s}\n'.format(self.offsets_loss)) - text_file.write('offsets_decay = {:f}\n'.format(self.offsets_decay)) + text_file.write('deform_fitting_mode = {:s}\n'.format(self.deform_fitting_mode)) + text_file.write('deform_fitting_power = {:f}\n'.format(self.deform_fitting_power)) + text_file.write('deform_loss_power = {:f}\n'.format(self.deform_loss_power)) text_file.write('batch_num = {:d}\n'.format(self.batch_num)) text_file.write('val_batch_num = {:d}\n'.format(self.val_batch_num)) text_file.write('max_epoch = {:d}\n'.format(self.max_epoch)) diff --git a/utils/visualizer.py b/utils/visualizer.py index ef8b16a..f01f1b4 100644 --- a/utils/visualizer.py +++ b/utils/visualizer.py @@ -79,7 +79,14 @@ class ModelVisualizer: ########################## checkpoint = torch.load(chkp_path) - net.load_state_dict(checkpoint['model_state_dict']) + + new_dict = {} + for k, v in checkpoint['model_state_dict'].items(): + if 'blocs' in k: + k = k.replace('blocs', 'blocks') + new_dict[k] = v + + net.load_state_dict(new_dict) self.epoch = checkpoint['epoch'] net.eval() print("\nModel state restored from {:s}.".format(chkp_path)) diff --git a/visualize_deformations.py b/visualize_deformations.py index cb2ad0c..c4519b0 100644 --- a/visualize_deformations.py +++ b/visualize_deformations.py @@ -96,7 +96,12 @@ if __name__ == '__main__': # chosen_log = 'results/Log_2020-04-04_10-04-42' # => ModelNet40 # chosen_log = 'results/Log_2020-04-22_11-53-45' # => S3DIS - chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected + # chosen_log = 'results/Log_2020-04-22_12-28-37' # => S3DIS corrected + # chosen_log = 'results/Log_2020-04-23_09-48-15' # => S3DIS no repulsive + # chosen_log = 'results/Log_2020-04-23_09-49-49' # => S3DIS repulsive 0.5 + # chosen_log = 'results/Log_2020-04-23_19-41-12' # => S3DIS 10*fitting + chosen_log = 'results/Log_2020-04-23_19-42-18' # => S3DIS no hook + # You can also choose the index of the snapshot to load (last by default) chkp_idx = -1