2020-03-31 19:42:35 +00:00
|
|
|
#
|
|
|
|
#
|
|
|
|
# 0=================================0
|
|
|
|
# | Kernel Point Convolutions |
|
|
|
|
# 0=================================0
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
|
|
#
|
|
|
|
# Define network architectures
|
|
|
|
#
|
|
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
|
|
#
|
|
|
|
# Hugues THOMAS - 06/03/2020
|
|
|
|
#
|
|
|
|
|
|
|
|
from models.blocks import *
|
|
|
|
import numpy as np
|
|
|
|
|
2020-04-24 16:00:11 +00:00
|
|
|
|
|
|
|
def p2p_fitting_regularizer(net):
|
|
|
|
|
|
|
|
fitting_loss = 0
|
|
|
|
repulsive_loss = 0
|
|
|
|
|
|
|
|
for m in net.modules():
|
|
|
|
|
|
|
|
if isinstance(m, KPConv) and m.deformable:
|
|
|
|
|
|
|
|
##############
|
|
|
|
# Fitting loss
|
|
|
|
##############
|
|
|
|
|
2020-04-27 22:01:40 +00:00
|
|
|
# Get the distance to closest input point and normalize to be independant from layers
|
|
|
|
KP_min_d2 = m.min_d2 / (m.KP_extent ** 2)
|
2020-04-24 16:00:11 +00:00
|
|
|
|
|
|
|
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
|
|
|
|
fitting_loss += net.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
|
|
|
|
|
|
|
|
################
|
|
|
|
# Repulsive loss
|
|
|
|
################
|
|
|
|
|
|
|
|
# Normalized KP locations
|
|
|
|
KP_locs = m.deformed_KP / m.KP_extent
|
|
|
|
|
|
|
|
# Point should not be close to each other
|
|
|
|
for i in range(net.K):
|
|
|
|
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
|
|
|
|
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
|
2020-04-27 22:01:40 +00:00
|
|
|
rep_loss = torch.sum(torch.clamp_max(distances - net.repulse_extent, max=0.0) ** 2, dim=1)
|
2020-04-24 16:00:11 +00:00
|
|
|
repulsive_loss += net.l1(rep_loss, torch.zeros_like(rep_loss)) / net.K
|
|
|
|
|
2020-04-28 15:07:27 +00:00
|
|
|
return net.deform_fitting_power * (2 * fitting_loss + repulsive_loss)
|
2020-04-24 16:00:11 +00:00
|
|
|
|
|
|
|
|
2020-03-31 19:42:35 +00:00
|
|
|
class KPCNN(nn.Module):
|
|
|
|
"""
|
|
|
|
Class defining KPCNN
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, config):
|
|
|
|
super(KPCNN, self).__init__()
|
|
|
|
|
|
|
|
#####################
|
|
|
|
# Network opperations
|
|
|
|
#####################
|
|
|
|
|
|
|
|
# Current radius of convolution and feature dimension
|
|
|
|
layer = 0
|
|
|
|
r = config.first_subsampling_dl * config.conv_radius
|
|
|
|
in_dim = config.in_features_dim
|
|
|
|
out_dim = config.first_features_dim
|
|
|
|
self.K = config.num_kernel_points
|
|
|
|
|
|
|
|
# Save all block operations in a list of modules
|
|
|
|
self.block_ops = nn.ModuleList()
|
|
|
|
|
|
|
|
# Loop over consecutive blocks
|
|
|
|
block_in_layer = 0
|
|
|
|
for block_i, block in enumerate(config.architecture):
|
|
|
|
|
|
|
|
# Check equivariance
|
|
|
|
if ('equivariant' in block) and (not out_dim % 3 == 0):
|
|
|
|
raise ValueError('Equivariant block but features dimension is not a factor of 3')
|
|
|
|
|
|
|
|
# Detect upsampling block to stop
|
|
|
|
if 'upsample' in block:
|
|
|
|
break
|
|
|
|
|
|
|
|
# Apply the good block function defining tf ops
|
|
|
|
self.block_ops.append(block_decider(block,
|
|
|
|
r,
|
|
|
|
in_dim,
|
|
|
|
out_dim,
|
|
|
|
layer,
|
|
|
|
config))
|
|
|
|
|
2020-04-02 21:31:35 +00:00
|
|
|
|
2020-03-31 19:42:35 +00:00
|
|
|
# Index of block in this layer
|
|
|
|
block_in_layer += 1
|
|
|
|
|
|
|
|
# Update dimension of input from output
|
2020-04-03 15:22:57 +00:00
|
|
|
if 'simple' in block:
|
|
|
|
in_dim = out_dim // 2
|
|
|
|
else:
|
|
|
|
in_dim = out_dim
|
|
|
|
|
2020-03-31 19:42:35 +00:00
|
|
|
|
|
|
|
# Detect change to a subsampled layer
|
|
|
|
if 'pool' in block or 'strided' in block:
|
|
|
|
# Update radius and feature dimension for next layer
|
|
|
|
layer += 1
|
|
|
|
r *= 2
|
|
|
|
out_dim *= 2
|
|
|
|
block_in_layer = 0
|
|
|
|
|
|
|
|
self.head_mlp = UnaryBlock(out_dim, 1024, False, 0)
|
|
|
|
self.head_softmax = UnaryBlock(1024, config.num_classes, False, 0)
|
|
|
|
|
|
|
|
################
|
|
|
|
# Network Losses
|
|
|
|
################
|
|
|
|
|
|
|
|
self.criterion = torch.nn.CrossEntropyLoss()
|
2020-04-24 16:00:11 +00:00
|
|
|
self.deform_fitting_mode = config.deform_fitting_mode
|
|
|
|
self.deform_fitting_power = config.deform_fitting_power
|
2020-04-28 15:07:27 +00:00
|
|
|
self.deform_lr_factor = config.deform_lr_factor
|
2020-04-27 22:01:40 +00:00
|
|
|
self.repulse_extent = config.repulse_extent
|
2020-03-31 19:42:35 +00:00
|
|
|
self.output_loss = 0
|
|
|
|
self.reg_loss = 0
|
|
|
|
self.l1 = nn.L1Loss()
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
def forward(self, batch, config):
|
|
|
|
|
|
|
|
# Save all block operations in a list of modules
|
|
|
|
x = batch.features.clone().detach()
|
|
|
|
|
|
|
|
# Loop over consecutive blocks
|
|
|
|
for block_op in self.block_ops:
|
|
|
|
x = block_op(x, batch)
|
|
|
|
|
|
|
|
# Head of network
|
|
|
|
x = self.head_mlp(x, batch)
|
|
|
|
x = self.head_softmax(x, batch)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
def loss(self, outputs, labels):
|
|
|
|
"""
|
|
|
|
Runs the loss on outputs of the model
|
|
|
|
:param outputs: logits
|
|
|
|
:param labels: labels
|
|
|
|
:return: loss
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Cross entropy loss
|
|
|
|
self.output_loss = self.criterion(outputs, labels)
|
|
|
|
|
|
|
|
# Regularization of deformable offsets
|
2020-04-24 16:00:11 +00:00
|
|
|
if self.deform_fitting_mode == 'point2point':
|
|
|
|
self.reg_loss = p2p_fitting_regularizer(self)
|
|
|
|
elif self.deform_fitting_mode == 'point2plane':
|
|
|
|
raise ValueError('point2plane fitting mode not implemented yet.')
|
|
|
|
else:
|
|
|
|
raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode)
|
2020-03-31 19:42:35 +00:00
|
|
|
|
|
|
|
# Combined loss
|
|
|
|
return self.output_loss + self.reg_loss
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def accuracy(outputs, labels):
|
|
|
|
"""
|
|
|
|
Computes accuracy of the current batch
|
|
|
|
:param outputs: logits predicted by the network
|
|
|
|
:param labels: labels
|
|
|
|
:return: accuracy value
|
|
|
|
"""
|
|
|
|
|
|
|
|
predicted = torch.argmax(outputs.data, dim=1)
|
|
|
|
total = labels.size(0)
|
|
|
|
correct = (predicted == labels).sum().item()
|
|
|
|
|
|
|
|
return correct / total
|
|
|
|
|
|
|
|
|
|
|
|
class KPFCNN(nn.Module):
|
|
|
|
"""
|
|
|
|
Class defining KPFCNN
|
|
|
|
"""
|
|
|
|
|
2020-04-09 21:13:27 +00:00
|
|
|
def __init__(self, config, lbl_values, ign_lbls):
|
2020-03-31 19:42:35 +00:00
|
|
|
super(KPFCNN, self).__init__()
|
|
|
|
|
|
|
|
############
|
|
|
|
# Parameters
|
|
|
|
############
|
|
|
|
|
|
|
|
# Current radius of convolution and feature dimension
|
|
|
|
layer = 0
|
|
|
|
r = config.first_subsampling_dl * config.conv_radius
|
|
|
|
in_dim = config.in_features_dim
|
|
|
|
out_dim = config.first_features_dim
|
|
|
|
self.K = config.num_kernel_points
|
2020-04-09 21:13:27 +00:00
|
|
|
self.C = len(lbl_values) - len(ign_lbls)
|
2020-03-31 19:42:35 +00:00
|
|
|
|
|
|
|
#####################
|
|
|
|
# List Encoder blocks
|
|
|
|
#####################
|
|
|
|
|
|
|
|
# Save all block operations in a list of modules
|
2020-04-23 13:51:16 +00:00
|
|
|
self.encoder_blocks = nn.ModuleList()
|
2020-03-31 19:42:35 +00:00
|
|
|
self.encoder_skip_dims = []
|
|
|
|
self.encoder_skips = []
|
|
|
|
|
|
|
|
# Loop over consecutive blocks
|
|
|
|
for block_i, block in enumerate(config.architecture):
|
|
|
|
|
|
|
|
# Check equivariance
|
|
|
|
if ('equivariant' in block) and (not out_dim % 3 == 0):
|
|
|
|
raise ValueError('Equivariant block but features dimension is not a factor of 3')
|
|
|
|
|
|
|
|
# Detect change to next layer for skip connection
|
|
|
|
if np.any([tmp in block for tmp in ['pool', 'strided', 'upsample', 'global']]):
|
|
|
|
self.encoder_skips.append(block_i)
|
|
|
|
self.encoder_skip_dims.append(in_dim)
|
|
|
|
|
|
|
|
# Detect upsampling block to stop
|
|
|
|
if 'upsample' in block:
|
|
|
|
break
|
|
|
|
|
|
|
|
# Apply the good block function defining tf ops
|
2020-04-23 13:51:16 +00:00
|
|
|
self.encoder_blocks.append(block_decider(block,
|
2020-03-31 19:42:35 +00:00
|
|
|
r,
|
|
|
|
in_dim,
|
|
|
|
out_dim,
|
|
|
|
layer,
|
|
|
|
config))
|
|
|
|
|
|
|
|
# Update dimension of input from output
|
2020-04-03 15:22:57 +00:00
|
|
|
if 'simple' in block:
|
|
|
|
in_dim = out_dim // 2
|
|
|
|
else:
|
|
|
|
in_dim = out_dim
|
2020-03-31 19:42:35 +00:00
|
|
|
|
|
|
|
# Detect change to a subsampled layer
|
|
|
|
if 'pool' in block or 'strided' in block:
|
|
|
|
# Update radius and feature dimension for next layer
|
|
|
|
layer += 1
|
|
|
|
r *= 2
|
|
|
|
out_dim *= 2
|
|
|
|
|
|
|
|
#####################
|
|
|
|
# List Decoder blocks
|
|
|
|
#####################
|
|
|
|
|
|
|
|
# Save all block operations in a list of modules
|
2020-04-23 13:51:16 +00:00
|
|
|
self.decoder_blocks = nn.ModuleList()
|
2020-03-31 19:42:35 +00:00
|
|
|
self.decoder_concats = []
|
|
|
|
|
|
|
|
# Find first upsampling block
|
|
|
|
start_i = 0
|
|
|
|
for block_i, block in enumerate(config.architecture):
|
|
|
|
if 'upsample' in block:
|
|
|
|
start_i = block_i
|
|
|
|
break
|
|
|
|
|
|
|
|
# Loop over consecutive blocks
|
|
|
|
for block_i, block in enumerate(config.architecture[start_i:]):
|
|
|
|
|
|
|
|
# Add dimension of skip connection concat
|
|
|
|
if block_i > 0 and 'upsample' in config.architecture[start_i + block_i - 1]:
|
|
|
|
in_dim += self.encoder_skip_dims[layer]
|
|
|
|
self.decoder_concats.append(block_i)
|
|
|
|
|
|
|
|
# Apply the good block function defining tf ops
|
2020-04-23 13:51:16 +00:00
|
|
|
self.decoder_blocks.append(block_decider(block,
|
2020-03-31 19:42:35 +00:00
|
|
|
r,
|
|
|
|
in_dim,
|
|
|
|
out_dim,
|
|
|
|
layer,
|
|
|
|
config))
|
|
|
|
|
|
|
|
# Update dimension of input from output
|
|
|
|
in_dim = out_dim
|
|
|
|
|
|
|
|
# Detect change to a subsampled layer
|
|
|
|
if 'upsample' in block:
|
|
|
|
# Update radius and feature dimension for next layer
|
|
|
|
layer -= 1
|
|
|
|
r *= 0.5
|
|
|
|
out_dim = out_dim // 2
|
|
|
|
|
|
|
|
self.head_mlp = UnaryBlock(out_dim, config.first_features_dim, False, 0)
|
2020-04-09 21:13:27 +00:00
|
|
|
self.head_softmax = UnaryBlock(config.first_features_dim, self.C, False, 0)
|
2020-03-31 19:42:35 +00:00
|
|
|
|
|
|
|
################
|
|
|
|
# Network Losses
|
|
|
|
################
|
|
|
|
|
2020-04-09 21:13:27 +00:00
|
|
|
# List of valid labels (those not ignored in loss)
|
|
|
|
self.valid_labels = np.sort([c for c in lbl_values if c not in ign_lbls])
|
|
|
|
|
2020-03-31 19:42:35 +00:00
|
|
|
# Choose segmentation loss
|
2020-04-09 21:13:27 +00:00
|
|
|
if len(config.class_w) > 0:
|
|
|
|
class_w = torch.from_numpy(np.array(config.class_w, dtype=np.float32))
|
|
|
|
self.criterion = torch.nn.CrossEntropyLoss(weight=class_w, ignore_index=-1)
|
2020-03-31 19:42:35 +00:00
|
|
|
else:
|
2020-04-09 21:13:27 +00:00
|
|
|
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
|
2020-04-24 16:00:11 +00:00
|
|
|
self.deform_fitting_mode = config.deform_fitting_mode
|
|
|
|
self.deform_fitting_power = config.deform_fitting_power
|
2020-04-28 15:07:27 +00:00
|
|
|
self.deform_lr_factor = config.deform_lr_factor
|
2020-04-27 22:01:40 +00:00
|
|
|
self.repulse_extent = config.repulse_extent
|
2020-03-31 19:42:35 +00:00
|
|
|
self.output_loss = 0
|
|
|
|
self.reg_loss = 0
|
|
|
|
self.l1 = nn.L1Loss()
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
def forward(self, batch, config):
|
|
|
|
|
|
|
|
# Get input features
|
|
|
|
x = batch.features.clone().detach()
|
|
|
|
|
|
|
|
# Loop over consecutive blocks
|
|
|
|
skip_x = []
|
2020-04-23 13:51:16 +00:00
|
|
|
for block_i, block_op in enumerate(self.encoder_blocks):
|
2020-03-31 19:42:35 +00:00
|
|
|
if block_i in self.encoder_skips:
|
|
|
|
skip_x.append(x)
|
|
|
|
x = block_op(x, batch)
|
|
|
|
|
2020-04-23 13:51:16 +00:00
|
|
|
for block_i, block_op in enumerate(self.decoder_blocks):
|
2020-03-31 19:42:35 +00:00
|
|
|
if block_i in self.decoder_concats:
|
|
|
|
x = torch.cat([x, skip_x.pop()], dim=1)
|
|
|
|
x = block_op(x, batch)
|
|
|
|
|
|
|
|
# Head of network
|
|
|
|
x = self.head_mlp(x, batch)
|
|
|
|
x = self.head_softmax(x, batch)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
def loss(self, outputs, labels):
|
|
|
|
"""
|
|
|
|
Runs the loss on outputs of the model
|
|
|
|
:param outputs: logits
|
|
|
|
:param labels: labels
|
|
|
|
:return: loss
|
|
|
|
"""
|
|
|
|
|
2020-04-09 21:13:27 +00:00
|
|
|
# Set all ignored labels to -1 and correct the other label to be in [0, C-1] range
|
|
|
|
target = - torch.ones_like(labels)
|
|
|
|
for i, c in enumerate(self.valid_labels):
|
|
|
|
target[labels == c] = i
|
|
|
|
|
|
|
|
# Reshape to have a minibatch size of 1
|
2020-03-31 19:42:35 +00:00
|
|
|
outputs = torch.transpose(outputs, 0, 1)
|
|
|
|
outputs = outputs.unsqueeze(0)
|
2020-04-09 21:13:27 +00:00
|
|
|
target = target.unsqueeze(0)
|
2020-03-31 19:42:35 +00:00
|
|
|
|
|
|
|
# Cross entropy loss
|
2020-04-09 21:13:27 +00:00
|
|
|
self.output_loss = self.criterion(outputs, target)
|
2020-03-31 19:42:35 +00:00
|
|
|
|
|
|
|
# Regularization of deformable offsets
|
2020-04-24 16:00:11 +00:00
|
|
|
if self.deform_fitting_mode == 'point2point':
|
|
|
|
self.reg_loss = p2p_fitting_regularizer(self)
|
|
|
|
elif self.deform_fitting_mode == 'point2plane':
|
|
|
|
raise ValueError('point2plane fitting mode not implemented yet.')
|
|
|
|
else:
|
|
|
|
raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode)
|
2020-03-31 19:42:35 +00:00
|
|
|
|
|
|
|
# Combined loss
|
|
|
|
return self.output_loss + self.reg_loss
|
|
|
|
|
2020-04-09 21:13:27 +00:00
|
|
|
def accuracy(self, outputs, labels):
|
2020-03-31 19:42:35 +00:00
|
|
|
"""
|
|
|
|
Computes accuracy of the current batch
|
|
|
|
:param outputs: logits predicted by the network
|
|
|
|
:param labels: labels
|
|
|
|
:return: accuracy value
|
|
|
|
"""
|
|
|
|
|
2020-04-09 21:13:27 +00:00
|
|
|
# Set all ignored labels to -1 and correct the other label to be in [0, C-1] range
|
|
|
|
target = - torch.ones_like(labels)
|
|
|
|
for i, c in enumerate(self.valid_labels):
|
|
|
|
target[labels == c] = i
|
|
|
|
|
2020-03-31 19:42:35 +00:00
|
|
|
predicted = torch.argmax(outputs.data, dim=1)
|
2020-04-09 21:13:27 +00:00
|
|
|
total = target.size(0)
|
|
|
|
correct = (predicted == target).sum().item()
|
2020-03-31 19:42:35 +00:00
|
|
|
|
|
|
|
return correct / total
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|