Add training/testing/demo codes

This commit is contained in:
Guandao Yang 2019-07-13 21:32:26 -07:00
parent f1581cf0f5
commit 4e6795a4d9
41 changed files with 3785 additions and 103 deletions

113
.gitignore vendored
View file

@ -1,108 +1,23 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*__pycache__
.idea/
.DS_Store
*/.DS_Store
*/*/.DS_Store
*.pyc
data
# C extensions
*.so
scratch.py
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
*.m
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
checkpoints
runs
# celery beat schedule file
celerybeat-schedule
metrics/structural_losses/*.so
metrics/structural_losses/*.cu.o
metrics/structural_losses/makefile
PyMesh
checkpoint
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
torchdiffeq/
demo/

View file

@ -2,7 +2,7 @@
This repository contains a PyTorch implementation of the paper:
[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](https://arxiv.org/abs/1906.12320).
[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](www.arxiv.com).
[Guandao Yang*](http://www.guandaoyang.com),
[Xun Huang*](http://www.cs.cornell.edu/~xhuang/),
@ -11,9 +11,6 @@ This repository contains a PyTorch implementation of the paper:
[Serge Belongie](http://blogs.cornell.edu/techfaculty/serge-belongie/),
[Bharath Hariharan](http://home.bharathh.info/)
**The code will be available soon!**
[\[Project page\]](https://www.guandaoyang.com/PointFlow/) [\[Video\]](https://www.youtube.com/watch?v=jqBiv77xC0M)
## Introduction
@ -23,3 +20,90 @@ As 3D point clouds become the representation of choice for multiple vision and g
<p float="left">
<img src="docs/assets/teaser.gif" height="256"/>
</p>
## Dependencies
* Python 3.6
* CUDA 10.0.
* G++ or GCC 5.
* [PyTorch](http://pytorch.org/). Codes are tested with version 1.0.1
* [torchdiffeq](https://github.com/rtqichen/torchdiffeq).
* (Optional) [Tensorboard](https://www.tensorflow.org/) for visualization of training process.
Following is the suggested way to install these dependencies:
```bash
# Create a new conda environment
conda create -n PointFlow python=3.6
conda activate PointFlow
# Install pytorch (please refer to the commend in the official website)
conda install pytorch=1.0.1 torchvision cudatoolkit=10.0 -c pytorch -y
# Install other dependencies such as torchdiffeq, structural losses, etc.
./install.sh
```
## Dataset
The point clouds are uniformly sampled from meshes from ShapeNetCore dataset (version 2) and use the official split.
Please use this [link](https://drive.google.com/drive/folders/1G0rf-6HSHoTll6aH7voh-dXj6hCRhSAQ?usp=sharing) to download the ShapeNet point cloud.
The point cloud should be placed into `data` directory.
```bash
mv ShapeNetCore.v2.PC15k.zip data/
cd data
unzip ShapeNetCore.v2.PC15k.zip
```
Please contact us if you need point clouds for ModelNet dataset.
## Training
Example training scripts can be found in `scripts/` folder.
```bash
# Train auto-encoder (no latent CNF)
./scripts/shapenet_airplane_ae.sh # Train with single GPU, about 7-8 GB GPU memory
./scripts/shapenet_airplane_ae_dist.sh # Train with multiple GPUs
# Train generative model
./scripts/shapenet_airplane_ae.sh # Train with single GPU, about 7-8 GB GPU memory
./scripts/shapenet_airplane_ae_dist.sh # Train with multiple GPUs
```
## Pre-trained models and test
Pretrained models can be downloaded from this [link](https://drive.google.com/file/d/1dcxjuuKiAXZxhiyWD_o_7Owx8Y3FbRHG/view?usp=sharing).
Following is the suggested way to evaluate the performance of the pre-trained models.
```bash
unzip pretrained_models.zip; # This will create a folder named pretrained_models
# Evaluate the reconstruction performance of an AE trained on the airplane category
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_ae_test.sh;
# Evaluate the reconstruction performance of an AE trained with the whole ShapeNet
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_all_ae_test.sh;
# Evaluate the generative performance of PointFlow trained on the airplane category.
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_gen_test.sh
```
## Demo
The demo relies on [Open3D](http://www.open3d.org/). Following is the suggested way to install it:
```bash
conda install -c open3d-admin open3d
```
The demo will sample shapes from a pre-trained model, save those shapes under the `demo` folder, and visualize those point clouds.
Once this dependency is in place, you can use the following script to use the demo for the pre-trained model for airplanes:
```bash
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_demo.py
```
## Cite
Please cite our work if you find it useful:
```latex
@article{pointflow,
title={PointFlow: 3D Point Cloud Generation with Continuous Normalizing Flows},
author={Yang, Guandao and Huang, Xun, and Hao, Zekun and Liu, Ming-Yu and Belongie, Serge and Hariharan, Bharath},
journal={arXiv},
year={2019}
}
```

163
args.py Normal file
View file

@ -0,0 +1,163 @@
import argparse
NONLINEARITIES = ["tanh", "relu", "softplus", "elu", "swish", "square", "identity"]
SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams']
LAYERS = ["ignore", "concat", "concat_v2", "squash", "concatsquash", "scale", "concatscale"]
def add_args(parser):
# model architecture options
parser.add_argument('--input_dim', type=int, default=3,
help='Number of input dimensions (3 for 3D point clouds)')
parser.add_argument('--dims', type=str, default='256')
parser.add_argument('--latent_dims', type=str, default='256')
parser.add_argument("--num_blocks", type=int, default=1,
help='Number of stacked CNFs.')
parser.add_argument("--latent_num_blocks", type=int, default=1,
help='Number of stacked CNFs.')
parser.add_argument("--layer_type", type=str, default="concatsquash", choices=LAYERS)
parser.add_argument('--time_length', type=float, default=0.5)
parser.add_argument('--train_T', type=eval, default=True, choices=[True, False])
parser.add_argument("--nonlinearity", type=str, default="tanh", choices=NONLINEARITIES)
parser.add_argument('--use_adjoint', type=eval, default=True, choices=[True, False])
parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS)
parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1e-5)
parser.add_argument('--batch_norm', type=eval, default=True, choices=[True, False])
parser.add_argument('--sync_bn', type=eval, default=False, choices=[True, False])
parser.add_argument('--bn_lag', type=float, default=0)
# training options
parser.add_argument('--use_latent_flow', action='store_true',
help='Whether to use the latent flow to model the prior.')
parser.add_argument('--use_deterministic_encoder', action='store_true',
help='Whether to use a deterministic encoder.')
parser.add_argument('--zdim', type=int, default=128,
help='Dimension of the shape code')
parser.add_argument('--optimizer', type=str, default='adam',
help='Optimizer to use', choices=['adam', 'adamax', 'sgd'])
parser.add_argument('--batch_size', type=int, default=50,
help='Batch size (of datasets) for training')
parser.add_argument('--lr', type=float, default=1e-3,
help='Learning rate for the Adam optimizer.')
parser.add_argument('--beta1', type=float, default=0.9,
help='Beta1 for Adam.')
parser.add_argument('--beta2', type=float, default=0.999,
help='Beta2 for Adam.')
parser.add_argument('--momentum', type=float, default=0.9,
help='Momentum for SGD')
parser.add_argument('--weight_decay', type=float, default=0.,
help='Weight decay for the optimizer.')
parser.add_argument('--epochs', type=int, default=100,
help='Number of epochs for training (default: 100)')
parser.add_argument('--seed', type=int, default=None,
help='Seed for initializing training. ')
parser.add_argument('--recon_weight', type=float, default=1.,
help='Weight for the reconstruction loss.')
parser.add_argument('--prior_weight', type=float, default=1.,
help='Weight for the prior loss.')
parser.add_argument('--entropy_weight', type=float, default=1.,
help='Weight for the entropy loss.')
parser.add_argument('--scheduler', type=str, default='linear',
help='Type of learning rate schedule')
parser.add_argument('--exp_decay', type=float, default=1.,
help='Learning rate schedule exponential decay rate')
parser.add_argument('--exp_decay_freq', type=int, default=1,
help='Learning rate exponential decay frequency')
# data options
parser.add_argument('--dataset_type', type=str, default="shapenet15k",
help="Dataset types.", choices=['shapenet15k', 'modelnet40_15k', 'modelnet10_15k'])
parser.add_argument('--cates', type=str, nargs='+', default=["airplane"],
help="Categories to be trained (useful only if 'shapenet' is selected)")
parser.add_argument('--data_dir', type=str, default="data/ShapeNetCore.v2.PC15k",
help="Path to the training data")
parser.add_argument('--mn40_data_dir', type=str, default="data/ModelNet40.PC15k",
help="Path to ModelNet40")
parser.add_argument('--mn10_data_dir', type=str, default="data/ModelNet10.PC15k",
help="Path to ModelNet10")
parser.add_argument('--dataset_scale', type=float, default=1.,
help='Scale of the dataset (x,y,z * scale = real output, default=1).')
parser.add_argument('--random_rotate', action='store_true',
help='Whether to randomly rotate each shape.')
parser.add_argument('--normalize_per_shape', action='store_true',
help='Whether to perform normalization per shape.')
parser.add_argument('--normalize_std_per_axis', action='store_true',
help='Whether to perform normalization per axis.')
parser.add_argument("--tr_max_sample_points", type=int, default=2048,
help='Max number of sampled points (train)')
parser.add_argument("--te_max_sample_points", type=int, default=2048,
help='Max number of sampled points (test)')
parser.add_argument('--num_workers', type=int, default=4,
help='Number of data loading threads')
# logging and saving frequency
parser.add_argument('--log_name', type=str, default=None, help="Name for the log dir")
parser.add_argument('--viz_freq', type=int, default=10)
parser.add_argument('--val_freq', type=int, default=10)
parser.add_argument('--log_freq', type=int, default=10)
parser.add_argument('--save_freq', type=int, default=10)
# validation options
parser.add_argument('--no_validation', action='store_true',
help='Whether to disable validation altogether.')
parser.add_argument('--save_val_results', action='store_true',
help='Whether to save the validation results.')
parser.add_argument('--eval_classification', action='store_true',
help='Whether to evaluate classification accuracy on MN40 and MN10.')
parser.add_argument('--no_eval_sampling', action='store_true',
help='Whether to evaluate sampling.')
parser.add_argument('--max_validate_shapes', type=int, default=None,
help='Max number of shapes used for validation pass.')
# resuming
parser.add_argument('--resume_checkpoint', type=str, default=None,
help='Path to the checkpoint to be loaded.')
parser.add_argument('--resume_optimizer', action='store_true',
help='Whether to resume the optimizer when resumed training.')
parser.add_argument('--resume_non_strict', action='store_true',
help='Whether to resume in none-strict mode.')
parser.add_argument('--resume_dataset_mean', type=str, default=None,
help='Path to the file storing the dataset mean.')
parser.add_argument('--resume_dataset_std', type=str, default=None,
help='Path to the file storing the dataset std.')
# distributed training
parser.add_argument('--world_size', default=1, type=int,
help='Number of distributed nodes.')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use. None means using all available GPUs.')
# Evaluation options
parser.add_argument('--evaluate_recon', default=False, action='store_true',
help='Whether set to the evaluation for reconstruction.')
parser.add_argument('--num_sample_shapes', default=10, type=int,
help='Number of shapes to be sampled (for demo.py).')
parser.add_argument('--num_sample_points', default=2048, type=int,
help='Number of points (per-shape) to be sampled (for demo.py).')
return parser
def get_parser():
# command line args
parser = argparse.ArgumentParser(description='Flow-based Point Cloud Generation Experiment')
parser = add_args(parser)
return parser
def get_args():
parser = get_parser()
args = parser.parse_args()
return args

389
datasets.py Normal file
View file

@ -0,0 +1,389 @@
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils import data
import random
# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
synsetid_to_cate = {
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
'04554684': 'washer', '02992529': 'cellphone',
'02843684': 'birdhouse', '02871439': 'bookshelf',
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
class Uniform15KPC(Dataset):
def __init__(self, root_dir, subdirs, tr_sample_size=10000,
te_sample_size=10000, split='train', scale=1.,
normalize_per_shape=False, random_subsample=False,
normalize_std_per_axis=False,
all_points_mean=None, all_points_std=None,
input_dim=3):
self.root_dir = root_dir
self.split = split
self.in_tr_sample_size = tr_sample_size
self.in_te_sample_size = te_sample_size
self.subdirs = subdirs
self.scale = scale
self.random_subsample = random_subsample
self.input_dim = input_dim
self.all_cate_mids = []
self.cate_idx_lst = []
self.all_points = []
for cate_idx, subd in enumerate(self.subdirs):
# NOTE: [subd] here is synset id
sub_path = os.path.join(root_dir, subd, self.split)
if not os.path.isdir(sub_path):
print("Directory missing : %s" % sub_path)
continue
all_mids = []
for x in os.listdir(sub_path):
if not x.endswith('.npy'):
continue
all_mids.append(os.path.join(self.split, x[:-len('.npy')]))
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
for mid in all_mids:
# obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
try:
point_cloud = np.load(obj_fname) # (15k, 3)
except:
continue
assert point_cloud.shape[0] == 15000
self.all_points.append(point_cloud[np.newaxis, ...])
self.cate_idx_lst.append(cate_idx)
self.all_cate_mids.append((subd, mid))
# Shuffle the index deterministically (based on the number of examples)
self.shuffle_idx = list(range(len(self.all_points)))
random.Random(38383).shuffle(self.shuffle_idx)
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
# Normalization
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
self.normalize_per_shape = normalize_per_shape
self.normalize_std_per_axis = normalize_std_per_axis
if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats
self.all_points_mean = all_points_mean
self.all_points_std = all_points_std
elif self.normalize_per_shape: # per shape normalization
B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.mean(axis=1).reshape(B, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(B, N, -1).std(axis=1).reshape(B, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(B, -1).std(axis=1).reshape(B, 1, 1)
else: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
self.tr_sample_size = min(10000, tr_sample_size)
self.te_sample_size = min(5000, te_sample_size)
print("Total number of data:%d" % len(self.train_points))
print("Min number of points: (train)%d (test)%d"
% (self.tr_sample_size, self.te_sample_size))
assert self.scale == 1, "Scale (!= 1) is deprecated"
def get_pc_stats(self, idx):
if self.normalize_per_shape:
m = self.all_points_mean[idx].reshape(1, self.input_dim)
s = self.all_points_std[idx].reshape(1, -1)
return m, s
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
def renormalize(self, mean, std):
self.all_points = self.all_points * self.all_points_std + self.all_points_mean
self.all_points_mean = mean
self.all_points_std = std
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
def __len__(self):
return len(self.train_points)
def __getitem__(self, idx):
tr_out = self.train_points[idx]
if self.random_subsample:
tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
else:
tr_idxs = np.arange(self.tr_sample_size)
tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
te_out = self.test_points[idx]
if self.random_subsample:
te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size)
else:
te_idxs = np.arange(self.te_sample_size)
te_out = torch.from_numpy(te_out[te_idxs, :]).float()
m, s = self.get_pc_stats(idx)
cate_idx = self.cate_idx_lst[idx]
sid, mid = self.all_cate_mids[idx]
return {
'idx': idx,
'train_points': tr_out,
'test_points': te_out,
'mean': m, 'std': s, 'cate_idx': cate_idx,
'sid': sid, 'mid': mid
}
class ModelNet40PointClouds(Uniform15KPC):
def __init__(self, root_dir="data/ModelNet40.PC15k",
tr_sample_size=10000, te_sample_size=2048,
split='train', scale=1., normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=False,
all_points_mean=None, all_points_std=None):
self.root_dir = root_dir
self.split = split
assert self.split in ['train', 'test']
self.sample_size = tr_sample_size
self.cates = []
for cate in os.listdir(root_dir):
if os.path.isdir(os.path.join(root_dir, cate)) \
and os.path.isdir(os.path.join(root_dir, cate, 'train')) \
and os.path.isdir(os.path.join(root_dir, cate, 'test')):
self.cates.append(cate)
assert len(self.cates) == 40, "%s %s" % (len(self.cates), self.cates)
# For non-aligned MN
# self.gravity_axis = 0
# self.display_axis_order = [0,1,2]
# Aligned MN has same axis-order as SN
self.gravity_axis = 1
self.display_axis_order = [0, 2, 1]
super(ModelNet40PointClouds, self).__init__(
root_dir, self.cates, tr_sample_size=tr_sample_size,
te_sample_size=te_sample_size, split=split, scale=scale,
normalize_per_shape=normalize_per_shape,
normalize_std_per_axis=normalize_std_per_axis,
random_subsample=random_subsample,
all_points_mean=all_points_mean, all_points_std=all_points_std,
input_dim=3)
class ModelNet10PointClouds(Uniform15KPC):
def __init__(self, root_dir="data/ModelNet10.PC15k",
tr_sample_size=10000, te_sample_size=2048,
split='train', scale=1., normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=False,
all_points_mean=None, all_points_std=None):
self.root_dir = root_dir
self.split = split
assert self.split in ['train', 'test']
self.cates = []
for cate in os.listdir(root_dir):
if os.path.isdir(os.path.join(root_dir, cate)) \
and os.path.isdir(os.path.join(root_dir, cate, 'train')) \
and os.path.isdir(os.path.join(root_dir, cate, 'test')):
self.cates.append(cate)
assert len(self.cates) == 10
# That's prealigned MN
# self.gravity_axis = 0
# self.display_axis_order = [0,1,2]
# Aligned MN has same axis-order as SN
self.gravity_axis = 1
self.display_axis_order = [0, 2, 1]
super(ModelNet10PointClouds, self).__init__(
root_dir, self.cates, tr_sample_size=tr_sample_size,
te_sample_size=te_sample_size, split=split, scale=scale,
normalize_per_shape=normalize_per_shape,
normalize_std_per_axis=normalize_std_per_axis,
random_subsample=random_subsample,
all_points_mean=all_points_mean, all_points_std=all_points_std,
input_dim=3)
class ShapeNet15kPointClouds(Uniform15KPC):
def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k",
categories=['airplane'], tr_sample_size=10000, te_sample_size=2048,
split='train', scale=1., normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=False,
all_points_mean=None, all_points_std=None):
self.root_dir = root_dir
self.split = split
assert self.split in ['train', 'test', 'val']
self.tr_sample_size = tr_sample_size
self.te_sample_size = te_sample_size
self.cates = categories
if 'all' in categories:
self.synset_ids = list(cate_to_synsetid.values())
else:
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
# assert 'v2' in root_dir, "Only supporting v2 right now."
self.gravity_axis = 1
self.display_axis_order = [0, 2, 1]
super(ShapeNet15kPointClouds, self).__init__(
root_dir, self.synset_ids,
tr_sample_size=tr_sample_size,
te_sample_size=te_sample_size,
split=split, scale=scale,
normalize_per_shape=normalize_per_shape,
normalize_std_per_axis=normalize_std_per_axis,
random_subsample=random_subsample,
all_points_mean=all_points_mean, all_points_std=all_points_std,
input_dim=3)
def init_np_seed(worker_id):
seed = torch.initial_seed()
np.random.seed(seed % 4294967296)
def _get_MN40_datasets_(args, data_dir=None):
tr_dataset = ModelNet40PointClouds(
split='train',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
root_dir=(args.data_dir if data_dir is None else data_dir),
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
random_subsample=True)
te_dataset = ModelNet40PointClouds(
split='test',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
root_dir=(args.data_dir if data_dir is None else data_dir),
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return tr_dataset, te_dataset
def _get_MN10_datasets_(args, data_dir=None):
tr_dataset = ModelNet10PointClouds(
split='train',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
root_dir=(args.data_dir if data_dir is None else data_dir),
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
random_subsample=True)
te_dataset = ModelNet10PointClouds(
split='test',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
root_dir=(args.data_dir if data_dir is None else data_dir),
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return tr_dataset, te_dataset
def get_datasets(args):
if args.dataset_type == 'shapenet15k':
tr_dataset = ShapeNet15kPointClouds(
categories=args.cates, split='train',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
scale=args.dataset_scale, root_dir=args.data_dir,
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
random_subsample=True)
te_dataset = ShapeNet15kPointClouds(
categories=args.cates, split='val',
tr_sample_size=args.tr_max_sample_points,
te_sample_size=args.te_max_sample_points,
scale=args.dataset_scale, root_dir=args.data_dir,
normalize_per_shape=args.normalize_per_shape,
normalize_std_per_axis=args.normalize_std_per_axis,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
elif args.dataset_type == 'modelnet40_15k':
tr_dataset, te_dataset = _get_MN40_datasets_(args)
elif args.dataset_type == 'modelnet10_15k':
tr_dataset, te_dataset = _get_MN10_datasets_(args)
else:
raise Exception("Invalid dataset type:%s" % args.dataset_type)
return tr_dataset, te_dataset
def get_clf_datasets(args):
return {
'MN40': _get_MN40_datasets_(args, data_dir=args.mn40_data_dir),
'MN10': _get_MN10_datasets_(args, data_dir=args.mn10_data_dir),
}
def get_data_loaders(args):
tr_dataset, te_dataset = get_datasets(args)
train_loader = data.DataLoader(
dataset=tr_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers, drop_last=True,
worker_init_fn=init_np_seed)
train_unshuffle_loader = data.DataLoader(
dataset=tr_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=args.num_workers, drop_last=True,
worker_init_fn=init_np_seed)
test_loader = data.DataLoader(
dataset=te_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=args.num_workers, drop_last=False,
worker_init_fn=init_np_seed)
loaders = {
"test_loader": test_loader,
'train_loader': train_loader,
'train_unshuffle_loader': train_unshuffle_loader,
}
return loaders
if __name__ == "__main__":
shape_ds = ShapeNet15kPointClouds(categories=['airplane'], split='val')
x_tr, x_te = next(iter(shape_ds))
print(x_tr.shape)
print(x_te.shape)

60
demo.py Normal file
View file

@ -0,0 +1,60 @@
import open3d as o3d
from datasets import get_datasets
from args import get_args
from models.networks import PointFlow
import os
import torch
import numpy as np
import torch.nn as nn
def main(args):
model = PointFlow(args)
def _transform_(m):
return nn.DataParallel(m)
model = model.cuda()
model.multi_gpu_wrapper(_transform_)
print("Resume Path:%s" % args.resume_checkpoint)
checkpoint = torch.load(args.resume_checkpoint)
model.load_state_dict(checkpoint)
model.eval()
_, te_dataset = get_datasets(args)
if args.resume_dataset_mean is not None and args.resume_dataset_std is not None:
mean = np.load(args.resume_dataset_mean)
std = np.load(args.resume_dataset_std)
te_dataset.renormalize(mean, std)
ds_mean = torch.from_numpy(te_dataset.all_points_mean).cuda()
ds_std = torch.from_numpy(te_dataset.all_points_std).cuda()
all_sample = []
with torch.no_grad():
for i in range(0, args.num_sample_shapes, args.batch_size):
B = len(range(i, min(i + args.batch_size, args.num_sample_shapes)))
N = args.num_sample_points
_, out_pc = model.sample(B, N)
out_pc = out_pc * ds_std + ds_mean
all_sample.append(out_pc)
sample_pcs = torch.cat(all_sample, dim=0).cpu().detach().numpy()
print("Generation sample size:(%s, %s, %s)" % sample_pcs.shape)
# Save the generative output
os.makedirs("demo", exist_ok=True)
np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs)
# Visualize the demo
pcl = o3d.geometry.PointCloud()
for i in range(int(sample_pcs.shape[0])):
print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
pts = sample_pcs[i].reshape(-1, 3)
pcl.points = o3d.utility.Vector3dVector(pts)
o3d.visualization.draw_geometries([pcl])
if __name__ == '__main__':
args = get_args()
main(args)

19
install.sh Executable file
View file

@ -0,0 +1,19 @@
#! /bin/bash
root=`pwd`
# Install dependecies
conda install numpy matplotlib pillow scipy tqdm scikit-learn -y
pip install tensorflow-gpu==1.13.1
pip install tensorboardX==1.7
# Compile CUDA kernel for CD/EMD loss
cd metrics/pytorch_structural_losses/
make clean
make
cd $root
# install torchdiffeq
git clone https://github.com/rtqichen/torchdiffeq.git
cd torchdiffeq
pip install -e .

1
metrics/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
StructuralLosses

0
metrics/__init__.py Normal file
View file

View file

@ -0,0 +1,336 @@
import torch
import numpy as np
import warnings
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/
from .StructuralLosses.match_cost import match_cost
from .StructuralLosses.nn_distance import nn_distance
# # Import CUDA version of CD, borrowed from https://github.com/ThibaultGROUEIX/AtlasNet
# try:
# from . chamfer_distance_ext.dist_chamfer import chamferDist
# CD = chamferDist()
# def distChamferCUDA(x,y):
# return CD(x,y,gpu)
# except:
def distChamferCUDA(x, y):
return nn_distance(x, y)
def emd_approx(sample, ref):
B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
assert N == N_ref, "Not sure what would EMD do in this case"
emd = match_cost(sample, ref) # (B,)
emd_norm = emd / float(N) # (B,)
return emd_norm
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
def distChamfer(a, b):
x, y = a, b
bs, num_points, points_dim = x.size()
xx = torch.bmm(x, x.transpose(2, 1))
yy = torch.bmm(y, y.transpose(2, 1))
zz = torch.bmm(x, y.transpose(2, 1))
diag_ind = torch.arange(0, num_points).to(a).long()
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
P = (rx.transpose(2, 1) + ry - 2 * zz)
return P.min(1)[0], P.min(2)[0]
def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
cd_lst = []
emd_lst = []
iterator = range(0, N_sample, batch_size)
for b_start in iterator:
b_end = min(N_sample, b_start + batch_size)
sample_batch = sample_pcs[b_start:b_end]
ref_batch = ref_pcs[b_start:b_end]
if accelerated_cd:
dl, dr = distChamferCUDA(sample_batch, ref_batch)
else:
dl, dr = distChamfer(sample_batch, ref_batch)
cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))
emd_batch = emd_approx(sample_batch, ref_batch)
emd_lst.append(emd_batch)
if reduced:
cd = torch.cat(cd_lst).mean()
emd = torch.cat(emd_lst).mean()
else:
cd = torch.cat(cd_lst)
emd = torch.cat(emd_lst)
results = {
'MMD-CD': cd,
'MMD-EMD': emd,
}
return results
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
all_cd = []
all_emd = []
iterator = range(N_sample)
for sample_b_start in iterator:
sample_batch = sample_pcs[sample_b_start]
cd_lst = []
emd_lst = []
for ref_b_start in range(0, N_ref, batch_size):
ref_b_end = min(N_ref, ref_b_start + batch_size)
ref_batch = ref_pcs[ref_b_start:ref_b_end]
batch_size_ref = ref_batch.size(0)
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
sample_batch_exp = sample_batch_exp.contiguous()
if accelerated_cd:
dl, dr = distChamferCUDA(sample_batch_exp, ref_batch)
else:
dl, dr = distChamfer(sample_batch_exp, ref_batch)
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
emd_batch = emd_approx(sample_batch_exp, ref_batch)
emd_lst.append(emd_batch.view(1, -1))
cd_lst = torch.cat(cd_lst, dim=1)
emd_lst = torch.cat(emd_lst, dim=1)
all_cd.append(cd_lst)
all_emd.append(emd_lst)
all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref
all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref
return all_cd, all_emd
# Adapted from https://github.com/xuqiantong/GAN-Metrics/blob/master/framework/metric.py
def knn(Mxx, Mxy, Myy, k, sqrt=False):
n0 = Mxx.size(0)
n1 = Myy.size(0)
label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx)
M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)
if sqrt:
M = M.abs().sqrt()
INFINITY = float('inf')
val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)
count = torch.zeros(n0 + n1).to(Mxx)
for i in range(0, k):
count = count + label.index_select(0, idx[i])
pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()
s = {
'tp': (pred * label).sum(),
'fp': (pred * (1 - label)).sum(),
'fn': ((1 - pred) * label).sum(),
'tn': ((1 - pred) * (1 - label)).sum(),
}
s.update({
'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10),
'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10),
'acc': torch.eq(label, pred).float().mean(),
})
return s
def lgan_mmd_cov(all_dist):
N_sample, N_ref = all_dist.size(0), all_dist.size(1)
min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
min_val, _ = torch.min(all_dist, dim=0)
mmd = min_val.mean()
mmd_smp = min_val_fromsmp.mean()
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
cov = torch.tensor(cov).to(all_dist)
return {
'lgan_mmd': mmd,
'lgan_cov': cov,
'lgan_mmd_smp': mmd_smp,
}
def compute_all_metrics(sample_pcs, ref_pcs, batch_size, accelerated_cd=False):
results = {}
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
res_cd = lgan_mmd_cov(M_rs_cd.t())
results.update({
"%s-CD" % k: v for k, v in res_cd.items()
})
res_emd = lgan_mmd_cov(M_rs_emd.t())
results.update({
"%s-EMD" % k: v for k, v in res_emd.items()
})
M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size, accelerated_cd=accelerated_cd)
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
# 1-NN results
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
results.update({
"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
})
one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
results.update({
"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
})
return results
#######################################################
# JSD : from https://github.com/optas/latent_3d_points
#######################################################
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
"""Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
that is placed in the unit-cube.
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
"""
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
spacing = 1.0 / float(resolution - 1)
for i in range(resolution):
for j in range(resolution):
for k in range(resolution):
grid[i, j, k, 0] = i * spacing - 0.5
grid[i, j, k, 1] = j * spacing - 0.5
grid[i, j, k, 2] = k * spacing - 0.5
if clip_sphere:
grid = grid.reshape(-1, 3)
grid = grid[norm(grid, axis=1) <= 0.5]
return grid, spacing
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
"""Computes the JSD between two sets of point-clouds, as introduced in the paper
```Learning Representations And Generative Models For 3D Point Clouds```.
Args:
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
resolution: (int) grid-resolution. Affects granularity of measurements.
"""
in_unit_sphere = True
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
return jensen_shannon_divergence(sample_grid_var, ref_grid_var)
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose=False):
"""Given a collection of point-clouds, estimate the entropy of the random variables
corresponding to occupancy-grid activation patterns.
Inputs:
pclouds: (numpy array) #point-clouds x points per point-cloud x 3
grid_resolution (int) size of occupancy grid that will be used.
"""
epsilon = 10e-4
bound = 0.5 + epsilon
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
if verbose:
warnings.warn('Point-clouds are not in unit cube.')
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
if verbose:
warnings.warn('Point-clouds are not in unit sphere.')
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
grid_coordinates = grid_coordinates.reshape(-1, 3)
grid_counters = np.zeros(len(grid_coordinates))
grid_bernoulli_rvars = np.zeros(len(grid_coordinates))
nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)
for pc in pclouds:
_, indices = nn.kneighbors(pc)
indices = np.squeeze(indices)
for i in indices:
grid_counters[i] += 1
indices = np.unique(indices)
for i in indices:
grid_bernoulli_rvars[i] += 1
acc_entropy = 0.0
n = float(len(pclouds))
for g in grid_bernoulli_rvars:
if g > 0:
p = float(g) / n
acc_entropy += entropy([p, 1.0 - p])
return acc_entropy / len(grid_counters), grid_counters
def jensen_shannon_divergence(P, Q):
if np.any(P < 0) or np.any(Q < 0):
raise ValueError('Negative values.')
if len(P) != len(Q):
raise ValueError('Non equal size.')
P_ = P / np.sum(P) # Ensure probabilities.
Q_ = Q / np.sum(Q)
e1 = entropy(P_, base=2)
e2 = entropy(Q_, base=2)
e_sum = entropy((P_ + Q_) / 2.0, base=2)
res = e_sum - ((e1 + e2) / 2.0)
res2 = _jsdiv(P_, Q_)
if not np.allclose(res, res2, atol=10e-5, rtol=0):
warnings.warn('Numerical values of two JSD methods don\'t agree.')
return res
def _jsdiv(P, Q):
"""another way of computing JSD"""
def _kldiv(A, B):
a = A.copy()
b = B.copy()
idx = np.logical_and(a > 0, b > 0)
a = a[idx]
b = b[idx]
return np.sum([v for v in a * np.log2(a / b)])
P_ = P / np.sum(P)
Q_ = Q / np.sum(Q)
M = 0.5 * (P_ + Q_)
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
if __name__ == "__main__":
B, N = 2, 10
x = torch.rand(B, N, 3)
y = torch.rand(B, N, 3)
distChamfer = distChamferCUDA()
min_l, min_r = distChamfer(x.cuda(), y.cuda())
print(min_l.shape)
print(min_r.shape)
l_dist = min_l.mean().cpu().detach().item()
r_dist = min_r.mean().cpu().detach().item()
print(l_dist, r_dist)

View file

@ -0,0 +1 @@
PyTorchStructuralLosses.egg-info/

View file

@ -0,0 +1,100 @@
###############################################################################
# Uncomment for debugging
# DEBUG := 1
# Pretty build
# Q ?= @
CXX := g++
PYTHON := python
NVCC := /usr/local/cuda/bin/nvcc
# PYTHON Header path
PYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())')
PYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]')
PYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]')
# CUDA ROOT DIR that contains bin/ lib64/ and include/
# CUDA_DIR := /usr/local/cuda
CUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())')
INCLUDE_DIRS := ./ $(CUDA_DIR)/include
INCLUDE_DIRS += $(PYTHON_HEADER_DIR)
INCLUDE_DIRS += $(PYTORCH_INCLUDES)
# Custom (MKL/ATLAS/OpenBLAS) include and lib directories.
# Leave commented to accept the defaults for your choice of BLAS
# (which should work)!
# BLAS_INCLUDE := /path/to/your/blas
# BLAS_LIB := /path/to/your/blas
###############################################################################
SRC_DIR := ./src
OBJ_DIR := ./objs
CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp)
CU_SRCS := $(wildcard $(SRC_DIR)/*.cu)
OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS))
CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS))
STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a
# CUDA architecture setting: going with all of them.
# For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility.
# For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility.
CUDA_ARCH := -gencode arch=compute_61,code=sm_61 \
-gencode arch=compute_61,code=compute_61 \
-gencode arch=compute_52,code=sm_52
# We will also explicitly add stdc++ to the link target.
LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu
# Debugging
ifeq ($(DEBUG), 1)
COMMON_FLAGS += -DDEBUG -g -O0
# https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/
NVCCFLAGS += -g -G # -rdc true
else
COMMON_FLAGS += -DNDEBUG -O3
endif
WARNINGS := -Wall -Wno-sign-compare -Wcomment
INCLUDE_DIRS += $(BLAS_INCLUDE)
# Automatic dependency generation (nvcc is handled separately)
CXXFLAGS += -MMD -MP
# Complete build flags.
COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \
-DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0
CXXFLAGS += -pthread -fPIC -fwrapv -std=c++11 $(COMMON_FLAGS) $(WARNINGS)
NVCCFLAGS += -std=c++11 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
all: $(STATIC_LIB)
$(PYTHON) setup.py build
@ mv build/lib.linux-x86_64-3.6/StructuralLosses ..
@ mv build/lib.linux-x86_64-3.6/*.so ../StructuralLosses/
@- $(RM) -rf $(OBJ_DIR) build objs
$(OBJ_DIR):
@ mkdir -p $@
@ mkdir -p $@/cuda
$(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR)
@ echo CXX $<
$(Q)$(CXX) $< $(CXXFLAGS) -c -o $@
$(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR)
@ echo NVCC $<
$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \
-odir $(@D)
$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@
$(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR)
$(RM) -f $(STATIC_LIB)
$(RM) -rf build dist
@ echo LD -o $@
ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS)
clean:
@- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses

View file

@ -0,0 +1,6 @@
#import torch
#from MakePytorchBackend import AddGPU, Foo, ApproxMatch
#from Add import add_gpu, approx_match

View file

@ -0,0 +1,45 @@
import torch
from torch.autograd import Function
from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad
# Inherit from Function
class MatchCostFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, seta, setb):
#print("Match Cost Forward")
ctx.save_for_backward(seta, setb)
'''
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
match : batch_size * #query_points * #dataset_points
'''
match, temp = ApproxMatch(seta, setb)
ctx.match = match
cost = MatchCost(seta, setb, match)
return cost
"""
grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)
return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]
"""
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
#print("Match Cost Backward")
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
seta, setb = ctx.saved_tensors
#grad_input = grad_weight = grad_bias = None
grada, gradb = MatchCostGrad(seta, setb, ctx.match)
grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)
return grada*grad_output_expand, gradb*grad_output_expand
match_cost = MatchCostFunction.apply

View file

@ -0,0 +1,42 @@
import torch
from torch.autograd import Function
# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
# Inherit from Function
class NNDistanceFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, seta, setb):
#print("Match Cost Forward")
ctx.save_for_backward(seta, setb)
'''
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
dist1, idx1, dist2, idx2
'''
dist1, idx1, dist2, idx2 = NNDistance(seta, setb)
ctx.idx1 = idx1
ctx.idx2 = idx2
return dist1, dist2
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_dist1, grad_dist2):
#print("Match Cost Backward")
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
seta, setb = ctx.saved_tensors
idx1 = ctx.idx1
idx2 = ctx.idx2
grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)
return grada, gradb
nn_distance = NNDistanceFunction.apply

View file

@ -0,0 +1,15 @@
#include <string>
#include <torch/extension.h>
#include "pybind/extern.hpp"
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("ApproxMatch", &ApproxMatch);
m.def("MatchCost", &MatchCost);
m.def("MatchCostGrad", &MatchCostGrad);
m.def("NNDistance", &NNDistance);
m.def("NNDistanceGrad", &NNDistanceGrad);
}

View file

@ -0,0 +1,6 @@
std::vector<at::Tensor> ApproxMatch(at::Tensor in_a, at::Tensor in_b);
at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
std::vector<at::Tensor> MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
std::vector<at::Tensor> NNDistance(at::Tensor set_d, at::Tensor set_q);
std::vector<at::Tensor> NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2);

View file

@ -0,0 +1,30 @@
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
# Python interface
setup(
name='PyTorchStructuralLosses',
version='0.1.0',
install_requires=['torch'],
packages=['StructuralLosses'],
package_dir={'StructuralLosses': './'},
ext_modules=[
CUDAExtension(
name='StructuralLossesBackend',
include_dirs=['./'],
sources=[
'pybind/bind.cpp',
],
libraries=['make_pytorch'],
library_dirs=['objs'],
# extra_compile_args=['-g']
)
],
cmdclass={'build_ext': BuildExtension},
author='Christopher B. Choy',
author_email='chrischoy@ai.stanford.edu',
description='Tutorial for Pytorch C++ Extension with a Makefile',
keywords='Pytorch C++ Extension',
url='https://github.com/chrischoy/MakePytorchPlusPlus',
zip_safe=False,
)

View file

@ -0,0 +1,326 @@
#include "utils.hpp"
__global__ void approxmatchkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){
float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
float multiL,multiR;
if (n>=m){
multiL=1;
multiR=n/m;
}else{
multiL=m/n;
multiR=1;
}
const int Block=1024;
__shared__ float buf[Block*4];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x;j<n*m;j+=blockDim.x)
match[i*n*m+j]=0;
for (int j=threadIdx.x;j<n;j+=blockDim.x)
remainL[j]=multiL;
for (int j=threadIdx.x;j<m;j+=blockDim.x)
remainR[j]=multiR;
__syncthreads();
//for (int j=7;j>=-2;j--){
for (int j=7;j>-2;j--){
float level=-powf(4.0f,j);
if (j==-2){
level=0;
}
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
float x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
float suml=1e-9f;
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
float x2=xyz2[i*m*3+l0*3+l*3+0];
float y2=xyz2[i*m*3+l0*3+l*3+1];
float z2=xyz2[i*m*3+l0*3+l*3+2];
buf[l*4+0]=x2;
buf[l*4+1]=y2;
buf[l*4+2]=z2;
buf[l*4+3]=remainR[l0+l];
}
__syncthreads();
for (int l=0;l<lend;l++){
float x2=buf[l*4+0];
float y2=buf[l*4+1];
float z2=buf[l*4+2];
float d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));
float w=__expf(d)*buf[l*4+3];
suml+=w;
}
__syncthreads();
}
if (k<n)
ratioL[k]=remainL[k]/suml;
}
/*for (int k=threadIdx.x;k<n;k+=gridDim.x){
float x1=xyz1[i*n*3+k*3+0];
float y1=xyz1[i*n*3+k*3+1];
float z1=xyz1[i*n*3+k*3+2];
float suml=1e-9f;
for (int l=0;l<m;l++){
float x2=xyz2[i*m*3+l*3+0];
float y2=xyz2[i*m*3+l*3+1];
float z2=xyz2[i*m*3+l*3+2];
float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*remainR[l];
suml+=w;
}
ratioL[k]=remainL[k]/suml;
}*/
__syncthreads();
for (int l0=0;l0<m;l0+=blockDim.x){
int l=l0+threadIdx.x;
float x2=0,y2=0,z2=0;
if (l<m){
x2=xyz2[i*m*3+l*3+0];
y2=xyz2[i*m*3+l*3+1];
z2=xyz2[i*m*3+l*3+2];
}
float sumr=0;
for (int k0=0;k0<n;k0+=Block){
int kend=min(n,k0+Block)-k0;
for (int k=threadIdx.x;k<kend;k+=blockDim.x){
buf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];
buf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];
buf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];
buf[k*4+3]=ratioL[k0+k];
}
__syncthreads();
for (int k=0;k<kend;k++){
float x1=buf[k*4+0];
float y1=buf[k*4+1];
float z1=buf[k*4+2];
float w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];
sumr+=w;
}
__syncthreads();
}
if (l<m){
sumr*=remainR[l];
float consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
ratioR[l]=consumption*remainR[l];
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
}
}
/*for (int l=threadIdx.x;l<m;l+=blockDim.x){
float x2=xyz2[i*m*3+l*3+0];
float y2=xyz2[i*m*3+l*3+1];
float z2=xyz2[i*m*3+l*3+2];
float sumr=0;
for (int k=0;k<n;k++){
float x1=xyz1[i*n*3+k*3+0];
float y1=xyz1[i*n*3+k*3+1];
float z1=xyz1[i*n*3+k*3+2];
float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*ratioL[k];
sumr+=w;
}
sumr*=remainR[l];
float consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
ratioR[l]=consumption*remainR[l];
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
}*/
__syncthreads();
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
float x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
float suml=0;
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
buf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];
buf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];
buf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];
buf[l*4+3]=ratioR[l0+l];
}
__syncthreads();
float rl=ratioL[k];
if (k<n){
for (int l=0;l<lend;l++){
float x2=buf[l*4+0];
float y2=buf[l*4+1];
float z2=buf[l*4+2];
float w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];
match[i*n*m+(l0+l)*n+k]+=w;
suml+=w;
}
}
__syncthreads();
}
if (k<n)
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
}
/*for (int k=threadIdx.x;k<n;k+=blockDim.x){
float x1=xyz1[i*n*3+k*3+0];
float y1=xyz1[i*n*3+k*3+1];
float z1=xyz1[i*n*3+k*3+2];
float suml=0;
for (int l=0;l<m;l++){
float x2=xyz2[i*m*3+l*3+0];
float y2=xyz2[i*m*3+l*3+1];
float z2=xyz2[i*m*3+l*3+2];
float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*ratioL[k]*ratioR[l];
match[i*n*m+l*n+k]+=w;
suml+=w;
}
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
}*/
__syncthreads();
}
}
}
__global__ void matchcostkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ out){
__shared__ float allsum[512];
const int Block=256;
__shared__ float buf[Block*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
float subsum=0;
for (int k0=0;k0<m;k0+=Block){
int endk=min(m,k0+Block);
for (int k=threadIdx.x;k<(endk-k0)*3;k+=blockDim.x){
buf[k]=xyz2[i*m*3+k0*3+k];
}
__syncthreads();
for (int j=threadIdx.x;j<n;j+=blockDim.x){
float x1=xyz1[(i*n+j)*3+0];
float y1=xyz1[(i*n+j)*3+1];
float z1=xyz1[(i*n+j)*3+2];
for (int k=0;k<endk-k0;k++){
//float x2=xyz2[(i*m+k)*3+0]-x1;
//float y2=xyz2[(i*m+k)*3+1]-y1;
//float z2=xyz2[(i*m+k)*3+2]-z1;
float x2=buf[k*3+0]-x1;
float y2=buf[k*3+1]-y1;
float z2=buf[k*3+2]-z1;
float d=sqrtf(x2*x2+y2*y2+z2*z2);
subsum+=match[i*n*m+(k0+k)*n+j]*d;
}
}
__syncthreads();
}
allsum[threadIdx.x]=subsum;
for (int j=1;j<blockDim.x;j<<=1){
__syncthreads();
if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){
allsum[threadIdx.x]+=allsum[threadIdx.x+j];
}
}
if (threadIdx.x==0)
out[i]=allsum[0];
__syncthreads();
}
}
//void matchcostLauncher(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * out){
// matchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);
//}
__global__ void matchcostgrad2kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){
__shared__ float sum_grad[256*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
int kbeg=m*blockIdx.y/gridDim.y;
int kend=m*(blockIdx.y+1)/gridDim.y;
for (int k=kbeg;k<kend;k++){
float x2=xyz2[(i*m+k)*3+0];
float y2=xyz2[(i*m+k)*3+1];
float z2=xyz2[(i*m+k)*3+2];
float subsumx=0,subsumy=0,subsumz=0;
for (int j=threadIdx.x;j<n;j+=blockDim.x){
float x1=x2-xyz1[(i*n+j)*3+0];
float y1=y2-xyz1[(i*n+j)*3+1];
float z1=z2-xyz1[(i*n+j)*3+2];
float d=match[i*n*m+k*n+j]*rsqrtf(fmaxf(x1*x1+y1*y1+z1*z1,1e-20f));
subsumx+=x1*d;
subsumy+=y1*d;
subsumz+=z1*d;
}
sum_grad[threadIdx.x*3+0]=subsumx;
sum_grad[threadIdx.x*3+1]=subsumy;
sum_grad[threadIdx.x*3+2]=subsumz;
for (int j=1;j<blockDim.x;j<<=1){
__syncthreads();
int j1=threadIdx.x;
int j2=threadIdx.x+j;
if ((j1&j)==0 && j2<blockDim.x){
sum_grad[j1*3+0]+=sum_grad[j2*3+0];
sum_grad[j1*3+1]+=sum_grad[j2*3+1];
sum_grad[j1*3+2]+=sum_grad[j2*3+2];
}
}
if (threadIdx.x==0){
grad2[(i*m+k)*3+0]=sum_grad[0];
grad2[(i*m+k)*3+1]=sum_grad[1];
grad2[(i*m+k)*3+2]=sum_grad[2];
}
__syncthreads();
}
}
}
__global__ void matchcostgrad1kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad1){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int l=threadIdx.x;l<n;l+=blockDim.x){
float x1=xyz1[i*n*3+l*3+0];
float y1=xyz1[i*n*3+l*3+1];
float z1=xyz1[i*n*3+l*3+2];
float dx=0,dy=0,dz=0;
for (int k=0;k<m;k++){
float x2=xyz2[i*m*3+k*3+0];
float y2=xyz2[i*m*3+k*3+1];
float z2=xyz2[i*m*3+k*3+2];
float d=match[i*n*m+k*n+l]*rsqrtf(fmaxf((x1-x2)*(x1-x2)+(y1-y2)*(y1-y2)+(z1-z2)*(z1-z2),1e-20f));
dx+=(x1-x2)*d;
dy+=(y1-y2)*d;
dz+=(z1-z2)*d;
}
grad1[i*n*3+l*3+0]=dx;
grad1[i*n*3+l*3+1]=dy;
grad1[i*n*3+l*3+2]=dz;
}
}
}
//void matchcostgradLauncher(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad2){
// matchcostgrad<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);
//}
/*void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,
cudaStream_t stream)*/
// temp: TensorShape{b,(n+m)*2}
void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream){
approxmatchkernel
<<<32, 512, 0, stream>>>(b,n,m,xyz1,xyz2,match,temp);
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
throw std::runtime_error(Formatter()
<< "CUDA kernel failed : " << std::to_string(err));
}
void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream){
matchcostkernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,out);
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
throw std::runtime_error(Formatter()
<< "CUDA kernel failed : " << std::to_string(err));
}
void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream){
matchcostgrad1kernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,grad1);
matchcostgrad2kernel<<<dim3(32,32),256,0,stream>>>(b,n,m,xyz1,xyz2,match,grad2);
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
throw std::runtime_error(Formatter()
<< "CUDA kernel failed : " << std::to_string(err));
}

