1. Fix the bug in pull request #20; 2. Add CD/EMD without CUDA kernel. 3, add an env.yaml to create conda environment.

This commit is contained in:
Grendel 2022-04-07 16:05:49 +02:00
parent fd00fce0da
commit b7a9216ffc
4 changed files with 200 additions and 47 deletions

View file

@ -35,13 +35,11 @@ As 3D point clouds become the representation of choice for multiple vision and g
Following is the suggested way to install these dependencies:
```bash
# Create a new conda environment
conda create -n PointFlow python=3.6
conda activate PointFlow
conda env create -f environment.yml
conda activate PointFlow
# Install pytorch (please refer to the commend in the official website)
conda install pytorch=1.0.1 torchvision cudatoolkit=10.0 -c pytorch -y
# Install other dependencies such as torchdiffeq, structural losses, etc.
# Compile structural losses, etc., but this step is not required
# as there is a version of EMD/CD not requiring the CUDA kernel.
./install.sh
```

128
env.yaml Normal file
View file

@ -0,0 +1,128 @@
name: PointFlow
channels:
- pytorch
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=4.5=1_gnu
- blas=1.0=mkl
- brotli=1.0.9=he6710b0_2
- brotlipy=0.7.0=py37h27cfd23_1003
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2022.3.29=h06a4308_0
- certifi=2021.10.8=py37h06a4308_2
- cffi=1.15.0=py37hd667e15_1
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- colorama=0.4.4=pyhd3eb1b0_0
- cryptography=36.0.0=py37h9ce1e76_0
- cudatoolkit=10.2.89=hfd86e86_1
- cycler=0.11.0=pyhd3eb1b0_0
- dbus=1.13.18=hb2f20db_0
- expat=2.4.4=h295c915_0
- ffmpeg=4.3=hf484d3e_0
- fontconfig=2.13.1=h6c09931_0
- fonttools=4.25.0=pyhd3eb1b0_0
- freetype=2.11.0=h70c0345_0
- giflib=5.2.1=h7b6447c_0
- glib=2.69.1=h4ff587b_1
- gmp=6.2.1=h2531618_2
- gnutls=3.6.15=he1e5248_0
- gst-plugins-base=1.14.0=h8213a91_2
- gstreamer=1.14.0=h28cd5cc_2
- icu=58.2=he6710b0_3
- idna=3.3=pyhd3eb1b0_0
- intel-openmp=2021.4.0=h06a4308_3561
- joblib=1.1.0=pyhd3eb1b0_0
- jpeg=9d=h7f8727e_0
- kiwisolver=1.3.2=py37h295c915_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.35.1=h7274673_9
- libffi=3.3=he6710b0_2
- libgcc-ng=9.3.0=h5101ec6_17
- libgfortran-ng=7.5.0=ha8ba4b0_17
- libgfortran4=7.5.0=ha8ba4b0_17
- libgomp=9.3.0=h5101ec6_17
- libiconv=1.15=h63c8f33_5
- libidn2=2.3.2=h7f8727e_0
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=9.3.0=hd4cf53a_17
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.2.0=h85742a9_0
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.0.3=h7f8727e_2
- libuv=1.40.0=h7b6447c_0
- libwebp=1.2.2=h55f646e_0
- libwebp-base=1.2.2=h7f8727e_0
- libxcb=1.14=h7b6447c_0
- libxml2=2.9.12=h03d6c58_0
- lz4-c=1.9.3=h295c915_1
- matplotlib=3.5.1=py37h06a4308_1
- matplotlib-base=3.5.1=py37ha18d171_1
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py37h7f8727e_0
- mkl_fft=1.3.1=py37hd3c417c_0
- mkl_random=1.2.2=py37h51133e4_0
- munkres=1.1.4=py_0
- ncurses=6.3=h7f8727e_2
- nettle=3.7.3=hbbd107a_1
- numpy-base=1.21.2=py37h79a1101_0
- openh264=2.1.1=h4ff587b_0
- openssl=1.1.1n=h7f8727e_0
- packaging=21.3=pyhd3eb1b0_0
- pcre=8.45=h295c915_0
- pip=21.2.2=py37h06a4308_0
- pycparser=2.21=pyhd3eb1b0_0
- pyopenssl=22.0.0=pyhd3eb1b0_0
- pyparsing=3.0.4=pyhd3eb1b0_0
- pyqt=5.9.2=py37h05f1152_2
- pysocks=1.7.1=py37_1
- python=3.7.13=h12debd9_0
- python-dateutil=2.8.2=pyhd3eb1b0_0
- pytorch=1.11.0=py3.7_cuda10.2_cudnn7.6.5_0
- pytorch-mutex=1.0=cuda
- qt=5.9.7=h5867ecd_1
- readline=8.1.2=h7f8727e_1
- requests=2.27.1=pyhd3eb1b0_0
- scikit-learn=1.0.2=py37h51133e4_1
- scipy=1.7.3=py37hc147768_0
- setuptools=58.0.4=py37h06a4308_0
- sip=4.19.8=py37hf484d3e_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.38.2=hc218d9a_0
- threadpoolctl=2.2.0=pyh0d69192_0
- tk=8.6.11=h1ccaba5_0
- torchaudio=0.11.0=py37_cu102
- torchvision=0.12.0=py37_cu102
- tornado=6.1=py37h27cfd23_0
- tqdm=4.63.0=pyhd3eb1b0_0
- typing_extensions=4.1.1=pyh06a4308_0
- urllib3=1.26.8=pyhd3eb1b0_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7f8727e_4
- zstd=1.4.9=haebb681_0
- pip:
- absl-py==1.0.0
- astor==0.8.1
- cached-property==1.5.2
- gast==0.5.3
- grpcio==1.44.0
- h5py==3.6.0
- importlib-metadata==4.11.3
- keras-applications==1.0.8
- keras-preprocessing==1.1.2
- markdown==3.3.6
- mock==4.0.3
- numpy==1.16.4
- pillow==5.0.0
- protobuf==3.20.0
- tensorboard==1.13.1
- tensorboardx==1.7
- tensorflow-estimator==1.13.0
- tensorflow-gpu==1.13.1
- termcolor==1.1.0
- torchdiffeq==0.0.1
- werkzeug==2.1.1
- zipp==3.8.0
prefix: /home/grendelyang/anaconda3/envs/PointFlow

