From 7e145fd312b211e26eafffb7e69b9d2e6d043710 Mon Sep 17 00:00:00 2001 From: Laurent FAINSIN Date: Mon, 24 Apr 2023 11:39:02 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=97=91=EF=B8=8F=20replace=20deprecated=20?= =?UTF-8?q?cuda=20instructions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cuda/emd_kernel.cu | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/cuda/emd_kernel.cu b/cuda/emd_kernel.cu index f63c344..de4db29 100644 --- a/cuda/emd_kernel.cu +++ b/cuda/emd_kernel.cu @@ -11,10 +11,9 @@ #include #include // at::cuda::getApplyGrid -#include -#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) @@ -175,9 +174,9 @@ at::Tensor ApproxMatchForward( const auto n = xyz1.size(1); const auto m = xyz2.size(1); - CHECK_EQ(xyz2.size(0), b); - CHECK_EQ(xyz1.size(2), 3); - CHECK_EQ(xyz2.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(0), b); + TORCH_CHECK_EQ (xyz1.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); @@ -187,7 +186,7 @@ at::Tensor ApproxMatchForward( AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] { approxmatch<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), temp.data()); })); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return match; } @@ -262,9 +261,9 @@ at::Tensor MatchCostForward( const auto n = xyz1.size(1); const auto m = xyz2.size(1); - CHECK_EQ(xyz2.size(0), b); - CHECK_EQ(xyz1.size(2), 3); - CHECK_EQ(xyz2.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(0), b); + TORCH_CHECK_EQ (xyz1.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); @@ -273,7 +272,7 @@ at::Tensor MatchCostForward( AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] { matchcost<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), cost.data()); })); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return cost; } @@ -379,9 +378,9 @@ std::vector MatchCostBackward( const auto n = xyz1.size(1); const auto m = xyz2.size(1); - CHECK_EQ(xyz2.size(0), b); - CHECK_EQ(xyz1.size(2), 3); - CHECK_EQ(xyz2.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(0), b); + TORCH_CHECK_EQ (xyz1.size(2), 3); + TORCH_CHECK_EQ (xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); @@ -392,7 +391,7 @@ std::vector MatchCostBackward( matchcostgrad1<<<32,512>>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad1.data()); matchcostgrad2<<>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad2.data()); })); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return std::vector({grad1, grad2}); }