From c4206265b38f0b37690f4ffccd8ab3557d3ae9e1 Mon Sep 17 00:00:00 2001 From: Kaichun MO Date: Tue, 10 Sep 2019 01:05:04 -0700 Subject: [PATCH] first commit --- .gitignore | 5 + README.md | 2 +- __init__.py | 0 cuda/emd.cpp | 29 ++++ cuda/emd_kernel.cu | 400 +++++++++++++++++++++++++++++++++++++++++++++ emd.py | 46 ++++++ setup.py | 27 +++ test_emd_loss.py | 44 +++++ 8 files changed, 552 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100755 __init__.py create mode 100755 cuda/emd.cpp create mode 100644 cuda/emd_kernel.cu create mode 100755 emd.py create mode 100755 setup.py create mode 100644 test_emd_loss.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8400d00 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__ +build +dist +emd_ext.egg-info +*.so diff --git a/README.md b/README.md index d736740..c29f579 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Check `test_emd_loss.py` for example. ## Author -The cuda code is originally written by Haoqiang Fan. The PyTorch version is modified by Kaichun Mo. +The cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps. ## License diff --git a/__init__.py b/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/cuda/emd.cpp b/cuda/emd.cpp new file mode 100755 index 0000000..b94db14 --- /dev/null +++ b/cuda/emd.cpp @@ -0,0 +1,29 @@ +#ifndef _EMD +#define _EMD + +#include +#include + +//CUDA declarations +at::Tensor ApproxMatchForward( + const at::Tensor xyz1, + const at::Tensor xyz2); + +at::Tensor MatchCostForward( + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match); + +std::vector MatchCostBackward( + const at::Tensor grad_cost, + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)"); + m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)"); + m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)"); +} + +#endif diff --git a/cuda/emd_kernel.cu b/cuda/emd_kernel.cu new file mode 100644 index 0000000..f63c344 --- /dev/null +++ b/cuda/emd_kernel.cu @@ -0,0 +1,400 @@ +/********************************** + * Original Author: Haoqiang Fan + * Modified by: Kaichun Mo + *********************************/ + +#ifndef _EMD_KERNEL +#define _EMD_KERNEL + +#include +#include + +#include +#include // at::cuda::getApplyGrid +#include + +#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + + +/******************************** +* Forward kernel for approxmatch +*********************************/ + +template +__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){ + scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; + scalar_t multiL,multiR; + if (n>=m){ + multiL=1; + multiR=n/m; + }else{ + multiL=m/n; + multiR=1; + } + const int Block=1024; + __shared__ scalar_t buf[Block*4]; + for (int i=blockIdx.x;i=-2;j--){ + scalar_t level=-powf(4.0f,j); + if (j==-2){ + level=0; + } + for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,temp); +//} + +/* ApproxMatch forward interface +Input: + xyz1: (B, N1, 3) # dataset_points + xyz2: (B, N2, 3) # query_points +Output: + match: (B, N2, N1) +*/ +at::Tensor ApproxMatchForward( + const at::Tensor xyz1, + const at::Tensor xyz2){ + const auto b = xyz1.size(0); + const auto n = xyz1.size(1); + const auto m = xyz2.size(1); + + CHECK_EQ(xyz2.size(0), b); + CHECK_EQ(xyz1.size(2), 3); + CHECK_EQ(xyz2.size(2), 3); + CHECK_INPUT(xyz1); + CHECK_INPUT(xyz2); + + auto match = at::zeros({b, m, n}, xyz1.type()); + auto temp = at::zeros({b, (n+m)*2}, xyz1.type()); + + AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] { + approxmatch<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), temp.data()); + })); + THCudaCheck(cudaGetLastError()); + + return match; +} + + +/******************************** +* Forward kernel for matchcost +*********************************/ + +template +__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){ + __shared__ scalar_t allsum[512]; + const int Block=1024; + __shared__ scalar_t buf[Block*3]; + for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out); +//} + +/* MatchCost forward interface +Input: + xyz1: (B, N1, 3) # dataset_points + xyz2: (B, N2, 3) # query_points + match: (B, N2, N1) +Output: + cost: (B) +*/ +at::Tensor MatchCostForward( + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match){ + const auto b = xyz1.size(0); + const auto n = xyz1.size(1); + const auto m = xyz2.size(1); + + CHECK_EQ(xyz2.size(0), b); + CHECK_EQ(xyz1.size(2), 3); + CHECK_EQ(xyz2.size(2), 3); + CHECK_INPUT(xyz1); + CHECK_INPUT(xyz2); + + auto cost = at::zeros({b}, xyz1.type()); + + AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] { + matchcost<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), cost.data()); + })); + THCudaCheck(cudaGetLastError()); + + return cost; +} + + +/******************************** +* matchcostgrad2 kernel +*********************************/ + +template +__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){ + __shared__ scalar_t sum_grad[256*3]; + for (int i=blockIdx.x;i +__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){ + for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad1); +// matchcostgrad2<<>>(b,n,m,xyz1,xyz2,match,grad2); +//} + + +/* MatchCost backward interface +Input: + grad_cost: (B) # gradients on cost + xyz1: (B, N1, 3) # dataset_points + xyz2: (B, N2, 3) # query_points + match: (B, N2, N1) +Output: + grad1: (B, N1, 3) + grad2: (B, N2, 3) +*/ +std::vector MatchCostBackward( + const at::Tensor grad_cost, + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor match){ + const auto b = xyz1.size(0); + const auto n = xyz1.size(1); + const auto m = xyz2.size(1); + + CHECK_EQ(xyz2.size(0), b); + CHECK_EQ(xyz1.size(2), 3); + CHECK_EQ(xyz2.size(2), 3); + CHECK_INPUT(xyz1); + CHECK_INPUT(xyz2); + + auto grad1 = at::zeros({b, n, 3}, xyz1.type()); + auto grad2 = at::zeros({b, m, 3}, xyz1.type()); + + AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] { + matchcostgrad1<<<32,512>>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad1.data()); + matchcostgrad2<<>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad2.data()); + })); + THCudaCheck(cudaGetLastError()); + + return std::vector({grad1, grad2}); +} + +#endif diff --git a/emd.py b/emd.py new file mode 100755 index 0000000..dc62714 --- /dev/null +++ b/emd.py @@ -0,0 +1,46 @@ +import torch +import emd_cuda + + +class EarthMoverDistanceFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + xyz1 = xyz1.contiguous() + xyz2 = xyz2.contiguous() + assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently." + match = emd_cuda.approxmatch_forward(xyz1, xyz2) + cost = emd_cuda.matchcost_forward(xyz1, xyz2, match) + ctx.save_for_backward(xyz1, xyz2, match) + return cost + + @staticmethod + def backward(ctx, grad_cost): + xyz1, xyz2, match = ctx.saved_tensors + grad_cost = grad_cost.contiguous() + grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match) + return grad_xyz1, grad_xyz2 + + +def earth_mover_distance(xyz1, xyz2, transpose=True): + """Earth Mover Distance (Approx) + + Args: + xyz1 (torch.Tensor): (b, 3, n1) + xyz2 (torch.Tensor): (b, 3, n1) + transpose (bool): whether to transpose inputs as it might be BCN format. + Extensions only support BNC format. + + Returns: + cost (torch.Tensor): (b) + + """ + if xyz1.dim() == 2: + xyz1 = xyz1.unsqueeze(0) + if xyz2.dim() == 2: + xyz2 = xyz2.unsqueeze(0) + if transpose: + xyz1 = xyz1.transpose(1, 2) + xyz2 = xyz2.transpose(1, 2) + cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) + return cost + diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..f648c3e --- /dev/null +++ b/setup.py @@ -0,0 +1,27 @@ +"""Setup extension + +Notes: + If extra_compile_args is provided, you need to provide different instances for different extensions. + Refer to https://github.com/pytorch/pytorch/issues/20169 + +""" + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name='emd_ext', + ext_modules=[ + CUDAExtension( + name='emd_cuda', + sources=[ + 'cuda/emd.cpp', + 'cuda/emd_kernel.cu', + ], + extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} + ), + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/test_emd_loss.py b/test_emd_loss.py new file mode 100644 index 0000000..66aa33c --- /dev/null +++ b/test_emd_loss.py @@ -0,0 +1,44 @@ +import torch +import numpy as np +import time +from emd import earth_mover_distance + +# gt +p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda() +p1 = p1.repeat(3, 1, 1) +p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda() +p2 = p2.repeat(3, 1, 1) +print(p1) +print(p2) +p1.requires_grad = True +p2.requires_grad = True + +gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \ + (((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \ + (((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3 +print('gt_dist: ', gt_dist) + +gt_dist.backward() +print(p1.grad) +print(p2.grad) + +# emd +p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda() +p1 = p1.repeat(3, 1, 1) +p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda() +p2 = p2.repeat(3, 1, 1) +print(p1) +print(p2) +p1.requires_grad = True +p2.requires_grad = True + +d = earth_mover_distance(p1, p2, transpose=False) +print(d) + +loss = d[0] / 2 + d[1] * 2 + d[2] / 3 +print(loss) + +loss.backward() +print(p1.grad) +print(p2.grad) +