2021-10-19 20:54:46 +00:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from torch.autograd import Function
|
|
|
|
|
|
|
|
from modules.functional.backend import _backend
|
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
__all__ = ["gather", "furthest_point_sample", "logits_mask"]
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Gather(Function):
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, features, indices):
|
|
|
|
"""
|
|
|
|
Gather
|
|
|
|
:param ctx:
|
|
|
|
:param features: features of points, FloatTensor[B, C, N]
|
|
|
|
:param indices: centers' indices in points, IntTensor[b, m]
|
|
|
|
:return:
|
|
|
|
centers_coords: coordinates of sampled centers, FloatTensor[B, C, M]
|
|
|
|
"""
|
|
|
|
features = features.contiguous()
|
|
|
|
indices = indices.int().contiguous()
|
|
|
|
ctx.save_for_backward(indices)
|
|
|
|
ctx.num_points = features.size(-1)
|
|
|
|
return _backend.gather_features_forward(features, indices)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_output):
|
2023-04-11 09:12:58 +00:00
|
|
|
(indices,) = ctx.saved_tensors
|
2021-10-19 20:54:46 +00:00
|
|
|
grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points)
|
|
|
|
return grad_features, None
|
|
|
|
|
|
|
|
|
|
|
|
gather = Gather.apply
|
|
|
|
|
|
|
|
|
|
|
|
def furthest_point_sample(coords, num_samples):
|
|
|
|
"""
|
|
|
|
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
|
|
|
minimum distance to the sampled point set
|
|
|
|
:param coords: coordinates of points, FloatTensor[B, 3, N]
|
|
|
|
:param num_samples: int, M
|
|
|
|
:return:
|
|
|
|
centers_coords: coordinates of sampled centers, FloatTensor[B, 3, M]
|
|
|
|
"""
|
|
|
|
coords = coords.contiguous()
|
|
|
|
indices = _backend.furthest_point_sampling(coords, num_samples)
|
|
|
|
return gather(coords, indices)
|
|
|
|
|
|
|
|
|
|
|
|
def logits_mask(coords, logits, num_points_per_object):
|
|
|
|
"""
|
|
|
|
Use logits to sample points
|
|
|
|
:param coords: coords of points, FloatTensor[B, 3, N]
|
|
|
|
:param logits: binary classification logits, FloatTensor[B, 2, N]
|
|
|
|
:param num_points_per_object: M, #points per object after masking, int
|
|
|
|
:return:
|
|
|
|
selected_coords: FloatTensor[B, 3, M]
|
|
|
|
masked_coords_mean: mean coords of selected points, FloatTensor[B, 3]
|
|
|
|
mask: mask to select points, BoolTensor[B, N]
|
|
|
|
"""
|
|
|
|
batch_size, _, num_points = coords.shape
|
2023-04-11 09:12:58 +00:00
|
|
|
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
|
2021-10-19 20:54:46 +00:00
|
|
|
num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
|
|
|
|
masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
|
2023-04-11 09:12:58 +00:00
|
|
|
masked_coords_mean = (
|
|
|
|
torch.sum(masked_coords, dim=-1) / torch.max(num_candidates, torch.ones_like(num_candidates)).float()
|
|
|
|
) # [B, C]
|
2021-10-19 20:54:46 +00:00
|
|
|
selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32)
|
|
|
|
for i in range(batch_size):
|
|
|
|
current_mask = mask[i] # [N]
|
|
|
|
current_candidates = current_mask.nonzero().view(-1)
|
|
|
|
current_num_candidates = current_candidates.numel()
|
|
|
|
if current_num_candidates >= num_points_per_object:
|
|
|
|
choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False)
|
|
|
|
selected_indices[i] = current_candidates[choices]
|
|
|
|
elif current_num_candidates > 0:
|
2023-04-11 09:12:58 +00:00
|
|
|
choices = np.concatenate(
|
|
|
|
[
|
|
|
|
np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
|
|
|
|
np.random.choice(
|
|
|
|
current_num_candidates, num_points_per_object % current_num_candidates, replace=False
|
|
|
|
),
|
|
|
|
]
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
np.random.shuffle(choices)
|
|
|
|
selected_indices[i] = current_candidates[choices]
|
|
|
|
selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices)
|
|
|
|
return selected_coords, masked_coords_mean, mask
|