🚚 move code to package
This commit is contained in:
parent
892b79f5c5
commit
00fc30f5d5
45
emd.py
45
emd.py
|
@ -1,45 +0,0 @@
|
||||||
import emd_cuda
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class EarthMoverDistanceFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, xyz1, xyz2):
|
|
||||||
xyz1 = xyz1.contiguous()
|
|
||||||
xyz2 = xyz2.contiguous()
|
|
||||||
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
|
|
||||||
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
|
|
||||||
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
|
|
||||||
ctx.save_for_backward(xyz1, xyz2, match)
|
|
||||||
return cost
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_cost):
|
|
||||||
xyz1, xyz2, match = ctx.saved_tensors
|
|
||||||
grad_cost = grad_cost.contiguous()
|
|
||||||
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
|
|
||||||
return grad_xyz1, grad_xyz2
|
|
||||||
|
|
||||||
|
|
||||||
def earth_mover_distance(xyz1, xyz2, transpose=True):
|
|
||||||
"""Earth Mover Distance (Approx)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
xyz1 (torch.Tensor): (b, 3, n1)
|
|
||||||
xyz2 (torch.Tensor): (b, 3, n1)
|
|
||||||
transpose (bool): whether to transpose inputs as it might be BCN format.
|
|
||||||
Extensions only support BNC format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
cost (torch.Tensor): (b)
|
|
||||||
|
|
||||||
"""
|
|
||||||
if xyz1.dim() == 2:
|
|
||||||
xyz1 = xyz1.unsqueeze(0)
|
|
||||||
if xyz2.dim() == 2:
|
|
||||||
xyz2 = xyz2.unsqueeze(0)
|
|
||||||
if transpose:
|
|
||||||
xyz1 = xyz1.transpose(1, 2)
|
|
||||||
xyz2 = xyz2.transpose(1, 2)
|
|
||||||
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
|
||||||
return cost
|
|
25
setup.py
25
setup.py
|
@ -1,25 +0,0 @@
|
||||||
"""Setup extension
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
If extra_compile_args is provided, you need to provide different instances for different extensions.
|
|
||||||
Refer to https://github.com/pytorch/pytorch/issues/20169
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from setuptools import setup
|
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name="emd_ext",
|
|
||||||
ext_modules=[
|
|
||||||
CUDAExtension(
|
|
||||||
name="emd_cuda",
|
|
||||||
sources=[
|
|
||||||
"cuda/emd.cpp",
|
|
||||||
"cuda/emd_kernel.cu",
|
|
||||||
],
|
|
||||||
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
cmdclass={"build_ext": BuildExtension},
|
|
||||||
)
|
|
|
@ -1,45 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from emd import earth_mover_distance
|
|
||||||
|
|
||||||
# gt
|
|
||||||
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
|
||||||
p1 = p1.repeat(3, 1, 1)
|
|
||||||
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
|
||||||
p2 = p2.repeat(3, 1, 1)
|
|
||||||
print(p1)
|
|
||||||
print(p2)
|
|
||||||
p1.requires_grad = True
|
|
||||||
p2.requires_grad = True
|
|
||||||
|
|
||||||
gt_dist = (
|
|
||||||
(((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2
|
|
||||||
+ (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2
|
|
||||||
+ (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3
|
|
||||||
)
|
|
||||||
print("gt_dist: ", gt_dist)
|
|
||||||
|
|
||||||
gt_dist.backward()
|
|
||||||
print(p1.grad)
|
|
||||||
print(p2.grad)
|
|
||||||
|
|
||||||
# emd
|
|
||||||
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
|
||||||
p1 = p1.repeat(3, 1, 1)
|
|
||||||
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
|
||||||
p2 = p2.repeat(3, 1, 1)
|
|
||||||
print(p1)
|
|
||||||
print(p2)
|
|
||||||
p1.requires_grad = True
|
|
||||||
p2.requires_grad = True
|
|
||||||
|
|
||||||
d = earth_mover_distance(p1, p2, transpose=False)
|
|
||||||
print(d)
|
|
||||||
|
|
||||||
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
|
|
||||||
print(loss)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
print(p1.grad)
|
|
||||||
print(p2.grad)
|
|
1
torchemd/__init__.py
Executable file
1
torchemd/__init__.py
Executable file
|
@ -0,0 +1 @@
|
||||||
|
from .emd import EarthMoverDistanceFunction, earth_mover_distance
|
60
torchemd/emd.py
Normal file
60
torchemd/emd.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.cpp_extension import load
|
||||||
|
|
||||||
|
# get path to where this file is located
|
||||||
|
REALPATH = pathlib.Path(__file__).parent
|
||||||
|
|
||||||
|
# JIT compile CUDA extensions
|
||||||
|
load(
|
||||||
|
name="torchemd_cuda",
|
||||||
|
sources=[
|
||||||
|
str(REALPATH / file)
|
||||||
|
for file in [
|
||||||
|
"cuda/emd.cpp",
|
||||||
|
"cuda/emd_kernel.cu",
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# import compiled CUDA extensions
|
||||||
|
import torchemd_cuda
|
||||||
|
|
||||||
|
|
||||||
|
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, xyz1, xyz2) -> torch.Tensor:
|
||||||
|
xyz1 = xyz1.contiguous()
|
||||||
|
xyz2 = xyz2.contiguous()
|
||||||
|
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
|
||||||
|
match = torchemd_cuda.approxmatch_forward(xyz1, xyz2)
|
||||||
|
cost = torchemd_cuda.matchcost_forward(xyz1, xyz2, match)
|
||||||
|
ctx.save_for_backward(xyz1, xyz2, match)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_cost) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
xyz1, xyz2, match = ctx.saved_tensors
|
||||||
|
grad_cost = grad_cost.contiguous()
|
||||||
|
grad_xyz1, grad_xyz2 = torchemd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
|
||||||
|
return grad_xyz1, grad_xyz2
|
||||||
|
|
||||||
|
|
||||||
|
def earth_mover_distance(xyz1, xyz2, transpose=True) -> torch.Tensor:
|
||||||
|
"""Earth Mover Distance (Approx)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xyz1 (torch.Tensor): (B, N, 3) or (N, 3)
|
||||||
|
xyz2 (torch.Tensor): (B, M, 3) or (M, 3)
|
||||||
|
transpose (bool, optional): If True, xyz1 and xyz2 will be transposed to (B, 3, N) and (B, 3, M).
|
||||||
|
"""
|
||||||
|
if xyz1.dim() == 2:
|
||||||
|
xyz1 = xyz1.unsqueeze(0)
|
||||||
|
if xyz2.dim() == 2:
|
||||||
|
xyz2 = xyz2.unsqueeze(0)
|
||||||
|
if transpose:
|
||||||
|
xyz1 = xyz1.transpose(1, 2)
|
||||||
|
xyz2 = xyz2.transpose(1, 2)
|
||||||
|
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
||||||
|
return cost
|
Loading…
Reference in a new issue