/********************************** * Original Author: Haoqiang Fan * Modified by: Kaichun Mo *********************************/ #ifndef _EMD_KERNEL #define _EMD_KERNEL #include #include #include #include // at::cuda::getApplyGrid #define CHECK_INPUT(x) /******************************** * Forward kernel for approxmatch *********************************/ template __global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){ scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; scalar_t multiL,multiR; if (n>=m){ multiL=1; multiR=n/m; }else{ multiL=m/n; multiR=1; } const int Block=1024; __shared__ scalar_t buf[Block*4]; for (int i=blockIdx.x;i=-2;j--){ scalar_t level=-powf(4.0f,j); if (j==-2){ level=0; } for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,temp); //} /* ApproxMatch forward interface Input: xyz1: (B, N1, 3) # dataset_points xyz2: (B, N2, 3) # query_points Output: match: (B, N2, N1) */ at::Tensor ApproxMatchForward( const at::Tensor xyz1, const at::Tensor xyz2){ const auto b = xyz1.size(0); const auto n = xyz1.size(1); const auto m = xyz2.size(1); TORCH_CHECK_EQ(xyz2.size(0), b); TORCH_CHECK_EQ(xyz1.size(2), 3); TORCH_CHECK_EQ(xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); auto match = at::zeros({b, m, n}, xyz1.type()); auto temp = at::zeros({b, (n+m)*2}, xyz1.type()); AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] { approxmatch<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), temp.data()); })); C10_CUDA_CHECK(cudaGetLastError()); return match; } /******************************** * Forward kernel for matchcost *********************************/ template __global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){ __shared__ scalar_t allsum[512]; const int Block=1024; __shared__ scalar_t buf[Block*3]; for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out); //} /* MatchCost forward interface Input: xyz1: (B, N1, 3) # dataset_points xyz2: (B, N2, 3) # query_points match: (B, N2, N1) Output: cost: (B) */ at::Tensor MatchCostForward( const at::Tensor xyz1, const at::Tensor xyz2, const at::Tensor match){ const auto b = xyz1.size(0); const auto n = xyz1.size(1); const auto m = xyz2.size(1); TORCH_CHECK_EQ(xyz2.size(0), b); TORCH_CHECK_EQ(xyz1.size(2), 3); TORCH_CHECK_EQ(xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); auto cost = at::zeros({b}, xyz1.type()); AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] { matchcost<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), cost.data()); })); C10_CUDA_CHECK(cudaGetLastError()); return cost; } /******************************** * matchcostgrad2 kernel *********************************/ template __global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){ __shared__ scalar_t sum_grad[256*3]; for (int i=blockIdx.x;i __global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){ for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad1); // matchcostgrad2<<>>(b,n,m,xyz1,xyz2,match,grad2); //} /* MatchCost backward interface Input: grad_cost: (B) # gradients on cost xyz1: (B, N1, 3) # dataset_points xyz2: (B, N2, 3) # query_points match: (B, N2, N1) Output: grad1: (B, N1, 3) grad2: (B, N2, 3) */ std::vector MatchCostBackward( const at::Tensor grad_cost, const at::Tensor xyz1, const at::Tensor xyz2, const at::Tensor match){ const auto b = xyz1.size(0); const auto n = xyz1.size(1); const auto m = xyz2.size(1); TORCH_CHECK_EQ(xyz2.size(0), b); TORCH_CHECK_EQ(xyz1.size(2), 3); TORCH_CHECK_EQ(xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); auto grad1 = at::zeros({b, n, 3}, xyz1.type()); auto grad2 = at::zeros({b, m, 3}, xyz1.type()); AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] { matchcostgrad1<<<32,512>>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad1.data()); matchcostgrad2<<>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad2.data()); })); C10_CUDA_CHECK(cudaGetLastError()); return std::vector({grad1, grad2}); } #endif