Add training/testing/demo codes
This commit is contained in:
parent
f1581cf0f5
commit
4e6795a4d9
113
.gitignore
vendored
113
.gitignore
vendored
|
@ -1,108 +1,23 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*__pycache__
|
||||
.idea/
|
||||
.DS_Store
|
||||
*/.DS_Store
|
||||
*/*/.DS_Store
|
||||
*.pyc
|
||||
data
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
scratch.py
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
*.m
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
checkpoints
|
||||
runs
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
metrics/structural_losses/*.so
|
||||
metrics/structural_losses/*.cu.o
|
||||
metrics/structural_losses/makefile
|
||||
PyMesh
|
||||
checkpoint
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
torchdiffeq/
|
||||
demo/
|
||||
|
|
92
README.md
92
README.md
|
@ -2,7 +2,7 @@
|
|||
|
||||
This repository contains a PyTorch implementation of the paper:
|
||||
|
||||
[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](https://arxiv.org/abs/1906.12320).
|
||||
[PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows](www.arxiv.com).
|
||||
|
||||
[Guandao Yang*](http://www.guandaoyang.com),
|
||||
[Xun Huang*](http://www.cs.cornell.edu/~xhuang/),
|
||||
|
@ -11,9 +11,6 @@ This repository contains a PyTorch implementation of the paper:
|
|||
[Serge Belongie](http://blogs.cornell.edu/techfaculty/serge-belongie/),
|
||||
[Bharath Hariharan](http://home.bharathh.info/)
|
||||
|
||||
**The code will be available soon!**
|
||||
|
||||
[\[Project page\]](https://www.guandaoyang.com/PointFlow/) [\[Video\]](https://www.youtube.com/watch?v=jqBiv77xC0M)
|
||||
|
||||
## Introduction
|
||||
|
||||
|
@ -23,3 +20,90 @@ As 3D point clouds become the representation of choice for multiple vision and g
|
|||
<p float="left">
|
||||
<img src="docs/assets/teaser.gif" height="256"/>
|
||||
</p>
|
||||
|
||||
## Dependencies
|
||||
* Python 3.6
|
||||
* CUDA 10.0.
|
||||
* G++ or GCC 5.
|
||||
* [PyTorch](http://pytorch.org/). Codes are tested with version 1.0.1
|
||||
* [torchdiffeq](https://github.com/rtqichen/torchdiffeq).
|
||||
* (Optional) [Tensorboard](https://www.tensorflow.org/) for visualization of training process.
|
||||
|
||||
Following is the suggested way to install these dependencies:
|
||||
```bash
|
||||
# Create a new conda environment
|
||||
conda create -n PointFlow python=3.6
|
||||
conda activate PointFlow
|
||||
|
||||
# Install pytorch (please refer to the commend in the official website)
|
||||
conda install pytorch=1.0.1 torchvision cudatoolkit=10.0 -c pytorch -y
|
||||
|
||||
# Install other dependencies such as torchdiffeq, structural losses, etc.
|
||||
./install.sh
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
||||
The point clouds are uniformly sampled from meshes from ShapeNetCore dataset (version 2) and use the official split.
|
||||
Please use this [link](https://drive.google.com/drive/folders/1G0rf-6HSHoTll6aH7voh-dXj6hCRhSAQ?usp=sharing) to download the ShapeNet point cloud.
|
||||
The point cloud should be placed into `data` directory.
|
||||
```bash
|
||||
mv ShapeNetCore.v2.PC15k.zip data/
|
||||
cd data
|
||||
unzip ShapeNetCore.v2.PC15k.zip
|
||||
```
|
||||
|
||||
Please contact us if you need point clouds for ModelNet dataset.
|
||||
|
||||
## Training
|
||||
|
||||
Example training scripts can be found in `scripts/` folder.
|
||||
```bash
|
||||
# Train auto-encoder (no latent CNF)
|
||||
./scripts/shapenet_airplane_ae.sh # Train with single GPU, about 7-8 GB GPU memory
|
||||
./scripts/shapenet_airplane_ae_dist.sh # Train with multiple GPUs
|
||||
|
||||
# Train generative model
|
||||
./scripts/shapenet_airplane_ae.sh # Train with single GPU, about 7-8 GB GPU memory
|
||||
./scripts/shapenet_airplane_ae_dist.sh # Train with multiple GPUs
|
||||
```
|
||||
|
||||
## Pre-trained models and test
|
||||
|
||||
Pretrained models can be downloaded from this [link](https://drive.google.com/file/d/1dcxjuuKiAXZxhiyWD_o_7Owx8Y3FbRHG/view?usp=sharing).
|
||||
Following is the suggested way to evaluate the performance of the pre-trained models.
|
||||
```bash
|
||||
unzip pretrained_models.zip; # This will create a folder named pretrained_models
|
||||
|
||||
# Evaluate the reconstruction performance of an AE trained on the airplane category
|
||||
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_ae_test.sh;
|
||||
|
||||
# Evaluate the reconstruction performance of an AE trained with the whole ShapeNet
|
||||
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_all_ae_test.sh;
|
||||
|
||||
# Evaluate the generative performance of PointFlow trained on the airplane category.
|
||||
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_gen_test.sh
|
||||
```
|
||||
|
||||
## Demo
|
||||
|
||||
The demo relies on [Open3D](http://www.open3d.org/). Following is the suggested way to install it:
|
||||
```bash
|
||||
conda install -c open3d-admin open3d
|
||||
```
|
||||
The demo will sample shapes from a pre-trained model, save those shapes under the `demo` folder, and visualize those point clouds.
|
||||
Once this dependency is in place, you can use the following script to use the demo for the pre-trained model for airplanes:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 ./scripts/shapenet_airplane_demo.py
|
||||
```
|
||||
|
||||
## Cite
|
||||
Please cite our work if you find it useful:
|
||||
```latex
|
||||
@article{pointflow,
|
||||
title={PointFlow: 3D Point Cloud Generation with Continuous Normalizing Flows},
|
||||
author={Yang, Guandao and Huang, Xun, and Hao, Zekun and Liu, Ming-Yu and Belongie, Serge and Hariharan, Bharath},
|
||||
journal={arXiv},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
|
163
args.py
Normal file
163
args.py
Normal file
|
@ -0,0 +1,163 @@
|
|||
import argparse
|
||||
|
||||
NONLINEARITIES = ["tanh", "relu", "softplus", "elu", "swish", "square", "identity"]
|
||||
SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams']
|
||||
LAYERS = ["ignore", "concat", "concat_v2", "squash", "concatsquash", "scale", "concatscale"]
|
||||
|
||||
|
||||
def add_args(parser):
|
||||
# model architecture options
|
||||
parser.add_argument('--input_dim', type=int, default=3,
|
||||
help='Number of input dimensions (3 for 3D point clouds)')
|
||||
parser.add_argument('--dims', type=str, default='256')
|
||||
parser.add_argument('--latent_dims', type=str, default='256')
|
||||
parser.add_argument("--num_blocks", type=int, default=1,
|
||||
help='Number of stacked CNFs.')
|
||||
parser.add_argument("--latent_num_blocks", type=int, default=1,
|
||||
help='Number of stacked CNFs.')
|
||||
parser.add_argument("--layer_type", type=str, default="concatsquash", choices=LAYERS)
|
||||
parser.add_argument('--time_length', type=float, default=0.5)
|
||||
parser.add_argument('--train_T', type=eval, default=True, choices=[True, False])
|
||||
parser.add_argument("--nonlinearity", type=str, default="tanh", choices=NONLINEARITIES)
|
||||
parser.add_argument('--use_adjoint', type=eval, default=True, choices=[True, False])
|
||||
parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS)
|
||||
parser.add_argument('--atol', type=float, default=1e-5)
|
||||
parser.add_argument('--rtol', type=float, default=1e-5)
|
||||
parser.add_argument('--batch_norm', type=eval, default=True, choices=[True, False])
|
||||
parser.add_argument('--sync_bn', type=eval, default=False, choices=[True, False])
|
||||
parser.add_argument('--bn_lag', type=float, default=0)
|
||||
|
||||
# training options
|
||||
parser.add_argument('--use_latent_flow', action='store_true',
|
||||
help='Whether to use the latent flow to model the prior.')
|
||||
parser.add_argument('--use_deterministic_encoder', action='store_true',
|
||||
help='Whether to use a deterministic encoder.')
|
||||
parser.add_argument('--zdim', type=int, default=128,
|
||||
help='Dimension of the shape code')
|
||||
parser.add_argument('--optimizer', type=str, default='adam',
|
||||
help='Optimizer to use', choices=['adam', 'adamax', 'sgd'])
|
||||
parser.add_argument('--batch_size', type=int, default=50,
|
||||
help='Batch size (of datasets) for training')
|
||||
parser.add_argument('--lr', type=float, default=1e-3,
|
||||
help='Learning rate for the Adam optimizer.')
|
||||
parser.add_argument('--beta1', type=float, default=0.9,
|
||||
help='Beta1 for Adam.')
|
||||
parser.add_argument('--beta2', type=float, default=0.999,
|
||||
help='Beta2 for Adam.')
|
||||
parser.add_argument('--momentum', type=float, default=0.9,
|
||||
help='Momentum for SGD')
|
||||
parser.add_argument('--weight_decay', type=float, default=0.,
|
||||
help='Weight decay for the optimizer.')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='Number of epochs for training (default: 100)')
|
||||
parser.add_argument('--seed', type=int, default=None,
|
||||
help='Seed for initializing training. ')
|
||||
parser.add_argument('--recon_weight', type=float, default=1.,
|
||||
help='Weight for the reconstruction loss.')
|
||||
parser.add_argument('--prior_weight', type=float, default=1.,
|
||||
help='Weight for the prior loss.')
|
||||
parser.add_argument('--entropy_weight', type=float, default=1.,
|
||||
help='Weight for the entropy loss.')
|
||||
parser.add_argument('--scheduler', type=str, default='linear',
|
||||
help='Type of learning rate schedule')
|
||||
parser.add_argument('--exp_decay', type=float, default=1.,
|
||||
help='Learning rate schedule exponential decay rate')
|
||||
parser.add_argument('--exp_decay_freq', type=int, default=1,
|
||||
help='Learning rate exponential decay frequency')
|
||||
|
||||
# data options
|
||||
parser.add_argument('--dataset_type', type=str, default="shapenet15k",
|
||||
help="Dataset types.", choices=['shapenet15k', 'modelnet40_15k', 'modelnet10_15k'])
|
||||
parser.add_argument('--cates', type=str, nargs='+', default=["airplane"],
|
||||
help="Categories to be trained (useful only if 'shapenet' is selected)")
|
||||
parser.add_argument('--data_dir', type=str, default="data/ShapeNetCore.v2.PC15k",
|
||||
help="Path to the training data")
|
||||
parser.add_argument('--mn40_data_dir', type=str, default="data/ModelNet40.PC15k",
|
||||
help="Path to ModelNet40")
|
||||
parser.add_argument('--mn10_data_dir', type=str, default="data/ModelNet10.PC15k",
|
||||
help="Path to ModelNet10")
|
||||
parser.add_argument('--dataset_scale', type=float, default=1.,
|
||||
help='Scale of the dataset (x,y,z * scale = real output, default=1).')
|
||||
parser.add_argument('--random_rotate', action='store_true',
|
||||
help='Whether to randomly rotate each shape.')
|
||||
parser.add_argument('--normalize_per_shape', action='store_true',
|
||||
help='Whether to perform normalization per shape.')
|
||||
parser.add_argument('--normalize_std_per_axis', action='store_true',
|
||||
help='Whether to perform normalization per axis.')
|
||||
parser.add_argument("--tr_max_sample_points", type=int, default=2048,
|
||||
help='Max number of sampled points (train)')
|
||||
parser.add_argument("--te_max_sample_points", type=int, default=2048,
|
||||
help='Max number of sampled points (test)')
|
||||
parser.add_argument('--num_workers', type=int, default=4,
|
||||
help='Number of data loading threads')
|
||||
|
||||
# logging and saving frequency
|
||||
parser.add_argument('--log_name', type=str, default=None, help="Name for the log dir")
|
||||
parser.add_argument('--viz_freq', type=int, default=10)
|
||||
parser.add_argument('--val_freq', type=int, default=10)
|
||||
parser.add_argument('--log_freq', type=int, default=10)
|
||||
parser.add_argument('--save_freq', type=int, default=10)
|
||||
|
||||
# validation options
|
||||
parser.add_argument('--no_validation', action='store_true',
|
||||
help='Whether to disable validation altogether.')
|
||||
parser.add_argument('--save_val_results', action='store_true',
|
||||
help='Whether to save the validation results.')
|
||||
parser.add_argument('--eval_classification', action='store_true',
|
||||
help='Whether to evaluate classification accuracy on MN40 and MN10.')
|
||||
parser.add_argument('--no_eval_sampling', action='store_true',
|
||||
help='Whether to evaluate sampling.')
|
||||
parser.add_argument('--max_validate_shapes', type=int, default=None,
|
||||
help='Max number of shapes used for validation pass.')
|
||||
|
||||
# resuming
|
||||
parser.add_argument('--resume_checkpoint', type=str, default=None,
|
||||
help='Path to the checkpoint to be loaded.')
|
||||
parser.add_argument('--resume_optimizer', action='store_true',
|
||||
help='Whether to resume the optimizer when resumed training.')
|
||||
parser.add_argument('--resume_non_strict', action='store_true',
|
||||
help='Whether to resume in none-strict mode.')
|
||||
parser.add_argument('--resume_dataset_mean', type=str, default=None,
|
||||
help='Path to the file storing the dataset mean.')
|
||||
parser.add_argument('--resume_dataset_std', type=str, default=None,
|
||||
help='Path to the file storing the dataset std.')
|
||||
|
||||
# distributed training
|
||||
parser.add_argument('--world_size', default=1, type=int,
|
||||
help='Number of distributed nodes.')
|
||||
parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist_backend', default='nccl', type=str,
|
||||
help='distributed backend')
|
||||
parser.add_argument('--distributed', action='store_true',
|
||||
help='Use multi-processing distributed training to launch '
|
||||
'N processes per node, which has N GPUs. This is the '
|
||||
'fastest way to use PyTorch for either single node or '
|
||||
'multi node data parallel training')
|
||||
parser.add_argument('--rank', default=0, type=int,
|
||||
help='node rank for distributed training')
|
||||
parser.add_argument('--gpu', default=None, type=int,
|
||||
help='GPU id to use. None means using all available GPUs.')
|
||||
|
||||
# Evaluation options
|
||||
parser.add_argument('--evaluate_recon', default=False, action='store_true',
|
||||
help='Whether set to the evaluation for reconstruction.')
|
||||
parser.add_argument('--num_sample_shapes', default=10, type=int,
|
||||
help='Number of shapes to be sampled (for demo.py).')
|
||||
parser.add_argument('--num_sample_points', default=2048, type=int,
|
||||
help='Number of points (per-shape) to be sampled (for demo.py).')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_parser():
|
||||
# command line args
|
||||
parser = argparse.ArgumentParser(description='Flow-based Point Cloud Generation Experiment')
|
||||
parser = add_args(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
return args
|
389
datasets.py
Normal file
389
datasets.py
Normal file
|
@ -0,0 +1,389 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils import data
|
||||
import random
|
||||
|
||||
# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
|
||||
synsetid_to_cate = {
|
||||
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
|
||||
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
|
||||
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
|
||||
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
|
||||
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
|
||||
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
|
||||
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
|
||||
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
|
||||
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
|
||||
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
|
||||
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
|
||||
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
|
||||
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
|
||||
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
|
||||
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
|
||||
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
|
||||
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
|
||||
'04554684': 'washer', '02992529': 'cellphone',
|
||||
'02843684': 'birdhouse', '02871439': 'bookshelf',
|
||||
# '02858304': 'boat', no boat in our dataset, merged into vessels
|
||||
# '02834778': 'bicycle', not in our taxonomy
|
||||
}
|
||||
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
|
||||
|
||||
|
||||
class Uniform15KPC(Dataset):
|
||||
def __init__(self, root_dir, subdirs, tr_sample_size=10000,
|
||||
te_sample_size=10000, split='train', scale=1.,
|
||||
normalize_per_shape=False, random_subsample=False,
|
||||
normalize_std_per_axis=False,
|
||||
all_points_mean=None, all_points_std=None,
|
||||
input_dim=3):
|
||||
self.root_dir = root_dir
|
||||
self.split = split
|
||||
self.in_tr_sample_size = tr_sample_size
|
||||
self.in_te_sample_size = te_sample_size
|
||||
self.subdirs = subdirs
|
||||
self.scale = scale
|
||||
self.random_subsample = random_subsample
|
||||
self.input_dim = input_dim
|
||||
|
||||
self.all_cate_mids = []
|
||||
self.cate_idx_lst = []
|
||||
self.all_points = []
|
||||
for cate_idx, subd in enumerate(self.subdirs):
|
||||
# NOTE: [subd] here is synset id
|
||||
sub_path = os.path.join(root_dir, subd, self.split)
|
||||
if not os.path.isdir(sub_path):
|
||||
print("Directory missing : %s" % sub_path)
|
||||
continue
|
||||
|
||||
all_mids = []
|
||||
for x in os.listdir(sub_path):
|
||||
if not x.endswith('.npy'):
|
||||
continue
|
||||
all_mids.append(os.path.join(self.split, x[:-len('.npy')]))
|
||||
|
||||
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
|
||||
for mid in all_mids:
|
||||
# obj_fname = os.path.join(sub_path, x)
|
||||
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
|
||||
try:
|
||||
point_cloud = np.load(obj_fname) # (15k, 3)
|
||||
except:
|
||||
continue
|
||||
|
||||
assert point_cloud.shape[0] == 15000
|
||||
self.all_points.append(point_cloud[np.newaxis, ...])
|
||||
self.cate_idx_lst.append(cate_idx)
|
||||
self.all_cate_mids.append((subd, mid))
|
||||
|
||||
# Shuffle the index deterministically (based on the number of examples)
|
||||
self.shuffle_idx = list(range(len(self.all_points)))
|
||||
random.Random(38383).shuffle(self.shuffle_idx)
|
||||
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
|
||||
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
|
||||
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
|
||||
|
||||
# Normalization
|
||||
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
|
||||
self.normalize_per_shape = normalize_per_shape
|
||||
self.normalize_std_per_axis = normalize_std_per_axis
|
||||
if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats
|
||||
self.all_points_mean = all_points_mean
|
||||
self.all_points_std = all_points_std
|
||||
elif self.normalize_per_shape: # per shape normalization
|
||||
B, N = self.all_points.shape[:2]
|
||||
self.all_points_mean = self.all_points.mean(axis=1).reshape(B, 1, input_dim)
|
||||
if normalize_std_per_axis:
|
||||
self.all_points_std = self.all_points.reshape(B, N, -1).std(axis=1).reshape(B, 1, input_dim)
|
||||
else:
|
||||
self.all_points_std = self.all_points.reshape(B, -1).std(axis=1).reshape(B, 1, 1)
|
||||
else: # normalize across the dataset
|
||||
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
|
||||
if normalize_std_per_axis:
|
||||
self.all_points_std = self.all_points.reshape(-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
|
||||
else:
|
||||
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
|
||||
|
||||
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
|
||||
self.train_points = self.all_points[:, :10000]
|
||||
self.test_points = self.all_points[:, 10000:]
|
||||
|
||||
self.tr_sample_size = min(10000, tr_sample_size)
|
||||
self.te_sample_size = min(5000, te_sample_size)
|
||||
print("Total number of data:%d" % len(self.train_points))
|
||||
print("Min number of points: (train)%d (test)%d"
|
||||
% (self.tr_sample_size, self.te_sample_size))
|
||||
assert self.scale == 1, "Scale (!= 1) is deprecated"
|
||||
|
||||
def get_pc_stats(self, idx):
|
||||
if self.normalize_per_shape:
|
||||
m = self.all_points_mean[idx].reshape(1, self.input_dim)
|
||||
s = self.all_points_std[idx].reshape(1, -1)
|
||||
return m, s
|
||||
|
||||
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
|
||||
|
||||
def renormalize(self, mean, std):
|
||||
self.all_points = self.all_points * self.all_points_std + self.all_points_mean
|
||||
self.all_points_mean = mean
|
||||
self.all_points_std = std
|
||||
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
|
||||
self.train_points = self.all_points[:, :10000]
|
||||
self.test_points = self.all_points[:, 10000:]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.train_points)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
tr_out = self.train_points[idx]
|
||||
if self.random_subsample:
|
||||
tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
|
||||
else:
|
||||
tr_idxs = np.arange(self.tr_sample_size)
|
||||
tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
|
||||
|
||||
te_out = self.test_points[idx]
|
||||
if self.random_subsample:
|
||||
te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size)
|
||||
else:
|
||||
te_idxs = np.arange(self.te_sample_size)
|
||||
te_out = torch.from_numpy(te_out[te_idxs, :]).float()
|
||||
|
||||
m, s = self.get_pc_stats(idx)
|
||||
cate_idx = self.cate_idx_lst[idx]
|
||||
sid, mid = self.all_cate_mids[idx]
|
||||
|
||||
return {
|
||||
'idx': idx,
|
||||
'train_points': tr_out,
|
||||
'test_points': te_out,
|
||||
'mean': m, 'std': s, 'cate_idx': cate_idx,
|
||||
'sid': sid, 'mid': mid
|
||||
}
|
||||
|
||||
|
||||
class ModelNet40PointClouds(Uniform15KPC):
|
||||
def __init__(self, root_dir="data/ModelNet40.PC15k",
|
||||
tr_sample_size=10000, te_sample_size=2048,
|
||||
split='train', scale=1., normalize_per_shape=False,
|
||||
normalize_std_per_axis=False,
|
||||
random_subsample=False,
|
||||
all_points_mean=None, all_points_std=None):
|
||||
self.root_dir = root_dir
|
||||
self.split = split
|
||||
assert self.split in ['train', 'test']
|
||||
self.sample_size = tr_sample_size
|
||||
self.cates = []
|
||||
for cate in os.listdir(root_dir):
|
||||
if os.path.isdir(os.path.join(root_dir, cate)) \
|
||||
and os.path.isdir(os.path.join(root_dir, cate, 'train')) \
|
||||
and os.path.isdir(os.path.join(root_dir, cate, 'test')):
|
||||
self.cates.append(cate)
|
||||
assert len(self.cates) == 40, "%s %s" % (len(self.cates), self.cates)
|
||||
|
||||
# For non-aligned MN
|
||||
# self.gravity_axis = 0
|
||||
# self.display_axis_order = [0,1,2]
|
||||
|
||||
# Aligned MN has same axis-order as SN
|
||||
self.gravity_axis = 1
|
||||
self.display_axis_order = [0, 2, 1]
|
||||
|
||||
super(ModelNet40PointClouds, self).__init__(
|
||||
root_dir, self.cates, tr_sample_size=tr_sample_size,
|
||||
te_sample_size=te_sample_size, split=split, scale=scale,
|
||||
normalize_per_shape=normalize_per_shape,
|
||||
normalize_std_per_axis=normalize_std_per_axis,
|
||||
random_subsample=random_subsample,
|
||||
all_points_mean=all_points_mean, all_points_std=all_points_std,
|
||||
input_dim=3)
|
||||
|
||||
|
||||
class ModelNet10PointClouds(Uniform15KPC):
|
||||
def __init__(self, root_dir="data/ModelNet10.PC15k",
|
||||
tr_sample_size=10000, te_sample_size=2048,
|
||||
split='train', scale=1., normalize_per_shape=False,
|
||||
normalize_std_per_axis=False,
|
||||
random_subsample=False,
|
||||
all_points_mean=None, all_points_std=None):
|
||||
self.root_dir = root_dir
|
||||
self.split = split
|
||||
assert self.split in ['train', 'test']
|
||||
self.cates = []
|
||||
for cate in os.listdir(root_dir):
|
||||
if os.path.isdir(os.path.join(root_dir, cate)) \
|
||||
and os.path.isdir(os.path.join(root_dir, cate, 'train')) \
|
||||
and os.path.isdir(os.path.join(root_dir, cate, 'test')):
|
||||
self.cates.append(cate)
|
||||
assert len(self.cates) == 10
|
||||
|
||||
# That's prealigned MN
|
||||
# self.gravity_axis = 0
|
||||
# self.display_axis_order = [0,1,2]
|
||||
|
||||
# Aligned MN has same axis-order as SN
|
||||
self.gravity_axis = 1
|
||||
self.display_axis_order = [0, 2, 1]
|
||||
|
||||
super(ModelNet10PointClouds, self).__init__(
|
||||
root_dir, self.cates, tr_sample_size=tr_sample_size,
|
||||
te_sample_size=te_sample_size, split=split, scale=scale,
|
||||
normalize_per_shape=normalize_per_shape,
|
||||
normalize_std_per_axis=normalize_std_per_axis,
|
||||
random_subsample=random_subsample,
|
||||
all_points_mean=all_points_mean, all_points_std=all_points_std,
|
||||
input_dim=3)
|
||||
|
||||
|
||||
class ShapeNet15kPointClouds(Uniform15KPC):
|
||||
def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k",
|
||||
categories=['airplane'], tr_sample_size=10000, te_sample_size=2048,
|
||||
split='train', scale=1., normalize_per_shape=False,
|
||||
normalize_std_per_axis=False,
|
||||
random_subsample=False,
|
||||
all_points_mean=None, all_points_std=None):
|
||||
self.root_dir = root_dir
|
||||
self.split = split
|
||||
assert self.split in ['train', 'test', 'val']
|
||||
self.tr_sample_size = tr_sample_size
|
||||
self.te_sample_size = te_sample_size
|
||||
self.cates = categories
|
||||
if 'all' in categories:
|
||||
self.synset_ids = list(cate_to_synsetid.values())
|
||||
else:
|
||||
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
|
||||
|
||||
# assert 'v2' in root_dir, "Only supporting v2 right now."
|
||||
self.gravity_axis = 1
|
||||
self.display_axis_order = [0, 2, 1]
|
||||
|
||||
super(ShapeNet15kPointClouds, self).__init__(
|
||||
root_dir, self.synset_ids,
|
||||
tr_sample_size=tr_sample_size,
|
||||
te_sample_size=te_sample_size,
|
||||
split=split, scale=scale,
|
||||
normalize_per_shape=normalize_per_shape,
|
||||
normalize_std_per_axis=normalize_std_per_axis,
|
||||
random_subsample=random_subsample,
|
||||
all_points_mean=all_points_mean, all_points_std=all_points_std,
|
||||
input_dim=3)
|
||||
|
||||
|
||||
def init_np_seed(worker_id):
|
||||
seed = torch.initial_seed()
|
||||
np.random.seed(seed % 4294967296)
|
||||
|
||||
|
||||
def _get_MN40_datasets_(args, data_dir=None):
|
||||
tr_dataset = ModelNet40PointClouds(
|
||||
split='train',
|
||||
tr_sample_size=args.tr_max_sample_points,
|
||||
te_sample_size=args.te_max_sample_points,
|
||||
root_dir=(args.data_dir if data_dir is None else data_dir),
|
||||
normalize_per_shape=args.normalize_per_shape,
|
||||
normalize_std_per_axis=args.normalize_std_per_axis,
|
||||
random_subsample=True)
|
||||
te_dataset = ModelNet40PointClouds(
|
||||
split='test',
|
||||
tr_sample_size=args.tr_max_sample_points,
|
||||
te_sample_size=args.te_max_sample_points,
|
||||
root_dir=(args.data_dir if data_dir is None else data_dir),
|
||||
normalize_per_shape=args.normalize_per_shape,
|
||||
normalize_std_per_axis=args.normalize_std_per_axis,
|
||||
all_points_mean=tr_dataset.all_points_mean,
|
||||
all_points_std=tr_dataset.all_points_std,
|
||||
)
|
||||
|
||||
return tr_dataset, te_dataset
|
||||
|
||||
|
||||
def _get_MN10_datasets_(args, data_dir=None):
|
||||
tr_dataset = ModelNet10PointClouds(
|
||||
split='train',
|
||||
tr_sample_size=args.tr_max_sample_points,
|
||||
te_sample_size=args.te_max_sample_points,
|
||||
root_dir=(args.data_dir if data_dir is None else data_dir),
|
||||
normalize_per_shape=args.normalize_per_shape,
|
||||
normalize_std_per_axis=args.normalize_std_per_axis,
|
||||
random_subsample=True)
|
||||
te_dataset = ModelNet10PointClouds(
|
||||
split='test',
|
||||
tr_sample_size=args.tr_max_sample_points,
|
||||
te_sample_size=args.te_max_sample_points,
|
||||
root_dir=(args.data_dir if data_dir is None else data_dir),
|
||||
normalize_per_shape=args.normalize_per_shape,
|
||||
normalize_std_per_axis=args.normalize_std_per_axis,
|
||||
all_points_mean=tr_dataset.all_points_mean,
|
||||
all_points_std=tr_dataset.all_points_std,
|
||||
)
|
||||
return tr_dataset, te_dataset
|
||||
|
||||
|
||||
def get_datasets(args):
|
||||
if args.dataset_type == 'shapenet15k':
|
||||
tr_dataset = ShapeNet15kPointClouds(
|
||||
categories=args.cates, split='train',
|
||||
tr_sample_size=args.tr_max_sample_points,
|
||||
te_sample_size=args.te_max_sample_points,
|
||||
scale=args.dataset_scale, root_dir=args.data_dir,
|
||||
normalize_per_shape=args.normalize_per_shape,
|
||||
normalize_std_per_axis=args.normalize_std_per_axis,
|
||||
random_subsample=True)
|
||||
te_dataset = ShapeNet15kPointClouds(
|
||||
categories=args.cates, split='val',
|
||||
tr_sample_size=args.tr_max_sample_points,
|
||||
te_sample_size=args.te_max_sample_points,
|
||||
scale=args.dataset_scale, root_dir=args.data_dir,
|
||||
normalize_per_shape=args.normalize_per_shape,
|
||||
normalize_std_per_axis=args.normalize_std_per_axis,
|
||||
all_points_mean=tr_dataset.all_points_mean,
|
||||
all_points_std=tr_dataset.all_points_std,
|
||||
)
|
||||
elif args.dataset_type == 'modelnet40_15k':
|
||||
tr_dataset, te_dataset = _get_MN40_datasets_(args)
|
||||
elif args.dataset_type == 'modelnet10_15k':
|
||||
tr_dataset, te_dataset = _get_MN10_datasets_(args)
|
||||
else:
|
||||
raise Exception("Invalid dataset type:%s" % args.dataset_type)
|
||||
|
||||
return tr_dataset, te_dataset
|
||||
|
||||
|
||||
def get_clf_datasets(args):
|
||||
return {
|
||||
'MN40': _get_MN40_datasets_(args, data_dir=args.mn40_data_dir),
|
||||
'MN10': _get_MN10_datasets_(args, data_dir=args.mn10_data_dir),
|
||||
}
|
||||
|
||||
|
||||
def get_data_loaders(args):
|
||||
tr_dataset, te_dataset = get_datasets(args)
|
||||
train_loader = data.DataLoader(
|
||||
dataset=tr_dataset, batch_size=args.batch_size,
|
||||
shuffle=True, num_workers=args.num_workers, drop_last=True,
|
||||
worker_init_fn=init_np_seed)
|
||||
train_unshuffle_loader = data.DataLoader(
|
||||
dataset=tr_dataset, batch_size=args.batch_size,
|
||||
shuffle=False, num_workers=args.num_workers, drop_last=True,
|
||||
worker_init_fn=init_np_seed)
|
||||
test_loader = data.DataLoader(
|
||||
dataset=te_dataset, batch_size=args.batch_size,
|
||||
shuffle=False, num_workers=args.num_workers, drop_last=False,
|
||||
worker_init_fn=init_np_seed)
|
||||
|
||||
loaders = {
|
||||
"test_loader": test_loader,
|
||||
'train_loader': train_loader,
|
||||
'train_unshuffle_loader': train_unshuffle_loader,
|
||||
}
|
||||
return loaders
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
shape_ds = ShapeNet15kPointClouds(categories=['airplane'], split='val')
|
||||
x_tr, x_te = next(iter(shape_ds))
|
||||
print(x_tr.shape)
|
||||
print(x_te.shape)
|
||||
|
60
demo.py
Normal file
60
demo.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import open3d as o3d
|
||||
from datasets import get_datasets
|
||||
from args import get_args
|
||||
from models.networks import PointFlow
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def main(args):
|
||||
model = PointFlow(args)
|
||||
|
||||
def _transform_(m):
|
||||
return nn.DataParallel(m)
|
||||
|
||||
model = model.cuda()
|
||||
model.multi_gpu_wrapper(_transform_)
|
||||
|
||||
print("Resume Path:%s" % args.resume_checkpoint)
|
||||
checkpoint = torch.load(args.resume_checkpoint)
|
||||
model.load_state_dict(checkpoint)
|
||||
model.eval()
|
||||
|
||||
_, te_dataset = get_datasets(args)
|
||||
if args.resume_dataset_mean is not None and args.resume_dataset_std is not None:
|
||||
mean = np.load(args.resume_dataset_mean)
|
||||
std = np.load(args.resume_dataset_std)
|
||||
te_dataset.renormalize(mean, std)
|
||||
ds_mean = torch.from_numpy(te_dataset.all_points_mean).cuda()
|
||||
ds_std = torch.from_numpy(te_dataset.all_points_std).cuda()
|
||||
|
||||
all_sample = []
|
||||
with torch.no_grad():
|
||||
for i in range(0, args.num_sample_shapes, args.batch_size):
|
||||
B = len(range(i, min(i + args.batch_size, args.num_sample_shapes)))
|
||||
N = args.num_sample_points
|
||||
_, out_pc = model.sample(B, N)
|
||||
out_pc = out_pc * ds_std + ds_mean
|
||||
all_sample.append(out_pc)
|
||||
|
||||
sample_pcs = torch.cat(all_sample, dim=0).cpu().detach().numpy()
|
||||
print("Generation sample size:(%s, %s, %s)" % sample_pcs.shape)
|
||||
|
||||
# Save the generative output
|
||||
os.makedirs("demo", exist_ok=True)
|
||||
np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs)
|
||||
|
||||
# Visualize the demo
|
||||
pcl = o3d.geometry.PointCloud()
|
||||
for i in range(int(sample_pcs.shape[0])):
|
||||
print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
|
||||
pts = sample_pcs[i].reshape(-1, 3)
|
||||
pcl.points = o3d.utility.Vector3dVector(pts)
|
||||
o3d.visualization.draw_geometries([pcl])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
main(args)
|
19
install.sh
Executable file
19
install.sh
Executable file
|
@ -0,0 +1,19 @@
|
|||
#! /bin/bash
|
||||
|
||||
root=`pwd`
|
||||
|
||||
# Install dependecies
|
||||
conda install numpy matplotlib pillow scipy tqdm scikit-learn -y
|
||||
pip install tensorflow-gpu==1.13.1
|
||||
pip install tensorboardX==1.7
|
||||
|
||||
# Compile CUDA kernel for CD/EMD loss
|
||||
cd metrics/pytorch_structural_losses/
|
||||
make clean
|
||||
make
|
||||
cd $root
|
||||
|
||||
# install torchdiffeq
|
||||
git clone https://github.com/rtqichen/torchdiffeq.git
|
||||
cd torchdiffeq
|
||||
pip install -e .
|
1
metrics/.gitignore
vendored
Normal file
1
metrics/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
StructuralLosses
|
0
metrics/__init__.py
Normal file
0
metrics/__init__.py
Normal file
336
metrics/evaluation_metrics.py
Normal file
336
metrics/evaluation_metrics.py
Normal file
|
@ -0,0 +1,336 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
from scipy.stats import entropy
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
from numpy.linalg import norm
|
||||
|
||||
# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/
|
||||
from .StructuralLosses.match_cost import match_cost
|
||||
from .StructuralLosses.nn_distance import nn_distance
|
||||
|
||||
|
||||
# # Import CUDA version of CD, borrowed from https://github.com/ThibaultGROUEIX/AtlasNet
|
||||
# try:
|
||||
# from . chamfer_distance_ext.dist_chamfer import chamferDist
|
||||
# CD = chamferDist()
|
||||
# def distChamferCUDA(x,y):
|
||||
# return CD(x,y,gpu)
|
||||
# except:
|
||||
|
||||
|
||||
def distChamferCUDA(x, y):
|
||||
return nn_distance(x, y)
|
||||
|
||||
|
||||
def emd_approx(sample, ref):
|
||||
B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
|
||||
assert N == N_ref, "Not sure what would EMD do in this case"
|
||||
emd = match_cost(sample, ref) # (B,)
|
||||
emd_norm = emd / float(N) # (B,)
|
||||
return emd_norm
|
||||
|
||||
|
||||
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
|
||||
def distChamfer(a, b):
|
||||
x, y = a, b
|
||||
bs, num_points, points_dim = x.size()
|
||||
xx = torch.bmm(x, x.transpose(2, 1))
|
||||
yy = torch.bmm(y, y.transpose(2, 1))
|
||||
zz = torch.bmm(x, y.transpose(2, 1))
|
||||
diag_ind = torch.arange(0, num_points).to(a).long()
|
||||
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
|
||||
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
|
||||
P = (rx.transpose(2, 1) + ry - 2 * zz)
|
||||
return P.min(1)[0], P.min(2)[0]
|
||||
|
||||
|
||||
def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
|
||||
N_sample = sample_pcs.shape[0]
|
||||
N_ref = ref_pcs.shape[0]
|
||||
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
|
||||
|
||||
cd_lst = []
|
||||
emd_lst = []
|
||||
iterator = range(0, N_sample, batch_size)
|
||||
|
||||
for b_start in iterator:
|
||||
b_end = min(N_sample, b_start + batch_size)
|
||||
sample_batch = sample_pcs[b_start:b_end]
|
||||
ref_batch = ref_pcs[b_start:b_end]
|
||||
|
||||
if accelerated_cd:
|
||||
dl, dr = distChamferCUDA(sample_batch, ref_batch)
|
||||
else:
|
||||
dl, dr = distChamfer(sample_batch, ref_batch)
|
||||
cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))
|
||||
|
||||
emd_batch = emd_approx(sample_batch, ref_batch)
|
||||
emd_lst.append(emd_batch)
|
||||
|
||||
if reduced:
|
||||
cd = torch.cat(cd_lst).mean()
|
||||
emd = torch.cat(emd_lst).mean()
|
||||
else:
|
||||
cd = torch.cat(cd_lst)
|
||||
emd = torch.cat(emd_lst)
|
||||
|
||||
results = {
|
||||
'MMD-CD': cd,
|
||||
'MMD-EMD': emd,
|
||||
}
|
||||
return results
|
||||
|
||||
|
||||
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
|
||||
N_sample = sample_pcs.shape[0]
|
||||
N_ref = ref_pcs.shape[0]
|
||||
all_cd = []
|
||||
all_emd = []
|
||||
iterator = range(N_sample)
|
||||
for sample_b_start in iterator:
|
||||
sample_batch = sample_pcs[sample_b_start]
|
||||
|
||||
cd_lst = []
|
||||
emd_lst = []
|
||||
for ref_b_start in range(0, N_ref, batch_size):
|
||||
ref_b_end = min(N_ref, ref_b_start + batch_size)
|
||||
ref_batch = ref_pcs[ref_b_start:ref_b_end]
|
||||
|
||||
batch_size_ref = ref_batch.size(0)
|
||||
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
|
||||
sample_batch_exp = sample_batch_exp.contiguous()
|
||||
|
||||
if accelerated_cd:
|
||||
dl, dr = distChamferCUDA(sample_batch_exp, ref_batch)
|
||||
else:
|
||||
dl, dr = distChamfer(sample_batch_exp, ref_batch)
|
||||
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
|
||||
|
||||
emd_batch = emd_approx(sample_batch_exp, ref_batch)
|
||||
emd_lst.append(emd_batch.view(1, -1))
|
||||
|
||||
cd_lst = torch.cat(cd_lst, dim=1)
|
||||
emd_lst = torch.cat(emd_lst, dim=1)
|
||||
all_cd.append(cd_lst)
|
||||
all_emd.append(emd_lst)
|
||||
|
||||
all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref
|
||||
all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref
|
||||
|
||||
return all_cd, all_emd
|
||||
|
||||
|
||||
# Adapted from https://github.com/xuqiantong/GAN-Metrics/blob/master/framework/metric.py
|
||||
def knn(Mxx, Mxy, Myy, k, sqrt=False):
|
||||
n0 = Mxx.size(0)
|
||||
n1 = Myy.size(0)
|
||||
label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx)
|
||||
M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)
|
||||
if sqrt:
|
||||
M = M.abs().sqrt()
|
||||
INFINITY = float('inf')
|
||||
val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)
|
||||
|
||||
count = torch.zeros(n0 + n1).to(Mxx)
|
||||
for i in range(0, k):
|
||||
count = count + label.index_select(0, idx[i])
|
||||
pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()
|
||||
|
||||
s = {
|
||||
'tp': (pred * label).sum(),
|
||||
'fp': (pred * (1 - label)).sum(),
|
||||
'fn': ((1 - pred) * label).sum(),
|
||||
'tn': ((1 - pred) * (1 - label)).sum(),
|
||||
}
|
||||
|
||||
s.update({
|
||||
'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10),
|
||||
'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
|
||||
'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
|
||||
'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10),
|
||||
'acc': torch.eq(label, pred).float().mean(),
|
||||
})
|
||||
return s
|
||||
|
||||
|
||||
def lgan_mmd_cov(all_dist):
|
||||
N_sample, N_ref = all_dist.size(0), all_dist.size(1)
|
||||
min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
|
||||
min_val, _ = torch.min(all_dist, dim=0)
|
||||
mmd = min_val.mean()
|
||||
mmd_smp = min_val_fromsmp.mean()
|
||||
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
|
||||
cov = torch.tensor(cov).to(all_dist)
|
||||
return {
|
||||
'lgan_mmd': mmd,
|
||||
'lgan_cov': cov,
|
||||
'lgan_mmd_smp': mmd_smp,
|
||||
}
|
||||
|
||||
|
||||
def compute_all_metrics(sample_pcs, ref_pcs, batch_size, accelerated_cd=False):
|
||||
results = {}
|
||||
|
||||
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
|
||||
|
||||
res_cd = lgan_mmd_cov(M_rs_cd.t())
|
||||
results.update({
|
||||
"%s-CD" % k: v for k, v in res_cd.items()
|
||||
})
|
||||
|
||||
res_emd = lgan_mmd_cov(M_rs_emd.t())
|
||||
results.update({
|
||||
"%s-EMD" % k: v for k, v in res_emd.items()
|
||||
})
|
||||
|
||||
M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size, accelerated_cd=accelerated_cd)
|
||||
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
|
||||
|
||||
# 1-NN results
|
||||
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
|
||||
results.update({
|
||||
"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
|
||||
})
|
||||
one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
|
||||
results.update({
|
||||
"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
#######################################################
|
||||
# JSD : from https://github.com/optas/latent_3d_points
|
||||
#######################################################
|
||||
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
|
||||
"""Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
|
||||
that is placed in the unit-cube.
|
||||
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
|
||||
"""
|
||||
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
|
||||
spacing = 1.0 / float(resolution - 1)
|
||||
for i in range(resolution):
|
||||
for j in range(resolution):
|
||||
for k in range(resolution):
|
||||
grid[i, j, k, 0] = i * spacing - 0.5
|
||||
grid[i, j, k, 1] = j * spacing - 0.5
|
||||
grid[i, j, k, 2] = k * spacing - 0.5
|
||||
|
||||
if clip_sphere:
|
||||
grid = grid.reshape(-1, 3)
|
||||
grid = grid[norm(grid, axis=1) <= 0.5]
|
||||
|
||||
return grid, spacing
|
||||
|
||||
|
||||
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
|
||||
"""Computes the JSD between two sets of point-clouds, as introduced in the paper
|
||||
```Learning Representations And Generative Models For 3D Point Clouds```.
|
||||
Args:
|
||||
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
|
||||
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
|
||||
resolution: (int) grid-resolution. Affects granularity of measurements.
|
||||
"""
|
||||
in_unit_sphere = True
|
||||
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
|
||||
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
|
||||
return jensen_shannon_divergence(sample_grid_var, ref_grid_var)
|
||||
|
||||
|
||||
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose=False):
|
||||
"""Given a collection of point-clouds, estimate the entropy of the random variables
|
||||
corresponding to occupancy-grid activation patterns.
|
||||
Inputs:
|
||||
pclouds: (numpy array) #point-clouds x points per point-cloud x 3
|
||||
grid_resolution (int) size of occupancy grid that will be used.
|
||||
"""
|
||||
epsilon = 10e-4
|
||||
bound = 0.5 + epsilon
|
||||
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
|
||||
if verbose:
|
||||
warnings.warn('Point-clouds are not in unit cube.')
|
||||
|
||||
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
|
||||
if verbose:
|
||||
warnings.warn('Point-clouds are not in unit sphere.')
|
||||
|
||||
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
|
||||
grid_coordinates = grid_coordinates.reshape(-1, 3)
|
||||
grid_counters = np.zeros(len(grid_coordinates))
|
||||
grid_bernoulli_rvars = np.zeros(len(grid_coordinates))
|
||||
nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)
|
||||
|
||||
for pc in pclouds:
|
||||
_, indices = nn.kneighbors(pc)
|
||||
indices = np.squeeze(indices)
|
||||
for i in indices:
|
||||
grid_counters[i] += 1
|
||||
indices = np.unique(indices)
|
||||
for i in indices:
|
||||
grid_bernoulli_rvars[i] += 1
|
||||
|
||||
acc_entropy = 0.0
|
||||
n = float(len(pclouds))
|
||||
for g in grid_bernoulli_rvars:
|
||||
if g > 0:
|
||||
p = float(g) / n
|
||||
acc_entropy += entropy([p, 1.0 - p])
|
||||
|
||||
return acc_entropy / len(grid_counters), grid_counters
|
||||
|
||||
|
||||
def jensen_shannon_divergence(P, Q):
|
||||
if np.any(P < 0) or np.any(Q < 0):
|
||||
raise ValueError('Negative values.')
|
||||
if len(P) != len(Q):
|
||||
raise ValueError('Non equal size.')
|
||||
|
||||
P_ = P / np.sum(P) # Ensure probabilities.
|
||||
Q_ = Q / np.sum(Q)
|
||||
|
||||
e1 = entropy(P_, base=2)
|
||||
e2 = entropy(Q_, base=2)
|
||||
e_sum = entropy((P_ + Q_) / 2.0, base=2)
|
||||
res = e_sum - ((e1 + e2) / 2.0)
|
||||
|
||||
res2 = _jsdiv(P_, Q_)
|
||||
|
||||
if not np.allclose(res, res2, atol=10e-5, rtol=0):
|
||||
warnings.warn('Numerical values of two JSD methods don\'t agree.')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _jsdiv(P, Q):
|
||||
"""another way of computing JSD"""
|
||||
|
||||
def _kldiv(A, B):
|
||||
a = A.copy()
|
||||
b = B.copy()
|
||||
idx = np.logical_and(a > 0, b > 0)
|
||||
a = a[idx]
|
||||
b = b[idx]
|
||||
return np.sum([v for v in a * np.log2(a / b)])
|
||||
|
||||
P_ = P / np.sum(P)
|
||||
Q_ = Q / np.sum(Q)
|
||||
|
||||
M = 0.5 * (P_ + Q_)
|
||||
|
||||
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
B, N = 2, 10
|
||||
x = torch.rand(B, N, 3)
|
||||
y = torch.rand(B, N, 3)
|
||||
|
||||
distChamfer = distChamferCUDA()
|
||||
min_l, min_r = distChamfer(x.cuda(), y.cuda())
|
||||
print(min_l.shape)
|
||||
print(min_r.shape)
|
||||
|
||||
l_dist = min_l.mean().cpu().detach().item()
|
||||
r_dist = min_r.mean().cpu().detach().item()
|
||||
print(l_dist, r_dist)
|
1
metrics/pytorch_structural_losses/.gitignore
vendored
Normal file
1
metrics/pytorch_structural_losses/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
PyTorchStructuralLosses.egg-info/
|
100
metrics/pytorch_structural_losses/Makefile
Normal file
100
metrics/pytorch_structural_losses/Makefile
Normal file
|
@ -0,0 +1,100 @@
|
|||
###############################################################################
|
||||
# Uncomment for debugging
|
||||
# DEBUG := 1
|
||||
# Pretty build
|
||||
# Q ?= @
|
||||
|
||||
CXX := g++
|
||||
PYTHON := python
|
||||
NVCC := /usr/local/cuda/bin/nvcc
|
||||
|
||||
# PYTHON Header path
|
||||
PYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())')
|
||||
PYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]')
|
||||
PYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]')
|
||||
|
||||
# CUDA ROOT DIR that contains bin/ lib64/ and include/
|
||||
# CUDA_DIR := /usr/local/cuda
|
||||
CUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())')
|
||||
|
||||
INCLUDE_DIRS := ./ $(CUDA_DIR)/include
|
||||
|
||||
INCLUDE_DIRS += $(PYTHON_HEADER_DIR)
|
||||
INCLUDE_DIRS += $(PYTORCH_INCLUDES)
|
||||
|
||||
# Custom (MKL/ATLAS/OpenBLAS) include and lib directories.
|
||||
# Leave commented to accept the defaults for your choice of BLAS
|
||||
# (which should work)!
|
||||
# BLAS_INCLUDE := /path/to/your/blas
|
||||
# BLAS_LIB := /path/to/your/blas
|
||||
|
||||
###############################################################################
|
||||
SRC_DIR := ./src
|
||||
OBJ_DIR := ./objs
|
||||
CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp)
|
||||
CU_SRCS := $(wildcard $(SRC_DIR)/*.cu)
|
||||
OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS))
|
||||
CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS))
|
||||
STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a
|
||||
|
||||
# CUDA architecture setting: going with all of them.
|
||||
# For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility.
|
||||
# For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility.
|
||||
CUDA_ARCH := -gencode arch=compute_61,code=sm_61 \
|
||||
-gencode arch=compute_61,code=compute_61 \
|
||||
-gencode arch=compute_52,code=sm_52
|
||||
|
||||
# We will also explicitly add stdc++ to the link target.
|
||||
LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu
|
||||
|
||||
# Debugging
|
||||
ifeq ($(DEBUG), 1)
|
||||
COMMON_FLAGS += -DDEBUG -g -O0
|
||||
# https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/
|
||||
NVCCFLAGS += -g -G # -rdc true
|
||||
else
|
||||
COMMON_FLAGS += -DNDEBUG -O3
|
||||
endif
|
||||
|
||||
WARNINGS := -Wall -Wno-sign-compare -Wcomment
|
||||
|
||||
INCLUDE_DIRS += $(BLAS_INCLUDE)
|
||||
|
||||
# Automatic dependency generation (nvcc is handled separately)
|
||||
CXXFLAGS += -MMD -MP
|
||||
|
||||
# Complete build flags.
|
||||
COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \
|
||||
-DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0
|
||||
CXXFLAGS += -pthread -fPIC -fwrapv -std=c++11 $(COMMON_FLAGS) $(WARNINGS)
|
||||
NVCCFLAGS += -std=c++11 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
|
||||
|
||||
all: $(STATIC_LIB)
|
||||
$(PYTHON) setup.py build
|
||||
@ mv build/lib.linux-x86_64-3.6/StructuralLosses ..
|
||||
@ mv build/lib.linux-x86_64-3.6/*.so ../StructuralLosses/
|
||||
@- $(RM) -rf $(OBJ_DIR) build objs
|
||||
|
||||
$(OBJ_DIR):
|
||||
@ mkdir -p $@
|
||||
@ mkdir -p $@/cuda
|
||||
|
||||
$(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR)
|
||||
@ echo CXX $<
|
||||
$(Q)$(CXX) $< $(CXXFLAGS) -c -o $@
|
||||
|
||||
$(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR)
|
||||
@ echo NVCC $<
|
||||
$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \
|
||||
-odir $(@D)
|
||||
$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@
|
||||
|
||||
$(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR)
|
||||
$(RM) -f $(STATIC_LIB)
|
||||
$(RM) -rf build dist
|
||||
@ echo LD -o $@
|
||||
ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS)
|
||||
|
||||
clean:
|
||||
@- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses
|
||||
|
6
metrics/pytorch_structural_losses/__init__.py
Normal file
6
metrics/pytorch_structural_losses/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
#import torch
|
||||
|
||||
#from MakePytorchBackend import AddGPU, Foo, ApproxMatch
|
||||
|
||||
#from Add import add_gpu, approx_match
|
||||
|
45
metrics/pytorch_structural_losses/match_cost.py
Normal file
45
metrics/pytorch_structural_losses/match_cost.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
import torch
|
||||
from torch.autograd import Function
|
||||
from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad
|
||||
|
||||
# Inherit from Function
|
||||
class MatchCostFunction(Function):
|
||||
# Note that both forward and backward are @staticmethods
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, seta, setb):
|
||||
#print("Match Cost Forward")
|
||||
ctx.save_for_backward(seta, setb)
|
||||
'''
|
||||
input:
|
||||
set1 : batch_size * #dataset_points * 3
|
||||
set2 : batch_size * #query_points * 3
|
||||
returns:
|
||||
match : batch_size * #query_points * #dataset_points
|
||||
'''
|
||||
match, temp = ApproxMatch(seta, setb)
|
||||
ctx.match = match
|
||||
cost = MatchCost(seta, setb, match)
|
||||
return cost
|
||||
|
||||
"""
|
||||
grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)
|
||||
return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]
|
||||
"""
|
||||
# This function has only a single output, so it gets only one gradient
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
#print("Match Cost Backward")
|
||||
# This is a pattern that is very convenient - at the top of backward
|
||||
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
|
||||
# None. Thanks to the fact that additional trailing Nones are
|
||||
# ignored, the return statement is simple even when the function has
|
||||
# optional inputs.
|
||||
seta, setb = ctx.saved_tensors
|
||||
#grad_input = grad_weight = grad_bias = None
|
||||
grada, gradb = MatchCostGrad(seta, setb, ctx.match)
|
||||
grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)
|
||||
return grada*grad_output_expand, gradb*grad_output_expand
|
||||
|
||||
match_cost = MatchCostFunction.apply
|
||||
|
42
metrics/pytorch_structural_losses/nn_distance.py
Normal file
42
metrics/pytorch_structural_losses/nn_distance.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
import torch
|
||||
from torch.autograd import Function
|
||||
# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
|
||||
from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
|
||||
|
||||
# Inherit from Function
|
||||
class NNDistanceFunction(Function):
|
||||
# Note that both forward and backward are @staticmethods
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, seta, setb):
|
||||
#print("Match Cost Forward")
|
||||
ctx.save_for_backward(seta, setb)
|
||||
'''
|
||||
input:
|
||||
set1 : batch_size * #dataset_points * 3
|
||||
set2 : batch_size * #query_points * 3
|
||||
returns:
|
||||
dist1, idx1, dist2, idx2
|
||||
'''
|
||||
dist1, idx1, dist2, idx2 = NNDistance(seta, setb)
|
||||
ctx.idx1 = idx1
|
||||
ctx.idx2 = idx2
|
||||
return dist1, dist2
|
||||
|
||||
# This function has only a single output, so it gets only one gradient
|
||||
@staticmethod
|
||||
def backward(ctx, grad_dist1, grad_dist2):
|
||||
#print("Match Cost Backward")
|
||||
# This is a pattern that is very convenient - at the top of backward
|
||||
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
|
||||
# None. Thanks to the fact that additional trailing Nones are
|
||||
# ignored, the return statement is simple even when the function has
|
||||
# optional inputs.
|
||||
seta, setb = ctx.saved_tensors
|
||||
idx1 = ctx.idx1
|
||||
idx2 = ctx.idx2
|
||||
grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)
|
||||
return grada, gradb
|
||||
|
||||
nn_distance = NNDistanceFunction.apply
|
||||
|
15
metrics/pytorch_structural_losses/pybind/bind.cpp
Normal file
15
metrics/pytorch_structural_losses/pybind/bind.cpp
Normal file
|
@ -0,0 +1,15 @@
|
|||
#include <string>
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "pybind/extern.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
|
||||
m.def("ApproxMatch", &ApproxMatch);
|
||||
m.def("MatchCost", &MatchCost);
|
||||
m.def("MatchCostGrad", &MatchCostGrad);
|
||||
m.def("NNDistance", &NNDistance);
|
||||
m.def("NNDistanceGrad", &NNDistanceGrad);
|
||||
}
|
6
metrics/pytorch_structural_losses/pybind/extern.hpp
Normal file
6
metrics/pytorch_structural_losses/pybind/extern.hpp
Normal file
|
@ -0,0 +1,6 @@
|
|||
std::vector<at::Tensor> ApproxMatch(at::Tensor in_a, at::Tensor in_b);
|
||||
at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
|
||||
std::vector<at::Tensor> MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
|
||||
|
||||
std::vector<at::Tensor> NNDistance(at::Tensor set_d, at::Tensor set_q);
|
||||
std::vector<at::Tensor> NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2);
|
30
metrics/pytorch_structural_losses/setup.py
Normal file
30
metrics/pytorch_structural_losses/setup.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
||||
|
||||
# Python interface
|
||||
setup(
|
||||
name='PyTorchStructuralLosses',
|
||||
version='0.1.0',
|
||||
install_requires=['torch'],
|
||||
packages=['StructuralLosses'],
|
||||
package_dir={'StructuralLosses': './'},
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='StructuralLossesBackend',
|
||||
include_dirs=['./'],
|
||||
sources=[
|
||||
'pybind/bind.cpp',
|
||||
],
|
||||
libraries=['make_pytorch'],
|
||||
library_dirs=['objs'],
|
||||
# extra_compile_args=['-g']
|
||||
)
|
||||
],
|
||||
cmdclass={'build_ext': BuildExtension},
|
||||
author='Christopher B. Choy',
|
||||
author_email='chrischoy@ai.stanford.edu',
|
||||
description='Tutorial for Pytorch C++ Extension with a Makefile',
|
||||
keywords='Pytorch C++ Extension',
|
||||
url='https://github.com/chrischoy/MakePytorchPlusPlus',
|
||||
zip_safe=False,
|
||||
)
|
326
metrics/pytorch_structural_losses/src/approxmatch.cu
Normal file
326
metrics/pytorch_structural_losses/src/approxmatch.cu
Normal file
|
@ -0,0 +1,326 @@
|
|||
#include "utils.hpp"
|
||||
|
||||
__global__ void approxmatchkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){
|
||||
float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
|
||||
float multiL,multiR;
|
||||
if (n>=m){
|
||||
multiL=1;
|
||||
multiR=n/m;
|
||||
}else{
|
||||
multiL=m/n;
|
||||
multiR=1;
|
||||
}
|
||||
const int Block=1024;
|
||||
__shared__ float buf[Block*4];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int j=threadIdx.x;j<n*m;j+=blockDim.x)
|
||||
match[i*n*m+j]=0;
|
||||
for (int j=threadIdx.x;j<n;j+=blockDim.x)
|
||||
remainL[j]=multiL;
|
||||
for (int j=threadIdx.x;j<m;j+=blockDim.x)
|
||||
remainR[j]=multiR;
|
||||
__syncthreads();
|
||||
//for (int j=7;j>=-2;j--){
|
||||
for (int j=7;j>-2;j--){
|
||||
float level=-powf(4.0f,j);
|
||||
if (j==-2){
|
||||
level=0;
|
||||
}
|
||||
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||
int k=k0+threadIdx.x;
|
||||
float x1=0,y1=0,z1=0;
|
||||
if (k<n){
|
||||
x1=xyz1[i*n*3+k*3+0];
|
||||
y1=xyz1[i*n*3+k*3+1];
|
||||
z1=xyz1[i*n*3+k*3+2];
|
||||
}
|
||||
float suml=1e-9f;
|
||||
for (int l0=0;l0<m;l0+=Block){
|
||||
int lend=min(m,l0+Block)-l0;
|
||||
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
|
||||
float x2=xyz2[i*m*3+l0*3+l*3+0];
|
||||
float y2=xyz2[i*m*3+l0*3+l*3+1];
|
||||
float z2=xyz2[i*m*3+l0*3+l*3+2];
|
||||
buf[l*4+0]=x2;
|
||||
buf[l*4+1]=y2;
|
||||
buf[l*4+2]=z2;
|
||||
buf[l*4+3]=remainR[l0+l];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int l=0;l<lend;l++){
|
||||
float x2=buf[l*4+0];
|
||||
float y2=buf[l*4+1];
|
||||
float z2=buf[l*4+2];
|
||||
float d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));
|
||||
float w=__expf(d)*buf[l*4+3];
|
||||
suml+=w;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (k<n)
|
||||
ratioL[k]=remainL[k]/suml;
|
||||
}
|
||||
/*for (int k=threadIdx.x;k<n;k+=gridDim.x){
|
||||
float x1=xyz1[i*n*3+k*3+0];
|
||||
float y1=xyz1[i*n*3+k*3+1];
|
||||
float z1=xyz1[i*n*3+k*3+2];
|
||||
float suml=1e-9f;
|
||||
for (int l=0;l<m;l++){
|
||||
float x2=xyz2[i*m*3+l*3+0];
|
||||
float y2=xyz2[i*m*3+l*3+1];
|
||||
float z2=xyz2[i*m*3+l*3+2];
|
||||
float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*remainR[l];
|
||||
suml+=w;
|
||||
}
|
||||
ratioL[k]=remainL[k]/suml;
|
||||
}*/
|
||||
__syncthreads();
|
||||
for (int l0=0;l0<m;l0+=blockDim.x){
|
||||
int l=l0+threadIdx.x;
|
||||
float x2=0,y2=0,z2=0;
|
||||
if (l<m){
|
||||
x2=xyz2[i*m*3+l*3+0];
|
||||
y2=xyz2[i*m*3+l*3+1];
|
||||
z2=xyz2[i*m*3+l*3+2];
|
||||
}
|
||||
float sumr=0;
|
||||
for (int k0=0;k0<n;k0+=Block){
|
||||
int kend=min(n,k0+Block)-k0;
|
||||
for (int k=threadIdx.x;k<kend;k+=blockDim.x){
|
||||
buf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];
|
||||
buf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];
|
||||
buf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];
|
||||
buf[k*4+3]=ratioL[k0+k];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int k=0;k<kend;k++){
|
||||
float x1=buf[k*4+0];
|
||||
float y1=buf[k*4+1];
|
||||
float z1=buf[k*4+2];
|
||||
float w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];
|
||||
sumr+=w;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (l<m){
|
||||
sumr*=remainR[l];
|
||||
float consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
|
||||
ratioR[l]=consumption*remainR[l];
|
||||
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
|
||||
}
|
||||
}
|
||||
/*for (int l=threadIdx.x;l<m;l+=blockDim.x){
|
||||
float x2=xyz2[i*m*3+l*3+0];
|
||||
float y2=xyz2[i*m*3+l*3+1];
|
||||
float z2=xyz2[i*m*3+l*3+2];
|
||||
float sumr=0;
|
||||
for (int k=0;k<n;k++){
|
||||
float x1=xyz1[i*n*3+k*3+0];
|
||||
float y1=xyz1[i*n*3+k*3+1];
|
||||
float z1=xyz1[i*n*3+k*3+2];
|
||||
float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*ratioL[k];
|
||||
sumr+=w;
|
||||
}
|
||||
sumr*=remainR[l];
|
||||
float consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
|
||||
ratioR[l]=consumption*remainR[l];
|
||||
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
|
||||
}*/
|
||||
__syncthreads();
|
||||
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||
int k=k0+threadIdx.x;
|
||||
float x1=0,y1=0,z1=0;
|
||||
if (k<n){
|
||||
x1=xyz1[i*n*3+k*3+0];
|
||||
y1=xyz1[i*n*3+k*3+1];
|
||||
z1=xyz1[i*n*3+k*3+2];
|
||||
}
|
||||
float suml=0;
|
||||
for (int l0=0;l0<m;l0+=Block){
|
||||
int lend=min(m,l0+Block)-l0;
|
||||
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
|
||||
buf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];
|
||||
buf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];
|
||||
buf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];
|
||||
buf[l*4+3]=ratioR[l0+l];
|
||||
}
|
||||
__syncthreads();
|
||||
float rl=ratioL[k];
|
||||
if (k<n){
|
||||
for (int l=0;l<lend;l++){
|
||||
float x2=buf[l*4+0];
|
||||
float y2=buf[l*4+1];
|
||||
float z2=buf[l*4+2];
|
||||
float w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];
|
||||
match[i*n*m+(l0+l)*n+k]+=w;
|
||||
suml+=w;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (k<n)
|
||||
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
|
||||
}
|
||||
/*for (int k=threadIdx.x;k<n;k+=blockDim.x){
|
||||
float x1=xyz1[i*n*3+k*3+0];
|
||||
float y1=xyz1[i*n*3+k*3+1];
|
||||
float z1=xyz1[i*n*3+k*3+2];
|
||||
float suml=0;
|
||||
for (int l=0;l<m;l++){
|
||||
float x2=xyz2[i*m*3+l*3+0];
|
||||
float y2=xyz2[i*m*3+l*3+1];
|
||||
float z2=xyz2[i*m*3+l*3+2];
|
||||
float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*ratioL[k]*ratioR[l];
|
||||
match[i*n*m+l*n+k]+=w;
|
||||
suml+=w;
|
||||
}
|
||||
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
|
||||
}*/
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void matchcostkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ out){
|
||||
__shared__ float allsum[512];
|
||||
const int Block=256;
|
||||
__shared__ float buf[Block*3];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
float subsum=0;
|
||||
for (int k0=0;k0<m;k0+=Block){
|
||||
int endk=min(m,k0+Block);
|
||||
for (int k=threadIdx.x;k<(endk-k0)*3;k+=blockDim.x){
|
||||
buf[k]=xyz2[i*m*3+k0*3+k];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j=threadIdx.x;j<n;j+=blockDim.x){
|
||||
float x1=xyz1[(i*n+j)*3+0];
|
||||
float y1=xyz1[(i*n+j)*3+1];
|
||||
float z1=xyz1[(i*n+j)*3+2];
|
||||
for (int k=0;k<endk-k0;k++){
|
||||
//float x2=xyz2[(i*m+k)*3+0]-x1;
|
||||
//float y2=xyz2[(i*m+k)*3+1]-y1;
|
||||
//float z2=xyz2[(i*m+k)*3+2]-z1;
|
||||
float x2=buf[k*3+0]-x1;
|
||||
float y2=buf[k*3+1]-y1;
|
||||
float z2=buf[k*3+2]-z1;
|
||||
float d=sqrtf(x2*x2+y2*y2+z2*z2);
|
||||
subsum+=match[i*n*m+(k0+k)*n+j]*d;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
allsum[threadIdx.x]=subsum;
|
||||
for (int j=1;j<blockDim.x;j<<=1){
|
||||
__syncthreads();
|
||||
if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){
|
||||
allsum[threadIdx.x]+=allsum[threadIdx.x+j];
|
||||
}
|
||||
}
|
||||
if (threadIdx.x==0)
|
||||
out[i]=allsum[0];
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
//void matchcostLauncher(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * out){
|
||||
// matchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);
|
||||
//}
|
||||
|
||||
__global__ void matchcostgrad2kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){
|
||||
__shared__ float sum_grad[256*3];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
int kbeg=m*blockIdx.y/gridDim.y;
|
||||
int kend=m*(blockIdx.y+1)/gridDim.y;
|
||||
for (int k=kbeg;k<kend;k++){
|
||||
float x2=xyz2[(i*m+k)*3+0];
|
||||
float y2=xyz2[(i*m+k)*3+1];
|
||||
float z2=xyz2[(i*m+k)*3+2];
|
||||
float subsumx=0,subsumy=0,subsumz=0;
|
||||
for (int j=threadIdx.x;j<n;j+=blockDim.x){
|
||||
float x1=x2-xyz1[(i*n+j)*3+0];
|
||||
float y1=y2-xyz1[(i*n+j)*3+1];
|
||||
float z1=z2-xyz1[(i*n+j)*3+2];
|
||||
float d=match[i*n*m+k*n+j]*rsqrtf(fmaxf(x1*x1+y1*y1+z1*z1,1e-20f));
|
||||
subsumx+=x1*d;
|
||||
subsumy+=y1*d;
|
||||
subsumz+=z1*d;
|
||||
}
|
||||
sum_grad[threadIdx.x*3+0]=subsumx;
|
||||
sum_grad[threadIdx.x*3+1]=subsumy;
|
||||
sum_grad[threadIdx.x*3+2]=subsumz;
|
||||
for (int j=1;j<blockDim.x;j<<=1){
|
||||
__syncthreads();
|
||||
int j1=threadIdx.x;
|
||||
int j2=threadIdx.x+j;
|
||||
if ((j1&j)==0 && j2<blockDim.x){
|
||||
sum_grad[j1*3+0]+=sum_grad[j2*3+0];
|
||||
sum_grad[j1*3+1]+=sum_grad[j2*3+1];
|
||||
sum_grad[j1*3+2]+=sum_grad[j2*3+2];
|
||||
}
|
||||
}
|
||||
if (threadIdx.x==0){
|
||||
grad2[(i*m+k)*3+0]=sum_grad[0];
|
||||
grad2[(i*m+k)*3+1]=sum_grad[1];
|
||||
grad2[(i*m+k)*3+2]=sum_grad[2];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
__global__ void matchcostgrad1kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad1){
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int l=threadIdx.x;l<n;l+=blockDim.x){
|
||||
float x1=xyz1[i*n*3+l*3+0];
|
||||
float y1=xyz1[i*n*3+l*3+1];
|
||||
float z1=xyz1[i*n*3+l*3+2];
|
||||
float dx=0,dy=0,dz=0;
|
||||
for (int k=0;k<m;k++){
|
||||
float x2=xyz2[i*m*3+k*3+0];
|
||||
float y2=xyz2[i*m*3+k*3+1];
|
||||
float z2=xyz2[i*m*3+k*3+2];
|
||||
float d=match[i*n*m+k*n+l]*rsqrtf(fmaxf((x1-x2)*(x1-x2)+(y1-y2)*(y1-y2)+(z1-z2)*(z1-z2),1e-20f));
|
||||
dx+=(x1-x2)*d;
|
||||
dy+=(y1-y2)*d;
|
||||
dz+=(z1-z2)*d;
|
||||
}
|
||||
grad1[i*n*3+l*3+0]=dx;
|
||||
grad1[i*n*3+l*3+1]=dy;
|
||||
grad1[i*n*3+l*3+2]=dz;
|
||||
}
|
||||
}
|
||||
}
|
||||
//void matchcostgradLauncher(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad2){
|
||||
// matchcostgrad<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);
|
||||
//}
|
||||
|
||||
/*void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,
|
||||
cudaStream_t stream)*/
|
||||
// temp: TensorShape{b,(n+m)*2}
|
||||
void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream){
|
||||
approxmatchkernel
|
||||
<<<32, 512, 0, stream>>>(b,n,m,xyz1,xyz2,match,temp);
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (cudaSuccess != err)
|
||||
throw std::runtime_error(Formatter()
|
||||
<< "CUDA kernel failed : " << std::to_string(err));
|
||||
}
|
||||
|
||||
void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream){
|
||||
matchcostkernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,out);
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (cudaSuccess != err)
|
||||
throw std::runtime_error(Formatter()
|
||||
<< "CUDA kernel failed : " << std::to_string(err));
|
||||
}
|
||||
|
||||
void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream){
|
||||
matchcostgrad1kernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,grad1);
|
||||
matchcostgrad2kernel<<<dim3(32,32),256,0,stream>>>(b,n,m,xyz1,xyz2,match,grad2);
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (cudaSuccess != err)
|
||||
throw std::runtime_error(Formatter()
|
||||
<< "CUDA kernel failed : " << std::to_string(err));
|
||||
}
|
8
metrics/pytorch_structural_losses/src/approxmatch.cuh
Normal file
8
metrics/pytorch_structural_losses/src/approxmatch.cuh
Normal file
|
@ -0,0 +1,8 @@
|
|||
/*
|
||||
template <typename Dtype>
|
||||
void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,
|
||||
cudaStream_t stream);
|
||||
*/
|
||||
void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream);
|
||||
void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream);
|
||||
void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream);
|
155
metrics/pytorch_structural_losses/src/nndistance.cu
Executable file
155
metrics/pytorch_structural_losses/src/nndistance.cu
Executable file
|
@ -0,0 +1,155 @@
|
|||
|
||||
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
|
||||
const int batch=512;
|
||||
__shared__ float buf[batch*3];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int k2=0;k2<m;k2+=batch){
|
||||
int end_k=min(m,k2+batch)-k2;
|
||||
for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
|
||||
buf[j]=xyz2[(i*m+k2)*3+j];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
||||
float x1=xyz[(i*n+j)*3+0];
|
||||
float y1=xyz[(i*n+j)*3+1];
|
||||
float z1=xyz[(i*n+j)*3+2];
|
||||
int best_i=0;
|
||||
float best=0;
|
||||
int end_ka=end_k-(end_k&3);
|
||||
if (end_ka==batch){
|
||||
for (int k=0;k<batch;k+=4){
|
||||
{
|
||||
float x2=buf[k*3+0]-x1;
|
||||
float y2=buf[k*3+1]-y1;
|
||||
float z2=buf[k*3+2]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+3]-x1;
|
||||
float y2=buf[k*3+4]-y1;
|
||||
float z2=buf[k*3+5]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+1;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+6]-x1;
|
||||
float y2=buf[k*3+7]-y1;
|
||||
float z2=buf[k*3+8]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+9]-x1;
|
||||
float y2=buf[k*3+10]-y1;
|
||||
float z2=buf[k*3+11]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
for (int k=0;k<end_ka;k+=4){
|
||||
{
|
||||
float x2=buf[k*3+0]-x1;
|
||||
float y2=buf[k*3+1]-y1;
|
||||
float z2=buf[k*3+2]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+3]-x1;
|
||||
float y2=buf[k*3+4]-y1;
|
||||
float z2=buf[k*3+5]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+1;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+6]-x1;
|
||||
float y2=buf[k*3+7]-y1;
|
||||
float z2=buf[k*3+8]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+9]-x1;
|
||||
float y2=buf[k*3+10]-y1;
|
||||
float z2=buf[k*3+11]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int k=end_ka;k<end_k;k++){
|
||||
float x2=buf[k*3+0]-x1;
|
||||
float y2=buf[k*3+1]-y1;
|
||||
float z2=buf[k*3+2]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
if (k2==0 || result[(i*n+j)]>best){
|
||||
result[(i*n+j)]=best;
|
||||
result_i[(i*n+j)]=best_i;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
|
||||
NmDistanceKernel<<<dim3(32,16,1),512, 0, stream>>>(b,n,xyz,m,xyz2,result,result_i);
|
||||
NmDistanceKernel<<<dim3(32,16,1),512, 0, stream>>>(b,m,xyz2,n,xyz,result2,result2_i);
|
||||
}
|
||||
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
||||
float x1=xyz1[(i*n+j)*3+0];
|
||||
float y1=xyz1[(i*n+j)*3+1];
|
||||
float z1=xyz1[(i*n+j)*3+2];
|
||||
int j2=idx1[i*n+j];
|
||||
float x2=xyz2[(i*m+j2)*3+0];
|
||||
float y2=xyz2[(i*m+j2)*3+1];
|
||||
float z2=xyz2[(i*m+j2)*3+2];
|
||||
float g=grad_dist1[i*n+j]*2;
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
|
||||
}
|
||||
}
|
||||
}
|
||||
void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
|
||||
cudaMemset(grad_xyz1,0,b*n*3*4);
|
||||
cudaMemset(grad_xyz2,0,b*m*3*4);
|
||||
NmDistanceGradKernel<<<dim3(1,16,1),256, 0, stream>>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2);
|
||||
NmDistanceGradKernel<<<dim3(1,16,1),256, 0, stream>>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1);
|
||||
}
|
||||
|
2
metrics/pytorch_structural_losses/src/nndistance.cuh
Executable file
2
metrics/pytorch_structural_losses/src/nndistance.cuh
Executable file
|
@ -0,0 +1,2 @@
|
|||
void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);
|
||||
void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);
|
125
metrics/pytorch_structural_losses/src/structural_loss.cpp
Normal file
125
metrics/pytorch_structural_losses/src/structural_loss.cpp
Normal file
|
@ -0,0 +1,125 @@
|
|||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "src/approxmatch.cuh"
|
||||
#include "src/nndistance.cuh"
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
/*
|
||||
input:
|
||||
set1 : batch_size * #dataset_points * 3
|
||||
set2 : batch_size * #query_points * 3
|
||||
returns:
|
||||
match : batch_size * #query_points * #dataset_points
|
||||
*/
|
||||
// temp: TensorShape{b,(n+m)*2}
|
||||
std::vector<at::Tensor> ApproxMatch(at::Tensor set_d, at::Tensor set_q) {
|
||||
//std::cout << "[ApproxMatch] Called." << std::endl;
|
||||
int64_t batch_size = set_d.size(0);
|
||||
int64_t n_dataset_points = set_d.size(1); // n
|
||||
int64_t n_query_points = set_q.size(1); // m
|
||||
//std::cout << "[ApproxMatch] batch_size:" << batch_size << std::endl;
|
||||
at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
|
||||
at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
|
||||
CHECK_INPUT(set_d);
|
||||
CHECK_INPUT(set_q);
|
||||
CHECK_INPUT(match);
|
||||
CHECK_INPUT(temp);
|
||||
|
||||
approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),temp.data<float>(), at::cuda::getCurrentCUDAStream());
|
||||
return {match, temp};
|
||||
}
|
||||
|
||||
at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {
|
||||
//std::cout << "[MatchCost] Called." << std::endl;
|
||||
int64_t batch_size = set_d.size(0);
|
||||
int64_t n_dataset_points = set_d.size(1); // n
|
||||
int64_t n_query_points = set_q.size(1); // m
|
||||
//std::cout << "[MatchCost] batch_size:" << batch_size << std::endl;
|
||||
at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
|
||||
CHECK_INPUT(set_d);
|
||||
CHECK_INPUT(set_q);
|
||||
CHECK_INPUT(match);
|
||||
CHECK_INPUT(out);
|
||||
matchcost(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),out.data<float>(),at::cuda::getCurrentCUDAStream());
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {
|
||||
//std::cout << "[MatchCostGrad] Called." << std::endl;
|
||||
int64_t batch_size = set_d.size(0);
|
||||
int64_t n_dataset_points = set_d.size(1); // n
|
||||
int64_t n_query_points = set_q.size(1); // m
|
||||
//std::cout << "[MatchCostGrad] batch_size:" << batch_size << std::endl;
|
||||
at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
|
||||
at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
|
||||
CHECK_INPUT(set_d);
|
||||
CHECK_INPUT(set_q);
|
||||
CHECK_INPUT(match);
|
||||
CHECK_INPUT(grad1);
|
||||
CHECK_INPUT(grad2);
|
||||
matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),grad1.data<float>(),grad2.data<float>(),at::cuda::getCurrentCUDAStream());
|
||||
return {grad1, grad2};
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
input:
|
||||
set_d : batch_size * #dataset_points * 3
|
||||
set_q : batch_size * #query_points * 3
|
||||
returns:
|
||||
dist1, idx1 : batch_size * #dataset_points
|
||||
dist2, idx2 : batch_size * #query_points
|
||||
*/
|
||||
std::vector<at::Tensor> NNDistance(at::Tensor set_d, at::Tensor set_q) {
|
||||
//std::cout << "[NNDistance] Called." << std::endl;
|
||||
int64_t batch_size = set_d.size(0);
|
||||
int64_t n_dataset_points = set_d.size(1); // n
|
||||
int64_t n_query_points = set_q.size(1); // m
|
||||
//std::cout << "[NNDistance] batch_size:" << batch_size << std::endl;
|
||||
at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
|
||||
at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));
|
||||
at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
|
||||
at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));
|
||||
CHECK_INPUT(set_d);
|
||||
CHECK_INPUT(set_q);
|
||||
CHECK_INPUT(dist1);
|
||||
CHECK_INPUT(idx1);
|
||||
CHECK_INPUT(dist2);
|
||||
CHECK_INPUT(idx2);
|
||||
// void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);
|
||||
nndistance(batch_size,n_dataset_points,set_d.data<float>(),n_query_points,set_q.data<float>(),dist1.data<float>(),idx1.data<int>(),dist2.data<float>(),idx2.data<int>(), at::cuda::getCurrentCUDAStream());
|
||||
return {dist1, idx1, dist2, idx2};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) {
|
||||
//std::cout << "[NNDistanceGrad] Called." << std::endl;
|
||||
int64_t batch_size = set_d.size(0);
|
||||
int64_t n_dataset_points = set_d.size(1); // n
|
||||
int64_t n_query_points = set_q.size(1); // m
|
||||
//std::cout << "[NNDistanceGrad] batch_size:" << batch_size << std::endl;
|
||||
at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
|
||||
at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
|
||||
CHECK_INPUT(set_d);
|
||||
CHECK_INPUT(set_q);
|
||||
CHECK_INPUT(idx1);
|
||||
CHECK_INPUT(idx2);
|
||||
CHECK_INPUT(grad_dist1);
|
||||
CHECK_INPUT(grad_dist2);
|
||||
CHECK_INPUT(grad1);
|
||||
CHECK_INPUT(grad2);
|
||||
//void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);
|
||||
nndistancegrad(batch_size,n_dataset_points,set_d.data<float>(),n_query_points,set_q.data<float>(),
|
||||
grad_dist1.data<float>(),idx1.data<int>(),
|
||||
grad_dist2.data<float>(),idx2.data<int>(),
|
||||
grad1.data<float>(),grad2.data<float>(),
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
return {grad1, grad2};
|
||||
}
|
||||
|
26
metrics/pytorch_structural_losses/src/utils.hpp
Normal file
26
metrics/pytorch_structural_losses/src/utils.hpp
Normal file
|
@ -0,0 +1,26 @@
|
|||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
class Formatter {
|
||||
public:
|
||||
Formatter() {}
|
||||
~Formatter() {}
|
||||
|
||||
template <typename Type> Formatter &operator<<(const Type &value) {
|
||||
stream_ << value;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string str() const { return stream_.str(); }
|
||||
operator std::string() const { return stream_.str(); }
|
||||
|
||||
enum ConvertToString { to_str };
|
||||
|
||||
std::string operator>>(ConvertToString) { return stream_.str(); }
|
||||
|
||||
private:
|
||||
std::stringstream stream_;
|
||||
Formatter(const Formatter &);
|
||||
Formatter &operator=(Formatter &);
|
||||
};
|
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
122
models/cnf.py
Normal file
122
models/cnf.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchdiffeq import odeint_adjoint
|
||||
from torchdiffeq import odeint as odeint_normal
|
||||
|
||||
__all__ = ["CNF", "SequentialFlow"]
|
||||
|
||||
|
||||
class SequentialFlow(nn.Module):
|
||||
"""A generalized nn.Sequential container for normalizing flows."""
|
||||
|
||||
def __init__(self, layer_list):
|
||||
super(SequentialFlow, self).__init__()
|
||||
self.chain = nn.ModuleList(layer_list)
|
||||
|
||||
def forward(self, x, context, logpx=None, reverse=False, inds=None, integration_times=None):
|
||||
if inds is None:
|
||||
if reverse:
|
||||
inds = range(len(self.chain) - 1, -1, -1)
|
||||
else:
|
||||
inds = range(len(self.chain))
|
||||
|
||||
if logpx is None:
|
||||
for i in inds:
|
||||
x = self.chain[i](x, context, logpx, integration_times, reverse)
|
||||
return x
|
||||
else:
|
||||
for i in inds:
|
||||
x, logpx = self.chain[i](x, context, logpx, integration_times, reverse)
|
||||
return x, logpx
|
||||
|
||||
|
||||
class CNF(nn.Module):
|
||||
def __init__(self, odefunc, conditional=True, T=1.0, train_T=False, regularization_fns=None,
|
||||
solver='dopri5', atol=1e-5, rtol=1e-5, use_adjoint=True):
|
||||
super(CNF, self).__init__()
|
||||
self.train_T = train_T
|
||||
self.T = T
|
||||
if train_T:
|
||||
self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T))))
|
||||
|
||||
if regularization_fns is not None and len(regularization_fns) > 0:
|
||||
raise NotImplementedError("Regularization not supported")
|
||||
self.use_adjoint = use_adjoint
|
||||
self.odefunc = odefunc
|
||||
self.solver = solver
|
||||
self.atol = atol
|
||||
self.rtol = rtol
|
||||
self.test_solver = solver
|
||||
self.test_atol = atol
|
||||
self.test_rtol = rtol
|
||||
self.solver_options = {}
|
||||
self.conditional = conditional
|
||||
|
||||
def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False):
|
||||
if logpx is None:
|
||||
_logpx = torch.zeros(*x.shape[:-1], 1).to(x)
|
||||
else:
|
||||
_logpx = logpx
|
||||
|
||||
if self.conditional:
|
||||
assert context is not None
|
||||
states = (x, _logpx, context)
|
||||
atol = [self.atol] * 3
|
||||
rtol = [self.rtol] * 3
|
||||
else:
|
||||
states = (x, _logpx)
|
||||
atol = [self.atol] * 2
|
||||
rtol = [self.rtol] * 2
|
||||
|
||||
if integration_times is None:
|
||||
if self.train_T:
|
||||
integration_times = torch.stack(
|
||||
[torch.tensor(0.0).to(x), self.sqrt_end_time * self.sqrt_end_time]
|
||||
).to(x)
|
||||
else:
|
||||
integration_times = torch.tensor([0., self.T], requires_grad=False).to(x)
|
||||
|
||||
if reverse:
|
||||
integration_times = _flip(integration_times, 0)
|
||||
|
||||
# Refresh the odefunc statistics.
|
||||
self.odefunc.before_odeint()
|
||||
odeint = odeint_adjoint if self.use_adjoint else odeint_normal
|
||||
if self.training:
|
||||
state_t = odeint(
|
||||
self.odefunc,
|
||||
states,
|
||||
integration_times.to(x),
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
method=self.solver,
|
||||
options=self.solver_options,
|
||||
)
|
||||
else:
|
||||
state_t = odeint(
|
||||
self.odefunc,
|
||||
states,
|
||||
integration_times.to(x),
|
||||
atol=self.test_atol,
|
||||
rtol=self.test_rtol,
|
||||
method=self.test_solver,
|
||||
)
|
||||
|
||||
if len(integration_times) == 2:
|
||||
state_t = tuple(s[1] for s in state_t)
|
||||
|
||||
z_t, logpz_t = state_t[:2]
|
||||
|
||||
if logpx is not None:
|
||||
return z_t, logpz_t
|
||||
else:
|
||||
return z_t
|
||||
|
||||
def num_evals(self):
|
||||
return self.odefunc._num_evals.item()
|
||||
|
||||
|
||||
def _flip(x, dim):
|
||||
indices = [slice(None)] * x.dim()
|
||||
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
|
||||
return x[tuple(indices)]
|
103
models/diffeq_layers.py
Normal file
103
models/diffeq_layers.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1 or classname.find('Conv') != -1:
|
||||
nn.init.constant_(m.weight, 0)
|
||||
nn.init.normal_(m.bias, 0, 0.01)
|
||||
|
||||
|
||||
class IgnoreLinear(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_c):
|
||||
super(IgnoreLinear, self).__init__()
|
||||
self._layer = nn.Linear(dim_in, dim_out)
|
||||
|
||||
def forward(self, context, x):
|
||||
return self._layer(x)
|
||||
|
||||
|
||||
class ConcatLinear(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_c):
|
||||
super(ConcatLinear, self).__init__()
|
||||
self._layer = nn.Linear(dim_in + 1 + dim_c, dim_out)
|
||||
|
||||
def forward(self, context, x, c):
|
||||
if x.dim() == 3:
|
||||
context = context.unsqueeze(1).expand(-1, x.size(1), -1)
|
||||
x_context = torch.cat((x, context), dim=2)
|
||||
return self._layer(x_context)
|
||||
|
||||
|
||||
class ConcatLinear_v2(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_c):
|
||||
super(ConcatLinear_v2, self).__init__()
|
||||
self._layer = nn.Linear(dim_in, dim_out)
|
||||
self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False)
|
||||
|
||||
def forward(self, context, x):
|
||||
bias = self._hyper_bias(context)
|
||||
if x.dim() == 3:
|
||||
bias = bias.unsqueeze(1)
|
||||
return self._layer(x) + bias
|
||||
|
||||
|
||||
class SquashLinear(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_c):
|
||||
super(SquashLinear, self).__init__()
|
||||
self._layer = nn.Linear(dim_in, dim_out)
|
||||
self._hyper = nn.Linear(1 + dim_c, dim_out)
|
||||
|
||||
def forward(self, context, x):
|
||||
gate = torch.sigmoid(self._hyper(context))
|
||||
if x.dim() == 3:
|
||||
gate = gate.unsqueeze(1)
|
||||
return self._layer(x) * gate
|
||||
|
||||
|
||||
class ScaleLinear(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_c):
|
||||
super(ScaleLinear, self).__init__()
|
||||
self._layer = nn.Linear(dim_in, dim_out)
|
||||
self._hyper = nn.Linear(1 + dim_c, dim_out)
|
||||
|
||||
def forward(self, context, x):
|
||||
gate = self._hyper(context)
|
||||
if x.dim() == 3:
|
||||
gate = gate.unsqueeze(1)
|
||||
return self._layer(x) * gate
|
||||
|
||||
|
||||
class ConcatSquashLinear(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_c):
|
||||
super(ConcatSquashLinear, self).__init__()
|
||||
self._layer = nn.Linear(dim_in, dim_out)
|
||||
self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False)
|
||||
self._hyper_gate = nn.Linear(1 + dim_c, dim_out)
|
||||
|
||||
def forward(self, context, x):
|
||||
gate = torch.sigmoid(self._hyper_gate(context))
|
||||
bias = self._hyper_bias(context)
|
||||
if x.dim() == 3:
|
||||
gate = gate.unsqueeze(1)
|
||||
bias = bias.unsqueeze(1)
|
||||
ret = self._layer(x) * gate + bias
|
||||
return ret
|
||||
|
||||
|
||||
class ConcatScaleLinear(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_c):
|
||||
super(ConcatScaleLinear, self).__init__()
|
||||
self._layer = nn.Linear(dim_in, dim_out)
|
||||
self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False)
|
||||
self._hyper_gate = nn.Linear(1 + dim_c, dim_out)
|
||||
|
||||
def forward(self, context, x):
|
||||
gate = self._hyper_gate(context)
|
||||
bias = self._hyper_bias(context)
|
||||
if x.dim() == 3:
|
||||
gate = gate.unsqueeze(1)
|
||||
bias = bias.unsqueeze(1)
|
||||
ret = self._layer(x) * gate + bias
|
||||
return ret
|
89
models/flow.py
Normal file
89
models/flow.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
from .odefunc import ODEfunc, ODEnet
|
||||
from .normalization import MovingBatchNorm1d
|
||||
from .cnf import CNF, SequentialFlow
|
||||
|
||||
|
||||
def count_nfe(model):
|
||||
class AccNumEvals(object):
|
||||
|
||||
def __init__(self):
|
||||
self.num_evals = 0
|
||||
|
||||
def __call__(self, module):
|
||||
if isinstance(module, CNF):
|
||||
self.num_evals += module.num_evals()
|
||||
|
||||
accumulator = AccNumEvals()
|
||||
model.apply(accumulator)
|
||||
return accumulator.num_evals
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def count_total_time(model):
|
||||
class Accumulator(object):
|
||||
|
||||
def __init__(self):
|
||||
self.total_time = 0
|
||||
|
||||
def __call__(self, module):
|
||||
if isinstance(module, CNF):
|
||||
self.total_time = self.total_time + module.sqrt_end_time * module.sqrt_end_time
|
||||
|
||||
accumulator = Accumulator()
|
||||
model.apply(accumulator)
|
||||
return accumulator.total_time
|
||||
|
||||
|
||||
def build_model(args, input_dim, hidden_dims, context_dim, num_blocks, conditional):
|
||||
def build_cnf():
|
||||
diffeq = ODEnet(
|
||||
hidden_dims=hidden_dims,
|
||||
input_shape=(input_dim,),
|
||||
context_dim=context_dim,
|
||||
layer_type=args.layer_type,
|
||||
nonlinearity=args.nonlinearity,
|
||||
)
|
||||
odefunc = ODEfunc(
|
||||
diffeq=diffeq,
|
||||
)
|
||||
cnf = CNF(
|
||||
odefunc=odefunc,
|
||||
T=args.time_length,
|
||||
train_T=args.train_T,
|
||||
conditional=conditional,
|
||||
solver=args.solver,
|
||||
use_adjoint=args.use_adjoint,
|
||||
atol=args.atol,
|
||||
rtol=args.rtol,
|
||||
)
|
||||
return cnf
|
||||
|
||||
chain = [build_cnf() for _ in range(num_blocks)]
|
||||
if args.batch_norm:
|
||||
bn_layers = [MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag, sync=args.sync_bn)
|
||||
for _ in range(num_blocks)]
|
||||
bn_chain = [MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag, sync=args.sync_bn)]
|
||||
for a, b in zip(chain, bn_layers):
|
||||
bn_chain.append(a)
|
||||
bn_chain.append(b)
|
||||
chain = bn_chain
|
||||
model = SequentialFlow(chain)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_point_cnf(args):
|
||||
dims = tuple(map(int, args.dims.split("-")))
|
||||
model = build_model(args, args.input_dim, dims, args.zdim, args.num_blocks, True).cuda()
|
||||
print("Number of trainable parameters of Point CNF: {}".format(count_parameters(model)))
|
||||
return model
|
||||
|
||||
|
||||
def get_latent_cnf(args):
|
||||
dims = tuple(map(int, args.latent_dims.split("-")))
|
||||
model = build_model(args, args.zdim, dims, 0, args.latent_num_blocks, False).cuda()
|
||||
print("Number of trainable parameters of Latent CNF: {}".format(count_parameters(model)))
|
||||
return model
|
224
models/networks.py
Normal file
224
models/networks.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch import optim
|
||||
from torch import nn
|
||||
from models.flow import get_point_cnf
|
||||
from models.flow import get_latent_cnf
|
||||
from utils import truncated_normal, reduce_tensor, standard_normal_logprob
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, zdim, input_dim=3, use_deterministic_encoder=False):
|
||||
super(Encoder, self).__init__()
|
||||
self.use_deterministic_encoder = use_deterministic_encoder
|
||||
self.zdim = zdim
|
||||
self.conv1 = nn.Conv1d(input_dim, 128, 1)
|
||||
self.conv2 = nn.Conv1d(128, 128, 1)
|
||||
self.conv3 = nn.Conv1d(128, 256, 1)
|
||||
self.conv4 = nn.Conv1d(256, 512, 1)
|
||||
self.bn1 = nn.BatchNorm1d(128)
|
||||
self.bn2 = nn.BatchNorm1d(128)
|
||||
self.bn3 = nn.BatchNorm1d(256)
|
||||
self.bn4 = nn.BatchNorm1d(512)
|
||||
|
||||
if self.use_deterministic_encoder:
|
||||
self.fc1 = nn.Linear(512, 256)
|
||||
self.fc2 = nn.Linear(256, 128)
|
||||
self.fc_bn1 = nn.BatchNorm1d(256)
|
||||
self.fc_bn2 = nn.BatchNorm1d(128)
|
||||
self.fc3 = nn.Linear(128, zdim)
|
||||
else:
|
||||
# Mapping to [c], cmean
|
||||
self.fc1_m = nn.Linear(512, 256)
|
||||
self.fc2_m = nn.Linear(256, 128)
|
||||
self.fc3_m = nn.Linear(128, zdim)
|
||||
self.fc_bn1_m = nn.BatchNorm1d(256)
|
||||
self.fc_bn2_m = nn.BatchNorm1d(128)
|
||||
|
||||
# Mapping to [c], cmean
|
||||
self.fc1_v = nn.Linear(512, 256)
|
||||
self.fc2_v = nn.Linear(256, 128)
|
||||
self.fc3_v = nn.Linear(128, zdim)
|
||||
self.fc_bn1_v = nn.BatchNorm1d(256)
|
||||
self.fc_bn2_v = nn.BatchNorm1d(128)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, 2)
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
x = self.bn4(self.conv4(x))
|
||||
x = torch.max(x, 2, keepdim=True)[0]
|
||||
x = x.view(-1, 512)
|
||||
|
||||
ms = F.relu(self.fc_bn1(self.fc1(x)))
|
||||
ms = F.relu(self.fc_bn2(self.fc2(ms)))
|
||||
ms = self.fc3(ms)
|
||||
if self.use_deterministic_encoder:
|
||||
m, v = ms, 0
|
||||
else:
|
||||
m = F.relu(self.fc_bn1_m(self.fc1_m(x)))
|
||||
m = F.relu(self.fc_bn2_m(self.fc2_m(m)))
|
||||
m = self.fc3_m(m)
|
||||
v = F.relu(self.fc_bn1_v(self.fc1_v(x)))
|
||||
v = F.relu(self.fc_bn2_v(self.fc2_v(v)))
|
||||
v = self.fc3_v(v)
|
||||
|
||||
return m, v
|
||||
|
||||
|
||||
# Model
|
||||
class PointFlow(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(PointFlow, self).__init__()
|
||||
self.input_dim = args.input_dim
|
||||
self.zdim = args.zdim
|
||||
self.use_latent_flow = args.use_latent_flow
|
||||
self.use_deterministic_encoder = args.use_deterministic_encoder
|
||||
self.prior_weight = args.prior_weight
|
||||
self.recon_weight = args.recon_weight
|
||||
self.entropy_weight = args.entropy_weight
|
||||
self.distributed = args.distributed
|
||||
self.truncate_std = None
|
||||
self.encoder = Encoder(
|
||||
zdim=args.zdim, input_dim=args.input_dim,
|
||||
use_deterministic_encoder=args.use_deterministic_encoder)
|
||||
self.point_cnf = get_point_cnf(args)
|
||||
self.latent_cnf = get_latent_cnf(args) if args.use_latent_flow else nn.Sequential()
|
||||
|
||||
@staticmethod
|
||||
def sample_gaussian(size, truncate_std=None, gpu=None):
|
||||
y = torch.randn(*size).float()
|
||||
y = y if gpu is None else y.cuda(gpu)
|
||||
if truncate_std is not None:
|
||||
truncated_normal(y, mean=0, std=1, trunc_std=truncate_std)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def reparameterize_gaussian(mean, logvar):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn(std.size()).to(mean)
|
||||
return mean + std * eps
|
||||
|
||||
@staticmethod
|
||||
def gaussian_entropy(logvar):
|
||||
const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2))
|
||||
ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const
|
||||
return ent
|
||||
|
||||
def multi_gpu_wrapper(self, f):
|
||||
self.encoder = f(self.encoder)
|
||||
self.point_cnf = f(self.point_cnf)
|
||||
self.latent_cnf = f(self.latent_cnf)
|
||||
|
||||
def make_optimizer(self, args):
|
||||
def _get_opt_(params):
|
||||
if args.optimizer == 'adam':
|
||||
optimizer = optim.Adam(params, lr=args.lr, betas=(args.beta1, args.beta2),
|
||||
weight_decay=args.weight_decay)
|
||||
elif args.optimizer == 'sgd':
|
||||
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum)
|
||||
else:
|
||||
assert 0, "args.optimizer should be either 'adam' or 'sgd'"
|
||||
return optimizer
|
||||
opt = _get_opt_(list(self.encoder.parameters()) + list(self.point_cnf.parameters())
|
||||
+ list(list(self.latent_cnf.parameters())))
|
||||
return opt
|
||||
|
||||
def forward(self, x, opt, step, writer=None):
|
||||
opt.zero_grad()
|
||||
batch_size = x.size(0)
|
||||
num_points = x.size(1)
|
||||
z_mu, z_sigma = self.encoder(x)
|
||||
if self.use_deterministic_encoder:
|
||||
z = z_mu + 0 * z_sigma
|
||||
else:
|
||||
z = self.reparameterize_gaussian(z_mu, z_sigma)
|
||||
|
||||
# Compute H[Q(z|X)]
|
||||
if self.use_deterministic_encoder:
|
||||
entropy = torch.zeros(batch_size).to(z)
|
||||
else:
|
||||
entropy = self.gaussian_entropy(z_sigma)
|
||||
|
||||
# Compute the prior probability P(z)
|
||||
if self.use_latent_flow:
|
||||
w, delta_log_pw = self.latent_cnf(z, None, torch.zeros(batch_size, 1).to(z))
|
||||
log_pw = standard_normal_logprob(w).view(batch_size, -1).sum(1, keepdim=True)
|
||||
delta_log_pw = delta_log_pw.view(batch_size, 1)
|
||||
log_pz = log_pw - delta_log_pw
|
||||
else:
|
||||
log_pz = torch.zeros(batch_size, 1).to(z)
|
||||
|
||||
# Compute the reconstruction likelihood P(X|z)
|
||||
z_new = z.view(*z.size())
|
||||
z_new = z_new + (log_pz * 0.).mean()
|
||||
y, delta_log_py = self.point_cnf(x, z_new, torch.zeros(batch_size, num_points, 1).to(x))
|
||||
log_py = standard_normal_logprob(y).view(batch_size, -1).sum(1, keepdim=True)
|
||||
delta_log_py = delta_log_py.view(batch_size, num_points, 1).sum(1)
|
||||
log_px = log_py - delta_log_py
|
||||
|
||||
# Loss
|
||||
entropy_loss = -entropy.mean() * self.entropy_weight
|
||||
recon_loss = -log_px.mean() * self.recon_weight
|
||||
prior_loss = -log_pz.mean() * self.prior_weight
|
||||
loss = entropy_loss + prior_loss + recon_loss
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
# LOGGING (after the training)
|
||||
if self.distributed:
|
||||
entropy_log = reduce_tensor(entropy.mean())
|
||||
recon = reduce_tensor(-log_px.mean())
|
||||
prior = reduce_tensor(-log_pz.mean())
|
||||
else:
|
||||
entropy_log = entropy.mean()
|
||||
recon = -log_px.mean()
|
||||
prior = -log_pz.mean()
|
||||
|
||||
recon_nats = recon / float(x.size(1) * x.size(2))
|
||||
prior_nats = prior / float(self.zdim)
|
||||
|
||||
if writer is not None:
|
||||
writer.add_scalar('train/entropy', entropy_log, step)
|
||||
writer.add_scalar('train/prior', prior, step)
|
||||
writer.add_scalar('train/prior(nats)', prior_nats, step)
|
||||
writer.add_scalar('train/recon', recon, step)
|
||||
writer.add_scalar('train/recon(nats)', recon_nats, step)
|
||||
|
||||
return {
|
||||
'entropy': entropy_log.cpu().detach().item()
|
||||
if not isinstance(entropy_log, float) else entropy_log,
|
||||
'prior_nats': prior_nats,
|
||||
'recon_nats': recon_nats,
|
||||
}
|
||||
|
||||
def encode(self, x):
|
||||
z_mu, z_sigma = self.encoder(x)
|
||||
if self.use_deterministic_encoder:
|
||||
return z_mu
|
||||
else:
|
||||
return self.reparameterize_gaussian(z_mu, z_sigma)
|
||||
|
||||
def decode(self, z, num_points, truncate_std=None):
|
||||
# transform points from the prior to a point cloud, conditioned on a shape code
|
||||
y = self.sample_gaussian((z.size(0), num_points, self.input_dim), truncate_std)
|
||||
x = self.point_cnf(y, z, reverse=True).view(*y.size())
|
||||
return y, x
|
||||
|
||||
def sample(self, batch_size, num_points, truncate_std=None, truncate_std_latent=None, gpu=None):
|
||||
assert self.use_latent_flow, "Sampling requires `self.use_latent_flow` to be True."
|
||||
# Generate the shape code from the prior
|
||||
w = self.sample_gaussian((batch_size, self.zdim), truncate_std_latent, gpu=gpu)
|
||||
z = self.latent_cnf(w, None, reverse=True).view(*w.size())
|
||||
# Sample points conditioned on the shape code
|
||||
y = self.sample_gaussian((batch_size, num_points, self.input_dim), truncate_std, gpu=gpu)
|
||||
x = self.point_cnf(y, z, reverse=True).view(*y.size())
|
||||
return z, x
|
||||
|
||||
def reconstruct(self, x, num_points=None, truncate_std=None):
|
||||
num_points = x.size(1) if num_points is None else num_points
|
||||
z = self.encode(x)
|
||||
_, x = self.decode(z, num_points, truncate_std)
|
||||
return x
|
145
models/normalization.py
Normal file
145
models/normalization.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Parameter
|
||||
from utils import reduce_tensor
|
||||
|
||||
__all__ = ['MovingBatchNorm1d']
|
||||
|
||||
|
||||
class MovingBatchNormNd(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True, sync=False):
|
||||
super(MovingBatchNormNd, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.sync = sync
|
||||
self.affine = affine
|
||||
self.eps = eps
|
||||
self.decay = decay
|
||||
self.bn_lag = bn_lag
|
||||
self.register_buffer('step', torch.zeros(1))
|
||||
if self.affine:
|
||||
self.weight = Parameter(torch.Tensor(num_features))
|
||||
self.bias = Parameter(torch.Tensor(num_features))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_parameters(self):
|
||||
self.running_mean.zero_()
|
||||
self.running_var.fill_(1)
|
||||
if self.affine:
|
||||
self.weight.data.zero_()
|
||||
self.bias.data.zero_()
|
||||
|
||||
def forward(self, x, c=None, logpx=None, reverse=False):
|
||||
if reverse:
|
||||
return self._reverse(x, logpx)
|
||||
else:
|
||||
return self._forward(x, logpx)
|
||||
|
||||
def _forward(self, x, logpx=None):
|
||||
num_channels = x.size(-1)
|
||||
used_mean = self.running_mean.clone().detach()
|
||||
used_var = self.running_var.clone().detach()
|
||||
|
||||
if self.training:
|
||||
# compute batch statistics
|
||||
x_t = x.transpose(0, 1).reshape(num_channels, -1)
|
||||
batch_mean = torch.mean(x_t, dim=1)
|
||||
|
||||
if self.sync:
|
||||
batch_ex2 = torch.mean(x_t**2, dim=1)
|
||||
batch_mean = reduce_tensor(batch_mean)
|
||||
batch_ex2 = reduce_tensor(batch_ex2)
|
||||
batch_var = batch_ex2 - batch_mean**2
|
||||
else:
|
||||
batch_var = torch.var(x_t, dim=1)
|
||||
|
||||
# moving average
|
||||
if self.bn_lag > 0:
|
||||
used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach())
|
||||
used_mean /= (1. - self.bn_lag**(self.step[0] + 1))
|
||||
used_var = batch_var - (1 - self.bn_lag) * (batch_var - used_var.detach())
|
||||
used_var /= (1. - self.bn_lag**(self.step[0] + 1))
|
||||
|
||||
# update running estimates
|
||||
self.running_mean -= self.decay * (self.running_mean - batch_mean.data)
|
||||
self.running_var -= self.decay * (self.running_var - batch_var.data)
|
||||
self.step += 1
|
||||
|
||||
# perform normalization
|
||||
used_mean = used_mean.view(*self.shape).expand_as(x)
|
||||
used_var = used_var.view(*self.shape).expand_as(x)
|
||||
|
||||
y = (x - used_mean) * torch.exp(-0.5 * torch.log(used_var + self.eps))
|
||||
|
||||
if self.affine:
|
||||
weight = self.weight.view(*self.shape).expand_as(x)
|
||||
bias = self.bias.view(*self.shape).expand_as(x)
|
||||
y = y * torch.exp(weight) + bias
|
||||
|
||||
if logpx is None:
|
||||
return y
|
||||
else:
|
||||
return y, logpx - self._logdetgrad(x, used_var).sum(-1, keepdim=True)
|
||||
|
||||
def _reverse(self, y, logpy=None):
|
||||
used_mean = self.running_mean
|
||||
used_var = self.running_var
|
||||
|
||||
if self.affine:
|
||||
weight = self.weight.view(*self.shape).expand_as(y)
|
||||
bias = self.bias.view(*self.shape).expand_as(y)
|
||||
y = (y - bias) * torch.exp(-weight)
|
||||
|
||||
used_mean = used_mean.view(*self.shape).expand_as(y)
|
||||
used_var = used_var.view(*self.shape).expand_as(y)
|
||||
x = y * torch.exp(0.5 * torch.log(used_var + self.eps)) + used_mean
|
||||
|
||||
if logpy is None:
|
||||
return x
|
||||
else:
|
||||
return x, logpy + self._logdetgrad(x, used_var).sum(-1, keepdim=True)
|
||||
|
||||
def _logdetgrad(self, x, used_var):
|
||||
logdetgrad = -0.5 * torch.log(used_var + self.eps)
|
||||
if self.affine:
|
||||
weight = self.weight.view(*self.shape).expand(*x.size())
|
||||
logdetgrad += weight
|
||||
return logdetgrad
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
'{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},'
|
||||
' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__)
|
||||
)
|
||||
|
||||
|
||||
def stable_var(x, mean=None, dim=1):
|
||||
if mean is None:
|
||||
mean = x.mean(dim, keepdim=True)
|
||||
mean = mean.view(-1, 1)
|
||||
res = torch.pow(x - mean, 2)
|
||||
max_sqr = torch.max(res, dim, keepdim=True)[0]
|
||||
var = torch.mean(res / max_sqr, 1, keepdim=True) * max_sqr
|
||||
var = var.view(-1)
|
||||
# change nan to zero
|
||||
var[var != var] = 0
|
||||
return var
|
||||
|
||||
|
||||
class MovingBatchNorm1d(MovingBatchNormNd):
|
||||
@property
|
||||
def shape(self):
|
||||
return [1, -1]
|
||||
|
||||
def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False):
|
||||
ret = super(MovingBatchNorm1d, self).forward(
|
||||
x, context, logpx=logpx, reverse=reverse)
|
||||
return ret
|
137
models/odefunc.py
Normal file
137
models/odefunc.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from . import diffeq_layers
|
||||
|
||||
__all__ = ["ODEnet", "ODEfunc"]
|
||||
|
||||
|
||||
def divergence_approx(f, y, e=None):
|
||||
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
|
||||
e_dzdx_e = e_dzdx.mul(e)
|
||||
|
||||
cnt = 0
|
||||
while not e_dzdx_e.requires_grad and cnt < 10:
|
||||
# print("RequiresGrad:f=%s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt=%d"
|
||||
# % (f.requires_grad, y.requires_grad, e_dzdx.requires_grad,
|
||||
# e.requires_grad, e_dzdx_e.requires_grad, cnt))
|
||||
e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
|
||||
e_dzdx_e = e_dzdx * e
|
||||
cnt += 1
|
||||
|
||||
approx_tr_dzdx = e_dzdx_e.sum(dim=-1)
|
||||
assert approx_tr_dzdx.requires_grad, \
|
||||
"(failed to add node to graph) f=%s %s, y(rgrad)=%s, e_dzdx:%s, e:%s, e_dzdx_e:%s cnt:%s" \
|
||||
% (
|
||||
f.size(), f.requires_grad, y.requires_grad, e_dzdx.requires_grad, e.requires_grad, e_dzdx_e.requires_grad, cnt)
|
||||
return approx_tr_dzdx
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self):
|
||||
super(Swish, self).__init__()
|
||||
self.beta = nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(self.beta * x)
|
||||
|
||||
|
||||
class Lambda(nn.Module):
|
||||
def __init__(self, f):
|
||||
super(Lambda, self).__init__()
|
||||
self.f = f
|
||||
|
||||
def forward(self, x):
|
||||
return self.f(x)
|
||||
|
||||
|
||||
NONLINEARITIES = {
|
||||
"tanh": nn.Tanh(),
|
||||
"relu": nn.ReLU(),
|
||||
"softplus": nn.Softplus(),
|
||||
"elu": nn.ELU(),
|
||||
"swish": Swish(),
|
||||
"square": Lambda(lambda x: x ** 2),
|
||||
"identity": Lambda(lambda x: x),
|
||||
}
|
||||
|
||||
|
||||
class ODEnet(nn.Module):
|
||||
"""
|
||||
Helper class to make neural nets for use in continuous normalizing flows
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dims, input_shape, context_dim, layer_type="concat", nonlinearity="softplus"):
|
||||
super(ODEnet, self).__init__()
|
||||
base_layer = {
|
||||
"ignore": diffeq_layers.IgnoreLinear,
|
||||
"squash": diffeq_layers.SquashLinear,
|
||||
"scale": diffeq_layers.ScaleLinear,
|
||||
"concat": diffeq_layers.ConcatLinear,
|
||||
"concat_v2": diffeq_layers.ConcatLinear_v2,
|
||||
"concatsquash": diffeq_layers.ConcatSquashLinear,
|
||||
"concatscale": diffeq_layers.ConcatScaleLinear,
|
||||
}[layer_type]
|
||||
|
||||
# build models and add them
|
||||
layers = []
|
||||
activation_fns = []
|
||||
hidden_shape = input_shape
|
||||
|
||||
for dim_out in (hidden_dims + (input_shape[0],)):
|
||||
layer_kwargs = {}
|
||||
layer = base_layer(hidden_shape[0], dim_out, context_dim, **layer_kwargs)
|
||||
layers.append(layer)
|
||||
activation_fns.append(NONLINEARITIES[nonlinearity])
|
||||
|
||||
hidden_shape = list(copy.copy(hidden_shape))
|
||||
hidden_shape[0] = dim_out
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.activation_fns = nn.ModuleList(activation_fns[:-1])
|
||||
|
||||
def forward(self, context, y):
|
||||
dx = y
|
||||
for l, layer in enumerate(self.layers):
|
||||
dx = layer(context, dx)
|
||||
# if not last layer, use nonlinearity
|
||||
if l < len(self.layers) - 1:
|
||||
dx = self.activation_fns[l](dx)
|
||||
return dx
|
||||
|
||||
|
||||
class ODEfunc(nn.Module):
|
||||
def __init__(self, diffeq):
|
||||
super(ODEfunc, self).__init__()
|
||||
self.diffeq = diffeq
|
||||
self.divergence_fn = divergence_approx
|
||||
self.register_buffer("_num_evals", torch.tensor(0.))
|
||||
|
||||
def before_odeint(self, e=None):
|
||||
self._e = e
|
||||
self._num_evals.fill_(0)
|
||||
|
||||
def forward(self, t, states):
|
||||
y = states[0]
|
||||
t = torch.ones(y.size(0), 1).to(y) * t.clone().detach().requires_grad_(True).type_as(y)
|
||||
self._num_evals += 1
|
||||
for state in states:
|
||||
state.requires_grad_(True)
|
||||
|
||||
# Sample and fix the noise.
|
||||
if self._e is None:
|
||||
self._e = torch.randn_like(y, requires_grad=True).to(y)
|
||||
|
||||
with torch.set_grad_enabled(True):
|
||||
if len(states) == 3: # conditional CNF
|
||||
c = states[2]
|
||||
tc = torch.cat([t, c.view(y.size(0), -1)], dim=1)
|
||||
dy = self.diffeq(tc, y)
|
||||
divergence = self.divergence_fn(dy, y, e=self._e).unsqueeze(-1)
|
||||
return dy, -divergence, torch.zeros_like(c).requires_grad_(True)
|
||||
elif len(states) == 2: # unconditional CNF
|
||||
dy = self.diffeq(t, y)
|
||||
divergence = self.divergence_fn(dy, y, e=self._e).view(-1, 1)
|
||||
return dy, -divergence
|
||||
else:
|
||||
assert 0, "`len(states)` should be 2 or 3"
|
38
scripts/shapenet_airplane_ae.sh
Executable file
38
scripts/shapenet_airplane_ae.sh
Executable file
|
@ -0,0 +1,38 @@
|
|||
#! /bin/bash
|
||||
|
||||
cate="airplane"
|
||||
dims="512-512-512"
|
||||
latent_dims="256-256"
|
||||
num_blocks=1
|
||||
latent_num_blocks=1
|
||||
zdim=128
|
||||
batch_size=48
|
||||
lr=2e-3
|
||||
epochs=4000
|
||||
ds=shapenet15k
|
||||
log_name="ae/${ds}-cate${cate}"
|
||||
data_dir="data/ShapeNetCore.v2.PC15k"
|
||||
|
||||
python train.py \
|
||||
--log_name ${log_name} \
|
||||
--lr ${lr} \
|
||||
--dataset_type ${ds} \
|
||||
--data_dir ${data_dir} \
|
||||
--cates ${cate} \
|
||||
--dims ${dims} \
|
||||
--latent_dims ${latent_dims} \
|
||||
--num_blocks ${num_blocks} \
|
||||
--latent_num_blocks ${latent_num_blocks} \
|
||||
--batch_size ${batch_size} \
|
||||
--zdim ${zdim} \
|
||||
--epochs ${epochs} \
|
||||
--save_freq 50 \
|
||||
--viz_freq 1 \
|
||||
--log_freq 1 \
|
||||
--val_freq 10 \
|
||||
--use_deterministic_encoder \
|
||||
--prior_weight 0 \
|
||||
--entropy_weight 0
|
||||
|
||||
echo "Done"
|
||||
exit 0
|
39
scripts/shapenet_airplane_ae_dist.sh
Executable file
39
scripts/shapenet_airplane_ae_dist.sh
Executable file
|
@ -0,0 +1,39 @@
|
|||
#! /bin/bash
|
||||
|
||||
cate="airplane"
|
||||
dims="512-512-512"
|
||||
latent_dims="256-256"
|
||||
num_blocks=1
|
||||
latent_num_blocks=1
|
||||
zdim=128
|
||||
batch_size=256
|
||||
lr=2e-3
|
||||
epochs=4000
|
||||
ds=shapenet15k
|
||||
log_name="ae/${ds}-cate${cate}"
|
||||
data_dir="data/ShapeNetCore.v2.PC15k"
|
||||
|
||||
python train.py \
|
||||
--log_name ${log_name} \
|
||||
--lr ${lr} \
|
||||
--dataset_type ${ds} \
|
||||
--data_dir ${data_dir} \
|
||||
--cates ${cate} \
|
||||
--dims ${dims} \
|
||||
--latent_dims ${latent_dims} \
|
||||
--num_blocks ${num_blocks} \
|
||||
--latent_num_blocks ${latent_num_blocks} \
|
||||
--batch_size ${batch_size} \
|
||||
--zdim ${zdim} \
|
||||
--epochs ${epochs} \
|
||||
--save_freq 50 \
|
||||
--viz_freq 1 \
|
||||
--log_freq 1 \
|
||||
--val_freq 10 \
|
||||
--distributed \
|
||||
--use_deterministic_encoder \
|
||||
--prior_weight 0 \
|
||||
--entropy_weight 0
|
||||
|
||||
echo "Done"
|
||||
exit 0
|
8
scripts/shapenet_airplane_ae_test.sh
Executable file
8
scripts/shapenet_airplane_ae_test.sh
Executable file
|
@ -0,0 +1,8 @@
|
|||
#! /bin/bash
|
||||
|
||||
python test.py \
|
||||
--cates airplane \
|
||||
--resume_checkpoint pretrained_models/ae/airplane/checkpoint.pt \
|
||||
--dims 512-512-512 \
|
||||
--use_deterministic_encoder \
|
||||
--evaluate_recon
|
11
scripts/shapenet_airplane_demo.sh
Executable file
11
scripts/shapenet_airplane_demo.sh
Executable file
|
@ -0,0 +1,11 @@
|
|||
#! /bin/bash
|
||||
|
||||
python demo.py \
|
||||
--cates airplane \
|
||||
--resume_checkpoint pretrained_models/gen/airplane/checkpoint.pt \
|
||||
--dims 512-512-512 \
|
||||
--latent_dims 256-256 \
|
||||
--use_latent_flow \
|
||||
--num_sample_shapes 20 \
|
||||
--num_sample_points 2048
|
||||
|
36
scripts/shapenet_airplane_gen.sh
Executable file
36
scripts/shapenet_airplane_gen.sh
Executable file
|
@ -0,0 +1,36 @@
|
|||
#! /bin/bash
|
||||
|
||||
cate="airplane"
|
||||
dims="512-512-512"
|
||||
latent_dims="256-256"
|
||||
num_blocks=1
|
||||
latent_num_blocks=1
|
||||
zdim=128
|
||||
batch_size=16
|
||||
lr=2e-3
|
||||
epochs=4000
|
||||
ds=shapenet15k
|
||||
log_name="gen/${ds}-cate${cate}"
|
||||
data_dir="data/ShapeNetCore.v2.PC15k"
|
||||
|
||||
python train.py \
|
||||
--log_name ${log_name} \
|
||||
--lr ${lr} \
|
||||
--dataset_type ${ds} \
|
||||
--data_dir ${data_dir} \
|
||||
--cates ${cate} \
|
||||
--dims ${dims} \
|
||||
--latent_dims ${latent_dims} \
|
||||
--num_blocks ${num_blocks} \
|
||||
--latent_num_blocks ${latent_num_blocks} \
|
||||
--batch_size ${batch_size} \
|
||||
--zdim ${zdim} \
|
||||
--epochs ${epochs} \
|
||||
--save_freq 50 \
|
||||
--viz_freq 1 \
|
||||
--log_freq 1 \
|
||||
--val_freq 10 \
|
||||
--use_latent_flow
|
||||
|
||||
echo "Done"
|
||||
exit 0
|
38
scripts/shapenet_airplane_gen_dist.sh
Executable file
38
scripts/shapenet_airplane_gen_dist.sh
Executable file
|
@ -0,0 +1,38 @@
|
|||
#! /bin/bash
|
||||
|
||||
cate="airplane"
|
||||
dims="512-512-512"
|
||||
latent_dims="256-256"
|
||||
num_blocks=1
|
||||
latent_num_blocks=1
|
||||
zdim=128
|
||||
batch_size=256
|
||||
lr=2e-3
|
||||
epochs=4000
|
||||
ds=shapenet15k
|
||||
log_name="gen/${ds}-cate${cate}-seqback"
|
||||
data_dir="data/ShapeNetCore.v2.PC15k"
|
||||
|
||||
python train.py \
|
||||
--log_name ${log_name} \
|
||||
--lr ${lr} \
|
||||
--train_T False \
|
||||
--dataset_type ${ds} \
|
||||
--data_dir ${data_dir} \
|
||||
--cates ${cate} \
|
||||
--dims ${dims} \
|
||||
--latent_dims ${latent_dims} \
|
||||
--num_blocks ${num_blocks} \
|
||||
--latent_num_blocks ${latent_num_blocks} \
|
||||
--batch_size ${batch_size} \
|
||||
--zdim ${zdim} \
|
||||
--epochs ${epochs} \
|
||||
--save_freq 50 \
|
||||
--viz_freq 1 \
|
||||
--log_freq 1 \
|
||||
--val_freq 10 \
|
||||
--distributed \
|
||||
--use_latent_flow
|
||||
|
||||
echo "Done"
|
||||
exit 0
|
10
scripts/shapenet_airplane_gen_test.sh
Executable file
10
scripts/shapenet_airplane_gen_test.sh
Executable file
|
@ -0,0 +1,10 @@
|
|||
#! /bin/bash
|
||||
|
||||
python test.py \
|
||||
--cates airplane \
|
||||
--resume_checkpoint pretrained_models/gen/airplane/checkpoint.pt \
|
||||
--dims 512-512-512 \
|
||||
--latent_dims 256-256 \
|
||||
--use_latent_flow
|
||||
|
||||
|
11
scripts/shapenet_all_ae_test.sh
Executable file
11
scripts/shapenet_all_ae_test.sh
Executable file
|
@ -0,0 +1,11 @@
|
|||
#! /bin/bash
|
||||
|
||||
python test.py \
|
||||
--cates all \
|
||||
--resume_checkpoint pretrained_models/ae/all/checkpoint.pt \
|
||||
--dims 512-512-512 \
|
||||
--use_deterministic_encoder \
|
||||
--evaluate_recon \
|
||||
--resume_dataset_mean pretrained_models/ae/all/train_set_mean.npy \
|
||||
--resume_dataset_std pretrained_models/ae/all/train_set_std.npy
|
||||
|
167
test.py
Normal file
167
test.py
Normal file
|
@ -0,0 +1,167 @@
|
|||
from datasets import get_datasets, synsetid_to_cate
|
||||
from args import get_args
|
||||
from pprint import pprint
|
||||
from metrics.evaluation_metrics import EMD_CD
|
||||
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
|
||||
from metrics.evaluation_metrics import compute_all_metrics
|
||||
from collections import defaultdict
|
||||
from models.networks import PointFlow
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_test_loader(args):
|
||||
_, te_dataset = get_datasets(args)
|
||||
if args.resume_dataset_mean is not None and args.resume_dataset_std is not None:
|
||||
mean = np.load(args.resume_dataset_mean)
|
||||
std = np.load(args.resume_dataset_std)
|
||||
te_dataset.renormalize(mean, std)
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset=te_dataset, batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=0, pin_memory=True, drop_last=False)
|
||||
return loader
|
||||
|
||||
|
||||
def evaluate_recon(model, args):
|
||||
# TODO: make this memory efficient
|
||||
if 'all' in args.cates:
|
||||
cates = list(synsetid_to_cate.values())
|
||||
else:
|
||||
cates = args.cates
|
||||
all_results = {}
|
||||
cate_to_len = {}
|
||||
save_dir = os.path.dirname(args.resume_checkpoint)
|
||||
for cate in cates:
|
||||
args.cates = [cate]
|
||||
loader = get_test_loader(args)
|
||||
|
||||
all_sample = []
|
||||
all_ref = []
|
||||
for data in loader:
|
||||
idx_b, tr_pc, te_pc = data['idx'], data['train_points'], data['test_points']
|
||||
te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
|
||||
tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
|
||||
B, N = te_pc.size(0), te_pc.size(1)
|
||||
out_pc = model.reconstruct(tr_pc, num_points=N)
|
||||
m, s = data['mean'].float(), data['std'].float()
|
||||
m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
|
||||
s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
|
||||
out_pc = out_pc * s + m
|
||||
te_pc = te_pc * s + m
|
||||
|
||||
all_sample.append(out_pc)
|
||||
all_ref.append(te_pc)
|
||||
|
||||
sample_pcs = torch.cat(all_sample, dim=0)
|
||||
ref_pcs = torch.cat(all_ref, dim=0)
|
||||
cate_to_len[cate] = int(sample_pcs.size(0))
|
||||
print("Cate=%s Total Sample size:%s Ref size: %s"
|
||||
% (cate, sample_pcs.size(), ref_pcs.size()))
|
||||
|
||||
# Save it
|
||||
np.save(os.path.join(save_dir, "%s_out_smp.npy" % cate),
|
||||
sample_pcs.cpu().detach().numpy())
|
||||
np.save(os.path.join(save_dir, "%s_out_ref.npy" % cate),
|
||||
ref_pcs.cpu().detach().numpy())
|
||||
|
||||
results = EMD_CD(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True)
|
||||
results = {
|
||||
k: (v.cpu().detach().item() if not isinstance(v, float) else v)
|
||||
for k, v in results.items()}
|
||||
pprint(results)
|
||||
all_results[cate] = results
|
||||
|
||||
# Save final results
|
||||
print("="*80)
|
||||
print("All category results:")
|
||||
print("="*80)
|
||||
pprint(all_results)
|
||||
save_path = os.path.join(save_dir, "percate_results.npy")
|
||||
np.save(save_path, all_results)
|
||||
|
||||
# Compute weighted performance
|
||||
ttl_r, ttl_cnt = defaultdict(lambda: 0.), defaultdict(lambda: 0.)
|
||||
for catename, l in cate_to_len.items():
|
||||
for k, v in all_results[catename].items():
|
||||
ttl_r[k] += v * float(l)
|
||||
ttl_cnt[k] += float(l)
|
||||
ttl_res = {k: (float(ttl_r[k]) / float(ttl_cnt[k])) for k in ttl_r.keys()}
|
||||
print("="*80)
|
||||
print("Averaged results:")
|
||||
pprint(ttl_res)
|
||||
print("="*80)
|
||||
|
||||
save_path = os.path.join(save_dir, "results.npy")
|
||||
np.save(save_path, all_results)
|
||||
|
||||
|
||||
def evaluate_gen(model, args):
|
||||
loader = get_test_loader(args)
|
||||
all_sample = []
|
||||
all_ref = []
|
||||
for data in loader:
|
||||
idx_b, te_pc = data['idx'], data['test_points']
|
||||
te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
|
||||
B, N = te_pc.size(0), te_pc.size(1)
|
||||
_, out_pc = model.sample(B, N)
|
||||
|
||||
# denormalize
|
||||
m, s = data['mean'].float(), data['std'].float()
|
||||
m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
|
||||
s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
|
||||
out_pc = out_pc * s + m
|
||||
te_pc = te_pc * s + m
|
||||
|
||||
all_sample.append(out_pc)
|
||||
all_ref.append(te_pc)
|
||||
|
||||
sample_pcs = torch.cat(all_sample, dim=0)
|
||||
ref_pcs = torch.cat(all_ref, dim=0)
|
||||
print("Generation sample size:%s reference size: %s"
|
||||
% (sample_pcs.size(), ref_pcs.size()))
|
||||
|
||||
# Save the generative output
|
||||
save_dir = os.path.dirname(args.resume_checkpoint)
|
||||
np.save(os.path.join(save_dir, "model_out_smp.npy"), sample_pcs.cpu().detach().numpy())
|
||||
np.save(os.path.join(save_dir, "model_out_ref.npy"), ref_pcs.cpu().detach().numpy())
|
||||
|
||||
# Compute metrics
|
||||
results = compute_all_metrics(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True)
|
||||
results = {k: (v.cpu().detach().item()
|
||||
if not isinstance(v, float) else v) for k, v in results.items()}
|
||||
pprint(results)
|
||||
|
||||
sample_pcl_npy = sample_pcs.cpu().detach().numpy()
|
||||
ref_pcl_npy = ref_pcs.cpu().detach().numpy()
|
||||
jsd = JSD(sample_pcl_npy, ref_pcl_npy)
|
||||
print("JSD:%s" % jsd)
|
||||
|
||||
|
||||
def main(args):
|
||||
model = PointFlow(args)
|
||||
|
||||
def _transform_(m):
|
||||
return nn.DataParallel(m)
|
||||
|
||||
model = model.cuda()
|
||||
model.multi_gpu_wrapper(_transform_)
|
||||
|
||||
print("Resume Path:%s" % args.resume_checkpoint)
|
||||
checkpoint = torch.load(args.resume_checkpoint)
|
||||
model.load_state_dict(checkpoint)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
if args.evaluate_recon:
|
||||
# Evaluate reconstruction
|
||||
evaluate_recon(model, args)
|
||||
else:
|
||||
# Evaluate generation
|
||||
evaluate_gen(model, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
main(args)
|
272
train.py
Normal file
272
train.py
Normal file
|
@ -0,0 +1,272 @@
|
|||
import sys
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import warnings
|
||||
import torch.distributed
|
||||
import numpy as np
|
||||
import random
|
||||
import faulthandler
|
||||
import torch.multiprocessing as mp
|
||||
import time
|
||||
import scipy.misc
|
||||
from models.networks import PointFlow
|
||||
from torch import optim
|
||||
from args import get_args
|
||||
from torch.backends import cudnn
|
||||
from utils import AverageValueMeter, set_random_seed, apply_random_rotation, save, resume, visualize_point_clouds
|
||||
from tensorboardX import SummaryWriter
|
||||
from datasets import get_datasets, init_np_seed
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
def main_worker(gpu, save_dir, ngpus_per_node, args):
|
||||
# basic setup
|
||||
cudnn.benchmark = True
|
||||
args.gpu = gpu
|
||||
if args.gpu is not None:
|
||||
print("Use GPU: {} for training".format(args.gpu))
|
||||
|
||||
if args.distributed:
|
||||
if args.dist_url == "env://" and args.rank == -1:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
if args.distributed:
|
||||
args.rank = args.rank * ngpus_per_node + gpu
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
|
||||
if args.log_name is not None:
|
||||
log_dir = "runs/%s" % args.log_name
|
||||
else:
|
||||
log_dir = "runs/time-%d" % time.time()
|
||||
|
||||
if not args.distributed or (args.rank % ngpus_per_node == 0):
|
||||
writer = SummaryWriter(logdir=log_dir)
|
||||
else:
|
||||
writer = None
|
||||
|
||||
if not args.use_latent_flow: # auto-encoder only
|
||||
args.prior_weight = 0
|
||||
args.entropy_weight = 0
|
||||
|
||||
# multi-GPU setup
|
||||
model = PointFlow(args)
|
||||
if args.distributed: # Multiple processes, single GPU per process
|
||||
if args.gpu is not None:
|
||||
def _transform_(m):
|
||||
return nn.parallel.DistributedDataParallel(
|
||||
m, device_ids=[args.gpu], output_device=args.gpu, check_reduction=True)
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
model.cuda(args.gpu)
|
||||
model.multi_gpu_wrapper(_transform_)
|
||||
args.batch_size = int(args.batch_size / ngpus_per_node)
|
||||
args.workers = 0
|
||||
else:
|
||||
assert 0, "DistributedDataParallel constructor should always set the single device scope"
|
||||
elif args.gpu is not None: # Single process, single GPU per process
|
||||
torch.cuda.set_device(args.gpu)
|
||||
model = model.cuda(args.gpu)
|
||||
else: # Single process, multiple GPUs per process
|
||||
def _transform_(m):
|
||||
return nn.DataParallel(m)
|
||||
model = model.cuda()
|
||||
model.multi_gpu_wrapper(_transform_)
|
||||
|
||||
# resume checkpoints
|
||||
start_epoch = 0
|
||||
optimizer = model.make_optimizer(args)
|
||||
if args.resume_checkpoint is None and os.path.exists(os.path.join(save_dir, 'checkpoint-latest.pt')):
|
||||
args.resume_checkpoint = os.path.join(save_dir, 'checkpoint-latest.pt') # use the latest checkpoint
|
||||
if args.resume_checkpoint is not None:
|
||||
if args.resume_optimizer:
|
||||
model, optimizer, start_epoch = resume(
|
||||
args.resume_checkpoint, model, optimizer, strict=(not args.resume_non_strict))
|
||||
else:
|
||||
model, _, start_epoch = resume(
|
||||
args.resume_checkpoint, model, optimizer=None, strict=(not args.resume_non_strict))
|
||||
print('Resumed from: ' + args.resume_checkpoint)
|
||||
|
||||
# initialize datasets and loaders
|
||||
tr_dataset, te_dataset = get_datasets(args)
|
||||
if args.distributed:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(tr_dataset)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset=tr_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||
num_workers=0, pin_memory=True, sampler=train_sampler, drop_last=True,
|
||||
worker_init_fn=init_np_seed)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
dataset=te_dataset, batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=0, pin_memory=True, drop_last=False,
|
||||
worker_init_fn=init_np_seed)
|
||||
|
||||
# save dataset statistics
|
||||
if not args.distributed or (args.rank % ngpus_per_node == 0):
|
||||
np.save(os.path.join(save_dir, "train_set_mean.npy"), tr_dataset.all_points_mean)
|
||||
np.save(os.path.join(save_dir, "train_set_std.npy"), tr_dataset.all_points_std)
|
||||
np.save(os.path.join(save_dir, "train_set_idx.npy"), np.array(tr_dataset.shuffle_idx))
|
||||
np.save(os.path.join(save_dir, "val_set_mean.npy"), te_dataset.all_points_mean)
|
||||
np.save(os.path.join(save_dir, "val_set_std.npy"), te_dataset.all_points_std)
|
||||
np.save(os.path.join(save_dir, "val_set_idx.npy"), np.array(te_dataset.shuffle_idx))
|
||||
|
||||
# load classification dataset if needed
|
||||
if args.eval_classification:
|
||||
from datasets import get_clf_datasets
|
||||
|
||||
def _make_data_loader_(dataset):
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=dataset, batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=0, pin_memory=True, drop_last=False,
|
||||
worker_init_fn=init_np_seed
|
||||
)
|
||||
|
||||
clf_datasets = get_clf_datasets(args)
|
||||
clf_loaders = {
|
||||
k: [_make_data_loader_(ds) for ds in ds_lst] for k, ds_lst in clf_datasets.items()
|
||||
}
|
||||
else:
|
||||
clf_loaders = None
|
||||
|
||||
# initialize the learning rate scheduler
|
||||
if args.scheduler == 'exponential':
|
||||
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay)
|
||||
elif args.scheduler == 'step':
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs // 2, gamma=0.1)
|
||||
elif args.scheduler == 'linear':
|
||||
def lambda_rule(ep):
|
||||
lr_l = 1.0 - max(0, ep - 0.5 * args.epochs) / float(0.5 * args.epochs)
|
||||
return lr_l
|
||||
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
||||
else:
|
||||
assert 0, "args.schedulers should be either 'exponential' or 'linear'"
|
||||
|
||||
# main training loop
|
||||
start_time = time.time()
|
||||
entropy_avg_meter = AverageValueMeter()
|
||||
latent_nats_avg_meter = AverageValueMeter()
|
||||
point_nats_avg_meter = AverageValueMeter()
|
||||
if args.distributed:
|
||||
print("[Rank %d] World size : %d" % (args.rank, dist.get_world_size()))
|
||||
|
||||
print("Start epoch: %d End epoch: %d" % (start_epoch, args.epochs))
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
# adjust the learning rate
|
||||
if (epoch + 1) % args.exp_decay_freq == 0:
|
||||
scheduler.step(epoch=epoch)
|
||||
if writer is not None:
|
||||
writer.add_scalar('lr/optimizer', scheduler.get_lr()[0], epoch)
|
||||
|
||||
# train for one epoch
|
||||
for bidx, data in enumerate(train_loader):
|
||||
idx_batch, tr_batch, te_batch = data['idx'], data['train_points'], data['test_points']
|
||||
step = bidx + len(train_loader) * epoch
|
||||
model.train()
|
||||
if args.random_rotate:
|
||||
tr_batch, _, _ = apply_random_rotation(
|
||||
tr_batch, rot_axis=train_loader.dataset.gravity_axis)
|
||||
inputs = tr_batch.cuda(args.gpu, non_blocking=True)
|
||||
out = model(inputs, optimizer, step, writer)
|
||||
entropy, prior_nats, recon_nats = out['entropy'], out['prior_nats'], out['recon_nats']
|
||||
entropy_avg_meter.update(entropy)
|
||||
point_nats_avg_meter.update(recon_nats)
|
||||
latent_nats_avg_meter.update(prior_nats)
|
||||
if step % args.log_freq == 0:
|
||||
duration = time.time() - start_time
|
||||
start_time = time.time()
|
||||
print("[Rank %d] Epoch %d Batch [%2d/%2d] Time [%3.2fs] Entropy %2.5f LatentNats %2.5f PointNats %2.5f"
|
||||
% (args.rank, epoch, bidx, len(train_loader), duration, entropy_avg_meter.avg,
|
||||
latent_nats_avg_meter.avg, point_nats_avg_meter.avg))
|
||||
|
||||
# evaluate on the validation set
|
||||
if not args.no_validation and (epoch + 1) % args.val_freq == 0:
|
||||
from utils import validate
|
||||
validate(test_loader, model, epoch, writer, save_dir, args, clf_loaders=clf_loaders)
|
||||
|
||||
# save visualizations
|
||||
if (epoch + 1) % args.viz_freq == 0:
|
||||
# reconstructions
|
||||
model.eval()
|
||||
samples = model.reconstruct(inputs)
|
||||
results = []
|
||||
for idx in range(min(10, inputs.size(0))):
|
||||
res = visualize_point_clouds(samples[idx], inputs[idx], idx,
|
||||
pert_order=train_loader.dataset.display_axis_order)
|
||||
results.append(res)
|
||||
res = np.concatenate(results, axis=1)
|
||||
scipy.misc.imsave(os.path.join(save_dir, 'images', 'tr_vis_conditioned_epoch%d-gpu%s.png' % (epoch, args.gpu)),
|
||||
res.transpose((1, 2, 0)))
|
||||
if writer is not None:
|
||||
writer.add_image('tr_vis/conditioned', torch.as_tensor(res), epoch)
|
||||
|
||||
# samples
|
||||
if args.use_latent_flow:
|
||||
num_samples = min(10, inputs.size(0))
|
||||
num_points = inputs.size(1)
|
||||
_, samples = model.sample(num_samples, num_points)
|
||||
results = []
|
||||
for idx in range(num_samples):
|
||||
res = visualize_point_clouds(samples[idx], inputs[idx], idx,
|
||||
pert_order=train_loader.dataset.display_axis_order)
|
||||
results.append(res)
|
||||
res = np.concatenate(results, axis=1)
|
||||
scipy.misc.imsave(os.path.join(save_dir, 'images', 'tr_vis_conditioned_epoch%d-gpu%s.png' % (epoch, args.gpu)),
|
||||
res.transpose((1, 2, 0)))
|
||||
if writer is not None:
|
||||
writer.add_image('tr_vis/sampled', torch.as_tensor(res), epoch)
|
||||
|
||||
# save checkpoints
|
||||
if not args.distributed or (args.rank % ngpus_per_node == 0):
|
||||
if (epoch + 1) % args.save_freq == 0:
|
||||
save(model, optimizer, epoch + 1,
|
||||
os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
|
||||
save(model, optimizer, epoch + 1,
|
||||
os.path.join(save_dir, 'checkpoint-latest.pt'))
|
||||
|
||||
|
||||
def main():
|
||||
# command line args
|
||||
args = get_args()
|
||||
save_dir = os.path.join("checkpoints", args.log_name)
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
os.makedirs(os.path.join(save_dir, 'images'))
|
||||
|
||||
with open(os.path.join(save_dir, 'command.sh'), 'w') as f:
|
||||
f.write('python -X faulthandler ' + ' '.join(sys.argv))
|
||||
f.write('\n')
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 1000000)
|
||||
set_random_seed(args.seed)
|
||||
|
||||
if args.gpu is not None:
|
||||
warnings.warn('You have chosen a specific GPU. This will completely '
|
||||
'disable data parallelism.')
|
||||
|
||||
if args.dist_url == "env://" and args.world_size == -1:
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
if args.sync_bn:
|
||||
assert args.distributed
|
||||
|
||||
print("Arguments:")
|
||||
print(args)
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
if args.distributed:
|
||||
args.world_size = ngpus_per_node * args.world_size
|
||||
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(save_dir, ngpus_per_node, args))
|
||||
else:
|
||||
main_worker(args.gpu, save_dir, ngpus_per_node, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
378
utils.py
Normal file
378
utils.py
Normal file
|
@ -0,0 +1,378 @@
|
|||
from pprint import pprint
|
||||
from sklearn.svm import LinearSVC
|
||||
from math import log, pi
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import random
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
|
||||
|
||||
class AverageValueMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0.0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def gaussian_log_likelihood(x, mean, logvar, clip=True):
|
||||
if clip:
|
||||
logvar = torch.clamp(logvar, min=-4, max=3)
|
||||
a = log(2 * pi)
|
||||
b = logvar
|
||||
c = (x - mean) ** 2 / torch.exp(logvar)
|
||||
return -0.5 * torch.sum(a + b + c)
|
||||
|
||||
|
||||
def bernoulli_log_likelihood(x, p, clip=True, eps=1e-6):
|
||||
if clip:
|
||||
p = torch.clamp(p, min=eps, max=1 - eps)
|
||||
return torch.sum((x * torch.log(p)) + ((1 - x) * torch.log(1 - p)))
|
||||
|
||||
|
||||
def kl_diagnormal_stdnormal(mean, logvar):
|
||||
a = mean ** 2
|
||||
b = torch.exp(logvar)
|
||||
c = -1
|
||||
d = -logvar
|
||||
return 0.5 * torch.sum(a + b + c + d)
|
||||
|
||||
|
||||
def kl_diagnormal_diagnormal(q_mean, q_logvar, p_mean, p_logvar):
|
||||
# Ensure correct shapes since no numpy broadcasting yet
|
||||
p_mean = p_mean.expand_as(q_mean)
|
||||
p_logvar = p_logvar.expand_as(q_logvar)
|
||||
|
||||
a = p_logvar
|
||||
b = - 1
|
||||
c = - q_logvar
|
||||
d = ((q_mean - p_mean) ** 2 + torch.exp(q_logvar)) / torch.exp(p_logvar)
|
||||
return 0.5 * torch.sum(a + b + c + d)
|
||||
|
||||
|
||||
# Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
|
||||
def truncated_normal(tensor, mean=0, std=1, trunc_std=2):
|
||||
size = tensor.shape
|
||||
tmp = tensor.new_empty(size + (4,)).normal_()
|
||||
valid = (tmp < trunc_std) & (tmp > -trunc_std)
|
||||
ind = valid.max(-1, keepdim=True)[1]
|
||||
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
|
||||
tensor.data.mul_(std).add_(mean)
|
||||
return tensor
|
||||
|
||||
|
||||
def reduce_tensor(tensor, world_size=None):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
if world_size is None:
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
rt /= world_size
|
||||
return rt
|
||||
|
||||
|
||||
def standard_normal_logprob(z):
|
||||
dim = z.size(-1)
|
||||
log_z = -0.5 * dim * log(2 * pi)
|
||||
return log_z - z.pow(2) / 2
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
"""set random seed"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
# Visualization
|
||||
def visualize_point_clouds(pts, gtr, idx, pert_order=[0, 1, 2]):
|
||||
pts = pts.cpu().detach().numpy()[:, pert_order]
|
||||
gtr = gtr.cpu().detach().numpy()[:, pert_order]
|
||||
|
||||
fig = plt.figure(figsize=(6, 3))
|
||||
ax1 = fig.add_subplot(121, projection='3d')
|
||||
ax1.set_title("Sample:%s" % idx)
|
||||
ax1.scatter(pts[:, 0], pts[:, 1], pts[:, 2], s=5)
|
||||
|
||||
ax2 = fig.add_subplot(122, projection='3d')
|
||||
ax2.set_title("Ground Truth:%s" % idx)
|
||||
ax2.scatter(gtr[:, 0], gtr[:, 1], gtr[:, 2], s=5)
|
||||
|
||||
fig.canvas.draw()
|
||||
|
||||
# grab the pixel buffer and dump it into a numpy array
|
||||
res = np.array(fig.canvas.renderer._renderer)
|
||||
res = np.transpose(res, (2, 0, 1))
|
||||
|
||||
plt.close()
|
||||
return res
|
||||
|
||||
|
||||
# Augmentation
|
||||
def apply_random_rotation(pc, rot_axis=1):
|
||||
B = pc.shape[0]
|
||||
|
||||
theta = np.random.rand(B) * 2 * np.pi
|
||||
zeros = np.zeros(B)
|
||||
ones = np.ones(B)
|
||||
cos = np.cos(theta)
|
||||
sin = np.sin(theta)
|
||||
|
||||
if rot_axis == 0:
|
||||
rot = np.stack([
|
||||
cos, -sin, zeros,
|
||||
sin, cos, zeros,
|
||||
zeros, zeros, ones
|
||||
]).T.reshape(B, 3, 3)
|
||||
elif rot_axis == 1:
|
||||
rot = np.stack([
|
||||
cos, zeros, -sin,
|
||||
zeros, ones, zeros,
|
||||
sin, zeros, cos
|
||||
]).T.reshape(B, 3, 3)
|
||||
elif rot_axis == 2:
|
||||
rot = np.stack([
|
||||
ones, zeros, zeros,
|
||||
zeros, cos, -sin,
|
||||
zeros, sin, cos
|
||||
]).T.reshape(B, 3, 3)
|
||||
else:
|
||||
raise Exception("Invalid rotation axis")
|
||||
rot = torch.from_numpy(rot).to(pc)
|
||||
|
||||
# (B, N, 3) mul (B, 3, 3) -> (B, N, 3)
|
||||
pc_rotated = torch.bmm(pc, rot)
|
||||
return pc_rotated, rot, theta
|
||||
|
||||
|
||||
def validate_classification(loaders, model, args):
|
||||
train_loader, test_loader = loaders
|
||||
|
||||
def _make_iter_(loader):
|
||||
iterator = iter(loader)
|
||||
return iterator
|
||||
|
||||
tr_latent = []
|
||||
tr_label = []
|
||||
for data in _make_iter_(train_loader):
|
||||
tr_pc = data['train_points']
|
||||
tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
|
||||
latent = model.encode(tr_pc)
|
||||
label = data['cate_idx']
|
||||
tr_latent.append(latent.cpu().detach().numpy())
|
||||
tr_label.append(label.cpu().detach().numpy())
|
||||
tr_label = np.concatenate(tr_label)
|
||||
tr_latent = np.concatenate(tr_latent)
|
||||
|
||||
te_latent = []
|
||||
te_label = []
|
||||
for data in _make_iter_(test_loader):
|
||||
tr_pc = data['train_points']
|
||||
tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
|
||||
latent = model.encode(tr_pc)
|
||||
label = data['cate_idx']
|
||||
te_latent.append(latent.cpu().detach().numpy())
|
||||
te_label.append(label.cpu().detach().numpy())
|
||||
te_label = np.concatenate(te_label)
|
||||
te_latent = np.concatenate(te_latent)
|
||||
|
||||
clf = LinearSVC(random_state=0)
|
||||
clf.fit(tr_latent, tr_label)
|
||||
test_pred = clf.predict(te_latent)
|
||||
test_gt = te_label.flatten()
|
||||
acc = np.mean((test_pred == test_gt).astype(float)) * 100.
|
||||
res = {'acc': acc}
|
||||
print("Acc:%s" % acc)
|
||||
return res
|
||||
|
||||
|
||||
def validate_conditioned(loader, model, args, max_samples=None, save_dir=None):
|
||||
from metrics.evaluation_metrics import EMD_CD
|
||||
all_idx = []
|
||||
all_sample = []
|
||||
all_ref = []
|
||||
ttl_samples = 0
|
||||
iterator = iter(loader)
|
||||
|
||||
for data in iterator:
|
||||
# idx_b, tr_pc, te_pc = data[:3]
|
||||
idx_b, tr_pc, te_pc = data['idx'], data['train_points'], data['test_points']
|
||||
tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
|
||||
te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
|
||||
|
||||
if tr_pc.size(1) > te_pc.size(1):
|
||||
tr_pc = tr_pc[:, :te_pc.size(1), :]
|
||||
out_pc = model.reconstruct(tr_pc, num_points=te_pc.size(1))
|
||||
|
||||
# denormalize
|
||||
m, s = data['mean'].float(), data['std'].float()
|
||||
m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
|
||||
s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
|
||||
out_pc = out_pc * s + m
|
||||
te_pc = te_pc * s + m
|
||||
|
||||
all_sample.append(out_pc)
|
||||
all_ref.append(te_pc)
|
||||
all_idx.append(idx_b)
|
||||
|
||||
ttl_samples += int(te_pc.size(0))
|
||||
if max_samples is not None and ttl_samples >= max_samples:
|
||||
break
|
||||
|
||||
# Compute MMD and CD
|
||||
sample_pcs = torch.cat(all_sample, dim=0)
|
||||
ref_pcs = torch.cat(all_ref, dim=0)
|
||||
print("[rank %s] Recon Sample size:%s Ref size: %s" % (args.rank, sample_pcs.size(), ref_pcs.size()))
|
||||
|
||||
if save_dir is not None and args.save_val_results:
|
||||
smp_pcs_save_name = os.path.join(save_dir, "smp_recon_pcls_gpu%s.npy" % args.gpu)
|
||||
ref_pcs_save_name = os.path.join(save_dir, "ref_recon_pcls_gpu%s.npy" % args.gpu)
|
||||
np.save(smp_pcs_save_name, sample_pcs.cpu().detach().numpy())
|
||||
np.save(ref_pcs_save_name, ref_pcs.cpu().detach().numpy())
|
||||
print("Saving file:%s %s" % (smp_pcs_save_name, ref_pcs_save_name))
|
||||
|
||||
res = EMD_CD(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True)
|
||||
mmd_cd = res['MMD-CD'] if 'MMD-CD' in res else None
|
||||
mmd_emd = res['MMD-EMD'] if 'MMD-EMD' in res else None
|
||||
|
||||
print("MMD-CD :%s" % mmd_cd)
|
||||
print("MMD-EMD :%s" % mmd_emd)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def validate_sample(loader, model, args, max_samples=None, save_dir=None):
|
||||
from metrics.evaluation_metrics import compute_all_metrics, jsd_between_point_cloud_sets as JSD
|
||||
all_sample = []
|
||||
all_ref = []
|
||||
ttl_samples = 0
|
||||
|
||||
iterator = iter(loader)
|
||||
|
||||
for data in iterator:
|
||||
idx_b, te_pc = data['idx'], data['test_points']
|
||||
te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
|
||||
_, out_pc = model.sample(te_pc.size(0), te_pc.size(1), gpu=args.gpu)
|
||||
|
||||
# denormalize
|
||||
m, s = data['mean'].float(), data['std'].float()
|
||||
m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
|
||||
s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
|
||||
out_pc = out_pc * s + m
|
||||
te_pc = te_pc * s + m
|
||||
|
||||
all_sample.append(out_pc)
|
||||
all_ref.append(te_pc)
|
||||
|
||||
ttl_samples += int(te_pc.size(0))
|
||||
if max_samples is not None and ttl_samples >= max_samples:
|
||||
break
|
||||
|
||||
sample_pcs = torch.cat(all_sample, dim=0)
|
||||
ref_pcs = torch.cat(all_ref, dim=0)
|
||||
print("[rank %s] Generation Sample size:%s Ref size: %s"
|
||||
% (args.rank, sample_pcs.size(), ref_pcs.size()))
|
||||
|
||||
if save_dir is not None and args.save_val_results:
|
||||
smp_pcs_save_name = os.path.join(save_dir, "smp_syn_pcls_gpu%s.npy" % args.gpu)
|
||||
ref_pcs_save_name = os.path.join(save_dir, "ref_syn_pcls_gpu%s.npy" % args.gpu)
|
||||
np.save(smp_pcs_save_name, sample_pcs.cpu().detach().numpy())
|
||||
np.save(ref_pcs_save_name, ref_pcs.cpu().detach().numpy())
|
||||
print("Saving file:%s %s" % (smp_pcs_save_name, ref_pcs_save_name))
|
||||
|
||||
res = compute_all_metrics(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True)
|
||||
pprint(res)
|
||||
|
||||
sample_pcs = sample_pcs.cpu().detach().numpy()
|
||||
ref_pcs = ref_pcs.cpu().detach().numpy()
|
||||
jsd = JSD(sample_pcs, ref_pcs)
|
||||
jsd = torch.tensor(jsd).cuda() if args.gpu is None else torch.tensor(jsd).cuda(args.gpu)
|
||||
res.update({"JSD": jsd})
|
||||
print("JSD :%s" % jsd)
|
||||
return res
|
||||
|
||||
|
||||
def save(model, optimizer, epoch, path):
|
||||
d = {
|
||||
'epoch': epoch,
|
||||
'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict()
|
||||
}
|
||||
torch.save(d, path)
|
||||
|
||||
|
||||
def resume(path, model, optimizer=None, strict=True):
|
||||
ckpt = torch.load(path)
|
||||
model.load_state_dict(ckpt['model'], strict=strict)
|
||||
start_epoch = ckpt['epoch']
|
||||
if optimizer is not None:
|
||||
optimizer.load_state_dict(ckpt['optimizer'])
|
||||
return model, optimizer, start_epoch
|
||||
|
||||
|
||||
def validate(test_loader, model, epoch, writer, save_dir, args, clf_loaders=None):
|
||||
model.eval()
|
||||
|
||||
# Make epoch wise save directory
|
||||
if writer is not None and args.save_val_results:
|
||||
save_dir = os.path.join(save_dir, 'epoch-%d' % epoch)
|
||||
if not os.path.isdir(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
else:
|
||||
save_dir = None
|
||||
|
||||
# classification
|
||||
if args.eval_classification and clf_loaders is not None:
|
||||
for clf_expr, loaders in clf_loaders.items():
|
||||
with torch.no_grad():
|
||||
clf_val_res = validate_classification(loaders, model, args)
|
||||
|
||||
for k, v in clf_val_res.items():
|
||||
if writer is not None and v is not None:
|
||||
writer.add_scalar('val_%s/%s' % (clf_expr, k), v, epoch)
|
||||
|
||||
# samples
|
||||
if args.use_latent_flow:
|
||||
with torch.no_grad():
|
||||
val_sample_res = validate_sample(
|
||||
test_loader, model, args, max_samples=args.max_validate_shapes,
|
||||
save_dir=save_dir)
|
||||
|
||||
for k, v in val_sample_res.items():
|
||||
if not isinstance(v, float):
|
||||
v = v.cpu().detach().item()
|
||||
if writer is not None and v is not None:
|
||||
writer.add_scalar('val_sample/%s' % k, v, epoch)
|
||||
|
||||
# reconstructions
|
||||
with torch.no_grad():
|
||||
val_res = validate_conditioned(
|
||||
test_loader, model, args, max_samples=args.max_validate_shapes,
|
||||
save_dir=save_dir)
|
||||
for k, v in val_res.items():
|
||||
if not isinstance(v, float):
|
||||
v = v.cpu().detach().item()
|
||||
if writer is not None and v is not None:
|
||||
writer.add_scalar('val_conditioned/%s' % k, v, epoch)
|
||||
|
Loading…
Reference in a new issue