💄 black autoformatting
This commit is contained in:
parent
51289d8670
commit
bb93f6ab47
3
emd.py
3
emd.py
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
|
||||||
import emd_cuda
|
import emd_cuda
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class EarthMoverDistanceFunction(torch.autograd.Function):
|
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||||
|
@ -43,4 +43,3 @@ def earth_mover_distance(xyz1, xyz2, transpose=True):
|
||||||
xyz2 = xyz2.transpose(1, 2)
|
xyz2 = xyz2.transpose(1, 2)
|
||||||
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
24
pyproject.toml
Normal file
24
pyproject.toml
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
[tool.ruff]
|
||||||
|
ignore-init-module-imports = true
|
||||||
|
select = ["E", "F", "I"]
|
||||||
|
line-length = 120
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
exclude = '''
|
||||||
|
/(
|
||||||
|
\.git
|
||||||
|
\.venv
|
||||||
|
)/
|
||||||
|
'''
|
||||||
|
include = '\.pyi?$'
|
||||||
|
line-length = 120
|
||||||
|
target-version = ["py310"]
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
multi_line_output = 3
|
||||||
|
profile = "black"
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.10"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
16
setup.py
16
setup.py
|
@ -9,19 +9,17 @@ Notes:
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='emd_ext',
|
name="emd_ext",
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name='emd_cuda',
|
name="emd_cuda",
|
||||||
sources=[
|
sources=[
|
||||||
'cuda/emd.cpp',
|
"cuda/emd.cpp",
|
||||||
'cuda/emd_kernel.cu',
|
"cuda/emd_kernel.cu",
|
||||||
],
|
],
|
||||||
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
|
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
cmdclass={
|
cmdclass={"build_ext": BuildExtension},
|
||||||
'build_ext': BuildExtension
|
)
|
||||||
})
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import torch
|
||||||
|
|
||||||
from emd import earth_mover_distance
|
from emd import earth_mover_distance
|
||||||
|
|
||||||
# gt
|
# gt
|
||||||
|
@ -13,10 +13,12 @@ print(p2)
|
||||||
p1.requires_grad = True
|
p1.requires_grad = True
|
||||||
p2.requires_grad = True
|
p2.requires_grad = True
|
||||||
|
|
||||||
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
|
gt_dist = (
|
||||||
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
|
(((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2
|
||||||
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
|
+ (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2
|
||||||
print('gt_dist: ', gt_dist)
|
+ (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3
|
||||||
|
)
|
||||||
|
print("gt_dist: ", gt_dist)
|
||||||
|
|
||||||
gt_dist.backward()
|
gt_dist.backward()
|
||||||
print(p1.grad)
|
print(p1.grad)
|
||||||
|
@ -41,4 +43,3 @@ print(loss)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
print(p1.grad)
|
print(p1.grad)
|
||||||
print(p2.grad)
|
print(p2.grad)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue