55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
|
from torch.autograd import Function
|
||
|
|
||
|
# from modules.functional.backend import _backend
|
||
|
from third_party.pvcnn.functional.backend import _backend
|
||
|
import torch
|
||
|
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
|
||
|
|
||
|
__all__ = ['nearest_neighbor_interpolate']
|
||
|
|
||
|
|
||
|
class NeighborInterpolation(Function):
|
||
|
@staticmethod
|
||
|
@custom_fwd(cast_inputs=torch.float32)
|
||
|
def forward(ctx, points_coords, centers_coords, centers_features):
|
||
|
"""
|
||
|
:param ctx:
|
||
|
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
|
||
|
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
|
||
|
:param centers_features: features of centers, FloatTensor[B, C, M]
|
||
|
:return:
|
||
|
points_features: features of points, FloatTensor[B, C, N]
|
||
|
"""
|
||
|
centers_coords = centers_coords[:,:3].contiguous()
|
||
|
points_coords = points_coords[:,:3].contiguous()
|
||
|
centers_features = centers_features.contiguous()
|
||
|
points_features, indices, weights = _backend.three_nearest_neighbors_interpolate_forward(
|
||
|
points_coords, centers_coords, centers_features)
|
||
|
ctx.save_for_backward(indices, weights)
|
||
|
ctx.num_centers = centers_coords.size(-1)
|
||
|
return points_features
|
||
|
|
||
|
@staticmethod
|
||
|
@custom_bwd
|
||
|
def backward(ctx, grad_output):
|
||
|
indices, weights = ctx.saved_tensors
|
||
|
grad_centers_features = _backend.three_nearest_neighbors_interpolate_backward(
|
||
|
grad_output.contiguous(), indices, weights, ctx.num_centers)
|
||
|
return None, None, grad_centers_features
|
||
|
|
||
|
|
||
|
nearest_neighbor_interpolate = NeighborInterpolation.apply
|
||
|
|
||
|
#def nearest_neighbor_interpolate(points_coords, centers_coords, centers_features):
|
||
|
# # points_coords: (B,6, 64)
|
||
|
# # centers_coords: (B,6, 16)
|
||
|
# # centers_features: (B,128,16)
|
||
|
# # interpolated_features: (B,128,64)
|
||
|
# B = points_coords.shape[0]
|
||
|
# D = centers_features.shape[1]
|
||
|
# N = points_coords.shape[2]
|
||
|
# output = torch.zeros(B,D,N).to(points_coords.shape)
|
||
|
# for b in range(B):
|
||
|
# for n in range(N):
|
||
|
# points_coords_cur = points_coords
|