KPConv-PyTorch/models/architectures.py

467 lines
14 KiB
Python
Raw Normal View History

2020-03-31 19:42:35 +00:00
#
#
# 0=================================0
# | Kernel Point Convolutions |
# 0=================================0
#
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Define network architectures
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Hugues THOMAS - 06/03/2020
#
from models.blocks import *
import numpy as np
class KPCNN(nn.Module):
"""
Class defining KPCNN
"""
def __init__(self, config):
super(KPCNN, self).__init__()
#####################
# Network opperations
#####################
# Current radius of convolution and feature dimension
layer = 0
r = config.first_subsampling_dl * config.conv_radius
in_dim = config.in_features_dim
out_dim = config.first_features_dim
self.K = config.num_kernel_points
# Save all block operations in a list of modules
self.block_ops = nn.ModuleList()
# Loop over consecutive blocks
block_in_layer = 0
for block_i, block in enumerate(config.architecture):
# Check equivariance
if ('equivariant' in block) and (not out_dim % 3 == 0):
raise ValueError('Equivariant block but features dimension is not a factor of 3')
# Detect upsampling block to stop
if 'upsample' in block:
break
# Apply the good block function defining tf ops
self.block_ops.append(block_decider(block,
r,
in_dim,
out_dim,
layer,
config))
2020-04-02 21:31:35 +00:00
2020-03-31 19:42:35 +00:00
# Index of block in this layer
block_in_layer += 1
# Update dimension of input from output
2020-04-03 15:22:57 +00:00
if 'simple' in block:
in_dim = out_dim // 2
else:
in_dim = out_dim
2020-03-31 19:42:35 +00:00
# Detect change to a subsampled layer
if 'pool' in block or 'strided' in block:
# Update radius and feature dimension for next layer
layer += 1
r *= 2
out_dim *= 2
block_in_layer = 0
self.head_mlp = UnaryBlock(out_dim, 1024, False, 0)
self.head_softmax = UnaryBlock(1024, config.num_classes, False, 0)
################
# Network Losses
################
self.criterion = torch.nn.CrossEntropyLoss()
self.offset_loss = config.offsets_loss
self.offset_decay = config.offsets_decay
self.output_loss = 0
self.reg_loss = 0
self.l1 = nn.L1Loss()
return
def forward(self, batch, config):
# Save all block operations in a list of modules
x = batch.features.clone().detach()
# Loop over consecutive blocks
for block_op in self.block_ops:
x = block_op(x, batch)
# Head of network
x = self.head_mlp(x, batch)
x = self.head_softmax(x, batch)
return x
def loss(self, outputs, labels):
"""
Runs the loss on outputs of the model
:param outputs: logits
:param labels: labels
:return: loss
"""
# Cross entropy loss
self.output_loss = self.criterion(outputs, labels)
# Regularization of deformable offsets
self.reg_loss = self.offset_regularizer()
# Combined loss
return self.output_loss + self.reg_loss
@staticmethod
def accuracy(outputs, labels):
"""
Computes accuracy of the current batch
:param outputs: logits predicted by the network
:param labels: labels
:return: accuracy value
"""
predicted = torch.argmax(outputs.data, dim=1)
total = labels.size(0)
correct = (predicted == labels).sum().item()
return correct / total
def offset_regularizer(self):
fitting_loss = 0
repulsive_loss = 0
for m in self.modules():
if isinstance(m, KPConv) and m.deformable:
##############################
# divide offset gradient by 10
##############################
m.unscaled_offsets.register_hook(lambda grad: grad * 0.1)
#m.unscaled_offsets.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
##############
# Fitting loss
##############
# Get the distance to closest input point
KP_min_d2, _ = torch.min(m.deformed_d2, dim=1)
# Normalize KP locations to be independant from layers
KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2)
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
fitting_loss += self.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
################
# Repulsive loss
################
# Normalized KP locations
KP_locs = m.deformed_KP / m.KP_extent
# Point should not be close to each other
for i in range(self.K):
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
rep_loss = torch.sum(torch.clamp_max(distances - 1.5, max=0.0) ** 2, dim=1)
2020-04-23 13:51:16 +00:00
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K
2020-03-31 19:42:35 +00:00
return self.offset_decay * (fitting_loss + repulsive_loss)
class KPFCNN(nn.Module):
"""
Class defining KPFCNN
"""
2020-04-09 21:13:27 +00:00
def __init__(self, config, lbl_values, ign_lbls):
2020-03-31 19:42:35 +00:00
super(KPFCNN, self).__init__()
############
# Parameters
############
# Current radius of convolution and feature dimension
layer = 0
r = config.first_subsampling_dl * config.conv_radius
in_dim = config.in_features_dim
out_dim = config.first_features_dim
self.K = config.num_kernel_points
2020-04-09 21:13:27 +00:00
self.C = len(lbl_values) - len(ign_lbls)
2020-03-31 19:42:35 +00:00
#####################
# List Encoder blocks
#####################
# Save all block operations in a list of modules
2020-04-23 13:51:16 +00:00
self.encoder_blocks = nn.ModuleList()
2020-03-31 19:42:35 +00:00
self.encoder_skip_dims = []
self.encoder_skips = []
# Loop over consecutive blocks
for block_i, block in enumerate(config.architecture):
# Check equivariance
if ('equivariant' in block) and (not out_dim % 3 == 0):
raise ValueError('Equivariant block but features dimension is not a factor of 3')
# Detect change to next layer for skip connection
if np.any([tmp in block for tmp in ['pool', 'strided', 'upsample', 'global']]):
self.encoder_skips.append(block_i)
self.encoder_skip_dims.append(in_dim)
# Detect upsampling block to stop
if 'upsample' in block:
break
# Apply the good block function defining tf ops
2020-04-23 13:51:16 +00:00
self.encoder_blocks.append(block_decider(block,
2020-03-31 19:42:35 +00:00
r,
in_dim,
out_dim,
layer,
config))
# Update dimension of input from output
2020-04-03 15:22:57 +00:00
if 'simple' in block:
in_dim = out_dim // 2
else:
in_dim = out_dim
2020-03-31 19:42:35 +00:00
# Detect change to a subsampled layer
if 'pool' in block or 'strided' in block:
# Update radius and feature dimension for next layer
layer += 1
r *= 2
out_dim *= 2
#####################
# List Decoder blocks
#####################
# Save all block operations in a list of modules
2020-04-23 13:51:16 +00:00
self.decoder_blocks = nn.ModuleList()
2020-03-31 19:42:35 +00:00
self.decoder_concats = []
# Find first upsampling block
start_i = 0
for block_i, block in enumerate(config.architecture):
if 'upsample' in block:
start_i = block_i
break
# Loop over consecutive blocks
for block_i, block in enumerate(config.architecture[start_i:]):
# Add dimension of skip connection concat
if block_i > 0 and 'upsample' in config.architecture[start_i + block_i - 1]:
in_dim += self.encoder_skip_dims[layer]
self.decoder_concats.append(block_i)
# Apply the good block function defining tf ops
2020-04-23 13:51:16 +00:00
self.decoder_blocks.append(block_decider(block,
2020-03-31 19:42:35 +00:00
r,
in_dim,
out_dim,
layer,
config))
# Update dimension of input from output
in_dim = out_dim
# Detect change to a subsampled layer
if 'upsample' in block:
# Update radius and feature dimension for next layer
layer -= 1
r *= 0.5
out_dim = out_dim // 2
self.head_mlp = UnaryBlock(out_dim, config.first_features_dim, False, 0)
2020-04-09 21:13:27 +00:00
self.head_softmax = UnaryBlock(config.first_features_dim, self.C, False, 0)
2020-03-31 19:42:35 +00:00
################
# Network Losses
################
2020-04-09 21:13:27 +00:00
# List of valid labels (those not ignored in loss)
self.valid_labels = np.sort([c for c in lbl_values if c not in ign_lbls])
2020-03-31 19:42:35 +00:00
# Choose segmentation loss
2020-04-09 21:13:27 +00:00
if len(config.class_w) > 0:
class_w = torch.from_numpy(np.array(config.class_w, dtype=np.float32))
self.criterion = torch.nn.CrossEntropyLoss(weight=class_w, ignore_index=-1)
2020-03-31 19:42:35 +00:00
else:
2020-04-09 21:13:27 +00:00
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
2020-03-31 19:42:35 +00:00
self.offset_loss = config.offsets_loss
self.offset_decay = config.offsets_decay
self.output_loss = 0
self.reg_loss = 0
self.l1 = nn.L1Loss()
return
def forward(self, batch, config):
# Get input features
x = batch.features.clone().detach()
# Loop over consecutive blocks
skip_x = []
2020-04-23 13:51:16 +00:00
for block_i, block_op in enumerate(self.encoder_blocks):
2020-03-31 19:42:35 +00:00
if block_i in self.encoder_skips:
skip_x.append(x)
x = block_op(x, batch)
2020-04-23 13:51:16 +00:00
for block_i, block_op in enumerate(self.decoder_blocks):
2020-03-31 19:42:35 +00:00
if block_i in self.decoder_concats:
x = torch.cat([x, skip_x.pop()], dim=1)
x = block_op(x, batch)
# Head of network
x = self.head_mlp(x, batch)
x = self.head_softmax(x, batch)
return x
def loss(self, outputs, labels):
"""
Runs the loss on outputs of the model
:param outputs: logits
:param labels: labels
:return: loss
"""
2020-04-09 21:13:27 +00:00
# Set all ignored labels to -1 and correct the other label to be in [0, C-1] range
target = - torch.ones_like(labels)
for i, c in enumerate(self.valid_labels):
target[labels == c] = i
# Reshape to have a minibatch size of 1
2020-03-31 19:42:35 +00:00
outputs = torch.transpose(outputs, 0, 1)
outputs = outputs.unsqueeze(0)
2020-04-09 21:13:27 +00:00
target = target.unsqueeze(0)
2020-03-31 19:42:35 +00:00
# Cross entropy loss
2020-04-09 21:13:27 +00:00
self.output_loss = self.criterion(outputs, target)
2020-03-31 19:42:35 +00:00
# Regularization of deformable offsets
self.reg_loss = self.offset_regularizer()
# Combined loss
return self.output_loss + self.reg_loss
2020-04-09 21:13:27 +00:00
def accuracy(self, outputs, labels):
2020-03-31 19:42:35 +00:00
"""
Computes accuracy of the current batch
:param outputs: logits predicted by the network
:param labels: labels
:return: accuracy value
"""
2020-04-09 21:13:27 +00:00
# Set all ignored labels to -1 and correct the other label to be in [0, C-1] range
target = - torch.ones_like(labels)
for i, c in enumerate(self.valid_labels):
target[labels == c] = i
2020-03-31 19:42:35 +00:00
predicted = torch.argmax(outputs.data, dim=1)
2020-04-09 21:13:27 +00:00
total = target.size(0)
correct = (predicted == target).sum().item()
2020-03-31 19:42:35 +00:00
return correct / total
def offset_regularizer(self):
fitting_loss = 0
repulsive_loss = 0
for m in self.modules():
if isinstance(m, KPConv) and m.deformable:
##############################
# divide offset gradient by 10
##############################
m.unscaled_offsets.register_hook(lambda grad: grad * 0.1)
#m.unscaled_offsets.register_hook(lambda grad: print('GRAD2', grad[10, 5, :]))
##############
# Fitting loss
##############
# Get the distance to closest input point
KP_min_d2, _ = torch.min(m.deformed_d2, dim=1)
# Normalize KP locations to be independant from layers
KP_min_d2 = KP_min_d2 / (m.KP_extent ** 2)
# Loss will be the square distance to closest input point. We use L1 because dist is already squared
fitting_loss += self.l1(KP_min_d2, torch.zeros_like(KP_min_d2))
################
# Repulsive loss
################
# Normalized KP locations
KP_locs = m.deformed_KP / m.KP_extent
# Point should not be close to each other
for i in range(self.K):
other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach()
distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2))
2020-04-23 13:51:16 +00:00
rep_loss = torch.sum(torch.clamp_max(distances - 0.5, max=0.0) ** 2, dim=1)
repulsive_loss += self.l1(rep_loss, torch.zeros_like(rep_loss)) / self.K
2020-03-31 19:42:35 +00:00
return self.offset_decay * (fitting_loss + repulsive_loss)