PointMLP/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py
2021-10-04 03:25:18 -04:00

210 lines
6.4 KiB
Python

from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from pointnet2_ops import pointnet2_utils
def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
layers = []
for i in range(1, len(mlp_spec)):
layers.append(
nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
)
if bn:
layers.append(nn.BatchNorm2d(mlp_spec[i]))
layers.append(nn.ReLU(True))
return nn.Sequential(*layers)
class _PointnetSAModuleBase(nn.Module):
def __init__(self):
super(_PointnetSAModuleBase, self).__init__()
self.npoint = None
self.groupers = None
self.mlps = None
def forward(
self, xyz: torch.Tensor, features: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Parameters
----------
xyz : torch.Tensor
(B, N, 3) tensor of the xyz coordinates of the features
features : torch.Tensor
(B, C, N) tensor of the descriptors of the the features
Returns
-------
new_xyz : torch.Tensor
(B, npoint, 3) tensor of the new features' xyz
new_features : torch.Tensor
(B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
"""
new_features_list = []
xyz_flipped = xyz.transpose(1, 2).contiguous()
new_xyz = (
pointnet2_utils.gather_operation(
xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
)
.transpose(1, 2)
.contiguous()
if self.npoint is not None
else None
)
for i in range(len(self.groupers)):
new_features = self.groupers[i](
xyz, new_xyz, features
) # (B, C, npoint, nsample)
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
new_features = F.max_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
) # (B, mlp[-1], npoint, 1)
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
new_features_list.append(new_features)
return new_xyz, torch.cat(new_features_list, dim=1)
class PointnetSAModuleMSG(_PointnetSAModuleBase):
r"""Pointnet set abstrction layer with multiscale grouping
Parameters
----------
npoint : int
Number of features
radii : list of float32
list of radii to group with
nsamples : list of int32
Number of samples in each ball query
mlps : list of list of int32
Spec of the pointnet before the global max_pool for each scale
bn : bool
Use batchnorm
"""
def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
# type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
super(PointnetSAModuleMSG, self).__init__()
assert len(radii) == len(nsamples) == len(mlps)
self.npoint = npoint
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radii)):
radius = radii[i]
nsample = nsamples[i]
self.groupers.append(
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
if npoint is not None
else pointnet2_utils.GroupAll(use_xyz)
)
mlp_spec = mlps[i]
if use_xyz:
mlp_spec[0] += 3
self.mlps.append(build_shared_mlp(mlp_spec, bn))
class PointnetSAModule(PointnetSAModuleMSG):
r"""Pointnet set abstrction layer
Parameters
----------
npoint : int
Number of features
radius : float
Radius of ball
nsample : int
Number of samples in the ball query
mlp : list
Spec of the pointnet before the global max_pool
bn : bool
Use batchnorm
"""
def __init__(
self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
):
# type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
super(PointnetSAModule, self).__init__(
mlps=[mlp],
npoint=npoint,
radii=[radius],
nsamples=[nsample],
bn=bn,
use_xyz=use_xyz,
)
class PointnetFPModule(nn.Module):
r"""Propigates the features of one set to another
Parameters
----------
mlp : list
Pointnet module parameters
bn : bool
Use batchnorm
"""
def __init__(self, mlp, bn=True):
# type: (PointnetFPModule, List[int], bool) -> None
super(PointnetFPModule, self).__init__()
self.mlp = build_shared_mlp(mlp, bn=bn)
def forward(self, unknown, known, unknow_feats, known_feats):
# type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
r"""
Parameters
----------
unknown : torch.Tensor
(B, n, 3) tensor of the xyz positions of the unknown features
known : torch.Tensor
(B, m, 3) tensor of the xyz positions of the known features
unknow_feats : torch.Tensor
(B, C1, n) tensor of the features to be propigated to
known_feats : torch.Tensor
(B, C2, m) tensor of features to be propigated
Returns
-------
new_features : torch.Tensor
(B, mlp[-1], n) tensor of the features of the unknown features
"""
if known is not None:
dist, idx = pointnet2_utils.three_nn(unknown, known)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feats = pointnet2_utils.three_interpolate(
known_feats, idx, weight
)
else:
interpolated_feats = known_feats.expand(
*(known_feats.size()[0:2] + [unknown.size(1)])
)
if unknow_feats is not None:
new_features = torch.cat(
[interpolated_feats, unknow_feats], dim=1
) # (B, C2 + C1, n)
else:
new_features = interpolated_feats
new_features = new_features.unsqueeze(-1)
new_features = self.mlp(new_features)
return new_features.squeeze(-1)