61 lines
1.8 KiB
Python
61 lines
1.8 KiB
Python
import pathlib
|
|
|
|
import torch
|
|
from torch.utils.cpp_extension import load
|
|
|
|
# get path to where this file is located
|
|
REALPATH = pathlib.Path(__file__).parent
|
|
|
|
# JIT compile CUDA extensions
|
|
load(
|
|
name="torchemd_cuda",
|
|
sources=[
|
|
str(REALPATH / file)
|
|
for file in [
|
|
"cuda/emd.cpp",
|
|
"cuda/emd_kernel.cu",
|
|
]
|
|
],
|
|
)
|
|
|
|
# import compiled CUDA extensions
|
|
import torchemd_cuda
|
|
|
|
|
|
class EarthMoverDistanceFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, xyz1, xyz2) -> torch.Tensor:
|
|
xyz1 = xyz1.contiguous()
|
|
xyz2 = xyz2.contiguous()
|
|
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
|
|
match = torchemd_cuda.approxmatch_forward(xyz1, xyz2)
|
|
cost = torchemd_cuda.matchcost_forward(xyz1, xyz2, match)
|
|
ctx.save_for_backward(xyz1, xyz2, match)
|
|
return cost
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_cost) -> tuple[torch.Tensor, torch.Tensor]:
|
|
xyz1, xyz2, match = ctx.saved_tensors
|
|
grad_cost = grad_cost.contiguous()
|
|
grad_xyz1, grad_xyz2 = torchemd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
|
|
return grad_xyz1, grad_xyz2
|
|
|
|
|
|
def earth_mover_distance(xyz1, xyz2, transpose=True) -> torch.Tensor:
|
|
"""Earth Mover Distance (Approx)
|
|
|
|
Args:
|
|
xyz1 (torch.Tensor): (B, N, 3) or (N, 3)
|
|
xyz2 (torch.Tensor): (B, M, 3) or (M, 3)
|
|
transpose (bool, optional): If True, xyz1 and xyz2 will be transposed to (B, 3, N) and (B, 3, M).
|
|
"""
|
|
if xyz1.dim() == 2:
|
|
xyz1 = xyz1.unsqueeze(0)
|
|
if xyz2.dim() == 2:
|
|
xyz2 = xyz2.unsqueeze(0)
|
|
if transpose:
|
|
xyz1 = xyz1.transpose(1, 2)
|
|
xyz2 = xyz2.transpose(1, 2)
|
|
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
|
return cost
|