Compare commits

...

10 commits

Author SHA1 Message Date
Laurent FAINSIN 2b5ac1c436 🙈 add .pytest_cache to file exclusion 2023-04-24 15:21:09 +02:00
Laurent FAINSIN 481461f442 convert and add pytests 2023-04-24 15:09:41 +02:00
Laurent FAINSIN b72b16424a 💬 remove outdated instructions from README.md 2023-04-24 15:09:15 +02:00
Laurent FAINSIN ba68f223e6 🔧 add pyproject.toml 2023-04-24 15:07:34 +02:00
Laurent FAINSIN 00fc30f5d5 🚚 move code to package 2023-04-24 15:07:12 +02:00
Laurent FAINSIN 892b79f5c5 🎨 autoformatting cpp/cu files 2023-04-24 11:53:37 +02:00
Laurent FAINSIN 2375a23ca5 🔧 add settings.json vscode config 2023-04-24 11:46:04 +02:00
Laurent FAINSIN bb93f6ab47 💄 black autoformatting 2023-04-24 11:45:30 +02:00
Laurent FAINSIN 51289d8670 🔧 add .editorconfig 2023-04-24 11:43:35 +02:00
Laurent FAINSIN 015969c42e 🙈 add .gitattribute 2023-04-24 11:42:40 +02:00
15 changed files with 736 additions and 543 deletions

15
.editorconfig Normal file
View file

@ -0,0 +1,15 @@
# EditorConfig is awesome: https://EditorConfig.org
# top-most EditorConfig file
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.{json,toml,yaml,yml}]
indent_size = 2

75
.gitattributes vendored Normal file
View file

@ -0,0 +1,75 @@
####################################################################################################
# Python .gitattributes template
# https://github.com/alexkaratarakis/gitattributes/blob/master/Python.gitattributes
####################################################################################################
# Source files
# ============
*.pxd text diff=python
*.py text diff=python
*.py3 text diff=python
*.pyw text diff=python
*.pyx text diff=python
*.pyz text diff=python
*.pyi text diff=python
# Binary files
# ============
*.db binary
*.p binary
*.pkl binary
*.pickle binary
*.pyc binary export-ignore
*.pyo binary export-ignore
*.pyd binary
# Jupyter notebook
*.ipynb text
# Note: .db, .p, and .pkl files are associated
# with the python modules ``pickle``, ``dbm.*``,
# ``shelve``, ``marshal``, ``anydbm``, & ``bsddb``
# (among others).
####################################################################################################
# C++ .gitattributes template
# https://github.com/alexkaratarakis/gitattributes/blob/master/C%2B%2B.gitattributes
####################################################################################################
# Sources
*.c text diff=cpp
*.cc text diff=cpp
*.cxx text diff=cpp
*.cpp text diff=cpp
*.cpi text diff=cpp
*.c++ text diff=cpp
*.hpp text diff=cpp
*.h text diff=cpp
*.h++ text diff=cpp
*.hh text diff=cpp
# Compiled Object files
*.slo binary
*.lo binary
*.o binary
*.obj binary
# Precompiled Headers
*.gch binary
*.pch binary
# Compiled Dynamic libraries
*.so binary
*.dylib binary
*.dll binary
# Compiled Static libraries
*.lai binary
*.la binary
*.a binary
*.lib binary
# Executables
*.exe binary
*.out binary
*.app binary

28
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,28 @@
{
// good pratice settings
"editor.formatOnSave": true,
"python.linting.enabled": true,
"python.linting.lintOnSave": true,
"python.linting.mypyEnabled": true,
"files.insertFinalNewline": true,
"python.analysis.typeCheckingMode": "basic", // get ready to be annoyed
"python.formatting.provider": "black",
"[python]": {
"editor.codeActionsOnSave": {
"source.organizeImports": true // isort
}
},
"files.exclude": {
// defaults
"**/.git": true,
"**/.svn": true,
"**/.hg": true,
"**/CVS": true,
"**/.DS_Store": true,
"**/Thumbs.db": true,
// annoying
"**/__pycache__": true,
"**/.mypy_cache": true,
"**/.pytest_cache": true,
}
}

View file

@ -1,31 +1,11 @@
# PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD)
## Dependency
The code has been tested on Ubuntu 16.04, PyTorch 1.1.0, CUDA 9.0.
## Usage
First compile using
python setup.py install
Then, copy the lib file out to the main directory,
cp build/lib.linux-x86_64-3.6/emd_cuda.cpython-36m-x86_64-linux-gnu.so .
Then, you can use it by simply
from emd import earth_mover_distance
d = earth_mover_distance(p1, p2, transpose=False) # p1: B x N1 x 3, p2: B x N2 x 3
Check `test_emd_loss.py` for example.
## Author
The cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.
The cuda code is originally written by Haoqiang Fan.
The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.
## License
MIT