View file

@ -0,0 +1,8 @@
/*
template <typename Dtype>
void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,
cudaStream_t stream);
*/
void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream);
void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream);
void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream);

View file

@ -0,0 +1,155 @@
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
const int batch=512;
__shared__ float buf[batch*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int k2=0;k2<m;k2+=batch){
int end_k=min(m,k2+batch)-k2;
for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
buf[j]=xyz2[(i*m+k2)*3+j];
}
__syncthreads();
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz[(i*n+j)*3+0];
float y1=xyz[(i*n+j)*3+1];
float z1=xyz[(i*n+j)*3+2];
int best_i=0;
float best=0;
int end_ka=end_k-(end_k&3);
if (end_ka==batch){
for (int k=0;k<batch;k+=4){
{
float x2=buf[k*3+0]-x1;
float y2=buf[k*3+1]-y1;
float z2=buf[k*3+2]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*3+3]-x1;
float y2=buf[k*3+4]-y1;
float z2=buf[k*3+5]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*3+6]-x1;
float y2=buf[k*3+7]-y1;
float z2=buf[k*3+8]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*3+9]-x1;
float y2=buf[k*3+10]-y1;
float z2=buf[k*3+11]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}else{
for (int k=0;k<end_ka;k+=4){
{
float x2=buf[k*3+0]-x1;
float y2=buf[k*3+1]-y1;
float z2=buf[k*3+2]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*3+3]-x1;
float y2=buf[k*3+4]-y1;
float z2=buf[k*3+5]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*3+6]-x1;
float y2=buf[k*3+7]-y1;
float z2=buf[k*3+8]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*3+9]-x1;
float y2=buf[k*3+10]-y1;
float z2=buf[k*3+11]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}
for (int k=end_ka;k<end_k;k++){
float x2=buf[k*3+0]-x1;
float y2=buf[k*3+1]-y1;
float z2=buf[k*3+2]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
if (k2==0 || result[(i*n+j)]>best){
result[(i*n+j)]=best;
result_i[(i*n+j)]=best_i;
}
}
__syncthreads();
}
}
}
void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
NmDistanceKernel<<<dim3(32,16,1),512, 0, stream>>>(b,n,xyz,m,xyz2,result,result_i);
NmDistanceKernel<<<dim3(32,16,1),512, 0, stream>>>(b,m,xyz2,n,xyz,result2,result2_i);
}
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz1[(i*n+j)*3+0];
float y1=xyz1[(i*n+j)*3+1];
float z1=xyz1[(i*n+j)*3+2];
int j2=idx1[i*n+j];
float x2=xyz2[(i*m+j2)*3+0];
float y2=xyz2[(i*m+j2)*3+1];
float z2=xyz2[(i*m+j2)*3+2];
float g=grad_dist1[i*n+j]*2;
atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
}
}
}
void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
cudaMemset(grad_xyz1,0,b*n*3*4);
cudaMemset(grad_xyz2,0,b*m*3*4);
NmDistanceGradKernel<<<dim3(1,16,1),256, 0, stream>>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2);
NmDistanceGradKernel<<<dim3(1,16,1),256, 0, stream>>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1);
}

