🚚 move code to package

This commit is contained in:
Laurent FAINSIN 2023-04-24 15:07:12 +02:00
parent 892b79f5c5
commit 00fc30f5d5
8 changed files with 61 additions and 115 deletions

View file

45
emd.py
View file

@ -1,45 +0,0 @@
import emd_cuda
import torch
class EarthMoverDistanceFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
ctx.save_for_backward(xyz1, xyz2, match)
return cost
@staticmethod
def backward(ctx, grad_cost):
xyz1, xyz2, match = ctx.saved_tensors
grad_cost = grad_cost.contiguous()
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
return grad_xyz1, grad_xyz2
def earth_mover_distance(xyz1, xyz2, transpose=True):
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (b, 3, n1)
xyz2 (torch.Tensor): (b, 3, n1)
transpose (bool): whether to transpose inputs as it might be BCN format.
Extensions only support BNC format.
Returns:
cost (torch.Tensor): (b)
"""
if xyz1.dim() == 2:
xyz1 = xyz1.unsqueeze(0)
if xyz2.dim() == 2:
xyz2 = xyz2.unsqueeze(0)
if transpose:
xyz1 = xyz1.transpose(1, 2)
xyz2 = xyz2.transpose(1, 2)
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
return cost

View file

@ -1,25 +0,0 @@
"""Setup extension
Notes:
If extra_compile_args is provided, you need to provide different instances for different extensions.
Refer to https://github.com/pytorch/pytorch/issues/20169
"""
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="emd_ext",
ext_modules=[
CUDAExtension(
name="emd_cuda",
sources=[
"cuda/emd.cpp",
"cuda/emd_kernel.cu",
],
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
),
],
cmdclass={"build_ext": BuildExtension},
)

View file

@ -1,45 +0,0 @@
import numpy as np
import torch
from emd import earth_mover_distance
# gt
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
p1.requires_grad = True
p2.requires_grad = True
gt_dist = (
(((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2
+ (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2
+ (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3
)
print("gt_dist: ", gt_dist)
gt_dist.backward()
print(p1.grad)
print(p2.grad)
# emd
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
p1.requires_grad = True
p2.requires_grad = True
d = earth_mover_distance(p1, p2, transpose=False)
print(d)
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
print(loss)
loss.backward()
print(p1.grad)
print(p2.grad)

1
torchemd/__init__.py Executable file
View file

@ -0,0 +1 @@
from .emd import EarthMoverDistanceFunction, earth_mover_distance

60
torchemd/emd.py Normal file
View file

@ -0,0 +1,60 @@
import pathlib
import torch
from torch.utils.cpp_extension import load
# get path to where this file is located
REALPATH = pathlib.Path(__file__).parent
# JIT compile CUDA extensions
load(
name="torchemd_cuda",
sources=[
str(REALPATH / file)
for file in [
"cuda/emd.cpp",
"cuda/emd_kernel.cu",
]
],
)
# import compiled CUDA extensions
import torchemd_cuda
class EarthMoverDistanceFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2) -> torch.Tensor:
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = torchemd_cuda.approxmatch_forward(xyz1, xyz2)
cost = torchemd_cuda.matchcost_forward(xyz1, xyz2, match)
ctx.save_for_backward(xyz1, xyz2, match)
return cost
@staticmethod
def backward(ctx, grad_cost) -> tuple[torch.Tensor, torch.Tensor]:
xyz1, xyz2, match = ctx.saved_tensors
grad_cost = grad_cost.contiguous()
grad_xyz1, grad_xyz2 = torchemd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
return grad_xyz1, grad_xyz2
def earth_mover_distance(xyz1, xyz2, transpose=True) -> torch.Tensor:
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (B, N, 3) or (N, 3)
xyz2 (torch.Tensor): (B, M, 3) or (M, 3)
transpose (bool, optional): If True, xyz1 and xyz2 will be transposed to (B, 3, N) and (B, 3, M).
"""
if xyz1.dim() == 2:
xyz1 = xyz1.unsqueeze(0)
if xyz2.dim() == 2:
xyz2 = xyz2.unsqueeze(0)
if transpose:
xyz1 = xyz1.transpose(1, 2)
xyz2 = xyz2.transpose(1, 2)
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
return cost