Add training/testing/demo codes

This commit is contained in:
Guandao Yang 2019-07-13 21:32:26 -07:00
parent f1581cf0f5
commit 4e6795a4d9
41 changed files with 3785 additions and 103 deletions

113
.gitignore vendored
View file

@ -1,108 +1,23 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*__pycache__
.idea/
.DS_Store
*/.DS_Store
*/*/.DS_Store
*.pyc
data
# C extensions
*.so
scratch.py
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
*.m
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
checkpoints
runs
# celery beat schedule file
celerybeat-schedule
metrics/structural_losses/*.so
metrics/structural_losses/*.cu.o
metrics/structural_losses/makefile
PyMesh
checkpoint
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
torchdiffeq/
demo/

View file

@ -2,7 +2,7 @@
This repository contains a PyTorch implementation of the paper:
[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](https://arxiv.org/abs/1906.12320).
[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](www.arxiv.com).
[Guandao Yang*](http://www.guandaoyang.com),
[Xun Huang*](http://www.cs.cornell.edu/~xhuang/),
@ -11,9 +11,6 @@ This repository contains a PyTorch implementation of the paper:
[Serge Belongie](http://blogs.cornell.edu/techfaculty/serge-belongie/),
[Bharath Hariharan](http://home.bharathh.info/)
**The code will be available soon!**
[\[Project page\]](https://www.guandaoyang.com/PointFlow/) [\[Video\]](https://www.youtube.com/watch?v=jqBiv77xC0M)
## Introduction
@ -23,3 +20,90 @@ As 3D point clouds become the representation of choice for multiple vision and g
<p float="left">
<img src="docs/assets/teaser.gif" height="256"/>
</p>
## Dependencies
* Python 3.6
* CUDA 10.0.
* G++ or GCC 5.
* [PyTorch](http://pytorch.org/). Codes are tested with version 1.0.1
* [torchdiffeq](https://github.com/rtqichen/torchdiffeq).
* (Optional) [Tensorboard](https://www.tensorflow.org/) for visualization of training process.
Following is the suggested way to install these dependencies:
```bash
# Create a new conda environment
conda create -n PointFlow python=3.6
conda activate PointFlow
# Install pytorch (please refer to the commend in the official website)
conda install pytorch=1.0.1 torchvision cudatoolkit=10.0 -c pytorch -y
# Install other dependencies such as torchdiffeq, structural losses, etc.
./install.sh
```
## Dataset
The point clouds are uniformly sampled from meshes from ShapeNetCore dataset (version 2) and use the official split.
Please use this [link](https://drive.google.com/drive/folders/1G0rf-6HSHoTll6aH7voh-dXj6hCRhSAQ?usp=sharing) to download the ShapeNet point cloud.
The point cloud should be placed into `data` directory.
```bash
mv ShapeNetCore.v2.PC15k.zip data/
cd data
unzip ShapeNetCore.v2.PC15k.zip
```
Please contact us if you need point clouds for ModelNet dataset.
## Training
Example training scripts can be found in `scripts/` folder.
```bash
# Train auto-encoder (no latent CNF)
./scripts/shapenet_airplane_ae.sh # Train with single GPU, about 7-8 GB GPU memory
./scripts/shapenet_airplane_ae_dist.sh # Train with multiple GPUs
# Train generative model
./scripts/shapenet_airplane_ae.sh # Train with single GPU, about 7-8 GB GPU memory
./scripts/shapenet_airplane_ae_dist.sh # Train with multiple GPUs
```
## Pre-trained models and test
Pretrained models can be downloaded from this [link](https://drive.google.com/file/d/1dcxjuuKiAXZxhiyWD_o_7Owx8Y3FbRHG/view?usp=sharing).
Following is the suggested way to evaluate the performance of the pre-trained models.
```bash
unzip pretrained_models.zip; # This will create a folder named pretrained_models
# Evaluate the reconstruction performance of an AE trained on the airplane category
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_ae_test.sh;
# Evaluate the reconstruction performance of an AE trained with the whole ShapeNet
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_all_ae_test.sh;
# Evaluate the generative performance of PointFlow trained on the airplane category.
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_gen_test.sh
```
## Demo
The demo relies on [Open3D](http://www.open3d.org/). Following is the suggested way to install it:
```bash
conda install -c open3d-admin open3d
```
The demo will sample shapes from a pre-trained model, save those shapes under the `demo` folder, and visualize those point clouds.
Once this dependency is in place, you can use the following script to use the demo for the pre-trained model for airplanes:
```bash
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_demo.py
```
## Cite
Please cite our work if you find it useful:
```latex
@article{pointflow,
title={PointFlow: 3D Point Cloud Generation with Continuous Normalizing Flows},
author={Yang, Guandao and Huang, Xun, and Hao, Zekun and Liu, Ming-Yu and Belongie, Serge and Hariharan, Bharath},
journal={arXiv},
year={2019}
}
```

163
args.py Normal file
View file

@ -0,0 +1,163 @@
import argparse
NONLINEARITIES = ["tanh", "relu", "softplus", "elu", "swish", "square", "identity"]
SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams']
LAYERS = ["ignore", "concat", "concat_v2", "squash", "concatsquash", "scale", "concatscale"]
def add_args(parser):
# model architecture options
parser.add_argument('--input_dim', type=int, default=3,
help='Number of input dimensions (3 for 3D point clouds)')
parser.add_argument('--dims', type=str, default='256')
parser.add_argument('--latent_dims', type=str, default='256')
parser.add_argument("--num_blocks", type=int, default=1,
help='Number of stacked CNFs.')
parser.add_argument("--latent_num_blocks", type=int, default=1,
help='Number of stacked CNFs.')
parser.add_argument("--layer_type", type=str, default="concatsquash", choices=LAYERS)
parser.add_argument('--time_length', type=float, default=0.5)
parser.add_argument('--train_T', type=eval, default=True, choices=[True, False])
parser.add_argument("--nonlinearity", type=str, default="tanh", choices=NONLINEARITIES)
parser.add_argument('--use_adjoint', type=eval, default=True, choices=[True, False])
parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS)
parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1e-5)
parser.add_argument('--batch_norm', type=eval, default=True, choices=[True, False])
parser.add_argument('--sync_bn', type=eval, default=False, choices=[True, False])
parser.add_argument('--bn_lag', type=float, default=0)
# training options
parser.add_argument('--use_latent_flow', action='store_true',
help='Whether to use the latent flow to model the prior.')
parser.add_argument('--use_deterministic_encoder', action='store_true',
help='Whether to use a deterministic encoder.')
parser.add_argument('--zdim', type=int, default=128,
help='Dimension of the shape code')
parser.add_argument('--optimizer', type=str, default='adam',
help='Optimizer to use', choices=['adam', 'adamax', 'sgd'])
parser.add_argument('--batch_size', type=int, default=50,
help='Batch size (of datasets) for training')
parser.add_argument('--lr', type=float, default=1e-3,
help='Learning rate for the Adam optimizer.')
parser.add_argument('--beta1', type=float, default=0.9,
help='Beta1 for Adam.')
parser.add_argument('--beta2', type=float, default=0.999,
help='Beta2 for Adam.')
parser.add_argument('--momentum', type=float, default=0.9,
help='Momentum for SGD')
parser.add_argument('--weight_decay', type=float, default=0.,
help='Weight decay for the optimizer.')
parser.add_argument('--epochs', type=int, default=100,
help='Number of epochs for training (default: 100)')
parser.add_argument('--seed', type=int, default=None,
help='Seed for initializing training. ')
parser.add_argument('--recon_weight', type=float, default=1.,
help='Weight for the reconstruction loss.')
parser.add_argument('--prior_weight', type=float, default=1.,
help='Weight for the prior loss.')
parser.add_argument('--entropy_weight', type=float, default=1.,
help='Weight for the entropy loss.')
parser.add_argument('--scheduler', type=str, default='linear',
help='Type of learning rate schedule')
parser.add_argument('--exp_decay', type=float, default=1.,
help='Learning rate schedule exponential decay rate')
parser.add_argument('--exp_decay_freq', type=int, default=1,
help='Learning rate exponential decay frequency')
# data options
parser.add_argument('--dataset_type', type=str, default="shapenet15k",
help="Dataset types.", choices=['shapenet15k', 'modelnet40_15k', 'modelnet10_15k'])
parser.add_argument('--cates', type=str, nargs='+', default=["airplane"],
help="Categories to be trained (useful only if 'shapenet' is selected)")
parser.add_argument('--data_dir', type=str, default="data/ShapeNetCore.v2.PC15k",
help="Path to the training data")
parser.add_argument('--mn40_data_dir', type=str, default="data/ModelNet40.PC15k",
help="Path to ModelNet40")
parser.add_argument('--mn10_data_dir', type=str, default="data/ModelNet10.PC15k",
help="Path to ModelNet10")
parser.add_argument('--dataset_scale', type=float, default=1.,
help='Scale of the dataset (x,y,z * scale = real output, default=1).')
parser.add_argument('--random_rotate', action='store_true',
help='Whether to randomly rotate each shape.')
parser.add_argument('--normalize_per_shape', action='store_true',
help='Whether to perform normalization per shape.')
parser.add_argument('--normalize_std_per_axis', action='store_true',
help='Whether to perform normalization per axis.')
parser.add_argument("--tr_max_sample_points", type=int, default=2048,
help='Max number of sampled points (train)')
parser.add_argument("--te_max_sample_points", type=int, default=2048,
help='Max number of sampled points (test)')
parser.add_argument('--num_workers', type=int, default=4,
help='Number of data loading threads')
# logging and saving frequency
parser.add_argument('--log_name', type=str, default=None, help="Name for the log dir")
parser.add_argument('--viz_freq', type=int, default=10)
parser.add_argument('--val_freq', type=int, default=10)
parser.add_argument('--log_freq', type=int, default=10)
parser.add_argument('--save_freq', type=int, default=10)
# validation options
parser.add_argument('--no_validation', action='store_true',
help='Whether to disable validation altogether.')
parser.add_argument('--save_val_results', action='store_true',
help='Whether to save the validation results.')
parser.add_argument('--eval_classification', action='store_true',
help='Whether to evaluate classification accuracy on MN40 and MN10.')
parser.add_argument('--no_eval_sampling', action='store_true',
help='Whether to evaluate sampling.')
parser.add_argument('--max_validate_shapes', type=int, default=None,
help='Max number of shapes used for validation pass.')
# resuming
parser.add_argument('--resume_checkpoint', type=str, default=None,
help='Path to the checkpoint to be loaded.')
parser.add_argument('--resume_optimizer', action='store_true',
help='Whether to resume the optimizer when resumed training.')
parser.add_argument('--resume_non_strict', action='store_true',
help='Whether to resume in none-strict mode.')
parser.add_argument('--resume_dataset_mean', type=str, default=None,
help='Path to the file storing the dataset mean.')
parser.add_argument('--resume_dataset_std', type=str, default=None,
help='Path to the file storing the dataset std.')
# distributed training
parser.add_argument('--world_size', default=1, type=int,
help='Number of distributed nodes.')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use. None means using all available GPUs.')
# Evaluation options
parser.add_argument('--evaluate_recon', default=False, action='store_true',
help='Whether set to the evaluation for reconstruction.')
parser.add_argument('--num_sample_shapes', default=10, type=int,
help='Number of shapes to be sampled (for demo.py).')
parser.add_argument('--num_sample_points', default=2048, type=int,
help='Number of points (per-shape) to be sampled (for demo.py).')
return parser
def get_parser():
# command line args
parser = argparse.ArgumentParser(description='Flow-based Point Cloud Generation Experiment')
parser = add_args(parser)
return parser
def get_args():
parser = get_parser()
args = parser.parse_args()
return args

389
datasets.py Normal file
View file

@ -0,0 +1,389 @@
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils import data
import random
# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
synsetid_to_cate = {
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
'04554684': 'washer', '02992529': 'cellphone',
'02843684': 'birdhouse', '02871439': 'bookshelf',
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
class Uniform15KPC(Dataset):
def __init__(self, root_dir, subdirs, tr_sample_size=10000,
te_sample_size=10000, split='train', scale=1.,
normalize_per_shape=False, random_subsample=False,
normalize_std_per_axis=False,
all_points_mean=None, all_points_std=None,
input_dim=3):
self.root_dir = root_dir
self.split = split
self.in_tr_sample_size = tr_sample_size
self.in_te_sample_size = te_sample_size
self.subdirs = subdirs
self.scale = scale
self.random_subsample = random_subsample
self.input_dim = input_dim
self.all_cate_mids = []
self.cate_idx_lst = []
self.all_points = []
for cate_idx, subd in enumerate(self.subdirs):
# NOTE: [subd] here is synset id
sub_path = os.path.join(root_dir, subd, self.split)
if not os.path.isdir(sub_path):
print("Directory missing : %s" % sub_path)
continue
all_mids = []
for x in os.listdir(sub_path):
if not x.endswith('.npy'):
continue
all_mids.append(os.path.join(self.split, x[:-len('.npy')]))
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
for mid in all_mids:
# obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
try:
point_cloud = np.load(obj_fname) # (15k, 3)
except:
continue
assert point_cloud.shape[0] == 15000
self.all_points.append(point_cloud[np.newaxis, ...])
self.cate_idx_lst.append(cate_idx)
self.all_cate_mids.append((subd, mid))
# Shuffle the index deterministically (based on the number of examples)
self.shuffle_idx = list(range(len(self.all_points)))
random.Random(38383).shuffle(self.shuffle_idx)
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
# Normalization
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
self.normalize_per_shape = normalize_per_shape
self.normalize_std_per_axis = normalize_std_per_axis
if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats
self.all_points_mean = all_points_mean
self.all_points_std = all_points_std
elif self.normalize_per_shape: # per shape normalization
B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.mean(axis=1).reshape(B, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(B, N, -1).std(axis=1).reshape(B, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(B, -1).std(axis=1).reshape(B, 1, 1)
else: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
self.tr_sample_size = min(10000, tr_sample_size)
self.te_sample_size = min(5000, te_sample_size)
print("Total number of data:%d" % len(self.train_points))
print("Min number of points: (train)%d (test)%d"
% (self.tr_sample_size, self.te_sample_size))
assert self.scale == 1, "Scale (!= 1) is deprecated"
def get_pc_stats(self, idx):
if self.normalize_per_shape:
m = self.all_points_mean[idx].reshape(1, self.input_dim)
s = self.all_points_std[idx].reshape(1, -1)
return m, s
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
def renormalize(self, mean, std):
self.all_points = self.all_points * self.all_points_std + self.all_points_mean
self.all_points_mean = mean
self.all_points_std = std
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
def __len__(self):
return len(self.train_points)
def __getitem__(self, idx):
tr_out = self.train_points[idx]
if self.random_subsample:
tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
else:
tr_idxs = np.arange(self.tr_sample_size)
tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
te_out = self.test_points[idx]
if self.random_subsample:
te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size)
else:
te_idxs = np.arange(self.te_sample_size)
te_out = torch.from_numpy(te_out[te_idxs, :]).float()
m, s = self.get_pc_stats(idx)
cate_idx = self.cate_idx_lst[idx]
sid, mid = self.all_cate_mids[idx]
return {
'idx': idx,
'train_points': tr_out,
'test_points': te_out,
'mean': m, 'std': s, 'cate_idx': cate_idx,
'sid': sid, 'mid': mid
}
class ModelNet40PointClouds(Uniform15KPC):
def __init__(self, root_dir="data/ModelNet40.PC15k",
tr_sample_size=10000, te_sample_size=2048,
split='train', scale=1., normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=False,
all_points_mean=None, all_points_std=None):
self.root_dir = root_dir
self.split = split
assert self.split in ['train', 'test']
self.sample_size = tr_sample_size
self.cates = []
for cate in os.listdir(root_dir):
if os.path.isdir(os.path.join(root_dir, cate)) \
and os.path.isdir(os.path.join(root_dir, cate, 'train')) \
and os.path.isdir(os.path.join(root_dir, cate, 'test')):
self.cates.append(cate)
assert len(self.cates) == 40, "%s %s" % (len(self.cates), self.cates)
# For non-aligned MN
# self.gravity_axis = 0
# self.display_axis_order = [0,1,2]
# Aligned MN has same axis-order as SN
self.gravity_axis = 1
self.display_axis_order = [0, 2, 1]
super(ModelNet40PointClouds, self).__init__(
root_dir, self.cates, tr_sample_size=tr_sample_size,
te_sample_size=te_sample_size, split=split, scale=scale,
normalize_per_shape=normalize_per_shape,
normalize_std_per_axis=normalize_std_per_axis,
random_subsample=random_subsample,
all_points_mean=all_points_mean, all_points_std=all_points_std,
input_dim=3)
class ModelNet10PointClouds(Uniform15KPC):
def __init__(self, root_dir="data/ModelNet10.PC15k",
tr_sample_size=10000, te_sample_size=2048,
split='train', scale=1., normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=False,
all_points_mean=None, all_points_std=None):
self.root_dir = root_dir
self.split = split
assert self.split in ['train', 'test']
self.cates = []
for cate in os.listdir(root_dir):
if os.path.isdir(os.path.join(root_dir, cate)) \
and os.path.isdir(os.path.join(root_dir, cate, 'train')) \
and os.path.isdir(os.path.join(root_dir, cate, 'test')):
self.cates.append(cate)
assert len(self.cates) == 10
# That's prealigned MN
# self.gravity_axis = 0
# self.display_axis_order = [0,1,2]
# Aligned MN has same axis-order as SN
self.gravity_axis = 1
self.display_axis_order = [0, 2, 1]
super(ModelNet10PointClouds, self).__init__(
root_dir, self.cates, tr_sample_size=tr_sample_size,
te_sample_size=te_sample_size, split=split, scale=scale,
normalize_per_shape=normalize_per_shape,
normalize_std_per_axis=normalize_std_per_axis,
random_subsample=random_subsample,
all_points_mean=all_points_mean, all_points_std=all_points_std,
input_dim=3)
class ShapeNet15kPointClouds(Uniform15KPC):
def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k",
categories=['airplane'], tr_sample_size=10000, te_sample_size=2048,
split='train', scale=1., normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=False,
all_points_mean=None, all_points_std=None):
self.root_dir = root_dir
self.split = split
assert self.split in ['train', 'test', 'val']
self.tr_sample_size = tr_sample_size
self.te_sample_size = te_sample_size
self.cates = categories
if 'all' in categories:
self.synset_ids = list(cate_to_synsetid.values())
else:
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
# assert 'v2' in root_dir, "Only supporting v2 right now."
self.gravity_axis = 1
self.display_axis_order = [0, 2, 1]
super(ShapeNet15kPointClouds, self).__init__(
root_dir, self.synset_ids,
tr_sample_size=tr_sample_size,
te_sample_size=te_sample_size,
split=split, scale=scale,
normalize_per_shape=normalize_per_shape,
normalize_std_per_axis=normalize_std_per_axis,
random_subsample=random_subsample,
all_points_mean=all_points_mean, all_points_std=all_points_std,
input_dim=3)
def init_np_seed(worker_id):
seed = torch.initial_seed()
np.random.seed(seed % 4294967296)
def _get_MN40_datasets_(args, data_dir=None):
tr_dataset = ModelNet40PointClouds(
split='train',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
root_dir=(args.data_dir if data_dir is None else data_dir),
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
random_subsample=True)
te_dataset = ModelNet40PointClouds(
split='test',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
root_dir=(args.data_dir if data_dir is None else data_dir),
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return tr_dataset, te_dataset
def _get_MN10_datasets_(args, data_dir=None):
tr_dataset = ModelNet10PointClouds(
split='train',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
root_dir=(args.data_dir if data_dir is None else data_dir),
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
random_subsample=True)
te_dataset = ModelNet10PointClouds(
split='test',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
root_dir=(args.data_dir if data_dir is None else data_dir),
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return tr_dataset, te_dataset
def get_datasets(args):
if args.dataset_type == 'shapenet15k':
tr_dataset = ShapeNet15kPointClouds(
categories=args.cates, split='train',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
scale=args.dataset_scale, root_dir=args.data_dir,
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
random_subsample=True)
te_dataset = ShapeNet15kPointClouds(
categories=args.cates, split='val',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
scale=args.dataset_scale, root_dir=args.data_dir,
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
elif args.dataset_type == 'modelnet40_15k':
tr_dataset, te_dataset = _get_MN40_datasets_(args)
elif args.dataset_type == 'modelnet10_15k':
tr_dataset, te_dataset = _get_MN10_datasets_(args)
else:
raise Exception("Invalid dataset type:%s" % args.dataset_type)
return tr_dataset, te_dataset
def get_clf_datasets(args):
return {
'MN40': _get_MN40_datasets_(args, data_dir=args.mn40_data_dir),
'MN10': _get_MN10_datasets_(args, data_dir=args.mn10_data_dir),
}
def get_data_loaders(args):
tr_dataset, te_dataset = get_datasets(args)
train_loader = data.DataLoader(
dataset=tr_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers, drop_last=True,
worker_init_fn=init_np_seed)
train_unshuffle_loader = data.DataLoader(
dataset=tr_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=args.num_workers, drop_last=True,
worker_init_fn=init_np_seed)
test_loader = data.DataLoader(
dataset=te_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=args.num_workers, drop_last=False,
worker_init_fn=init_np_seed)
loaders = {
"test_loader": test_loader,
'train_loader': train_loader,
'train_unshuffle_loader': train_unshuffle_loader,
}
return loaders
if __name__ == "__main__":
shape_ds = ShapeNet15kPointClouds(categories=['airplane'], split='val')
x_tr, x_te = next(iter(shape_ds))
print(x_tr.shape)
print(x_te.shape)

60
demo.py Normal file
View file

@ -0,0 +1,60 @@
import open3d as o3d
from datasets import get_datasets
from args import get_args
from models.networks import PointFlow
import os
import torch
import numpy as np
import torch.nn as nn
def main(args):
model = PointFlow(args)
def _transform_(m):
return nn.DataParallel(m)
model = model.cuda()
model.multi_gpu_wrapper(_transform_)
print("Resume Path:%s" % args.resume_checkpoint)
checkpoint = torch.load(args.resume_checkpoint)
model.load_state_dict(checkpoint)
model.eval()
_, te_dataset = get_datasets(args)
if args.resume_dataset_mean is not None and args.resume_dataset_std is not None:
mean = np.load(args.resume_dataset_mean)
std = np.load(args.resume_dataset_std)
te_dataset.renormalize(mean, std)
ds_mean = torch.from_numpy(te_dataset.all_points_mean).cuda()
ds_std = torch.from_numpy(te_dataset.all_points_std).cuda()
all_sample = []
with torch.no_grad():
for i in range(0, args.num_sample_shapes, args.batch_size):
B = len(range(i, min(i + args.batch_size, args.num_sample_shapes)))
N = args.num_sample_points
_, out_pc = model.sample(B, N)
out_pc = out_pc * ds_std + ds_mean
all_sample.append(out_pc)
sample_pcs = torch.cat(all_sample, dim=0).cpu().detach().numpy()
print("Generation sample size:(%s, %s, %s)" % sample_pcs.shape)
# Save the generative output
os.makedirs("demo", exist_ok=True)
np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs)
# Visualize the demo
pcl = o3d.geometry.PointCloud()
for i in range(int(sample_pcs.shape[0])):
print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
pts = sample_pcs[i].reshape(-1, 3)
pcl.points = o3d.utility.Vector3dVector(pts)
o3d.visualization.draw_geometries([pcl])
if __name__ == '__main__':
args = get_args()
main(args)

19
install.sh Executable file
View file

@ -0,0 +1,19 @@
#! /bin/bash
root=`pwd`
# Install dependecies
conda install numpy matplotlib pillow scipy tqdm scikit-learn -y
pip install tensorflow-gpu==1.13.1
pip install tensorboardX==1.7
# Compile CUDA kernel for CD/EMD loss
cd metrics/pytorch_structural_losses/
make clean
make
cd $root
# install torchdiffeq
git clone https://github.com/rtqichen/torchdiffeq.git
cd torchdiffeq
pip install -e .

1
metrics/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
StructuralLosses

0
metrics/__init__.py Normal file
View file

View file

@ -0,0 +1,336 @@
import torch
import numpy as np
import warnings
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/
from .StructuralLosses.match_cost import match_cost
from .StructuralLosses.nn_distance import nn_distance
# # Import CUDA version of CD, borrowed from https://github.com/ThibaultGROUEIX/AtlasNet
# try:
# from . chamfer_distance_ext.dist_chamfer import chamferDist
# CD = chamferDist()
# def distChamferCUDA(x,y):
# return CD(x,y,gpu)
# except:
def distChamferCUDA(x, y):
return nn_distance(x, y)
def emd_approx(sample, ref):
B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
assert N == N_ref, "Not sure what would EMD do in this case"
emd = match_cost(sample, ref) # (B,)
emd_norm = emd / float(N) # (B,)
return emd_norm
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
def distChamfer(a, b):
x, y = a, b
bs, num_points, points_dim = x.size()
xx = torch.bmm(x, x.transpose(2, 1))
yy = torch.bmm(y, y.transpose(2, 1))
zz = torch.bmm(x, y.transpose(2, 1))
diag_ind = torch.arange(0, num_points).to(a).long()
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
P = (rx.transpose(2, 1) + ry - 2 * zz)
return P.min(1)[0], P.min(2)[0]
def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
cd_lst = []
emd_lst = []
iterator = range(0, N_sample, batch_size)
for b_start in iterator:
b_end = min(N_sample, b_start + batch_size)
sample_batch = sample_pcs[b_start:b_end]
ref_batch = ref_pcs[b_start:b_end]
if accelerated_cd:
dl, dr = distChamferCUDA(sample_batch, ref_batch)
else:
dl, dr = distChamfer(sample_batch, ref_batch)
cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))
emd_batch = emd_approx(sample_batch, ref_batch)
emd_lst.append(emd_batch)
if reduced:
cd = torch.cat(cd_lst).mean()
emd = torch.cat(emd_lst).mean()
else:
cd = torch.cat(cd_lst)
emd = torch.cat(emd_lst)
results = {
'MMD-CD': cd,
'MMD-EMD': emd,
}
return results
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
all_cd = []
all_emd = []
iterator = range(N_sample)
for sample_b_start in iterator:
sample_batch = sample_pcs[sample_b_start]
cd_lst = []
emd_lst = []
for ref_b_start in range(0, N_ref, batch_size):
ref_b_end = min(N_ref, ref_b_start + batch_size)
ref_batch = ref_pcs[ref_b_start:ref_b_end]
batch_size_ref = ref_batch.size(0)
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
sample_batch_exp = sample_batch_exp.contiguous()
if accelerated_cd:
dl, dr = distChamferCUDA(sample_batch_exp, ref_batch)
else:
dl, dr = distChamfer(sample_batch_exp, ref_batch)
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
emd_batch = emd_approx(sample_batch_exp, ref_batch)
emd_lst.append(emd_batch.view(1, -1))
cd_lst = torch.cat(cd_lst, dim=1)
emd_lst = torch.cat(emd_lst, dim=1)
all_cd.append(cd_lst)
all_emd.append(emd_lst)
all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref
all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref
return all_cd, all_emd
# Adapted from https://github.com/xuqiantong/GAN-Metrics/blob/master/framework/metric.py
def knn(Mxx, Mxy, Myy, k, sqrt=False):
n0 = Mxx.size(0)
n1 = Myy.size(0)
label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx)
M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)
if sqrt:
M = M.abs().sqrt()
INFINITY = float('inf')
val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)
count = torch.zeros(n0 + n1).to(Mxx)
for i in range(0, k):
count = count + label.index_select(0, idx[i])
pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()
s = {
'tp': (pred * label).sum(),
'fp': (pred * (1 - label)).sum(),
'fn': ((1 - pred) * label).sum(),
'tn': ((1 - pred) * (1 - label)).sum(),
}
s.update({
'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10),
'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10),
'acc': torch.eq(label, pred).float().mean(),
})
return s
def lgan_mmd_cov(all_dist):
N_sample, N_ref = all_dist.size(0), all_dist.size(1)
min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
min_val, _ = torch.min(all_dist, dim=0)
mmd = min_val.mean()
mmd_smp = min_val_fromsmp.mean()
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
cov = torch.tensor(cov).to(all_dist)
return {
'lgan_mmd': mmd,
'lgan_cov': cov,
'lgan_mmd_smp': mmd_smp,
}
def compute_all_metrics(sample_pcs, ref_pcs, batch_size, accelerated_cd=False):
results = {}
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
res_cd = lgan_mmd_cov(M_rs_cd.t())
results.update({
"%s-CD" % k: v for k, v in res_cd.items()
})
res_emd = lgan_mmd_cov(M_rs_emd.t())
results.update({
"%s-EMD" % k: v for k, v in res_emd.items()
})
M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size, accelerated_cd=accelerated_cd)
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
# 1-NN results
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
results.update({
"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
})
one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
results.update({
"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
})
return results
#######################################################
# JSD : from https://github.com/optas/latent_3d_points
#######################################################
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
"""Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
that is placed in the unit-cube.
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
"""
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
spacing = 1.0 / float(resolution - 1)
for i in range(resolution):
for j in range(resolution):
for k in range(resolution):
grid[i, j, k, 0] = i * spacing - 0.5
grid[i, j, k, 1] = j * spacing - 0.5
grid[i, j, k, 2] = k * spacing - 0.5
if clip_sphere:
grid = grid.reshape(-1, 3)
grid = grid[norm(grid, axis=1) <= 0.5]
return grid, spacing
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
"""Computes the JSD between two sets of point-clouds, as introduced in the paper
```Learning Representations And Generative Models For 3D Point Clouds```.
Args:
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
resolution: (int) grid-resolution. Affects granularity of measurements.
"""
in_unit_sphere = True
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
return jensen_shannon_divergence(sample_grid_var, ref_grid_var)
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose=False):
"""Given a collection of point-clouds, estimate the entropy of the random variables
corresponding to occupancy-grid activation patterns.
Inputs:
pclouds: (numpy array) #point-clouds x points per point-cloud x 3
grid_resolution (int) size of occupancy grid that will be used.
"""
epsilon = 10e-4
bound = 0.5 + epsilon
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
if verbose:
warnings.warn('Point-clouds are not in unit cube.')
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
if verbose:
warnings.warn('Point-clouds are not in unit sphere.')
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
grid_coordinates = grid_coordinates.reshape(-1, 3)
grid_counters = np.zeros(len(grid_coordinates))
grid_bernoulli_rvars = np.zeros(len(grid_coordinates))
nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)
for pc in pclouds:
_, indices = nn.kneighbors(pc)
indices = np.squeeze(indices)
for i in indices:
grid_counters[i] += 1
indices = np.unique(indices)
for i in indices:
grid_bernoulli_rvars[i] += 1
acc_entropy = 0.0
n = float(len(pclouds))
for g in grid_bernoulli_rvars:
if g > 0:
p = float(g) / n
acc_entropy += entropy([p, 1.0 - p])
return acc_entropy / len(grid_counters), grid_counters
def jensen_shannon_divergence(P, Q):
if np.any(P < 0) or np.any(Q < 0):
raise ValueError('Negative values.')
if len(P) != len(Q):
raise ValueError('Non equal size.')
P_ = P / np.sum(P) # Ensure probabilities.
Q_ = Q / np.sum(Q)
e1 = entropy(P_, base=2)
e2 = entropy(Q_, base=2)
e_sum = entropy((P_ + Q_) / 2.0, base=2)
res = e_sum - ((e1 + e2) / 2.0)
res2 = _jsdiv(P_, Q_)
if not np.allclose(res, res2, atol=10e-5, rtol=0):
warnings.warn('Numerical values of two JSD methods don\'t agree.')
return res
def _jsdiv(P, Q):
"""another way of computing JSD"""
def _kldiv(A, B):
a = A.copy()
b = B.copy()
idx = np.logical_and(a > 0, b > 0)
a = a[idx]
b = b[idx]
return np.sum([v for v in a * np.log2(a / b)])
P_ = P / np.sum(P)
Q_ = Q / np.sum(Q)
M = 0.5 * (P_ + Q_)
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
if __name__ == "__main__":
B, N = 2, 10
x = torch.rand(B, N, 3)
y = torch.rand(B, N, 3)
distChamfer = distChamferCUDA()
min_l, min_r = distChamfer(x.cuda(), y.cuda())
print(min_l.shape)
print(min_r.shape)
l_dist = min_l.mean().cpu().detach().item()
r_dist = min_r.mean().cpu().detach().item()
print(l_dist, r_dist)

View file

@ -0,0 +1 @@
PyTorchStructuralLosses.egg-info/

View file

@ -0,0 +1,100 @@
###############################################################################
# Uncomment for debugging
# DEBUG := 1
# Pretty build
# Q ?= @
CXX := g++
PYTHON := python
NVCC := /usr/local/cuda/bin/nvcc
# PYTHON Header path
PYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())')
PYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]')
PYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]')
# CUDA ROOT DIR that contains bin/ lib64/ and include/
# CUDA_DIR := /usr/local/cuda
CUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())')
INCLUDE_DIRS := ./ $(CUDA_DIR)/include
INCLUDE_DIRS += $(PYTHON_HEADER_DIR)
INCLUDE_DIRS += $(PYTORCH_INCLUDES)
# Custom (MKL/ATLAS/OpenBLAS) include and lib directories.
# Leave commented to accept the defaults for your choice of BLAS
# (which should work)!
# BLAS_INCLUDE := /path/to/your/blas
# BLAS_LIB := /path/to/your/blas
###############################################################################
SRC_DIR := ./src
OBJ_DIR := ./objs
CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp)
CU_SRCS := $(wildcard $(SRC_DIR)/*.cu)
OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS))
CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS))
STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a
# CUDA architecture setting: going with all of them.
# For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility.
# For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility.
CUDA_ARCH := -gencode arch=compute_61,code=sm_61 \
-gencode arch=compute_61,code=compute_61 \
-gencode arch=compute_52,code=sm_52
# We will also explicitly add stdc++ to the link target.
LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu
# Debugging
ifeq ($(DEBUG), 1)
COMMON_FLAGS += -DDEBUG -g -O0
# https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/
NVCCFLAGS += -g -G # -rdc true
else
COMMON_FLAGS += -DNDEBUG -O3
endif
WARNINGS := -Wall -Wno-sign-compare -Wcomment
INCLUDE_DIRS += $(BLAS_INCLUDE)
# Automatic dependency generation (nvcc is handled separately)
CXXFLAGS += -MMD -MP
# Complete build flags.
COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \
-DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0
CXXFLAGS += -pthread -fPIC -fwrapv -std=c++11 $(COMMON_FLAGS) $(WARNINGS)
NVCCFLAGS += -std=c++11 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
all: $(STATIC_LIB)
$(PYTHON) setup.py build
@ mv build/lib.linux-x86_64-3.6/StructuralLosses ..
@ mv build/lib.linux-x86_64-3.6/*.so ../StructuralLosses/
@- $(RM) -rf $(OBJ_DIR) build objs
$(OBJ_DIR):
@ mkdir -p $@
@ mkdir -p $@/cuda
$(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR)
@ echo CXX $<
$(Q)$(CXX) $< $(CXXFLAGS) -c -o $@
$(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR)
@ echo NVCC $<
$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \
-odir $(@D)
$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@
$(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR)
$(RM) -f $(STATIC_LIB)
$(RM) -rf build dist
@ echo LD -o $@
ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS)
clean:
@- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses

View file

@ -0,0 +1,6 @@
#import torch
#from MakePytorchBackend import AddGPU, Foo, ApproxMatch
#from Add import add_gpu, approx_match

View file

@ -0,0 +1,45 @@
import torch
from torch.autograd import Function
from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad
# Inherit from Function
class MatchCostFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, seta, setb):
#print("Match Cost Forward")
ctx.save_for_backward(seta, setb)
'''
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
match : batch_size * #query_points * #dataset_points
'''
match, temp = ApproxMatch(seta, setb)
ctx.match = match
cost = MatchCost(seta, setb, match)
return cost
"""
grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)
return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]
"""
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
#print("Match Cost Backward")
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
seta, setb = ctx.saved_tensors
#grad_input = grad_weight = grad_bias = None
grada, gradb = MatchCostGrad(seta, setb, ctx.match)
grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)
return grada*grad_output_expand, gradb*grad_output_expand
match_cost = MatchCostFunction.apply

View file

@ -0,0 +1,42 @@
import torch
from torch.autograd import Function
# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
# Inherit from Function
class NNDistanceFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, seta, setb):
#print("Match Cost Forward")
ctx.save_for_backward(seta, setb)
'''
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
dist1, idx1, dist2, idx2
'''
dist1, idx1