View file

@ -0,0 +1,2 @@
void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);
void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);

View file

@ -0,0 +1,125 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "src/approxmatch.cuh"
#include "src/nndistance.cuh"
#include <vector>
#include <iostream>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
/*
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
match : batch_size * #query_points * #dataset_points
*/
// temp: TensorShape{b,(n+m)*2}
std::vector<at::Tensor> ApproxMatch(at::Tensor set_d, at::Tensor set_q) {
//std::cout << "[ApproxMatch] Called." << std::endl;
int64_t batch_size = set_d.size(0);
int64_t n_dataset_points = set_d.size(1); // n
int64_t n_query_points = set_q.size(1); // m
//std::cout << "[ApproxMatch] batch_size:" << batch_size << std::endl;
at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
CHECK_INPUT(set_d);
CHECK_INPUT(set_q);
CHECK_INPUT(match);
CHECK_INPUT(temp);
approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),temp.data<float>(), at::cuda::getCurrentCUDAStream());
return {match, temp};
}
at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {
//std::cout << "[MatchCost] Called." << std::endl;
int64_t batch_size = set_d.size(0);
int64_t n_dataset_points = set_d.size(1); // n
int64_t n_query_points = set_q.size(1); // m
//std::cout << "[MatchCost] batch_size:" << batch_size << std::endl;
at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
CHECK_INPUT(set_d);
CHECK_INPUT(set_q);
CHECK_INPUT(match);
CHECK_INPUT(out);
matchcost(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),out.data<float>(),at::cuda::getCurrentCUDAStream());
return out;
}
std::vector<at::Tensor> MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {
//std::cout << "[MatchCostGrad] Called." << std::endl;
int64_t batch_size = set_d.size(0);
int64_t n_dataset_points = set_d.size(1); // n
int64_t n_query_points = set_q.size(1); // m
//std::cout << "[MatchCostGrad] batch_size:" << batch_size << std::endl;
at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
CHECK_INPUT(set_d);
CHECK_INPUT(set_q);
CHECK_INPUT(match);
CHECK_INPUT(grad1);
CHECK_INPUT(grad2);
matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),grad1.data<float>(),grad2.data<float>(),at::cuda::getCurrentCUDAStream());
return {grad1, grad2};
}
/*
input:
set_d : batch_size * #dataset_points * 3
set_q : batch_size * #query_points * 3
returns:
dist1, idx1 : batch_size * #dataset_points
dist2, idx2 : batch_size * #query_points
*/
std::vector<at::Tensor> NNDistance(at::Tensor set_d, at::Tensor set_q) {
//std::cout << "[NNDistance] Called." << std::endl;
int64_t batch_size = set_d.size(0);
int64_t n_dataset_points = set_d.size(1); // n
int64_t n_query_points = set_q.size(1); // m
//std::cout << "[NNDistance] batch_size:" << batch_size << std::endl;
at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));
at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));
CHECK_INPUT(set_d);
CHECK_INPUT(set_q);
CHECK_INPUT(dist1);
CHECK_INPUT(idx1);
CHECK_INPUT(dist2);
CHECK_INPUT(idx2);
// void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);
nndistance(batch_size,n_dataset_points,set_d.data<float>(),n_query_points,set_q.data<float>(),dist1.data<float>(),idx1.data<int>(),dist2.data<float>(),idx2.data<int>(), at::cuda::getCurrentCUDAStream());
return {dist1, idx1, dist2, idx2};
}
std::vector<at::Tensor> NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) {
//std::cout << "[NNDistanceGrad] Called." << std::endl;
int64_t batch_size = set_d.size(0);
int64_t n_dataset_points = set_d.size(1); // n
int64_t n_query_points = set_q.size(1); // m
//std::cout << "[NNDistanceGrad] batch_size:" << batch_size << std::endl;
at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
CHECK_INPUT(set_d);
CHECK_INPUT(set_q);
CHECK_INPUT(idx1);
CHECK_INPUT(idx2);
CHECK_INPUT(grad_dist1);
CHECK_INPUT(grad_dist2);
CHECK_INPUT(grad1);
CHECK_INPUT(grad2);
//void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);
nndistancegrad(batch_size,n_dataset_points,set_d.data<float>(),n_query_points,set_q.data<float>(),
grad_dist1.data<float>(),idx1.data<int>(),
grad_dist2.data<float>(),idx2.data<int>(),
grad1.data<float>(),grad2.data<float>(),
at::cuda::getCurrentCUDAStream());
return {grad1, grad2};
}

