PointFlow/metrics/pytorch_structural_losses/pybind/extern.hpp

7 lines
469 B
C++
Raw Normal View History

2019-07-14 04:32:26 +00:00
std::vector<at::Tensor> ApproxMatch(at::Tensor in_a, at::Tensor in_b);
at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
std::vector<at::Tensor> MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
std::vector<at::Tensor> NNDistance(at::Tensor set_d, at::Tensor set_q);
std::vector<at::Tensor> NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2);