KPConv-PyTorch/models/blocks.py
Laurent FAINSIN d0cdb8e4ee 🎨 black + ruff
2023-05-15 17:18:10 +02:00

728 lines
24 KiB
Python

#
#
# 0=================================0
# | Kernel Point Convolutions |
# 0=================================0
#
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Define network blocks
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Hugues THOMAS - 06/03/2020
#
import math
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn.init import kaiming_uniform_
from kernels.kernel_points import load_kernels
# ----------------------------------------------------------------------------------------------------------------------
#
# Simple functions
# \**********************/
#
def gather(x, idx, method=2):
"""
implementation of a custom gather operation for faster backwards.
:param x: input with shape [N, D_1, ... D_d]
:param idx: indexing with shape [n_1, ..., n_m]
:param method: Choice of the method
:return: x[idx] with shape [n_1, ..., n_m, D_1, ... D_d]
"""
if method == 0:
return x[idx]
elif method == 1:
x = x.unsqueeze(1)
x = x.expand((-1, idx.shape[-1], -1))
idx = idx.unsqueeze(2)
idx = idx.expand((-1, -1, x.shape[-1]))
return x.gather(0, idx)
elif method == 2:
for i, ni in enumerate(idx.size()[1:]):
x = x.unsqueeze(i + 1)
new_s = list(x.size())
new_s[i + 1] = ni
x = x.expand(new_s)
n = len(idx.size())
for i, di in enumerate(x.size()[n:]):
idx = idx.unsqueeze(i + n)
new_s = list(idx.size())
new_s[i + n] = di
idx = idx.expand(new_s)
return x.gather(0, idx)
else:
raise ValueError("Unkown method")
def radius_gaussian(sq_r, sig, eps=1e-9):
"""
Compute a radius gaussian (gaussian of distance)
:param sq_r: input radiuses [dn, ..., d1, d0]
:param sig: extents of gaussians [d1, d0] or [d0] or float
:return: gaussian of sq_r [dn, ..., d1, d0]
"""
return torch.exp(-sq_r / (2 * sig**2 + eps))
def closest_pool(x, inds):
"""
Pools features from the closest neighbors. WARNING: this function assumes the neighbors are ordered.
:param x: [n1, d] features matrix
:param inds: [n2, max_num] Only the first column is used for pooling
:return: [n2, d] pooled features matrix
"""
# Add a last row with minimum features for shadow pools
x = torch.cat((x, torch.zeros_like(x[:1, :])), 0)
# Get features for each pooling location [n2, d]
return gather(x, inds[:, 0])
def max_pool(x, inds):
"""
Pools features with the maximum values.
:param x: [n1, d] features matrix
:param inds: [n2, max_num] pooling indices
:return: [n2, d] pooled features matrix
"""
# Add a last row with minimum features for shadow pools
x = torch.cat((x, torch.zeros_like(x[:1, :])), 0)
# Get all features for each pooling location [n2, max_num, d]
pool_features = gather(x, inds)
# Pool the maximum [n2, d]
max_features, _ = torch.max(pool_features, 1)
return max_features
def global_average(x, batch_lengths):
"""
Block performing a global average over batch pooling
:param x: [N, D] input features
:param batch_lengths: [B] list of batch lengths
:return: [B, D] averaged features
"""
# Loop over the clouds of the batch
averaged_features = []
i0 = 0
for b_i, length in enumerate(batch_lengths):
# Average features for each batch cloud
averaged_features.append(torch.mean(x[i0 : i0 + length], dim=0))
# Increment for next cloud
i0 += length
# Average features in each batch
return torch.stack(averaged_features)
# ----------------------------------------------------------------------------------------------------------------------
#
# KPConv class
# \******************/
#
class KPConv(nn.Module):
def __init__(
self,
kernel_size,
p_dim,
in_channels,
out_channels,
KP_extent,
radius,
fixed_kernel_points="center",
KP_influence="linear",
aggregation_mode="sum",
deformable=False,
modulated=False,
):
"""
Initialize parameters for KPConvDeformable.
:param kernel_size: Number of kernel points.
:param p_dim: dimension of the point space.
:param in_channels: dimension of input features.
:param out_channels: dimension of output features.
:param KP_extent: influence radius of each kernel point.
:param radius: radius used for kernel point init. Even for deformable, use the config.conv_radius
:param fixed_kernel_points: fix position of certain kernel points ('none', 'center' or 'verticals').
:param KP_influence: influence function of the kernel points ('constant', 'linear', 'gaussian').
:param aggregation_mode: choose to sum influences, or only keep the closest ('closest', 'sum').
:param deformable: choose deformable or not
:param modulated: choose if kernel weights are modulated in addition to deformed
"""
super(KPConv, self).__init__()
# Save parameters
self.K = kernel_size
self.p_dim = p_dim
self.in_channels = in_channels
self.out_channels = out_channels
self.radius = radius
self.KP_extent = KP_extent
self.fixed_kernel_points = fixed_kernel_points
self.KP_influence = KP_influence
self.aggregation_mode = aggregation_mode
self.deformable = deformable
self.modulated = modulated
# Running variable containing deformed KP distance to input points. (used in regularization loss)
self.min_d2 = None
self.deformed_KP = None
self.offset_features = None
# Initialize weights
self.weights = Parameter(
torch.zeros((self.K, in_channels, out_channels), dtype=torch.float32),
requires_grad=True,
)
# Initiate weights for offsets
if deformable:
if modulated:
self.offset_dim = (self.p_dim + 1) * self.K
else:
self.offset_dim = self.p_dim * self.K
self.offset_conv = KPConv(
self.K,
self.p_dim,
self.in_channels,
self.offset_dim,
KP_extent,
radius,
fixed_kernel_points=fixed_kernel_points,
KP_influence=KP_influence,
aggregation_mode=aggregation_mode,
)
self.offset_bias = Parameter(
torch.zeros(self.offset_dim, dtype=torch.float32), requires_grad=True
)
else:
self.offset_dim = None
self.offset_conv = None
self.offset_bias = None
# Reset parameters
self.reset_parameters()
# Initialize kernel points
self.kernel_points = self.init_KP()
return
def reset_parameters(self):
kaiming_uniform_(self.weights, a=math.sqrt(5))
if self.deformable:
nn.init.zeros_(self.offset_bias)
return
def init_KP(self):
"""
Initialize the kernel point positions in a sphere
:return: the tensor of kernel points
"""
# Create one kernel disposition (as numpy array). Choose the KP distance to center thanks to the KP extent
K_points_numpy = load_kernels(
self.radius, self.K, dimension=self.p_dim, fixed=self.fixed_kernel_points
)
return Parameter(
torch.tensor(K_points_numpy, dtype=torch.float32), requires_grad=False
)
def forward(self, q_pts, s_pts, neighb_inds, x):
###################
# Offset generation
###################
if self.deformable:
# Get offsets with a KPConv that only takes part of the features
self.offset_features = (
self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias
)
if self.modulated:
# Get offset (in normalized scale) from features
unscaled_offsets = self.offset_features[:, : self.p_dim * self.K]
unscaled_offsets = unscaled_offsets.view(-1, self.K, self.p_dim)
# Get modulations
modulations = 2 * torch.sigmoid(
self.offset_features[:, self.p_dim * self.K :]
)
else:
# Get offset (in normalized scale) from features
unscaled_offsets = self.offset_features.view(-1, self.K, self.p_dim)
# No modulations
modulations = None
# Rescale offset for this layer
offsets = unscaled_offsets * self.KP_extent
else:
offsets = None
modulations = None
######################
# Deformed convolution
######################
# Add a fake point in the last row for shadow neighbors
s_pts = torch.cat((s_pts, torch.zeros_like(s_pts[:1, :]) + 1e6), 0)
# Get neighbor points [n_points, n_neighbors, dim]
neighbors = s_pts[neighb_inds, :]
# Center every neighborhood
neighbors = neighbors - q_pts.unsqueeze(1)
# Apply offsets to kernel points [n_points, n_kpoints, dim]
if self.deformable:
self.deformed_KP = offsets + self.kernel_points
deformed_K_points = self.deformed_KP.unsqueeze(1)
else:
deformed_K_points = self.kernel_points
# Get all difference matrices [n_points, n_neighbors, n_kpoints, dim]
neighbors.unsqueeze_(2)
differences = neighbors - deformed_K_points
# Get the square distances [n_points, n_neighbors, n_kpoints]
sq_distances = torch.sum(differences**2, dim=3)
# Optimization by ignoring points outside a deformed KP range
if self.deformable:
# Save distances for loss
self.min_d2, _ = torch.min(sq_distances, dim=1)
# Boolean of the neighbors in range of a kernel point [n_points, n_neighbors]
in_range = torch.any(sq_distances < self.KP_extent**2, dim=2).type(
torch.int32
)
# New value of max neighbors
new_max_neighb = torch.max(torch.sum(in_range, dim=1))
# For each row of neighbors, indices of the ones that are in range [n_points, new_max_neighb]
neighb_row_bool, neighb_row_inds = torch.topk(
in_range, new_max_neighb.item(), dim=1
)
# Gather new neighbor indices [n_points, new_max_neighb]
new_neighb_inds = neighb_inds.gather(1, neighb_row_inds, sparse_grad=False)
# Gather new distances to KP [n_points, new_max_neighb, n_kpoints]
neighb_row_inds.unsqueeze_(2)
neighb_row_inds = neighb_row_inds.expand(-1, -1, self.K)
sq_distances = sq_distances.gather(1, neighb_row_inds, sparse_grad=False)
# New shadow neighbors have to point to the last shadow point
new_neighb_inds *= neighb_row_bool
new_neighb_inds -= (neighb_row_bool.type(torch.int64) - 1) * int(
s_pts.shape[0] - 1
)
else:
new_neighb_inds = neighb_inds
# Get Kernel point influences [n_points, n_kpoints, n_neighbors]
if self.KP_influence == "constant":
# Every point get an influence of 1.
all_weights = torch.ones_like(sq_distances)
all_weights = torch.transpose(all_weights, 1, 2)
elif self.KP_influence == "linear":
# Influence decrease linearly with the distance, and get to zero when d = KP_extent.
all_weights = torch.clamp(
1 - torch.sqrt(sq_distances) / self.KP_extent, min=0.0
)
all_weights = torch.transpose(all_weights, 1, 2)
elif self.KP_influence == "gaussian":
# Influence in gaussian of the distance.
sigma = self.KP_extent * 0.3
all_weights = radius_gaussian(sq_distances, sigma)
all_weights = torch.transpose(all_weights, 1, 2)
else:
raise ValueError("Unknown influence function type (config.KP_influence)")
# In case of closest mode, only the closest KP can influence each point
if self.aggregation_mode == "closest":
neighbors_1nn = torch.argmin(sq_distances, dim=2)
all_weights *= torch.transpose(
nn.functional.one_hot(neighbors_1nn, self.K), 1, 2
)
elif self.aggregation_mode != "sum":
raise ValueError("Unknown convolution mode. Should be 'closest' or 'sum'")
# Add a zero feature for shadow neighbors
x = torch.cat((x, torch.zeros_like(x[:1, :])), 0)
# Get the features of each neighborhood [n_points, n_neighbors, in_fdim]
neighb_x = gather(x, new_neighb_inds)
# Apply distance weights [n_points, n_kpoints, in_fdim]
weighted_features = torch.matmul(all_weights, neighb_x)
# Apply modulations
if self.deformable and self.modulated:
weighted_features *= modulations.unsqueeze(2)
# Apply network weights [n_kpoints, n_points, out_fdim]
weighted_features = weighted_features.permute((1, 0, 2))
kernel_outputs = torch.matmul(weighted_features, self.weights)
# Convolution sum [n_points, out_fdim]
return torch.sum(kernel_outputs, dim=0)
def __repr__(self):
return "KPConv(radius: {:.2f}, in_feat: {:d}, out_feat: {:d})".format(
self.radius, self.in_channels, self.out_channels
)
# ----------------------------------------------------------------------------------------------------------------------
#
# Complex blocks
# \********************/
#
def block_decider(block_name, radius, in_dim, out_dim, layer_ind, config):
if block_name == "unary":
return UnaryBlock(
in_dim, out_dim, config.use_batch_norm, config.batch_norm_momentum
)
elif block_name in [
"simple",
"simple_deformable",
"simple_invariant",
"simple_equivariant",
"simple_strided",
"simple_deformable_strided",
"simple_invariant_strided",
"simple_equivariant_strided",
]:
return SimpleBlock(block_name, in_dim, out_dim, radius, layer_ind, config)
elif block_name in [
"resnetb",
"resnetb_invariant",
"resnetb_equivariant",
"resnetb_deformable",
"resnetb_strided",
"resnetb_deformable_strided",
"resnetb_equivariant_strided",
"resnetb_invariant_strided",
]:
return ResnetBottleneckBlock(
block_name, in_dim, out_dim, radius, layer_ind, config
)
elif block_name == "max_pool" or block_name == "max_pool_wide":
return MaxPoolBlock(layer_ind)
elif block_name == "global_average":
return GlobalAverageBlock()
elif block_name == "nearest_upsample":
return NearestUpsampleBlock(layer_ind)
else:
raise ValueError(
"Unknown block name in the architecture definition : " + block_name
)
class BatchNormBlock(nn.Module):
def __init__(self, in_dim, use_bn, bn_momentum):
"""
Initialize a batch normalization block. If network does not use batch normalization, replace with biases.
:param in_dim: dimension input features
:param use_bn: boolean indicating if we use Batch Norm
:param bn_momentum: Batch norm momentum
"""
super(BatchNormBlock, self).__init__()
self.bn_momentum = bn_momentum
self.use_bn = use_bn
self.in_dim = in_dim
if self.use_bn:
self.batch_norm = nn.BatchNorm1d(in_dim, momentum=bn_momentum)
# self.batch_norm = nn.InstanceNorm1d(in_dim, momentum=bn_momentum)
else:
self.bias = Parameter(
torch.zeros(in_dim, dtype=torch.float32), requires_grad=True
)
return
def reset_parameters(self):
nn.init.zeros_(self.bias)
def forward(self, x):
if self.use_bn:
x = x.unsqueeze(2)
x = x.transpose(0, 2)
x = self.batch_norm(x)
x = x.transpose(0, 2)
return x.squeeze()
else:
return x + self.bias
def __repr__(self):
return (
"BatchNormBlock(in_feat: {:d}, momentum: {:.3f}, only_bias: {:s})".format(
self.in_dim, self.bn_momentum, str(not self.use_bn)
)
)
class UnaryBlock(nn.Module):
def __init__(self, in_dim, out_dim, use_bn, bn_momentum, no_relu=False):
"""
Initialize a standard unary block with its ReLU and BatchNorm.
:param in_dim: dimension input features
:param out_dim: dimension input features
:param use_bn: boolean indicating if we use Batch Norm
:param bn_momentum: Batch norm momentum
"""
super(UnaryBlock, self).__init__()
self.bn_momentum = bn_momentum
self.use_bn = use_bn
self.no_relu = no_relu
self.in_dim = in_dim
self.out_dim = out_dim
self.mlp = nn.Linear(in_dim, out_dim, bias=False)
self.batch_norm = BatchNormBlock(out_dim, self.use_bn, self.bn_momentum)
if not no_relu:
self.leaky_relu = nn.LeakyReLU(0.1)
return
def forward(self, x, batch=None):
x = self.mlp(x)
x = self.batch_norm(x)
if not self.no_relu:
x = self.leaky_relu(x)
return x
def __repr__(self):
return "UnaryBlock(in_feat: {:d}, out_feat: {:d}, BN: {:s}, ReLU: {:s})".format(
self.in_dim, self.out_dim, str(self.use_bn), str(not self.no_relu)
)
class SimpleBlock(nn.Module):
def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, config):
"""
Initialize a simple convolution block with its ReLU and BatchNorm.
:param in_dim: dimension input features
:param out_dim: dimension input features
:param radius: current radius of convolution
:param config: parameters
"""
super(SimpleBlock, self).__init__()
# get KP_extent from current radius
current_extent = radius * config.KP_extent / config.conv_radius
# Get other parameters
self.bn_momentum = config.batch_norm_momentum
self.use_bn = config.use_batch_norm
self.layer_ind = layer_ind
self.block_name = block_name
self.in_dim = in_dim
self.out_dim = out_dim
# Define the KPConv class
self.KPConv = KPConv(
config.num_kernel_points,
config.in_points_dim,
in_dim,
out_dim // 2,
current_extent,
radius,
fixed_kernel_points=config.fixed_kernel_points,
KP_influence=config.KP_influence,
aggregation_mode=config.aggregation_mode,
deformable="deform" in block_name,
modulated=config.modulated,
)
# Other opperations
self.batch_norm = BatchNormBlock(out_dim // 2, self.use_bn, self.bn_momentum)
self.leaky_relu = nn.LeakyReLU(0.1)
return
def forward(self, x, batch):
if "strided" in self.block_name:
q_pts = batch.points[self.layer_ind + 1]
s_pts = batch.points[self.layer_ind]
neighb_inds = batch.pools[self.layer_ind]
else:
q_pts = batch.points[self.layer_ind]
s_pts = batch.points[self.layer_ind]
neighb_inds = batch.neighbors[self.layer_ind]
x = self.KPConv(q_pts, s_pts, neighb_inds, x)
return self.leaky_relu(self.batch_norm(x))
class ResnetBottleneckBlock(nn.Module):
def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, config):
"""
Initialize a resnet bottleneck block.
:param in_dim: dimension input features
:param out_dim: dimension input features
:param radius: current radius of convolution
:param config: parameters
"""
super(ResnetBottleneckBlock, self).__init__()
# get KP_extent from current radius
current_extent = radius * config.KP_extent / config.conv_radius
# Get other parameters
self.bn_momentum = config.batch_norm_momentum
self.use_bn = config.use_batch_norm
self.block_name = block_name
self.layer_ind = layer_ind
self.in_dim = in_dim
self.out_dim = out_dim
# First downscaling mlp
if in_dim != out_dim // 4:
self.unary1 = UnaryBlock(
in_dim, out_dim // 4, self.use_bn, self.bn_momentum
)
else:
self.unary1 = nn.Identity()
# KPConv block
self.KPConv = KPConv(
config.num_kernel_points,
config.in_points_dim,
out_dim // 4,
out_dim // 4,
current_extent,
radius,
fixed_kernel_points=config.fixed_kernel_points,
KP_influence=config.KP_influence,
aggregation_mode=config.aggregation_mode,
deformable="deform" in block_name,
modulated=config.modulated,
)
self.batch_norm_conv = BatchNormBlock(
out_dim // 4, self.use_bn, self.bn_momentum
)
# Second upscaling mlp
self.unary2 = UnaryBlock(
out_dim // 4, out_dim, self.use_bn, self.bn_momentum, no_relu=True
)
# Shortcut optional mpl
if in_dim != out_dim:
self.unary_shortcut = UnaryBlock(
in_dim, out_dim, self.use_bn, self.bn_momentum, no_relu=True
)
else:
self.unary_shortcut = nn.Identity()
# Other operations
self.leaky_relu = nn.LeakyReLU(0.1)
return
def forward(self, features, batch):
if "strided" in self.block_name:
q_pts = batch.points[self.layer_ind + 1]
s_pts = batch.points[self.layer_ind]
neighb_inds = batch.pools[self.layer_ind]
else:
q_pts = batch.points[self.layer_ind]
s_pts = batch.points[self.layer_ind]
neighb_inds = batch.neighbors[self.layer_ind]
# First downscaling mlp
x = self.unary1(features)
# Convolution
x = self.KPConv(q_pts, s_pts, neighb_inds, x)
x = self.leaky_relu(self.batch_norm_conv(x))
# Second upscaling mlp
x = self.unary2(x)
# Shortcut
if "strided" in self.block_name:
shortcut = max_pool(features, neighb_inds)
else:
shortcut = features
shortcut = self.unary_shortcut(shortcut)
return self.leaky_relu(x + shortcut)
class GlobalAverageBlock(nn.Module):
def __init__(self):
"""
Initialize a global average block with its ReLU and BatchNorm.
"""
super(GlobalAverageBlock, self).__init__()
return
def forward(self, x, batch):
return global_average(x, batch.lengths[-1])
class NearestUpsampleBlock(nn.Module):
def __init__(self, layer_ind):
"""
Initialize a nearest upsampling block with its ReLU and BatchNorm.
"""
super(NearestUpsampleBlock, self).__init__()
self.layer_ind = layer_ind
return
def forward(self, x, batch):
return closest_pool(x, batch.upsamples[self.layer_ind - 1])
def __repr__(self):
return "NearestUpsampleBlock(layer: {:d} -> {:d})".format(
self.layer_ind, self.layer_ind - 1
)
class MaxPoolBlock(nn.Module):
def __init__(self, layer_ind):
"""
Initialize a max pooling block with its ReLU and BatchNorm.
"""
super(MaxPoolBlock, self).__init__()
self.layer_ind = layer_ind
return
def forward(self, x, batch):
return max_pool(x, batch.pools[self.layer_ind + 1])