From c0d5c6754ba99b360a2e88836e6b2f258d2a7713 Mon Sep 17 00:00:00 2001 From: Laurent FAINSIN Date: Fri, 7 Apr 2023 10:15:59 +0200 Subject: [PATCH] fix: apply same patches as PVD --- third_party/PyTorchEMD/cuda/emd_kernel.cu | 25 +++++++++++------------ 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/third_party/PyTorchEMD/cuda/emd_kernel.cu b/third_party/PyTorchEMD/cuda/emd_kernel.cu index 8da5d88..db10493 100644 --- a/third_party/PyTorchEMD/cuda/emd_kernel.cu +++ b/third_party/PyTorchEMD/cuda/emd_kernel.cu @@ -11,7 +11,6 @@ #include #include // at::cuda::getApplyGrid -#include #define CHECK_INPUT(x) @@ -173,9 +172,9 @@ at::Tensor ApproxMatchForward( const auto n = xyz1.size(1); const auto m = xyz2.size(1); - CHECK_EQ(xyz2.size(0), b); - CHECK_EQ(xyz1.size(2), 3); - CHECK_EQ(xyz2.size(2), 3); + TORCH_CHECK_EQ(xyz2.size(0), b); + TORCH_CHECK_EQ(xyz1.size(2), 3); + TORCH_CHECK_EQ(xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); @@ -185,7 +184,7 @@ at::Tensor ApproxMatchForward( AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] { approxmatch<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), temp.data()); })); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return match; } @@ -260,9 +259,9 @@ at::Tensor MatchCostForward( const auto n = xyz1.size(1); const auto m = xyz2.size(1); - CHECK_EQ(xyz2.size(0), b); - CHECK_EQ(xyz1.size(2), 3); - CHECK_EQ(xyz2.size(2), 3); + TORCH_CHECK_EQ(xyz2.size(0), b); + TORCH_CHECK_EQ(xyz1.size(2), 3); + TORCH_CHECK_EQ(xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); @@ -271,7 +270,7 @@ at::Tensor MatchCostForward( AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] { matchcost<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), cost.data()); })); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return cost; } @@ -377,9 +376,9 @@ std::vector MatchCostBackward( const auto n = xyz1.size(1); const auto m = xyz2.size(1); - CHECK_EQ(xyz2.size(0), b); - CHECK_EQ(xyz1.size(2), 3); - CHECK_EQ(xyz2.size(2), 3); + TORCH_CHECK_EQ(xyz2.size(0), b); + TORCH_CHECK_EQ(xyz1.size(2), 3); + TORCH_CHECK_EQ(xyz2.size(2), 3); CHECK_INPUT(xyz1); CHECK_INPUT(xyz2); @@ -390,7 +389,7 @@ std::vector MatchCostBackward( matchcostgrad1<<<32,512>>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad1.data()); matchcostgrad2<<>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad2.data()); })); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return std::vector({grad1, grad2}); }