7 lines
469 B
C++
7 lines
469 B
C++
|
std::vector<at::Tensor> ApproxMatch(at::Tensor in_a, at::Tensor in_b);
|
||
|
at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
|
||
|
std::vector<at::Tensor> MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
|
||
|
|
||
|
std::vector<at::Tensor> NNDistance(at::Tensor set_d, at::Tensor set_q);
|
||
|
std::vector<at::Tensor> NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2);
|