🗑️ replace deprecated cuda instructions
This commit is contained in:
parent
517dd44d0e
commit
7e145fd312
|
@ -11,10 +11,9 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
||||
#include <THC/THC.h>
|
||||
|
||||
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
|
||||
|
@ -175,9 +174,9 @@ at::Tensor ApproxMatchForward(
|
|||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
CHECK_EQ(xyz2.size(0), b);
|
||||
CHECK_EQ(xyz1.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(2), 3);
|
||||
TORCH_CHECK_EQ (xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ (xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ (xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
|
@ -187,7 +186,7 @@ at::Tensor ApproxMatchForward(
|
|||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
||||
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
||||
}));
|
||||
THCudaCheck(cudaGetLastError());
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return match;
|
||||
}
|
||||
|
@ -262,9 +261,9 @@ at::Tensor MatchCostForward(
|
|||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
CHECK_EQ(xyz2.size(0), b);
|
||||
CHECK_EQ(xyz1.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(2), 3);
|
||||
TORCH_CHECK_EQ (xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ (xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ (xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
|
@ -273,7 +272,7 @@ at::Tensor MatchCostForward(
|
|||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
||||
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
||||
}));
|
||||
THCudaCheck(cudaGetLastError());
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return cost;
|
||||
}
|
||||
|
@ -379,9 +378,9 @@ std::vector<at::Tensor> MatchCostBackward(
|
|||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
CHECK_EQ(xyz2.size(0), b);
|
||||
CHECK_EQ(xyz1.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(2), 3);
|
||||
TORCH_CHECK_EQ (xyz2.size(0), b);
|
||||
TORCH_CHECK_EQ (xyz1.size(2), 3);
|
||||
TORCH_CHECK_EQ (xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
|
@ -392,7 +391,7 @@ std::vector<at::Tensor> MatchCostBackward(
|
|||
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
||||
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
||||
}));
|
||||
THCudaCheck(cudaGetLastError());
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
return std::vector<at::Tensor>({grad1, grad2});
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue