Compare commits
10 commits
4dda09e2d0
...
fb4031bb33
Author | SHA1 | Date | |
---|---|---|---|
fb4031bb33 | |||
6d4f2cb9d3 | |||
8b3bceffd7 | |||
b7a9216ffc | |||
fd00fce0da | |||
33ce21ff8c | |||
e15a9ccd03 | |||
372215d427 | |||
470cfc4a77 | |||
8bedce0be6 |
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -19,5 +19,7 @@ metrics/structural_losses/makefile
|
|||
PyMesh
|
||||
checkpoint
|
||||
|
||||
pretrained_models*
|
||||
|
||||
torchdiffeq/
|
||||
demo/
|
||||
|
|
31
README.md
31
README.md
|
@ -2,14 +2,17 @@
|
|||
|
||||
This repository contains a PyTorch implementation of the paper:
|
||||
|
||||
[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](www.arxiv.com).
|
||||
|
||||
[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](https://arxiv.org/abs/1906.12320).
|
||||
<br>
|
||||
[Guandao Yang*](http://www.guandaoyang.com),
|
||||
[Xun Huang*](http://www.cs.cornell.edu/~xhuang/),
|
||||
[Zekun Hao](http://www.cs.cornell.edu/~zekun/),
|
||||
[Ming-Yu Liu](http://mingyuliu.net/),
|
||||
[Serge Belongie](http://blogs.cornell.edu/techfaculty/serge-belongie/),
|
||||
[Bharath Hariharan](http://home.bharathh.info/)
|
||||
(* equal contribution)
|
||||
<br>
|
||||
ICCV 2019 (**Oral**)
|
||||
|
||||
|
||||
## Introduction
|
||||
|
@ -27,18 +30,16 @@ As 3D point clouds become the representation of choice for multiple vision and g
|
|||
* G++ or GCC 5.
|
||||
* [PyTorch](http://pytorch.org/). Codes are tested with version 1.0.1
|
||||
* [torchdiffeq](https://github.com/rtqichen/torchdiffeq).
|
||||
* (Optional) [Tensorboard](https://www.tensorflow.org/) for visualization of training process.
|
||||
* (Optional) [Tensorboard](https://www.tensorflow.org/) for visualization of the training process.
|
||||
|
||||
Following is the suggested way to install these dependencies:
|
||||
```bash
|
||||
# Create a new conda environment
|
||||
conda create -n PointFlow python=3.6
|
||||
conda activate PointFlow
|
||||
conda env create -f environment.yml
|
||||
conda activate PointFlow
|
||||
|
||||
# Install pytorch (please refer to the commend in the official website)
|
||||
conda install pytorch=1.0.1 torchvision cudatoolkit=10.0 -c pytorch -y
|
||||
|
||||
# Install other dependencies such as torchdiffeq, structural losses, etc.
|
||||
# Compile structural losses, etc., but this step is not required
|
||||
# as there is a version of EMD/CD not requiring the CUDA kernel.
|
||||
./install.sh
|
||||
```
|
||||
|
||||
|
@ -53,7 +54,7 @@ cd data
|
|||
unzip ShapeNetCore.v2.PC15k.zip
|
||||
```
|
||||
|
||||
Please contact us if you need point clouds for ModelNet dataset.
|
||||
Please contact us if you need point clouds for the ModelNet dataset.
|
||||
|
||||
## Training
|
||||
|
||||
|
@ -71,7 +72,7 @@ Example training scripts can be found in `scripts/` folder.
|
|||
## Pre-trained models and test
|
||||
|
||||
Pretrained models can be downloaded from this [link](https://drive.google.com/file/d/1dcxjuuKiAXZxhiyWD_o_7Owx8Y3FbRHG/view?usp=sharing).
|
||||
Following is the suggested way to evaluate the performance of the pre-trained models.
|
||||
The following is the suggested way to evaluate the performance of the pre-trained models.
|
||||
```bash
|
||||
unzip pretrained_models.zip; # This will create a folder named pretrained_models
|
||||
|
||||
|
@ -87,16 +88,20 @@ CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_gen_test.sh
|
|||
|
||||
## Demo
|
||||
|
||||
The demo relies on [Open3D](http://www.open3d.org/). Following is the suggested way to install it:
|
||||
The demo relies on [Open3D](http://www.open3d.org/). The following is the suggested way to install it:
|
||||
```bash
|
||||
conda install -c open3d-admin open3d
|
||||
```
|
||||
The demo will sample shapes from a pre-trained model, save those shapes under the `demo` folder, and visualize those point clouds.
|
||||
Once this dependency is in place, you can use the following script to use the demo for the pre-trained model for airplanes:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_demo.py
|
||||
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_demo.sh
|
||||
```
|
||||
|
||||
## Point cloud rendering
|
||||
|
||||
Please refer to the following github repository for our point cloud rendering code: https://github.com/zekunhao1995/PointFlowRenderer.
|
||||
|
||||
## Cite
|
||||
Please cite our work if you find it useful:
|
||||
```latex
|
||||
|
|
28
demo.py
28
demo.py
|
@ -1,4 +1,4 @@
|
|||
import open3d as o3d
|
||||
# import open3d as o3d
|
||||
from datasets import get_datasets
|
||||
from args import get_args
|
||||
from models.networks import PointFlow
|
||||
|
@ -47,12 +47,32 @@ def main(args):
|
|||
np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs)
|
||||
|
||||
# Visualize the demo
|
||||
pcl = o3d.geometry.PointCloud()
|
||||
# pcl = o3d.geometry.PointCloud()
|
||||
# for i in range(int(sample_pcs.shape[0])):
|
||||
# print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
|
||||
# pts = sample_pcs[i].reshape(-1, 3)
|
||||
# pcl.points = o3d.utility.Vector3dVector(pts)
|
||||
# o3d.visualization.draw_geometries([pcl])
|
||||
|
||||
# Visualize the demo using matplotlib, each point cloud in a different figure
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib
|
||||
matplotlib.use('TkAgg')
|
||||
|
||||
for i in range(int(sample_pcs.shape[0])):
|
||||
print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
|
||||
pts = sample_pcs[i].reshape(-1, 3)
|
||||
pcl.points = o3d.utility.Vector3dVector(pts)
|
||||
o3d.visualization.draw_geometries([pcl])
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], c=pts[:, 2], cmap=cm.jet)
|
||||
ax.set_aspect('equal')
|
||||
ax.set_xlim(-1, 1)
|
||||
ax.set_ylim(-1, 1)
|
||||
ax.set_zlim(-1, 1)
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -40,6 +40,7 @@
|
|||
<a class="col-md-6 col-xs-6" href="http://home.bharathh.info/"><span>Bharath Hariharan</span></a>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- affiliations -->
|
||||
<div class='row mt-1 mt-2' >
|
||||
<div class='col text-center'>
|
||||
|
@ -51,6 +52,10 @@
|
|||
</div>
|
||||
</div>
|
||||
|
||||
<div class='row text-center h5 font-weight-light pl-4 pr-4 mb-4'>
|
||||
<span class="col-md-12 col-xs-12" style="color:#007bff">*Equal contribution</span>
|
||||
</div>
|
||||
|
||||
<div class='row justify-content-center' style="position: relative; width: 100%;height: 0;padding-bottom: 33%; margin:0">
|
||||
<img
|
||||
class="img-fluid rounded mx-auto d-block"
|
||||
|
|
109
environment.yml
Normal file
109
environment.yml
Normal file
|
@ -0,0 +1,109 @@
|
|||
name: PointFlow
|
||||
|
||||
channels:
|
||||
- nodefaults
|
||||
- drti
|
||||
- pytorch
|
||||
- nvidia
|
||||
- conda-forge
|
||||
- pyg
|
||||
|
||||
dependencies:
|
||||
- _libgcc_mutex
|
||||
- _openmp_mutex
|
||||
- blas
|
||||
- brotli
|
||||
- brotlipy
|
||||
- bzip2
|
||||
- ca-certificates
|
||||
- certifi
|
||||
- cffi
|
||||
- charset-normalizer
|
||||
- colorama
|
||||
- cryptography
|
||||
- cudatoolkit
|
||||
- cycler
|
||||
- dbus
|
||||
- expat
|
||||
- ffmpeg
|
||||
- fontconfig
|
||||
- fonttools
|
||||
- freetype
|
||||
- giflib
|
||||
- glib
|
||||
- gmp
|
||||
- gnutls
|
||||
- gst-plugins-base
|
||||
- gstreamer
|
||||
- icu
|
||||
- idna
|
||||
- joblib
|
||||
- jpeg
|
||||
- kiwisolver
|
||||
- lame
|
||||
- lcms2
|
||||
- ld_impl_linux-64
|
||||
- libffi
|
||||
- libgcc-ng
|
||||
- libgfortran-ng
|
||||
- libgfortran4
|
||||
- libgomp
|
||||
- libiconv
|
||||
- libidn2
|
||||
- libpng
|
||||
- libstdcxx-ng
|
||||
- libtasn1
|
||||
- libtiff
|
||||
- libunistring
|
||||
- libuuid
|
||||
- libuv
|
||||
- libwebp
|
||||
- libwebp-base
|
||||
- libxcb
|
||||
- libxml2
|
||||
- lz4-c
|
||||
- matplotlib
|
||||
- matplotlib-base
|
||||
- mkl
|
||||
- mkl-service
|
||||
- mkl_fft
|
||||
- mkl_random
|
||||
- munkres
|
||||
- ncurses
|
||||
- nettle
|
||||
- openh264
|
||||
- openssl
|
||||
- packaging
|
||||
- pcre
|
||||
- pip
|
||||
- pycparser
|
||||
- pyopenssl
|
||||
- pyparsing
|
||||
- pyqt
|
||||
- pysocks
|
||||
- python
|
||||
- python-dateutil
|
||||
- pytorch
|
||||
- pytorch-mutex
|
||||
- qt
|
||||
- readline
|
||||
- requests
|
||||
- scikit-learn
|
||||
- scipy
|
||||
- setuptools
|
||||
- sip
|
||||
- six
|
||||
- sqlite
|
||||
- threadpoolctl
|
||||
- tk
|
||||
- torchaudio
|
||||
- torchvision
|
||||
- tornado
|
||||
- tqdm
|
||||
- typing_extensions
|
||||
- urllib3
|
||||
- wheel
|
||||
- xz
|
||||
- zlib
|
||||
- zstd
|
||||
- torchdiffeq
|
19
install.sh
19
install.sh
|
@ -1,19 +1,18 @@
|
|||
#! /bin/bash
|
||||
|
||||
root=`pwd`
|
||||
|
||||
# Install dependecies
|
||||
conda install numpy matplotlib pillow scipy tqdm scikit-learn -y
|
||||
pip install tensorflow-gpu==1.13.1
|
||||
pip install tensorboardX==1.7
|
||||
# conda install pytorch==1.0.1 torchvision==0.2.2 cudatoolkit=10.0 -c pytorch
|
||||
# conda install matplotlib tqdm scikit-learn -y
|
||||
# pip install pillow==5.0.0
|
||||
# pip install scipy==1.0.1
|
||||
# pip install numpy==1.16.4
|
||||
# pip install tensorflow-gpu==1.13.1
|
||||
# pip install tensorboardX==1.7
|
||||
# pip install torchdiffeq==0.0.1
|
||||
|
||||
# Compile CUDA kernel for CD/EMD loss
|
||||
root=`pwd`
|
||||
cd metrics/pytorch_structural_losses/
|
||||
make clean
|
||||
make
|
||||
cd $root
|
||||
|
||||
# install torchdiffeq
|
||||
git clone https://github.com/rtqichen/torchdiffeq.git
|
||||
cd torchdiffeq
|
||||
pip install -e .
|
||||
|
|
|
@ -4,31 +4,7 @@ import warnings
|
|||
from scipy.stats import entropy
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
from numpy.linalg import norm
|
||||
|
||||
# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/
|
||||
from .StructuralLosses.match_cost import match_cost
|
||||
from .StructuralLosses.nn_distance import nn_distance
|
||||
|
||||
|
||||
# # Import CUDA version of CD, borrowed from https://github.com/ThibaultGROUEIX/AtlasNet
|
||||
# try:
|
||||
# from . chamfer_distance_ext.dist_chamfer import chamferDist
|
||||
# CD = chamferDist()
|
||||
# def distChamferCUDA(x,y):
|
||||
# return CD(x,y,gpu)
|
||||
# except:
|
||||
|
||||
|
||||
def distChamferCUDA(x, y):
|
||||
return nn_distance(x, y)
|
||||
|
||||
|
||||
def emd_approx(sample, ref):
|
||||
B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
|
||||
assert N == N_ref, "Not sure what would EMD do in this case"
|
||||
emd = match_cost(sample, ref) # (B,)
|
||||
emd_norm = emd / float(N) # (B,)
|
||||
return emd_norm
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
|
||||
|
@ -44,8 +20,53 @@ def distChamfer(a, b):
|
|||
P = (rx.transpose(2, 1) + ry - 2 * zz)
|
||||
return P.min(1)[0], P.min(2)[0]
|
||||
|
||||
# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/
|
||||
try:
|
||||
from .StructuralLosses.nn_distance import nn_distance
|
||||
def distChamferCUDA(x, y):
|
||||
return nn_distance(x, y)
|
||||
except:
|
||||
print("distChamferCUDA not available; fall back to slower version.")
|
||||
def distChamferCUDA(x, y):
|
||||
return distChamfer(x, y)
|
||||
|
||||
def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
|
||||
|
||||
def emd_approx(x, y):
|
||||
bs, npts, mpts, dim = x.size(0), x.size(1), y.size(1), x.size(2)
|
||||
assert npts == mpts, "EMD only works if two point clouds are equal size"
|
||||
dim = x.shape[-1]
|
||||
x = x.reshape(bs, npts, 1, dim)
|
||||
y = y.reshape(bs, 1, mpts, dim)
|
||||
dist = (x - y).norm(dim=-1, keepdim=False) # (bs, npts, mpts)
|
||||
|
||||
emd_lst = []
|
||||
dist_np = dist.cpu().detach().numpy()
|
||||
for i in range(bs):
|
||||
d_i = dist_np[i]
|
||||
r_idx, c_idx = linear_sum_assignment(d_i)
|
||||
emd_i = d_i[r_idx, c_idx].mean()
|
||||
emd_lst.append(emd_i)
|
||||
emd = np.stack(emd_lst).reshape(-1)
|
||||
emd_torch = torch.from_numpy(emd).to(x)
|
||||
return emd_torch
|
||||
|
||||
|
||||
try:
|
||||
from .StructuralLosses.match_cost import match_cost
|
||||
def emd_approx_cuda(sample, ref):
|
||||
B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
|
||||
assert N == N_ref, "Not sure what would EMD do in this case"
|
||||
emd = match_cost(sample, ref) # (B,)
|
||||
emd_norm = emd / float(N) # (B,)
|
||||
return emd_norm
|
||||
except:
|
||||
print("emd_approx_cuda not available. Fall back to slower version.")
|
||||
def emd_approx_cuda(sample, ref):
|
||||
return emd_approx(sample, ref)
|
||||
|
||||
|
||||
def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True,
|
||||
accelerated_emd=False):
|
||||
N_sample = sample_pcs.shape[0]
|
||||
N_ref = ref_pcs.shape[0]
|
||||
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
|
||||
|
@ -65,7 +86,10 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
|
|||
dl, dr = distChamfer(sample_batch, ref_batch)
|
||||
cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))
|
||||
|
||||
emd_batch = emd_approx(sample_batch, ref_batch)
|
||||
if accelerated_emd:
|
||||
emd_batch = emd_approx_cuda(sample_batch, ref_batch)
|
||||
else:
|
||||
emd_batch = emd_approx(sample_batch, ref_batch)
|
||||
emd_lst.append(emd_batch)
|
||||
|
||||
if reduced:
|
||||
|
@ -82,7 +106,8 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
|
|||
return results
|
||||
|
||||
|
||||
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
|
||||
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True,
|
||||
accelerated_emd=True):
|
||||
N_sample = sample_pcs.shape[0]
|
||||
N_ref = ref_pcs.shape[0]
|
||||
all_cd = []
|
||||
|
@ -101,13 +126,16 @@ def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
|
|||
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
|
||||
sample_batch_exp = sample_batch_exp.contiguous()
|
||||
|
||||
if accelerated_cd:
|
||||
if accelerated_cd and distChamferCUDA is not None:
|
||||
dl, dr = distChamferCUDA(sample_batch_exp, ref_batch)
|
||||
else:
|
||||
dl, dr = distChamfer(sample_batch_exp, ref_batch)
|
||||
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
|
||||
|
||||
emd_batch = emd_approx(sample_batch_exp, ref_batch)
|
||||
if accelerated_emd:
|
||||
emd_batch = emd_approx_cuda(sample_batch_exp, ref_batch)
|
||||
else:
|
||||
emd_batch = emd_approx(sample_batch_exp, ref_batch)
|
||||
emd_lst.append(emd_batch.view(1, -1))
|
||||
|
||||
cd_lst = torch.cat(cd_lst, dim=1)
|
||||
|
|
Loading…
Reference in a new issue