81 lines
2.2 KiB
Python
81 lines
2.2 KiB
Python
|
import torch
|
||
|
from einops import repeat
|
||
|
from torch import Tensor
|
||
|
|
||
|
from torchemd import earth_mover_distance
|
||
|
|
||
|
|
||
|
def generate_pointclouds() -> tuple[Tensor, Tensor]:
|
||
|
# create first point cloud
|
||
|
pc1 = torch.Tensor(
|
||
|
[
|
||
|
[1.7, -0.1, 0.1],
|
||
|
[0.1, 1.2, 0.3],
|
||
|
],
|
||
|
).cuda()
|
||
|
pc1 = repeat(pc1, "n c -> b n c", b=3)
|
||
|
pc1.requires_grad = True
|
||
|
|
||
|
# create second point cloud
|
||
|
pc2 = torch.Tensor(
|
||
|
[
|
||
|
[0.4, 1.8, 0.2],
|
||
|
[1.2, -0.2, 0.3],
|
||
|
],
|
||
|
).cuda()
|
||
|
pc2 = repeat(pc2, "n c -> b n c", b=3)
|
||
|
pc2.requires_grad = True
|
||
|
|
||
|
return pc1, pc2
|
||
|
|
||
|
|
||
|
def dummy_loss(distance: Tensor) -> Tensor:
|
||
|
return distance[0] / 2 + distance[1] * 2 + distance[2] / 3
|
||
|
|
||
|
|
||
|
def manual_distance_computation(pc1: Tensor, pc2: Tensor) -> Tensor:
|
||
|
return (pc1[:, 0] - pc2[:, 1]).pow(2).sum(dim=1) + (pc1[:, 1] - pc2[:, 0]).pow(2).sum(dim=1)
|
||
|
|
||
|
|
||
|
def test_emd():
|
||
|
# compute earth mover distance directly from formula
|
||
|
pc1, pc2 = generate_pointclouds()
|
||
|
ground_truth_distance = manual_distance_computation(pc1, pc2)
|
||
|
ground_truth_loss = dummy_loss(ground_truth_distance)
|
||
|
# get gradients of point clouds
|
||
|
ground_truth_loss.backward()
|
||
|
pc1_gt_grad = pc1.grad
|
||
|
pc2_gt_grad = pc2.grad
|
||
|
|
||
|
# compute earth mover distance directly from implementation
|
||
|
pc1, pc2 = generate_pointclouds()
|
||
|
computed_distance = earth_mover_distance(pc1, pc2, transpose=False)
|
||
|
loss = dummy_loss(computed_distance)
|
||
|
# get gradients of point clouds
|
||
|
loss.backward()
|
||
|
pc1_grad = pc1.grad
|
||
|
pc2_grad = pc2.grad
|
||
|
|
||
|
# compare gradients
|
||
|
assert pc1_grad.allclose(pc1_gt_grad)
|
||
|
assert pc2_grad.allclose(pc2_gt_grad)
|
||
|
|
||
|
# compare distances
|
||
|
assert computed_distance.allclose(ground_truth_distance)
|
||
|
|
||
|
# compare loss
|
||
|
assert loss.allclose(ground_truth_loss)
|
||
|
|
||
|
|
||
|
def test_equality():
|
||
|
pc1, _ = generate_pointclouds()
|
||
|
distance = earth_mover_distance(pc1, pc1, transpose=False)
|
||
|
assert distance.allclose(torch.zeros_like(distance))
|
||
|
|
||
|
|
||
|
def test_symetry():
|
||
|
pc1, pc2 = generate_pointclouds()
|
||
|
distance1 = earth_mover_distance(pc1, pc2, transpose=False)
|
||
|
distance2 = earth_mover_distance(pc2, pc1, transpose=False)
|
||
|
assert distance1.allclose(distance2)
|