diff --git a/.gitignore b/.gitignore index ab2e498..56dc157 100644 --- a/.gitignore +++ b/.gitignore @@ -1,108 +1,23 @@ -# Byte-compiled / optimized / DLL files __pycache__/ -*.py[cod] -*$py.class +*__pycache__ .idea/ -.DS_Store -*/.DS_Store -*/*/.DS_Store +*.pyc +data -# C extensions -*.so +scratch.py -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST +*.m -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook .ipynb_checkpoints -# pyenv -.python-version +checkpoints +runs -# celery beat schedule file -celerybeat-schedule +metrics/structural_losses/*.so +metrics/structural_losses/*.cu.o +metrics/structural_losses/makefile +PyMesh +checkpoint -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ +torchdiffeq/ +demo/ diff --git a/README.md b/README.md index 080eb06..d0d7bdc 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This repository contains a PyTorch implementation of the paper: -[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](https://arxiv.org/abs/1906.12320). +[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](www.arxiv.com). [Guandao Yang*](http://www.guandaoyang.com), [Xun Huang*](http://www.cs.cornell.edu/~xhuang/), @@ -11,9 +11,6 @@ This repository contains a PyTorch implementation of the paper: [Serge Belongie](http://blogs.cornell.edu/techfaculty/serge-belongie/), [Bharath Hariharan](http://home.bharathh.info/) -**The code will be available soon!** - -[\[Project page\]](https://www.guandaoyang.com/PointFlow/) [\[Video\]](https://www.youtube.com/watch?v=jqBiv77xC0M) ## Introduction @@ -23,3 +20,90 @@ As 3D point clouds become the representation of choice for multiple vision and g

