💄 black autoformatting

This commit is contained in:
Laurent FAINSIN 2023-04-24 11:45:30 +02:00
parent 51289d8670
commit bb93f6ab47
4 changed files with 40 additions and 18 deletions

3
emd.py
View file

@ -1,5 +1,5 @@
import torch
import emd_cuda
import torch
class EarthMoverDistanceFunction(torch.autograd.Function):
@ -43,4 +43,3 @@ def earth_mover_distance(xyz1, xyz2, transpose=True):
xyz2 = xyz2.transpose(1, 2)
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
return cost

24
pyproject.toml Normal file
View file

@ -0,0 +1,24 @@
[tool.ruff]
ignore-init-module-imports = true
select = ["E", "F", "I"]
line-length = 120
[tool.black]
exclude = '''
/(
\.git
\.venv
)/
'''
include = '\.pyi?$'
line-length = 120
target-version = ["py310"]
[tool.isort]
multi_line_output = 3
profile = "black"
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true

View file

@ -9,19 +9,17 @@ Notes:
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='emd_ext',
name="emd_ext",
ext_modules=[
CUDAExtension(
name='emd_cuda',
name="emd_cuda",
sources=[
'cuda/emd.cpp',
'cuda/emd_kernel.cu',
"cuda/emd.cpp",
"cuda/emd_kernel.cu",
],
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
),
],
cmdclass={
'build_ext': BuildExtension
})
cmdclass={"build_ext": BuildExtension},
)

View file

@ -1,6 +1,6 @@
import torch
import numpy as np
import time
import torch
from emd import earth_mover_distance
# gt
@ -13,10 +13,12 @@ print(p2)
p1.requires_grad = True
p2.requires_grad = True
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
print('gt_dist: ', gt_dist)
gt_dist = (
(((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2
+ (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2
+ (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3
)
print("gt_dist: ", gt_dist)
gt_dist.backward()
print(p1.grad)
@ -41,4 +43,3 @@ print(loss)
loss.backward()
print(p1.grad)
print(p2.grad)