View file

@ -0,0 +1,26 @@
#include <iostream>
#include <sstream>
#include <string>
class Formatter {
public:
Formatter() {}
~Formatter() {}
template <typename Type> Formatter &operator<<(const Type &value) {
stream_ << value;
return *this;
}
std::string str() const { return stream_.str(); }
operator std::string() const { return stream_.str(); }
enum ConvertToString { to_str };
std::string operator>>(ConvertToString) { return stream_.str(); }
private:
std::stringstream stream_;
Formatter(const Formatter &);
Formatter &operator=(Formatter &);
};

0
models/__init__.py Normal file
View file

122
models/cnf.py Normal file
View file

@ -0,0 +1,122 @@
import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint
from torchdiffeq import odeint as odeint_normal
__all__ = ["CNF", "SequentialFlow"]
class SequentialFlow(nn.Module):
"""A generalized nn.Sequential container for normalizing flows."""
def __init__(self, layer_list):
super(SequentialFlow, self).__init__()
self.chain = nn.ModuleList(layer_list)
def forward(self, x, context, logpx=None, reverse=False, inds=None, integration_times=None):
if inds is None:
if reverse:
inds = range(len(self.chain) - 1, -1, -1)
else:
inds = range(len(self.chain))
if logpx is None:
for i in inds:
x = self.chain[i](x, context, logpx, integration_times, reverse)
return x
else:
for i in inds:
x, logpx = self.chain[i](x, context, logpx, integration_times, reverse)
return x, logpx
class CNF(nn.Module):
def __init__(self, odefunc, conditional=True, T=1.0, train_T=False, regularization_fns=None,
solver='dopri5', atol=1e-5, rtol=1e-5, use_adjoint=True):
super(CNF, self).__init__()
self.train_T = train_T
self.T = T
if train_T:
self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T))))
if regularization_fns is not None and len(regularization_fns) > 0:
raise NotImplementedError("Regularization not supported")
self.use_adjoint = use_adjoint
self.odefunc = odefunc
self.solver = solver
self.atol = atol
self.rtol = rtol
self.test_solver = solver
self.test_atol = atol
self.test_rtol = rtol
self.solver_options = {}
self.conditional = conditional
def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False):
if logpx is None:
_logpx = torch.zeros(*x.shape[:-1], 1).to(x)
else:
_logpx = logpx
if self.conditional:
assert context is not None
states = (x, _logpx, context)
atol = [self.atol] * 3
rtol = [self.rtol] * 3
else:
states = (x, _logpx)
atol = [self.atol] * 2
rtol = [self.rtol] * 2
if integration_times is None:
if self.train_T:
integration_times = torch.stack(
[torch.tensor(0.0).to(x), self.sqrt_end_time * self.sqrt_end_time]
).to(x)
else:
integration_times = torch.tensor([0., self.T], requires_grad=False).to(x)
if reverse:
integration_times = _flip(integration_times, 0)
# Refresh the odefunc statistics.
self.odefunc.before_odeint()
odeint = odeint_adjoint if self.use_adjoint else odeint_normal
if self.training:
state_t = odeint(
self.odefunc,
states,
integration_times.to(x),
atol=atol,
rtol=rtol,
method=self.solver,
options=self.solver_options,
)
else:
state_t = odeint(
self.odefunc,
states,
integration_times.to(x),
atol=self.test_atol,
rtol=self.test_rtol,
method=self.test_solver,
)
if len(integration_times) == 2:
state_t = tuple(s[1] for s in state_t)
z_t, logpz_t = state_t[:2]
if logpx is not None:
return z_t, logpz_t
else:
return z_t
def num_evals(self):
return self.odefunc._num_evals.item()
def _flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
return x[tuple(indices)]

