Compare commits
No commits in common. "2b5ac1c436aec81893de415ec95b49ff9353a211" and "7e145fd312b211e26eafffb7e69b9d2e6d043710" have entirely different histories.
2b5ac1c436
...
7e145fd312
|
@ -1,15 +0,0 @@
|
|||
# EditorConfig is awesome: https://EditorConfig.org
|
||||
|
||||
# top-most EditorConfig file
|
||||
root = true
|
||||
|
||||
[*]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.{json,toml,yaml,yml}]
|
||||
indent_size = 2
|
75
.gitattributes
vendored
75
.gitattributes
vendored
|
@ -1,75 +0,0 @@
|
|||
####################################################################################################
|
||||
# Python .gitattributes template
|
||||
# https://github.com/alexkaratarakis/gitattributes/blob/master/Python.gitattributes
|
||||
####################################################################################################
|
||||
|
||||
# Source files
|
||||
# ============
|
||||
*.pxd text diff=python
|
||||
*.py text diff=python
|
||||
*.py3 text diff=python
|
||||
*.pyw text diff=python
|
||||
*.pyx text diff=python
|
||||
*.pyz text diff=python
|
||||
*.pyi text diff=python
|
||||
|
||||
# Binary files
|
||||
# ============
|
||||
*.db binary
|
||||
*.p binary
|
||||
*.pkl binary
|
||||
*.pickle binary
|
||||
*.pyc binary export-ignore
|
||||
*.pyo binary export-ignore
|
||||
*.pyd binary
|
||||
|
||||
# Jupyter notebook
|
||||
*.ipynb text
|
||||
|
||||
# Note: .db, .p, and .pkl files are associated
|
||||
# with the python modules ``pickle``, ``dbm.*``,
|
||||
# ``shelve``, ``marshal``, ``anydbm``, & ``bsddb``
|
||||
# (among others).
|
||||
|
||||
####################################################################################################
|
||||
# C++ .gitattributes template
|
||||
# https://github.com/alexkaratarakis/gitattributes/blob/master/C%2B%2B.gitattributes
|
||||
####################################################################################################
|
||||
|
||||
# Sources
|
||||
*.c text diff=cpp
|
||||
*.cc text diff=cpp
|
||||
*.cxx text diff=cpp
|
||||
*.cpp text diff=cpp
|
||||
*.cpi text diff=cpp
|
||||
*.c++ text diff=cpp
|
||||
*.hpp text diff=cpp
|
||||
*.h text diff=cpp
|
||||
*.h++ text diff=cpp
|
||||
*.hh text diff=cpp
|
||||
|
||||
# Compiled Object files
|
||||
*.slo binary
|
||||
*.lo binary
|
||||
*.o binary
|
||||
*.obj binary
|
||||
|
||||
# Precompiled Headers
|
||||
*.gch binary
|
||||
*.pch binary
|
||||
|
||||
# Compiled Dynamic libraries
|
||||
*.so binary
|
||||
*.dylib binary
|
||||
*.dll binary
|
||||
|
||||
# Compiled Static libraries
|
||||
*.lai binary
|
||||
*.la binary
|
||||
*.a binary
|
||||
*.lib binary
|
||||
|
||||
# Executables
|
||||
*.exe binary
|
||||
*.out binary
|
||||
*.app binary
|
28
.vscode/settings.json
vendored
28
.vscode/settings.json
vendored
|
@ -1,28 +0,0 @@
|
|||
{
|
||||
// good pratice settings
|
||||
"editor.formatOnSave": true,
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.lintOnSave": true,
|
||||
"python.linting.mypyEnabled": true,
|
||||
"files.insertFinalNewline": true,
|
||||
"python.analysis.typeCheckingMode": "basic", // get ready to be annoyed
|
||||
"python.formatting.provider": "black",
|
||||
"[python]": {
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": true // isort
|
||||
}
|
||||
},
|
||||
"files.exclude": {
|
||||
// defaults
|
||||
"**/.git": true,
|
||||
"**/.svn": true,
|
||||
"**/.hg": true,
|
||||
"**/CVS": true,
|
||||
"**/.DS_Store": true,
|
||||
"**/Thumbs.db": true,
|
||||
// annoying
|
||||
"**/__pycache__": true,
|
||||
"**/.mypy_cache": true,
|
||||
"**/.pytest_cache": true,
|
||||
}
|
||||
}
|
24
README.md
24
README.md
|
@ -1,11 +1,31 @@
|
|||
# PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD)
|
||||
|
||||
## Dependency
|
||||
|
||||
The code has been tested on Ubuntu 16.04, PyTorch 1.1.0, CUDA 9.0.
|
||||
|
||||
## Usage
|
||||
|
||||
First compile using
|
||||
|
||||
python setup.py install
|
||||
|
||||
Then, copy the lib file out to the main directory,
|
||||
|
||||
cp build/lib.linux-x86_64-3.6/emd_cuda.cpython-36m-x86_64-linux-gnu.so .
|
||||
|
||||
Then, you can use it by simply
|
||||
|
||||
from emd import earth_mover_distance
|
||||
d = earth_mover_distance(p1, p2, transpose=False) # p1: B x N1 x 3, p2: B x N2 x 3
|
||||
|
||||
Check `test_emd_loss.py` for example.
|
||||
|
||||
## Author
|
||||
|
||||
The cuda code is originally written by Haoqiang Fan.
|
||||
The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.
|
||||
The cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
|
|
0
__init__.py
Executable file
0
__init__.py
Executable file
|
@ -4,7 +4,7 @@
|
|||
#include <vector>
|
||||
#include <torch/extension.h>
|
||||
|
||||
// CUDA declarations
|
||||
//CUDA declarations
|
||||
at::Tensor ApproxMatchForward(
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2);
|
||||
|
@ -20,11 +20,10 @@ std::vector<at::Tensor> MatchCostBackward(
|
|||
const at::Tensor xyz2,
|
||||
const at::Tensor match);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("approxmatch_forward", &ApproxMatchForward, "ApproxMatch forward (CUDA)");
|
||||
m.def("matchcost_forward", &MatchCostForward, "MatchCost forward (CUDA)");
|
||||
m.def("matchcost_backward", &MatchCostBackward, "MatchCost backward (CUDA)");
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)");
|
||||
m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)");
|
||||
m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)");
|
||||
}
|
||||
|
||||
#endif
|
399
cuda/emd_kernel.cu
Normal file
399
cuda/emd_kernel.cu
Normal file
|
@ -0,0 +1,399 @@
|
|||
/**********************************
|
||||
* Original Author: Haoqiang Fan
|
||||
* Modified by: Kaichun Mo
|
||||
*********************************/
|
||||
|
||||
#ifndef _EMD_KERNEL
|
||||
#define _EMD_KERNEL
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
|
||||
/********************************
|
||||
* Forward kernel for approxmatch
|
||||
*********************************/
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){
|
||||
scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
|
||||
scalar_t multiL,multiR;
|
||||
if (n>=m){
|
||||
multiL=1;
|
||||
multiR=n/m;
|
||||
}else{
|
||||
multiL=m/n;
|
||||
multiR=1;
|
||||
}
|
||||
const int Block=1024;
|
||||
__shared__ scalar_t buf[Block*4];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int j=threadIdx.x;j<n*m;j+=blockDim.x)
|
||||
match[i*n*m+j]=0;
|
||||
for (int j=threadIdx.x;j<n;j+=blockDim.x)
|
||||
remainL[j]=multiL;
|
||||
for (int j=threadIdx.x;j<m;j+=blockDim.x)
|
||||
remainR[j]=multiR;
|
||||
__syncthreads();
|
||||
for (int j=7;j>=-2;j--){
|
||||
scalar_t level=-powf(4.0f,j);
|
||||
if (j==-2){
|
||||
level=0;
|
||||
}
|
||||
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||
int k=k0+threadIdx.x;
|
||||
scalar_t x1=0,y1=0,z1=0;
|
||||
if (k<n){
|
||||
x1=xyz1[i*n*3+k*3+0];
|
||||
y1=xyz1[i*n*3+k*3+1];
|
||||
z1=xyz1[i*n*3+k*3+2];
|
||||
}
|
||||
scalar_t suml=1e-9f;
|
||||
for (int l0=0;l0<m;l0+=Block){
|
||||
int lend=min(m,l0+Block)-l0;
|
||||
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
|
||||
scalar_t x2=xyz2[i*m*3+l0*3+l*3+0];
|
||||
scalar_t y2=xyz2[i*m*3+l0*3+l*3+1];
|
||||
scalar_t z2=xyz2[i*m*3+l0*3+l*3+2];
|
||||
buf[l*4+0]=x2;
|
||||
buf[l*4+1]=y2;
|
||||
buf[l*4+2]=z2;
|
||||
buf[l*4+3]=remainR[l0+l];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int l=0;l<lend;l++){
|
||||
scalar_t x2=buf[l*4+0];
|
||||
scalar_t y2=buf[l*4+1];
|
||||
scalar_t z2=buf[l*4+2];
|
||||
scalar_t d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));
|
||||
scalar_t w=__expf(d)*buf[l*4+3];
|
||||
suml+=w;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (k<n)
|
||||
ratioL[k]=remainL[k]/suml;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int l0=0;l0<m;l0+=blockDim.x){
|
||||
int l=l0+threadIdx.x;
|
||||
scalar_t x2=0,y2=0,z2=0;
|
||||
if (l<m){
|
||||
x2=xyz2[i*m*3+l*3+0];
|
||||
y2=xyz2[i*m*3+l*3+1];
|
||||
z2=xyz2[i*m*3+l*3+2];
|
||||
}
|
||||
scalar_t sumr=0;
|
||||
for (int k0=0;k0<n;k0+=Block){
|
||||
int kend=min(n,k0+Block)-k0;
|
||||
for (int k=threadIdx.x;k<kend;k+=blockDim.x){
|
||||
buf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];
|
||||
buf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];
|
||||
buf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];
|
||||
buf[k*4+3]=ratioL[k0+k];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int k=0;k<kend;k++){
|
||||
scalar_t x1=buf[k*4+0];
|
||||
scalar_t y1=buf[k*4+1];
|
||||
scalar_t z1=buf[k*4+2];
|
||||
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];
|
||||
sumr+=w;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (l<m){
|
||||
sumr*=remainR[l];
|
||||
scalar_t consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
|
||||
ratioR[l]=consumption*remainR[l];
|
||||
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||
int k=k0+threadIdx.x;
|
||||
scalar_t x1=0,y1=0,z1=0;
|
||||
if (k<n){
|
||||
x1=xyz1[i*n*3+k*3+0];
|
||||
y1=xyz1[i*n*3+k*3+1];
|
||||
z1=xyz1[i*n*3+k*3+2];
|
||||
}
|
||||
scalar_t suml=0;
|
||||
for (int l0=0;l0<m;l0+=Block){
|
||||
int lend=min(m,l0+Block)-l0;
|
||||
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
|
||||
buf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];
|
||||
buf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];
|
||||
buf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];
|
||||
buf[l*4+3]=ratioR[l0+l];
|
||||
}
|
||||
__syncthreads();
|
||||
scalar_t rl=ratioL[k];
|
||||
if (k<n){
|
||||
for (int l=0;l<lend;l++){
|
||||
scalar_t x2=buf[l*4+0];
|
||||
scalar_t y2=buf[l*4+1];
|
||||
scalar_t z2=buf[l*4+2];
|
||||
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];
|
||||
match[i*n*m+(l0+l)*n+k]+=w;
|
||||
suml+=w;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (k<n)
|
||||
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//void approxmatchLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,scalar_t * match,scalar_t * temp){
|
||||
// approxmatch<<<32,512>>>(b,n,m,xyz1,xyz2,match,temp);
|
||||
//}
|
||||
|
||||
/* ApproxMatch forward interface
|
||||
Input:
|
||||
xyz1: (B, N1, 3) # dataset_points
|
||||
xyz2: (B, N2, 3) # query_points
|
||||
Output:
|
||||
match: (B, N2, N1)
|
||||
*/
|
||||
at::Tensor ApproxMatchForward(
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2){
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
TORCH_CHECK_EQ (xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ (xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ (xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
auto match = at::zeros({b, m, n}, xyz1.type());
|
||||
auto temp = at::zeros({b, (n+m)*2}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
||||
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
||||
}));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
|
||||
/********************************
|
||||
* Forward kernel for matchcost
|
||||
*********************************/
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){
|
||||
__shared__ scalar_t allsum[512];
|
||||
const int Block=1024;
|
||||
__shared__ scalar_t buf[Block*3];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
scalar_t subsum=0;
|
||||
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||
int k=k0+threadIdx.x;
|
||||
scalar_t x1=0,y1=0,z1=0;
|
||||
if (k<n){
|
||||
x1=xyz1[i*n*3+k*3+0];
|
||||
y1=xyz1[i*n*3+k*3+1];
|
||||
z1=xyz1[i*n*3+k*3+2];
|
||||
}
|
||||
for (int l0=0;l0<m;l0+=Block){
|
||||
int lend=min(m,l0+Block)-l0;
|
||||
for (int l=threadIdx.x;l<lend*3;l+=blockDim.x)
|
||||
buf[l]=xyz2[i*m*3+l0*3+l];
|
||||
__syncthreads();
|
||||
if (k<n){
|
||||
for (int l=0;l<lend;l++){
|
||||
scalar_t x2=buf[l*3+0];
|
||||
scalar_t y2=buf[l*3+1];
|
||||
scalar_t z2=buf[l*3+2];
|
||||
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
|
||||
subsum+=d*match[i*n*m+(l0+l)*n+k];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
allsum[threadIdx.x]=subsum;
|
||||
for (int j=1;j<blockDim.x;j<<=1){
|
||||
__syncthreads();
|
||||
if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){
|
||||
allsum[threadIdx.x]+=allsum[threadIdx.x+j];
|
||||
}
|
||||
}
|
||||
if (threadIdx.x==0)
|
||||
out[i]=allsum[0];
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
//void matchcostLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * out){
|
||||
// matchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);
|
||||
//}
|
||||
|
||||
/* MatchCost forward interface
|
||||
Input:
|
||||
xyz1: (B, N1, 3) # dataset_points
|
||||
xyz2: (B, N2, 3) # query_points
|
||||
match: (B, N2, N1)
|
||||
Output:
|
||||
cost: (B)
|
||||
*/
|
||||
at::Tensor MatchCostForward(
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2,
|
||||
const at::Tensor match){
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
TORCH_CHECK_EQ (xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ (xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ (xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
auto cost = at::zeros({b}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
||||
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
||||
}));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
/********************************
|
||||
* matchcostgrad2 kernel
|
||||
*********************************/
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){
|
||||
__shared__ scalar_t sum_grad[256*3];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
int kbeg=m*blockIdx.y/gridDim.y;
|
||||
int kend=m*(blockIdx.y+1)/gridDim.y;
|
||||
for (int k=kbeg;k<kend;k++){
|
||||
scalar_t x2=xyz2[(i*m+k)*3+0];
|
||||
scalar_t y2=xyz2[(i*m+k)*3+1];
|
||||
scalar_t z2=xyz2[(i*m+k)*3+2];
|
||||
scalar_t subsumx=0,subsumy=0,subsumz=0;
|
||||
for (int j=threadIdx.x;j<n;j+=blockDim.x){
|
||||
scalar_t x1=x2-xyz1[(i*n+j)*3+0];
|
||||
scalar_t y1=y2-xyz1[(i*n+j)*3+1];
|
||||
scalar_t z1=z2-xyz1[(i*n+j)*3+2];
|
||||
scalar_t d=match[i*n*m+k*n+j]*2;
|
||||
subsumx+=x1*d;
|
||||
subsumy+=y1*d;
|
||||
subsumz+=z1*d;
|
||||
}
|
||||
sum_grad[threadIdx.x*3+0]=subsumx;
|
||||
sum_grad[threadIdx.x*3+1]=subsumy;
|
||||
sum_grad[threadIdx.x*3+2]=subsumz;
|
||||
for (int j=1;j<blockDim.x;j<<=1){
|
||||
__syncthreads();
|
||||
int j1=threadIdx.x;
|
||||
int j2=threadIdx.x+j;
|
||||
if ((j1&j)==0 && j2<blockDim.x){
|
||||
sum_grad[j1*3+0]+=sum_grad[j2*3+0];
|
||||
sum_grad[j1*3+1]+=sum_grad[j2*3+1];
|
||||
sum_grad[j1*3+2]+=sum_grad[j2*3+2];
|
||||
}
|
||||
}
|
||||
if (threadIdx.x==0){
|
||||
grad2[(i*m+k)*3+0]=sum_grad[0]*grad_cost[i];
|
||||
grad2[(i*m+k)*3+1]=sum_grad[1]*grad_cost[i];
|
||||
grad2[(i*m+k)*3+2]=sum_grad[2]*grad_cost[i];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/********************************
|
||||
* matchcostgrad1 kernel
|
||||
*********************************/
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int l=threadIdx.x;l<n;l+=blockDim.x){
|
||||
scalar_t x1=xyz1[i*n*3+l*3+0];
|
||||
scalar_t y1=xyz1[i*n*3+l*3+1];
|
||||
scalar_t z1=xyz1[i*n*3+l*3+2];
|
||||
scalar_t dx=0,dy=0,dz=0;
|
||||
for (int k=0;k<m;k++){
|
||||
scalar_t x2=xyz2[i*m*3+k*3+0];
|
||||
scalar_t y2=xyz2[i*m*3+k*3+1];
|
||||
scalar_t z2=xyz2[i*m*3+k*3+2];
|
||||
scalar_t d=match[i*n*m+k*n+l]*2;
|
||||
dx+=(x1-x2)*d;
|
||||
dy+=(y1-y2)*d;
|
||||
dz+=(z1-z2)*d;
|
||||
}
|
||||
grad1[i*n*3+l*3+0]=dx*grad_cost[i];
|
||||
grad1[i*n*3+l*3+1]=dy*grad_cost[i];
|
||||
grad1[i*n*3+l*3+2]=dz*grad_cost[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//void matchcostgradLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * grad1,scalar_t * grad2){
|
||||
// matchcostgrad1<<<32,512>>>(b,n,m,xyz1,xyz2,match,grad1);
|
||||
// matchcostgrad2<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);
|
||||
//}
|
||||
|
||||
|
||||
/* MatchCost backward interface
|
||||
Input:
|
||||
grad_cost: (B) # gradients on cost
|
||||
xyz1: (B, N1, 3) # dataset_points
|
||||
xyz2: (B, N2, 3) # query_points
|
||||
match: (B, N2, N1)
|
||||
Output:
|
||||
grad1: (B, N1, 3)
|
||||
grad2: (B, N2, 3)
|
||||
*/
|
||||
std::vector<at::Tensor> MatchCostBackward(
|
||||
const at::Tensor grad_cost,
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2,
|
||||
const at::Tensor match){
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
TORCH_CHECK_EQ (xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ (xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ (xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
auto grad1 = at::zeros({b, n, 3}, xyz1.type());
|
||||
auto grad2 = at::zeros({b, m, 3}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] {
|
||||
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
||||
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
||||
}));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return std::vector<at::Tensor>({grad1, grad2});
|
||||
}
|
||||
|
||||
#endif
|
46
emd.py
Executable file
46
emd.py
Executable file
|
@ -0,0 +1,46 @@
|
|||
import torch
|
||||
import emd_cuda
|
||||
|
||||
|
||||
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz1, xyz2):
|
||||
xyz1 = xyz1.contiguous()
|
||||
xyz2 = xyz2.contiguous()
|
||||
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
|
||||
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
|
||||
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
|
||||
ctx.save_for_backward(xyz1, xyz2, match)
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_cost):
|
||||
xyz1, xyz2, match = ctx.saved_tensors
|
||||
grad_cost = grad_cost.contiguous()
|
||||
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
|
||||
return grad_xyz1, grad_xyz2
|
||||
|
||||
|
||||
def earth_mover_distance(xyz1, xyz2, transpose=True):
|
||||
"""Earth Mover Distance (Approx)
|
||||
|
||||
Args:
|
||||
xyz1 (torch.Tensor): (b, 3, n1)
|
||||
xyz2 (torch.Tensor): (b, 3, n1)
|
||||
transpose (bool): whether to transpose inputs as it might be BCN format.
|
||||
Extensions only support BNC format.
|
||||
|
||||
Returns:
|
||||
cost (torch.Tensor): (b)
|
||||
|
||||
"""
|
||||
if xyz1.dim() == 2:
|
||||
xyz1 = xyz1.unsqueeze(0)
|
||||
if xyz2.dim() == 2:
|
||||
xyz2 = xyz2.unsqueeze(0)
|
||||
if transpose:
|
||||
xyz1 = xyz1.transpose(1, 2)
|
||||
xyz2 = xyz2.transpose(1, 2)
|
||||
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
||||
return cost
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
[project]
|
||||
name = "torchemd"
|
||||
version = "1.0.0"
|
||||
authors = [{ name = "Laurent Fainsin", email = "laurent@fainsin.bzh" }]
|
||||
description = "PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD)"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
|
||||
[build-system]
|
||||
requires = ["flit_core>=3.4"]
|
||||
build-backend = "flit_core.buildapi"
|
||||
|
||||
[tool.ruff]
|
||||
ignore-init-module-imports = true
|
||||
select = ["E", "F", "I"]
|
||||
line-length = 120
|
||||
|
||||
[tool.black]
|
||||
exclude = '''
|
||||
/(
|
||||
\.git
|
||||
\.venv
|
||||
)/
|
||||
'''
|
||||
include = '\.pyi?$'
|
||||
line-length = 120
|
||||
target-version = ["py310"]
|
||||
|
||||
[tool.isort]
|
||||
multi_line_output = 3
|
||||
profile = "black"
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
27
setup.py
Executable file
27
setup.py
Executable file
|
@ -0,0 +1,27 @@
|
|||
"""Setup extension
|
||||
|
||||
Notes:
|
||||
If extra_compile_args is provided, you need to provide different instances for different extensions.
|
||||
Refer to https://github.com/pytorch/pytorch/issues/20169
|
||||
|
||||
"""
|
||||
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
setup(
|
||||
name='emd_ext',
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='emd_cuda',
|
||||
sources=[
|
||||
'cuda/emd.cpp',
|
||||
'cuda/emd_kernel.cu',
|
||||
],
|
||||
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
44
test_emd_loss.py
Normal file
44
test_emd_loss.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
from emd import earth_mover_distance
|
||||
|
||||
# gt
|
||||
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p1 = p1.repeat(3, 1, 1)
|
||||
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p2 = p2.repeat(3, 1, 1)
|
||||
print(p1)
|
||||
print(p2)
|
||||
p1.requires_grad = True
|
||||
p2.requires_grad = True
|
||||
|
||||
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
|
||||
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
|
||||
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
|
||||
print('gt_dist: ', gt_dist)
|
||||
|
||||
gt_dist.backward()
|
||||
print(p1.grad)
|
||||
print(p2.grad)
|
||||
|
||||
# emd
|
||||
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p1 = p1.repeat(3, 1, 1)
|
||||
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p2 = p2.repeat(3, 1, 1)
|
||||
print(p1)
|
||||
print(p2)
|
||||
p1.requires_grad = True
|
||||
p2.requires_grad = True
|
||||
|
||||
d = earth_mover_distance(p1, p2, transpose=False)
|
||||
print(d)
|
||||
|
||||
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
|
||||
print(loss)
|
||||
|
||||
loss.backward()
|
||||
print(p1.grad)
|
||||
print(p2.grad)
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
import torch
|
||||
from einops import repeat
|
||||
from torch import Tensor
|
||||
|
||||
from torchemd import earth_mover_distance
|
||||
|
||||
|
||||
def generate_pointclouds() -> tuple[Tensor, Tensor]:
|
||||
# create first point cloud
|
||||
pc1 = torch.Tensor(
|
||||
[
|
||||
[1.7, -0.1, 0.1],
|
||||
[0.1, 1.2, 0.3],
|
||||
],
|
||||
).cuda()
|
||||
pc1 = repeat(pc1, "n c -> b n c", b=3)
|
||||
pc1.requires_grad = True
|
||||
|
||||
# create second point cloud
|
||||
pc2 = torch.Tensor(
|
||||
[
|
||||
[0.4, 1.8, 0.2],
|
||||
[1.2, -0.2, 0.3],
|
||||
],
|
||||
).cuda()
|
||||
pc2 = repeat(pc2, "n c -> b n c", b=3)
|
||||
pc2.requires_grad = True
|
||||
|
||||
return pc1, pc2
|
||||
|
||||
|
||||
def dummy_loss(distance: Tensor) -> Tensor:
|
||||
return distance[0] / 2 + distance[1] * 2 + distance[2] / 3
|
||||
|
||||
|
||||
def manual_distance_computation(pc1: Tensor, pc2: Tensor) -> Tensor:
|
||||
return (pc1[:, 0] - pc2[:, 1]).pow(2).sum(dim=1) + (pc1[:, 1] - pc2[:, 0]).pow(2).sum(dim=1)
|
||||
|
||||
|
||||
def test_emd():
|
||||
# compute earth mover distance directly from formula
|
||||
pc1, pc2 = generate_pointclouds()
|
||||
ground_truth_distance = manual_distance_computation(pc1, pc2)
|
||||
ground_truth_loss = dummy_loss(ground_truth_distance)
|
||||
# get gradients of point clouds
|
||||
ground_truth_loss.backward()
|
||||
pc1_gt_grad = pc1.grad
|
||||
pc2_gt_grad = pc2.grad
|
||||
|
||||
# compute earth mover distance directly from implementation
|
||||
pc1, pc2 = generate_pointclouds()
|
||||
computed_distance = earth_mover_distance(pc1, pc2, transpose=False)
|
||||
loss = dummy_loss(computed_distance)
|
||||
# get gradients of point clouds
|
||||
loss.backward()
|
||||
pc1_grad = pc1.grad
|
||||
pc2_grad = pc2.grad
|
||||
|
||||
# compare gradients
|
||||
assert pc1_grad.allclose(pc1_gt_grad)
|
||||
assert pc2_grad.allclose(pc2_gt_grad)
|
||||
|
||||
# compare distances
|
||||
assert computed_distance.allclose(ground_truth_distance)
|
||||
|
||||
# compare loss
|
||||
assert loss.allclose(ground_truth_loss)
|
||||
|
||||
|
||||
def test_equality():
|
||||
pc1, _ = generate_pointclouds()
|
||||
distance = earth_mover_distance(pc1, pc1, transpose=False)
|
||||
assert distance.allclose(torch.zeros_like(distance))
|
||||
|
||||
|
||||
def test_symetry():
|
||||
pc1, pc2 = generate_pointclouds()
|
||||
distance1 = earth_mover_distance(pc1, pc2, transpose=False)
|
||||
distance2 = earth_mover_distance(pc2, pc1, transpose=False)
|
||||
assert distance1.allclose(distance2)
|
|
@ -1 +0,0 @@
|
|||
from .emd import EarthMoverDistanceFunction, earth_mover_distance
|
|
@ -1,433 +0,0 @@
|
|||
/**********************************
|
||||
* Original Author: Haoqiang Fan
|
||||
* Modified by: Kaichun Mo
|
||||
*********************************/
|
||||
|
||||
#ifndef _EMD_KERNEL
|
||||
#define _EMD_KERNEL
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
/********************************
|
||||
* Forward kernel for approxmatch
|
||||
*********************************/
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void approxmatch(int b, int n, int m, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, scalar_t *__restrict__ match, scalar_t *temp)
|
||||
{
|
||||
scalar_t *remainL = temp + blockIdx.x * (n + m) * 2;
|
||||
scalar_t *remainR = temp + blockIdx.x * (n + m) * 2 + n;
|
||||
scalar_t *ratioL = temp + blockIdx.x * (n + m) * 2 + n + m;
|
||||
scalar_t *ratioR = temp + blockIdx.x * (n + m) * 2 + n + m + n;
|
||||
scalar_t multiL, multiR;
|
||||
|
||||
if (n >= m)
|
||||
{
|
||||
multiL = 1;
|
||||
multiR = n / m;
|
||||
}
|
||||
else
|
||||
{
|
||||
multiL = m / n;
|
||||
multiR = 1;
|
||||
}
|
||||
const int Block = 1024;
|
||||
__shared__ scalar_t buf[Block * 4];
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x)
|
||||
{
|
||||
for (int j = threadIdx.x; j < n * m; j += blockDim.x)
|
||||
match[i * n * m + j] = 0;
|
||||
for (int j = threadIdx.x; j < n; j += blockDim.x)
|
||||
remainL[j] = multiL;
|
||||
for (int j = threadIdx.x; j < m; j += blockDim.x)
|
||||
remainR[j] = multiR;
|
||||
__syncthreads();
|
||||
for (int j = 7; j >= -2; j--)
|
||||
{
|
||||
scalar_t level = -powf(4.0f, j);
|
||||
if (j == -2)
|
||||
{
|
||||
level = 0;
|
||||
}
|
||||
for (int k0 = 0; k0 < n; k0 += blockDim.x)
|
||||
{
|
||||
int k = k0 + threadIdx.x;
|
||||
scalar_t x1 = 0, y1 = 0, z1 = 0;
|
||||
if (k < n)
|
||||
{
|
||||
x1 = xyz1[i * n * 3 + k * 3 + 0];
|
||||
y1 = xyz1[i * n * 3 + k * 3 + 1];
|
||||
z1 = xyz1[i * n * 3 + k * 3 + 2];
|
||||
}
|
||||
scalar_t suml = 1e-9f;
|
||||
for (int l0 = 0; l0 < m; l0 += Block)
|
||||
{
|
||||
int lend = min(m, l0 + Block) - l0;
|
||||
for (int l = threadIdx.x; l < lend; l += blockDim.x)
|
||||
{
|
||||
scalar_t x2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 0];
|
||||
scalar_t y2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 1];
|
||||
scalar_t z2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 2];
|
||||
buf[l * 4 + 0] = x2;
|
||||
buf[l * 4 + 1] = y2;
|
||||
buf[l * 4 + 2] = z2;
|
||||
buf[l * 4 + 3] = remainR[l0 + l];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int l = 0; l < lend; l++)
|
||||
{
|
||||
scalar_t x2 = buf[l * 4 + 0];
|
||||
scalar_t y2 = buf[l * 4 + 1];
|
||||
scalar_t z2 = buf[l * 4 + 2];
|
||||
scalar_t d = level * ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1));
|
||||
scalar_t w = __expf(d) * buf[l * 4 + 3];
|
||||
suml += w;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (k < n)
|
||||
ratioL[k] = remainL[k] / suml;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int l0 = 0; l0 < m; l0 += blockDim.x)
|
||||
{
|
||||
int l = l0 + threadIdx.x;
|
||||
scalar_t x2 = 0, y2 = 0, z2 = 0;
|
||||
if (l < m)
|
||||
{
|
||||
x2 = xyz2[i * m * 3 + l * 3 + 0];
|
||||
y2 = xyz2[i * m * 3 + l * 3 + 1];
|
||||
z2 = xyz2[i * m * 3 + l * 3 + 2];
|
||||
}
|
||||
scalar_t sumr = 0;
|
||||
for (int k0 = 0; k0 < n; k0 += Block)
|
||||
{
|
||||
int kend = min(n, k0 + Block) - k0;
|
||||
for (int k = threadIdx.x; k < kend; k += blockDim.x)
|
||||
{
|
||||
buf[k * 4 + 0] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 0];
|
||||
buf[k * 4 + 1] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 1];
|
||||
buf[k * 4 + 2] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 2];
|
||||
buf[k * 4 + 3] = ratioL[k0 + k];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int k = 0; k < kend; k++)
|
||||
{
|
||||
scalar_t x1 = buf[k * 4 + 0];
|
||||
scalar_t y1 = buf[k * 4 + 1];
|
||||
scalar_t z1 = buf[k * 4 + 2];
|
||||
scalar_t w = __expf(level * ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1))) * buf[k * 4 + 3];
|
||||
sumr += w;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (l < m)
|
||||
{
|
||||
sumr *= remainR[l];
|
||||
scalar_t consumption = fminf(remainR[l] / (sumr + 1e-9f), 1.0f);
|
||||
ratioR[l] = consumption * remainR[l];
|
||||
remainR[l] = fmaxf(0.0f, remainR[l] - sumr);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for (int k0 = 0; k0 < n; k0 += blockDim.x)
|
||||
{
|
||||
int k = k0 + threadIdx.x;
|
||||
scalar_t x1 = 0, y1 = 0, z1 = 0;
|
||||
if (k < n)
|
||||
{
|
||||
x1 = xyz1[i * n * 3 + k * 3 + 0];
|
||||
y1 = xyz1[i * n * 3 + k * 3 + 1];
|
||||
z1 = xyz1[i * n * 3 + k * 3 + 2];
|
||||
}
|
||||
scalar_t suml = 0;
|
||||
for (int l0 = 0; l0 < m; l0 += Block)
|
||||
{
|
||||
int lend = min(m, l0 + Block) - l0;
|
||||
for (int l = threadIdx.x; l < lend; l += blockDim.x)
|
||||
{
|
||||
buf[l * 4 + 0] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 0];
|
||||
buf[l * 4 + 1] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 1];
|
||||
buf[l * 4 + 2] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 2];
|
||||
buf[l * 4 + 3] = ratioR[l0 + l];
|
||||
}
|
||||
__syncthreads();
|
||||
scalar_t rl = ratioL[k];
|
||||
if (k < n)
|
||||
{
|
||||
for (int l = 0; l < lend; l++)
|
||||
{
|
||||
scalar_t x2 = buf[l * 4 + 0];
|
||||
scalar_t y2 = buf[l * 4 + 1];
|
||||
scalar_t z2 = buf[l * 4 + 2];
|
||||
scalar_t w = __expf(level * ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1))) * rl * buf[l * 4 + 3];
|
||||
match[i * n * m + (l0 + l) * n + k] += w;
|
||||
suml += w;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (k < n)
|
||||
remainL[k] = fmaxf(0.0f, remainL[k] - suml);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ApproxMatch forward interface
|
||||
Input:
|
||||
xyz1: (B, N1, 3) # dataset_points
|
||||
xyz2: (B, N2, 3) # query_points
|
||||
Output:
|
||||
match: (B, N2, N1)
|
||||
*/
|
||||
at::Tensor ApproxMatchForward(
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2)
|
||||
{
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
auto match = at::zeros({b, m, n}, xyz1.type());
|
||||
auto temp = at::zeros({b, (n + m) * 2}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&]
|
||||
{ approxmatch<scalar_t><<<32, 512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>()); }));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
/********************************
|
||||
* Forward kernel for matchcost
|
||||
*********************************/
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void matchcost(int b, int n, int m, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ out)
|
||||
{
|
||||
__shared__ scalar_t allsum[512];
|
||||
const int Block = 1024;
|
||||
__shared__ scalar_t buf[Block * 3];
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x)
|
||||
{
|
||||
scalar_t subsum = 0;
|
||||
for (int k0 = 0; k0 < n; k0 += blockDim.x)
|
||||
{
|
||||
int k = k0 + threadIdx.x;
|
||||
scalar_t x1 = 0, y1 = 0, z1 = 0;
|
||||
if (k < n)
|
||||
{
|
||||
x1 = xyz1[i * n * 3 + k * 3 + 0];
|
||||
y1 = xyz1[i * n * 3 + k * 3 + 1];
|
||||
z1 = xyz1[i * n * 3 + k * 3 + 2];
|
||||
}
|
||||
for (int l0 = 0; l0 < m; l0 += Block)
|
||||
{
|
||||
int lend = min(m, l0 + Block) - l0;
|
||||
for (int l = threadIdx.x; l < lend * 3; l += blockDim.x)
|
||||
buf[l] = xyz2[i * m * 3 + l0 * 3 + l];
|
||||
__syncthreads();
|
||||
if (k < n)
|
||||
{
|
||||
for (int l = 0; l < lend; l++)
|
||||
{
|
||||
scalar_t x2 = buf[l * 3 + 0];
|
||||
scalar_t y2 = buf[l * 3 + 1];
|
||||
scalar_t z2 = buf[l * 3 + 2];
|
||||
scalar_t d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
|
||||
subsum += d * match[i * n * m + (l0 + l) * n + k];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
allsum[threadIdx.x] = subsum;
|
||||
for (int j = 1; j < blockDim.x; j <<= 1)
|
||||
{
|
||||
__syncthreads();
|
||||
if ((threadIdx.x & j) == 0 && threadIdx.x + j < blockDim.x)
|
||||
{
|
||||
allsum[threadIdx.x] += allsum[threadIdx.x + j];
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0)
|
||||
out[i] = allsum[0];
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
/* MatchCost forward interface
|
||||
Input:
|
||||
xyz1: (B, N1, 3) # dataset_points
|
||||
xyz2: (B, N2, 3) # query_points
|
||||
match: (B, N2, N1)
|
||||
Output:
|
||||
cost: (B)
|
||||
*/
|
||||
at::Tensor MatchCostForward(
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2,
|
||||
const at::Tensor match)
|
||||
{
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
auto cost = at::zeros({b}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&]
|
||||
{ matchcost<scalar_t><<<32, 512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>()); }));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return cost;
|
||||
}
|
||||
|
||||
/********************************
|
||||
* matchcostgrad2 kernel
|
||||
*********************************/
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void matchcostgrad2(int b, int n, int m, const scalar_t *__restrict__ grad_cost, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ grad2)
|
||||
{
|
||||
__shared__ scalar_t sum_grad[256 * 3];
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x)
|
||||
{
|
||||
int kbeg = m * blockIdx.y / gridDim.y;
|
||||
int kend = m * (blockIdx.y + 1) / gridDim.y;
|
||||
for (int k = kbeg; k < kend; k++)
|
||||
{
|
||||
scalar_t x2 = xyz2[(i * m + k) * 3 + 0];
|
||||
scalar_t y2 = xyz2[(i * m + k) * 3 + 1];
|
||||
scalar_t z2 = xyz2[(i * m + k) * 3 + 2];
|
||||
scalar_t subsumx = 0, subsumy = 0, subsumz = 0;
|
||||
for (int j = threadIdx.x; j < n; j += blockDim.x)
|
||||
{
|
||||
scalar_t x1 = x2 - xyz1[(i * n + j) * 3 + 0];
|
||||
scalar_t y1 = y2 - xyz1[(i * n + j) * 3 + 1];
|
||||
scalar_t z1 = z2 - xyz1[(i * n + j) * 3 + 2];
|
||||
scalar_t d = match[i * n * m + k * n + j] * 2;
|
||||
subsumx += x1 * d;
|
||||
subsumy += y1 * d;
|
||||
subsumz += z1 * d;
|
||||
}
|
||||
sum_grad[threadIdx.x * 3 + 0] = subsumx;
|
||||
sum_grad[threadIdx.x * 3 + 1] = subsumy;
|
||||
sum_grad[threadIdx.x * 3 + 2] = subsumz;
|
||||
for (int j = 1; j < blockDim.x; j <<= 1)
|
||||
{
|
||||
__syncthreads();
|
||||
int j1 = threadIdx.x;
|
||||
int j2 = threadIdx.x + j;
|
||||
if ((j1 & j) == 0 && j2 < blockDim.x)
|
||||
{
|
||||
sum_grad[j1 * 3 + 0] += sum_grad[j2 * 3 + 0];
|
||||
sum_grad[j1 * 3 + 1] += sum_grad[j2 * 3 + 1];
|
||||
sum_grad[j1 * 3 + 2] += sum_grad[j2 * 3 + 2];
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
grad2[(i * m + k) * 3 + 0] = sum_grad[0] * grad_cost[i];
|
||||
grad2[(i * m + k) * 3 + 1] = sum_grad[1] * grad_cost[i];
|
||||
grad2[(i * m + k) * 3 + 2] = sum_grad[2] * grad_cost[i];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/********************************
|
||||
* matchcostgrad1 kernel
|
||||
*********************************/
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void matchcostgrad1(int b, int n, int m, const scalar_t *__restrict__ grad_cost, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ grad1)
|
||||
{
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x)
|
||||
{
|
||||
for (int l = threadIdx.x; l < n; l += blockDim.x)
|
||||
{
|
||||
scalar_t x1 = xyz1[i * n * 3 + l * 3 + 0];
|
||||
scalar_t y1 = xyz1[i * n * 3 + l * 3 + 1];
|
||||
scalar_t z1 = xyz1[i * n * 3 + l * 3 + 2];
|
||||
scalar_t dx = 0, dy = 0, dz = 0;
|
||||
for (int k = 0; k < m; k++)
|
||||
{
|
||||
scalar_t x2 = xyz2[i * m * 3 + k * 3 + 0];
|
||||
scalar_t y2 = xyz2[i * m * 3 + k * 3 + 1];
|
||||
scalar_t z2 = xyz2[i * m * 3 + k * 3 + 2];
|
||||
scalar_t d = match[i * n * m + k * n + l] * 2;
|
||||
dx += (x1 - x2) * d;
|
||||
dy += (y1 - y2) * d;
|
||||
dz += (z1 - z2) * d;
|
||||
}
|
||||
grad1[i * n * 3 + l * 3 + 0] = dx * grad_cost[i];
|
||||
grad1[i * n * 3 + l * 3 + 1] = dy * grad_cost[i];
|
||||
grad1[i * n * 3 + l * 3 + 2] = dz * grad_cost[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* MatchCost backward interface
|
||||
Input:
|
||||
grad_cost: (B) # gradients on cost
|
||||
xyz1: (B, N1, 3) # dataset_points
|
||||
xyz2: (B, N2, 3) # query_points
|
||||
match: (B, N2, N1)
|
||||
Output:
|
||||
grad1: (B, N1, 3)
|
||||
grad2: (B, N2, 3)
|
||||
*/
|
||||
std::vector<at::Tensor> MatchCostBackward(
|
||||
const at::Tensor grad_cost,
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2,
|
||||
const at::Tensor match)
|
||||
{
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
auto grad1 = at::zeros({b, n, 3}, xyz1.type());
|
||||
auto grad2 = at::zeros({b, m, 3}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&]
|
||||
{
|
||||
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
||||
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>()); }));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return std::vector<at::Tensor>({grad1, grad2});
|
||||
}
|
||||
|
||||
#endif
|
|
@ -1,60 +0,0 @@
|
|||
import pathlib
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
# get path to where this file is located
|
||||
REALPATH = pathlib.Path(__file__).parent
|
||||
|
||||
# JIT compile CUDA extensions
|
||||
load(
|
||||
name="torchemd_cuda",
|
||||
sources=[
|
||||
str(REALPATH / file)
|
||||
for file in [
|
||||
"cuda/emd.cpp",
|
||||
"cuda/emd_kernel.cu",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
# import compiled CUDA extensions
|
||||
import torchemd_cuda
|
||||
|
||||
|
||||
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz1, xyz2) -> torch.Tensor:
|
||||
xyz1 = xyz1.contiguous()
|
||||
xyz2 = xyz2.contiguous()
|
||||
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
|
||||
match = torchemd_cuda.approxmatch_forward(xyz1, xyz2)
|
||||
cost = torchemd_cuda.matchcost_forward(xyz1, xyz2, match)
|
||||
ctx.save_for_backward(xyz1, xyz2, match)
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_cost) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
xyz1, xyz2, match = ctx.saved_tensors
|
||||
grad_cost = grad_cost.contiguous()
|
||||
grad_xyz1, grad_xyz2 = torchemd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
|
||||
return grad_xyz1, grad_xyz2
|
||||
|
||||
|
||||
def earth_mover_distance(xyz1, xyz2, transpose=True) -> torch.Tensor:
|
||||
"""Earth Mover Distance (Approx)
|
||||
|
||||
Args:
|
||||
xyz1 (torch.Tensor): (B, N, 3) or (N, 3)
|
||||
xyz2 (torch.Tensor): (B, M, 3) or (M, 3)
|
||||
transpose (bool, optional): If True, xyz1 and xyz2 will be transposed to (B, 3, N) and (B, 3, M).
|
||||
"""
|
||||
if xyz1.dim() == 2:
|
||||
xyz1 = xyz1.unsqueeze(0)
|
||||
if xyz2.dim() == 2:
|
||||
xyz2 = xyz2.unsqueeze(0)
|
||||
if transpose:
|
||||
xyz1 = xyz1.transpose(1, 2)
|
||||
xyz2 = xyz2.transpose(1, 2)
|
||||
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
||||
return cost
|
Loading…
Reference in a new issue