70 lines
2.2 KiB
Python
70 lines
2.2 KiB
Python
import torch, time
|
|
import chamfer2D.dist_chamfer_2D
|
|
import chamfer3D.dist_chamfer_3D
|
|
import chamfer5D.dist_chamfer_5D
|
|
import chamfer_python
|
|
|
|
cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist()
|
|
cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
|
|
cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist()
|
|
|
|
from torch.autograd import Variable
|
|
from fscore import fscore
|
|
|
|
def test_chamfer(distChamfer, dim):
|
|
points1 = torch.rand(4, 100, dim).cuda()
|
|
points2 = torch.rand(4, 200, dim, requires_grad=True).cuda()
|
|
dist1, dist2, idx1, idx2= distChamfer(points1, points2)
|
|
|
|
loss = torch.sum(dist1)
|
|
loss.backward()
|
|
|
|
mydist1, mydist2, myidx1, myidx2 = chamfer_python.distChamfer(points1, points2)
|
|
d1 = (dist1 - mydist1) ** 2
|
|
d2 = (dist2 - mydist2) ** 2
|
|
assert (
|
|
torch.mean(d1) + torch.mean(d2) < 0.00000001
|
|
), "chamfer cuda and chamfer normal are not giving the same results"
|
|
|
|
xd1 = idx1 - myidx1
|
|
xd2 = idx2 - myidx2
|
|
assert (
|
|
torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0
|
|
), "chamfer cuda and chamfer normal are not giving the same results"
|
|
print(f"fscore :", fscore(dist1, dist2))
|
|
print("Unit test passed")
|
|
|
|
|
|
def timings(distChamfer, dim):
|
|
p1 = torch.rand(32, 2000, dim).cuda()
|
|
p2 = torch.rand(32, 1000, dim).cuda()
|
|
print("Timings : Start CUDA version")
|
|
start = time.time()
|
|
num_it = 100
|
|
for i in range(num_it):
|
|
points1 = Variable(p1, requires_grad=True)
|
|
points2 = Variable(p2)
|
|
mydist1, mydist2, idx1, idx2 = distChamfer(points1, points2)
|
|
loss = torch.sum(mydist1)
|
|
loss.backward()
|
|
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
|
|
|
|
|
|
print("Timings : Start Pythonic version")
|
|
start = time.time()
|
|
for i in range(num_it):
|
|
points1 = Variable(p1, requires_grad=True)
|
|
points2 = Variable(p2)
|
|
mydist1, mydist2, idx1, idx2 = chamfer_python.distChamfer(points1, points2)
|
|
loss = torch.sum(mydist1)
|
|
loss.backward()
|
|
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
|
|
|
|
|
|
|
|
dims = [2,3,5]
|
|
for i,cham in enumerate([cham2D, cham3D, cham5D]):
|
|
print(f"testing Chamfer {dims[i]}D")
|
|
test_chamfer(cham, dims[i])
|
|
timings(cham, dims[i])
|