View file

View file

@ -1,399 +0,0 @@
/**********************************
* Original Author: Haoqiang Fan
* Modified by: Kaichun Mo
*********************************/
#ifndef _EMD_KERNEL
#define _EMD_KERNEL
#include <cmath>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
/********************************
* Forward kernel for approxmatch
*********************************/
template<typename scalar_t>
__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){
scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
scalar_t multiL,multiR;
if (n>=m){
multiL=1;
multiR=n/m;
}else{
multiL=m/n;
multiR=1;
}
const int Block=1024;
__shared__ scalar_t buf[Block*4];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x;j<n*m;j+=blockDim.x)
match[i*n*m+j]=0;
for (int j=threadIdx.x;j<n;j+=blockDim.x)
remainL[j]=multiL;
for (int j=threadIdx.x;j<m;j+=blockDim.x)
remainR[j]=multiR;
__syncthreads();
for (int j=7;j>=-2;j--){
scalar_t level=-powf(4.0f,j);
if (j==-2){
level=0;
}
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
scalar_t x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
scalar_t suml=1e-9f;
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
scalar_t x2=xyz2[i*m*3+l0*3+l*3+0];
scalar_t y2=xyz2[i*m*3+l0*3+l*3+1];
scalar_t z2=xyz2[i*m*3+l0*3+l*3+2];
buf[l*4+0]=x2;
buf[l*4+1]=y2;
buf[l*4+2]=z2;
buf[l*4+3]=remainR[l0+l];
}
__syncthreads();
for (int l=0;l<lend;l++){
scalar_t x2=buf[l*4+0];
scalar_t y2=buf[l*4+1];
scalar_t z2=buf[l*4+2];
scalar_t d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));
scalar_t w=__expf(d)*buf[l*4+3];
suml+=w;
}
__syncthreads();
}
if (k<n)
ratioL[k]=remainL[k]/suml;
}
__syncthreads();
for (int l0=0;l0<m;l0+=blockDim.x){
int l=l0+threadIdx.x;
scalar_t x2=0,y2=0,z2=0;
if (l<m){
x2=xyz2[i*m*3+l*3+0];
y2=xyz2[i*m*3+l*3+1];
z2=xyz2[i*m*3+l*3+2];
}
scalar_t sumr=0;
for (int k0=0;k0<n;k0+=Block){
int kend=min(n,k0+Block)-k0;
for (int k=threadIdx.x;k<kend;k+=blockDim.x){
buf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];
buf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];
buf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];
buf[k*4+3]=ratioL[k0+k];
}
__syncthreads();
for (int k=0;k<kend;k++){
scalar_t x1=buf[k*4+0];
scalar_t y1=buf[k*4+1];
scalar_t z1=buf[k*4+2];
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];
sumr+=w;
}
__syncthreads();
}
if (l<m){
sumr*=remainR[l];
scalar_t consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
ratioR[l]=consumption*remainR[l];
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
}
}
__syncthreads();
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
scalar_t x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
scalar_t suml=0;
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
buf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];
buf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];
buf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];
buf[l*4+3]=ratioR[l0+l];
}
__syncthreads();
scalar_t rl=ratioL[k];
if (k<n){
for (int l=0;l<lend;l++){
scalar_t x2=buf[l*4+0];
scalar_t y2=buf[l*4+1];
scalar_t z2=buf[l*4+2];
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];
match[i*n*m+(l0+l)*n+k]+=w;
suml+=w;
}
}
__syncthreads();
}
if (k<n)
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
}
__syncthreads();
}
}
}
//void approxmatchLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,scalar_t * match,scalar_t * temp){
// approxmatch<<<32,512>>>(b,n,m,xyz1,xyz2,match,temp);
//}
/* ApproxMatch forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
Output:
match: (B, N2, N1)
*/
at::Tensor ApproxMatchForward(
const at::Tensor xyz1,
const at::Tensor xyz2){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
TORCH_CHECK_EQ (xyz2.size(0), b);
TORCH_CHECK_EQ (xyz1.size(2), 3);
TORCH_CHECK_EQ (xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto match = at::zeros({b, m, n}, xyz1.type());
auto temp = at::zeros({b, (n+m)*2}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
}));
C10_CUDA_CHECK(cudaGetLastError());
return match;
}
/********************************
* Forward kernel for matchcost
*********************************/
template<typename scalar_t>
__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){
__shared__ scalar_t allsum[512];
const int Block=1024;
__shared__ scalar_t buf[Block*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
scalar_t subsum=0;
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
scalar_t x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend*3;l+=blockDim.x)
buf[l]=xyz2[i*m*3+l0*3+l];
__syncthreads();
if (k<n){
for (int l=0;l<lend;l++){
scalar_t x2=buf[l*3+0];
scalar_t y2=buf[l*3+1];
scalar_t z2=buf[l*3+2];
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
subsum+=d*match[i*n*m+(l0+l)*n+k];
}
}
__syncthreads();
}
}
allsum[threadIdx.x]=subsum;
for (int j=1;j<blockDim.x;j<<=1){
__syncthreads();
if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){
allsum[threadIdx.x]+=allsum[threadIdx.x+j];
}
}
if (threadIdx.x==0)
out[i]=allsum[0];
__syncthreads();
}
}
//void matchcostLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * out){
// matchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);
//}
/* MatchCost forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
cost: (B)
*/
at::Tensor MatchCostForward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
TORCH_CHECK_EQ (xyz2.size(0), b);
TORCH_CHECK_EQ (xyz1.size(2), 3);
TORCH_CHECK_EQ (xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto cost = at::zeros({b}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
}));
C10_CUDA_CHECK(cudaGetLastError());
return cost;
}
/********************************
* matchcostgrad2 kernel
*********************************/
template<typename scalar_t>
__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){
__shared__ scalar_t sum_grad[256*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
int kbeg=m*blockIdx.y/gridDim.y;
int kend=m*(blockIdx.y+1)/gridDim.y;
for (int k=kbeg;k<kend;k++){
scalar_t x2=xyz2[(i*m+k)*3+0];
scalar_t y2=xyz2[(i*m+k)*3+1];
scalar_t z2=xyz2[(i*m+k)*3+2];
scalar_t subsumx=0,subsumy=0,subsumz=0;
for (int j=threadIdx.x;j<n;j+=blockDim.x){
scalar_t x1=x2-xyz1[(i*n+j)*3+0];
scalar_t y1=y2-xyz1[(i*n+j)*3+1];
scalar_t z1=z2-xyz1[(i*n+j)*3+2];
scalar_t d=match[i*n*m+k*n+j]*2;
subsumx+=x1*d;
subsumy+=y1*d;
subsumz+=z1*d;
}
sum_grad[threadIdx.x*3+0]=subsumx;
sum_grad[threadIdx.x*3+1]=subsumy;
sum_grad[threadIdx.x*3+2]=subsumz;
for (int j=1;j<blockDim.x;j<<=1){
__syncthreads();
int j1=threadIdx.x;
int j2=threadIdx.x+j;
if ((j1&j)==0 && j2<blockDim.x){
sum_grad[j1*3+0]+=sum_grad[j2*3+0];
sum_grad[j1*3+1]+=sum_grad[j2*3+1];
sum_grad[j1*3+2]+=sum_grad[j2*3+2];
}
}
if (threadIdx.x==0){
grad2[(i*m+k)*3+0]=sum_grad[0]*grad_cost[i];
grad2[(i*m+k)*3+1]=sum_grad[1]*grad_cost[i];
grad2[(i*m+k)*3+2]=sum_grad[2]*grad_cost[i];
}
__syncthreads();
}
}
}
/********************************
* matchcostgrad1 kernel
*********************************/
template<typename scalar_t>
__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int l=threadIdx.x;l<n;l+=blockDim.x){
scalar_t x1=xyz1[i*n*3+l*3+0];
scalar_t y1=xyz1[i*n*3+l*3+1];
scalar_t z1=xyz1[i*n*3+l*3+2];
scalar_t dx=0,dy=0,dz=0;
for (int k=0;k<m;k++){
scalar_t x2=xyz2[i*m*3+k*3+0];
scalar_t y2=xyz2[i*m*3+k*3+1];
scalar_t z2=xyz2[i*m*3+k*3+2];
scalar_t d=match[i*n*m+k*n+l]*2;
dx+=(x1-x2)*d;
dy+=(y1-y2)*d;
dz+=(z1-z2)*d;
}
grad1[i*n*3+l*3+0]=dx*grad_cost[i];
grad1[i*n*3+l*3+1]=dy*grad_cost[i];
grad1[i*n*3+l*3+2]=dz*grad_cost[i];
}
}
}
//void matchcostgradLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * grad1,scalar_t * grad2){
// matchcostgrad1<<<32,512>>>(b,n,m,xyz1,xyz2,match,grad1);
// matchcostgrad2<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);
//}
/* MatchCost backward interface
Input:
grad_cost: (B) # gradients on cost
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
grad1: (B, N1, 3)
grad2: (B, N2, 3)
*/
std::vector<at::Tensor> MatchCostBackward(
const at::Tensor grad_cost,
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
TORCH_CHECK_EQ (xyz2.size(0), b);
TORCH_CHECK_EQ (xyz1.size(2), 3);
TORCH_CHECK_EQ (xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto grad1 = at::zeros({b, n, 3}, xyz1.type());
auto grad2 = at::zeros({b, m, 3}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] {
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
}));
C10_CUDA_CHECK(cudaGetLastError());
return std::vector<at::Tensor>({grad1, grad2});
}
#endif

46
emd.py
View file

@ -1,46 +0,0 @@
import torch
import emd_cuda
class EarthMoverDistanceFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
ctx.save_for_backward(xyz1, xyz2, match)
return cost
@staticmethod
def backward(ctx, grad_cost):
xyz1, xyz2, match = ctx.saved_tensors
grad_cost = grad_cost.contiguous()
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
return grad_xyz1, grad_xyz2
def earth_mover_distance(xyz1, xyz2, transpose=True):
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (b, 3, n1)
xyz2 (torch.Tensor): (b, 3, n1)
transpose (bool): whether to transpose inputs as it might be BCN format.
Extensions only support BNC format.
Returns:
cost (torch.Tensor): (b)
"""
if xyz1.dim() == 2:
xyz1 = xyz1.unsqueeze(0)
if xyz2.dim() == 2:
xyz2 = xyz2.unsqueeze(0)
if transpose:
xyz1 = xyz1.transpose(1, 2)
xyz2 = xyz2.transpose(1, 2)
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
return cost

36
pyproject.toml Normal file
View file

@ -0,0 +1,36 @@
[project]
name = "torchemd"
version = "1.0.0"
authors = [{ name = "Laurent Fainsin", email = "laurent@fainsin.bzh" }]
description = "PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD)"
readme = "README.md"
requires-python = ">=3.9"
[build-system]
requires = ["flit_core>=3.4"]
build-backend = "flit_core.buildapi"
[tool.ruff]
ignore-init-module-imports = true
select = ["E", "F", "I"]
line-length = 120
[tool.black]
exclude = '''
/(
\.git
\.venv
)/
'''
include = '\.pyi?$'
line-length = 120
target-version = ["py310"]
[tool.isort]
multi_line_output = 3
profile = "black"
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true

View file

@ -1,27 +0,0 @@
"""Setup extension
Notes:
If extra_compile_args is provided, you need to provide different instances for different extensions.
Refer to https://github.com/pytorch/pytorch/issues/20169
"""
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='emd_ext',
ext_modules=[
CUDAExtension(
name='emd_cuda',
sources=[
'cuda/emd.cpp',
'cuda/emd_kernel.cu',
],
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
),
],
cmdclass={
'build_ext': BuildExtension
})

View file

@ -1,44 +0,0 @@
import torch
import numpy as np
import time
from emd import earth_mover_distance
# gt
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
p1.requires_grad = True
p2.requires_grad = True
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
print('gt_dist: ', gt_dist)
gt_dist.backward()
print(p1.grad)
print(p2.grad)
# emd
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
p1.requires_grad = True
p2.requires_grad = True
d = earth_mover_distance(p1, p2, transpose=False)
print(d)
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
print(loss)
loss.backward()
print(p1.grad)
print(p2.grad)

80
tests/test_correctness.py Normal file
View file

@ -0,0 +1,80 @@
import torch
from einops import repeat
from torch import Tensor
from torchemd import earth_mover_distance
def generate_pointclouds() -> tuple[Tensor, Tensor]:
# create first point cloud
pc1 = torch.Tensor(
[
[1.7, -0.1, 0.1],
[0.1, 1.2, 0.3],
],
).cuda()
pc1 = repeat(pc1, "n c -> b n c", b=3)
pc1.requires_grad = True
# create second point cloud
pc2 = torch.Tensor(
[
[0.4, 1.8, 0.2],
[1.2, -0.2, 0.3],
],
).cuda()
pc2 = repeat(pc2, "n c -> b n c", b=3)
pc2.requires_grad = True
return pc1, pc2
def dummy_loss(distance: Tensor) -> Tensor:
return distance[0] / 2 + distance[1] * 2 + distance[2] / 3
def manual_distance_computation(pc1: Tensor, pc2: Tensor) -> Tensor:
return (pc1[:, 0] - pc2[:, 1]).pow(2).sum(dim=1) + (pc1[:, 1] - pc2[:, 0]).pow(2).sum(dim=1)
def test_emd():
# compute earth mover distance directly from formula
pc1, pc2 = generate_pointclouds()
ground_truth_distance = manual_distance_computation(pc1, pc2)
ground_truth_loss = dummy_loss(ground_truth_distance)
# get gradients of point clouds
ground_truth_loss.backward()
pc1_gt_grad = pc1.grad
pc2_gt_grad = pc2.grad
# compute earth mover distance directly from implementation
pc1, pc2 = generate_pointclouds()
computed_distance = earth_mover_distance(pc1, pc2, transpose=False)
loss = dummy_loss(computed_distance)
# get gradients of point clouds
loss.backward()
pc1_grad = pc1.grad
pc2_grad = pc2.grad
# compare gradients
assert pc1_grad.allclose(pc1_gt_grad)
assert pc2_grad.allclose(pc2_gt_grad)
# compare distances
assert computed_distance.allclose(ground_truth_distance)
# compare loss
assert loss.allclose(ground_truth_loss)
def test_equality():
pc1, _ = generate_pointclouds()
distance = earth_mover_distance(pc1, pc1, transpose=False)
assert distance.allclose(torch.zeros_like(distance))
def test_symetry():
pc1, pc2 = generate_pointclouds()
distance1 = earth_mover_distance(pc1, pc2, transpose=False)
distance2 = earth_mover_distance(pc2, pc1, transpose=False)
assert distance1.allclose(distance2)

1
torchemd/__init__.py Executable file
View file

@ -0,0 +1 @@
from .emd import EarthMoverDistanceFunction, earth_mover_distance

View file

@ -4,7 +4,7 @@
#include <vector>
#include <torch/extension.h>
//CUDA declarations
// CUDA declarations
at::Tensor ApproxMatchForward(
const at::Tensor xyz1,
const at::Tensor xyz2);
@ -20,10 +20,11 @@ std::vector<at::Tensor> MatchCostBackward(
const at::Tensor xyz2,
const at::Tensor match);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)");
m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)");
m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)");
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("approxmatch_forward", &ApproxMatchForward, "ApproxMatch forward (CUDA)");
m.def("matchcost_forward", &MatchCostForward, "MatchCost forward (CUDA)");
m.def("matchcost_backward", &MatchCostBackward, "MatchCost backward (CUDA)");
}
#endif

433
torchemd/cuda/emd_kernel.cu Normal file
View file

@ -0,0 +1,433 @@
/**********************************
* Original Author: Haoqiang Fan
* Modified by: Kaichun Mo
*********************************/
#ifndef _EMD_KERNEL
#define _EMD_KERNEL
#include <cmath>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
/********************************
* Forward kernel for approxmatch
*********************************/
template <typename scalar_t>
__global__ void approxmatch(int b, int n, int m, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, scalar_t *__restrict__ match, scalar_t *temp)
{
scalar_t *remainL = temp + blockIdx.x * (n + m) * 2;
scalar_t *remainR = temp + blockIdx.x * (n + m) * 2 + n;
scalar_t *ratioL = temp + blockIdx.x * (n + m) * 2 + n + m;
scalar_t *ratioR = temp + blockIdx.x * (n + m) * 2 + n + m + n;
scalar_t multiL, multiR;
if (n >= m)
{
multiL = 1;
multiR = n / m;
}
else
{
multiL = m / n;
multiR = 1;
}
const int Block = 1024;
__shared__ scalar_t buf[Block * 4];
for (int i = blockIdx.x; i < b; i += gridDim.x)
{
for (int j = threadIdx.x; j < n * m; j += blockDim.x)
match[i * n * m + j] = 0;
for (int j = threadIdx.x; j < n; j += blockDim.x)
remainL[j] = multiL;
for (int j = threadIdx.x; j < m; j += blockDim.x)
remainR[j] = multiR;
__syncthreads();
for (int j = 7; j >= -2; j--)
{
scalar_t level = -powf(4.0f, j);
if (j == -2)
{
level = 0;
}
for (int k0 = 0; k0 < n; k0 += blockDim.x)
{
int k = k0 + threadIdx.x;
scalar_t x1 = 0, y1 = 0, z1 = 0;
if (k < n)
{
x1 = xyz1[i * n * 3 + k * 3 + 0];
y1 = xyz1[i * n * 3 + k * 3 + 1];
z1 = xyz1[i * n * 3 + k * 3 + 2];
}
scalar_t suml = 1e-9f;
for (int l0 = 0; l0 < m; l0 += Block)
{
int lend = min(m, l0 + Block) - l0;
for (int l = threadIdx.x; l < lend; l += blockDim.x)
{
scalar_t x2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 0];
scalar_t y2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 1];
scalar_t z2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 2];
buf[l * 4 + 0] = x2;
buf[l * 4 + 1] = y2;
buf[l * 4 + 2] = z2;
buf[l * 4 + 3] = remainR[l0 + l];
}
__syncthreads();
for (int l = 0; l < lend; l++)
{
scalar_t x2 = buf[l * 4 + 0];
scalar_t y2 = buf[l * 4 + 1];
scalar_t z2 = buf[l * 4 + 2];
scalar_t d = level * ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1));
scalar_t w = __expf(d) * buf[l * 4 + 3];
suml += w;
}
__syncthreads();
}
if (k < n)
ratioL[k] = remainL[k] / suml;
}
__syncthreads();
for (int l0 = 0; l0 < m; l0 += blockDim.x)
{
int l = l0 + threadIdx.x;
scalar_t x2 = 0, y2 = 0, z2 = 0;
if (l < m)
{
x2 = xyz2[i * m * 3 + l * 3 + 0];
y2 = xyz2[i * m * 3 + l * 3 + 1];
z2 = xyz2[i * m * 3 + l * 3 + 2];
}
scalar_t sumr = 0;
for (int k0 = 0; k0 < n; k0 += Block)
{
int kend = min(n, k0 + Block) - k0;
for (int k = threadIdx.x; k < kend; k += blockDim.x)
{
buf[k * 4 + 0] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 0];
buf[k * 4 + 1] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 1];
buf[k * 4 + 2] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 2];
buf[k * 4 + 3] = ratioL[k0 + k];
}
__syncthreads();
for (int k = 0; k < kend; k++)
{
scalar_t x1 = buf[k * 4 + 0];
scalar_t y1 = buf[k * 4 + 1];
scalar_t z1 = buf[k * 4 + 2];
scalar_t w = __expf(level * ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1))) * buf[k * 4 + 3];
sumr += w;
}
__syncthreads();
}
if (l < m)
{
sumr *= remainR[l];
scalar_t consumption = fminf(remainR[l] / (sumr + 1e-9f), 1.0f);
ratioR[l] = consumption * remainR[l];
remainR[l] = fmaxf(0.0f, remainR[l] - sumr);
}
}
__syncthreads();
for (int k0 = 0; k0 < n; k0 += blockDim.x)
{
int k = k0 + threadIdx.x;
scalar_t x1 = 0, y1 = 0, z1 = 0;
if (k < n)
{
x1 = xyz1[i * n * 3 + k * 3 + 0];
y1 = xyz1[i * n * 3 + k * 3 + 1];
z1 = xyz1[i * n * 3 + k * 3 + 2];
}
scalar_t suml = 0;
for (int l0 = 0; l0 < m; l0 += Block)
{
int lend = min(m, l0 + Block) - l0;
for (int l = threadIdx.x; l < lend; l += blockDim.x)
{
buf[l * 4 + 0] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 0];
buf[l * 4 + 1] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 1];
buf[l * 4 + 2] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 2];
buf[l * 4 + 3] = ratioR[l0 + l];
}
__syncthreads();
scalar_t rl = ratioL[k];
if (k < n)
{
for (int l = 0; l < lend; l++)
{
scalar_t x2 = buf[l * 4 + 0];
scalar_t y2 = buf[l * 4 + 1];
scalar_t z2 = buf[l * 4 + 2];
scalar_t w = __expf(level * ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1))) * rl * buf[l * 4 + 3];
match[i * n * m + (l0 + l) * n + k] += w;
suml += w;
}
}
__syncthreads();
}
if (k < n)
remainL[k] = fmaxf(0.0f, remainL[k] - suml);
}
__syncthreads();
}
}
}
/* ApproxMatch forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
Output:
match: (B, N2, N1)
*/
at::Tensor ApproxMatchForward(
const at::Tensor xyz1,
const at::Tensor xyz2)
{
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
TORCH_CHECK_EQ(xyz2.size(0), b);
TORCH_CHECK_EQ(xyz1.size(2), 3);
TORCH_CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto match = at::zeros({b, m, n}, xyz1.type());
auto temp = at::zeros({b, (n + m) * 2}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&]
{ approxmatch<scalar_t><<<32, 512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>()); }));
C10_CUDA_CHECK(cudaGetLastError());
return match;
}
/********************************
* Forward kernel for matchcost
*********************************/
template <typename scalar_t>
__global__ void matchcost(int b, int n, int m, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ out)
{
__shared__ scalar_t allsum[512];
const int Block = 1024;
__shared__ scalar_t buf[Block * 3];
for (int i = blockIdx.x; i < b; i += gridDim.x)
{
scalar_t subsum = 0;
for (int k0 = 0; k0 < n; k0 += blockDim.x)
{
int k = k0 + threadIdx.x;
scalar_t x1 = 0, y1 = 0, z1 = 0;
if (k < n)
{
x1 = xyz1[i * n * 3 + k * 3 + 0];
y1 = xyz1[i * n * 3 + k * 3 + 1];
z1 = xyz1[i * n * 3 + k * 3 + 2];
}
for (int l0 = 0; l0 < m; l0 += Block)
{
int lend = min(m, l0 + Block) - l0;
for (int l = threadIdx.x; l < lend * 3; l += blockDim.x)
buf[l] = xyz2[i * m * 3 + l0 * 3 + l];
__syncthreads();
if (k < n)
{
for (int l = 0; l < lend; l++)
{
scalar_t x2 = buf[l * 3 + 0];
scalar_t y2 = buf[l * 3 + 1];
scalar_t z2 = buf[l * 3 + 2];
scalar_t d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
subsum += d * match[i * n * m + (l0 + l) * n + k];
}
}
__syncthreads();
}
}
allsum[threadIdx.x] = subsum;
for (int j = 1; j < blockDim.x; j <<= 1)
{
__syncthreads();
if ((threadIdx.x & j) == 0 && threadIdx.x + j < blockDim.x)
{
allsum[threadIdx.x] += allsum[threadIdx.x + j];
}
}
if (threadIdx.x == 0)
out[i] = allsum[0];
__syncthreads();
}
}
/* MatchCost forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
cost: (B)
*/
at::Tensor MatchCostForward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match)
{
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
TORCH_CHECK_EQ(xyz2.size(0), b);
TORCH_CHECK_EQ(xyz1.size(2), 3);
TORCH_CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto cost = at::zeros({b}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&]
{ matchcost<scalar_t><<<32, 512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>()); }));
C10_CUDA_CHECK(cudaGetLastError());
return cost;
}
/********************************
* matchcostgrad2 kernel
*********************************/
template <typename scalar_t>
__global__ void matchcostgrad2(int b, int n, int m, const scalar_t *__restrict__ grad_cost, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ grad2)
{
__shared__ scalar_t sum_grad[256 * 3];
for (int i = blockIdx.x; i < b; i += gridDim.x)
{
int kbeg = m * blockIdx.y / gridDim.y;
int kend = m * (blockIdx.y + 1) / gridDim.y;
for (int k = kbeg; k < kend; k++)
{
scalar_t x2 = xyz2[(i * m + k) * 3 + 0];
scalar_t y2 = xyz2[(i * m + k) * 3 + 1];
scalar_t z2 = xyz2[(i * m + k) * 3 + 2];
scalar_t subsumx = 0, subsumy = 0, subsumz = 0;
for (int j = threadIdx.x; j < n; j += blockDim.x)
{
scalar_t x1 = x2 - xyz1[(i * n + j) * 3 + 0];
scalar_t y1 = y2 - xyz1[(i * n + j) * 3 + 1];
scalar_t z1 = z2 - xyz1[(i * n + j) * 3 + 2];
scalar_t d = match[i * n * m + k * n + j] * 2;
subsumx += x1 * d;
subsumy += y1 * d;
subsumz += z1 * d;
}
sum_grad[threadIdx.x * 3 + 0] = subsumx;
sum_grad[threadIdx.x * 3 + 1] = subsumy;
sum_grad[threadIdx.x * 3 + 2] = subsumz;
for (int j = 1; j < blockDim.x; j <<= 1)
{
__syncthreads();
int j1 = threadIdx.x;
int j2 = threadIdx.x + j;
if ((j1 & j) == 0 && j2 < blockDim.x)
{
sum_grad[j1 * 3 + 0] += sum_grad[j2 * 3 + 0];
sum_grad[j1 * 3 + 1] += sum_grad[j2 * 3 + 1];
sum_grad[j1 * 3 + 2] += sum_grad[j2 * 3 + 2];
}
}
if (threadIdx.x == 0)
{
grad2[(i * m + k) * 3 + 0] = sum_grad[0] * grad_cost[i];
grad2[(i * m + k) * 3 + 1] = sum_grad[1] * grad_cost[i];
grad2[(i * m + k) * 3 + 2] = sum_grad[2] * grad_cost[i];
}
__syncthreads();
}
}
}
/********************************
* matchcostgrad1 kernel
*********************************/
template <typename scalar_t>
__global__ void matchcostgrad1(int b, int n, int m, const scalar_t *__restrict__ grad_cost, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ grad1)
{
for (int i = blockIdx.x; i < b; i += gridDim.x)
{
for (int l = threadIdx.x; l < n; l += blockDim.x)
{
scalar_t x1 = xyz1[i * n * 3 + l * 3 + 0];
scalar_t y1 = xyz1[i * n * 3 + l * 3 + 1];
scalar_t z1 = xyz1[i * n * 3 + l * 3 + 2];
scalar_t dx = 0, dy = 0, dz = 0;
for (int k = 0; k < m; k++)
{
scalar_t x2 = xyz2[i * m * 3 + k * 3 + 0];
scalar_t y2 = xyz2[i * m * 3 + k * 3 + 1];
scalar_t z2 = xyz2[i * m * 3 + k * 3 + 2];
scalar_t d = match[i * n * m + k * n + l] * 2;
dx += (x1 - x2) * d;
dy += (y1 - y2) * d;
dz += (z1 - z2) * d;
}
grad1[i * n * 3 + l * 3 + 0] = dx * grad_cost[i];
grad1[i * n * 3 + l * 3 + 1] = dy * grad_cost[i];
grad1[i * n * 3 + l * 3 + 2] = dz * grad_cost[i];
}
}
}
/* MatchCost backward interface
Input:
grad_cost: (B) # gradients on cost
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
grad1: (B, N1, 3)
grad2: (B, N2, 3)
*/
std::vector<at::Tensor> MatchCostBackward(
const at::Tensor grad_cost,
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match)
{
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
TORCH_CHECK_EQ(xyz2.size(0), b);
TORCH_CHECK_EQ(xyz1.size(2), 3);
TORCH_CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto grad1 = at::zeros({b, n, 3}, xyz1.type());
auto grad2 = at::zeros({b, m, 3}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&]
{
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>()); }));
C10_CUDA_CHECK(cudaGetLastError());
return std::vector<at::Tensor>({grad1, grad2});
}
#endif

60
torchemd/emd.py Normal file
View file

@ -0,0 +1,60 @@
import pathlib
import torch
from torch.utils.cpp_extension import load
# get path to where this file is located
REALPATH = pathlib.Path(__file__).parent
# JIT compile CUDA extensions
load(
name="torchemd_cuda",
sources=[
str(REALPATH / file)
for file in [
"cuda/emd.cpp",
"cuda/emd_kernel.cu",
]
],
)
# import compiled CUDA extensions
import torchemd_cuda
class EarthMoverDistanceFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2) -> torch.Tensor:
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = torchemd_cuda.approxmatch_forward(xyz1, xyz2)
cost = torchemd_cuda.matchcost_forward(xyz1, xyz2, match)
ctx.save_for_backward(xyz1, xyz2, match)
return cost
@staticmethod
def backward(ctx, grad_cost) -> tuple[torch.Tensor, torch.Tensor]:
xyz1, xyz2, match = ctx.saved_tensors
grad_cost = grad_cost.contiguous()
grad_xyz1, grad_xyz2 = torchemd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
return grad_xyz1, grad_xyz2
def earth_mover_distance(xyz1, xyz2, transpose=True) -> torch.Tensor:
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (B, N, 3) or (N, 3)
xyz2 (torch.Tensor): (B, M, 3) or (M, 3)
transpose (bool, optional): If True, xyz1 and xyz2 will be transposed to (B, 3, N) and (B, 3, M).
"""
if xyz1.dim() == 2:
xyz1 = xyz1.unsqueeze(0)
if xyz2.dim() == 2:
xyz2 = xyz2.unsqueeze(0)
if transpose:
xyz1 = xyz1.transpose(1, 2)
xyz2 = xyz2.transpose(1, 2)
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
return cost