🚚 move code to package
This commit is contained in:
parent
892b79f5c5
commit
00fc30f5d5
45
emd.py
45
emd.py
|
@ -1,45 +0,0 @@
|
|||
import emd_cuda
|
||||
import torch
|
||||
|
||||
|
||||
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz1, xyz2):
|
||||
xyz1 = xyz1.contiguous()
|
||||
xyz2 = xyz2.contiguous()
|
||||
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
|
||||
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
|
||||
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
|
||||
ctx.save_for_backward(xyz1, xyz2, match)
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_cost):
|
||||
xyz1, xyz2, match = ctx.saved_tensors
|
||||
grad_cost = grad_cost.contiguous()
|
||||
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
|
||||
return grad_xyz1, grad_xyz2
|
||||
|
||||
|
||||
def earth_mover_distance(xyz1, xyz2, transpose=True):
|
||||
"""Earth Mover Distance (Approx)
|
||||
|
||||
Args:
|
||||
xyz1 (torch.Tensor): (b, 3, n1)
|
||||
xyz2 (torch.Tensor): (b, 3, n1)
|
||||
transpose (bool): whether to transpose inputs as it might be BCN format.
|
||||
Extensions only support BNC format.
|
||||
|
||||
Returns:
|
||||
cost (torch.Tensor): (b)
|
||||
|
||||
"""
|
||||
if xyz1.dim() == 2:
|
||||
xyz1 = xyz1.unsqueeze(0)
|
||||
if xyz2.dim() == 2:
|
||||
xyz2 = xyz2.unsqueeze(0)
|
||||
if transpose:
|
||||
xyz1 = xyz1.transpose(1, 2)
|
||||
xyz2 = xyz2.transpose(1, 2)
|
||||
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
||||
return cost
|
25
setup.py
25
setup.py
|
@ -1,25 +0,0 @@
|
|||
"""Setup extension
|
||||
|
||||
Notes:
|
||||
If extra_compile_args is provided, you need to provide different instances for different extensions.
|
||||
Refer to https://github.com/pytorch/pytorch/issues/20169
|
||||
|
||||
"""
|
||||
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name="emd_ext",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name="emd_cuda",
|
||||
sources=[
|
||||
"cuda/emd.cpp",
|
||||
"cuda/emd_kernel.cu",
|
||||
],
|
||||
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
|
||||
),
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
|
@ -1,45 +0,0 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from emd import earth_mover_distance
|
||||
|
||||
# gt
|
||||
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p1 = p1.repeat(3, 1, 1)
|
||||
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p2 = p2.repeat(3, 1, 1)
|
||||
print(p1)
|
||||
print(p2)
|
||||
p1.requires_grad = True
|
||||
p2.requires_grad = True
|
||||
|
||||
gt_dist = (
|
||||
(((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2
|
||||
+ (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2
|
||||
+ (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3
|
||||
)
|
||||
print("gt_dist: ", gt_dist)
|
||||
|
||||
gt_dist.backward()
|
||||
print(p1.grad)
|
||||
print(p2.grad)
|
||||
|
||||
# emd
|
||||
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p1 = p1.repeat(3, 1, 1)
|
||||
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p2 = p2.repeat(3, 1, 1)
|
||||
print(p1)
|
||||
print(p2)
|
||||
p1.requires_grad = True
|
||||
p2.requires_grad = True
|
||||
|
||||
d = earth_mover_distance(p1, p2, transpose=False)
|
||||
print(d)
|
||||
|
||||
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
|
||||
print(loss)
|
||||
|
||||
loss.backward()
|
||||
print(p1.grad)
|
||||
print(p2.grad)
|
1
torchemd/__init__.py
Executable file
1
torchemd/__init__.py
Executable file
|
@ -0,0 +1 @@
|
|||
from .emd import EarthMoverDistanceFunction, earth_mover_distance
|
60
torchemd/emd.py
Normal file
60
torchemd/emd.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import pathlib
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
# get path to where this file is located
|
||||
REALPATH = pathlib.Path(__file__).parent
|
||||
|
||||
# JIT compile CUDA extensions
|
||||
load(
|
||||
name="torchemd_cuda",
|
||||
sources=[
|
||||
str(REALPATH / file)
|
||||
for file in [
|
||||
"cuda/emd.cpp",
|
||||
"cuda/emd_kernel.cu",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
# import compiled CUDA extensions
|
||||
import torchemd_cuda
|
||||
|
||||
|
||||
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz1, xyz2) -> torch.Tensor:
|
||||
xyz1 = xyz1.contiguous()
|
||||
xyz2 = xyz2.contiguous()
|
||||
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
|
||||
match = torchemd_cuda.approxmatch_forward(xyz1, xyz2)
|
||||
cost = torchemd_cuda.matchcost_forward(xyz1, xyz2, match)
|
||||
ctx.save_for_backward(xyz1, xyz2, match)
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_cost) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
xyz1, xyz2, match = ctx.saved_tensors
|
||||
grad_cost = grad_cost.contiguous()
|
||||
grad_xyz1, grad_xyz2 = torchemd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
|
||||
return grad_xyz1, grad_xyz2
|
||||
|
||||
|
||||
def earth_mover_distance(xyz1, xyz2, transpose=True) -> torch.Tensor:
|
||||
"""Earth Mover Distance (Approx)
|
||||
|
||||
Args:
|
||||
xyz1 (torch.Tensor): (B, N, 3) or (N, 3)
|
||||
xyz2 (torch.Tensor): (B, M, 3) or (M, 3)
|
||||
transpose (bool, optional): If True, xyz1 and xyz2 will be transposed to (B, 3, N) and (B, 3, M).
|
||||
"""
|
||||
if xyz1.dim() == 2:
|
||||
xyz1 = xyz1.unsqueeze(0)
|
||||
if xyz2.dim() == 2:
|
||||
xyz2 = xyz2.unsqueeze(0)
|
||||
if transpose:
|
||||
xyz1 = xyz1.transpose(1, 2)
|
||||
xyz2 = xyz2.transpose(1, 2)
|
||||
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
||||
return cost
|
Loading…
Reference in a new issue