103
models/diffeq_layers.py Normal file
View file

@ -0,0 +1,103 @@
import torch
import torch.nn as nn
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1 or classname.find('Conv') != -1:
nn.init.constant_(m.weight, 0)
nn.init.normal_(m.bias, 0, 0.01)
class IgnoreLinear(nn.Module):
def __init__(self, dim_in, dim_out, dim_c):
super(IgnoreLinear, self).__init__()
self._layer = nn.Linear(dim_in, dim_out)
def forward(self, context, x):
return self._layer(x)
class ConcatLinear(nn.Module):
def __init__(self, dim_in, dim_out, dim_c):
super(ConcatLinear, self).__init__()
self._layer = nn.Linear(dim_in + 1 + dim_c, dim_out)
def forward(self, context, x, c):
if x.dim() == 3:
context = context.unsqueeze(1).expand(-1, x.size(1), -1)
x_context = torch.cat((x, context), dim=2)
return self._layer(x_context)
class ConcatLinear_v2(nn.Module):
def __init__(self, dim_in, dim_out, dim_c):
super(ConcatLinear_v2, self).__init__()
self._layer = nn.Linear(dim_in, dim_out)
self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False)
def forward(self, context, x):
bias = self._hyper_bias(context)
if x.dim() == 3:
bias = bias.unsqueeze(1)
return self._layer(x) + bias
class SquashLinear(nn.Module):
def __init__(self, dim_in, dim_out, dim_c):
super(SquashLinear, self).__init__()
self._layer = nn.Linear(dim_in, dim_out)
self._hyper = nn.Linear(1 + dim_c, dim_out)
def forward(self, context, x):
gate = torch.sigmoid(self._hyper(context))
if x.dim() == 3:
gate = gate.unsqueeze(1)
return self._layer(x) * gate
class ScaleLinear(nn.Module):
def __init__(self, dim_in, dim_out, dim_c):
super(ScaleLinear, self).__init__()
self._layer = nn.Linear(dim_in, dim_out)
self._hyper = nn.Linear(1 + dim_c, dim_out)
def forward(self, context, x):
gate = self._hyper(context)
if x.dim() == 3:
gate = gate.unsqueeze(1)
return self._layer(x) * gate
class ConcatSquashLinear(nn.Module):
def __init__(self, dim_in, dim_out, dim_c):
super(ConcatSquashLinear, self).__init__()
self._layer = nn.Linear(dim_in, dim_out)
self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False)
self._hyper_gate = nn.Linear(1 + dim_c, dim_out)
def forward(self, context, x):
gate = torch.sigmoid(self._hyper_gate(context))
bias = self._hyper_bias(context)
if x.dim() == 3:
gate = gate.unsqueeze(1)
bias = bias.unsqueeze(1)
ret = self._layer(x) * gate + bias
return ret
class ConcatScaleLinear(nn.Module):
def __init__(self, dim_in, dim_out, dim_c):
super(ConcatScaleLinear, self).__init__()
self._layer = nn.Linear(dim_in, dim_out)
self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False)
self._hyper_gate = nn.Linear(1 + dim_c, dim_out)
def forward(self, context, x):
gate = self._hyper_gate(context)
bias = self._hyper_bias(context)
if x.dim() == 3:
gate = gate.unsqueeze(1)
bias = bias.unsqueeze(1)
ret = self._layer(x) * gate + bias
return ret

89
models/flow.py Normal file
View file

@ -0,0 +1,89 @@
from .odefunc import ODEfunc, ODEnet
from .normalization import MovingBatchNorm1d
from .cnf import CNF, SequentialFlow
def count_nfe(model):
class AccNumEvals(object):
def __init__(self):
self.num_evals = 0
def __call__(self, module):
if isinstance(module, CNF):
self.num_evals += module.num_evals()
accumulator = AccNumEvals()
model.apply(accumulator)
return accumulator.num_evals
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def count_total_time(model):
class Accumulator(object):
def __init__(self):
self.total_time = 0
def __call__(self, module):
if isinstance(module, CNF):
self.total_time = self.total_time + module.sqrt_end_time * module.sqrt_end_time
accumulator = Accumulator()
model.apply(accumulator)
return accumulator.total_time
def build_model(args, input_dim, hidden_dims, context_dim, num_blocks, conditional):
def build_cnf():
diffeq = ODEnet(
hidden_dims=hidden_dims,
input_shape=(input_dim,),
context_dim=context_dim,
layer_type=args.layer_type,
nonlinearity=args.nonlinearity,
)
odefunc = ODEfunc(
diffeq=diffeq,
)
cnf = CNF(
odefunc=odefunc,
T=args.time_length,
train_T=args.train_T,
conditional=conditional,
solver=args.solver,
use_adjoint=args.use_adjoint,
atol=args.atol,
rtol=args.rtol,
)
return cnf
chain = [build_cnf() for _ in range(num_blocks)]
if args.batch_norm:
bn_layers = [MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag, sync=args.sync_bn)
for _ in range(num_blocks)]
bn_chain = [MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag, sync=args.sync_bn)]
for a, b in zip(chain, bn_layers):
bn_chain.append(a)
bn_chain.append(b)
chain = bn_chain
model = SequentialFlow(chain)
return model
def get_point_cnf(args):
dims = tuple(map(int, args.dims.split("-")))
model = build_model(args, args.input_dim, dims, args.zdim, args.num_blocks, True).cuda()
print("Number of trainable parameters of Point CNF: {}".format(count_parameters(model)))
return model
def get_latent_cnf(args):
dims = tuple(map(int, args.latent_dims.split("-")))
model = build_model(args, args.zdim, dims, 0, args.latent_num_blocks, False).cuda()
print("Number of trainable parameters of Latent CNF: {}".format(count_parameters(model)))
return model

224
models/networks.py Normal file
View file

@ -0,0 +1,224 @@
import torch
import numpy as np
import torch.nn.functional as F
from torch import optim
from torch import nn
from models.flow import get_point_cnf
from models.flow import get_latent_cnf
from utils import truncated_normal, reduce_tensor, standard_normal_logprob
class Encoder(nn.Module):
def __init__(self, zdim, input_dim=3, use_deterministic_encoder=False):
super(Encoder, self).__init__()
self.use_deterministic_encoder = use_deterministic_encoder
self.zdim = zdim
self.conv1 = nn.Conv1d(input_dim, 128, 1)
self.conv2 = nn.Conv1d(128, 128, 1)
self.conv3 = nn.Conv1d(128, 256, 1)
self.conv4 = nn.Conv1d(256, 512, 1)
self.bn1 = nn.BatchNorm1d(128)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(256)
self.bn4 = nn.BatchNorm1d(512)
if self.use_deterministic_encoder:
self.fc1 = nn.Linear(512, 256)
self.fc2 = nn.Linear(256, 128)
self.fc_bn1 = nn.BatchNorm1d(256)
self.fc_bn2 = nn.BatchNorm1d(128)
self.fc3 = nn.Linear(128, zdim)
else:
# Mapping to [c], cmean
self.fc1_m = nn.Linear(512, 256)
self.fc2_m = nn.Linear(256, 128)
self.fc3_m = nn.Linear(128, zdim)
self.fc_bn1_m = nn.BatchNorm1d(256)
self.fc_bn2_m = nn.BatchNorm1d(128)
# Mapping to [c], cmean
self.fc1_v = nn.Linear(512, 256)
self.fc2_v = nn.Linear(256, 128)
self.fc3_v = nn.Linear(128, zdim)
self.fc_bn1_v = nn.BatchNorm1d(256)
self.fc_bn2_v = nn.BatchNorm1d(128)
def forward(self, x):
x = x.transpose(1, 2)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = self.bn4(self.conv4(x))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 512)
ms = F.relu(self.fc_bn1(self.fc1(x)))
ms = F.relu(self.fc_bn2(self.fc2(ms)))
ms = self.fc3(ms)
if self.use_deterministic_encoder:
m, v = ms, 0
else:
m = F.relu(self.fc_bn1_m(self.fc1_m(x)))
m = F.relu(self.fc_bn2_m(self.fc2_m(m)))
m = self.fc3_m(m)
v = F.relu(self.fc_bn1_v(self.fc1_v(x)))
v = F.relu(self.fc_bn2_v(self.fc2_v(v)))
v = self.fc3_v(v)
return m, v
# Model
class PointFlow(nn.Module):
def __init__(self, args):
super(PointFlow, self).__init__()
self.input_dim = args.input_dim
self.zdim = args.zdim
self.use_latent_flow = args.use_latent_flow
self.use_deterministic_encoder = args.use_deterministic_encoder
self.prior_weight = args.prior_weight
self.recon_weight = args.recon_weight
self.entropy_weight = args.entropy_weight
self.distributed = args.distributed
self.truncate_std = None
self.encoder = Encoder(
zdim=args.zdim, input_dim=args.input_dim,
use_deterministic_encoder=args.use_deterministic_encoder)
self.point_cnf = get_point_cnf(args)
self.latent_cnf = get_latent_cnf(args) if args.use_latent_flow else nn.Sequential()
@staticmethod
def sample_gaussian(size, truncate_std=None, gpu=None):
y = torch.randn(*size).float()
y = y if gpu is None else y.cuda(gpu)
if truncate_std is not None:
truncated_normal(y, mean=0, std=1, trunc_std=truncate_std)
return y
@staticmethod
def reparameterize_gaussian(mean, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn(std.size()).to(mean)
return mean + std * eps
@staticmethod
def gaussian_entropy(logvar):
const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2))
ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const
return ent
def multi_gpu_wrapper(self, f):
self.encoder = f(self.encoder)
self.point_cnf = f(self.point_cnf)
self.latent_cnf = f(self.latent_cnf)
def make_optimizer(self, args):
def _get_opt_(params):
if args.optimizer == 'adam':
optimizer = optim.Adam(params, lr=args.lr, betas=(args.beta1, args.beta2),
weight_decay=args.weight_decay)
elif args.optimizer == 'sgd':
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum)
else:
assert 0, "args.optimizer should be either 'adam' or 'sgd'"
return optimizer
opt = _get_opt_(list(self.encoder.parameters()) + list(self.point_cnf.parameters())
+ list(list(self.latent_cnf.parameters())))
return opt
def forward(self, x, opt, step, writer=None):
opt.zero_grad()
batch_size = x.size(0)
num_points = x.size(1)
z_mu, z_sigma = self.encoder(x)
if self.use_deterministic_encoder:
z = z_mu + 0 * z_sigma
else:
z = self.reparameterize_gaussian(z_mu, z_sigma)
# Compute H[Q(z|X)]
if self.use_deterministic_encoder:
entropy = torch.zeros(batch_size).to(z)
else:
entropy = self.gaussian_entropy(z_sigma)
# Compute the prior probability P(z)
if self.use_latent_flow:
w, delta_log_pw = self.latent_cnf(z, None, torch.zeros(batch_size, 1).to(z))
log_pw = standard_normal_logprob(w).view(batch_size, -1).sum(1, keepdim=True)
delta_log_pw = delta_log_pw.view(batch_size, 1)
log_pz = log_pw - delta_log_pw
else:
log_pz = torch.zeros(batch_size, 1).to(z)
# Compute the reconstruction likelihood P(X|z)
z_new = z.view(*z.size())
z_new = z_new + (log_pz * 0.).mean()
y, delta_log_py = self.point_cnf(x, z_new, torch.zeros(batch_size, num_points, 1).to(x))
log_py = standard_normal_logprob(y).view(batch_size, -1).sum(1, keepdim=True)
delta_log_py = delta_log_py.view(batch_size, num_points, 1).sum(1)
log_px = log_py - delta_log_py
# Loss
entropy_loss = -entropy.mean() * self.entropy_weight
recon_loss = -log_px.mean() * self.recon_weight
prior_loss = -log_pz.mean() * self.prior_weight
loss = entropy_loss + prior_loss + recon_loss
loss.backward()
opt.step()
# LOGGING (after the training)
if self.distributed:
entropy_log = reduce_tensor(entropy.mean())
recon = reduce_tensor(-log_px.mean())
prior = reduce_tensor(-log_pz.mean())
else:
entropy_log = entropy.mean()
recon = -log_px.mean()
prior = -log_pz.mean()
recon_nats = recon / float(x.size(1) * x.size(2))
prior_nats = prior / float(self.zdim)
if writer is not None:
writer.add_scalar('train/entropy', entropy_log, step)
writer.add_scalar('train/prior', prior, step)
writer.add_scalar('train/prior(nats)', prior_nats, step)
writer.add_scalar('train/recon', recon, step)
writer.add_scalar('train/recon(nats)', recon_nats, step)
return {
'entropy': entropy_log.cpu().detach().item()
if not isinstance(entropy_log, float) else entropy_log,
'prior_nats': prior_nats,
'recon_nats': recon_nats,
}
def encode(self, x):
z_mu, z_sigma = self.encoder(x)
if self.use_deterministic_encoder:
return z_mu
else:
return self.reparameterize_gaussian(z_mu, z_sigma)
def decode(self, z, num_points, truncate_std=None):
# transform points from the prior to a point cloud, conditioned on a shape code
y = self.sample_gaussian((z.size(0), num_points, self.input_dim), truncate_std)
x = self.point_cnf(y, z, reverse=True).view(*y.size())
return y, x
def sample(self, batch_size, num_points, truncate_std=None, truncate_std_latent=None, gpu=None):
assert self.use_latent_flow, "Sampling requires `self.use_latent_flow` to be True."
# Generate the shape code from the prior
w = self.sample_gaussian((batch_size, self.zdim), truncate_std_latent, gpu=gpu)
z = self.latent_cnf(w, None, reverse=True).view(*w.size())
# Sample points conditioned on the shape code
y = self.sample_gaussian((batch_size, num_points, self.input_dim), truncate_std, gpu=gpu)
x = self.point_cnf(y, z, reverse=True).view(*y.size())
return z, x
def reconstruct(self, x, num_points=None, truncate_std=None):
num_points = x.size(1) if num_points is None else num_points
z = self.encode(x)
_, x = self.decode(z, num_points, truncate_std)
return x

145
models/normalization.py Normal file
View file

@ -0,0 +1,145 @@
import torch
import torch.nn as nn
from torch.nn import Parameter
from utils import reduce_tensor
__all__ = ['MovingBatchNorm1d']
class MovingBatchNormNd(nn.Module):
def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True, sync=False):
super(MovingBatchNormNd, self).__init__()
self.num_features = num_features
self.sync = sync
self.affine = affine
self.eps = eps
self.decay = decay
self.bn_lag = bn_lag
self.register_buffer('step', torch.zeros(1))
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
@property
def shape(self):
raise NotImplementedError
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.zero_()
self.bias.data.zero_()
def forward(self, x, c=None, logpx=None, reverse=False):
if reverse:
return self._reverse(x, logpx)
else:
return self._forward(x, logpx)
def _forward(self, x, logpx=None):
num_channels = x.size(-1)
used_mean = self.running_mean.clone().detach()
used_var = self.running_var.clone().detach()
if self.training:
# compute batch statistics
x_t = x.transpose(0, 1).reshape(num_channels, -1)
batch_mean = torch.mean(x_t, dim=1)
if self.sync:
batch_ex2 = torch.mean(x_t**2, dim=1)
batch_mean = reduce_tensor(batch_mean)
batch_ex2 = reduce_tensor(batch_ex2)
batch_var = batch_ex2 - batch_mean**2
else:
batch_var = torch.var(x_t, dim=1)
# moving average
if self.bn_lag > 0:
used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach())
used_mean /= (1. - self.bn_lag**(self.step[0] + 1))
used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach())
used_var /= (1. - self.bn_lag**(self.step[0] + 1))
# update running estimates
self.running_mean -= self.decay * (self.running_mean - batch_mean.data)
self.running_var -= self.decay * (self.running_var - batch_var.data)
self.step += 1
# perform normalization
used_mean = used_mean.view(*self.shape).expand_as(x)
used_var = used_var.view(*self.shape).expand_as(x)
y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps))
if self.affine:
weight = self.weight.view(*self.shape).expand_as(x)
bias = self.bias.view(*self.shape).expand_as(x)
y = y * torch.exp(weight) + bias
if logpx is None:
return y
else:
return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True)
def _reverse(self, y, logpy=None):
used_mean = self.running_mean
used_var = self.running_var
if self.affine:
weight = self.weight.view(*self.shape).expand_as(y)
bias = self.bias.view(*self.shape).expand_as(y)
y = (y - bias) * torch.exp(-weight)
used_mean = used_mean.view(*self.shape).expand_as(y)
used_var = used_var.view(*self.shape).expand_as(y)
x = y * torch.exp(0.5 * torch.log(used_var + self.eps)) + used_mean
if logpy is None:
return x
else:
return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True)
def _logdetgrad(self, x, used_var):
logdetgrad = -0.5 * torch.log(used_var + self.eps)
if self.affine:
weight = self.weight.view(*self.shape).expand(*x.size())
logdetgrad += weight
return logdetgrad
def __repr__(self):
return (
'{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},'
' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__)
)
def stable_var(x, mean=None, dim=1):
if mean is None:
mean = x.mean(dim, keepdim=True)
mean = mean.view(-1, 1)
res = torch.pow(x - mean, 2)
max_sqr = torch.max(res, dim, keepdim=True)[0]
var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr
var = var.view(-1)
# change nan to zero
var[var != var] = 0
return var
class MovingBatchNorm1d(MovingBatchNormNd):
@property
def shape(self):
return [1, -1]
def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False):
ret = super(MovingBatchNorm1d, self).forward(
x, context, logpx=logpx, reverse=reverse)
return ret

137
models/odefunc.py Normal file
View file

@ -0,0 +1,137 @@
import copy
import torch
import torch.nn as nn
from . import diffeq_layers
__all__ = ["ODEnet", "ODEfunc"]
def divergence_approx(f, y, e=None):
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
e_dzdx_e = e_dzdx.mul(e)
cnt = 0
while not e_dzdx_e.requires_grad and cnt < 10:
# print("RequiresGrad:f=%s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt=%d"
# % (f.requires_grad, y.requires_grad, e_dzdx.requires_grad,
# e.requires_grad, e_dzdx_e.requires_grad, cnt))
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
e_dzdx_e = e_dzdx * e
cnt += 1
approx_tr_dzdx = e_dzdx_e.sum(dim=-1)
assert approx_tr_dzdx.requires_grad, \
"(failed to add node to graph) f=%s %s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt:%s" \
% (
f.size(), f.requires_grad, y.requires_grad, e_dzdx.requires_grad, e.requires_grad, e_dzdx_e.requires_grad, cnt)
return approx_tr_dzdx
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
self.beta = nn.Parameter(torch.tensor(1.0))
def forward(self, x):
return x * torch.sigmoid(self.beta * x)
class Lambda(nn.Module):
def __init__(self, f):
super(Lambda, self).__init__()
self.f = f
def forward(self, x):
return self.f(x)
NONLINEARITIES = {
"tanh": nn.Tanh(),
"relu": nn.ReLU(),
"softplus": nn.Softplus(),
"elu": nn.ELU(),
"swish": Swish(),
"square": Lambda(lambda x: x ** 2),
"identity": Lambda(lambda x: x),
}
class ODEnet(nn.Module):
"""
Helper class to make neural nets for use in continuous normalizing flows
"""
def __init__(self, hidden_dims, input_shape, context_dim, layer_type="concat", nonlinearity="softplus"):
super(ODEnet, self).__init__()
base_layer = {
"ignore": diffeq_layers.IgnoreLinear,
"squash": diffeq_layers.SquashLinear,
"scale": diffeq_layers.ScaleLinear,
"concat": diffeq_layers.ConcatLinear,
"concat_v2": diffeq_layers.ConcatLinear_v2,
"concatsquash": diffeq_layers.ConcatSquashLinear,
"concatscale": diffeq_layers.ConcatScaleLinear,
}[layer_type]
# build models and add them
layers = []
activation_fns = []
hidden_shape = input_shape
for dim_out in (hidden_dims + (input_shape[0],)):
layer_kwargs = {}
layer = base_layer(hidden_shape[0], dim_out, context_dim, **layer_kwargs)
layers.append(layer)
activation_fns.append(NONLINEARITIES[nonlinearity])
hidden_shape = list(copy.copy(hidden_shape))
hidden_shape[0] = dim_out
self.layers = nn.ModuleList(layers)
self.activation_fns = nn.ModuleList(activation_fns[:-1])
def forward(self, context, y):
dx = y
for l, layer in enumerate(self.layers):
dx = layer(context, dx)
# if not last layer, use nonlinearity
if l < len(self.layers) - 1:
dx = self.activation_fns[l](dx)
return dx
class ODEfunc(nn.Module):
def __init__(self, diffeq):
super(ODEfunc, self).__init__()
self.diffeq = diffeq
self.divergence_fn = divergence_approx
self.register_buffer("_num_evals", torch.tensor(0.))
def before_odeint(self, e=None):
self._e = e
self._num_evals.fill_(0)
def forward(self, t, states):
y = states[0]
t = torch.ones(y.size(0), 1).to(y) * t.clone().detach().requires_grad_(True).type_as(y)
self._num_evals += 1
for state in states:
state.requires_grad_(True)
# Sample and fix the noise.
if self._e is None:
self._e = torch.randn_like(y, requires_grad=True).to(y)
with torch.set_grad_enabled(True):
if len(states) == 3: # conditional CNF
c = states[2]
tc = torch.cat([t, c.view(y.size(0), -1)], dim=1)
dy = self.diffeq(tc, y)
divergence = self.divergence_fn(dy, y, e=self._e).unsqueeze(-1)
return dy, -divergence, torch.zeros_like(c).requires_grad_(True)
elif len(states) == 2: # unconditional CNF
dy = self.diffeq(t, y)
divergence = self.divergence_fn(dy, y, e=self._e).view(-1, 1)
return dy, -divergence
else:
assert 0, "`len(states)` should be 2 or 3"

38
scripts/shapenet_airplane_ae.sh Executable file
View file

@ -0,0 +1,38 @@
#! /bin/bash
cate="airplane"
dims="512-512-512"
latent_dims="256-256"
num_blocks=1
latent_num_blocks=1
zdim=128
batch_size=48
lr=2e-3
epochs=4000
ds=shapenet15k
log_name="ae/${ds}-cate${cate}"
data_dir="data/ShapeNetCore.v2.PC15k"
python train.py \
--log_name ${log_name} \
--lr ${lr} \
--dataset_type ${ds} \
--data_dir ${data_dir} \
--cates ${cate} \
--dims ${dims} \
--latent_dims ${latent_dims} \
--num_blocks ${num_blocks} \
--latent_num_blocks ${latent_num_blocks} \
--batch_size ${batch_size} \
--zdim ${zdim} \
--epochs ${epochs} \
--save_freq 50 \
--viz_freq 1 \
--log_freq 1 \
--val_freq 10 \
--use_deterministic_encoder \
--prior_weight 0 \
--entropy_weight 0
echo "Done"
exit 0

View file

@ -0,0 +1,39 @@
#! /bin/bash
cate="airplane"
dims="512-512-512"
latent_dims="256-256"
num_blocks=1
latent_num_blocks=1
zdim=128
batch_size=256
lr=2e-3
epochs=4000
ds=shapenet15k
log_name="ae/${ds}-cate${cate}"
data_dir="data/ShapeNetCore.v2.PC15k"
python train.py \
--log_name ${log_name} \
--lr ${lr} \
--dataset_type ${ds} \
--data_dir ${data_dir} \
--cates ${cate} \
--dims ${dims} \
--latent_dims ${latent_dims} \
--num_blocks ${num_blocks} \
--latent_num_blocks ${latent_num_blocks} \
--batch_size ${batch_size} \
--zdim ${zdim} \
--epochs ${epochs} \
--save_freq 50 \
--viz_freq 1 \
--log_freq 1 \
--val_freq 10 \
--distributed \
--use_deterministic_encoder \
--prior_weight 0 \
--entropy_weight 0
echo "Done"
exit 0

View file

@ -0,0 +1,8 @@
#! /bin/bash
python test.py \
--cates airplane \
--resume_checkpoint pretrained_models/ae/airplane/checkpoint.pt \
--dims 512-512-512 \
--use_deterministic_encoder \
--evaluate_recon

View file

@ -0,0 +1,11 @@
#! /bin/bash
python demo.py \
--cates airplane \
--resume_checkpoint pretrained_models/gen/airplane/checkpoint.pt \
--dims 512-512-512 \
--latent_dims 256-256 \
--use_latent_flow \
--num_sample_shapes 20 \
--num_sample_points 2048

View file

@ -0,0 +1,36 @@
#! /bin/bash
cate="airplane"
dims="512-512-512"
latent_dims="256-256"
num_blocks=1
latent_num_blocks=1
zdim=128
batch_size=16
lr=2e-3
epochs=4000
ds=shapenet15k
log_name="gen/${ds}-cate${cate}"
data_dir="data/ShapeNetCore.v2.PC15k"
python train.py \
--log_name ${log_name} \
--lr ${lr} \
--dataset_type ${ds} \
--data_dir ${data_dir} \
--cates ${cate} \
--dims ${dims} \
--latent_dims ${latent_dims} \
--num_blocks ${num_blocks} \
--latent_num_blocks ${latent_num_blocks} \
--batch_size ${batch_size} \
--zdim ${zdim} \
--epochs ${epochs} \
--save_freq 50 \
--viz_freq 1 \
--log_freq 1 \
--val_freq 10 \
--use_latent_flow
echo "Done"
exit 0

View file

@ -0,0 +1,38 @@
#! /bin/bash
cate="airplane"
dims="512-512-512"
latent_dims="256-256"
num_blocks=1
latent_num_blocks=1
zdim=128
batch_size=256
lr=2e-3
epochs=4000
ds=shapenet15k
log_name="gen/${ds}-cate${cate}-seqback"
data_dir="data/ShapeNetCore.v2.PC15k"
python train.py \
--log_name ${log_name} \
--lr ${lr} \
--train_T False \
--dataset_type ${ds} \
--data_dir ${data_dir} \
--cates ${cate} \
--dims ${dims} \
--latent_dims ${latent_dims} \
--num_blocks ${num_blocks} \
--latent_num_blocks ${latent_num_blocks} \
--batch_size ${batch_size} \
--zdim ${zdim} \
--epochs ${epochs} \
--save_freq 50 \
--viz_freq 1 \
--log_freq 1 \
--val_freq 10 \
--distributed \
--use_latent_flow
echo "Done"
exit 0

View file

@ -0,0 +1,10 @@
#! /bin/bash
python test.py \
--cates airplane \
--resume_checkpoint pretrained_models/gen/airplane/checkpoint.pt \
--dims 512-512-512 \
--latent_dims 256-256 \
--use_latent_flow

11
scripts/shapenet_all_ae_test.sh Executable file
View file

@ -0,0 +1,11 @@
#! /bin/bash
python test.py \
--cates all \
--resume_checkpoint pretrained_models/ae/all/checkpoint.pt \
--dims 512-512-512 \
--use_deterministic_encoder \
--evaluate_recon \
--resume_dataset_mean pretrained_models/ae/all/train_set_mean.npy \
--resume_dataset_std pretrained_models/ae/all/train_set_std.npy

167
test.py Normal file
View file

@ -0,0 +1,167 @@
from datasets import get_datasets, synsetid_to_cate
from args import get_args
from pprint import pprint
from metrics.evaluation_metrics import EMD_CD
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
from metrics.evaluation_metrics import compute_all_metrics
from collections import defaultdict
from models.networks import PointFlow
import os
import torch
import numpy as np
import torch.nn as nn
def get_test_loader(args):
_, te_dataset = get_datasets(args)
if args.resume_dataset_mean is not None and args.resume_dataset_std is not None:
mean = np.load(args.resume_dataset_mean)
std = np.load(args.resume_dataset_std)
te_dataset.renormalize(mean, std)
loader = torch.utils.data.DataLoader(
dataset=te_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=0, pin_memory=True, drop_last=False)
return loader
def evaluate_recon(model, args):
# TODO: make this memory efficient
if 'all' in args.cates:
cates = list(synsetid_to_cate.values())
else:
cates = args.cates
all_results = {}
cate_to_len = {}
save_dir = os.path.dirname(args.resume_checkpoint)
for cate in cates:
args.cates = [cate]
loader = get_test_loader(args)
all_sample = []
all_ref = []
for data in loader:
idx_b, tr_pc, te_pc = data['idx'], data['train_points'], data['test_points']
te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
B, N = te_pc.size(0), te_pc.size(1)
out_pc = model.reconstruct(tr_pc, num_points=N)
m, s = data['mean'].float(), data['std'].float()
m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
out_pc = out_pc * s + m
te_pc = te_pc * s + m
all_sample.append(out_pc)
all_ref.append(te_pc)
sample_pcs = torch.cat(all_sample, dim=0)
ref_pcs = torch.cat(all_ref, dim=0)
cate_to_len[cate] = int(sample_pcs.size(0))
print("Cate=%s Total Sample size:%s Ref size: %s"
% (cate, sample_pcs.size(), ref_pcs.size()))
# Save it
np.save(os.path.join(save_dir, "%s_out_smp.npy" % cate),
sample_pcs.cpu().detach().numpy())
np.save(os.path.join(save_dir, "%s_out_ref.npy" % cate),
ref_pcs.cpu().detach().numpy())
results = EMD_CD(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True)
results = {
k: (v.cpu().detach().item() if not isinstance(v, float) else v)
for k, v in results.items()}
pprint(results)
all_results[cate] = results
# Save final results
print("="*80)
print("All category results:")
print("="*80)
pprint(all_results)
save_path = os.path.join(save_dir, "percate_results.npy")
np.save(save_path, all_results)
# Compute weighted performance
ttl_r, ttl_cnt = defaultdict(lambda: 0.), defaultdict(lambda: 0.)
for catename, l in cate_to_len.items():
for k, v in all_results[catename].items():
ttl_r[k] += v * float(l)
ttl_cnt[k] += float(l)
ttl_res = {k: (float(ttl_r[k]) / float(ttl_cnt[k])) for k in ttl_r.keys()}
print("="*80)
print("Averaged results:")
pprint(ttl_res)
print("="*80)
save_path = os.path.join(save_dir, "results.npy")
np.save(save_path, all_results)
def evaluate_gen(model, args):
loader = get_test_loader(args)
all_sample = []
all_ref = []
for data in loader:
idx_b, te_pc = data['idx'], data['test_points']
te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
B, N = te_pc.size(0), te_pc.size(1)
_, out_pc = model.sample(B, N)
# denormalize
m, s = data['mean'].float(), data['std'].float()
m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
out_pc = out_pc * s + m
te_pc = te_pc * s + m
all_sample.append(out_pc)
all_ref.append(te_pc)
sample_pcs = torch.cat(all_sample, dim=0)
ref_pcs = torch.cat(all_ref, dim=0)
print("Generation sample size:%s reference size: %s"
% (sample_pcs.size(), ref_pcs.size()))
# Save the generative output
save_dir = os.path.dirname(args.resume_checkpoint)
np.save(os.path.join(save_dir, "model_out_smp.npy"), sample_pcs.cpu().detach().numpy())
np.save(os.path.join(save_dir, "model_out_ref.npy"), ref_pcs.cpu().detach().numpy())
# Compute metrics
results = compute_all_metrics(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True)
results = {k: (v.cpu().detach().item()
if not isinstance(v, float) else v) for k, v in results.items()}
pprint(results)
sample_pcl_npy = sample_pcs.cpu().detach().numpy()
ref_pcl_npy = ref_pcs.cpu().detach().numpy()
jsd = JSD(sample_pcl_npy, ref_pcl_npy)
print("JSD:%s" % jsd)
def main(args):
model = PointFlow(args)
def _transform_(m):
return nn.DataParallel(m)
model = model.cuda()
model.multi_gpu_wrapper(_transform_)
print("Resume Path:%s" % args.resume_checkpoint)
checkpoint = torch.load(args.resume_checkpoint)
model.load_state_dict(checkpoint)
model.eval()
with torch.no_grad():
if args.evaluate_recon:
# Evaluate reconstruction
evaluate_recon(model, args)
else:
# Evaluate generation
evaluate_gen(model, args)
if __name__ == '__main__':
args = get_args()
main(args)

272
train.py Normal file
View file

