652 lines
22 KiB
Python
652 lines
22 KiB
Python
#
|
|
#
|
|
# 0=================================0
|
|
# | Kernel Point Convolutions |
|
|
# 0=================================0
|
|
#
|
|
#
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Define network blocks
|
|
#
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Hugues THOMAS - 06/03/2020
|
|
#
|
|
|
|
|
|
import time
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.parameter import Parameter
|
|
from torch.nn.init import kaiming_uniform_
|
|
from kernels.kernel_points import load_kernels
|
|
|
|
from utils.ply import write_ply
|
|
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Simple functions
|
|
# \**********************/
|
|
#
|
|
|
|
|
|
def gather(x, idx, method=2):
|
|
"""
|
|
implementation of a custom gather operation for faster backwards.
|
|
:param x: input with shape [N, D_1, ... D_d]
|
|
:param idx: indexing with shape [n_1, ..., n_m]
|
|
:param method: Choice of the method
|
|
:return: x[idx] with shape [n_1, ..., n_m, D_1, ... D_d]
|
|
"""
|
|
|
|
if method == 0:
|
|
return x[idx]
|
|
elif method == 1:
|
|
x = x.unsqueeze(1)
|
|
x = x.expand((-1, idx.shape[-1], -1))
|
|
idx = idx.unsqueeze(2)
|
|
idx = idx.expand((-1, -1, x.shape[-1]))
|
|
return x.gather(0, idx)
|
|
elif method == 2:
|
|
for i, ni in enumerate(idx.size()[1:]):
|
|
x = x.unsqueeze(i+1)
|
|
new_s = list(x.size())
|
|
new_s[i+1] = ni
|
|
x = x.expand(new_s)
|
|
n = len(idx.size())
|
|
for i, di in enumerate(x.size()[n:]):
|
|
idx = idx.unsqueeze(i+n)
|
|
new_s = list(idx.size())
|
|
new_s[i+n] = di
|
|
idx = idx.expand(new_s)
|
|
return x.gather(0, idx)
|
|
else:
|
|
raise ValueError('Unkown method')
|
|
|
|
|
|
def radius_gaussian(sq_r, sig, eps=1e-9):
|
|
"""
|
|
Compute a radius gaussian (gaussian of distance)
|
|
:param sq_r: input radiuses [dn, ..., d1, d0]
|
|
:param sig: extents of gaussians [d1, d0] or [d0] or float
|
|
:return: gaussian of sq_r [dn, ..., d1, d0]
|
|
"""
|
|
return torch.exp(-sq_r / (2 * sig**2 + eps))
|
|
|
|
|
|
def closest_pool(x, inds):
|
|
"""
|
|
Pools features from the closest neighbors. WARNING: this function assumes the neighbors are ordered.
|
|
:param x: [n1, d] features matrix
|
|
:param inds: [n2, max_num] Only the first column is used for pooling
|
|
:return: [n2, d] pooled features matrix
|
|
"""
|
|
|
|
# Add a last row with minimum features for shadow pools
|
|
x = torch.cat((x, torch.zeros_like(x[:1, :])), 0)
|
|
|
|
# Get features for each pooling location [n2, d]
|
|
return gather(x, inds[:, 0])
|
|
|
|
|
|
def max_pool(x, inds):
|
|
"""
|
|
Pools features with the maximum values.
|
|
:param x: [n1, d] features matrix
|
|
:param inds: [n2, max_num] pooling indices
|
|
:return: [n2, d] pooled features matrix
|
|
"""
|
|
|
|
# Add a last row with minimum features for shadow pools
|
|
x = torch.cat((x, torch.zeros_like(x[:1, :])), 0)
|
|
|
|
# Get all features for each pooling location [n2, max_num, d]
|
|
pool_features = gather(x, inds)
|
|
|
|
# Pool the maximum [n2, d]
|
|
max_features, _ = torch.max(pool_features, 1)
|
|
return max_features
|
|
|
|
|
|
def global_average(x, batch_lengths):
|
|
"""
|
|
Block performing a global average over batch pooling
|
|
:param x: [N, D] input features
|
|
:param batch_lengths: [B] list of batch lengths
|
|
:return: [B, D] averaged features
|
|
"""
|
|
|
|
# Loop over the clouds of the batch
|
|
averaged_features = []
|
|
i0 = 0
|
|
for b_i, length in enumerate(batch_lengths):
|
|
|
|
# Average features for each batch cloud
|
|
averaged_features.append(torch.mean(x[i0:i0 + length], dim=0))
|
|
|
|
# Increment for next cloud
|
|
i0 += length
|
|
|
|
# Average features in each batch
|
|
return torch.stack(averaged_features)
|
|
|
|
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# KPConv class
|
|
# \******************/
|
|
#
|
|
|
|
|
|
class KPConv(nn.Module):
|
|
|
|
def __init__(self, kernel_size, p_dim, in_channels, out_channels, KP_extent, radius,
|
|
fixed_kernel_points='center', KP_influence='linear', aggregation_mode='sum',
|
|
deformable=False, modulated=False):
|
|
"""
|
|
Initialize parameters for KPConvDeformable.
|
|
:param kernel_size: Number of kernel points.
|
|
:param p_dim: dimension of the point space.
|
|
:param in_channels: dimension of input features.
|
|
:param out_channels: dimension of output features.
|
|
:param KP_extent: influence radius of each kernel point.
|
|
:param radius: radius used for kernel point init. Even for deformable, use the config.conv_radius
|
|
:param fixed_kernel_points: fix position of certain kernel points ('none', 'center' or 'verticals').
|
|
:param KP_influence: influence function of the kernel points ('constant', 'linear', 'gaussian').
|
|
:param aggregation_mode: choose to sum influences, or only keep the closest ('closest', 'sum').
|
|
:param deformable: choose deformable or not
|
|
:param modulated: choose if kernel weights are modulated in addition to deformed
|
|
"""
|
|
super(KPConv, self).__init__()
|
|
|
|
# Save parameters
|
|
self.K = kernel_size
|
|
self.p_dim = p_dim
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.radius = radius
|
|
self.KP_extent = KP_extent
|
|
self.fixed_kernel_points = fixed_kernel_points
|
|
self.KP_influence = KP_influence
|
|
self.aggregation_mode = aggregation_mode
|
|
self.deformable = deformable
|
|
self.modulated = modulated
|
|
|
|
# Running variable containing deformed KP distance to input points. (used in regularization loss)
|
|
self.deformed_d2 = None
|
|
self.deformed_KP = None
|
|
self.unscaled_offsets = None
|
|
|
|
# Initialize weights
|
|
self.weights = Parameter(torch.zeros((self.K, in_channels, out_channels), dtype=torch.float32),
|
|
requires_grad=True)
|
|
|
|
# Initiate weights for offsets
|
|
if deformable:
|
|
if modulated:
|
|
self.offset_dim = (self.p_dim + 1) * self.K
|
|
else:
|
|
self.offset_dim = self.p_dim * self.K
|
|
self.offset_conv = KPConv(self.K,
|
|
self.p_dim,
|
|
in_channels,
|
|
self.offset_dim,
|
|
KP_extent,
|
|
radius,
|
|
fixed_kernel_points=fixed_kernel_points,
|
|
KP_influence=KP_influence,
|
|
aggregation_mode=aggregation_mode)
|
|
self.offset_bias = Parameter(torch.zeros(self.offset_dim, dtype=torch.float32), requires_grad=True)
|
|
|
|
else:
|
|
self.offset_dim = None
|
|
self.offset_conv = None
|
|
self.offset_bias = None
|
|
|
|
# Reset parameters
|
|
self.reset_parameters()
|
|
|
|
# Initialize kernel points
|
|
self.kernel_points = self.init_KP()
|
|
|
|
return
|
|
|
|
def reset_parameters(self):
|
|
kaiming_uniform_(self.weights, a=math.sqrt(5))
|
|
if self.deformable:
|
|
nn.init.zeros_(self.offset_bias)
|
|
return
|
|
|
|
def init_KP(self):
|
|
"""
|
|
Initialize the kernel point positions in a sphere
|
|
:return: the tensor of kernel points
|
|
"""
|
|
|
|
# Create one kernel disposition (as numpy array). Choose the KP distance to center thanks to the KP extent
|
|
K_points_numpy = load_kernels(self.radius,
|
|
self.K,
|
|
dimension=self.p_dim,
|
|
fixed=self.fixed_kernel_points)
|
|
|
|
return Parameter(torch.tensor(K_points_numpy, dtype=torch.float32),
|
|
requires_grad=False)
|
|
|
|
def forward(self, q_pts, s_pts, neighb_inds, x):
|
|
|
|
###################
|
|
# Offset generation
|
|
###################
|
|
|
|
if self.deformable:
|
|
offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias
|
|
|
|
if self.modulated:
|
|
|
|
# Get offset (in normalized scale) from features
|
|
offsets = offset_features[:, :self.p_dim * self.K]
|
|
self.unscaled_offsets = offsets.view(-1, self.K, self.p_dim)
|
|
|
|
# Get modulations
|
|
modulations = 2 * torch.sigmoid(offset_features[:, self.p_dim * self.K:])
|
|
|
|
else:
|
|
|
|
# Get offset (in normalized scale) from features
|
|
self.unscaled_offsets = offset_features.view(-1, self.K, self.p_dim)
|
|
|
|
# No modulations
|
|
modulations = None
|
|
|
|
# Rescale offset for this layer
|
|
offsets = self.unscaled_offsets * self.KP_extent
|
|
|
|
else:
|
|
offsets = None
|
|
modulations = None
|
|
|
|
######################
|
|
# Deformed convolution
|
|
######################
|
|
|
|
# Add a fake point in the last row for shadow neighbors
|
|
s_pts = torch.cat((s_pts, torch.zeros_like(s_pts[:1, :]) + 1e6), 0)
|
|
|
|
# Get neighbor points [n_points, n_neighbors, dim]
|
|
neighbors = s_pts[neighb_inds, :]
|
|
|
|
# Center every neighborhood
|
|
neighbors = neighbors - q_pts.unsqueeze(1)
|
|
|
|
# Apply offsets to kernel points [n_points, n_kpoints, dim]
|
|
if self.deformable:
|
|
self.deformed_KP = offsets + self.kernel_points
|
|
deformed_K_points = self.deformed_KP.unsqueeze(1)
|
|
else:
|
|
deformed_K_points = self.kernel_points
|
|
|
|
# Get all difference matrices [n_points, n_neighbors, n_kpoints, dim]
|
|
neighbors.unsqueeze_(2)
|
|
differences = neighbors - deformed_K_points
|
|
|
|
# Get the square distances [n_points, n_neighbors, n_kpoints]
|
|
sq_distances = torch.sum(differences ** 2, dim=3)
|
|
|
|
# Optimization by ignoring points outside a deformed KP range
|
|
if False and self.deformable:
|
|
# Boolean of the neighbors in range of a kernel point [n_points, n_neighbors]
|
|
in_range = torch.any(sq_distances < self.KP_extent ** 2, dim=2)
|
|
|
|
# New value of max neighbors
|
|
new_max_neighb = torch.max(torch.sum(in_range, dim=1))
|
|
|
|
print(sq_distances.shape[1], '=>', new_max_neighb.item())
|
|
|
|
# Save distances for loss
|
|
if self.deformable:
|
|
self.deformed_d2 = sq_distances
|
|
|
|
# Get Kernel point influences [n_points, n_kpoints, n_neighbors]
|
|
if self.KP_influence == 'constant':
|
|
# Every point get an influence of 1.
|
|
all_weights = torch.ones_like(sq_distances)
|
|
all_weights = torch.transpose(all_weights, 1, 2)
|
|
|
|
elif self.KP_influence == 'linear':
|
|
# Influence decrease linearly with the distance, and get to zero when d = KP_extent.
|
|
all_weights = torch.clamp(1 - torch.sqrt(sq_distances) / self.KP_extent, min=0.0)
|
|
all_weights = torch.transpose(all_weights, 1, 2)
|
|
|
|
elif self.KP_influence == 'gaussian':
|
|
# Influence in gaussian of the distance.
|
|
sigma = self.KP_extent * 0.3
|
|
all_weights = radius_gaussian(sq_distances, sigma)
|
|
all_weights = torch.transpose(all_weights, 1, 2)
|
|
else:
|
|
raise ValueError('Unknown influence function type (config.KP_influence)')
|
|
|
|
# In case of closest mode, only the closest KP can influence each point
|
|
if self.aggregation_mode == 'closest':
|
|
neighbors_1nn = torch.argmin(sq_distances, dim=2)
|
|
all_weights *= torch.transpose(nn.functional.one_hot(neighbors_1nn, self.K), 1, 2)
|
|
|
|
elif self.aggregation_mode != 'sum':
|
|
raise ValueError("Unknown convolution mode. Should be 'closest' or 'sum'")
|
|
|
|
# Add a zero feature for shadow neighbors
|
|
x = torch.cat((x, torch.zeros_like(x[:1, :])), 0)
|
|
|
|
# Get the features of each neighborhood [n_points, n_neighbors, in_fdim]
|
|
neighb_x = gather(x, neighb_inds)
|
|
|
|
# Apply distance weights [n_points, n_kpoints, in_fdim]
|
|
weighted_features = torch.matmul(all_weights, neighb_x)
|
|
|
|
# Apply modulations
|
|
if self.deformable and self.modulated:
|
|
weighted_features *= modulations.unsqueeze(2)
|
|
|
|
# Apply network weights [n_kpoints, n_points, out_fdim]
|
|
weighted_features = weighted_features.permute((1, 0, 2))
|
|
kernel_outputs = torch.matmul(weighted_features, self.weights)
|
|
|
|
# Convolution sum [n_points, out_fdim]
|
|
return torch.sum(kernel_outputs, dim=0)
|
|
|
|
|
|
# ----------------------------------------------------------------------------------------------------------------------
|
|
#
|
|
# Complex blocks
|
|
# \********************/
|
|
#
|
|
|
|
def block_decider(block_name,
|
|
radius,
|
|
in_dim,
|
|
out_dim,
|
|
layer_ind,
|
|
config):
|
|
|
|
if block_name == 'unary':
|
|
return UnaryBlock(in_dim, out_dim, config.use_batch_norm, config.batch_norm_momentum)
|
|
|
|
elif block_name in ['simple',
|
|
'simple_deformable',
|
|
'simple_invariant',
|
|
'simple_equivariant',
|
|
'simple_strided',
|
|
'simple_deformable_strided',
|
|
'simple_invariant_strided',
|
|
'simple_equivariant_strided']:
|
|
return SimpleBlock(block_name, in_dim, out_dim, radius, layer_ind, config)
|
|
|
|
elif block_name in ['resnetb',
|
|
'resnetb_invariant',
|
|
'resnetb_equivariant',
|
|
'resnetb_deformable',
|
|
'resnetb_strided',
|
|
'resnetb_deformable_strided',
|
|
'resnetb_equivariant_strided',
|
|
'resnetb_invariant_strided']:
|
|
return ResnetBottleneckBlock(block_name, in_dim, out_dim, radius, layer_ind, config)
|
|
|
|
elif block_name == 'max_pool' or block_name == 'max_pool_wide':
|
|
return MaxPoolBlock(layer_ind)
|
|
|
|
elif block_name == 'global_average':
|
|
return GlobalAverageBlock()
|
|
|
|
elif block_name == 'nearest_upsample':
|
|
return NearestUpsampleBlock(layer_ind)
|
|
|
|
else:
|
|
raise ValueError('Unknown block name in the architecture definition : ' + block_name)
|
|
|
|
|
|
class BatchNormBlock(nn.Module):
|
|
|
|
def __init__(self, in_dim, use_bn, bn_momentum):
|
|
"""
|
|
Initialize a batch normalization block. If network does not use batch normalization, replace with biases.
|
|
:param in_dim: dimension input features
|
|
:param use_bn: boolean indicating if we use Batch Norm
|
|
:param bn_momentum: Batch norm momentum
|
|
"""
|
|
super(BatchNormBlock, self).__init__()
|
|
self.bn_momentum = bn_momentum
|
|
self.use_bn = use_bn
|
|
if self.use_bn:
|
|
self.batch_norm = nn.BatchNorm1d(in_dim, momentum=bn_momentum)
|
|
#self.batch_norm = nn.InstanceNorm1d(in_dim, momentum=bn_momentum)
|
|
else:
|
|
self.bias = Parameter(torch.zeros(in_dim, dtype=torch.float32), requires_grad=True)
|
|
return
|
|
|
|
def reset_parameters(self):
|
|
nn.init.zeros_(self.bias)
|
|
|
|
def forward(self, x):
|
|
if self.use_bn:
|
|
|
|
x = x.unsqueeze(2)
|
|
x = x.transpose(0, 2)
|
|
x = self.batch_norm(x)
|
|
x = x.transpose(0, 2)
|
|
return x.squeeze()
|
|
else:
|
|
return x + self.bias
|
|
|
|
|
|
class UnaryBlock(nn.Module):
|
|
|
|
def __init__(self, in_dim, out_dim, use_bn, bn_momentum, no_relu=False):
|
|
"""
|
|
Initialize a standard unary block with its ReLU and BatchNorm.
|
|
:param in_dim: dimension input features
|
|
:param out_dim: dimension input features
|
|
:param use_bn: boolean indicating if we use Batch Norm
|
|
:param bn_momentum: Batch norm momentum
|
|
"""
|
|
|
|
super(UnaryBlock, self).__init__()
|
|
self.bn_momentum = bn_momentum
|
|
self.use_bn = use_bn
|
|
self.no_relu = no_relu
|
|
self.mlp = nn.Linear(in_dim, out_dim, bias=False)
|
|
self.batch_norm = BatchNormBlock(out_dim, self.use_bn, self.bn_momentum)
|
|
if not no_relu:
|
|
self.leaky_relu = nn.LeakyReLU(0.1)
|
|
return
|
|
|
|
def forward(self, x, batch=None):
|
|
x = self.mlp(x)
|
|
x = self.batch_norm(x)
|
|
if not self.no_relu:
|
|
x = self.leaky_relu(x)
|
|
return x
|
|
|
|
|
|
class SimpleBlock(nn.Module):
|
|
|
|
def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, config):
|
|
"""
|
|
Initialize a simple convolution block with its ReLU and BatchNorm.
|
|
:param in_dim: dimension input features
|
|
:param out_dim: dimension input features
|
|
:param radius: current radius of convolution
|
|
:param config: parameters
|
|
"""
|
|
super(SimpleBlock, self).__init__()
|
|
|
|
# get KP_extent from current radius
|
|
current_extent = radius * config.KP_extent / config.conv_radius
|
|
|
|
# Get other parameters
|
|
self.bn_momentum = config.batch_norm_momentum
|
|
self.use_bn = config.use_batch_norm
|
|
self.layer_ind = layer_ind
|
|
self.block_name = block_name
|
|
|
|
# Define the KPConv class
|
|
self.KPConv = KPConv(config.num_kernel_points,
|
|
config.in_points_dim,
|
|
in_dim,
|
|
out_dim,
|
|
current_extent,
|
|
radius,
|
|
fixed_kernel_points=config.fixed_kernel_points,
|
|
KP_influence=config.KP_influence,
|
|
aggregation_mode=config.aggregation_mode,
|
|
deformable='deform' in block_name,
|
|
modulated=config.modulated)
|
|
|
|
# Other opperations
|
|
self.batch_norm = BatchNormBlock(out_dim, self.use_bn, self.bn_momentum)
|
|
self.leaky_relu = nn.LeakyReLU(0.1)
|
|
|
|
return
|
|
|
|
def forward(self, x, batch):
|
|
|
|
if 'strided' in self.block_name:
|
|
q_pts = batch.points[self.layer_ind + 1]
|
|
s_pts = batch.points[self.layer_ind]
|
|
neighb_inds = batch.pools[self.layer_ind]
|
|
else:
|
|
q_pts = batch.points[self.layer_ind]
|
|
s_pts = batch.points[self.layer_ind]
|
|
neighb_inds = batch.neighbors[self.layer_ind]
|
|
|
|
x = self.KPConv(q_pts, s_pts, neighb_inds, x)
|
|
return self.leaky_relu(self.batch_norm(x))
|
|
|
|
|
|
class ResnetBottleneckBlock(nn.Module):
|
|
|
|
def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, config):
|
|
"""
|
|
Initialize a resnet bottleneck block.
|
|
:param in_dim: dimension input features
|
|
:param out_dim: dimension input features
|
|
:param radius: current radius of convolution
|
|
:param config: parameters
|
|
"""
|
|
super(ResnetBottleneckBlock, self).__init__()
|
|
|
|
# get KP_extent from current radius
|
|
current_extent = radius * config.KP_extent / config.conv_radius
|
|
|
|
# Get other parameters
|
|
self.bn_momentum = config.batch_norm_momentum
|
|
self.use_bn = config.use_batch_norm
|
|
self.block_name = block_name
|
|
self.layer_ind = layer_ind
|
|
|
|
# First downscaling mlp
|
|
if in_dim != out_dim // 2:
|
|
self.unary1 = UnaryBlock(in_dim, out_dim // 2, self.use_bn, self.bn_momentum)
|
|
else:
|
|
self.unary1 = nn.Identity()
|
|
|
|
# KPConv block
|
|
self.KPConv = KPConv(config.num_kernel_points,
|
|
config.in_points_dim,
|
|
out_dim // 2,
|
|
out_dim // 2,
|
|
current_extent,
|
|
radius,
|
|
fixed_kernel_points=config.fixed_kernel_points,
|
|
KP_influence=config.KP_influence,
|
|
aggregation_mode=config.aggregation_mode,
|
|
deformable='deform' in block_name,
|
|
modulated=config.modulated)
|
|
self.batch_norm_conv = BatchNormBlock(out_dim // 2, self.use_bn, self.bn_momentum)
|
|
|
|
# Second upscaling mlp
|
|
self.unary2 = UnaryBlock(out_dim // 2, out_dim, self.use_bn, self.bn_momentum, no_relu=True)
|
|
|
|
# Shortcut optional mpl
|
|
if in_dim != out_dim:
|
|
self.unary_shortcut = UnaryBlock(in_dim, out_dim, self.use_bn, self.bn_momentum, no_relu=True)
|
|
else:
|
|
self.unary_shortcut = nn.Identity()
|
|
|
|
# Other operations
|
|
self.leaky_relu = nn.LeakyReLU(0.1)
|
|
|
|
return
|
|
|
|
def forward(self, features, batch):
|
|
|
|
if 'strided' in self.block_name:
|
|
q_pts = batch.points[self.layer_ind + 1]
|
|
s_pts = batch.points[self.layer_ind]
|
|
neighb_inds = batch.pools[self.layer_ind]
|
|
else:
|
|
q_pts = batch.points[self.layer_ind]
|
|
s_pts = batch.points[self.layer_ind]
|
|
neighb_inds = batch.neighbors[self.layer_ind]
|
|
|
|
# First downscaling mlp
|
|
x = self.unary1(features)
|
|
|
|
# Convolution
|
|
x = self.KPConv(q_pts, s_pts, neighb_inds, x)
|
|
x = self.leaky_relu(self.batch_norm_conv(x))
|
|
|
|
# Second upscaling mlp
|
|
x = self.unary2(x)
|
|
|
|
# Shortcut
|
|
if 'strided' in self.block_name:
|
|
shortcut = max_pool(features, neighb_inds)
|
|
else:
|
|
shortcut = features
|
|
shortcut = self.unary_shortcut(shortcut)
|
|
|
|
return self.leaky_relu(x + shortcut)
|
|
|
|
|
|
class GlobalAverageBlock(nn.Module):
|
|
|
|
def __init__(self):
|
|
"""
|
|
Initialize a global average block with its ReLU and BatchNorm.
|
|
"""
|
|
super(GlobalAverageBlock, self).__init__()
|
|
return
|
|
|
|
def forward(self, x, batch):
|
|
return global_average(x, batch.lengths[-1])
|
|
|
|
|
|
class NearestUpsampleBlock(nn.Module):
|
|
|
|
def __init__(self, layer_ind):
|
|
"""
|
|
Initialize a nearest upsampling block with its ReLU and BatchNorm.
|
|
"""
|
|
super(NearestUpsampleBlock, self).__init__()
|
|
self.layer_ind = layer_ind
|
|
return
|
|
|
|
def forward(self, x, batch):
|
|
return closest_pool(x, batch.upsamples[self.layer_ind - 1])
|
|
|
|
|
|
class MaxPoolBlock(nn.Module):
|
|
|
|
def __init__(self, layer_ind):
|
|
"""
|
|
Initialize a max pooling block with its ReLU and BatchNorm.
|
|
"""
|
|
super(MaxPoolBlock, self).__init__()
|
|
self.layer_ind = layer_ind
|
|
return
|
|
|
|
def forward(self, x, batch):
|
|
return max_pool(x, batch.pools[self.layer_ind + 1])
|
|
|