LION/third_party/PyTorchEMD/emd.py

53 lines
1.7 KiB
Python
Raw Permalink Normal View History

2023-01-23 05:14:49 +00:00
import torch
# from backend import emd_cuda_dynamic as emd_cuda # jit compiling
from third_party.PyTorchEMD.backend import emd_cuda_dynamic as emd_cuda
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
class EarthMoverDistanceFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, xyz1, xyz2):
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
ctx.save_for_backward(xyz1, xyz2, match)
return cost
@staticmethod
@custom_bwd
def backward(ctx, grad_cost):
xyz1, xyz2, match = ctx.saved_tensors
grad_cost = grad_cost.contiguous()
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
return grad_xyz1, grad_xyz2
def earth_mover_distance(xyz1, xyz2, transpose=True):
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (b, 3, n1)
xyz2 (torch.Tensor): (b, 3, n1)
transpose (bool): whether to transpose inputs as it might be BCN format.
Extensions only support BNC format.
Returns:
cost (torch.Tensor): (b)
"""
if xyz1.dim() == 2:
xyz1 = xyz1.unsqueeze(0)
if xyz2.dim() == 2:
xyz2 = xyz2.unsqueeze(0)
if transpose:
xyz1 = xyz1.transpose(1, 2)
xyz2 = xyz2.transpose(1, 2)
# xyz1: B,N,3
N = xyz1.shape[1]
assert(xyz1.shape[-1] == 3), f'require it to be B,N,3; get: {xyz1.shape}'
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) / float(N)
return cost