fix: apply same patches as PVD
This commit is contained in:
parent
231ff196c7
commit
c0d5c6754b
25
third_party/PyTorchEMD/cuda/emd_kernel.cu
vendored
25
third_party/PyTorchEMD/cuda/emd_kernel.cu
vendored
|
@ -11,7 +11,6 @@
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
||||||
#include <THC/THC.h>
|
|
||||||
|
|
||||||
#define CHECK_INPUT(x)
|
#define CHECK_INPUT(x)
|
||||||
|
|
||||||
|
@ -173,9 +172,9 @@ at::Tensor ApproxMatchForward(
|
||||||
const auto n = xyz1.size(1);
|
const auto n = xyz1.size(1);
|
||||||
const auto m = xyz2.size(1);
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
CHECK_EQ(xyz2.size(0), b);
|
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||||
CHECK_EQ(xyz1.size(2), 3);
|
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||||
CHECK_EQ(xyz2.size(2), 3);
|
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||||
CHECK_INPUT(xyz1);
|
CHECK_INPUT(xyz1);
|
||||||
CHECK_INPUT(xyz2);
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
@ -185,7 +184,7 @@ at::Tensor ApproxMatchForward(
|
||||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
||||||
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
||||||
}));
|
}));
|
||||||
THCudaCheck(cudaGetLastError());
|
C10_CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
return match;
|
return match;
|
||||||
}
|
}
|
||||||
|
@ -260,9 +259,9 @@ at::Tensor MatchCostForward(
|
||||||
const auto n = xyz1.size(1);
|
const auto n = xyz1.size(1);
|
||||||
const auto m = xyz2.size(1);
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
CHECK_EQ(xyz2.size(0), b);
|
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||||
CHECK_EQ(xyz1.size(2), 3);
|
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||||
CHECK_EQ(xyz2.size(2), 3);
|
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||||
CHECK_INPUT(xyz1);
|
CHECK_INPUT(xyz1);
|
||||||
CHECK_INPUT(xyz2);
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
@ -271,7 +270,7 @@ at::Tensor MatchCostForward(
|
||||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
||||||
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
||||||
}));
|
}));
|
||||||
THCudaCheck(cudaGetLastError());
|
C10_CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
return cost;
|
return cost;
|
||||||
}
|
}
|
||||||
|
@ -377,9 +376,9 @@ std::vector<at::Tensor> MatchCostBackward(
|
||||||
const auto n = xyz1.size(1);
|
const auto n = xyz1.size(1);
|
||||||
const auto m = xyz2.size(1);
|
const auto m = xyz2.size(1);
|
||||||
|
|
||||||
CHECK_EQ(xyz2.size(0), b);
|
TORCH_CHECK_EQ(xyz2.size(0), b);
|
||||||
CHECK_EQ(xyz1.size(2), 3);
|
TORCH_CHECK_EQ(xyz1.size(2), 3);
|
||||||
CHECK_EQ(xyz2.size(2), 3);
|
TORCH_CHECK_EQ(xyz2.size(2), 3);
|
||||||
CHECK_INPUT(xyz1);
|
CHECK_INPUT(xyz1);
|
||||||
CHECK_INPUT(xyz2);
|
CHECK_INPUT(xyz2);
|
||||||
|
|
||||||
|
@ -390,7 +389,7 @@ std::vector<at::Tensor> MatchCostBackward(
|
||||||
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
||||||
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
||||||
}));
|
}));
|
||||||
THCudaCheck(cudaGetLastError());
|
C10_CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
return std::vector<at::Tensor>({grad1, grad2});
|
return std::vector<at::Tensor>({grad1, grad2});
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue