# # # 0=================================0 # | Kernel Point Convolutions | # 0=================================0 # # # ---------------------------------------------------------------------------------------------------------------------- # # Define network blocks # # ---------------------------------------------------------------------------------------------------------------------- # # Hugues THOMAS - 06/03/2020 # import math import torch import torch.nn as nn from torch.nn.parameter import Parameter from torch.nn.init import kaiming_uniform_ from kernels.kernel_points import load_kernels # ---------------------------------------------------------------------------------------------------------------------- # # Simple functions # \**********************/ # def gather(x, idx, method=2): """ implementation of a custom gather operation for faster backwards. :param x: input with shape [N, D_1, ... D_d] :param idx: indexing with shape [n_1, ..., n_m] :param method: Choice of the method :return: x[idx] with shape [n_1, ..., n_m, D_1, ... D_d] """ if method == 0: return x[idx] elif method == 1: x = x.unsqueeze(1) x = x.expand((-1, idx.shape[-1], -1)) idx = idx.unsqueeze(2) idx = idx.expand((-1, -1, x.shape[-1])) return x.gather(0, idx) elif method == 2: for i, ni in enumerate(idx.size()[1:]): x = x.unsqueeze(i + 1) new_s = list(x.size()) new_s[i + 1] = ni x = x.expand(new_s) n = len(idx.size()) for i, di in enumerate(x.size()[n:]): idx = idx.unsqueeze(i + n) new_s = list(idx.size()) new_s[i + n] = di idx = idx.expand(new_s) return x.gather(0, idx) else: raise ValueError("Unkown method") def radius_gaussian(sq_r, sig, eps=1e-9): """ Compute a radius gaussian (gaussian of distance) :param sq_r: input radiuses [dn, ..., d1, d0] :param sig: extents of gaussians [d1, d0] or [d0] or float :return: gaussian of sq_r [dn, ..., d1, d0] """ return torch.exp(-sq_r / (2 * sig**2 + eps)) def closest_pool(x, inds): """ Pools features from the closest neighbors. WARNING: this function assumes the neighbors are ordered. :param x: [n1, d] features matrix :param inds: [n2, max_num] Only the first column is used for pooling :return: [n2, d] pooled features matrix """ # Add a last row with minimum features for shadow pools x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) # Get features for each pooling location [n2, d] return gather(x, inds[:, 0]) def max_pool(x, inds): """ Pools features with the maximum values. :param x: [n1, d] features matrix :param inds: [n2, max_num] pooling indices :return: [n2, d] pooled features matrix """ # Add a last row with minimum features for shadow pools x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) # Get all features for each pooling location [n2, max_num, d] pool_features = gather(x, inds) # Pool the maximum [n2, d] max_features, _ = torch.max(pool_features, 1) return max_features def global_average(x, batch_lengths): """ Block performing a global average over batch pooling :param x: [N, D] input features :param batch_lengths: [B] list of batch lengths :return: [B, D] averaged features """ # Loop over the clouds of the batch averaged_features = [] i0 = 0 for b_i, length in enumerate(batch_lengths): # Average features for each batch cloud averaged_features.append(torch.mean(x[i0 : i0 + length], dim=0)) # Increment for next cloud i0 += length # Average features in each batch return torch.stack(averaged_features) # ---------------------------------------------------------------------------------------------------------------------- # # KPConv class # \******************/ # class KPConv(nn.Module): def __init__( self, kernel_size, p_dim, in_channels, out_channels, KP_extent, radius, fixed_kernel_points="center", KP_influence="linear", aggregation_mode="sum", deformable=False, modulated=False, ): """ Initialize parameters for KPConvDeformable. :param kernel_size: Number of kernel points. :param p_dim: dimension of the point space. :param in_channels: dimension of input features. :param out_channels: dimension of output features. :param KP_extent: influence radius of each kernel point. :param radius: radius used for kernel point init. Even for deformable, use the config.conv_radius :param fixed_kernel_points: fix position of certain kernel points ('none', 'center' or 'verticals'). :param KP_influence: influence function of the kernel points ('constant', 'linear', 'gaussian'). :param aggregation_mode: choose to sum influences, or only keep the closest ('closest', 'sum'). :param deformable: choose deformable or not :param modulated: choose if kernel weights are modulated in addition to deformed """ super(KPConv, self).__init__() # Save parameters self.K = kernel_size self.p_dim = p_dim self.in_channels = in_channels self.out_channels = out_channels self.radius = radius self.KP_extent = KP_extent self.fixed_kernel_points = fixed_kernel_points self.KP_influence = KP_influence self.aggregation_mode = aggregation_mode self.deformable = deformable self.modulated = modulated # Running variable containing deformed KP distance to input points. (used in regularization loss) self.min_d2 = None self.deformed_KP = None self.offset_features = None # Initialize weights self.weights = Parameter( torch.zeros((self.K, in_channels, out_channels), dtype=torch.float32), requires_grad=True, ) # Initiate weights for offsets if deformable: if modulated: self.offset_dim = (self.p_dim + 1) * self.K else: self.offset_dim = self.p_dim * self.K self.offset_conv = KPConv( self.K, self.p_dim, self.in_channels, self.offset_dim, KP_extent, radius, fixed_kernel_points=fixed_kernel_points, KP_influence=KP_influence, aggregation_mode=aggregation_mode, ) self.offset_bias = Parameter( torch.zeros(self.offset_dim, dtype=torch.float32), requires_grad=True ) else: self.offset_dim = None self.offset_conv = None self.offset_bias = None # Reset parameters self.reset_parameters() # Initialize kernel points self.kernel_points = self.init_KP() return def reset_parameters(self): kaiming_uniform_(self.weights, a=math.sqrt(5)) if self.deformable: nn.init.zeros_(self.offset_bias) return def init_KP(self): """ Initialize the kernel point positions in a sphere :return: the tensor of kernel points """ # Create one kernel disposition (as numpy array). Choose the KP distance to center thanks to the KP extent K_points_numpy = load_kernels( self.radius, self.K, dimension=self.p_dim, fixed=self.fixed_kernel_points ) return Parameter( torch.tensor(K_points_numpy, dtype=torch.float32), requires_grad=False ) def forward(self, q_pts, s_pts, neighb_inds, x): ################### # Offset generation ################### if self.deformable: # Get offsets with a KPConv that only takes part of the features self.offset_features = ( self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias ) if self.modulated: # Get offset (in normalized scale) from features unscaled_offsets = self.offset_features[:, : self.p_dim * self.K] unscaled_offsets = unscaled_offsets.view(-1, self.K, self.p_dim) # Get modulations modulations = 2 * torch.sigmoid( self.offset_features[:, self.p_dim * self.K :] ) else: # Get offset (in normalized scale) from features unscaled_offsets = self.offset_features.view(-1, self.K, self.p_dim) # No modulations modulations = None # Rescale offset for this layer offsets = unscaled_offsets * self.KP_extent else: offsets = None modulations = None ###################### # Deformed convolution ###################### # Add a fake point in the last row for shadow neighbors s_pts = torch.cat((s_pts, torch.zeros_like(s_pts[:1, :]) + 1e6), 0) # Get neighbor points [n_points, n_neighbors, dim] neighbors = s_pts[neighb_inds, :] # Center every neighborhood neighbors = neighbors - q_pts.unsqueeze(1) # Apply offsets to kernel points [n_points, n_kpoints, dim] if self.deformable: self.deformed_KP = offsets + self.kernel_points deformed_K_points = self.deformed_KP.unsqueeze(1) else: deformed_K_points = self.kernel_points # Get all difference matrices [n_points, n_neighbors, n_kpoints, dim] neighbors.unsqueeze_(2) differences = neighbors - deformed_K_points # Get the square distances [n_points, n_neighbors, n_kpoints] sq_distances = torch.sum(differences**2, dim=3) # Optimization by ignoring points outside a deformed KP range if self.deformable: # Save distances for loss self.min_d2, _ = torch.min(sq_distances, dim=1) # Boolean of the neighbors in range of a kernel point [n_points, n_neighbors] in_range = torch.any(sq_distances < self.KP_extent**2, dim=2).type( torch.int32 ) # New value of max neighbors new_max_neighb = torch.max(torch.sum(in_range, dim=1)) # For each row of neighbors, indices of the ones that are in range [n_points, new_max_neighb] neighb_row_bool, neighb_row_inds = torch.topk( in_range, new_max_neighb.item(), dim=1 ) # Gather new neighbor indices [n_points, new_max_neighb] new_neighb_inds = neighb_inds.gather(1, neighb_row_inds, sparse_grad=False) # Gather new distances to KP [n_points, new_max_neighb, n_kpoints] neighb_row_inds.unsqueeze_(2) neighb_row_inds = neighb_row_inds.expand(-1, -1, self.K) sq_distances = sq_distances.gather(1, neighb_row_inds, sparse_grad=False) # New shadow neighbors have to point to the last shadow point new_neighb_inds *= neighb_row_bool new_neighb_inds -= (neighb_row_bool.type(torch.int64) - 1) * int( s_pts.shape[0] - 1 ) else: new_neighb_inds = neighb_inds # Get Kernel point influences [n_points, n_kpoints, n_neighbors] if self.KP_influence == "constant": # Every point get an influence of 1. all_weights = torch.ones_like(sq_distances) all_weights = torch.transpose(all_weights, 1, 2) elif self.KP_influence == "linear": # Influence decrease linearly with the distance, and get to zero when d = KP_extent. all_weights = torch.clamp( 1 - torch.sqrt(sq_distances) / self.KP_extent, min=0.0 ) all_weights = torch.transpose(all_weights, 1, 2) elif self.KP_influence == "gaussian": # Influence in gaussian of the distance. sigma = self.KP_extent * 0.3 all_weights = radius_gaussian(sq_distances, sigma) all_weights = torch.transpose(all_weights, 1, 2) else: raise ValueError("Unknown influence function type (config.KP_influence)") # In case of closest mode, only the closest KP can influence each point if self.aggregation_mode == "closest": neighbors_1nn = torch.argmin(sq_distances, dim=2) all_weights *= torch.transpose( nn.functional.one_hot(neighbors_1nn, self.K), 1, 2 ) elif self.aggregation_mode != "sum": raise ValueError("Unknown convolution mode. Should be 'closest' or 'sum'") # Add a zero feature for shadow neighbors x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) # Get the features of each neighborhood [n_points, n_neighbors, in_fdim] neighb_x = gather(x, new_neighb_inds) # Apply distance weights [n_points, n_kpoints, in_fdim] weighted_features = torch.matmul(all_weights, neighb_x) # Apply modulations if self.deformable and self.modulated: weighted_features *= modulations.unsqueeze(2) # Apply network weights [n_kpoints, n_points, out_fdim] weighted_features = weighted_features.permute((1, 0, 2)) kernel_outputs = torch.matmul(weighted_features, self.weights) # Convolution sum [n_points, out_fdim] return torch.sum(kernel_outputs, dim=0) def __repr__(self): return "KPConv(radius: {:.2f}, in_feat: {:d}, out_feat: {:d})".format( self.radius, self.in_channels, self.out_channels ) # ---------------------------------------------------------------------------------------------------------------------- # # Complex blocks # \********************/ # def block_decider(block_name, radius, in_dim, out_dim, layer_ind, config): if block_name == "unary": return UnaryBlock( in_dim, out_dim, config.use_batch_norm, config.batch_norm_momentum ) elif block_name in [ "simple", "simple_deformable", "simple_invariant", "simple_equivariant", "simple_strided", "simple_deformable_strided", "simple_invariant_strided", "simple_equivariant_strided", ]: return SimpleBlock(block_name, in_dim, out_dim, radius, layer_ind, config) elif block_name in [ "resnetb", "resnetb_invariant", "resnetb_equivariant", "resnetb_deformable", "resnetb_strided", "resnetb_deformable_strided", "resnetb_equivariant_strided", "resnetb_invariant_strided", ]: return ResnetBottleneckBlock( block_name, in_dim, out_dim, radius, layer_ind, config ) elif block_name == "max_pool" or block_name == "max_pool_wide": return MaxPoolBlock(layer_ind) elif block_name == "global_average": return GlobalAverageBlock() elif block_name == "nearest_upsample": return NearestUpsampleBlock(layer_ind) else: raise ValueError( "Unknown block name in the architecture definition : " + block_name ) class BatchNormBlock(nn.Module): def __init__(self, in_dim, use_bn, bn_momentum): """ Initialize a batch normalization block. If network does not use batch normalization, replace with biases. :param in_dim: dimension input features :param use_bn: boolean indicating if we use Batch Norm :param bn_momentum: Batch norm momentum """ super(BatchNormBlock, self).__init__() self.bn_momentum = bn_momentum self.use_bn = use_bn self.in_dim = in_dim if self.use_bn: self.batch_norm = nn.BatchNorm1d(in_dim, momentum=bn_momentum) # self.batch_norm = nn.InstanceNorm1d(in_dim, momentum=bn_momentum) else: self.bias = Parameter( torch.zeros(in_dim, dtype=torch.float32), requires_grad=True ) return def reset_parameters(self): nn.init.zeros_(self.bias) def forward(self, x): if self.use_bn: x = x.unsqueeze(2) x = x.transpose(0, 2) x = self.batch_norm(x) x = x.transpose(0, 2) return x.squeeze() else: return x + self.bias def __repr__(self): return ( "BatchNormBlock(in_feat: {:d}, momentum: {:.3f}, only_bias: {:s})".format( self.in_dim, self.bn_momentum, str(not self.use_bn) ) ) class UnaryBlock(nn.Module): def __init__(self, in_dim, out_dim, use_bn, bn_momentum, no_relu=False): """ Initialize a standard unary block with its ReLU and BatchNorm. :param in_dim: dimension input features :param out_dim: dimension input features :param use_bn: boolean indicating if we use Batch Norm :param bn_momentum: Batch norm momentum """ super(UnaryBlock, self).__init__() self.bn_momentum = bn_momentum self.use_bn = use_bn self.no_relu = no_relu self.in_dim = in_dim self.out_dim = out_dim self.mlp = nn.Linear(in_dim, out_dim, bias=False) self.batch_norm = BatchNormBlock(out_dim, self.use_bn, self.bn_momentum) if not no_relu: self.leaky_relu = nn.LeakyReLU(0.1) return def forward(self, x, batch=None): x = self.mlp(x) x = self.batch_norm(x) if not self.no_relu: x = self.leaky_relu(x) return x def __repr__(self): return "UnaryBlock(in_feat: {:d}, out_feat: {:d}, BN: {:s}, ReLU: {:s})".format( self.in_dim, self.out_dim, str(self.use_bn), str(not self.no_relu) ) class SimpleBlock(nn.Module): def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, config): """ Initialize a simple convolution block with its ReLU and BatchNorm. :param in_dim: dimension input features :param out_dim: dimension input features :param radius: current radius of convolution :param config: parameters """ super(SimpleBlock, self).__init__() # get KP_extent from current radius current_extent = radius * config.KP_extent / config.conv_radius # Get other parameters self.bn_momentum = config.batch_norm_momentum self.use_bn = config.use_batch_norm self.layer_ind = layer_ind self.block_name = block_name self.in_dim = in_dim self.out_dim = out_dim # Define the KPConv class self.KPConv = KPConv( config.num_kernel_points, config.in_points_dim, in_dim, out_dim // 2, current_extent, radius, fixed_kernel_points=config.fixed_kernel_points, KP_influence=config.KP_influence, aggregation_mode=config.aggregation_mode, deformable="deform" in block_name, modulated=config.modulated, ) # Other opperations self.batch_norm = BatchNormBlock(out_dim // 2, self.use_bn, self.bn_momentum) self.leaky_relu = nn.LeakyReLU(0.1) return def forward(self, x, batch): if "strided" in self.block_name: q_pts = batch.points[self.layer_ind + 1] s_pts = batch.points[self.layer_ind] neighb_inds = batch.pools[self.layer_ind] else: q_pts = batch.points[self.layer_ind] s_pts = batch.points[self.layer_ind] neighb_inds = batch.neighbors[self.layer_ind] x = self.KPConv(q_pts, s_pts, neighb_inds, x) return self.leaky_relu(self.batch_norm(x)) class ResnetBottleneckBlock(nn.Module): def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, config): """ Initialize a resnet bottleneck block. :param in_dim: dimension input features :param out_dim: dimension input features :param radius: current radius of convolution :param config: parameters """ super(ResnetBottleneckBlock, self).__init__() # get KP_extent from current radius current_extent = radius * config.KP_extent / config.conv_radius # Get other parameters self.bn_momentum = config.batch_norm_momentum self.use_bn = config.use_batch_norm self.block_name = block_name self.layer_ind = layer_ind self.in_dim = in_dim self.out_dim = out_dim # First downscaling mlp if in_dim != out_dim // 4: self.unary1 = UnaryBlock( in_dim, out_dim // 4, self.use_bn, self.bn_momentum ) else: self.unary1 = nn.Identity() # KPConv block self.KPConv = KPConv( config.num_kernel_points, config.in_points_dim, out_dim // 4, out_dim // 4, current_extent, radius, fixed_kernel_points=config.fixed_kernel_points, KP_influence=config.KP_influence, aggregation_mode=config.aggregation_mode, deformable="deform" in block_name, modulated=config.modulated, ) self.batch_norm_conv = BatchNormBlock( out_dim // 4, self.use_bn, self.bn_momentum ) # Second upscaling mlp self.unary2 = UnaryBlock( out_dim // 4, out_dim, self.use_bn, self.bn_momentum, no_relu=True ) # Shortcut optional mpl if in_dim != out_dim: self.unary_shortcut = UnaryBlock( in_dim, out_dim, self.use_bn, self.bn_momentum, no_relu=True ) else: self.unary_shortcut = nn.Identity() # Other operations self.leaky_relu = nn.LeakyReLU(0.1) return def forward(self, features, batch): if "strided" in self.block_name: q_pts = batch.points[self.layer_ind + 1] s_pts = batch.points[self.layer_ind] neighb_inds = batch.pools[self.layer_ind] else: q_pts = batch.points[self.layer_ind] s_pts = batch.points[self.layer_ind] neighb_inds = batch.neighbors[self.layer_ind] # First downscaling mlp x = self.unary1(features) # Convolution x = self.KPConv(q_pts, s_pts, neighb_inds, x) x = self.leaky_relu(self.batch_norm_conv(x)) # Second upscaling mlp x = self.unary2(x) # Shortcut if "strided" in self.block_name: shortcut = max_pool(features, neighb_inds) else: shortcut = features shortcut = self.unary_shortcut(shortcut) return self.leaky_relu(x + shortcut) class GlobalAverageBlock(nn.Module): def __init__(self): """ Initialize a global average block with its ReLU and BatchNorm. """ super(GlobalAverageBlock, self).__init__() return def forward(self, x, batch): return global_average(x, batch.lengths[-1]) class NearestUpsampleBlock(nn.Module): def __init__(self, layer_ind): """ Initialize a nearest upsampling block with its ReLU and BatchNorm. """ super(NearestUpsampleBlock, self).__init__() self.layer_ind = layer_ind return def forward(self, x, batch): return closest_pool(x, batch.upsamples[self.layer_ind - 1]) def __repr__(self): return "NearestUpsampleBlock(layer: {:d} -> {:d})".format( self.layer_ind, self.layer_ind - 1 ) class MaxPoolBlock(nn.Module): def __init__(self, layer_ind): """ Initialize a max pooling block with its ReLU and BatchNorm. """ super(MaxPoolBlock, self).__init__() self.layer_ind = layer_ind return def forward(self, x, batch): return max_pool(x, batch.pools[self.layer_ind + 1])