diff --git a/cuda/emd_kernel.cu b/cuda/emd_kernel.cu index f63c344..de4db29 100644 --- a/cuda/emd_kernel.cu +++ b/cuda/emd_kernel.cu @@ -11,10 +11,9 @@ #include #include // at::cuda::getApplyGrid -#include -#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) @@ -175,9 +174,9 @@ at::Tensor ApproxMatchForward( const auto n = xyz1.size(1); const auto m = xyz2.size(1); - CHECK_EQ(xyz2.size(0), b); - CHECK_EQ(xyz1.size(2), 3); - CHECK_EQ(xyz2.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(0), b); + TORCH_CHECK_EQ (xyz1.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); @@ -187,7 +186,7 @@ at::Tensor ApproxMatchForward( AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] { approxmatch<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), temp.data()); })); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return match; } @@ -262,9 +261,9 @@ at::Tensor MatchCostForward( const auto n = xyz1.size(1); const auto m = xyz2.size(1); - CHECK_EQ(xyz2.size(0), b); - CHECK_EQ(xyz1.size(2), 3); - CHECK_EQ(xyz2.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(0), b); + TORCH_CHECK_EQ (xyz1.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); @@ -273,7 +272,7 @@ at::Tensor MatchCostForward( AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] { matchcost<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), cost.data()); })); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return cost; } @@ -379,9 +378,9 @@ std::vector MatchCostBackward( const auto n = xyz1.size(1); const auto m = xyz2.size(1); - CHECK_EQ(xyz2.size(0), b); - CHECK_EQ(xyz1.size(2), 3); - CHECK_EQ(xyz2.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(0), b); + TORCH_CHECK_EQ (xyz1.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); @@ -392,7 +391,7 @@ std::vector MatchCostBackward( matchcostgrad1<<<32,512>>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad1.data()); matchcostgrad2<<>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad2.data()); })); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return std::vector({grad1, grad2}); }