From 00fc30f5d528377728d8b5e92861c597ce272ad4 Mon Sep 17 00:00:00 2001 From: Laurent FAINSIN Date: Mon, 24 Apr 2023 15:07:12 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=9A=20move=20code=20to=20package?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- __init__.py | 0 emd.py | 45 -------------------- setup.py | 25 ----------- test_emd_loss.py | 45 -------------------- torchemd/__init__.py | 1 + {cuda => torchemd/cuda}/emd.cpp | 0 {cuda => torchemd/cuda}/emd_kernel.cu | 0 torchemd/emd.py | 60 +++++++++++++++++++++++++++ 8 files changed, 61 insertions(+), 115 deletions(-) delete mode 100755 __init__.py delete mode 100755 emd.py delete mode 100755 setup.py delete mode 100644 test_emd_loss.py create mode 100755 torchemd/__init__.py rename {cuda => torchemd/cuda}/emd.cpp (100%) rename {cuda => torchemd/cuda}/emd_kernel.cu (100%) create mode 100644 torchemd/emd.py diff --git a/__init__.py b/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/emd.py b/emd.py deleted file mode 100755 index cb9e164..0000000 --- a/emd.py +++ /dev/null @@ -1,45 +0,0 @@ -import emd_cuda -import torch - - -class EarthMoverDistanceFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, xyz1, xyz2): - xyz1 = xyz1.contiguous() - xyz2 = xyz2.contiguous() - assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently." - match = emd_cuda.approxmatch_forward(xyz1, xyz2) - cost = emd_cuda.matchcost_forward(xyz1, xyz2, match) - ctx.save_for_backward(xyz1, xyz2, match) - return cost - - @staticmethod - def backward(ctx, grad_cost): - xyz1, xyz2, match = ctx.saved_tensors - grad_cost = grad_cost.contiguous() - grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match) - return grad_xyz1, grad_xyz2 - - -def earth_mover_distance(xyz1, xyz2, transpose=True): - """Earth Mover Distance (Approx) - - Args: - xyz1 (torch.Tensor): (b, 3, n1) - xyz2 (torch.Tensor): (b, 3, n1) - transpose (bool): whether to transpose inputs as it might be BCN format. - Extensions only support BNC format. - - Returns: - cost (torch.Tensor): (b) - - """ - if xyz1.dim() == 2: - xyz1 = xyz1.unsqueeze(0) - if xyz2.dim() == 2: - xyz2 = xyz2.unsqueeze(0) - if transpose: - xyz1 = xyz1.transpose(1, 2) - xyz2 = xyz2.transpose(1, 2) - cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) - return cost diff --git a/setup.py b/setup.py deleted file mode 100755 index 0c9be58..0000000 --- a/setup.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Setup extension - -Notes: - If extra_compile_args is provided, you need to provide different instances for different extensions. - Refer to https://github.com/pytorch/pytorch/issues/20169 - -""" - -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -setup( - name="emd_ext", - ext_modules=[ - CUDAExtension( - name="emd_cuda", - sources=[ - "cuda/emd.cpp", - "cuda/emd_kernel.cu", - ], - extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]}, - ), - ], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/test_emd_loss.py b/test_emd_loss.py deleted file mode 100644 index fd6b554..0000000 --- a/test_emd_loss.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np -import torch - -from emd import earth_mover_distance - -# gt -p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda() -p1 = p1.repeat(3, 1, 1) -p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda() -p2 = p2.repeat(3, 1, 1) -print(p1) -print(p2) -p1.requires_grad = True -p2.requires_grad = True - -gt_dist = ( - (((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2 - + (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2 - + (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3 -) -print("gt_dist: ", gt_dist) - -gt_dist.backward() -print(p1.grad) -print(p2.grad) - -# emd -p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda() -p1 = p1.repeat(3, 1, 1) -p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda() -p2 = p2.repeat(3, 1, 1) -print(p1) -print(p2) -p1.requires_grad = True -p2.requires_grad = True - -d = earth_mover_distance(p1, p2, transpose=False) -print(d) - -loss = d[0] / 2 + d[1] * 2 + d[2] / 3 -print(loss) - -loss.backward() -print(p1.grad) -print(p2.grad) diff --git a/torchemd/__init__.py b/torchemd/__init__.py new file mode 100755 index 0000000..03d08ef --- /dev/null +++ b/torchemd/__init__.py @@ -0,0 +1 @@ +from .emd import EarthMoverDistanceFunction, earth_mover_distance diff --git a/cuda/emd.cpp b/torchemd/cuda/emd.cpp similarity index 100% rename from cuda/emd.cpp rename to torchemd/cuda/emd.cpp diff --git a/cuda/emd_kernel.cu b/torchemd/cuda/emd_kernel.cu similarity index 100% rename from cuda/emd_kernel.cu rename to torchemd/cuda/emd_kernel.cu diff --git a/torchemd/emd.py b/torchemd/emd.py new file mode 100644 index 0000000..5976084 --- /dev/null +++ b/torchemd/emd.py @@ -0,0 +1,60 @@ +import pathlib + +import torch +from torch.utils.cpp_extension import load + +# get path to where this file is located +REALPATH = pathlib.Path(__file__).parent + +# JIT compile CUDA extensions +load( + name="torchemd_cuda", + sources=[ + str(REALPATH / file) + for file in [ + "cuda/emd.cpp", + "cuda/emd_kernel.cu", + ] + ], +) + +# import compiled CUDA extensions +import torchemd_cuda + + +class EarthMoverDistanceFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, xyz1, xyz2) -> torch.Tensor: + xyz1 = xyz1.contiguous() + xyz2 = xyz2.contiguous() + assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently." + match = torchemd_cuda.approxmatch_forward(xyz1, xyz2) + cost = torchemd_cuda.matchcost_forward(xyz1, xyz2, match) + ctx.save_for_backward(xyz1, xyz2, match) + return cost + + @staticmethod + def backward(ctx, grad_cost) -> tuple[torch.Tensor, torch.Tensor]: + xyz1, xyz2, match = ctx.saved_tensors + grad_cost = grad_cost.contiguous() + grad_xyz1, grad_xyz2 = torchemd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match) + return grad_xyz1, grad_xyz2 + + +def earth_mover_distance(xyz1, xyz2, transpose=True) -> torch.Tensor: + """Earth Mover Distance (Approx) + + Args: + xyz1 (torch.Tensor): (B, N, 3) or (N, 3) + xyz2 (torch.Tensor): (B, M, 3) or (M, 3) + transpose (bool, optional): If True, xyz1 and xyz2 will be transposed to (B, 3, N) and (B, 3, M). + """ + if xyz1.dim() == 2: + xyz1 = xyz1.unsqueeze(0) + if xyz2.dim() == 2: + xyz2 = xyz2.unsqueeze(0) + if transpose: + xyz1 = xyz1.transpose(1, 2) + xyz2 = xyz2.transpose(1, 2) + cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) + return cost