import torch from einops import repeat from torch import Tensor from torchemd import earth_mover_distance def generate_pointclouds() -> tuple[Tensor, Tensor]: # create first point cloud pc1 = torch.Tensor( [ [1.7, -0.1, 0.1], [0.1, 1.2, 0.3], ], ).cuda() pc1 = repeat(pc1, "n c -> b n c", b=3) pc1.requires_grad = True # create second point cloud pc2 = torch.Tensor( [ [0.4, 1.8, 0.2], [1.2, -0.2, 0.3], ], ).cuda() pc2 = repeat(pc2, "n c -> b n c", b=3) pc2.requires_grad = True return pc1, pc2 def dummy_loss(distance: Tensor) -> Tensor: return distance[0] / 2 + distance[1] * 2 + distance[2] / 3 def manual_distance_computation(pc1: Tensor, pc2: Tensor) -> Tensor: return (pc1[:, 0] - pc2[:, 1]).pow(2).sum(dim=1) + (pc1[:, 1] - pc2[:, 0]).pow(2).sum(dim=1) def test_emd(): # compute earth mover distance directly from formula pc1, pc2 = generate_pointclouds() ground_truth_distance = manual_distance_computation(pc1, pc2) ground_truth_loss = dummy_loss(ground_truth_distance) # get gradients of point clouds ground_truth_loss.backward() pc1_gt_grad = pc1.grad pc2_gt_grad = pc2.grad # compute earth mover distance directly from implementation pc1, pc2 = generate_pointclouds() computed_distance = earth_mover_distance(pc1, pc2, transpose=False) loss = dummy_loss(computed_distance) # get gradients of point clouds loss.backward() pc1_grad = pc1.grad pc2_grad = pc2.grad # compare gradients assert pc1_grad.allclose(pc1_gt_grad) assert pc2_grad.allclose(pc2_gt_grad) # compare distances assert computed_distance.allclose(ground_truth_distance) # compare loss assert loss.allclose(ground_truth_loss) def test_equality(): pc1, _ = generate_pointclouds() distance = earth_mover_distance(pc1, pc1, transpose=False) assert distance.allclose(torch.zeros_like(distance)) def test_symetry(): pc1, pc2 = generate_pointclouds() distance1 = earth_mover_distance(pc1, pc2, transpose=False) distance2 = earth_mover_distance(pc2, pc1, transpose=False) assert distance1.allclose(distance2)