💄 black autoformatting
This commit is contained in:
parent
51289d8670
commit
bb93f6ab47
3
emd.py
3
emd.py
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
import emd_cuda
|
||||
import torch
|
||||
|
||||
|
||||
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||
|
@ -43,4 +43,3 @@ def earth_mover_distance(xyz1, xyz2, transpose=True):
|
|||
xyz2 = xyz2.transpose(1, 2)
|
||||
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
||||
return cost
|
||||
|
||||
|
|
24
pyproject.toml
Normal file
24
pyproject.toml
Normal file
|
@ -0,0 +1,24 @@
|
|||
[tool.ruff]
|
||||
ignore-init-module-imports = true
|
||||
select = ["E", "F", "I"]
|
||||
line-length = 120
|
||||
|
||||
[tool.black]
|
||||
exclude = '''
|
||||
/(
|
||||
\.git
|
||||
\.venv
|
||||
)/
|
||||
'''
|
||||
include = '\.pyi?$'
|
||||
line-length = 120
|
||||
target-version = ["py310"]
|
||||
|
||||
[tool.isort]
|
||||
multi_line_output = 3
|
||||
profile = "black"
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
16
setup.py
16
setup.py
|
@ -9,19 +9,17 @@ Notes:
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
setup(
|
||||
name='emd_ext',
|
||||
name="emd_ext",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='emd_cuda',
|
||||
name="emd_cuda",
|
||||
sources=[
|
||||
'cuda/emd.cpp',
|
||||
'cuda/emd_kernel.cu',
|
||||
"cuda/emd.cpp",
|
||||
"cuda/emd_kernel.cu",
|
||||
],
|
||||
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
|
||||
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
|
||||
from emd import earth_mover_distance
|
||||
|
||||
# gt
|
||||
|
@ -13,10 +13,12 @@ print(p2)
|
|||
p1.requires_grad = True
|
||||
p2.requires_grad = True
|
||||
|
||||
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
|
||||
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
|
||||
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
|
||||
print('gt_dist: ', gt_dist)
|
||||
gt_dist = (
|
||||
(((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2
|
||||
+ (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2
|
||||
+ (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3
|
||||
)
|
||||
print("gt_dist: ", gt_dist)
|
||||
|
||||
gt_dist.backward()
|
||||
print(p1.grad)
|
||||
|
@ -41,4 +43,3 @@ print(loss)
|
|||
loss.backward()
|
||||
print(p1.grad)
|
||||
print(p2.grad)
|
||||
|
||||
|
|
Loading…
Reference in a new issue