210 lines
6.4 KiB
Python
210 lines
6.4 KiB
Python
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from pointnet2_ops import pointnet2_utils
|
|
|
|
|
|
def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
|
|
layers = []
|
|
for i in range(1, len(mlp_spec)):
|
|
layers.append(
|
|
nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
|
|
)
|
|
if bn:
|
|
layers.append(nn.BatchNorm2d(mlp_spec[i]))
|
|
layers.append(nn.ReLU(True))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
class _PointnetSAModuleBase(nn.Module):
|
|
def __init__(self):
|
|
super(_PointnetSAModuleBase, self).__init__()
|
|
self.npoint = None
|
|
self.groupers = None
|
|
self.mlps = None
|
|
|
|
def forward(
|
|
self, xyz: torch.Tensor, features: Optional[torch.Tensor]
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""
|
|
Parameters
|
|
----------
|
|
xyz : torch.Tensor
|
|
(B, N, 3) tensor of the xyz coordinates of the features
|
|
features : torch.Tensor
|
|
(B, C, N) tensor of the descriptors of the the features
|
|
|
|
Returns
|
|
-------
|
|
new_xyz : torch.Tensor
|
|
(B, npoint, 3) tensor of the new features' xyz
|
|
new_features : torch.Tensor
|
|
(B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
|
|
"""
|
|
|
|
new_features_list = []
|
|
|
|
xyz_flipped = xyz.transpose(1, 2).contiguous()
|
|
new_xyz = (
|
|
pointnet2_utils.gather_operation(
|
|
xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
|
|
)
|
|
.transpose(1, 2)
|
|
.contiguous()
|
|
if self.npoint is not None
|
|
else None
|
|
)
|
|
|
|
for i in range(len(self.groupers)):
|
|
new_features = self.groupers[i](
|
|
xyz, new_xyz, features
|
|
) # (B, C, npoint, nsample)
|
|
|
|
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
|
|
new_features = F.max_pool2d(
|
|
new_features, kernel_size=[1, new_features.size(3)]
|
|
) # (B, mlp[-1], npoint, 1)
|
|
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
|
|
|
|
new_features_list.append(new_features)
|
|
|
|
return new_xyz, torch.cat(new_features_list, dim=1)
|
|
|
|
|
|
class PointnetSAModuleMSG(_PointnetSAModuleBase):
|
|
r"""Pointnet set abstrction layer with multiscale grouping
|
|
|
|
Parameters
|
|
----------
|
|
npoint : int
|
|
Number of features
|
|
radii : list of float32
|
|
list of radii to group with
|
|
nsamples : list of int32
|
|
Number of samples in each ball query
|
|
mlps : list of list of int32
|
|
Spec of the pointnet before the global max_pool for each scale
|
|
bn : bool
|
|
Use batchnorm
|
|
"""
|
|
|
|
def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
|
|
# type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
|
|
super(PointnetSAModuleMSG, self).__init__()
|
|
|
|
assert len(radii) == len(nsamples) == len(mlps)
|
|
|
|
self.npoint = npoint
|
|
self.groupers = nn.ModuleList()
|
|
self.mlps = nn.ModuleList()
|
|
for i in range(len(radii)):
|
|
radius = radii[i]
|
|
nsample = nsamples[i]
|
|
self.groupers.append(
|
|
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
|
|
if npoint is not None
|
|
else pointnet2_utils.GroupAll(use_xyz)
|
|
)
|
|
mlp_spec = mlps[i]
|
|
if use_xyz:
|
|
mlp_spec[0] += 3
|
|
|
|
self.mlps.append(build_shared_mlp(mlp_spec, bn))
|
|
|
|
|
|
class PointnetSAModule(PointnetSAModuleMSG):
|
|
r"""Pointnet set abstrction layer
|
|
|
|
Parameters
|
|
----------
|
|
npoint : int
|
|
Number of features
|
|
radius : float
|
|
Radius of ball
|
|
nsample : int
|
|
Number of samples in the ball query
|
|
mlp : list
|
|
Spec of the pointnet before the global max_pool
|
|
bn : bool
|
|
Use batchnorm
|
|
"""
|
|
|
|
def __init__(
|
|
self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
|
|
):
|
|
# type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
|
|
super(PointnetSAModule, self).__init__(
|
|
mlps=[mlp],
|
|
npoint=npoint,
|
|
radii=[radius],
|
|
nsamples=[nsample],
|
|
bn=bn,
|
|
use_xyz=use_xyz,
|
|
)
|
|
|
|
|
|
class PointnetFPModule(nn.Module):
|
|
r"""Propigates the features of one set to another
|
|
|
|
Parameters
|
|
----------
|
|
mlp : list
|
|
Pointnet module parameters
|
|
bn : bool
|
|
Use batchnorm
|
|
"""
|
|
|
|
def __init__(self, mlp, bn=True):
|
|
# type: (PointnetFPModule, List[int], bool) -> None
|
|
super(PointnetFPModule, self).__init__()
|
|
self.mlp = build_shared_mlp(mlp, bn=bn)
|
|
|
|
def forward(self, unknown, known, unknow_feats, known_feats):
|
|
# type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
|
|
r"""
|
|
Parameters
|
|
----------
|
|
unknown : torch.Tensor
|
|
(B, n, 3) tensor of the xyz positions of the unknown features
|
|
known : torch.Tensor
|
|
(B, m, 3) tensor of the xyz positions of the known features
|
|
unknow_feats : torch.Tensor
|
|
(B, C1, n) tensor of the features to be propigated to
|
|
known_feats : torch.Tensor
|
|
(B, C2, m) tensor of features to be propigated
|
|
|
|
Returns
|
|
-------
|
|
new_features : torch.Tensor
|
|
(B, mlp[-1], n) tensor of the features of the unknown features
|
|
"""
|
|
|
|
if known is not None:
|
|
dist, idx = pointnet2_utils.three_nn(unknown, known)
|
|
dist_recip = 1.0 / (dist + 1e-8)
|
|
norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
|
weight = dist_recip / norm
|
|
|
|
interpolated_feats = pointnet2_utils.three_interpolate(
|
|
known_feats, idx, weight
|
|
)
|
|
else:
|
|
interpolated_feats = known_feats.expand(
|
|
*(known_feats.size()[0:2] + [unknown.size(1)])
|
|
)
|
|
|
|
if unknow_feats is not None:
|
|
new_features = torch.cat(
|
|
[interpolated_feats, unknow_feats], dim=1
|
|
) # (B, C2 + C1, n)
|
|
else:
|
|
new_features = interpolated_feats
|
|
|
|
new_features = new_features.unsqueeze(-1)
|
|
new_features = self.mlp(new_features)
|
|
|
|
return new_features.squeeze(-1)
|