from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from pointnet2_ops import pointnet2_utils def build_shared_mlp(mlp_spec: List[int], bn: bool = True): layers = [] for i in range(1, len(mlp_spec)): layers.append( nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) ) if bn: layers.append(nn.BatchNorm2d(mlp_spec[i])) layers.append(nn.ReLU(True)) return nn.Sequential(*layers) class _PointnetSAModuleBase(nn.Module): def __init__(self): super(_PointnetSAModuleBase, self).__init__() self.npoint = None self.groupers = None self.mlps = None def forward( self, xyz: torch.Tensor, features: Optional[torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Parameters ---------- xyz : torch.Tensor (B, N, 3) tensor of the xyz coordinates of the features features : torch.Tensor (B, C, N) tensor of the descriptors of the the features Returns ------- new_xyz : torch.Tensor (B, npoint, 3) tensor of the new features' xyz new_features : torch.Tensor (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors """ new_features_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() new_xyz = ( pointnet2_utils.gather_operation( xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) ) .transpose(1, 2) .contiguous() if self.npoint is not None else None ) for i in range(len(self.groupers)): new_features = self.groupers[i]( xyz, new_xyz, features ) # (B, C, npoint, nsample) new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) new_features = F.max_pool2d( new_features, kernel_size=[1, new_features.size(3)] ) # (B, mlp[-1], npoint, 1) new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) new_features_list.append(new_features) return new_xyz, torch.cat(new_features_list, dim=1) class PointnetSAModuleMSG(_PointnetSAModuleBase): r"""Pointnet set abstrction layer with multiscale grouping Parameters ---------- npoint : int Number of features radii : list of float32 list of radii to group with nsamples : list of int32 Number of samples in each ball query mlps : list of list of int32 Spec of the pointnet before the global max_pool for each scale bn : bool Use batchnorm """ def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None super(PointnetSAModuleMSG, self).__init__() assert len(radii) == len(nsamples) == len(mlps) self.npoint = npoint self.groupers = nn.ModuleList() self.mlps = nn.ModuleList() for i in range(len(radii)): radius = radii[i] nsample = nsamples[i] self.groupers.append( pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) if npoint is not None else pointnet2_utils.GroupAll(use_xyz) ) mlp_spec = mlps[i] if use_xyz: mlp_spec[0] += 3 self.mlps.append(build_shared_mlp(mlp_spec, bn)) class PointnetSAModule(PointnetSAModuleMSG): r"""Pointnet set abstrction layer Parameters ---------- npoint : int Number of features radius : float Radius of ball nsample : int Number of samples in the ball query mlp : list Spec of the pointnet before the global max_pool bn : bool Use batchnorm """ def __init__( self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True ): # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None super(PointnetSAModule, self).__init__( mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, ) class PointnetFPModule(nn.Module): r"""Propigates the features of one set to another Parameters ---------- mlp : list Pointnet module parameters bn : bool Use batchnorm """ def __init__(self, mlp, bn=True): # type: (PointnetFPModule, List[int], bool) -> None super(PointnetFPModule, self).__init__() self.mlp = build_shared_mlp(mlp, bn=bn) def forward(self, unknown, known, unknow_feats, known_feats): # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor r""" Parameters ---------- unknown : torch.Tensor (B, n, 3) tensor of the xyz positions of the unknown features known : torch.Tensor (B, m, 3) tensor of the xyz positions of the known features unknow_feats : torch.Tensor (B, C1, n) tensor of the features to be propigated to known_feats : torch.Tensor (B, C2, m) tensor of features to be propigated Returns ------- new_features : torch.Tensor (B, mlp[-1], n) tensor of the features of the unknown features """ if known is not None: dist, idx = pointnet2_utils.three_nn(unknown, known) dist_recip = 1.0 / (dist + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_feats = pointnet2_utils.three_interpolate( known_feats, idx, weight ) else: interpolated_feats = known_feats.expand( *(known_feats.size()[0:2] + [unknown.size(1)]) ) if unknow_feats is not None: new_features = torch.cat( [interpolated_feats, unknow_feats], dim=1 ) # (B, C2 + C1, n) else: new_features = interpolated_feats new_features = new_features.unsqueeze(-1) new_features = self.mlp(new_features) return new_features.squeeze(-1)