PyTorchEMD/tests/test_correctness.py
2023-04-24 15:09:41 +02:00

81 lines
2.2 KiB
Python

import torch
from einops import repeat
from torch import Tensor
from torchemd import earth_mover_distance
def generate_pointclouds() -> tuple[Tensor, Tensor]:
# create first point cloud
pc1 = torch.Tensor(
[
[1.7, -0.1, 0.1],
[0.1, 1.2, 0.3],
],
).cuda()
pc1 = repeat(pc1, "n c -> b n c", b=3)
pc1.requires_grad = True
# create second point cloud
pc2 = torch.Tensor(
[
[0.4, 1.8, 0.2],
[1.2, -0.2, 0.3],
],
).cuda()
pc2 = repeat(pc2, "n c -> b n c", b=3)
pc2.requires_grad = True
return pc1, pc2
def dummy_loss(distance: Tensor) -> Tensor:
return distance[0] / 2 + distance[1] * 2 + distance[2] / 3
def manual_distance_computation(pc1: Tensor, pc2: Tensor) -> Tensor:
return (pc1[:, 0] - pc2[:, 1]).pow(2).sum(dim=1) + (pc1[:, 1] - pc2[:, 0]).pow(2).sum(dim=1)
def test_emd():
# compute earth mover distance directly from formula
pc1, pc2 = generate_pointclouds()
ground_truth_distance = manual_distance_computation(pc1, pc2)
ground_truth_loss = dummy_loss(ground_truth_distance)
# get gradients of point clouds
ground_truth_loss.backward()
pc1_gt_grad = pc1.grad
pc2_gt_grad = pc2.grad
# compute earth mover distance directly from implementation
pc1, pc2 = generate_pointclouds()
computed_distance = earth_mover_distance(pc1, pc2, transpose=False)
loss = dummy_loss(computed_distance)
# get gradients of point clouds
loss.backward()
pc1_grad = pc1.grad
pc2_grad = pc2.grad
# compare gradients
assert pc1_grad.allclose(pc1_gt_grad)
assert pc2_grad.allclose(pc2_gt_grad)
# compare distances
assert computed_distance.allclose(ground_truth_distance)
# compare loss
assert loss.allclose(ground_truth_loss)
def test_equality():
pc1, _ = generate_pointclouds()
distance = earth_mover_distance(pc1, pc1, transpose=False)
assert distance.allclose(torch.zeros_like(distance))
def test_symetry():
pc1, pc2 = generate_pointclouds()
distance1 = earth_mover_distance(pc1, pc2, transpose=False)
distance2 = earth_mover_distance(pc2, pc1, transpose=False)
assert distance1.allclose(distance2)