LION/third_party/PyTorchEMD/backend.py
2023-01-23 00:14:49 -05:00

22 lines
767 B
Python
Executable file

import os
import time
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(os.path.join(_src_path, 'build_dynamic')):
os.makedirs(os.path.join(_src_path, 'build_dynamic'))
tic = time.time()
emd_cuda_dynamic = load(name='emd_ext',
extra_cflags=['-O3', '-std=c++17'],
## build_directory=os.path.join(_src_path, 'build_dynamic'),
verbose=True,
sources=[
os.path.join(_src_path, f) for f in [
'cuda/emd.cpp',
'cuda/emd_kernel.cu',
]
])
print('load emd_ext time: {:.3f}s'.format(time.time() - tic))
__all__ = ['emd_cuda_dynamic']