134 lines
4.5 KiB
Python
134 lines
4.5 KiB
Python
from torch import nn
|
|
from torch.autograd import Function
|
|
import torch
|
|
import importlib
|
|
import os
|
|
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
|
|
cur_path = os.path.dirname(os.path.abspath(__file__))
|
|
build_path = cur_path.replace('chamfer3D', 'tmp')
|
|
os.makedirs(build_path, exist_ok=True)
|
|
|
|
from torch.utils.cpp_extension import load
|
|
chamfer_3D = load(name="chamfer_3D",
|
|
sources=[
|
|
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
|
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
|
|
], build_directory=build_path)
|
|
|
|
#chamfer_found = importlib.find_loader("chamfer_3D") is not None
|
|
#if not chamfer_found:
|
|
# ## Cool trick from https://github.com/chrdiller
|
|
# print("Jitting Chamfer 3D")
|
|
# cur_path = os.path.dirname(os.path.abspath(__file__))
|
|
# build_path = cur_path.replace('chamfer3D', 'tmp')
|
|
# os.makedirs(build_path, exist_ok=True)
|
|
#
|
|
# from torch.utils.cpp_extension import load
|
|
# chamfer_3D = load(name="chamfer_3D",
|
|
# sources=[
|
|
# "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
|
# "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
|
|
# ], build_directory=build_path)
|
|
# print("Loaded JIT 3D CUDA chamfer distance")
|
|
#
|
|
#else:
|
|
# import chamfer_3D
|
|
# print("Loaded compiled 3D CUDA chamfer distance")
|
|
|
|
|
|
# Chamfer's distance module @thibaultgroueix
|
|
# GPU tensors only
|
|
class chamfer_3DFunction(Function):
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float32)
|
|
def forward(ctx, xyz1, xyz2):
|
|
batchsize, n, dim = xyz1.size()
|
|
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
|
_, m, dim = xyz2.size()
|
|
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
|
device = xyz1.device
|
|
|
|
device = xyz1.device
|
|
|
|
dist1 = torch.zeros(batchsize, n)
|
|
dist2 = torch.zeros(batchsize, m)
|
|
|
|
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
|
|
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
|
|
|
|
dist1 = dist1.to(device)
|
|
dist2 = dist2.to(device)
|
|
idx1 = idx1.to(device)
|
|
idx2 = idx2.to(device)
|
|
torch.cuda.set_device(device)
|
|
|
|
chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
|
|
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
|
|
return dist1, dist2, idx1, idx2
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
|
|
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
|
|
graddist1 = graddist1.contiguous()
|
|
graddist2 = graddist2.contiguous()
|
|
device = graddist1.device
|
|
|
|
gradxyz1 = torch.zeros(xyz1.size())
|
|
gradxyz2 = torch.zeros(xyz2.size())
|
|
|
|
gradxyz1 = gradxyz1.to(device)
|
|
gradxyz2 = gradxyz2.to(device)
|
|
chamfer_3D.backward(
|
|
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
|
)
|
|
return gradxyz1, gradxyz2
|
|
|
|
|
|
class chamfer_3DDist(nn.Module):
|
|
def __init__(self):
|
|
super(chamfer_3DDist, self).__init__()
|
|
|
|
def forward(self, input1, input2):
|
|
input1 = input1.contiguous()
|
|
input2 = input2.contiguous()
|
|
return chamfer_3DFunction.apply(input1, input2)
|
|
|
|
|
|
# Chamfer's distance module @thibaultgroueix
|
|
# GPU tensors only
|
|
class chamfer_3DFunction_noGrad(Function):
|
|
@staticmethod
|
|
def forward(ctx, xyz1, xyz2):
|
|
batchsize, n, dim = xyz1.size()
|
|
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
|
_, m, dim = xyz2.size()
|
|
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
|
device = xyz1.device
|
|
|
|
device = xyz1.device
|
|
|
|
dist1 = torch.zeros(batchsize, n)
|
|
dist2 = torch.zeros(batchsize, m)
|
|
|
|
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
|
|
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
|
|
|
|
dist1 = dist1.to(device)
|
|
dist2 = dist2.to(device)
|
|
idx1 = idx1.to(device)
|
|
idx2 = idx2.to(device)
|
|
torch.cuda.set_device(device)
|
|
|
|
chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
|
|
return dist1, dist2, idx1, idx2
|
|
|
|
class chamfer_3DDist_nograd(nn.Module):
|
|
def __init__(self):
|
|
super(chamfer_3DDist_nograd, self).__init__()
|
|
|
|
def forward(self, input1, input2):
|
|
input1 = input1.contiguous()
|
|
input2 = input2.contiguous()
|
|
return chamfer_3DFunction_noGrad.apply(input1, input2)
|