import pathlib import torch from torch.utils.cpp_extension import load # get path to where this file is located REALPATH = pathlib.Path(__file__).parent # JIT compile CUDA extensions load( name="torchemd_cuda", sources=[ str(REALPATH / file) for file in [ "cuda/emd.cpp", "cuda/emd_kernel.cu", ] ], ) # import compiled CUDA extensions import torchemd_cuda class EarthMoverDistanceFunction(torch.autograd.Function): @staticmethod def forward(ctx, xyz1, xyz2) -> torch.Tensor: xyz1 = xyz1.contiguous() xyz2 = xyz2.contiguous() assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently." match = torchemd_cuda.approxmatch_forward(xyz1, xyz2) cost = torchemd_cuda.matchcost_forward(xyz1, xyz2, match) ctx.save_for_backward(xyz1, xyz2, match) return cost @staticmethod def backward(ctx, grad_cost) -> tuple[torch.Tensor, torch.Tensor]: xyz1, xyz2, match = ctx.saved_tensors grad_cost = grad_cost.contiguous() grad_xyz1, grad_xyz2 = torchemd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match) return grad_xyz1, grad_xyz2 def earth_mover_distance(xyz1, xyz2, transpose=True) -> torch.Tensor: """Earth Mover Distance (Approx) Args: xyz1 (torch.Tensor): (B, N, 3) or (N, 3) xyz2 (torch.Tensor): (B, M, 3) or (M, 3) transpose (bool, optional): If True, xyz1 and xyz2 will be transposed to (B, 3, N) and (B, 3, M). """ if xyz1.dim() == 2: xyz1 = xyz1.unsqueeze(0) if xyz2.dim() == 2: xyz2 = xyz2.unsqueeze(0) if transpose: xyz1 = xyz1.transpose(1, 2) xyz2 = xyz2.transpose(1, 2) cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) return cost