🎨 autoformatting cpp/cu files
This commit is contained in:
parent
2375a23ca5
commit
892b79f5c5
|
@ -20,7 +20,8 @@ std::vector<at::Tensor> MatchCostBackward(
|
|||
const at::Tensor xyz2,
|
||||
const at::Tensor match);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("approxmatch_forward", &ApproxMatchForward, "ApproxMatch forward (CUDA)");
|
||||
m.def("matchcost_forward", &MatchCostForward, "MatchCost forward (CUDA)");
|
||||
m.def("matchcost_backward", &MatchCostBackward, "MatchCost backward (CUDA)");
|
||||
|
|
|
@ -14,27 +14,37 @@
|
|||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
/********************************
|
||||
* Forward kernel for approxmatch
|
||||
*********************************/
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){
|
||||
scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
|
||||
__global__ void approxmatch(int b, int n, int m, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, scalar_t *__restrict__ match, scalar_t *temp)
|
||||
{
|
||||
scalar_t *remainL = temp + blockIdx.x * (n + m) * 2;
|
||||
scalar_t *remainR = temp + blockIdx.x * (n + m) * 2 + n;
|
||||
scalar_t *ratioL = temp + blockIdx.x * (n + m) * 2 + n + m;
|
||||
scalar_t *ratioR = temp + blockIdx.x * (n + m) * 2 + n + m + n;
|
||||
scalar_t multiL, multiR;
|
||||
if (n>=m){
|
||||
|
||||
if (n >= m)
|
||||
{
|
||||
multiL = 1;
|
||||
multiR = n / m;
|
||||
}else{
|
||||
}
|
||||
else
|
||||
{
|
||||
multiL = m / n;
|
||||
multiR = 1;
|
||||
}
|
||||
const int Block = 1024;
|
||||
__shared__ scalar_t buf[Block * 4];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x)
|
||||
{
|
||||
for (int j = threadIdx.x; j < n * m; j += blockDim.x)
|
||||
match[i * n * m + j] = 0;
|
||||
for (int j = threadIdx.x; j < n; j += blockDim.x)
|
||||
|
@ -42,23 +52,29 @@ __global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1
|
|||
for (int j = threadIdx.x; j < m; j += blockDim.x)
|
||||
remainR[j] = multiR;
|
||||
__syncthreads();
|
||||
for (int j=7;j>=-2;j--){
|
||||
for (int j = 7; j >= -2; j--)
|
||||
{
|
||||
scalar_t level = -powf(4.0f, j);
|
||||
if (j==-2){
|
||||
if (j == -2)
|
||||
{
|
||||
level = 0;
|
||||
}
|
||||
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||
for (int k0 = 0; k0 < n; k0 += blockDim.x)
|
||||
{
|
||||
int k = k0 + threadIdx.x;
|
||||
scalar_t x1 = 0, y1 = 0, z1 = 0;
|
||||
if (k<n){
|
||||
if (k < n)
|
||||
{
|
||||
x1 = xyz1[i * n * 3 + k * 3 + 0];
|
||||
y1 = xyz1[i * n * 3 + k * 3 + 1];
|
||||
z1 = xyz1[i * n * 3 + k * 3 + 2];
|
||||
}
|
||||
scalar_t suml = 1e-9f;
|
||||
for (int l0=0;l0<m;l0+=Block){
|
||||
for (int l0 = 0; l0 < m; l0 += Block)
|
||||
{
|
||||
int lend = min(m, l0 + Block) - l0;
|
||||
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
|
||||
for (int l = threadIdx.x; l < lend; l += blockDim.x)
|
||||
{
|
||||
scalar_t x2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 0];
|
||||
scalar_t y2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 1];
|
||||
scalar_t z2 = xyz2[i * m * 3 + l0 * 3 + l * 3 + 2];
|
||||
|
@ -68,7 +84,8 @@ __global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1
|
|||
buf[l * 4 + 3] = remainR[l0 + l];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int l=0;l<lend;l++){
|
||||
for (int l = 0; l < lend; l++)
|
||||
{
|
||||
scalar_t x2 = buf[l * 4 + 0];
|
||||
scalar_t y2 = buf[l * 4 + 1];
|
||||
scalar_t z2 = buf[l * 4 + 2];
|
||||
|
@ -82,25 +99,30 @@ __global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1
|
|||
ratioL[k] = remainL[k] / suml;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int l0=0;l0<m;l0+=blockDim.x){
|
||||
for (int l0 = 0; l0 < m; l0 += blockDim.x)
|
||||
{
|
||||
int l = l0 + threadIdx.x;
|
||||
scalar_t x2 = 0, y2 = 0, z2 = 0;
|
||||
if (l<m){
|
||||
if (l < m)
|
||||
{
|
||||
x2 = xyz2[i * m * 3 + l * 3 + 0];
|
||||
y2 = xyz2[i * m * 3 + l * 3 + 1];
|
||||
z2 = xyz2[i * m * 3 + l * 3 + 2];
|
||||
}
|
||||
scalar_t sumr = 0;
|
||||
for (int k0=0;k0<n;k0+=Block){
|
||||
for (int k0 = 0; k0 < n; k0 += Block)
|
||||
{
|
||||
int kend = min(n, k0 + Block) - k0;
|
||||
for (int k=threadIdx.x;k<kend;k+=blockDim.x){
|
||||
for (int k = threadIdx.x; k < kend; k += blockDim.x)
|
||||
{
|
||||
buf[k * 4 + 0] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 0];
|
||||
buf[k * 4 + 1] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 1];
|
||||
buf[k * 4 + 2] = xyz1[i * n * 3 + k0 * 3 + k * 3 + 2];
|
||||
buf[k * 4 + 3] = ratioL[k0 + k];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int k=0;k<kend;k++){
|
||||
for (int k = 0; k < kend; k++)
|
||||
{
|
||||
scalar_t x1 = buf[k * 4 + 0];
|
||||
scalar_t y1 = buf[k * 4 + 1];
|
||||
scalar_t z1 = buf[k * 4 + 2];
|
||||
|
@ -109,7 +131,8 @@ __global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1
|
|||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (l<m){
|
||||
if (l < m)
|
||||
{
|
||||
sumr *= remainR[l];
|
||||
scalar_t consumption = fminf(remainR[l] / (sumr + 1e-9f), 1.0f);
|
||||
ratioR[l] = consumption * remainR[l];
|
||||
|
@ -117,18 +140,22 @@ __global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1
|
|||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||
for (int k0 = 0; k0 < n; k0 += blockDim.x)
|
||||
{
|
||||
int k = k0 + threadIdx.x;
|
||||
scalar_t x1 = 0, y1 = 0, z1 = 0;
|
||||
if (k<n){
|
||||
if (k < n)
|
||||
{
|
||||
x1 = xyz1[i * n * 3 + k * 3 + 0];
|
||||
y1 = xyz1[i * n * 3 + k * 3 + 1];
|
||||
z1 = xyz1[i * n * 3 + k * 3 + 2];
|
||||
}
|
||||
scalar_t suml = 0;
|
||||
for (int l0=0;l0<m;l0+=Block){
|
||||
for (int l0 = 0; l0 < m; l0 += Block)
|
||||
{
|
||||
int lend = min(m, l0 + Block) - l0;
|
||||
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
|
||||
for (int l = threadIdx.x; l < lend; l += blockDim.x)
|
||||
{
|
||||
buf[l * 4 + 0] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 0];
|
||||
buf[l * 4 + 1] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 1];
|
||||
buf[l * 4 + 2] = xyz2[i * m * 3 + l0 * 3 + l * 3 + 2];
|
||||
|
@ -136,8 +163,10 @@ __global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1
|
|||
}
|
||||
__syncthreads();
|
||||
scalar_t rl = ratioL[k];
|
||||
if (k<n){
|
||||
for (int l=0;l<lend;l++){
|
||||
if (k < n)
|
||||
{
|
||||
for (int l = 0; l < lend; l++)
|
||||
{
|
||||
scalar_t x2 = buf[l * 4 + 0];
|
||||
scalar_t y2 = buf[l * 4 + 1];
|
||||
scalar_t z2 = buf[l * 4 + 2];
|
||||
|
@ -156,10 +185,6 @@ __global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1
|
|||
}
|
||||
}
|
||||
|
||||
//void approxmatchLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,scalar_t * match,scalar_t * temp){
|
||||
// approxmatch<<<32,512>>>(b,n,m,xyz1,xyz2,match,temp);
|
||||
//}
|
||||
|
||||
/* ApproxMatch forward interface
|
||||
Input:
|
||||
xyz1: (B, N1, 3) # dataset_points
|
||||
|
@ -169,7 +194,8 @@ Output:
|
|||
*/
|
||||
at::Tensor ApproxMatchForward(
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2){
|
||||
const at::Tensor xyz2)
|
||||
{
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
@ -183,41 +209,46 @@ at::Tensor ApproxMatchForward(
|
|||
auto match = at::zeros({b, m, n}, xyz1.type());
|
||||
auto temp = at::zeros({b, (n + m) * 2}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
||||
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
||||
}));
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&]
|
||||
{ approxmatch<scalar_t><<<32, 512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>()); }));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
|
||||
/********************************
|
||||
* Forward kernel for matchcost
|
||||
*********************************/
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){
|
||||
__global__ void matchcost(int b, int n, int m, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ out)
|
||||
{
|
||||
__shared__ scalar_t allsum[512];
|
||||
const int Block = 1024;
|
||||
__shared__ scalar_t buf[Block * 3];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x)
|
||||
{
|
||||
scalar_t subsum = 0;
|
||||
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||
for (int k0 = 0; k0 < n; k0 += blockDim.x)
|
||||
{
|
||||
int k = k0 + threadIdx.x;
|
||||
scalar_t x1 = 0, y1 = 0, z1 = 0;
|
||||
if (k<n){
|
||||
if (k < n)
|
||||
{
|
||||
x1 = xyz1[i * n * 3 + k * 3 + 0];
|
||||
y1 = xyz1[i * n * 3 + k * 3 + 1];
|
||||
z1 = xyz1[i * n * 3 + k * 3 + 2];
|
||||
}
|
||||
for (int l0=0;l0<m;l0+=Block){
|
||||
for (int l0 = 0; l0 < m; l0 += Block)
|
||||
{
|
||||
int lend = min(m, l0 + Block) - l0;
|
||||
for (int l = threadIdx.x; l < lend * 3; l += blockDim.x)
|
||||
buf[l] = xyz2[i * m * 3 + l0 * 3 + l];
|
||||
__syncthreads();
|
||||
if (k<n){
|
||||
for (int l=0;l<lend;l++){
|
||||
if (k < n)
|
||||
{
|
||||
for (int l = 0; l < lend; l++)
|
||||
{
|
||||
scalar_t x2 = buf[l * 3 + 0];
|
||||
scalar_t y2 = buf[l * 3 + 1];
|
||||
scalar_t z2 = buf[l * 3 + 2];
|
||||
|
@ -229,9 +260,11 @@ __global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,c
|
|||
}
|
||||
}
|
||||
allsum[threadIdx.x] = subsum;
|
||||
for (int j=1;j<blockDim.x;j<<=1){
|
||||
for (int j = 1; j < blockDim.x; j <<= 1)
|
||||
{
|
||||
__syncthreads();
|
||||
if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){
|
||||
if ((threadIdx.x & j) == 0 && threadIdx.x + j < blockDim.x)
|
||||
{
|
||||
allsum[threadIdx.x] += allsum[threadIdx.x + j];
|
||||
}
|
||||
}
|
||||
|
@ -241,10 +274,6 @@ __global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,c
|
|||
}
|
||||
}
|
||||
|
||||
//void matchcostLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * out){
|
||||
// matchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);
|
||||
//}
|
||||
|
||||
/* MatchCost forward interface
|
||||
Input:
|
||||
xyz1: (B, N1, 3) # dataset_points
|
||||
|
@ -256,7 +285,8 @@ Output:
|
|||
at::Tensor MatchCostForward(
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2,
|
||||
const at::Tensor match){
|
||||
const at::Tensor match)
|
||||
{
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
@ -269,31 +299,33 @@ at::Tensor MatchCostForward(
|
|||
|
||||
auto cost = at::zeros({b}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
||||
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
||||
}));
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&]
|
||||
{ matchcost<scalar_t><<<32, 512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>()); }));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
/********************************
|
||||
* matchcostgrad2 kernel
|
||||
*********************************/
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){
|
||||
__global__ void matchcostgrad2(int b, int n, int m, const scalar_t *__restrict__ grad_cost, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ grad2)
|
||||
{
|
||||
__shared__ scalar_t sum_grad[256 * 3];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x)
|
||||
{
|
||||
int kbeg = m * blockIdx.y / gridDim.y;
|
||||
int kend = m * (blockIdx.y + 1) / gridDim.y;
|
||||
for (int k=kbeg;k<kend;k++){
|
||||
for (int k = kbeg; k < kend; k++)
|
||||
{
|
||||
scalar_t x2 = xyz2[(i * m + k) * 3 + 0];
|
||||
scalar_t y2 = xyz2[(i * m + k) * 3 + 1];
|
||||
scalar_t z2 = xyz2[(i * m + k) * 3 + 2];
|
||||
scalar_t subsumx = 0, subsumy = 0, subsumz = 0;
|
||||
for (int j=threadIdx.x;j<n;j+=blockDim.x){
|
||||
for (int j = threadIdx.x; j < n; j += blockDim.x)
|
||||
{
|
||||
scalar_t x1 = x2 - xyz1[(i * n + j) * 3 + 0];
|
||||
scalar_t y1 = y2 - xyz1[(i * n + j) * 3 + 1];
|
||||
scalar_t z1 = z2 - xyz1[(i * n + j) * 3 + 2];
|
||||
|
@ -305,17 +337,20 @@ __global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ g
|
|||
sum_grad[threadIdx.x * 3 + 0] = subsumx;
|
||||
sum_grad[threadIdx.x * 3 + 1] = subsumy;
|
||||
sum_grad[threadIdx.x * 3 + 2] = subsumz;
|
||||
for (int j=1;j<blockDim.x;j<<=1){
|
||||
for (int j = 1; j < blockDim.x; j <<= 1)
|
||||
{
|
||||
__syncthreads();
|
||||
int j1 = threadIdx.x;
|
||||
int j2 = threadIdx.x + j;
|
||||
if ((j1&j)==0 && j2<blockDim.x){
|
||||
if ((j1 & j) == 0 && j2 < blockDim.x)
|
||||
{
|
||||
sum_grad[j1 * 3 + 0] += sum_grad[j2 * 3 + 0];
|
||||
sum_grad[j1 * 3 + 1] += sum_grad[j2 * 3 + 1];
|
||||
sum_grad[j1 * 3 + 2] += sum_grad[j2 * 3 + 2];
|
||||
}
|
||||
}
|
||||
if (threadIdx.x==0){
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
grad2[(i * m + k) * 3 + 0] = sum_grad[0] * grad_cost[i];
|
||||
grad2[(i * m + k) * 3 + 1] = sum_grad[1] * grad_cost[i];
|
||||
grad2[(i * m + k) * 3 + 2] = sum_grad[2] * grad_cost[i];
|
||||
|
@ -330,14 +365,18 @@ __global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ g
|
|||
*********************************/
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int l=threadIdx.x;l<n;l+=blockDim.x){
|
||||
__global__ void matchcostgrad1(int b, int n, int m, const scalar_t *__restrict__ grad_cost, const scalar_t *__restrict__ xyz1, const scalar_t *__restrict__ xyz2, const scalar_t *__restrict__ match, scalar_t *__restrict__ grad1)
|
||||
{
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x)
|
||||
{
|
||||
for (int l = threadIdx.x; l < n; l += blockDim.x)
|
||||
{
|
||||
scalar_t x1 = xyz1[i * n * 3 + l * 3 + 0];
|
||||
scalar_t y1 = xyz1[i * n * 3 + l * 3 + 1];
|
||||
scalar_t z1 = xyz1[i * n * 3 + l * 3 + 2];
|
||||
scalar_t dx = 0, dy = 0, dz = 0;
|
||||
for (int k=0;k<m;k++){
|
||||
for (int k = 0; k < m; k++)
|
||||
{
|
||||
scalar_t x2 = xyz2[i * m * 3 + k * 3 + 0];
|
||||
scalar_t y2 = xyz2[i * m * 3 + k * 3 + 1];
|
||||
scalar_t z2 = xyz2[i * m * 3 + k * 3 + 2];
|
||||
|
@ -353,12 +392,6 @@ __global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ g
|
|||
}
|
||||
}
|
||||
|
||||
//void matchcostgradLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * grad1,scalar_t * grad2){
|
||||
// matchcostgrad1<<<32,512>>>(b,n,m,xyz1,xyz2,match,grad1);
|
||||
// matchcostgrad2<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);
|
||||
//}
|
||||
|
||||
|
||||
/* MatchCost backward interface
|
||||
Input:
|
||||
grad_cost: (B) # gradients on cost
|
||||
|
@ -373,7 +406,8 @@ std::vector<at::Tensor> MatchCostBackward(
|
|||
const at::Tensor grad_cost,
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2,
|
||||
const at::Tensor match){
|
||||
const at::Tensor match)
|
||||
{
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
@ -387,10 +421,10 @@ std::vector<at::Tensor> MatchCostBackward(
|
|||
auto grad1 = at::zeros({b, n, 3}, xyz1.type());
|
||||
auto grad2 = at::zeros({b, m, 3}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] {
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&]
|
||||
{
|
||||
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
||||
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
||||
}));
|
||||
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>()); }));
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return std::vector<at::Tensor>({grad1, grad2});
|
||||
|
|
Loading…
Reference in a new issue