diff --git a/tests/test_correctness.py b/tests/test_correctness.py new file mode 100644 index 0000000..c6a9071 --- /dev/null +++ b/tests/test_correctness.py @@ -0,0 +1,80 @@ +import torch +from einops import repeat +from torch import Tensor + +from torchemd import earth_mover_distance + + +def generate_pointclouds() -> tuple[Tensor, Tensor]: + # create first point cloud + pc1 = torch.Tensor( + [ + [1.7, -0.1, 0.1], + [0.1, 1.2, 0.3], + ], + ).cuda() + pc1 = repeat(pc1, "n c -> b n c", b=3) + pc1.requires_grad = True + + # create second point cloud + pc2 = torch.Tensor( + [ + [0.4, 1.8, 0.2], + [1.2, -0.2, 0.3], + ], + ).cuda() + pc2 = repeat(pc2, "n c -> b n c", b=3) + pc2.requires_grad = True + + return pc1, pc2 + + +def dummy_loss(distance: Tensor) -> Tensor: + return distance[0] / 2 + distance[1] * 2 + distance[2] / 3 + + +def manual_distance_computation(pc1: Tensor, pc2: Tensor) -> Tensor: + return (pc1[:, 0] - pc2[:, 1]).pow(2).sum(dim=1) + (pc1[:, 1] - pc2[:, 0]).pow(2).sum(dim=1) + + +def test_emd(): + # compute earth mover distance directly from formula + pc1, pc2 = generate_pointclouds() + ground_truth_distance = manual_distance_computation(pc1, pc2) + ground_truth_loss = dummy_loss(ground_truth_distance) + # get gradients of point clouds + ground_truth_loss.backward() + pc1_gt_grad = pc1.grad + pc2_gt_grad = pc2.grad + + # compute earth mover distance directly from implementation + pc1, pc2 = generate_pointclouds() + computed_distance = earth_mover_distance(pc1, pc2, transpose=False) + loss = dummy_loss(computed_distance) + # get gradients of point clouds + loss.backward() + pc1_grad = pc1.grad + pc2_grad = pc2.grad + + # compare gradients + assert pc1_grad.allclose(pc1_gt_grad) + assert pc2_grad.allclose(pc2_gt_grad) + + # compare distances + assert computed_distance.allclose(ground_truth_distance) + + # compare loss + assert loss.allclose(ground_truth_loss) + + +def test_equality(): + pc1, _ = generate_pointclouds() + distance = earth_mover_distance(pc1, pc1, transpose=False) + assert distance.allclose(torch.zeros_like(distance)) + + +def test_symetry(): + pc1, pc2 = generate_pointclouds() + distance1 = earth_mover_distance(pc1, pc2, transpose=False) + distance2 = earth_mover_distance(pc2, pc1, transpose=False) + assert distance1.allclose(distance2)