45 lines
1.2 KiB
Python
45 lines
1.2 KiB
Python
|
import torch
|
||
|
import numpy as np
|
||
|
import time
|
||
|
from emd import earth_mover_distance
|
||
|
|
||
|
# gt
|
||
|
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
||
|
p1 = p1.repeat(3, 1, 1)
|
||
|
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
||
|
p2 = p2.repeat(3, 1, 1)
|
||
|
print(p1)
|
||
|
print(p2)
|
||
|
p1.requires_grad = True
|
||
|
p2.requires_grad = True
|
||
|
|
||
|
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
|
||
|
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
|
||
|
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
|
||
|
print('gt_dist: ', gt_dist)
|
||
|
|
||
|
gt_dist.backward()
|
||
|
print(p1.grad)
|
||
|
print(p2.grad)
|
||
|
|
||
|
# emd
|
||
|
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
||
|
p1 = p1.repeat(3, 1, 1)
|
||
|
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
||
|
p2 = p2.repeat(3, 1, 1)
|
||
|
print(p1)
|
||
|
print(p2)
|
||
|
p1.requires_grad = True
|
||
|
p2.requires_grad = True
|
||
|
|
||
|
d = earth_mover_distance(p1, p2, transpose=False)
|
||
|
print(d)
|
||
|
|
||
|
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
|
||
|
print(loss)
|
||
|
|
||
|
loss.backward()
|
||
|
print(p1.grad)
|
||
|
print(p2.grad)
|
||
|
|