380 lines
10 KiB
Python
380 lines
10 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import warnings
|
|
from torch.autograd import Function
|
|
from typing import *
|
|
|
|
try:
|
|
import pointnet2_ops._ext as _ext
|
|
except ImportError:
|
|
from torch.utils.cpp_extension import load
|
|
import glob
|
|
import os.path as osp
|
|
import os
|
|
|
|
warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
|
|
|
|
_ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
|
|
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
|
|
osp.join(_ext_src_root, "src", "*.cu")
|
|
)
|
|
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
|
|
|
|
os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
|
|
_ext = load(
|
|
"_ext",
|
|
sources=_ext_sources,
|
|
extra_include_paths=[osp.join(_ext_src_root, "include")],
|
|
extra_cflags=["-O3"],
|
|
extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"],
|
|
with_cuda=True,
|
|
)
|
|
|
|
|
|
class FurthestPointSampling(Function):
|
|
@staticmethod
|
|
def forward(ctx, xyz, npoint):
|
|
# type: (Any, torch.Tensor, int) -> torch.Tensor
|
|
r"""
|
|
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
|
minimum distance
|
|
|
|
Parameters
|
|
----------
|
|
xyz : torch.Tensor
|
|
(B, N, 3) tensor where N > npoint
|
|
npoint : int32
|
|
number of features in the sampled set
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, npoint) tensor containing the set
|
|
"""
|
|
out = _ext.furthest_point_sampling(xyz, npoint)
|
|
|
|
ctx.mark_non_differentiable(out)
|
|
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return ()
|
|
|
|
|
|
furthest_point_sample = FurthestPointSampling.apply
|
|
|
|
|
|
class GatherOperation(Function):
|
|
@staticmethod
|
|
def forward(ctx, features, idx):
|
|
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
|
|
r"""
|
|
|
|
Parameters
|
|
----------
|
|
features : torch.Tensor
|
|
(B, C, N) tensor
|
|
|
|
idx : torch.Tensor
|
|
(B, npoint) tensor of the features to gather
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, C, npoint) tensor
|
|
"""
|
|
|
|
ctx.save_for_backward(idx, features)
|
|
|
|
return _ext.gather_points(features, idx)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
idx, features = ctx.saved_tensors
|
|
N = features.size(2)
|
|
|
|
grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
|
|
return grad_features, None
|
|
|
|
|
|
gather_operation = GatherOperation.apply
|
|
|
|
|
|
class ThreeNN(Function):
|
|
@staticmethod
|
|
def forward(ctx, unknown, known):
|
|
# type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
|
r"""
|
|
Find the three nearest neighbors of unknown in known
|
|
Parameters
|
|
----------
|
|
unknown : torch.Tensor
|
|
(B, n, 3) tensor of known features
|
|
known : torch.Tensor
|
|
(B, m, 3) tensor of unknown features
|
|
|
|
Returns
|
|
-------
|
|
dist : torch.Tensor
|
|
(B, n, 3) l2 distance to the three nearest neighbors
|
|
idx : torch.Tensor
|
|
(B, n, 3) index of 3 nearest neighbors
|
|
"""
|
|
dist2, idx = _ext.three_nn(unknown, known)
|
|
dist = torch.sqrt(dist2)
|
|
|
|
ctx.mark_non_differentiable(dist, idx)
|
|
|
|
return dist, idx
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_dist, grad_idx):
|
|
return ()
|
|
|
|
|
|
three_nn = ThreeNN.apply
|
|
|
|
|
|
class ThreeInterpolate(Function):
|
|
@staticmethod
|
|
def forward(ctx, features, idx, weight):
|
|
# type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
|
|
r"""
|
|
Performs weight linear interpolation on 3 features
|
|
Parameters
|
|
----------
|
|
features : torch.Tensor
|
|
(B, c, m) Features descriptors to be interpolated from
|
|
idx : torch.Tensor
|
|
(B, n, 3) three nearest neighbors of the target features in features
|
|
weight : torch.Tensor
|
|
(B, n, 3) weights
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, c, n) tensor of the interpolated features
|
|
"""
|
|
ctx.save_for_backward(idx, weight, features)
|
|
|
|
return _ext.three_interpolate(features, idx, weight)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
# type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
r"""
|
|
Parameters
|
|
----------
|
|
grad_out : torch.Tensor
|
|
(B, c, n) tensor with gradients of ouputs
|
|
|
|
Returns
|
|
-------
|
|
grad_features : torch.Tensor
|
|
(B, c, m) tensor with gradients of features
|
|
|
|
None
|
|
|
|
None
|
|
"""
|
|
idx, weight, features = ctx.saved_tensors
|
|
m = features.size(2)
|
|
|
|
grad_features = _ext.three_interpolate_grad(
|
|
grad_out.contiguous(), idx, weight, m
|
|
)
|
|
|
|
return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
|
|
|
|
|
|
three_interpolate = ThreeInterpolate.apply
|
|
|
|
|
|
class GroupingOperation(Function):
|
|
@staticmethod
|
|
def forward(ctx, features, idx):
|
|
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
|
|
r"""
|
|
|
|
Parameters
|
|
----------
|
|
features : torch.Tensor
|
|
(B, C, N) tensor of features to group
|
|
idx : torch.Tensor
|
|
(B, npoint, nsample) tensor containing the indicies of features to group with
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, C, npoint, nsample) tensor
|
|
"""
|
|
ctx.save_for_backward(idx, features)
|
|
|
|
return _ext.group_points(features, idx)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
# type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
|
r"""
|
|
|
|
Parameters
|
|
----------
|
|
grad_out : torch.Tensor
|
|
(B, C, npoint, nsample) tensor of the gradients of the output from forward
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, C, N) gradient of the features
|
|
None
|
|
"""
|
|
idx, features = ctx.saved_tensors
|
|
N = features.size(2)
|
|
|
|
grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
|
|
|
|
return grad_features, torch.zeros_like(idx)
|
|
|
|
|
|
grouping_operation = GroupingOperation.apply
|
|
|
|
|
|
class BallQuery(Function):
|
|
@staticmethod
|
|
def forward(ctx, radius, nsample, xyz, new_xyz):
|
|
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
|
|
r"""
|
|
|
|
Parameters
|
|
----------
|
|
radius : float
|
|
radius of the balls
|
|
nsample : int
|
|
maximum number of features in the balls
|
|
xyz : torch.Tensor
|
|
(B, N, 3) xyz coordinates of the features
|
|
new_xyz : torch.Tensor
|
|
(B, npoint, 3) centers of the ball query
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
|
"""
|
|
output = _ext.ball_query(new_xyz, xyz, radius, nsample)
|
|
|
|
ctx.mark_non_differentiable(output)
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return ()
|
|
|
|
|
|
ball_query = BallQuery.apply
|
|
|
|
|
|
class QueryAndGroup(nn.Module):
|
|
r"""
|
|
Groups with a ball query of radius
|
|
|
|
Parameters
|
|
---------
|
|
radius : float32
|
|
Radius of ball
|
|
nsample : int32
|
|
Maximum number of features to gather in the ball
|
|
"""
|
|
|
|
def __init__(self, radius, nsample, use_xyz=True):
|
|
# type: (QueryAndGroup, float, int, bool) -> None
|
|
super(QueryAndGroup, self).__init__()
|
|
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
|
|
|
def forward(self, xyz, new_xyz, features=None):
|
|
# type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
|
|
r"""
|
|
Parameters
|
|
----------
|
|
xyz : torch.Tensor
|
|
xyz coordinates of the features (B, N, 3)
|
|
new_xyz : torch.Tensor
|
|
centriods (B, npoint, 3)
|
|
features : torch.Tensor
|
|
Descriptors of the features (B, C, N)
|
|
|
|
Returns
|
|
-------
|
|
new_features : torch.Tensor
|
|
(B, 3 + C, npoint, nsample) tensor
|
|
"""
|
|
|
|
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
|
xyz_trans = xyz.transpose(1, 2).contiguous()
|
|
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
|
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
|
|
|
|
if features is not None:
|
|
grouped_features = grouping_operation(features, idx)
|
|
if self.use_xyz:
|
|
new_features = torch.cat(
|
|
[grouped_xyz, grouped_features], dim=1
|
|
) # (B, C + 3, npoint, nsample)
|
|
else:
|
|
new_features = grouped_features
|
|
else:
|
|
assert (
|
|
self.use_xyz
|
|
), "Cannot have not features and not use xyz as a feature!"
|
|
new_features = grouped_xyz
|
|
|
|
return new_features
|
|
|
|
|
|
class GroupAll(nn.Module):
|
|
r"""
|
|
Groups all features
|
|
|
|
Parameters
|
|
---------
|
|
"""
|
|
|
|
def __init__(self, use_xyz=True):
|
|
# type: (GroupAll, bool) -> None
|
|
super(GroupAll, self).__init__()
|
|
self.use_xyz = use_xyz
|
|
|
|
def forward(self, xyz, new_xyz, features=None):
|
|
# type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
|
|
r"""
|
|
Parameters
|
|
----------
|
|
xyz : torch.Tensor
|
|
xyz coordinates of the features (B, N, 3)
|
|
new_xyz : torch.Tensor
|
|
Ignored
|
|
features : torch.Tensor
|
|
Descriptors of the features (B, C, N)
|
|
|
|
Returns
|
|
-------
|
|
new_features : torch.Tensor
|
|
(B, C + 3, 1, N) tensor
|
|
"""
|
|
|
|
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
|
if features is not None:
|
|
grouped_features = features.unsqueeze(2)
|
|
if self.use_xyz:
|
|
new_features = torch.cat(
|
|
[grouped_xyz, grouped_features], dim=1
|
|
) # (B, 3 + C, 1, N)
|
|
else:
|
|
new_features = grouped_features
|
|
else:
|
|
new_features = grouped_xyz
|
|
|
|
return new_features
|