Compare commits
No commits in common. "fb4031bb335c5ba0c46b5f49c59bb24b471cb832" and "4dda09e2d09d8382f5007b4038198ee2134438f5" have entirely different histories.
fb4031bb33
...
4dda09e2d0
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -19,7 +19,5 @@ metrics/structural_losses/makefile
|
||||||
PyMesh
|
PyMesh
|
||||||
checkpoint
|
checkpoint
|
||||||
|
|
||||||
pretrained_models*
|
|
||||||
|
|
||||||
torchdiffeq/
|
torchdiffeq/
|
||||||
demo/
|
demo/
|
||||||
|
|
31
README.md
31
README.md
|
@ -2,17 +2,14 @@
|
||||||
|
|
||||||
This repository contains a PyTorch implementation of the paper:
|
This repository contains a PyTorch implementation of the paper:
|
||||||
|
|
||||||
[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](https://arxiv.org/abs/1906.12320).
|
[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](www.arxiv.com).
|
||||||
<br>
|
|
||||||
[Guandao Yang*](http://www.guandaoyang.com),
|
[Guandao Yang*](http://www.guandaoyang.com),
|
||||||
[Xun Huang*](http://www.cs.cornell.edu/~xhuang/),
|
[Xun Huang*](http://www.cs.cornell.edu/~xhuang/),
|
||||||
[Zekun Hao](http://www.cs.cornell.edu/~zekun/),
|
[Zekun Hao](http://www.cs.cornell.edu/~zekun/),
|
||||||
[Ming-Yu Liu](http://mingyuliu.net/),
|
[Ming-Yu Liu](http://mingyuliu.net/),
|
||||||
[Serge Belongie](http://blogs.cornell.edu/techfaculty/serge-belongie/),
|
[Serge Belongie](http://blogs.cornell.edu/techfaculty/serge-belongie/),
|
||||||
[Bharath Hariharan](http://home.bharathh.info/)
|
[Bharath Hariharan](http://home.bharathh.info/)
|
||||||
(* equal contribution)
|
|
||||||
<br>
|
|
||||||
ICCV 2019 (**Oral**)
|
|
||||||
|
|
||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
@ -30,16 +27,18 @@ As 3D point clouds become the representation of choice for multiple vision and g
|
||||||
* G++ or GCC 5.
|
* G++ or GCC 5.
|
||||||
* [PyTorch](http://pytorch.org/). Codes are tested with version 1.0.1
|
* [PyTorch](http://pytorch.org/). Codes are tested with version 1.0.1
|
||||||
* [torchdiffeq](https://github.com/rtqichen/torchdiffeq).
|
* [torchdiffeq](https://github.com/rtqichen/torchdiffeq).
|
||||||
* (Optional) [Tensorboard](https://www.tensorflow.org/) for visualization of the training process.
|
* (Optional) [Tensorboard](https://www.tensorflow.org/) for visualization of training process.
|
||||||
|
|
||||||
Following is the suggested way to install these dependencies:
|
Following is the suggested way to install these dependencies:
|
||||||
```bash
|
```bash
|
||||||
# Create a new conda environment
|
# Create a new conda environment
|
||||||
conda env create -f environment.yml
|
conda create -n PointFlow python=3.6
|
||||||
conda activate PointFlow
|
conda activate PointFlow
|
||||||
|
|
||||||
# Compile structural losses, etc., but this step is not required
|
# Install pytorch (please refer to the commend in the official website)
|
||||||
# as there is a version of EMD/CD not requiring the CUDA kernel.
|
conda install pytorch=1.0.1 torchvision cudatoolkit=10.0 -c pytorch -y
|
||||||
|
|
||||||
|
# Install other dependencies such as torchdiffeq, structural losses, etc.
|
||||||
./install.sh
|
./install.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -54,7 +53,7 @@ cd data
|
||||||
unzip ShapeNetCore.v2.PC15k.zip
|
unzip ShapeNetCore.v2.PC15k.zip
|
||||||
```
|
```
|
||||||
|
|
||||||
Please contact us if you need point clouds for the ModelNet dataset.
|
Please contact us if you need point clouds for ModelNet dataset.
|
||||||
|
|
||||||
## Training
|
## Training
|
||||||
|
|
||||||
|
@ -72,7 +71,7 @@ Example training scripts can be found in `scripts/` folder.
|
||||||
## Pre-trained models and test
|
## Pre-trained models and test
|
||||||
|
|
||||||
Pretrained models can be downloaded from this [link](https://drive.google.com/file/d/1dcxjuuKiAXZxhiyWD_o_7Owx8Y3FbRHG/view?usp=sharing).
|
Pretrained models can be downloaded from this [link](https://drive.google.com/file/d/1dcxjuuKiAXZxhiyWD_o_7Owx8Y3FbRHG/view?usp=sharing).
|
||||||
The following is the suggested way to evaluate the performance of the pre-trained models.
|
Following is the suggested way to evaluate the performance of the pre-trained models.
|
||||||
```bash
|
```bash
|
||||||
unzip pretrained_models.zip; # This will create a folder named pretrained_models
|
unzip pretrained_models.zip; # This will create a folder named pretrained_models
|
||||||
|
|
||||||
|
@ -88,20 +87,16 @@ CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_gen_test.sh
|
||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
|
|
||||||
The demo relies on [Open3D](http://www.open3d.org/). The following is the suggested way to install it:
|
The demo relies on [Open3D](http://www.open3d.org/). Following is the suggested way to install it:
|
||||||
```bash
|
```bash
|
||||||
conda install -c open3d-admin open3d
|
conda install -c open3d-admin open3d
|
||||||
```
|
```
|
||||||
The demo will sample shapes from a pre-trained model, save those shapes under the `demo` folder, and visualize those point clouds.
|
The demo will sample shapes from a pre-trained model, save those shapes under the `demo` folder, and visualize those point clouds.
|
||||||
Once this dependency is in place, you can use the following script to use the demo for the pre-trained model for airplanes:
|
Once this dependency is in place, you can use the following script to use the demo for the pre-trained model for airplanes:
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_demo.sh
|
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_demo.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Point cloud rendering
|
|
||||||
|
|
||||||
Please refer to the following github repository for our point cloud rendering code: https://github.com/zekunhao1995/PointFlowRenderer.
|
|
||||||
|
|
||||||
## Cite
|
## Cite
|
||||||
Please cite our work if you find it useful:
|
Please cite our work if you find it useful:
|
||||||
```latex
|
```latex
|
||||||
|
|
28
demo.py
28
demo.py
|
@ -1,4 +1,4 @@
|
||||||
# import open3d as o3d
|
import open3d as o3d
|
||||||
from datasets import get_datasets
|
from datasets import get_datasets
|
||||||
from args import get_args
|
from args import get_args
|
||||||
from models.networks import PointFlow
|
from models.networks import PointFlow
|
||||||
|
@ -47,32 +47,12 @@ def main(args):
|
||||||
np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs)
|
np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs)
|
||||||
|
|
||||||
# Visualize the demo
|
# Visualize the demo
|
||||||
# pcl = o3d.geometry.PointCloud()
|
pcl = o3d.geometry.PointCloud()
|
||||||
# for i in range(int(sample_pcs.shape[0])):
|
|
||||||
# print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
|
|
||||||
# pts = sample_pcs[i].reshape(-1, 3)
|
|
||||||
# pcl.points = o3d.utility.Vector3dVector(pts)
|
|
||||||
# o3d.visualization.draw_geometries([pcl])
|
|
||||||
|
|
||||||
# Visualize the demo using matplotlib, each point cloud in a different figure
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import matplotlib.cm as cm
|
|
||||||
import matplotlib
|
|
||||||
matplotlib.use('TkAgg')
|
|
||||||
|
|
||||||
for i in range(int(sample_pcs.shape[0])):
|
for i in range(int(sample_pcs.shape[0])):
|
||||||
print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
|
print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
|
||||||
pts = sample_pcs[i].reshape(-1, 3)
|
pts = sample_pcs[i].reshape(-1, 3)
|
||||||
fig = plt.figure()
|
pcl.points = o3d.utility.Vector3dVector(pts)
|
||||||
ax = fig.add_subplot(111, projection='3d')
|
o3d.visualization.draw_geometries([pcl])
|
||||||
ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], c=pts[:, 2], cmap=cm.jet)
|
|
||||||
ax.set_aspect('equal')
|
|
||||||
ax.set_xlim(-1, 1)
|
|
||||||
ax.set_ylim(-1, 1)
|
|
||||||
ax.set_zlim(-1, 1)
|
|
||||||
plt.show()
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -40,7 +40,6 @@
|
||||||
<a class="col-md-6 col-xs-6" href="http://home.bharathh.info/"><span>Bharath Hariharan</span></a>
|
<a class="col-md-6 col-xs-6" href="http://home.bharathh.info/"><span>Bharath Hariharan</span></a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
<!-- affiliations -->
|
<!-- affiliations -->
|
||||||
<div class='row mt-1 mt-2' >
|
<div class='row mt-1 mt-2' >
|
||||||
<div class='col text-center'>
|
<div class='col text-center'>
|
||||||
|
@ -52,10 +51,6 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class='row text-center h5 font-weight-light pl-4 pr-4 mb-4'>
|
|
||||||
<span class="col-md-12 col-xs-12" style="color:#007bff">*Equal contribution</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class='row justify-content-center' style="position: relative; width: 100%;height: 0;padding-bottom: 33%; margin:0">
|
<div class='row justify-content-center' style="position: relative; width: 100%;height: 0;padding-bottom: 33%; margin:0">
|
||||||
<img
|
<img
|
||||||
class="img-fluid rounded mx-auto d-block"
|
class="img-fluid rounded mx-auto d-block"
|
||||||
|
|
109
environment.yml
109
environment.yml
|
@ -1,109 +0,0 @@
|
||||||
name: PointFlow
|
|
||||||
|
|
||||||
channels:
|
|
||||||
- nodefaults
|
|
||||||
- drti
|
|
||||||
- pytorch
|
|
||||||
- nvidia
|
|
||||||
- conda-forge
|
|
||||||
- pyg
|
|
||||||
|
|
||||||
dependencies:
|
|
||||||
- _libgcc_mutex
|
|
||||||
- _openmp_mutex
|
|
||||||
- blas
|
|
||||||
- brotli
|
|
||||||
- brotlipy
|
|
||||||
- bzip2
|
|
||||||
- ca-certificates
|
|
||||||
- certifi
|
|
||||||
- cffi
|
|
||||||
- charset-normalizer
|
|
||||||
- colorama
|
|
||||||
- cryptography
|
|
||||||
- cudatoolkit
|
|
||||||
- cycler
|
|
||||||
- dbus
|
|
||||||
- expat
|
|
||||||
- ffmpeg
|
|
||||||
- fontconfig
|
|
||||||
- fonttools
|
|
||||||
- freetype
|
|
||||||
- giflib
|
|
||||||
- glib
|
|
||||||
- gmp
|
|
||||||
- gnutls
|
|
||||||
- gst-plugins-base
|
|
||||||
- gstreamer
|
|
||||||
- icu
|
|
||||||
- idna
|
|
||||||
- joblib
|
|
||||||
- jpeg
|
|
||||||
- kiwisolver
|
|
||||||
- lame
|
|
||||||
- lcms2
|
|
||||||
- ld_impl_linux-64
|
|
||||||
- libffi
|
|
||||||
- libgcc-ng
|
|
||||||
- libgfortran-ng
|
|
||||||
- libgfortran4
|
|
||||||
- libgomp
|
|
||||||
- libiconv
|
|
||||||
- libidn2
|
|
||||||
- libpng
|
|
||||||
- libstdcxx-ng
|
|
||||||
- libtasn1
|
|
||||||
- libtiff
|
|
||||||
- libunistring
|
|
||||||
- libuuid
|
|
||||||
- libuv
|
|
||||||
- libwebp
|
|
||||||
- libwebp-base
|
|
||||||
- libxcb
|
|
||||||
- libxml2
|
|
||||||
- lz4-c
|
|
||||||
- matplotlib
|
|
||||||
- matplotlib-base
|
|
||||||
- mkl
|
|
||||||
- mkl-service
|
|
||||||
- mkl_fft
|
|
||||||
- mkl_random
|
|
||||||
- munkres
|
|
||||||
- ncurses
|
|
||||||
- nettle
|
|
||||||
- openh264
|
|
||||||
- openssl
|
|
||||||
- packaging
|
|
||||||
- pcre
|
|
||||||
- pip
|
|
||||||
- pycparser
|
|
||||||
- pyopenssl
|
|
||||||
- pyparsing
|
|
||||||
- pyqt
|
|
||||||
- pysocks
|
|
||||||
- python
|
|
||||||
- python-dateutil
|
|
||||||
- pytorch
|
|
||||||
- pytorch-mutex
|
|
||||||
- qt
|
|
||||||
- readline
|
|
||||||
- requests
|
|
||||||
- scikit-learn
|
|
||||||
- scipy
|
|
||||||
- setuptools
|
|
||||||
- sip
|
|
||||||
- six
|
|
||||||
- sqlite
|
|
||||||
- threadpoolctl
|
|
||||||
- tk
|
|
||||||
- torchaudio
|
|
||||||
- torchvision
|
|
||||||
- tornado
|
|
||||||
- tqdm
|
|
||||||
- typing_extensions
|
|
||||||
- urllib3
|
|
||||||
- wheel
|
|
||||||
- xz
|
|
||||||
- zlib
|
|
||||||
- zstd
|
|
||||||
- torchdiffeq
|
|
19
install.sh
19
install.sh
|
@ -1,18 +1,19 @@
|
||||||
#! /bin/bash
|
#! /bin/bash
|
||||||
|
|
||||||
|
root=`pwd`
|
||||||
|
|
||||||
# Install dependecies
|
# Install dependecies
|
||||||
# conda install pytorch==1.0.1 torchvision==0.2.2 cudatoolkit=10.0 -c pytorch
|
conda install numpy matplotlib pillow scipy tqdm scikit-learn -y
|
||||||
# conda install matplotlib tqdm scikit-learn -y
|
pip install tensorflow-gpu==1.13.1
|
||||||
# pip install pillow==5.0.0
|
pip install tensorboardX==1.7
|
||||||
# pip install scipy==1.0.1
|
|
||||||
# pip install numpy==1.16.4
|
|
||||||
# pip install tensorflow-gpu==1.13.1
|
|
||||||
# pip install tensorboardX==1.7
|
|
||||||
# pip install torchdiffeq==0.0.1
|
|
||||||
|
|
||||||
# Compile CUDA kernel for CD/EMD loss
|
# Compile CUDA kernel for CD/EMD loss
|
||||||
root=`pwd`
|
|
||||||
cd metrics/pytorch_structural_losses/
|
cd metrics/pytorch_structural_losses/
|
||||||
make clean
|
make clean
|
||||||
make
|
make
|
||||||
cd $root
|
cd $root
|
||||||
|
|
||||||
|
# install torchdiffeq
|
||||||
|
git clone https://github.com/rtqichen/torchdiffeq.git
|
||||||
|
cd torchdiffeq
|
||||||
|
pip install -e .
|
||||||
|
|
|
@ -4,7 +4,31 @@ import warnings
|
||||||
from scipy.stats import entropy
|
from scipy.stats import entropy
|
||||||
from sklearn.neighbors import NearestNeighbors
|
from sklearn.neighbors import NearestNeighbors
|
||||||
from numpy.linalg import norm
|
from numpy.linalg import norm
|
||||||
from scipy.optimize import linear_sum_assignment
|
|
||||||
|
# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/
|
||||||
|
from .StructuralLosses.match_cost import match_cost
|
||||||
|
from .StructuralLosses.nn_distance import nn_distance
|
||||||
|
|
||||||
|
|
||||||
|
# # Import CUDA version of CD, borrowed from https://github.com/ThibaultGROUEIX/AtlasNet
|
||||||
|
# try:
|
||||||
|
# from . chamfer_distance_ext.dist_chamfer import chamferDist
|
||||||
|
# CD = chamferDist()
|
||||||
|
# def distChamferCUDA(x,y):
|
||||||
|
# return CD(x,y,gpu)
|
||||||
|
# except:
|
||||||
|
|
||||||
|
|
||||||
|
def distChamferCUDA(x, y):
|
||||||
|
return nn_distance(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def emd_approx(sample, ref):
|
||||||
|
B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
|
||||||
|
assert N == N_ref, "Not sure what would EMD do in this case"
|
||||||
|
emd = match_cost(sample, ref) # (B,)
|
||||||
|
emd_norm = emd / float(N) # (B,)
|
||||||
|
return emd_norm
|
||||||
|
|
||||||
|
|
||||||
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
|
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
|
||||||
|
@ -20,53 +44,8 @@ def distChamfer(a, b):
|
||||||
P = (rx.transpose(2, 1) + ry - 2 * zz)
|
P = (rx.transpose(2, 1) + ry - 2 * zz)
|
||||||
return P.min(1)[0], P.min(2)[0]
|
return P.min(1)[0], P.min(2)[0]
|
||||||
|
|
||||||
# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/
|
|
||||||
try:
|
|
||||||
from .StructuralLosses.nn_distance import nn_distance
|
|
||||||
def distChamferCUDA(x, y):
|
|
||||||
return nn_distance(x, y)
|
|
||||||
except:
|
|
||||||
print("distChamferCUDA not available; fall back to slower version.")
|
|
||||||
def distChamferCUDA(x, y):
|
|
||||||
return distChamfer(x, y)
|
|
||||||
|
|
||||||
|
def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
|
||||||
def emd_approx(x, y):
|
|
||||||
bs, npts, mpts, dim = x.size(0), x.size(1), y.size(1), x.size(2)
|
|
||||||
assert npts == mpts, "EMD only works if two point clouds are equal size"
|
|
||||||
dim = x.shape[-1]
|
|
||||||
x = x.reshape(bs, npts, 1, dim)
|
|
||||||
y = y.reshape(bs, 1, mpts, dim)
|
|
||||||
dist = (x - y).norm(dim=-1, keepdim=False) # (bs, npts, mpts)
|
|
||||||
|
|
||||||
emd_lst = []
|
|
||||||
dist_np = dist.cpu().detach().numpy()
|
|
||||||
for i in range(bs):
|
|
||||||
d_i = dist_np[i]
|
|
||||||
r_idx, c_idx = linear_sum_assignment(d_i)
|
|
||||||
emd_i = d_i[r_idx, c_idx].mean()
|
|
||||||
emd_lst.append(emd_i)
|
|
||||||
emd = np.stack(emd_lst).reshape(-1)
|
|
||||||
emd_torch = torch.from_numpy(emd).to(x)
|
|
||||||
return emd_torch
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .StructuralLosses.match_cost import match_cost
|
|
||||||
def emd_approx_cuda(sample, ref):
|
|
||||||
B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
|
|
||||||
assert N == N_ref, "Not sure what would EMD do in this case"
|
|
||||||
emd = match_cost(sample, ref) # (B,)
|
|
||||||
emd_norm = emd / float(N) # (B,)
|
|
||||||
return emd_norm
|
|
||||||
except:
|
|
||||||
print("emd_approx_cuda not available. Fall back to slower version.")
|
|
||||||
def emd_approx_cuda(sample, ref):
|
|
||||||
return emd_approx(sample, ref)
|
|
||||||
|
|
||||||
|
|
||||||
def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True,
|
|
||||||
accelerated_emd=False):
|
|
||||||
N_sample = sample_pcs.shape[0]
|
N_sample = sample_pcs.shape[0]
|
||||||
N_ref = ref_pcs.shape[0]
|
N_ref = ref_pcs.shape[0]
|
||||||
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
|
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
|
||||||
|
@ -86,10 +65,7 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True,
|
||||||
dl, dr = distChamfer(sample_batch, ref_batch)
|
dl, dr = distChamfer(sample_batch, ref_batch)
|
||||||
cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))
|
cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))
|
||||||
|
|
||||||
if accelerated_emd:
|
emd_batch = emd_approx(sample_batch, ref_batch)
|
||||||
emd_batch = emd_approx_cuda(sample_batch, ref_batch)
|
|
||||||
else:
|
|
||||||
emd_batch = emd_approx(sample_batch, ref_batch)
|
|
||||||
emd_lst.append(emd_batch)
|
emd_lst.append(emd_batch)
|
||||||
|
|
||||||
if reduced:
|
if reduced:
|
||||||
|
@ -106,8 +82,7 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True,
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True,
|
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
|
||||||
accelerated_emd=True):
|
|
||||||
N_sample = sample_pcs.shape[0]
|
N_sample = sample_pcs.shape[0]
|
||||||
N_ref = ref_pcs.shape[0]
|
N_ref = ref_pcs.shape[0]
|
||||||
all_cd = []
|
all_cd = []
|
||||||
|
@ -126,16 +101,13 @@ def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True,
|
||||||
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
|
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
|
||||||
sample_batch_exp = sample_batch_exp.contiguous()
|
sample_batch_exp = sample_batch_exp.contiguous()
|
||||||
|
|
||||||
if accelerated_cd and distChamferCUDA is not None:
|
if accelerated_cd:
|
||||||
dl, dr = distChamferCUDA(sample_batch_exp, ref_batch)
|
dl, dr = distChamferCUDA(sample_batch_exp, ref_batch)
|
||||||
else:
|
else:
|
||||||
dl, dr = distChamfer(sample_batch_exp, ref_batch)
|
dl, dr = distChamfer(sample_batch_exp, ref_batch)
|
||||||
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
|
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
|
||||||
|
|
||||||
if accelerated_emd:
|
emd_batch = emd_approx(sample_batch_exp, ref_batch)
|
||||||
emd_batch = emd_approx_cuda(sample_batch_exp, ref_batch)
|
|
||||||
else:
|
|
||||||
emd_batch = emd_approx(sample_batch_exp, ref_batch)
|
|
||||||
emd_lst.append(emd_batch.view(1, -1))
|
emd_lst.append(emd_batch.view(1, -1))
|
||||||
|
|
||||||
cd_lst = torch.cat(cd_lst, dim=1)
|
cd_lst = torch.cat(cd_lst, dim=1)
|
||||||
|
|
Loading…
Reference in a new issue