🎨 autoformatting cpp/cu files

This commit is contained in:
Laurent FAINSIN 2023-04-24 11:53:37 +02:00
parent 2375a23ca5
commit 892b79f5c5
2 changed files with 368 additions and 333 deletions

View file

@ -4,7 +4,7 @@
#include <vector> #include <vector>
#include <torch/extension.h> #include <torch/extension.h>
//CUDA declarations // CUDA declarations
at::Tensor ApproxMatchForward( at::Tensor ApproxMatchForward(
const at::Tensor xyz1, const at::Tensor xyz1,
const at::Tensor xyz2); const at::Tensor xyz2);
@ -20,10 +20,11 @@ std::vector<at::Tensor> MatchCostBackward(
const at::Tensor xyz2, const at::Tensor xyz2,
const at::Tensor match); const at::Tensor match);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)"); {
m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)"); m.def("approxmatch_forward", &ApproxMatchForward, "ApproxMatch forward (CUDA)");
m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)"); m.def("matchcost_forward", &MatchCostForward, "MatchCost forward (CUDA)");
m.def("matchcost_backward", &MatchCostBackward, "MatchCost backward (CUDA)");
} }
#endif #endif

View file

@ -10,390 +10,424 @@
#include <vector> #include <vector>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid #include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
/******************************** /********************************
* Forward kernel for approxmatch * Forward kernel for approxmatch
*********************************/ *********************************/
template<typename scalar_t> template <typename scalar_t>
__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){ __global__ void approxmatch(int b, int n, int m, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, scalar_t *__restrict__ match, scalar_t *temp)
scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; {
scalar_t multiL,multiR; scalar_t *remainL = temp + blockIdx.x * (n + m) * 2;
if (n>=m){ scalar_t *remainR = temp + blockIdx.x * (n + m) * 2 + n;
multiL=1; scalar_t *ratioL = temp + blockIdx.x * (n + m) * 2 + n + m;
multiR=n/m; scalar_t *ratioR = temp + blockIdx.x * (n + m) * 2 + n + m + n;
}else{ scalar_t multiL, multiR;
multiL=m/n;
multiR=1; if (n >= m)
} {
const int Block=1024; multiL = 1;
__shared__ scalar_t buf[Block*4]; multiR = n / m;
for (int i=blockIdx.x;i<b;i+=gridDim.x){ }
for (int j=threadIdx.x;j<n*m;j+=blockDim.x) else
match[i*n*m+j]=0; {
for (int j=threadIdx.x;j<n;j+=blockDim.x) multiL = m / n;
remainL[j]=multiL; multiR = 1;
for (int j=threadIdx.x;j<m;j+=blockDim.x) }
remainR[j]=multiR; const int Block = 1024;
__syncthreads(); __shared__ scalar_t buf[Block * 4];
for (int j=7;j>=-2;j--){ for (int i = blockIdx.x; i < b; i += gridDim.x)
scalar_t level=-powf(4.0f,j); {
if (j==-2){ for (int j = threadIdx.x; j < n * m; j += blockDim.x)
level=0; match[i * n * m + j] = 0;
} for (int j = threadIdx.x; j < n; j += blockDim.x)
for (int k0=0;k0<n;k0+=blockDim.x){ remainL[j] = multiL;
int k=k0+threadIdx.x; for (int j = threadIdx.x; j < m; j += blockDim.x)
scalar_t x1=0,y1=0,z1=0; remainR[j] = multiR;
if (k<n){ __syncthreads();
x1=xyz1[i*n*3+k*3+0]; for (int j = 7; j >= -2; j--)
y1=xyz1[i*n*3+k*3+1]; {
z1=xyz1[i*n*3+k*3+2]; scalar_t level = -powf(4.0f, j);
} if (j == -2)
scalar_t suml=1e-9f; {
for (int l0=0;l0<m;l0+=Block){ level = 0;
int lend=min(m,l0+Block)-l0; }
for (int l=threadIdx.x;l<lend;l+=blockDim.x){ for (int k0 = 0; k0 < n; k0 += blockDim.x)
scalar_t x2=xyz2[i*m*3+l0*3+l*3+0]; {
scalar_t y2=xyz2[i*m*3+l0*3+l*3+1]; int k = k0 + threadIdx.x;
scalar_t z2=xyz2[i*m*3+l0*3+l*3+2]; scalar_t x1 = 0, y1 = 0, z1 = 0;
buf[l*4+0]=x2; if (k < n)
buf[l*4+1]=y2; {
buf[l*4+2]=z2; x1 = xyz1[i * n * 3 + k * 3 + 0];
buf[l*4+3]=remainR[l0+l]; y1 = xyz1[i * n * 3 + k * 3 + 1];
} z1 = xyz1[i * n * 3 + k * 3 + 2];
__syncthreads(); }
for (int l=0;l<lend;l++){ scalar_t suml = 1e-9f;
scalar_t x2=buf[l*4+0]; for (int l0 = 0; l0 < m; l0 += Block)
scalar_t y2=buf[l*4+1]; {
scalar_t z2=buf[l*4+2]; int lend = min(m, l0 + Block) - l0;
scalar_t d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)); for (int l = threadIdx.x; l < lend; l += blockDim.x)
scalar_t w=__expf(d)*buf[l*4+3]; {
suml+=w; scalar_t x2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 0];
} scalar_t y2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 1];
__syncthreads(); scalar_t z2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 2];
} buf[l * 4 + 0] = x2;
if (k<n) buf[l * 4 + 1] = y2;
ratioL[k]=remainL[k]/suml; buf[l * 4 + 2] = z2;
} buf[l * 4 + 3] = remainR[l0 + l];
__syncthreads(); }
for (int l0=0;l0<m;l0+=blockDim.x){ __syncthreads();
int l=l0+threadIdx.x; for (int l = 0; l < lend; l++)
scalar_t x2=0,y2=0,z2=0; {
if (l<m){ scalar_t x2 = buf[l * 4 + 0];
x2=xyz2[i*m*3+l*3+0]; scalar_t y2 = buf[l * 4 + 1];
y2=xyz2[i*m*3+l*3+1]; scalar_t z2 = buf[l * 4 + 2];
z2=xyz2[i*m*3+l*3+2]; scalar_t d = level * ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1));
} scalar_t w = __expf(d) * buf[l * 4 + 3];
scalar_t sumr=0; suml += w;
for (int k0=0;k0<n;k0+=Block){ }
int kend=min(n,k0+Block)-k0; __syncthreads();
for (int k=threadIdx.x;k<kend;k+=blockDim.x){ }
buf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0]; if (k < n)
buf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1]; ratioL[k] = remainL[k] / suml;
buf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2]; }
buf[k*4+3]=ratioL[k0+k]; __syncthreads();
} for (int l0 = 0; l0 < m; l0 += blockDim.x)
__syncthreads(); {
for (int k=0;k<kend;k++){ int l = l0 + threadIdx.x;
scalar_t x1=buf[k*4+0]; scalar_t x2 = 0, y2 = 0, z2 = 0;
scalar_t y1=buf[k*4+1]; if (l < m)
scalar_t z1=buf[k*4+2]; {
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3]; x2 = xyz2[i * m * 3 + l * 3 + 0];
sumr+=w; y2 = xyz2[i * m * 3 + l * 3 + 1];
} z2 = xyz2[i * m * 3 + l * 3 + 2];
__syncthreads(); }
} scalar_t sumr = 0;
if (l<m){ for (int k0 = 0; k0 < n; k0 += Block)
sumr*=remainR[l]; {
scalar_t consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f); int kend = min(n, k0 + Block) - k0;
ratioR[l]=consumption*remainR[l]; for (int k = threadIdx.x; k < kend; k += blockDim.x)
remainR[l]=fmaxf(0.0f,remainR[l]-sumr); {
} buf[k * 4 + 0] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 0];
} buf[k * 4 + 1] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 1];
__syncthreads(); buf[k * 4 + 2] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 2];
for (int k0=0;k0<n;k0+=blockDim.x){ buf[k * 4 + 3] = ratioL[k0 + k];
int k=k0+threadIdx.x; }
scalar_t x1=0,y1=0,z1=0; __syncthreads();
if (k<n){ for (int k = 0; k < kend; k++)
x1=xyz1[i*n*3+k*3+0]; {
y1=xyz1[i*n*3+k*3+1]; scalar_t x1 = buf[k * 4 + 0];
z1=xyz1[i*n*3+k*3+2]; scalar_t y1 = buf[k * 4 + 1];
} scalar_t z1 = buf[k * 4 + 2];
scalar_t suml=0; scalar_t w = __expf(level * ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1))) * buf[k * 4 + 3];
for (int l0=0;l0<m;l0+=Block){ sumr += w;
int lend=min(m,l0+Block)-l0; }
for (int l=threadIdx.x;l<lend;l+=blockDim.x){ __syncthreads();
buf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0]; }
buf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1]; if (l < m)
buf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2]; {
buf[l*4+3]=ratioR[l0+l]; sumr *= remainR[l];
} scalar_t consumption = fminf(remainR[l] / (sumr + 1e-9f), 1.0f);
__syncthreads(); ratioR[l] = consumption * remainR[l];
scalar_t rl=ratioL[k]; remainR[l] = fmaxf(0.0f, remainR[l] - sumr);
if (k<n){ }
for (int l=0;l<lend;l++){ }
scalar_t x2=buf[l*4+0]; __syncthreads();
scalar_t y2=buf[l*4+1]; for (int k0 = 0; k0 < n; k0 += blockDim.x)
scalar_t z2=buf[l*4+2]; {
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3]; int k = k0 + threadIdx.x;
match[i*n*m+(l0+l)*n+k]+=w; scalar_t x1 = 0, y1 = 0, z1 = 0;
suml+=w; if (k < n)
} {
} x1 = xyz1[i * n * 3 + k * 3 + 0];
__syncthreads(); y1 = xyz1[i * n * 3 + k * 3 + 1];
} z1 = xyz1[i * n * 3 + k * 3 + 2];
if (k<n) }
remainL[k]=fmaxf(0.0f,remainL[k]-suml); scalar_t suml = 0;
} for (int l0 = 0; l0 < m; l0 += Block)
__syncthreads(); {
} int lend = min(m, l0 + Block) - l0;
} for (int l = threadIdx.x; l < lend; l += blockDim.x)
{
buf[l * 4 + 0] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 0];
buf[l * 4 + 1] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 1];
buf[l * 4 + 2] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 2];
buf[l * 4 + 3] = ratioR[l0 + l];
}
__syncthreads();
scalar_t rl = ratioL[k];
if (k < n)
{
for (int l = 0; l < lend; l++)
{
scalar_t x2 = buf[l * 4 + 0];
scalar_t y2 = buf[l * 4 + 1];
scalar_t z2 = buf[l * 4 + 2];
scalar_t w = __expf(level * ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1))) * rl * buf[l * 4 + 3];
match[i * n * m + (l0 + l) * n + k] += w;
suml += w;
}
}
__syncthreads();
}
if (k < n)
remainL[k] = fmaxf(0.0f, remainL[k] - suml);
}
__syncthreads();
}
}
} }
//void approxmatchLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,scalar_t * match,scalar_t * temp){
// approxmatch<<<32,512>>>(b,n,m,xyz1,xyz2,match,temp);
//}
/* ApproxMatch forward interface /* ApproxMatch forward interface
Input: Input:
xyz1: (B, N1, 3) # dataset_points xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points xyz2: (B, N2, 3) # query_points
Output: Output:
match: (B, N2, N1) match: (B, N2, N1)
*/ */
at::Tensor ApproxMatchForward( at::Tensor ApproxMatchForward(
const at::Tensor xyz1, const at::Tensor xyz1,
const at::Tensor xyz2){ const at::Tensor xyz2)
const auto b = xyz1.size(0); {
const auto n = xyz1.size(1); const auto b = xyz1.size(0);
const auto m = xyz2.size(1); const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
TORCH_CHECK_EQ (xyz2.size(0), b); TORCH_CHECK_EQ(xyz2.size(0), b);
TORCH_CHECK_EQ (xyz1.size(2), 3); TORCH_CHECK_EQ(xyz1.size(2), 3);
TORCH_CHECK_EQ (xyz2.size(2), 3); TORCH_CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1); CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2); CHECK_INPUT(xyz2);
auto match = at::zeros({b, m, n}, xyz1.type()); auto match = at::zeros({b, m, n}, xyz1.type());
auto temp = at::zeros({b, (n+m)*2}, xyz1.type()); auto temp = at::zeros({b, (n + m) * 2}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] { AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&]
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>()); { approxmatch<scalar_t><<<32, 512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>()); }));
})); C10_CUDA_CHECK(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return match; return match;
} }
/******************************** /********************************
* Forward kernel for matchcost * Forward kernel for matchcost
*********************************/ *********************************/
template<typename scalar_t> template <typename scalar_t>
__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){ __global__ void matchcost(int b, int n, int m, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ out)
__shared__ scalar_t allsum[512]; {
const int Block=1024; __shared__ scalar_t allsum[512];
__shared__ scalar_t buf[Block*3]; const int Block = 1024;
for (int i=blockIdx.x;i<b;i+=gridDim.x){ __shared__ scalar_t buf[Block * 3];
scalar_t subsum=0; for (int i = blockIdx.x; i < b; i += gridDim.x)
for (int k0=0;k0<n;k0+=blockDim.x){ {
int k=k0+threadIdx.x; scalar_t subsum = 0;
scalar_t x1=0,y1=0,z1=0; for (int k0 = 0; k0 < n; k0 += blockDim.x)
if (k<n){ {
x1=xyz1[i*n*3+k*3+0]; int k = k0 + threadIdx.x;
y1=xyz1[i*n*3+k*3+1]; scalar_t x1 = 0, y1 = 0, z1 = 0;
z1=xyz1[i*n*3+k*3+2]; if (k < n)
} {
for (int l0=0;l0<m;l0+=Block){ x1 = xyz1[i * n * 3 + k * 3 + 0];
int lend=min(m,l0+Block)-l0; y1 = xyz1[i * n * 3 + k * 3 + 1];
for (int l=threadIdx.x;l<lend*3;l+=blockDim.x) z1 = xyz1[i * n * 3 + k * 3 + 2];
buf[l]=xyz2[i*m*3+l0*3+l]; }
__syncthreads(); for (int l0 = 0; l0 < m; l0 += Block)
if (k<n){ {
for (int l=0;l<lend;l++){ int lend = min(m, l0 + Block) - l0;
scalar_t x2=buf[l*3+0]; for (int l = threadIdx.x; l < lend * 3; l += blockDim.x)
scalar_t y2=buf[l*3+1]; buf[l] = xyz2[i * m * 3 + l0 * 3 + l];
scalar_t z2=buf[l*3+2]; __syncthreads();
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1); if (k < n)
subsum+=d*match[i*n*m+(l0+l)*n+k]; {
} for (int l = 0; l < lend; l++)
} {
__syncthreads(); scalar_t x2 = buf[l * 3 + 0];
} scalar_t y2 = buf[l * 3 + 1];
} scalar_t z2 = buf[l * 3 + 2];
allsum[threadIdx.x]=subsum; scalar_t d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
for (int j=1;j<blockDim.x;j<<=1){ subsum += d * match[i * n * m + (l0 + l) * n + k];
__syncthreads(); }
if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){ }
allsum[threadIdx.x]+=allsum[threadIdx.x+j]; __syncthreads();
} }
} }
if (threadIdx.x==0) allsum[threadIdx.x] = subsum;
out[i]=allsum[0]; for (int j = 1; j < blockDim.x; j <<= 1)
__syncthreads(); {
} __syncthreads();
if ((threadIdx.x & j) == 0 && threadIdx.x + j < blockDim.x)
{
allsum[threadIdx.x] += allsum[threadIdx.x + j];
}
}
if (threadIdx.x == 0)
out[i] = allsum[0];
__syncthreads();
}
} }
//void matchcostLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * out){
// matchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);
//}
/* MatchCost forward interface /* MatchCost forward interface
Input: Input:
xyz1: (B, N1, 3) # dataset_points xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points xyz2: (B, N2, 3) # query_points
match: (B, N2, N1) match: (B, N2, N1)
Output: Output:
cost: (B) cost: (B)
*/ */
at::Tensor MatchCostForward( at::Tensor MatchCostForward(
const at::Tensor xyz1, const at::Tensor xyz1,
const at::Tensor xyz2, const at::Tensor xyz2,
const at::Tensor match){ const at::Tensor match)
const auto b = xyz1.size(0); {
const auto n = xyz1.size(1); const auto b = xyz1.size(0);
const auto m = xyz2.size(1); const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
TORCH_CHECK_EQ (xyz2.size(0), b); TORCH_CHECK_EQ(xyz2.size(0), b);
TORCH_CHECK_EQ (xyz1.size(2), 3); TORCH_CHECK_EQ(xyz1.size(2), 3);
TORCH_CHECK_EQ (xyz2.size(2), 3); TORCH_CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1); CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2); CHECK_INPUT(xyz2);
auto cost = at::zeros({b}, xyz1.type()); auto cost = at::zeros({b}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] { AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&]
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>()); { matchcost<scalar_t><<<32, 512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>()); }));
})); C10_CUDA_CHECK(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return cost; return cost;
}
/********************************
* matchcostgrad2 kernel
*********************************/
template<typename scalar_t>
__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){
__shared__ scalar_t sum_grad[256*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
int kbeg=m*blockIdx.y/gridDim.y;
int kend=m*(blockIdx.y+1)/gridDim.y;
for (int k=kbeg;k<kend;k++){
scalar_t x2=xyz2[(i*m+k)*3+0];
scalar_t y2=xyz2[(i*m+k)*3+1];
scalar_t z2=xyz2[(i*m+k)*3+2];
scalar_t subsumx=0,subsumy=0,subsumz=0;
for (int j=threadIdx.x;j<n;j+=blockDim.x){
scalar_t x1=x2-xyz1[(i*n+j)*3+0];
scalar_t y1=y2-xyz1[(i*n+j)*3+1];
scalar_t z1=z2-xyz1[(i*n+j)*3+2];
scalar_t d=match[i*n*m+k*n+j]*2;
subsumx+=x1*d;
subsumy+=y1*d;
subsumz+=z1*d;
}
sum_grad[threadIdx.x*3+0]=subsumx;
sum_grad[threadIdx.x*3+1]=subsumy;
sum_grad[threadIdx.x*3+2]=subsumz;
for (int j=1;j<blockDim.x;j<<=1){
__syncthreads();
int j1=threadIdx.x;
int j2=threadIdx.x+j;
if ((j1&j)==0 && j2<blockDim.x){
sum_grad[j1*3+0]+=sum_grad[j2*3+0];
sum_grad[j1*3+1]+=sum_grad[j2*3+1];
sum_grad[j1*3+2]+=sum_grad[j2*3+2];
}
}
if (threadIdx.x==0){
grad2[(i*m+k)*3+0]=sum_grad[0]*grad_cost[i];
grad2[(i*m+k)*3+1]=sum_grad[1]*grad_cost[i];
grad2[(i*m+k)*3+2]=sum_grad[2]*grad_cost[i];
}
__syncthreads();
}
}
} }
/******************************** /********************************
* matchcostgrad1 kernel * matchcostgrad2 kernel
*********************************/ *********************************/
template<typename scalar_t> template <typename scalar_t>
__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){ __global__ void matchcostgrad2(int b, int n, int m, const scalar_t *__restrict__ grad_cost, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ grad2)
for (int i=blockIdx.x;i<b;i+=gridDim.x){ {
for (int l=threadIdx.x;l<n;l+=blockDim.x){ __shared__ scalar_t sum_grad[256 * 3];
scalar_t x1=xyz1[i*n*3+l*3+0]; for (int i = blockIdx.x; i < b; i += gridDim.x)
scalar_t y1=xyz1[i*n*3+l*3+1]; {
scalar_t z1=xyz1[i*n*3+l*3+2]; int kbeg = m * blockIdx.y / gridDim.y;
scalar_t dx=0,dy=0,dz=0; int kend = m * (blockIdx.y + 1) / gridDim.y;
for (int k=0;k<m;k++){ for (int k = kbeg; k < kend; k++)
scalar_t x2=xyz2[i*m*3+k*3+0]; {
scalar_t y2=xyz2[i*m*3+k*3+1]; scalar_t x2 = xyz2[(i * m + k) * 3 + 0];
scalar_t z2=xyz2[i*m*3+k*3+2]; scalar_t y2 = xyz2[(i * m + k) * 3 + 1];
scalar_t d=match[i*n*m+k*n+l]*2; scalar_t z2 = xyz2[(i * m + k) * 3 + 2];
dx+=(x1-x2)*d; scalar_t subsumx = 0, subsumy = 0, subsumz = 0;
dy+=(y1-y2)*d; for (int j = threadIdx.x; j < n; j += blockDim.x)
dz+=(z1-z2)*d; {
} scalar_t x1 = x2 - xyz1[(i * n + j) * 3 + 0];
grad1[i*n*3+l*3+0]=dx*grad_cost[i]; scalar_t y1 = y2 - xyz1[(i * n + j) * 3 + 1];
grad1[i*n*3+l*3+1]=dy*grad_cost[i]; scalar_t z1 = z2 - xyz1[(i * n + j) * 3 + 2];
grad1[i*n*3+l*3+2]=dz*grad_cost[i]; scalar_t d = match[i * n * m + k * n + j] * 2;
} subsumx += x1 * d;
} subsumy += y1 * d;
subsumz += z1 * d;
}
sum_grad[threadIdx.x * 3 + 0] = subsumx;
sum_grad[threadIdx.x * 3 + 1] = subsumy;
sum_grad[threadIdx.x * 3 + 2] = subsumz;
for (int j = 1; j < blockDim.x; j <<= 1)
{
__syncthreads();
int j1 = threadIdx.x;
int j2 = threadIdx.x + j;
if ((j1 & j) == 0 && j2 < blockDim.x)
{
sum_grad[j1 * 3 + 0] += sum_grad[j2 * 3 + 0];
sum_grad[j1 * 3 + 1] += sum_grad[j2 * 3 + 1];
sum_grad[j1 * 3 + 2] += sum_grad[j2 * 3 + 2];
}
}
if (threadIdx.x == 0)
{
grad2[(i * m + k) * 3 + 0] = sum_grad[0] * grad_cost[i];
grad2[(i * m + k) * 3 + 1] = sum_grad[1] * grad_cost[i];
grad2[(i * m + k) * 3 + 2] = sum_grad[2] * grad_cost[i];
}
__syncthreads();
}
}
} }
//void matchcostgradLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * grad1,scalar_t * grad2){ /********************************
// matchcostgrad1<<<32,512>>>(b,n,m,xyz1,xyz2,match,grad1); * matchcostgrad1 kernel
// matchcostgrad2<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2); *********************************/
//}
template <typename scalar_t>
__global__ void matchcostgrad1(int b, int n, int m, const scalar_t *__restrict__ grad_cost, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ grad1)
{
for (int i = blockIdx.x; i < b; i += gridDim.x)
{
for (int l = threadIdx.x; l < n; l += blockDim.x)
{
scalar_t x1 = xyz1[i * n * 3 + l * 3 + 0];
scalar_t y1 = xyz1[i * n * 3 + l * 3 + 1];
scalar_t z1 = xyz1[i * n * 3 + l * 3 + 2];
scalar_t dx = 0, dy = 0, dz = 0;
for (int k = 0; k < m; k++)
{
scalar_t x2 = xyz2[i * m * 3 + k * 3 + 0];
scalar_t y2 = xyz2[i * m * 3 + k * 3 + 1];
scalar_t z2 = xyz2[i * m * 3 + k * 3 + 2];
scalar_t d = match[i * n * m + k * n + l] * 2;
dx += (x1 - x2) * d;
dy += (y1 - y2) * d;
dz += (z1 - z2) * d;
}
grad1[i * n * 3 + l * 3 + 0] = dx * grad_cost[i];
grad1[i * n * 3 + l * 3 + 1] = dy * grad_cost[i];
grad1[i * n * 3 + l * 3 + 2] = dz * grad_cost[i];
}
}
}
/* MatchCost backward interface /* MatchCost backward interface
Input: Input:
grad_cost: (B) # gradients on cost grad_cost: (B) # gradients on cost
xyz1: (B, N1, 3) # dataset_points xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points xyz2: (B, N2, 3) # query_points
match: (B, N2, N1) match: (B, N2, N1)
Output: Output:
grad1: (B, N1, 3) grad1: (B, N1, 3)
grad2: (B, N2, 3) grad2: (B, N2, 3)
*/ */
std::vector<at::Tensor> MatchCostBackward( std::vector<at::Tensor> MatchCostBackward(
const at::Tensor grad_cost, const at::Tensor grad_cost,
const at::Tensor xyz1, const at::Tensor xyz1,
const at::Tensor xyz2, const at::Tensor xyz2,
const at::Tensor match){ const at::Tensor match)
const auto b = xyz1.size(0); {
const auto n = xyz1.size(1); const auto b = xyz1.size(0);
const auto m = xyz2.size(1); const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
TORCH_CHECK_EQ (xyz2.size(0), b); TORCH_CHECK_EQ(xyz2.size(0), b);
TORCH_CHECK_EQ (xyz1.size(2), 3); TORCH_CHECK_EQ(xyz1.size(2), 3);
TORCH_CHECK_EQ (xyz2.size(2), 3); TORCH_CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1); CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2); CHECK_INPUT(xyz2);
auto grad1 = at::zeros({b, n, 3}, xyz1.type()); auto grad1 = at::zeros({b, n, 3}, xyz1.type());
auto grad2 = at::zeros({b, m, 3}, xyz1.type()); auto grad2 = at::zeros({b, m, 3}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] { AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&]
{
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>()); matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>()); matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>()); }));
})); C10_CUDA_CHECK(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return std::vector<at::Tensor>({grad1, grad2}); return std::vector<at::Tensor>({grad1, grad2});
} }
#endif #endif