from torch import nn from torch.autograd import Function import torch import importlib import os from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd cur_path = os.path.dirname(os.path.abspath(__file__)) build_path = cur_path.replace('chamfer3D', 'tmp') os.makedirs(build_path, exist_ok=True) from torch.utils.cpp_extension import load chamfer_3D = load(name="chamfer_3D", sources=[ "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), ], build_directory=build_path) #chamfer_found = importlib.find_loader("chamfer_3D") is not None #if not chamfer_found: # ## Cool trick from https://github.com/chrdiller # print("Jitting Chamfer 3D") # cur_path = os.path.dirname(os.path.abspath(__file__)) # build_path = cur_path.replace('chamfer3D', 'tmp') # os.makedirs(build_path, exist_ok=True) # # from torch.utils.cpp_extension import load # chamfer_3D = load(name="chamfer_3D", # sources=[ # "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), # "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), # ], build_directory=build_path) # print("Loaded JIT 3D CUDA chamfer distance") # #else: # import chamfer_3D # print("Loaded compiled 3D CUDA chamfer distance") # Chamfer's distance module @thibaultgroueix # GPU tensors only class chamfer_3DFunction(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, xyz1, xyz2): batchsize, n, dim = xyz1.size() assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" _, m, dim = xyz2.size() assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" device = xyz1.device device = xyz1.device dist1 = torch.zeros(batchsize, n) dist2 = torch.zeros(batchsize, m) idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) dist1 = dist1.to(device) dist2 = dist2.to(device) idx1 = idx1.to(device) idx2 = idx2.to(device) torch.cuda.set_device(device) chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) ctx.save_for_backward(xyz1, xyz2, idx1, idx2) return dist1, dist2, idx1, idx2 @staticmethod @custom_bwd def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): xyz1, xyz2, idx1, idx2 = ctx.saved_tensors graddist1 = graddist1.contiguous() graddist2 = graddist2.contiguous() device = graddist1.device gradxyz1 = torch.zeros(xyz1.size()) gradxyz2 = torch.zeros(xyz2.size()) gradxyz1 = gradxyz1.to(device) gradxyz2 = gradxyz2.to(device) chamfer_3D.backward( xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 ) return gradxyz1, gradxyz2 class chamfer_3DDist(nn.Module): def __init__(self): super(chamfer_3DDist, self).__init__() def forward(self, input1, input2): input1 = input1.contiguous() input2 = input2.contiguous() return chamfer_3DFunction.apply(input1, input2) # Chamfer's distance module @thibaultgroueix # GPU tensors only class chamfer_3DFunction_noGrad(Function): @staticmethod def forward(ctx, xyz1, xyz2): batchsize, n, dim = xyz1.size() assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" _, m, dim = xyz2.size() assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" device = xyz1.device device = xyz1.device dist1 = torch.zeros(batchsize, n) dist2 = torch.zeros(batchsize, m) idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) dist1 = dist1.to(device) dist2 = dist2.to(device) idx1 = idx1.to(device) idx2 = idx2.to(device) torch.cuda.set_device(device) chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) return dist1, dist2, idx1, idx2 class chamfer_3DDist_nograd(nn.Module): def __init__(self): super(chamfer_3DDist_nograd, self).__init__() def forward(self, input1, input2): input1 = input1.contiguous() input2 = input2.contiguous() return chamfer_3DFunction_noGrad.apply(input1, input2)