from torch import nn from torch.autograd import Function import torch import importlib import os chamfer_found = importlib.find_loader("chamfer_2D") is not None if not chamfer_found: ## Cool trick from print("Jitting Chamfer 2D") cur_path = os.path.dirname(os.path.abspath(__file__)) build_path = cur_path.replace('chamfer2D', 'tmp') os.makedirs(build_path, exist_ok=True) from torch.utils.cpp_extension import load chamfer_2D = load(name="chamfer_2D", sources=[ "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), "/".join(os.path.abspath(__file__).split('/')[:-1] + [""]), ], build_directory=build_path) print("Loaded JIT 2D CUDA chamfer distance") else: import chamfer_2D print("Loaded compiled 2D CUDA chamfer distance") # Chamfer's distance module @thibaultgroueix # GPU tensors only class chamfer_2DFunction(Function): @staticmethod def forward(ctx, xyz1, xyz2): batchsize, n, dim = xyz1.size() assert dim==2, "Wrong last dimension for the chamfer distance 's input! Check with .size()" _, m, dim = xyz2.size() assert dim==2, "Wrong last dimension for the chamfer distance 's input! Check with .size()" device = xyz1.device device = xyz1.device dist1 = torch.zeros(batchsize, n) dist2 = torch.zeros(batchsize, m) idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) dist1 = dist2 = idx1 = idx2 = torch.cuda.set_device(device) chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) ctx.save_for_backward(xyz1, xyz2, idx1, idx2) return dist1, dist2, idx1, idx2 @staticmethod def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): xyz1, xyz2, idx1, idx2 = ctx.saved_tensors graddist1 = graddist1.contiguous() graddist2 = graddist2.contiguous() device = graddist1.device gradxyz1 = torch.zeros(xyz1.size()) gradxyz2 = torch.zeros(xyz2.size()) gradxyz1 = gradxyz2 = chamfer_2D.backward( xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 ) return gradxyz1, gradxyz2 class chamfer_2DDist(nn.Module): def __init__(self): super(chamfer_2DDist, self).__init__() def forward(self, input1, input2): input1 = input1.contiguous() input2 = input2.contiguous() return chamfer_2DFunction.apply(input1, input2)