View file

@ -1,19 +1,18 @@
#! /bin/bash
root=`pwd`
# Install dependecies
conda install numpy matplotlib pillow scipy tqdm scikit-learn -y
pip install tensorflow-gpu==1.13.1
pip install tensorboardX==1.7
# conda install pytorch==1.0.1 torchvision==0.2.2 cudatoolkit=10.0 -c pytorch
# conda install matplotlib tqdm scikit-learn -y
# pip install pillow==5.0.0
# pip install scipy==1.0.1
# pip install numpy==1.16.4
# pip install tensorflow-gpu==1.13.1
# pip install tensorboardX==1.7
# pip install torchdiffeq==0.0.1
# Compile CUDA kernel for CD/EMD loss
root=`pwd`
cd metrics/pytorch_structural_losses/
make clean
make
cd $root
# install torchdiffeq
git clone https://github.com/rtqichen/torchdiffeq.git
cd torchdiffeq
pip install -e .

View file

@ -4,31 +4,7 @@ import warnings
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/
from .StructuralLosses.match_cost import match_cost
from .StructuralLosses.nn_distance import nn_distance
# # Import CUDA version of CD, borrowed from https://github.com/ThibaultGROUEIX/AtlasNet
# try:
# from . chamfer_distance_ext.dist_chamfer import chamferDist
# CD = chamferDist()
# def distChamferCUDA(x,y):
# return CD(x,y,gpu)
# except:
def distChamferCUDA(x, y):
return nn_distance(x, y)
def emd_approx(sample, ref):
B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
assert N == N_ref, "Not sure what would EMD do in this case"
emd = match_cost(sample, ref) # (B,)
emd_norm = emd / float(N) # (B,)
return emd_norm
from scipy.optimize import linear_sum_assignment
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
@ -44,8 +20,53 @@ def distChamfer(a, b):
P = (rx.transpose(2, 1) + ry - 2 * zz)
return P.min(1)[0], P.min(2)[0]
# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/
try:
from .StructuralLosses.nn_distance import nn_distance
def distChamferCUDA(x, y):
return nn_distance(x, y)
except:
print("distChamferCUDA not available; fall back to slower version.")
def distChamferCUDA(x, y):
return distChamfer(x, y)
def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
def emd_approx(x, y):
bs, npts, mpts, dim = x.size(0), x.size(1), y.size(1), x.size(2)
assert npts == mpts, "EMD only works if two point clouds are equal size"
dim = x.shape[-1]
x = x.reshape(bs, npts, 1, dim)
y = y.reshape(bs, 1, mpts, dim)
dist = (x - y).norm(dim=-1, keepdim=False) # (bs, npts, mpts)
emd_lst = []
dist_np = dist.cpu().detach().numpy()
for i in range(bs):
d_i = dist_np[i]
r_idx, c_idx = linear_sum_assignment(d_i)
emd_i = d_i[r_idx, c_idx].mean()
emd_lst.append(emd_i)
emd = np.stack(emd_lst).reshape(-1)
emd_torch = torch.from_numpy(emd).to(x)
return emd_torch
try:
from .StructuralLosses.match_cost import match_cost
def emd_approx_cuda(sample, ref):
B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
assert N == N_ref, "Not sure what would EMD do in this case"
emd = match_cost(sample, ref) # (B,)
emd_norm = emd / float(N) # (B,)
return emd_norm
except:
print("emd_approx_cuda not available. Fall back to slower version.")
def emd_approx_cuda(sample, ref):
return emd_approx(sample, ref)
def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True,
accelerated_emd=False):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
@ -65,7 +86,10 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
dl, dr = distChamfer(sample_batch, ref_batch)
cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))
emd_batch = emd_approx(sample_batch, ref_batch)
if accelerated_emd:
emd_batch = emd_approx_cuda(sample_batch, ref_batch)
else:
emd_batch = emd_approx(sample_batch, ref_batch)
emd_lst.append(emd_batch)
if reduced:
@ -82,7 +106,8 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
return results
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True,
accelerated_emd=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
all_cd = []
@ -101,13 +126,16 @@ def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
sample_batch_exp = sample_batch_exp.contiguous()
if accelerated_cd:
if accelerated_cd and distChamferCUDA is not None:
dl, dr = distChamferCUDA(sample_batch_exp, ref_batch)
else:
dl, dr = distChamfer(sample_batch_exp, ref_batch)
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
emd_batch = emd_approx(sample_batch_exp, ref_batch)
if accelerated_emd:
emd_batch = emd_approx_cuda(sample_batch_exp, ref_batch)
else:
emd_batch = emd_approx(sample_batch_exp, ref_batch)
emd_lst.append(emd_batch.view(1, -1))
cd_lst = torch.cat(cd_lst, dim=1)
@ -172,7 +200,7 @@ def lgan_mmd_cov(all_dist):
def compute_all_metrics(sample_pcs, ref_pcs, batch_size, accelerated_cd=False):
results = {}
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=accelerated_cd)
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
res_cd = lgan_mmd_cov(M_rs_cd.t())
results.update({