@ -0,0 +1,272 @@
import sys
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import warnings
import torch.distributed
import numpy as np
import random
import faulthandler
import torch.multiprocessing as mp
import time
import scipy.misc
from models.networks import PointFlow
from torch import optim
from args import get_args
from torch.backends import cudnn
from utils import AverageValueMeter, set_random_seed, apply_random_rotation, save, resume, visualize_point_clouds
from tensorboardX import SummaryWriter
from datasets import get_datasets, init_np_seed
faulthandler.enable()
def main_worker(gpu, save_dir, ngpus_per_node, args):
# basic setup
cudnn.benchmark = True
args.gpu = gpu
if args.gpu is not None:
print("Use GPU: {} for training".format(args.gpu))
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.distributed:
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
if args.log_name is not None:
log_dir = "runs/%s" % args.log_name
else:
log_dir = "runs/time-%d" % time.time()
if not args.distributed or (args.rank % ngpus_per_node == 0):
writer = SummaryWriter(logdir=log_dir)
else:
writer = None
if not args.use_latent_flow: # auto-encoder only
args.prior_weight = 0
args.entropy_weight = 0
# multi-GPU setup
model = PointFlow(args)
if args.distributed: # Multiple processes, single GPU per process
if args.gpu is not None:
def _transform_(m):
return nn.parallel.DistributedDataParallel(
m, device_ids=[args.gpu], output_device=args.gpu, check_reduction=True)
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
model.multi_gpu_wrapper(_transform_)
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = 0
else:
assert 0, "DistributedDataParallel constructor should always set the single device scope"
elif args.gpu is not None: # Single process, single GPU per process
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
else: # Single process, multiple GPUs per process
def _transform_(m):
return nn.DataParallel(m)
model = model.cuda()
model.multi_gpu_wrapper(_transform_)
# resume checkpoints
start_epoch = 0
optimizer = model.make_optimizer(args)
if args.resume_checkpoint is None and os.path.exists(os.path.join(save_dir, 'checkpoint-latest.pt')):
args.resume_checkpoint = os.path.join(save_dir, 'checkpoint-latest.pt') # use the latest checkpoint
if args.resume_checkpoint is not None:
if args.resume_optimizer:
model, optimizer, start_epoch = resume(
args.resume_checkpoint, model, optimizer, strict=(not args.resume_non_strict))
else:
model, _, start_epoch = resume(
args.resume_checkpoint, model, optimizer=None, strict=(not args.resume_non_strict))
print('Resumed from: ' + args.resume_checkpoint)
# initialize datasets and loaders
tr_dataset, te_dataset = get_datasets(args)
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(tr_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
dataset=tr_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=0, pin_memory=True, sampler=train_sampler, drop_last=True,
worker_init_fn=init_np_seed)
test_loader = torch.utils.data.DataLoader(
dataset=te_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=0, pin_memory=True, drop_last=False,
worker_init_fn=init_np_seed)
# save dataset statistics
if not args.distributed or (args.rank % ngpus_per_node == 0):
np.save(os.path.join(save_dir, "train_set_mean.npy"), tr_dataset.all_points_mean)
np.save(os.path.join(save_dir, "train_set_std.npy"), tr_dataset.all_points_std)
np.save(os.path.join(save_dir, "train_set_idx.npy"), np.array(tr_dataset.shuffle_idx))
np.save(os.path.join(save_dir, "val_set_mean.npy"), te_dataset.all_points_mean)
np.save(os.path.join(save_dir, "val_set_std.npy"), te_dataset.all_points_std)
np.save(os.path.join(save_dir, "val_set_idx.npy"), np.array(te_dataset.shuffle_idx))
# load classification dataset if needed
if args.eval_classification:
from datasets import get_clf_datasets
def _make_data_loader_(dataset):
return torch.utils.data.DataLoader(
dataset=dataset, batch_size=args.batch_size, shuffle=False,
num_workers=0, pin_memory=True, drop_last=False,
worker_init_fn=init_np_seed
)
clf_datasets = get_clf_datasets(args)
clf_loaders = {
k: [_make_data_loader_(ds) for ds in ds_lst] for k, ds_lst in clf_datasets.items()
}
else:
clf_loaders = None
# initialize the learning rate scheduler
if args.scheduler == 'exponential':
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay)
elif args.scheduler == 'step':
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs // 2, gamma=0.1)
elif args.scheduler == 'linear':
def lambda_rule(ep):
lr_l = 1.0 - max(0, ep - 0.5 * args.epochs) / float(0.5 * args.epochs)
return lr_l
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
else:
assert 0, "args.schedulers should be either 'exponential' or 'linear'"
# main training loop
start_time = time.time()
entropy_avg_meter = AverageValueMeter()
latent_nats_avg_meter = AverageValueMeter()
point_nats_avg_meter = AverageValueMeter()
if args.distributed:
print("[Rank %d] World size : %d" % (args.rank, dist.get_world_size()))
print("Start epoch: %d End epoch: %d" % (start_epoch, args.epochs))
for epoch in range(start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
# adjust the learning rate
if (epoch + 1) % args.exp_decay_freq == 0:
scheduler.step(epoch=epoch)
if writer is not None:
writer.add_scalar('lr/optimizer', scheduler.get_lr()[0], epoch)
# train for one epoch
for bidx, data in enumerate(train_loader):
idx_batch, tr_batch, te_batch = data['idx'], data['train_points'], data['test_points']
step = bidx + len(train_loader) * epoch
model.train()
if args.random_rotate:
tr_batch, _, _ = apply_random_rotation(
tr_batch, rot_axis=train_loader.dataset.gravity_axis)
inputs = tr_batch.cuda(args.gpu, non_blocking=True)
out = model(inputs, optimizer, step, writer)
entropy, prior_nats, recon_nats = out['entropy'], out['prior_nats'], out['recon_nats']
entropy_avg_meter.update(entropy)
point_nats_avg_meter.update(recon_nats)
latent_nats_avg_meter.update(prior_nats)
if step % args.log_freq == 0:
duration = time.time() - start_time
start_time = time.time()
print("[Rank %d] Epoch %d Batch [%2d/%2d] Time [%3.2fs] Entropy %2.5f LatentNats %2.5f PointNats %2.5f"
% (args.rank, epoch, bidx, len(train_loader), duration, entropy_avg_meter.avg,
latent_nats_avg_meter.avg, point_nats_avg_meter.avg))
# evaluate on the validation set
if not args.no_validation and (epoch + 1) % args.val_freq == 0:
from utils import validate
validate(test_loader, model, epoch, writer, save_dir, args, clf_loaders=clf_loaders)
# save visualizations
if (epoch + 1) % args.viz_freq == 0:
# reconstructions
model.eval()
samples = model.reconstruct(inputs)
results = []
for idx in range(min(10, inputs.size(0))):
res = visualize_point_clouds(samples[idx], inputs[idx], idx,
pert_order=train_loader.dataset.display_axis_order)
results.append(res)
res = np.concatenate(results, axis=1)
scipy.misc.imsave(os.path.join(save_dir, 'images', 'tr_vis_conditioned_epoch%d-gpu%s.png' % (epoch, args.gpu)),
res.transpose((1, 2, 0)))
if writer is not None:
writer.add_image('tr_vis/conditioned', torch.as_tensor(res), epoch)
# samples
if args.use_latent_flow:
num_samples = min(10, inputs.size(0))
num_points = inputs.size(1)
_, samples = model.sample(num_samples, num_points)
results = []
for idx in range(num_samples):
res = visualize_point_clouds(samples[idx], inputs[idx], idx,
pert_order=train_loader.dataset.display_axis_order)
results.append(res)
res = np.concatenate(results, axis=1)
scipy.misc.imsave(os.path.join(save_dir, 'images', 'tr_vis_conditioned_epoch%d-gpu%s.png' % (epoch, args.gpu)),
res.transpose((1, 2, 0)))
if writer is not None:
writer.add_image('tr_vis/sampled', torch.as_tensor(res), epoch)
# save checkpoints
if not args.distributed or (args.rank % ngpus_per_node == 0):
if (epoch + 1) % args.save_freq == 0:
save(model, optimizer, epoch + 1,
os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
save(model, optimizer, epoch + 1,
os.path.join(save_dir, 'checkpoint-latest.pt'))
def main():
# command line args
args = get_args()
save_dir = os.path.join("checkpoints", args.log_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
os.makedirs(os.path.join(save_dir, 'images'))
with open(os.path.join(save_dir, 'command.sh'), 'w') as f:
f.write('python -X faulthandler ' + ' '.join(sys.argv))
f.write('\n')
if args.seed is None:
args.seed = random.randint(0, 1000000)
set_random_seed(args.seed)
if args.gpu is not None:
warnings.warn('You have chosen a specific GPU. This will completely '
'disable data parallelism.')
if args.dist_url == "env://" and args.world_size == -1:
args.world_size = int(os.environ["WORLD_SIZE"])
if args.sync_bn:
assert args.distributed
print("Arguments:")
print(args)
ngpus_per_node = torch.cuda.device_count()
if args.distributed:
args.world_size = ngpus_per_node * args.world_size
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(save_dir, ngpus_per_node, args))
else:
main_worker(args.gpu, save_dir, ngpus_per_node, args)
if __name__ == '__main__':
main()

378
utils.py Normal file
View file

@ -0,0 +1,378 @@
from pprint import pprint
from sklearn.svm import LinearSVC
from math import log, pi
import os
import torch
import torch.distributed as dist
import random
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
class AverageValueMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0.0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0.0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def gaussian_log_likelihood(x, mean, logvar, clip=True):
if clip:
logvar = torch.clamp(logvar, min=-4, max=3)
a = log(2 * pi)
b = logvar
c = (x - mean) ** 2 / torch.exp(logvar)
return -0.5 * torch.sum(a + b + c)
def bernoulli_log_likelihood(x, p, clip=True, eps=1e-6):
if clip:
p = torch.clamp(p, min=eps, max=1 - eps)
return torch.sum((x * torch.log(p)) + ((1 - x) * torch.log(1 - p)))
def kl_diagnormal_stdnormal(mean, logvar):
a = mean ** 2
b = torch.exp(logvar)
c = -1
d = -logvar
return 0.5 * torch.sum(a + b + c + d)
def kl_diagnormal_diagnormal(q_mean, q_logvar, p_mean, p_logvar):
# Ensure correct shapes since no numpy broadcasting yet
p_mean = p_mean.expand_as(q_mean)
p_logvar = p_logvar.expand_as(q_logvar)
a = p_logvar
b = - 1
c = - q_logvar
d = ((q_mean - p_mean) ** 2 + torch.exp(q_logvar)) / torch.exp(p_logvar)
return 0.5 * torch.sum(a + b + c + d)
# Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
def truncated_normal(tensor, mean=0, std=1, trunc_std=2):
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < trunc_std) & (tmp > -trunc_std)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
return tensor
def reduce_tensor(tensor, world_size=None):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
if world_size is None:
world_size = dist.get_world_size()
rt /= world_size
return rt
def standard_normal_logprob(z):
dim = z.size(-1)
log_z = -0.5 * dim * log(2 * pi)
return log_z - z.pow(2) / 2
def set_random_seed(seed):
"""set random seed"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Visualization
def visualize_point_clouds(pts, gtr, idx, pert_order=[0, 1, 2]):
pts = pts.cpu().detach().numpy()[:, pert_order]
gtr = gtr.cpu().detach().numpy()[:, pert_order]
fig = plt.figure(figsize=(6, 3))
ax1 = fig.add_subplot(121, projection='3d')
ax1.set_title("Sample:%s" % idx)
ax1.scatter(pts[:, 0], pts[:, 1], pts[:, 2], s=5)
ax2 = fig.add_subplot(122, projection='3d')
ax2.set_title("Ground Truth:%s" % idx)
ax2.scatter(gtr[:, 0], gtr[:, 1], gtr[:, 2], s=5)
fig.canvas.draw()
# grab the pixel buffer and dump it into a numpy array
res = np.array(fig.canvas.renderer._renderer)
res = np.transpose(res, (2, 0, 1))
plt.close()
return res
# Augmentation
def apply_random_rotation(pc, rot_axis=1):
B = pc.shape[0]
theta = np.random.rand(B) * 2 * np.pi
zeros = np.zeros(B)
ones = np.ones(B)
cos = np.cos(theta)
sin = np.sin(theta)
if rot_axis == 0:
rot = np.stack([
cos, -sin, zeros,
sin, cos, zeros,
zeros, zeros, ones
]).T.reshape(B, 3, 3)
elif rot_axis == 1:
rot = np.stack([
cos, zeros, -sin,
zeros, ones, zeros,
sin, zeros, cos
]).T.reshape(B, 3, 3)
elif rot_axis == 2:
rot = np.stack([
ones, zeros, zeros,
zeros, cos, -sin,
zeros, sin, cos
]).T.reshape(B, 3, 3)
else:
raise Exception("Invalid rotation axis")
rot = torch.from_numpy(rot).to(pc)
# (B, N, 3) mul (B, 3, 3) -> (B, N, 3)
pc_rotated = torch.bmm(pc, rot)
return pc_rotated, rot, theta
def validate_classification(loaders, model, args):
train_loader, test_loader = loaders
def _make_iter_(loader):
iterator = iter(loader)
return iterator
tr_latent = []
tr_label = []
for data in _make_iter_(train_loader):
tr_pc = data['train_points']
tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
latent = model.encode(tr_pc)
label = data['cate_idx']
tr_latent.append(latent.cpu().detach().numpy())
tr_label.append(label.cpu().detach().numpy())
tr_label = np.concatenate(tr_label)
tr_latent = np.concatenate(tr_latent)
te_latent = []
te_label = []
for data in _make_iter_(test_loader):
tr_pc = data['train_points']
tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
latent = model.encode(tr_pc)
label = data['cate_idx']
te_latent.append(latent.cpu().detach().numpy())
te_label.append(label.cpu().detach().numpy())
te_label = np.concatenate(te_label)
te_latent = np.concatenate(te_latent)
clf = LinearSVC(random_state=0)
clf.fit(tr_latent, tr_label)
test_pred = clf.predict(te_latent)
test_gt = te_label.flatten()
acc = np.mean((test_pred == test_gt).astype(float)) * 100.
res = {'acc': acc}
print("Acc:%s" % acc)
return res
def validate_conditioned(loader, model, args, max_samples=None, save_dir=None):
from metrics.evaluation_metrics import EMD_CD
all_idx = []
all_sample = []
all_ref = []
ttl_samples = 0
iterator = iter(loader)
for data in iterator:
# idx_b, tr_pc, te_pc = data[:3]
idx_b, tr_pc, te_pc = data['idx'], data['train_points'], data['test_points']
tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
if tr_pc.size(1) > te_pc.size(1):
tr_pc = tr_pc[:, :te_pc.size(1), :]
out_pc = model.reconstruct(tr_pc, num_points=te_pc.size(1))
# denormalize
m, s = data['mean'].float(), data['std'].float()
m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
out_pc = out_pc * s + m
te_pc = te_pc * s + m
all_sample.append(out_pc)
all_ref.append(te_pc)
all_idx.append(idx_b)
ttl_samples += int(te_pc.size(0))
if max_samples is not None and ttl_samples >= max_samples:
break
# Compute MMD and CD
sample_pcs = torch.cat(all_sample, dim=0)
ref_pcs = torch.cat(all_ref, dim=0)
print("[rank %s] Recon Sample size:%s Ref size: %s" % (args.rank, sample_pcs.size(), ref_pcs.size()))
if save_dir is not None and args.save_val_results:
smp_pcs_save_name = os.path.join(save_dir, "smp_recon_pcls_gpu%s.npy" % args.gpu)
ref_pcs_save_name = os.path.join(save_dir, "ref_recon_pcls_gpu%s.npy" % args.gpu)
np.save(smp_pcs_save_name, sample_pcs.cpu().detach().numpy())
np.save(ref_pcs_save_name, ref_pcs.cpu().detach().numpy())
print("Saving file:%s %s" % (smp_pcs_save_name, ref_pcs_save_name))
res = EMD_CD(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True)
mmd_cd = res['MMD-CD'] if 'MMD-CD' in res else None
mmd_emd = res['MMD-EMD'] if 'MMD-EMD' in res else None
print("MMD-CD :%s" % mmd_cd)
print("MMD-EMD :%s" % mmd_emd)
return res
def validate_sample(loader, model, args, max_samples=None, save_dir=None):
from metrics.evaluation_metrics import compute_all_metrics, jsd_between_point_cloud_sets as JSD
all_sample = []
all_ref = []
ttl_samples = 0
iterator = iter(loader)
for data in iterator:
idx_b, te_pc = data['idx'], data['test_points']
te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
_, out_pc = model.sample(te_pc.size(0), te_pc.size(1), gpu=args.gpu)
# denormalize
m, s = data['mean'].float(), data['std'].float()
m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
out_pc = out_pc * s + m
te_pc = te_pc * s + m
all_sample.append(out_pc)
all_ref.append(te_pc)
ttl_samples += int(te_pc.size(0))
if max_samples is not None and ttl_samples >= max_samples:
break
sample_pcs = torch.cat(all_sample, dim=0)
ref_pcs = torch.cat(all_ref, dim=0)
print("[rank %s] Generation Sample size:%s Ref size: %s"
% (args.rank, sample_pcs.size(), ref_pcs.size()))
if save_dir is not None and args.save_val_results:
smp_pcs_save_name = os.path.join(save_dir, "smp_syn_pcls_gpu%s.npy" % args.gpu)
ref_pcs_save_name = os.path.join(save_dir, "ref_syn_pcls_gpu%s.npy" % args.gpu)
np.save(smp_pcs_save_name, sample_pcs.cpu().detach().numpy())
np.save(ref_pcs_save_name, ref_pcs.cpu().detach().numpy())
print("Saving file:%s %s" % (smp_pcs_save_name, ref_pcs_save_name))
res = compute_all_metrics(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True)
pprint(res)
sample_pcs = sample_pcs.cpu().detach().numpy()
ref_pcs = ref_pcs.cpu().detach().numpy()
jsd = JSD(sample_pcs, ref_pcs)
jsd = torch.tensor(jsd).cuda() if args.gpu is None else torch.tensor(jsd).cuda(args.gpu)
res.update({"JSD": jsd})
print("JSD :%s" % jsd)
return res
def save(model, optimizer, epoch, path):
d = {
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(d, path)
def resume(path, model, optimizer=None, strict=True):
ckpt = torch.load(path)
model.load_state_dict(ckpt['model'], strict=strict)
start_epoch = ckpt['epoch']
if optimizer is not None:
optimizer.load_state_dict(ckpt['optimizer'])
return model, optimizer, start_epoch
def validate(test_loader, model, epoch, writer, save_dir, args, clf_loaders=None):
model.eval()
# Make epoch wise save directory
if writer is not None and args.save_val_results:
save_dir = os.path.join(save_dir, 'epoch-%d' % epoch)
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
else:
save_dir = None
# classification
if args.eval_classification and clf_loaders is not None:
for clf_expr, loaders in clf_loaders.items():
with torch.no_grad():
clf_val_res = validate_classification(loaders, model, args)
for k, v in clf_val_res.items():
if writer is not None and v is not None:
writer.add_scalar('val_%s/%s' % (clf_expr, k), v, epoch)
# samples
if args.use_latent_flow:
with torch.no_grad():
val_sample_res = validate_sample(
test_loader, model, args, max_samples=args.max_validate_shapes,
save_dir=save_dir)
for k, v in val_sample_res.items():
if not isinstance(v, float):
v = v.cpu().detach().item()
if writer is not None and v is not None:
writer.add_scalar('val_sample/%s' % k, v, epoch)
# reconstructions
with torch.no_grad():
val_res = validate_conditioned(
test_loader, model, args, max_samples=args.max_validate_shapes,
save_dir=save_dir)
for k, v in val_res.items():
if not isinstance(v, float):
v = v.cpu().detach().item()
if writer is not None and v is not None:
writer.add_scalar('val_conditioned/%s' % k, v, epoch)