+ +## Dependencies +* Python 3.6 +* CUDA 10.0. +* G++ or GCC 5. +* [PyTorch](http://pytorch.org/). Codes are tested with version 1.0.1 +* [torchdiffeq](https://github.com/rtqichen/torchdiffeq). +* (Optional) [Tensorboard](https://www.tensorflow.org/) for visualization of training process. + +Following is the suggested way to install these dependencies: +```bash +# Create a new conda environment +conda create -n PointFlow python=3.6 +conda activate PointFlow + +# Install pytorch (please refer to the commend in the official website) +conda install pytorch=1.0.1 torchvision cudatoolkit=10.0 -c pytorch -y + +# Install other dependencies such as torchdiffeq, structural losses, etc. +./install.sh +``` + +## Dataset + +The point clouds are uniformly sampled from meshes from ShapeNetCore dataset (version 2) and use the official split. +Please use this [link](https://drive.google.com/drive/folders/1G0rf-6HSHoTll6aH7voh-dXj6hCRhSAQ?usp=sharing) to download the ShapeNet point cloud. +The point cloud should be placed into `data` directory. +```bash +mv ShapeNetCore.v2.PC15k.zip data/ +cd data +unzip ShapeNetCore.v2.PC15k.zip +``` + +Please contact us if you need point clouds for ModelNet dataset. + +## Training + +Example training scripts can be found in `scripts/` folder. +```bash +# Train auto-encoder (no latent CNF) +./scripts/shapenet_airplane_ae.sh # Train with single GPU, about 7-8 GB GPU memory +./scripts/shapenet_airplane_ae_dist.sh # Train with multiple GPUs + +# Train generative model +./scripts/shapenet_airplane_ae.sh # Train with single GPU, about 7-8 GB GPU memory +./scripts/shapenet_airplane_ae_dist.sh # Train with multiple GPUs +``` + +## Pre-trained models and test + +Pretrained models can be downloaded from this [link](https://drive.google.com/file/d/1dcxjuuKiAXZxhiyWD_o_7Owx8Y3FbRHG/view?usp=sharing). +Following is the suggested way to evaluate the performance of the pre-trained models. +```bash +unzip pretrained_models.zip; # This will create a folder named pretrained_models + +# Evaluate the reconstruction performance of an AE trained on the airplane category +CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_ae_test.sh; + +# Evaluate the reconstruction performance of an AE trained with the whole ShapeNet +CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_all_ae_test.sh; + +# Evaluate the generative performance of PointFlow trained on the airplane category. +CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_gen_test.sh +``` + +## Demo + +The demo relies on [Open3D](http://www.open3d.org/). Following is the suggested way to install it: +```bash +conda install -c open3d-admin open3d +``` +The demo will sample shapes from a pre-trained model, save those shapes under the `demo` folder, and visualize those point clouds. +Once this dependency is in place, you can use the following script to use the demo for the pre-trained model for airplanes: +```bash +CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_demo.py +``` + +## Cite +Please cite our work if you find it useful: +```latex +@article{pointflow, + title={PointFlow: 3D Point Cloud Generation with Continuous Normalizing Flows}, + author={Yang, Guandao and Huang, Xun, and Hao, Zekun and Liu, Ming-Yu and Belongie, Serge and Hariharan, Bharath}, + journal={arXiv}, + year={2019} +} +``` diff --git a/args.py b/args.py new file mode 100644 index 0000000..8209dc6 --- /dev/null +++ b/args.py @@ -0,0 +1,163 @@ +import argparse + +NONLINEARITIES = ["tanh", "relu", "softplus", "elu", "swish", "square", "identity"] +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +LAYERS = ["ignore", "concat", "concat_v2", "squash", "concatsquash", "scale", "concatscale"] + + +def add_args(parser): + # model architecture options + parser.add_argument('--input_dim', type=int, default=3, + help='Number of input dimensions (3 for 3D point clouds)') + parser.add_argument('--dims', type=str, default='256') + parser.add_argument('--latent_dims', type=str, default='256') + parser.add_argument("--num_blocks", type=int, default=1, + help='Number of stacked CNFs.') + parser.add_argument("--latent_num_blocks", type=int, default=1, + help='Number of stacked CNFs.') + parser.add_argument("--layer_type", type=str, default="concatsquash", choices=LAYERS) + parser.add_argument('--time_length', type=float, default=0.5) + parser.add_argument('--train_T', type=eval, default=True, choices=[True, False]) + parser.add_argument("--nonlinearity", type=str, default="tanh", choices=NONLINEARITIES) + parser.add_argument('--use_adjoint', type=eval, default=True, choices=[True, False]) + parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) + parser.add_argument('--atol', type=float, default=1e-5) + parser.add_argument('--rtol', type=float, default=1e-5) + parser.add_argument('--batch_norm', type=eval, default=True, choices=[True, False]) + parser.add_argument('--sync_bn', type=eval, default=False, choices=[True, False]) + parser.add_argument('--bn_lag', type=float, default=0) + + # training options + parser.add_argument('--use_latent_flow', action='store_true', + help='Whether to use the latent flow to model the prior.') + parser.add_argument('--use_deterministic_encoder', action='store_true', + help='Whether to use a deterministic encoder.') + parser.add_argument('--zdim', type=int, default=128, + help='Dimension of the shape code') + parser.add_argument('--optimizer', type=str, default='adam', + help='Optimizer to use', choices=['adam', 'adamax', 'sgd']) + parser.add_argument('--batch_size', type=int, default=50, + help='Batch size (of datasets) for training') + parser.add_argument('--lr', type=float, default=1e-3, + help='Learning rate for the Adam optimizer.') + parser.add_argument('--beta1', type=float, default=0.9, + help='Beta1 for Adam.') + parser.add_argument('--beta2', type=float, default=0.999, + help='Beta2 for Adam.') + parser.add_argument('--momentum', type=float, default=0.9, + help='Momentum for SGD') + parser.add_argument('--weight_decay', type=float, default=0., + help='Weight decay for the optimizer.') + parser.add_argument('--epochs', type=int, default=100, + help='Number of epochs for training (default: 100)') + parser.add_argument('--seed', type=int, default=None, + help='Seed for initializing training. ') + parser.add_argument('--recon_weight', type=float, default=1., + help='Weight for the reconstruction loss.') + parser.add_argument('--prior_weight', type=float, default=1., + help='Weight for the prior loss.') + parser.add_argument('--entropy_weight', type=float, default=1., + help='Weight for the entropy loss.') + parser.add_argument('--scheduler', type=str, default='linear', + help='Type of learning rate schedule') + parser.add_argument('--exp_decay', type=float, default=1., + help='Learning rate schedule exponential decay rate') + parser.add_argument('--exp_decay_freq', type=int, default=1, + help='Learning rate exponential decay frequency') + + # data options + parser.add_argument('--dataset_type', type=str, default="shapenet15k", + help="Dataset types.", choices=['shapenet15k', 'modelnet40_15k', 'modelnet10_15k']) + parser.add_argument('--cates', type=str, nargs='+', default=["airplane"], + help="Categories to be trained (useful only if 'shapenet' is selected)") + parser.add_argument('--data_dir', type=str, default="data/ShapeNetCore.v2.PC15k", + help="Path to the training data") + parser.add_argument('--mn40_data_dir', type=str, default="data/ModelNet40.PC15k", + help="Path to ModelNet40") + parser.add_argument('--mn10_data_dir', type=str, default="data/ModelNet10.PC15k", + help="Path to ModelNet10") + parser.add_argument('--dataset_scale', type=float, default=1., + help='Scale of the dataset (x,y,z * scale = real output, default=1).') + parser.add_argument('--random_rotate', action='store_true', + help='Whether to randomly rotate each shape.') + parser.add_argument('--normalize_per_shape', action='store_true', + help='Whether to perform normalization per shape.') + parser.add_argument('--normalize_std_per_axis', action='store_true', + help='Whether to perform normalization per axis.') + parser.add_argument("--tr_max_sample_points", type=int, default=2048, + help='Max number of sampled points (train)') + parser.add_argument("--te_max_sample_points", type=int, default=2048, + help='Max number of sampled points (test)') + parser.add_argument('--num_workers', type=int, default=4, + help='Number of data loading threads') + + # logging and saving frequency + parser.add_argument('--log_name', type=str, default=None, help="Name for the log dir") + parser.add_argument('--viz_freq', type=int, default=10) + parser.add_argument('--val_freq', type=int, default=10) + parser.add_argument('--log_freq', type=int, default=10) + parser.add_argument('--save_freq', type=int, default=10) + + # validation options + parser.add_argument('--no_validation', action='store_true', + help='Whether to disable validation altogether.') + parser.add_argument('--save_val_results', action='store_true', + help='Whether to save the validation results.') + parser.add_argument('--eval_classification', action='store_true', + help='Whether to evaluate classification accuracy on MN40 and MN10.') + parser.add_argument('--no_eval_sampling', action='store_true', + help='Whether to evaluate sampling.') + parser.add_argument('--max_validate_shapes', type=int, default=None, + help='Max number of shapes used for validation pass.') + + # resuming + parser.add_argument('--resume_checkpoint', type=str, default=None, + help='Path to the checkpoint to be loaded.') + parser.add_argument('--resume_optimizer', action='store_true', + help='Whether to resume the optimizer when resumed training.') + parser.add_argument('--resume_non_strict', action='store_true', + help='Whether to resume in none-strict mode.') + parser.add_argument('--resume_dataset_mean', type=str, default=None, + help='Path to the file storing the dataset mean.') + parser.add_argument('--resume_dataset_std', type=str, default=None, + help='Path to the file storing the dataset std.') + + # distributed training + parser.add_argument('--world_size', default=1, type=int, + help='Number of distributed nodes.') + parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str, + help='url used to set up distributed training') + parser.add_argument('--dist_backend', default='nccl', type=str, + help='distributed backend') + parser.add_argument('--distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + parser.add_argument('--rank', default=0, type=int, + help='node rank for distributed training') + parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use. None means using all available GPUs.') + + # Evaluation options + parser.add_argument('--evaluate_recon', default=False, action='store_true', + help='Whether set to the evaluation for reconstruction.') + parser.add_argument('--num_sample_shapes', default=10, type=int, + help='Number of shapes to be sampled (for demo.py).') + parser.add_argument('--num_sample_points', default=2048, type=int, + help='Number of points (per-shape) to be sampled (for demo.py).') + + return parser + + +def get_parser(): + # command line args + parser = argparse.ArgumentParser(description='Flow-based Point Cloud Generation Experiment') + parser = add_args(parser) + return parser + + +def get_args(): + parser = get_parser() + args = parser.parse_args() + return args diff --git a/datasets.py b/datasets.py new file mode 100644 index 0000000..b6cdd05 --- /dev/null +++ b/datasets.py @@ -0,0 +1,389 @@ +import os +import torch +import numpy as np +from torch.utils.data import Dataset +from torch.utils import data +import random + +# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py +synsetid_to_cate = { + '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', + '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', + '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', + '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', + '02954340': 'cap', '02958343': 'car', '03001627': 'chair', + '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', + '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', + '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', + '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', + '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', + '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', + '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', + '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', + '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', + '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', + '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', + '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', + '04554684': 'washer', '02992529': 'cellphone', + '02843684': 'birdhouse', '02871439': 'bookshelf', + # '02858304': 'boat', no boat in our dataset, merged into vessels + # '02834778': 'bicycle', not in our taxonomy +} +cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()} + + +class Uniform15KPC(Dataset): + def __init__(self, root_dir, subdirs, tr_sample_size=10000, + te_sample_size=10000, split='train', scale=1., + normalize_per_shape=False, random_subsample=False, + normalize_std_per_axis=False, + all_points_mean=None, all_points_std=None, + input_dim=3): + self.root_dir = root_dir + self.split = split + self.in_tr_sample_size = tr_sample_size + self.in_te_sample_size = te_sample_size + self.subdirs = subdirs + self.scale = scale + self.random_subsample = random_subsample + self.input_dim = input_dim + + self.all_cate_mids = [] + self.cate_idx_lst = [] + self.all_points = [] + for cate_idx, subd in enumerate(self.subdirs): + # NOTE: [subd] here is synset id + sub_path = os.path.join(root_dir, subd, self.split) + if not os.path.isdir(sub_path): + print("Directory missing : %s" % sub_path) + continue + + all_mids = [] + for x in os.listdir(sub_path): + if not x.endswith('.npy'): + continue + all_mids.append(os.path.join(self.split, x[:-len('.npy')])) + + # NOTE: [mid] contains the split: i.e. "train/" or "val/" or "test/" + for mid in all_mids: + # obj_fname = os.path.join(sub_path, x) + obj_fname = os.path.join(root_dir, subd, mid + ".npy") + try: + point_cloud = np.load(obj_fname) # (15k, 3) + except: + continue + + assert point_cloud.shape[0] == 15000 + self.all_points.append(point_cloud[np.newaxis, ...]) + self.cate_idx_lst.append(cate_idx) + self.all_cate_mids.append((subd, mid)) + + # Shuffle the index deterministically (based on the number of examples) + self.shuffle_idx = list(range(len(self.all_points))) + random.Random(38383).shuffle(self.shuffle_idx) + self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx] + self.all_points = [self.all_points[i] for i in self.shuffle_idx] + self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx] + + # Normalization + self.all_points = np.concatenate(self.all_points) # (N, 15000, 3) + self.normalize_per_shape = normalize_per_shape + self.normalize_std_per_axis = normalize_std_per_axis + if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats + self.all_points_mean = all_points_mean + self.all_points_std = all_points_std + elif self.normalize_per_shape: # per shape normalization + B, N = self.all_points.shape[:2] + self.all_points_mean = self.all_points.mean(axis=1).reshape(B, 1, input_dim) + if normalize_std_per_axis: + self.all_points_std = self.all_points.reshape(B, N, -1).std(axis=1).reshape(B, 1, input_dim) + else: + self.all_points_std = self.all_points.reshape(B, -1).std(axis=1).reshape(B, 1, 1) + else: # normalize across the dataset + self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim) + if normalize_std_per_axis: + self.all_points_std = self.all_points.reshape(-1, input_dim).std(axis=0).reshape(1, 1, input_dim) + else: + self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1) + + self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std + self.train_points = self.all_points[:, :10000] + self.test_points = self.all_points[:, 10000:] + + self.tr_sample_size = min(10000, tr_sample_size) + self.te_sample_size = min(5000, te_sample_size) + print("Total number of data:%d" % len(self.train_points)) + print("Min number of points: (train)%d (test)%d" + % (self.tr_sample_size, self.te_sample_size)) + assert self.scale == 1, "Scale (!= 1) is deprecated" + + def get_pc_stats(self, idx): + if self.normalize_per_shape: + m = self.all_points_mean[idx].reshape(1, self.input_dim) + s = self.all_points_std[idx].reshape(1, -1) + return m, s + + return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1) + + def renormalize(self, mean, std): + self.all_points = self.all_points * self.all_points_std + self.all_points_mean + self.all_points_mean = mean + self.all_points_std = std + self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std + self.train_points = self.all_points[:, :10000] + self.test_points = self.all_points[:, 10000:] + + def __len__(self): + return len(self.train_points) + + def __getitem__(self, idx): + tr_out = self.train_points[idx] + if self.random_subsample: + tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size) + else: + tr_idxs = np.arange(self.tr_sample_size) + tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float() + + te_out = self.test_points[idx] + if self.random_subsample: + te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size) + else: + te_idxs = np.arange(self.te_sample_size) + te_out = torch.from_numpy(te_out[te_idxs, :]).float() + + m, s = self.get_pc_stats(idx) + cate_idx = self.cate_idx_lst[idx] + sid, mid = self.all_cate_mids[idx] + + return { + 'idx': idx, + 'train_points': tr_out, + 'test_points': te_out, + 'mean': m, 'std': s, 'cate_idx': cate_idx, + 'sid': sid, 'mid': mid + } + + +class ModelNet40PointClouds(Uniform15KPC): + def __init__(self, root_dir="data/ModelNet40.PC15k", + tr_sample_size=10000, te_sample_size=2048, + split='train', scale=1., normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=False, + all_points_mean=None, all_points_std=None): + self.root_dir = root_dir + self.split = split + assert self.split in ['train', 'test'] + self.sample_size = tr_sample_size + self.cates = [] + for cate in os.listdir(root_dir): + if os.path.isdir(os.path.join(root_dir, cate)) \ + and os.path.isdir(os.path.join(root_dir, cate, 'train')) \ + and os.path.isdir(os.path.join(root_dir, cate, 'test')): + self.cates.append(cate) + assert len(self.cates) == 40, "%s %s" % (len(self.cates), self.cates) + + # For non-aligned MN + # self.gravity_axis = 0 + # self.display_axis_order = [0,1,2] + + # Aligned MN has same axis-order as SN + self.gravity_axis = 1 + self.display_axis_order = [0, 2, 1] + + super(ModelNet40PointClouds, self).__init__( + root_dir, self.cates, tr_sample_size=tr_sample_size, + te_sample_size=te_sample_size, split=split, scale=scale, + normalize_per_shape=normalize_per_shape, + normalize_std_per_axis=normalize_std_per_axis, + random_subsample=random_subsample, + all_points_mean=all_points_mean, all_points_std=all_points_std, + input_dim=3) + + +class ModelNet10PointClouds(Uniform15KPC): + def __init__(self, root_dir="data/ModelNet10.PC15k", + tr_sample_size=10000, te_sample_size=2048, + split='train', scale=1., normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=False, + all_points_mean=None, all_points_std=None): + self.root_dir = root_dir + self.split = split + assert self.split in ['train', 'test'] + self.cates = [] + for cate in os.listdir(root_dir): + if os.path.isdir(os.path.join(root_dir, cate)) \ + and os.path.isdir(os.path.join(root_dir, cate, 'train')) \ + and os.path.isdir(os.path.join(root_dir, cate, 'test')): + self.cates.append(cate) + assert len(self.cates) == 10 + + # That's prealigned MN + # self.gravity_axis = 0 + # self.display_axis_order = [0,1,2] + + # Aligned MN has same axis-order as SN + self.gravity_axis = 1 + self.display_axis_order = [0, 2, 1] + + super(ModelNet10PointClouds, self).__init__( + root_dir, self.cates, tr_sample_size=tr_sample_size, + te_sample_size=te_sample_size, split=split, scale=scale, + normalize_per_shape=normalize_per_shape, + normalize_std_per_axis=normalize_std_per_axis, + random_subsample=random_subsample, + all_points_mean=all_points_mean, all_points_std=all_points_std, + input_dim=3) + + +class ShapeNet15kPointClouds(Uniform15KPC): + def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k", + categories=['airplane'], tr_sample_size=10000, te_sample_size=2048, + split='train', scale=1., normalize_per_shape=False, + normalize_std_per_axis=False, + random_subsample=False, + all_points_mean=None, all_points_std=None): + self.root_dir = root_dir + self.split = split + assert self.split in ['train', 'test', 'val'] + self.tr_sample_size = tr_sample_size + self.te_sample_size = te_sample_size + self.cates = categories + if 'all' in categories: + self.synset_ids = list(cate_to_synsetid.values()) + else: + self.synset_ids = [cate_to_synsetid[c] for c in self.cates] + + # assert 'v2' in root_dir, "Only supporting v2 right now." + self.gravity_axis = 1 + self.display_axis_order = [0, 2, 1] + + super(ShapeNet15kPointClouds, self).__init__( + root_dir, self.synset_ids, + tr_sample_size=tr_sample_size, + te_sample_size=te_sample_size, + split=split, scale=scale, + normalize_per_shape=normalize_per_shape, + normalize_std_per_axis=normalize_std_per_axis, + random_subsample=random_subsample, + all_points_mean=all_points_mean, all_points_std=all_points_std, + input_dim=3) + + +def init_np_seed(worker_id): + seed = torch.initial_seed() + np.random.seed(seed % 4294967296) + + +def _get_MN40_datasets_(args, data_dir=None): + tr_dataset = ModelNet40PointClouds( + split='train', + tr_sample_size=args.tr_max_sample_points, + te_sample_size=args.te_max_sample_points, + root_dir=(args.data_dir if data_dir is None else data_dir), + normalize_per_shape=args.normalize_per_shape, + normalize_std_per_axis=args.normalize_std_per_axis, + random_subsample=True) + te_dataset = ModelNet40PointClouds( + split='test', + tr_sample_size=args.tr_max_sample_points, + te_sample_size=args.te_max_sample_points, + root_dir=(args.data_dir if data_dir is None else data_dir), + normalize_per_shape=args.normalize_per_shape, + normalize_std_per_axis=args.normalize_std_per_axis, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + + return tr_dataset, te_dataset + + +def _get_MN10_datasets_(args, data_dir=None): + tr_dataset = ModelNet10PointClouds( + split='train', + tr_sample_size=args.tr_max_sample_points, + te_sample_size=args.te_max_sample_points, + root_dir=(args.data_dir if data_dir is None else data_dir), + normalize_per_shape=args.normalize_per_shape, + normalize_std_per_axis=args.normalize_std_per_axis, + random_subsample=True) + te_dataset = ModelNet10PointClouds( + split='test', + tr_sample_size=args.tr_max_sample_points, + te_sample_size=args.te_max_sample_points, + root_dir=(args.data_dir if data_dir is None else data_dir), + normalize_per_shape=args.normalize_per_shape, + normalize_std_per_axis=args.normalize_std_per_axis, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + return tr_dataset, te_dataset + + +def get_datasets(args): + if args.dataset_type == 'shapenet15k': + tr_dataset = ShapeNet15kPointClouds( + categories=args.cates, split='train', + tr_sample_size=args.tr_max_sample_points, + te_sample_size=args.te_max_sample_points, + scale=args.dataset_scale, root_dir=args.data_dir, + normalize_per_shape=args.normalize_per_shape, + normalize_std_per_axis=args.normalize_std_per_axis, + random_subsample=True) + te_dataset = ShapeNet15kPointClouds( + categories=args.cates, split='val', + tr_sample_size=args.tr_max_sample_points, + te_sample_size=args.te_max_sample_points, + scale=args.dataset_scale, root_dir=args.data_dir, + normalize_per_shape=args.normalize_per_shape, + normalize_std_per_axis=args.normalize_std_per_axis, + all_points_mean=tr_dataset.all_points_mean, + all_points_std=tr_dataset.all_points_std, + ) + elif args.dataset_type == 'modelnet40_15k': + tr_dataset, te_dataset = _get_MN40_datasets_(args) + elif args.dataset_type == 'modelnet10_15k': + tr_dataset, te_dataset = _get_MN10_datasets_(args) + else: + raise Exception("Invalid dataset type:%s" % args.dataset_type) + + return tr_dataset, te_dataset + + +def get_clf_datasets(args): + return { + 'MN40': _get_MN40_datasets_(args, data_dir=args.mn40_data_dir), + 'MN10': _get_MN10_datasets_(args, data_dir=args.mn10_data_dir), + } + + +def get_data_loaders(args): + tr_dataset, te_dataset = get_datasets(args) + train_loader = data.DataLoader( + dataset=tr_dataset, batch_size=args.batch_size, + shuffle=True, num_workers=args.num_workers, drop_last=True, + worker_init_fn=init_np_seed) + train_unshuffle_loader = data.DataLoader( + dataset=tr_dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers, drop_last=True, + worker_init_fn=init_np_seed) + test_loader = data.DataLoader( + dataset=te_dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers, drop_last=False, + worker_init_fn=init_np_seed) + + loaders = { + "test_loader": test_loader, + 'train_loader': train_loader, + 'train_unshuffle_loader': train_unshuffle_loader, + } + return loaders + + +if __name__ == "__main__": + shape_ds = ShapeNet15kPointClouds(categories=['airplane'], split='val') + x_tr, x_te = next(iter(shape_ds)) + print(x_tr.shape) + print(x_te.shape) + diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..8995249 --- /dev/null +++ b/demo.py @@ -0,0 +1,60 @@ +import open3d as o3d +from datasets import get_datasets +from args import get_args +from models.networks import PointFlow +import os +import torch +import numpy as np +import torch.nn as nn + + +def main(args): + model = PointFlow(args) + + def _transform_(m): + return nn.DataParallel(m) + + model = model.cuda() + model.multi_gpu_wrapper(_transform_) + + print("Resume Path:%s" % args.resume_checkpoint) + checkpoint = torch.load(args.resume_checkpoint) + model.load_state_dict(checkpoint) + model.eval() + + _, te_dataset = get_datasets(args) + if args.resume_dataset_mean is not None and args.resume_dataset_std is not None: + mean = np.load(args.resume_dataset_mean) + std = np.load(args.resume_dataset_std) + te_dataset.renormalize(mean, std) + ds_mean = torch.from_numpy(te_dataset.all_points_mean).cuda() + ds_std = torch.from_numpy(te_dataset.all_points_std).cuda() + + all_sample = [] + with torch.no_grad(): + for i in range(0, args.num_sample_shapes, args.batch_size): + B = len(range(i, min(i + args.batch_size, args.num_sample_shapes))) + N = args.num_sample_points + _, out_pc = model.sample(B, N) + out_pc = out_pc * ds_std + ds_mean + all_sample.append(out_pc) + + sample_pcs = torch.cat(all_sample, dim=0).cpu().detach().numpy() + print("Generation sample size:(%s, %s, %s)" % sample_pcs.shape) + + # Save the generative output + os.makedirs("demo", exist_ok=True) + np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs) + + # Visualize the demo + pcl = o3d.geometry.PointCloud() + for i in range(int(sample_pcs.shape[0])): + print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0])) + pts = sample_pcs[i].reshape(-1, 3) + pcl.points = o3d.utility.Vector3dVector(pts) + o3d.visualization.draw_geometries([pcl]) + + +if __name__ == '__main__': + args = get_args() + main(args) diff --git a/install.sh b/install.sh new file mode 100755 index 0000000..42605d5 --- /dev/null +++ b/install.sh @@ -0,0 +1,19 @@ +#! /bin/bash + +root=`pwd` + +# Install dependecies +conda install numpy matplotlib pillow scipy tqdm scikit-learn -y +pip install tensorflow-gpu==1.13.1 +pip install tensorboardX==1.7 + +# Compile CUDA kernel for CD/EMD loss +cd metrics/pytorch_structural_losses/ +make clean +make +cd $root + +# install torchdiffeq +git clone https://github.com/rtqichen/torchdiffeq.git +cd torchdiffeq +pip install -e . diff --git a/metrics/.gitignore b/metrics/.gitignore new file mode 100644 index 0000000..98c2a8b --- /dev/null +++ b/metrics/.gitignore @@ -0,0 +1 @@ +StructuralLosses diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/metrics/evaluation_metrics.py b/metrics/evaluation_metrics.py new file mode 100644 index 0000000..cfd794c --- /dev/null +++ b/metrics/evaluation_metrics.py @@ -0,0 +1,336 @@ +import torch +import numpy as np +import warnings +from scipy.stats import entropy +from sklearn.neighbors import NearestNeighbors +from numpy.linalg import norm + +# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/ +from .StructuralLosses.match_cost import match_cost +from .StructuralLosses.nn_distance import nn_distance + + +# # Import CUDA version of CD, borrowed from https://github.com/ThibaultGROUEIX/AtlasNet +# try: +# from . chamfer_distance_ext.dist_chamfer import chamferDist +# CD = chamferDist() +# def distChamferCUDA(x,y): +# return CD(x,y,gpu) +# except: + + +def distChamferCUDA(x, y): + return nn_distance(x, y) + + +def emd_approx(sample, ref): + B, N, N_ref = sample.size(0), sample.size(1), ref.size(1) + assert N == N_ref, "Not sure what would EMD do in this case" + emd = match_cost(sample, ref) # (B,) + emd_norm = emd / float(N) # (B,) + return emd_norm + + +# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet +def distChamfer(a, b): + x, y = a, b + bs, num_points, points_dim = x.size() + xx = torch.bmm(x, x.transpose(2, 1)) + yy = torch.bmm(y, y.transpose(2, 1)) + zz = torch.bmm(x, y.transpose(2, 1)) + diag_ind = torch.arange(0, num_points).to(a).long() + rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) + ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) + P = (rx.transpose(2, 1) + ry - 2 * zz) + return P.min(1)[0], P.min(2)[0] + + +def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True): + N_sample = sample_pcs.shape[0] + N_ref = ref_pcs.shape[0] + assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample) + + cd_lst = [] + emd_lst = [] + iterator = range(0, N_sample, batch_size) + + for b_start in iterator: + b_end = min(N_sample, b_start + batch_size) + sample_batch = sample_pcs[b_start:b_end] + ref_batch = ref_pcs[b_start:b_end] + + if accelerated_cd: + dl, dr = distChamferCUDA(sample_batch, ref_batch) + else: + dl, dr = distChamfer(sample_batch, ref_batch) + cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1)) + + emd_batch = emd_approx(sample_batch, ref_batch) + emd_lst.append(emd_batch) + + if reduced: + cd = torch.cat(cd_lst).mean() + emd = torch.cat(emd_lst).mean() + else: + cd = torch.cat(cd_lst) + emd = torch.cat(emd_lst) + + results = { + 'MMD-CD': cd, + 'MMD-EMD': emd, + } + return results + + +def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True): + N_sample = sample_pcs.shape[0] + N_ref = ref_pcs.shape[0] + all_cd = [] + all_emd = [] + iterator = range(N_sample) + for sample_b_start in iterator: + sample_batch = sample_pcs[sample_b_start] + + cd_lst = [] + emd_lst = [] + for ref_b_start in range(0, N_ref, batch_size): + ref_b_end = min(N_ref, ref_b_start + batch_size) + ref_batch = ref_pcs[ref_b_start:ref_b_end] + + batch_size_ref = ref_batch.size(0) + sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1) + sample_batch_exp = sample_batch_exp.contiguous() + + if accelerated_cd: + dl, dr = distChamferCUDA(sample_batch_exp, ref_batch) + else: + dl, dr = distChamfer(sample_batch_exp, ref_batch) + cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1)) + + emd_batch = emd_approx(sample_batch_exp, ref_batch) + emd_lst.append(emd_batch.view(1, -1)) + + cd_lst = torch.cat(cd_lst, dim=1) + emd_lst = torch.cat(emd_lst, dim=1) + all_cd.append(cd_lst) + all_emd.append(emd_lst) + + all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref + all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref + + return all_cd, all_emd + + +# Adapted from https://github.com/xuqiantong/GAN-Metrics/blob/master/framework/metric.py +def knn(Mxx, Mxy, Myy, k, sqrt=False): + n0 = Mxx.size(0) + n1 = Myy.size(0) + label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx) + M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0) + if sqrt: + M = M.abs().sqrt() + INFINITY = float('inf') + val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False) + + count = torch.zeros(n0 + n1).to(Mxx) + for i in range(0, k): + count = count + label.index_select(0, idx[i]) + pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float() + + s = { + 'tp': (pred * label).sum(), + 'fp': (pred * (1 - label)).sum(), + 'fn': ((1 - pred) * label).sum(), + 'tn': ((1 - pred) * (1 - label)).sum(), + } + + s.update({ + 'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10), + 'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10), + 'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10), + 'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10), + 'acc': torch.eq(label, pred).float().mean(), + }) + return s + + +def lgan_mmd_cov(all_dist): + N_sample, N_ref = all_dist.size(0), all_dist.size(1) + min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) + min_val, _ = torch.min(all_dist, dim=0) + mmd = min_val.mean() + mmd_smp = min_val_fromsmp.mean() + cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) + cov = torch.tensor(cov).to(all_dist) + return { + 'lgan_mmd': mmd, + 'lgan_cov': cov, + 'lgan_mmd_smp': mmd_smp, + } + + +def compute_all_metrics(sample_pcs, ref_pcs, batch_size, accelerated_cd=False): + results = {} + + M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd) + + res_cd = lgan_mmd_cov(M_rs_cd.t()) + results.update({ + "%s-CD" % k: v for k, v in res_cd.items() + }) + + res_emd = lgan_mmd_cov(M_rs_emd.t()) + results.update({ + "%s-EMD" % k: v for k, v in res_emd.items() + }) + + M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size, accelerated_cd=accelerated_cd) + M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd) + + # 1-NN results + one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False) + results.update({ + "1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k + }) + one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False) + results.update({ + "1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k + }) + + return results + + +####################################################### +# JSD : from https://github.com/optas/latent_3d_points +####################################################### +def unit_cube_grid_point_cloud(resolution, clip_sphere=False): + """Returns the center coordinates of each cell of a 3D grid with resolution^3 cells, + that is placed in the unit-cube. + If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere. + """ + grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) + spacing = 1.0 / float(resolution - 1) + for i in range(resolution): + for j in range(resolution): + for k in range(resolution): + grid[i, j, k, 0] = i * spacing - 0.5 + grid[i, j, k, 1] = j * spacing - 0.5 + grid[i, j, k, 2] = k * spacing - 0.5 + + if clip_sphere: + grid = grid.reshape(-1, 3) + grid = grid[norm(grid, axis=1) <= 0.5] + + return grid, spacing + + +def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28): + """Computes the JSD between two sets of point-clouds, as introduced in the paper + ```Learning Representations And Generative Models For 3D Point Clouds```. + Args: + sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. + ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. + resolution: (int) grid-resolution. Affects granularity of measurements. + """ + in_unit_sphere = True + sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1] + ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1] + return jensen_shannon_divergence(sample_grid_var, ref_grid_var) + + +def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose=False): + """Given a collection of point-clouds, estimate the entropy of the random variables + corresponding to occupancy-grid activation patterns. + Inputs: + pclouds: (numpy array) #point-clouds x points per point-cloud x 3 + grid_resolution (int) size of occupancy grid that will be used. + """ + epsilon = 10e-4 + bound = 0.5 + epsilon + if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: + if verbose: + warnings.warn('Point-clouds are not in unit cube.') + + if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: + if verbose: + warnings.warn('Point-clouds are not in unit sphere.') + + grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) + grid_coordinates = grid_coordinates.reshape(-1, 3) + grid_counters = np.zeros(len(grid_coordinates)) + grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) + nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) + + for pc in pclouds: + _, indices = nn.kneighbors(pc) + indices = np.squeeze(indices) + for i in indices: + grid_counters[i] += 1 + indices = np.unique(indices) + for i in indices: + grid_bernoulli_rvars[i] += 1 + + acc_entropy = 0.0 + n = float(len(pclouds)) + for g in grid_bernoulli_rvars: + if g > 0: + p = float(g) / n + acc_entropy += entropy([p, 1.0 - p]) + + return acc_entropy / len(grid_counters), grid_counters + + +def jensen_shannon_divergence(P, Q): + if np.any(P < 0) or np.any(Q < 0): + raise ValueError('Negative values.') + if len(P) != len(Q): + raise ValueError('Non equal size.') + + P_ = P / np.sum(P) # Ensure probabilities. + Q_ = Q / np.sum(Q) + + e1 = entropy(P_, base=2) + e2 = entropy(Q_, base=2) + e_sum = entropy((P_ + Q_) / 2.0, base=2) + res = e_sum - ((e1 + e2) / 2.0) + + res2 = _jsdiv(P_, Q_) + + if not np.allclose(res, res2, atol=10e-5, rtol=0): + warnings.warn('Numerical values of two JSD methods don\'t agree.') + + return res + + +def _jsdiv(P, Q): + """another way of computing JSD""" + + def _kldiv(A, B): + a = A.copy() + b = B.copy() + idx = np.logical_and(a > 0, b > 0) + a = a[idx] + b = b[idx] + return np.sum([v for v in a * np.log2(a / b)]) + + P_ = P / np.sum(P) + Q_ = Q / np.sum(Q) + + M = 0.5 * (P_ + Q_) + + return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) + + +if __name__ == "__main__": + B, N = 2, 10 + x = torch.rand(B, N, 3) + y = torch.rand(B, N, 3) + + distChamfer = distChamferCUDA() + min_l, min_r = distChamfer(x.cuda(), y.cuda()) + print(min_l.shape) + print(min_r.shape) + + l_dist = min_l.mean().cpu().detach().item() + r_dist = min_r.mean().cpu().detach().item() + print(l_dist, r_dist) diff --git a/metrics/pytorch_structural_losses/.gitignore b/metrics/pytorch_structural_losses/.gitignore new file mode 100644 index 0000000..76ab6ce --- /dev/null +++ b/metrics/pytorch_structural_losses/.gitignore @@ -0,0 +1 @@ +PyTorchStructuralLosses.egg-info/ diff --git a/metrics/pytorch_structural_losses/Makefile b/metrics/pytorch_structural_losses/Makefile new file mode 100644 index 0000000..b826528 --- /dev/null +++ b/metrics/pytorch_structural_losses/Makefile @@ -0,0 +1,100 @@ +############################################################################### +# Uncomment for debugging +# DEBUG := 1 +# Pretty build +# Q ?= @ + +CXX := g++ +PYTHON := python +NVCC := /usr/local/cuda/bin/nvcc + +# PYTHON Header path +PYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())') +PYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]') +PYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]') + +# CUDA ROOT DIR that contains bin/ lib64/ and include/ +# CUDA_DIR := /usr/local/cuda +CUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())') + +INCLUDE_DIRS := ./ $(CUDA_DIR)/include + +INCLUDE_DIRS += $(PYTHON_HEADER_DIR) +INCLUDE_DIRS += $(PYTORCH_INCLUDES) + +# Custom (MKL/ATLAS/OpenBLAS) include and lib directories. +# Leave commented to accept the defaults for your choice of BLAS +# (which should work)! +# BLAS_INCLUDE := /path/to/your/blas +# BLAS_LIB := /path/to/your/blas + +############################################################################### +SRC_DIR := ./src +OBJ_DIR := ./objs +CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp) +CU_SRCS := $(wildcard $(SRC_DIR)/*.cu) +OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS)) +CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS)) +STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a + +# CUDA architecture setting: going with all of them. +# For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility. +# For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility. +CUDA_ARCH := -gencode arch=compute_61,code=sm_61 \ + -gencode arch=compute_61,code=compute_61 \ + -gencode arch=compute_52,code=sm_52 + +# We will also explicitly add stdc++ to the link target. +LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu + +# Debugging +ifeq ($(DEBUG), 1) + COMMON_FLAGS += -DDEBUG -g -O0 + # https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/ + NVCCFLAGS += -g -G # -rdc true +else + COMMON_FLAGS += -DNDEBUG -O3 +endif + +WARNINGS := -Wall -Wno-sign-compare -Wcomment + +INCLUDE_DIRS += $(BLAS_INCLUDE) + +# Automatic dependency generation (nvcc is handled separately) +CXXFLAGS += -MMD -MP + +# Complete build flags. +COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \ + -DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0 +CXXFLAGS += -pthread -fPIC -fwrapv -std=c++11 $(COMMON_FLAGS) $(WARNINGS) +NVCCFLAGS += -std=c++11 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS) + +all: $(STATIC_LIB) + $(PYTHON) setup.py build + @ mv build/lib.linux-x86_64-3.6/StructuralLosses .. + @ mv build/lib.linux-x86_64-3.6/*.so ../StructuralLosses/ + @- $(RM) -rf $(OBJ_DIR) build objs + +$(OBJ_DIR): + @ mkdir -p $@ + @ mkdir -p $@/cuda + +$(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR) + @ echo CXX $< + $(Q)$(CXX) $< $(CXXFLAGS) -c -o $@ + +$(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR) + @ echo NVCC $< + $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \ + -odir $(@D) + $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@ + +$(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR) + $(RM) -f $(STATIC_LIB) + $(RM) -rf build dist + @ echo LD -o $@ + ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS) + +clean: + @- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses + diff --git a/metrics/pytorch_structural_losses/__init__.py b/metrics/pytorch_structural_losses/__init__.py new file mode 100644 index 0000000..1656c10 --- /dev/null +++ b/metrics/pytorch_structural_losses/__init__.py @@ -0,0 +1,6 @@ +#import torch + +#from MakePytorchBackend import AddGPU, Foo, ApproxMatch + +#from Add import add_gpu, approx_match + diff --git a/metrics/pytorch_structural_losses/match_cost.py b/metrics/pytorch_structural_losses/match_cost.py new file mode 100644 index 0000000..a919aef --- /dev/null +++ b/metrics/pytorch_structural_losses/match_cost.py @@ -0,0 +1,45 @@ +import torch +from torch.autograd import Function +from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad + +# Inherit from Function +class MatchCostFunction(Function): + # Note that both forward and backward are @staticmethods + @staticmethod + # bias is an optional argument + def forward(ctx, seta, setb): + #print("Match Cost Forward") + ctx.save_for_backward(seta, setb) + ''' + input: + set1 : batch_size * #dataset_points * 3 + set2 : batch_size * #query_points * 3 + returns: + match : batch_size * #query_points * #dataset_points + ''' + match, temp = ApproxMatch(seta, setb) + ctx.match = match + cost = MatchCost(seta, setb, match) + return cost + + """ + grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match) + return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None] + """ + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, grad_output): + #print("Match Cost Backward") + # This is a pattern that is very convenient - at the top of backward + # unpack saved_tensors and initialize all gradients w.r.t. inputs to + # None. Thanks to the fact that additional trailing Nones are + # ignored, the return statement is simple even when the function has + # optional inputs. + seta, setb = ctx.saved_tensors + #grad_input = grad_weight = grad_bias = None + grada, gradb = MatchCostGrad(seta, setb, ctx.match) + grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2) + return grada*grad_output_expand, gradb*grad_output_expand + +match_cost = MatchCostFunction.apply + diff --git a/metrics/pytorch_structural_losses/nn_distance.py b/metrics/pytorch_structural_losses/nn_distance.py new file mode 100644 index 0000000..44a85bc --- /dev/null +++ b/metrics/pytorch_structural_losses/nn_distance.py @@ -0,0 +1,42 @@ +import torch +from torch.autograd import Function +# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad +from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad + +# Inherit from Function +class NNDistanceFunction(Function): + # Note that both forward and backward are @staticmethods + @staticmethod + # bias is an optional argument + def forward(ctx, seta, setb): + #print("Match Cost Forward") + ctx.save_for_backward(seta, setb) + ''' + input: + set1 : batch_size * #dataset_points * 3 + set2 : batch_size * #query_points * 3 + returns: + dist1, idx1, dist2, idx2 + ''' + dist1, idx1, dist2, idx2 = NNDistance(seta, setb) + ctx.idx1 = idx1 + ctx.idx2 = idx2 + return dist1, dist2 + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, grad_dist1, grad_dist2): + #print("Match Cost Backward") + # This is a pattern that is very convenient - at the top of backward + # unpack saved_tensors and initialize all gradients w.r.t. inputs to + # None. Thanks to the fact that additional trailing Nones are + # ignored, the return statement is simple even when the function has + # optional inputs. + seta, setb = ctx.saved_tensors + idx1 = ctx.idx1 + idx2 = ctx.idx2 + grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2) + return grada, gradb + +nn_distance = NNDistanceFunction.apply + diff --git a/metrics/pytorch_structural_losses/pybind/bind.cpp b/metrics/pytorch_structural_losses/pybind/bind.cpp new file mode 100644 index 0000000..3202e77 --- /dev/null +++ b/metrics/pytorch_structural_losses/pybind/bind.cpp @@ -0,0 +1,15 @@ +#include + +#include + +#include "pybind/extern.hpp" + +namespace py = pybind11; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ + m.def("ApproxMatch", &ApproxMatch); + m.def("MatchCost", &MatchCost); + m.def("MatchCostGrad", &MatchCostGrad); + m.def("NNDistance", &NNDistance); + m.def("NNDistanceGrad", &NNDistanceGrad); +} diff --git a/metrics/pytorch_structural_losses/pybind/extern.hpp b/metrics/pytorch_structural_losses/pybind/extern.hpp new file mode 100644 index 0000000..003877b --- /dev/null +++ b/metrics/pytorch_structural_losses/pybind/extern.hpp @@ -0,0 +1,6 @@ +std::vector ApproxMatch(at::Tensor in_a, at::Tensor in_b); +at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match); +std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match); + +std::vector NNDistance(at::Tensor set_d, at::Tensor set_q); +std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2); diff --git a/metrics/pytorch_structural_losses/setup.py b/metrics/pytorch_structural_losses/setup.py new file mode 100644 index 0000000..67f0e8c --- /dev/null +++ b/metrics/pytorch_structural_losses/setup.py @@ -0,0 +1,30 @@ +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension + +# Python interface +setup( + name='PyTorchStructuralLosses', + version='0.1.0', + install_requires=['torch'], + packages=['StructuralLosses'], + package_dir={'StructuralLosses': './'}, + ext_modules=[ + CUDAExtension( + name='StructuralLossesBackend', + include_dirs=['./'], + sources=[ + 'pybind/bind.cpp', + ], + libraries=['make_pytorch'], + library_dirs=['objs'], + # extra_compile_args=['-g'] + ) + ], + cmdclass={'build_ext': BuildExtension}, + author='Christopher B. Choy', + author_email='chrischoy@ai.stanford.edu', + description='Tutorial for Pytorch C++ Extension with a Makefile', + keywords='Pytorch C++ Extension', + url='https://github.com/chrischoy/MakePytorchPlusPlus', + zip_safe=False, +) diff --git a/metrics/pytorch_structural_losses/src/approxmatch.cu b/metrics/pytorch_structural_losses/src/approxmatch.cu new file mode 100644 index 0000000..42058be --- /dev/null +++ b/metrics/pytorch_structural_losses/src/approxmatch.cu @@ -0,0 +1,326 @@ +#include "utils.hpp" + +__global__ void approxmatchkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){ + float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; + float multiL,multiR; + if (n>=m){ + multiL=1; + multiR=n/m; + }else{ + multiL=m/n; + multiR=1; + } + const int Block=1024; + __shared__ float buf[Block*4]; + for (int i=blockIdx.x;i=-2;j--){ + for (int j=7;j>-2;j--){ + float level=-powf(4.0f,j); + if (j==-2){ + level=0; + } + for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,out); +//} + +__global__ void matchcostgrad2kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){ + __shared__ float sum_grad[256*3]; + for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad2); +//} + +/*void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N, + cudaStream_t stream)*/ +// temp: TensorShape{b,(n+m)*2} +void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream){ + approxmatchkernel + <<<32, 512, 0, stream>>>(b,n,m,xyz1,xyz2,match,temp); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error(Formatter() + << "CUDA kernel failed : " << std::to_string(err)); +} + +void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream){ + matchcostkernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,out); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error(Formatter() + << "CUDA kernel failed : " << std::to_string(err)); +} + +void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream){ + matchcostgrad1kernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,grad1); + matchcostgrad2kernel<<>>(b,n,m,xyz1,xyz2,match,grad2); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error(Formatter() + << "CUDA kernel failed : " << std::to_string(err)); +} diff --git a/metrics/pytorch_structural_losses/src/approxmatch.cuh b/metrics/pytorch_structural_losses/src/approxmatch.cuh new file mode 100644 index 0000000..440d64d --- /dev/null +++ b/metrics/pytorch_structural_losses/src/approxmatch.cuh @@ -0,0 +1,8 @@ +/* +template +void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N, + cudaStream_t stream); +*/ +void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream); +void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream); +void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream); diff --git a/metrics/pytorch_structural_losses/src/nndistance.cu b/metrics/pytorch_structural_losses/src/nndistance.cu new file mode 100755 index 0000000..bd13b8b --- /dev/null +++ b/metrics/pytorch_structural_losses/src/nndistance.cu @@ -0,0 +1,155 @@ + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=512; + __shared__ float buf[batch*3]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ + NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); + NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); + NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); +} + diff --git a/metrics/pytorch_structural_losses/src/nndistance.cuh b/metrics/pytorch_structural_losses/src/nndistance.cuh new file mode 100755 index 0000000..e2b65c3 --- /dev/null +++ b/metrics/pytorch_structural_losses/src/nndistance.cuh @@ -0,0 +1,2 @@ +void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); +void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); diff --git a/metrics/pytorch_structural_losses/src/structural_loss.cpp b/metrics/pytorch_structural_losses/src/structural_loss.cpp new file mode 100644 index 0000000..f58702c --- /dev/null +++ b/metrics/pytorch_structural_losses/src/structural_loss.cpp @@ -0,0 +1,125 @@ +#include +#include + +#include "src/approxmatch.cuh" +#include "src/nndistance.cuh" + +#include +#include + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +/* +input: + set1 : batch_size * #dataset_points * 3 + set2 : batch_size * #query_points * 3 +returns: + match : batch_size * #query_points * #dataset_points +*/ +// temp: TensorShape{b,(n+m)*2} +std::vector ApproxMatch(at::Tensor set_d, at::Tensor set_q) { + //std::cout << "[ApproxMatch] Called." << std::endl; + int64_t batch_size = set_d.size(0); + int64_t n_dataset_points = set_d.size(1); // n + int64_t n_query_points = set_q.size(1); // m + //std::cout << "[ApproxMatch] batch_size:" << batch_size << std::endl; + at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + CHECK_INPUT(set_d); + CHECK_INPUT(set_q); + CHECK_INPUT(match); + CHECK_INPUT(temp); + + approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),temp.data(), at::cuda::getCurrentCUDAStream()); + return {match, temp}; +} + +at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { + //std::cout << "[MatchCost] Called." << std::endl; + int64_t batch_size = set_d.size(0); + int64_t n_dataset_points = set_d.size(1); // n + int64_t n_query_points = set_q.size(1); // m + //std::cout << "[MatchCost] batch_size:" << batch_size << std::endl; + at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + CHECK_INPUT(set_d); + CHECK_INPUT(set_q); + CHECK_INPUT(match); + CHECK_INPUT(out); + matchcost(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),out.data(),at::cuda::getCurrentCUDAStream()); + return out; +} + +std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { + //std::cout << "[MatchCostGrad] Called." << std::endl; + int64_t batch_size = set_d.size(0); + int64_t n_dataset_points = set_d.size(1); // n + int64_t n_query_points = set_q.size(1); // m + //std::cout << "[MatchCostGrad] batch_size:" << batch_size << std::endl; + at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + CHECK_INPUT(set_d); + CHECK_INPUT(set_q); + CHECK_INPUT(match); + CHECK_INPUT(grad1); + CHECK_INPUT(grad2); + matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),grad1.data(),grad2.data(),at::cuda::getCurrentCUDAStream()); + return {grad1, grad2}; +} + + +/* +input: + set_d : batch_size * #dataset_points * 3 + set_q : batch_size * #query_points * 3 +returns: + dist1, idx1 : batch_size * #dataset_points + dist2, idx2 : batch_size * #query_points +*/ +std::vector NNDistance(at::Tensor set_d, at::Tensor set_q) { + //std::cout << "[NNDistance] Called." << std::endl; + int64_t batch_size = set_d.size(0); + int64_t n_dataset_points = set_d.size(1); // n + int64_t n_query_points = set_q.size(1); // m + //std::cout << "[NNDistance] batch_size:" << batch_size << std::endl; + at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); + at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); + CHECK_INPUT(set_d); + CHECK_INPUT(set_q); + CHECK_INPUT(dist1); + CHECK_INPUT(idx1); + CHECK_INPUT(dist2); + CHECK_INPUT(idx2); + // void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); + nndistance(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(),dist1.data(),idx1.data(),dist2.data(),idx2.data(), at::cuda::getCurrentCUDAStream()); + return {dist1, idx1, dist2, idx2}; +} + +std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) { + //std::cout << "[NNDistanceGrad] Called." << std::endl; + int64_t batch_size = set_d.size(0); + int64_t n_dataset_points = set_d.size(1); // n + int64_t n_query_points = set_q.size(1); // m + //std::cout << "[NNDistanceGrad] batch_size:" << batch_size << std::endl; + at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + CHECK_INPUT(set_d); + CHECK_INPUT(set_q); + CHECK_INPUT(idx1); + CHECK_INPUT(idx2); + CHECK_INPUT(grad_dist1); + CHECK_INPUT(grad_dist2); + CHECK_INPUT(grad1); + CHECK_INPUT(grad2); + //void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); + nndistancegrad(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(), + grad_dist1.data(),idx1.data(), + grad_dist2.data(),idx2.data(), + grad1.data(),grad2.data(), + at::cuda::getCurrentCUDAStream()); + return {grad1, grad2}; +} + diff --git a/metrics/pytorch_structural_losses/src/utils.hpp b/metrics/pytorch_structural_losses/src/utils.hpp new file mode 100644 index 0000000..d60fa2b --- /dev/null +++ b/metrics/pytorch_structural_losses/src/utils.hpp @@ -0,0 +1,26 @@ +#include +#include +#include + +class Formatter { +public: + Formatter() {} + ~Formatter() {} + + template Formatter &operator<<(const Type &value) { + stream_ << value; + return *this; + } + + std::string str() const { return stream_.str(); } + operator std::string() const { return stream_.str(); } + + enum ConvertToString { to_str }; + + std::string operator>>(ConvertToString) { return stream_.str(); } + +private: + std::stringstream stream_; + Formatter(const Formatter &); + Formatter &operator=(Formatter &); +}; diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/cnf.py b/models/cnf.py new file mode 100644 index 0000000..92ef7ae --- /dev/null +++ b/models/cnf.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn +from torchdiffeq import odeint_adjoint +from torchdiffeq import odeint as odeint_normal + +__all__ = ["CNF", "SequentialFlow"] + + +class SequentialFlow(nn.Module): + """A generalized nn.Sequential container for normalizing flows.""" + + def __init__(self, layer_list): + super(SequentialFlow, self).__init__() + self.chain = nn.ModuleList(layer_list) + + def forward(self, x, context, logpx=None, reverse=False, inds=None, integration_times=None): + if inds is None: + if reverse: + inds = range(len(self.chain) - 1, -1, -1) + else: + inds = range(len(self.chain)) + + if logpx is None: + for i in inds: + x = self.chain[i](x, context, logpx, integration_times, reverse) + return x + else: + for i in inds: + x, logpx = self.chain[i](x, context, logpx, integration_times, reverse) + return x, logpx + + +class CNF(nn.Module): + def __init__(self, odefunc, conditional=True, T=1.0, train_T=False, regularization_fns=None, + solver='dopri5', atol=1e-5, rtol=1e-5, use_adjoint=True): + super(CNF, self).__init__() + self.train_T = train_T + self.T = T + if train_T: + self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T)))) + + if regularization_fns is not None and len(regularization_fns) > 0: + raise NotImplementedError("Regularization not supported") + self.use_adjoint = use_adjoint + self.odefunc = odefunc + self.solver = solver + self.atol = atol + self.rtol = rtol + self.test_solver = solver + self.test_atol = atol + self.test_rtol = rtol + self.solver_options = {} + self.conditional = conditional + + def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False): + if logpx is None: + _logpx = torch.zeros(*x.shape[:-1], 1).to(x) + else: + _logpx = logpx + + if self.conditional: + assert context is not None + states = (x, _logpx, context) + atol = [self.atol] * 3 + rtol = [self.rtol] * 3 + else: + states = (x, _logpx) + atol = [self.atol] * 2 + rtol = [self.rtol] * 2 + + if integration_times is None: + if self.train_T: + integration_times = torch.stack( + [torch.tensor(0.0).to(x), self.sqrt_end_time * self.sqrt_end_time] + ).to(x) + else: + integration_times = torch.tensor([0., self.T], requires_grad=False).to(x) + + if reverse: + integration_times = _flip(integration_times, 0) + + # Refresh the odefunc statistics. + self.odefunc.before_odeint() + odeint = odeint_adjoint if self.use_adjoint else odeint_normal + if self.training: + state_t = odeint( + self.odefunc, + states, + integration_times.to(x), + atol=atol, + rtol=rtol, + method=self.solver, + options=self.solver_options, + ) + else: + state_t = odeint( + self.odefunc, + states, + integration_times.to(x), + atol=self.test_atol, + rtol=self.test_rtol, + method=self.test_solver, + ) + + if len(integration_times) == 2: + state_t = tuple(s[1] for s in state_t) + + z_t, logpz_t = state_t[:2] + + if logpx is not None: + return z_t, logpz_t + else: + return z_t + + def num_evals(self): + return self.odefunc._num_evals.item() + + +def _flip(x, dim): + indices = [slice(None)] * x.dim() + indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) + return x[tuple(indices)] diff --git a/models/diffeq_layers.py b/models/diffeq_layers.py new file mode 100644 index 0000000..fae766c --- /dev/null +++ b/models/diffeq_layers.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1 or classname.find('Conv') != -1: + nn.init.constant_(m.weight, 0) + nn.init.normal_(m.bias, 0, 0.01) + + +class IgnoreLinear(nn.Module): + def __init__(self, dim_in, dim_out, dim_c): + super(IgnoreLinear, self).__init__() + self._layer = nn.Linear(dim_in, dim_out) + + def forward(self, context, x): + return self._layer(x) + + +class ConcatLinear(nn.Module): + def __init__(self, dim_in, dim_out, dim_c): + super(ConcatLinear, self).__init__() + self._layer = nn.Linear(dim_in + 1 + dim_c, dim_out) + + def forward(self, context, x, c): + if x.dim() == 3: + context = context.unsqueeze(1).expand(-1, x.size(1), -1) + x_context = torch.cat((x, context), dim=2) + return self._layer(x_context) + + +class ConcatLinear_v2(nn.Module): + def __init__(self, dim_in, dim_out, dim_c): + super(ConcatLinear_v2, self).__init__() + self._layer = nn.Linear(dim_in, dim_out) + self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False) + + def forward(self, context, x): + bias = self._hyper_bias(context) + if x.dim() == 3: + bias = bias.unsqueeze(1) + return self._layer(x) + bias + + +class SquashLinear(nn.Module): + def __init__(self, dim_in, dim_out, dim_c): + super(SquashLinear, self).__init__() + self._layer = nn.Linear(dim_in, dim_out) + self._hyper = nn.Linear(1 + dim_c, dim_out) + + def forward(self, context, x): + gate = torch.sigmoid(self._hyper(context)) + if x.dim() == 3: + gate = gate.unsqueeze(1) + return self._layer(x) * gate + + +class ScaleLinear(nn.Module): + def __init__(self, dim_in, dim_out, dim_c): + super(ScaleLinear, self).__init__() + self._layer = nn.Linear(dim_in, dim_out) + self._hyper = nn.Linear(1 + dim_c, dim_out) + + def forward(self, context, x): + gate = self._hyper(context) + if x.dim() == 3: + gate = gate.unsqueeze(1) + return self._layer(x) * gate + + +class ConcatSquashLinear(nn.Module): + def __init__(self, dim_in, dim_out, dim_c): + super(ConcatSquashLinear, self).__init__() + self._layer = nn.Linear(dim_in, dim_out) + self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False) + self._hyper_gate = nn.Linear(1 + dim_c, dim_out) + + def forward(self, context, x): + gate = torch.sigmoid(self._hyper_gate(context)) + bias = self._hyper_bias(context) + if x.dim() == 3: + gate = gate.unsqueeze(1) + bias = bias.unsqueeze(1) + ret = self._layer(x) * gate + bias + return ret + + +class ConcatScaleLinear(nn.Module): + def __init__(self, dim_in, dim_out, dim_c): + super(ConcatScaleLinear, self).__init__() + self._layer = nn.Linear(dim_in, dim_out) + self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False) + self._hyper_gate = nn.Linear(1 + dim_c, dim_out) + + def forward(self, context, x): + gate = self._hyper_gate(context) + bias = self._hyper_bias(context) + if x.dim() == 3: + gate = gate.unsqueeze(1) + bias = bias.unsqueeze(1) + ret = self._layer(x) * gate + bias + return ret diff --git a/models/flow.py b/models/flow.py new file mode 100644 index 0000000..6979314 --- /dev/null +++ b/models/flow.py @@ -0,0 +1,89 @@ +from .odefunc import ODEfunc, ODEnet +from .normalization import MovingBatchNorm1d +from .cnf import CNF, SequentialFlow + + +def count_nfe(model): + class AccNumEvals(object): + + def __init__(self): + self.num_evals = 0 + + def __call__(self, module): + if isinstance(module, CNF): + self.num_evals += module.num_evals() + + accumulator = AccNumEvals() + model.apply(accumulator) + return accumulator.num_evals + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def count_total_time(model): + class Accumulator(object): + + def __init__(self): + self.total_time = 0 + + def __call__(self, module): + if isinstance(module, CNF): + self.total_time = self.total_time + module.sqrt_end_time * module.sqrt_end_time + + accumulator = Accumulator() + model.apply(accumulator) + return accumulator.total_time + + +def build_model(args, input_dim, hidden_dims, context_dim, num_blocks, conditional): + def build_cnf(): + diffeq = ODEnet( + hidden_dims=hidden_dims, + input_shape=(input_dim,), + context_dim=context_dim, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = ODEfunc( + diffeq=diffeq, + ) + cnf = CNF( + odefunc=odefunc, + T=args.time_length, + train_T=args.train_T, + conditional=conditional, + solver=args.solver, + use_adjoint=args.use_adjoint, + atol=args.atol, + rtol=args.rtol, + ) + return cnf + + chain = [build_cnf() for _ in range(num_blocks)] + if args.batch_norm: + bn_layers = [MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag, sync=args.sync_bn) + for _ in range(num_blocks)] + bn_chain = [MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag, sync=args.sync_bn)] + for a, b in zip(chain, bn_layers): + bn_chain.append(a) + bn_chain.append(b) + chain = bn_chain + model = SequentialFlow(chain) + + return model + + +def get_point_cnf(args): + dims = tuple(map(int, args.dims.split("-"))) + model = build_model(args, args.input_dim, dims, args.zdim, args.num_blocks, True).cuda() + print("Number of trainable parameters of Point CNF: {}".format(count_parameters(model))) + return model + + +def get_latent_cnf(args): + dims = tuple(map(int, args.latent_dims.split("-"))) + model = build_model(args, args.zdim, dims, 0, args.latent_num_blocks, False).cuda() + print("Number of trainable parameters of Latent CNF: {}".format(count_parameters(model))) + return model diff --git a/models/networks.py b/models/networks.py new file mode 100644 index 0000000..88a4935 --- /dev/null +++ b/models/networks.py @@ -0,0 +1,224 @@ +import torch +import numpy as np +import torch.nn.functional as F +from torch import optim +from torch import nn +from models.flow import get_point_cnf +from models.flow import get_latent_cnf +from utils import truncated_normal, reduce_tensor, standard_normal_logprob + + +class Encoder(nn.Module): + def __init__(self, zdim, input_dim=3, use_deterministic_encoder=False): + super(Encoder, self).__init__() + self.use_deterministic_encoder = use_deterministic_encoder + self.zdim = zdim + self.conv1 = nn.Conv1d(input_dim, 128, 1) + self.conv2 = nn.Conv1d(128, 128, 1) + self.conv3 = nn.Conv1d(128, 256, 1) + self.conv4 = nn.Conv1d(256, 512, 1) + self.bn1 = nn.BatchNorm1d(128) + self.bn2 = nn.BatchNorm1d(128) + self.bn3 = nn.BatchNorm1d(256) + self.bn4 = nn.BatchNorm1d(512) + + if self.use_deterministic_encoder: + self.fc1 = nn.Linear(512, 256) + self.fc2 = nn.Linear(256, 128) + self.fc_bn1 = nn.BatchNorm1d(256) + self.fc_bn2 = nn.BatchNorm1d(128) + self.fc3 = nn.Linear(128, zdim) + else: + # Mapping to [c], cmean + self.fc1_m = nn.Linear(512, 256) + self.fc2_m = nn.Linear(256, 128) + self.fc3_m = nn.Linear(128, zdim) + self.fc_bn1_m = nn.BatchNorm1d(256) + self.fc_bn2_m = nn.BatchNorm1d(128) + + # Mapping to [c], cmean + self.fc1_v = nn.Linear(512, 256) + self.fc2_v = nn.Linear(256, 128) + self.fc3_v = nn.Linear(128, zdim) + self.fc_bn1_v = nn.BatchNorm1d(256) + self.fc_bn2_v = nn.BatchNorm1d(128) + + def forward(self, x): + x = x.transpose(1, 2) + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn3(self.conv3(x))) + x = self.bn4(self.conv4(x)) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, 512) + + ms = F.relu(self.fc_bn1(self.fc1(x))) + ms = F.relu(self.fc_bn2(self.fc2(ms))) + ms = self.fc3(ms) + if self.use_deterministic_encoder: + m, v = ms, 0 + else: + m = F.relu(self.fc_bn1_m(self.fc1_m(x))) + m = F.relu(self.fc_bn2_m(self.fc2_m(m))) + m = self.fc3_m(m) + v = F.relu(self.fc_bn1_v(self.fc1_v(x))) + v = F.relu(self.fc_bn2_v(self.fc2_v(v))) + v = self.fc3_v(v) + + return m, v + + +# Model +class PointFlow(nn.Module): + def __init__(self, args): + super(PointFlow, self).__init__() + self.input_dim = args.input_dim + self.zdim = args.zdim + self.use_latent_flow = args.use_latent_flow + self.use_deterministic_encoder = args.use_deterministic_encoder + self.prior_weight = args.prior_weight + self.recon_weight = args.recon_weight + self.entropy_weight = args.entropy_weight + self.distributed = args.distributed + self.truncate_std = None + self.encoder = Encoder( + zdim=args.zdim, input_dim=args.input_dim, + use_deterministic_encoder=args.use_deterministic_encoder) + self.point_cnf = get_point_cnf(args) + self.latent_cnf = get_latent_cnf(args) if args.use_latent_flow else nn.Sequential() + + @staticmethod + def sample_gaussian(size, truncate_std=None, gpu=None): + y = torch.randn(*size).float() + y = y if gpu is None else y.cuda(gpu) + if truncate_std is not None: + truncated_normal(y, mean=0, std=1, trunc_std=truncate_std) + return y + + @staticmethod + def reparameterize_gaussian(mean, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn(std.size()).to(mean) + return mean + std * eps + + @staticmethod + def gaussian_entropy(logvar): + const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2)) + ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const + return ent + + def multi_gpu_wrapper(self, f): + self.encoder = f(self.encoder) + self.point_cnf = f(self.point_cnf) + self.latent_cnf = f(self.latent_cnf) + + def make_optimizer(self, args): + def _get_opt_(params): + if args.optimizer == 'adam': + optimizer = optim.Adam(params, lr=args.lr, betas=(args.beta1, args.beta2), + weight_decay=args.weight_decay) + elif args.optimizer == 'sgd': + optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum) + else: + assert 0, "args.optimizer should be either 'adam' or 'sgd'" + return optimizer + opt = _get_opt_(list(self.encoder.parameters()) + list(self.point_cnf.parameters()) + + list(list(self.latent_cnf.parameters()))) + return opt + + def forward(self, x, opt, step, writer=None): + opt.zero_grad() + batch_size = x.size(0) + num_points = x.size(1) + z_mu, z_sigma = self.encoder(x) + if self.use_deterministic_encoder: + z = z_mu + 0 * z_sigma + else: + z = self.reparameterize_gaussian(z_mu, z_sigma) + + # Compute H[Q(z|X)] + if self.use_deterministic_encoder: + entropy = torch.zeros(batch_size).to(z) + else: + entropy = self.gaussian_entropy(z_sigma) + + # Compute the prior probability P(z) + if self.use_latent_flow: + w, delta_log_pw = self.latent_cnf(z, None, torch.zeros(batch_size, 1).to(z)) + log_pw = standard_normal_logprob(w).view(batch_size, -1).sum(1, keepdim=True) + delta_log_pw = delta_log_pw.view(batch_size, 1) + log_pz = log_pw - delta_log_pw + else: + log_pz = torch.zeros(batch_size, 1).to(z) + + # Compute the reconstruction likelihood P(X|z) + z_new = z.view(*z.size()) + z_new = z_new + (log_pz * 0.).mean() + y, delta_log_py = self.point_cnf(x, z_new, torch.zeros(batch_size, num_points, 1).to(x)) + log_py = standard_normal_logprob(y).view(batch_size, -1).sum(1, keepdim=True) + delta_log_py = delta_log_py.view(batch_size, num_points, 1).sum(1) + log_px = log_py - delta_log_py + + # Loss + entropy_loss = -entropy.mean() * self.entropy_weight + recon_loss = -log_px.mean() * self.recon_weight + prior_loss = -log_pz.mean() * self.prior_weight + loss = entropy_loss + prior_loss + recon_loss + loss.backward() + opt.step() + + # LOGGING (after the training) + if self.distributed: + entropy_log = reduce_tensor(entropy.mean()) + recon = reduce_tensor(-log_px.mean()) + prior = reduce_tensor(-log_pz.mean()) + else: + entropy_log = entropy.mean() + recon = -log_px.mean() + prior = -log_pz.mean() + + recon_nats = recon / float(x.size(1) * x.size(2)) + prior_nats = prior / float(self.zdim) + + if writer is not None: + writer.add_scalar('train/entropy', entropy_log, step) + writer.add_scalar('train/prior', prior, step) + writer.add_scalar('train/prior(nats)', prior_nats, step) + writer.add_scalar('train/recon', recon, step) + writer.add_scalar('train/recon(nats)', recon_nats, step) + + return { + 'entropy': entropy_log.cpu().detach().item() + if not isinstance(entropy_log, float) else entropy_log, + 'prior_nats': prior_nats, + 'recon_nats': recon_nats, + } + + def encode(self, x): + z_mu, z_sigma = self.encoder(x) + if self.use_deterministic_encoder: + return z_mu + else: + return self.reparameterize_gaussian(z_mu, z_sigma) + + def decode(self, z, num_points, truncate_std=None): + # transform points from the prior to a point cloud, conditioned on a shape code + y = self.sample_gaussian((z.size(0), num_points, self.input_dim), truncate_std) + x = self.point_cnf(y, z, reverse=True).view(*y.size()) + return y, x + + def sample(self, batch_size, num_points, truncate_std=None, truncate_std_latent=None, gpu=None): + assert self.use_latent_flow, "Sampling requires `self.use_latent_flow` to be True." + # Generate the shape code from the prior + w = self.sample_gaussian((batch_size, self.zdim), truncate_std_latent, gpu=gpu) + z = self.latent_cnf(w, None, reverse=True).view(*w.size()) + # Sample points conditioned on the shape code + y = self.sample_gaussian((batch_size, num_points, self.input_dim), truncate_std, gpu=gpu) + x = self.point_cnf(y, z, reverse=True).view(*y.size()) + return z, x + + def reconstruct(self, x, num_points=None, truncate_std=None): + num_points = x.size(1) if num_points is None else num_points + z = self.encode(x) + _, x = self.decode(z, num_points, truncate_std) + return x diff --git a/models/normalization.py b/models/normalization.py new file mode 100644 index 0000000..5be7908 --- /dev/null +++ b/models/normalization.py @@ -0,0 +1,145 @@ +import torch +import torch.nn as nn +from torch.nn import Parameter +from utils import reduce_tensor + +__all__ = ['MovingBatchNorm1d'] + + +class MovingBatchNormNd(nn.Module): + def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True, sync=False): + super(MovingBatchNormNd, self).__init__() + self.num_features = num_features + self.sync = sync + self.affine = affine + self.eps = eps + self.decay = decay + self.bn_lag = bn_lag + self.register_buffer('step', torch.zeros(1)) + if self.affine: + self.weight = Parameter(torch.Tensor(num_features)) + self.bias = Parameter(torch.Tensor(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + @property + def shape(self): + raise NotImplementedError + + def reset_parameters(self): + self.running_mean.zero_() + self.running_var.fill_(1) + if self.affine: + self.weight.data.zero_() + self.bias.data.zero_() + + def forward(self, x, c=None, logpx=None, reverse=False): + if reverse: + return self._reverse(x, logpx) + else: + return self._forward(x, logpx) + + def _forward(self, x, logpx=None): + num_channels = x.size(-1) + used_mean = self.running_mean.clone().detach() + used_var = self.running_var.clone().detach() + + if self.training: + # compute batch statistics + x_t = x.transpose(0, 1).reshape(num_channels, -1) + batch_mean = torch.mean(x_t, dim=1) + + if self.sync: + batch_ex2 = torch.mean(x_t**2, dim=1) + batch_mean = reduce_tensor(batch_mean) + batch_ex2 = reduce_tensor(batch_ex2) + batch_var = batch_ex2 - batch_mean**2 + else: + batch_var = torch.var(x_t, dim=1) + + # moving average + if self.bn_lag > 0: + used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach()) + used_mean /= (1. - self.bn_lag**(self.step[0] + 1)) + used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach()) + used_var /= (1. - self.bn_lag**(self.step[0] + 1)) + + # update running estimates + self.running_mean -= self.decay * (self.running_mean - batch_mean.data) + self.running_var -= self.decay * (self.running_var - batch_var.data) + self.step += 1 + + # perform normalization + used_mean = used_mean.view(*self.shape).expand_as(x) + used_var = used_var.view(*self.shape).expand_as(x) + + y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps)) + + if self.affine: + weight = self.weight.view(*self.shape).expand_as(x) + bias = self.bias.view(*self.shape).expand_as(x) + y = y * torch.exp(weight) + bias + + if logpx is None: + return y + else: + return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True) + + def _reverse(self, y, logpy=None): + used_mean = self.running_mean + used_var = self.running_var + + if self.affine: + weight = self.weight.view(*self.shape).expand_as(y) + bias = self.bias.view(*self.shape).expand_as(y) + y = (y - bias) * torch.exp(-weight) + + used_mean = used_mean.view(*self.shape).expand_as(y) + used_var = used_var.view(*self.shape).expand_as(y) + x = y * torch.exp(0.5 * torch.log(used_var + self.eps)) + used_mean + + if logpy is None: + return x + else: + return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True) + + def _logdetgrad(self, x, used_var): + logdetgrad = -0.5 * torch.log(used_var + self.eps) + if self.affine: + weight = self.weight.view(*self.shape).expand(*x.size()) + logdetgrad += weight + return logdetgrad + + def __repr__(self): + return ( + '{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},' + ' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__) + ) + + +def stable_var(x, mean=None, dim=1): + if mean is None: + mean = x.mean(dim, keepdim=True) + mean = mean.view(-1, 1) + res = torch.pow(x - mean, 2) + max_sqr = torch.max(res, dim, keepdim=True)[0] + var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr + var = var.view(-1) + # change nan to zero + var[var != var] = 0 + return var + + +class MovingBatchNorm1d(MovingBatchNormNd): + @property + def shape(self): + return [1, -1] + + def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False): + ret = super(MovingBatchNorm1d, self).forward( + x, context, logpx=logpx, reverse=reverse) + return ret diff --git a/models/odefunc.py b/models/odefunc.py new file mode 100644 index 0000000..76ccc69 --- /dev/null +++ b/models/odefunc.py @@ -0,0 +1,137 @@ +import copy +import torch +import torch.nn as nn +from . import diffeq_layers + +__all__ = ["ODEnet", "ODEfunc"] + + +def divergence_approx(f, y, e=None): + e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0] + e_dzdx_e = e_dzdx.mul(e) + + cnt = 0 + while not e_dzdx_e.requires_grad and cnt < 10: + # print("RequiresGrad:f=%s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt=%d" + # % (f.requires_grad, y.requires_grad, e_dzdx.requires_grad, + # e.requires_grad, e_dzdx_e.requires_grad, cnt)) + e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0] + e_dzdx_e = e_dzdx * e + cnt += 1 + + approx_tr_dzdx = e_dzdx_e.sum(dim=-1) + assert approx_tr_dzdx.requires_grad, \ + "(failed to add node to graph) f=%s %s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt:%s" \ + % ( + f.size(), f.requires_grad, y.requires_grad, e_dzdx.requires_grad, e.requires_grad, e_dzdx_e.requires_grad, cnt) + return approx_tr_dzdx + + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + self.beta = nn.Parameter(torch.tensor(1.0)) + + def forward(self, x): + return x * torch.sigmoid(self.beta * x) + + +class Lambda(nn.Module): + def __init__(self, f): + super(Lambda, self).__init__() + self.f = f + + def forward(self, x): + return self.f(x) + + +NONLINEARITIES = { + "tanh": nn.Tanh(), + "relu": nn.ReLU(), + "softplus": nn.Softplus(), + "elu": nn.ELU(), + "swish": Swish(), + "square": Lambda(lambda x: x ** 2), + "identity": Lambda(lambda x: x), +} + + +class ODEnet(nn.Module): + """ + Helper class to make neural nets for use in continuous normalizing flows + """ + + def __init__(self, hidden_dims, input_shape, context_dim, layer_type="concat", nonlinearity="softplus"): + super(ODEnet, self).__init__() + base_layer = { + "ignore": diffeq_layers.IgnoreLinear, + "squash": diffeq_layers.SquashLinear, + "scale": diffeq_layers.ScaleLinear, + "concat": diffeq_layers.ConcatLinear, + "concat_v2": diffeq_layers.ConcatLinear_v2, + "concatsquash": diffeq_layers.ConcatSquashLinear, + "concatscale": diffeq_layers.ConcatScaleLinear, + }[layer_type] + + # build models and add them + layers = [] + activation_fns = [] + hidden_shape = input_shape + + for dim_out in (hidden_dims + (input_shape[0],)): + layer_kwargs = {} + layer = base_layer(hidden_shape[0], dim_out, context_dim, **layer_kwargs) + layers.append(layer) + activation_fns.append(NONLINEARITIES[nonlinearity]) + + hidden_shape = list(copy.copy(hidden_shape)) + hidden_shape[0] = dim_out + + self.layers = nn.ModuleList(layers) + self.activation_fns = nn.ModuleList(activation_fns[:-1]) + + def forward(self, context, y): + dx = y + for l, layer in enumerate(self.layers): + dx = layer(context, dx) + # if not last layer, use nonlinearity + if l < len(self.layers) - 1: + dx = self.activation_fns[l](dx) + return dx + + +class ODEfunc(nn.Module): + def __init__(self, diffeq): + super(ODEfunc, self).__init__() + self.diffeq = diffeq + self.divergence_fn = divergence_approx + self.register_buffer("_num_evals", torch.tensor(0.)) + + def before_odeint(self, e=None): + self._e = e + self._num_evals.fill_(0) + + def forward(self, t, states): + y = states[0] + t = torch.ones(y.size(0), 1).to(y) * t.clone().detach().requires_grad_(True).type_as(y) + self._num_evals += 1 + for state in states: + state.requires_grad_(True) + + # Sample and fix the noise. + if self._e is None: + self._e = torch.randn_like(y, requires_grad=True).to(y) + + with torch.set_grad_enabled(True): + if len(states) == 3: # conditional CNF + c = states[2] + tc = torch.cat([t, c.view(y.size(0), -1)], dim=1) + dy = self.diffeq(tc, y) + divergence = self.divergence_fn(dy, y, e=self._e).unsqueeze(-1) + return dy, -divergence, torch.zeros_like(c).requires_grad_(True) + elif len(states) == 2: # unconditional CNF + dy = self.diffeq(t, y) + divergence = self.divergence_fn(dy, y, e=self._e).view(-1, 1) + return dy, -divergence + else: + assert 0, "`len(states)` should be 2 or 3" diff --git a/scripts/shapenet_airplane_ae.sh b/scripts/shapenet_airplane_ae.sh new file mode 100755 index 0000000..6768dbe --- /dev/null +++ b/scripts/shapenet_airplane_ae.sh @@ -0,0 +1,38 @@ +#! /bin/bash + +cate="airplane" +dims="512-512-512" +latent_dims="256-256" +num_blocks=1 +latent_num_blocks=1 +zdim=128 +batch_size=48 +lr=2e-3 +epochs=4000 +ds=shapenet15k +log_name="ae/${ds}-cate${cate}" +data_dir="data/ShapeNetCore.v2.PC15k" + +python train.py \ + --log_name ${log_name} \ + --lr ${lr} \ + --dataset_type ${ds} \ + --data_dir ${data_dir} \ + --cates ${cate} \ + --dims ${dims} \ + --latent_dims ${latent_dims} \ + --num_blocks ${num_blocks} \ + --latent_num_blocks ${latent_num_blocks} \ + --batch_size ${batch_size} \ + --zdim ${zdim} \ + --epochs ${epochs} \ + --save_freq 50 \ + --viz_freq 1 \ + --log_freq 1 \ + --val_freq 10 \ + --use_deterministic_encoder \ + --prior_weight 0 \ + --entropy_weight 0 + +echo "Done" +exit 0 diff --git a/scripts/shapenet_airplane_ae_dist.sh b/scripts/shapenet_airplane_ae_dist.sh new file mode 100755 index 0000000..4d23280 --- /dev/null +++ b/scripts/shapenet_airplane_ae_dist.sh @@ -0,0 +1,39 @@ +#! /bin/bash + +cate="airplane" +dims="512-512-512" +latent_dims="256-256" +num_blocks=1 +latent_num_blocks=1 +zdim=128 +batch_size=256 +lr=2e-3 +epochs=4000 +ds=shapenet15k +log_name="ae/${ds}-cate${cate}" +data_dir="data/ShapeNetCore.v2.PC15k" + +python train.py \ + --log_name ${log_name} \ + --lr ${lr} \ + --dataset_type ${ds} \ + --data_dir ${data_dir} \ + --cates ${cate} \ + --dims ${dims} \ + --latent_dims ${latent_dims} \ + --num_blocks ${num_blocks} \ + --latent_num_blocks ${latent_num_blocks} \ + --batch_size ${batch_size} \ + --zdim ${zdim} \ + --epochs ${epochs} \ + --save_freq 50 \ + --viz_freq 1 \ + --log_freq 1 \ + --val_freq 10 \ + --distributed \ + --use_deterministic_encoder \ + --prior_weight 0 \ + --entropy_weight 0 + +echo "Done" +exit 0 diff --git a/scripts/shapenet_airplane_ae_test.sh b/scripts/shapenet_airplane_ae_test.sh new file mode 100755 index 0000000..63636a0 --- /dev/null +++ b/scripts/shapenet_airplane_ae_test.sh @@ -0,0 +1,8 @@ +#! /bin/bash + +python test.py \ + --cates airplane \ + --resume_checkpoint pretrained_models/ae/airplane/checkpoint.pt \ + --dims 512-512-512 \ + --use_deterministic_encoder \ + --evaluate_recon diff --git a/scripts/shapenet_airplane_demo.sh b/scripts/shapenet_airplane_demo.sh new file mode 100755 index 0000000..8ece9ee --- /dev/null +++ b/scripts/shapenet_airplane_demo.sh @@ -0,0 +1,11 @@ +#! /bin/bash + +python demo.py \ + --cates airplane \ + --resume_checkpoint pretrained_models/gen/airplane/checkpoint.pt \ + --dims 512-512-512 \ + --latent_dims 256-256 \ + --use_latent_flow \ + --num_sample_shapes 20 \ + --num_sample_points 2048 + diff --git a/scripts/shapenet_airplane_gen.sh b/scripts/shapenet_airplane_gen.sh new file mode 100755 index 0000000..a5cf987 --- /dev/null +++ b/scripts/shapenet_airplane_gen.sh @@ -0,0 +1,36 @@ +#! /bin/bash + +cate="airplane" +dims="512-512-512" +latent_dims="256-256" +num_blocks=1 +latent_num_blocks=1 +zdim=128 +batch_size=16 +lr=2e-3 +epochs=4000 +ds=shapenet15k +log_name="gen/${ds}-cate${cate}" +data_dir="data/ShapeNetCore.v2.PC15k" + +python train.py \ + --log_name ${log_name} \ + --lr ${lr} \ + --dataset_type ${ds} \ + --data_dir ${data_dir} \ + --cates ${cate} \ + --dims ${dims} \ + --latent_dims ${latent_dims} \ + --num_blocks ${num_blocks} \ + --latent_num_blocks ${latent_num_blocks} \ + --batch_size ${batch_size} \ + --zdim ${zdim} \ + --epochs ${epochs} \ + --save_freq 50 \ + --viz_freq 1 \ + --log_freq 1 \ + --val_freq 10 \ + --use_latent_flow + +echo "Done" +exit 0 diff --git a/scripts/shapenet_airplane_gen_dist.sh b/scripts/shapenet_airplane_gen_dist.sh new file mode 100755 index 0000000..94704d5 --- /dev/null +++ b/scripts/shapenet_airplane_gen_dist.sh @@ -0,0 +1,38 @@ +#! /bin/bash + +cate="airplane" +dims="512-512-512" +latent_dims="256-256" +num_blocks=1 +latent_num_blocks=1 +zdim=128 +batch_size=256 +lr=2e-3 +epochs=4000 +ds=shapenet15k +log_name="gen/${ds}-cate${cate}-seqback" +data_dir="data/ShapeNetCore.v2.PC15k" + +python train.py \ + --log_name ${log_name} \ + --lr ${lr} \ + --train_T False \ + --dataset_type ${ds} \ + --data_dir ${data_dir} \ + --cates ${cate} \ + --dims ${dims} \ + --latent_dims ${latent_dims} \ + --num_blocks ${num_blocks} \ + --latent_num_blocks ${latent_num_blocks} \ + --batch_size ${batch_size} \ + --zdim ${zdim} \ + --epochs ${epochs} \ + --save_freq 50 \ + --viz_freq 1 \ + --log_freq 1 \ + --val_freq 10 \ + --distributed \ + --use_latent_flow + +echo "Done" +exit 0 diff --git a/scripts/shapenet_airplane_gen_test.sh b/scripts/shapenet_airplane_gen_test.sh new file mode 100755 index 0000000..50dd1e2 --- /dev/null +++ b/scripts/shapenet_airplane_gen_test.sh @@ -0,0 +1,10 @@ +#! /bin/bash + +python test.py \ + --cates airplane \ + --resume_checkpoint pretrained_models/gen/airplane/checkpoint.pt \ + --dims 512-512-512 \ + --latent_dims 256-256 \ + --use_latent_flow + + diff --git a/scripts/shapenet_all_ae_test.sh b/scripts/shapenet_all_ae_test.sh new file mode 100755 index 0000000..32dacda --- /dev/null +++ b/scripts/shapenet_all_ae_test.sh @@ -0,0 +1,11 @@ +#! /bin/bash + +python test.py \ + --cates all \ + --resume_checkpoint pretrained_models/ae/all/checkpoint.pt \ + --dims 512-512-512 \ + --use_deterministic_encoder \ + --evaluate_recon \ + --resume_dataset_mean pretrained_models/ae/all/train_set_mean.npy \ + --resume_dataset_std pretrained_models/ae/all/train_set_std.npy + diff --git a/test.py b/test.py new file mode 100644 index 0000000..4b50133 --- /dev/null +++ b/test.py @@ -0,0 +1,167 @@ +from datasets import get_datasets, synsetid_to_cate +from args import get_args +from pprint import pprint +from metrics.evaluation_metrics import EMD_CD +from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD +from metrics.evaluation_metrics import compute_all_metrics +from collections import defaultdict +from models.networks import PointFlow +import os +import torch +import numpy as np +import torch.nn as nn + + +def get_test_loader(args): + _, te_dataset = get_datasets(args) + if args.resume_dataset_mean is not None and args.resume_dataset_std is not None: + mean = np.load(args.resume_dataset_mean) + std = np.load(args.resume_dataset_std) + te_dataset.renormalize(mean, std) + loader = torch.utils.data.DataLoader( + dataset=te_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=0, pin_memory=True, drop_last=False) + return loader + + +def evaluate_recon(model, args): + # TODO: make this memory efficient + if 'all' in args.cates: + cates = list(synsetid_to_cate.values()) + else: + cates = args.cates + all_results = {} + cate_to_len = {} + save_dir = os.path.dirname(args.resume_checkpoint) + for cate in cates: + args.cates = [cate] + loader = get_test_loader(args) + + all_sample = [] + all_ref = [] + for data in loader: + idx_b, tr_pc, te_pc = data['idx'], data['train_points'], data['test_points'] + te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu) + tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu) + B, N = te_pc.size(0), te_pc.size(1) + out_pc = model.reconstruct(tr_pc, num_points=N) + m, s = data['mean'].float(), data['std'].float() + m = m.cuda() if args.gpu is None else m.cuda(args.gpu) + s = s.cuda() if args.gpu is None else s.cuda(args.gpu) + out_pc = out_pc * s + m + te_pc = te_pc * s + m + + all_sample.append(out_pc) + all_ref.append(te_pc) + + sample_pcs = torch.cat(all_sample, dim=0) + ref_pcs = torch.cat(all_ref, dim=0) + cate_to_len[cate] = int(sample_pcs.size(0)) + print("Cate=%s Total Sample size:%s Ref size: %s" + % (cate, sample_pcs.size(), ref_pcs.size())) + + # Save it + np.save(os.path.join(save_dir, "%s_out_smp.npy" % cate), + sample_pcs.cpu().detach().numpy()) + np.save(os.path.join(save_dir, "%s_out_ref.npy" % cate), + ref_pcs.cpu().detach().numpy()) + + results = EMD_CD(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True) + results = { + k: (v.cpu().detach().item() if not isinstance(v, float) else v) + for k, v in results.items()} + pprint(results) + all_results[cate] = results + + # Save final results + print("="*80) + print("All category results:") + print("="*80) + pprint(all_results) + save_path = os.path.join(save_dir, "percate_results.npy") + np.save(save_path, all_results) + + # Compute weighted performance + ttl_r, ttl_cnt = defaultdict(lambda: 0.), defaultdict(lambda: 0.) + for catename, l in cate_to_len.items(): + for k, v in all_results[catename].items(): + ttl_r[k] += v * float(l) + ttl_cnt[k] += float(l) + ttl_res = {k: (float(ttl_r[k]) / float(ttl_cnt[k])) for k in ttl_r.keys()} + print("="*80) + print("Averaged results:") + pprint(ttl_res) + print("="*80) + + save_path = os.path.join(save_dir, "results.npy") + np.save(save_path, all_results) + + +def evaluate_gen(model, args): + loader = get_test_loader(args) + all_sample = [] + all_ref = [] + for data in loader: + idx_b, te_pc = data['idx'], data['test_points'] + te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu) + B, N = te_pc.size(0), te_pc.size(1) + _, out_pc = model.sample(B, N) + + # denormalize + m, s = data['mean'].float(), data['std'].float() + m = m.cuda() if args.gpu is None else m.cuda(args.gpu) + s = s.cuda() if args.gpu is None else s.cuda(args.gpu) + out_pc = out_pc * s + m + te_pc = te_pc * s + m + + all_sample.append(out_pc) + all_ref.append(te_pc) + + sample_pcs = torch.cat(all_sample, dim=0) + ref_pcs = torch.cat(all_ref, dim=0) + print("Generation sample size:%s reference size: %s" + % (sample_pcs.size(), ref_pcs.size())) + + # Save the generative output + save_dir = os.path.dirname(args.resume_checkpoint) + np.save(os.path.join(save_dir, "model_out_smp.npy"), sample_pcs.cpu().detach().numpy()) + np.save(os.path.join(save_dir, "model_out_ref.npy"), ref_pcs.cpu().detach().numpy()) + + # Compute metrics + results = compute_all_metrics(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True) + results = {k: (v.cpu().detach().item() + if not isinstance(v, float) else v) for k, v in results.items()} + pprint(results) + + sample_pcl_npy = sample_pcs.cpu().detach().numpy() + ref_pcl_npy = ref_pcs.cpu().detach().numpy() + jsd = JSD(sample_pcl_npy, ref_pcl_npy) + print("JSD:%s" % jsd) + + +def main(args): + model = PointFlow(args) + + def _transform_(m): + return nn.DataParallel(m) + + model = model.cuda() + model.multi_gpu_wrapper(_transform_) + + print("Resume Path:%s" % args.resume_checkpoint) + checkpoint = torch.load(args.resume_checkpoint) + model.load_state_dict(checkpoint) + model.eval() + + with torch.no_grad(): + if args.evaluate_recon: + # Evaluate reconstruction + evaluate_recon(model, args) + else: + # Evaluate generation + evaluate_gen(model, args) + + +if __name__ == '__main__': + args = get_args() + main(args) diff --git a/train.py b/train.py new file mode 100644 index 0000000..650a0c3 --- /dev/null +++ b/train.py @@ -0,0 +1,272 @@ +import sys +import os +import torch +import torch.distributed as dist +import torch.nn as nn +import warnings +import torch.distributed +import numpy as np +import random +import faulthandler +import torch.multiprocessing as mp +import time +import scipy.misc +from models.networks import PointFlow +from torch import optim +from args import get_args +from torch.backends import cudnn +from utils import AverageValueMeter, set_random_seed, apply_random_rotation, save, resume, visualize_point_clouds +from tensorboardX import SummaryWriter +from datasets import get_datasets, init_np_seed + +faulthandler.enable() + + +def main_worker(gpu, save_dir, ngpus_per_node, args): + # basic setup + cudnn.benchmark = True + args.gpu = gpu + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.distributed: + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + + if args.log_name is not None: + log_dir = "runs/%s" % args.log_name + else: + log_dir = "runs/time-%d" % time.time() + + if not args.distributed or (args.rank % ngpus_per_node == 0): + writer = SummaryWriter(logdir=log_dir) + else: + writer = None + + if not args.use_latent_flow: # auto-encoder only + args.prior_weight = 0 + args.entropy_weight = 0 + + # multi-GPU setup + model = PointFlow(args) + if args.distributed: # Multiple processes, single GPU per process + if args.gpu is not None: + def _transform_(m): + return nn.parallel.DistributedDataParallel( + m, device_ids=[args.gpu], output_device=args.gpu, check_reduction=True) + + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + model.multi_gpu_wrapper(_transform_) + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = 0 + else: + assert 0, "DistributedDataParallel constructor should always set the single device scope" + elif args.gpu is not None: # Single process, single GPU per process + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + else: # Single process, multiple GPUs per process + def _transform_(m): + return nn.DataParallel(m) + model = model.cuda() + model.multi_gpu_wrapper(_transform_) + + # resume checkpoints + start_epoch = 0 + optimizer = model.make_optimizer(args) + if args.resume_checkpoint is None and os.path.exists(os.path.join(save_dir, 'checkpoint-latest.pt')): + args.resume_checkpoint = os.path.join(save_dir, 'checkpoint-latest.pt') # use the latest checkpoint + if args.resume_checkpoint is not None: + if args.resume_optimizer: + model, optimizer, start_epoch = resume( + args.resume_checkpoint, model, optimizer, strict=(not args.resume_non_strict)) + else: + model, _, start_epoch = resume( + args.resume_checkpoint, model, optimizer=None, strict=(not args.resume_non_strict)) + print('Resumed from: ' + args.resume_checkpoint) + + # initialize datasets and loaders + tr_dataset, te_dataset = get_datasets(args) + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(tr_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + dataset=tr_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=0, pin_memory=True, sampler=train_sampler, drop_last=True, + worker_init_fn=init_np_seed) + test_loader = torch.utils.data.DataLoader( + dataset=te_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=0, pin_memory=True, drop_last=False, + worker_init_fn=init_np_seed) + + # save dataset statistics + if not args.distributed or (args.rank % ngpus_per_node == 0): + np.save(os.path.join(save_dir, "train_set_mean.npy"), tr_dataset.all_points_mean) + np.save(os.path.join(save_dir, "train_set_std.npy"), tr_dataset.all_points_std) + np.save(os.path.join(save_dir, "train_set_idx.npy"), np.array(tr_dataset.shuffle_idx)) + np.save(os.path.join(save_dir, "val_set_mean.npy"), te_dataset.all_points_mean) + np.save(os.path.join(save_dir, "val_set_std.npy"), te_dataset.all_points_std) + np.save(os.path.join(save_dir, "val_set_idx.npy"), np.array(te_dataset.shuffle_idx)) + + # load classification dataset if needed + if args.eval_classification: + from datasets import get_clf_datasets + + def _make_data_loader_(dataset): + return torch.utils.data.DataLoader( + dataset=dataset, batch_size=args.batch_size, shuffle=False, + num_workers=0, pin_memory=True, drop_last=False, + worker_init_fn=init_np_seed + ) + + clf_datasets = get_clf_datasets(args) + clf_loaders = { + k: [_make_data_loader_(ds) for ds in ds_lst] for k, ds_lst in clf_datasets.items() + } + else: + clf_loaders = None + + # initialize the learning rate scheduler + if args.scheduler == 'exponential': + scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay) + elif args.scheduler == 'step': + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs // 2, gamma=0.1) + elif args.scheduler == 'linear': + def lambda_rule(ep): + lr_l = 1.0 - max(0, ep - 0.5 * args.epochs) / float(0.5 * args.epochs) + return lr_l + scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + else: + assert 0, "args.schedulers should be either 'exponential' or 'linear'" + + # main training loop + start_time = time.time() + entropy_avg_meter = AverageValueMeter() + latent_nats_avg_meter = AverageValueMeter() + point_nats_avg_meter = AverageValueMeter() + if args.distributed: + print("[Rank %d] World size : %d" % (args.rank, dist.get_world_size())) + + print("Start epoch: %d End epoch: %d" % (start_epoch, args.epochs)) + for epoch in range(start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + + # adjust the learning rate + if (epoch + 1) % args.exp_decay_freq == 0: + scheduler.step(epoch=epoch) + if writer is not None: + writer.add_scalar('lr/optimizer', scheduler.get_lr()[0], epoch) + + # train for one epoch + for bidx, data in enumerate(train_loader): + idx_batch, tr_batch, te_batch = data['idx'], data['train_points'], data['test_points'] + step = bidx + len(train_loader) * epoch + model.train() + if args.random_rotate: + tr_batch, _, _ = apply_random_rotation( + tr_batch, rot_axis=train_loader.dataset.gravity_axis) + inputs = tr_batch.cuda(args.gpu, non_blocking=True) + out = model(inputs, optimizer, step, writer) + entropy, prior_nats, recon_nats = out['entropy'], out['prior_nats'], out['recon_nats'] + entropy_avg_meter.update(entropy) + point_nats_avg_meter.update(recon_nats) + latent_nats_avg_meter.update(prior_nats) + if step % args.log_freq == 0: + duration = time.time() - start_time + start_time = time.time() + print("[Rank %d] Epoch %d Batch [%2d/%2d] Time [%3.2fs] Entropy %2.5f LatentNats %2.5f PointNats %2.5f" + % (args.rank, epoch, bidx, len(train_loader), duration, entropy_avg_meter.avg, + latent_nats_avg_meter.avg, point_nats_avg_meter.avg)) + + # evaluate on the validation set + if not args.no_validation and (epoch + 1) % args.val_freq == 0: + from utils import validate + validate(test_loader, model, epoch, writer, save_dir, args, clf_loaders=clf_loaders) + + # save visualizations + if (epoch + 1) % args.viz_freq == 0: + # reconstructions + model.eval() + samples = model.reconstruct(inputs) + results = [] + for idx in range(min(10, inputs.size(0))): + res = visualize_point_clouds(samples[idx], inputs[idx], idx, + pert_order=train_loader.dataset.display_axis_order) + results.append(res) + res = np.concatenate(results, axis=1) + scipy.misc.imsave(os.path.join(save_dir, 'images', 'tr_vis_conditioned_epoch%d-gpu%s.png' % (epoch, args.gpu)), + res.transpose((1, 2, 0))) + if writer is not None: + writer.add_image('tr_vis/conditioned', torch.as_tensor(res), epoch) + + # samples + if args.use_latent_flow: + num_samples = min(10, inputs.size(0)) + num_points = inputs.size(1) + _, samples = model.sample(num_samples, num_points) + results = [] + for idx in range(num_samples): + res = visualize_point_clouds(samples[idx], inputs[idx], idx, + pert_order=train_loader.dataset.display_axis_order) + results.append(res) + res = np.concatenate(results, axis=1) + scipy.misc.imsave(os.path.join(save_dir, 'images', 'tr_vis_conditioned_epoch%d-gpu%s.png' % (epoch, args.gpu)), + res.transpose((1, 2, 0))) + if writer is not None: + writer.add_image('tr_vis/sampled', torch.as_tensor(res), epoch) + + # save checkpoints + if not args.distributed or (args.rank % ngpus_per_node == 0): + if (epoch + 1) % args.save_freq == 0: + save(model, optimizer, epoch + 1, + os.path.join(save_dir, 'checkpoint-%d.pt' % epoch)) + save(model, optimizer, epoch + 1, + os.path.join(save_dir, 'checkpoint-latest.pt')) + + +def main(): + # command line args + args = get_args() + save_dir = os.path.join("checkpoints", args.log_name) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + os.makedirs(os.path.join(save_dir, 'images')) + + with open(os.path.join(save_dir, 'command.sh'), 'w') as f: + f.write('python -X faulthandler ' + ' '.join(sys.argv)) + f.write('\n') + + if args.seed is None: + args.seed = random.randint(0, 1000000) + set_random_seed(args.seed) + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + if args.sync_bn: + assert args.distributed + + print("Arguments:") + print(args) + + ngpus_per_node = torch.cuda.device_count() + if args.distributed: + args.world_size = ngpus_per_node * args.world_size + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(save_dir, ngpus_per_node, args)) + else: + main_worker(args.gpu, save_dir, ngpus_per_node, args) + + +if __name__ == '__main__': + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..f35176c --- /dev/null +++ b/utils.py @@ -0,0 +1,378 @@ +from pprint import pprint +from sklearn.svm import LinearSVC +from math import log, pi +import os +import torch +import torch.distributed as dist +import random +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D + + +class AverageValueMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0.0 + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0.0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def gaussian_log_likelihood(x, mean, logvar, clip=True): + if clip: + logvar = torch.clamp(logvar, min=-4, max=3) + a = log(2 * pi) + b = logvar + c = (x - mean) ** 2 / torch.exp(logvar) + return -0.5 * torch.sum(a + b + c) + + +def bernoulli_log_likelihood(x, p, clip=True, eps=1e-6): + if clip: + p = torch.clamp(p, min=eps, max=1 - eps) + return torch.sum((x * torch.log(p)) + ((1 - x) * torch.log(1 - p))) + + +def kl_diagnormal_stdnormal(mean, logvar): + a = mean ** 2 + b = torch.exp(logvar) + c = -1 + d = -logvar + return 0.5 * torch.sum(a + b + c + d) + + +def kl_diagnormal_diagnormal(q_mean, q_logvar, p_mean, p_logvar): + # Ensure correct shapes since no numpy broadcasting yet + p_mean = p_mean.expand_as(q_mean) + p_logvar = p_logvar.expand_as(q_logvar) + + a = p_logvar + b = - 1 + c = - q_logvar + d = ((q_mean - p_mean) ** 2 + torch.exp(q_logvar)) / torch.exp(p_logvar) + return 0.5 * torch.sum(a + b + c + d) + + +# Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15 +def truncated_normal(tensor, mean=0, std=1, trunc_std=2): + size = tensor.shape + tmp = tensor.new_empty(size + (4,)).normal_() + valid = (tmp < trunc_std) & (tmp > -trunc_std) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + return tensor + + +def reduce_tensor(tensor, world_size=None): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + if world_size is None: + world_size = dist.get_world_size() + + rt /= world_size + return rt + + +def standard_normal_logprob(z): + dim = z.size(-1) + log_z = -0.5 * dim * log(2 * pi) + return log_z - z.pow(2) / 2 + + +def set_random_seed(seed): + """set random seed""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +# Visualization +def visualize_point_clouds(pts, gtr, idx, pert_order=[0, 1, 2]): + pts = pts.cpu().detach().numpy()[:, pert_order] + gtr = gtr.cpu().detach().numpy()[:, pert_order] + + fig = plt.figure(figsize=(6, 3)) + ax1 = fig.add_subplot(121, projection='3d') + ax1.set_title("Sample:%s" % idx) + ax1.scatter(pts[:, 0], pts[:, 1], pts[:, 2], s=5) + + ax2 = fig.add_subplot(122, projection='3d') + ax2.set_title("Ground Truth:%s" % idx) + ax2.scatter(gtr[:, 0], gtr[:, 1], gtr[:, 2], s=5) + + fig.canvas.draw() + + # grab the pixel buffer and dump it into a numpy array + res = np.array(fig.canvas.renderer._renderer) + res = np.transpose(res, (2, 0, 1)) + + plt.close() + return res + + +# Augmentation +def apply_random_rotation(pc, rot_axis=1): + B = pc.shape[0] + + theta = np.random.rand(B) * 2 * np.pi + zeros = np.zeros(B) + ones = np.ones(B) + cos = np.cos(theta) + sin = np.sin(theta) + + if rot_axis == 0: + rot = np.stack([ + cos, -sin, zeros, + sin, cos, zeros, + zeros, zeros, ones + ]).T.reshape(B, 3, 3) + elif rot_axis == 1: + rot = np.stack([ + cos, zeros, -sin, + zeros, ones, zeros, + sin, zeros, cos + ]).T.reshape(B, 3, 3) + elif rot_axis == 2: + rot = np.stack([ + ones, zeros, zeros, + zeros, cos, -sin, + zeros, sin, cos + ]).T.reshape(B, 3, 3) + else: + raise Exception("Invalid rotation axis") + rot = torch.from_numpy(rot).to(pc) + + # (B, N, 3) mul (B, 3, 3) -> (B, N, 3) + pc_rotated = torch.bmm(pc, rot) + return pc_rotated, rot, theta + + +def validate_classification(loaders, model, args): + train_loader, test_loader = loaders + + def _make_iter_(loader): + iterator = iter(loader) + return iterator + + tr_latent = [] + tr_label = [] + for data in _make_iter_(train_loader): + tr_pc = data['train_points'] + tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu) + latent = model.encode(tr_pc) + label = data['cate_idx'] + tr_latent.append(latent.cpu().detach().numpy()) + tr_label.append(label.cpu().detach().numpy()) + tr_label = np.concatenate(tr_label) + tr_latent = np.concatenate(tr_latent) + + te_latent = [] + te_label = [] + for data in _make_iter_(test_loader): + tr_pc = data['train_points'] + tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu) + latent = model.encode(tr_pc) + label = data['cate_idx'] + te_latent.append(latent.cpu().detach().numpy()) + te_label.append(label.cpu().detach().numpy()) + te_label = np.concatenate(te_label) + te_latent = np.concatenate(te_latent) + + clf = LinearSVC(random_state=0) + clf.fit(tr_latent, tr_label) + test_pred = clf.predict(te_latent) + test_gt = te_label.flatten() + acc = np.mean((test_pred == test_gt).astype(float)) * 100. + res = {'acc': acc} + print("Acc:%s" % acc) + return res + + +def validate_conditioned(loader, model, args, max_samples=None, save_dir=None): + from metrics.evaluation_metrics import EMD_CD + all_idx = [] + all_sample = [] + all_ref = [] + ttl_samples = 0 + iterator = iter(loader) + + for data in iterator: + # idx_b, tr_pc, te_pc = data[:3] + idx_b, tr_pc, te_pc = data['idx'], data['train_points'], data['test_points'] + tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu) + te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu) + + if tr_pc.size(1) > te_pc.size(1): + tr_pc = tr_pc[:, :te_pc.size(1), :] + out_pc = model.reconstruct(tr_pc, num_points=te_pc.size(1)) + + # denormalize + m, s = data['mean'].float(), data['std'].float() + m = m.cuda() if args.gpu is None else m.cuda(args.gpu) + s = s.cuda() if args.gpu is None else s.cuda(args.gpu) + out_pc = out_pc * s + m + te_pc = te_pc * s + m + + all_sample.append(out_pc) + all_ref.append(te_pc) + all_idx.append(idx_b) + + ttl_samples += int(te_pc.size(0)) + if max_samples is not None and ttl_samples >= max_samples: + break + + # Compute MMD and CD + sample_pcs = torch.cat(all_sample, dim=0) + ref_pcs = torch.cat(all_ref, dim=0) + print("[rank %s] Recon Sample size:%s Ref size: %s" % (args.rank, sample_pcs.size(), ref_pcs.size())) + + if save_dir is not None and args.save_val_results: + smp_pcs_save_name = os.path.join(save_dir, "smp_recon_pcls_gpu%s.npy" % args.gpu) + ref_pcs_save_name = os.path.join(save_dir, "ref_recon_pcls_gpu%s.npy" % args.gpu) + np.save(smp_pcs_save_name, sample_pcs.cpu().detach().numpy()) + np.save(ref_pcs_save_name, ref_pcs.cpu().detach().numpy()) + print("Saving file:%s %s" % (smp_pcs_save_name, ref_pcs_save_name)) + + res = EMD_CD(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True) + mmd_cd = res['MMD-CD'] if 'MMD-CD' in res else None + mmd_emd = res['MMD-EMD'] if 'MMD-EMD' in res else None + + print("MMD-CD :%s" % mmd_cd) + print("MMD-EMD :%s" % mmd_emd) + + return res + + +def validate_sample(loader, model, args, max_samples=None, save_dir=None): + from metrics.evaluation_metrics import compute_all_metrics, jsd_between_point_cloud_sets as JSD + all_sample = [] + all_ref = [] + ttl_samples = 0 + + iterator = iter(loader) + + for data in iterator: + idx_b, te_pc = data['idx'], data['test_points'] + te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu) + _, out_pc = model.sample(te_pc.size(0), te_pc.size(1), gpu=args.gpu) + + # denormalize + m, s = data['mean'].float(), data['std'].float() + m = m.cuda() if args.gpu is None else m.cuda(args.gpu) + s = s.cuda() if args.gpu is None else s.cuda(args.gpu) + out_pc = out_pc * s + m + te_pc = te_pc * s + m + + all_sample.append(out_pc) + all_ref.append(te_pc) + + ttl_samples += int(te_pc.size(0)) + if max_samples is not None and ttl_samples >= max_samples: + break + + sample_pcs = torch.cat(all_sample, dim=0) + ref_pcs = torch.cat(all_ref, dim=0) + print("[rank %s] Generation Sample size:%s Ref size: %s" + % (args.rank, sample_pcs.size(), ref_pcs.size())) + + if save_dir is not None and args.save_val_results: + smp_pcs_save_name = os.path.join(save_dir, "smp_syn_pcls_gpu%s.npy" % args.gpu) + ref_pcs_save_name = os.path.join(save_dir, "ref_syn_pcls_gpu%s.npy" % args.gpu) + np.save(smp_pcs_save_name, sample_pcs.cpu().detach().numpy()) + np.save(ref_pcs_save_name, ref_pcs.cpu().detach().numpy()) + print("Saving file:%s %s" % (smp_pcs_save_name, ref_pcs_save_name)) + + res = compute_all_metrics(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True) + pprint(res) + + sample_pcs = sample_pcs.cpu().detach().numpy() + ref_pcs = ref_pcs.cpu().detach().numpy() + jsd = JSD(sample_pcs, ref_pcs) + jsd = torch.tensor(jsd).cuda() if args.gpu is None else torch.tensor(jsd).cuda(args.gpu) + res.update({"JSD": jsd}) + print("JSD :%s" % jsd) + return res + + +def save(model, optimizer, epoch, path): + d = { + 'epoch': epoch, + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict() + } + torch.save(d, path) + + +def resume(path, model, optimizer=None, strict=True): + ckpt = torch.load(path) + model.load_state_dict(ckpt['model'], strict=strict) + start_epoch = ckpt['epoch'] + if optimizer is not None: + optimizer.load_state_dict(ckpt['optimizer']) + return model, optimizer, start_epoch + + +def validate(test_loader, model, epoch, writer, save_dir, args, clf_loaders=None): + model.eval() + + # Make epoch wise save directory + if writer is not None and args.save_val_results: + save_dir = os.path.join(save_dir, 'epoch-%d' % epoch) + if not os.path.isdir(save_dir): + os.makedirs(save_dir) + else: + save_dir = None + + # classification + if args.eval_classification and clf_loaders is not None: + for clf_expr, loaders in clf_loaders.items(): + with torch.no_grad(): + clf_val_res = validate_classification(loaders, model, args) + + for k, v in clf_val_res.items(): + if writer is not None and v is not None: + writer.add_scalar('val_%s/%s' % (clf_expr, k), v, epoch) + + # samples + if args.use_latent_flow: + with torch.no_grad(): + val_sample_res = validate_sample( + test_loader, model, args, max_samples=args.max_validate_shapes, + save_dir=save_dir) + + for k, v in val_sample_res.items(): + if not isinstance(v, float): + v = v.cpu().detach().item() + if writer is not None and v is not None: + writer.add_scalar('val_sample/%s' % k, v, epoch) + + # reconstructions + with torch.no_grad(): + val_res = validate_conditioned( + test_loader, model, args, max_samples=args.max_validate_shapes, + save_dir=save_dir) + for k, v in val_res.items(): + if not isinstance(v, float): + v = v.cpu().detach().item() + if writer is not None and v is not None: + writer.add_scalar('val_conditioned/%s' % k, v, epoch) +