PyTorchEMD/torchemd/emd.py
2023-04-24 15:07:12 +02:00

61 lines
1.8 KiB
Python

import pathlib
import torch
from torch.utils.cpp_extension import load
# get path to where this file is located
REALPATH = pathlib.Path(__file__).parent
# JIT compile CUDA extensions
load(
name="torchemd_cuda",
sources=[
str(REALPATH / file)
for file in [
"cuda/emd.cpp",
"cuda/emd_kernel.cu",
]
],
)
# import compiled CUDA extensions
import torchemd_cuda
class EarthMoverDistanceFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2) -> torch.Tensor:
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = torchemd_cuda.approxmatch_forward(xyz1, xyz2)
cost = torchemd_cuda.matchcost_forward(xyz1, xyz2, match)
ctx.save_for_backward(xyz1, xyz2, match)
return cost
@staticmethod
def backward(ctx, grad_cost) -> tuple[torch.Tensor, torch.Tensor]:
xyz1, xyz2, match = ctx.saved_tensors
grad_cost = grad_cost.contiguous()
grad_xyz1, grad_xyz2 = torchemd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
return grad_xyz1, grad_xyz2
def earth_mover_distance(xyz1, xyz2, transpose=True) -> torch.Tensor:
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (B, N, 3) or (N, 3)
xyz2 (torch.Tensor): (B, M, 3) or (M, 3)
transpose (bool, optional): If True, xyz1 and xyz2 will be transposed to (B, 3, N) and (B, 3, M).
"""
if xyz1.dim() == 2:
xyz1 = xyz1.unsqueeze(0)
if xyz2.dim() == 2:
xyz2 = xyz2.unsqueeze(0)
if transpose:
xyz1 = xyz1.transpose(1, 2)
xyz2 = xyz2.transpose(1, 2)
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
return cost