💄 black autoformatting

This commit is contained in:
Laurent FAINSIN 2023-04-24 11:45:30 +02:00
parent 51289d8670
commit bb93f6ab47
4 changed files with 40 additions and 18 deletions

3
emd.py
View file

@ -1,5 +1,5 @@
import torch
import emd_cuda import emd_cuda
import torch
class EarthMoverDistanceFunction(torch.autograd.Function): class EarthMoverDistanceFunction(torch.autograd.Function):
@ -43,4 +43,3 @@ def earth_mover_distance(xyz1, xyz2, transpose=True):
xyz2 = xyz2.transpose(1, 2) xyz2 = xyz2.transpose(1, 2)
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
return cost return cost

24
pyproject.toml Normal file
View file

@ -0,0 +1,24 @@
[tool.ruff]
ignore-init-module-imports = true
select = ["E", "F", "I"]
line-length = 120
[tool.black]
exclude = '''
/(
\.git
\.venv
)/
'''
include = '\.pyi?$'
line-length = 120
target-version = ["py310"]
[tool.isort]
multi_line_output = 3
profile = "black"
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true

View file

@ -9,19 +9,17 @@ Notes:
from setuptools import setup from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup( setup(
name='emd_ext', name="emd_ext",
ext_modules=[ ext_modules=[
CUDAExtension( CUDAExtension(
name='emd_cuda', name="emd_cuda",
sources=[ sources=[
'cuda/emd.cpp', "cuda/emd.cpp",
'cuda/emd_kernel.cu', "cuda/emd_kernel.cu",
], ],
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
), ),
], ],
cmdclass={ cmdclass={"build_ext": BuildExtension},
'build_ext': BuildExtension )
})

View file

@ -1,6 +1,6 @@
import torch
import numpy as np import numpy as np
import time import torch
from emd import earth_mover_distance from emd import earth_mover_distance
# gt # gt
@ -13,10 +13,12 @@ print(p2)
p1.requires_grad = True p1.requires_grad = True
p2.requires_grad = True p2.requires_grad = True
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \ gt_dist = (
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \ (((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3 + (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2
print('gt_dist: ', gt_dist) + (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3
)
print("gt_dist: ", gt_dist)
gt_dist.backward() gt_dist.backward()
print(p1.grad) print(p1.grad)
@ -41,4 +43,3 @@ print(loss)
loss.backward() loss.backward()
print(p1.grad) print(p1.grad)
print(p2.grad) print(p2.grad)