first commit
This commit is contained in:
parent
7fb2b53a02
commit
c4206265b3
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
__pycache__
|
||||||
|
build
|
||||||
|
dist
|
||||||
|
emd_ext.egg-info
|
||||||
|
*.so
|
|
@ -23,7 +23,7 @@ Check `test_emd_loss.py` for example.
|
||||||
|
|
||||||
## Author
|
## Author
|
||||||
|
|
||||||
The cuda code is originally written by Haoqiang Fan. The PyTorch version is modified by Kaichun Mo.
|
The cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|
0
__init__.py
Executable file
0
__init__.py
Executable file
29
cuda/emd.cpp
Executable file
29
cuda/emd.cpp
Executable file
|
@ -0,0 +1,29 @@
|
||||||
|
#ifndef _EMD
|
||||||
|
#define _EMD
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
//CUDA declarations
|
||||||
|
at::Tensor ApproxMatchForward(
|
||||||
|
const at::Tensor xyz1,
|
||||||
|
const at::Tensor xyz2);
|
||||||
|
|
||||||
|
at::Tensor MatchCostForward(
|
||||||
|
const at::Tensor xyz1,
|
||||||
|
const at::Tensor xyz2,
|
||||||
|
const at::Tensor match);
|
||||||
|
|
||||||
|
std::vector<at::Tensor> MatchCostBackward(
|
||||||
|
const at::Tensor grad_cost,
|
||||||
|
const at::Tensor xyz1,
|
||||||
|
const at::Tensor xyz2,
|
||||||
|
const at::Tensor match);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)");
|
||||||
|
m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)");
|
||||||
|
m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)");
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
400
cuda/emd_kernel.cu
Normal file
400
cuda/emd_kernel.cu
Normal file
|
@ -0,0 +1,400 @@
|
||||||
|
/**********************************
|
||||||
|
* Original Author: Haoqiang Fan
|
||||||
|
* Modified by: Kaichun Mo
|
||||||
|
*********************************/
|
||||||
|
|
||||||
|
#ifndef _EMD_KERNEL
|
||||||
|
#define _EMD_KERNEL
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
||||||
|
#include <THC/THC.h>
|
||||||
|
|
||||||
|
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
|
||||||
|
/********************************
|
||||||
|
* Forward kernel for approxmatch
|
||||||
|
*********************************/
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){
|
||||||
|
scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
|
||||||
|
scalar_t multiL,multiR;
|
||||||
|
if (n>=m){
|
||||||
|
multiL=1;
|
||||||
|
multiR=n/m;
|
||||||
|
}else{
|
||||||
|
multiL=m/n;
|
||||||
|
multiR=1;
|
||||||
|
}
|
||||||
|
const int Block=1024;
|
||||||
|
__shared__ scalar_t buf[Block*4];
|
||||||
|
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||||
|
for (int j=threadIdx.x;j<n*m;j+=blockDim.x)
|
||||||
|
match[i*n*m+j]=0;
|
||||||
|
for (int j=threadIdx.x;j<n;j+=blockDim.x)
|
||||||
|
remainL[j]=multiL;
|
||||||
|
for (int j=threadIdx.x;j<m;j+=blockDim.x)
|
||||||
|
remainR[j]=multiR;
|
||||||
|
__syncthreads();
|
||||||
|
for (int j=7;j>=-2;j--){
|
||||||
|
scalar_t level=-powf(4.0f,j);
|
||||||
|
if (j==-2){
|
||||||
|
level=0;
|
||||||
|
}
|
||||||
|
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||||
|
int k=k0+threadIdx.x;
|
||||||
|
scalar_t x1=0,y1=0,z1=0;
|
||||||
|
if (k<n){
|
||||||
|
x1=xyz1[i*n*3+k*3+0];
|
||||||
|
y1=xyz1[i*n*3+k*3+1];
|
||||||
|
z1=xyz1[i*n*3+k*3+2];
|
||||||
|
}
|
||||||
|
scalar_t suml=1e-9f;
|
||||||
|
for (int l0=0;l0<m;l0+=Block){
|
||||||
|
int lend=min(m,l0+Block)-l0;
|
||||||
|
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
|
||||||
|
scalar_t x2=xyz2[i*m*3+l0*3+l*3+0];
|
||||||
|
scalar_t y2=xyz2[i*m*3+l0*3+l*3+1];
|
||||||
|
scalar_t z2=xyz2[i*m*3+l0*3+l*3+2];
|
||||||
|
buf[l*4+0]=x2;
|
||||||
|
buf[l*4+1]=y2;
|
||||||
|
buf[l*4+2]=z2;
|
||||||
|
buf[l*4+3]=remainR[l0+l];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
for (int l=0;l<lend;l++){
|
||||||
|
scalar_t x2=buf[l*4+0];
|
||||||
|
scalar_t y2=buf[l*4+1];
|
||||||
|
scalar_t z2=buf[l*4+2];
|
||||||
|
scalar_t d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));
|
||||||
|
scalar_t w=__expf(d)*buf[l*4+3];
|
||||||
|
suml+=w;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
if (k<n)
|
||||||
|
ratioL[k]=remainL[k]/suml;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
for (int l0=0;l0<m;l0+=blockDim.x){
|
||||||
|
int l=l0+threadIdx.x;
|
||||||
|
scalar_t x2=0,y2=0,z2=0;
|
||||||
|
if (l<m){
|
||||||
|
x2=xyz2[i*m*3+l*3+0];
|
||||||
|
y2=xyz2[i*m*3+l*3+1];
|
||||||
|
z2=xyz2[i*m*3+l*3+2];
|
||||||
|
}
|
||||||
|
scalar_t sumr=0;
|
||||||
|
for (int k0=0;k0<n;k0+=Block){
|
||||||
|
int kend=min(n,k0+Block)-k0;
|
||||||
|
for (int k=threadIdx.x;k<kend;k+=blockDim.x){
|
||||||
|
buf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];
|
||||||
|
buf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];
|
||||||
|
buf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];
|
||||||
|
buf[k*4+3]=ratioL[k0+k];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
for (int k=0;k<kend;k++){
|
||||||
|
scalar_t x1=buf[k*4+0];
|
||||||
|
scalar_t y1=buf[k*4+1];
|
||||||
|
scalar_t z1=buf[k*4+2];
|
||||||
|
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];
|
||||||
|
sumr+=w;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
if (l<m){
|
||||||
|
sumr*=remainR[l];
|
||||||
|
scalar_t consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
|
||||||
|
ratioR[l]=consumption*remainR[l];
|
||||||
|
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||||
|
int k=k0+threadIdx.x;
|
||||||
|
scalar_t x1=0,y1=0,z1=0;
|
||||||
|
if (k<n){
|
||||||
|
x1=xyz1[i*n*3+k*3+0];
|
||||||
|
y1=xyz1[i*n*3+k*3+1];
|
||||||
|
z1=xyz1[i*n*3+k*3+2];
|
||||||
|
}
|
||||||
|
scalar_t suml=0;
|
||||||
|
for (int l0=0;l0<m;l0+=Block){
|
||||||
|
int lend=min(m,l0+Block)-l0;
|
||||||
|
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
|
||||||
|
buf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];
|
||||||
|
buf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];
|
||||||
|
buf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];
|
||||||
|
buf[l*4+3]=ratioR[l0+l];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
scalar_t rl=ratioL[k];
|
||||||
|
if (k<n){
|
||||||
|
for (int l=0;l<lend;l++){
|
||||||
|
scalar_t x2=buf[l*4+0];
|
||||||
|
scalar_t y2=buf[l*4+1];
|
||||||
|
scalar_t z2=buf[l*4+2];
|
||||||
|
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];
|
||||||
|
match[i*n*m+(l0+l)*n+k]+=w;
|
||||||
|
suml+=w;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
if (k<n)
|
||||||
|
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//void approxmatchLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,scalar_t * match,scalar_t * temp){
|
||||||
|
// approxmatch<<<32,512>>>(b,n,m,xyz1,xyz2,match,temp);
|
||||||
|
//}
|
||||||
|
|
||||||
|
/* ApproxMatch forward interface
|
||||||
|
Input:
|
||||||
|
xyz1: (B, N1, 3) # dataset_points
|
||||||
|
xyz2: (B, N2, 3) # query_points
|
||||||
|
Output:
|
||||||
|
match: (B, N2, N1)
|
||||||
|
*/
|
||||||
|
at::Tensor ApproxMatchForward(
|
||||||
|
const at::Tensor xyz1,
|
||||||
|
const at::Tensor xyz2){
|
||||||
|
const auto b = xyz1.size(0);
|
||||||
|
const auto n = xyz1.size(1);
|
||||||
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
|
CHECK_EQ(xyz2.size(0), b);
|
||||||
|
CHECK_EQ(xyz1.size(2), 3);
|
||||||
|
CHECK_EQ(xyz2.size(2), 3);
|
||||||
|
CHECK_INPUT(xyz1);
|
||||||
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
auto match = at::zeros({b, m, n}, xyz1.type());
|
||||||
|
auto temp = at::zeros({b, (n+m)*2}, xyz1.type());
|
||||||
|
|
||||||
|
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
||||||
|
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
||||||
|
}));
|
||||||
|
THCudaCheck(cudaGetLastError());
|
||||||
|
|
||||||
|
return match;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/********************************
|
||||||
|
* Forward kernel for matchcost
|
||||||
|
*********************************/
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){
|
||||||
|
__shared__ scalar_t allsum[512];
|
||||||
|
const int Block=1024;
|
||||||
|
__shared__ scalar_t buf[Block*3];
|
||||||
|
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||||
|
scalar_t subsum=0;
|
||||||
|
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||||
|
int k=k0+threadIdx.x;
|
||||||
|
scalar_t x1=0,y1=0,z1=0;
|
||||||
|
if (k<n){
|
||||||
|
x1=xyz1[i*n*3+k*3+0];
|
||||||
|
y1=xyz1[i*n*3+k*3+1];
|
||||||
|
z1=xyz1[i*n*3+k*3+2];
|
||||||
|
}
|
||||||
|
for (int l0=0;l0<m;l0+=Block){
|
||||||
|
int lend=min(m,l0+Block)-l0;
|
||||||
|
for (int l=threadIdx.x;l<lend*3;l+=blockDim.x)
|
||||||
|
buf[l]=xyz2[i*m*3+l0*3+l];
|
||||||
|
__syncthreads();
|
||||||
|
if (k<n){
|
||||||
|
for (int l=0;l<lend;l++){
|
||||||
|
scalar_t x2=buf[l*3+0];
|
||||||
|
scalar_t y2=buf[l*3+1];
|
||||||
|
scalar_t z2=buf[l*3+2];
|
||||||
|
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
|
||||||
|
subsum+=d*match[i*n*m+(l0+l)*n+k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
allsum[threadIdx.x]=subsum;
|
||||||
|
for (int j=1;j<blockDim.x;j<<=1){
|
||||||
|
__syncthreads();
|
||||||
|
if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){
|
||||||
|
allsum[threadIdx.x]+=allsum[threadIdx.x+j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (threadIdx.x==0)
|
||||||
|
out[i]=allsum[0];
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//void matchcostLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * out){
|
||||||
|
// matchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);
|
||||||
|
//}
|
||||||
|
|
||||||
|
/* MatchCost forward interface
|
||||||
|
Input:
|
||||||
|
xyz1: (B, N1, 3) # dataset_points
|
||||||
|
xyz2: (B, N2, 3) # query_points
|
||||||
|
match: (B, N2, N1)
|
||||||
|
Output:
|
||||||
|
cost: (B)
|
||||||
|
*/
|
||||||
|
at::Tensor MatchCostForward(
|
||||||
|
const at::Tensor xyz1,
|
||||||
|
const at::Tensor xyz2,
|
||||||
|
const at::Tensor match){
|
||||||
|
const auto b = xyz1.size(0);
|
||||||
|
const auto n = xyz1.size(1);
|
||||||
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
|
CHECK_EQ(xyz2.size(0), b);
|
||||||
|
CHECK_EQ(xyz1.size(2), 3);
|
||||||
|
CHECK_EQ(xyz2.size(2), 3);
|
||||||
|
CHECK_INPUT(xyz1);
|
||||||
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
auto cost = at::zeros({b}, xyz1.type());
|
||||||
|
|
||||||
|
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
||||||
|
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
||||||
|
}));
|
||||||
|
THCudaCheck(cudaGetLastError());
|
||||||
|
|
||||||
|
return cost;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/********************************
|
||||||
|
* matchcostgrad2 kernel
|
||||||
|
*********************************/
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){
|
||||||
|
__shared__ scalar_t sum_grad[256*3];
|
||||||
|
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||||
|
int kbeg=m*blockIdx.y/gridDim.y;
|
||||||
|
int kend=m*(blockIdx.y+1)/gridDim.y;
|
||||||
|
for (int k=kbeg;k<kend;k++){
|
||||||
|
scalar_t x2=xyz2[(i*m+k)*3+0];
|
||||||
|
scalar_t y2=xyz2[(i*m+k)*3+1];
|
||||||
|
scalar_t z2=xyz2[(i*m+k)*3+2];
|
||||||
|
scalar_t subsumx=0,subsumy=0,subsumz=0;
|
||||||
|
for (int j=threadIdx.x;j<n;j+=blockDim.x){
|
||||||
|
scalar_t x1=x2-xyz1[(i*n+j)*3+0];
|
||||||
|
scalar_t y1=y2-xyz1[(i*n+j)*3+1];
|
||||||
|
scalar_t z1=z2-xyz1[(i*n+j)*3+2];
|
||||||
|
scalar_t d=match[i*n*m+k*n+j]*2;
|
||||||
|
subsumx+=x1*d;
|
||||||
|
subsumy+=y1*d;
|
||||||
|
subsumz+=z1*d;
|
||||||
|
}
|
||||||
|
sum_grad[threadIdx.x*3+0]=subsumx;
|
||||||
|
sum_grad[threadIdx.x*3+1]=subsumy;
|
||||||
|
sum_grad[threadIdx.x*3+2]=subsumz;
|
||||||
|
for (int j=1;j<blockDim.x;j<<=1){
|
||||||
|
__syncthreads();
|
||||||
|
int j1=threadIdx.x;
|
||||||
|
int j2=threadIdx.x+j;
|
||||||
|
if ((j1&j)==0 && j2<blockDim.x){
|
||||||
|
sum_grad[j1*3+0]+=sum_grad[j2*3+0];
|
||||||
|
sum_grad[j1*3+1]+=sum_grad[j2*3+1];
|
||||||
|
sum_grad[j1*3+2]+=sum_grad[j2*3+2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (threadIdx.x==0){
|
||||||
|
grad2[(i*m+k)*3+0]=sum_grad[0]*grad_cost[i];
|
||||||
|
grad2[(i*m+k)*3+1]=sum_grad[1]*grad_cost[i];
|
||||||
|
grad2[(i*m+k)*3+2]=sum_grad[2]*grad_cost[i];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/********************************
|
||||||
|
* matchcostgrad1 kernel
|
||||||
|
*********************************/
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){
|
||||||
|
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||||
|
for (int l=threadIdx.x;l<n;l+=blockDim.x){
|
||||||
|
scalar_t x1=xyz1[i*n*3+l*3+0];
|
||||||
|
scalar_t y1=xyz1[i*n*3+l*3+1];
|
||||||
|
scalar_t z1=xyz1[i*n*3+l*3+2];
|
||||||
|
scalar_t dx=0,dy=0,dz=0;
|
||||||
|
for (int k=0;k<m;k++){
|
||||||
|
scalar_t x2=xyz2[i*m*3+k*3+0];
|
||||||
|
scalar_t y2=xyz2[i*m*3+k*3+1];
|
||||||
|
scalar_t z2=xyz2[i*m*3+k*3+2];
|
||||||
|
scalar_t d=match[i*n*m+k*n+l]*2;
|
||||||
|
dx+=(x1-x2)*d;
|
||||||
|
dy+=(y1-y2)*d;
|
||||||
|
dz+=(z1-z2)*d;
|
||||||
|
}
|
||||||
|
grad1[i*n*3+l*3+0]=dx*grad_cost[i];
|
||||||
|
grad1[i*n*3+l*3+1]=dy*grad_cost[i];
|
||||||
|
grad1[i*n*3+l*3+2]=dz*grad_cost[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//void matchcostgradLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * grad1,scalar_t * grad2){
|
||||||
|
// matchcostgrad1<<<32,512>>>(b,n,m,xyz1,xyz2,match,grad1);
|
||||||
|
// matchcostgrad2<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);
|
||||||
|
//}
|
||||||
|
|
||||||
|
|
||||||
|
/* MatchCost backward interface
|
||||||
|
Input:
|
||||||
|
grad_cost: (B) # gradients on cost
|
||||||
|
xyz1: (B, N1, 3) # dataset_points
|
||||||
|
xyz2: (B, N2, 3) # query_points
|
||||||
|
match: (B, N2, N1)
|
||||||
|
Output:
|
||||||
|
grad1: (B, N1, 3)
|
||||||
|
grad2: (B, N2, 3)
|
||||||
|
*/
|
||||||
|
std::vector<at::Tensor> MatchCostBackward(
|
||||||
|
const at::Tensor grad_cost,
|
||||||
|
const at::Tensor xyz1,
|
||||||
|
const at::Tensor xyz2,
|
||||||
|
const at::Tensor match){
|
||||||
|
const auto b = xyz1.size(0);
|
||||||
|
const auto n = xyz1.size(1);
|
||||||
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
|
CHECK_EQ(xyz2.size(0), b);
|
||||||
|
CHECK_EQ(xyz1.size(2), 3);
|
||||||
|
CHECK_EQ(xyz2.size(2), 3);
|
||||||
|
CHECK_INPUT(xyz1);
|
||||||
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
auto grad1 = at::zeros({b, n, 3}, xyz1.type());
|
||||||
|
auto grad2 = at::zeros({b, m, 3}, xyz1.type());
|
||||||
|
|
||||||
|
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] {
|
||||||
|
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
||||||
|
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
||||||
|
}));
|
||||||
|
THCudaCheck(cudaGetLastError());
|
||||||
|
|
||||||
|
return std::vector<at::Tensor>({grad1, grad2});
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
46
emd.py
Executable file
46
emd.py
Executable file
|
@ -0,0 +1,46 @@
|
||||||
|
import torch
|
||||||
|
import emd_cuda
|
||||||
|
|
||||||
|
|
||||||
|
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, xyz1, xyz2):
|
||||||
|
xyz1 = xyz1.contiguous()
|
||||||
|
xyz2 = xyz2.contiguous()
|
||||||
|
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
|
||||||
|
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
|
||||||
|
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
|
||||||
|
ctx.save_for_backward(xyz1, xyz2, match)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_cost):
|
||||||
|
xyz1, xyz2, match = ctx.saved_tensors
|
||||||
|
grad_cost = grad_cost.contiguous()
|
||||||
|
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
|
||||||
|
return grad_xyz1, grad_xyz2
|
||||||
|
|
||||||
|
|
||||||
|
def earth_mover_distance(xyz1, xyz2, transpose=True):
|
||||||
|
"""Earth Mover Distance (Approx)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xyz1 (torch.Tensor): (b, 3, n1)
|
||||||
|
xyz2 (torch.Tensor): (b, 3, n1)
|
||||||
|
transpose (bool): whether to transpose inputs as it might be BCN format.
|
||||||
|
Extensions only support BNC format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
cost (torch.Tensor): (b)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if xyz1.dim() == 2:
|
||||||
|
xyz1 = xyz1.unsqueeze(0)
|
||||||
|
if xyz2.dim() == 2:
|
||||||
|
xyz2 = xyz2.unsqueeze(0)
|
||||||
|
if transpose:
|
||||||
|
xyz1 = xyz1.transpose(1, 2)
|
||||||
|
xyz2 = xyz2.transpose(1, 2)
|
||||||
|
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
||||||
|
return cost
|
||||||
|
|
27
setup.py
Executable file
27
setup.py
Executable file
|
@ -0,0 +1,27 @@
|
||||||
|
"""Setup extension
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
If extra_compile_args is provided, you need to provide different instances for different extensions.
|
||||||
|
Refer to https://github.com/pytorch/pytorch/issues/20169
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from setuptools import setup
|
||||||
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='emd_ext',
|
||||||
|
ext_modules=[
|
||||||
|
CUDAExtension(
|
||||||
|
name='emd_cuda',
|
||||||
|
sources=[
|
||||||
|
'cuda/emd.cpp',
|
||||||
|
'cuda/emd_kernel.cu',
|
||||||
|
],
|
||||||
|
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
|
||||||
|
),
|
||||||
|
],
|
||||||
|
cmdclass={
|
||||||
|
'build_ext': BuildExtension
|
||||||
|
})
|
44
test_emd_loss.py
Normal file
44
test_emd_loss.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
from emd import earth_mover_distance
|
||||||
|
|
||||||
|
# gt
|
||||||
|
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
||||||
|
p1 = p1.repeat(3, 1, 1)
|
||||||
|
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
||||||
|
p2 = p2.repeat(3, 1, 1)
|
||||||
|
print(p1)
|
||||||
|
print(p2)
|
||||||
|
p1.requires_grad = True
|
||||||
|
p2.requires_grad = True
|
||||||
|
|
||||||
|
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
|
||||||
|
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
|
||||||
|
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
|
||||||
|
print('gt_dist: ', gt_dist)
|
||||||
|
|
||||||
|
gt_dist.backward()
|
||||||
|
print(p1.grad)
|
||||||
|
print(p2.grad)
|
||||||
|
|
||||||
|
# emd
|
||||||
|
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
||||||
|
p1 = p1.repeat(3, 1, 1)
|
||||||
|
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
||||||
|
p2 = p2.repeat(3, 1, 1)
|
||||||
|
print(p1)
|
||||||
|
print(p2)
|
||||||
|
p1.requires_grad = True
|
||||||
|
p2.requires_grad = True
|
||||||
|
|
||||||
|
d = earth_mover_distance(p1, p2, transpose=False)
|
||||||
|
print(d)
|
||||||
|
|
||||||
|
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
|
||||||
|
print(loss)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
print(p1.grad)
|
||||||
|
print(p2.grad)
|
||||||
|
|
Loading…
Reference in a new issue