PointMLP/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py

366 lines
10 KiB
Python
Raw Normal View History

2023-08-03 14:40:14 +00:00
import warnings
from typing import *
2021-10-04 07:25:18 +00:00
import torch
import torch.nn as nn
from torch.autograd import Function
try:
import pointnet2_ops._ext as _ext
except ImportError:
import glob
import os
2023-08-03 14:40:14 +00:00
import os.path as osp
from torch.utils.cpp_extension import load
2021-10-04 07:25:18 +00:00
warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
_ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
2023-08-03 14:40:14 +00:00
osp.join(_ext_src_root, "src", "*.cu"),
2021-10-04 07:25:18 +00:00
)
_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
_ext = load(
"_ext",
sources=_ext_sources,
extra_include_paths=[osp.join(_ext_src_root, "include")],
extra_cflags=["-O3"],
extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"],
with_cuda=True,
)
class FurthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz, npoint):
# type: (Any, torch.Tensor, int) -> torch.Tensor
2023-08-03 14:40:14 +00:00
r"""Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance.
2021-10-04 07:25:18 +00:00
Parameters
----------
xyz : torch.Tensor
(B, N, 3) tensor where N > npoint
npoint : int32
number of features in the sampled set
2023-08-03 14:40:14 +00:00
Returns:
2021-10-04 07:25:18 +00:00
-------
torch.Tensor
(B, npoint) tensor containing the set
"""
out = _ext.furthest_point_sampling(xyz, npoint)
ctx.mark_non_differentiable(out)
return out
@staticmethod
def backward(ctx, grad_out):
return ()
furthest_point_sample = FurthestPointSampling.apply
class GatherOperation(Function):
@staticmethod
def forward(ctx, features, idx):
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
2023-08-03 14:40:14 +00:00
r"""Parameters
2021-10-04 07:25:18 +00:00
----------
features : torch.Tensor
(B, C, N) tensor
idx : torch.Tensor
(B, npoint) tensor of the features to gather
2023-08-03 14:40:14 +00:00
Returns:
2021-10-04 07:25:18 +00:00
-------
torch.Tensor
(B, C, npoint) tensor
"""
ctx.save_for_backward(idx, features)
return _ext.gather_points(features, idx)
@staticmethod
def backward(ctx, grad_out):
idx, features = ctx.saved_tensors
N = features.size(2)
grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
return grad_features, None
gather_operation = GatherOperation.apply
class ThreeNN(Function):
@staticmethod
def forward(ctx, unknown, known):
# type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
2023-08-03 14:40:14 +00:00
r"""Find the three nearest neighbors of unknown in known
2021-10-04 07:25:18 +00:00
Parameters
----------
unknown : torch.Tensor
(B, n, 3) tensor of known features
known : torch.Tensor
2023-08-03 14:40:14 +00:00
(B, m, 3) tensor of unknown features.
2021-10-04 07:25:18 +00:00
2023-08-03 14:40:14 +00:00
Returns:
2021-10-04 07:25:18 +00:00
-------
dist : torch.Tensor
(B, n, 3) l2 distance to the three nearest neighbors
idx : torch.Tensor
(B, n, 3) index of 3 nearest neighbors
"""
dist2, idx = _ext.three_nn(unknown, known)
dist = torch.sqrt(dist2)
ctx.mark_non_differentiable(dist, idx)
return dist, idx
@staticmethod
def backward(ctx, grad_dist, grad_idx):
return ()
three_nn = ThreeNN.apply
class ThreeInterpolate(Function):
@staticmethod
def forward(ctx, features, idx, weight):
# type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
2023-08-03 14:40:14 +00:00
r"""Performs weight linear interpolation on 3 features
2021-10-04 07:25:18 +00:00
Parameters
----------
features : torch.Tensor
(B, c, m) Features descriptors to be interpolated from
idx : torch.Tensor
(B, n, 3) three nearest neighbors of the target features in features
weight : torch.Tensor
2023-08-03 14:40:14 +00:00
(B, n, 3) weights.
2021-10-04 07:25:18 +00:00
2023-08-03 14:40:14 +00:00
Returns:
2021-10-04 07:25:18 +00:00
-------
torch.Tensor
(B, c, n) tensor of the interpolated features
"""
ctx.save_for_backward(idx, weight, features)
return _ext.three_interpolate(features, idx, weight)
@staticmethod
def backward(ctx, grad_out):
# type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
2023-08-03 14:40:14 +00:00
r"""Parameters
2021-10-04 07:25:18 +00:00
----------
grad_out : torch.Tensor
(B, c, n) tensor with gradients of ouputs
2023-08-03 14:40:14 +00:00
Returns:
2021-10-04 07:25:18 +00:00
-------
grad_features : torch.Tensor
(B, c, m) tensor with gradients of features
None
None
"""
idx, weight, features = ctx.saved_tensors
m = features.size(2)
grad_features = _ext.three_interpolate_grad(
2023-08-03 14:40:14 +00:00
grad_out.contiguous(),
idx,
weight,
m,
2021-10-04 07:25:18 +00:00
)
return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
three_interpolate = ThreeInterpolate.apply
class GroupingOperation(Function):
@staticmethod
def forward(ctx, features, idx):
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
2023-08-03 14:40:14 +00:00
r"""Parameters
2021-10-04 07:25:18 +00:00
----------
features : torch.Tensor
(B, C, N) tensor of features to group
idx : torch.Tensor
(B, npoint, nsample) tensor containing the indicies of features to group with
2023-08-03 14:40:14 +00:00
Returns:
2021-10-04 07:25:18 +00:00
-------
torch.Tensor
(B, C, npoint, nsample) tensor
"""
ctx.save_for_backward(idx, features)
return _ext.group_points(features, idx)
@staticmethod
def backward(ctx, grad_out):
# type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
2023-08-03 14:40:14 +00:00
r"""Parameters
2021-10-04 07:25:18 +00:00
----------
grad_out : torch.Tensor
(B, C, npoint, nsample) tensor of the gradients of the output from forward
2023-08-03 14:40:14 +00:00
Returns:
2021-10-04 07:25:18 +00:00
-------
torch.Tensor
(B, C, N) gradient of the features
None
"""
idx, features = ctx.saved_tensors
N = features.size(2)
grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
return grad_features, torch.zeros_like(idx)
grouping_operation = GroupingOperation.apply
class BallQuery(Function):
@staticmethod
def forward(ctx, radius, nsample, xyz, new_xyz):
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
2023-08-03 14:40:14 +00:00
r"""Parameters
2021-10-04 07:25:18 +00:00
----------
radius : float
radius of the balls
nsample : int
maximum number of features in the balls
xyz : torch.Tensor
(B, N, 3) xyz coordinates of the features
new_xyz : torch.Tensor
(B, npoint, 3) centers of the ball query
2023-08-03 14:40:14 +00:00
Returns:
2021-10-04 07:25:18 +00:00
-------
torch.Tensor
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
output = _ext.ball_query(new_xyz, xyz, radius, nsample)
ctx.mark_non_differentiable(output)
return output
@staticmethod
def backward(ctx, grad_out):
return ()
ball_query = BallQuery.apply
class QueryAndGroup(nn.Module):
2023-08-03 14:40:14 +00:00
r"""Groups with a ball query of radius.
2021-10-04 07:25:18 +00:00
Parameters
---------
radius : float32
Radius of ball
nsample : int32
Maximum number of features to gather in the ball
"""
def __init__(self, radius, nsample, use_xyz=True):
# type: (QueryAndGroup, float, int, bool) -> None
2023-08-03 14:40:14 +00:00
super().__init__()
2021-10-04 07:25:18 +00:00
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(self, xyz, new_xyz, features=None):
# type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
2023-08-03 14:40:14 +00:00
r"""Parameters
2021-10-04 07:25:18 +00:00
----------
xyz : torch.Tensor
xyz coordinates of the features (B, N, 3)
new_xyz : torch.Tensor
centriods (B, npoint, 3)
features : torch.Tensor
Descriptors of the features (B, C, N)
2023-08-03 14:40:14 +00:00
Returns:
2021-10-04 07:25:18 +00:00
-------
new_features : torch.Tensor
(B, 3 + C, npoint, nsample) tensor
"""
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
xyz_trans = xyz.transpose(1, 2).contiguous()
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
if features is not None:
grouped_features = grouping_operation(features, idx)
if self.use_xyz:
new_features = torch.cat(
2023-08-03 14:40:14 +00:00
[grouped_xyz, grouped_features],
dim=1,
2021-10-04 07:25:18 +00:00
) # (B, C + 3, npoint, nsample)
else:
new_features = grouped_features
else:
2023-08-03 14:40:14 +00:00
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
2021-10-04 07:25:18 +00:00
new_features = grouped_xyz
return new_features
class GroupAll(nn.Module):
2023-08-03 14:40:14 +00:00
r"""Groups all features.
2021-10-04 07:25:18 +00:00
Parameters
---------
"""
def __init__(self, use_xyz=True):
# type: (GroupAll, bool) -> None
2023-08-03 14:40:14 +00:00
super().__init__()
2021-10-04 07:25:18 +00:00
self.use_xyz = use_xyz
def forward(self, xyz, new_xyz, features=None):
# type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
2023-08-03 14:40:14 +00:00
r"""Parameters
2021-10-04 07:25:18 +00:00
----------
xyz : torch.Tensor
xyz coordinates of the features (B, N, 3)
new_xyz : torch.Tensor
Ignored
features : torch.Tensor
Descriptors of the features (B, C, N)
2023-08-03 14:40:14 +00:00
Returns:
2021-10-04 07:25:18 +00:00
-------
new_features : torch.Tensor
(B, C + 3, 1, N) tensor
"""
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
if features is not None:
grouped_features = features.unsqueeze(2)
if self.use_xyz:
new_features = torch.cat(
2023-08-03 14:40:14 +00:00
[grouped_xyz, grouped_features],
dim=1,
2021-10-04 07:25:18 +00:00
) # (B, 3 + C, 1, N)
else:
new_features = grouped_features
else:
new_features = grouped_xyz
return new_features