PVD/modules/functional/devoxelization.py

43 lines
1.4 KiB
Python
Raw Normal View History

2021-10-19 20:54:46 +00:00
from torch.autograd import Function
from modules.functional.backend import _backend
2023-04-11 09:12:58 +00:00
__all__ = ["trilinear_devoxelize"]
2021-10-19 20:54:46 +00:00
class TrilinearDevoxelization(Function):
@staticmethod
def forward(ctx, features, coords, resolution, is_training=True):
"""
:param ctx:
:param coords: the coordinates of points, FloatTensor[B, 3, N]
:param features: FloatTensor[B, C, R, R, R]
:param resolution: int, the voxel resolution
:param is_training: bool, training mode
:return:
FloatTensor[B, C, N]
"""
B, C = features.shape[:2]
features = features.contiguous().view(B, C, -1)
coords = coords.contiguous()
outs, inds, wgts = _backend.trilinear_devoxelize_forward(resolution, is_training, coords, features)
if is_training:
ctx.save_for_backward(inds, wgts)
ctx.r = resolution
return outs
@staticmethod
def backward(ctx, grad_output):
"""
2023-04-11 09:12:58 +00:00
:param ctx:
2021-10-19 20:54:46 +00:00
:param grad_output: gradient of outputs, FloatTensor[B, C, N]
:return:
gradient of inputs, FloatTensor[B, C, R, R, R]
"""
inds, wgts = ctx.saved_tensors
grad_inputs = _backend.trilinear_devoxelize_backward(grad_output.contiguous(), inds, wgts, ctx.r)
return grad_inputs.view(grad_output.size(0), grad_output.size(1), ctx.r, ctx.r, ctx.r), None, None, None
trilinear_devoxelize = TrilinearDevoxelization.apply